from faenet.frame_averaging import (
frame_averaging_2D,
frame_averaging_3D,
data_augmentation,
)
[docs]class FrameAveraging(Transform):
r"""Frame Averaging (FA) Transform for (PyG) Data objects (e.g. 3D atomic graphs).
Computes new atomic positions (`fa_pos`) for all datapoints, as well as
new unit cells (`fa_cell`) attributes for crystal structures, when applicable.
The rotation matrix (`fa_rot`) used for the frame averaging is also stored.
Args:
frame_averaging (str): Transform method used.
Can be 2D FA, 3D FA, Data Augmentation or no FA, respectively denoted by
(`"2D"`, `"3D"`, `"DA"`, `""`)
fa_method (str): the actual frame averaging technique used.
"stochastic" refers to sampling one frame at random (at each epoch),
"det" to chosing deterministically one frame, and "all" to using all frames.
The prefix "se3-" refers to the SE(3) equivariant version of the method.
"" means that no frame averaging is used.
(`""`, `"stochastic"`, `"all"`, `"det"`, `"se3-stochastic"`, `"se3-all"`, `"se3-det"`)
Returns:
(data.Data): updated data object with new positions (+ unit cell) attributes
and the rotation matrices used for the frame averaging transform.
"""
def __init__(self, frame_averaging=None, fa_method=None):
self.fa_method = (
"stochastic" if (fa_method is None or fa_method == "") else fa_method
)
self.frame_averaging = "" if frame_averaging is None else frame_averaging
self.inactive = not self.frame_averaging
assert self.frame_averaging in {
"",
"2D",
"3D",
"DA",
}
assert self.fa_method in {
"",
"stochastic",
"det",
"all",
"se3-stochastic",
"se3-det",
"se3-all",
}
if self.frame_averaging:
if self.frame_averaging == "2D":
self.fa_func = frame_averaging_2D
elif self.frame_averaging == "3D":
self.fa_func = frame_averaging_3D
elif self.frame_averaging == "DA":
self.fa_func = data_augmentation
else:
raise ValueError(f"Unknown frame averaging: {self.frame_averaging}")
[docs] def __call__(self, data):
"""The only requirement for the data is to have a `pos` attribute."""
if self.inactive:
return data
elif self.frame_averaging == "DA":
return self.fa_func(data, self.fa_method)
else:
data.fa_pos, data.fa_cell, data.fa_rot = self.fa_func(
data.pos, data.cell if hasattr(data, "cell") else None, self.fa_method
)
return data