faenet.model

Code of the Scalable Frame Averaging (Rotation Invariant) GNN

Module Contents

Classes

EmbeddingBlock

Initialise atom and edge representations

FAENet

Non-symmetry preserving GNN model for 3D atomic systems,

InteractionBlock

Updates atom representations through message passing

OutputBlock

Compute task-specific predictions from final atom representations

class faenet.model.EmbeddingBlock(num_gaussians, num_filters, hidden_channels, tag_hidden_channels, pg_hidden_channels, phys_hidden_channels, phys_embeds, act, second_layer_MLP)

Bases: torch.nn.Module

Initialise atom and edge representations

forward(z, rel_pos, edge_attr, tag=None, subnodes=None)
reset_parameters()
class faenet.model.FAENet(cutoff=6.0, act='swish', preprocess='pbc_preprocess', complex_mp=False, max_num_neighbors=40, num_gaussians=50, num_filters=128, hidden_channels=128, tag_hidden_channels=32, pg_hidden_channels=32, phys_hidden_channels=0, phys_embeds=True, num_interactions=4, mp_type='updownscale_base', graph_norm=True, second_layer_MLP=True, skip_co='concat', energy_head=None, regress_forces=None, force_decoder_type='mlp', force_decoder_model_config={'hidden_channels': 128})

Bases: faenet.base_model.BaseModel

Non-symmetry preserving GNN model for 3D atomic systems, called FAENet: Frame Averaging Equivariant Network.

Parameters:
  • cutoff (float) – Cutoff distance for interatomic interactions. (default: 6.0)

  • preprocess (callable) – Pre-processing function for the data. This function should accept a data object as input and return a tuple containing the following: atomic numbers, batch indices, final adjacency, relative positions, pairwise distances. Examples of valid preprocessing functions include pbc_preprocess, base_preprocess, or custom functions.

  • act (str) – Activation function (default: swish)

  • max_num_neighbors (int) – The maximum number of neighbors to collect for each node within the cutoff distance. (default: 40)

  • hidden_channels (int) – Hidden embedding size. (default: 128)

  • tag_hidden_channels (int) – Hidden tag embedding size. (default: 32)

  • pg_hidden_channels (int) – Hidden period and group embedding size. (default: 32)

  • phys_embeds (bool) – Do we include fixed physics-aware embeddings. (default: :obj: True)

  • phys_hidden_channels (int) – Hidden size of learnable physics-aware embeddings. (default: 0)

  • num_interactions (int) – The number of interaction (i.e. message passing) blocks. (default: 4)

  • num_gaussians (int) – The number of gaussians \(\mu\) to encode distance info. (default: 50)

  • num_filters (int) – The size of convolutional filters. (default: 128)

  • second_layer_MLP (bool) – Use 2-layers MLP at the end of the Embedding block. (default: False)

  • skip_co (str) – Add a skip connection between each interaction block and energy-head. (False, “add”, “concat”, “concat_atom”)

  • mp_type (str) – Specificies the Message Passing type of the interaction block. (“base”, “updownscale_base”, “updownscale”, “updown_local_env”, “simple”):

  • graph_norm (bool) – Whether to apply batch norm after every linear layer. (default: True)

  • complex_mp (bool) – (default: True)

  • energy_head (str) – Method to compute energy prediction from atom representations. (None, “weighted-av-initial-embeds”, “weighted-av-final-embeds”)

  • regress_forces (str) – Specifies if we predict forces or not, and how do we predict them. (None or “”, “direct”, “direct_with_gradient_target”)

  • force_decoder_type (str) – Specifies the type of force decoder (“simple”, “mlp”, “res”, “res_updown”)

  • force_decoder_model_config (dict) – contains information about the for decoder architecture (e.g. number of layers, hidden size).

energy_forward(data, preproc=True)

Predicts any graph-level properties (e.g. energy) for 3D atomic systems.

Parameters:
  • data (data.Batch) – Batch of graphs datapoints.

  • preproc (bool) – Whether to preprocess the graph. Default to True.

Returns:

predicted properties for each graph (e.g. energy)

Return type:

dict

forces_forward(preds)

Predicts forces for 3D atomic systems. Can be utilised to predict any atom-level property.

Parameters:

preds (dict) – dictionnary with predicted properties for each graph.

Returns:

additional predicted properties, at an atom-level (e.g. forces)

Return type:

dict

class faenet.model.InteractionBlock(hidden_channels, num_filters, act, mp_type, complex_mp, graph_norm)

Bases: torch_geometric.nn.MessagePassing

Updates atom representations through message passing

forward(h, edge_index, e)
message(x_j, W, local_env=None)
reset_parameters()
class faenet.model.OutputBlock(energy_head, hidden_channels, act)

Bases: torch.nn.Module

Compute task-specific predictions from final atom representations

forward(h, edge_index, edge_weight, batch, alpha)
reset_parameters()