Holmes
commited on
Commit
·
ca7299e
1
Parent(s):
1af230e
test
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +134 -3
- analysis/Ramachandran_plot.py +99 -0
- analysis/__pycache__/Ramachandran_plot.cpython-310.pyc +0 -0
- analysis/__pycache__/merge_pred_pdb.cpython-310.pyc +0 -0
- analysis/__pycache__/metrics.cpython-310.pyc +0 -0
- analysis/__pycache__/utils.cpython-310.pyc +0 -0
- analysis/__pycache__/utils.cpython-38.pyc +0 -0
- analysis/eval_result.py +66 -0
- analysis/merge_pred_pdb.py +45 -0
- analysis/metrics.py +54 -0
- analysis/pca_analyse.py +116 -0
- analysis/src/__init__.py +0 -0
- analysis/src/__pycache__/__init__.cpython-310.pyc +0 -0
- analysis/src/__pycache__/__init__.cpython-37.pyc +0 -0
- analysis/src/__pycache__/__init__.cpython-39.pyc +0 -0
- analysis/src/__pycache__/eval.cpython-310.pyc +0 -0
- analysis/src/__pycache__/eval.cpython-37.pyc +0 -0
- analysis/src/__pycache__/eval.cpython-39.pyc +0 -0
- analysis/src/common/__init__.py +0 -0
- analysis/src/common/__pycache__/__init__.cpython-310.pyc +0 -0
- analysis/src/common/__pycache__/__init__.cpython-39.pyc +0 -0
- analysis/src/common/__pycache__/all_atom.cpython-39.pyc +0 -0
- analysis/src/common/__pycache__/data_transforms.cpython-39.pyc +0 -0
- analysis/src/common/__pycache__/geo_utils.cpython-310.pyc +0 -0
- analysis/src/common/__pycache__/geo_utils.cpython-39.pyc +0 -0
- analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc +0 -0
- analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc +0 -0
- analysis/src/common/__pycache__/protein.cpython-310.pyc +0 -0
- analysis/src/common/__pycache__/protein.cpython-39.pyc +0 -0
- analysis/src/common/__pycache__/residue_constants.cpython-310.pyc +0 -0
- analysis/src/common/__pycache__/residue_constants.cpython-39.pyc +0 -0
- analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc +0 -0
- analysis/src/common/__pycache__/rotation3d.cpython-39.pyc +0 -0
- analysis/src/common/all_atom.py +219 -0
- analysis/src/common/data_transforms.py +1194 -0
- analysis/src/common/geo_utils.py +155 -0
- analysis/src/common/pdb_utils.py +353 -0
- analysis/src/common/protein.py +289 -0
- analysis/src/common/residue_constants.py +897 -0
- analysis/src/common/rigid_utils.py +1451 -0
- analysis/src/common/rotation3d.py +596 -0
- analysis/src/data/__init__.py +0 -0
- analysis/src/data/__pycache__/__init__.cpython-39.pyc +0 -0
- analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc +0 -0
- analysis/src/data/components/__init__.py +0 -0
- analysis/src/data/components/__pycache__/__init__.cpython-39.pyc +0 -0
- analysis/src/data/components/__pycache__/dataset.cpython-39.pyc +0 -0
- analysis/src/data/components/dataset.py +321 -0
- analysis/src/data/protein_datamodule.py +242 -0
- analysis/src/eval.py +217 -0
README.md
CHANGED
|
@@ -1,3 +1,134 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# P2DFlow
|
| 2 |
+
|
| 3 |
+
> ## ℹ️ The version 2 of codes for P2DFlow will come soon to align with the new revised paper, and for easier use and modification (the old version can run correctly)
|
| 4 |
+
|
| 5 |
+
P2DFlow is a protein ensemble generative model with SE(3) flow matching based on ESMFold, the ensembles generated by P2DFlow could aid in understanding protein functions across various scenarios.
|
| 6 |
+
|
| 7 |
+
Technical details and evaluation results are provided in our paper:
|
| 8 |
+
* [P2DFlow: A Protein Ensemble Generative Model with SE(3) Flow Matching](https://arxiv.org/abs/2411.17196)
|
| 9 |
+
|
| 10 |
+
<p align="center">
|
| 11 |
+
<img src="resources/workflow.jpg" width="600"/>
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+

|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Table of Contents
|
| 18 |
+
1. [Installation](#Installation)
|
| 19 |
+
2. [Prepare Dataset](#Prepare-Dataset)
|
| 20 |
+
3. [Model weights](#Model-weights)
|
| 21 |
+
4. [Training](#Training)
|
| 22 |
+
5. [Inference](#Inference)
|
| 23 |
+
6. [Evaluation](#Evaluation)
|
| 24 |
+
7. [License](#License)
|
| 25 |
+
8. [Citation](#Citation)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## Installation
|
| 29 |
+
In an environment with cuda 11.7, run:
|
| 30 |
+
```
|
| 31 |
+
conda env create -f environment.yml
|
| 32 |
+
```
|
| 33 |
+
To activate the environment, run:
|
| 34 |
+
```
|
| 35 |
+
conda activate P2DFlow
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Prepare Dataset
|
| 39 |
+
#### (tips: If you want to use the data we have preprocessed, please go directly to `3. Process selected dataset`; if you prefer to process the data from scratch or work with your own data, please start from the beginning)
|
| 40 |
+
|
| 41 |
+
#### 1. Download raw ATLAS dataset
|
| 42 |
+
(i) Download the `Analysis & MDs` dataset from [ATLAS](https://www.dsimb.inserm.fr/ATLAS/), or you can use `./dataset/download.py` by running:
|
| 43 |
+
```
|
| 44 |
+
python ./dataset/download.py
|
| 45 |
+
```
|
| 46 |
+
We will use `.pdb` and `.xtc` files for the following calculation.
|
| 47 |
+
|
| 48 |
+
#### 2. Calculate the 'approximate energy and select representative structures
|
| 49 |
+
(i) Use `gaussian_kde` to calculate the 'approximate energy' (You need to put all files above in `./dataset`, include `ATLAS_filename.txt` for filenames of all proteins):
|
| 50 |
+
```
|
| 51 |
+
python ./dataset/traj_analyse.py
|
| 52 |
+
```
|
| 53 |
+
And you will get `traj_info.csv`.
|
| 54 |
+
|
| 55 |
+
(ii) Select representative structures at equal intervals based on the 'approximate energy':
|
| 56 |
+
```
|
| 57 |
+
python ./dataset/md_select.py
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
#### 3. Process selected dataset
|
| 61 |
+
|
| 62 |
+
(i) Download the selected dataset (or get it from the two steps above) from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filename is `selected_dataset.tar`, and decompress it using:
|
| 63 |
+
```
|
| 64 |
+
tar -xvf select_dataset.tar
|
| 65 |
+
```
|
| 66 |
+
(ii) Preprocess `.pdb` files to get `.pkl` files:
|
| 67 |
+
```
|
| 68 |
+
python ./data/process_pdb_files.py --pdb_dir ${pdb_dir} --write_dir ${write_dir}
|
| 69 |
+
```
|
| 70 |
+
And you will get `metadata.csv`.
|
| 71 |
+
|
| 72 |
+
then compute node representation and pair representation using ESM-2 (`csv_path` is the path of `metadata.csv`):
|
| 73 |
+
```
|
| 74 |
+
python ./data/cal_repr.py --csv_path ${csv_path}
|
| 75 |
+
```
|
| 76 |
+
then compute predicted static structure using ESMFold (`csv_path` is the path of `metadata.csv`):
|
| 77 |
+
```
|
| 78 |
+
python ./data/cal_static_structure.py --csv_path ${csv_path}
|
| 79 |
+
```
|
| 80 |
+
(iii) Provide the necessary `.csv` files for training
|
| 81 |
+
|
| 82 |
+
If you are using the data we have preprocessed, download the `.csv` files from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filenames are `train_dataset.csv` and `train_dataset_energy.csv`(they correspond to `csv_path` and `energy_csv_path` in `./configs/base.yaml` during training).
|
| 83 |
+
|
| 84 |
+
Or if you are using your own data, you can get `metadata.csv` from step 3 (correspond to `csv_path` in `./configs/base.yaml` during training, and you need to split a subset from it as the train dataset), and get `traj_info.csv` from step 2 (correspond to `energy_csv_path`).
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
## Model weights
|
| 89 |
+
Download the pretrained checkpoint from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filename is `pretrained.ckpt`, and put it into `./weights` folder. You can use the pretrained weight for inference.
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
## Training
|
| 93 |
+
To train P2DFlow, firstly make sure you have prepared the dataset according to `Prepare Dataset`, and put it in the right folder, then modify `./configs/base.yaml` (especially for `csv_path` and `energy_csv_path`). After this, you can run:
|
| 94 |
+
```
|
| 95 |
+
python experiments/train_se3_flows.py
|
| 96 |
+
```
|
| 97 |
+
And you will get the checkpoints in `./ckpt`.
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
## Inference
|
| 101 |
+
To infer for specified protein sequence, firstly modify `./inference/valid_seq.csv` and `./configs/inference.yaml` (especially for `validset_path`), then run:
|
| 102 |
+
```
|
| 103 |
+
python experiments/inference_se3_flows.py
|
| 104 |
+
```
|
| 105 |
+
And you will get the results in `./inference_outputs/weights/`.
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
## Evaluation
|
| 109 |
+
To evaluate metrics related to fidelity and dynamics, specify paths in `./analysis/eval_test.py`, then run:
|
| 110 |
+
```
|
| 111 |
+
python ./analysis/eval_test.py
|
| 112 |
+
```
|
| 113 |
+
To evaluate PCA, specify paths in `./analysis/pca_analyse.py`, then run:
|
| 114 |
+
```
|
| 115 |
+
python ./analysis/pca_analyse.py
|
| 116 |
+
```
|
| 117 |
+
To draw the ramachandran plots, specify paths in `./analysis/Ramachandran_plot.py`, then run:
|
| 118 |
+
```
|
| 119 |
+
python ./analysis/Ramachandran_plot.py
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
## License
|
| 123 |
+
This project is licensed under the terms of the GPL-3.0 license.
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
## Citation
|
| 127 |
+
```
|
| 128 |
+
@article{jin2024p2dflow,
|
| 129 |
+
title={P2DFlow: A Protein Ensemble Generative Model with SE(3) Flow Matching},
|
| 130 |
+
author={Yaowei Jin, Qi Huang, Ziyang Song, Mingyue Zheng, Dan Teng, Qian Shi},
|
| 131 |
+
journal={arXiv preprint arXiv:2411.17196},
|
| 132 |
+
year={2024}
|
| 133 |
+
}
|
| 134 |
+
```
|
analysis/Ramachandran_plot.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import MDAnalysis as mda
|
| 2 |
+
import numpy as np
|
| 3 |
+
from MDAnalysis.analysis.dihedrals import Ramachandran
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def ramachandran_eval(all_paths, pdb_file, output_dir):
|
| 11 |
+
angle_results_all = []
|
| 12 |
+
|
| 13 |
+
for dirpath in all_paths:
|
| 14 |
+
pdb_path = os.path.join(dirpath,pdb_file)
|
| 15 |
+
|
| 16 |
+
u = mda.Universe(pdb_path)
|
| 17 |
+
protein = u.select_atoms('protein')
|
| 18 |
+
# print('There are {} residues in the protein'.format(len(protein.residues)))
|
| 19 |
+
|
| 20 |
+
ramachandran = Ramachandran(protein)
|
| 21 |
+
ramachandran.run()
|
| 22 |
+
angle_results = ramachandran.results.angles
|
| 23 |
+
# print(angle_results.shape)
|
| 24 |
+
|
| 25 |
+
# ramachandran.plot(color='black', marker='.')
|
| 26 |
+
|
| 27 |
+
angle_results_all.append(angle_results.reshape([-1,2]))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# df = pd.DataFrame(angle_results.reshape([-1,2]))
|
| 31 |
+
# df.to_csv(os.path.join(output_dir, os.path.basename(dirpath)+'_'+pdb_file.split('.')[0]+'.csv'), index=False)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
points1 = angle_results_all[0]
|
| 35 |
+
grid_size = 360 # 网格的大小
|
| 36 |
+
x_bins = np.linspace(-180, 180, grid_size)
|
| 37 |
+
y_bins = np.linspace(-180, 180, grid_size)
|
| 38 |
+
result_tmp={}
|
| 39 |
+
for idx in range(len(angle_results_all[1:])):
|
| 40 |
+
idx = idx + 1
|
| 41 |
+
points2 = angle_results_all[idx]
|
| 42 |
+
|
| 43 |
+
# 使用2D直方图统计每组点在网格上的分布
|
| 44 |
+
hist1, _, _ = np.histogram2d(points1[:, 0], points1[:, 1], bins=[x_bins, y_bins])
|
| 45 |
+
hist2, _, _ = np.histogram2d(points2[:, 0], points2[:, 1], bins=[x_bins, y_bins])
|
| 46 |
+
|
| 47 |
+
# 将直方图转换为布尔值,表示某个网格是否有点落入
|
| 48 |
+
mask1 = hist1 > 0
|
| 49 |
+
mask2 = hist2 > 0
|
| 50 |
+
|
| 51 |
+
intersection = np.logical_and(mask1, mask2).sum()
|
| 52 |
+
all_mask2 = mask2.sum()
|
| 53 |
+
val_ratio = intersection / all_mask2
|
| 54 |
+
print(os.path.basename(all_paths[idx]), "val_ratio:", val_ratio)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
result_tmp[os.path.basename(all_paths[idx])] = val_ratio
|
| 58 |
+
result_tmp['file'] = pdb_file
|
| 59 |
+
|
| 60 |
+
return result_tmp
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
key1 = 'P2DFlow_epoch19'
|
| 66 |
+
all_paths = [
|
| 67 |
+
"/cluster/home/shiqian/frame-flow-test1/valid/evaluate/ATLAS_valid",
|
| 68 |
+
# "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/esm_n_pred",
|
| 69 |
+
"/cluster/home/shiqian/frame-flow-test1/valid/evaluate/alphaflow_pred",
|
| 70 |
+
"/cluster/home/shiqian/frame-flow-test1/valid/evaluate/Str2Str_pred",
|
| 71 |
+
|
| 72 |
+
f'/cluster/home/shiqian/frame-flow-test1/valid/evaluate/{key1}',
|
| 73 |
+
|
| 74 |
+
]
|
| 75 |
+
output_dir = '/cluster/home/shiqian/frame-flow-test1/valid/evaluate/Ramachandran'
|
| 76 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 77 |
+
results={
|
| 78 |
+
'file':[],
|
| 79 |
+
# 'esm_n_pred':[],
|
| 80 |
+
'alphaflow_pred':[],
|
| 81 |
+
'Str2Str_pred':[],
|
| 82 |
+
|
| 83 |
+
key1:[],
|
| 84 |
+
}
|
| 85 |
+
for file in os.listdir(all_paths[0]):
|
| 86 |
+
if re.search('\.pdb',file):
|
| 87 |
+
|
| 88 |
+
pdb_file = file
|
| 89 |
+
print(file)
|
| 90 |
+
result_tmp = ramachandran_eval(
|
| 91 |
+
all_paths=all_paths,
|
| 92 |
+
pdb_file=pdb_file,
|
| 93 |
+
output_dir=output_dir
|
| 94 |
+
)
|
| 95 |
+
for key in results.keys():
|
| 96 |
+
results[key].append(result_tmp[key])
|
| 97 |
+
|
| 98 |
+
out_total_df = pd.DataFrame(results)
|
| 99 |
+
out_total_df.to_csv(os.path.join(output_dir,f'Ramachandran_plot_validity_{key1}.csv'),index=False)
|
analysis/__pycache__/Ramachandran_plot.cpython-310.pyc
ADDED
|
Binary file (2.25 kB). View file
|
|
|
analysis/__pycache__/merge_pred_pdb.cpython-310.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
analysis/__pycache__/metrics.cpython-310.pyc
ADDED
|
Binary file (2.07 kB). View file
|
|
|
analysis/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (2.2 kB). View file
|
|
|
analysis/__pycache__/utils.cpython-38.pyc
ADDED
|
Binary file (2.18 kB). View file
|
|
|
analysis/eval_result.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import sys
|
| 4 |
+
sys.path.append('./analysis')
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from src.eval import evaluate_prediction
|
| 9 |
+
from merge_pred_pdb import merge_pdb_full
|
| 10 |
+
from Ramachandran_plot import ramachandran_eval
|
| 11 |
+
|
| 12 |
+
if __name__ == '__main__':
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
|
| 15 |
+
parser.add_argument("--pred_org_dir", type=str, default="./inference_outputs/weights/pretrained/2025-03-13_10-08")
|
| 16 |
+
parser.add_argument("--valid_csv_file", type=str, default="./inference/valid_seq.csv")
|
| 17 |
+
parser.add_argument("--pred_merge_dir", type=str, default="./inference/test/pred_merge_results")
|
| 18 |
+
parser.add_argument("--target_dir", type=str, default="./inference/test/target_dir")
|
| 19 |
+
parser.add_argument("--crystal_dir", type=str, default="./inference/test/crystal_dir")
|
| 20 |
+
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# merge pdb
|
| 25 |
+
pred_org_dir = args.pred_org_dir
|
| 26 |
+
valid_csv_file = args.valid_csv_file
|
| 27 |
+
pred_merge_dir = args.pred_merge_dir
|
| 28 |
+
merge_pdb_full(pred_org_dir, valid_csv_file, pred_merge_dir)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# cal_eval
|
| 32 |
+
pred_merge_dir = args.pred_merge_dir
|
| 33 |
+
target_dir = args.target_dir
|
| 34 |
+
crystal_dir = args.crystal_dir
|
| 35 |
+
evaluate_prediction(pred_merge_dir, target_dir, crystal_dir)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# cal_RP
|
| 39 |
+
all_paths = [
|
| 40 |
+
args.target_dir,
|
| 41 |
+
args.pred_merge_dir,
|
| 42 |
+
]
|
| 43 |
+
results={}
|
| 44 |
+
for file in os.listdir(all_paths[0]):
|
| 45 |
+
if re.search('\.pdb',file):
|
| 46 |
+
|
| 47 |
+
pdb_file = file
|
| 48 |
+
print(file)
|
| 49 |
+
result_tmp = ramachandran_eval(
|
| 50 |
+
all_paths=all_paths,
|
| 51 |
+
pdb_file=pdb_file,
|
| 52 |
+
output_dir=args.pred_merge_dir,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
for pred_paths in all_paths[1:]:
|
| 56 |
+
key_name = os.path.basename(pred_paths)
|
| 57 |
+
if key_name is results.keys():
|
| 58 |
+
results[key_name].append(result_tmp[key_name])
|
| 59 |
+
else:
|
| 60 |
+
results[key_name] = [result_tmp[key_name]]
|
| 61 |
+
|
| 62 |
+
out_total_df = pd.DataFrame(results)
|
| 63 |
+
out_total_df.to_csv(os.path.join(args.pred_merge_dir, f'Ramachandran_plot_validity.csv'), index=False)
|
| 64 |
+
print(f"RP results saved to {os.path.join(args.pred_merge_dir, f'Ramachandran_plot_validity.csv')}")
|
| 65 |
+
|
| 66 |
+
|
analysis/merge_pred_pdb.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from Bio.PDB import PDBParser, PDBIO
|
| 5 |
+
|
| 6 |
+
def merge_pdb(work_dir, new_file, ref_pdb):
|
| 7 |
+
parser = PDBParser()
|
| 8 |
+
structures = []
|
| 9 |
+
for pdb_dir in os.listdir(work_dir):
|
| 10 |
+
pattern=".*"+ref_pdb
|
| 11 |
+
pdb_dir_full=os.path.join(work_dir,pdb_dir)
|
| 12 |
+
if os.path.isdir(pdb_dir_full) and re.match(pattern,pdb_dir):
|
| 13 |
+
for pdb_file in os.listdir(pdb_dir_full):
|
| 14 |
+
if re.match("sample.*\.pdb",pdb_file):
|
| 15 |
+
structure = parser.get_structure(pdb_file, os.path.join(work_dir,pdb_dir,pdb_file))
|
| 16 |
+
structures.append(structure)
|
| 17 |
+
|
| 18 |
+
if len(structures) == 0:
|
| 19 |
+
return
|
| 20 |
+
print(ref_pdb,len(structures),"files")
|
| 21 |
+
|
| 22 |
+
new_structure = structures[0]
|
| 23 |
+
count = 0
|
| 24 |
+
for structure in structures[1:]:
|
| 25 |
+
for model in structure:
|
| 26 |
+
count += 1
|
| 27 |
+
# print(dir(model))
|
| 28 |
+
model.id = count
|
| 29 |
+
new_structure.add(model)
|
| 30 |
+
|
| 31 |
+
io = PDBIO()
|
| 32 |
+
io.set_structure(new_structure)
|
| 33 |
+
io.save(new_file)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def merge_pdb_full(inference_dir_f, valid_csv, output_dir):
|
| 37 |
+
os.makedirs(output_dir,exist_ok=True)
|
| 38 |
+
valid_set = pd.read_csv(valid_csv)
|
| 39 |
+
for filename in valid_set['file']:
|
| 40 |
+
output_file = os.path.join(output_dir, filename+".pdb")
|
| 41 |
+
merge_pdb(inference_dir_f, output_file, filename)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
analysis/metrics.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Metrics. """
|
| 2 |
+
import mdtraj as md
|
| 3 |
+
import numpy as np
|
| 4 |
+
from openfold.np import residue_constants
|
| 5 |
+
from tmtools import tm_align
|
| 6 |
+
from data import utils as du
|
| 7 |
+
|
| 8 |
+
def calc_tm_score(pos_1, pos_2, seq_1, seq_2):
|
| 9 |
+
tm_results = tm_align(pos_1, pos_2, seq_1, seq_2)
|
| 10 |
+
return tm_results.tm_norm_chain1, tm_results.tm_norm_chain2
|
| 11 |
+
|
| 12 |
+
def calc_mdtraj_metrics(pdb_path):
|
| 13 |
+
try:
|
| 14 |
+
traj = md.load(pdb_path)
|
| 15 |
+
pdb_ss = md.compute_dssp(traj, simplified=True)
|
| 16 |
+
pdb_coil_percent = np.mean(pdb_ss == 'C')
|
| 17 |
+
pdb_helix_percent = np.mean(pdb_ss == 'H')
|
| 18 |
+
pdb_strand_percent = np.mean(pdb_ss == 'E')
|
| 19 |
+
pdb_ss_percent = pdb_helix_percent + pdb_strand_percent
|
| 20 |
+
pdb_rg = md.compute_rg(traj)[0]
|
| 21 |
+
except IndexError as e:
|
| 22 |
+
print('Error in calc_mdtraj_metrics: {}'.format(e))
|
| 23 |
+
pdb_ss_percent = 0.0
|
| 24 |
+
pdb_coil_percent = 0.0
|
| 25 |
+
pdb_helix_percent = 0.0
|
| 26 |
+
pdb_strand_percent = 0.0
|
| 27 |
+
pdb_rg = 0.0
|
| 28 |
+
return {
|
| 29 |
+
'non_coil_percent': pdb_ss_percent,
|
| 30 |
+
'coil_percent': pdb_coil_percent,
|
| 31 |
+
'helix_percent': pdb_helix_percent,
|
| 32 |
+
'strand_percent': pdb_strand_percent,
|
| 33 |
+
'radius_of_gyration': pdb_rg,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
def calc_aligned_rmsd(pos_1, pos_2):
|
| 37 |
+
aligned_pos_1 = du.rigid_transform_3D(pos_1, pos_2)[0]
|
| 38 |
+
return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1))
|
| 39 |
+
|
| 40 |
+
def calc_ca_ca_metrics(ca_pos, bond_tol=0.1, clash_tol=1.0):
|
| 41 |
+
ca_bond_dists = np.linalg.norm(
|
| 42 |
+
ca_pos - np.roll(ca_pos, 1, axis=0), axis=-1)[1:]
|
| 43 |
+
ca_ca_dev = np.mean(np.abs(ca_bond_dists - residue_constants.ca_ca))
|
| 44 |
+
ca_ca_valid = np.mean(ca_bond_dists < (residue_constants.ca_ca + bond_tol))
|
| 45 |
+
|
| 46 |
+
ca_ca_dists2d = np.linalg.norm(
|
| 47 |
+
ca_pos[:, None, :] - ca_pos[None, :, :], axis=-1)
|
| 48 |
+
inter_dists = ca_ca_dists2d[np.where(np.triu(ca_ca_dists2d, k=0) > 0)]
|
| 49 |
+
clashes = inter_dists < clash_tol
|
| 50 |
+
return {
|
| 51 |
+
'ca_ca_deviation': ca_ca_dev,
|
| 52 |
+
'ca_ca_valid_percent': ca_ca_valid,
|
| 53 |
+
'num_ca_ca_clashes': np.sum(clashes),
|
| 54 |
+
}
|
analysis/pca_analyse.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import MDAnalysis as mda
|
| 5 |
+
from MDAnalysis.analysis import pca, align, rms
|
| 6 |
+
import numpy as np
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import seaborn as sns
|
| 9 |
+
import warnings
|
| 10 |
+
import argparse
|
| 11 |
+
warnings.filterwarnings("ignore")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def cal_PCA(md_pdb_path,ref_path,pred_pdb_path,n_components = 2):
|
| 15 |
+
print("")
|
| 16 |
+
print('filename=',os.path.basename(ref_path))
|
| 17 |
+
|
| 18 |
+
u = mda.Universe(md_pdb_path, md_pdb_path)
|
| 19 |
+
u_ref = mda.Universe(ref_path, ref_path)
|
| 20 |
+
|
| 21 |
+
aligner = align.AlignTraj(u,
|
| 22 |
+
u_ref,
|
| 23 |
+
select='name CA or name C or name N',
|
| 24 |
+
in_memory=True).run()
|
| 25 |
+
|
| 26 |
+
pc = pca.PCA(u,
|
| 27 |
+
select='name CA or name C or name N',
|
| 28 |
+
align=False, mean=None,
|
| 29 |
+
# n_components=None,
|
| 30 |
+
n_components=n_components,
|
| 31 |
+
).run()
|
| 32 |
+
|
| 33 |
+
backbone = u.select_atoms('name CA or name C or name N')
|
| 34 |
+
n_bb = len(backbone)
|
| 35 |
+
print('There are {} backbone atoms in the analysis'.format(n_bb))
|
| 36 |
+
|
| 37 |
+
for i in range(n_components):
|
| 38 |
+
print(f"Cumulated variance {i+1}: {pc.cumulated_variance[i]:.3f}")
|
| 39 |
+
|
| 40 |
+
transformed = pc.transform(backbone, n_components=n_components)
|
| 41 |
+
|
| 42 |
+
print(transformed.shape) # (3000, 2)
|
| 43 |
+
|
| 44 |
+
df = pd.DataFrame(transformed,
|
| 45 |
+
columns=['PC{}'.format(i+1) for i in range(n_components)])
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
plt.scatter(df['PC1'],df['PC2'],marker='o')
|
| 49 |
+
plt.show()
|
| 50 |
+
|
| 51 |
+
output_dir = os.path.dirname(md_pdb_path)
|
| 52 |
+
output_filename = os.path.basename(md_pdb_path).split('.')[0]
|
| 53 |
+
|
| 54 |
+
df.to_csv(os.path.join(output_dir, f'{output_filename}_md_pca.csv'))
|
| 55 |
+
plt.savefig(os.path.join(output_dir, f'{output_filename}_md_pca.png'))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
for k,v in pred_pdb_path.items():
|
| 59 |
+
u_pred = mda.Universe(v, v)
|
| 60 |
+
aligner = align.AlignTraj(u_pred,
|
| 61 |
+
u_ref,
|
| 62 |
+
select='name CA or name C or name N',
|
| 63 |
+
in_memory=True).run()
|
| 64 |
+
pred_backbone = u_pred.select_atoms('name CA or name C or name N')
|
| 65 |
+
pred_transformed = pc.transform(pred_backbone, n_components=n_components)
|
| 66 |
+
|
| 67 |
+
df = pd.DataFrame(pred_transformed,
|
| 68 |
+
columns=['PC{}'.format(i+1) for i in range(n_components)])
|
| 69 |
+
|
| 70 |
+
plt.scatter(df['PC1'],df['PC2'],marker='o')
|
| 71 |
+
plt.show()
|
| 72 |
+
|
| 73 |
+
output_dir = os.path.dirname(v)
|
| 74 |
+
output_filename = os.path.basename(v).split('.')[0]
|
| 75 |
+
df.to_csv(os.path.join(output_dir, f'{output_filename}_{k}_pca.csv'))
|
| 76 |
+
plt.savefig(os.path.join(output_dir, f'{output_filename}_{k}_pca.png'))
|
| 77 |
+
plt.clf()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if __name__ == '__main__':
|
| 81 |
+
|
| 82 |
+
parser = argparse.ArgumentParser()
|
| 83 |
+
|
| 84 |
+
parser.add_argument("--pred_pdb_dir", type=str, default="./inference/test/pred_merge_results")
|
| 85 |
+
parser.add_argument("--target_dir", type=str, default="./inference/test/target_dir")
|
| 86 |
+
parser.add_argument("--crystal_dir", type=str, default="./inference/test/crystal_dir")
|
| 87 |
+
|
| 88 |
+
args = parser.parse_args()
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
pred_pdb_path_org={
|
| 92 |
+
'P2DFlow':args.pred_pdb_dir,
|
| 93 |
+
}
|
| 94 |
+
md_pdb_path_org = args.target_dir
|
| 95 |
+
ref_path_org = args.crystal_dir
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
for file in os.listdir(md_pdb_path_org):
|
| 99 |
+
if re.search('\.pdb',file):
|
| 100 |
+
pred_pdb_path={
|
| 101 |
+
'P2DFlow':'',
|
| 102 |
+
# 'alphaflow':'',
|
| 103 |
+
# 'Str2Str':'',
|
| 104 |
+
}
|
| 105 |
+
for k,v in pred_pdb_path.items():
|
| 106 |
+
pred_pdb_path[k]=os.path.join(pred_pdb_path_org[k],file)
|
| 107 |
+
md_pdb_path = os.path.join(md_pdb_path_org, file)
|
| 108 |
+
ref_path = os.path.join(ref_path_org, file)
|
| 109 |
+
cal_PCA(md_pdb_path,ref_path,pred_pdb_path)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
|
analysis/src/__init__.py
ADDED
|
File without changes
|
analysis/src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (138 Bytes). View file
|
|
|
analysis/src/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (132 Bytes). View file
|
|
|
analysis/src/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (136 Bytes). View file
|
|
|
analysis/src/__pycache__/eval.cpython-310.pyc
ADDED
|
Binary file (3.09 kB). View file
|
|
|
analysis/src/__pycache__/eval.cpython-37.pyc
ADDED
|
Binary file (4.61 kB). View file
|
|
|
analysis/src/__pycache__/eval.cpython-39.pyc
ADDED
|
Binary file (4.95 kB). View file
|
|
|
analysis/src/common/__init__.py
ADDED
|
File without changes
|
analysis/src/common/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
analysis/src/common/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
analysis/src/common/__pycache__/all_atom.cpython-39.pyc
ADDED
|
Binary file (5.15 kB). View file
|
|
|
analysis/src/common/__pycache__/data_transforms.cpython-39.pyc
ADDED
|
Binary file (26.9 kB). View file
|
|
|
analysis/src/common/__pycache__/geo_utils.cpython-310.pyc
ADDED
|
Binary file (5.02 kB). View file
|
|
|
analysis/src/common/__pycache__/geo_utils.cpython-39.pyc
ADDED
|
Binary file (5 kB). View file
|
|
|
analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
analysis/src/common/__pycache__/protein.cpython-310.pyc
ADDED
|
Binary file (7.46 kB). View file
|
|
|
analysis/src/common/__pycache__/protein.cpython-39.pyc
ADDED
|
Binary file (7.45 kB). View file
|
|
|
analysis/src/common/__pycache__/residue_constants.cpython-310.pyc
ADDED
|
Binary file (23.7 kB). View file
|
|
|
analysis/src/common/__pycache__/residue_constants.cpython-39.pyc
ADDED
|
Binary file (23.2 kB). View file
|
|
|
analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc
ADDED
|
Binary file (41.4 kB). View file
|
|
|
analysis/src/common/__pycache__/rotation3d.cpython-39.pyc
ADDED
|
Binary file (17 kB). View file
|
|
|
analysis/src/common/all_atom.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utilities for calculating all atom representations.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from src.common import residue_constants as rc
|
| 8 |
+
from src.common.data_transforms import atom37_to_torsion_angles
|
| 9 |
+
from src.common.rigid_utils import Rigid, Rotation
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Residue Constants from OpenFold/AlphaFold2.
|
| 13 |
+
IDEALIZED_POS37 = torch.tensor(rc.restype_atom37_rigid_group_positions)
|
| 14 |
+
IDEALIZED_POS37_MASK = torch.any(IDEALIZED_POS37, axis=-1)
|
| 15 |
+
IDEALIZED_POS = torch.tensor(rc.restype_atom14_rigid_group_positions)
|
| 16 |
+
DEFAULT_FRAMES = torch.tensor(rc.restype_rigid_group_default_frame)
|
| 17 |
+
ATOM_MASK = torch.tensor(rc.restype_atom14_mask)
|
| 18 |
+
GROUP_IDX = torch.tensor(rc.restype_atom14_to_rigid_group)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def torsion_angles_to_frames(
|
| 22 |
+
r: Rigid,
|
| 23 |
+
alpha: torch.Tensor,
|
| 24 |
+
aatype: torch.Tensor,
|
| 25 |
+
):
|
| 26 |
+
# [*, N, 8, 4, 4]
|
| 27 |
+
default_4x4 = DEFAULT_FRAMES[aatype, ...].to(r.device)
|
| 28 |
+
|
| 29 |
+
# [*, N, 8] transformations, i.e.
|
| 30 |
+
# One [*, N, 8, 3, 3] rotation matrix and
|
| 31 |
+
# One [*, N, 8, 3] translation matrix
|
| 32 |
+
default_r = r.from_tensor_4x4(default_4x4)
|
| 33 |
+
|
| 34 |
+
bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
|
| 35 |
+
bb_rot[..., 1] = 1
|
| 36 |
+
|
| 37 |
+
# [*, N, 8, 2]
|
| 38 |
+
alpha = torch.cat(
|
| 39 |
+
[bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# [*, N, 8, 3, 3]
|
| 43 |
+
# Produces rotation matrices of the form:
|
| 44 |
+
# [
|
| 45 |
+
# [1, 0 , 0 ],
|
| 46 |
+
# [0, a_2,-a_1],
|
| 47 |
+
# [0, a_1, a_2]
|
| 48 |
+
# ]
|
| 49 |
+
# This follows the original code rather than the supplement, which uses
|
| 50 |
+
# different indices.
|
| 51 |
+
|
| 52 |
+
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
|
| 53 |
+
all_rots[..., 0, 0] = 1
|
| 54 |
+
all_rots[..., 1, 1] = alpha[..., 1]
|
| 55 |
+
all_rots[..., 1, 2] = -alpha[..., 0]
|
| 56 |
+
all_rots[..., 2, 1:] = alpha
|
| 57 |
+
|
| 58 |
+
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
|
| 59 |
+
|
| 60 |
+
all_frames = default_r.compose(all_rots)
|
| 61 |
+
|
| 62 |
+
chi2_frame_to_frame = all_frames[..., 5]
|
| 63 |
+
chi3_frame_to_frame = all_frames[..., 6]
|
| 64 |
+
chi4_frame_to_frame = all_frames[..., 7]
|
| 65 |
+
|
| 66 |
+
chi1_frame_to_bb = all_frames[..., 4]
|
| 67 |
+
chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
|
| 68 |
+
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
|
| 69 |
+
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
|
| 70 |
+
|
| 71 |
+
all_frames_to_bb = Rigid.cat(
|
| 72 |
+
[
|
| 73 |
+
all_frames[..., :5],
|
| 74 |
+
chi2_frame_to_bb.unsqueeze(-1),
|
| 75 |
+
chi3_frame_to_bb.unsqueeze(-1),
|
| 76 |
+
chi4_frame_to_bb.unsqueeze(-1),
|
| 77 |
+
],
|
| 78 |
+
dim=-1,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
all_frames_to_global = r[..., None].compose(all_frames_to_bb)
|
| 82 |
+
|
| 83 |
+
return all_frames_to_global
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def prot_to_torsion_angles(aatype, atom37, atom37_mask):
|
| 87 |
+
"""Calculate torsion angle features from protein features."""
|
| 88 |
+
prot_feats = {
|
| 89 |
+
'aatype': aatype,
|
| 90 |
+
'all_atom_positions': atom37,
|
| 91 |
+
'all_atom_mask': atom37_mask,
|
| 92 |
+
}
|
| 93 |
+
torsion_angles_feats = atom37_to_torsion_angles()(prot_feats)
|
| 94 |
+
torsion_angles = torsion_angles_feats['torsion_angles_sin_cos']
|
| 95 |
+
torsion_mask = torsion_angles_feats['torsion_angles_mask']
|
| 96 |
+
return torsion_angles, torsion_mask
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def frames_to_atom14_pos(
|
| 100 |
+
r: Rigid,
|
| 101 |
+
aatype: torch.Tensor,
|
| 102 |
+
):
|
| 103 |
+
"""Convert frames to their idealized all atom representation.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
r: All rigid groups. [..., N, 8, 3]
|
| 107 |
+
aatype: Residue types. [..., N]
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
# [*, N, 14]
|
| 114 |
+
group_mask = GROUP_IDX[aatype, ...]
|
| 115 |
+
|
| 116 |
+
# [*, N, 14, 8]
|
| 117 |
+
group_mask = torch.nn.functional.one_hot(
|
| 118 |
+
group_mask,
|
| 119 |
+
num_classes=DEFAULT_FRAMES.shape[-3],
|
| 120 |
+
).to(r.device)
|
| 121 |
+
|
| 122 |
+
# [*, N, 14, 8]
|
| 123 |
+
t_atoms_to_global = r[..., None, :] * group_mask
|
| 124 |
+
|
| 125 |
+
# [*, N, 14]
|
| 126 |
+
t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
|
| 127 |
+
lambda x: torch.sum(x, dim=-1)
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# [*, N, 14, 1]
|
| 131 |
+
frame_atom_mask = ATOM_MASK[aatype, ...].unsqueeze(-1).to(r.device)
|
| 132 |
+
|
| 133 |
+
# [*, N, 14, 3]
|
| 134 |
+
frame_null_pos = IDEALIZED_POS[aatype, ...].to(r.device)
|
| 135 |
+
pred_positions = t_atoms_to_global.apply(frame_null_pos)
|
| 136 |
+
pred_positions = pred_positions * frame_atom_mask
|
| 137 |
+
|
| 138 |
+
return pred_positions
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def compute_backbone(bb_rigids, psi_torsions, aatype=None, device=None):
|
| 142 |
+
if device is None:
|
| 143 |
+
device = bb_rigids.device
|
| 144 |
+
|
| 145 |
+
torsion_angles = torch.tile(
|
| 146 |
+
psi_torsions[..., None, :],
|
| 147 |
+
tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1)
|
| 148 |
+
).to(device)
|
| 149 |
+
|
| 150 |
+
# aatype must be on cpu for initializing the tensor by indexing
|
| 151 |
+
if aatype is None:
|
| 152 |
+
aatype = torch.zeros_like(bb_rigids).cpu().long()
|
| 153 |
+
else:
|
| 154 |
+
aatype = aatype.cpu()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
all_frames = torsion_angles_to_frames(
|
| 158 |
+
bb_rigids,
|
| 159 |
+
torsion_angles,
|
| 160 |
+
aatype,
|
| 161 |
+
)
|
| 162 |
+
atom14_pos = frames_to_atom14_pos(
|
| 163 |
+
all_frames,
|
| 164 |
+
aatype,
|
| 165 |
+
)
|
| 166 |
+
atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=device)
|
| 167 |
+
# atom14 bb order = ['N', 'CA', 'C', 'O', 'CB']
|
| 168 |
+
# atom37 bb order = ['N', 'CA', 'C', 'CB', 'O']
|
| 169 |
+
atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :]
|
| 170 |
+
atom37_bb_pos[..., 3, :] = atom14_pos[..., 4, :]
|
| 171 |
+
atom37_bb_pos[..., 4, :] = atom14_pos[..., 3, :]
|
| 172 |
+
atom37_mask = torch.any(atom37_bb_pos, axis=-1)
|
| 173 |
+
return atom37_bb_pos, atom37_mask, aatype.to(device), atom14_pos
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def calculate_neighbor_angles(R_ac, R_ab):
|
| 177 |
+
"""Calculate angles between atoms c <- a -> b.
|
| 178 |
+
|
| 179 |
+
Parameters
|
| 180 |
+
----------
|
| 181 |
+
R_ac: Tensor, shape = (N,3)
|
| 182 |
+
Vector from atom a to c.
|
| 183 |
+
R_ab: Tensor, shape = (N,3)
|
| 184 |
+
Vector from atom a to b.
|
| 185 |
+
|
| 186 |
+
Returns
|
| 187 |
+
-------
|
| 188 |
+
angle_cab: Tensor, shape = (N,)
|
| 189 |
+
Angle between atoms c <- a -> b.
|
| 190 |
+
"""
|
| 191 |
+
# cos(alpha) = (u * v) / (|u|*|v|)
|
| 192 |
+
x = torch.sum(R_ac * R_ab, dim=1) # shape = (N,)
|
| 193 |
+
# sin(alpha) = |u x v| / (|u|*|v|)
|
| 194 |
+
y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,)
|
| 195 |
+
# avoid that for y == (0,0,0) the gradient wrt. y becomes NaN
|
| 196 |
+
y = torch.max(y, torch.tensor(1e-9))
|
| 197 |
+
angle = torch.atan2(y, x)
|
| 198 |
+
return angle
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def vector_projection(R_ab, P_n):
|
| 202 |
+
"""
|
| 203 |
+
Project the vector R_ab onto a plane with normal vector P_n.
|
| 204 |
+
|
| 205 |
+
Parameters
|
| 206 |
+
----------
|
| 207 |
+
R_ab: Tensor, shape = (N,3)
|
| 208 |
+
Vector from atom a to b.
|
| 209 |
+
P_n: Tensor, shape = (N,3)
|
| 210 |
+
Normal vector of a plane onto which to project R_ab.
|
| 211 |
+
|
| 212 |
+
Returns
|
| 213 |
+
-------
|
| 214 |
+
R_ab_proj: Tensor, shape = (N,3)
|
| 215 |
+
Projected vector (orthogonal to P_n).
|
| 216 |
+
"""
|
| 217 |
+
a_x_b = torch.sum(R_ab * P_n, dim=-1)
|
| 218 |
+
b_x_b = torch.sum(P_n * P_n, dim=-1)
|
| 219 |
+
return R_ab - (a_x_b / b_x_b)[:, None] * P_n
|
analysis/src/common/data_transforms.py
ADDED
|
@@ -0,0 +1,1194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import itertools
|
| 17 |
+
from functools import reduce, wraps
|
| 18 |
+
from operator import add
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from src.common import residue_constants as rc
|
| 24 |
+
from src.common.rigid_utils import Rotation, Rigid
|
| 25 |
+
from src.utils.tensor_utils import (
|
| 26 |
+
tree_map,
|
| 27 |
+
tensor_tree_map,
|
| 28 |
+
batched_gather,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
NUM_RES = "num residues placeholder"
|
| 32 |
+
NUM_MSA_SEQ = "msa placeholder"
|
| 33 |
+
NUM_EXTRA_SEQ = "extra msa placeholder"
|
| 34 |
+
NUM_TEMPLATES = "num templates placeholder"
|
| 35 |
+
|
| 36 |
+
MSA_FEATURE_NAMES = [
|
| 37 |
+
"msa",
|
| 38 |
+
"deletion_matrix",
|
| 39 |
+
"msa_mask",
|
| 40 |
+
"msa_row_mask",
|
| 41 |
+
"bert_mask",
|
| 42 |
+
"true_msa",
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def cast_to_64bit_ints(protein):
|
| 47 |
+
# We keep all ints as int64
|
| 48 |
+
for k, v in protein.items():
|
| 49 |
+
if v.dtype == torch.int32:
|
| 50 |
+
protein[k] = v.type(torch.int64)
|
| 51 |
+
|
| 52 |
+
return protein
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def make_one_hot(x, num_classes):
|
| 56 |
+
x_one_hot = torch.zeros(*x.shape, num_classes)
|
| 57 |
+
x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
|
| 58 |
+
return x_one_hot
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def make_seq_mask(protein):
|
| 62 |
+
protein["seq_mask"] = torch.ones(
|
| 63 |
+
protein["aatype"].shape, dtype=torch.float32
|
| 64 |
+
)
|
| 65 |
+
return protein
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def make_template_mask(protein):
|
| 69 |
+
protein["template_mask"] = torch.ones(
|
| 70 |
+
protein["template_aatype"].shape[0], dtype=torch.float32
|
| 71 |
+
)
|
| 72 |
+
return protein
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def curry1(f):
|
| 76 |
+
"""Supply all arguments but the first."""
|
| 77 |
+
@wraps(f)
|
| 78 |
+
def fc(*args, **kwargs):
|
| 79 |
+
return lambda x: f(x, *args, **kwargs)
|
| 80 |
+
|
| 81 |
+
return fc
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def make_all_atom_aatype(protein):
|
| 85 |
+
protein["all_atom_aatype"] = protein["aatype"]
|
| 86 |
+
return protein
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def fix_templates_aatype(protein):
|
| 90 |
+
# Map one-hot to indices
|
| 91 |
+
num_templates = protein["template_aatype"].shape[0]
|
| 92 |
+
if(num_templates > 0):
|
| 93 |
+
protein["template_aatype"] = torch.argmax(
|
| 94 |
+
protein["template_aatype"], dim=-1
|
| 95 |
+
)
|
| 96 |
+
# Map hhsearch-aatype to our aatype.
|
| 97 |
+
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
|
| 98 |
+
new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
|
| 99 |
+
num_templates, -1
|
| 100 |
+
)
|
| 101 |
+
protein["template_aatype"] = torch.gather(
|
| 102 |
+
new_order, 1, index=protein["template_aatype"]
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
return protein
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def correct_msa_restypes(protein):
|
| 109 |
+
"""Correct MSA restype to have the same order as rc."""
|
| 110 |
+
new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
|
| 111 |
+
new_order = torch.tensor(
|
| 112 |
+
[new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
|
| 113 |
+
).transpose(0, 1)
|
| 114 |
+
protein["msa"] = torch.gather(new_order, 0, protein["msa"])
|
| 115 |
+
|
| 116 |
+
perm_matrix = np.zeros((22, 22), dtype=np.float32)
|
| 117 |
+
perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
|
| 118 |
+
|
| 119 |
+
for k in protein:
|
| 120 |
+
if "profile" in k:
|
| 121 |
+
num_dim = protein[k].shape.as_list()[-1]
|
| 122 |
+
assert num_dim in [
|
| 123 |
+
20,
|
| 124 |
+
21,
|
| 125 |
+
22,
|
| 126 |
+
], "num_dim for %s out of expected range: %s" % (k, num_dim)
|
| 127 |
+
protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
|
| 128 |
+
|
| 129 |
+
return protein
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def squeeze_features(protein):
|
| 133 |
+
"""Remove singleton and repeated dimensions in protein features."""
|
| 134 |
+
protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
|
| 135 |
+
for k in [
|
| 136 |
+
"domain_name",
|
| 137 |
+
"msa",
|
| 138 |
+
"num_alignments",
|
| 139 |
+
"seq_length",
|
| 140 |
+
"sequence",
|
| 141 |
+
"superfamily",
|
| 142 |
+
"deletion_matrix",
|
| 143 |
+
"resolution",
|
| 144 |
+
"between_segment_residues",
|
| 145 |
+
"residue_index",
|
| 146 |
+
"template_all_atom_mask",
|
| 147 |
+
]:
|
| 148 |
+
if k in protein:
|
| 149 |
+
final_dim = protein[k].shape[-1]
|
| 150 |
+
if isinstance(final_dim, int) and final_dim == 1:
|
| 151 |
+
if torch.is_tensor(protein[k]):
|
| 152 |
+
protein[k] = torch.squeeze(protein[k], dim=-1)
|
| 153 |
+
else:
|
| 154 |
+
protein[k] = np.squeeze(protein[k], axis=-1)
|
| 155 |
+
|
| 156 |
+
for k in ["seq_length", "num_alignments"]:
|
| 157 |
+
if k in protein:
|
| 158 |
+
protein[k] = protein[k][0]
|
| 159 |
+
|
| 160 |
+
return protein
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@curry1
|
| 164 |
+
def randomly_replace_msa_with_unknown(protein, replace_proportion):
|
| 165 |
+
"""Replace a portion of the MSA with 'X'."""
|
| 166 |
+
msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
|
| 167 |
+
x_idx = 20
|
| 168 |
+
gap_idx = 21
|
| 169 |
+
msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
|
| 170 |
+
protein["msa"] = torch.where(
|
| 171 |
+
msa_mask,
|
| 172 |
+
torch.ones_like(protein["msa"]) * x_idx,
|
| 173 |
+
protein["msa"]
|
| 174 |
+
)
|
| 175 |
+
aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
|
| 176 |
+
|
| 177 |
+
protein["aatype"] = torch.where(
|
| 178 |
+
aatype_mask,
|
| 179 |
+
torch.ones_like(protein["aatype"]) * x_idx,
|
| 180 |
+
protein["aatype"],
|
| 181 |
+
)
|
| 182 |
+
return protein
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@curry1
|
| 186 |
+
def sample_msa(protein, max_seq, keep_extra, seed=None):
|
| 187 |
+
"""Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
|
| 188 |
+
num_seq = protein["msa"].shape[0]
|
| 189 |
+
g = torch.Generator(device=protein["msa"].device)
|
| 190 |
+
if seed is not None:
|
| 191 |
+
g.manual_seed(seed)
|
| 192 |
+
shuffled = torch.randperm(num_seq - 1, generator=g) + 1
|
| 193 |
+
index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
|
| 194 |
+
num_sel = min(max_seq, num_seq)
|
| 195 |
+
sel_seq, not_sel_seq = torch.split(
|
| 196 |
+
index_order, [num_sel, num_seq - num_sel]
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
for k in MSA_FEATURE_NAMES:
|
| 200 |
+
if k in protein:
|
| 201 |
+
if keep_extra:
|
| 202 |
+
protein["extra_" + k] = torch.index_select(
|
| 203 |
+
protein[k], 0, not_sel_seq
|
| 204 |
+
)
|
| 205 |
+
protein[k] = torch.index_select(protein[k], 0, sel_seq)
|
| 206 |
+
|
| 207 |
+
return protein
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@curry1
|
| 211 |
+
def add_distillation_flag(protein, distillation):
|
| 212 |
+
protein['is_distillation'] = distillation
|
| 213 |
+
return protein
|
| 214 |
+
|
| 215 |
+
@curry1
|
| 216 |
+
def sample_msa_distillation(protein, max_seq):
|
| 217 |
+
if(protein["is_distillation"] == 1):
|
| 218 |
+
protein = sample_msa(max_seq, keep_extra=False)(protein)
|
| 219 |
+
return protein
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
@curry1
|
| 223 |
+
def crop_extra_msa(protein, max_extra_msa):
|
| 224 |
+
num_seq = protein["extra_msa"].shape[0]
|
| 225 |
+
num_sel = min(max_extra_msa, num_seq)
|
| 226 |
+
select_indices = torch.randperm(num_seq)[:num_sel]
|
| 227 |
+
for k in MSA_FEATURE_NAMES:
|
| 228 |
+
if "extra_" + k in protein:
|
| 229 |
+
protein["extra_" + k] = torch.index_select(
|
| 230 |
+
protein["extra_" + k], 0, select_indices
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
return protein
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def delete_extra_msa(protein):
|
| 237 |
+
for k in MSA_FEATURE_NAMES:
|
| 238 |
+
if "extra_" + k in protein:
|
| 239 |
+
del protein["extra_" + k]
|
| 240 |
+
return protein
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
# Not used in inference
|
| 244 |
+
@curry1
|
| 245 |
+
def block_delete_msa(protein, config):
|
| 246 |
+
num_seq = protein["msa"].shape[0]
|
| 247 |
+
block_num_seq = torch.floor(
|
| 248 |
+
torch.tensor(num_seq, dtype=torch.float32)
|
| 249 |
+
* config.msa_fraction_per_block
|
| 250 |
+
).to(torch.int32)
|
| 251 |
+
|
| 252 |
+
if config.randomize_num_blocks:
|
| 253 |
+
nb = torch.distributions.uniform.Uniform(
|
| 254 |
+
0, config.num_blocks + 1
|
| 255 |
+
).sample()
|
| 256 |
+
else:
|
| 257 |
+
nb = config.num_blocks
|
| 258 |
+
|
| 259 |
+
del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
|
| 260 |
+
del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
|
| 261 |
+
del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
|
| 262 |
+
del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
|
| 263 |
+
|
| 264 |
+
# Make sure we keep the original sequence
|
| 265 |
+
combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
|
| 266 |
+
uniques, counts = combined.unique(return_counts=True)
|
| 267 |
+
difference = uniques[counts == 1]
|
| 268 |
+
intersection = uniques[counts > 1]
|
| 269 |
+
keep_indices = torch.squeeze(difference, 0)
|
| 270 |
+
|
| 271 |
+
for k in MSA_FEATURE_NAMES:
|
| 272 |
+
if k in protein:
|
| 273 |
+
protein[k] = torch.gather(protein[k], keep_indices)
|
| 274 |
+
|
| 275 |
+
return protein
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@curry1
|
| 279 |
+
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
|
| 280 |
+
weights = torch.cat(
|
| 281 |
+
[torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
|
| 282 |
+
0,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# Make agreement score as weighted Hamming distance
|
| 286 |
+
msa_one_hot = make_one_hot(protein["msa"], 23)
|
| 287 |
+
sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
|
| 288 |
+
extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
|
| 289 |
+
extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
|
| 290 |
+
|
| 291 |
+
num_seq, num_res, _ = sample_one_hot.shape
|
| 292 |
+
extra_num_seq, _, _ = extra_one_hot.shape
|
| 293 |
+
|
| 294 |
+
# Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
|
| 295 |
+
# in an optimized fashion to avoid possible memory or computation blowup.
|
| 296 |
+
agreement = torch.matmul(
|
| 297 |
+
torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
|
| 298 |
+
torch.reshape(
|
| 299 |
+
sample_one_hot * weights, [num_seq, num_res * 23]
|
| 300 |
+
).transpose(0, 1),
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Assign each sequence in the extra sequences to the closest MSA sample
|
| 304 |
+
protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
|
| 305 |
+
torch.int64
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
return protein
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def unsorted_segment_sum(data, segment_ids, num_segments):
|
| 312 |
+
"""
|
| 313 |
+
Computes the sum along segments of a tensor. Similar to
|
| 314 |
+
tf.unsorted_segment_sum, but only supports 1-D indices.
|
| 315 |
+
|
| 316 |
+
:param data: A tensor whose segments are to be summed.
|
| 317 |
+
:param segment_ids: The 1-D segment indices tensor.
|
| 318 |
+
:param num_segments: The number of segments.
|
| 319 |
+
:return: A tensor of same data type as the data argument.
|
| 320 |
+
"""
|
| 321 |
+
assert (
|
| 322 |
+
len(segment_ids.shape) == 1 and
|
| 323 |
+
segment_ids.shape[0] == data.shape[0]
|
| 324 |
+
)
|
| 325 |
+
segment_ids = segment_ids.view(
|
| 326 |
+
segment_ids.shape[0], *((1,) * len(data.shape[1:]))
|
| 327 |
+
)
|
| 328 |
+
segment_ids = segment_ids.expand(data.shape)
|
| 329 |
+
shape = [num_segments] + list(data.shape[1:])
|
| 330 |
+
tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float())
|
| 331 |
+
tensor = tensor.type(data.dtype)
|
| 332 |
+
return tensor
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
@curry1
|
| 336 |
+
def summarize_clusters(protein):
|
| 337 |
+
"""Produce profile and deletion_matrix_mean within each cluster."""
|
| 338 |
+
num_seq = protein["msa"].shape[0]
|
| 339 |
+
|
| 340 |
+
def csum(x):
|
| 341 |
+
return unsorted_segment_sum(
|
| 342 |
+
x, protein["extra_cluster_assignment"], num_seq
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
mask = protein["extra_msa_mask"]
|
| 346 |
+
mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
|
| 347 |
+
|
| 348 |
+
msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
|
| 349 |
+
msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
|
| 350 |
+
protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
|
| 351 |
+
del msa_sum
|
| 352 |
+
|
| 353 |
+
del_sum = csum(mask * protein["extra_deletion_matrix"])
|
| 354 |
+
del_sum += protein["deletion_matrix"] # Original sequence
|
| 355 |
+
protein["cluster_deletion_mean"] = del_sum / mask_counts
|
| 356 |
+
del del_sum
|
| 357 |
+
|
| 358 |
+
return protein
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def make_msa_mask(protein):
|
| 362 |
+
"""Mask features are all ones, but will later be zero-padded."""
|
| 363 |
+
protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
|
| 364 |
+
protein["msa_row_mask"] = torch.ones(
|
| 365 |
+
(protein["msa"].shape[0]), dtype=torch.float32
|
| 366 |
+
)
|
| 367 |
+
return protein
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
|
| 371 |
+
"""Create pseudo beta features."""
|
| 372 |
+
is_gly = torch.eq(aatype, rc.restype_order["G"])
|
| 373 |
+
ca_idx = rc.atom_order["CA"]
|
| 374 |
+
cb_idx = rc.atom_order["CB"]
|
| 375 |
+
pseudo_beta = torch.where(
|
| 376 |
+
torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
|
| 377 |
+
all_atom_positions[..., ca_idx, :],
|
| 378 |
+
all_atom_positions[..., cb_idx, :],
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
if all_atom_mask is not None:
|
| 382 |
+
pseudo_beta_mask = torch.where(
|
| 383 |
+
is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
|
| 384 |
+
)
|
| 385 |
+
return pseudo_beta, pseudo_beta_mask
|
| 386 |
+
else:
|
| 387 |
+
return pseudo_beta
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
@curry1
|
| 391 |
+
def make_pseudo_beta(protein, prefix=""):
|
| 392 |
+
"""Create pseudo-beta (alpha for glycine) position and mask."""
|
| 393 |
+
assert prefix in ["", "template_"]
|
| 394 |
+
(
|
| 395 |
+
protein[prefix + "pseudo_beta"],
|
| 396 |
+
protein[prefix + "pseudo_beta_mask"],
|
| 397 |
+
) = pseudo_beta_fn(
|
| 398 |
+
protein["template_aatype" if prefix else "aatype"],
|
| 399 |
+
protein[prefix + "all_atom_positions"],
|
| 400 |
+
protein["template_all_atom_mask" if prefix else "all_atom_mask"],
|
| 401 |
+
)
|
| 402 |
+
return protein
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
@curry1
|
| 406 |
+
def add_constant_field(protein, key, value):
|
| 407 |
+
protein[key] = torch.tensor(value)
|
| 408 |
+
return protein
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def shaped_categorical(probs, epsilon=1e-10):
|
| 412 |
+
ds = probs.shape
|
| 413 |
+
num_classes = ds[-1]
|
| 414 |
+
distribution = torch.distributions.categorical.Categorical(
|
| 415 |
+
torch.reshape(probs + epsilon, [-1, num_classes])
|
| 416 |
+
)
|
| 417 |
+
counts = distribution.sample()
|
| 418 |
+
return torch.reshape(counts, ds[:-1])
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def make_hhblits_profile(protein):
|
| 422 |
+
"""Compute the HHblits MSA profile if not already present."""
|
| 423 |
+
if "hhblits_profile" in protein:
|
| 424 |
+
return protein
|
| 425 |
+
|
| 426 |
+
# Compute the profile for every residue (over all MSA sequences).
|
| 427 |
+
msa_one_hot = make_one_hot(protein["msa"], 22)
|
| 428 |
+
|
| 429 |
+
protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
|
| 430 |
+
return protein
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
@curry1
|
| 434 |
+
def make_masked_msa(protein, config, replace_fraction):
|
| 435 |
+
"""Create data for BERT on raw MSA."""
|
| 436 |
+
# Add a random amino acid uniformly.
|
| 437 |
+
random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
|
| 438 |
+
|
| 439 |
+
categorical_probs = (
|
| 440 |
+
config.uniform_prob * random_aa
|
| 441 |
+
+ config.profile_prob * protein["hhblits_profile"]
|
| 442 |
+
+ config.same_prob * make_one_hot(protein["msa"], 22)
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Put all remaining probability on [MASK] which is a new column
|
| 446 |
+
pad_shapes = list(
|
| 447 |
+
reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
|
| 448 |
+
)
|
| 449 |
+
pad_shapes[1] = 1
|
| 450 |
+
mask_prob = (
|
| 451 |
+
1.0 - config.profile_prob - config.same_prob - config.uniform_prob
|
| 452 |
+
)
|
| 453 |
+
assert mask_prob >= 0.0
|
| 454 |
+
categorical_probs = torch.nn.functional.pad(
|
| 455 |
+
categorical_probs, pad_shapes, value=mask_prob
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
sh = protein["msa"].shape
|
| 459 |
+
mask_position = torch.rand(sh) < replace_fraction
|
| 460 |
+
|
| 461 |
+
bert_msa = shaped_categorical(categorical_probs)
|
| 462 |
+
bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
|
| 463 |
+
|
| 464 |
+
# Mix real and masked MSA
|
| 465 |
+
protein["bert_mask"] = mask_position.to(torch.float32)
|
| 466 |
+
protein["true_msa"] = protein["msa"]
|
| 467 |
+
protein["msa"] = bert_msa
|
| 468 |
+
|
| 469 |
+
return protein
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
@curry1
|
| 473 |
+
def make_fixed_size(
|
| 474 |
+
protein,
|
| 475 |
+
shape_schema,
|
| 476 |
+
msa_cluster_size,
|
| 477 |
+
extra_msa_size,
|
| 478 |
+
num_res=0,
|
| 479 |
+
num_templates=0,
|
| 480 |
+
):
|
| 481 |
+
"""Guess at the MSA and sequence dimension to make fixed size."""
|
| 482 |
+
pad_size_map = {
|
| 483 |
+
NUM_RES: num_res,
|
| 484 |
+
NUM_MSA_SEQ: msa_cluster_size,
|
| 485 |
+
NUM_EXTRA_SEQ: extra_msa_size,
|
| 486 |
+
NUM_TEMPLATES: num_templates,
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
for k, v in protein.items():
|
| 490 |
+
# Don't transfer this to the accelerator.
|
| 491 |
+
if k == "extra_cluster_assignment":
|
| 492 |
+
continue
|
| 493 |
+
shape = list(v.shape)
|
| 494 |
+
schema = shape_schema[k]
|
| 495 |
+
msg = "Rank mismatch between shape and shape schema for"
|
| 496 |
+
assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
|
| 497 |
+
pad_size = [
|
| 498 |
+
pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
|
| 499 |
+
]
|
| 500 |
+
|
| 501 |
+
padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
|
| 502 |
+
padding.reverse()
|
| 503 |
+
padding = list(itertools.chain(*padding))
|
| 504 |
+
if padding:
|
| 505 |
+
protein[k] = torch.nn.functional.pad(v, padding)
|
| 506 |
+
protein[k] = torch.reshape(protein[k], pad_size)
|
| 507 |
+
|
| 508 |
+
return protein
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
@curry1
|
| 512 |
+
def make_msa_feat(protein):
|
| 513 |
+
"""Create and concatenate MSA features."""
|
| 514 |
+
# Whether there is a domain break. Always zero for chains, but keeping for
|
| 515 |
+
# compatibility with domain datasets.
|
| 516 |
+
has_break = torch.clip(
|
| 517 |
+
protein["between_segment_residues"].to(torch.float32), 0, 1
|
| 518 |
+
)
|
| 519 |
+
aatype_1hot = make_one_hot(protein["aatype"], 21)
|
| 520 |
+
|
| 521 |
+
target_feat = [
|
| 522 |
+
torch.unsqueeze(has_break, dim=-1),
|
| 523 |
+
aatype_1hot, # Everyone gets the original sequence.
|
| 524 |
+
]
|
| 525 |
+
|
| 526 |
+
msa_1hot = make_one_hot(protein["msa"], 23)
|
| 527 |
+
has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
|
| 528 |
+
deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
|
| 529 |
+
2.0 / np.pi
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
msa_feat = [
|
| 533 |
+
msa_1hot,
|
| 534 |
+
torch.unsqueeze(has_deletion, dim=-1),
|
| 535 |
+
torch.unsqueeze(deletion_value, dim=-1),
|
| 536 |
+
]
|
| 537 |
+
|
| 538 |
+
if "cluster_profile" in protein:
|
| 539 |
+
deletion_mean_value = torch.atan(
|
| 540 |
+
protein["cluster_deletion_mean"] / 3.0
|
| 541 |
+
) * (2.0 / np.pi)
|
| 542 |
+
msa_feat.extend(
|
| 543 |
+
[
|
| 544 |
+
protein["cluster_profile"],
|
| 545 |
+
torch.unsqueeze(deletion_mean_value, dim=-1),
|
| 546 |
+
]
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
if "extra_deletion_matrix" in protein:
|
| 550 |
+
protein["extra_has_deletion"] = torch.clip(
|
| 551 |
+
protein["extra_deletion_matrix"], 0.0, 1.0
|
| 552 |
+
)
|
| 553 |
+
protein["extra_deletion_value"] = torch.atan(
|
| 554 |
+
protein["extra_deletion_matrix"] / 3.0
|
| 555 |
+
) * (2.0 / np.pi)
|
| 556 |
+
|
| 557 |
+
protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
|
| 558 |
+
protein["target_feat"] = torch.cat(target_feat, dim=-1)
|
| 559 |
+
return protein
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
@curry1
|
| 563 |
+
def select_feat(protein, feature_list):
|
| 564 |
+
return {k: v for k, v in protein.items() if k in feature_list}
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
@curry1
|
| 568 |
+
def crop_templates(protein, max_templates):
|
| 569 |
+
for k, v in protein.items():
|
| 570 |
+
if k.startswith("template_"):
|
| 571 |
+
protein[k] = v[:max_templates]
|
| 572 |
+
return protein
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def make_atom14_masks(protein):
|
| 576 |
+
"""Construct denser atom positions (14 dimensions instead of 37)."""
|
| 577 |
+
restype_atom14_to_atom37 = []
|
| 578 |
+
restype_atom37_to_atom14 = []
|
| 579 |
+
restype_atom14_mask = []
|
| 580 |
+
|
| 581 |
+
for rt in rc.restypes:
|
| 582 |
+
atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
|
| 583 |
+
restype_atom14_to_atom37.append(
|
| 584 |
+
[(rc.atom_order[name] if name else 0) for name in atom_names]
|
| 585 |
+
)
|
| 586 |
+
atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
|
| 587 |
+
restype_atom37_to_atom14.append(
|
| 588 |
+
[
|
| 589 |
+
(atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
|
| 590 |
+
for name in rc.atom_types
|
| 591 |
+
]
|
| 592 |
+
)
|
| 593 |
+
|
| 594 |
+
restype_atom14_mask.append(
|
| 595 |
+
[(1.0 if name else 0.0) for name in atom_names]
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
# Add dummy mapping for restype 'UNK'
|
| 599 |
+
restype_atom14_to_atom37.append([0] * 14)
|
| 600 |
+
restype_atom37_to_atom14.append([0] * 37)
|
| 601 |
+
restype_atom14_mask.append([0.0] * 14)
|
| 602 |
+
|
| 603 |
+
restype_atom14_to_atom37 = torch.tensor(
|
| 604 |
+
restype_atom14_to_atom37,
|
| 605 |
+
dtype=torch.int32,
|
| 606 |
+
device=protein["aatype"].device,
|
| 607 |
+
)
|
| 608 |
+
restype_atom37_to_atom14 = torch.tensor(
|
| 609 |
+
restype_atom37_to_atom14,
|
| 610 |
+
dtype=torch.int32,
|
| 611 |
+
device=protein["aatype"].device,
|
| 612 |
+
)
|
| 613 |
+
restype_atom14_mask = torch.tensor(
|
| 614 |
+
restype_atom14_mask,
|
| 615 |
+
dtype=torch.float32,
|
| 616 |
+
device=protein["aatype"].device,
|
| 617 |
+
)
|
| 618 |
+
protein_aatype = protein['aatype'].to(torch.long)
|
| 619 |
+
|
| 620 |
+
# create the mapping for (residx, atom14) --> atom37, i.e. an array
|
| 621 |
+
# with shape (num_res, 14) containing the atom37 indices for this protein
|
| 622 |
+
residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
|
| 623 |
+
residx_atom14_mask = restype_atom14_mask[protein_aatype]
|
| 624 |
+
|
| 625 |
+
protein["atom14_atom_exists"] = residx_atom14_mask
|
| 626 |
+
protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
|
| 627 |
+
|
| 628 |
+
# create the gather indices for mapping back
|
| 629 |
+
residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
|
| 630 |
+
protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
|
| 631 |
+
|
| 632 |
+
# create the corresponding mask
|
| 633 |
+
restype_atom37_mask = torch.zeros(
|
| 634 |
+
[21, 37], dtype=torch.float32, device=protein["aatype"].device
|
| 635 |
+
)
|
| 636 |
+
for restype, restype_letter in enumerate(rc.restypes):
|
| 637 |
+
restype_name = rc.restype_1to3[restype_letter]
|
| 638 |
+
atom_names = rc.residue_atoms[restype_name]
|
| 639 |
+
for atom_name in atom_names:
|
| 640 |
+
atom_type = rc.atom_order[atom_name]
|
| 641 |
+
restype_atom37_mask[restype, atom_type] = 1
|
| 642 |
+
|
| 643 |
+
residx_atom37_mask = restype_atom37_mask[protein_aatype]
|
| 644 |
+
protein["atom37_atom_exists"] = residx_atom37_mask
|
| 645 |
+
|
| 646 |
+
return protein
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def make_atom14_masks_np(batch):
|
| 650 |
+
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
|
| 651 |
+
out = make_atom14_masks(batch)
|
| 652 |
+
out = tensor_tree_map(lambda t: np.array(t), out)
|
| 653 |
+
return out
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
def make_atom14_positions(protein):
|
| 657 |
+
"""Constructs denser atom positions (14 dimensions instead of 37)."""
|
| 658 |
+
residx_atom14_mask = protein["atom14_atom_exists"]
|
| 659 |
+
residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
|
| 660 |
+
|
| 661 |
+
# Create a mask for known ground truth positions.
|
| 662 |
+
residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
|
| 663 |
+
protein["all_atom_mask"],
|
| 664 |
+
residx_atom14_to_atom37,
|
| 665 |
+
dim=-1,
|
| 666 |
+
no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
# Gather the ground truth positions.
|
| 670 |
+
residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
|
| 671 |
+
batched_gather(
|
| 672 |
+
protein["all_atom_positions"],
|
| 673 |
+
residx_atom14_to_atom37,
|
| 674 |
+
dim=-2,
|
| 675 |
+
no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
|
| 676 |
+
)
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
protein["atom14_atom_exists"] = residx_atom14_mask
|
| 680 |
+
protein["atom14_gt_exists"] = residx_atom14_gt_mask
|
| 681 |
+
protein["atom14_gt_positions"] = residx_atom14_gt_positions
|
| 682 |
+
|
| 683 |
+
# As the atom naming is ambiguous for 7 of the 20 amino acids, provide
|
| 684 |
+
# alternative ground truth coordinates where the naming is swapped
|
| 685 |
+
restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
|
| 686 |
+
restype_3 += ["UNK"]
|
| 687 |
+
|
| 688 |
+
# Matrices for renaming ambiguous atoms.
|
| 689 |
+
all_matrices = {
|
| 690 |
+
res: torch.eye(
|
| 691 |
+
14,
|
| 692 |
+
dtype=protein["all_atom_mask"].dtype,
|
| 693 |
+
device=protein["all_atom_mask"].device,
|
| 694 |
+
)
|
| 695 |
+
for res in restype_3
|
| 696 |
+
}
|
| 697 |
+
for resname, swap in rc.residue_atom_renaming_swaps.items():
|
| 698 |
+
correspondences = torch.arange(
|
| 699 |
+
14, device=protein["all_atom_mask"].device
|
| 700 |
+
)
|
| 701 |
+
for source_atom_swap, target_atom_swap in swap.items():
|
| 702 |
+
source_index = rc.restype_name_to_atom14_names[resname].index(
|
| 703 |
+
source_atom_swap
|
| 704 |
+
)
|
| 705 |
+
target_index = rc.restype_name_to_atom14_names[resname].index(
|
| 706 |
+
target_atom_swap
|
| 707 |
+
)
|
| 708 |
+
correspondences[source_index] = target_index
|
| 709 |
+
correspondences[target_index] = source_index
|
| 710 |
+
renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
|
| 711 |
+
for index, correspondence in enumerate(correspondences):
|
| 712 |
+
renaming_matrix[index, correspondence] = 1.0
|
| 713 |
+
all_matrices[resname] = renaming_matrix
|
| 714 |
+
renaming_matrices = torch.stack(
|
| 715 |
+
[all_matrices[restype] for restype in restype_3]
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
# Pick the transformation matrices for the given residue sequence
|
| 719 |
+
# shape (num_res, 14, 14).
|
| 720 |
+
renaming_transform = renaming_matrices[protein["aatype"]]
|
| 721 |
+
|
| 722 |
+
# Apply it to the ground truth positions. shape (num_res, 14, 3).
|
| 723 |
+
alternative_gt_positions = torch.einsum(
|
| 724 |
+
"...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
|
| 725 |
+
)
|
| 726 |
+
protein["atom14_alt_gt_positions"] = alternative_gt_positions
|
| 727 |
+
|
| 728 |
+
# Create the mask for the alternative ground truth (differs from the
|
| 729 |
+
# ground truth mask, if only one of the atoms in an ambiguous pair has a
|
| 730 |
+
# ground truth position).
|
| 731 |
+
alternative_gt_mask = torch.einsum(
|
| 732 |
+
"...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
|
| 733 |
+
)
|
| 734 |
+
protein["atom14_alt_gt_exists"] = alternative_gt_mask
|
| 735 |
+
|
| 736 |
+
# Create an ambiguous atoms mask. shape: (21, 14).
|
| 737 |
+
restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
|
| 738 |
+
for resname, swap in rc.residue_atom_renaming_swaps.items():
|
| 739 |
+
for atom_name1, atom_name2 in swap.items():
|
| 740 |
+
restype = rc.restype_order[rc.restype_3to1[resname]]
|
| 741 |
+
atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
|
| 742 |
+
atom_name1
|
| 743 |
+
)
|
| 744 |
+
atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
|
| 745 |
+
atom_name2
|
| 746 |
+
)
|
| 747 |
+
restype_atom14_is_ambiguous[restype, atom_idx1] = 1
|
| 748 |
+
restype_atom14_is_ambiguous[restype, atom_idx2] = 1
|
| 749 |
+
|
| 750 |
+
# From this create an ambiguous_mask for the given sequence.
|
| 751 |
+
protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
|
| 752 |
+
protein["aatype"]
|
| 753 |
+
]
|
| 754 |
+
|
| 755 |
+
return protein
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def atom37_to_frames(protein, eps=1e-8):
|
| 759 |
+
aatype = protein["aatype"]
|
| 760 |
+
all_atom_positions = protein["all_atom_positions"]
|
| 761 |
+
all_atom_mask = protein["all_atom_mask"]
|
| 762 |
+
|
| 763 |
+
batch_dims = len(aatype.shape[:-1])
|
| 764 |
+
|
| 765 |
+
restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
|
| 766 |
+
restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
|
| 767 |
+
restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
|
| 768 |
+
|
| 769 |
+
for restype, restype_letter in enumerate(rc.restypes):
|
| 770 |
+
resname = rc.restype_1to3[restype_letter]
|
| 771 |
+
for chi_idx in range(4):
|
| 772 |
+
if rc.chi_angles_mask[restype][chi_idx]:
|
| 773 |
+
names = rc.chi_angles_atoms[resname][chi_idx]
|
| 774 |
+
restype_rigidgroup_base_atom_names[
|
| 775 |
+
restype, chi_idx + 4, :
|
| 776 |
+
] = names[1:]
|
| 777 |
+
|
| 778 |
+
restype_rigidgroup_mask = all_atom_mask.new_zeros(
|
| 779 |
+
(*aatype.shape[:-1], 21, 8),
|
| 780 |
+
)
|
| 781 |
+
restype_rigidgroup_mask[..., 0] = 1
|
| 782 |
+
restype_rigidgroup_mask[..., 3] = 1
|
| 783 |
+
restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
|
| 784 |
+
rc.chi_angles_mask
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
lookuptable = rc.atom_order.copy()
|
| 788 |
+
lookuptable[""] = 0
|
| 789 |
+
lookup = np.vectorize(lambda x: lookuptable[x])
|
| 790 |
+
restype_rigidgroup_base_atom37_idx = lookup(
|
| 791 |
+
restype_rigidgroup_base_atom_names,
|
| 792 |
+
)
|
| 793 |
+
restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
|
| 794 |
+
restype_rigidgroup_base_atom37_idx,
|
| 795 |
+
)
|
| 796 |
+
restype_rigidgroup_base_atom37_idx = (
|
| 797 |
+
restype_rigidgroup_base_atom37_idx.view(
|
| 798 |
+
*((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
|
| 799 |
+
)
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
residx_rigidgroup_base_atom37_idx = batched_gather(
|
| 803 |
+
restype_rigidgroup_base_atom37_idx,
|
| 804 |
+
aatype,
|
| 805 |
+
dim=-3,
|
| 806 |
+
no_batch_dims=batch_dims,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
base_atom_pos = batched_gather(
|
| 810 |
+
all_atom_positions,
|
| 811 |
+
residx_rigidgroup_base_atom37_idx,
|
| 812 |
+
dim=-2,
|
| 813 |
+
no_batch_dims=len(all_atom_positions.shape[:-2]),
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
gt_frames = Rigid.from_3_points(
|
| 817 |
+
p_neg_x_axis=base_atom_pos[..., 0, :],
|
| 818 |
+
origin=base_atom_pos[..., 1, :],
|
| 819 |
+
p_xy_plane=base_atom_pos[..., 2, :],
|
| 820 |
+
eps=eps,
|
| 821 |
+
)
|
| 822 |
+
|
| 823 |
+
group_exists = batched_gather(
|
| 824 |
+
restype_rigidgroup_mask,
|
| 825 |
+
aatype,
|
| 826 |
+
dim=-2,
|
| 827 |
+
no_batch_dims=batch_dims,
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
gt_atoms_exist = batched_gather(
|
| 831 |
+
all_atom_mask,
|
| 832 |
+
residx_rigidgroup_base_atom37_idx,
|
| 833 |
+
dim=-1,
|
| 834 |
+
no_batch_dims=len(all_atom_mask.shape[:-1]),
|
| 835 |
+
)
|
| 836 |
+
gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
|
| 837 |
+
|
| 838 |
+
rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
|
| 839 |
+
rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
|
| 840 |
+
rots[..., 0, 0, 0] = -1
|
| 841 |
+
rots[..., 0, 2, 2] = -1
|
| 842 |
+
rots = Rotation(rot_mats=rots)
|
| 843 |
+
|
| 844 |
+
gt_frames = gt_frames.compose(Rigid(rots, None))
|
| 845 |
+
|
| 846 |
+
restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
|
| 847 |
+
*((1,) * batch_dims), 21, 8
|
| 848 |
+
)
|
| 849 |
+
restype_rigidgroup_rots = torch.eye(
|
| 850 |
+
3, dtype=all_atom_mask.dtype, device=aatype.device
|
| 851 |
+
)
|
| 852 |
+
restype_rigidgroup_rots = torch.tile(
|
| 853 |
+
restype_rigidgroup_rots,
|
| 854 |
+
(*((1,) * batch_dims), 21, 8, 1, 1),
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
for resname, _ in rc.residue_atom_renaming_swaps.items():
|
| 858 |
+
restype = rc.restype_order[rc.restype_3to1[resname]]
|
| 859 |
+
chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
|
| 860 |
+
restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
|
| 861 |
+
restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
|
| 862 |
+
restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
|
| 863 |
+
|
| 864 |
+
residx_rigidgroup_is_ambiguous = batched_gather(
|
| 865 |
+
restype_rigidgroup_is_ambiguous,
|
| 866 |
+
aatype,
|
| 867 |
+
dim=-2,
|
| 868 |
+
no_batch_dims=batch_dims,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
residx_rigidgroup_ambiguity_rot = batched_gather(
|
| 872 |
+
restype_rigidgroup_rots,
|
| 873 |
+
aatype,
|
| 874 |
+
dim=-4,
|
| 875 |
+
no_batch_dims=batch_dims,
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
residx_rigidgroup_ambiguity_rot = Rotation(
|
| 879 |
+
rot_mats=residx_rigidgroup_ambiguity_rot
|
| 880 |
+
)
|
| 881 |
+
alt_gt_frames = gt_frames.compose(
|
| 882 |
+
Rigid(residx_rigidgroup_ambiguity_rot, None)
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
gt_frames_tensor = gt_frames.to_tensor_4x4()
|
| 886 |
+
alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
|
| 887 |
+
|
| 888 |
+
protein["rigidgroups_gt_frames"] = gt_frames_tensor
|
| 889 |
+
protein["rigidgroups_gt_exists"] = gt_exists
|
| 890 |
+
protein["rigidgroups_group_exists"] = group_exists
|
| 891 |
+
protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
|
| 892 |
+
protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
|
| 893 |
+
|
| 894 |
+
return protein
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
def get_chi_atom_indices():
|
| 898 |
+
"""Returns atom indices needed to compute chi angles for all residue types.
|
| 899 |
+
|
| 900 |
+
Returns:
|
| 901 |
+
A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
|
| 902 |
+
in the order specified in rc.restypes + unknown residue type
|
| 903 |
+
at the end. For chi angles which are not defined on the residue, the
|
| 904 |
+
positions indices are by default set to 0.
|
| 905 |
+
"""
|
| 906 |
+
chi_atom_indices = []
|
| 907 |
+
for residue_name in rc.restypes:
|
| 908 |
+
residue_name = rc.restype_1to3[residue_name]
|
| 909 |
+
residue_chi_angles = rc.chi_angles_atoms[residue_name]
|
| 910 |
+
atom_indices = []
|
| 911 |
+
for chi_angle in residue_chi_angles:
|
| 912 |
+
atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
|
| 913 |
+
for _ in range(4 - len(atom_indices)):
|
| 914 |
+
atom_indices.append(
|
| 915 |
+
[0, 0, 0, 0]
|
| 916 |
+
) # For chi angles not defined on the AA.
|
| 917 |
+
chi_atom_indices.append(atom_indices)
|
| 918 |
+
|
| 919 |
+
chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
|
| 920 |
+
|
| 921 |
+
return chi_atom_indices
|
| 922 |
+
|
| 923 |
+
|
| 924 |
+
@curry1
|
| 925 |
+
def atom37_to_torsion_angles(
|
| 926 |
+
protein,
|
| 927 |
+
prefix="",
|
| 928 |
+
):
|
| 929 |
+
"""
|
| 930 |
+
Convert coordinates to torsion angles.
|
| 931 |
+
|
| 932 |
+
This function is extremely sensitive to floating point imprecisions
|
| 933 |
+
and should be run with double precision whenever possible.
|
| 934 |
+
|
| 935 |
+
Args:
|
| 936 |
+
Dict containing:
|
| 937 |
+
* (prefix)aatype:
|
| 938 |
+
[*, N_res] residue indices
|
| 939 |
+
* (prefix)all_atom_positions:
|
| 940 |
+
[*, N_res, 37, 3] atom positions (in atom37
|
| 941 |
+
format)
|
| 942 |
+
* (prefix)all_atom_mask:
|
| 943 |
+
[*, N_res, 37] atom position mask
|
| 944 |
+
Returns:
|
| 945 |
+
The same dictionary updated with the following features:
|
| 946 |
+
|
| 947 |
+
"(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
|
| 948 |
+
Torsion angles
|
| 949 |
+
"(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
|
| 950 |
+
Alternate torsion angles (accounting for 180-degree symmetry)
|
| 951 |
+
"(prefix)torsion_angles_mask" ([*, N_res, 7])
|
| 952 |
+
Torsion angles mask
|
| 953 |
+
"""
|
| 954 |
+
aatype = protein[prefix + "aatype"]
|
| 955 |
+
all_atom_positions = protein[prefix + "all_atom_positions"]
|
| 956 |
+
all_atom_mask = protein[prefix + "all_atom_mask"]
|
| 957 |
+
|
| 958 |
+
aatype = torch.clamp(aatype, max=20)
|
| 959 |
+
|
| 960 |
+
pad = all_atom_positions.new_zeros(
|
| 961 |
+
[*all_atom_positions.shape[:-3], 1, 37, 3]
|
| 962 |
+
)
|
| 963 |
+
prev_all_atom_positions = torch.cat(
|
| 964 |
+
[pad, all_atom_positions[..., :-1, :, :]], dim=-3
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
|
| 968 |
+
prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
|
| 969 |
+
|
| 970 |
+
pre_omega_atom_pos = torch.cat(
|
| 971 |
+
[prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
|
| 972 |
+
dim=-2,
|
| 973 |
+
)
|
| 974 |
+
phi_atom_pos = torch.cat(
|
| 975 |
+
[prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
|
| 976 |
+
dim=-2,
|
| 977 |
+
)
|
| 978 |
+
psi_atom_pos = torch.cat(
|
| 979 |
+
[all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
|
| 980 |
+
dim=-2,
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
pre_omega_mask = torch.prod(
|
| 984 |
+
prev_all_atom_mask[..., 1:3], dim=-1
|
| 985 |
+
) * torch.prod(all_atom_mask[..., :2], dim=-1)
|
| 986 |
+
phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
|
| 987 |
+
all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
|
| 988 |
+
)
|
| 989 |
+
psi_mask = (
|
| 990 |
+
torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
|
| 991 |
+
* all_atom_mask[..., 4]
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
chi_atom_indices = torch.as_tensor(
|
| 995 |
+
get_chi_atom_indices(), device=aatype.device
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
atom_indices = chi_atom_indices[..., aatype, :, :]
|
| 999 |
+
chis_atom_pos = batched_gather(
|
| 1000 |
+
all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
|
| 1001 |
+
)
|
| 1002 |
+
|
| 1003 |
+
chi_angles_mask = list(rc.chi_angles_mask)
|
| 1004 |
+
chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
|
| 1005 |
+
chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
|
| 1006 |
+
|
| 1007 |
+
chis_mask = chi_angles_mask[aatype, :]
|
| 1008 |
+
|
| 1009 |
+
chi_angle_atoms_mask = batched_gather(
|
| 1010 |
+
all_atom_mask,
|
| 1011 |
+
atom_indices,
|
| 1012 |
+
dim=-1,
|
| 1013 |
+
no_batch_dims=len(atom_indices.shape[:-2]),
|
| 1014 |
+
)
|
| 1015 |
+
chi_angle_atoms_mask = torch.prod(
|
| 1016 |
+
chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
|
| 1017 |
+
)
|
| 1018 |
+
chis_mask = chis_mask * chi_angle_atoms_mask
|
| 1019 |
+
|
| 1020 |
+
torsions_atom_pos = torch.cat(
|
| 1021 |
+
[
|
| 1022 |
+
pre_omega_atom_pos[..., None, :, :],
|
| 1023 |
+
phi_atom_pos[..., None, :, :],
|
| 1024 |
+
psi_atom_pos[..., None, :, :],
|
| 1025 |
+
chis_atom_pos,
|
| 1026 |
+
],
|
| 1027 |
+
dim=-3,
|
| 1028 |
+
)
|
| 1029 |
+
|
| 1030 |
+
torsion_angles_mask = torch.cat(
|
| 1031 |
+
[
|
| 1032 |
+
pre_omega_mask[..., None],
|
| 1033 |
+
phi_mask[..., None],
|
| 1034 |
+
psi_mask[..., None],
|
| 1035 |
+
chis_mask,
|
| 1036 |
+
],
|
| 1037 |
+
dim=-1,
|
| 1038 |
+
)
|
| 1039 |
+
|
| 1040 |
+
torsion_frames = Rigid.from_3_points(
|
| 1041 |
+
torsions_atom_pos[..., 1, :],
|
| 1042 |
+
torsions_atom_pos[..., 2, :],
|
| 1043 |
+
torsions_atom_pos[..., 0, :],
|
| 1044 |
+
eps=1e-8,
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
fourth_atom_rel_pos = torsion_frames.invert().apply(
|
| 1048 |
+
torsions_atom_pos[..., 3, :]
|
| 1049 |
+
)
|
| 1050 |
+
|
| 1051 |
+
torsion_angles_sin_cos = torch.stack(
|
| 1052 |
+
[fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
|
| 1053 |
+
)
|
| 1054 |
+
|
| 1055 |
+
denom = torch.sqrt(
|
| 1056 |
+
torch.sum(
|
| 1057 |
+
torch.square(torsion_angles_sin_cos),
|
| 1058 |
+
dim=-1,
|
| 1059 |
+
dtype=torsion_angles_sin_cos.dtype,
|
| 1060 |
+
keepdims=True,
|
| 1061 |
+
)
|
| 1062 |
+
+ 1e-8
|
| 1063 |
+
)
|
| 1064 |
+
torsion_angles_sin_cos = torsion_angles_sin_cos / denom
|
| 1065 |
+
|
| 1066 |
+
torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
|
| 1067 |
+
[1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
|
| 1068 |
+
)[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
|
| 1069 |
+
|
| 1070 |
+
chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
|
| 1071 |
+
rc.chi_pi_periodic,
|
| 1072 |
+
)[aatype, ...]
|
| 1073 |
+
|
| 1074 |
+
mirror_torsion_angles = torch.cat(
|
| 1075 |
+
[
|
| 1076 |
+
all_atom_mask.new_ones(*aatype.shape, 3),
|
| 1077 |
+
1.0 - 2.0 * chi_is_ambiguous,
|
| 1078 |
+
],
|
| 1079 |
+
dim=-1,
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
alt_torsion_angles_sin_cos = (
|
| 1083 |
+
torsion_angles_sin_cos * mirror_torsion_angles[..., None]
|
| 1084 |
+
)
|
| 1085 |
+
|
| 1086 |
+
protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
|
| 1087 |
+
protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
|
| 1088 |
+
protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
|
| 1089 |
+
|
| 1090 |
+
return protein
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
def get_backbone_frames(protein):
|
| 1094 |
+
# DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
|
| 1095 |
+
protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
|
| 1096 |
+
..., 0, :, :
|
| 1097 |
+
]
|
| 1098 |
+
protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
|
| 1099 |
+
|
| 1100 |
+
return protein
|
| 1101 |
+
|
| 1102 |
+
|
| 1103 |
+
def get_chi_angles(protein):
|
| 1104 |
+
dtype = protein["all_atom_mask"].dtype
|
| 1105 |
+
protein["chi_angles_sin_cos"] = (
|
| 1106 |
+
protein["torsion_angles_sin_cos"][..., 3:, :]
|
| 1107 |
+
).to(dtype)
|
| 1108 |
+
protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
|
| 1109 |
+
|
| 1110 |
+
return protein
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
@curry1
|
| 1114 |
+
def random_crop_to_size(
|
| 1115 |
+
protein,
|
| 1116 |
+
crop_size,
|
| 1117 |
+
max_templates,
|
| 1118 |
+
shape_schema,
|
| 1119 |
+
subsample_templates=False,
|
| 1120 |
+
seed=None,
|
| 1121 |
+
):
|
| 1122 |
+
"""Crop randomly to `crop_size`, or keep as is if shorter than that."""
|
| 1123 |
+
# We want each ensemble to be cropped the same way
|
| 1124 |
+
g = torch.Generator(device=protein["seq_length"].device)
|
| 1125 |
+
if seed is not None:
|
| 1126 |
+
g.manual_seed(seed)
|
| 1127 |
+
|
| 1128 |
+
seq_length = protein["seq_length"]
|
| 1129 |
+
|
| 1130 |
+
if "template_mask" in protein:
|
| 1131 |
+
num_templates = protein["template_mask"].shape[-1]
|
| 1132 |
+
else:
|
| 1133 |
+
num_templates = 0
|
| 1134 |
+
|
| 1135 |
+
# No need to subsample templates if there aren't any
|
| 1136 |
+
subsample_templates = subsample_templates and num_templates
|
| 1137 |
+
|
| 1138 |
+
num_res_crop_size = min(int(seq_length), crop_size)
|
| 1139 |
+
|
| 1140 |
+
def _randint(lower, upper):
|
| 1141 |
+
return int(torch.randint(
|
| 1142 |
+
lower,
|
| 1143 |
+
upper + 1,
|
| 1144 |
+
(1,),
|
| 1145 |
+
device=protein["seq_length"].device,
|
| 1146 |
+
generator=g,
|
| 1147 |
+
)[0])
|
| 1148 |
+
|
| 1149 |
+
if subsample_templates:
|
| 1150 |
+
templates_crop_start = _randint(0, num_templates)
|
| 1151 |
+
templates_select_indices = torch.randperm(
|
| 1152 |
+
num_templates, device=protein["seq_length"].device, generator=g
|
| 1153 |
+
)
|
| 1154 |
+
else:
|
| 1155 |
+
templates_crop_start = 0
|
| 1156 |
+
|
| 1157 |
+
num_templates_crop_size = min(
|
| 1158 |
+
num_templates - templates_crop_start, max_templates
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
n = seq_length - num_res_crop_size
|
| 1162 |
+
if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
|
| 1163 |
+
right_anchor = n
|
| 1164 |
+
else:
|
| 1165 |
+
x = _randint(0, n)
|
| 1166 |
+
right_anchor = n - x
|
| 1167 |
+
|
| 1168 |
+
num_res_crop_start = _randint(0, right_anchor)
|
| 1169 |
+
|
| 1170 |
+
for k, v in protein.items():
|
| 1171 |
+
if k not in shape_schema or (
|
| 1172 |
+
"template" not in k and NUM_RES not in shape_schema[k]
|
| 1173 |
+
):
|
| 1174 |
+
continue
|
| 1175 |
+
|
| 1176 |
+
# randomly permute the templates before cropping them.
|
| 1177 |
+
if k.startswith("template") and subsample_templates:
|
| 1178 |
+
v = v[templates_select_indices]
|
| 1179 |
+
|
| 1180 |
+
slices = []
|
| 1181 |
+
for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
|
| 1182 |
+
is_num_res = dim_size == NUM_RES
|
| 1183 |
+
if i == 0 and k.startswith("template"):
|
| 1184 |
+
crop_size = num_templates_crop_size
|
| 1185 |
+
crop_start = templates_crop_start
|
| 1186 |
+
else:
|
| 1187 |
+
crop_start = num_res_crop_start if is_num_res else 0
|
| 1188 |
+
crop_size = num_res_crop_size if is_num_res else dim
|
| 1189 |
+
slices.append(slice(crop_start, crop_start + crop_size))
|
| 1190 |
+
protein[k] = v[slices]
|
| 1191 |
+
|
| 1192 |
+
protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
|
| 1193 |
+
|
| 1194 |
+
return protein
|
analysis/src/common/geo_utils.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for geometric operations (torch only).
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def rots_mul_vecs(m, v):
|
| 8 |
+
"""(Batch) Apply rotations 'm' to vectors 'v'."""
|
| 9 |
+
return torch.stack([
|
| 10 |
+
m[..., 0, 0] * v[..., 0] + m[..., 0, 1] * v[..., 1] + m[..., 0, 2] * v[..., 2],
|
| 11 |
+
m[..., 1, 0] * v[..., 0] + m[..., 1, 1] * v[..., 1] + m[..., 1, 2] * v[..., 2],
|
| 12 |
+
m[..., 2, 0] * v[..., 0] + m[..., 2, 1] * v[..., 1] + m[..., 2, 2] * v[..., 2],
|
| 13 |
+
], dim=-1)
|
| 14 |
+
|
| 15 |
+
def distance(p, eps=1e-10):
|
| 16 |
+
"""Calculate distance between a pair of points (dim=-2)."""
|
| 17 |
+
# [*, 2, 3]
|
| 18 |
+
return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5
|
| 19 |
+
|
| 20 |
+
def dihedral(p, eps=1e-10):
|
| 21 |
+
"""Calculate dihedral angle between a quadruple of points (dim=-2)."""
|
| 22 |
+
# p: [*, 4, 3]
|
| 23 |
+
|
| 24 |
+
# [*, 3]
|
| 25 |
+
u1 = p[..., 1, :] - p[..., 0, :]
|
| 26 |
+
u2 = p[..., 2, :] - p[..., 1, :]
|
| 27 |
+
u3 = p[..., 3, :] - p[..., 2, :]
|
| 28 |
+
|
| 29 |
+
# [*, 3]
|
| 30 |
+
u1xu2 = torch.cross(u1, u2, dim=-1)
|
| 31 |
+
u2xu3 = torch.cross(u2, u3, dim=-1)
|
| 32 |
+
|
| 33 |
+
# [*]
|
| 34 |
+
u2_norm = (eps + torch.sum(u2 ** 2, dim=-1)) ** 0.5
|
| 35 |
+
u1xu2_norm = (eps + torch.sum(u1xu2 ** 2, dim=-1)) ** 0.5
|
| 36 |
+
u2xu3_norm = (eps + torch.sum(u2xu3 ** 2, dim=-1)) ** 0.5
|
| 37 |
+
|
| 38 |
+
# [*]
|
| 39 |
+
cos_enc = torch.einsum('...d,...d->...', u1xu2, u2xu3)/ (u1xu2_norm * u2xu3_norm)
|
| 40 |
+
sin_enc = torch.einsum('...d,...d->...', u2, torch.cross(u1xu2, u2xu3, dim=-1)) / (u2_norm * u1xu2_norm * u2xu3_norm)
|
| 41 |
+
|
| 42 |
+
return torch.stack([cos_enc, sin_enc], dim=-1)
|
| 43 |
+
|
| 44 |
+
def calc_distogram(pos: torch.Tensor, min_bin: float, max_bin: float, num_bins: int):
|
| 45 |
+
# pos: [*, L, 3]
|
| 46 |
+
dists_2d = torch.linalg.norm(
|
| 47 |
+
pos[..., :, None, :] - pos[..., None, :, :], axis=-1
|
| 48 |
+
)[..., None]
|
| 49 |
+
lower = torch.linspace(
|
| 50 |
+
min_bin,
|
| 51 |
+
max_bin,
|
| 52 |
+
num_bins,
|
| 53 |
+
device=pos.device)
|
| 54 |
+
upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
|
| 55 |
+
distogram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
|
| 56 |
+
return distogram
|
| 57 |
+
|
| 58 |
+
def rmsd(xyz1, xyz2):
|
| 59 |
+
""" Abbreviation for squared_deviation(xyz1, xyz2, 'rmsd') """
|
| 60 |
+
return squared_deviation(xyz1, xyz2, 'rmsd')
|
| 61 |
+
|
| 62 |
+
def squared_deviation(xyz1, xyz2, reduction='none'):
|
| 63 |
+
"""Squared point-wise deviation between two point clouds after alignment.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
xyz1: (*, L, 3), to be transformed
|
| 67 |
+
xyz2: (*, L, 3), the reference
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
rmsd: (*, ) or none: (*, L)
|
| 71 |
+
"""
|
| 72 |
+
map_to_np = False
|
| 73 |
+
if not torch.is_tensor(xyz1):
|
| 74 |
+
map_to_np = True
|
| 75 |
+
xyz1 = torch.as_tensor(xyz1)
|
| 76 |
+
xyz2 = torch.as_tensor(xyz2)
|
| 77 |
+
|
| 78 |
+
R, t = _find_rigid_alignment(xyz1, xyz2)
|
| 79 |
+
|
| 80 |
+
# print(R.shape, t.shape) # B, 3, 3 & B, 3
|
| 81 |
+
|
| 82 |
+
# xyz1_aligned = (R.bmm(xyz1.transpose(-2,-1))).transpose(-2,-1) + t.unsqueeze(1)
|
| 83 |
+
xyz1_aligned = (torch.matmul(R, xyz1.transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
|
| 84 |
+
|
| 85 |
+
sd = ((xyz1_aligned - xyz2)**2).sum(dim=-1) # (*, L)
|
| 86 |
+
|
| 87 |
+
assert sd.shape == xyz1.shape[:-1]
|
| 88 |
+
if reduction == 'none':
|
| 89 |
+
pass
|
| 90 |
+
elif reduction == 'rmsd':
|
| 91 |
+
sd = torch.sqrt(sd.mean(dim=-1))
|
| 92 |
+
else:
|
| 93 |
+
raise NotImplementedError()
|
| 94 |
+
|
| 95 |
+
sd = sd.numpy() if map_to_np else sd
|
| 96 |
+
return sd
|
| 97 |
+
|
| 98 |
+
def _find_rigid_alignment(src, tgt):
|
| 99 |
+
"""Inspired by https://research.pasteur.fr/en/member/guillaume-bouvier/;
|
| 100 |
+
https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8
|
| 101 |
+
|
| 102 |
+
See: https://en.wikipedia.org/wiki/Kabsch_algorithm
|
| 103 |
+
|
| 104 |
+
2-D or 3-D registration with known correspondences.
|
| 105 |
+
Registration occurs in the zero centered coordinate system, and then
|
| 106 |
+
must be transported back.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
src: Torch tensor of shape (*, L, 3) -- Point Cloud to Align (source)
|
| 110 |
+
tgt: Torch tensor of shape (*, L, 3) -- Reference Point Cloud (target)
|
| 111 |
+
Returns:
|
| 112 |
+
R: optimal rotation (*, 3, 3)
|
| 113 |
+
t: optimal translation (*, 3)
|
| 114 |
+
|
| 115 |
+
Test on rotation + translation and on rotation + translation + reflection
|
| 116 |
+
>>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
|
| 117 |
+
>>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
|
| 118 |
+
>>> B = (R0.mm(A.T)).T
|
| 119 |
+
>>> t0 = torch.tensor([3., 3.])
|
| 120 |
+
>>> B += t0
|
| 121 |
+
>>> R, t = find_rigid_alignment(A, B)
|
| 122 |
+
>>> A_aligned = (R.mm(A.T)).T + t
|
| 123 |
+
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
|
| 124 |
+
>>> rmsd
|
| 125 |
+
tensor(3.7064e-07)
|
| 126 |
+
>>> B *= torch.tensor([-1., 1.])
|
| 127 |
+
>>> R, t = find_rigid_alignment(A, B)
|
| 128 |
+
>>> A_aligned = (R.mm(A.T)).T + t
|
| 129 |
+
>>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
|
| 130 |
+
>>> rmsd
|
| 131 |
+
tensor(3.7064e-07)
|
| 132 |
+
"""
|
| 133 |
+
assert src.shape[-2] > 1
|
| 134 |
+
src_com = src.mean(dim=-2, keepdim=True)
|
| 135 |
+
tgt_com = tgt.mean(dim=-2, keepdim=True)
|
| 136 |
+
src_centered = src - src_com
|
| 137 |
+
tgt_centered = tgt - tgt_com
|
| 138 |
+
|
| 139 |
+
# Covariance matrix
|
| 140 |
+
|
| 141 |
+
# H = src_centered.transpose(-2,-1).bmm(tgt_centered) # *, 3, 3
|
| 142 |
+
H = torch.matmul(src_centered.transpose(-2,-1), tgt_centered)
|
| 143 |
+
|
| 144 |
+
U, S, V = torch.svd(H)
|
| 145 |
+
# Rotation matrix
|
| 146 |
+
|
| 147 |
+
# R = V.bmm(U.transpose(-2,-1))
|
| 148 |
+
R = torch.matmul(V, U.transpose(-2, -1))
|
| 149 |
+
|
| 150 |
+
# Translation vector
|
| 151 |
+
|
| 152 |
+
# t = tgt_com - R.bmm(src_com.transpose(-2,-1)).transpose(-2,-1)
|
| 153 |
+
t = tgt_com - torch.matmul(R, src_com.transpose(-2, -1)).transpose(-2, -1)
|
| 154 |
+
|
| 155 |
+
return R, t.squeeze(-2) # (B, 3, 3), (B, 3)
|
analysis/src/common/pdb_utils.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for operating PDB files.
|
| 2 |
+
"""
|
| 3 |
+
import os
|
| 4 |
+
import re
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from collections import OrderedDict
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
import biotite.structure as struct
|
| 12 |
+
from biotite.structure.io.pdb import PDBFile
|
| 13 |
+
|
| 14 |
+
from src.common import protein
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def write_pdb_string(pdb_string: str, save_to: str):
|
| 18 |
+
"""Write pdb string to file"""
|
| 19 |
+
with open(save_to, 'w') as f:
|
| 20 |
+
f.write(pdb_string)
|
| 21 |
+
|
| 22 |
+
def read_pdb_to_string(pdb_file):
|
| 23 |
+
"""Read PDB file as pdb string. Convenient API"""
|
| 24 |
+
with open(pdb_file, 'r') as fi:
|
| 25 |
+
pdb_string = ''
|
| 26 |
+
for line in fi:
|
| 27 |
+
if line.startswith('END') or line.startswith('TER') \
|
| 28 |
+
or line.startswith('MODEL') or line.startswith('ATOM'):
|
| 29 |
+
pdb_string += line
|
| 30 |
+
return pdb_string
|
| 31 |
+
|
| 32 |
+
def merge_pdbfiles(input, output_file, verbose=True):
|
| 33 |
+
"""ordered merging process of pdbs"""
|
| 34 |
+
if isinstance(input, str):
|
| 35 |
+
pdb_files = [os.path.join(input, f) for f in os.listdir(input) if f.endswith('.pdb')]
|
| 36 |
+
elif isinstance(input, list):
|
| 37 |
+
pdb_files = input
|
| 38 |
+
|
| 39 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 40 |
+
|
| 41 |
+
model_number = 0
|
| 42 |
+
pdb_lines = []
|
| 43 |
+
if verbose:
|
| 44 |
+
_iter = tqdm(pdb_files, desc='Merging PDBs')
|
| 45 |
+
else:
|
| 46 |
+
_iter = pdb_files
|
| 47 |
+
|
| 48 |
+
for pdb_file in _iter:
|
| 49 |
+
with open(pdb_file, 'r') as pdb:
|
| 50 |
+
lines = pdb.readlines()
|
| 51 |
+
single_model = True
|
| 52 |
+
|
| 53 |
+
for line in lines:
|
| 54 |
+
if line.startswith('MODEL') or line.startswith('ENDMDL'):
|
| 55 |
+
single_model = False
|
| 56 |
+
break
|
| 57 |
+
|
| 58 |
+
if single_model: # single model
|
| 59 |
+
model_number += 1
|
| 60 |
+
pdb_lines.append(f"MODEL {model_number}")
|
| 61 |
+
for line in lines:
|
| 62 |
+
if line.startswith('TER') or line.startswith('ATOM'):
|
| 63 |
+
pdb_lines.append(line.strip())
|
| 64 |
+
pdb_lines.append("ENDMDL")
|
| 65 |
+
else: # multiple models
|
| 66 |
+
for line in lines:
|
| 67 |
+
if line.startswith('MODEL'):
|
| 68 |
+
model_number += 1
|
| 69 |
+
if model_number > 1:
|
| 70 |
+
pdb_lines.append("ENDMDL")
|
| 71 |
+
pdb_lines.append(f"MODEL {model_number}")
|
| 72 |
+
elif line.startswith('END'):
|
| 73 |
+
continue
|
| 74 |
+
elif line.startswith('TER') or line.startswith('ATOM'):
|
| 75 |
+
pdb_lines.append(line.strip())
|
| 76 |
+
pdb_lines.append('ENDMDL')
|
| 77 |
+
pdb_lines.append('END')
|
| 78 |
+
pdb_lines = [_line.ljust(80) for _line in pdb_lines]
|
| 79 |
+
pdb_str = '\n'.join(pdb_lines) + '\n'
|
| 80 |
+
with open(output_file, 'w') as fo:
|
| 81 |
+
fo.write(pdb_str)
|
| 82 |
+
|
| 83 |
+
if verbose:
|
| 84 |
+
print(f"Merged {len(pdb_files)} PDB files into {output_file} with {model_number} models.")
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def split_pdbfile(pdb_file, output_dir=None, suffix='index', verbose=True):
|
| 88 |
+
"""Split a PDB file into multiple PDB files in output_dir.
|
| 89 |
+
Preassume that each model is wrapped by 'MODEL' and 'ENDMDL'.
|
| 90 |
+
"""
|
| 91 |
+
assert os.path.exists(pdb_file), f"File {pdb_file} does not exist."
|
| 92 |
+
assert suffix == 'index', 'Only support [suffix=index] for now.'
|
| 93 |
+
|
| 94 |
+
if output_dir is not None: # also dump to output_dir
|
| 95 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 96 |
+
base = os.path.splitext(os.path.basename(pdb_file))[0]
|
| 97 |
+
|
| 98 |
+
i = 0
|
| 99 |
+
pdb_strs = []
|
| 100 |
+
pdb_string = ''
|
| 101 |
+
with open(pdb_file, 'r') as fi:
|
| 102 |
+
# pdb_string = ''
|
| 103 |
+
for line in fi:
|
| 104 |
+
if line.startswith('MODEL'):
|
| 105 |
+
pdb_string = ''
|
| 106 |
+
elif line.startswith('ATOM') or line.startswith('TER'):
|
| 107 |
+
pdb_string += line
|
| 108 |
+
elif line.startswith('ENDMDL') or line.startswith('END'):
|
| 109 |
+
if pdb_string == '': continue
|
| 110 |
+
pdb_string += 'END\n'
|
| 111 |
+
if output_dir is not None:
|
| 112 |
+
_save_to = os.path.join(output_dir, f'{base}_{i}.pdb') if suffix == 'index' else None
|
| 113 |
+
with open(_save_to, 'w') as fo:
|
| 114 |
+
fo.write(pdb_string)
|
| 115 |
+
pdb_strs.append(pdb_string)
|
| 116 |
+
pdb_string = ''
|
| 117 |
+
i += 1
|
| 118 |
+
else:
|
| 119 |
+
if verbose:
|
| 120 |
+
print(f"Warning: line '{line}' is not recognized. Skip.")
|
| 121 |
+
if verbose:
|
| 122 |
+
print(f">>> Split pdb {pdb_file} into {i}/{len(pdb_strs)} structures.")
|
| 123 |
+
return pdb_strs
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def stratify_sample_pdbfile(input_path, output_path, n_max_sample=1000, end_at=0, verbose=True):
|
| 127 |
+
""" """
|
| 128 |
+
assert os.path.exists(input_path), f"File {input_path} does not exist."
|
| 129 |
+
assert not os.path.exists(output_path), f"Output path {output_path} already exists."
|
| 130 |
+
|
| 131 |
+
i = 0
|
| 132 |
+
pdb_strs = []
|
| 133 |
+
with open(input_path, 'r') as fi:
|
| 134 |
+
# pdb_string = ''
|
| 135 |
+
pdb_lines_per_model = []
|
| 136 |
+
for line in fi:
|
| 137 |
+
if line.startswith('MODEL'):
|
| 138 |
+
pdb_lines_per_model = []
|
| 139 |
+
elif line.startswith('ATOM') or line.startswith('TER'):
|
| 140 |
+
pdb_lines_per_model.append(line.strip())
|
| 141 |
+
elif line.startswith('ENDMDL') or line.startswith('END'):
|
| 142 |
+
if pdb_lines_per_model == []: continue # skip empty model
|
| 143 |
+
# wrap up the model
|
| 144 |
+
pdb_lines_per_model.append('ENDMDL')
|
| 145 |
+
# Pad all lines to 80 characters.
|
| 146 |
+
pdb_lines_per_model = [_line.ljust(80) for _line in pdb_lines_per_model]
|
| 147 |
+
pdb_str_per_model = '\n'.join(pdb_lines_per_model) + '\n' # Add terminating newline.
|
| 148 |
+
pdb_strs.append(pdb_str_per_model)
|
| 149 |
+
# reset
|
| 150 |
+
pdb_lines_per_model = []
|
| 151 |
+
i += 1
|
| 152 |
+
else:
|
| 153 |
+
if verbose:
|
| 154 |
+
print(f"Warning: line '{line}' is not recognized. Skip.")
|
| 155 |
+
if end_at > 0 and i > end_at:
|
| 156 |
+
break
|
| 157 |
+
end = end_at if end_at > 0 else len(pdb_strs)
|
| 158 |
+
|
| 159 |
+
# sample evenly
|
| 160 |
+
if end > n_max_sample:
|
| 161 |
+
interleave_step = int(end // n_max_sample) # floor
|
| 162 |
+
sampled_pdb_strs = pdb_strs[:end][::interleave_step][:n_max_sample]
|
| 163 |
+
else:
|
| 164 |
+
sampled_pdb_strs = pdb_strs[:end]
|
| 165 |
+
|
| 166 |
+
output_str = ''
|
| 167 |
+
for i, pdb_str in enumerate(sampled_pdb_strs): # renumber models
|
| 168 |
+
output_str += f"MODEL {i+1}".ljust(80) + '\n'
|
| 169 |
+
output_str += pdb_str
|
| 170 |
+
output_str = output_str + ('END'.ljust(80) + '\n')
|
| 171 |
+
|
| 172 |
+
write_pdb_string(output_str, save_to=output_path)
|
| 173 |
+
if verbose:
|
| 174 |
+
print(f">>> Split pdb {input_path} into {len(sampled_pdb_strs)}/{n_max_sample} structures.")
|
| 175 |
+
return
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def protein_with_default_params(
|
| 179 |
+
atom_positions: np.ndarray,
|
| 180 |
+
atom_mask: np.ndarray,
|
| 181 |
+
aatype: Optional[np.ndarray] = None,
|
| 182 |
+
b_factors: Optional[np.ndarray] = None,
|
| 183 |
+
chain_index: Optional[np.ndarray] = None,
|
| 184 |
+
residue_index: Optional[np.ndarray] = None,
|
| 185 |
+
):
|
| 186 |
+
assert atom_positions.ndim == 3
|
| 187 |
+
assert atom_positions.shape[-1] == 3
|
| 188 |
+
assert atom_positions.shape[-2] == 37
|
| 189 |
+
n = atom_positions.shape[0]
|
| 190 |
+
sqz = lambda x: np.squeeze(x) if x.shape[0] == 1 and len(x.shape) > 1 else x
|
| 191 |
+
|
| 192 |
+
residue_index = np.arange(n) + 1 if residue_index is None else sqz(residue_index)
|
| 193 |
+
chain_index = np.zeros(n) if chain_index is None else sqz(chain_index)
|
| 194 |
+
b_factors = np.zeros([n, 37]) if b_factors is None else sqz(b_factors)
|
| 195 |
+
aatype = np.zeros(n, dtype=int) if aatype is None else sqz(aatype)
|
| 196 |
+
|
| 197 |
+
return protein.Protein(
|
| 198 |
+
atom_positions=atom_positions,
|
| 199 |
+
atom_mask=atom_mask,
|
| 200 |
+
aatype=aatype,
|
| 201 |
+
residue_index=residue_index,
|
| 202 |
+
chain_index=chain_index,
|
| 203 |
+
b_factors=b_factors
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def atom37_to_pdb(
|
| 207 |
+
save_to: str,
|
| 208 |
+
atom_positions: np.ndarray,
|
| 209 |
+
aatype: Optional[np.ndarray] = None,
|
| 210 |
+
b_factors: Optional[np.ndarray] = None,
|
| 211 |
+
chain_index: Optional[np.ndarray] = None,
|
| 212 |
+
residue_index: Optional[np.ndarray] = None,
|
| 213 |
+
overwrite: bool = False,
|
| 214 |
+
no_indexing: bool = True,
|
| 215 |
+
):
|
| 216 |
+
# configure save path
|
| 217 |
+
if overwrite:
|
| 218 |
+
max_existing_idx = 0
|
| 219 |
+
else:
|
| 220 |
+
file_dir = os.path.dirname(save_to)
|
| 221 |
+
file_name = os.path.basename(save_to).strip('.pdb')
|
| 222 |
+
existing_files = [x for x in os.listdir(file_dir) if file_name in x]
|
| 223 |
+
max_existing_idx = max([
|
| 224 |
+
int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x)
|
| 225 |
+
if re.findall(r'_(\d+).pdb', x)] + [0])
|
| 226 |
+
if not no_indexing:
|
| 227 |
+
save_to = save_to.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb'
|
| 228 |
+
else:
|
| 229 |
+
save_to = save_to
|
| 230 |
+
|
| 231 |
+
with open(save_to, 'w') as f:
|
| 232 |
+
if atom_positions.ndim == 4:
|
| 233 |
+
for mi, pos37 in enumerate(atom_positions):
|
| 234 |
+
atom_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7
|
| 235 |
+
prot = protein_with_default_params(
|
| 236 |
+
pos37, atom_mask, aatype=aatype, b_factors=b_factors,
|
| 237 |
+
chain_index=chain_index, residue_index=residue_index
|
| 238 |
+
)
|
| 239 |
+
pdb_str = protein.to_pdb(prot, model=mi+1, add_end=False)
|
| 240 |
+
f.write(pdb_str)
|
| 241 |
+
elif atom_positions.ndim == 3:
|
| 242 |
+
atom_mask = np.sum(np.abs(atom_positions), axis=-1) > 1e-7
|
| 243 |
+
prot = protein_with_default_params(
|
| 244 |
+
atom_positions, atom_mask, aatype=aatype, b_factors=b_factors,
|
| 245 |
+
chain_index=chain_index, residue_index=residue_index
|
| 246 |
+
)
|
| 247 |
+
pdb_str = protein.to_pdb(prot, model=1, add_end=False)
|
| 248 |
+
f.write(pdb_str)
|
| 249 |
+
else:
|
| 250 |
+
raise ValueError(f'Invalid positions shape {atom_positions.shape}')
|
| 251 |
+
f.write('END')
|
| 252 |
+
|
| 253 |
+
return save_to
|
| 254 |
+
|
| 255 |
+
def extract_backbone_coords_from_pdb(pdb_path: str, target_atoms: Optional[list] = ["CA"]):
|
| 256 |
+
structure = PDBFile.read(pdb_path)
|
| 257 |
+
structure_list = structure.get_structure()
|
| 258 |
+
|
| 259 |
+
coords_list = []
|
| 260 |
+
for b_idx in range(structure.get_model_count()):
|
| 261 |
+
chain = structure_list[b_idx]
|
| 262 |
+
|
| 263 |
+
backbone_atoms = chain[struct.filter_backbone(chain)] # This includes the “N”, “CA” and “C” atoms of amino acids.
|
| 264 |
+
ret_coords = OrderedDict()
|
| 265 |
+
# init dict
|
| 266 |
+
for k in target_atoms:
|
| 267 |
+
ret_coords[k] = []
|
| 268 |
+
|
| 269 |
+
for c in backbone_atoms:
|
| 270 |
+
if c.atom_name in ret_coords:
|
| 271 |
+
ret_coords[c.atom_name].append(c.coord)
|
| 272 |
+
|
| 273 |
+
ret_coords = [np.vstack(v) for k,v in ret_coords.items()]
|
| 274 |
+
if len(target_atoms) == 1:
|
| 275 |
+
ret_coords = ret_coords[0] # L, 3
|
| 276 |
+
else:
|
| 277 |
+
ret_coords = np.stack(ret_coords, axis=1) # L, na, 3
|
| 278 |
+
|
| 279 |
+
coords_list.append(ret_coords)
|
| 280 |
+
|
| 281 |
+
coords_list = np.stack(coords_list, axis=0) # B, L, na, 3 or B, L, 3 (ca only)
|
| 282 |
+
return coords_list
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def extract_backbone_coords_from_pdb_dir(pdb_dir: str):
|
| 286 |
+
return np.concatenate([
|
| 287 |
+
extract_backbone_coords_from_pdb(os.path.join(pdb_dir, f))
|
| 288 |
+
for f in os.listdir(pdb_dir) if f.endswith('.pdb')
|
| 289 |
+
], axis=0)
|
| 290 |
+
|
| 291 |
+
def extract_backbone_coords_from_npy(npy_path: str):
|
| 292 |
+
return np.load(npy_path)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def extract_backbone_coords(input_path: str,
|
| 296 |
+
max_n_model: Optional[int] = None,
|
| 297 |
+
):
|
| 298 |
+
"""Extract backbone coordinates from PDB file.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
input_path (str): The path to the PDB file.
|
| 302 |
+
ca_only (bool): Whether to extract only CA coordinates.
|
| 303 |
+
max_n_model (int): The maximum number of models to extract.
|
| 304 |
+
"""
|
| 305 |
+
assert os.path.exists(input_path), f"File {input_path} does not exist."
|
| 306 |
+
if input_path.endswith('.pdb'):
|
| 307 |
+
coords = extract_backbone_coords_from_pdb(input_path)
|
| 308 |
+
elif input_path.endswith('.npy'):
|
| 309 |
+
coords = extract_backbone_coords_from_npy(input_path)
|
| 310 |
+
elif os.path.isdir(input_path):
|
| 311 |
+
coords = extract_backbone_coords_from_pdb_dir(input_path)
|
| 312 |
+
else:
|
| 313 |
+
raise ValueError(f"Unrecognized input path {input_path}.")
|
| 314 |
+
|
| 315 |
+
if max_n_model is not None and len(coords) > max_n_model > 0:
|
| 316 |
+
coords = coords[:max_n_model]
|
| 317 |
+
return coords
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == '__main__':
|
| 322 |
+
import argparse
|
| 323 |
+
def get_argparser():
|
| 324 |
+
parser = argparse.ArgumentParser(description='Main script for pdb processing.')
|
| 325 |
+
parser.add_argument("input", type=str, help="The generic path to sampled pdb directory / pdb file.")
|
| 326 |
+
parser.add_argument("-m", "--mode", type=str, help="The mode of processing.",
|
| 327 |
+
default="split")
|
| 328 |
+
parser.add_argument("-o", "--output", type=str, help="The output directory for processed pdb files.",
|
| 329 |
+
default=None)
|
| 330 |
+
|
| 331 |
+
args = parser.parse_args()
|
| 332 |
+
return args
|
| 333 |
+
args = get_argparser()
|
| 334 |
+
|
| 335 |
+
# ad hoc functions
|
| 336 |
+
def split_pdbs(args):
|
| 337 |
+
os.makedirs(args.output, exist_ok=True)
|
| 338 |
+
_ = split_pdbfile(pdb_file=args.input,
|
| 339 |
+
output_dir=args.output)
|
| 340 |
+
|
| 341 |
+
def merge_pdbs(args):
|
| 342 |
+
output = args.output or f"{args.input}_all.pdb"
|
| 343 |
+
merge_pdbfiles(input=args.input,
|
| 344 |
+
output_file=output)
|
| 345 |
+
|
| 346 |
+
if args.mode == "split":
|
| 347 |
+
split_pdbs(args)
|
| 348 |
+
elif args.mode == "merge":
|
| 349 |
+
merge_pdbs(args)
|
| 350 |
+
elif args.mode == "stratify":
|
| 351 |
+
stratify_sample_pdbfile(input_path=args.input, output_path=args.output)
|
| 352 |
+
else:
|
| 353 |
+
raise ValueError(f"Unrecognized mode {args.mode}.")
|
analysis/src/common/protein.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Protein data type."""
|
| 16 |
+
import dataclasses
|
| 17 |
+
import io
|
| 18 |
+
from typing import Any, Mapping, Optional
|
| 19 |
+
|
| 20 |
+
from Bio.PDB import PDBParser
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
from src.common import residue_constants
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
FeatureDict = Mapping[str, np.ndarray]
|
| 27 |
+
ModelOutput = Mapping[str, Any] # Is a nested dict.
|
| 28 |
+
|
| 29 |
+
# Complete sequence of chain IDs supported by the PDB format.
|
| 30 |
+
PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
|
| 31 |
+
PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclasses.dataclass(frozen=True)
|
| 35 |
+
class Protein:
|
| 36 |
+
"""Protein structure representation."""
|
| 37 |
+
|
| 38 |
+
# Cartesian coordinates of atoms in angstroms. The atom types correspond to
|
| 39 |
+
# residue_constants.atom_types, i.e. the first three are N, CA, CB.
|
| 40 |
+
atom_positions: np.ndarray # [num_res, num_atom_type, 3]
|
| 41 |
+
|
| 42 |
+
# Amino-acid type for each residue represented as an integer between 0 and
|
| 43 |
+
# 20, where 20 is 'X'.
|
| 44 |
+
aatype: np.ndarray # [num_res]
|
| 45 |
+
|
| 46 |
+
# Binary float mask to indicate presence of a particular atom. 1.0 if an atom
|
| 47 |
+
# is present and 0.0 if not. This should be used for loss masking.
|
| 48 |
+
atom_mask: np.ndarray # [num_res, num_atom_type]
|
| 49 |
+
|
| 50 |
+
# Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
|
| 51 |
+
residue_index: np.ndarray # [num_res]
|
| 52 |
+
|
| 53 |
+
# 0-indexed number corresponding to the chain in the protein that this residue
|
| 54 |
+
# belongs to.
|
| 55 |
+
chain_index: np.ndarray # [num_res]
|
| 56 |
+
|
| 57 |
+
# B-factors, or temperature factors, of each residue (in sq. angstroms units),
|
| 58 |
+
# representing the displacement of the residue from its ground truth mean
|
| 59 |
+
# value.
|
| 60 |
+
b_factors: np.ndarray # [num_res, num_atom_type]
|
| 61 |
+
|
| 62 |
+
def __post_init__(self):
|
| 63 |
+
if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
|
| 64 |
+
raise ValueError(
|
| 65 |
+
f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
|
| 66 |
+
'because these cannot be written to PDB format.')
|
| 67 |
+
|
| 68 |
+
def to_dict(self):
|
| 69 |
+
return dataclasses.asdict(self)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
|
| 73 |
+
"""Takes a PDB string and constructs a Protein object.
|
| 74 |
+
|
| 75 |
+
WARNING: All non-standard residue types will be converted into UNK. All
|
| 76 |
+
non-standard atoms will be ignored.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
pdb_str: The contents of the pdb file
|
| 80 |
+
chain_id: If chain_id is specified (e.g. A), then only that chain
|
| 81 |
+
is parsed. Otherwise all chains are parsed.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
A new `Protein` parsed from the pdb contents.
|
| 85 |
+
"""
|
| 86 |
+
pdb_fh = io.StringIO(pdb_str)
|
| 87 |
+
parser = PDBParser(QUIET=True)
|
| 88 |
+
structure = parser.get_structure('none', pdb_fh)
|
| 89 |
+
models = list(structure.get_models())
|
| 90 |
+
if len(models) != 1:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
f'Only single model PDBs are supported. Found {len(models)} models.')
|
| 93 |
+
model = models[0]
|
| 94 |
+
|
| 95 |
+
atom_positions = []
|
| 96 |
+
aatype = []
|
| 97 |
+
atom_mask = []
|
| 98 |
+
residue_index = []
|
| 99 |
+
chain_ids = []
|
| 100 |
+
b_factors = []
|
| 101 |
+
|
| 102 |
+
for chain in model:
|
| 103 |
+
if chain_id is not None and chain.id != chain_id:
|
| 104 |
+
continue
|
| 105 |
+
for res in chain:
|
| 106 |
+
if res.id[2] != ' ':
|
| 107 |
+
raise ValueError(
|
| 108 |
+
f'PDB contains an insertion code at chain {chain.id} and residue '
|
| 109 |
+
f'index {res.id[1]}. These are not supported.')
|
| 110 |
+
res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
|
| 111 |
+
restype_idx = residue_constants.restype_order.get(
|
| 112 |
+
res_shortname, residue_constants.restype_num)
|
| 113 |
+
pos = np.zeros((residue_constants.atom_type_num, 3))
|
| 114 |
+
mask = np.zeros((residue_constants.atom_type_num,))
|
| 115 |
+
res_b_factors = np.zeros((residue_constants.atom_type_num,))
|
| 116 |
+
for atom in res:
|
| 117 |
+
if atom.name not in residue_constants.atom_types:
|
| 118 |
+
continue
|
| 119 |
+
pos[residue_constants.atom_order[atom.name]] = atom.coord
|
| 120 |
+
mask[residue_constants.atom_order[atom.name]] = 1.
|
| 121 |
+
res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
|
| 122 |
+
if np.sum(mask) < 0.5:
|
| 123 |
+
# If no known atom positions are reported for the residue then skip it.
|
| 124 |
+
continue
|
| 125 |
+
aatype.append(restype_idx)
|
| 126 |
+
atom_positions.append(pos)
|
| 127 |
+
atom_mask.append(mask)
|
| 128 |
+
residue_index.append(res.id[1])
|
| 129 |
+
chain_ids.append(chain.id)
|
| 130 |
+
b_factors.append(res_b_factors)
|
| 131 |
+
|
| 132 |
+
# Chain IDs are usually characters so map these to ints.
|
| 133 |
+
unique_chain_ids = np.unique(chain_ids)
|
| 134 |
+
chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
|
| 135 |
+
chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
|
| 136 |
+
|
| 137 |
+
return Protein(
|
| 138 |
+
atom_positions=np.array(atom_positions),
|
| 139 |
+
atom_mask=np.array(atom_mask),
|
| 140 |
+
aatype=np.array(aatype),
|
| 141 |
+
residue_index=np.array(residue_index),
|
| 142 |
+
chain_index=chain_index,
|
| 143 |
+
b_factors=np.array(b_factors))
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
|
| 147 |
+
chain_end = 'TER'
|
| 148 |
+
return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
|
| 149 |
+
f'{chain_name:>1}{residue_index:>4}')
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def to_pdb(prot: Protein, model=1, add_end=True) -> str:
|
| 153 |
+
"""Converts a `Protein` instance to a PDB string.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
prot: The protein to convert to PDB.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
PDB string.
|
| 160 |
+
"""
|
| 161 |
+
restypes = residue_constants.restypes + ['X']
|
| 162 |
+
res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
|
| 163 |
+
atom_types = residue_constants.atom_types
|
| 164 |
+
|
| 165 |
+
pdb_lines = []
|
| 166 |
+
|
| 167 |
+
atom_mask = prot.atom_mask
|
| 168 |
+
aatype = prot.aatype
|
| 169 |
+
atom_positions = prot.atom_positions
|
| 170 |
+
residue_index = prot.residue_index.astype(int)
|
| 171 |
+
chain_index = prot.chain_index.astype(int)
|
| 172 |
+
b_factors = prot.b_factors
|
| 173 |
+
|
| 174 |
+
if np.any(aatype > residue_constants.restype_num):
|
| 175 |
+
raise ValueError('Invalid aatypes.')
|
| 176 |
+
|
| 177 |
+
# Construct a mapping from chain integer indices to chain ID strings.
|
| 178 |
+
chain_ids = {}
|
| 179 |
+
for i in np.unique(chain_index): # np.unique gives sorted output.
|
| 180 |
+
if i >= PDB_MAX_CHAINS:
|
| 181 |
+
raise ValueError(
|
| 182 |
+
f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
|
| 183 |
+
chain_ids[i] = PDB_CHAIN_IDS[i]
|
| 184 |
+
|
| 185 |
+
pdb_lines.append(f'MODEL {model}')
|
| 186 |
+
atom_index = 1
|
| 187 |
+
last_chain_index = chain_index[0]
|
| 188 |
+
# Add all atom sites.
|
| 189 |
+
for i in range(aatype.shape[0]):
|
| 190 |
+
# Close the previous chain if in a multichain PDB.
|
| 191 |
+
if last_chain_index != chain_index[i]:
|
| 192 |
+
pdb_lines.append(_chain_end(
|
| 193 |
+
atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
|
| 194 |
+
residue_index[i - 1]))
|
| 195 |
+
last_chain_index = chain_index[i]
|
| 196 |
+
atom_index += 1 # Atom index increases at the TER symbol.
|
| 197 |
+
|
| 198 |
+
res_name_3 = res_1to3(aatype[i])
|
| 199 |
+
for atom_name, pos, mask, b_factor in zip(
|
| 200 |
+
atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
|
| 201 |
+
if mask < 0.5:
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
# skip CB for GLY
|
| 205 |
+
if res_name_3 == 'GLY' and atom_name == 'CB':
|
| 206 |
+
continue
|
| 207 |
+
|
| 208 |
+
record_type = 'ATOM'
|
| 209 |
+
name = atom_name if len(atom_name) == 4 else f' {atom_name}'
|
| 210 |
+
alt_loc = ''
|
| 211 |
+
insertion_code = ''
|
| 212 |
+
occupancy = 1.00
|
| 213 |
+
element = atom_name[0] # Protein supports only C, N, O, S, this works.
|
| 214 |
+
charge = ''
|
| 215 |
+
# PDB is a columnar format, every space matters here!
|
| 216 |
+
atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
|
| 217 |
+
f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
|
| 218 |
+
f'{residue_index[i]:>4}{insertion_code:>1} '
|
| 219 |
+
f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
|
| 220 |
+
f'{occupancy:>6.2f}{b_factor:>6.2f} '
|
| 221 |
+
f'{element:>2}{charge:>2}')
|
| 222 |
+
pdb_lines.append(atom_line)
|
| 223 |
+
atom_index += 1
|
| 224 |
+
|
| 225 |
+
# Close the final chain.
|
| 226 |
+
pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
|
| 227 |
+
chain_ids[chain_index[-1]], residue_index[-1]))
|
| 228 |
+
pdb_lines.append('ENDMDL')
|
| 229 |
+
if add_end:
|
| 230 |
+
pdb_lines.append('END')
|
| 231 |
+
|
| 232 |
+
# Pad all lines to 80 characters.
|
| 233 |
+
pdb_lines = [line.ljust(80) for line in pdb_lines]
|
| 234 |
+
return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def ideal_atom_mask(prot: Protein) -> np.ndarray:
|
| 238 |
+
"""Computes an ideal atom mask.
|
| 239 |
+
|
| 240 |
+
`Protein.atom_mask` typically is defined according to the atoms that are
|
| 241 |
+
reported in the PDB. This function computes a mask according to heavy atoms
|
| 242 |
+
that should be present in the given sequence of amino acids.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
prot: `Protein` whose fields are `numpy.ndarray` objects.
|
| 246 |
+
|
| 247 |
+
Returns:
|
| 248 |
+
An ideal atom mask.
|
| 249 |
+
"""
|
| 250 |
+
return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def from_prediction(
|
| 254 |
+
features: FeatureDict,
|
| 255 |
+
result: ModelOutput,
|
| 256 |
+
b_factors: Optional[np.ndarray] = None,
|
| 257 |
+
remove_leading_feature_dimension: bool = True) -> Protein:
|
| 258 |
+
"""Assembles a protein from a prediction.
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
features: Dictionary holding model inputs.
|
| 262 |
+
result: Dictionary holding model outputs.
|
| 263 |
+
b_factors: (Optional) B-factors to use for the protein.
|
| 264 |
+
remove_leading_feature_dimension: Whether to remove the leading dimension
|
| 265 |
+
of the `features` values.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
A protein instance.
|
| 269 |
+
"""
|
| 270 |
+
fold_output = result['structure_module']
|
| 271 |
+
|
| 272 |
+
def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
|
| 273 |
+
return arr[0] if remove_leading_feature_dimension else arr
|
| 274 |
+
|
| 275 |
+
if 'asym_id' in features:
|
| 276 |
+
chain_index = _maybe_remove_leading_dim(features['asym_id'])
|
| 277 |
+
else:
|
| 278 |
+
chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))
|
| 279 |
+
|
| 280 |
+
if b_factors is None:
|
| 281 |
+
b_factors = np.zeros_like(fold_output['final_atom_mask'])
|
| 282 |
+
|
| 283 |
+
return Protein(
|
| 284 |
+
aatype=_maybe_remove_leading_dim(features['aatype']),
|
| 285 |
+
atom_positions=fold_output['final_atom_positions'],
|
| 286 |
+
atom_mask=fold_output['final_atom_mask'],
|
| 287 |
+
residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
|
| 288 |
+
chain_index=chain_index,
|
| 289 |
+
b_factors=b_factors)
|
analysis/src/common/residue_constants.py
ADDED
|
@@ -0,0 +1,897 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Constants used in AlphaFold."""
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import functools
|
| 19 |
+
import os
|
| 20 |
+
from typing import List, Mapping, Tuple
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tree
|
| 24 |
+
|
| 25 |
+
# Internal import (35fd).
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Distance from one CA to next CA [trans configuration: omega = 180].
|
| 29 |
+
ca_ca = 3.80209737096
|
| 30 |
+
|
| 31 |
+
# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
|
| 32 |
+
# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
|
| 33 |
+
# chi angles so their chi angle lists are empty.
|
| 34 |
+
chi_angles_atoms = {
|
| 35 |
+
'ALA': [],
|
| 36 |
+
# Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
|
| 37 |
+
'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
|
| 38 |
+
['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']],
|
| 39 |
+
'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
|
| 40 |
+
'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
|
| 41 |
+
'CYS': [['N', 'CA', 'CB', 'SG']],
|
| 42 |
+
'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
|
| 43 |
+
['CB', 'CG', 'CD', 'OE1']],
|
| 44 |
+
'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
|
| 45 |
+
['CB', 'CG', 'CD', 'OE1']],
|
| 46 |
+
'GLY': [],
|
| 47 |
+
'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
|
| 48 |
+
'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
|
| 49 |
+
'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
|
| 50 |
+
'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
|
| 51 |
+
['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']],
|
| 52 |
+
'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'],
|
| 53 |
+
['CB', 'CG', 'SD', 'CE']],
|
| 54 |
+
'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
|
| 55 |
+
'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
|
| 56 |
+
'SER': [['N', 'CA', 'CB', 'OG']],
|
| 57 |
+
'THR': [['N', 'CA', 'CB', 'OG1']],
|
| 58 |
+
'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
|
| 59 |
+
'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
|
| 60 |
+
'VAL': [['N', 'CA', 'CB', 'CG1']],
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
# If chi angles given in fixed-length array, this matrix determines how to mask
|
| 64 |
+
# them for each AA type. The order is as per restype_order (see below).
|
| 65 |
+
chi_angles_mask = [
|
| 66 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
| 67 |
+
[1.0, 1.0, 1.0, 1.0], # ARG
|
| 68 |
+
[1.0, 1.0, 0.0, 0.0], # ASN
|
| 69 |
+
[1.0, 1.0, 0.0, 0.0], # ASP
|
| 70 |
+
[1.0, 0.0, 0.0, 0.0], # CYS
|
| 71 |
+
[1.0, 1.0, 1.0, 0.0], # GLN
|
| 72 |
+
[1.0, 1.0, 1.0, 0.0], # GLU
|
| 73 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
| 74 |
+
[1.0, 1.0, 0.0, 0.0], # HIS
|
| 75 |
+
[1.0, 1.0, 0.0, 0.0], # ILE
|
| 76 |
+
[1.0, 1.0, 0.0, 0.0], # LEU
|
| 77 |
+
[1.0, 1.0, 1.0, 1.0], # LYS
|
| 78 |
+
[1.0, 1.0, 1.0, 0.0], # MET
|
| 79 |
+
[1.0, 1.0, 0.0, 0.0], # PHE
|
| 80 |
+
[1.0, 1.0, 0.0, 0.0], # PRO
|
| 81 |
+
[1.0, 0.0, 0.0, 0.0], # SER
|
| 82 |
+
[1.0, 0.0, 0.0, 0.0], # THR
|
| 83 |
+
[1.0, 1.0, 0.0, 0.0], # TRP
|
| 84 |
+
[1.0, 1.0, 0.0, 0.0], # TYR
|
| 85 |
+
[1.0, 0.0, 0.0, 0.0], # VAL
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
# The following chi angles are pi periodic: they can be rotated by a multiple
|
| 89 |
+
# of pi without affecting the structure.
|
| 90 |
+
chi_pi_periodic = [
|
| 91 |
+
[0.0, 0.0, 0.0, 0.0], # ALA
|
| 92 |
+
[0.0, 0.0, 0.0, 0.0], # ARG
|
| 93 |
+
[0.0, 0.0, 0.0, 0.0], # ASN
|
| 94 |
+
[0.0, 1.0, 0.0, 0.0], # ASP
|
| 95 |
+
[0.0, 0.0, 0.0, 0.0], # CYS
|
| 96 |
+
[0.0, 0.0, 0.0, 0.0], # GLN
|
| 97 |
+
[0.0, 0.0, 1.0, 0.0], # GLU
|
| 98 |
+
[0.0, 0.0, 0.0, 0.0], # GLY
|
| 99 |
+
[0.0, 0.0, 0.0, 0.0], # HIS
|
| 100 |
+
[0.0, 0.0, 0.0, 0.0], # ILE
|
| 101 |
+
[0.0, 0.0, 0.0, 0.0], # LEU
|
| 102 |
+
[0.0, 0.0, 0.0, 0.0], # LYS
|
| 103 |
+
[0.0, 0.0, 0.0, 0.0], # MET
|
| 104 |
+
[0.0, 1.0, 0.0, 0.0], # PHE
|
| 105 |
+
[0.0, 0.0, 0.0, 0.0], # PRO
|
| 106 |
+
[0.0, 0.0, 0.0, 0.0], # SER
|
| 107 |
+
[0.0, 0.0, 0.0, 0.0], # THR
|
| 108 |
+
[0.0, 0.0, 0.0, 0.0], # TRP
|
| 109 |
+
[0.0, 1.0, 0.0, 0.0], # TYR
|
| 110 |
+
[0.0, 0.0, 0.0, 0.0], # VAL
|
| 111 |
+
[0.0, 0.0, 0.0, 0.0], # UNK
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
# Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
|
| 115 |
+
# psi and chi angles:
|
| 116 |
+
# 0: 'backbone group',
|
| 117 |
+
# 1: 'pre-omega-group', (empty)
|
| 118 |
+
# 2: 'phi-group', (currently empty, because it defines only hydrogens)
|
| 119 |
+
# 3: 'psi-group',
|
| 120 |
+
# 4,5,6,7: 'chi1,2,3,4-group'
|
| 121 |
+
# The atom positions are relative to the axis-end-atom of the corresponding
|
| 122 |
+
# rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
|
| 123 |
+
# is defined such that the dihedral-angle-definiting atom (the last entry in
|
| 124 |
+
# chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
|
| 125 |
+
# format: [atomname, group_idx, rel_position]
|
| 126 |
+
rigid_group_atom_positions = {
|
| 127 |
+
'ALA': [
|
| 128 |
+
['N', 0, (-0.525, 1.363, 0.000)],
|
| 129 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 130 |
+
['C', 0, (1.526, -0.000, -0.000)],
|
| 131 |
+
['CB', 0, (-0.529, -0.774, -1.205)],
|
| 132 |
+
['O', 3, (0.627, 1.062, 0.000)],
|
| 133 |
+
],
|
| 134 |
+
'ARG': [
|
| 135 |
+
['N', 0, (-0.524, 1.362, -0.000)],
|
| 136 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 137 |
+
['C', 0, (1.525, -0.000, -0.000)],
|
| 138 |
+
['CB', 0, (-0.524, -0.778, -1.209)],
|
| 139 |
+
['O', 3, (0.626, 1.062, 0.000)],
|
| 140 |
+
['CG', 4, (0.616, 1.390, -0.000)],
|
| 141 |
+
['CD', 5, (0.564, 1.414, 0.000)],
|
| 142 |
+
['NE', 6, (0.539, 1.357, -0.000)],
|
| 143 |
+
['NH1', 7, (0.206, 2.301, 0.000)],
|
| 144 |
+
['NH2', 7, (2.078, 0.978, -0.000)],
|
| 145 |
+
['CZ', 7, (0.758, 1.093, -0.000)],
|
| 146 |
+
],
|
| 147 |
+
'ASN': [
|
| 148 |
+
['N', 0, (-0.536, 1.357, 0.000)],
|
| 149 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 150 |
+
['C', 0, (1.526, -0.000, -0.000)],
|
| 151 |
+
['CB', 0, (-0.531, -0.787, -1.200)],
|
| 152 |
+
['O', 3, (0.625, 1.062, 0.000)],
|
| 153 |
+
['CG', 4, (0.584, 1.399, 0.000)],
|
| 154 |
+
['ND2', 5, (0.593, -1.188, 0.001)],
|
| 155 |
+
['OD1', 5, (0.633, 1.059, 0.000)],
|
| 156 |
+
],
|
| 157 |
+
'ASP': [
|
| 158 |
+
['N', 0, (-0.525, 1.362, -0.000)],
|
| 159 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 160 |
+
['C', 0, (1.527, 0.000, -0.000)],
|
| 161 |
+
['CB', 0, (-0.526, -0.778, -1.208)],
|
| 162 |
+
['O', 3, (0.626, 1.062, -0.000)],
|
| 163 |
+
['CG', 4, (0.593, 1.398, -0.000)],
|
| 164 |
+
['OD1', 5, (0.610, 1.091, 0.000)],
|
| 165 |
+
['OD2', 5, (0.592, -1.101, -0.003)],
|
| 166 |
+
],
|
| 167 |
+
'CYS': [
|
| 168 |
+
['N', 0, (-0.522, 1.362, -0.000)],
|
| 169 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 170 |
+
['C', 0, (1.524, 0.000, 0.000)],
|
| 171 |
+
['CB', 0, (-0.519, -0.773, -1.212)],
|
| 172 |
+
['O', 3, (0.625, 1.062, -0.000)],
|
| 173 |
+
['SG', 4, (0.728, 1.653, 0.000)],
|
| 174 |
+
],
|
| 175 |
+
'GLN': [
|
| 176 |
+
['N', 0, (-0.526, 1.361, -0.000)],
|
| 177 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 178 |
+
['C', 0, (1.526, 0.000, 0.000)],
|
| 179 |
+
['CB', 0, (-0.525, -0.779, -1.207)],
|
| 180 |
+
['O', 3, (0.626, 1.062, -0.000)],
|
| 181 |
+
['CG', 4, (0.615, 1.393, 0.000)],
|
| 182 |
+
['CD', 5, (0.587, 1.399, -0.000)],
|
| 183 |
+
['NE2', 6, (0.593, -1.189, -0.001)],
|
| 184 |
+
['OE1', 6, (0.634, 1.060, 0.000)],
|
| 185 |
+
],
|
| 186 |
+
'GLU': [
|
| 187 |
+
['N', 0, (-0.528, 1.361, 0.000)],
|
| 188 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 189 |
+
['C', 0, (1.526, -0.000, -0.000)],
|
| 190 |
+
['CB', 0, (-0.526, -0.781, -1.207)],
|
| 191 |
+
['O', 3, (0.626, 1.062, 0.000)],
|
| 192 |
+
['CG', 4, (0.615, 1.392, 0.000)],
|
| 193 |
+
['CD', 5, (0.600, 1.397, 0.000)],
|
| 194 |
+
['OE1', 6, (0.607, 1.095, -0.000)],
|
| 195 |
+
['OE2', 6, (0.589, -1.104, -0.001)],
|
| 196 |
+
],
|
| 197 |
+
'GLY': [
|
| 198 |
+
['N', 0, (-0.572, 1.337, 0.000)],
|
| 199 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 200 |
+
['C', 0, (1.517, -0.000, -0.000)],
|
| 201 |
+
['O', 3, (0.626, 1.062, -0.000)],
|
| 202 |
+
],
|
| 203 |
+
'HIS': [
|
| 204 |
+
['N', 0, (-0.527, 1.360, 0.000)],
|
| 205 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 206 |
+
['C', 0, (1.525, 0.000, 0.000)],
|
| 207 |
+
['CB', 0, (-0.525, -0.778, -1.208)],
|
| 208 |
+
['O', 3, (0.625, 1.063, 0.000)],
|
| 209 |
+
['CG', 4, (0.600, 1.370, -0.000)],
|
| 210 |
+
['CD2', 5, (0.889, -1.021, 0.003)],
|
| 211 |
+
['ND1', 5, (0.744, 1.160, -0.000)],
|
| 212 |
+
['CE1', 5, (2.030, 0.851, 0.002)],
|
| 213 |
+
['NE2', 5, (2.145, -0.466, 0.004)],
|
| 214 |
+
],
|
| 215 |
+
'ILE': [
|
| 216 |
+
['N', 0, (-0.493, 1.373, -0.000)],
|
| 217 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 218 |
+
['C', 0, (1.527, -0.000, -0.000)],
|
| 219 |
+
['CB', 0, (-0.536, -0.793, -1.213)],
|
| 220 |
+
['O', 3, (0.627, 1.062, -0.000)],
|
| 221 |
+
['CG1', 4, (0.534, 1.437, -0.000)],
|
| 222 |
+
['CG2', 4, (0.540, -0.785, -1.199)],
|
| 223 |
+
['CD1', 5, (0.619, 1.391, 0.000)],
|
| 224 |
+
],
|
| 225 |
+
'LEU': [
|
| 226 |
+
['N', 0, (-0.520, 1.363, 0.000)],
|
| 227 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 228 |
+
['C', 0, (1.525, -0.000, -0.000)],
|
| 229 |
+
['CB', 0, (-0.522, -0.773, -1.214)],
|
| 230 |
+
['O', 3, (0.625, 1.063, -0.000)],
|
| 231 |
+
['CG', 4, (0.678, 1.371, 0.000)],
|
| 232 |
+
['CD1', 5, (0.530, 1.430, -0.000)],
|
| 233 |
+
['CD2', 5, (0.535, -0.774, 1.200)],
|
| 234 |
+
],
|
| 235 |
+
'LYS': [
|
| 236 |
+
['N', 0, (-0.526, 1.362, -0.000)],
|
| 237 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 238 |
+
['C', 0, (1.526, 0.000, 0.000)],
|
| 239 |
+
['CB', 0, (-0.524, -0.778, -1.208)],
|
| 240 |
+
['O', 3, (0.626, 1.062, -0.000)],
|
| 241 |
+
['CG', 4, (0.619, 1.390, 0.000)],
|
| 242 |
+
['CD', 5, (0.559, 1.417, 0.000)],
|
| 243 |
+
['CE', 6, (0.560, 1.416, 0.000)],
|
| 244 |
+
['NZ', 7, (0.554, 1.387, 0.000)],
|
| 245 |
+
],
|
| 246 |
+
'MET': [
|
| 247 |
+
['N', 0, (-0.521, 1.364, -0.000)],
|
| 248 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 249 |
+
['C', 0, (1.525, 0.000, 0.000)],
|
| 250 |
+
['CB', 0, (-0.523, -0.776, -1.210)],
|
| 251 |
+
['O', 3, (0.625, 1.062, -0.000)],
|
| 252 |
+
['CG', 4, (0.613, 1.391, -0.000)],
|
| 253 |
+
['SD', 5, (0.703, 1.695, 0.000)],
|
| 254 |
+
['CE', 6, (0.320, 1.786, -0.000)],
|
| 255 |
+
],
|
| 256 |
+
'PHE': [
|
| 257 |
+
['N', 0, (-0.518, 1.363, 0.000)],
|
| 258 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 259 |
+
['C', 0, (1.524, 0.000, -0.000)],
|
| 260 |
+
['CB', 0, (-0.525, -0.776, -1.212)],
|
| 261 |
+
['O', 3, (0.626, 1.062, -0.000)],
|
| 262 |
+
['CG', 4, (0.607, 1.377, 0.000)],
|
| 263 |
+
['CD1', 5, (0.709, 1.195, -0.000)],
|
| 264 |
+
['CD2', 5, (0.706, -1.196, 0.000)],
|
| 265 |
+
['CE1', 5, (2.102, 1.198, -0.000)],
|
| 266 |
+
['CE2', 5, (2.098, -1.201, -0.000)],
|
| 267 |
+
['CZ', 5, (2.794, -0.003, -0.001)],
|
| 268 |
+
],
|
| 269 |
+
'PRO': [
|
| 270 |
+
['N', 0, (-0.566, 1.351, -0.000)],
|
| 271 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 272 |
+
['C', 0, (1.527, -0.000, 0.000)],
|
| 273 |
+
['CB', 0, (-0.546, -0.611, -1.293)],
|
| 274 |
+
['O', 3, (0.621, 1.066, 0.000)],
|
| 275 |
+
['CG', 4, (0.382, 1.445, 0.0)],
|
| 276 |
+
# ['CD', 5, (0.427, 1.440, 0.0)],
|
| 277 |
+
['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
|
| 278 |
+
],
|
| 279 |
+
'SER': [
|
| 280 |
+
['N', 0, (-0.529, 1.360, -0.000)],
|
| 281 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 282 |
+
['C', 0, (1.525, -0.000, -0.000)],
|
| 283 |
+
['CB', 0, (-0.518, -0.777, -1.211)],
|
| 284 |
+
['O', 3, (0.626, 1.062, -0.000)],
|
| 285 |
+
['OG', 4, (0.503, 1.325, 0.000)],
|
| 286 |
+
],
|
| 287 |
+
'THR': [
|
| 288 |
+
['N', 0, (-0.517, 1.364, 0.000)],
|
| 289 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 290 |
+
['C', 0, (1.526, 0.000, -0.000)],
|
| 291 |
+
['CB', 0, (-0.516, -0.793, -1.215)],
|
| 292 |
+
['O', 3, (0.626, 1.062, 0.000)],
|
| 293 |
+
['CG2', 4, (0.550, -0.718, -1.228)],
|
| 294 |
+
['OG1', 4, (0.472, 1.353, 0.000)],
|
| 295 |
+
],
|
| 296 |
+
'TRP': [
|
| 297 |
+
['N', 0, (-0.521, 1.363, 0.000)],
|
| 298 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 299 |
+
['C', 0, (1.525, -0.000, 0.000)],
|
| 300 |
+
['CB', 0, (-0.523, -0.776, -1.212)],
|
| 301 |
+
['O', 3, (0.627, 1.062, 0.000)],
|
| 302 |
+
['CG', 4, (0.609, 1.370, -0.000)],
|
| 303 |
+
['CD1', 5, (0.824, 1.091, 0.000)],
|
| 304 |
+
['CD2', 5, (0.854, -1.148, -0.005)],
|
| 305 |
+
['CE2', 5, (2.186, -0.678, -0.007)],
|
| 306 |
+
['CE3', 5, (0.622, -2.530, -0.007)],
|
| 307 |
+
['NE1', 5, (2.140, 0.690, -0.004)],
|
| 308 |
+
['CH2', 5, (3.028, -2.890, -0.013)],
|
| 309 |
+
['CZ2', 5, (3.283, -1.543, -0.011)],
|
| 310 |
+
['CZ3', 5, (1.715, -3.389, -0.011)],
|
| 311 |
+
],
|
| 312 |
+
'TYR': [
|
| 313 |
+
['N', 0, (-0.522, 1.362, 0.000)],
|
| 314 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 315 |
+
['C', 0, (1.524, -0.000, -0.000)],
|
| 316 |
+
['CB', 0, (-0.522, -0.776, -1.213)],
|
| 317 |
+
['O', 3, (0.627, 1.062, -0.000)],
|
| 318 |
+
['CG', 4, (0.607, 1.382, -0.000)],
|
| 319 |
+
['CD1', 5, (0.716, 1.195, -0.000)],
|
| 320 |
+
['CD2', 5, (0.713, -1.194, -0.001)],
|
| 321 |
+
['CE1', 5, (2.107, 1.200, -0.002)],
|
| 322 |
+
['CE2', 5, (2.104, -1.201, -0.003)],
|
| 323 |
+
['OH', 5, (4.168, -0.002, -0.005)],
|
| 324 |
+
['CZ', 5, (2.791, -0.001, -0.003)],
|
| 325 |
+
],
|
| 326 |
+
'VAL': [
|
| 327 |
+
['N', 0, (-0.494, 1.373, -0.000)],
|
| 328 |
+
['CA', 0, (0.000, 0.000, 0.000)],
|
| 329 |
+
['C', 0, (1.527, -0.000, -0.000)],
|
| 330 |
+
['CB', 0, (-0.533, -0.795, -1.213)],
|
| 331 |
+
['O', 3, (0.627, 1.062, -0.000)],
|
| 332 |
+
['CG1', 4, (0.540, 1.429, -0.000)],
|
| 333 |
+
['CG2', 4, (0.533, -0.776, 1.203)],
|
| 334 |
+
],
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
|
| 338 |
+
residue_atoms = {
|
| 339 |
+
'ALA': ['C', 'CA', 'CB', 'N', 'O'],
|
| 340 |
+
'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'],
|
| 341 |
+
'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'],
|
| 342 |
+
'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'],
|
| 343 |
+
'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'],
|
| 344 |
+
'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'],
|
| 345 |
+
'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'],
|
| 346 |
+
'GLY': ['C', 'CA', 'N', 'O'],
|
| 347 |
+
'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'],
|
| 348 |
+
'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'],
|
| 349 |
+
'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'],
|
| 350 |
+
'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'],
|
| 351 |
+
'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'],
|
| 352 |
+
'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'],
|
| 353 |
+
'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'],
|
| 354 |
+
'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'],
|
| 355 |
+
'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'],
|
| 356 |
+
'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3',
|
| 357 |
+
'CH2', 'N', 'NE1', 'O'],
|
| 358 |
+
'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O',
|
| 359 |
+
'OH'],
|
| 360 |
+
'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O']
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
# Naming swaps for ambiguous atom names.
|
| 364 |
+
# Due to symmetries in the amino acids the naming of atoms is ambiguous in
|
| 365 |
+
# 4 of the 20 amino acids.
|
| 366 |
+
# (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
|
| 367 |
+
# in LEU, VAL and ARG can be resolved by using the 3d constellations of
|
| 368 |
+
# the 'ambiguous' atoms and their neighbours)
|
| 369 |
+
residue_atom_renaming_swaps = {
|
| 370 |
+
'ASP': {'OD1': 'OD2'},
|
| 371 |
+
'GLU': {'OE1': 'OE2'},
|
| 372 |
+
'PHE': {'CD1': 'CD2', 'CE1': 'CE2'},
|
| 373 |
+
'TYR': {'CD1': 'CD2', 'CE1': 'CE2'},
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
# Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
|
| 377 |
+
van_der_waals_radius = {
|
| 378 |
+
'C': 1.7,
|
| 379 |
+
'N': 1.55,
|
| 380 |
+
'O': 1.52,
|
| 381 |
+
'S': 1.8,
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
Bond = collections.namedtuple(
|
| 385 |
+
'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev'])
|
| 386 |
+
BondAngle = collections.namedtuple(
|
| 387 |
+
'BondAngle',
|
| 388 |
+
['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev'])
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@functools.lru_cache(maxsize=None)
|
| 392 |
+
def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
|
| 393 |
+
Mapping[str, List[Bond]],
|
| 394 |
+
Mapping[str, List[BondAngle]]]:
|
| 395 |
+
"""Load stereo_chemical_props.txt into a nice structure.
|
| 396 |
+
|
| 397 |
+
Load literature values for bond lengths and bond angles and translate
|
| 398 |
+
bond angles into the length of the opposite edge of the triangle
|
| 399 |
+
("residue_virtual_bonds").
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
residue_bonds: Dict that maps resname -> list of Bond tuples.
|
| 403 |
+
residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
|
| 404 |
+
residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
|
| 405 |
+
"""
|
| 406 |
+
stereo_chemical_props_path = os.path.join(
|
| 407 |
+
os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt'
|
| 408 |
+
)
|
| 409 |
+
with open(stereo_chemical_props_path, 'rt') as f:
|
| 410 |
+
stereo_chemical_props = f.read()
|
| 411 |
+
lines_iter = iter(stereo_chemical_props.splitlines())
|
| 412 |
+
# Load bond lengths.
|
| 413 |
+
residue_bonds = {}
|
| 414 |
+
next(lines_iter) # Skip header line.
|
| 415 |
+
for line in lines_iter:
|
| 416 |
+
if line.strip() == '-':
|
| 417 |
+
break
|
| 418 |
+
bond, resname, length, stddev = line.split()
|
| 419 |
+
atom1, atom2 = bond.split('-')
|
| 420 |
+
if resname not in residue_bonds:
|
| 421 |
+
residue_bonds[resname] = []
|
| 422 |
+
residue_bonds[resname].append(
|
| 423 |
+
Bond(atom1, atom2, float(length), float(stddev)))
|
| 424 |
+
residue_bonds['UNK'] = []
|
| 425 |
+
|
| 426 |
+
# Load bond angles.
|
| 427 |
+
residue_bond_angles = {}
|
| 428 |
+
next(lines_iter) # Skip empty line.
|
| 429 |
+
next(lines_iter) # Skip header line.
|
| 430 |
+
for line in lines_iter:
|
| 431 |
+
if line.strip() == '-':
|
| 432 |
+
break
|
| 433 |
+
bond, resname, angle_degree, stddev_degree = line.split()
|
| 434 |
+
atom1, atom2, atom3 = bond.split('-')
|
| 435 |
+
if resname not in residue_bond_angles:
|
| 436 |
+
residue_bond_angles[resname] = []
|
| 437 |
+
residue_bond_angles[resname].append(
|
| 438 |
+
BondAngle(atom1, atom2, atom3,
|
| 439 |
+
float(angle_degree) / 180. * np.pi,
|
| 440 |
+
float(stddev_degree) / 180. * np.pi))
|
| 441 |
+
residue_bond_angles['UNK'] = []
|
| 442 |
+
|
| 443 |
+
def make_bond_key(atom1_name, atom2_name):
|
| 444 |
+
"""Unique key to lookup bonds."""
|
| 445 |
+
return '-'.join(sorted([atom1_name, atom2_name]))
|
| 446 |
+
|
| 447 |
+
# Translate bond angles into distances ("virtual bonds").
|
| 448 |
+
residue_virtual_bonds = {}
|
| 449 |
+
for resname, bond_angles in residue_bond_angles.items():
|
| 450 |
+
# Create a fast lookup dict for bond lengths.
|
| 451 |
+
bond_cache = {}
|
| 452 |
+
for b in residue_bonds[resname]:
|
| 453 |
+
bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
|
| 454 |
+
residue_virtual_bonds[resname] = []
|
| 455 |
+
for ba in bond_angles:
|
| 456 |
+
bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
|
| 457 |
+
bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
|
| 458 |
+
|
| 459 |
+
# Compute distance between atom1 and atom3 using the law of cosines
|
| 460 |
+
# c^2 = a^2 + b^2 - 2ab*cos(gamma).
|
| 461 |
+
gamma = ba.angle_rad
|
| 462 |
+
length = np.sqrt(bond1.length**2 + bond2.length**2
|
| 463 |
+
- 2 * bond1.length * bond2.length * np.cos(gamma))
|
| 464 |
+
|
| 465 |
+
# Propagation of uncertainty assuming uncorrelated errors.
|
| 466 |
+
dl_outer = 0.5 / length
|
| 467 |
+
dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
|
| 468 |
+
dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
|
| 469 |
+
dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
|
| 470 |
+
stddev = np.sqrt((dl_dgamma * ba.stddev)**2 +
|
| 471 |
+
(dl_db1 * bond1.stddev)**2 +
|
| 472 |
+
(dl_db2 * bond2.stddev)**2)
|
| 473 |
+
residue_virtual_bonds[resname].append(
|
| 474 |
+
Bond(ba.atom1_name, ba.atom3name, length, stddev))
|
| 475 |
+
|
| 476 |
+
return (residue_bonds,
|
| 477 |
+
residue_virtual_bonds,
|
| 478 |
+
residue_bond_angles)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# Between-residue bond lengths for general bonds (first element) and for Proline
|
| 482 |
+
# (second element).
|
| 483 |
+
between_res_bond_length_c_n = [1.329, 1.341]
|
| 484 |
+
between_res_bond_length_stddev_c_n = [0.014, 0.016]
|
| 485 |
+
|
| 486 |
+
# Between-residue cos_angles.
|
| 487 |
+
between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
|
| 488 |
+
between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
|
| 489 |
+
|
| 490 |
+
# This mapping is used when we need to store atom data in a format that requires
|
| 491 |
+
# fixed atom data size for every residue (e.g. a numpy array).
|
| 492 |
+
atom_types = [
|
| 493 |
+
'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
|
| 494 |
+
'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
|
| 495 |
+
'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
|
| 496 |
+
'CZ3', 'NZ', 'OXT'
|
| 497 |
+
]
|
| 498 |
+
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
|
| 499 |
+
atom_type_num = len(atom_types) # := 37.
|
| 500 |
+
|
| 501 |
+
# A compact atom encoding with 14 columns
|
| 502 |
+
# pylint: disable=line-too-long
|
| 503 |
+
# pylint: disable=bad-whitespace
|
| 504 |
+
restype_name_to_atom14_names = {
|
| 505 |
+
'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
|
| 506 |
+
'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
|
| 507 |
+
'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
|
| 508 |
+
'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
|
| 509 |
+
'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
|
| 510 |
+
'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
|
| 511 |
+
'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
|
| 512 |
+
'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
|
| 513 |
+
'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
|
| 514 |
+
'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
|
| 515 |
+
'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
|
| 516 |
+
'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
|
| 517 |
+
'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
|
| 518 |
+
'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
|
| 519 |
+
'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
|
| 520 |
+
'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
|
| 521 |
+
'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
|
| 522 |
+
'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
|
| 523 |
+
'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
|
| 524 |
+
'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
|
| 525 |
+
'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
|
| 526 |
+
|
| 527 |
+
}
|
| 528 |
+
# pylint: enable=line-too-long
|
| 529 |
+
# pylint: enable=bad-whitespace
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
# This is the standard residue order when coding AA type as a number.
|
| 533 |
+
# Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
|
| 534 |
+
restypes = [
|
| 535 |
+
'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
|
| 536 |
+
'S', 'T', 'W', 'Y', 'V'
|
| 537 |
+
]
|
| 538 |
+
restype_order = {restype: i for i, restype in enumerate(restypes)}
|
| 539 |
+
restype_num = len(restypes) # := 20.
|
| 540 |
+
unk_restype_index = restype_num # Catch-all index for unknown restypes.
|
| 541 |
+
|
| 542 |
+
restypes_with_x = restypes + ['X']
|
| 543 |
+
restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def sequence_to_onehot(
|
| 547 |
+
sequence: str,
|
| 548 |
+
mapping: Mapping[str, int],
|
| 549 |
+
map_unknown_to_x: bool = False) -> np.ndarray:
|
| 550 |
+
"""Maps the given sequence into a one-hot encoded matrix.
|
| 551 |
+
|
| 552 |
+
Args:
|
| 553 |
+
sequence: An amino acid sequence.
|
| 554 |
+
mapping: A dictionary mapping amino acids to integers.
|
| 555 |
+
map_unknown_to_x: If True, any amino acid that is not in the mapping will be
|
| 556 |
+
mapped to the unknown amino acid 'X'. If the mapping doesn't contain
|
| 557 |
+
amino acid 'X', an error will be thrown. If False, any amino acid not in
|
| 558 |
+
the mapping will throw an error.
|
| 559 |
+
|
| 560 |
+
Returns:
|
| 561 |
+
A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
|
| 562 |
+
the sequence.
|
| 563 |
+
|
| 564 |
+
Raises:
|
| 565 |
+
ValueError: If the mapping doesn't contain values from 0 to
|
| 566 |
+
num_unique_aas - 1 without any gaps.
|
| 567 |
+
"""
|
| 568 |
+
num_entries = max(mapping.values()) + 1
|
| 569 |
+
|
| 570 |
+
if sorted(set(mapping.values())) != list(range(num_entries)):
|
| 571 |
+
raise ValueError('The mapping must have values from 0 to num_unique_aas-1 '
|
| 572 |
+
'without any gaps. Got: %s' % sorted(mapping.values()))
|
| 573 |
+
|
| 574 |
+
one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int)
|
| 575 |
+
|
| 576 |
+
for aa_index, aa_type in enumerate(sequence):
|
| 577 |
+
if map_unknown_to_x:
|
| 578 |
+
if aa_type.isalpha() and aa_type.isupper():
|
| 579 |
+
aa_id = mapping.get(aa_type, mapping['X'])
|
| 580 |
+
else:
|
| 581 |
+
raise ValueError(f'Invalid character in the sequence: {aa_type}')
|
| 582 |
+
else:
|
| 583 |
+
aa_id = mapping[aa_type]
|
| 584 |
+
one_hot_arr[aa_index, aa_id] = 1
|
| 585 |
+
|
| 586 |
+
return one_hot_arr
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
restype_1to3 = {
|
| 590 |
+
'A': 'ALA',
|
| 591 |
+
'R': 'ARG',
|
| 592 |
+
'N': 'ASN',
|
| 593 |
+
'D': 'ASP',
|
| 594 |
+
'C': 'CYS',
|
| 595 |
+
'Q': 'GLN',
|
| 596 |
+
'E': 'GLU',
|
| 597 |
+
'G': 'GLY',
|
| 598 |
+
'H': 'HIS',
|
| 599 |
+
'I': 'ILE',
|
| 600 |
+
'L': 'LEU',
|
| 601 |
+
'K': 'LYS',
|
| 602 |
+
'M': 'MET',
|
| 603 |
+
'F': 'PHE',
|
| 604 |
+
'P': 'PRO',
|
| 605 |
+
'S': 'SER',
|
| 606 |
+
'T': 'THR',
|
| 607 |
+
'W': 'TRP',
|
| 608 |
+
'Y': 'TYR',
|
| 609 |
+
'V': 'VAL',
|
| 610 |
+
}
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
# NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
|
| 614 |
+
# 1-to-1 mapping of 3 letter names to one letter names. The latter contains
|
| 615 |
+
# many more, and less common, three letter names as keys and maps many of these
|
| 616 |
+
# to the same one letter name (including 'X' and 'U' which we don't use here).
|
| 617 |
+
restype_3to1 = {v: k for k, v in restype_1to3.items()}
|
| 618 |
+
|
| 619 |
+
# Define a restype name for all unknown residues.
|
| 620 |
+
unk_restype = 'UNK'
|
| 621 |
+
|
| 622 |
+
resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
|
| 623 |
+
resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
# The mapping here uses hhblits convention, so that B is mapped to D, J and O
|
| 627 |
+
# are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
|
| 628 |
+
# remaining 20 amino acids are kept in alphabetical order.
|
| 629 |
+
# There are 2 non-amino acid codes, X (representing any amino acid) and
|
| 630 |
+
# "-" representing a missing amino acid in an alignment. The id for these
|
| 631 |
+
# codes is put at the end (20 and 21) so that they can easily be ignored if
|
| 632 |
+
# desired.
|
| 633 |
+
HHBLITS_AA_TO_ID = {
|
| 634 |
+
'A': 0,
|
| 635 |
+
'B': 2,
|
| 636 |
+
'C': 1,
|
| 637 |
+
'D': 2,
|
| 638 |
+
'E': 3,
|
| 639 |
+
'F': 4,
|
| 640 |
+
'G': 5,
|
| 641 |
+
'H': 6,
|
| 642 |
+
'I': 7,
|
| 643 |
+
'J': 20,
|
| 644 |
+
'K': 8,
|
| 645 |
+
'L': 9,
|
| 646 |
+
'M': 10,
|
| 647 |
+
'N': 11,
|
| 648 |
+
'O': 20,
|
| 649 |
+
'P': 12,
|
| 650 |
+
'Q': 13,
|
| 651 |
+
'R': 14,
|
| 652 |
+
'S': 15,
|
| 653 |
+
'T': 16,
|
| 654 |
+
'U': 1,
|
| 655 |
+
'V': 17,
|
| 656 |
+
'W': 18,
|
| 657 |
+
'X': 20,
|
| 658 |
+
'Y': 19,
|
| 659 |
+
'Z': 3,
|
| 660 |
+
'-': 21,
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
# Partial inversion of HHBLITS_AA_TO_ID.
|
| 664 |
+
ID_TO_HHBLITS_AA = {
|
| 665 |
+
0: 'A',
|
| 666 |
+
1: 'C', # Also U.
|
| 667 |
+
2: 'D', # Also B.
|
| 668 |
+
3: 'E', # Also Z.
|
| 669 |
+
4: 'F',
|
| 670 |
+
5: 'G',
|
| 671 |
+
6: 'H',
|
| 672 |
+
7: 'I',
|
| 673 |
+
8: 'K',
|
| 674 |
+
9: 'L',
|
| 675 |
+
10: 'M',
|
| 676 |
+
11: 'N',
|
| 677 |
+
12: 'P',
|
| 678 |
+
13: 'Q',
|
| 679 |
+
14: 'R',
|
| 680 |
+
15: 'S',
|
| 681 |
+
16: 'T',
|
| 682 |
+
17: 'V',
|
| 683 |
+
18: 'W',
|
| 684 |
+
19: 'Y',
|
| 685 |
+
20: 'X', # Includes J and O.
|
| 686 |
+
21: '-',
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
restypes_with_x_and_gap = restypes + ['X', '-']
|
| 690 |
+
MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
|
| 691 |
+
restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
|
| 692 |
+
for i in range(len(restypes_with_x_and_gap)))
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
def _make_standard_atom_mask() -> np.ndarray:
|
| 696 |
+
"""Returns [num_res_types, num_atom_types] mask array."""
|
| 697 |
+
# +1 to account for unknown (all 0s).
|
| 698 |
+
mask = np.zeros([restype_num + 1, atom_type_num], dtype=int)
|
| 699 |
+
for restype, restype_letter in enumerate(restypes):
|
| 700 |
+
restype_name = restype_1to3[restype_letter]
|
| 701 |
+
atom_names = residue_atoms[restype_name]
|
| 702 |
+
for atom_name in atom_names:
|
| 703 |
+
atom_type = atom_order[atom_name]
|
| 704 |
+
mask[restype, atom_type] = 1
|
| 705 |
+
return mask
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
STANDARD_ATOM_MASK = _make_standard_atom_mask()
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
# A one hot representation for the first and second atoms defining the axis
|
| 712 |
+
# of rotation for each chi-angle in each residue.
|
| 713 |
+
def chi_angle_atom(atom_index: int) -> np.ndarray:
|
| 714 |
+
"""Define chi-angle rigid groups via one-hot representations."""
|
| 715 |
+
chi_angles_index = {}
|
| 716 |
+
one_hots = []
|
| 717 |
+
|
| 718 |
+
for k, v in chi_angles_atoms.items():
|
| 719 |
+
indices = [atom_types.index(s[atom_index]) for s in v]
|
| 720 |
+
indices.extend([-1]*(4-len(indices)))
|
| 721 |
+
chi_angles_index[k] = indices
|
| 722 |
+
|
| 723 |
+
for r in restypes:
|
| 724 |
+
res3 = restype_1to3[r]
|
| 725 |
+
one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
|
| 726 |
+
one_hots.append(one_hot)
|
| 727 |
+
|
| 728 |
+
one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
|
| 729 |
+
one_hot = np.stack(one_hots, axis=0)
|
| 730 |
+
one_hot = np.transpose(one_hot, [0, 2, 1])
|
| 731 |
+
|
| 732 |
+
return one_hot
|
| 733 |
+
|
| 734 |
+
chi_atom_1_one_hot = chi_angle_atom(1)
|
| 735 |
+
chi_atom_2_one_hot = chi_angle_atom(2)
|
| 736 |
+
|
| 737 |
+
# An array like chi_angles_atoms but using indices rather than names.
|
| 738 |
+
chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
|
| 739 |
+
chi_angles_atom_indices = tree.map_structure(
|
| 740 |
+
lambda atom_name: atom_order[atom_name], chi_angles_atom_indices)
|
| 741 |
+
chi_angles_atom_indices = np.array([
|
| 742 |
+
chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
|
| 743 |
+
for chi_atoms in chi_angles_atom_indices])
|
| 744 |
+
|
| 745 |
+
# Mapping from (res_name, atom_name) pairs to the atom's chi group index
|
| 746 |
+
# and atom index within that group.
|
| 747 |
+
chi_groups_for_atom = collections.defaultdict(list)
|
| 748 |
+
for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
|
| 749 |
+
for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
|
| 750 |
+
for atom_i, atom in enumerate(chi_group):
|
| 751 |
+
chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
|
| 752 |
+
chi_groups_for_atom = dict(chi_groups_for_atom)
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def _make_rigid_transformation_4x4(ex, ey, translation):
|
| 756 |
+
"""Create a rigid 4x4 transformation matrix from two axes and transl."""
|
| 757 |
+
# Normalize ex.
|
| 758 |
+
ex_normalized = ex / np.linalg.norm(ex)
|
| 759 |
+
|
| 760 |
+
# make ey perpendicular to ex
|
| 761 |
+
ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
|
| 762 |
+
ey_normalized /= np.linalg.norm(ey_normalized)
|
| 763 |
+
|
| 764 |
+
# compute ez as cross product
|
| 765 |
+
eznorm = np.cross(ex_normalized, ey_normalized)
|
| 766 |
+
m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
|
| 767 |
+
m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0)
|
| 768 |
+
return m
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
# create an array with (restype, atomtype) --> rigid_group_idx
|
| 772 |
+
# and an array with (restype, atomtype, coord) for the atom positions
|
| 773 |
+
# and compute affine transformation matrices (4,4) from one rigid group to the
|
| 774 |
+
# previous group
|
| 775 |
+
restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
|
| 776 |
+
restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
|
| 777 |
+
restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
|
| 778 |
+
restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
|
| 779 |
+
restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
|
| 780 |
+
restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
|
| 781 |
+
restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
def _make_rigid_group_constants():
|
| 785 |
+
"""Fill the arrays above."""
|
| 786 |
+
for restype, restype_letter in enumerate(restypes):
|
| 787 |
+
resname = restype_1to3[restype_letter]
|
| 788 |
+
for atomname, group_idx, atom_position in rigid_group_atom_positions[
|
| 789 |
+
resname]:
|
| 790 |
+
atomtype = atom_order[atomname]
|
| 791 |
+
restype_atom37_to_rigid_group[restype, atomtype] = group_idx
|
| 792 |
+
restype_atom37_mask[restype, atomtype] = 1
|
| 793 |
+
restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
|
| 794 |
+
|
| 795 |
+
atom14idx = restype_name_to_atom14_names[resname].index(atomname)
|
| 796 |
+
restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
|
| 797 |
+
restype_atom14_mask[restype, atom14idx] = 1
|
| 798 |
+
restype_atom14_rigid_group_positions[restype,
|
| 799 |
+
atom14idx, :] = atom_position
|
| 800 |
+
|
| 801 |
+
for restype, restype_letter in enumerate(restypes):
|
| 802 |
+
resname = restype_1to3[restype_letter]
|
| 803 |
+
atom_positions = {name: np.array(pos) for name, _, pos
|
| 804 |
+
in rigid_group_atom_positions[resname]}
|
| 805 |
+
|
| 806 |
+
# backbone to backbone is the identity transform
|
| 807 |
+
restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
|
| 808 |
+
|
| 809 |
+
# pre-omega-frame to backbone (currently dummy identity matrix)
|
| 810 |
+
restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
|
| 811 |
+
|
| 812 |
+
# phi-frame to backbone
|
| 813 |
+
mat = _make_rigid_transformation_4x4(
|
| 814 |
+
ex=atom_positions['N'] - atom_positions['CA'],
|
| 815 |
+
ey=np.array([1., 0., 0.]),
|
| 816 |
+
translation=atom_positions['N'])
|
| 817 |
+
restype_rigid_group_default_frame[restype, 2, :, :] = mat
|
| 818 |
+
|
| 819 |
+
# psi-frame to backbone
|
| 820 |
+
mat = _make_rigid_transformation_4x4(
|
| 821 |
+
ex=atom_positions['C'] - atom_positions['CA'],
|
| 822 |
+
ey=atom_positions['CA'] - atom_positions['N'],
|
| 823 |
+
translation=atom_positions['C'])
|
| 824 |
+
restype_rigid_group_default_frame[restype, 3, :, :] = mat
|
| 825 |
+
|
| 826 |
+
# chi1-frame to backbone
|
| 827 |
+
if chi_angles_mask[restype][0]:
|
| 828 |
+
base_atom_names = chi_angles_atoms[resname][0]
|
| 829 |
+
base_atom_positions = [atom_positions[name] for name in base_atom_names]
|
| 830 |
+
mat = _make_rigid_transformation_4x4(
|
| 831 |
+
ex=base_atom_positions[2] - base_atom_positions[1],
|
| 832 |
+
ey=base_atom_positions[0] - base_atom_positions[1],
|
| 833 |
+
translation=base_atom_positions[2])
|
| 834 |
+
restype_rigid_group_default_frame[restype, 4, :, :] = mat
|
| 835 |
+
|
| 836 |
+
# chi2-frame to chi1-frame
|
| 837 |
+
# chi3-frame to chi2-frame
|
| 838 |
+
# chi4-frame to chi3-frame
|
| 839 |
+
# luckily all rotation axes for the next frame start at (0,0,0) of the
|
| 840 |
+
# previous frame
|
| 841 |
+
for chi_idx in range(1, 4):
|
| 842 |
+
if chi_angles_mask[restype][chi_idx]:
|
| 843 |
+
axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
|
| 844 |
+
axis_end_atom_position = atom_positions[axis_end_atom_name]
|
| 845 |
+
mat = _make_rigid_transformation_4x4(
|
| 846 |
+
ex=axis_end_atom_position,
|
| 847 |
+
ey=np.array([-1., 0., 0.]),
|
| 848 |
+
translation=axis_end_atom_position)
|
| 849 |
+
restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
_make_rigid_group_constants()
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
def make_atom14_dists_bounds(overlap_tolerance=1.5,
|
| 856 |
+
bond_length_tolerance_factor=15):
|
| 857 |
+
"""compute upper and lower bounds for bonds to assess violations."""
|
| 858 |
+
restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
|
| 859 |
+
restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
|
| 860 |
+
restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
|
| 861 |
+
residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
|
| 862 |
+
for restype, restype_letter in enumerate(restypes):
|
| 863 |
+
resname = restype_1to3[restype_letter]
|
| 864 |
+
atom_list = restype_name_to_atom14_names[resname]
|
| 865 |
+
|
| 866 |
+
# create lower and upper bounds for clashes
|
| 867 |
+
for atom1_idx, atom1_name in enumerate(atom_list):
|
| 868 |
+
if not atom1_name:
|
| 869 |
+
continue
|
| 870 |
+
atom1_radius = van_der_waals_radius[atom1_name[0]]
|
| 871 |
+
for atom2_idx, atom2_name in enumerate(atom_list):
|
| 872 |
+
if (not atom2_name) or atom1_idx == atom2_idx:
|
| 873 |
+
continue
|
| 874 |
+
atom2_radius = van_der_waals_radius[atom2_name[0]]
|
| 875 |
+
lower = atom1_radius + atom2_radius - overlap_tolerance
|
| 876 |
+
upper = 1e10
|
| 877 |
+
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
|
| 878 |
+
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
|
| 879 |
+
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
|
| 880 |
+
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
|
| 881 |
+
|
| 882 |
+
# overwrite lower and upper bounds for bonds and angles
|
| 883 |
+
for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
|
| 884 |
+
atom1_idx = atom_list.index(b.atom1_name)
|
| 885 |
+
atom2_idx = atom_list.index(b.atom2_name)
|
| 886 |
+
lower = b.length - bond_length_tolerance_factor * b.stddev
|
| 887 |
+
upper = b.length + bond_length_tolerance_factor * b.stddev
|
| 888 |
+
restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
|
| 889 |
+
restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
|
| 890 |
+
restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
|
| 891 |
+
restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
|
| 892 |
+
restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
|
| 893 |
+
restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
|
| 894 |
+
return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14)
|
| 895 |
+
'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14)
|
| 896 |
+
'stddev': restype_atom14_bond_stddev, # shape (21,14,14)
|
| 897 |
+
}
|
analysis/src/common/rigid_utils.py
ADDED
|
@@ -0,0 +1,1451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2021 AlQuraishi Laboratory
|
| 2 |
+
# Copyright 2021 DeepMind Technologies Limited
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Tuple, Any, Sequence, Callable, Optional
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from src.common import rotation3d
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def rot_matmul(
|
| 25 |
+
a: torch.Tensor,
|
| 26 |
+
b: torch.Tensor
|
| 27 |
+
) -> torch.Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Performs matrix multiplication of two rotation matrix tensors. Written
|
| 30 |
+
out by hand to avoid AMP downcasting.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
a: [*, 3, 3] left multiplicand
|
| 34 |
+
b: [*, 3, 3] right multiplicand
|
| 35 |
+
Returns:
|
| 36 |
+
The product ab
|
| 37 |
+
"""
|
| 38 |
+
row_1 = torch.stack(
|
| 39 |
+
[
|
| 40 |
+
a[..., 0, 0] * b[..., 0, 0]
|
| 41 |
+
+ a[..., 0, 1] * b[..., 1, 0]
|
| 42 |
+
+ a[..., 0, 2] * b[..., 2, 0],
|
| 43 |
+
a[..., 0, 0] * b[..., 0, 1]
|
| 44 |
+
+ a[..., 0, 1] * b[..., 1, 1]
|
| 45 |
+
+ a[..., 0, 2] * b[..., 2, 1],
|
| 46 |
+
a[..., 0, 0] * b[..., 0, 2]
|
| 47 |
+
+ a[..., 0, 1] * b[..., 1, 2]
|
| 48 |
+
+ a[..., 0, 2] * b[..., 2, 2],
|
| 49 |
+
],
|
| 50 |
+
dim=-1,
|
| 51 |
+
)
|
| 52 |
+
row_2 = torch.stack(
|
| 53 |
+
[
|
| 54 |
+
a[..., 1, 0] * b[..., 0, 0]
|
| 55 |
+
+ a[..., 1, 1] * b[..., 1, 0]
|
| 56 |
+
+ a[..., 1, 2] * b[..., 2, 0],
|
| 57 |
+
a[..., 1, 0] * b[..., 0, 1]
|
| 58 |
+
+ a[..., 1, 1] * b[..., 1, 1]
|
| 59 |
+
+ a[..., 1, 2] * b[..., 2, 1],
|
| 60 |
+
a[..., 1, 0] * b[..., 0, 2]
|
| 61 |
+
+ a[..., 1, 1] * b[..., 1, 2]
|
| 62 |
+
+ a[..., 1, 2] * b[..., 2, 2],
|
| 63 |
+
],
|
| 64 |
+
dim=-1,
|
| 65 |
+
)
|
| 66 |
+
row_3 = torch.stack(
|
| 67 |
+
[
|
| 68 |
+
a[..., 2, 0] * b[..., 0, 0]
|
| 69 |
+
+ a[..., 2, 1] * b[..., 1, 0]
|
| 70 |
+
+ a[..., 2, 2] * b[..., 2, 0],
|
| 71 |
+
a[..., 2, 0] * b[..., 0, 1]
|
| 72 |
+
+ a[..., 2, 1] * b[..., 1, 1]
|
| 73 |
+
+ a[..., 2, 2] * b[..., 2, 1],
|
| 74 |
+
a[..., 2, 0] * b[..., 0, 2]
|
| 75 |
+
+ a[..., 2, 1] * b[..., 1, 2]
|
| 76 |
+
+ a[..., 2, 2] * b[..., 2, 2],
|
| 77 |
+
],
|
| 78 |
+
dim=-1,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return torch.stack([row_1, row_2, row_3], dim=-2)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def rot_vec_mul(
|
| 85 |
+
r: torch.Tensor,
|
| 86 |
+
t: torch.Tensor
|
| 87 |
+
) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Applies a rotation to a vector. Written out by hand to avoid transfer
|
| 90 |
+
to avoid AMP downcasting.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
r: [*, 3, 3] rotation matrices
|
| 94 |
+
t: [*, 3] coordinate tensors
|
| 95 |
+
Returns:
|
| 96 |
+
[*, 3] rotated coordinates
|
| 97 |
+
"""
|
| 98 |
+
x = t[..., 0]
|
| 99 |
+
y = t[..., 1]
|
| 100 |
+
z = t[..., 2]
|
| 101 |
+
return torch.stack(
|
| 102 |
+
[
|
| 103 |
+
r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
|
| 104 |
+
r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
|
| 105 |
+
r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
|
| 106 |
+
],
|
| 107 |
+
dim=-1,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def identity_rot_mats(
|
| 112 |
+
batch_dims: Tuple[int],
|
| 113 |
+
dtype: Optional[torch.dtype] = None,
|
| 114 |
+
device: Optional[torch.device] = None,
|
| 115 |
+
requires_grad: bool = True,
|
| 116 |
+
) -> torch.Tensor:
|
| 117 |
+
rots = torch.eye(
|
| 118 |
+
3, dtype=dtype, device=device, requires_grad=requires_grad
|
| 119 |
+
)
|
| 120 |
+
rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
|
| 121 |
+
rots = rots.expand(*batch_dims, -1, -1)
|
| 122 |
+
|
| 123 |
+
return rots
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def identity_trans(
|
| 127 |
+
batch_dims: Tuple[int],
|
| 128 |
+
dtype: Optional[torch.dtype] = None,
|
| 129 |
+
device: Optional[torch.device] = None,
|
| 130 |
+
requires_grad: bool = True,
|
| 131 |
+
) -> torch.Tensor:
|
| 132 |
+
trans = torch.zeros(
|
| 133 |
+
(*batch_dims, 3),
|
| 134 |
+
dtype=dtype,
|
| 135 |
+
device=device,
|
| 136 |
+
requires_grad=requires_grad
|
| 137 |
+
)
|
| 138 |
+
return trans
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def identity_quats(
|
| 142 |
+
batch_dims: Tuple[int],
|
| 143 |
+
dtype: Optional[torch.dtype] = None,
|
| 144 |
+
device: Optional[torch.device] = None,
|
| 145 |
+
requires_grad: bool = True,
|
| 146 |
+
) -> torch.Tensor:
|
| 147 |
+
quat = torch.zeros(
|
| 148 |
+
(*batch_dims, 4),
|
| 149 |
+
dtype=dtype,
|
| 150 |
+
device=device,
|
| 151 |
+
requires_grad=requires_grad
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
with torch.no_grad():
|
| 155 |
+
quat[..., 0] = 1
|
| 156 |
+
|
| 157 |
+
return quat
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
_quat_elements = ["a", "b", "c", "d"]
|
| 161 |
+
_qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
|
| 162 |
+
_qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _to_mat(pairs):
|
| 166 |
+
mat = np.zeros((4, 4))
|
| 167 |
+
for pair in pairs:
|
| 168 |
+
key, value = pair
|
| 169 |
+
ind = _qtr_ind_dict[key]
|
| 170 |
+
mat[ind // 4][ind % 4] = value
|
| 171 |
+
|
| 172 |
+
return mat
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
_QTR_MAT = np.zeros((4, 4, 3, 3))
|
| 176 |
+
_QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
|
| 177 |
+
_QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
|
| 178 |
+
_QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
|
| 179 |
+
_QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
|
| 180 |
+
_QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
|
| 181 |
+
_QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
|
| 182 |
+
_QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
|
| 183 |
+
_QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
|
| 184 |
+
_QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""
|
| 189 |
+
Converts a quaternion to a rotation matrix.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
quat: [*, 4] quaternions
|
| 193 |
+
Returns:
|
| 194 |
+
[*, 3, 3] rotation matrices
|
| 195 |
+
"""
|
| 196 |
+
# [*, 4, 4]
|
| 197 |
+
quat = quat[..., None] * quat[..., None, :]
|
| 198 |
+
|
| 199 |
+
# [4, 4, 3, 3]
|
| 200 |
+
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
|
| 201 |
+
|
| 202 |
+
# [*, 4, 4, 3, 3]
|
| 203 |
+
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
|
| 204 |
+
quat = quat[..., None, None] * shaped_qtr_mat
|
| 205 |
+
|
| 206 |
+
# [*, 3, 3]
|
| 207 |
+
return torch.sum(quat, dim=(-3, -4))
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def rot_to_quat(
|
| 211 |
+
rot: torch.Tensor,
|
| 212 |
+
):
|
| 213 |
+
if(rot.shape[-2:] != (3, 3)):
|
| 214 |
+
raise ValueError("Input rotation is incorrectly shaped")
|
| 215 |
+
|
| 216 |
+
rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
|
| 217 |
+
[[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
|
| 218 |
+
|
| 219 |
+
k = [
|
| 220 |
+
[ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
|
| 221 |
+
[ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
|
| 222 |
+
[ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
|
| 223 |
+
[ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
|
| 224 |
+
]
|
| 225 |
+
|
| 226 |
+
k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
|
| 227 |
+
|
| 228 |
+
_, vectors = torch.linalg.eigh(k)
|
| 229 |
+
return vectors[..., -1]
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
_QUAT_MULTIPLY = np.zeros((4, 4, 4))
|
| 233 |
+
_QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
|
| 234 |
+
[ 0,-1, 0, 0],
|
| 235 |
+
[ 0, 0,-1, 0],
|
| 236 |
+
[ 0, 0, 0,-1]]
|
| 237 |
+
|
| 238 |
+
_QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
|
| 239 |
+
[ 1, 0, 0, 0],
|
| 240 |
+
[ 0, 0, 0, 1],
|
| 241 |
+
[ 0, 0,-1, 0]]
|
| 242 |
+
|
| 243 |
+
_QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
|
| 244 |
+
[ 0, 0, 0,-1],
|
| 245 |
+
[ 1, 0, 0, 0],
|
| 246 |
+
[ 0, 1, 0, 0]]
|
| 247 |
+
|
| 248 |
+
_QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
|
| 249 |
+
[ 0, 0, 1, 0],
|
| 250 |
+
[ 0,-1, 0, 0],
|
| 251 |
+
[ 1, 0, 0, 0]]
|
| 252 |
+
|
| 253 |
+
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def quat_multiply(quat1, quat2):
|
| 257 |
+
"""Multiply a quaternion by another quaternion."""
|
| 258 |
+
mat = quat1.new_tensor(_QUAT_MULTIPLY)
|
| 259 |
+
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
|
| 260 |
+
return torch.sum(
|
| 261 |
+
reshaped_mat *
|
| 262 |
+
quat1[..., :, None, None] *
|
| 263 |
+
quat2[..., None, :, None],
|
| 264 |
+
dim=(-3, -2)
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def quat_multiply_by_vec(quat, vec):
|
| 269 |
+
"""Multiply a quaternion by a pure-vector quaternion."""
|
| 270 |
+
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
|
| 271 |
+
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
|
| 272 |
+
return torch.sum(
|
| 273 |
+
reshaped_mat *
|
| 274 |
+
quat[..., :, None, None] *
|
| 275 |
+
vec[..., None, :, None],
|
| 276 |
+
dim=(-3, -2)
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def invert_rot_mat(rot_mat: torch.Tensor):
|
| 281 |
+
return rot_mat.transpose(-1, -2)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def invert_quat(quat: torch.Tensor):
|
| 285 |
+
quat_prime = quat.clone()
|
| 286 |
+
quat_prime[..., 1:] *= -1
|
| 287 |
+
inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
|
| 288 |
+
return inv
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class Rotation:
|
| 292 |
+
"""
|
| 293 |
+
A 3D rotation. Depending on how the object is initialized, the
|
| 294 |
+
rotation is represented by either a rotation matrix or a
|
| 295 |
+
quaternion, though both formats are made available by helper functions.
|
| 296 |
+
To simplify gradient computation, the underlying format of the
|
| 297 |
+
rotation cannot be changed in-place. Like Rigid, the class is designed
|
| 298 |
+
to mimic the behavior of a torch Tensor, almost as if each Rotation
|
| 299 |
+
object were a tensor of rotations, in one format or another.
|
| 300 |
+
"""
|
| 301 |
+
def __init__(self,
|
| 302 |
+
rot_mats: Optional[torch.Tensor] = None,
|
| 303 |
+
quats: Optional[torch.Tensor] = None,
|
| 304 |
+
normalize_quats: bool = True,
|
| 305 |
+
):
|
| 306 |
+
"""
|
| 307 |
+
Args:
|
| 308 |
+
rot_mats:
|
| 309 |
+
A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
|
| 310 |
+
quats
|
| 311 |
+
quats:
|
| 312 |
+
A [*, 4] quaternion. Mutually exclusive with rot_mats. If
|
| 313 |
+
normalize_quats is not True, must be a unit quaternion
|
| 314 |
+
normalize_quats:
|
| 315 |
+
If quats is specified, whether to normalize quats
|
| 316 |
+
"""
|
| 317 |
+
if((rot_mats is None and quats is None) or
|
| 318 |
+
(rot_mats is not None and quats is not None)):
|
| 319 |
+
raise ValueError("Exactly one input argument must be specified")
|
| 320 |
+
|
| 321 |
+
if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
|
| 322 |
+
(quats is not None and quats.shape[-1] != 4)):
|
| 323 |
+
raise ValueError(
|
| 324 |
+
"Incorrectly shaped rotation matrix or quaternion"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Force full-precision
|
| 328 |
+
if(quats is not None):
|
| 329 |
+
quats = quats.type(torch.float32)
|
| 330 |
+
if(rot_mats is not None):
|
| 331 |
+
rot_mats = rot_mats.type(torch.float32)
|
| 332 |
+
|
| 333 |
+
if(quats is not None and normalize_quats):
|
| 334 |
+
quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
|
| 335 |
+
|
| 336 |
+
self._rot_mats = rot_mats
|
| 337 |
+
self._quats = quats
|
| 338 |
+
|
| 339 |
+
@staticmethod
|
| 340 |
+
def identity(
|
| 341 |
+
shape,
|
| 342 |
+
dtype: Optional[torch.dtype] = None,
|
| 343 |
+
device: Optional[torch.device] = None,
|
| 344 |
+
requires_grad: bool = True,
|
| 345 |
+
fmt: str = "quat",
|
| 346 |
+
):
|
| 347 |
+
"""
|
| 348 |
+
Returns an identity Rotation.
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
shape:
|
| 352 |
+
The "shape" of the resulting Rotation object. See documentation
|
| 353 |
+
for the shape property
|
| 354 |
+
dtype:
|
| 355 |
+
The torch dtype for the rotation
|
| 356 |
+
device:
|
| 357 |
+
The torch device for the new rotation
|
| 358 |
+
requires_grad:
|
| 359 |
+
Whether the underlying tensors in the new rotation object
|
| 360 |
+
should require gradient computation
|
| 361 |
+
fmt:
|
| 362 |
+
One of "quat" or "rot_mat". Determines the underlying format
|
| 363 |
+
of the new object's rotation
|
| 364 |
+
Returns:
|
| 365 |
+
A new identity rotation
|
| 366 |
+
"""
|
| 367 |
+
if(fmt == "rot_mat"):
|
| 368 |
+
rot_mats = identity_rot_mats(
|
| 369 |
+
shape, dtype, device, requires_grad,
|
| 370 |
+
)
|
| 371 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 372 |
+
elif(fmt == "quat"):
|
| 373 |
+
quats = identity_quats(shape, dtype, device, requires_grad)
|
| 374 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 375 |
+
else:
|
| 376 |
+
raise ValueError(f"Invalid format: f{fmt}")
|
| 377 |
+
|
| 378 |
+
# Magic methods
|
| 379 |
+
|
| 380 |
+
def __getitem__(self, index: Any):
|
| 381 |
+
"""
|
| 382 |
+
Allows torch-style indexing over the virtual shape of the rotation
|
| 383 |
+
object. See documentation for the shape property.
|
| 384 |
+
|
| 385 |
+
Args:
|
| 386 |
+
index:
|
| 387 |
+
A torch index. E.g. (1, 3, 2), or (slice(None,))
|
| 388 |
+
Returns:
|
| 389 |
+
The indexed rotation
|
| 390 |
+
"""
|
| 391 |
+
if type(index) != tuple:
|
| 392 |
+
index = (index,)
|
| 393 |
+
|
| 394 |
+
if(self._rot_mats is not None):
|
| 395 |
+
rot_mats = self._rot_mats[index + (slice(None), slice(None))]
|
| 396 |
+
return Rotation(rot_mats=rot_mats)
|
| 397 |
+
elif(self._quats is not None):
|
| 398 |
+
quats = self._quats[index + (slice(None),)]
|
| 399 |
+
return Rotation(quats=quats, normalize_quats=False)
|
| 400 |
+
else:
|
| 401 |
+
raise ValueError("Both rotations are None")
|
| 402 |
+
|
| 403 |
+
def __mul__(self,
|
| 404 |
+
right: torch.Tensor,
|
| 405 |
+
):
|
| 406 |
+
"""
|
| 407 |
+
Pointwise left multiplication of the rotation with a tensor. Can be
|
| 408 |
+
used to e.g. mask the Rotation.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
right:
|
| 412 |
+
The tensor multiplicand
|
| 413 |
+
Returns:
|
| 414 |
+
The product
|
| 415 |
+
"""
|
| 416 |
+
if not(isinstance(right, torch.Tensor)):
|
| 417 |
+
raise TypeError("The other multiplicand must be a Tensor")
|
| 418 |
+
|
| 419 |
+
if(self._rot_mats is not None):
|
| 420 |
+
rot_mats = self._rot_mats * right[..., None, None]
|
| 421 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 422 |
+
elif(self._quats is not None):
|
| 423 |
+
quats = self._quats * right[..., None]
|
| 424 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 425 |
+
else:
|
| 426 |
+
raise ValueError("Both rotations are None")
|
| 427 |
+
|
| 428 |
+
def __rmul__(self,
|
| 429 |
+
left: torch.Tensor,
|
| 430 |
+
):
|
| 431 |
+
"""
|
| 432 |
+
Reverse pointwise multiplication of the rotation with a tensor.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
left:
|
| 436 |
+
The left multiplicand
|
| 437 |
+
Returns:
|
| 438 |
+
The product
|
| 439 |
+
"""
|
| 440 |
+
return self.__mul__(left)
|
| 441 |
+
|
| 442 |
+
# Properties
|
| 443 |
+
|
| 444 |
+
@property
|
| 445 |
+
def shape(self) -> torch.Size:
|
| 446 |
+
"""
|
| 447 |
+
Returns the virtual shape of the rotation object. This shape is
|
| 448 |
+
defined as the batch dimensions of the underlying rotation matrix
|
| 449 |
+
or quaternion. If the Rotation was initialized with a [10, 3, 3]
|
| 450 |
+
rotation matrix tensor, for example, the resulting shape would be
|
| 451 |
+
[10].
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
The virtual shape of the rotation object
|
| 455 |
+
"""
|
| 456 |
+
s = None
|
| 457 |
+
if(self._quats is not None):
|
| 458 |
+
s = self._quats.shape[:-1]
|
| 459 |
+
else:
|
| 460 |
+
s = self._rot_mats.shape[:-2]
|
| 461 |
+
|
| 462 |
+
return s
|
| 463 |
+
|
| 464 |
+
@property
|
| 465 |
+
def dtype(self) -> torch.dtype:
|
| 466 |
+
"""
|
| 467 |
+
Returns the dtype of the underlying rotation.
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
The dtype of the underlying rotation
|
| 471 |
+
"""
|
| 472 |
+
if(self._rot_mats is not None):
|
| 473 |
+
return self._rot_mats.dtype
|
| 474 |
+
elif(self._quats is not None):
|
| 475 |
+
return self._quats.dtype
|
| 476 |
+
else:
|
| 477 |
+
raise ValueError("Both rotations are None")
|
| 478 |
+
|
| 479 |
+
@property
|
| 480 |
+
def device(self) -> torch.device:
|
| 481 |
+
"""
|
| 482 |
+
The device of the underlying rotation
|
| 483 |
+
|
| 484 |
+
Returns:
|
| 485 |
+
The device of the underlying rotation
|
| 486 |
+
"""
|
| 487 |
+
if(self._rot_mats is not None):
|
| 488 |
+
return self._rot_mats.device
|
| 489 |
+
elif(self._quats is not None):
|
| 490 |
+
return self._quats.device
|
| 491 |
+
else:
|
| 492 |
+
raise ValueError("Both rotations are None")
|
| 493 |
+
|
| 494 |
+
@property
|
| 495 |
+
def requires_grad(self) -> bool:
|
| 496 |
+
"""
|
| 497 |
+
Returns the requires_grad property of the underlying rotation
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
The requires_grad property of the underlying tensor
|
| 501 |
+
"""
|
| 502 |
+
if(self._rot_mats is not None):
|
| 503 |
+
return self._rot_mats.requires_grad
|
| 504 |
+
elif(self._quats is not None):
|
| 505 |
+
return self._quats.requires_grad
|
| 506 |
+
else:
|
| 507 |
+
raise ValueError("Both rotations are None")
|
| 508 |
+
|
| 509 |
+
def get_rot_mats(self) -> torch.Tensor:
|
| 510 |
+
"""
|
| 511 |
+
Returns the underlying rotation as a rotation matrix tensor.
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
The rotation as a rotation matrix tensor
|
| 515 |
+
"""
|
| 516 |
+
rot_mats = self._rot_mats
|
| 517 |
+
if(rot_mats is None):
|
| 518 |
+
if(self._quats is None):
|
| 519 |
+
raise ValueError("Both rotations are None")
|
| 520 |
+
else:
|
| 521 |
+
rot_mats = quat_to_rot(self._quats)
|
| 522 |
+
|
| 523 |
+
return rot_mats
|
| 524 |
+
|
| 525 |
+
def get_quats(self) -> torch.Tensor:
|
| 526 |
+
"""
|
| 527 |
+
Returns the underlying rotation as a quaternion tensor.
|
| 528 |
+
|
| 529 |
+
Depending on whether the Rotation was initialized with a
|
| 530 |
+
quaternion, this function may call torch.linalg.eigh.
|
| 531 |
+
|
| 532 |
+
Returns:
|
| 533 |
+
The rotation as a quaternion tensor.
|
| 534 |
+
"""
|
| 535 |
+
quats = self._quats
|
| 536 |
+
if(quats is None):
|
| 537 |
+
if(self._rot_mats is None):
|
| 538 |
+
raise ValueError("Both rotations are None")
|
| 539 |
+
else:
|
| 540 |
+
# quats = rot_to_quat(self._rot_mats)
|
| 541 |
+
quats = rotation3d.matrix_to_quaternion(self._rot_mats)
|
| 542 |
+
|
| 543 |
+
return quats
|
| 544 |
+
|
| 545 |
+
def get_cur_rot(self) -> torch.Tensor:
|
| 546 |
+
"""
|
| 547 |
+
Return the underlying rotation in its current form
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
The stored rotation
|
| 551 |
+
"""
|
| 552 |
+
if(self._rot_mats is not None):
|
| 553 |
+
return self._rot_mats
|
| 554 |
+
elif(self._quats is not None):
|
| 555 |
+
return self._quats
|
| 556 |
+
else:
|
| 557 |
+
raise ValueError("Both rotations are None")
|
| 558 |
+
|
| 559 |
+
def get_rotvec(self, eps=1e-6) -> torch.Tensor:
|
| 560 |
+
"""
|
| 561 |
+
Return the underlying axis-angle rotation vector.
|
| 562 |
+
|
| 563 |
+
Follow's scipy's implementation:
|
| 564 |
+
https://github.com/scipy/scipy/blob/HEAD/scipy/spatial/transform/_rotation.pyx#L1385-L1402
|
| 565 |
+
|
| 566 |
+
Returns:
|
| 567 |
+
The stored rotation as a axis-angle vector.
|
| 568 |
+
"""
|
| 569 |
+
quat = self.get_quats()
|
| 570 |
+
# w > 0 to ensure 0 <= angle <= pi
|
| 571 |
+
flip = (quat[..., :1] < 0).float()
|
| 572 |
+
quat = (-1 * quat) * flip + (1 - flip) * quat
|
| 573 |
+
|
| 574 |
+
angle = 2 * torch.atan2(
|
| 575 |
+
torch.linalg.norm(quat[..., 1:], dim=-1),
|
| 576 |
+
quat[..., 0]
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
angle2 = angle * angle
|
| 580 |
+
small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
|
| 581 |
+
large_angle_scales = angle / torch.sin(angle / 2 + eps)
|
| 582 |
+
|
| 583 |
+
small_angles = (angle <= 1e-3).float()
|
| 584 |
+
rot_vec_scale = small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
|
| 585 |
+
rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
|
| 586 |
+
return rot_vec
|
| 587 |
+
|
| 588 |
+
# Rotation functions
|
| 589 |
+
|
| 590 |
+
def compose_q_update_vec(self,
|
| 591 |
+
q_update_vec: torch.Tensor,
|
| 592 |
+
normalize_quats: bool = True,
|
| 593 |
+
update_mask: torch.Tensor = None,
|
| 594 |
+
):
|
| 595 |
+
"""
|
| 596 |
+
Returns a new quaternion Rotation after updating the current
|
| 597 |
+
object's underlying rotation with a quaternion update, formatted
|
| 598 |
+
as a [*, 3] tensor whose final three columns represent x, y, z such
|
| 599 |
+
that (1, x, y, z) is the desired (not necessarily unit) quaternion
|
| 600 |
+
update.
|
| 601 |
+
|
| 602 |
+
Args:
|
| 603 |
+
q_update_vec:
|
| 604 |
+
A [*, 3] quaternion update tensor
|
| 605 |
+
normalize_quats:
|
| 606 |
+
Whether to normalize the output quaternion
|
| 607 |
+
Returns:
|
| 608 |
+
An updated Rotation
|
| 609 |
+
"""
|
| 610 |
+
quats = self.get_quats()
|
| 611 |
+
quat_update = quat_multiply_by_vec(quats, q_update_vec)
|
| 612 |
+
if update_mask is not None:
|
| 613 |
+
quat_update = quat_update * update_mask
|
| 614 |
+
new_quats = quats + quat_update
|
| 615 |
+
return Rotation(
|
| 616 |
+
rot_mats=None,
|
| 617 |
+
quats=new_quats,
|
| 618 |
+
normalize_quats=normalize_quats,
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
def compose_r(self, r):
|
| 622 |
+
"""
|
| 623 |
+
Compose the rotation matrices of the current Rotation object with
|
| 624 |
+
those of another.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
r:
|
| 628 |
+
An update rotation object
|
| 629 |
+
Returns:
|
| 630 |
+
An updated rotation object
|
| 631 |
+
"""
|
| 632 |
+
r1 = self.get_rot_mats()
|
| 633 |
+
r2 = r.get_rot_mats()
|
| 634 |
+
new_rot_mats = rot_matmul(r1, r2)
|
| 635 |
+
return Rotation(rot_mats=new_rot_mats, quats=None)
|
| 636 |
+
|
| 637 |
+
def compose_q(self, r, normalize_quats: bool = True):
|
| 638 |
+
"""
|
| 639 |
+
Compose the quaternions of the current Rotation object with those
|
| 640 |
+
of another.
|
| 641 |
+
|
| 642 |
+
Depending on whether either Rotation was initialized with
|
| 643 |
+
quaternions, this function may call torch.linalg.eigh.
|
| 644 |
+
|
| 645 |
+
Args:
|
| 646 |
+
r:
|
| 647 |
+
An update rotation object
|
| 648 |
+
Returns:
|
| 649 |
+
An updated rotation object
|
| 650 |
+
"""
|
| 651 |
+
q1 = self.get_quats()
|
| 652 |
+
q2 = r.get_quats()
|
| 653 |
+
new_quats = quat_multiply(q1, q2)
|
| 654 |
+
return Rotation(
|
| 655 |
+
rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
def apply(self, pts: torch.Tensor) -> torch.Tensor:
|
| 659 |
+
"""
|
| 660 |
+
Apply the current Rotation as a rotation matrix to a set of 3D
|
| 661 |
+
coordinates.
|
| 662 |
+
|
| 663 |
+
Args:
|
| 664 |
+
pts:
|
| 665 |
+
A [*, 3] set of points
|
| 666 |
+
Returns:
|
| 667 |
+
[*, 3] rotated points
|
| 668 |
+
"""
|
| 669 |
+
rot_mats = self.get_rot_mats()
|
| 670 |
+
return rot_vec_mul(rot_mats, pts)
|
| 671 |
+
|
| 672 |
+
def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
|
| 673 |
+
"""
|
| 674 |
+
The inverse of the apply() method.
|
| 675 |
+
|
| 676 |
+
Args:
|
| 677 |
+
pts:
|
| 678 |
+
A [*, 3] set of points
|
| 679 |
+
Returns:
|
| 680 |
+
[*, 3] inverse-rotated points
|
| 681 |
+
"""
|
| 682 |
+
rot_mats = self.get_rot_mats()
|
| 683 |
+
inv_rot_mats = invert_rot_mat(rot_mats)
|
| 684 |
+
return rot_vec_mul(inv_rot_mats, pts)
|
| 685 |
+
|
| 686 |
+
def invert(self) :
|
| 687 |
+
"""
|
| 688 |
+
Returns the inverse of the current Rotation.
|
| 689 |
+
|
| 690 |
+
Returns:
|
| 691 |
+
The inverse of the current Rotation
|
| 692 |
+
"""
|
| 693 |
+
if(self._rot_mats is not None):
|
| 694 |
+
return Rotation(
|
| 695 |
+
rot_mats=invert_rot_mat(self._rot_mats),
|
| 696 |
+
quats=None
|
| 697 |
+
)
|
| 698 |
+
elif(self._quats is not None):
|
| 699 |
+
return Rotation(
|
| 700 |
+
rot_mats=None,
|
| 701 |
+
quats=invert_quat(self._quats),
|
| 702 |
+
normalize_quats=False,
|
| 703 |
+
)
|
| 704 |
+
else:
|
| 705 |
+
raise ValueError("Both rotations are None")
|
| 706 |
+
|
| 707 |
+
# "Tensor" stuff
|
| 708 |
+
|
| 709 |
+
def unsqueeze(self,
|
| 710 |
+
dim: int,
|
| 711 |
+
):
|
| 712 |
+
"""
|
| 713 |
+
Analogous to torch.unsqueeze. The dimension is relative to the
|
| 714 |
+
shape of the Rotation object.
|
| 715 |
+
|
| 716 |
+
Args:
|
| 717 |
+
dim: A positive or negative dimension index.
|
| 718 |
+
Returns:
|
| 719 |
+
The unsqueezed Rotation.
|
| 720 |
+
"""
|
| 721 |
+
if dim >= len(self.shape):
|
| 722 |
+
raise ValueError("Invalid dimension")
|
| 723 |
+
|
| 724 |
+
if(self._rot_mats is not None):
|
| 725 |
+
rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
|
| 726 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 727 |
+
elif(self._quats is not None):
|
| 728 |
+
quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
|
| 729 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 730 |
+
else:
|
| 731 |
+
raise ValueError("Both rotations are None")
|
| 732 |
+
|
| 733 |
+
@staticmethod
|
| 734 |
+
def cat(
|
| 735 |
+
rs,
|
| 736 |
+
dim: int,
|
| 737 |
+
):
|
| 738 |
+
"""
|
| 739 |
+
Concatenates rotations along one of the batch dimensions. Analogous
|
| 740 |
+
to torch.cat().
|
| 741 |
+
|
| 742 |
+
Note that the output of this operation is always a rotation matrix,
|
| 743 |
+
regardless of the format of input rotations.
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
rs:
|
| 747 |
+
A list of rotation objects
|
| 748 |
+
dim:
|
| 749 |
+
The dimension along which the rotations should be
|
| 750 |
+
concatenated
|
| 751 |
+
Returns:
|
| 752 |
+
A concatenated Rotation object in rotation matrix format
|
| 753 |
+
"""
|
| 754 |
+
rot_mats = [r.get_rot_mats() for r in rs]
|
| 755 |
+
rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
|
| 756 |
+
|
| 757 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 758 |
+
|
| 759 |
+
def map_tensor_fn(self,
|
| 760 |
+
fn
|
| 761 |
+
):
|
| 762 |
+
"""
|
| 763 |
+
Apply a Tensor -> Tensor function to underlying rotation tensors,
|
| 764 |
+
mapping over the rotation dimension(s). Can be used e.g. to sum out
|
| 765 |
+
a one-hot batch dimension.
|
| 766 |
+
|
| 767 |
+
Args:
|
| 768 |
+
fn:
|
| 769 |
+
A Tensor -> Tensor function to be mapped over the Rotation
|
| 770 |
+
Returns:
|
| 771 |
+
The transformed Rotation object
|
| 772 |
+
"""
|
| 773 |
+
if(self._rot_mats is not None):
|
| 774 |
+
rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
|
| 775 |
+
rot_mats = torch.stack(
|
| 776 |
+
list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
|
| 777 |
+
)
|
| 778 |
+
rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
|
| 779 |
+
return Rotation(rot_mats=rot_mats, quats=None)
|
| 780 |
+
elif(self._quats is not None):
|
| 781 |
+
quats = torch.stack(
|
| 782 |
+
list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
|
| 783 |
+
)
|
| 784 |
+
return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
|
| 785 |
+
else:
|
| 786 |
+
raise ValueError("Both rotations are None")
|
| 787 |
+
|
| 788 |
+
def cuda(self):
|
| 789 |
+
"""
|
| 790 |
+
Analogous to the cuda() method of torch Tensors
|
| 791 |
+
|
| 792 |
+
Returns:
|
| 793 |
+
A copy of the Rotation in CUDA memory
|
| 794 |
+
"""
|
| 795 |
+
if(self._rot_mats is not None):
|
| 796 |
+
return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
|
| 797 |
+
elif(self._quats is not None):
|
| 798 |
+
return Rotation(
|
| 799 |
+
rot_mats=None,
|
| 800 |
+
quats=self._quats.cuda(),
|
| 801 |
+
normalize_quats=False
|
| 802 |
+
)
|
| 803 |
+
else:
|
| 804 |
+
raise ValueError("Both rotations are None")
|
| 805 |
+
|
| 806 |
+
def to(self,
|
| 807 |
+
device: Optional[torch.device],
|
| 808 |
+
dtype: Optional[torch.dtype]
|
| 809 |
+
):
|
| 810 |
+
"""
|
| 811 |
+
Analogous to the to() method of torch Tensors
|
| 812 |
+
|
| 813 |
+
Args:
|
| 814 |
+
device:
|
| 815 |
+
A torch device
|
| 816 |
+
dtype:
|
| 817 |
+
A torch dtype
|
| 818 |
+
Returns:
|
| 819 |
+
A copy of the Rotation using the new device and dtype
|
| 820 |
+
"""
|
| 821 |
+
if(self._rot_mats is not None):
|
| 822 |
+
return Rotation(
|
| 823 |
+
rot_mats=self._rot_mats.to(device=device, dtype=dtype),
|
| 824 |
+
quats=None,
|
| 825 |
+
)
|
| 826 |
+
elif(self._quats is not None):
|
| 827 |
+
return Rotation(
|
| 828 |
+
rot_mats=None,
|
| 829 |
+
quats=self._quats.to(device=device, dtype=dtype),
|
| 830 |
+
normalize_quats=False,
|
| 831 |
+
)
|
| 832 |
+
else:
|
| 833 |
+
raise ValueError("Both rotations are None")
|
| 834 |
+
|
| 835 |
+
def detach(self):
|
| 836 |
+
"""
|
| 837 |
+
Returns a copy of the Rotation whose underlying Tensor has been
|
| 838 |
+
detached from its torch graph.
|
| 839 |
+
|
| 840 |
+
Returns:
|
| 841 |
+
A copy of the Rotation whose underlying Tensor has been detached
|
| 842 |
+
from its torch graph
|
| 843 |
+
"""
|
| 844 |
+
if(self._rot_mats is not None):
|
| 845 |
+
return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
|
| 846 |
+
elif(self._quats is not None):
|
| 847 |
+
return Rotation(
|
| 848 |
+
rot_mats=None,
|
| 849 |
+
quats=self._quats.detach(),
|
| 850 |
+
normalize_quats=False,
|
| 851 |
+
)
|
| 852 |
+
else:
|
| 853 |
+
raise ValueError("Both rotations are None")
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
class Rigid:
|
| 857 |
+
"""
|
| 858 |
+
A class representing a rigid transformation. Little more than a wrapper
|
| 859 |
+
around two objects: a Rotation object and a [*, 3] translation
|
| 860 |
+
Designed to behave approximately like a single torch tensor with the
|
| 861 |
+
shape of the shared batch dimensions of its component parts.
|
| 862 |
+
"""
|
| 863 |
+
def __init__(self,
|
| 864 |
+
rots: Optional[Rotation],
|
| 865 |
+
trans: Optional[torch.Tensor],
|
| 866 |
+
):
|
| 867 |
+
"""
|
| 868 |
+
Args:
|
| 869 |
+
rots: A [*, 3, 3] rotation tensor
|
| 870 |
+
trans: A corresponding [*, 3] translation tensor
|
| 871 |
+
"""
|
| 872 |
+
# (we need device, dtype, etc. from at least one input)
|
| 873 |
+
|
| 874 |
+
batch_dims, dtype, device, requires_grad = None, None, None, None
|
| 875 |
+
if(trans is not None):
|
| 876 |
+
batch_dims = trans.shape[:-1]
|
| 877 |
+
dtype = trans.dtype
|
| 878 |
+
device = trans.device
|
| 879 |
+
requires_grad = trans.requires_grad
|
| 880 |
+
elif(rots is not None):
|
| 881 |
+
batch_dims = rots.shape
|
| 882 |
+
dtype = rots.dtype
|
| 883 |
+
device = rots.device
|
| 884 |
+
requires_grad = rots.requires_grad
|
| 885 |
+
else:
|
| 886 |
+
raise ValueError("At least one input argument must be specified")
|
| 887 |
+
|
| 888 |
+
if(rots is None):
|
| 889 |
+
rots = Rotation.identity(
|
| 890 |
+
batch_dims, dtype, device, requires_grad,
|
| 891 |
+
)
|
| 892 |
+
elif(trans is None):
|
| 893 |
+
trans = identity_trans(
|
| 894 |
+
batch_dims, dtype, device, requires_grad,
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
if((rots.shape != trans.shape[:-1]) or
|
| 898 |
+
(rots.device != trans.device)):
|
| 899 |
+
raise ValueError("Rots and trans incompatible")
|
| 900 |
+
|
| 901 |
+
# Force full precision. Happens to the rotations automatically.
|
| 902 |
+
trans = trans.type(torch.float32)
|
| 903 |
+
|
| 904 |
+
self._rots = rots
|
| 905 |
+
self._trans = trans
|
| 906 |
+
|
| 907 |
+
@staticmethod
|
| 908 |
+
def identity(
|
| 909 |
+
shape: Tuple[int],
|
| 910 |
+
dtype: Optional[torch.dtype] = None,
|
| 911 |
+
device: Optional[torch.device] = None,
|
| 912 |
+
requires_grad: bool = True,
|
| 913 |
+
fmt: str = "quat",
|
| 914 |
+
):
|
| 915 |
+
"""
|
| 916 |
+
Constructs an identity transformation.
|
| 917 |
+
|
| 918 |
+
Args:
|
| 919 |
+
shape:
|
| 920 |
+
The desired shape
|
| 921 |
+
dtype:
|
| 922 |
+
The dtype of both internal tensors
|
| 923 |
+
device:
|
| 924 |
+
The device of both internal tensors
|
| 925 |
+
requires_grad:
|
| 926 |
+
Whether grad should be enabled for the internal tensors
|
| 927 |
+
Returns:
|
| 928 |
+
The identity transformation
|
| 929 |
+
"""
|
| 930 |
+
return Rigid(
|
| 931 |
+
Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
|
| 932 |
+
identity_trans(shape, dtype, device, requires_grad),
|
| 933 |
+
)
|
| 934 |
+
|
| 935 |
+
def __getitem__(self,
|
| 936 |
+
index: Any,
|
| 937 |
+
):
|
| 938 |
+
"""
|
| 939 |
+
Indexes the affine transformation with PyTorch-style indices.
|
| 940 |
+
The index is applied to the shared dimensions of both the rotation
|
| 941 |
+
and the translation.
|
| 942 |
+
|
| 943 |
+
E.g.::
|
| 944 |
+
|
| 945 |
+
r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
|
| 946 |
+
t = Rigid(r, torch.rand(10, 10, 3))
|
| 947 |
+
indexed = t[3, 4:6]
|
| 948 |
+
assert(indexed.shape == (2,))
|
| 949 |
+
assert(indexed.get_rots().shape == (2,))
|
| 950 |
+
assert(indexed.get_trans().shape == (2, 3))
|
| 951 |
+
|
| 952 |
+
Args:
|
| 953 |
+
index: A standard torch tensor index. E.g. 8, (10, None, 3),
|
| 954 |
+
or (3, slice(0, 1, None))
|
| 955 |
+
Returns:
|
| 956 |
+
The indexed tensor
|
| 957 |
+
"""
|
| 958 |
+
if type(index) != tuple:
|
| 959 |
+
index = (index,)
|
| 960 |
+
|
| 961 |
+
return Rigid(
|
| 962 |
+
self._rots[index],
|
| 963 |
+
self._trans[index + (slice(None),)],
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
def __mul__(self,
|
| 967 |
+
right: torch.Tensor,
|
| 968 |
+
):
|
| 969 |
+
"""
|
| 970 |
+
Pointwise left multiplication of the transformation with a tensor.
|
| 971 |
+
Can be used to e.g. mask the Rigid.
|
| 972 |
+
|
| 973 |
+
Args:
|
| 974 |
+
right:
|
| 975 |
+
The tensor multiplicand
|
| 976 |
+
Returns:
|
| 977 |
+
The product
|
| 978 |
+
"""
|
| 979 |
+
if not(isinstance(right, torch.Tensor)):
|
| 980 |
+
raise TypeError("The other multiplicand must be a Tensor")
|
| 981 |
+
|
| 982 |
+
new_rots = self._rots * right
|
| 983 |
+
new_trans = self._trans * right[..., None]
|
| 984 |
+
|
| 985 |
+
return Rigid(new_rots, new_trans)
|
| 986 |
+
|
| 987 |
+
def __rmul__(self,
|
| 988 |
+
left: torch.Tensor,
|
| 989 |
+
):
|
| 990 |
+
"""
|
| 991 |
+
Reverse pointwise multiplication of the transformation with a
|
| 992 |
+
tensor.
|
| 993 |
+
|
| 994 |
+
Args:
|
| 995 |
+
left:
|
| 996 |
+
The left multiplicand
|
| 997 |
+
Returns:
|
| 998 |
+
The product
|
| 999 |
+
"""
|
| 1000 |
+
return self.__mul__(left)
|
| 1001 |
+
|
| 1002 |
+
@property
|
| 1003 |
+
def shape(self) -> torch.Size:
|
| 1004 |
+
"""
|
| 1005 |
+
Returns the shape of the shared dimensions of the rotation and
|
| 1006 |
+
the translation.
|
| 1007 |
+
|
| 1008 |
+
Returns:
|
| 1009 |
+
The shape of the transformation
|
| 1010 |
+
"""
|
| 1011 |
+
s = self._trans.shape[:-1]
|
| 1012 |
+
return s
|
| 1013 |
+
|
| 1014 |
+
@property
|
| 1015 |
+
def device(self) -> torch.device:
|
| 1016 |
+
"""
|
| 1017 |
+
Returns the device on which the Rigid's tensors are located.
|
| 1018 |
+
|
| 1019 |
+
Returns:
|
| 1020 |
+
The device on which the Rigid's tensors are located
|
| 1021 |
+
"""
|
| 1022 |
+
return self._trans.device
|
| 1023 |
+
|
| 1024 |
+
def get_rots(self) -> Rotation:
|
| 1025 |
+
"""
|
| 1026 |
+
Getter for the rotation.
|
| 1027 |
+
|
| 1028 |
+
Returns:
|
| 1029 |
+
The rotation object
|
| 1030 |
+
"""
|
| 1031 |
+
return self._rots
|
| 1032 |
+
|
| 1033 |
+
def get_trans(self) -> torch.Tensor:
|
| 1034 |
+
"""
|
| 1035 |
+
Getter for the translation.
|
| 1036 |
+
|
| 1037 |
+
Returns:
|
| 1038 |
+
The stored translation
|
| 1039 |
+
"""
|
| 1040 |
+
return self._trans
|
| 1041 |
+
|
| 1042 |
+
def compose_q_update_vec(self,
|
| 1043 |
+
q_update_vec: torch.Tensor,
|
| 1044 |
+
update_mask: torch.Tensor=None,
|
| 1045 |
+
):
|
| 1046 |
+
"""
|
| 1047 |
+
Composes the transformation with a quaternion update vector of
|
| 1048 |
+
shape [*, 6], where the final 6 columns represent the x, y, and
|
| 1049 |
+
z values of a quaternion of form (1, x, y, z) followed by a 3D
|
| 1050 |
+
translation.
|
| 1051 |
+
|
| 1052 |
+
Args:
|
| 1053 |
+
q_vec: The quaternion update vector.
|
| 1054 |
+
Returns:
|
| 1055 |
+
The composed transformation.
|
| 1056 |
+
"""
|
| 1057 |
+
q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
|
| 1058 |
+
new_rots = self._rots.compose_q_update_vec(
|
| 1059 |
+
q_vec, update_mask=update_mask)
|
| 1060 |
+
|
| 1061 |
+
trans_update = self._rots.apply(t_vec)
|
| 1062 |
+
if update_mask is not None:
|
| 1063 |
+
trans_update = trans_update * update_mask
|
| 1064 |
+
new_translation = self._trans + trans_update
|
| 1065 |
+
|
| 1066 |
+
return Rigid(new_rots, new_translation)
|
| 1067 |
+
|
| 1068 |
+
def compose(self,
|
| 1069 |
+
r,
|
| 1070 |
+
):
|
| 1071 |
+
"""
|
| 1072 |
+
Composes the current rigid object with another.
|
| 1073 |
+
|
| 1074 |
+
Args:
|
| 1075 |
+
r:
|
| 1076 |
+
Another Rigid object
|
| 1077 |
+
Returns:
|
| 1078 |
+
The composition of the two transformations
|
| 1079 |
+
"""
|
| 1080 |
+
new_rot = self._rots.compose_r(r._rots)
|
| 1081 |
+
new_trans = self._rots.apply(r._trans) + self._trans
|
| 1082 |
+
return Rigid(new_rot, new_trans)
|
| 1083 |
+
|
| 1084 |
+
def compose_r(self,
|
| 1085 |
+
rot,
|
| 1086 |
+
order='right'
|
| 1087 |
+
):
|
| 1088 |
+
"""
|
| 1089 |
+
Composes the current rigid object with another.
|
| 1090 |
+
|
| 1091 |
+
Args:
|
| 1092 |
+
r:
|
| 1093 |
+
Another Rigid object
|
| 1094 |
+
order:
|
| 1095 |
+
Order in which to perform rotation multiplication.
|
| 1096 |
+
Returns:
|
| 1097 |
+
The composition of the two transformations
|
| 1098 |
+
"""
|
| 1099 |
+
if order == 'right':
|
| 1100 |
+
new_rot = self._rots.compose_r(rot)
|
| 1101 |
+
elif order == 'left':
|
| 1102 |
+
new_rot = rot.compose_r(self._rots)
|
| 1103 |
+
else:
|
| 1104 |
+
raise ValueError(f'Unrecognized multiplication order: {order}')
|
| 1105 |
+
return Rigid(new_rot, self._trans)
|
| 1106 |
+
|
| 1107 |
+
def apply(self,
|
| 1108 |
+
pts: torch.Tensor,
|
| 1109 |
+
) -> torch.Tensor:
|
| 1110 |
+
"""
|
| 1111 |
+
Applies the transformation to a coordinate tensor.
|
| 1112 |
+
|
| 1113 |
+
Args:
|
| 1114 |
+
pts: A [*, 3] coordinate tensor.
|
| 1115 |
+
Returns:
|
| 1116 |
+
The transformed points.
|
| 1117 |
+
"""
|
| 1118 |
+
rotated = self._rots.apply(pts)
|
| 1119 |
+
return rotated + self._trans
|
| 1120 |
+
|
| 1121 |
+
def invert_apply(self,
|
| 1122 |
+
pts: torch.Tensor
|
| 1123 |
+
) -> torch.Tensor:
|
| 1124 |
+
"""
|
| 1125 |
+
Applies the inverse of the transformation to a coordinate tensor.
|
| 1126 |
+
|
| 1127 |
+
Args:
|
| 1128 |
+
pts: A [*, 3] coordinate tensor
|
| 1129 |
+
Returns:
|
| 1130 |
+
The transformed points.
|
| 1131 |
+
"""
|
| 1132 |
+
pts = pts - self._trans
|
| 1133 |
+
return self._rots.invert_apply(pts)
|
| 1134 |
+
|
| 1135 |
+
def invert(self):
|
| 1136 |
+
"""
|
| 1137 |
+
Inverts the transformation.
|
| 1138 |
+
|
| 1139 |
+
Returns:
|
| 1140 |
+
The inverse transformation.
|
| 1141 |
+
"""
|
| 1142 |
+
rot_inv = self._rots.invert()
|
| 1143 |
+
trn_inv = rot_inv.apply(self._trans)
|
| 1144 |
+
|
| 1145 |
+
return Rigid(rot_inv, -1 * trn_inv)
|
| 1146 |
+
|
| 1147 |
+
def map_tensor_fn(self,
|
| 1148 |
+
fn
|
| 1149 |
+
):
|
| 1150 |
+
"""
|
| 1151 |
+
Apply a Tensor -> Tensor function to underlying translation and
|
| 1152 |
+
rotation tensors, mapping over the translation/rotation dimensions
|
| 1153 |
+
respectively.
|
| 1154 |
+
|
| 1155 |
+
Args:
|
| 1156 |
+
fn:
|
| 1157 |
+
A Tensor -> Tensor function to be mapped over the Rigid
|
| 1158 |
+
Returns:
|
| 1159 |
+
The transformed Rigid object
|
| 1160 |
+
"""
|
| 1161 |
+
new_rots = self._rots.map_tensor_fn(fn)
|
| 1162 |
+
new_trans = torch.stack(
|
| 1163 |
+
list(map(fn, torch.unbind(self._trans, dim=-1))),
|
| 1164 |
+
dim=-1
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
return Rigid(new_rots, new_trans)
|
| 1168 |
+
|
| 1169 |
+
def to_tensor_4x4(self) -> torch.Tensor:
|
| 1170 |
+
"""
|
| 1171 |
+
Converts a transformation to a homogenous transformation tensor.
|
| 1172 |
+
|
| 1173 |
+
Returns:
|
| 1174 |
+
A [*, 4, 4] homogenous transformation tensor
|
| 1175 |
+
"""
|
| 1176 |
+
tensor = self._trans.new_zeros((*self.shape, 4, 4))
|
| 1177 |
+
tensor[..., :3, :3] = self._rots.get_rot_mats()
|
| 1178 |
+
tensor[..., :3, 3] = self._trans
|
| 1179 |
+
tensor[..., 3, 3] = 1
|
| 1180 |
+
return tensor
|
| 1181 |
+
|
| 1182 |
+
@staticmethod
|
| 1183 |
+
def from_tensor_4x4(
|
| 1184 |
+
t: torch.Tensor
|
| 1185 |
+
):
|
| 1186 |
+
"""
|
| 1187 |
+
Constructs a transformation from a homogenous transformation
|
| 1188 |
+
tensor.
|
| 1189 |
+
|
| 1190 |
+
Args:
|
| 1191 |
+
t: [*, 4, 4] homogenous transformation tensor
|
| 1192 |
+
Returns:
|
| 1193 |
+
T object with shape [*]
|
| 1194 |
+
"""
|
| 1195 |
+
if(t.shape[-2:] != (4, 4)):
|
| 1196 |
+
raise ValueError("Incorrectly shaped input tensor")
|
| 1197 |
+
|
| 1198 |
+
rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
|
| 1199 |
+
trans = t[..., :3, 3]
|
| 1200 |
+
|
| 1201 |
+
return Rigid(rots, trans)
|
| 1202 |
+
|
| 1203 |
+
def to_tensor_7(self) -> torch.Tensor:
|
| 1204 |
+
"""
|
| 1205 |
+
Converts a transformation to a tensor with 7 final columns, four
|
| 1206 |
+
for the quaternion followed by three for the translation.
|
| 1207 |
+
|
| 1208 |
+
Returns:
|
| 1209 |
+
A [*, 7] tensor representation of the transformation
|
| 1210 |
+
"""
|
| 1211 |
+
tensor = self._trans.new_zeros((*self.shape, 7))
|
| 1212 |
+
tensor[..., :4] = self._rots.get_quats()
|
| 1213 |
+
tensor[..., 4:] = self._trans
|
| 1214 |
+
|
| 1215 |
+
return tensor
|
| 1216 |
+
|
| 1217 |
+
@staticmethod
|
| 1218 |
+
def from_tensor_7(
|
| 1219 |
+
t: torch.Tensor,
|
| 1220 |
+
normalize_quats: bool = False,
|
| 1221 |
+
):
|
| 1222 |
+
if(t.shape[-1] != 7):
|
| 1223 |
+
raise ValueError("Incorrectly shaped input tensor")
|
| 1224 |
+
|
| 1225 |
+
quats, trans = t[..., :4], t[..., 4:]
|
| 1226 |
+
|
| 1227 |
+
rots = Rotation(
|
| 1228 |
+
rot_mats=None,
|
| 1229 |
+
quats=quats,
|
| 1230 |
+
normalize_quats=normalize_quats
|
| 1231 |
+
)
|
| 1232 |
+
|
| 1233 |
+
return Rigid(rots, trans)
|
| 1234 |
+
|
| 1235 |
+
@staticmethod
|
| 1236 |
+
def from_3_points(
|
| 1237 |
+
p_neg_x_axis: torch.Tensor,
|
| 1238 |
+
origin: torch.Tensor,
|
| 1239 |
+
p_xy_plane: torch.Tensor,
|
| 1240 |
+
eps: float = 1e-8
|
| 1241 |
+
):
|
| 1242 |
+
"""
|
| 1243 |
+
Implements algorithm 21. Constructs transformations from sets of 3
|
| 1244 |
+
points using the Gram-Schmidt algorithm.
|
| 1245 |
+
|
| 1246 |
+
Args:
|
| 1247 |
+
p_neg_x_axis: [*, 3] coordinates
|
| 1248 |
+
origin: [*, 3] coordinates used as frame origins
|
| 1249 |
+
p_xy_plane: [*, 3] coordinates
|
| 1250 |
+
eps: Small epsilon value
|
| 1251 |
+
Returns:
|
| 1252 |
+
A transformation object of shape [*]
|
| 1253 |
+
"""
|
| 1254 |
+
p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
|
| 1255 |
+
origin = torch.unbind(origin, dim=-1)
|
| 1256 |
+
p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
|
| 1257 |
+
|
| 1258 |
+
e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
|
| 1259 |
+
e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
|
| 1260 |
+
|
| 1261 |
+
denom = torch.sqrt(sum((c * c for c in e0)) + eps)
|
| 1262 |
+
e0 = [c / denom for c in e0]
|
| 1263 |
+
dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
|
| 1264 |
+
e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
|
| 1265 |
+
denom = torch.sqrt(sum((c * c for c in e1)) + eps)
|
| 1266 |
+
e1 = [c / denom for c in e1]
|
| 1267 |
+
e2 = [
|
| 1268 |
+
e0[1] * e1[2] - e0[2] * e1[1],
|
| 1269 |
+
e0[2] * e1[0] - e0[0] * e1[2],
|
| 1270 |
+
e0[0] * e1[1] - e0[1] * e1[0],
|
| 1271 |
+
]
|
| 1272 |
+
|
| 1273 |
+
rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
|
| 1274 |
+
rots = rots.reshape(rots.shape[:-1] + (3, 3))
|
| 1275 |
+
|
| 1276 |
+
rot_obj = Rotation(rot_mats=rots, quats=None)
|
| 1277 |
+
|
| 1278 |
+
return Rigid(rot_obj, torch.stack(origin, dim=-1))
|
| 1279 |
+
|
| 1280 |
+
def unsqueeze(self,
|
| 1281 |
+
dim: int,
|
| 1282 |
+
):
|
| 1283 |
+
"""
|
| 1284 |
+
Analogous to torch.unsqueeze. The dimension is relative to the
|
| 1285 |
+
shared dimensions of the rotation/translation.
|
| 1286 |
+
|
| 1287 |
+
Args:
|
| 1288 |
+
dim: A positive or negative dimension index.
|
| 1289 |
+
Returns:
|
| 1290 |
+
The unsqueezed transformation.
|
| 1291 |
+
"""
|
| 1292 |
+
if dim >= len(self.shape):
|
| 1293 |
+
raise ValueError("Invalid dimension")
|
| 1294 |
+
rots = self._rots.unsqueeze(dim)
|
| 1295 |
+
trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
|
| 1296 |
+
|
| 1297 |
+
return Rigid(rots, trans)
|
| 1298 |
+
|
| 1299 |
+
@staticmethod
|
| 1300 |
+
def cat(
|
| 1301 |
+
ts,
|
| 1302 |
+
dim: int,
|
| 1303 |
+
):
|
| 1304 |
+
"""
|
| 1305 |
+
Concatenates transformations along a new dimension.
|
| 1306 |
+
|
| 1307 |
+
Args:
|
| 1308 |
+
ts:
|
| 1309 |
+
A list of T objects
|
| 1310 |
+
dim:
|
| 1311 |
+
The dimension along which the transformations should be
|
| 1312 |
+
concatenated
|
| 1313 |
+
Returns:
|
| 1314 |
+
A concatenated transformation object
|
| 1315 |
+
"""
|
| 1316 |
+
rots = Rotation.cat([t._rots for t in ts], dim)
|
| 1317 |
+
trans = torch.cat(
|
| 1318 |
+
[t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
|
| 1319 |
+
)
|
| 1320 |
+
|
| 1321 |
+
return Rigid(rots, trans)
|
| 1322 |
+
|
| 1323 |
+
def apply_rot_fn(self, fn):
|
| 1324 |
+
"""
|
| 1325 |
+
Applies a Rotation -> Rotation function to the stored rotation
|
| 1326 |
+
object.
|
| 1327 |
+
|
| 1328 |
+
Args:
|
| 1329 |
+
fn: A function of type Rotation -> Rotation
|
| 1330 |
+
Returns:
|
| 1331 |
+
A transformation object with a transformed rotation.
|
| 1332 |
+
"""
|
| 1333 |
+
return Rigid(fn(self._rots), self._trans)
|
| 1334 |
+
|
| 1335 |
+
def apply_trans_fn(self, fn):
|
| 1336 |
+
"""
|
| 1337 |
+
Applies a Tensor -> Tensor function to the stored translation.
|
| 1338 |
+
|
| 1339 |
+
Args:
|
| 1340 |
+
fn:
|
| 1341 |
+
A function of type Tensor -> Tensor to be applied to the
|
| 1342 |
+
translation
|
| 1343 |
+
Returns:
|
| 1344 |
+
A transformation object with a transformed translation.
|
| 1345 |
+
"""
|
| 1346 |
+
return Rigid(self._rots, fn(self._trans))
|
| 1347 |
+
|
| 1348 |
+
def scale_translation(self, trans_scale_factor: float):
|
| 1349 |
+
"""
|
| 1350 |
+
Scales the translation by a constant factor.
|
| 1351 |
+
|
| 1352 |
+
Args:
|
| 1353 |
+
trans_scale_factor:
|
| 1354 |
+
The constant factor
|
| 1355 |
+
Returns:
|
| 1356 |
+
A transformation object with a scaled translation.
|
| 1357 |
+
"""
|
| 1358 |
+
fn = lambda t: t * trans_scale_factor
|
| 1359 |
+
return self.apply_trans_fn(fn)
|
| 1360 |
+
|
| 1361 |
+
def stop_rot_gradient(self):
|
| 1362 |
+
"""
|
| 1363 |
+
Detaches the underlying rotation object
|
| 1364 |
+
|
| 1365 |
+
Returns:
|
| 1366 |
+
A transformation object with detached rotations
|
| 1367 |
+
"""
|
| 1368 |
+
fn = lambda r: r.detach()
|
| 1369 |
+
return self.apply_rot_fn(fn)
|
| 1370 |
+
|
| 1371 |
+
@staticmethod
|
| 1372 |
+
def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
|
| 1373 |
+
"""
|
| 1374 |
+
Returns a transformation object from reference coordinates.
|
| 1375 |
+
|
| 1376 |
+
Note that this method does not take care of symmetries. If you
|
| 1377 |
+
provide the atom positions in the non-standard way, the N atom will
|
| 1378 |
+
end up not at [-0.527250, 1.359329, 0.0] but instead at
|
| 1379 |
+
[-0.527250, -1.359329, 0.0]. You need to take care of such cases in
|
| 1380 |
+
your code.
|
| 1381 |
+
|
| 1382 |
+
Args:
|
| 1383 |
+
n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
|
| 1384 |
+
ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
|
| 1385 |
+
c_xyz: A [*, 3] tensor of carbon xyz coordinates.
|
| 1386 |
+
Returns:
|
| 1387 |
+
A transformation object. After applying the translation and
|
| 1388 |
+
rotation to the reference backbone, the coordinates will
|
| 1389 |
+
approximately equal to the input coordinates.
|
| 1390 |
+
"""
|
| 1391 |
+
translation = -1 * ca_xyz
|
| 1392 |
+
n_xyz = n_xyz + translation
|
| 1393 |
+
c_xyz = c_xyz + translation
|
| 1394 |
+
|
| 1395 |
+
c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
|
| 1396 |
+
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
|
| 1397 |
+
sin_c1 = -c_y / norm
|
| 1398 |
+
cos_c1 = c_x / norm
|
| 1399 |
+
zeros = sin_c1.new_zeros(sin_c1.shape)
|
| 1400 |
+
ones = sin_c1.new_ones(sin_c1.shape)
|
| 1401 |
+
|
| 1402 |
+
c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
|
| 1403 |
+
c1_rots[..., 0, 0] = cos_c1
|
| 1404 |
+
c1_rots[..., 0, 1] = -1 * sin_c1
|
| 1405 |
+
c1_rots[..., 1, 0] = sin_c1
|
| 1406 |
+
c1_rots[..., 1, 1] = cos_c1
|
| 1407 |
+
c1_rots[..., 2, 2] = 1
|
| 1408 |
+
|
| 1409 |
+
norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
|
| 1410 |
+
sin_c2 = c_z / norm
|
| 1411 |
+
cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm
|
| 1412 |
+
|
| 1413 |
+
c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
| 1414 |
+
c2_rots[..., 0, 0] = cos_c2
|
| 1415 |
+
c2_rots[..., 0, 2] = sin_c2
|
| 1416 |
+
c2_rots[..., 1, 1] = 1
|
| 1417 |
+
c1_rots[..., 2, 0] = -1 * sin_c2
|
| 1418 |
+
c1_rots[..., 2, 2] = cos_c2
|
| 1419 |
+
|
| 1420 |
+
c_rots = rot_matmul(c2_rots, c1_rots)
|
| 1421 |
+
n_xyz = rot_vec_mul(c_rots, n_xyz)
|
| 1422 |
+
|
| 1423 |
+
_, n_y, n_z = [n_xyz[..., i] for i in range(3)]
|
| 1424 |
+
norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
|
| 1425 |
+
sin_n = -n_z / norm
|
| 1426 |
+
cos_n = n_y / norm
|
| 1427 |
+
|
| 1428 |
+
n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
|
| 1429 |
+
n_rots[..., 0, 0] = 1
|
| 1430 |
+
n_rots[..., 1, 1] = cos_n
|
| 1431 |
+
n_rots[..., 1, 2] = -1 * sin_n
|
| 1432 |
+
n_rots[..., 2, 1] = sin_n
|
| 1433 |
+
n_rots[..., 2, 2] = cos_n
|
| 1434 |
+
|
| 1435 |
+
rots = rot_matmul(n_rots, c_rots)
|
| 1436 |
+
|
| 1437 |
+
rots = rots.transpose(-1, -2)
|
| 1438 |
+
translation = -1 * translation
|
| 1439 |
+
|
| 1440 |
+
rot_obj = Rotation(rot_mats=rots, quats=None)
|
| 1441 |
+
|
| 1442 |
+
return Rigid(rot_obj, translation)
|
| 1443 |
+
|
| 1444 |
+
def cuda(self):
|
| 1445 |
+
"""
|
| 1446 |
+
Moves the transformation object to GPU memory
|
| 1447 |
+
|
| 1448 |
+
Returns:
|
| 1449 |
+
A version of the transformation on GPU
|
| 1450 |
+
"""
|
| 1451 |
+
return Rigid(self._rots.cuda(), self._trans.cuda())
|
analysis/src/common/rotation3d.py
ADDED
|
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Optional, Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
|
| 12 |
+
Device = Union[str, torch.device]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
The transformation matrices returned from the functions in this file assume
|
| 17 |
+
the points on which the transformation will be applied are column vectors.
|
| 18 |
+
i.e. the R matrix is structured as
|
| 19 |
+
|
| 20 |
+
R = [
|
| 21 |
+
[Rxx, Rxy, Rxz],
|
| 22 |
+
[Ryx, Ryy, Ryz],
|
| 23 |
+
[Rzx, Rzy, Rzz],
|
| 24 |
+
] # (3, 3)
|
| 25 |
+
|
| 26 |
+
This matrix can be applied to column vectors by post multiplication
|
| 27 |
+
by the points e.g.
|
| 28 |
+
|
| 29 |
+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
|
| 30 |
+
transformed_points = R * points
|
| 31 |
+
|
| 32 |
+
To apply the same matrix to points which are row vectors, the R matrix
|
| 33 |
+
can be transposed and pre multiplied by the points:
|
| 34 |
+
|
| 35 |
+
e.g.
|
| 36 |
+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
| 37 |
+
transformed_points = points * R.transpose(1, 0)
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
"""
|
| 43 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
quaternions: quaternions with real part first,
|
| 47 |
+
as tensor of shape (..., 4).
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 51 |
+
"""
|
| 52 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
| 53 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
| 54 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 55 |
+
|
| 56 |
+
o = torch.stack(
|
| 57 |
+
(
|
| 58 |
+
1 - two_s * (j * j + k * k),
|
| 59 |
+
two_s * (i * j - k * r),
|
| 60 |
+
two_s * (i * k + j * r),
|
| 61 |
+
two_s * (i * j + k * r),
|
| 62 |
+
1 - two_s * (i * i + k * k),
|
| 63 |
+
two_s * (j * k - i * r),
|
| 64 |
+
two_s * (i * k - j * r),
|
| 65 |
+
two_s * (j * k + i * r),
|
| 66 |
+
1 - two_s * (i * i + j * j),
|
| 67 |
+
),
|
| 68 |
+
-1,
|
| 69 |
+
)
|
| 70 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
"""
|
| 75 |
+
Return a tensor where each element has the absolute value taken from the,
|
| 76 |
+
corresponding element of a, with sign taken from the corresponding
|
| 77 |
+
element of b. This is like the standard copysign floating-point operation,
|
| 78 |
+
but is not careful about negative 0 and NaN.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
a: source tensor.
|
| 82 |
+
b: tensor whose signs will be used, of the same shape as a.
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
Tensor of the same shape as a with the signs of b.
|
| 86 |
+
"""
|
| 87 |
+
signs_differ = (a < 0) != (b < 0)
|
| 88 |
+
return torch.where(signs_differ, -a, a)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
"""
|
| 93 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 94 |
+
but with a zero subgradient where x is 0.
|
| 95 |
+
"""
|
| 96 |
+
ret = torch.zeros_like(x)
|
| 97 |
+
positive_mask = x > 0
|
| 98 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 99 |
+
return ret
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
| 103 |
+
"""
|
| 104 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 111 |
+
"""
|
| 112 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 113 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 114 |
+
|
| 115 |
+
batch_dim = matrix.shape[:-2]
|
| 116 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
| 117 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
q_abs = _sqrt_positive_part(
|
| 121 |
+
torch.stack(
|
| 122 |
+
[
|
| 123 |
+
1.0 + m00 + m11 + m22,
|
| 124 |
+
1.0 + m00 - m11 - m22,
|
| 125 |
+
1.0 - m00 + m11 - m22,
|
| 126 |
+
1.0 - m00 - m11 + m22,
|
| 127 |
+
],
|
| 128 |
+
dim=-1,
|
| 129 |
+
)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
| 133 |
+
quat_by_rijk = torch.stack(
|
| 134 |
+
[
|
| 135 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 136 |
+
# `int`.
|
| 137 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 138 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 139 |
+
# `int`.
|
| 140 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 141 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 142 |
+
# `int`.
|
| 143 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 144 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
| 145 |
+
# `int`.
|
| 146 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 147 |
+
],
|
| 148 |
+
dim=-2,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
| 152 |
+
# the candidate won't be picked.
|
| 153 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 154 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 155 |
+
|
| 156 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
| 157 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
| 158 |
+
|
| 159 |
+
return quat_candidates[
|
| 160 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
| 161 |
+
].reshape(batch_dim + (4,))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
"""
|
| 166 |
+
Return the rotation matrices for one of the rotations about an axis
|
| 167 |
+
of which Euler angles describe, for each value of the angle given.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
axis: Axis label "X" or "Y or "Z".
|
| 171 |
+
angle: any shape tensor of Euler angles in radians
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
cos = torch.cos(angle)
|
| 178 |
+
sin = torch.sin(angle)
|
| 179 |
+
one = torch.ones_like(angle)
|
| 180 |
+
zero = torch.zeros_like(angle)
|
| 181 |
+
|
| 182 |
+
if axis == "X":
|
| 183 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
| 184 |
+
elif axis == "Y":
|
| 185 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
| 186 |
+
elif axis == "Z":
|
| 187 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
| 188 |
+
else:
|
| 189 |
+
raise ValueError("letter must be either X, Y or Z.")
|
| 190 |
+
|
| 191 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
|
| 195 |
+
"""
|
| 196 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
| 200 |
+
convention: Convention string of three uppercase letters from
|
| 201 |
+
{"X", "Y", and "Z"}.
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 205 |
+
"""
|
| 206 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
| 207 |
+
raise ValueError("Invalid input euler angles.")
|
| 208 |
+
if len(convention) != 3:
|
| 209 |
+
raise ValueError("Convention must have 3 letters.")
|
| 210 |
+
if convention[1] in (convention[0], convention[2]):
|
| 211 |
+
raise ValueError(f"Invalid convention {convention}.")
|
| 212 |
+
for letter in convention:
|
| 213 |
+
if letter not in ("X", "Y", "Z"):
|
| 214 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
| 215 |
+
matrices = [
|
| 216 |
+
_axis_angle_rotation(c, e)
|
| 217 |
+
for c, e in zip(convention, torch.unbind(euler_angles, -1))
|
| 218 |
+
]
|
| 219 |
+
# return functools.reduce(torch.matmul, matrices)
|
| 220 |
+
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def _angle_from_tan(
|
| 224 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
| 225 |
+
) -> torch.Tensor:
|
| 226 |
+
"""
|
| 227 |
+
Extract the first or third Euler angle from the two members of
|
| 228 |
+
the matrix which are positive constant times its sine and cosine.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
| 232 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
| 233 |
+
convention.
|
| 234 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
| 235 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
| 236 |
+
which means the relevant entries are in the same row of the
|
| 237 |
+
rotation matrix. If not, they are in the same column.
|
| 238 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Euler Angles in radians for each matrix in data as a tensor
|
| 242 |
+
of shape (...).
|
| 243 |
+
"""
|
| 244 |
+
|
| 245 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
| 246 |
+
if horizontal:
|
| 247 |
+
i2, i1 = i1, i2
|
| 248 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
| 249 |
+
if horizontal == even:
|
| 250 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
| 251 |
+
if tait_bryan:
|
| 252 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
| 253 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _index_from_letter(letter: str) -> int:
|
| 257 |
+
if letter == "X":
|
| 258 |
+
return 0
|
| 259 |
+
if letter == "Y":
|
| 260 |
+
return 1
|
| 261 |
+
if letter == "Z":
|
| 262 |
+
return 2
|
| 263 |
+
raise ValueError("letter must be either X, Y or Z.")
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
|
| 267 |
+
"""
|
| 268 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 272 |
+
convention: Convention string of three uppercase letters.
|
| 273 |
+
|
| 274 |
+
Returns:
|
| 275 |
+
Euler angles in radians as tensor of shape (..., 3).
|
| 276 |
+
"""
|
| 277 |
+
if len(convention) != 3:
|
| 278 |
+
raise ValueError("Convention must have 3 letters.")
|
| 279 |
+
if convention[1] in (convention[0], convention[2]):
|
| 280 |
+
raise ValueError(f"Invalid convention {convention}.")
|
| 281 |
+
for letter in convention:
|
| 282 |
+
if letter not in ("X", "Y", "Z"):
|
| 283 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
| 284 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 285 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 286 |
+
i0 = _index_from_letter(convention[0])
|
| 287 |
+
i2 = _index_from_letter(convention[2])
|
| 288 |
+
tait_bryan = i0 != i2
|
| 289 |
+
if tait_bryan:
|
| 290 |
+
central_angle = torch.asin(
|
| 291 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
| 295 |
+
|
| 296 |
+
o = (
|
| 297 |
+
_angle_from_tan(
|
| 298 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
| 299 |
+
),
|
| 300 |
+
central_angle,
|
| 301 |
+
_angle_from_tan(
|
| 302 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
| 303 |
+
),
|
| 304 |
+
)
|
| 305 |
+
return torch.stack(o, -1)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def random_quaternions(
|
| 309 |
+
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
| 310 |
+
) -> torch.Tensor:
|
| 311 |
+
"""
|
| 312 |
+
Generate random quaternions representing rotations,
|
| 313 |
+
i.e. versors with nonnegative real part.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
n: Number of quaternions in a batch to return.
|
| 317 |
+
dtype: Type to return.
|
| 318 |
+
device: Desired device of returned tensor. Default:
|
| 319 |
+
uses the current device for the default tensor type.
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
Quaternions as tensor of shape (N, 4).
|
| 323 |
+
"""
|
| 324 |
+
if isinstance(device, str):
|
| 325 |
+
device = torch.device(device)
|
| 326 |
+
o = torch.randn((n, 4), dtype=dtype, device=device)
|
| 327 |
+
s = (o * o).sum(1)
|
| 328 |
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
| 329 |
+
return o
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def random_rotations(
|
| 333 |
+
n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
| 334 |
+
) -> torch.Tensor:
|
| 335 |
+
"""
|
| 336 |
+
Generate random rotations as 3x3 rotation matrices.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
n: Number of rotation matrices in a batch to return.
|
| 340 |
+
dtype: Type to return.
|
| 341 |
+
device: Device of returned tensor. Default: if None,
|
| 342 |
+
uses the current device for the default tensor type.
|
| 343 |
+
|
| 344 |
+
Returns:
|
| 345 |
+
Rotation matrices as tensor of shape (n, 3, 3).
|
| 346 |
+
"""
|
| 347 |
+
quaternions = random_quaternions(n, dtype=dtype, device=device)
|
| 348 |
+
return quaternion_to_matrix(quaternions)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def random_rotation(
|
| 352 |
+
dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
|
| 353 |
+
) -> torch.Tensor:
|
| 354 |
+
"""
|
| 355 |
+
Generate a single random 3x3 rotation matrix.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
dtype: Type to return
|
| 359 |
+
device: Device of returned tensor. Default: if None,
|
| 360 |
+
uses the current device for the default tensor type
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
Rotation matrix as tensor of shape (3, 3).
|
| 364 |
+
"""
|
| 365 |
+
return random_rotations(1, dtype, device)[0]
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 369 |
+
"""
|
| 370 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 371 |
+
part is non negative.
|
| 372 |
+
|
| 373 |
+
Args:
|
| 374 |
+
quaternions: Quaternions with real part first,
|
| 375 |
+
as tensor of shape (..., 4).
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 379 |
+
"""
|
| 380 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 384 |
+
"""
|
| 385 |
+
Multiply two quaternions.
|
| 386 |
+
Usual torch rules for broadcasting apply.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
| 390 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
| 391 |
+
|
| 392 |
+
Returns:
|
| 393 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
| 394 |
+
"""
|
| 395 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
| 396 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
| 397 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
| 398 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
| 399 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
| 400 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
| 401 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 405 |
+
"""
|
| 406 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
| 407 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
| 408 |
+
Usual torch rules for broadcasting apply.
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
| 412 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
| 416 |
+
"""
|
| 417 |
+
ab = quaternion_raw_multiply(a, b)
|
| 418 |
+
return standardize_quaternion(ab)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
|
| 422 |
+
"""
|
| 423 |
+
Given a quaternion representing rotation, get the quaternion representing
|
| 424 |
+
its inverse.
|
| 425 |
+
|
| 426 |
+
Args:
|
| 427 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
| 428 |
+
first, which must be versors (unit quaternions).
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
| 432 |
+
"""
|
| 433 |
+
|
| 434 |
+
scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
|
| 435 |
+
return quaternion * scaling
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
|
| 439 |
+
"""
|
| 440 |
+
Apply the rotation given by a quaternion to a 3D point.
|
| 441 |
+
Usual torch rules for broadcasting apply.
|
| 442 |
+
|
| 443 |
+
Args:
|
| 444 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
| 445 |
+
point: Tensor of 3D points of shape (..., 3).
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
Tensor of rotated points of shape (..., 3).
|
| 449 |
+
"""
|
| 450 |
+
if point.size(-1) != 3:
|
| 451 |
+
raise ValueError(f"Points are not in 3D, {point.shape}.")
|
| 452 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
| 453 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
| 454 |
+
out = quaternion_raw_multiply(
|
| 455 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
| 456 |
+
quaternion_invert(quaternion),
|
| 457 |
+
)
|
| 458 |
+
return out[..., 1:]
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
|
| 462 |
+
"""
|
| 463 |
+
Convert rotations given as axis/angle to rotation matrices.
|
| 464 |
+
|
| 465 |
+
Args:
|
| 466 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 467 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 468 |
+
the angle turned anticlockwise in radians around the
|
| 469 |
+
vector's direction.
|
| 470 |
+
|
| 471 |
+
Returns:
|
| 472 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 473 |
+
"""
|
| 474 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
|
| 478 |
+
"""
|
| 479 |
+
Convert rotations given as rotation matrices to axis/angle.
|
| 480 |
+
|
| 481 |
+
Args:
|
| 482 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 483 |
+
|
| 484 |
+
Returns:
|
| 485 |
+
Rotations given as a vector in axis angle form, as a tensor
|
| 486 |
+
of shape (..., 3), where the magnitude is the angle
|
| 487 |
+
turned anticlockwise in radians around the vector's
|
| 488 |
+
direction.
|
| 489 |
+
"""
|
| 490 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
|
| 494 |
+
"""
|
| 495 |
+
Convert rotations given as axis/angle to quaternions.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
| 499 |
+
as a tensor of shape (..., 3), where the magnitude is
|
| 500 |
+
the angle turned anticlockwise in radians around the
|
| 501 |
+
vector's direction.
|
| 502 |
+
|
| 503 |
+
Returns:
|
| 504 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
| 505 |
+
"""
|
| 506 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
| 507 |
+
half_angles = angles * 0.5
|
| 508 |
+
eps = 1e-6
|
| 509 |
+
small_angles = angles.abs() < eps
|
| 510 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 511 |
+
sin_half_angles_over_angles[~small_angles] = (
|
| 512 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 513 |
+
)
|
| 514 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 515 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 516 |
+
sin_half_angles_over_angles[small_angles] = (
|
| 517 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 518 |
+
)
|
| 519 |
+
quaternions = torch.cat(
|
| 520 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
| 521 |
+
)
|
| 522 |
+
return quaternions
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
| 526 |
+
"""
|
| 527 |
+
Convert rotations given as quaternions to axis/angle.
|
| 528 |
+
|
| 529 |
+
Args:
|
| 530 |
+
quaternions: quaternions with real part first,
|
| 531 |
+
as tensor of shape (..., 4).
|
| 532 |
+
|
| 533 |
+
Returns:
|
| 534 |
+
Rotations given as a vector in axis angle form, as a tensor
|
| 535 |
+
of shape (..., 3), where the magnitude is the angle
|
| 536 |
+
turned anticlockwise in radians around the vector's
|
| 537 |
+
direction.
|
| 538 |
+
"""
|
| 539 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
| 540 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
| 541 |
+
angles = 2 * half_angles
|
| 542 |
+
eps = 1e-6
|
| 543 |
+
small_angles = angles.abs() < eps
|
| 544 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
| 545 |
+
sin_half_angles_over_angles[~small_angles] = (
|
| 546 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
| 547 |
+
)
|
| 548 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
| 549 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
| 550 |
+
sin_half_angles_over_angles[small_angles] = (
|
| 551 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
| 552 |
+
)
|
| 553 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
| 557 |
+
"""
|
| 558 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
| 559 |
+
using Gram--Schmidt orthogonalization per Section B of [1].
|
| 560 |
+
Args:
|
| 561 |
+
d6: 6D rotation representation, of size (*, 6)
|
| 562 |
+
|
| 563 |
+
Returns:
|
| 564 |
+
batch of rotation matrices of size (*, 3, 3)
|
| 565 |
+
|
| 566 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
| 567 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
| 568 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
| 569 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
| 570 |
+
"""
|
| 571 |
+
|
| 572 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
| 573 |
+
b1 = F.normalize(a1, dim=-1)
|
| 574 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
| 575 |
+
b2 = F.normalize(b2, dim=-1)
|
| 576 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
| 577 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
| 581 |
+
"""
|
| 582 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
| 583 |
+
by dropping the last row. Note that 6D representation is not unique.
|
| 584 |
+
Args:
|
| 585 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
| 586 |
+
|
| 587 |
+
Returns:
|
| 588 |
+
6D rotation representation, of size (*, 6)
|
| 589 |
+
|
| 590 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
| 591 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
| 592 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
| 593 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
| 594 |
+
"""
|
| 595 |
+
batch_dim = matrix.size()[:-2]
|
| 596 |
+
return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
|
analysis/src/data/__init__.py
ADDED
|
File without changes
|
analysis/src/data/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (141 Bytes). View file
|
|
|
analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
analysis/src/data/components/__init__.py
ADDED
|
File without changes
|
analysis/src/data/components/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (152 Bytes). View file
|
|
|
analysis/src/data/components/__pycache__/dataset.cpython-39.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
analysis/src/data/components/dataset.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Protein dataset class."""
|
| 2 |
+
import os
|
| 3 |
+
import pickle
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from glob import glob
|
| 6 |
+
from typing import Optional, Sequence, List, Union
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
import tree
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
from src.common import residue_constants, data_transforms, rigid_utils, protein
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
CA_IDX = residue_constants.atom_order['CA']
|
| 19 |
+
DTYPE_MAPPING = {
|
| 20 |
+
'aatype': torch.long,
|
| 21 |
+
'atom_positions': torch.double,
|
| 22 |
+
'atom_mask': torch.double,
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class ProteinFeatureTransform:
|
| 27 |
+
def __init__(self,
|
| 28 |
+
unit: Optional[str] = 'angstrom',
|
| 29 |
+
truncate_length: Optional[int] = None,
|
| 30 |
+
strip_missing_residues: bool = True,
|
| 31 |
+
recenter_and_scale: bool = True,
|
| 32 |
+
eps: float = 1e-8,
|
| 33 |
+
):
|
| 34 |
+
if unit == 'angstrom':
|
| 35 |
+
self.coordinate_scale = 1.0
|
| 36 |
+
elif unit in ('nm', 'nanometer'):
|
| 37 |
+
self.coordiante_scale = 0.1
|
| 38 |
+
else:
|
| 39 |
+
raise ValueError(f"Invalid unit: {unit}")
|
| 40 |
+
|
| 41 |
+
if truncate_length is not None:
|
| 42 |
+
assert truncate_length > 0, f"Invalid truncate_length: {truncate_length}"
|
| 43 |
+
self.truncate_length = truncate_length
|
| 44 |
+
|
| 45 |
+
self.strip_missing_residues = strip_missing_residues
|
| 46 |
+
self.recenter_and_scale = recenter_and_scale
|
| 47 |
+
self.eps = eps
|
| 48 |
+
|
| 49 |
+
def __call__(self, chain_feats):
|
| 50 |
+
chain_feats = self.patch_feats(chain_feats)
|
| 51 |
+
|
| 52 |
+
if self.strip_missing_residues:
|
| 53 |
+
chain_feats = self.strip_ends(chain_feats)
|
| 54 |
+
|
| 55 |
+
if self.truncate_length is not None:
|
| 56 |
+
chain_feats = self.random_truncate(chain_feats, max_len=self.truncate_length)
|
| 57 |
+
|
| 58 |
+
# Recenter and scale atom positions
|
| 59 |
+
if self.recenter_and_scale:
|
| 60 |
+
chain_feats = self.recenter_and_scale_coords(chain_feats, coordinate_scale=self.coordinate_scale, eps=self.eps)
|
| 61 |
+
|
| 62 |
+
# Map to torch Tensor
|
| 63 |
+
chain_feats = self.map_to_tensors(chain_feats)
|
| 64 |
+
# Add extra features from AF2
|
| 65 |
+
chain_feats = self.protein_data_transform(chain_feats)
|
| 66 |
+
|
| 67 |
+
# ** refer to line 170 in pdb_data_loader.py **
|
| 68 |
+
return chain_feats
|
| 69 |
+
|
| 70 |
+
@staticmethod
|
| 71 |
+
def patch_feats(chain_feats):
|
| 72 |
+
seq_mask = chain_feats['atom_mask'][:, CA_IDX] # a little hack here
|
| 73 |
+
# residue_idx = np.arange(seq_mask.shape[0], dtype=np.int64)
|
| 74 |
+
residue_idx = chain_feats['residue_index'] - np.min(chain_feats['residue_index']) # start from 0, possibly has chain break
|
| 75 |
+
patch_feats = {
|
| 76 |
+
'seq_mask': seq_mask,
|
| 77 |
+
'residue_mask': seq_mask,
|
| 78 |
+
'residue_idx': residue_idx,
|
| 79 |
+
'fixed_mask': np.zeros_like(seq_mask),
|
| 80 |
+
'sc_ca_t': np.zeros(seq_mask.shape + (3, )),
|
| 81 |
+
}
|
| 82 |
+
chain_feats.update(patch_feats)
|
| 83 |
+
return chain_feats
|
| 84 |
+
|
| 85 |
+
@staticmethod
|
| 86 |
+
def strip_ends(chain_feats):
|
| 87 |
+
# Strip missing residues on both ends
|
| 88 |
+
modeled_idx = np.where(chain_feats['aatype'] != 20)[0]
|
| 89 |
+
min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx)
|
| 90 |
+
chain_feats = tree.map_structure(
|
| 91 |
+
lambda x: x[min_idx : (max_idx+1)], chain_feats)
|
| 92 |
+
return chain_feats
|
| 93 |
+
|
| 94 |
+
@staticmethod
|
| 95 |
+
def random_truncate(chain_feats, max_len):
|
| 96 |
+
L = chain_feats['aatype'].shape[0]
|
| 97 |
+
if L > max_len:
|
| 98 |
+
# Randomly truncate
|
| 99 |
+
start = np.random.randint(0, L - max_len + 1)
|
| 100 |
+
end = start + max_len
|
| 101 |
+
chain_feats = tree.map_structure(
|
| 102 |
+
lambda x: x[start : end], chain_feats)
|
| 103 |
+
return chain_feats
|
| 104 |
+
|
| 105 |
+
@staticmethod
|
| 106 |
+
def map_to_tensors(chain_feats):
|
| 107 |
+
chain_feats = {k: torch.as_tensor(v) for k,v in chain_feats.items()}
|
| 108 |
+
# Alter dtype
|
| 109 |
+
for k, dtype in DTYPE_MAPPING.items():
|
| 110 |
+
if k in chain_feats:
|
| 111 |
+
chain_feats[k] = chain_feats[k].type(dtype)
|
| 112 |
+
return chain_feats
|
| 113 |
+
|
| 114 |
+
@staticmethod
|
| 115 |
+
def recenter_and_scale_coords(chain_feats, coordinate_scale, eps=1e-8):
|
| 116 |
+
# recenter and scale atom positions
|
| 117 |
+
bb_pos = chain_feats['atom_positions'][:, CA_IDX]
|
| 118 |
+
bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['seq_mask']) + eps)
|
| 119 |
+
centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
|
| 120 |
+
scaled_pos = centered_pos * coordinate_scale
|
| 121 |
+
chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None]
|
| 122 |
+
return chain_feats
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def protein_data_transform(chain_feats):
|
| 126 |
+
chain_feats.update(
|
| 127 |
+
{
|
| 128 |
+
"all_atom_positions": chain_feats["atom_positions"],
|
| 129 |
+
"all_atom_mask": chain_feats["atom_mask"],
|
| 130 |
+
}
|
| 131 |
+
)
|
| 132 |
+
chain_feats = data_transforms.atom37_to_frames(chain_feats)
|
| 133 |
+
chain_feats = data_transforms.atom37_to_torsion_angles("")(chain_feats)
|
| 134 |
+
chain_feats = data_transforms.get_backbone_frames(chain_feats)
|
| 135 |
+
chain_feats = data_transforms.get_chi_angles(chain_feats)
|
| 136 |
+
chain_feats = data_transforms.make_pseudo_beta("")(chain_feats)
|
| 137 |
+
chain_feats = data_transforms.make_atom14_masks(chain_feats)
|
| 138 |
+
chain_feats = data_transforms.make_atom14_positions(chain_feats)
|
| 139 |
+
|
| 140 |
+
# Add convenient key
|
| 141 |
+
chain_feats.pop("all_atom_positions")
|
| 142 |
+
chain_feats.pop("all_atom_mask")
|
| 143 |
+
return chain_feats
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class MetadataFilter:
|
| 147 |
+
def __init__(self,
|
| 148 |
+
min_len: Optional[int] = None,
|
| 149 |
+
max_len: Optional[int] = None,
|
| 150 |
+
min_chains: Optional[int] = None,
|
| 151 |
+
max_chains: Optional[int] = None,
|
| 152 |
+
min_resolution: Optional[int] = None,
|
| 153 |
+
max_resolution: Optional[int] = None,
|
| 154 |
+
include_structure_method: Optional[List[str]] = None,
|
| 155 |
+
include_oligomeric_detail: Optional[List[str]] = None,
|
| 156 |
+
**kwargs,
|
| 157 |
+
):
|
| 158 |
+
self.min_len = min_len
|
| 159 |
+
self.max_len = max_len
|
| 160 |
+
self.min_chains = min_chains
|
| 161 |
+
self.max_chains = max_chains
|
| 162 |
+
self.min_resolution = min_resolution
|
| 163 |
+
self.max_resolution = max_resolution
|
| 164 |
+
self.include_structure_method = include_structure_method
|
| 165 |
+
self.include_oligomeric_detail = include_oligomeric_detail
|
| 166 |
+
|
| 167 |
+
def __call__(self, df):
|
| 168 |
+
_pre_filter_len = len(df)
|
| 169 |
+
if self.min_len is not None:
|
| 170 |
+
df = df[df['raw_seq_len'] >= self.min_len]
|
| 171 |
+
if self.max_len is not None:
|
| 172 |
+
df = df[df['raw_seq_len'] <= self.max_len]
|
| 173 |
+
if self.min_chains is not None:
|
| 174 |
+
df = df[df['num_chains'] >= self.min_chains]
|
| 175 |
+
if self.max_chains is not None:
|
| 176 |
+
df = df[df['num_chains'] <= self.max_chains]
|
| 177 |
+
if self.min_resolution is not None:
|
| 178 |
+
df = df[df['resolution'] >= self.min_resolution]
|
| 179 |
+
if self.max_resolution is not None:
|
| 180 |
+
df = df[df['resolution'] <= self.max_resolution]
|
| 181 |
+
if self.include_structure_method is not None:
|
| 182 |
+
df = df[df['include_structure_method'].isin(self.include_structure_method)]
|
| 183 |
+
if self.include_oligomeric_detail is not None:
|
| 184 |
+
df = df[df['include_oligomeric_detail'].isin(self.include_oligomeric_detail)]
|
| 185 |
+
|
| 186 |
+
print(f">>> Filter out {len(df)} samples out of {_pre_filter_len} by the metadata filter")
|
| 187 |
+
return df
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class RandomAccessProteinDataset(torch.utils.data.Dataset):
|
| 191 |
+
"""Random access to pickle protein objects of dataset.
|
| 192 |
+
|
| 193 |
+
dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors'])
|
| 194 |
+
|
| 195 |
+
Note that each value is a ndarray in shape (L, *), for example:
|
| 196 |
+
'atom_positions': (L, 37, 3)
|
| 197 |
+
"""
|
| 198 |
+
def __init__(self,
|
| 199 |
+
path_to_dataset: Union[Path, str],
|
| 200 |
+
path_to_seq_embedding: Optional[Path] = None,
|
| 201 |
+
metadata_filter: Optional[MetadataFilter] = None,
|
| 202 |
+
training: bool = True,
|
| 203 |
+
transform: Optional[ProteinFeatureTransform] = None,
|
| 204 |
+
suffix: Optional[str] = '.pkl',
|
| 205 |
+
accession_code_fillter: Optional[Sequence[str]] = None,
|
| 206 |
+
**kwargs,
|
| 207 |
+
):
|
| 208 |
+
super().__init__()
|
| 209 |
+
path_to_dataset = os.path.expanduser(path_to_dataset)
|
| 210 |
+
suffix = suffix if suffix.startswith('.') else '.' + suffix
|
| 211 |
+
assert suffix in ('.pkl', '.pdb'), f"Invalid suffix: {suffix}"
|
| 212 |
+
|
| 213 |
+
if os.path.isfile(path_to_dataset): # path to csv file
|
| 214 |
+
assert path_to_dataset.endswith('.csv'), f"Invalid file extension: {path_to_dataset} (have to be .csv)"
|
| 215 |
+
self._df = pd.read_csv(path_to_dataset)
|
| 216 |
+
self._df.sort_values('modeled_seq_len', ascending=False)
|
| 217 |
+
if metadata_filter:
|
| 218 |
+
self._df = metadata_filter(self._df)
|
| 219 |
+
self._data = self._df['processed_complex_path'].tolist()
|
| 220 |
+
elif os.path.isdir(path_to_dataset): # path to directory
|
| 221 |
+
self._data = sorted(glob(os.path.join(path_to_dataset, '*' + suffix)))
|
| 222 |
+
assert len(self._data) > 0, f"No {suffix} file found in '{path_to_dataset}'"
|
| 223 |
+
else: # path as glob pattern
|
| 224 |
+
_pattern = path_to_dataset
|
| 225 |
+
self._data = sorted(glob(_pattern))
|
| 226 |
+
assert len(self._data) > 0, f"No files found in '{_pattern}'"
|
| 227 |
+
|
| 228 |
+
if accession_code_fillter and len(accession_code_fillter) > 0:
|
| 229 |
+
self._data = [p for p in self._data
|
| 230 |
+
if np.isin(os.path.splitext(os.path.basename(p))[0], accession_code_fillter)
|
| 231 |
+
]
|
| 232 |
+
|
| 233 |
+
self.data = np.asarray(self._data)
|
| 234 |
+
self.path_to_seq_embedding = os.path.expanduser(path_to_seq_embedding) \
|
| 235 |
+
if path_to_seq_embedding is not None else None
|
| 236 |
+
self.suffix = suffix
|
| 237 |
+
self.transform = transform
|
| 238 |
+
self.training = training # not implemented yet
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def num_samples(self):
|
| 243 |
+
return len(self.data)
|
| 244 |
+
|
| 245 |
+
def len(self):
|
| 246 |
+
return self.__len__()
|
| 247 |
+
|
| 248 |
+
def __len__(self):
|
| 249 |
+
return self.num_samples
|
| 250 |
+
|
| 251 |
+
def get(self, idx):
|
| 252 |
+
return self.__getitem__(idx)
|
| 253 |
+
|
| 254 |
+
@lru_cache(maxsize=100)
|
| 255 |
+
def __getitem__(self, idx):
|
| 256 |
+
"""return single pyg.Data() instance
|
| 257 |
+
"""
|
| 258 |
+
data_path = self.data[idx]
|
| 259 |
+
accession_code = os.path.splitext(os.path.basename(data_path))[0]
|
| 260 |
+
|
| 261 |
+
if self.suffix == '.pkl':
|
| 262 |
+
# Load pickled protein
|
| 263 |
+
with open(data_path, 'rb') as f:
|
| 264 |
+
data_object = pickle.load(f)
|
| 265 |
+
elif self.suffix == '.pdb':
|
| 266 |
+
# Load pdb file
|
| 267 |
+
with open(data_path, 'r') as f:
|
| 268 |
+
pdb_string = f.read()
|
| 269 |
+
data_object = protein.from_pdb_string(pdb_string).to_dict()
|
| 270 |
+
|
| 271 |
+
# Apply data transform
|
| 272 |
+
if self.transform is not None:
|
| 273 |
+
data_object = self.transform(data_object)
|
| 274 |
+
|
| 275 |
+
# Get sequence embedding if have
|
| 276 |
+
if self.path_to_seq_embedding is not None:
|
| 277 |
+
embed_dict = torch.load(
|
| 278 |
+
os.path.join(self.path_to_seq_embedding, f"{accession_code}.pt")
|
| 279 |
+
)
|
| 280 |
+
data_object.update(
|
| 281 |
+
{
|
| 282 |
+
'seq_emb': embed_dict['representations'][33].float(),
|
| 283 |
+
} # 33 is for ESM650M
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
data_object['accession_code'] = accession_code
|
| 287 |
+
return data_object # dict of arrays
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
class PretrainPDBDataset(RandomAccessProteinDataset):
|
| 292 |
+
def __init__(self,
|
| 293 |
+
path_to_dataset: str,
|
| 294 |
+
metadata_filter: MetadataFilter,
|
| 295 |
+
transform: ProteinFeatureTransform,
|
| 296 |
+
**kwargs,
|
| 297 |
+
):
|
| 298 |
+
super(PretrainPDBDataset, self).__init__(path_to_dataset=path_to_dataset,
|
| 299 |
+
metadata_filter=metadata_filter,
|
| 300 |
+
transform=transform,
|
| 301 |
+
**kwargs,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class SamplingPDBDataset(RandomAccessProteinDataset):
|
| 306 |
+
def __init__(self,
|
| 307 |
+
path_to_dataset: str,
|
| 308 |
+
training: bool = False,
|
| 309 |
+
suffix: str = '.pdb',
|
| 310 |
+
transform: Optional[ProteinFeatureTransform] = None,
|
| 311 |
+
accession_code_fillter: Optional[Sequence[str]] = None,
|
| 312 |
+
):
|
| 313 |
+
assert os.path.isdir(path_to_dataset), f"Invalid path (expected to be directory): {path_to_dataset}"
|
| 314 |
+
super(SamplingPDBDataset, self).__init__(path_to_dataset=path_to_dataset,
|
| 315 |
+
training=training,
|
| 316 |
+
suffix=suffix,
|
| 317 |
+
transform=transform,
|
| 318 |
+
accession_code_fillter=accession_code_fillter,
|
| 319 |
+
metadata_filter=None,
|
| 320 |
+
)
|
| 321 |
+
|
analysis/src/data/protein_datamodule.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Optional, Tuple, List, Sequence
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
|
| 5 |
+
from lightning import LightningDataModule
|
| 6 |
+
from hydra.utils import instantiate
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BatchTensorConverter:
|
| 10 |
+
"""Callable to convert an unprocessed (labels + strings) batch to a
|
| 11 |
+
processed (labels + tensor) batch.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, target_keys: Optional[List] = None):
|
| 14 |
+
self.target_keys = target_keys
|
| 15 |
+
|
| 16 |
+
def __call__(self, raw_batch: Sequence[Dict[str, object]]):
|
| 17 |
+
B = len(raw_batch)
|
| 18 |
+
# Only do for Tensor
|
| 19 |
+
target_keys = self.target_keys \
|
| 20 |
+
if self.target_keys is not None else [k for k,v in raw_batch[0].items() if torch.is_tensor(v)]
|
| 21 |
+
# Non-array, for example string, int
|
| 22 |
+
non_array_keys = [k for k in raw_batch[0] if k not in target_keys]
|
| 23 |
+
collated_batch = dict()
|
| 24 |
+
for k in target_keys:
|
| 25 |
+
collated_batch[k] = self.collate_dense_tensors([d[k] for d in raw_batch], pad_v=0.0)
|
| 26 |
+
for k in non_array_keys: # return non-array keys as is
|
| 27 |
+
collated_batch[k] = [d[k] for d in raw_batch]
|
| 28 |
+
return collated_batch
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def collate_dense_tensors(samples: Sequence, pad_v: float = 0.0):
|
| 32 |
+
"""
|
| 33 |
+
Takes a list of tensors with the following dimensions:
|
| 34 |
+
[(d_11, ..., d_1K),
|
| 35 |
+
(d_21, ..., d_2K),
|
| 36 |
+
...,
|
| 37 |
+
(d_N1, ..., d_NK)]
|
| 38 |
+
and stack + pads them into a single tensor of:
|
| 39 |
+
(N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
|
| 40 |
+
"""
|
| 41 |
+
if len(samples) == 0:
|
| 42 |
+
return torch.Tensor()
|
| 43 |
+
if len(set(x.dim() for x in samples)) != 1:
|
| 44 |
+
raise RuntimeError(
|
| 45 |
+
f"Samples has varying dimensions: {[x.dim() for x in samples]}"
|
| 46 |
+
)
|
| 47 |
+
(device,) = tuple(set(x.device for x in samples)) # assumes all on same device
|
| 48 |
+
max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
|
| 49 |
+
result = torch.empty(
|
| 50 |
+
len(samples), *max_shape, dtype=samples[0].dtype, device=device
|
| 51 |
+
)
|
| 52 |
+
result.fill_(pad_v)
|
| 53 |
+
for i in range(len(samples)):
|
| 54 |
+
result_i = result[i]
|
| 55 |
+
t = samples[i]
|
| 56 |
+
result_i[tuple(slice(0, k) for k in t.shape)] = t
|
| 57 |
+
return result
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ProteinDataModule(LightningDataModule):
|
| 61 |
+
"""`LightningDataModule` for a single protein dataset,
|
| 62 |
+
for pretrain or finetune purpose.
|
| 63 |
+
|
| 64 |
+
### To be revised.###
|
| 65 |
+
|
| 66 |
+
The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.
|
| 67 |
+
It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a
|
| 68 |
+
fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box
|
| 69 |
+
while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing
|
| 70 |
+
technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of
|
| 71 |
+
mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
|
| 72 |
+
|
| 73 |
+
A `LightningDataModule` implements 7 key methods:
|
| 74 |
+
|
| 75 |
+
```python
|
| 76 |
+
def prepare_data(self):
|
| 77 |
+
# Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
|
| 78 |
+
# Download data, pre-process, split, save to disk, etc...
|
| 79 |
+
|
| 80 |
+
def setup(self, stage):
|
| 81 |
+
# Things to do on every process in DDP.
|
| 82 |
+
# Load data, set variables, etc...
|
| 83 |
+
|
| 84 |
+
def train_dataloader(self):
|
| 85 |
+
# return train dataloader
|
| 86 |
+
|
| 87 |
+
def val_dataloader(self):
|
| 88 |
+
# return validation dataloader
|
| 89 |
+
|
| 90 |
+
def test_dataloader(self):
|
| 91 |
+
# return test dataloader
|
| 92 |
+
|
| 93 |
+
def predict_dataloader(self):
|
| 94 |
+
# return predict dataloader
|
| 95 |
+
|
| 96 |
+
def teardown(self, stage):
|
| 97 |
+
# Called on every process in DDP.
|
| 98 |
+
# Clean up after fit or test.
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
This allows you to share a full dataset without explaining how to download,
|
| 102 |
+
split, transform and process the data.
|
| 103 |
+
|
| 104 |
+
Read the docs:
|
| 105 |
+
https://lightning.ai/docs/pytorch/latest/data/datamodule.html
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self,
|
| 110 |
+
dataset: torch.utils.data.Dataset,
|
| 111 |
+
batch_size: int = 64,
|
| 112 |
+
generator_seed: int = 42,
|
| 113 |
+
train_val_split: Tuple[float, float] = (0.95, 0.05),
|
| 114 |
+
num_workers: int = 0,
|
| 115 |
+
pin_memory: bool = False,
|
| 116 |
+
shuffle: bool = False,
|
| 117 |
+
) -> None:
|
| 118 |
+
"""Initialize a `MNISTDataModule`.
|
| 119 |
+
|
| 120 |
+
:param data_dir: The data directory. Defaults to `"data/"`.
|
| 121 |
+
:param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
|
| 122 |
+
:param batch_size: The batch size. Defaults to `64`.
|
| 123 |
+
:param num_workers: The number of workers. Defaults to `0`.
|
| 124 |
+
:param pin_memory: Whether to pin memory. Defaults to `False`.
|
| 125 |
+
"""
|
| 126 |
+
super().__init__()
|
| 127 |
+
|
| 128 |
+
# this line allows to access init params with 'self.hparams' attribute
|
| 129 |
+
# also ensures init params will be stored in ckpt
|
| 130 |
+
self.save_hyperparameters(logger=False)
|
| 131 |
+
|
| 132 |
+
self.dataset = dataset
|
| 133 |
+
|
| 134 |
+
self.data_train: Optional[Dataset] = None
|
| 135 |
+
self.data_val: Optional[Dataset] = None
|
| 136 |
+
self.data_test: Optional[Dataset] = None
|
| 137 |
+
|
| 138 |
+
self.batch_size_per_device = batch_size
|
| 139 |
+
|
| 140 |
+
def prepare_data(self) -> None:
|
| 141 |
+
"""Download data if needed. Lightning ensures that `self.prepare_data()` is called only
|
| 142 |
+
within a single process on CPU, so you can safely add your downloading logic within. In
|
| 143 |
+
case of multi-node training, the execution of this hook depends upon
|
| 144 |
+
`self.prepare_data_per_node()`.
|
| 145 |
+
|
| 146 |
+
Do not use it to assign state (self.x = y).
|
| 147 |
+
"""
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
| 151 |
+
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
| 152 |
+
|
| 153 |
+
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
|
| 154 |
+
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
|
| 155 |
+
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
|
| 156 |
+
`self.setup()` once the data is prepared and available for use.
|
| 157 |
+
|
| 158 |
+
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
|
| 159 |
+
"""
|
| 160 |
+
# Divide batch size by the number of devices.
|
| 161 |
+
if self.trainer is not None:
|
| 162 |
+
if self.hparams.batch_size % self.trainer.world_size != 0:
|
| 163 |
+
raise RuntimeError(
|
| 164 |
+
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
|
| 165 |
+
)
|
| 166 |
+
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
|
| 167 |
+
|
| 168 |
+
# load and split datasets only if not loaded already
|
| 169 |
+
if stage == 'fit' and not self.data_train and not self.data_val:
|
| 170 |
+
# dataset = ConcatDataset(datasets=[trainset, testset])
|
| 171 |
+
self.data_train, self.data_val = random_split(
|
| 172 |
+
dataset=self.dataset,
|
| 173 |
+
lengths=self.hparams.train_val_split,
|
| 174 |
+
generator=torch.Generator().manual_seed(self.hparams.generator_seed),
|
| 175 |
+
)
|
| 176 |
+
elif stage in ('predict', 'test'):
|
| 177 |
+
self.data_test = self.dataset
|
| 178 |
+
else:
|
| 179 |
+
raise NotImplementedError(f"Stage {stage} not implemented.")
|
| 180 |
+
|
| 181 |
+
def _dataloader_template(self, dataset: Dataset[Any]) -> DataLoader[Any]:
|
| 182 |
+
"""Create a dataloader from a dataset.
|
| 183 |
+
|
| 184 |
+
:param dataset: The dataset.
|
| 185 |
+
:return: The dataloader.
|
| 186 |
+
"""
|
| 187 |
+
batch_collator = BatchTensorConverter() # list of dicts -> dict of tensors
|
| 188 |
+
return DataLoader(
|
| 189 |
+
dataset=dataset,
|
| 190 |
+
collate_fn=batch_collator,
|
| 191 |
+
batch_size=self.batch_size_per_device,
|
| 192 |
+
num_workers=self.hparams.num_workers,
|
| 193 |
+
pin_memory=self.hparams.pin_memory,
|
| 194 |
+
shuffle=self.hparams.shuffle,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def train_dataloader(self) -> DataLoader[Any]:
|
| 198 |
+
"""Create and return the train dataloader.
|
| 199 |
+
|
| 200 |
+
:return: The train dataloader.
|
| 201 |
+
"""
|
| 202 |
+
return self._dataloader_template(self.data_train)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def val_dataloader(self) -> DataLoader[Any]:
|
| 206 |
+
"""Create and return the validation dataloader.
|
| 207 |
+
|
| 208 |
+
:return: The validation dataloader.
|
| 209 |
+
"""
|
| 210 |
+
return self._dataloader_template(self.data_val)
|
| 211 |
+
|
| 212 |
+
def test_dataloader(self) -> DataLoader[Any]:
|
| 213 |
+
"""Create and return the test dataloader.
|
| 214 |
+
|
| 215 |
+
:return: The test dataloader.
|
| 216 |
+
"""
|
| 217 |
+
return self._dataloader_template(self.data_test)
|
| 218 |
+
|
| 219 |
+
def teardown(self, stage: Optional[str] = None) -> None:
|
| 220 |
+
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
|
| 221 |
+
`trainer.test()`, and `trainer.predict()`.
|
| 222 |
+
|
| 223 |
+
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
| 224 |
+
Defaults to ``None``.
|
| 225 |
+
"""
|
| 226 |
+
pass
|
| 227 |
+
|
| 228 |
+
def state_dict(self) -> Dict[Any, Any]:
|
| 229 |
+
"""Called when saving a checkpoint. Implement to generate and save the datamodule state.
|
| 230 |
+
|
| 231 |
+
:return: A dictionary containing the datamodule state that you want to save.
|
| 232 |
+
"""
|
| 233 |
+
return {}
|
| 234 |
+
|
| 235 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 236 |
+
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule
|
| 237 |
+
`state_dict()`.
|
| 238 |
+
|
| 239 |
+
:param state_dict: The datamodule state returned by `self.state_dict()`.
|
| 240 |
+
"""
|
| 241 |
+
pass
|
| 242 |
+
|
analysis/src/eval.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Tuple
|
| 2 |
+
import os
|
| 3 |
+
from time import strftime
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torch
|
| 8 |
+
# import hydra
|
| 9 |
+
# import rootutils
|
| 10 |
+
# from lightning import LightningDataModule, LightningModule, Trainer
|
| 11 |
+
# from lightning.pytorch.loggers import Logger
|
| 12 |
+
from omegaconf import DictConfig
|
| 13 |
+
|
| 14 |
+
# rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 15 |
+
|
| 16 |
+
# ------------------------------------------------------------------------------------ #
|
| 17 |
+
# the setup_root above is equivalent to:
|
| 18 |
+
# - adding project root dir to PYTHONPATH
|
| 19 |
+
# (so you don't need to force user to install project as a package)
|
| 20 |
+
# (necessary before importing any local modules e.g. `from src import utils`)
|
| 21 |
+
# - setting up PROJECT_ROOT environment variable
|
| 22 |
+
# (which is used as a base for paths in "configs/paths/default.yaml")
|
| 23 |
+
# (this way all filepaths are the same no matter where you run the code)
|
| 24 |
+
# - loading environment variables from ".env" in root dir
|
| 25 |
+
#
|
| 26 |
+
# you can remove it if you:
|
| 27 |
+
# 1. either install project as a package or move entry files to project root dir
|
| 28 |
+
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
| 29 |
+
#
|
| 30 |
+
# more info: https://github.com/ashleve/rootutils
|
| 31 |
+
# ------------------------------------------------------------------------------------ #
|
| 32 |
+
|
| 33 |
+
from src.utils import (
|
| 34 |
+
RankedLogger,
|
| 35 |
+
extras,
|
| 36 |
+
instantiate_loggers,
|
| 37 |
+
log_hyperparameters,
|
| 38 |
+
task_wrapper,
|
| 39 |
+
checkpoint_utils,
|
| 40 |
+
plot_utils,
|
| 41 |
+
)
|
| 42 |
+
from src.common.pdb_utils import extract_backbone_coords
|
| 43 |
+
from src.metrics import metrics
|
| 44 |
+
from src.common.geo_utils import _find_rigid_alignment
|
| 45 |
+
|
| 46 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def evaluate_prediction(pred_dir: str, target_dir: str = None, crystal_dir: str = None, tag: str = None):
|
| 50 |
+
"""Evaluate prediction results based on pdb files.
|
| 51 |
+
"""
|
| 52 |
+
if target_dir is None or not os.path.isdir(target_dir):
|
| 53 |
+
log.warning(f"target_dir {target_dir} does not exist. Skip evaluation.")
|
| 54 |
+
return {}
|
| 55 |
+
|
| 56 |
+
assert os.path.isdir(pred_dir), f"pred_dir {pred_dir} is not a directory."
|
| 57 |
+
|
| 58 |
+
targets = [
|
| 59 |
+
d.replace(".pdb", "") for d in os.listdir(target_dir)
|
| 60 |
+
]
|
| 61 |
+
# pred_bases = os.listdir(pred_dir)
|
| 62 |
+
output_dir = pred_dir
|
| 63 |
+
tag = tag if tag is not None else "dev"
|
| 64 |
+
timestamp = strftime("%m%d-%H-%M")
|
| 65 |
+
|
| 66 |
+
fns = {
|
| 67 |
+
'val_clash': metrics.validity,
|
| 68 |
+
'val_bond': metrics.bonding_validity,
|
| 69 |
+
'js_pwd': metrics.js_pwd,
|
| 70 |
+
'js_rg': metrics.js_rg,
|
| 71 |
+
# 'js_tica_pos': metrics.js_tica_pos,
|
| 72 |
+
'w2_rmwd': metrics.w2_rmwd,
|
| 73 |
+
# 'div_rmsd': metrics.div_rmsd,
|
| 74 |
+
'div_rmsf': metrics.div_rmsf,
|
| 75 |
+
'pro_w_contacks': metrics.pro_w_contacts,
|
| 76 |
+
'pro_t_contacks': metrics.pro_t_contacts,
|
| 77 |
+
# 'pro_c_contacks': metrics.pro_c_contacts,
|
| 78 |
+
}
|
| 79 |
+
eval_res = {k: {} for k in fns}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
print(f"total_md_num = {len(targets)}")
|
| 83 |
+
count = 0
|
| 84 |
+
for target in targets:
|
| 85 |
+
count += 1
|
| 86 |
+
print("")
|
| 87 |
+
print(count, target)
|
| 88 |
+
pred_file = os.path.join(pred_dir, f"{target}.pdb")
|
| 89 |
+
# assert os.path.isfile(pred_file), f"pred_file {pred_file} does not exist."
|
| 90 |
+
if not os.path.isfile(pred_file):
|
| 91 |
+
continue
|
| 92 |
+
|
| 93 |
+
target_file = os.path.join(target_dir, f"{target}.pdb")
|
| 94 |
+
ca_coords = {
|
| 95 |
+
'target': extract_backbone_coords(target_file),
|
| 96 |
+
'pred': extract_backbone_coords(pred_file),
|
| 97 |
+
}
|
| 98 |
+
cry_target_file = os.path.join(crystal_dir, f"{target}.pdb")
|
| 99 |
+
cry_ca_coords = extract_backbone_coords(cry_target_file)[0]
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
for f_name, func in fns.items():
|
| 103 |
+
print(f_name)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if f_name == 'w2_rmwd':
|
| 107 |
+
v_ref = torch.as_tensor(ca_coords['target'][0])
|
| 108 |
+
for k, v in ca_coords.items():
|
| 109 |
+
v = torch.as_tensor(v) # (250,356,3)
|
| 110 |
+
for idx in range(v.shape[0]):
|
| 111 |
+
R, t = _find_rigid_alignment(v[idx], v_ref)
|
| 112 |
+
v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
|
| 113 |
+
ca_coords[k] = v.numpy()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if f_name.startswith('js_'):
|
| 117 |
+
res = func(ca_coords, ref_key='target')
|
| 118 |
+
elif f_name == 'pro_c_contacks':
|
| 119 |
+
res = func(target_file, pred_file, cry_target_file)
|
| 120 |
+
elif f_name.startswith('pro_'):
|
| 121 |
+
res = func(ca_coords, cry_ca_coords)
|
| 122 |
+
else:
|
| 123 |
+
res = func(ca_coords)
|
| 124 |
+
|
| 125 |
+
if f_name == 'js_tica' or f_name == 'js_tica_pos':
|
| 126 |
+
pass
|
| 127 |
+
# eval_res[f_name][target] = res[0]['pred']
|
| 128 |
+
# save_to = os.path.join(output_dir, f"tica_{target}_{tag}_{timestamp}.png")
|
| 129 |
+
# plot_utils.scatterplot_2d(res[1], save_to=save_to, ref_key='target')
|
| 130 |
+
else:
|
| 131 |
+
eval_res[f_name][target] = res['pred']
|
| 132 |
+
|
| 133 |
+
csv_save_to = os.path.join(output_dir, f"metrics_{tag}_{timestamp}.csv")
|
| 134 |
+
df = pd.DataFrame.from_dict(eval_res) # row = target, col = metric name
|
| 135 |
+
df.to_csv(csv_save_to)
|
| 136 |
+
print(f"metrics saved to {csv_save_to}")
|
| 137 |
+
mean_metrics = np.around(df.mean(), decimals=4)
|
| 138 |
+
|
| 139 |
+
return mean_metrics
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
# @task_wrapper
|
| 143 |
+
# def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 144 |
+
# """Sample on a test set and report evaluation metrics.
|
| 145 |
+
|
| 146 |
+
# This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
| 147 |
+
# failure. Useful for multiruns, saving info about the crash, etc.
|
| 148 |
+
|
| 149 |
+
# :param cfg: DictConfig configuration composed by Hydra.
|
| 150 |
+
# :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
|
| 151 |
+
# """
|
| 152 |
+
# # assert cfg.ckpt_path
|
| 153 |
+
# pred_dir = cfg.get("pred_dir")
|
| 154 |
+
# if pred_dir and os.path.isdir(pred_dir):
|
| 155 |
+
# log.info(f"Found pre-computed prediction directory {pred_dir}.")
|
| 156 |
+
# metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir)
|
| 157 |
+
# return metric_dict, None
|
| 158 |
+
|
| 159 |
+
# log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
| 160 |
+
# datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
| 161 |
+
|
| 162 |
+
# log.info(f"Instantiating model <{cfg.model._target_}>")
|
| 163 |
+
# model: LightningModule = hydra.utils.instantiate(cfg.model)
|
| 164 |
+
|
| 165 |
+
# log.info("Instantiating loggers...")
|
| 166 |
+
# logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
| 167 |
+
|
| 168 |
+
# log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
| 169 |
+
# trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
|
| 170 |
+
|
| 171 |
+
# object_dict = {
|
| 172 |
+
# "cfg": cfg,
|
| 173 |
+
# "datamodule": datamodule,
|
| 174 |
+
# "model": model,
|
| 175 |
+
# "logger": logger,
|
| 176 |
+
# "trainer": trainer,
|
| 177 |
+
# }
|
| 178 |
+
|
| 179 |
+
# if logger:
|
| 180 |
+
# log.info("Logging hyperparameters!")
|
| 181 |
+
# log_hyperparameters(object_dict)
|
| 182 |
+
|
| 183 |
+
# # Load checkpoint manually.
|
| 184 |
+
# model, ckpt_path = checkpoint_utils.load_model_checkpoint(model, cfg.ckpt_path)
|
| 185 |
+
|
| 186 |
+
# # log.info("Starting testing!")
|
| 187 |
+
# # trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
|
| 188 |
+
|
| 189 |
+
# # Get dataloader for prediction.
|
| 190 |
+
# datamodule.setup(stage="predict")
|
| 191 |
+
# dataloaders = datamodule.test_dataloader()
|
| 192 |
+
|
| 193 |
+
# log.info("Starting predictions.")
|
| 194 |
+
# pred_dir = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=ckpt_path)[-1]
|
| 195 |
+
|
| 196 |
+
# # metric_dict = trainer.callback_metrics
|
| 197 |
+
# log.info("Starting evaluations.")
|
| 198 |
+
# metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir)
|
| 199 |
+
|
| 200 |
+
# return metric_dict, object_dict
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
|
| 204 |
+
# def main(cfg: DictConfig) -> None:
|
| 205 |
+
# """Main entry point for evaluation.
|
| 206 |
+
|
| 207 |
+
# :param cfg: DictConfig configuration composed by Hydra.
|
| 208 |
+
# """
|
| 209 |
+
# # apply extra utilities
|
| 210 |
+
# # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
| 211 |
+
# extras(cfg)
|
| 212 |
+
|
| 213 |
+
# evaluate(cfg)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# if __name__ == "__main__":
|
| 217 |
+
# main()
|