faenet.transforms

Module Contents

Classes

FrameAveraging

Frame Averaging (FA) Transform for (PyG) Data objects (e.g. 3D atomic graphs).

Transform

Base class for all transforms.

class faenet.transforms.FrameAveraging(frame_averaging=None, fa_method=None)[source]

Bases: Transform

Frame Averaging (FA) Transform for (PyG) Data objects (e.g. 3D atomic graphs).

Computes new atomic positions (fa_pos) for all datapoints, as well as new unit cells (fa_cell) attributes for crystal structures, when applicable. The rotation matrix (fa_rot) used for the frame averaging is also stored.

Parameters:
  • frame_averaging (str) – Transform method used. Can be 2D FA, 3D FA, Data Augmentation or no FA, respectively denoted by (“2D”, “3D”, “DA”, “”)

  • fa_method (str) – the actual frame averaging technique used. “stochastic” refers to sampling one frame at random (at each epoch), “det” to chosing deterministically one frame, and “all” to using all frames. The prefix “se3-” refers to the SE(3) equivariant version of the method. “” means that no frame averaging is used. (“”, “stochastic”, “all”, “det”, “se3-stochastic”, “se3-all”, “se3-det”)

Returns:

updated data object with new positions (+ unit cell) attributes and the rotation matrices used for the frame averaging transform.

Return type:

(data.Data)

__call__(data)[source]

The only requirement for the data is to have a pos attribute.

class faenet.transforms.Transform[source]

Base class for all transforms.

abstract __call__(data)[source]
__str__()[source]

Return str(self).