Source code for faenet.transforms

from faenet.frame_averaging import (
    frame_averaging_2D,
    frame_averaging_3D,
    data_augmentation,
)


[docs]class Transform: """Base class for all transforms."""
[docs] def __call__(self, data): raise NotImplementedError
[docs] def __str__(self): name = self.__class__.__name__ items = [ f"{k}={v}" for k, v in self.__dict__.items() if not callable(v) and k != "inactive" ] s = f"{name}({', '.join(items)})" if self.inactive: s = f"[inactive] {s}" return s
[docs]class FrameAveraging(Transform): r"""Frame Averaging (FA) Transform for (PyG) Data objects (e.g. 3D atomic graphs). Computes new atomic positions (`fa_pos`) for all datapoints, as well as new unit cells (`fa_cell`) attributes for crystal structures, when applicable. The rotation matrix (`fa_rot`) used for the frame averaging is also stored. Args: frame_averaging (str): Transform method used. Can be 2D FA, 3D FA, Data Augmentation or no FA, respectively denoted by (`"2D"`, `"3D"`, `"DA"`, `""`) fa_method (str): the actual frame averaging technique used. "stochastic" refers to sampling one frame at random (at each epoch), "det" to chosing deterministically one frame, and "all" to using all frames. The prefix "se3-" refers to the SE(3) equivariant version of the method. "" means that no frame averaging is used. (`""`, `"stochastic"`, `"all"`, `"det"`, `"se3-stochastic"`, `"se3-all"`, `"se3-det"`) Returns: (data.Data): updated data object with new positions (+ unit cell) attributes and the rotation matrices used for the frame averaging transform. """ def __init__(self, frame_averaging=None, fa_method=None): self.fa_method = ( "stochastic" if (fa_method is None or fa_method == "") else fa_method ) self.frame_averaging = "" if frame_averaging is None else frame_averaging self.inactive = not self.frame_averaging assert self.frame_averaging in { "", "2D", "3D", "DA", } assert self.fa_method in { "", "stochastic", "det", "all", "se3-stochastic", "se3-det", "se3-all", } if self.frame_averaging: if self.frame_averaging == "2D": self.fa_func = frame_averaging_2D elif self.frame_averaging == "3D": self.fa_func = frame_averaging_3D elif self.frame_averaging == "DA": self.fa_func = data_augmentation else: raise ValueError(f"Unknown frame averaging: {self.frame_averaging}")
[docs] def __call__(self, data): """The only requirement for the data is to have a `pos` attribute.""" if self.inactive: return data elif self.frame_averaging == "DA": return self.fa_func(data, self.fa_method) else: data.fa_pos, data.fa_cell, data.fa_rot = self.fa_func( data.pos, data.cell if hasattr(data, "cell") else None, self.fa_method ) return data