Source code for faenet.frame_averaging

import random
from copy import deepcopy
from itertools import product
from faenet.utils import RandomRotate

import torch


[docs]def compute_frames( eigenvec, pos, cell, fa_method="stochastic", pos_3D=None, det_index=0 ): """Compute all `frames` for a given graph, i.e. all possible canonical representations of the 3D graph (of all euclidean transformations). Args: eigenvec (tensor): eigenvectors matrix pos (tensor): centered position vector cell (tensor): cell direction (dxd) fa_method (str): the Frame Averaging (FA) inspired technique chosen to select frames: stochastic-FA (stochastic), deterministic-FA (det), Full-FA (all) or SE(3)-FA (se3). pos_3D: for 2D FA, pass atoms' 3rd position coordinate. Returns: (list): 3D position tensors of projected representation """ dim = pos.shape[1] # to differentiate between 2D or 3D case plus_minus_list = list(product([1, -1], repeat=dim)) plus_minus_list = [torch.tensor(x) for x in plus_minus_list] all_fa_pos = [] all_cell = [] all_rots = [] assert fa_method in { "all", "stochastic", "det", "se3-all", "se3-stochastic", "se3-det", } se3 = fa_method in { "se3-all", "se3-stochastic", "se3-det", } fa_cell = deepcopy(cell) if fa_method == "det" or fa_method == "se3-det": sum_eigenvec = torch.sum(eigenvec, axis=0) plus_minus_list = [torch.where(sum_eigenvec >= 0, 1.0, -1.0)] for pm in plus_minus_list: # Append new graph positions to list new_eigenvec = pm * eigenvec # Consider frame if it passes above check fa_pos = pos @ new_eigenvec if pos_3D is not None: full_eigenvec = torch.eye(3) fa_pos = torch.cat((fa_pos, pos_3D.unsqueeze(1)), dim=1) full_eigenvec[:2, :2] = new_eigenvec new_eigenvec = full_eigenvec if cell is not None: fa_cell = cell @ new_eigenvec # Check if determinant is 1 for SE(3) case if se3 and not torch.allclose( torch.linalg.det(new_eigenvec), torch.tensor(1.0), atol=1e-03 ): continue all_fa_pos.append(fa_pos) all_cell.append(fa_cell) all_rots.append(new_eigenvec.unsqueeze(0)) # Handle rare case where no R is positive orthogonal if all_fa_pos == []: all_fa_pos.append(fa_pos) all_cell.append(fa_cell) all_rots.append(new_eigenvec.unsqueeze(0)) # Return frame(s) depending on method fa_method if fa_method == "all" or fa_method == "se3-all": return all_fa_pos, all_cell, all_rots elif fa_method == "det" or fa_method == "se3-det": return [all_fa_pos[det_index]], [all_cell[det_index]], [all_rots[det_index]] index = random.randint(0, len(all_fa_pos) - 1) return [all_fa_pos[index]], [all_cell[index]], [all_rots[index]]
[docs]def check_constraints(eigenval, eigenvec, dim=3): """Check that the requirements for frame averaging are satisfied Args: eigenval (tensor): eigenvalues eigenvec (tensor): eigenvectors dim (int): 2D or 3D frame averaging """ # Check eigenvalues are different if dim == 3: if (eigenval[1] / eigenval[0] > 0.90) or (eigenval[2] / eigenval[1] > 0.90): print("Eigenvalues are quite similar") else: if eigenval[1] / eigenval[0] > 0.90: print("Eigenvalues are quite similar") # Check eigenvectors are orthonormal if not torch.allclose(eigenvec @ eigenvec.T, torch.eye(dim), atol=1e-03): print("Matrix not orthogonal") # Check determinant of eigenvectors is 1 if not torch.allclose(torch.linalg.det(eigenvec), torch.tensor(1.0), atol=1e-03): print("Determinant is not 1")
[docs]def frame_averaging_3D(pos, cell=None, fa_method="stochastic", check=False): """Computes new positions for the graph atoms using frame averaging, which itself builds on the PCA of atom positions. Base case for 3D inputs. Args: pos (tensor): positions of atoms in the graph cell (tensor): unit cell of the graph. None if no pbc. fa_method (str): FA method used (stochastic, det, all, se3-all, se3-det, se3-stochastic) check (bool): check if constraints are satisfied. Default: False. Returns: (tensor): updated atom positions (tensor): updated unit cell (tensor): the rotation matrix used (PCA) """ # Compute centroid and covariance pos = pos - pos.mean(dim=0, keepdim=True) C = torch.matmul(pos.t(), pos) # Eigendecomposition eigenval, eigenvec = torch.linalg.eigh(C) # Sort, if necessary idx = eigenval.argsort(descending=True) eigenvec = eigenvec[:, idx] eigenval = eigenval[idx] # Check if constraints are satisfied if check: check_constraints(eigenval, eigenvec, 3) # Compute fa_pos fa_pos, fa_cell, fa_rot = compute_frames(eigenvec, pos, cell, fa_method) # No need to update distances, they are preserved. return fa_pos, fa_cell, fa_rot
[docs]def frame_averaging_2D(pos, cell=None, fa_method="stochastic", check=False): """Computes new positions for the graph atoms using frame averaging, which itself builds on the PCA of atom positions. 2D case: we project the atoms on the plane orthogonal to the z-axis. Motivation: sometimes, the z-axis is not the most relevant one (e.g. fixed). Args: pos (tensor): positions of atoms in the graph cell (tensor): unit cell of the graph. None if no pbc. fa_method (str): FA method used (stochastic, det, all, se3) check (bool): check if constraints are satisfied. Default: False. Returns: (tensor): updated atom positions (tensor): updated unit cell (tensor): the rotation matrix used (PCA) """ # Compute centroid and covariance pos_2D = pos[:, :2] - pos[:, :2].mean(dim=0, keepdim=True) C = torch.matmul(pos_2D.t(), pos_2D) # Eigendecomposition eigenval, eigenvec = torch.linalg.eigh(C) # Sort eigenvalues idx = eigenval.argsort(descending=True) eigenval = eigenval[idx] eigenvec = eigenvec[:, idx] # Check if constraints are satisfied if check: check_constraints(eigenval, eigenvec, 3) # Compute all frames fa_pos, fa_cell, fa_rot = compute_frames( eigenvec, pos_2D, cell, fa_method, pos[:, 2] ) # No need to update distances, they are preserved. return fa_pos, fa_cell, fa_rot
[docs]def data_augmentation(g, d=3, *args): """Data augmentation: randomly rotated graphs are added in the dataloader transform. Args: g (data.Data): single graph d (int): dimension of the DA rotation (2D around z-axis or 3D) rotation (str, optional): around which axis do we rotate it. Defaults to 'z'. Returns: (data.Data): rotated graph """ # Sampling a random rotation within [-180, 180] for all axes. if d == 3: transform = RandomRotate([-180, 180], [0, 1, 2]) # 3D else: transform = RandomRotate([-180, 180], [2]) # 2D around z-axis # Rotate graph graph_rotated, _, _ = transform(g) return graph_rotated