Source code for faenet.model

"""
FAENet: Frame Averaging Equivariant graph neural Network 
Simple, scalable and expressive model for property prediction on 3D atomic systems.
"""
from typing import Dict, Optional, Union

import torch
from torch import nn
from torch.nn import Embedding, Linear
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.norm import GraphNorm
from torch_scatter import scatter

from faenet.base_model import BaseModel
from faenet.embedding import PhysEmbedding
from faenet.force_decoder import ForceDecoder
from faenet.utils import GaussianSmearing, swish, pbc_preprocess, base_preprocess


[docs]class EmbeddingBlock(nn.Module): """Initialise atom and edge representations.""" def __init__( self, num_gaussians, num_filters, hidden_channels, tag_hidden_channels, pg_hidden_channels, phys_hidden_channels, phys_embeds, act, second_layer_MLP, ): super().__init__() self.act = act self.use_tag = tag_hidden_channels > 0 self.use_pg = pg_hidden_channels > 0 self.use_mlp_phys = phys_hidden_channels > 0 and phys_embeds self.second_layer_MLP = second_layer_MLP # --- Node embedding --- # Phys embeddings self.phys_emb = PhysEmbedding( props=phys_embeds, props_grad=phys_hidden_channels > 0, pg=self.use_pg ) # With MLP if self.use_mlp_phys: self.phys_lin = Linear(self.phys_emb.n_properties, phys_hidden_channels) else: phys_hidden_channels = self.phys_emb.n_properties # Period + group embeddings if self.use_pg: self.period_embedding = Embedding( self.phys_emb.period_size, pg_hidden_channels ) self.group_embedding = Embedding( self.phys_emb.group_size, pg_hidden_channels ) # Tag embedding if tag_hidden_channels: self.tag_embedding = Embedding(3, tag_hidden_channels) # Main embedding self.emb = Embedding( 85, hidden_channels - tag_hidden_channels - phys_hidden_channels - 2 * pg_hidden_channels, ) # MLP self.lin = Linear(hidden_channels, hidden_channels) if self.second_layer_MLP: self.lin_2 = Linear(hidden_channels, hidden_channels) # --- Edge embedding --- self.lin_e1 = Linear(3, num_filters // 2) # r_ij self.lin_e12 = Linear(num_gaussians, num_filters - (num_filters // 2)) # d_ij if self.second_layer_MLP: self.lin_e2 = Linear(num_filters, num_filters) self.reset_parameters()
[docs] def reset_parameters(self): self.emb.reset_parameters() if self.use_mlp_phys: nn.init.xavier_uniform_(self.phys_lin.weight) if self.use_tag: self.tag_embedding.reset_parameters() if self.use_pg: self.period_embedding.reset_parameters() self.group_embedding.reset_parameters() nn.init.xavier_uniform_(self.lin.weight) self.lin.bias.data.fill_(0) nn.init.xavier_uniform_(self.lin_e1.weight) self.lin_e1.bias.data.fill_(0) if self.second_layer_MLP: nn.init.xavier_uniform_(self.lin_2.weight) self.lin_2.bias.data.fill_(0) nn.init.xavier_uniform_(self.lin_e2.weight) self.lin_e2.bias.data.fill_(0)
[docs] def forward(self, z, rel_pos, edge_attr, tag=None, subnodes=None): """Forward pass of the Embedding block. Called in FAENet to generate initial atom and edge representations. Args: z (tensor): atomic numbers. (num_atoms, ) rel_pos (tensor): relative atomic positions. (num_edges, 3) edge_attr (tensor): RBF of pairwise distances. (num_edges, num_gaussians) tag (tensor, optional): atom information specific to OCP. Defaults to None. Returns: (tensor, tensor): atom embeddings, edge embeddings """ # --- Edge embedding -- rel_pos = self.lin_e1(rel_pos) # r_ij edge_attr = self.lin_e12(edge_attr) # d_ij e = torch.cat((rel_pos, edge_attr), dim=1) e = self.act(e) # can comment out if self.second_layer_MLP: # e = self.lin_e2(e) e = self.act(self.lin_e2(e)) # --- Node embedding -- # Create atom embeddings based on its characteristic number h = self.emb(z) if self.phys_emb.device != h.device: self.phys_emb = self.phys_emb.to(h.device) # Concat tag embedding if self.use_tag: h_tag = self.tag_embedding(tag) h = torch.cat((h, h_tag), dim=1) # Concat physics embeddings if self.phys_emb.n_properties > 0: h_phys = self.phys_emb.properties[z] if self.use_mlp_phys: h_phys = self.phys_lin(h_phys) h = torch.cat((h, h_phys), dim=1) # Concat period & group embedding if self.use_pg: h_period = self.period_embedding(self.phys_emb.period[z]) h_group = self.group_embedding(self.phys_emb.group[z]) h = torch.cat((h, h_period, h_group), dim=1) # MLP h = self.act(self.lin(h)) if self.second_layer_MLP: h = self.act(self.lin_2(h)) return h, e
[docs]class InteractionBlock(MessagePassing): """Updates atom representations through custom message passing.""" def __init__( self, hidden_channels, num_filters, act, mp_type, complex_mp, graph_norm, ): super(InteractionBlock, self).__init__() self.act = act self.mp_type = mp_type self.hidden_channels = hidden_channels self.complex_mp = complex_mp self.graph_norm = graph_norm if graph_norm: self.graph_norm = GraphNorm( hidden_channels if "updown" not in self.mp_type else num_filters ) if self.mp_type == "simple": self.lin_h = nn.Linear(hidden_channels, hidden_channels) elif self.mp_type == "updownscale": self.lin_geom = nn.Linear(num_filters, num_filters) self.lin_down = nn.Linear(hidden_channels, num_filters) self.lin_up = nn.Linear(num_filters, hidden_channels) elif self.mp_type == "updownscale_base": self.lin_geom = nn.Linear(num_filters + 2 * hidden_channels, num_filters) self.lin_down = nn.Linear(hidden_channels, num_filters) self.lin_up = nn.Linear(num_filters, hidden_channels) elif self.mp_type == "updown_local_env": self.lin_down = nn.Linear(hidden_channels, num_filters) self.lin_geom = nn.Linear(num_filters, num_filters) self.lin_up = nn.Linear(2 * num_filters, hidden_channels) else: # base self.lin_geom = nn.Linear( num_filters + 2 * hidden_channels, hidden_channels ) self.lin_h = nn.Linear(hidden_channels, hidden_channels) if self.complex_mp: self.other_mlp = nn.Linear(hidden_channels, hidden_channels) self.reset_parameters()
[docs] def reset_parameters(self): if self.mp_type != "simple": nn.init.xavier_uniform_(self.lin_geom.weight) self.lin_geom.bias.data.fill_(0) if self.complex_mp: nn.init.xavier_uniform_(self.other_mlp.weight) self.other_mlp.bias.data.fill_(0) if self.mp_type in {"updownscale", "updownscale_base", "updown_local_env"}: nn.init.xavier_uniform_(self.lin_up.weight) self.lin_up.bias.data.fill_(0) nn.init.xavier_uniform_(self.lin_down.weight) self.lin_down.bias.data.fill_(0) else: nn.init.xavier_uniform_(self.lin_h.weight) self.lin_h.bias.data.fill_(0)
[docs] def forward(self, h, edge_index, e): """Forward pass of the Interaction block. Called in FAENet forward pass to update atom representations. Args: h (tensor): atom embedddings. (num_atoms, hidden_channels) edge_index (tensor): adjacency matrix. (2, num_edges) e (tensor): edge embeddings. (num_edges, num_filters) Returns: (tensor): updated atom embeddings """ # Define edge embedding if self.mp_type in {"base", "updownscale_base"}: e = torch.cat([e, h[edge_index[0]], h[edge_index[1]]], dim=1) if self.mp_type in { "updownscale", "base", "updownscale_base", }: e = self.act(self.lin_geom(e)) # --- Message Passing block -- if self.mp_type == "updownscale" or self.mp_type == "updownscale_base": h = self.act(self.lin_down(h)) # downscale node rep. h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: h = self.act(self.graph_norm(h)) h = self.act(self.lin_up(h)) # upscale node rep. elif self.mp_type == "updown_local_env": h = self.act(self.lin_down(h)) chi = self.propagate(edge_index, x=h, W=e, local_env=True) e = self.lin_geom(e) h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: h = self.act(self.graph_norm(h)) h = torch.cat((h, chi), dim=1) h = self.lin_up(h) elif self.mp_type in {"base", "simple"}: h = self.propagate(edge_index, x=h, W=e) # propagate if self.graph_norm: h = self.act(self.graph_norm(h)) h = self.act(self.lin_h(h)) else: raise ValueError("mp_type provided does not exist") if self.complex_mp: h = self.act(self.other_mlp(h)) return h
[docs] def message(self, x_j, W, local_env=None): if local_env is not None: return W else: return x_j * W
[docs]class OutputBlock(nn.Module): """Compute task-specific predictions from final atom representations.""" def __init__(self, energy_head, hidden_channels, act, out_dim=1): super().__init__() self.energy_head = energy_head self.act = act self.lin1 = Linear(hidden_channels, hidden_channels // 2) self.lin2 = Linear(hidden_channels // 2, out_dim) if self.energy_head == "weighted-av-final-embeds": self.w_lin = Linear(hidden_channels, 1)
[docs] def reset_parameters(self): nn.init.xavier_uniform_(self.lin1.weight) self.lin1.bias.data.fill_(0) nn.init.xavier_uniform_(self.lin2.weight) self.lin2.bias.data.fill_(0) if self.energy_head == "weighted-av-final-embeds": nn.init.xavier_uniform_(self.w_lin.weight) self.w_lin.bias.data.fill_(0)
[docs] def forward(self, h, edge_index, edge_weight, batch, alpha): """Forward pass of the Output block. Called in FAENet to make prediction from final atom representations. Args: h (tensor): atom representations. (num_atoms, hidden_channels) edge_index (tensor): adjacency matrix. (2, num_edges) edge_weight (tensor): edge weights. (num_edges, ) batch (tensor): batch indices. (num_atoms, ) alpha (tensor): atom attention weights for late energy head. (num_atoms, ) Returns: (tensor): graph-level representation (e.g. energy prediction) """ if self.energy_head == "weighted-av-final-embeds": alpha = self.w_lin(h) # MLP h = self.lin1(h) h = self.act(h) h = self.lin2(h) if self.energy_head in { "weighted-av-initial-embeds", "weighted-av-final-embeds", }: h = h * alpha # Global pooling out = scatter(h, batch, dim=0, reduce="add") return out
[docs]class FAENet(BaseModel): r"""Non-symmetry preserving GNN model for 3D atomic systems, called FAENet: Frame Averaging Equivariant Network. Args: cutoff (float): Cutoff distance for interatomic interactions. (default: :obj:`6.0`) preprocess (callable): Pre-processing function for the data. This function should accept a data object as input and return a tuple containing the following: atomic numbers, batch indices, final adjacency, relative positions, pairwise distances. Examples of valid preprocessing functions include `pbc_preprocess`, `base_preprocess`, or custom functions. act (str): Activation function (default: `swish`) max_num_neighbors (int): The maximum number of neighbors to collect for each node within the :attr:`cutoff` distance. (default: `40`) hidden_channels (int): Hidden embedding size. (default: `128`) tag_hidden_channels (int): Hidden tag embedding size. (default: :obj:`32`) pg_hidden_channels (int): Hidden period and group embedding size. (default: :obj:`32`) phys_embeds (bool): Do we include fixed physics-aware embeddings. (default: :obj: `True`) phys_hidden_channels (int): Hidden size of learnable physics-aware embeddings. (default: :obj:`0`) num_interactions (int): The number of interaction (i.e. message passing) blocks. (default: :obj:`4`) num_gaussians (int): The number of gaussians :math:`\mu` to encode distance info. (default: :obj:`50`) num_filters (int): The size of convolutional filters. (default: :obj:`128`) second_layer_MLP (bool): Use 2-layers MLP at the end of the Embedding block. (default: :obj:`False`) skip_co (str): Add a skip connection between each interaction block and energy-head. (`False`, `"add"`, `"concat"`, `"concat_atom"`) mp_type (str): Specificies the Message Passing type of the interaction block. (`"base"`, `"updownscale_base"`, `"updownscale"`, `"updown_local_env"`, `"simple"`): graph_norm (bool): Whether to apply batch norm after every linear layer. (default: :obj:`True`) complex_mp (bool); Whether to add a second layer MLP at the end of each Interaction (default: :obj:`True`) energy_head (str): Method to compute energy prediction from atom representations. (`None`, `"weighted-av-initial-embeds"`, `"weighted-av-final-embeds"`) out_dim (int): size of the output tensor for graph-level predicted properties ("energy") Allows to predict multiple properties at the same time. (default: :obj:`1`) pred_as_dict (bool): Set to False to return a (property) prediction tensor. By default, predictions are returned as a dictionary with several keys (e.g. energy, forces) (default: :obj:`True`) regress_forces (str): Specifies if we predict forces or not, and how do we predict them. (`None` or `""`, `"direct"`, `"direct_with_gradient_target"`) force_decoder_type (str): Specifies the type of force decoder (`"simple"`, `"mlp"`, `"res"`, `"res_updown"`) force_decoder_model_config (dict): contains information about the for decoder architecture (e.g. number of layers, hidden size). """ def __init__( self, cutoff: float = 6.0, preprocess: Union[str, callable] = "pbc_preprocess", act: str = "swish", max_num_neighbors: int = 40, hidden_channels: int = 128, tag_hidden_channels: int = 32, pg_hidden_channels: int = 32, phys_embeds: bool = True, phys_hidden_channels: int = 0, num_interactions: int = 4, num_gaussians: int = 50, num_filters: int = 128, second_layer_MLP: bool = True, skip_co: str = "concat", mp_type: str = "updownscale_base", graph_norm: bool = True, complex_mp: bool = False, energy_head: Optional[str] = None, out_dim: int = 1, pred_as_dict: bool = True, regress_forces: Optional[str] = None, force_decoder_type: Optional[str] = "mlp", force_decoder_model_config: Optional[dict] = {"hidden_channels": 128}, ): super(FAENet, self).__init__() self.act = act self.complex_mp = complex_mp self.cutoff = cutoff self.energy_head = energy_head self.force_decoder_type = force_decoder_type self.force_decoder_model_config = force_decoder_model_config self.graph_norm = graph_norm self.hidden_channels = hidden_channels self.max_num_neighbors = max_num_neighbors self.mp_type = mp_type self.num_filters = num_filters self.num_gaussians = num_gaussians self.num_interactions = num_interactions self.pg_hidden_channels = pg_hidden_channels self.phys_embeds = phys_embeds self.phys_hidden_channels = phys_hidden_channels self.regress_forces = regress_forces self.second_layer_MLP = second_layer_MLP self.skip_co = skip_co self.tag_hidden_channels = tag_hidden_channels self.preprocess = preprocess self.pred_as_dict = pred_as_dict if isinstance(self.preprocess, str): self.preprocess = eval(self.preprocess) if not isinstance(self.regress_forces, str): assert self.regress_forces is False or self.regress_forces is None, ( "regress_forces must be a string " + "('', 'direct', 'direct_with_gradient_target') or False or None" ) self.regress_forces = "" if self.mp_type == "simple": self.num_filters = self.hidden_channels self.act = ( (getattr(nn.functional, self.act) if self.act != "swish" else swish) if isinstance(self.act, str) else self.act ) assert callable(self.act), ( "act must be a callable function or a string " + "describing that function in torch.nn.functional" ) # Gaussian Basis self.distance_expansion = GaussianSmearing(0.0, self.cutoff, self.num_gaussians) # Embedding block self.embed_block = EmbeddingBlock( self.num_gaussians, self.num_filters, self.hidden_channels, self.tag_hidden_channels, self.pg_hidden_channels, self.phys_hidden_channels, self.phys_embeds, self.act, self.second_layer_MLP, ) # Interaction block self.interaction_blocks = nn.ModuleList( [ InteractionBlock( self.hidden_channels, self.num_filters, self.act, self.mp_type, self.complex_mp, self.graph_norm, ) for _ in range(self.num_interactions) ] ) # Output block self.output_block = OutputBlock( self.energy_head, self.hidden_channels, self.act, out_dim ) # Energy head if self.energy_head == "weighted-av-initial-embeds": self.w_lin = Linear(self.hidden_channels, 1) # Force head self.decoder = ( ForceDecoder( self.force_decoder_type, self.hidden_channels, self.force_decoder_model_config, self.act, ) if "direct" in self.regress_forces else None ) # Skip co if self.skip_co == "concat": self.mlp_skip_co = Linear(out_dim * (self.num_interactions + 1), out_dim) elif self.skip_co == "concat_atom": self.mlp_skip_co = Linear( ((self.num_interactions + 1) * self.hidden_channels), self.hidden_channels, ) # FAENet's forward pass in done in BaseModel, inherited here. # It uses forces_forward() and energy_forward() defined below.
[docs] def forces_forward(self, preds): """Predicts forces for 3D atomic systems. Can be utilised to predict any atom-level property. Args: preds (dict): dictionnary with final atomic representations (hidden_state) and predicted properties (e.g. energy) for each graph Returns: (dict): additional predicted properties, at an atom-level (e.g. forces) """ if self.decoder: return self.decoder(preds["hidden_state"])
[docs] def energy_forward(self, data, preproc=True): """Predicts any graph-level property (e.g. energy) for 3D atomic systems. Args: data (data.Batch): Batch of graphs data objects. preproc (bool): Whether to apply (any given) preprocessing to the graph. Default to True. Returns: (dict): predicted properties for each graph (key: "energy") and final atomic representations (key: "hidden_state") """ # Pre-process data (e.g. pbc, cutoff graph, etc.) # Should output all necessary attributes, in correct format. if preproc: z, batch, edge_index, rel_pos, edge_weight = self.preprocess( data, self.cutoff, self.max_num_neighbors ) else: rel_pos = data.pos[data.edge_index[0]] - data.pos[data.edge_index[1]] z, batch, edge_index, rel_pos, edge_weight = ( data.atomic_numbers.long(), data.batch, data.edge_index, rel_pos, rel_pos.norm(dim=-1), ) edge_attr = self.distance_expansion(edge_weight) # RBF of pairwise distances assert z.dim() == 1 and z.dtype == torch.long # Embedding block h, e = self.embed_block( z, rel_pos, edge_attr, data.tags if hasattr(data, "tags") else None ) # Compute atom weights for late energy head if self.energy_head == "weighted-av-initial-embeds": alpha = self.w_lin(h) else: alpha = None # Interaction blocks energy_skip_co = [] for interaction in self.interaction_blocks: if self.skip_co == "concat_atom": energy_skip_co.append(h) elif self.skip_co: energy_skip_co.append( self.output_block(h, edge_index, edge_weight, batch, alpha) ) h = h + interaction(h, edge_index, e) # Atom skip-co if self.skip_co == "concat_atom": energy_skip_co.append(h) h = self.act(self.mlp_skip_co(torch.cat(energy_skip_co, dim=1))) energy = self.output_block(h, edge_index, edge_weight, batch, alpha) # Skip-connection energy_skip_co.append(energy) if self.skip_co == "concat": energy = self.mlp_skip_co(torch.cat(energy_skip_co, dim=1)) elif self.skip_co == "add": energy = sum(energy_skip_co) preds = {"energy": energy, "hidden_state": h} return preds