faenet.model
FAENet: Frame Averaging Equivariant graph neural Network Simple, scalable and expressive model for property prediction on 3D atomic systems.
Module Contents
Classes
Initialise atom and edge representations. |
|
Non-symmetry preserving GNN model for 3D atomic systems, |
|
Updates atom representations through custom message passing. |
|
Compute task-specific predictions from final atom representations. |
- class faenet.model.EmbeddingBlock(num_gaussians, num_filters, hidden_channels, tag_hidden_channels, pg_hidden_channels, phys_hidden_channels, phys_embeds, act, second_layer_MLP)[source]
Bases:
torch.nn.Module
Initialise atom and edge representations.
- forward(z, rel_pos, edge_attr, tag=None, subnodes=None)[source]
Forward pass of the Embedding block. Called in FAENet to generate initial atom and edge representations.
- Parameters:
z (tensor) – atomic numbers. (num_atoms, )
rel_pos (tensor) – relative atomic positions. (num_edges, 3)
edge_attr (tensor) – RBF of pairwise distances. (num_edges, num_gaussians)
tag (tensor, optional) – atom information specific to OCP. Defaults to None.
- Returns:
atom embeddings, edge embeddings
- Return type:
(tensor, tensor)
- class faenet.model.FAENet(cutoff=6.0, preprocess='pbc_preprocess', act='swish', max_num_neighbors=40, hidden_channels=128, tag_hidden_channels=32, pg_hidden_channels=32, phys_embeds=True, phys_hidden_channels=0, num_interactions=4, num_gaussians=50, num_filters=128, second_layer_MLP=True, skip_co='concat', mp_type='updownscale_base', graph_norm=True, complex_mp=False, energy_head=None, out_dim=1, pred_as_dict=True, regress_forces=None, force_decoder_type='mlp', force_decoder_model_config={'hidden_channels': 128})[source]
Bases:
faenet.base_model.BaseModel
Non-symmetry preserving GNN model for 3D atomic systems, called FAENet: Frame Averaging Equivariant Network.
- Parameters:
cutoff (float) – Cutoff distance for interatomic interactions. (default:
6.0
)preprocess (callable) – Pre-processing function for the data. This function should accept a data object as input and return a tuple containing the following: atomic numbers, batch indices, final adjacency, relative positions, pairwise distances. Examples of valid preprocessing functions include pbc_preprocess, base_preprocess, or custom functions.
act (str) – Activation function (default: swish)
max_num_neighbors (int) – The maximum number of neighbors to collect for each node within the
cutoff
distance. (default: 40)hidden_channels (int) – Hidden embedding size. (default: 128)
tag_hidden_channels (int) – Hidden tag embedding size. (default:
32
)pg_hidden_channels (int) – Hidden period and group embedding size. (default:
32
)phys_embeds (bool) – Do we include fixed physics-aware embeddings. (default: :obj: True)
phys_hidden_channels (int) – Hidden size of learnable physics-aware embeddings. (default:
0
)num_interactions (int) – The number of interaction (i.e. message passing) blocks. (default:
4
)num_gaussians (int) – The number of gaussians \(\mu\) to encode distance info. (default:
50
)num_filters (int) – The size of convolutional filters. (default:
128
)second_layer_MLP (bool) – Use 2-layers MLP at the end of the Embedding block. (default:
False
)skip_co (str) – Add a skip connection between each interaction block and energy-head. (False, “add”, “concat”, “concat_atom”)
mp_type (str) – Specificies the Message Passing type of the interaction block. (“base”, “updownscale_base”, “updownscale”, “updown_local_env”, “simple”):
graph_norm (bool) – Whether to apply batch norm after every linear layer. (default:
True
)complex_mp (bool) – (default:
True
)energy_head (str) – Method to compute energy prediction from atom representations. (None, “weighted-av-initial-embeds”, “weighted-av-final-embeds”)
out_dim (int) – size of the output tensor for graph-level predicted properties (“energy”) Allows to predict multiple properties at the same time. (default:
1
)pred_as_dict (bool) – Set to False to return a (property) prediction tensor. By default, predictions are returned as a dictionary with several keys (e.g. energy, forces) (default:
True
)regress_forces (str) – Specifies if we predict forces or not, and how do we predict them. (None or “”, “direct”, “direct_with_gradient_target”)
force_decoder_type (str) – Specifies the type of force decoder (“simple”, “mlp”, “res”, “res_updown”)
force_decoder_model_config (dict) – contains information about the for decoder architecture (e.g. number of layers, hidden size).
- energy_forward(data, preproc=True)[source]
Predicts any graph-level property (e.g. energy) for 3D atomic systems.
- Parameters:
data (data.Batch) – Batch of graphs data objects.
preproc (bool) – Whether to apply (any given) preprocessing to the graph. Default to True.
- Returns:
- predicted properties for each graph (key: “energy”)
and final atomic representations (key: “hidden_state”)
- Return type:
(dict)
- forces_forward(preds)[source]
Predicts forces for 3D atomic systems. Can be utilised to predict any atom-level property.
- Parameters:
preds (dict) – dictionnary with final atomic representations (hidden_state) and predicted properties (e.g. energy) for each graph
- Returns:
additional predicted properties, at an atom-level (e.g. forces)
- Return type:
(dict)
- class faenet.model.InteractionBlock(hidden_channels, num_filters, act, mp_type, complex_mp, graph_norm)[source]
Bases:
torch_geometric.nn.MessagePassing
Updates atom representations through custom message passing.
- forward(h, edge_index, e)[source]
Forward pass of the Interaction block. Called in FAENet forward pass to update atom representations.
- Parameters:
h (tensor) – atom embedddings. (num_atoms, hidden_channels)
edge_index (tensor) – adjacency matrix. (2, num_edges)
e (tensor) – edge embeddings. (num_edges, num_filters)
- Returns:
updated atom embeddings
- Return type:
(tensor)
- class faenet.model.OutputBlock(energy_head, hidden_channels, act, out_dim=1)[source]
Bases:
torch.nn.Module
Compute task-specific predictions from final atom representations.
- forward(h, edge_index, edge_weight, batch, alpha)[source]
Forward pass of the Output block. Called in FAENet to make prediction from final atom representations.
- Parameters:
h (tensor) – atom representations. (num_atoms, hidden_channels)
edge_index (tensor) – adjacency matrix. (2, num_edges)
edge_weight (tensor) – edge weights. (num_edges, )
batch (tensor) – batch indices. (num_atoms, )
alpha (tensor) – atom attention weights for late energy head. (num_atoms, )
- Returns:
graph-level representation (e.g. energy prediction)
- Return type:
(tensor)