faenet.fa_forward
Module Contents
Functions
|
Perform a model forward pass when frame averaging is applied. |
- faenet.fa_forward.model_forward(batch, model, frame_averaging, mode='train', crystal_task=True)[source]
Perform a model forward pass when frame averaging is applied.
- Parameters:
batch (data.Batch) – batch of graphs with attributes: - original atom positions (pos) - batch indices (to which graph in batch each atom belongs to) (batch) - frame averaged positions, cell and rotation matrices (fa_pos, fa_cell, fa_rot)
model – model instance
frame_averaging (str) – symmetry preserving method (already) applied (“2D”, “3D”, “DA”, “”)
mode (str, optional) – model mode. Defaults to “train”. (“train”, “eval”)
crystal_task (bool, optional) – Whether crystals (molecules) are considered. If they are, the unit cell (3x3) is affected by frame averaged and expected as attribute. (default:
True
)
- Returns:
model predictions tensor for “energy” and “forces”.
- Return type:
(dict)