faenet.eval
Module Contents
Functions
|
Test rotation and reflection invariance & equivariance of GNNs |
|
Rotate all graphs in a batch |
|
Rotate all graphs in a batch |
|
Apply a transformation to a batch of graphs |
- faenet.eval.eval_model_symmetries(loader, model, frame_averaging, fa_method, device, task_name, crystal_task=True)[source]
Test rotation and reflection invariance & equivariance of GNNs
- Parameters:
loader (data) – dataloader
model – model instance
frame_averaging (str) – frame averaging (“2D”, “3D”), data augmentation (“DA”) or none (“”)
fa_method (str) – _description_
task_name (str) – the targeted task (“energy”, “forces”)
crystal_task (bool) – whether we have a crystal (i.e. a unit cell) or a molecule
- Returns:
- metrics measuring invariance/equivariance
of energy/force predictions
- Return type:
(dict)
- faenet.eval.reflect_graph(batch, frame_averaging, fa_method, reflection=None)[source]
Rotate all graphs in a batch
- Parameters:
batch (data.Batch) – batch of graphs
frame_averaging (str) – Transform method used (“2D”, “3D”, “DA”, “”)
fa_method (str) – FA method used (“”, “stochastic”, “all”, “det”, “se3-stochastic”, “se3-all”, “se3-det”)
reflection (str, optional) – type of reflection applied. (default:
None
)
- Returns:
reflected batch sample and rotation matrix used to reflect it
- Return type:
(dict)
- faenet.eval.rotate_graph(batch, frame_averaging, fa_method, rotation=None)[source]
Rotate all graphs in a batch
- Parameters:
batch (data.Batch) – batch of graphs.
frame_averaging (str) – Transform method used. (“2D”, “3D”, “DA”, “”)
fa_method (str) – FA method used. (“”, “stochastic”, “all”, “det”, “se3-stochastic”, “se3-all”, “se3-det”)
rotation (str, optional) – type of rotation applied. (default:
None
) (“z”, “x”, “y”, None)
- Returns:
rotated batch sample and rotation matrix used to rotate it
- Return type:
(dict)
- faenet.eval.transform_batch(batch, frame_averaging, fa_method, neighbors=None)[source]
Apply a transformation to a batch of graphs
- Parameters:
batch (data.Batch) – batch of data.Data objects.
frame_averaging (str) – Transform method used.
fa_method (str) – FA method used.
neighbors (list, optional) – list containing the number of edges in each graph of the batch. (default:
None
)
- Returns:
transformed batch sample
- Return type:
(data.Batch)