faenet.transforms
Module Contents
Classes
Frame Averaging (FA) Transform for (PyG) Data objects (e.g. 3D atomic graphs). |
|
Base class for all transforms. |
- class faenet.transforms.FrameAveraging(frame_averaging=None, fa_method=None)[source]
Bases:
Transform
Frame Averaging (FA) Transform for (PyG) Data objects (e.g. 3D atomic graphs).
Computes new atomic positions (fa_pos) for all datapoints, as well as new unit cells (fa_cell) attributes for crystal structures, when applicable. The rotation matrix (fa_rot) used for the frame averaging is also stored.
- Parameters:
frame_averaging (str) – Transform method used. Can be 2D FA, 3D FA, Data Augmentation or no FA, respectively denoted by (“2D”, “3D”, “DA”, “”)
fa_method (str) – the actual frame averaging technique used. “stochastic” refers to sampling one frame at random (at each epoch), “det” to chosing deterministically one frame, and “all” to using all frames. The prefix “se3-” refers to the SE(3) equivariant version of the method. “” means that no frame averaging is used. (“”, “stochastic”, “all”, “det”, “se3-stochastic”, “se3-all”, “se3-det”)
- Returns:
updated data object with new positions (+ unit cell) attributes and the rotation matrices used for the frame averaging transform.
- Return type:
(data.Data)