faenet.base_model

Module Contents

Classes

BaseModel

Base class for ML models applied to 3D atomic systems.

class faenet.base_model.BaseModel(**kwargs)[source]

Bases: torch.nn.Module

Base class for ML models applied to 3D atomic systems.

property num_params[source]
abstract energy_forward(data, preproc=True)[source]

Forward pass for energy prediction.

forces_as_energy_grad(pos, energy)[source]

Computes forces from energy gradient

Parameters:
  • pos (tensor) – 3D atom positions

  • energy (tensor) – system’s predicted energy

Returns:

forces as the energy gradient w.r.t. atom positions

Return type:

(tensor)

abstract forces_forward(preds)[source]

Forward pass for force prediction.

forward(data, mode='train', preproc=True)[source]

Main Forward pass.

Parameters:
  • data (Data) – input data object, with 3D atom positions (pos)

  • mode (str) – train or inference mode

  • preproc (bool) – Whether to preprocess (pbc, cutoff graph) the input graph or point cloud. Default: True.

Returns:

predicted energy, forces and final atomic hidden states

Return type:

(dict)

reset_parameters()[source]

Resets all learnable parameters of the module.