faenet.frame_averaging
Module Contents
Functions
|
Check that the requirements for frame averaging are satisfied |
|
Compute all frames for a given graph, i.e. all possible |
|
Data augmentation: randomly rotated graphs are added |
|
Computes new positions for the graph atoms using |
|
Computes new positions for the graph atoms using |
- faenet.frame_averaging.check_constraints(eigenval, eigenvec, dim=3)[source]
Check that the requirements for frame averaging are satisfied
- Parameters:
eigenval (tensor) – eigenvalues
eigenvec (tensor) – eigenvectors
dim (int) – 2D or 3D frame averaging
- faenet.frame_averaging.compute_frames(eigenvec, pos, cell, fa_method='stochastic', pos_3D=None, det_index=0)[source]
Compute all frames for a given graph, i.e. all possible canonical representations of the 3D graph (of all euclidean transformations).
- Parameters:
eigenvec (tensor) – eigenvectors matrix
pos (tensor) – centered position vector
cell (tensor) – cell direction (dxd)
fa_method (str) – the Frame Averaging (FA) inspired technique chosen to select frames: stochastic-FA (stochastic), deterministic-FA (det), Full-FA (all) or SE(3)-FA (se3).
pos_3D – for 2D FA, pass atoms’ 3rd position coordinate.
- Returns:
3D position tensors of projected representation
- Return type:
(list)
- faenet.frame_averaging.data_augmentation(g, d=3, *args)[source]
Data augmentation: randomly rotated graphs are added in the dataloader transform.
- Parameters:
g (data.Data) – single graph
d (int) – dimension of the DA rotation (2D around z-axis or 3D)
rotation (str, optional) – around which axis do we rotate it. Defaults to ‘z’.
- Returns:
rotated graph
- Return type:
(data.Data)
- faenet.frame_averaging.frame_averaging_2D(pos, cell=None, fa_method='stochastic', check=False)[source]
Computes new positions for the graph atoms using frame averaging, which itself builds on the PCA of atom positions. 2D case: we project the atoms on the plane orthogonal to the z-axis. Motivation: sometimes, the z-axis is not the most relevant one (e.g. fixed).
- Parameters:
pos (tensor) – positions of atoms in the graph
cell (tensor) – unit cell of the graph. None if no pbc.
fa_method (str) – FA method used (stochastic, det, all, se3)
check (bool) – check if constraints are satisfied. Default: False.
- Returns:
updated atom positions (tensor): updated unit cell (tensor): the rotation matrix used (PCA)
- Return type:
(tensor)
- faenet.frame_averaging.frame_averaging_3D(pos, cell=None, fa_method='stochastic', check=False)[source]
Computes new positions for the graph atoms using frame averaging, which itself builds on the PCA of atom positions. Base case for 3D inputs.
- Parameters:
pos (tensor) – positions of atoms in the graph
cell (tensor) – unit cell of the graph. None if no pbc.
fa_method (str) – FA method used (stochastic, det, all, se3-all, se3-det, se3-stochastic)
check (bool) – check if constraints are satisfied. Default: False.
- Returns:
updated atom positions (tensor): updated unit cell (tensor): the rotation matrix used (PCA)
- Return type:
(tensor)