faenet.model

FAENet: Frame Averaging Equivariant graph neural Network Simple, scalable and expressive model for property prediction on 3D atomic systems.

Module Contents

Classes

EmbeddingBlock

Initialise atom and edge representations.

FAENet

Non-symmetry preserving GNN model for 3D atomic systems,

InteractionBlock

Updates atom representations through custom message passing.

OutputBlock

Compute task-specific predictions from final atom representations.

class faenet.model.EmbeddingBlock(num_gaussians, num_filters, hidden_channels, tag_hidden_channels, pg_hidden_channels, phys_hidden_channels, phys_embeds, act, second_layer_MLP)[source]

Bases: torch.nn.Module

Initialise atom and edge representations.

forward(z, rel_pos, edge_attr, tag=None, subnodes=None)[source]

Forward pass of the Embedding block. Called in FAENet to generate initial atom and edge representations.

Parameters:
  • z (tensor) – atomic numbers. (num_atoms, )

  • rel_pos (tensor) – relative atomic positions. (num_edges, 3)

  • edge_attr (tensor) – RBF of pairwise distances. (num_edges, num_gaussians)

  • tag (tensor, optional) – atom information specific to OCP. Defaults to None.

Returns:

atom embeddings, edge embeddings

Return type:

(tensor, tensor)

reset_parameters()[source]
class faenet.model.FAENet(cutoff=6.0, preprocess='pbc_preprocess', act='swish', max_num_neighbors=40, hidden_channels=128, tag_hidden_channels=32, pg_hidden_channels=32, phys_embeds=True, phys_hidden_channels=0, num_interactions=4, num_gaussians=50, num_filters=128, second_layer_MLP=True, skip_co='concat', mp_type='updownscale_base', graph_norm=True, complex_mp=False, energy_head=None, out_dim=1, pred_as_dict=True, regress_forces=None, force_decoder_type='mlp', force_decoder_model_config={'hidden_channels': 128})[source]

Bases: faenet.base_model.BaseModel

Non-symmetry preserving GNN model for 3D atomic systems, called FAENet: Frame Averaging Equivariant Network.

Parameters:
  • cutoff (float) – Cutoff distance for interatomic interactions. (default: 6.0)

  • preprocess (callable) – Pre-processing function for the data. This function should accept a data object as input and return a tuple containing the following: atomic numbers, batch indices, final adjacency, relative positions, pairwise distances. Examples of valid preprocessing functions include pbc_preprocess, base_preprocess, or custom functions.

  • act (str) – Activation function (default: swish)

  • max_num_neighbors (int) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 40)

  • hidden_channels (int) – Hidden embedding size. (default: 128)

  • tag_hidden_channels (int) – Hidden tag embedding size. (default: 32)

  • pg_hidden_channels (int) – Hidden period and group embedding size. (default: 32)

  • phys_embeds (bool) – Do we include fixed physics-aware embeddings. (default: :obj: True)

  • phys_hidden_channels (int) – Hidden size of learnable physics-aware embeddings. (default: 0)

  • num_interactions (int) – The number of interaction (i.e. message passing) blocks. (default: 4)

  • num_gaussians (int) – The number of gaussians \(\mu\) to encode distance info. (default: 50)

  • num_filters (int) – The size of convolutional filters. (default: 128)

  • second_layer_MLP (bool) – Use 2-layers MLP at the end of the Embedding block. (default: False)

  • skip_co (str) – Add a skip connection between each interaction block and energy-head. (False, “add”, “concat”, “concat_atom”)

  • mp_type (str) – Specificies the Message Passing type of the interaction block. (“base”, “updownscale_base”, “updownscale”, “updown_local_env”, “simple”):

  • graph_norm (bool) – Whether to apply batch norm after every linear layer. (default: True)

  • complex_mp (bool) – (default: True)

  • energy_head (str) – Method to compute energy prediction from atom representations. (None, “weighted-av-initial-embeds”, “weighted-av-final-embeds”)

  • out_dim (int) – size of the output tensor for graph-level predicted properties (“energy”) Allows to predict multiple properties at the same time. (default: 1)

  • pred_as_dict (bool) – Set to False to return a (property) prediction tensor. By default, predictions are returned as a dictionary with several keys (e.g. energy, forces) (default: True)

  • regress_forces (str) – Specifies if we predict forces or not, and how do we predict them. (None or “”, “direct”, “direct_with_gradient_target”)

  • force_decoder_type (str) – Specifies the type of force decoder (“simple”, “mlp”, “res”, “res_updown”)

  • force_decoder_model_config (dict) – contains information about the for decoder architecture (e.g. number of layers, hidden size).

energy_forward(data, preproc=True)[source]

Predicts any graph-level property (e.g. energy) for 3D atomic systems.

Parameters:
  • data (data.Batch) – Batch of graphs data objects.

  • preproc (bool) – Whether to apply (any given) preprocessing to the graph. Default to True.

Returns:

predicted properties for each graph (key: “energy”)

and final atomic representations (key: “hidden_state”)

Return type:

(dict)

forces_forward(preds)[source]

Predicts forces for 3D atomic systems. Can be utilised to predict any atom-level property.

Parameters:

preds (dict) – dictionnary with final atomic representations (hidden_state) and predicted properties (e.g. energy) for each graph

Returns:

additional predicted properties, at an atom-level (e.g. forces)

Return type:

(dict)

class faenet.model.InteractionBlock(hidden_channels, num_filters, act, mp_type, complex_mp, graph_norm)[source]

Bases: torch_geometric.nn.MessagePassing

Updates atom representations through custom message passing.

forward(h, edge_index, e)[source]

Forward pass of the Interaction block. Called in FAENet forward pass to update atom representations.

Parameters:
  • h (tensor) – atom embedddings. (num_atoms, hidden_channels)

  • edge_index (tensor) – adjacency matrix. (2, num_edges)

  • e (tensor) – edge embeddings. (num_edges, num_filters)

Returns:

updated atom embeddings

Return type:

(tensor)

message(x_j, W, local_env=None)[source]
reset_parameters()[source]
class faenet.model.OutputBlock(energy_head, hidden_channels, act, out_dim=1)[source]

Bases: torch.nn.Module

Compute task-specific predictions from final atom representations.

forward(h, edge_index, edge_weight, batch, alpha)[source]

Forward pass of the Output block. Called in FAENet to make prediction from final atom representations.

Parameters:
  • h (tensor) – atom representations. (num_atoms, hidden_channels)

  • edge_index (tensor) – adjacency matrix. (2, num_edges)

  • edge_weight (tensor) – edge weights. (num_edges, )

  • batch (tensor) – batch indices. (num_atoms, )

  • alpha (tensor) – atom attention weights for late energy head. (num_atoms, )

Returns:

graph-level representation (e.g. energy prediction)

Return type:

(tensor)

reset_parameters()[source]