faenet.base_model

Module Contents

Classes

BaseModel

class faenet.base_model.BaseModel(**kwargs)

Bases: torch.nn.Module

property num_params
abstract energy_forward(data, preproc=True)

Forward pass for energy prediction.

forces_as_energy_grad(pos, energy)

Computes forces from energy gradient

Parameters:
  • pos (tensor) – atom positions

  • energy (tensor) – predicted energy

Returns:

gradient of energy w.r.t. atom positions

Return type:

forces (tensor)

abstract forces_forward(preds)

Forward pass for force prediction.

forward(data, mode='train', preproc=True)

Main Forward pass.

reset_parameters()

Resets all learnable parameters of the module.