faenet.base_model
Module Contents
Classes
Base class for ML models applied to 3D atomic systems. |
- class faenet.base_model.BaseModel(**kwargs)[source]
Bases:
torch.nn.Module
Base class for ML models applied to 3D atomic systems.
- forces_as_energy_grad(pos, energy)[source]
Computes forces from energy gradient
- Parameters:
pos (tensor) – 3D atom positions
energy (tensor) – system’s predicted energy
- Returns:
forces as the energy gradient w.r.t. atom positions
- Return type:
(tensor)
- forward(data, mode='train', preproc=True)[source]
Main Forward pass.
- Parameters:
data (Data) – input data object, with 3D atom positions (pos)
mode (str) – train or inference mode
preproc (bool) – Whether to preprocess (pbc, cutoff graph) the input graph or point cloud. Default: True.
- Returns:
predicted energy, forces and final atomic hidden states
- Return type:
(dict)