faenet.fa_forward

Module Contents

Functions

model_forward(batch, model, frame_averaging[, mode, ...])

Performs a model forward pass when frame averaging is applied,

faenet.fa_forward.model_forward(batch, model, frame_averaging, mode='train', crystal_task=True)[source]

Performs a model forward pass when frame averaging is applied, enabling to get equivariant or invariant predictions.

Parameters:
  • batch (data.Batch) –

    batch of graphs with attributes:

    • pos: original atom positions

    • batch: batch indices (to which graph in batch each atom belongs to)

    • fa_pos, fa_cell, fa_rot: frame averaging positions, cell and rotation matrices

  • model – ML model instance

  • frame_averaging (str) – symmetry preserving transform (already applied). Can be 2D FA, 3D FA, Data Augmentation or None, respectively denoted by (“2D”, “3D”, “DA”, “”).

  • mode (str, optional) – model mode: train or inference. (default: "train")

  • crystal_task (bool, optional) – Whether crystals (molecules) are considered. If they are, the unit cell (3x3) is affected by frame averaged and expected as attribute. (default: True)

Returns:

dictionary of model predictions:
  • energy (torch.Tensor): any graph-level property (e.g. energy)

  • forces (torch.Tensor): any node-level property (e.g. forces)

Return type:

(dict)