faenet

Submodules

Package Contents

Classes

FAENet

Non-symmetry preserving GNN model for 3D atomic systems,

FrameAveraging

Frame Averaging (FA) Transform for (PyG) Data objects (e.g. 3D atomic graphs).

Functions

frame_averaging_2D(pos[, cell, fa_method, check])

Computes new positions for the graph atoms,

frame_averaging_3D(pos[, cell, fa_method, check])

Computes new positions for the graph atoms using PCA

model_forward(batch, model, frame_averaging[, mode, ...])

Perform a forward pass of the model when frame averaging is applied.

Attributes

__version__

class faenet.FAENet(cutoff=6.0, act='swish', preprocess='pbc_preprocess', complex_mp=False, max_num_neighbors=40, num_gaussians=50, num_filters=128, hidden_channels=128, tag_hidden_channels=32, pg_hidden_channels=32, phys_hidden_channels=0, phys_embeds=True, num_interactions=4, mp_type='updownscale_base', graph_norm=True, second_layer_MLP=True, skip_co='concat', energy_head=None, regress_forces=None, force_decoder_type='mlp', force_decoder_model_config={'hidden_channels': 128})

Bases: faenet.base_model.BaseModel

Non-symmetry preserving GNN model for 3D atomic systems, called FAENet: Frame Averaging Equivariant Network.

Parameters:
  • cutoff (float) – Cutoff distance for interatomic interactions. (default: 6.0)

  • preprocess (callable) – Pre-processing function for the data. This function should accept a data object as input and return a tuple containing the following: atomic numbers, batch indices, final adjacency, relative positions, pairwise distances. Examples of valid preprocessing functions include pbc_preprocess, base_preprocess, or custom functions.

  • act (str) – Activation function (default: swish)

  • max_num_neighbors (int) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 40)

  • hidden_channels (int) – Hidden embedding size. (default: 128)

  • tag_hidden_channels (int) – Hidden tag embedding size. (default: 32)

  • pg_hidden_channels (int) – Hidden period and group embedding size. (default: 32)

  • phys_embeds (bool) – Do we include fixed physics-aware embeddings. (default: :obj: True)

  • phys_hidden_channels (int) – Hidden size of learnable physics-aware embeddings. (default: 0)

  • num_interactions (int) – The number of interaction (i.e. message passing) blocks. (default: 4)

  • num_gaussians (int) – The number of gaussians \(\mu\) to encode distance info. (default: 50)

  • num_filters (int) – The size of convolutional filters. (default: 128)

  • second_layer_MLP (bool) – Use 2-layers MLP at the end of the Embedding block. (default: False)

  • skip_co (str) – Add a skip connection between each interaction block and energy-head. (False, “add”, “concat”, “concat_atom”)

  • mp_type (str) – Specificies the Message Passing type of the interaction block. (“base”, “updownscale_base”, “updownscale”, “updown_local_env”, “simple”):

  • graph_norm (bool) – Whether to apply batch norm after every linear layer. (default: True)

  • complex_mp (bool) – (default: True)

  • energy_head (str) – Method to compute energy prediction from atom representations. (None, “weighted-av-initial-embeds”, “weighted-av-final-embeds”)

  • regress_forces (str) – Specifies if we predict forces or not, and how do we predict them. (None or “”, “direct”, “direct_with_gradient_target”)

  • force_decoder_type (str) – Specifies the type of force decoder (“simple”, “mlp”, “res”, “res_updown”)

  • force_decoder_model_config (dict) – contains information about the for decoder architecture (e.g. number of layers, hidden size).

energy_forward(data, preproc=True)

Predicts any graph-level properties (e.g. energy) for 3D atomic systems.

Parameters:
  • data (data.Batch) – Batch of graphs datapoints.

  • preproc (bool) – Whether to preprocess the graph. Default to True.

Returns:

predicted properties for each graph (e.g. energy)

Return type:

dict

forces_forward(preds)

Predicts forces for 3D atomic systems. Can be utilised to predict any atom-level property.

Parameters:

preds (dict) – dictionnary with predicted properties for each graph.

Returns:

additional predicted properties, at an atom-level (e.g. forces)

Return type:

dict

class faenet.FrameAveraging(frame_averaging=None, fa_method=None)

Bases: Transform

Frame Averaging (FA) Transform for (PyG) Data objects (e.g. 3D atomic graphs).

Computes new atomic positions (fa_pos) for all datapoints, as well as new unit cells (fa_cell) attributes for crystal structures, when applicable. The rotation matrix (fa_rot) used for the frame averaging is also stored.

Parameters:
  • frame_averaging (str) – Transform method used. Can be 2D FA, 3D FA, Data Augmentation or no FA, respectively denoted by (“2D”, “3D”, “DA”, “”)

  • fa_method (str) – the actual frame averaging technique used. “stochastic” refers to sampling one frame at random (at each epoch), “det” to chosing deterministically one frame, and “all” to using all frames. The prefix “se3-” refers to the SE(3) equivariant version of the method. “” means that no frame averaging is used. (“”, “stochastic”, “all”, “det”, “se3-stochastic”, “se3-all”, “se3-det”)

Returns:

updated data object with new positions (+ unit cell) attributes and the rotation matrices used for the frame averaging transform.

Return type:

(data.Data)

__call__(data)

The only requirement for the data is to have a pos attribute.

faenet.__version__
faenet.frame_averaging_2D(pos, cell=None, fa_method='stochastic', check=False)

Computes new positions for the graph atoms, based on a frame averaging building on PCA.

Parameters:
  • pos (tensor) – positions of atoms in the graph

  • cell (tensor) – unit cell of the graph. None if no pbc.

  • fa_method (str) – FA method used (stochastic, det, all, se3)

  • check (bool) – check if constraints are satisfied. Default: False.

Returns:

updated atom positions tensor: updated unit cell tensor: the rotation matrix used (PCA)

Return type:

tensor

faenet.frame_averaging_3D(pos, cell=None, fa_method='stochastic', check=False)

Computes new positions for the graph atoms using PCA

Parameters:
  • pos (tensor) – positions of atoms in the graph

  • cell (tensor) – unit cell of the graph. None if no pbc.

  • fa_method (str) – FA method used (stochastic, det, all, se3-all, se3-det, se3-stochastic)

  • check (bool) – check if constraints are satisfied. Default: False.

Returns:

updated atom positions tensor: updated unit cell tensor: the rotation matrix used (PCA)

Return type:

tensor

faenet.model_forward(batch, model, frame_averaging, mode='train', crystal_task=True)

Perform a forward pass of the model when frame averaging is applied.

Parameters:
  • batch (data.Batch) – batch of graphs with attributes: - original atom positions (pos) - batch indices (to which graph in batch each atom belongs to) (batch) - frame averaged positions, cell and rotation matrices

  • model – model instance

  • frame_averaging (str) – symmetry preserving method (already) applied (“2D”, “3D”, “DA”, “”)

  • mode (str, optional) – model mode. Defaults to “train”. (“train”, “inference”)

  • crystal_task (bool, optional) – Whether crystals (molecules) are considered. If they are, the unit cell is affected by frame averaged and expected as attribute. (default: True)

Returns:

model predictions tensor for “energy” and “forces”.

Return type:

(dict)