faenet.fa_forward
Module Contents
Functions
|
Performs a model forward pass when frame averaging is applied, |
- faenet.fa_forward.model_forward(batch, model, frame_averaging, mode='train', crystal_task=True)[source]
Performs a model forward pass when frame averaging is applied, enabling to get equivariant or invariant predictions.
- Parameters:
batch (data.Batch) –
batch of graphs with attributes:
pos: original atom positions
batch: batch indices (to which graph in batch each atom belongs to)
fa_pos, fa_cell, fa_rot: frame averaging positions, cell and rotation matrices
model – ML model instance
frame_averaging (str) – symmetry preserving transform (already applied). Can be 2D FA, 3D FA, Data Augmentation or None, respectively denoted by (“2D”, “3D”, “DA”, “”).
mode (str, optional) – model mode: train or inference. (default:
"train"
)crystal_task (bool, optional) – Whether crystals (molecules) are considered. If they are, the unit cell (3x3) is affected by frame averaged and expected as attribute. (default:
True
)
- Returns:
- dictionary of model predictions:
energy (torch.Tensor): any graph-level property (e.g. energy)
forces (torch.Tensor): any node-level property (e.g. forces)
- Return type:
(dict)