faenet.eval

Module Contents

Functions

eval_model_symmetries(loader, model, frame_averaging, ...)

Test rotation and reflection invariance & equivariance of GNNs

reflect_graph(batch, frame_averaging, fa_method[, ...])

Rotate all graphs in a batch

rotate_graph(batch, frame_averaging, fa_method[, rotation])

Rotate all graphs in a batch

transform_batch(batch, frame_averaging, fa_method[, ...])

Apply a transformation to a batch of graphs

faenet.eval.eval_model_symmetries(loader, model, frame_averaging, fa_method, device, task_name, crystal_task=True)[source]

Test rotation and reflection invariance & equivariance of GNNs

Parameters:
  • loader (data) – dataloader

  • model – model instance

  • frame_averaging (str) – frame averaging (“2D”, “3D”), data augmentation (“DA”) or none (“”)

  • fa_method (str) – _description_

  • task_name (str) – the targeted task (“energy”, “forces”)

  • crystal_task (bool) – whether we have a crystal (i.e. a unit cell) or a molecule

Returns:

metrics measuring invariance/equivariance

of energy/force predictions

Return type:

(dict)

faenet.eval.reflect_graph(batch, frame_averaging, fa_method, reflection=None)[source]

Rotate all graphs in a batch

Parameters:
  • batch (data.Batch) – batch of graphs

  • frame_averaging (str) – Transform method used (“2D”, “3D”, “DA”, “”)

  • fa_method (str) – FA method used (“”, “stochastic”, “all”, “det”, “se3-stochastic”, “se3-all”, “se3-det”)

  • reflection (str, optional) – type of reflection applied. (default: None)

Returns:

reflected batch sample and rotation matrix used to reflect it

Return type:

(dict)

faenet.eval.rotate_graph(batch, frame_averaging, fa_method, rotation=None)[source]

Rotate all graphs in a batch

Parameters:
  • batch (data.Batch) – batch of graphs.

  • frame_averaging (str) – Transform method used. (“2D”, “3D”, “DA”, “”)

  • fa_method (str) – FA method used. (“”, “stochastic”, “all”, “det”, “se3-stochastic”, “se3-all”, “se3-det”)

  • rotation (str, optional) – type of rotation applied. (default: None) (“z”, “x”, “y”, None)

Returns:

rotated batch sample and rotation matrix used to rotate it

Return type:

(dict)

faenet.eval.transform_batch(batch, frame_averaging, fa_method, neighbors=None)[source]

Apply a transformation to a batch of graphs

Parameters:
  • batch (data.Batch) – batch of data.Data objects.

  • frame_averaging (str) – Transform method used.

  • fa_method (str) – FA method used.

  • neighbors (list, optional) – list containing the number of edges in each graph of the batch. (default: None)

Returns:

transformed batch sample

Return type:

(data.Batch)