faenet.fa_forward

Module Contents

Functions

model_forward(batch, model, frame_averaging[, mode, ...])

Perform a model forward pass when frame averaging is applied.

faenet.fa_forward.model_forward(batch, model, frame_averaging, mode='train', crystal_task=True)[source]

Perform a model forward pass when frame averaging is applied.

Parameters:
  • batch (data.Batch) – batch of graphs with attributes: - original atom positions (pos) - batch indices (to which graph in batch each atom belongs to) (batch) - frame averaged positions, cell and rotation matrices (fa_pos, fa_cell, fa_rot)

  • model – model instance

  • frame_averaging (str) – symmetry preserving method (already) applied (“2D”, “3D”, “DA”, “”)

  • mode (str, optional) – model mode. Defaults to “train”. (“train”, “eval”)

  • crystal_task (bool, optional) – Whether crystals (molecules) are considered. If they are, the unit cell (3x3) is affected by frame averaged and expected as attribute. (default: True)

Returns:

model predictions tensor for “energy” and “forces”.

Return type:

(dict)