faenet.model
Code of the Scalable Frame Averaging (Rotation Invariant) GNN
Module Contents
Classes
Initialise atom and edge representations |
|
Non-symmetry preserving GNN model for 3D atomic systems, |
|
Updates atom representations through message passing |
|
Compute task-specific predictions from final atom representations |
- class faenet.model.EmbeddingBlock(num_gaussians, num_filters, hidden_channels, tag_hidden_channels, pg_hidden_channels, phys_hidden_channels, phys_embeds, act, second_layer_MLP)
Bases:
torch.nn.Module
Initialise atom and edge representations
- forward(z, rel_pos, edge_attr, tag=None, subnodes=None)
- reset_parameters()
- class faenet.model.FAENet(cutoff=6.0, act='swish', preprocess='pbc_preprocess', complex_mp=False, max_num_neighbors=40, num_gaussians=50, num_filters=128, hidden_channels=128, tag_hidden_channels=32, pg_hidden_channels=32, phys_hidden_channels=0, phys_embeds=True, num_interactions=4, mp_type='updownscale_base', graph_norm=True, second_layer_MLP=True, skip_co='concat', energy_head=None, regress_forces=None, force_decoder_type='mlp', force_decoder_model_config={'hidden_channels': 128})
Bases:
faenet.base_model.BaseModel
Non-symmetry preserving GNN model for 3D atomic systems, called FAENet: Frame Averaging Equivariant Network.
- Parameters:
cutoff (float) – Cutoff distance for interatomic interactions. (default:
6.0
)preprocess (callable) – Pre-processing function for the data. This function should accept a data object as input and return a tuple containing the following: atomic numbers, batch indices, final adjacency, relative positions, pairwise distances. Examples of valid preprocessing functions include pbc_preprocess, base_preprocess, or custom functions.
act (str) – Activation function (default: swish)
max_num_neighbors (int) – The maximum number of neighbors to collect for each node within the
cutoff
distance. (default: 40)hidden_channels (int) – Hidden embedding size. (default: 128)
tag_hidden_channels (int) – Hidden tag embedding size. (default:
32
)pg_hidden_channels (int) – Hidden period and group embedding size. (default:
32
)phys_embeds (bool) – Do we include fixed physics-aware embeddings. (default: :obj: True)
phys_hidden_channels (int) – Hidden size of learnable physics-aware embeddings. (default:
0
)num_interactions (int) – The number of interaction (i.e. message passing) blocks. (default:
4
)num_gaussians (int) – The number of gaussians \(\mu\) to encode distance info. (default:
50
)num_filters (int) – The size of convolutional filters. (default:
128
)second_layer_MLP (bool) – Use 2-layers MLP at the end of the Embedding block. (default:
False
)skip_co (str) – Add a skip connection between each interaction block and energy-head. (False, “add”, “concat”, “concat_atom”)
mp_type (str) – Specificies the Message Passing type of the interaction block. (“base”, “updownscale_base”, “updownscale”, “updown_local_env”, “simple”):
graph_norm (bool) – Whether to apply batch norm after every linear layer. (default:
True
)complex_mp (bool) – (default:
True
)energy_head (str) – Method to compute energy prediction from atom representations. (None, “weighted-av-initial-embeds”, “weighted-av-final-embeds”)
regress_forces (str) – Specifies if we predict forces or not, and how do we predict them. (None or “”, “direct”, “direct_with_gradient_target”)
force_decoder_type (str) – Specifies the type of force decoder (“simple”, “mlp”, “res”, “res_updown”)
force_decoder_model_config (dict) – contains information about the for decoder architecture (e.g. number of layers, hidden size).
- energy_forward(data, preproc=True)
Predicts any graph-level properties (e.g. energy) for 3D atomic systems.
- Parameters:
data (data.Batch) – Batch of graphs datapoints.
preproc (bool) – Whether to preprocess the graph. Default to True.
- Returns:
predicted properties for each graph (e.g. energy)
- Return type:
dict
- forces_forward(preds)
Predicts forces for 3D atomic systems. Can be utilised to predict any atom-level property.
- Parameters:
preds (dict) – dictionnary with predicted properties for each graph.
- Returns:
additional predicted properties, at an atom-level (e.g. forces)
- Return type:
dict