Holmes commited on
Commit
ca7299e
·
1 Parent(s): 1af230e
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +134 -3
  2. analysis/Ramachandran_plot.py +99 -0
  3. analysis/__pycache__/Ramachandran_plot.cpython-310.pyc +0 -0
  4. analysis/__pycache__/merge_pred_pdb.cpython-310.pyc +0 -0
  5. analysis/__pycache__/metrics.cpython-310.pyc +0 -0
  6. analysis/__pycache__/utils.cpython-310.pyc +0 -0
  7. analysis/__pycache__/utils.cpython-38.pyc +0 -0
  8. analysis/eval_result.py +66 -0
  9. analysis/merge_pred_pdb.py +45 -0
  10. analysis/metrics.py +54 -0
  11. analysis/pca_analyse.py +116 -0
  12. analysis/src/__init__.py +0 -0
  13. analysis/src/__pycache__/__init__.cpython-310.pyc +0 -0
  14. analysis/src/__pycache__/__init__.cpython-37.pyc +0 -0
  15. analysis/src/__pycache__/__init__.cpython-39.pyc +0 -0
  16. analysis/src/__pycache__/eval.cpython-310.pyc +0 -0
  17. analysis/src/__pycache__/eval.cpython-37.pyc +0 -0
  18. analysis/src/__pycache__/eval.cpython-39.pyc +0 -0
  19. analysis/src/common/__init__.py +0 -0
  20. analysis/src/common/__pycache__/__init__.cpython-310.pyc +0 -0
  21. analysis/src/common/__pycache__/__init__.cpython-39.pyc +0 -0
  22. analysis/src/common/__pycache__/all_atom.cpython-39.pyc +0 -0
  23. analysis/src/common/__pycache__/data_transforms.cpython-39.pyc +0 -0
  24. analysis/src/common/__pycache__/geo_utils.cpython-310.pyc +0 -0
  25. analysis/src/common/__pycache__/geo_utils.cpython-39.pyc +0 -0
  26. analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc +0 -0
  27. analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc +0 -0
  28. analysis/src/common/__pycache__/protein.cpython-310.pyc +0 -0
  29. analysis/src/common/__pycache__/protein.cpython-39.pyc +0 -0
  30. analysis/src/common/__pycache__/residue_constants.cpython-310.pyc +0 -0
  31. analysis/src/common/__pycache__/residue_constants.cpython-39.pyc +0 -0
  32. analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc +0 -0
  33. analysis/src/common/__pycache__/rotation3d.cpython-39.pyc +0 -0
  34. analysis/src/common/all_atom.py +219 -0
  35. analysis/src/common/data_transforms.py +1194 -0
  36. analysis/src/common/geo_utils.py +155 -0
  37. analysis/src/common/pdb_utils.py +353 -0
  38. analysis/src/common/protein.py +289 -0
  39. analysis/src/common/residue_constants.py +897 -0
  40. analysis/src/common/rigid_utils.py +1451 -0
  41. analysis/src/common/rotation3d.py +596 -0
  42. analysis/src/data/__init__.py +0 -0
  43. analysis/src/data/__pycache__/__init__.cpython-39.pyc +0 -0
  44. analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc +0 -0
  45. analysis/src/data/components/__init__.py +0 -0
  46. analysis/src/data/components/__pycache__/__init__.cpython-39.pyc +0 -0
  47. analysis/src/data/components/__pycache__/dataset.cpython-39.pyc +0 -0
  48. analysis/src/data/components/dataset.py +321 -0
  49. analysis/src/data/protein_datamodule.py +242 -0
  50. analysis/src/eval.py +217 -0
README.md CHANGED
@@ -1,3 +1,134 @@
1
- ---
2
- license: gpl-3.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # P2DFlow
2
+
3
+ > ## ℹ️ The version 2 of codes for P2DFlow will come soon to align with the new revised paper, and for easier use and modification (the old version can run correctly)
4
+
5
+ P2DFlow is a protein ensemble generative model with SE(3) flow matching based on ESMFold, the ensembles generated by P2DFlow could aid in understanding protein functions across various scenarios.
6
+
7
+ Technical details and evaluation results are provided in our paper:
8
+ * [P2DFlow: A Protein Ensemble Generative Model with SE(3) Flow Matching](https://arxiv.org/abs/2411.17196)
9
+
10
+ <p align="center">
11
+ <img src="resources/workflow.jpg" width="600"/>
12
+ </p>
13
+
14
+ ![P2DFlow](resources/gen_example.gif)
15
+
16
+
17
+ ## Table of Contents
18
+ 1. [Installation](#Installation)
19
+ 2. [Prepare Dataset](#Prepare-Dataset)
20
+ 3. [Model weights](#Model-weights)
21
+ 4. [Training](#Training)
22
+ 5. [Inference](#Inference)
23
+ 6. [Evaluation](#Evaluation)
24
+ 7. [License](#License)
25
+ 8. [Citation](#Citation)
26
+
27
+
28
+ ## Installation
29
+ In an environment with cuda 11.7, run:
30
+ ```
31
+ conda env create -f environment.yml
32
+ ```
33
+ To activate the environment, run:
34
+ ```
35
+ conda activate P2DFlow
36
+ ```
37
+
38
+ ## Prepare Dataset
39
+ #### (tips: If you want to use the data we have preprocessed, please go directly to `3. Process selected dataset`; if you prefer to process the data from scratch or work with your own data, please start from the beginning)
40
+
41
+ #### 1. Download raw ATLAS dataset
42
+ (i) Download the `Analysis & MDs` dataset from [ATLAS](https://www.dsimb.inserm.fr/ATLAS/), or you can use `./dataset/download.py` by running:
43
+ ```
44
+ python ./dataset/download.py
45
+ ```
46
+ We will use `.pdb` and `.xtc` files for the following calculation.
47
+
48
+ #### 2. Calculate the 'approximate energy and select representative structures
49
+ (i) Use `gaussian_kde` to calculate the 'approximate energy' (You need to put all files above in `./dataset`, include `ATLAS_filename.txt` for filenames of all proteins):
50
+ ```
51
+ python ./dataset/traj_analyse.py
52
+ ```
53
+ And you will get `traj_info.csv`.
54
+
55
+ (ii) Select representative structures at equal intervals based on the 'approximate energy':
56
+ ```
57
+ python ./dataset/md_select.py
58
+ ```
59
+
60
+ #### 3. Process selected dataset
61
+
62
+ (i) Download the selected dataset (or get it from the two steps above) from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filename is `selected_dataset.tar`, and decompress it using:
63
+ ```
64
+ tar -xvf select_dataset.tar
65
+ ```
66
+ (ii) Preprocess `.pdb` files to get `.pkl` files:
67
+ ```
68
+ python ./data/process_pdb_files.py --pdb_dir ${pdb_dir} --write_dir ${write_dir}
69
+ ```
70
+ And you will get `metadata.csv`.
71
+
72
+ then compute node representation and pair representation using ESM-2 (`csv_path` is the path of `metadata.csv`):
73
+ ```
74
+ python ./data/cal_repr.py --csv_path ${csv_path}
75
+ ```
76
+ then compute predicted static structure using ESMFold (`csv_path` is the path of `metadata.csv`):
77
+ ```
78
+ python ./data/cal_static_structure.py --csv_path ${csv_path}
79
+ ```
80
+ (iii) Provide the necessary `.csv` files for training
81
+
82
+ If you are using the data we have preprocessed, download the `.csv` files from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filenames are `train_dataset.csv` and `train_dataset_energy.csv`(they correspond to `csv_path` and `energy_csv_path` in `./configs/base.yaml` during training).
83
+
84
+ Or if you are using your own data, you can get `metadata.csv` from step 3 (correspond to `csv_path` in `./configs/base.yaml` during training, and you need to split a subset from it as the train dataset), and get `traj_info.csv` from step 2 (correspond to `energy_csv_path`).
85
+
86
+
87
+
88
+ ## Model weights
89
+ Download the pretrained checkpoint from [Google Drive](https://drive.google.com/drive/folders/11mdVfMi2rpVn7nNG2mQAGA5sNXCKePZj?usp=sharing) whose filename is `pretrained.ckpt`, and put it into `./weights` folder. You can use the pretrained weight for inference.
90
+
91
+
92
+ ## Training
93
+ To train P2DFlow, firstly make sure you have prepared the dataset according to `Prepare Dataset`, and put it in the right folder, then modify `./configs/base.yaml` (especially for `csv_path` and `energy_csv_path`). After this, you can run:
94
+ ```
95
+ python experiments/train_se3_flows.py
96
+ ```
97
+ And you will get the checkpoints in `./ckpt`.
98
+
99
+
100
+ ## Inference
101
+ To infer for specified protein sequence, firstly modify `./inference/valid_seq.csv` and `./configs/inference.yaml` (especially for `validset_path`), then run:
102
+ ```
103
+ python experiments/inference_se3_flows.py
104
+ ```
105
+ And you will get the results in `./inference_outputs/weights/`.
106
+
107
+
108
+ ## Evaluation
109
+ To evaluate metrics related to fidelity and dynamics, specify paths in `./analysis/eval_test.py`, then run:
110
+ ```
111
+ python ./analysis/eval_test.py
112
+ ```
113
+ To evaluate PCA, specify paths in `./analysis/pca_analyse.py`, then run:
114
+ ```
115
+ python ./analysis/pca_analyse.py
116
+ ```
117
+ To draw the ramachandran plots, specify paths in `./analysis/Ramachandran_plot.py`, then run:
118
+ ```
119
+ python ./analysis/Ramachandran_plot.py
120
+ ```
121
+
122
+ ## License
123
+ This project is licensed under the terms of the GPL-3.0 license.
124
+
125
+
126
+ ## Citation
127
+ ```
128
+ @article{jin2024p2dflow,
129
+ title={P2DFlow: A Protein Ensemble Generative Model with SE(3) Flow Matching},
130
+ author={Yaowei Jin, Qi Huang, Ziyang Song, Mingyue Zheng, Dan Teng, Qian Shi},
131
+ journal={arXiv preprint arXiv:2411.17196},
132
+ year={2024}
133
+ }
134
+ ```
analysis/Ramachandran_plot.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import MDAnalysis as mda
2
+ import numpy as np
3
+ from MDAnalysis.analysis.dihedrals import Ramachandran
4
+ import os
5
+ import re
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+
9
+
10
+ def ramachandran_eval(all_paths, pdb_file, output_dir):
11
+ angle_results_all = []
12
+
13
+ for dirpath in all_paths:
14
+ pdb_path = os.path.join(dirpath,pdb_file)
15
+
16
+ u = mda.Universe(pdb_path)
17
+ protein = u.select_atoms('protein')
18
+ # print('There are {} residues in the protein'.format(len(protein.residues)))
19
+
20
+ ramachandran = Ramachandran(protein)
21
+ ramachandran.run()
22
+ angle_results = ramachandran.results.angles
23
+ # print(angle_results.shape)
24
+
25
+ # ramachandran.plot(color='black', marker='.')
26
+
27
+ angle_results_all.append(angle_results.reshape([-1,2]))
28
+
29
+
30
+ # df = pd.DataFrame(angle_results.reshape([-1,2]))
31
+ # df.to_csv(os.path.join(output_dir, os.path.basename(dirpath)+'_'+pdb_file.split('.')[0]+'.csv'), index=False)
32
+
33
+
34
+ points1 = angle_results_all[0]
35
+ grid_size = 360 # 网格的大小
36
+ x_bins = np.linspace(-180, 180, grid_size)
37
+ y_bins = np.linspace(-180, 180, grid_size)
38
+ result_tmp={}
39
+ for idx in range(len(angle_results_all[1:])):
40
+ idx = idx + 1
41
+ points2 = angle_results_all[idx]
42
+
43
+ # 使用2D直方图统计每组点在网格上的分布
44
+ hist1, _, _ = np.histogram2d(points1[:, 0], points1[:, 1], bins=[x_bins, y_bins])
45
+ hist2, _, _ = np.histogram2d(points2[:, 0], points2[:, 1], bins=[x_bins, y_bins])
46
+
47
+ # 将直方图转换为布尔值,表示某个网格是否有点落入
48
+ mask1 = hist1 > 0
49
+ mask2 = hist2 > 0
50
+
51
+ intersection = np.logical_and(mask1, mask2).sum()
52
+ all_mask2 = mask2.sum()
53
+ val_ratio = intersection / all_mask2
54
+ print(os.path.basename(all_paths[idx]), "val_ratio:", val_ratio)
55
+
56
+
57
+ result_tmp[os.path.basename(all_paths[idx])] = val_ratio
58
+ result_tmp['file'] = pdb_file
59
+
60
+ return result_tmp
61
+
62
+
63
+
64
+ if __name__ == "__main__":
65
+ key1 = 'P2DFlow_epoch19'
66
+ all_paths = [
67
+ "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/ATLAS_valid",
68
+ # "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/esm_n_pred",
69
+ "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/alphaflow_pred",
70
+ "/cluster/home/shiqian/frame-flow-test1/valid/evaluate/Str2Str_pred",
71
+
72
+ f'/cluster/home/shiqian/frame-flow-test1/valid/evaluate/{key1}',
73
+
74
+ ]
75
+ output_dir = '/cluster/home/shiqian/frame-flow-test1/valid/evaluate/Ramachandran'
76
+ os.makedirs(output_dir, exist_ok=True)
77
+ results={
78
+ 'file':[],
79
+ # 'esm_n_pred':[],
80
+ 'alphaflow_pred':[],
81
+ 'Str2Str_pred':[],
82
+
83
+ key1:[],
84
+ }
85
+ for file in os.listdir(all_paths[0]):
86
+ if re.search('\.pdb',file):
87
+
88
+ pdb_file = file
89
+ print(file)
90
+ result_tmp = ramachandran_eval(
91
+ all_paths=all_paths,
92
+ pdb_file=pdb_file,
93
+ output_dir=output_dir
94
+ )
95
+ for key in results.keys():
96
+ results[key].append(result_tmp[key])
97
+
98
+ out_total_df = pd.DataFrame(results)
99
+ out_total_df.to_csv(os.path.join(output_dir,f'Ramachandran_plot_validity_{key1}.csv'),index=False)
analysis/__pycache__/Ramachandran_plot.cpython-310.pyc ADDED
Binary file (2.25 kB). View file
 
analysis/__pycache__/merge_pred_pdb.cpython-310.pyc ADDED
Binary file (1.28 kB). View file
 
analysis/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (2.07 kB). View file
 
analysis/__pycache__/utils.cpython-310.pyc ADDED
Binary file (2.2 kB). View file
 
analysis/__pycache__/utils.cpython-38.pyc ADDED
Binary file (2.18 kB). View file
 
analysis/eval_result.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import sys
4
+ sys.path.append('./analysis')
5
+ import argparse
6
+
7
+ import pandas as pd
8
+ from src.eval import evaluate_prediction
9
+ from merge_pred_pdb import merge_pdb_full
10
+ from Ramachandran_plot import ramachandran_eval
11
+
12
+ if __name__ == '__main__':
13
+ parser = argparse.ArgumentParser()
14
+
15
+ parser.add_argument("--pred_org_dir", type=str, default="./inference_outputs/weights/pretrained/2025-03-13_10-08")
16
+ parser.add_argument("--valid_csv_file", type=str, default="./inference/valid_seq.csv")
17
+ parser.add_argument("--pred_merge_dir", type=str, default="./inference/test/pred_merge_results")
18
+ parser.add_argument("--target_dir", type=str, default="./inference/test/target_dir")
19
+ parser.add_argument("--crystal_dir", type=str, default="./inference/test/crystal_dir")
20
+
21
+ args = parser.parse_args()
22
+
23
+
24
+ # merge pdb
25
+ pred_org_dir = args.pred_org_dir
26
+ valid_csv_file = args.valid_csv_file
27
+ pred_merge_dir = args.pred_merge_dir
28
+ merge_pdb_full(pred_org_dir, valid_csv_file, pred_merge_dir)
29
+
30
+
31
+ # cal_eval
32
+ pred_merge_dir = args.pred_merge_dir
33
+ target_dir = args.target_dir
34
+ crystal_dir = args.crystal_dir
35
+ evaluate_prediction(pred_merge_dir, target_dir, crystal_dir)
36
+
37
+
38
+ # cal_RP
39
+ all_paths = [
40
+ args.target_dir,
41
+ args.pred_merge_dir,
42
+ ]
43
+ results={}
44
+ for file in os.listdir(all_paths[0]):
45
+ if re.search('\.pdb',file):
46
+
47
+ pdb_file = file
48
+ print(file)
49
+ result_tmp = ramachandran_eval(
50
+ all_paths=all_paths,
51
+ pdb_file=pdb_file,
52
+ output_dir=args.pred_merge_dir,
53
+ )
54
+
55
+ for pred_paths in all_paths[1:]:
56
+ key_name = os.path.basename(pred_paths)
57
+ if key_name is results.keys():
58
+ results[key_name].append(result_tmp[key_name])
59
+ else:
60
+ results[key_name] = [result_tmp[key_name]]
61
+
62
+ out_total_df = pd.DataFrame(results)
63
+ out_total_df.to_csv(os.path.join(args.pred_merge_dir, f'Ramachandran_plot_validity.csv'), index=False)
64
+ print(f"RP results saved to {os.path.join(args.pred_merge_dir, f'Ramachandran_plot_validity.csv')}")
65
+
66
+
analysis/merge_pred_pdb.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ from Bio.PDB import PDBParser, PDBIO
5
+
6
+ def merge_pdb(work_dir, new_file, ref_pdb):
7
+ parser = PDBParser()
8
+ structures = []
9
+ for pdb_dir in os.listdir(work_dir):
10
+ pattern=".*"+ref_pdb
11
+ pdb_dir_full=os.path.join(work_dir,pdb_dir)
12
+ if os.path.isdir(pdb_dir_full) and re.match(pattern,pdb_dir):
13
+ for pdb_file in os.listdir(pdb_dir_full):
14
+ if re.match("sample.*\.pdb",pdb_file):
15
+ structure = parser.get_structure(pdb_file, os.path.join(work_dir,pdb_dir,pdb_file))
16
+ structures.append(structure)
17
+
18
+ if len(structures) == 0:
19
+ return
20
+ print(ref_pdb,len(structures),"files")
21
+
22
+ new_structure = structures[0]
23
+ count = 0
24
+ for structure in structures[1:]:
25
+ for model in structure:
26
+ count += 1
27
+ # print(dir(model))
28
+ model.id = count
29
+ new_structure.add(model)
30
+
31
+ io = PDBIO()
32
+ io.set_structure(new_structure)
33
+ io.save(new_file)
34
+
35
+
36
+ def merge_pdb_full(inference_dir_f, valid_csv, output_dir):
37
+ os.makedirs(output_dir,exist_ok=True)
38
+ valid_set = pd.read_csv(valid_csv)
39
+ for filename in valid_set['file']:
40
+ output_file = os.path.join(output_dir, filename+".pdb")
41
+ merge_pdb(inference_dir_f, output_file, filename)
42
+
43
+
44
+
45
+
analysis/metrics.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Metrics. """
2
+ import mdtraj as md
3
+ import numpy as np
4
+ from openfold.np import residue_constants
5
+ from tmtools import tm_align
6
+ from data import utils as du
7
+
8
+ def calc_tm_score(pos_1, pos_2, seq_1, seq_2):
9
+ tm_results = tm_align(pos_1, pos_2, seq_1, seq_2)
10
+ return tm_results.tm_norm_chain1, tm_results.tm_norm_chain2
11
+
12
+ def calc_mdtraj_metrics(pdb_path):
13
+ try:
14
+ traj = md.load(pdb_path)
15
+ pdb_ss = md.compute_dssp(traj, simplified=True)
16
+ pdb_coil_percent = np.mean(pdb_ss == 'C')
17
+ pdb_helix_percent = np.mean(pdb_ss == 'H')
18
+ pdb_strand_percent = np.mean(pdb_ss == 'E')
19
+ pdb_ss_percent = pdb_helix_percent + pdb_strand_percent
20
+ pdb_rg = md.compute_rg(traj)[0]
21
+ except IndexError as e:
22
+ print('Error in calc_mdtraj_metrics: {}'.format(e))
23
+ pdb_ss_percent = 0.0
24
+ pdb_coil_percent = 0.0
25
+ pdb_helix_percent = 0.0
26
+ pdb_strand_percent = 0.0
27
+ pdb_rg = 0.0
28
+ return {
29
+ 'non_coil_percent': pdb_ss_percent,
30
+ 'coil_percent': pdb_coil_percent,
31
+ 'helix_percent': pdb_helix_percent,
32
+ 'strand_percent': pdb_strand_percent,
33
+ 'radius_of_gyration': pdb_rg,
34
+ }
35
+
36
+ def calc_aligned_rmsd(pos_1, pos_2):
37
+ aligned_pos_1 = du.rigid_transform_3D(pos_1, pos_2)[0]
38
+ return np.mean(np.linalg.norm(aligned_pos_1 - pos_2, axis=-1))
39
+
40
+ def calc_ca_ca_metrics(ca_pos, bond_tol=0.1, clash_tol=1.0):
41
+ ca_bond_dists = np.linalg.norm(
42
+ ca_pos - np.roll(ca_pos, 1, axis=0), axis=-1)[1:]
43
+ ca_ca_dev = np.mean(np.abs(ca_bond_dists - residue_constants.ca_ca))
44
+ ca_ca_valid = np.mean(ca_bond_dists < (residue_constants.ca_ca + bond_tol))
45
+
46
+ ca_ca_dists2d = np.linalg.norm(
47
+ ca_pos[:, None, :] - ca_pos[None, :, :], axis=-1)
48
+ inter_dists = ca_ca_dists2d[np.where(np.triu(ca_ca_dists2d, k=0) > 0)]
49
+ clashes = inter_dists < clash_tol
50
+ return {
51
+ 'ca_ca_deviation': ca_ca_dev,
52
+ 'ca_ca_valid_percent': ca_ca_valid,
53
+ 'num_ca_ca_clashes': np.sum(clashes),
54
+ }
analysis/pca_analyse.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ import MDAnalysis as mda
5
+ from MDAnalysis.analysis import pca, align, rms
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ import warnings
10
+ import argparse
11
+ warnings.filterwarnings("ignore")
12
+
13
+
14
+ def cal_PCA(md_pdb_path,ref_path,pred_pdb_path,n_components = 2):
15
+ print("")
16
+ print('filename=',os.path.basename(ref_path))
17
+
18
+ u = mda.Universe(md_pdb_path, md_pdb_path)
19
+ u_ref = mda.Universe(ref_path, ref_path)
20
+
21
+ aligner = align.AlignTraj(u,
22
+ u_ref,
23
+ select='name CA or name C or name N',
24
+ in_memory=True).run()
25
+
26
+ pc = pca.PCA(u,
27
+ select='name CA or name C or name N',
28
+ align=False, mean=None,
29
+ # n_components=None,
30
+ n_components=n_components,
31
+ ).run()
32
+
33
+ backbone = u.select_atoms('name CA or name C or name N')
34
+ n_bb = len(backbone)
35
+ print('There are {} backbone atoms in the analysis'.format(n_bb))
36
+
37
+ for i in range(n_components):
38
+ print(f"Cumulated variance {i+1}: {pc.cumulated_variance[i]:.3f}")
39
+
40
+ transformed = pc.transform(backbone, n_components=n_components)
41
+
42
+ print(transformed.shape) # (3000, 2)
43
+
44
+ df = pd.DataFrame(transformed,
45
+ columns=['PC{}'.format(i+1) for i in range(n_components)])
46
+
47
+
48
+ plt.scatter(df['PC1'],df['PC2'],marker='o')
49
+ plt.show()
50
+
51
+ output_dir = os.path.dirname(md_pdb_path)
52
+ output_filename = os.path.basename(md_pdb_path).split('.')[0]
53
+
54
+ df.to_csv(os.path.join(output_dir, f'{output_filename}_md_pca.csv'))
55
+ plt.savefig(os.path.join(output_dir, f'{output_filename}_md_pca.png'))
56
+
57
+
58
+ for k,v in pred_pdb_path.items():
59
+ u_pred = mda.Universe(v, v)
60
+ aligner = align.AlignTraj(u_pred,
61
+ u_ref,
62
+ select='name CA or name C or name N',
63
+ in_memory=True).run()
64
+ pred_backbone = u_pred.select_atoms('name CA or name C or name N')
65
+ pred_transformed = pc.transform(pred_backbone, n_components=n_components)
66
+
67
+ df = pd.DataFrame(pred_transformed,
68
+ columns=['PC{}'.format(i+1) for i in range(n_components)])
69
+
70
+ plt.scatter(df['PC1'],df['PC2'],marker='o')
71
+ plt.show()
72
+
73
+ output_dir = os.path.dirname(v)
74
+ output_filename = os.path.basename(v).split('.')[0]
75
+ df.to_csv(os.path.join(output_dir, f'{output_filename}_{k}_pca.csv'))
76
+ plt.savefig(os.path.join(output_dir, f'{output_filename}_{k}_pca.png'))
77
+ plt.clf()
78
+
79
+
80
+ if __name__ == '__main__':
81
+
82
+ parser = argparse.ArgumentParser()
83
+
84
+ parser.add_argument("--pred_pdb_dir", type=str, default="./inference/test/pred_merge_results")
85
+ parser.add_argument("--target_dir", type=str, default="./inference/test/target_dir")
86
+ parser.add_argument("--crystal_dir", type=str, default="./inference/test/crystal_dir")
87
+
88
+ args = parser.parse_args()
89
+
90
+
91
+ pred_pdb_path_org={
92
+ 'P2DFlow':args.pred_pdb_dir,
93
+ }
94
+ md_pdb_path_org = args.target_dir
95
+ ref_path_org = args.crystal_dir
96
+
97
+
98
+ for file in os.listdir(md_pdb_path_org):
99
+ if re.search('\.pdb',file):
100
+ pred_pdb_path={
101
+ 'P2DFlow':'',
102
+ # 'alphaflow':'',
103
+ # 'Str2Str':'',
104
+ }
105
+ for k,v in pred_pdb_path.items():
106
+ pred_pdb_path[k]=os.path.join(pred_pdb_path_org[k],file)
107
+ md_pdb_path = os.path.join(md_pdb_path_org, file)
108
+ ref_path = os.path.join(ref_path_org, file)
109
+ cal_PCA(md_pdb_path,ref_path,pred_pdb_path)
110
+
111
+
112
+
113
+
114
+
115
+
116
+
analysis/src/__init__.py ADDED
File without changes
analysis/src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (138 Bytes). View file
 
analysis/src/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (132 Bytes). View file
 
analysis/src/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (136 Bytes). View file
 
analysis/src/__pycache__/eval.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
analysis/src/__pycache__/eval.cpython-37.pyc ADDED
Binary file (4.61 kB). View file
 
analysis/src/__pycache__/eval.cpython-39.pyc ADDED
Binary file (4.95 kB). View file
 
analysis/src/common/__init__.py ADDED
File without changes
analysis/src/common/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (164 Bytes). View file
 
analysis/src/common/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (143 Bytes). View file
 
analysis/src/common/__pycache__/all_atom.cpython-39.pyc ADDED
Binary file (5.15 kB). View file
 
analysis/src/common/__pycache__/data_transforms.cpython-39.pyc ADDED
Binary file (26.9 kB). View file
 
analysis/src/common/__pycache__/geo_utils.cpython-310.pyc ADDED
Binary file (5.02 kB). View file
 
analysis/src/common/__pycache__/geo_utils.cpython-39.pyc ADDED
Binary file (5 kB). View file
 
analysis/src/common/__pycache__/pdb_utils.cpython-310.pyc ADDED
Binary file (10.5 kB). View file
 
analysis/src/common/__pycache__/pdb_utils.cpython-39.pyc ADDED
Binary file (10.4 kB). View file
 
analysis/src/common/__pycache__/protein.cpython-310.pyc ADDED
Binary file (7.46 kB). View file
 
analysis/src/common/__pycache__/protein.cpython-39.pyc ADDED
Binary file (7.45 kB). View file
 
analysis/src/common/__pycache__/residue_constants.cpython-310.pyc ADDED
Binary file (23.7 kB). View file
 
analysis/src/common/__pycache__/residue_constants.cpython-39.pyc ADDED
Binary file (23.2 kB). View file
 
analysis/src/common/__pycache__/rigid_utils.cpython-39.pyc ADDED
Binary file (41.4 kB). View file
 
analysis/src/common/__pycache__/rotation3d.cpython-39.pyc ADDED
Binary file (17 kB). View file
 
analysis/src/common/all_atom.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for calculating all atom representations.
3
+ """
4
+
5
+ import torch
6
+
7
+ from src.common import residue_constants as rc
8
+ from src.common.data_transforms import atom37_to_torsion_angles
9
+ from src.common.rigid_utils import Rigid, Rotation
10
+
11
+
12
+ # Residue Constants from OpenFold/AlphaFold2.
13
+ IDEALIZED_POS37 = torch.tensor(rc.restype_atom37_rigid_group_positions)
14
+ IDEALIZED_POS37_MASK = torch.any(IDEALIZED_POS37, axis=-1)
15
+ IDEALIZED_POS = torch.tensor(rc.restype_atom14_rigid_group_positions)
16
+ DEFAULT_FRAMES = torch.tensor(rc.restype_rigid_group_default_frame)
17
+ ATOM_MASK = torch.tensor(rc.restype_atom14_mask)
18
+ GROUP_IDX = torch.tensor(rc.restype_atom14_to_rigid_group)
19
+
20
+
21
+ def torsion_angles_to_frames(
22
+ r: Rigid,
23
+ alpha: torch.Tensor,
24
+ aatype: torch.Tensor,
25
+ ):
26
+ # [*, N, 8, 4, 4]
27
+ default_4x4 = DEFAULT_FRAMES[aatype, ...].to(r.device)
28
+
29
+ # [*, N, 8] transformations, i.e.
30
+ # One [*, N, 8, 3, 3] rotation matrix and
31
+ # One [*, N, 8, 3] translation matrix
32
+ default_r = r.from_tensor_4x4(default_4x4)
33
+
34
+ bb_rot = alpha.new_zeros((*((1,) * len(alpha.shape[:-1])), 2))
35
+ bb_rot[..., 1] = 1
36
+
37
+ # [*, N, 8, 2]
38
+ alpha = torch.cat(
39
+ [bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], dim=-2
40
+ )
41
+
42
+ # [*, N, 8, 3, 3]
43
+ # Produces rotation matrices of the form:
44
+ # [
45
+ # [1, 0 , 0 ],
46
+ # [0, a_2,-a_1],
47
+ # [0, a_1, a_2]
48
+ # ]
49
+ # This follows the original code rather than the supplement, which uses
50
+ # different indices.
51
+
52
+ all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
53
+ all_rots[..., 0, 0] = 1
54
+ all_rots[..., 1, 1] = alpha[..., 1]
55
+ all_rots[..., 1, 2] = -alpha[..., 0]
56
+ all_rots[..., 2, 1:] = alpha
57
+
58
+ all_rots = Rigid(Rotation(rot_mats=all_rots), None)
59
+
60
+ all_frames = default_r.compose(all_rots)
61
+
62
+ chi2_frame_to_frame = all_frames[..., 5]
63
+ chi3_frame_to_frame = all_frames[..., 6]
64
+ chi4_frame_to_frame = all_frames[..., 7]
65
+
66
+ chi1_frame_to_bb = all_frames[..., 4]
67
+ chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame)
68
+ chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
69
+ chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
70
+
71
+ all_frames_to_bb = Rigid.cat(
72
+ [
73
+ all_frames[..., :5],
74
+ chi2_frame_to_bb.unsqueeze(-1),
75
+ chi3_frame_to_bb.unsqueeze(-1),
76
+ chi4_frame_to_bb.unsqueeze(-1),
77
+ ],
78
+ dim=-1,
79
+ )
80
+
81
+ all_frames_to_global = r[..., None].compose(all_frames_to_bb)
82
+
83
+ return all_frames_to_global
84
+
85
+
86
+ def prot_to_torsion_angles(aatype, atom37, atom37_mask):
87
+ """Calculate torsion angle features from protein features."""
88
+ prot_feats = {
89
+ 'aatype': aatype,
90
+ 'all_atom_positions': atom37,
91
+ 'all_atom_mask': atom37_mask,
92
+ }
93
+ torsion_angles_feats = atom37_to_torsion_angles()(prot_feats)
94
+ torsion_angles = torsion_angles_feats['torsion_angles_sin_cos']
95
+ torsion_mask = torsion_angles_feats['torsion_angles_mask']
96
+ return torsion_angles, torsion_mask
97
+
98
+
99
+ def frames_to_atom14_pos(
100
+ r: Rigid,
101
+ aatype: torch.Tensor,
102
+ ):
103
+ """Convert frames to their idealized all atom representation.
104
+
105
+ Args:
106
+ r: All rigid groups. [..., N, 8, 3]
107
+ aatype: Residue types. [..., N]
108
+
109
+ Returns:
110
+
111
+ """
112
+
113
+ # [*, N, 14]
114
+ group_mask = GROUP_IDX[aatype, ...]
115
+
116
+ # [*, N, 14, 8]
117
+ group_mask = torch.nn.functional.one_hot(
118
+ group_mask,
119
+ num_classes=DEFAULT_FRAMES.shape[-3],
120
+ ).to(r.device)
121
+
122
+ # [*, N, 14, 8]
123
+ t_atoms_to_global = r[..., None, :] * group_mask
124
+
125
+ # [*, N, 14]
126
+ t_atoms_to_global = t_atoms_to_global.map_tensor_fn(
127
+ lambda x: torch.sum(x, dim=-1)
128
+ )
129
+
130
+ # [*, N, 14, 1]
131
+ frame_atom_mask = ATOM_MASK[aatype, ...].unsqueeze(-1).to(r.device)
132
+
133
+ # [*, N, 14, 3]
134
+ frame_null_pos = IDEALIZED_POS[aatype, ...].to(r.device)
135
+ pred_positions = t_atoms_to_global.apply(frame_null_pos)
136
+ pred_positions = pred_positions * frame_atom_mask
137
+
138
+ return pred_positions
139
+
140
+
141
+ def compute_backbone(bb_rigids, psi_torsions, aatype=None, device=None):
142
+ if device is None:
143
+ device = bb_rigids.device
144
+
145
+ torsion_angles = torch.tile(
146
+ psi_torsions[..., None, :],
147
+ tuple([1 for _ in range(len(bb_rigids.shape))]) + (7, 1)
148
+ ).to(device)
149
+
150
+ # aatype must be on cpu for initializing the tensor by indexing
151
+ if aatype is None:
152
+ aatype = torch.zeros_like(bb_rigids).cpu().long()
153
+ else:
154
+ aatype = aatype.cpu()
155
+
156
+
157
+ all_frames = torsion_angles_to_frames(
158
+ bb_rigids,
159
+ torsion_angles,
160
+ aatype,
161
+ )
162
+ atom14_pos = frames_to_atom14_pos(
163
+ all_frames,
164
+ aatype,
165
+ )
166
+ atom37_bb_pos = torch.zeros(bb_rigids.shape + (37, 3), device=device)
167
+ # atom14 bb order = ['N', 'CA', 'C', 'O', 'CB']
168
+ # atom37 bb order = ['N', 'CA', 'C', 'CB', 'O']
169
+ atom37_bb_pos[..., :3, :] = atom14_pos[..., :3, :]
170
+ atom37_bb_pos[..., 3, :] = atom14_pos[..., 4, :]
171
+ atom37_bb_pos[..., 4, :] = atom14_pos[..., 3, :]
172
+ atom37_mask = torch.any(atom37_bb_pos, axis=-1)
173
+ return atom37_bb_pos, atom37_mask, aatype.to(device), atom14_pos
174
+
175
+
176
+ def calculate_neighbor_angles(R_ac, R_ab):
177
+ """Calculate angles between atoms c <- a -> b.
178
+
179
+ Parameters
180
+ ----------
181
+ R_ac: Tensor, shape = (N,3)
182
+ Vector from atom a to c.
183
+ R_ab: Tensor, shape = (N,3)
184
+ Vector from atom a to b.
185
+
186
+ Returns
187
+ -------
188
+ angle_cab: Tensor, shape = (N,)
189
+ Angle between atoms c <- a -> b.
190
+ """
191
+ # cos(alpha) = (u * v) / (|u|*|v|)
192
+ x = torch.sum(R_ac * R_ab, dim=1) # shape = (N,)
193
+ # sin(alpha) = |u x v| / (|u|*|v|)
194
+ y = torch.cross(R_ac, R_ab).norm(dim=-1) # shape = (N,)
195
+ # avoid that for y == (0,0,0) the gradient wrt. y becomes NaN
196
+ y = torch.max(y, torch.tensor(1e-9))
197
+ angle = torch.atan2(y, x)
198
+ return angle
199
+
200
+
201
+ def vector_projection(R_ab, P_n):
202
+ """
203
+ Project the vector R_ab onto a plane with normal vector P_n.
204
+
205
+ Parameters
206
+ ----------
207
+ R_ab: Tensor, shape = (N,3)
208
+ Vector from atom a to b.
209
+ P_n: Tensor, shape = (N,3)
210
+ Normal vector of a plane onto which to project R_ab.
211
+
212
+ Returns
213
+ -------
214
+ R_ab_proj: Tensor, shape = (N,3)
215
+ Projected vector (orthogonal to P_n).
216
+ """
217
+ a_x_b = torch.sum(R_ab * P_n, dim=-1)
218
+ b_x_b = torch.sum(P_n * P_n, dim=-1)
219
+ return R_ab - (a_x_b / b_x_b)[:, None] * P_n
analysis/src/common/data_transforms.py ADDED
@@ -0,0 +1,1194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import itertools
17
+ from functools import reduce, wraps
18
+ from operator import add
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from src.common import residue_constants as rc
24
+ from src.common.rigid_utils import Rotation, Rigid
25
+ from src.utils.tensor_utils import (
26
+ tree_map,
27
+ tensor_tree_map,
28
+ batched_gather,
29
+ )
30
+
31
+ NUM_RES = "num residues placeholder"
32
+ NUM_MSA_SEQ = "msa placeholder"
33
+ NUM_EXTRA_SEQ = "extra msa placeholder"
34
+ NUM_TEMPLATES = "num templates placeholder"
35
+
36
+ MSA_FEATURE_NAMES = [
37
+ "msa",
38
+ "deletion_matrix",
39
+ "msa_mask",
40
+ "msa_row_mask",
41
+ "bert_mask",
42
+ "true_msa",
43
+ ]
44
+
45
+
46
+ def cast_to_64bit_ints(protein):
47
+ # We keep all ints as int64
48
+ for k, v in protein.items():
49
+ if v.dtype == torch.int32:
50
+ protein[k] = v.type(torch.int64)
51
+
52
+ return protein
53
+
54
+
55
+ def make_one_hot(x, num_classes):
56
+ x_one_hot = torch.zeros(*x.shape, num_classes)
57
+ x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
58
+ return x_one_hot
59
+
60
+
61
+ def make_seq_mask(protein):
62
+ protein["seq_mask"] = torch.ones(
63
+ protein["aatype"].shape, dtype=torch.float32
64
+ )
65
+ return protein
66
+
67
+
68
+ def make_template_mask(protein):
69
+ protein["template_mask"] = torch.ones(
70
+ protein["template_aatype"].shape[0], dtype=torch.float32
71
+ )
72
+ return protein
73
+
74
+
75
+ def curry1(f):
76
+ """Supply all arguments but the first."""
77
+ @wraps(f)
78
+ def fc(*args, **kwargs):
79
+ return lambda x: f(x, *args, **kwargs)
80
+
81
+ return fc
82
+
83
+
84
+ def make_all_atom_aatype(protein):
85
+ protein["all_atom_aatype"] = protein["aatype"]
86
+ return protein
87
+
88
+
89
+ def fix_templates_aatype(protein):
90
+ # Map one-hot to indices
91
+ num_templates = protein["template_aatype"].shape[0]
92
+ if(num_templates > 0):
93
+ protein["template_aatype"] = torch.argmax(
94
+ protein["template_aatype"], dim=-1
95
+ )
96
+ # Map hhsearch-aatype to our aatype.
97
+ new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
98
+ new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
99
+ num_templates, -1
100
+ )
101
+ protein["template_aatype"] = torch.gather(
102
+ new_order, 1, index=protein["template_aatype"]
103
+ )
104
+
105
+ return protein
106
+
107
+
108
+ def correct_msa_restypes(protein):
109
+ """Correct MSA restype to have the same order as rc."""
110
+ new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
111
+ new_order = torch.tensor(
112
+ [new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
113
+ ).transpose(0, 1)
114
+ protein["msa"] = torch.gather(new_order, 0, protein["msa"])
115
+
116
+ perm_matrix = np.zeros((22, 22), dtype=np.float32)
117
+ perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
118
+
119
+ for k in protein:
120
+ if "profile" in k:
121
+ num_dim = protein[k].shape.as_list()[-1]
122
+ assert num_dim in [
123
+ 20,
124
+ 21,
125
+ 22,
126
+ ], "num_dim for %s out of expected range: %s" % (k, num_dim)
127
+ protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
128
+
129
+ return protein
130
+
131
+
132
+ def squeeze_features(protein):
133
+ """Remove singleton and repeated dimensions in protein features."""
134
+ protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
135
+ for k in [
136
+ "domain_name",
137
+ "msa",
138
+ "num_alignments",
139
+ "seq_length",
140
+ "sequence",
141
+ "superfamily",
142
+ "deletion_matrix",
143
+ "resolution",
144
+ "between_segment_residues",
145
+ "residue_index",
146
+ "template_all_atom_mask",
147
+ ]:
148
+ if k in protein:
149
+ final_dim = protein[k].shape[-1]
150
+ if isinstance(final_dim, int) and final_dim == 1:
151
+ if torch.is_tensor(protein[k]):
152
+ protein[k] = torch.squeeze(protein[k], dim=-1)
153
+ else:
154
+ protein[k] = np.squeeze(protein[k], axis=-1)
155
+
156
+ for k in ["seq_length", "num_alignments"]:
157
+ if k in protein:
158
+ protein[k] = protein[k][0]
159
+
160
+ return protein
161
+
162
+
163
+ @curry1
164
+ def randomly_replace_msa_with_unknown(protein, replace_proportion):
165
+ """Replace a portion of the MSA with 'X'."""
166
+ msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
167
+ x_idx = 20
168
+ gap_idx = 21
169
+ msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
170
+ protein["msa"] = torch.where(
171
+ msa_mask,
172
+ torch.ones_like(protein["msa"]) * x_idx,
173
+ protein["msa"]
174
+ )
175
+ aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
176
+
177
+ protein["aatype"] = torch.where(
178
+ aatype_mask,
179
+ torch.ones_like(protein["aatype"]) * x_idx,
180
+ protein["aatype"],
181
+ )
182
+ return protein
183
+
184
+
185
+ @curry1
186
+ def sample_msa(protein, max_seq, keep_extra, seed=None):
187
+ """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
188
+ num_seq = protein["msa"].shape[0]
189
+ g = torch.Generator(device=protein["msa"].device)
190
+ if seed is not None:
191
+ g.manual_seed(seed)
192
+ shuffled = torch.randperm(num_seq - 1, generator=g) + 1
193
+ index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
194
+ num_sel = min(max_seq, num_seq)
195
+ sel_seq, not_sel_seq = torch.split(
196
+ index_order, [num_sel, num_seq - num_sel]
197
+ )
198
+
199
+ for k in MSA_FEATURE_NAMES:
200
+ if k in protein:
201
+ if keep_extra:
202
+ protein["extra_" + k] = torch.index_select(
203
+ protein[k], 0, not_sel_seq
204
+ )
205
+ protein[k] = torch.index_select(protein[k], 0, sel_seq)
206
+
207
+ return protein
208
+
209
+
210
+ @curry1
211
+ def add_distillation_flag(protein, distillation):
212
+ protein['is_distillation'] = distillation
213
+ return protein
214
+
215
+ @curry1
216
+ def sample_msa_distillation(protein, max_seq):
217
+ if(protein["is_distillation"] == 1):
218
+ protein = sample_msa(max_seq, keep_extra=False)(protein)
219
+ return protein
220
+
221
+
222
+ @curry1
223
+ def crop_extra_msa(protein, max_extra_msa):
224
+ num_seq = protein["extra_msa"].shape[0]
225
+ num_sel = min(max_extra_msa, num_seq)
226
+ select_indices = torch.randperm(num_seq)[:num_sel]
227
+ for k in MSA_FEATURE_NAMES:
228
+ if "extra_" + k in protein:
229
+ protein["extra_" + k] = torch.index_select(
230
+ protein["extra_" + k], 0, select_indices
231
+ )
232
+
233
+ return protein
234
+
235
+
236
+ def delete_extra_msa(protein):
237
+ for k in MSA_FEATURE_NAMES:
238
+ if "extra_" + k in protein:
239
+ del protein["extra_" + k]
240
+ return protein
241
+
242
+
243
+ # Not used in inference
244
+ @curry1
245
+ def block_delete_msa(protein, config):
246
+ num_seq = protein["msa"].shape[0]
247
+ block_num_seq = torch.floor(
248
+ torch.tensor(num_seq, dtype=torch.float32)
249
+ * config.msa_fraction_per_block
250
+ ).to(torch.int32)
251
+
252
+ if config.randomize_num_blocks:
253
+ nb = torch.distributions.uniform.Uniform(
254
+ 0, config.num_blocks + 1
255
+ ).sample()
256
+ else:
257
+ nb = config.num_blocks
258
+
259
+ del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
260
+ del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
261
+ del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
262
+ del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]
263
+
264
+ # Make sure we keep the original sequence
265
+ combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
266
+ uniques, counts = combined.unique(return_counts=True)
267
+ difference = uniques[counts == 1]
268
+ intersection = uniques[counts > 1]
269
+ keep_indices = torch.squeeze(difference, 0)
270
+
271
+ for k in MSA_FEATURE_NAMES:
272
+ if k in protein:
273
+ protein[k] = torch.gather(protein[k], keep_indices)
274
+
275
+ return protein
276
+
277
+
278
+ @curry1
279
+ def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
280
+ weights = torch.cat(
281
+ [torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
282
+ 0,
283
+ )
284
+
285
+ # Make agreement score as weighted Hamming distance
286
+ msa_one_hot = make_one_hot(protein["msa"], 23)
287
+ sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
288
+ extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
289
+ extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
290
+
291
+ num_seq, num_res, _ = sample_one_hot.shape
292
+ extra_num_seq, _, _ = extra_one_hot.shape
293
+
294
+ # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
295
+ # in an optimized fashion to avoid possible memory or computation blowup.
296
+ agreement = torch.matmul(
297
+ torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
298
+ torch.reshape(
299
+ sample_one_hot * weights, [num_seq, num_res * 23]
300
+ ).transpose(0, 1),
301
+ )
302
+
303
+ # Assign each sequence in the extra sequences to the closest MSA sample
304
+ protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
305
+ torch.int64
306
+ )
307
+
308
+ return protein
309
+
310
+
311
+ def unsorted_segment_sum(data, segment_ids, num_segments):
312
+ """
313
+ Computes the sum along segments of a tensor. Similar to
314
+ tf.unsorted_segment_sum, but only supports 1-D indices.
315
+
316
+ :param data: A tensor whose segments are to be summed.
317
+ :param segment_ids: The 1-D segment indices tensor.
318
+ :param num_segments: The number of segments.
319
+ :return: A tensor of same data type as the data argument.
320
+ """
321
+ assert (
322
+ len(segment_ids.shape) == 1 and
323
+ segment_ids.shape[0] == data.shape[0]
324
+ )
325
+ segment_ids = segment_ids.view(
326
+ segment_ids.shape[0], *((1,) * len(data.shape[1:]))
327
+ )
328
+ segment_ids = segment_ids.expand(data.shape)
329
+ shape = [num_segments] + list(data.shape[1:])
330
+ tensor = torch.zeros(*shape).scatter_add_(0, segment_ids, data.float())
331
+ tensor = tensor.type(data.dtype)
332
+ return tensor
333
+
334
+
335
+ @curry1
336
+ def summarize_clusters(protein):
337
+ """Produce profile and deletion_matrix_mean within each cluster."""
338
+ num_seq = protein["msa"].shape[0]
339
+
340
+ def csum(x):
341
+ return unsorted_segment_sum(
342
+ x, protein["extra_cluster_assignment"], num_seq
343
+ )
344
+
345
+ mask = protein["extra_msa_mask"]
346
+ mask_counts = 1e-6 + protein["msa_mask"] + csum(mask) # Include center
347
+
348
+ msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
349
+ msa_sum += make_one_hot(protein["msa"], 23) # Original sequence
350
+ protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
351
+ del msa_sum
352
+
353
+ del_sum = csum(mask * protein["extra_deletion_matrix"])
354
+ del_sum += protein["deletion_matrix"] # Original sequence
355
+ protein["cluster_deletion_mean"] = del_sum / mask_counts
356
+ del del_sum
357
+
358
+ return protein
359
+
360
+
361
+ def make_msa_mask(protein):
362
+ """Mask features are all ones, but will later be zero-padded."""
363
+ protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
364
+ protein["msa_row_mask"] = torch.ones(
365
+ (protein["msa"].shape[0]), dtype=torch.float32
366
+ )
367
+ return protein
368
+
369
+
370
+ def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
371
+ """Create pseudo beta features."""
372
+ is_gly = torch.eq(aatype, rc.restype_order["G"])
373
+ ca_idx = rc.atom_order["CA"]
374
+ cb_idx = rc.atom_order["CB"]
375
+ pseudo_beta = torch.where(
376
+ torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
377
+ all_atom_positions[..., ca_idx, :],
378
+ all_atom_positions[..., cb_idx, :],
379
+ )
380
+
381
+ if all_atom_mask is not None:
382
+ pseudo_beta_mask = torch.where(
383
+ is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
384
+ )
385
+ return pseudo_beta, pseudo_beta_mask
386
+ else:
387
+ return pseudo_beta
388
+
389
+
390
+ @curry1
391
+ def make_pseudo_beta(protein, prefix=""):
392
+ """Create pseudo-beta (alpha for glycine) position and mask."""
393
+ assert prefix in ["", "template_"]
394
+ (
395
+ protein[prefix + "pseudo_beta"],
396
+ protein[prefix + "pseudo_beta_mask"],
397
+ ) = pseudo_beta_fn(
398
+ protein["template_aatype" if prefix else "aatype"],
399
+ protein[prefix + "all_atom_positions"],
400
+ protein["template_all_atom_mask" if prefix else "all_atom_mask"],
401
+ )
402
+ return protein
403
+
404
+
405
+ @curry1
406
+ def add_constant_field(protein, key, value):
407
+ protein[key] = torch.tensor(value)
408
+ return protein
409
+
410
+
411
+ def shaped_categorical(probs, epsilon=1e-10):
412
+ ds = probs.shape
413
+ num_classes = ds[-1]
414
+ distribution = torch.distributions.categorical.Categorical(
415
+ torch.reshape(probs + epsilon, [-1, num_classes])
416
+ )
417
+ counts = distribution.sample()
418
+ return torch.reshape(counts, ds[:-1])
419
+
420
+
421
+ def make_hhblits_profile(protein):
422
+ """Compute the HHblits MSA profile if not already present."""
423
+ if "hhblits_profile" in protein:
424
+ return protein
425
+
426
+ # Compute the profile for every residue (over all MSA sequences).
427
+ msa_one_hot = make_one_hot(protein["msa"], 22)
428
+
429
+ protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
430
+ return protein
431
+
432
+
433
+ @curry1
434
+ def make_masked_msa(protein, config, replace_fraction):
435
+ """Create data for BERT on raw MSA."""
436
+ # Add a random amino acid uniformly.
437
+ random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
438
+
439
+ categorical_probs = (
440
+ config.uniform_prob * random_aa
441
+ + config.profile_prob * protein["hhblits_profile"]
442
+ + config.same_prob * make_one_hot(protein["msa"], 22)
443
+ )
444
+
445
+ # Put all remaining probability on [MASK] which is a new column
446
+ pad_shapes = list(
447
+ reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
448
+ )
449
+ pad_shapes[1] = 1
450
+ mask_prob = (
451
+ 1.0 - config.profile_prob - config.same_prob - config.uniform_prob
452
+ )
453
+ assert mask_prob >= 0.0
454
+ categorical_probs = torch.nn.functional.pad(
455
+ categorical_probs, pad_shapes, value=mask_prob
456
+ )
457
+
458
+ sh = protein["msa"].shape
459
+ mask_position = torch.rand(sh) < replace_fraction
460
+
461
+ bert_msa = shaped_categorical(categorical_probs)
462
+ bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
463
+
464
+ # Mix real and masked MSA
465
+ protein["bert_mask"] = mask_position.to(torch.float32)
466
+ protein["true_msa"] = protein["msa"]
467
+ protein["msa"] = bert_msa
468
+
469
+ return protein
470
+
471
+
472
+ @curry1
473
+ def make_fixed_size(
474
+ protein,
475
+ shape_schema,
476
+ msa_cluster_size,
477
+ extra_msa_size,
478
+ num_res=0,
479
+ num_templates=0,
480
+ ):
481
+ """Guess at the MSA and sequence dimension to make fixed size."""
482
+ pad_size_map = {
483
+ NUM_RES: num_res,
484
+ NUM_MSA_SEQ: msa_cluster_size,
485
+ NUM_EXTRA_SEQ: extra_msa_size,
486
+ NUM_TEMPLATES: num_templates,
487
+ }
488
+
489
+ for k, v in protein.items():
490
+ # Don't transfer this to the accelerator.
491
+ if k == "extra_cluster_assignment":
492
+ continue
493
+ shape = list(v.shape)
494
+ schema = shape_schema[k]
495
+ msg = "Rank mismatch between shape and shape schema for"
496
+ assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
497
+ pad_size = [
498
+ pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
499
+ ]
500
+
501
+ padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
502
+ padding.reverse()
503
+ padding = list(itertools.chain(*padding))
504
+ if padding:
505
+ protein[k] = torch.nn.functional.pad(v, padding)
506
+ protein[k] = torch.reshape(protein[k], pad_size)
507
+
508
+ return protein
509
+
510
+
511
+ @curry1
512
+ def make_msa_feat(protein):
513
+ """Create and concatenate MSA features."""
514
+ # Whether there is a domain break. Always zero for chains, but keeping for
515
+ # compatibility with domain datasets.
516
+ has_break = torch.clip(
517
+ protein["between_segment_residues"].to(torch.float32), 0, 1
518
+ )
519
+ aatype_1hot = make_one_hot(protein["aatype"], 21)
520
+
521
+ target_feat = [
522
+ torch.unsqueeze(has_break, dim=-1),
523
+ aatype_1hot, # Everyone gets the original sequence.
524
+ ]
525
+
526
+ msa_1hot = make_one_hot(protein["msa"], 23)
527
+ has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
528
+ deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
529
+ 2.0 / np.pi
530
+ )
531
+
532
+ msa_feat = [
533
+ msa_1hot,
534
+ torch.unsqueeze(has_deletion, dim=-1),
535
+ torch.unsqueeze(deletion_value, dim=-1),
536
+ ]
537
+
538
+ if "cluster_profile" in protein:
539
+ deletion_mean_value = torch.atan(
540
+ protein["cluster_deletion_mean"] / 3.0
541
+ ) * (2.0 / np.pi)
542
+ msa_feat.extend(
543
+ [
544
+ protein["cluster_profile"],
545
+ torch.unsqueeze(deletion_mean_value, dim=-1),
546
+ ]
547
+ )
548
+
549
+ if "extra_deletion_matrix" in protein:
550
+ protein["extra_has_deletion"] = torch.clip(
551
+ protein["extra_deletion_matrix"], 0.0, 1.0
552
+ )
553
+ protein["extra_deletion_value"] = torch.atan(
554
+ protein["extra_deletion_matrix"] / 3.0
555
+ ) * (2.0 / np.pi)
556
+
557
+ protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
558
+ protein["target_feat"] = torch.cat(target_feat, dim=-1)
559
+ return protein
560
+
561
+
562
+ @curry1
563
+ def select_feat(protein, feature_list):
564
+ return {k: v for k, v in protein.items() if k in feature_list}
565
+
566
+
567
+ @curry1
568
+ def crop_templates(protein, max_templates):
569
+ for k, v in protein.items():
570
+ if k.startswith("template_"):
571
+ protein[k] = v[:max_templates]
572
+ return protein
573
+
574
+
575
+ def make_atom14_masks(protein):
576
+ """Construct denser atom positions (14 dimensions instead of 37)."""
577
+ restype_atom14_to_atom37 = []
578
+ restype_atom37_to_atom14 = []
579
+ restype_atom14_mask = []
580
+
581
+ for rt in rc.restypes:
582
+ atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
583
+ restype_atom14_to_atom37.append(
584
+ [(rc.atom_order[name] if name else 0) for name in atom_names]
585
+ )
586
+ atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
587
+ restype_atom37_to_atom14.append(
588
+ [
589
+ (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
590
+ for name in rc.atom_types
591
+ ]
592
+ )
593
+
594
+ restype_atom14_mask.append(
595
+ [(1.0 if name else 0.0) for name in atom_names]
596
+ )
597
+
598
+ # Add dummy mapping for restype 'UNK'
599
+ restype_atom14_to_atom37.append([0] * 14)
600
+ restype_atom37_to_atom14.append([0] * 37)
601
+ restype_atom14_mask.append([0.0] * 14)
602
+
603
+ restype_atom14_to_atom37 = torch.tensor(
604
+ restype_atom14_to_atom37,
605
+ dtype=torch.int32,
606
+ device=protein["aatype"].device,
607
+ )
608
+ restype_atom37_to_atom14 = torch.tensor(
609
+ restype_atom37_to_atom14,
610
+ dtype=torch.int32,
611
+ device=protein["aatype"].device,
612
+ )
613
+ restype_atom14_mask = torch.tensor(
614
+ restype_atom14_mask,
615
+ dtype=torch.float32,
616
+ device=protein["aatype"].device,
617
+ )
618
+ protein_aatype = protein['aatype'].to(torch.long)
619
+
620
+ # create the mapping for (residx, atom14) --> atom37, i.e. an array
621
+ # with shape (num_res, 14) containing the atom37 indices for this protein
622
+ residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
623
+ residx_atom14_mask = restype_atom14_mask[protein_aatype]
624
+
625
+ protein["atom14_atom_exists"] = residx_atom14_mask
626
+ protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
627
+
628
+ # create the gather indices for mapping back
629
+ residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
630
+ protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
631
+
632
+ # create the corresponding mask
633
+ restype_atom37_mask = torch.zeros(
634
+ [21, 37], dtype=torch.float32, device=protein["aatype"].device
635
+ )
636
+ for restype, restype_letter in enumerate(rc.restypes):
637
+ restype_name = rc.restype_1to3[restype_letter]
638
+ atom_names = rc.residue_atoms[restype_name]
639
+ for atom_name in atom_names:
640
+ atom_type = rc.atom_order[atom_name]
641
+ restype_atom37_mask[restype, atom_type] = 1
642
+
643
+ residx_atom37_mask = restype_atom37_mask[protein_aatype]
644
+ protein["atom37_atom_exists"] = residx_atom37_mask
645
+
646
+ return protein
647
+
648
+
649
+ def make_atom14_masks_np(batch):
650
+ batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
651
+ out = make_atom14_masks(batch)
652
+ out = tensor_tree_map(lambda t: np.array(t), out)
653
+ return out
654
+
655
+
656
+ def make_atom14_positions(protein):
657
+ """Constructs denser atom positions (14 dimensions instead of 37)."""
658
+ residx_atom14_mask = protein["atom14_atom_exists"]
659
+ residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
660
+
661
+ # Create a mask for known ground truth positions.
662
+ residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
663
+ protein["all_atom_mask"],
664
+ residx_atom14_to_atom37,
665
+ dim=-1,
666
+ no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
667
+ )
668
+
669
+ # Gather the ground truth positions.
670
+ residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
671
+ batched_gather(
672
+ protein["all_atom_positions"],
673
+ residx_atom14_to_atom37,
674
+ dim=-2,
675
+ no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
676
+ )
677
+ )
678
+
679
+ protein["atom14_atom_exists"] = residx_atom14_mask
680
+ protein["atom14_gt_exists"] = residx_atom14_gt_mask
681
+ protein["atom14_gt_positions"] = residx_atom14_gt_positions
682
+
683
+ # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
684
+ # alternative ground truth coordinates where the naming is swapped
685
+ restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
686
+ restype_3 += ["UNK"]
687
+
688
+ # Matrices for renaming ambiguous atoms.
689
+ all_matrices = {
690
+ res: torch.eye(
691
+ 14,
692
+ dtype=protein["all_atom_mask"].dtype,
693
+ device=protein["all_atom_mask"].device,
694
+ )
695
+ for res in restype_3
696
+ }
697
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
698
+ correspondences = torch.arange(
699
+ 14, device=protein["all_atom_mask"].device
700
+ )
701
+ for source_atom_swap, target_atom_swap in swap.items():
702
+ source_index = rc.restype_name_to_atom14_names[resname].index(
703
+ source_atom_swap
704
+ )
705
+ target_index = rc.restype_name_to_atom14_names[resname].index(
706
+ target_atom_swap
707
+ )
708
+ correspondences[source_index] = target_index
709
+ correspondences[target_index] = source_index
710
+ renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
711
+ for index, correspondence in enumerate(correspondences):
712
+ renaming_matrix[index, correspondence] = 1.0
713
+ all_matrices[resname] = renaming_matrix
714
+ renaming_matrices = torch.stack(
715
+ [all_matrices[restype] for restype in restype_3]
716
+ )
717
+
718
+ # Pick the transformation matrices for the given residue sequence
719
+ # shape (num_res, 14, 14).
720
+ renaming_transform = renaming_matrices[protein["aatype"]]
721
+
722
+ # Apply it to the ground truth positions. shape (num_res, 14, 3).
723
+ alternative_gt_positions = torch.einsum(
724
+ "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
725
+ )
726
+ protein["atom14_alt_gt_positions"] = alternative_gt_positions
727
+
728
+ # Create the mask for the alternative ground truth (differs from the
729
+ # ground truth mask, if only one of the atoms in an ambiguous pair has a
730
+ # ground truth position).
731
+ alternative_gt_mask = torch.einsum(
732
+ "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
733
+ )
734
+ protein["atom14_alt_gt_exists"] = alternative_gt_mask
735
+
736
+ # Create an ambiguous atoms mask. shape: (21, 14).
737
+ restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
738
+ for resname, swap in rc.residue_atom_renaming_swaps.items():
739
+ for atom_name1, atom_name2 in swap.items():
740
+ restype = rc.restype_order[rc.restype_3to1[resname]]
741
+ atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
742
+ atom_name1
743
+ )
744
+ atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
745
+ atom_name2
746
+ )
747
+ restype_atom14_is_ambiguous[restype, atom_idx1] = 1
748
+ restype_atom14_is_ambiguous[restype, atom_idx2] = 1
749
+
750
+ # From this create an ambiguous_mask for the given sequence.
751
+ protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
752
+ protein["aatype"]
753
+ ]
754
+
755
+ return protein
756
+
757
+
758
+ def atom37_to_frames(protein, eps=1e-8):
759
+ aatype = protein["aatype"]
760
+ all_atom_positions = protein["all_atom_positions"]
761
+ all_atom_mask = protein["all_atom_mask"]
762
+
763
+ batch_dims = len(aatype.shape[:-1])
764
+
765
+ restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
766
+ restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
767
+ restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]
768
+
769
+ for restype, restype_letter in enumerate(rc.restypes):
770
+ resname = rc.restype_1to3[restype_letter]
771
+ for chi_idx in range(4):
772
+ if rc.chi_angles_mask[restype][chi_idx]:
773
+ names = rc.chi_angles_atoms[resname][chi_idx]
774
+ restype_rigidgroup_base_atom_names[
775
+ restype, chi_idx + 4, :
776
+ ] = names[1:]
777
+
778
+ restype_rigidgroup_mask = all_atom_mask.new_zeros(
779
+ (*aatype.shape[:-1], 21, 8),
780
+ )
781
+ restype_rigidgroup_mask[..., 0] = 1
782
+ restype_rigidgroup_mask[..., 3] = 1
783
+ restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
784
+ rc.chi_angles_mask
785
+ )
786
+
787
+ lookuptable = rc.atom_order.copy()
788
+ lookuptable[""] = 0
789
+ lookup = np.vectorize(lambda x: lookuptable[x])
790
+ restype_rigidgroup_base_atom37_idx = lookup(
791
+ restype_rigidgroup_base_atom_names,
792
+ )
793
+ restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
794
+ restype_rigidgroup_base_atom37_idx,
795
+ )
796
+ restype_rigidgroup_base_atom37_idx = (
797
+ restype_rigidgroup_base_atom37_idx.view(
798
+ *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
799
+ )
800
+ )
801
+
802
+ residx_rigidgroup_base_atom37_idx = batched_gather(
803
+ restype_rigidgroup_base_atom37_idx,
804
+ aatype,
805
+ dim=-3,
806
+ no_batch_dims=batch_dims,
807
+ )
808
+
809
+ base_atom_pos = batched_gather(
810
+ all_atom_positions,
811
+ residx_rigidgroup_base_atom37_idx,
812
+ dim=-2,
813
+ no_batch_dims=len(all_atom_positions.shape[:-2]),
814
+ )
815
+
816
+ gt_frames = Rigid.from_3_points(
817
+ p_neg_x_axis=base_atom_pos[..., 0, :],
818
+ origin=base_atom_pos[..., 1, :],
819
+ p_xy_plane=base_atom_pos[..., 2, :],
820
+ eps=eps,
821
+ )
822
+
823
+ group_exists = batched_gather(
824
+ restype_rigidgroup_mask,
825
+ aatype,
826
+ dim=-2,
827
+ no_batch_dims=batch_dims,
828
+ )
829
+
830
+ gt_atoms_exist = batched_gather(
831
+ all_atom_mask,
832
+ residx_rigidgroup_base_atom37_idx,
833
+ dim=-1,
834
+ no_batch_dims=len(all_atom_mask.shape[:-1]),
835
+ )
836
+ gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists
837
+
838
+ rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
839
+ rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
840
+ rots[..., 0, 0, 0] = -1
841
+ rots[..., 0, 2, 2] = -1
842
+ rots = Rotation(rot_mats=rots)
843
+
844
+ gt_frames = gt_frames.compose(Rigid(rots, None))
845
+
846
+ restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
847
+ *((1,) * batch_dims), 21, 8
848
+ )
849
+ restype_rigidgroup_rots = torch.eye(
850
+ 3, dtype=all_atom_mask.dtype, device=aatype.device
851
+ )
852
+ restype_rigidgroup_rots = torch.tile(
853
+ restype_rigidgroup_rots,
854
+ (*((1,) * batch_dims), 21, 8, 1, 1),
855
+ )
856
+
857
+ for resname, _ in rc.residue_atom_renaming_swaps.items():
858
+ restype = rc.restype_order[rc.restype_3to1[resname]]
859
+ chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
860
+ restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
861
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
862
+ restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1
863
+
864
+ residx_rigidgroup_is_ambiguous = batched_gather(
865
+ restype_rigidgroup_is_ambiguous,
866
+ aatype,
867
+ dim=-2,
868
+ no_batch_dims=batch_dims,
869
+ )
870
+
871
+ residx_rigidgroup_ambiguity_rot = batched_gather(
872
+ restype_rigidgroup_rots,
873
+ aatype,
874
+ dim=-4,
875
+ no_batch_dims=batch_dims,
876
+ )
877
+
878
+ residx_rigidgroup_ambiguity_rot = Rotation(
879
+ rot_mats=residx_rigidgroup_ambiguity_rot
880
+ )
881
+ alt_gt_frames = gt_frames.compose(
882
+ Rigid(residx_rigidgroup_ambiguity_rot, None)
883
+ )
884
+
885
+ gt_frames_tensor = gt_frames.to_tensor_4x4()
886
+ alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
887
+
888
+ protein["rigidgroups_gt_frames"] = gt_frames_tensor
889
+ protein["rigidgroups_gt_exists"] = gt_exists
890
+ protein["rigidgroups_group_exists"] = group_exists
891
+ protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
892
+ protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
893
+
894
+ return protein
895
+
896
+
897
+ def get_chi_atom_indices():
898
+ """Returns atom indices needed to compute chi angles for all residue types.
899
+
900
+ Returns:
901
+ A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
902
+ in the order specified in rc.restypes + unknown residue type
903
+ at the end. For chi angles which are not defined on the residue, the
904
+ positions indices are by default set to 0.
905
+ """
906
+ chi_atom_indices = []
907
+ for residue_name in rc.restypes:
908
+ residue_name = rc.restype_1to3[residue_name]
909
+ residue_chi_angles = rc.chi_angles_atoms[residue_name]
910
+ atom_indices = []
911
+ for chi_angle in residue_chi_angles:
912
+ atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
913
+ for _ in range(4 - len(atom_indices)):
914
+ atom_indices.append(
915
+ [0, 0, 0, 0]
916
+ ) # For chi angles not defined on the AA.
917
+ chi_atom_indices.append(atom_indices)
918
+
919
+ chi_atom_indices.append([[0, 0, 0, 0]] * 4) # For UNKNOWN residue.
920
+
921
+ return chi_atom_indices
922
+
923
+
924
+ @curry1
925
+ def atom37_to_torsion_angles(
926
+ protein,
927
+ prefix="",
928
+ ):
929
+ """
930
+ Convert coordinates to torsion angles.
931
+
932
+ This function is extremely sensitive to floating point imprecisions
933
+ and should be run with double precision whenever possible.
934
+
935
+ Args:
936
+ Dict containing:
937
+ * (prefix)aatype:
938
+ [*, N_res] residue indices
939
+ * (prefix)all_atom_positions:
940
+ [*, N_res, 37, 3] atom positions (in atom37
941
+ format)
942
+ * (prefix)all_atom_mask:
943
+ [*, N_res, 37] atom position mask
944
+ Returns:
945
+ The same dictionary updated with the following features:
946
+
947
+ "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
948
+ Torsion angles
949
+ "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
950
+ Alternate torsion angles (accounting for 180-degree symmetry)
951
+ "(prefix)torsion_angles_mask" ([*, N_res, 7])
952
+ Torsion angles mask
953
+ """
954
+ aatype = protein[prefix + "aatype"]
955
+ all_atom_positions = protein[prefix + "all_atom_positions"]
956
+ all_atom_mask = protein[prefix + "all_atom_mask"]
957
+
958
+ aatype = torch.clamp(aatype, max=20)
959
+
960
+ pad = all_atom_positions.new_zeros(
961
+ [*all_atom_positions.shape[:-3], 1, 37, 3]
962
+ )
963
+ prev_all_atom_positions = torch.cat(
964
+ [pad, all_atom_positions[..., :-1, :, :]], dim=-3
965
+ )
966
+
967
+ pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
968
+ prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)
969
+
970
+ pre_omega_atom_pos = torch.cat(
971
+ [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
972
+ dim=-2,
973
+ )
974
+ phi_atom_pos = torch.cat(
975
+ [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
976
+ dim=-2,
977
+ )
978
+ psi_atom_pos = torch.cat(
979
+ [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
980
+ dim=-2,
981
+ )
982
+
983
+ pre_omega_mask = torch.prod(
984
+ prev_all_atom_mask[..., 1:3], dim=-1
985
+ ) * torch.prod(all_atom_mask[..., :2], dim=-1)
986
+ phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
987
+ all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
988
+ )
989
+ psi_mask = (
990
+ torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
991
+ * all_atom_mask[..., 4]
992
+ )
993
+
994
+ chi_atom_indices = torch.as_tensor(
995
+ get_chi_atom_indices(), device=aatype.device
996
+ )
997
+
998
+ atom_indices = chi_atom_indices[..., aatype, :, :]
999
+ chis_atom_pos = batched_gather(
1000
+ all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
1001
+ )
1002
+
1003
+ chi_angles_mask = list(rc.chi_angles_mask)
1004
+ chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
1005
+ chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
1006
+
1007
+ chis_mask = chi_angles_mask[aatype, :]
1008
+
1009
+ chi_angle_atoms_mask = batched_gather(
1010
+ all_atom_mask,
1011
+ atom_indices,
1012
+ dim=-1,
1013
+ no_batch_dims=len(atom_indices.shape[:-2]),
1014
+ )
1015
+ chi_angle_atoms_mask = torch.prod(
1016
+ chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
1017
+ )
1018
+ chis_mask = chis_mask * chi_angle_atoms_mask
1019
+
1020
+ torsions_atom_pos = torch.cat(
1021
+ [
1022
+ pre_omega_atom_pos[..., None, :, :],
1023
+ phi_atom_pos[..., None, :, :],
1024
+ psi_atom_pos[..., None, :, :],
1025
+ chis_atom_pos,
1026
+ ],
1027
+ dim=-3,
1028
+ )
1029
+
1030
+ torsion_angles_mask = torch.cat(
1031
+ [
1032
+ pre_omega_mask[..., None],
1033
+ phi_mask[..., None],
1034
+ psi_mask[..., None],
1035
+ chis_mask,
1036
+ ],
1037
+ dim=-1,
1038
+ )
1039
+
1040
+ torsion_frames = Rigid.from_3_points(
1041
+ torsions_atom_pos[..., 1, :],
1042
+ torsions_atom_pos[..., 2, :],
1043
+ torsions_atom_pos[..., 0, :],
1044
+ eps=1e-8,
1045
+ )
1046
+
1047
+ fourth_atom_rel_pos = torsion_frames.invert().apply(
1048
+ torsions_atom_pos[..., 3, :]
1049
+ )
1050
+
1051
+ torsion_angles_sin_cos = torch.stack(
1052
+ [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
1053
+ )
1054
+
1055
+ denom = torch.sqrt(
1056
+ torch.sum(
1057
+ torch.square(torsion_angles_sin_cos),
1058
+ dim=-1,
1059
+ dtype=torsion_angles_sin_cos.dtype,
1060
+ keepdims=True,
1061
+ )
1062
+ + 1e-8
1063
+ )
1064
+ torsion_angles_sin_cos = torsion_angles_sin_cos / denom
1065
+
1066
+ torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
1067
+ [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
1068
+ )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]
1069
+
1070
+ chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
1071
+ rc.chi_pi_periodic,
1072
+ )[aatype, ...]
1073
+
1074
+ mirror_torsion_angles = torch.cat(
1075
+ [
1076
+ all_atom_mask.new_ones(*aatype.shape, 3),
1077
+ 1.0 - 2.0 * chi_is_ambiguous,
1078
+ ],
1079
+ dim=-1,
1080
+ )
1081
+
1082
+ alt_torsion_angles_sin_cos = (
1083
+ torsion_angles_sin_cos * mirror_torsion_angles[..., None]
1084
+ )
1085
+
1086
+ protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
1087
+ protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
1088
+ protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
1089
+
1090
+ return protein
1091
+
1092
+
1093
+ def get_backbone_frames(protein):
1094
+ # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
1095
+ protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
1096
+ ..., 0, :, :
1097
+ ]
1098
+ protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
1099
+
1100
+ return protein
1101
+
1102
+
1103
+ def get_chi_angles(protein):
1104
+ dtype = protein["all_atom_mask"].dtype
1105
+ protein["chi_angles_sin_cos"] = (
1106
+ protein["torsion_angles_sin_cos"][..., 3:, :]
1107
+ ).to(dtype)
1108
+ protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)
1109
+
1110
+ return protein
1111
+
1112
+
1113
+ @curry1
1114
+ def random_crop_to_size(
1115
+ protein,
1116
+ crop_size,
1117
+ max_templates,
1118
+ shape_schema,
1119
+ subsample_templates=False,
1120
+ seed=None,
1121
+ ):
1122
+ """Crop randomly to `crop_size`, or keep as is if shorter than that."""
1123
+ # We want each ensemble to be cropped the same way
1124
+ g = torch.Generator(device=protein["seq_length"].device)
1125
+ if seed is not None:
1126
+ g.manual_seed(seed)
1127
+
1128
+ seq_length = protein["seq_length"]
1129
+
1130
+ if "template_mask" in protein:
1131
+ num_templates = protein["template_mask"].shape[-1]
1132
+ else:
1133
+ num_templates = 0
1134
+
1135
+ # No need to subsample templates if there aren't any
1136
+ subsample_templates = subsample_templates and num_templates
1137
+
1138
+ num_res_crop_size = min(int(seq_length), crop_size)
1139
+
1140
+ def _randint(lower, upper):
1141
+ return int(torch.randint(
1142
+ lower,
1143
+ upper + 1,
1144
+ (1,),
1145
+ device=protein["seq_length"].device,
1146
+ generator=g,
1147
+ )[0])
1148
+
1149
+ if subsample_templates:
1150
+ templates_crop_start = _randint(0, num_templates)
1151
+ templates_select_indices = torch.randperm(
1152
+ num_templates, device=protein["seq_length"].device, generator=g
1153
+ )
1154
+ else:
1155
+ templates_crop_start = 0
1156
+
1157
+ num_templates_crop_size = min(
1158
+ num_templates - templates_crop_start, max_templates
1159
+ )
1160
+
1161
+ n = seq_length - num_res_crop_size
1162
+ if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
1163
+ right_anchor = n
1164
+ else:
1165
+ x = _randint(0, n)
1166
+ right_anchor = n - x
1167
+
1168
+ num_res_crop_start = _randint(0, right_anchor)
1169
+
1170
+ for k, v in protein.items():
1171
+ if k not in shape_schema or (
1172
+ "template" not in k and NUM_RES not in shape_schema[k]
1173
+ ):
1174
+ continue
1175
+
1176
+ # randomly permute the templates before cropping them.
1177
+ if k.startswith("template") and subsample_templates:
1178
+ v = v[templates_select_indices]
1179
+
1180
+ slices = []
1181
+ for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
1182
+ is_num_res = dim_size == NUM_RES
1183
+ if i == 0 and k.startswith("template"):
1184
+ crop_size = num_templates_crop_size
1185
+ crop_start = templates_crop_start
1186
+ else:
1187
+ crop_start = num_res_crop_start if is_num_res else 0
1188
+ crop_size = num_res_crop_size if is_num_res else dim
1189
+ slices.append(slice(crop_start, crop_start + crop_size))
1190
+ protein[k] = v[slices]
1191
+
1192
+ protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
1193
+
1194
+ return protein
analysis/src/common/geo_utils.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for geometric operations (torch only).
3
+ """
4
+ import torch
5
+
6
+
7
+ def rots_mul_vecs(m, v):
8
+ """(Batch) Apply rotations 'm' to vectors 'v'."""
9
+ return torch.stack([
10
+ m[..., 0, 0] * v[..., 0] + m[..., 0, 1] * v[..., 1] + m[..., 0, 2] * v[..., 2],
11
+ m[..., 1, 0] * v[..., 0] + m[..., 1, 1] * v[..., 1] + m[..., 1, 2] * v[..., 2],
12
+ m[..., 2, 0] * v[..., 0] + m[..., 2, 1] * v[..., 1] + m[..., 2, 2] * v[..., 2],
13
+ ], dim=-1)
14
+
15
+ def distance(p, eps=1e-10):
16
+ """Calculate distance between a pair of points (dim=-2)."""
17
+ # [*, 2, 3]
18
+ return (eps + torch.sum((p[..., 0, :] - p[..., 1, :]) ** 2, dim=-1)) ** 0.5
19
+
20
+ def dihedral(p, eps=1e-10):
21
+ """Calculate dihedral angle between a quadruple of points (dim=-2)."""
22
+ # p: [*, 4, 3]
23
+
24
+ # [*, 3]
25
+ u1 = p[..., 1, :] - p[..., 0, :]
26
+ u2 = p[..., 2, :] - p[..., 1, :]
27
+ u3 = p[..., 3, :] - p[..., 2, :]
28
+
29
+ # [*, 3]
30
+ u1xu2 = torch.cross(u1, u2, dim=-1)
31
+ u2xu3 = torch.cross(u2, u3, dim=-1)
32
+
33
+ # [*]
34
+ u2_norm = (eps + torch.sum(u2 ** 2, dim=-1)) ** 0.5
35
+ u1xu2_norm = (eps + torch.sum(u1xu2 ** 2, dim=-1)) ** 0.5
36
+ u2xu3_norm = (eps + torch.sum(u2xu3 ** 2, dim=-1)) ** 0.5
37
+
38
+ # [*]
39
+ cos_enc = torch.einsum('...d,...d->...', u1xu2, u2xu3)/ (u1xu2_norm * u2xu3_norm)
40
+ sin_enc = torch.einsum('...d,...d->...', u2, torch.cross(u1xu2, u2xu3, dim=-1)) / (u2_norm * u1xu2_norm * u2xu3_norm)
41
+
42
+ return torch.stack([cos_enc, sin_enc], dim=-1)
43
+
44
+ def calc_distogram(pos: torch.Tensor, min_bin: float, max_bin: float, num_bins: int):
45
+ # pos: [*, L, 3]
46
+ dists_2d = torch.linalg.norm(
47
+ pos[..., :, None, :] - pos[..., None, :, :], axis=-1
48
+ )[..., None]
49
+ lower = torch.linspace(
50
+ min_bin,
51
+ max_bin,
52
+ num_bins,
53
+ device=pos.device)
54
+ upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
55
+ distogram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
56
+ return distogram
57
+
58
+ def rmsd(xyz1, xyz2):
59
+ """ Abbreviation for squared_deviation(xyz1, xyz2, 'rmsd') """
60
+ return squared_deviation(xyz1, xyz2, 'rmsd')
61
+
62
+ def squared_deviation(xyz1, xyz2, reduction='none'):
63
+ """Squared point-wise deviation between two point clouds after alignment.
64
+
65
+ Args:
66
+ xyz1: (*, L, 3), to be transformed
67
+ xyz2: (*, L, 3), the reference
68
+
69
+ Returns:
70
+ rmsd: (*, ) or none: (*, L)
71
+ """
72
+ map_to_np = False
73
+ if not torch.is_tensor(xyz1):
74
+ map_to_np = True
75
+ xyz1 = torch.as_tensor(xyz1)
76
+ xyz2 = torch.as_tensor(xyz2)
77
+
78
+ R, t = _find_rigid_alignment(xyz1, xyz2)
79
+
80
+ # print(R.shape, t.shape) # B, 3, 3 & B, 3
81
+
82
+ # xyz1_aligned = (R.bmm(xyz1.transpose(-2,-1))).transpose(-2,-1) + t.unsqueeze(1)
83
+ xyz1_aligned = (torch.matmul(R, xyz1.transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
84
+
85
+ sd = ((xyz1_aligned - xyz2)**2).sum(dim=-1) # (*, L)
86
+
87
+ assert sd.shape == xyz1.shape[:-1]
88
+ if reduction == 'none':
89
+ pass
90
+ elif reduction == 'rmsd':
91
+ sd = torch.sqrt(sd.mean(dim=-1))
92
+ else:
93
+ raise NotImplementedError()
94
+
95
+ sd = sd.numpy() if map_to_np else sd
96
+ return sd
97
+
98
+ def _find_rigid_alignment(src, tgt):
99
+ """Inspired by https://research.pasteur.fr/en/member/guillaume-bouvier/;
100
+ https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8
101
+
102
+ See: https://en.wikipedia.org/wiki/Kabsch_algorithm
103
+
104
+ 2-D or 3-D registration with known correspondences.
105
+ Registration occurs in the zero centered coordinate system, and then
106
+ must be transported back.
107
+
108
+ Args:
109
+ src: Torch tensor of shape (*, L, 3) -- Point Cloud to Align (source)
110
+ tgt: Torch tensor of shape (*, L, 3) -- Reference Point Cloud (target)
111
+ Returns:
112
+ R: optimal rotation (*, 3, 3)
113
+ t: optimal translation (*, 3)
114
+
115
+ Test on rotation + translation and on rotation + translation + reflection
116
+ >>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
117
+ >>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
118
+ >>> B = (R0.mm(A.T)).T
119
+ >>> t0 = torch.tensor([3., 3.])
120
+ >>> B += t0
121
+ >>> R, t = find_rigid_alignment(A, B)
122
+ >>> A_aligned = (R.mm(A.T)).T + t
123
+ >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
124
+ >>> rmsd
125
+ tensor(3.7064e-07)
126
+ >>> B *= torch.tensor([-1., 1.])
127
+ >>> R, t = find_rigid_alignment(A, B)
128
+ >>> A_aligned = (R.mm(A.T)).T + t
129
+ >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
130
+ >>> rmsd
131
+ tensor(3.7064e-07)
132
+ """
133
+ assert src.shape[-2] > 1
134
+ src_com = src.mean(dim=-2, keepdim=True)
135
+ tgt_com = tgt.mean(dim=-2, keepdim=True)
136
+ src_centered = src - src_com
137
+ tgt_centered = tgt - tgt_com
138
+
139
+ # Covariance matrix
140
+
141
+ # H = src_centered.transpose(-2,-1).bmm(tgt_centered) # *, 3, 3
142
+ H = torch.matmul(src_centered.transpose(-2,-1), tgt_centered)
143
+
144
+ U, S, V = torch.svd(H)
145
+ # Rotation matrix
146
+
147
+ # R = V.bmm(U.transpose(-2,-1))
148
+ R = torch.matmul(V, U.transpose(-2, -1))
149
+
150
+ # Translation vector
151
+
152
+ # t = tgt_com - R.bmm(src_com.transpose(-2,-1)).transpose(-2,-1)
153
+ t = tgt_com - torch.matmul(R, src_com.transpose(-2, -1)).transpose(-2, -1)
154
+
155
+ return R, t.squeeze(-2) # (B, 3, 3), (B, 3)
analysis/src/common/pdb_utils.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for operating PDB files.
2
+ """
3
+ import os
4
+ import re
5
+ from typing import Optional
6
+ from collections import OrderedDict
7
+
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+
11
+ import biotite.structure as struct
12
+ from biotite.structure.io.pdb import PDBFile
13
+
14
+ from src.common import protein
15
+
16
+
17
+ def write_pdb_string(pdb_string: str, save_to: str):
18
+ """Write pdb string to file"""
19
+ with open(save_to, 'w') as f:
20
+ f.write(pdb_string)
21
+
22
+ def read_pdb_to_string(pdb_file):
23
+ """Read PDB file as pdb string. Convenient API"""
24
+ with open(pdb_file, 'r') as fi:
25
+ pdb_string = ''
26
+ for line in fi:
27
+ if line.startswith('END') or line.startswith('TER') \
28
+ or line.startswith('MODEL') or line.startswith('ATOM'):
29
+ pdb_string += line
30
+ return pdb_string
31
+
32
+ def merge_pdbfiles(input, output_file, verbose=True):
33
+ """ordered merging process of pdbs"""
34
+ if isinstance(input, str):
35
+ pdb_files = [os.path.join(input, f) for f in os.listdir(input) if f.endswith('.pdb')]
36
+ elif isinstance(input, list):
37
+ pdb_files = input
38
+
39
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
40
+
41
+ model_number = 0
42
+ pdb_lines = []
43
+ if verbose:
44
+ _iter = tqdm(pdb_files, desc='Merging PDBs')
45
+ else:
46
+ _iter = pdb_files
47
+
48
+ for pdb_file in _iter:
49
+ with open(pdb_file, 'r') as pdb:
50
+ lines = pdb.readlines()
51
+ single_model = True
52
+
53
+ for line in lines:
54
+ if line.startswith('MODEL') or line.startswith('ENDMDL'):
55
+ single_model = False
56
+ break
57
+
58
+ if single_model: # single model
59
+ model_number += 1
60
+ pdb_lines.append(f"MODEL {model_number}")
61
+ for line in lines:
62
+ if line.startswith('TER') or line.startswith('ATOM'):
63
+ pdb_lines.append(line.strip())
64
+ pdb_lines.append("ENDMDL")
65
+ else: # multiple models
66
+ for line in lines:
67
+ if line.startswith('MODEL'):
68
+ model_number += 1
69
+ if model_number > 1:
70
+ pdb_lines.append("ENDMDL")
71
+ pdb_lines.append(f"MODEL {model_number}")
72
+ elif line.startswith('END'):
73
+ continue
74
+ elif line.startswith('TER') or line.startswith('ATOM'):
75
+ pdb_lines.append(line.strip())
76
+ pdb_lines.append('ENDMDL')
77
+ pdb_lines.append('END')
78
+ pdb_lines = [_line.ljust(80) for _line in pdb_lines]
79
+ pdb_str = '\n'.join(pdb_lines) + '\n'
80
+ with open(output_file, 'w') as fo:
81
+ fo.write(pdb_str)
82
+
83
+ if verbose:
84
+ print(f"Merged {len(pdb_files)} PDB files into {output_file} with {model_number} models.")
85
+
86
+
87
+ def split_pdbfile(pdb_file, output_dir=None, suffix='index', verbose=True):
88
+ """Split a PDB file into multiple PDB files in output_dir.
89
+ Preassume that each model is wrapped by 'MODEL' and 'ENDMDL'.
90
+ """
91
+ assert os.path.exists(pdb_file), f"File {pdb_file} does not exist."
92
+ assert suffix == 'index', 'Only support [suffix=index] for now.'
93
+
94
+ if output_dir is not None: # also dump to output_dir
95
+ os.makedirs(output_dir, exist_ok=True)
96
+ base = os.path.splitext(os.path.basename(pdb_file))[0]
97
+
98
+ i = 0
99
+ pdb_strs = []
100
+ pdb_string = ''
101
+ with open(pdb_file, 'r') as fi:
102
+ # pdb_string = ''
103
+ for line in fi:
104
+ if line.startswith('MODEL'):
105
+ pdb_string = ''
106
+ elif line.startswith('ATOM') or line.startswith('TER'):
107
+ pdb_string += line
108
+ elif line.startswith('ENDMDL') or line.startswith('END'):
109
+ if pdb_string == '': continue
110
+ pdb_string += 'END\n'
111
+ if output_dir is not None:
112
+ _save_to = os.path.join(output_dir, f'{base}_{i}.pdb') if suffix == 'index' else None
113
+ with open(_save_to, 'w') as fo:
114
+ fo.write(pdb_string)
115
+ pdb_strs.append(pdb_string)
116
+ pdb_string = ''
117
+ i += 1
118
+ else:
119
+ if verbose:
120
+ print(f"Warning: line '{line}' is not recognized. Skip.")
121
+ if verbose:
122
+ print(f">>> Split pdb {pdb_file} into {i}/{len(pdb_strs)} structures.")
123
+ return pdb_strs
124
+
125
+
126
+ def stratify_sample_pdbfile(input_path, output_path, n_max_sample=1000, end_at=0, verbose=True):
127
+ """ """
128
+ assert os.path.exists(input_path), f"File {input_path} does not exist."
129
+ assert not os.path.exists(output_path), f"Output path {output_path} already exists."
130
+
131
+ i = 0
132
+ pdb_strs = []
133
+ with open(input_path, 'r') as fi:
134
+ # pdb_string = ''
135
+ pdb_lines_per_model = []
136
+ for line in fi:
137
+ if line.startswith('MODEL'):
138
+ pdb_lines_per_model = []
139
+ elif line.startswith('ATOM') or line.startswith('TER'):
140
+ pdb_lines_per_model.append(line.strip())
141
+ elif line.startswith('ENDMDL') or line.startswith('END'):
142
+ if pdb_lines_per_model == []: continue # skip empty model
143
+ # wrap up the model
144
+ pdb_lines_per_model.append('ENDMDL')
145
+ # Pad all lines to 80 characters.
146
+ pdb_lines_per_model = [_line.ljust(80) for _line in pdb_lines_per_model]
147
+ pdb_str_per_model = '\n'.join(pdb_lines_per_model) + '\n' # Add terminating newline.
148
+ pdb_strs.append(pdb_str_per_model)
149
+ # reset
150
+ pdb_lines_per_model = []
151
+ i += 1
152
+ else:
153
+ if verbose:
154
+ print(f"Warning: line '{line}' is not recognized. Skip.")
155
+ if end_at > 0 and i > end_at:
156
+ break
157
+ end = end_at if end_at > 0 else len(pdb_strs)
158
+
159
+ # sample evenly
160
+ if end > n_max_sample:
161
+ interleave_step = int(end // n_max_sample) # floor
162
+ sampled_pdb_strs = pdb_strs[:end][::interleave_step][:n_max_sample]
163
+ else:
164
+ sampled_pdb_strs = pdb_strs[:end]
165
+
166
+ output_str = ''
167
+ for i, pdb_str in enumerate(sampled_pdb_strs): # renumber models
168
+ output_str += f"MODEL {i+1}".ljust(80) + '\n'
169
+ output_str += pdb_str
170
+ output_str = output_str + ('END'.ljust(80) + '\n')
171
+
172
+ write_pdb_string(output_str, save_to=output_path)
173
+ if verbose:
174
+ print(f">>> Split pdb {input_path} into {len(sampled_pdb_strs)}/{n_max_sample} structures.")
175
+ return
176
+
177
+
178
+ def protein_with_default_params(
179
+ atom_positions: np.ndarray,
180
+ atom_mask: np.ndarray,
181
+ aatype: Optional[np.ndarray] = None,
182
+ b_factors: Optional[np.ndarray] = None,
183
+ chain_index: Optional[np.ndarray] = None,
184
+ residue_index: Optional[np.ndarray] = None,
185
+ ):
186
+ assert atom_positions.ndim == 3
187
+ assert atom_positions.shape[-1] == 3
188
+ assert atom_positions.shape[-2] == 37
189
+ n = atom_positions.shape[0]
190
+ sqz = lambda x: np.squeeze(x) if x.shape[0] == 1 and len(x.shape) > 1 else x
191
+
192
+ residue_index = np.arange(n) + 1 if residue_index is None else sqz(residue_index)
193
+ chain_index = np.zeros(n) if chain_index is None else sqz(chain_index)
194
+ b_factors = np.zeros([n, 37]) if b_factors is None else sqz(b_factors)
195
+ aatype = np.zeros(n, dtype=int) if aatype is None else sqz(aatype)
196
+
197
+ return protein.Protein(
198
+ atom_positions=atom_positions,
199
+ atom_mask=atom_mask,
200
+ aatype=aatype,
201
+ residue_index=residue_index,
202
+ chain_index=chain_index,
203
+ b_factors=b_factors
204
+ )
205
+
206
+ def atom37_to_pdb(
207
+ save_to: str,
208
+ atom_positions: np.ndarray,
209
+ aatype: Optional[np.ndarray] = None,
210
+ b_factors: Optional[np.ndarray] = None,
211
+ chain_index: Optional[np.ndarray] = None,
212
+ residue_index: Optional[np.ndarray] = None,
213
+ overwrite: bool = False,
214
+ no_indexing: bool = True,
215
+ ):
216
+ # configure save path
217
+ if overwrite:
218
+ max_existing_idx = 0
219
+ else:
220
+ file_dir = os.path.dirname(save_to)
221
+ file_name = os.path.basename(save_to).strip('.pdb')
222
+ existing_files = [x for x in os.listdir(file_dir) if file_name in x]
223
+ max_existing_idx = max([
224
+ int(re.findall(r'_(\d+).pdb', x)[0]) for x in existing_files if re.findall(r'_(\d+).pdb', x)
225
+ if re.findall(r'_(\d+).pdb', x)] + [0])
226
+ if not no_indexing:
227
+ save_to = save_to.replace('.pdb', '') + f'_{max_existing_idx+1}.pdb'
228
+ else:
229
+ save_to = save_to
230
+
231
+ with open(save_to, 'w') as f:
232
+ if atom_positions.ndim == 4:
233
+ for mi, pos37 in enumerate(atom_positions):
234
+ atom_mask = np.sum(np.abs(pos37), axis=-1) > 1e-7
235
+ prot = protein_with_default_params(
236
+ pos37, atom_mask, aatype=aatype, b_factors=b_factors,
237
+ chain_index=chain_index, residue_index=residue_index
238
+ )
239
+ pdb_str = protein.to_pdb(prot, model=mi+1, add_end=False)
240
+ f.write(pdb_str)
241
+ elif atom_positions.ndim == 3:
242
+ atom_mask = np.sum(np.abs(atom_positions), axis=-1) > 1e-7
243
+ prot = protein_with_default_params(
244
+ atom_positions, atom_mask, aatype=aatype, b_factors=b_factors,
245
+ chain_index=chain_index, residue_index=residue_index
246
+ )
247
+ pdb_str = protein.to_pdb(prot, model=1, add_end=False)
248
+ f.write(pdb_str)
249
+ else:
250
+ raise ValueError(f'Invalid positions shape {atom_positions.shape}')
251
+ f.write('END')
252
+
253
+ return save_to
254
+
255
+ def extract_backbone_coords_from_pdb(pdb_path: str, target_atoms: Optional[list] = ["CA"]):
256
+ structure = PDBFile.read(pdb_path)
257
+ structure_list = structure.get_structure()
258
+
259
+ coords_list = []
260
+ for b_idx in range(structure.get_model_count()):
261
+ chain = structure_list[b_idx]
262
+
263
+ backbone_atoms = chain[struct.filter_backbone(chain)] # This includes the “N”, “CA” and “C” atoms of amino acids.
264
+ ret_coords = OrderedDict()
265
+ # init dict
266
+ for k in target_atoms:
267
+ ret_coords[k] = []
268
+
269
+ for c in backbone_atoms:
270
+ if c.atom_name in ret_coords:
271
+ ret_coords[c.atom_name].append(c.coord)
272
+
273
+ ret_coords = [np.vstack(v) for k,v in ret_coords.items()]
274
+ if len(target_atoms) == 1:
275
+ ret_coords = ret_coords[0] # L, 3
276
+ else:
277
+ ret_coords = np.stack(ret_coords, axis=1) # L, na, 3
278
+
279
+ coords_list.append(ret_coords)
280
+
281
+ coords_list = np.stack(coords_list, axis=0) # B, L, na, 3 or B, L, 3 (ca only)
282
+ return coords_list
283
+
284
+
285
+ def extract_backbone_coords_from_pdb_dir(pdb_dir: str):
286
+ return np.concatenate([
287
+ extract_backbone_coords_from_pdb(os.path.join(pdb_dir, f))
288
+ for f in os.listdir(pdb_dir) if f.endswith('.pdb')
289
+ ], axis=0)
290
+
291
+ def extract_backbone_coords_from_npy(npy_path: str):
292
+ return np.load(npy_path)
293
+
294
+
295
+ def extract_backbone_coords(input_path: str,
296
+ max_n_model: Optional[int] = None,
297
+ ):
298
+ """Extract backbone coordinates from PDB file.
299
+
300
+ Args:
301
+ input_path (str): The path to the PDB file.
302
+ ca_only (bool): Whether to extract only CA coordinates.
303
+ max_n_model (int): The maximum number of models to extract.
304
+ """
305
+ assert os.path.exists(input_path), f"File {input_path} does not exist."
306
+ if input_path.endswith('.pdb'):
307
+ coords = extract_backbone_coords_from_pdb(input_path)
308
+ elif input_path.endswith('.npy'):
309
+ coords = extract_backbone_coords_from_npy(input_path)
310
+ elif os.path.isdir(input_path):
311
+ coords = extract_backbone_coords_from_pdb_dir(input_path)
312
+ else:
313
+ raise ValueError(f"Unrecognized input path {input_path}.")
314
+
315
+ if max_n_model is not None and len(coords) > max_n_model > 0:
316
+ coords = coords[:max_n_model]
317
+ return coords
318
+
319
+
320
+
321
+ if __name__ == '__main__':
322
+ import argparse
323
+ def get_argparser():
324
+ parser = argparse.ArgumentParser(description='Main script for pdb processing.')
325
+ parser.add_argument("input", type=str, help="The generic path to sampled pdb directory / pdb file.")
326
+ parser.add_argument("-m", "--mode", type=str, help="The mode of processing.",
327
+ default="split")
328
+ parser.add_argument("-o", "--output", type=str, help="The output directory for processed pdb files.",
329
+ default=None)
330
+
331
+ args = parser.parse_args()
332
+ return args
333
+ args = get_argparser()
334
+
335
+ # ad hoc functions
336
+ def split_pdbs(args):
337
+ os.makedirs(args.output, exist_ok=True)
338
+ _ = split_pdbfile(pdb_file=args.input,
339
+ output_dir=args.output)
340
+
341
+ def merge_pdbs(args):
342
+ output = args.output or f"{args.input}_all.pdb"
343
+ merge_pdbfiles(input=args.input,
344
+ output_file=output)
345
+
346
+ if args.mode == "split":
347
+ split_pdbs(args)
348
+ elif args.mode == "merge":
349
+ merge_pdbs(args)
350
+ elif args.mode == "stratify":
351
+ stratify_sample_pdbfile(input_path=args.input, output_path=args.output)
352
+ else:
353
+ raise ValueError(f"Unrecognized mode {args.mode}.")
analysis/src/common/protein.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Protein data type."""
16
+ import dataclasses
17
+ import io
18
+ from typing import Any, Mapping, Optional
19
+
20
+ from Bio.PDB import PDBParser
21
+ import numpy as np
22
+
23
+ from src.common import residue_constants
24
+
25
+
26
+ FeatureDict = Mapping[str, np.ndarray]
27
+ ModelOutput = Mapping[str, Any] # Is a nested dict.
28
+
29
+ # Complete sequence of chain IDs supported by the PDB format.
30
+ PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
31
+ PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62.
32
+
33
+
34
+ @dataclasses.dataclass(frozen=True)
35
+ class Protein:
36
+ """Protein structure representation."""
37
+
38
+ # Cartesian coordinates of atoms in angstroms. The atom types correspond to
39
+ # residue_constants.atom_types, i.e. the first three are N, CA, CB.
40
+ atom_positions: np.ndarray # [num_res, num_atom_type, 3]
41
+
42
+ # Amino-acid type for each residue represented as an integer between 0 and
43
+ # 20, where 20 is 'X'.
44
+ aatype: np.ndarray # [num_res]
45
+
46
+ # Binary float mask to indicate presence of a particular atom. 1.0 if an atom
47
+ # is present and 0.0 if not. This should be used for loss masking.
48
+ atom_mask: np.ndarray # [num_res, num_atom_type]
49
+
50
+ # Residue index as used in PDB. It is not necessarily continuous or 0-indexed.
51
+ residue_index: np.ndarray # [num_res]
52
+
53
+ # 0-indexed number corresponding to the chain in the protein that this residue
54
+ # belongs to.
55
+ chain_index: np.ndarray # [num_res]
56
+
57
+ # B-factors, or temperature factors, of each residue (in sq. angstroms units),
58
+ # representing the displacement of the residue from its ground truth mean
59
+ # value.
60
+ b_factors: np.ndarray # [num_res, num_atom_type]
61
+
62
+ def __post_init__(self):
63
+ if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS:
64
+ raise ValueError(
65
+ f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains '
66
+ 'because these cannot be written to PDB format.')
67
+
68
+ def to_dict(self):
69
+ return dataclasses.asdict(self)
70
+
71
+
72
+ def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein:
73
+ """Takes a PDB string and constructs a Protein object.
74
+
75
+ WARNING: All non-standard residue types will be converted into UNK. All
76
+ non-standard atoms will be ignored.
77
+
78
+ Args:
79
+ pdb_str: The contents of the pdb file
80
+ chain_id: If chain_id is specified (e.g. A), then only that chain
81
+ is parsed. Otherwise all chains are parsed.
82
+
83
+ Returns:
84
+ A new `Protein` parsed from the pdb contents.
85
+ """
86
+ pdb_fh = io.StringIO(pdb_str)
87
+ parser = PDBParser(QUIET=True)
88
+ structure = parser.get_structure('none', pdb_fh)
89
+ models = list(structure.get_models())
90
+ if len(models) != 1:
91
+ raise ValueError(
92
+ f'Only single model PDBs are supported. Found {len(models)} models.')
93
+ model = models[0]
94
+
95
+ atom_positions = []
96
+ aatype = []
97
+ atom_mask = []
98
+ residue_index = []
99
+ chain_ids = []
100
+ b_factors = []
101
+
102
+ for chain in model:
103
+ if chain_id is not None and chain.id != chain_id:
104
+ continue
105
+ for res in chain:
106
+ if res.id[2] != ' ':
107
+ raise ValueError(
108
+ f'PDB contains an insertion code at chain {chain.id} and residue '
109
+ f'index {res.id[1]}. These are not supported.')
110
+ res_shortname = residue_constants.restype_3to1.get(res.resname, 'X')
111
+ restype_idx = residue_constants.restype_order.get(
112
+ res_shortname, residue_constants.restype_num)
113
+ pos = np.zeros((residue_constants.atom_type_num, 3))
114
+ mask = np.zeros((residue_constants.atom_type_num,))
115
+ res_b_factors = np.zeros((residue_constants.atom_type_num,))
116
+ for atom in res:
117
+ if atom.name not in residue_constants.atom_types:
118
+ continue
119
+ pos[residue_constants.atom_order[atom.name]] = atom.coord
120
+ mask[residue_constants.atom_order[atom.name]] = 1.
121
+ res_b_factors[residue_constants.atom_order[atom.name]] = atom.bfactor
122
+ if np.sum(mask) < 0.5:
123
+ # If no known atom positions are reported for the residue then skip it.
124
+ continue
125
+ aatype.append(restype_idx)
126
+ atom_positions.append(pos)
127
+ atom_mask.append(mask)
128
+ residue_index.append(res.id[1])
129
+ chain_ids.append(chain.id)
130
+ b_factors.append(res_b_factors)
131
+
132
+ # Chain IDs are usually characters so map these to ints.
133
+ unique_chain_ids = np.unique(chain_ids)
134
+ chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)}
135
+ chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids])
136
+
137
+ return Protein(
138
+ atom_positions=np.array(atom_positions),
139
+ atom_mask=np.array(atom_mask),
140
+ aatype=np.array(aatype),
141
+ residue_index=np.array(residue_index),
142
+ chain_index=chain_index,
143
+ b_factors=np.array(b_factors))
144
+
145
+
146
+ def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
147
+ chain_end = 'TER'
148
+ return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
149
+ f'{chain_name:>1}{residue_index:>4}')
150
+
151
+
152
+ def to_pdb(prot: Protein, model=1, add_end=True) -> str:
153
+ """Converts a `Protein` instance to a PDB string.
154
+
155
+ Args:
156
+ prot: The protein to convert to PDB.
157
+
158
+ Returns:
159
+ PDB string.
160
+ """
161
+ restypes = residue_constants.restypes + ['X']
162
+ res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK')
163
+ atom_types = residue_constants.atom_types
164
+
165
+ pdb_lines = []
166
+
167
+ atom_mask = prot.atom_mask
168
+ aatype = prot.aatype
169
+ atom_positions = prot.atom_positions
170
+ residue_index = prot.residue_index.astype(int)
171
+ chain_index = prot.chain_index.astype(int)
172
+ b_factors = prot.b_factors
173
+
174
+ if np.any(aatype > residue_constants.restype_num):
175
+ raise ValueError('Invalid aatypes.')
176
+
177
+ # Construct a mapping from chain integer indices to chain ID strings.
178
+ chain_ids = {}
179
+ for i in np.unique(chain_index): # np.unique gives sorted output.
180
+ if i >= PDB_MAX_CHAINS:
181
+ raise ValueError(
182
+ f'The PDB format supports at most {PDB_MAX_CHAINS} chains.')
183
+ chain_ids[i] = PDB_CHAIN_IDS[i]
184
+
185
+ pdb_lines.append(f'MODEL {model}')
186
+ atom_index = 1
187
+ last_chain_index = chain_index[0]
188
+ # Add all atom sites.
189
+ for i in range(aatype.shape[0]):
190
+ # Close the previous chain if in a multichain PDB.
191
+ if last_chain_index != chain_index[i]:
192
+ pdb_lines.append(_chain_end(
193
+ atom_index, res_1to3(aatype[i - 1]), chain_ids[chain_index[i - 1]],
194
+ residue_index[i - 1]))
195
+ last_chain_index = chain_index[i]
196
+ atom_index += 1 # Atom index increases at the TER symbol.
197
+
198
+ res_name_3 = res_1to3(aatype[i])
199
+ for atom_name, pos, mask, b_factor in zip(
200
+ atom_types, atom_positions[i], atom_mask[i], b_factors[i]):
201
+ if mask < 0.5:
202
+ continue
203
+
204
+ # skip CB for GLY
205
+ if res_name_3 == 'GLY' and atom_name == 'CB':
206
+ continue
207
+
208
+ record_type = 'ATOM'
209
+ name = atom_name if len(atom_name) == 4 else f' {atom_name}'
210
+ alt_loc = ''
211
+ insertion_code = ''
212
+ occupancy = 1.00
213
+ element = atom_name[0] # Protein supports only C, N, O, S, this works.
214
+ charge = ''
215
+ # PDB is a columnar format, every space matters here!
216
+ atom_line = (f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}'
217
+ f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}'
218
+ f'{residue_index[i]:>4}{insertion_code:>1} '
219
+ f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}'
220
+ f'{occupancy:>6.2f}{b_factor:>6.2f} '
221
+ f'{element:>2}{charge:>2}')
222
+ pdb_lines.append(atom_line)
223
+ atom_index += 1
224
+
225
+ # Close the final chain.
226
+ pdb_lines.append(_chain_end(atom_index, res_1to3(aatype[-1]),
227
+ chain_ids[chain_index[-1]], residue_index[-1]))
228
+ pdb_lines.append('ENDMDL')
229
+ if add_end:
230
+ pdb_lines.append('END')
231
+
232
+ # Pad all lines to 80 characters.
233
+ pdb_lines = [line.ljust(80) for line in pdb_lines]
234
+ return '\n'.join(pdb_lines) + '\n' # Add terminating newline.
235
+
236
+
237
+ def ideal_atom_mask(prot: Protein) -> np.ndarray:
238
+ """Computes an ideal atom mask.
239
+
240
+ `Protein.atom_mask` typically is defined according to the atoms that are
241
+ reported in the PDB. This function computes a mask according to heavy atoms
242
+ that should be present in the given sequence of amino acids.
243
+
244
+ Args:
245
+ prot: `Protein` whose fields are `numpy.ndarray` objects.
246
+
247
+ Returns:
248
+ An ideal atom mask.
249
+ """
250
+ return residue_constants.STANDARD_ATOM_MASK[prot.aatype]
251
+
252
+
253
+ def from_prediction(
254
+ features: FeatureDict,
255
+ result: ModelOutput,
256
+ b_factors: Optional[np.ndarray] = None,
257
+ remove_leading_feature_dimension: bool = True) -> Protein:
258
+ """Assembles a protein from a prediction.
259
+
260
+ Args:
261
+ features: Dictionary holding model inputs.
262
+ result: Dictionary holding model outputs.
263
+ b_factors: (Optional) B-factors to use for the protein.
264
+ remove_leading_feature_dimension: Whether to remove the leading dimension
265
+ of the `features` values.
266
+
267
+ Returns:
268
+ A protein instance.
269
+ """
270
+ fold_output = result['structure_module']
271
+
272
+ def _maybe_remove_leading_dim(arr: np.ndarray) -> np.ndarray:
273
+ return arr[0] if remove_leading_feature_dimension else arr
274
+
275
+ if 'asym_id' in features:
276
+ chain_index = _maybe_remove_leading_dim(features['asym_id'])
277
+ else:
278
+ chain_index = np.zeros_like(_maybe_remove_leading_dim(features['aatype']))
279
+
280
+ if b_factors is None:
281
+ b_factors = np.zeros_like(fold_output['final_atom_mask'])
282
+
283
+ return Protein(
284
+ aatype=_maybe_remove_leading_dim(features['aatype']),
285
+ atom_positions=fold_output['final_atom_positions'],
286
+ atom_mask=fold_output['final_atom_mask'],
287
+ residue_index=_maybe_remove_leading_dim(features['residue_index']) + 1,
288
+ chain_index=chain_index,
289
+ b_factors=b_factors)
analysis/src/common/residue_constants.py ADDED
@@ -0,0 +1,897 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Constants used in AlphaFold."""
16
+
17
+ import collections
18
+ import functools
19
+ import os
20
+ from typing import List, Mapping, Tuple
21
+
22
+ import numpy as np
23
+ import tree
24
+
25
+ # Internal import (35fd).
26
+
27
+
28
+ # Distance from one CA to next CA [trans configuration: omega = 180].
29
+ ca_ca = 3.80209737096
30
+
31
+ # Format: The list for each AA type contains chi1, chi2, chi3, chi4 in
32
+ # this order (or a relevant subset from chi1 onwards). ALA and GLY don't have
33
+ # chi angles so their chi angle lists are empty.
34
+ chi_angles_atoms = {
35
+ 'ALA': [],
36
+ # Chi5 in arginine is always 0 +- 5 degrees, so ignore it.
37
+ 'ARG': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
38
+ ['CB', 'CG', 'CD', 'NE'], ['CG', 'CD', 'NE', 'CZ']],
39
+ 'ASN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
40
+ 'ASP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'OD1']],
41
+ 'CYS': [['N', 'CA', 'CB', 'SG']],
42
+ 'GLN': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
43
+ ['CB', 'CG', 'CD', 'OE1']],
44
+ 'GLU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
45
+ ['CB', 'CG', 'CD', 'OE1']],
46
+ 'GLY': [],
47
+ 'HIS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'ND1']],
48
+ 'ILE': [['N', 'CA', 'CB', 'CG1'], ['CA', 'CB', 'CG1', 'CD1']],
49
+ 'LEU': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
50
+ 'LYS': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD'],
51
+ ['CB', 'CG', 'CD', 'CE'], ['CG', 'CD', 'CE', 'NZ']],
52
+ 'MET': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'SD'],
53
+ ['CB', 'CG', 'SD', 'CE']],
54
+ 'PHE': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
55
+ 'PRO': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD']],
56
+ 'SER': [['N', 'CA', 'CB', 'OG']],
57
+ 'THR': [['N', 'CA', 'CB', 'OG1']],
58
+ 'TRP': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
59
+ 'TYR': [['N', 'CA', 'CB', 'CG'], ['CA', 'CB', 'CG', 'CD1']],
60
+ 'VAL': [['N', 'CA', 'CB', 'CG1']],
61
+ }
62
+
63
+ # If chi angles given in fixed-length array, this matrix determines how to mask
64
+ # them for each AA type. The order is as per restype_order (see below).
65
+ chi_angles_mask = [
66
+ [0.0, 0.0, 0.0, 0.0], # ALA
67
+ [1.0, 1.0, 1.0, 1.0], # ARG
68
+ [1.0, 1.0, 0.0, 0.0], # ASN
69
+ [1.0, 1.0, 0.0, 0.0], # ASP
70
+ [1.0, 0.0, 0.0, 0.0], # CYS
71
+ [1.0, 1.0, 1.0, 0.0], # GLN
72
+ [1.0, 1.0, 1.0, 0.0], # GLU
73
+ [0.0, 0.0, 0.0, 0.0], # GLY
74
+ [1.0, 1.0, 0.0, 0.0], # HIS
75
+ [1.0, 1.0, 0.0, 0.0], # ILE
76
+ [1.0, 1.0, 0.0, 0.0], # LEU
77
+ [1.0, 1.0, 1.0, 1.0], # LYS
78
+ [1.0, 1.0, 1.0, 0.0], # MET
79
+ [1.0, 1.0, 0.0, 0.0], # PHE
80
+ [1.0, 1.0, 0.0, 0.0], # PRO
81
+ [1.0, 0.0, 0.0, 0.0], # SER
82
+ [1.0, 0.0, 0.0, 0.0], # THR
83
+ [1.0, 1.0, 0.0, 0.0], # TRP
84
+ [1.0, 1.0, 0.0, 0.0], # TYR
85
+ [1.0, 0.0, 0.0, 0.0], # VAL
86
+ ]
87
+
88
+ # The following chi angles are pi periodic: they can be rotated by a multiple
89
+ # of pi without affecting the structure.
90
+ chi_pi_periodic = [
91
+ [0.0, 0.0, 0.0, 0.0], # ALA
92
+ [0.0, 0.0, 0.0, 0.0], # ARG
93
+ [0.0, 0.0, 0.0, 0.0], # ASN
94
+ [0.0, 1.0, 0.0, 0.0], # ASP
95
+ [0.0, 0.0, 0.0, 0.0], # CYS
96
+ [0.0, 0.0, 0.0, 0.0], # GLN
97
+ [0.0, 0.0, 1.0, 0.0], # GLU
98
+ [0.0, 0.0, 0.0, 0.0], # GLY
99
+ [0.0, 0.0, 0.0, 0.0], # HIS
100
+ [0.0, 0.0, 0.0, 0.0], # ILE
101
+ [0.0, 0.0, 0.0, 0.0], # LEU
102
+ [0.0, 0.0, 0.0, 0.0], # LYS
103
+ [0.0, 0.0, 0.0, 0.0], # MET
104
+ [0.0, 1.0, 0.0, 0.0], # PHE
105
+ [0.0, 0.0, 0.0, 0.0], # PRO
106
+ [0.0, 0.0, 0.0, 0.0], # SER
107
+ [0.0, 0.0, 0.0, 0.0], # THR
108
+ [0.0, 0.0, 0.0, 0.0], # TRP
109
+ [0.0, 1.0, 0.0, 0.0], # TYR
110
+ [0.0, 0.0, 0.0, 0.0], # VAL
111
+ [0.0, 0.0, 0.0, 0.0], # UNK
112
+ ]
113
+
114
+ # Atoms positions relative to the 8 rigid groups, defined by the pre-omega, phi,
115
+ # psi and chi angles:
116
+ # 0: 'backbone group',
117
+ # 1: 'pre-omega-group', (empty)
118
+ # 2: 'phi-group', (currently empty, because it defines only hydrogens)
119
+ # 3: 'psi-group',
120
+ # 4,5,6,7: 'chi1,2,3,4-group'
121
+ # The atom positions are relative to the axis-end-atom of the corresponding
122
+ # rotation axis. The x-axis is in direction of the rotation axis, and the y-axis
123
+ # is defined such that the dihedral-angle-definiting atom (the last entry in
124
+ # chi_angles_atoms above) is in the xy-plane (with a positive y-coordinate).
125
+ # format: [atomname, group_idx, rel_position]
126
+ rigid_group_atom_positions = {
127
+ 'ALA': [
128
+ ['N', 0, (-0.525, 1.363, 0.000)],
129
+ ['CA', 0, (0.000, 0.000, 0.000)],
130
+ ['C', 0, (1.526, -0.000, -0.000)],
131
+ ['CB', 0, (-0.529, -0.774, -1.205)],
132
+ ['O', 3, (0.627, 1.062, 0.000)],
133
+ ],
134
+ 'ARG': [
135
+ ['N', 0, (-0.524, 1.362, -0.000)],
136
+ ['CA', 0, (0.000, 0.000, 0.000)],
137
+ ['C', 0, (1.525, -0.000, -0.000)],
138
+ ['CB', 0, (-0.524, -0.778, -1.209)],
139
+ ['O', 3, (0.626, 1.062, 0.000)],
140
+ ['CG', 4, (0.616, 1.390, -0.000)],
141
+ ['CD', 5, (0.564, 1.414, 0.000)],
142
+ ['NE', 6, (0.539, 1.357, -0.000)],
143
+ ['NH1', 7, (0.206, 2.301, 0.000)],
144
+ ['NH2', 7, (2.078, 0.978, -0.000)],
145
+ ['CZ', 7, (0.758, 1.093, -0.000)],
146
+ ],
147
+ 'ASN': [
148
+ ['N', 0, (-0.536, 1.357, 0.000)],
149
+ ['CA', 0, (0.000, 0.000, 0.000)],
150
+ ['C', 0, (1.526, -0.000, -0.000)],
151
+ ['CB', 0, (-0.531, -0.787, -1.200)],
152
+ ['O', 3, (0.625, 1.062, 0.000)],
153
+ ['CG', 4, (0.584, 1.399, 0.000)],
154
+ ['ND2', 5, (0.593, -1.188, 0.001)],
155
+ ['OD1', 5, (0.633, 1.059, 0.000)],
156
+ ],
157
+ 'ASP': [
158
+ ['N', 0, (-0.525, 1.362, -0.000)],
159
+ ['CA', 0, (0.000, 0.000, 0.000)],
160
+ ['C', 0, (1.527, 0.000, -0.000)],
161
+ ['CB', 0, (-0.526, -0.778, -1.208)],
162
+ ['O', 3, (0.626, 1.062, -0.000)],
163
+ ['CG', 4, (0.593, 1.398, -0.000)],
164
+ ['OD1', 5, (0.610, 1.091, 0.000)],
165
+ ['OD2', 5, (0.592, -1.101, -0.003)],
166
+ ],
167
+ 'CYS': [
168
+ ['N', 0, (-0.522, 1.362, -0.000)],
169
+ ['CA', 0, (0.000, 0.000, 0.000)],
170
+ ['C', 0, (1.524, 0.000, 0.000)],
171
+ ['CB', 0, (-0.519, -0.773, -1.212)],
172
+ ['O', 3, (0.625, 1.062, -0.000)],
173
+ ['SG', 4, (0.728, 1.653, 0.000)],
174
+ ],
175
+ 'GLN': [
176
+ ['N', 0, (-0.526, 1.361, -0.000)],
177
+ ['CA', 0, (0.000, 0.000, 0.000)],
178
+ ['C', 0, (1.526, 0.000, 0.000)],
179
+ ['CB', 0, (-0.525, -0.779, -1.207)],
180
+ ['O', 3, (0.626, 1.062, -0.000)],
181
+ ['CG', 4, (0.615, 1.393, 0.000)],
182
+ ['CD', 5, (0.587, 1.399, -0.000)],
183
+ ['NE2', 6, (0.593, -1.189, -0.001)],
184
+ ['OE1', 6, (0.634, 1.060, 0.000)],
185
+ ],
186
+ 'GLU': [
187
+ ['N', 0, (-0.528, 1.361, 0.000)],
188
+ ['CA', 0, (0.000, 0.000, 0.000)],
189
+ ['C', 0, (1.526, -0.000, -0.000)],
190
+ ['CB', 0, (-0.526, -0.781, -1.207)],
191
+ ['O', 3, (0.626, 1.062, 0.000)],
192
+ ['CG', 4, (0.615, 1.392, 0.000)],
193
+ ['CD', 5, (0.600, 1.397, 0.000)],
194
+ ['OE1', 6, (0.607, 1.095, -0.000)],
195
+ ['OE2', 6, (0.589, -1.104, -0.001)],
196
+ ],
197
+ 'GLY': [
198
+ ['N', 0, (-0.572, 1.337, 0.000)],
199
+ ['CA', 0, (0.000, 0.000, 0.000)],
200
+ ['C', 0, (1.517, -0.000, -0.000)],
201
+ ['O', 3, (0.626, 1.062, -0.000)],
202
+ ],
203
+ 'HIS': [
204
+ ['N', 0, (-0.527, 1.360, 0.000)],
205
+ ['CA', 0, (0.000, 0.000, 0.000)],
206
+ ['C', 0, (1.525, 0.000, 0.000)],
207
+ ['CB', 0, (-0.525, -0.778, -1.208)],
208
+ ['O', 3, (0.625, 1.063, 0.000)],
209
+ ['CG', 4, (0.600, 1.370, -0.000)],
210
+ ['CD2', 5, (0.889, -1.021, 0.003)],
211
+ ['ND1', 5, (0.744, 1.160, -0.000)],
212
+ ['CE1', 5, (2.030, 0.851, 0.002)],
213
+ ['NE2', 5, (2.145, -0.466, 0.004)],
214
+ ],
215
+ 'ILE': [
216
+ ['N', 0, (-0.493, 1.373, -0.000)],
217
+ ['CA', 0, (0.000, 0.000, 0.000)],
218
+ ['C', 0, (1.527, -0.000, -0.000)],
219
+ ['CB', 0, (-0.536, -0.793, -1.213)],
220
+ ['O', 3, (0.627, 1.062, -0.000)],
221
+ ['CG1', 4, (0.534, 1.437, -0.000)],
222
+ ['CG2', 4, (0.540, -0.785, -1.199)],
223
+ ['CD1', 5, (0.619, 1.391, 0.000)],
224
+ ],
225
+ 'LEU': [
226
+ ['N', 0, (-0.520, 1.363, 0.000)],
227
+ ['CA', 0, (0.000, 0.000, 0.000)],
228
+ ['C', 0, (1.525, -0.000, -0.000)],
229
+ ['CB', 0, (-0.522, -0.773, -1.214)],
230
+ ['O', 3, (0.625, 1.063, -0.000)],
231
+ ['CG', 4, (0.678, 1.371, 0.000)],
232
+ ['CD1', 5, (0.530, 1.430, -0.000)],
233
+ ['CD2', 5, (0.535, -0.774, 1.200)],
234
+ ],
235
+ 'LYS': [
236
+ ['N', 0, (-0.526, 1.362, -0.000)],
237
+ ['CA', 0, (0.000, 0.000, 0.000)],
238
+ ['C', 0, (1.526, 0.000, 0.000)],
239
+ ['CB', 0, (-0.524, -0.778, -1.208)],
240
+ ['O', 3, (0.626, 1.062, -0.000)],
241
+ ['CG', 4, (0.619, 1.390, 0.000)],
242
+ ['CD', 5, (0.559, 1.417, 0.000)],
243
+ ['CE', 6, (0.560, 1.416, 0.000)],
244
+ ['NZ', 7, (0.554, 1.387, 0.000)],
245
+ ],
246
+ 'MET': [
247
+ ['N', 0, (-0.521, 1.364, -0.000)],
248
+ ['CA', 0, (0.000, 0.000, 0.000)],
249
+ ['C', 0, (1.525, 0.000, 0.000)],
250
+ ['CB', 0, (-0.523, -0.776, -1.210)],
251
+ ['O', 3, (0.625, 1.062, -0.000)],
252
+ ['CG', 4, (0.613, 1.391, -0.000)],
253
+ ['SD', 5, (0.703, 1.695, 0.000)],
254
+ ['CE', 6, (0.320, 1.786, -0.000)],
255
+ ],
256
+ 'PHE': [
257
+ ['N', 0, (-0.518, 1.363, 0.000)],
258
+ ['CA', 0, (0.000, 0.000, 0.000)],
259
+ ['C', 0, (1.524, 0.000, -0.000)],
260
+ ['CB', 0, (-0.525, -0.776, -1.212)],
261
+ ['O', 3, (0.626, 1.062, -0.000)],
262
+ ['CG', 4, (0.607, 1.377, 0.000)],
263
+ ['CD1', 5, (0.709, 1.195, -0.000)],
264
+ ['CD2', 5, (0.706, -1.196, 0.000)],
265
+ ['CE1', 5, (2.102, 1.198, -0.000)],
266
+ ['CE2', 5, (2.098, -1.201, -0.000)],
267
+ ['CZ', 5, (2.794, -0.003, -0.001)],
268
+ ],
269
+ 'PRO': [
270
+ ['N', 0, (-0.566, 1.351, -0.000)],
271
+ ['CA', 0, (0.000, 0.000, 0.000)],
272
+ ['C', 0, (1.527, -0.000, 0.000)],
273
+ ['CB', 0, (-0.546, -0.611, -1.293)],
274
+ ['O', 3, (0.621, 1.066, 0.000)],
275
+ ['CG', 4, (0.382, 1.445, 0.0)],
276
+ # ['CD', 5, (0.427, 1.440, 0.0)],
277
+ ['CD', 5, (0.477, 1.424, 0.0)], # manually made angle 2 degrees larger
278
+ ],
279
+ 'SER': [
280
+ ['N', 0, (-0.529, 1.360, -0.000)],
281
+ ['CA', 0, (0.000, 0.000, 0.000)],
282
+ ['C', 0, (1.525, -0.000, -0.000)],
283
+ ['CB', 0, (-0.518, -0.777, -1.211)],
284
+ ['O', 3, (0.626, 1.062, -0.000)],
285
+ ['OG', 4, (0.503, 1.325, 0.000)],
286
+ ],
287
+ 'THR': [
288
+ ['N', 0, (-0.517, 1.364, 0.000)],
289
+ ['CA', 0, (0.000, 0.000, 0.000)],
290
+ ['C', 0, (1.526, 0.000, -0.000)],
291
+ ['CB', 0, (-0.516, -0.793, -1.215)],
292
+ ['O', 3, (0.626, 1.062, 0.000)],
293
+ ['CG2', 4, (0.550, -0.718, -1.228)],
294
+ ['OG1', 4, (0.472, 1.353, 0.000)],
295
+ ],
296
+ 'TRP': [
297
+ ['N', 0, (-0.521, 1.363, 0.000)],
298
+ ['CA', 0, (0.000, 0.000, 0.000)],
299
+ ['C', 0, (1.525, -0.000, 0.000)],
300
+ ['CB', 0, (-0.523, -0.776, -1.212)],
301
+ ['O', 3, (0.627, 1.062, 0.000)],
302
+ ['CG', 4, (0.609, 1.370, -0.000)],
303
+ ['CD1', 5, (0.824, 1.091, 0.000)],
304
+ ['CD2', 5, (0.854, -1.148, -0.005)],
305
+ ['CE2', 5, (2.186, -0.678, -0.007)],
306
+ ['CE3', 5, (0.622, -2.530, -0.007)],
307
+ ['NE1', 5, (2.140, 0.690, -0.004)],
308
+ ['CH2', 5, (3.028, -2.890, -0.013)],
309
+ ['CZ2', 5, (3.283, -1.543, -0.011)],
310
+ ['CZ3', 5, (1.715, -3.389, -0.011)],
311
+ ],
312
+ 'TYR': [
313
+ ['N', 0, (-0.522, 1.362, 0.000)],
314
+ ['CA', 0, (0.000, 0.000, 0.000)],
315
+ ['C', 0, (1.524, -0.000, -0.000)],
316
+ ['CB', 0, (-0.522, -0.776, -1.213)],
317
+ ['O', 3, (0.627, 1.062, -0.000)],
318
+ ['CG', 4, (0.607, 1.382, -0.000)],
319
+ ['CD1', 5, (0.716, 1.195, -0.000)],
320
+ ['CD2', 5, (0.713, -1.194, -0.001)],
321
+ ['CE1', 5, (2.107, 1.200, -0.002)],
322
+ ['CE2', 5, (2.104, -1.201, -0.003)],
323
+ ['OH', 5, (4.168, -0.002, -0.005)],
324
+ ['CZ', 5, (2.791, -0.001, -0.003)],
325
+ ],
326
+ 'VAL': [
327
+ ['N', 0, (-0.494, 1.373, -0.000)],
328
+ ['CA', 0, (0.000, 0.000, 0.000)],
329
+ ['C', 0, (1.527, -0.000, -0.000)],
330
+ ['CB', 0, (-0.533, -0.795, -1.213)],
331
+ ['O', 3, (0.627, 1.062, -0.000)],
332
+ ['CG1', 4, (0.540, 1.429, -0.000)],
333
+ ['CG2', 4, (0.533, -0.776, 1.203)],
334
+ ],
335
+ }
336
+
337
+ # A list of atoms (excluding hydrogen) for each AA type. PDB naming convention.
338
+ residue_atoms = {
339
+ 'ALA': ['C', 'CA', 'CB', 'N', 'O'],
340
+ 'ARG': ['C', 'CA', 'CB', 'CG', 'CD', 'CZ', 'N', 'NE', 'O', 'NH1', 'NH2'],
341
+ 'ASP': ['C', 'CA', 'CB', 'CG', 'N', 'O', 'OD1', 'OD2'],
342
+ 'ASN': ['C', 'CA', 'CB', 'CG', 'N', 'ND2', 'O', 'OD1'],
343
+ 'CYS': ['C', 'CA', 'CB', 'N', 'O', 'SG'],
344
+ 'GLU': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O', 'OE1', 'OE2'],
345
+ 'GLN': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'NE2', 'O', 'OE1'],
346
+ 'GLY': ['C', 'CA', 'N', 'O'],
347
+ 'HIS': ['C', 'CA', 'CB', 'CG', 'CD2', 'CE1', 'N', 'ND1', 'NE2', 'O'],
348
+ 'ILE': ['C', 'CA', 'CB', 'CG1', 'CG2', 'CD1', 'N', 'O'],
349
+ 'LEU': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'N', 'O'],
350
+ 'LYS': ['C', 'CA', 'CB', 'CG', 'CD', 'CE', 'N', 'NZ', 'O'],
351
+ 'MET': ['C', 'CA', 'CB', 'CG', 'CE', 'N', 'O', 'SD'],
352
+ 'PHE': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O'],
353
+ 'PRO': ['C', 'CA', 'CB', 'CG', 'CD', 'N', 'O'],
354
+ 'SER': ['C', 'CA', 'CB', 'N', 'O', 'OG'],
355
+ 'THR': ['C', 'CA', 'CB', 'CG2', 'N', 'O', 'OG1'],
356
+ 'TRP': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'CZ2', 'CZ3',
357
+ 'CH2', 'N', 'NE1', 'O'],
358
+ 'TYR': ['C', 'CA', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'N', 'O',
359
+ 'OH'],
360
+ 'VAL': ['C', 'CA', 'CB', 'CG1', 'CG2', 'N', 'O']
361
+ }
362
+
363
+ # Naming swaps for ambiguous atom names.
364
+ # Due to symmetries in the amino acids the naming of atoms is ambiguous in
365
+ # 4 of the 20 amino acids.
366
+ # (The LDDT paper lists 7 amino acids as ambiguous, but the naming ambiguities
367
+ # in LEU, VAL and ARG can be resolved by using the 3d constellations of
368
+ # the 'ambiguous' atoms and their neighbours)
369
+ residue_atom_renaming_swaps = {
370
+ 'ASP': {'OD1': 'OD2'},
371
+ 'GLU': {'OE1': 'OE2'},
372
+ 'PHE': {'CD1': 'CD2', 'CE1': 'CE2'},
373
+ 'TYR': {'CD1': 'CD2', 'CE1': 'CE2'},
374
+ }
375
+
376
+ # Van der Waals radii [Angstroem] of the atoms (from Wikipedia)
377
+ van_der_waals_radius = {
378
+ 'C': 1.7,
379
+ 'N': 1.55,
380
+ 'O': 1.52,
381
+ 'S': 1.8,
382
+ }
383
+
384
+ Bond = collections.namedtuple(
385
+ 'Bond', ['atom1_name', 'atom2_name', 'length', 'stddev'])
386
+ BondAngle = collections.namedtuple(
387
+ 'BondAngle',
388
+ ['atom1_name', 'atom2_name', 'atom3name', 'angle_rad', 'stddev'])
389
+
390
+
391
+ @functools.lru_cache(maxsize=None)
392
+ def load_stereo_chemical_props() -> Tuple[Mapping[str, List[Bond]],
393
+ Mapping[str, List[Bond]],
394
+ Mapping[str, List[BondAngle]]]:
395
+ """Load stereo_chemical_props.txt into a nice structure.
396
+
397
+ Load literature values for bond lengths and bond angles and translate
398
+ bond angles into the length of the opposite edge of the triangle
399
+ ("residue_virtual_bonds").
400
+
401
+ Returns:
402
+ residue_bonds: Dict that maps resname -> list of Bond tuples.
403
+ residue_virtual_bonds: Dict that maps resname -> list of Bond tuples.
404
+ residue_bond_angles: Dict that maps resname -> list of BondAngle tuples.
405
+ """
406
+ stereo_chemical_props_path = os.path.join(
407
+ os.path.dirname(os.path.abspath(__file__)), 'stereo_chemical_props.txt'
408
+ )
409
+ with open(stereo_chemical_props_path, 'rt') as f:
410
+ stereo_chemical_props = f.read()
411
+ lines_iter = iter(stereo_chemical_props.splitlines())
412
+ # Load bond lengths.
413
+ residue_bonds = {}
414
+ next(lines_iter) # Skip header line.
415
+ for line in lines_iter:
416
+ if line.strip() == '-':
417
+ break
418
+ bond, resname, length, stddev = line.split()
419
+ atom1, atom2 = bond.split('-')
420
+ if resname not in residue_bonds:
421
+ residue_bonds[resname] = []
422
+ residue_bonds[resname].append(
423
+ Bond(atom1, atom2, float(length), float(stddev)))
424
+ residue_bonds['UNK'] = []
425
+
426
+ # Load bond angles.
427
+ residue_bond_angles = {}
428
+ next(lines_iter) # Skip empty line.
429
+ next(lines_iter) # Skip header line.
430
+ for line in lines_iter:
431
+ if line.strip() == '-':
432
+ break
433
+ bond, resname, angle_degree, stddev_degree = line.split()
434
+ atom1, atom2, atom3 = bond.split('-')
435
+ if resname not in residue_bond_angles:
436
+ residue_bond_angles[resname] = []
437
+ residue_bond_angles[resname].append(
438
+ BondAngle(atom1, atom2, atom3,
439
+ float(angle_degree) / 180. * np.pi,
440
+ float(stddev_degree) / 180. * np.pi))
441
+ residue_bond_angles['UNK'] = []
442
+
443
+ def make_bond_key(atom1_name, atom2_name):
444
+ """Unique key to lookup bonds."""
445
+ return '-'.join(sorted([atom1_name, atom2_name]))
446
+
447
+ # Translate bond angles into distances ("virtual bonds").
448
+ residue_virtual_bonds = {}
449
+ for resname, bond_angles in residue_bond_angles.items():
450
+ # Create a fast lookup dict for bond lengths.
451
+ bond_cache = {}
452
+ for b in residue_bonds[resname]:
453
+ bond_cache[make_bond_key(b.atom1_name, b.atom2_name)] = b
454
+ residue_virtual_bonds[resname] = []
455
+ for ba in bond_angles:
456
+ bond1 = bond_cache[make_bond_key(ba.atom1_name, ba.atom2_name)]
457
+ bond2 = bond_cache[make_bond_key(ba.atom2_name, ba.atom3name)]
458
+
459
+ # Compute distance between atom1 and atom3 using the law of cosines
460
+ # c^2 = a^2 + b^2 - 2ab*cos(gamma).
461
+ gamma = ba.angle_rad
462
+ length = np.sqrt(bond1.length**2 + bond2.length**2
463
+ - 2 * bond1.length * bond2.length * np.cos(gamma))
464
+
465
+ # Propagation of uncertainty assuming uncorrelated errors.
466
+ dl_outer = 0.5 / length
467
+ dl_dgamma = (2 * bond1.length * bond2.length * np.sin(gamma)) * dl_outer
468
+ dl_db1 = (2 * bond1.length - 2 * bond2.length * np.cos(gamma)) * dl_outer
469
+ dl_db2 = (2 * bond2.length - 2 * bond1.length * np.cos(gamma)) * dl_outer
470
+ stddev = np.sqrt((dl_dgamma * ba.stddev)**2 +
471
+ (dl_db1 * bond1.stddev)**2 +
472
+ (dl_db2 * bond2.stddev)**2)
473
+ residue_virtual_bonds[resname].append(
474
+ Bond(ba.atom1_name, ba.atom3name, length, stddev))
475
+
476
+ return (residue_bonds,
477
+ residue_virtual_bonds,
478
+ residue_bond_angles)
479
+
480
+
481
+ # Between-residue bond lengths for general bonds (first element) and for Proline
482
+ # (second element).
483
+ between_res_bond_length_c_n = [1.329, 1.341]
484
+ between_res_bond_length_stddev_c_n = [0.014, 0.016]
485
+
486
+ # Between-residue cos_angles.
487
+ between_res_cos_angles_c_n_ca = [-0.5203, 0.0353] # degrees: 121.352 +- 2.315
488
+ between_res_cos_angles_ca_c_n = [-0.4473, 0.0311] # degrees: 116.568 +- 1.995
489
+
490
+ # This mapping is used when we need to store atom data in a format that requires
491
+ # fixed atom data size for every residue (e.g. a numpy array).
492
+ atom_types = [
493
+ 'N', 'CA', 'C', 'CB', 'O', 'CG', 'CG1', 'CG2', 'OG', 'OG1', 'SG', 'CD',
494
+ 'CD1', 'CD2', 'ND1', 'ND2', 'OD1', 'OD2', 'SD', 'CE', 'CE1', 'CE2', 'CE3',
495
+ 'NE', 'NE1', 'NE2', 'OE1', 'OE2', 'CH2', 'NH1', 'NH2', 'OH', 'CZ', 'CZ2',
496
+ 'CZ3', 'NZ', 'OXT'
497
+ ]
498
+ atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}
499
+ atom_type_num = len(atom_types) # := 37.
500
+
501
+ # A compact atom encoding with 14 columns
502
+ # pylint: disable=line-too-long
503
+ # pylint: disable=bad-whitespace
504
+ restype_name_to_atom14_names = {
505
+ 'ALA': ['N', 'CA', 'C', 'O', 'CB', '', '', '', '', '', '', '', '', ''],
506
+ 'ARG': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2', '', '', ''],
507
+ 'ASN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2', '', '', '', '', '', ''],
508
+ 'ASP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2', '', '', '', '', '', ''],
509
+ 'CYS': ['N', 'CA', 'C', 'O', 'CB', 'SG', '', '', '', '', '', '', '', ''],
510
+ 'GLN': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2', '', '', '', '', ''],
511
+ 'GLU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2', '', '', '', '', ''],
512
+ 'GLY': ['N', 'CA', 'C', 'O', '', '', '', '', '', '', '', '', '', ''],
513
+ 'HIS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2', '', '', '', ''],
514
+ 'ILE': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1', '', '', '', '', '', ''],
515
+ 'LEU': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', '', '', '', '', '', ''],
516
+ 'LYS': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ', '', '', '', '', ''],
517
+ 'MET': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE', '', '', '', '', '', ''],
518
+ 'PHE': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', '', '', ''],
519
+ 'PRO': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', '', '', '', '', '', '', ''],
520
+ 'SER': ['N', 'CA', 'C', 'O', 'CB', 'OG', '', '', '', '', '', '', '', ''],
521
+ 'THR': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2', '', '', '', '', '', '', ''],
522
+ 'TRP': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2'],
523
+ 'TYR': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH', '', ''],
524
+ 'VAL': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', '', '', '', '', '', '', ''],
525
+ 'UNK': ['', '', '', '', '', '', '', '', '', '', '', '', '', ''],
526
+
527
+ }
528
+ # pylint: enable=line-too-long
529
+ # pylint: enable=bad-whitespace
530
+
531
+
532
+ # This is the standard residue order when coding AA type as a number.
533
+ # Reproduce it by taking 3-letter AA codes and sorting them alphabetically.
534
+ restypes = [
535
+ 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
536
+ 'S', 'T', 'W', 'Y', 'V'
537
+ ]
538
+ restype_order = {restype: i for i, restype in enumerate(restypes)}
539
+ restype_num = len(restypes) # := 20.
540
+ unk_restype_index = restype_num # Catch-all index for unknown restypes.
541
+
542
+ restypes_with_x = restypes + ['X']
543
+ restype_order_with_x = {restype: i for i, restype in enumerate(restypes_with_x)}
544
+
545
+
546
+ def sequence_to_onehot(
547
+ sequence: str,
548
+ mapping: Mapping[str, int],
549
+ map_unknown_to_x: bool = False) -> np.ndarray:
550
+ """Maps the given sequence into a one-hot encoded matrix.
551
+
552
+ Args:
553
+ sequence: An amino acid sequence.
554
+ mapping: A dictionary mapping amino acids to integers.
555
+ map_unknown_to_x: If True, any amino acid that is not in the mapping will be
556
+ mapped to the unknown amino acid 'X'. If the mapping doesn't contain
557
+ amino acid 'X', an error will be thrown. If False, any amino acid not in
558
+ the mapping will throw an error.
559
+
560
+ Returns:
561
+ A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
562
+ the sequence.
563
+
564
+ Raises:
565
+ ValueError: If the mapping doesn't contain values from 0 to
566
+ num_unique_aas - 1 without any gaps.
567
+ """
568
+ num_entries = max(mapping.values()) + 1
569
+
570
+ if sorted(set(mapping.values())) != list(range(num_entries)):
571
+ raise ValueError('The mapping must have values from 0 to num_unique_aas-1 '
572
+ 'without any gaps. Got: %s' % sorted(mapping.values()))
573
+
574
+ one_hot_arr = np.zeros((len(sequence), num_entries), dtype=int)
575
+
576
+ for aa_index, aa_type in enumerate(sequence):
577
+ if map_unknown_to_x:
578
+ if aa_type.isalpha() and aa_type.isupper():
579
+ aa_id = mapping.get(aa_type, mapping['X'])
580
+ else:
581
+ raise ValueError(f'Invalid character in the sequence: {aa_type}')
582
+ else:
583
+ aa_id = mapping[aa_type]
584
+ one_hot_arr[aa_index, aa_id] = 1
585
+
586
+ return one_hot_arr
587
+
588
+
589
+ restype_1to3 = {
590
+ 'A': 'ALA',
591
+ 'R': 'ARG',
592
+ 'N': 'ASN',
593
+ 'D': 'ASP',
594
+ 'C': 'CYS',
595
+ 'Q': 'GLN',
596
+ 'E': 'GLU',
597
+ 'G': 'GLY',
598
+ 'H': 'HIS',
599
+ 'I': 'ILE',
600
+ 'L': 'LEU',
601
+ 'K': 'LYS',
602
+ 'M': 'MET',
603
+ 'F': 'PHE',
604
+ 'P': 'PRO',
605
+ 'S': 'SER',
606
+ 'T': 'THR',
607
+ 'W': 'TRP',
608
+ 'Y': 'TYR',
609
+ 'V': 'VAL',
610
+ }
611
+
612
+
613
+ # NB: restype_3to1 differs from Bio.PDB.protein_letters_3to1 by being a simple
614
+ # 1-to-1 mapping of 3 letter names to one letter names. The latter contains
615
+ # many more, and less common, three letter names as keys and maps many of these
616
+ # to the same one letter name (including 'X' and 'U' which we don't use here).
617
+ restype_3to1 = {v: k for k, v in restype_1to3.items()}
618
+
619
+ # Define a restype name for all unknown residues.
620
+ unk_restype = 'UNK'
621
+
622
+ resnames = [restype_1to3[r] for r in restypes] + [unk_restype]
623
+ resname_to_idx = {resname: i for i, resname in enumerate(resnames)}
624
+
625
+
626
+ # The mapping here uses hhblits convention, so that B is mapped to D, J and O
627
+ # are mapped to X, U is mapped to C, and Z is mapped to E. Other than that the
628
+ # remaining 20 amino acids are kept in alphabetical order.
629
+ # There are 2 non-amino acid codes, X (representing any amino acid) and
630
+ # "-" representing a missing amino acid in an alignment. The id for these
631
+ # codes is put at the end (20 and 21) so that they can easily be ignored if
632
+ # desired.
633
+ HHBLITS_AA_TO_ID = {
634
+ 'A': 0,
635
+ 'B': 2,
636
+ 'C': 1,
637
+ 'D': 2,
638
+ 'E': 3,
639
+ 'F': 4,
640
+ 'G': 5,
641
+ 'H': 6,
642
+ 'I': 7,
643
+ 'J': 20,
644
+ 'K': 8,
645
+ 'L': 9,
646
+ 'M': 10,
647
+ 'N': 11,
648
+ 'O': 20,
649
+ 'P': 12,
650
+ 'Q': 13,
651
+ 'R': 14,
652
+ 'S': 15,
653
+ 'T': 16,
654
+ 'U': 1,
655
+ 'V': 17,
656
+ 'W': 18,
657
+ 'X': 20,
658
+ 'Y': 19,
659
+ 'Z': 3,
660
+ '-': 21,
661
+ }
662
+
663
+ # Partial inversion of HHBLITS_AA_TO_ID.
664
+ ID_TO_HHBLITS_AA = {
665
+ 0: 'A',
666
+ 1: 'C', # Also U.
667
+ 2: 'D', # Also B.
668
+ 3: 'E', # Also Z.
669
+ 4: 'F',
670
+ 5: 'G',
671
+ 6: 'H',
672
+ 7: 'I',
673
+ 8: 'K',
674
+ 9: 'L',
675
+ 10: 'M',
676
+ 11: 'N',
677
+ 12: 'P',
678
+ 13: 'Q',
679
+ 14: 'R',
680
+ 15: 'S',
681
+ 16: 'T',
682
+ 17: 'V',
683
+ 18: 'W',
684
+ 19: 'Y',
685
+ 20: 'X', # Includes J and O.
686
+ 21: '-',
687
+ }
688
+
689
+ restypes_with_x_and_gap = restypes + ['X', '-']
690
+ MAP_HHBLITS_AATYPE_TO_OUR_AATYPE = tuple(
691
+ restypes_with_x_and_gap.index(ID_TO_HHBLITS_AA[i])
692
+ for i in range(len(restypes_with_x_and_gap)))
693
+
694
+
695
+ def _make_standard_atom_mask() -> np.ndarray:
696
+ """Returns [num_res_types, num_atom_types] mask array."""
697
+ # +1 to account for unknown (all 0s).
698
+ mask = np.zeros([restype_num + 1, atom_type_num], dtype=int)
699
+ for restype, restype_letter in enumerate(restypes):
700
+ restype_name = restype_1to3[restype_letter]
701
+ atom_names = residue_atoms[restype_name]
702
+ for atom_name in atom_names:
703
+ atom_type = atom_order[atom_name]
704
+ mask[restype, atom_type] = 1
705
+ return mask
706
+
707
+
708
+ STANDARD_ATOM_MASK = _make_standard_atom_mask()
709
+
710
+
711
+ # A one hot representation for the first and second atoms defining the axis
712
+ # of rotation for each chi-angle in each residue.
713
+ def chi_angle_atom(atom_index: int) -> np.ndarray:
714
+ """Define chi-angle rigid groups via one-hot representations."""
715
+ chi_angles_index = {}
716
+ one_hots = []
717
+
718
+ for k, v in chi_angles_atoms.items():
719
+ indices = [atom_types.index(s[atom_index]) for s in v]
720
+ indices.extend([-1]*(4-len(indices)))
721
+ chi_angles_index[k] = indices
722
+
723
+ for r in restypes:
724
+ res3 = restype_1to3[r]
725
+ one_hot = np.eye(atom_type_num)[chi_angles_index[res3]]
726
+ one_hots.append(one_hot)
727
+
728
+ one_hots.append(np.zeros([4, atom_type_num])) # Add zeros for residue `X`.
729
+ one_hot = np.stack(one_hots, axis=0)
730
+ one_hot = np.transpose(one_hot, [0, 2, 1])
731
+
732
+ return one_hot
733
+
734
+ chi_atom_1_one_hot = chi_angle_atom(1)
735
+ chi_atom_2_one_hot = chi_angle_atom(2)
736
+
737
+ # An array like chi_angles_atoms but using indices rather than names.
738
+ chi_angles_atom_indices = [chi_angles_atoms[restype_1to3[r]] for r in restypes]
739
+ chi_angles_atom_indices = tree.map_structure(
740
+ lambda atom_name: atom_order[atom_name], chi_angles_atom_indices)
741
+ chi_angles_atom_indices = np.array([
742
+ chi_atoms + ([[0, 0, 0, 0]] * (4 - len(chi_atoms)))
743
+ for chi_atoms in chi_angles_atom_indices])
744
+
745
+ # Mapping from (res_name, atom_name) pairs to the atom's chi group index
746
+ # and atom index within that group.
747
+ chi_groups_for_atom = collections.defaultdict(list)
748
+ for res_name, chi_angle_atoms_for_res in chi_angles_atoms.items():
749
+ for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res):
750
+ for atom_i, atom in enumerate(chi_group):
751
+ chi_groups_for_atom[(res_name, atom)].append((chi_group_i, atom_i))
752
+ chi_groups_for_atom = dict(chi_groups_for_atom)
753
+
754
+
755
+ def _make_rigid_transformation_4x4(ex, ey, translation):
756
+ """Create a rigid 4x4 transformation matrix from two axes and transl."""
757
+ # Normalize ex.
758
+ ex_normalized = ex / np.linalg.norm(ex)
759
+
760
+ # make ey perpendicular to ex
761
+ ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
762
+ ey_normalized /= np.linalg.norm(ey_normalized)
763
+
764
+ # compute ez as cross product
765
+ eznorm = np.cross(ex_normalized, ey_normalized)
766
+ m = np.stack([ex_normalized, ey_normalized, eznorm, translation]).transpose()
767
+ m = np.concatenate([m, [[0., 0., 0., 1.]]], axis=0)
768
+ return m
769
+
770
+
771
+ # create an array with (restype, atomtype) --> rigid_group_idx
772
+ # and an array with (restype, atomtype, coord) for the atom positions
773
+ # and compute affine transformation matrices (4,4) from one rigid group to the
774
+ # previous group
775
+ restype_atom37_to_rigid_group = np.zeros([21, 37], dtype=int)
776
+ restype_atom37_mask = np.zeros([21, 37], dtype=np.float32)
777
+ restype_atom37_rigid_group_positions = np.zeros([21, 37, 3], dtype=np.float32)
778
+ restype_atom14_to_rigid_group = np.zeros([21, 14], dtype=int)
779
+ restype_atom14_mask = np.zeros([21, 14], dtype=np.float32)
780
+ restype_atom14_rigid_group_positions = np.zeros([21, 14, 3], dtype=np.float32)
781
+ restype_rigid_group_default_frame = np.zeros([21, 8, 4, 4], dtype=np.float32)
782
+
783
+
784
+ def _make_rigid_group_constants():
785
+ """Fill the arrays above."""
786
+ for restype, restype_letter in enumerate(restypes):
787
+ resname = restype_1to3[restype_letter]
788
+ for atomname, group_idx, atom_position in rigid_group_atom_positions[
789
+ resname]:
790
+ atomtype = atom_order[atomname]
791
+ restype_atom37_to_rigid_group[restype, atomtype] = group_idx
792
+ restype_atom37_mask[restype, atomtype] = 1
793
+ restype_atom37_rigid_group_positions[restype, atomtype, :] = atom_position
794
+
795
+ atom14idx = restype_name_to_atom14_names[resname].index(atomname)
796
+ restype_atom14_to_rigid_group[restype, atom14idx] = group_idx
797
+ restype_atom14_mask[restype, atom14idx] = 1
798
+ restype_atom14_rigid_group_positions[restype,
799
+ atom14idx, :] = atom_position
800
+
801
+ for restype, restype_letter in enumerate(restypes):
802
+ resname = restype_1to3[restype_letter]
803
+ atom_positions = {name: np.array(pos) for name, _, pos
804
+ in rigid_group_atom_positions[resname]}
805
+
806
+ # backbone to backbone is the identity transform
807
+ restype_rigid_group_default_frame[restype, 0, :, :] = np.eye(4)
808
+
809
+ # pre-omega-frame to backbone (currently dummy identity matrix)
810
+ restype_rigid_group_default_frame[restype, 1, :, :] = np.eye(4)
811
+
812
+ # phi-frame to backbone
813
+ mat = _make_rigid_transformation_4x4(
814
+ ex=atom_positions['N'] - atom_positions['CA'],
815
+ ey=np.array([1., 0., 0.]),
816
+ translation=atom_positions['N'])
817
+ restype_rigid_group_default_frame[restype, 2, :, :] = mat
818
+
819
+ # psi-frame to backbone
820
+ mat = _make_rigid_transformation_4x4(
821
+ ex=atom_positions['C'] - atom_positions['CA'],
822
+ ey=atom_positions['CA'] - atom_positions['N'],
823
+ translation=atom_positions['C'])
824
+ restype_rigid_group_default_frame[restype, 3, :, :] = mat
825
+
826
+ # chi1-frame to backbone
827
+ if chi_angles_mask[restype][0]:
828
+ base_atom_names = chi_angles_atoms[resname][0]
829
+ base_atom_positions = [atom_positions[name] for name in base_atom_names]
830
+ mat = _make_rigid_transformation_4x4(
831
+ ex=base_atom_positions[2] - base_atom_positions[1],
832
+ ey=base_atom_positions[0] - base_atom_positions[1],
833
+ translation=base_atom_positions[2])
834
+ restype_rigid_group_default_frame[restype, 4, :, :] = mat
835
+
836
+ # chi2-frame to chi1-frame
837
+ # chi3-frame to chi2-frame
838
+ # chi4-frame to chi3-frame
839
+ # luckily all rotation axes for the next frame start at (0,0,0) of the
840
+ # previous frame
841
+ for chi_idx in range(1, 4):
842
+ if chi_angles_mask[restype][chi_idx]:
843
+ axis_end_atom_name = chi_angles_atoms[resname][chi_idx][2]
844
+ axis_end_atom_position = atom_positions[axis_end_atom_name]
845
+ mat = _make_rigid_transformation_4x4(
846
+ ex=axis_end_atom_position,
847
+ ey=np.array([-1., 0., 0.]),
848
+ translation=axis_end_atom_position)
849
+ restype_rigid_group_default_frame[restype, 4 + chi_idx, :, :] = mat
850
+
851
+
852
+ _make_rigid_group_constants()
853
+
854
+
855
+ def make_atom14_dists_bounds(overlap_tolerance=1.5,
856
+ bond_length_tolerance_factor=15):
857
+ """compute upper and lower bounds for bonds to assess violations."""
858
+ restype_atom14_bond_lower_bound = np.zeros([21, 14, 14], np.float32)
859
+ restype_atom14_bond_upper_bound = np.zeros([21, 14, 14], np.float32)
860
+ restype_atom14_bond_stddev = np.zeros([21, 14, 14], np.float32)
861
+ residue_bonds, residue_virtual_bonds, _ = load_stereo_chemical_props()
862
+ for restype, restype_letter in enumerate(restypes):
863
+ resname = restype_1to3[restype_letter]
864
+ atom_list = restype_name_to_atom14_names[resname]
865
+
866
+ # create lower and upper bounds for clashes
867
+ for atom1_idx, atom1_name in enumerate(atom_list):
868
+ if not atom1_name:
869
+ continue
870
+ atom1_radius = van_der_waals_radius[atom1_name[0]]
871
+ for atom2_idx, atom2_name in enumerate(atom_list):
872
+ if (not atom2_name) or atom1_idx == atom2_idx:
873
+ continue
874
+ atom2_radius = van_der_waals_radius[atom2_name[0]]
875
+ lower = atom1_radius + atom2_radius - overlap_tolerance
876
+ upper = 1e10
877
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
878
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
879
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
880
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
881
+
882
+ # overwrite lower and upper bounds for bonds and angles
883
+ for b in residue_bonds[resname] + residue_virtual_bonds[resname]:
884
+ atom1_idx = atom_list.index(b.atom1_name)
885
+ atom2_idx = atom_list.index(b.atom2_name)
886
+ lower = b.length - bond_length_tolerance_factor * b.stddev
887
+ upper = b.length + bond_length_tolerance_factor * b.stddev
888
+ restype_atom14_bond_lower_bound[restype, atom1_idx, atom2_idx] = lower
889
+ restype_atom14_bond_lower_bound[restype, atom2_idx, atom1_idx] = lower
890
+ restype_atom14_bond_upper_bound[restype, atom1_idx, atom2_idx] = upper
891
+ restype_atom14_bond_upper_bound[restype, atom2_idx, atom1_idx] = upper
892
+ restype_atom14_bond_stddev[restype, atom1_idx, atom2_idx] = b.stddev
893
+ restype_atom14_bond_stddev[restype, atom2_idx, atom1_idx] = b.stddev
894
+ return {'lower_bound': restype_atom14_bond_lower_bound, # shape (21,14,14)
895
+ 'upper_bound': restype_atom14_bond_upper_bound, # shape (21,14,14)
896
+ 'stddev': restype_atom14_bond_stddev, # shape (21,14,14)
897
+ }
analysis/src/common/rigid_utils.py ADDED
@@ -0,0 +1,1451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 AlQuraishi Laboratory
2
+ # Copyright 2021 DeepMind Technologies Limited
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Tuple, Any, Sequence, Callable, Optional
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from src.common import rotation3d
22
+
23
+
24
+ def rot_matmul(
25
+ a: torch.Tensor,
26
+ b: torch.Tensor
27
+ ) -> torch.Tensor:
28
+ """
29
+ Performs matrix multiplication of two rotation matrix tensors. Written
30
+ out by hand to avoid AMP downcasting.
31
+
32
+ Args:
33
+ a: [*, 3, 3] left multiplicand
34
+ b: [*, 3, 3] right multiplicand
35
+ Returns:
36
+ The product ab
37
+ """
38
+ row_1 = torch.stack(
39
+ [
40
+ a[..., 0, 0] * b[..., 0, 0]
41
+ + a[..., 0, 1] * b[..., 1, 0]
42
+ + a[..., 0, 2] * b[..., 2, 0],
43
+ a[..., 0, 0] * b[..., 0, 1]
44
+ + a[..., 0, 1] * b[..., 1, 1]
45
+ + a[..., 0, 2] * b[..., 2, 1],
46
+ a[..., 0, 0] * b[..., 0, 2]
47
+ + a[..., 0, 1] * b[..., 1, 2]
48
+ + a[..., 0, 2] * b[..., 2, 2],
49
+ ],
50
+ dim=-1,
51
+ )
52
+ row_2 = torch.stack(
53
+ [
54
+ a[..., 1, 0] * b[..., 0, 0]
55
+ + a[..., 1, 1] * b[..., 1, 0]
56
+ + a[..., 1, 2] * b[..., 2, 0],
57
+ a[..., 1, 0] * b[..., 0, 1]
58
+ + a[..., 1, 1] * b[..., 1, 1]
59
+ + a[..., 1, 2] * b[..., 2, 1],
60
+ a[..., 1, 0] * b[..., 0, 2]
61
+ + a[..., 1, 1] * b[..., 1, 2]
62
+ + a[..., 1, 2] * b[..., 2, 2],
63
+ ],
64
+ dim=-1,
65
+ )
66
+ row_3 = torch.stack(
67
+ [
68
+ a[..., 2, 0] * b[..., 0, 0]
69
+ + a[..., 2, 1] * b[..., 1, 0]
70
+ + a[..., 2, 2] * b[..., 2, 0],
71
+ a[..., 2, 0] * b[..., 0, 1]
72
+ + a[..., 2, 1] * b[..., 1, 1]
73
+ + a[..., 2, 2] * b[..., 2, 1],
74
+ a[..., 2, 0] * b[..., 0, 2]
75
+ + a[..., 2, 1] * b[..., 1, 2]
76
+ + a[..., 2, 2] * b[..., 2, 2],
77
+ ],
78
+ dim=-1,
79
+ )
80
+
81
+ return torch.stack([row_1, row_2, row_3], dim=-2)
82
+
83
+
84
+ def rot_vec_mul(
85
+ r: torch.Tensor,
86
+ t: torch.Tensor
87
+ ) -> torch.Tensor:
88
+ """
89
+ Applies a rotation to a vector. Written out by hand to avoid transfer
90
+ to avoid AMP downcasting.
91
+
92
+ Args:
93
+ r: [*, 3, 3] rotation matrices
94
+ t: [*, 3] coordinate tensors
95
+ Returns:
96
+ [*, 3] rotated coordinates
97
+ """
98
+ x = t[..., 0]
99
+ y = t[..., 1]
100
+ z = t[..., 2]
101
+ return torch.stack(
102
+ [
103
+ r[..., 0, 0] * x + r[..., 0, 1] * y + r[..., 0, 2] * z,
104
+ r[..., 1, 0] * x + r[..., 1, 1] * y + r[..., 1, 2] * z,
105
+ r[..., 2, 0] * x + r[..., 2, 1] * y + r[..., 2, 2] * z,
106
+ ],
107
+ dim=-1,
108
+ )
109
+
110
+
111
+ def identity_rot_mats(
112
+ batch_dims: Tuple[int],
113
+ dtype: Optional[torch.dtype] = None,
114
+ device: Optional[torch.device] = None,
115
+ requires_grad: bool = True,
116
+ ) -> torch.Tensor:
117
+ rots = torch.eye(
118
+ 3, dtype=dtype, device=device, requires_grad=requires_grad
119
+ )
120
+ rots = rots.view(*((1,) * len(batch_dims)), 3, 3)
121
+ rots = rots.expand(*batch_dims, -1, -1)
122
+
123
+ return rots
124
+
125
+
126
+ def identity_trans(
127
+ batch_dims: Tuple[int],
128
+ dtype: Optional[torch.dtype] = None,
129
+ device: Optional[torch.device] = None,
130
+ requires_grad: bool = True,
131
+ ) -> torch.Tensor:
132
+ trans = torch.zeros(
133
+ (*batch_dims, 3),
134
+ dtype=dtype,
135
+ device=device,
136
+ requires_grad=requires_grad
137
+ )
138
+ return trans
139
+
140
+
141
+ def identity_quats(
142
+ batch_dims: Tuple[int],
143
+ dtype: Optional[torch.dtype] = None,
144
+ device: Optional[torch.device] = None,
145
+ requires_grad: bool = True,
146
+ ) -> torch.Tensor:
147
+ quat = torch.zeros(
148
+ (*batch_dims, 4),
149
+ dtype=dtype,
150
+ device=device,
151
+ requires_grad=requires_grad
152
+ )
153
+
154
+ with torch.no_grad():
155
+ quat[..., 0] = 1
156
+
157
+ return quat
158
+
159
+
160
+ _quat_elements = ["a", "b", "c", "d"]
161
+ _qtr_keys = [l1 + l2 for l1 in _quat_elements for l2 in _quat_elements]
162
+ _qtr_ind_dict = {key: ind for ind, key in enumerate(_qtr_keys)}
163
+
164
+
165
+ def _to_mat(pairs):
166
+ mat = np.zeros((4, 4))
167
+ for pair in pairs:
168
+ key, value = pair
169
+ ind = _qtr_ind_dict[key]
170
+ mat[ind // 4][ind % 4] = value
171
+
172
+ return mat
173
+
174
+
175
+ _QTR_MAT = np.zeros((4, 4, 3, 3))
176
+ _QTR_MAT[..., 0, 0] = _to_mat([("aa", 1), ("bb", 1), ("cc", -1), ("dd", -1)])
177
+ _QTR_MAT[..., 0, 1] = _to_mat([("bc", 2), ("ad", -2)])
178
+ _QTR_MAT[..., 0, 2] = _to_mat([("bd", 2), ("ac", 2)])
179
+ _QTR_MAT[..., 1, 0] = _to_mat([("bc", 2), ("ad", 2)])
180
+ _QTR_MAT[..., 1, 1] = _to_mat([("aa", 1), ("bb", -1), ("cc", 1), ("dd", -1)])
181
+ _QTR_MAT[..., 1, 2] = _to_mat([("cd", 2), ("ab", -2)])
182
+ _QTR_MAT[..., 2, 0] = _to_mat([("bd", 2), ("ac", -2)])
183
+ _QTR_MAT[..., 2, 1] = _to_mat([("cd", 2), ("ab", 2)])
184
+ _QTR_MAT[..., 2, 2] = _to_mat([("aa", 1), ("bb", -1), ("cc", -1), ("dd", 1)])
185
+
186
+
187
+ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
188
+ """
189
+ Converts a quaternion to a rotation matrix.
190
+
191
+ Args:
192
+ quat: [*, 4] quaternions
193
+ Returns:
194
+ [*, 3, 3] rotation matrices
195
+ """
196
+ # [*, 4, 4]
197
+ quat = quat[..., None] * quat[..., None, :]
198
+
199
+ # [4, 4, 3, 3]
200
+ mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
201
+
202
+ # [*, 4, 4, 3, 3]
203
+ shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
204
+ quat = quat[..., None, None] * shaped_qtr_mat
205
+
206
+ # [*, 3, 3]
207
+ return torch.sum(quat, dim=(-3, -4))
208
+
209
+
210
+ def rot_to_quat(
211
+ rot: torch.Tensor,
212
+ ):
213
+ if(rot.shape[-2:] != (3, 3)):
214
+ raise ValueError("Input rotation is incorrectly shaped")
215
+
216
+ rot = [[rot[..., i, j] for j in range(3)] for i in range(3)]
217
+ [[xx, xy, xz], [yx, yy, yz], [zx, zy, zz]] = rot
218
+
219
+ k = [
220
+ [ xx + yy + zz, zy - yz, xz - zx, yx - xy,],
221
+ [ zy - yz, xx - yy - zz, xy + yx, xz + zx,],
222
+ [ xz - zx, xy + yx, yy - xx - zz, yz + zy,],
223
+ [ yx - xy, xz + zx, yz + zy, zz - xx - yy,]
224
+ ]
225
+
226
+ k = (1./3.) * torch.stack([torch.stack(t, dim=-1) for t in k], dim=-2)
227
+
228
+ _, vectors = torch.linalg.eigh(k)
229
+ return vectors[..., -1]
230
+
231
+
232
+ _QUAT_MULTIPLY = np.zeros((4, 4, 4))
233
+ _QUAT_MULTIPLY[:, :, 0] = [[ 1, 0, 0, 0],
234
+ [ 0,-1, 0, 0],
235
+ [ 0, 0,-1, 0],
236
+ [ 0, 0, 0,-1]]
237
+
238
+ _QUAT_MULTIPLY[:, :, 1] = [[ 0, 1, 0, 0],
239
+ [ 1, 0, 0, 0],
240
+ [ 0, 0, 0, 1],
241
+ [ 0, 0,-1, 0]]
242
+
243
+ _QUAT_MULTIPLY[:, :, 2] = [[ 0, 0, 1, 0],
244
+ [ 0, 0, 0,-1],
245
+ [ 1, 0, 0, 0],
246
+ [ 0, 1, 0, 0]]
247
+
248
+ _QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
249
+ [ 0, 0, 1, 0],
250
+ [ 0,-1, 0, 0],
251
+ [ 1, 0, 0, 0]]
252
+
253
+ _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
254
+
255
+
256
+ def quat_multiply(quat1, quat2):
257
+ """Multiply a quaternion by another quaternion."""
258
+ mat = quat1.new_tensor(_QUAT_MULTIPLY)
259
+ reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
260
+ return torch.sum(
261
+ reshaped_mat *
262
+ quat1[..., :, None, None] *
263
+ quat2[..., None, :, None],
264
+ dim=(-3, -2)
265
+ )
266
+
267
+
268
+ def quat_multiply_by_vec(quat, vec):
269
+ """Multiply a quaternion by a pure-vector quaternion."""
270
+ mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
271
+ reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
272
+ return torch.sum(
273
+ reshaped_mat *
274
+ quat[..., :, None, None] *
275
+ vec[..., None, :, None],
276
+ dim=(-3, -2)
277
+ )
278
+
279
+
280
+ def invert_rot_mat(rot_mat: torch.Tensor):
281
+ return rot_mat.transpose(-1, -2)
282
+
283
+
284
+ def invert_quat(quat: torch.Tensor):
285
+ quat_prime = quat.clone()
286
+ quat_prime[..., 1:] *= -1
287
+ inv = quat_prime / torch.sum(quat ** 2, dim=-1, keepdim=True)
288
+ return inv
289
+
290
+
291
+ class Rotation:
292
+ """
293
+ A 3D rotation. Depending on how the object is initialized, the
294
+ rotation is represented by either a rotation matrix or a
295
+ quaternion, though both formats are made available by helper functions.
296
+ To simplify gradient computation, the underlying format of the
297
+ rotation cannot be changed in-place. Like Rigid, the class is designed
298
+ to mimic the behavior of a torch Tensor, almost as if each Rotation
299
+ object were a tensor of rotations, in one format or another.
300
+ """
301
+ def __init__(self,
302
+ rot_mats: Optional[torch.Tensor] = None,
303
+ quats: Optional[torch.Tensor] = None,
304
+ normalize_quats: bool = True,
305
+ ):
306
+ """
307
+ Args:
308
+ rot_mats:
309
+ A [*, 3, 3] rotation matrix tensor. Mutually exclusive with
310
+ quats
311
+ quats:
312
+ A [*, 4] quaternion. Mutually exclusive with rot_mats. If
313
+ normalize_quats is not True, must be a unit quaternion
314
+ normalize_quats:
315
+ If quats is specified, whether to normalize quats
316
+ """
317
+ if((rot_mats is None and quats is None) or
318
+ (rot_mats is not None and quats is not None)):
319
+ raise ValueError("Exactly one input argument must be specified")
320
+
321
+ if((rot_mats is not None and rot_mats.shape[-2:] != (3, 3)) or
322
+ (quats is not None and quats.shape[-1] != 4)):
323
+ raise ValueError(
324
+ "Incorrectly shaped rotation matrix or quaternion"
325
+ )
326
+
327
+ # Force full-precision
328
+ if(quats is not None):
329
+ quats = quats.type(torch.float32)
330
+ if(rot_mats is not None):
331
+ rot_mats = rot_mats.type(torch.float32)
332
+
333
+ if(quats is not None and normalize_quats):
334
+ quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True)
335
+
336
+ self._rot_mats = rot_mats
337
+ self._quats = quats
338
+
339
+ @staticmethod
340
+ def identity(
341
+ shape,
342
+ dtype: Optional[torch.dtype] = None,
343
+ device: Optional[torch.device] = None,
344
+ requires_grad: bool = True,
345
+ fmt: str = "quat",
346
+ ):
347
+ """
348
+ Returns an identity Rotation.
349
+
350
+ Args:
351
+ shape:
352
+ The "shape" of the resulting Rotation object. See documentation
353
+ for the shape property
354
+ dtype:
355
+ The torch dtype for the rotation
356
+ device:
357
+ The torch device for the new rotation
358
+ requires_grad:
359
+ Whether the underlying tensors in the new rotation object
360
+ should require gradient computation
361
+ fmt:
362
+ One of "quat" or "rot_mat". Determines the underlying format
363
+ of the new object's rotation
364
+ Returns:
365
+ A new identity rotation
366
+ """
367
+ if(fmt == "rot_mat"):
368
+ rot_mats = identity_rot_mats(
369
+ shape, dtype, device, requires_grad,
370
+ )
371
+ return Rotation(rot_mats=rot_mats, quats=None)
372
+ elif(fmt == "quat"):
373
+ quats = identity_quats(shape, dtype, device, requires_grad)
374
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
375
+ else:
376
+ raise ValueError(f"Invalid format: f{fmt}")
377
+
378
+ # Magic methods
379
+
380
+ def __getitem__(self, index: Any):
381
+ """
382
+ Allows torch-style indexing over the virtual shape of the rotation
383
+ object. See documentation for the shape property.
384
+
385
+ Args:
386
+ index:
387
+ A torch index. E.g. (1, 3, 2), or (slice(None,))
388
+ Returns:
389
+ The indexed rotation
390
+ """
391
+ if type(index) != tuple:
392
+ index = (index,)
393
+
394
+ if(self._rot_mats is not None):
395
+ rot_mats = self._rot_mats[index + (slice(None), slice(None))]
396
+ return Rotation(rot_mats=rot_mats)
397
+ elif(self._quats is not None):
398
+ quats = self._quats[index + (slice(None),)]
399
+ return Rotation(quats=quats, normalize_quats=False)
400
+ else:
401
+ raise ValueError("Both rotations are None")
402
+
403
+ def __mul__(self,
404
+ right: torch.Tensor,
405
+ ):
406
+ """
407
+ Pointwise left multiplication of the rotation with a tensor. Can be
408
+ used to e.g. mask the Rotation.
409
+
410
+ Args:
411
+ right:
412
+ The tensor multiplicand
413
+ Returns:
414
+ The product
415
+ """
416
+ if not(isinstance(right, torch.Tensor)):
417
+ raise TypeError("The other multiplicand must be a Tensor")
418
+
419
+ if(self._rot_mats is not None):
420
+ rot_mats = self._rot_mats * right[..., None, None]
421
+ return Rotation(rot_mats=rot_mats, quats=None)
422
+ elif(self._quats is not None):
423
+ quats = self._quats * right[..., None]
424
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
425
+ else:
426
+ raise ValueError("Both rotations are None")
427
+
428
+ def __rmul__(self,
429
+ left: torch.Tensor,
430
+ ):
431
+ """
432
+ Reverse pointwise multiplication of the rotation with a tensor.
433
+
434
+ Args:
435
+ left:
436
+ The left multiplicand
437
+ Returns:
438
+ The product
439
+ """
440
+ return self.__mul__(left)
441
+
442
+ # Properties
443
+
444
+ @property
445
+ def shape(self) -> torch.Size:
446
+ """
447
+ Returns the virtual shape of the rotation object. This shape is
448
+ defined as the batch dimensions of the underlying rotation matrix
449
+ or quaternion. If the Rotation was initialized with a [10, 3, 3]
450
+ rotation matrix tensor, for example, the resulting shape would be
451
+ [10].
452
+
453
+ Returns:
454
+ The virtual shape of the rotation object
455
+ """
456
+ s = None
457
+ if(self._quats is not None):
458
+ s = self._quats.shape[:-1]
459
+ else:
460
+ s = self._rot_mats.shape[:-2]
461
+
462
+ return s
463
+
464
+ @property
465
+ def dtype(self) -> torch.dtype:
466
+ """
467
+ Returns the dtype of the underlying rotation.
468
+
469
+ Returns:
470
+ The dtype of the underlying rotation
471
+ """
472
+ if(self._rot_mats is not None):
473
+ return self._rot_mats.dtype
474
+ elif(self._quats is not None):
475
+ return self._quats.dtype
476
+ else:
477
+ raise ValueError("Both rotations are None")
478
+
479
+ @property
480
+ def device(self) -> torch.device:
481
+ """
482
+ The device of the underlying rotation
483
+
484
+ Returns:
485
+ The device of the underlying rotation
486
+ """
487
+ if(self._rot_mats is not None):
488
+ return self._rot_mats.device
489
+ elif(self._quats is not None):
490
+ return self._quats.device
491
+ else:
492
+ raise ValueError("Both rotations are None")
493
+
494
+ @property
495
+ def requires_grad(self) -> bool:
496
+ """
497
+ Returns the requires_grad property of the underlying rotation
498
+
499
+ Returns:
500
+ The requires_grad property of the underlying tensor
501
+ """
502
+ if(self._rot_mats is not None):
503
+ return self._rot_mats.requires_grad
504
+ elif(self._quats is not None):
505
+ return self._quats.requires_grad
506
+ else:
507
+ raise ValueError("Both rotations are None")
508
+
509
+ def get_rot_mats(self) -> torch.Tensor:
510
+ """
511
+ Returns the underlying rotation as a rotation matrix tensor.
512
+
513
+ Returns:
514
+ The rotation as a rotation matrix tensor
515
+ """
516
+ rot_mats = self._rot_mats
517
+ if(rot_mats is None):
518
+ if(self._quats is None):
519
+ raise ValueError("Both rotations are None")
520
+ else:
521
+ rot_mats = quat_to_rot(self._quats)
522
+
523
+ return rot_mats
524
+
525
+ def get_quats(self) -> torch.Tensor:
526
+ """
527
+ Returns the underlying rotation as a quaternion tensor.
528
+
529
+ Depending on whether the Rotation was initialized with a
530
+ quaternion, this function may call torch.linalg.eigh.
531
+
532
+ Returns:
533
+ The rotation as a quaternion tensor.
534
+ """
535
+ quats = self._quats
536
+ if(quats is None):
537
+ if(self._rot_mats is None):
538
+ raise ValueError("Both rotations are None")
539
+ else:
540
+ # quats = rot_to_quat(self._rot_mats)
541
+ quats = rotation3d.matrix_to_quaternion(self._rot_mats)
542
+
543
+ return quats
544
+
545
+ def get_cur_rot(self) -> torch.Tensor:
546
+ """
547
+ Return the underlying rotation in its current form
548
+
549
+ Returns:
550
+ The stored rotation
551
+ """
552
+ if(self._rot_mats is not None):
553
+ return self._rot_mats
554
+ elif(self._quats is not None):
555
+ return self._quats
556
+ else:
557
+ raise ValueError("Both rotations are None")
558
+
559
+ def get_rotvec(self, eps=1e-6) -> torch.Tensor:
560
+ """
561
+ Return the underlying axis-angle rotation vector.
562
+
563
+ Follow's scipy's implementation:
564
+ https://github.com/scipy/scipy/blob/HEAD/scipy/spatial/transform/_rotation.pyx#L1385-L1402
565
+
566
+ Returns:
567
+ The stored rotation as a axis-angle vector.
568
+ """
569
+ quat = self.get_quats()
570
+ # w > 0 to ensure 0 <= angle <= pi
571
+ flip = (quat[..., :1] < 0).float()
572
+ quat = (-1 * quat) * flip + (1 - flip) * quat
573
+
574
+ angle = 2 * torch.atan2(
575
+ torch.linalg.norm(quat[..., 1:], dim=-1),
576
+ quat[..., 0]
577
+ )
578
+
579
+ angle2 = angle * angle
580
+ small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
581
+ large_angle_scales = angle / torch.sin(angle / 2 + eps)
582
+
583
+ small_angles = (angle <= 1e-3).float()
584
+ rot_vec_scale = small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
585
+ rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
586
+ return rot_vec
587
+
588
+ # Rotation functions
589
+
590
+ def compose_q_update_vec(self,
591
+ q_update_vec: torch.Tensor,
592
+ normalize_quats: bool = True,
593
+ update_mask: torch.Tensor = None,
594
+ ):
595
+ """
596
+ Returns a new quaternion Rotation after updating the current
597
+ object's underlying rotation with a quaternion update, formatted
598
+ as a [*, 3] tensor whose final three columns represent x, y, z such
599
+ that (1, x, y, z) is the desired (not necessarily unit) quaternion
600
+ update.
601
+
602
+ Args:
603
+ q_update_vec:
604
+ A [*, 3] quaternion update tensor
605
+ normalize_quats:
606
+ Whether to normalize the output quaternion
607
+ Returns:
608
+ An updated Rotation
609
+ """
610
+ quats = self.get_quats()
611
+ quat_update = quat_multiply_by_vec(quats, q_update_vec)
612
+ if update_mask is not None:
613
+ quat_update = quat_update * update_mask
614
+ new_quats = quats + quat_update
615
+ return Rotation(
616
+ rot_mats=None,
617
+ quats=new_quats,
618
+ normalize_quats=normalize_quats,
619
+ )
620
+
621
+ def compose_r(self, r):
622
+ """
623
+ Compose the rotation matrices of the current Rotation object with
624
+ those of another.
625
+
626
+ Args:
627
+ r:
628
+ An update rotation object
629
+ Returns:
630
+ An updated rotation object
631
+ """
632
+ r1 = self.get_rot_mats()
633
+ r2 = r.get_rot_mats()
634
+ new_rot_mats = rot_matmul(r1, r2)
635
+ return Rotation(rot_mats=new_rot_mats, quats=None)
636
+
637
+ def compose_q(self, r, normalize_quats: bool = True):
638
+ """
639
+ Compose the quaternions of the current Rotation object with those
640
+ of another.
641
+
642
+ Depending on whether either Rotation was initialized with
643
+ quaternions, this function may call torch.linalg.eigh.
644
+
645
+ Args:
646
+ r:
647
+ An update rotation object
648
+ Returns:
649
+ An updated rotation object
650
+ """
651
+ q1 = self.get_quats()
652
+ q2 = r.get_quats()
653
+ new_quats = quat_multiply(q1, q2)
654
+ return Rotation(
655
+ rot_mats=None, quats=new_quats, normalize_quats=normalize_quats
656
+ )
657
+
658
+ def apply(self, pts: torch.Tensor) -> torch.Tensor:
659
+ """
660
+ Apply the current Rotation as a rotation matrix to a set of 3D
661
+ coordinates.
662
+
663
+ Args:
664
+ pts:
665
+ A [*, 3] set of points
666
+ Returns:
667
+ [*, 3] rotated points
668
+ """
669
+ rot_mats = self.get_rot_mats()
670
+ return rot_vec_mul(rot_mats, pts)
671
+
672
+ def invert_apply(self, pts: torch.Tensor) -> torch.Tensor:
673
+ """
674
+ The inverse of the apply() method.
675
+
676
+ Args:
677
+ pts:
678
+ A [*, 3] set of points
679
+ Returns:
680
+ [*, 3] inverse-rotated points
681
+ """
682
+ rot_mats = self.get_rot_mats()
683
+ inv_rot_mats = invert_rot_mat(rot_mats)
684
+ return rot_vec_mul(inv_rot_mats, pts)
685
+
686
+ def invert(self) :
687
+ """
688
+ Returns the inverse of the current Rotation.
689
+
690
+ Returns:
691
+ The inverse of the current Rotation
692
+ """
693
+ if(self._rot_mats is not None):
694
+ return Rotation(
695
+ rot_mats=invert_rot_mat(self._rot_mats),
696
+ quats=None
697
+ )
698
+ elif(self._quats is not None):
699
+ return Rotation(
700
+ rot_mats=None,
701
+ quats=invert_quat(self._quats),
702
+ normalize_quats=False,
703
+ )
704
+ else:
705
+ raise ValueError("Both rotations are None")
706
+
707
+ # "Tensor" stuff
708
+
709
+ def unsqueeze(self,
710
+ dim: int,
711
+ ):
712
+ """
713
+ Analogous to torch.unsqueeze. The dimension is relative to the
714
+ shape of the Rotation object.
715
+
716
+ Args:
717
+ dim: A positive or negative dimension index.
718
+ Returns:
719
+ The unsqueezed Rotation.
720
+ """
721
+ if dim >= len(self.shape):
722
+ raise ValueError("Invalid dimension")
723
+
724
+ if(self._rot_mats is not None):
725
+ rot_mats = self._rot_mats.unsqueeze(dim if dim >= 0 else dim - 2)
726
+ return Rotation(rot_mats=rot_mats, quats=None)
727
+ elif(self._quats is not None):
728
+ quats = self._quats.unsqueeze(dim if dim >= 0 else dim - 1)
729
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
730
+ else:
731
+ raise ValueError("Both rotations are None")
732
+
733
+ @staticmethod
734
+ def cat(
735
+ rs,
736
+ dim: int,
737
+ ):
738
+ """
739
+ Concatenates rotations along one of the batch dimensions. Analogous
740
+ to torch.cat().
741
+
742
+ Note that the output of this operation is always a rotation matrix,
743
+ regardless of the format of input rotations.
744
+
745
+ Args:
746
+ rs:
747
+ A list of rotation objects
748
+ dim:
749
+ The dimension along which the rotations should be
750
+ concatenated
751
+ Returns:
752
+ A concatenated Rotation object in rotation matrix format
753
+ """
754
+ rot_mats = [r.get_rot_mats() for r in rs]
755
+ rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2)
756
+
757
+ return Rotation(rot_mats=rot_mats, quats=None)
758
+
759
+ def map_tensor_fn(self,
760
+ fn
761
+ ):
762
+ """
763
+ Apply a Tensor -> Tensor function to underlying rotation tensors,
764
+ mapping over the rotation dimension(s). Can be used e.g. to sum out
765
+ a one-hot batch dimension.
766
+
767
+ Args:
768
+ fn:
769
+ A Tensor -> Tensor function to be mapped over the Rotation
770
+ Returns:
771
+ The transformed Rotation object
772
+ """
773
+ if(self._rot_mats is not None):
774
+ rot_mats = self._rot_mats.view(self._rot_mats.shape[:-2] + (9,))
775
+ rot_mats = torch.stack(
776
+ list(map(fn, torch.unbind(rot_mats, dim=-1))), dim=-1
777
+ )
778
+ rot_mats = rot_mats.view(rot_mats.shape[:-1] + (3, 3))
779
+ return Rotation(rot_mats=rot_mats, quats=None)
780
+ elif(self._quats is not None):
781
+ quats = torch.stack(
782
+ list(map(fn, torch.unbind(self._quats, dim=-1))), dim=-1
783
+ )
784
+ return Rotation(rot_mats=None, quats=quats, normalize_quats=False)
785
+ else:
786
+ raise ValueError("Both rotations are None")
787
+
788
+ def cuda(self):
789
+ """
790
+ Analogous to the cuda() method of torch Tensors
791
+
792
+ Returns:
793
+ A copy of the Rotation in CUDA memory
794
+ """
795
+ if(self._rot_mats is not None):
796
+ return Rotation(rot_mats=self._rot_mats.cuda(), quats=None)
797
+ elif(self._quats is not None):
798
+ return Rotation(
799
+ rot_mats=None,
800
+ quats=self._quats.cuda(),
801
+ normalize_quats=False
802
+ )
803
+ else:
804
+ raise ValueError("Both rotations are None")
805
+
806
+ def to(self,
807
+ device: Optional[torch.device],
808
+ dtype: Optional[torch.dtype]
809
+ ):
810
+ """
811
+ Analogous to the to() method of torch Tensors
812
+
813
+ Args:
814
+ device:
815
+ A torch device
816
+ dtype:
817
+ A torch dtype
818
+ Returns:
819
+ A copy of the Rotation using the new device and dtype
820
+ """
821
+ if(self._rot_mats is not None):
822
+ return Rotation(
823
+ rot_mats=self._rot_mats.to(device=device, dtype=dtype),
824
+ quats=None,
825
+ )
826
+ elif(self._quats is not None):
827
+ return Rotation(
828
+ rot_mats=None,
829
+ quats=self._quats.to(device=device, dtype=dtype),
830
+ normalize_quats=False,
831
+ )
832
+ else:
833
+ raise ValueError("Both rotations are None")
834
+
835
+ def detach(self):
836
+ """
837
+ Returns a copy of the Rotation whose underlying Tensor has been
838
+ detached from its torch graph.
839
+
840
+ Returns:
841
+ A copy of the Rotation whose underlying Tensor has been detached
842
+ from its torch graph
843
+ """
844
+ if(self._rot_mats is not None):
845
+ return Rotation(rot_mats=self._rot_mats.detach(), quats=None)
846
+ elif(self._quats is not None):
847
+ return Rotation(
848
+ rot_mats=None,
849
+ quats=self._quats.detach(),
850
+ normalize_quats=False,
851
+ )
852
+ else:
853
+ raise ValueError("Both rotations are None")
854
+
855
+
856
+ class Rigid:
857
+ """
858
+ A class representing a rigid transformation. Little more than a wrapper
859
+ around two objects: a Rotation object and a [*, 3] translation
860
+ Designed to behave approximately like a single torch tensor with the
861
+ shape of the shared batch dimensions of its component parts.
862
+ """
863
+ def __init__(self,
864
+ rots: Optional[Rotation],
865
+ trans: Optional[torch.Tensor],
866
+ ):
867
+ """
868
+ Args:
869
+ rots: A [*, 3, 3] rotation tensor
870
+ trans: A corresponding [*, 3] translation tensor
871
+ """
872
+ # (we need device, dtype, etc. from at least one input)
873
+
874
+ batch_dims, dtype, device, requires_grad = None, None, None, None
875
+ if(trans is not None):
876
+ batch_dims = trans.shape[:-1]
877
+ dtype = trans.dtype
878
+ device = trans.device
879
+ requires_grad = trans.requires_grad
880
+ elif(rots is not None):
881
+ batch_dims = rots.shape
882
+ dtype = rots.dtype
883
+ device = rots.device
884
+ requires_grad = rots.requires_grad
885
+ else:
886
+ raise ValueError("At least one input argument must be specified")
887
+
888
+ if(rots is None):
889
+ rots = Rotation.identity(
890
+ batch_dims, dtype, device, requires_grad,
891
+ )
892
+ elif(trans is None):
893
+ trans = identity_trans(
894
+ batch_dims, dtype, device, requires_grad,
895
+ )
896
+
897
+ if((rots.shape != trans.shape[:-1]) or
898
+ (rots.device != trans.device)):
899
+ raise ValueError("Rots and trans incompatible")
900
+
901
+ # Force full precision. Happens to the rotations automatically.
902
+ trans = trans.type(torch.float32)
903
+
904
+ self._rots = rots
905
+ self._trans = trans
906
+
907
+ @staticmethod
908
+ def identity(
909
+ shape: Tuple[int],
910
+ dtype: Optional[torch.dtype] = None,
911
+ device: Optional[torch.device] = None,
912
+ requires_grad: bool = True,
913
+ fmt: str = "quat",
914
+ ):
915
+ """
916
+ Constructs an identity transformation.
917
+
918
+ Args:
919
+ shape:
920
+ The desired shape
921
+ dtype:
922
+ The dtype of both internal tensors
923
+ device:
924
+ The device of both internal tensors
925
+ requires_grad:
926
+ Whether grad should be enabled for the internal tensors
927
+ Returns:
928
+ The identity transformation
929
+ """
930
+ return Rigid(
931
+ Rotation.identity(shape, dtype, device, requires_grad, fmt=fmt),
932
+ identity_trans(shape, dtype, device, requires_grad),
933
+ )
934
+
935
+ def __getitem__(self,
936
+ index: Any,
937
+ ):
938
+ """
939
+ Indexes the affine transformation with PyTorch-style indices.
940
+ The index is applied to the shared dimensions of both the rotation
941
+ and the translation.
942
+
943
+ E.g.::
944
+
945
+ r = Rotation(rot_mats=torch.rand(10, 10, 3, 3), quats=None)
946
+ t = Rigid(r, torch.rand(10, 10, 3))
947
+ indexed = t[3, 4:6]
948
+ assert(indexed.shape == (2,))
949
+ assert(indexed.get_rots().shape == (2,))
950
+ assert(indexed.get_trans().shape == (2, 3))
951
+
952
+ Args:
953
+ index: A standard torch tensor index. E.g. 8, (10, None, 3),
954
+ or (3, slice(0, 1, None))
955
+ Returns:
956
+ The indexed tensor
957
+ """
958
+ if type(index) != tuple:
959
+ index = (index,)
960
+
961
+ return Rigid(
962
+ self._rots[index],
963
+ self._trans[index + (slice(None),)],
964
+ )
965
+
966
+ def __mul__(self,
967
+ right: torch.Tensor,
968
+ ):
969
+ """
970
+ Pointwise left multiplication of the transformation with a tensor.
971
+ Can be used to e.g. mask the Rigid.
972
+
973
+ Args:
974
+ right:
975
+ The tensor multiplicand
976
+ Returns:
977
+ The product
978
+ """
979
+ if not(isinstance(right, torch.Tensor)):
980
+ raise TypeError("The other multiplicand must be a Tensor")
981
+
982
+ new_rots = self._rots * right
983
+ new_trans = self._trans * right[..., None]
984
+
985
+ return Rigid(new_rots, new_trans)
986
+
987
+ def __rmul__(self,
988
+ left: torch.Tensor,
989
+ ):
990
+ """
991
+ Reverse pointwise multiplication of the transformation with a
992
+ tensor.
993
+
994
+ Args:
995
+ left:
996
+ The left multiplicand
997
+ Returns:
998
+ The product
999
+ """
1000
+ return self.__mul__(left)
1001
+
1002
+ @property
1003
+ def shape(self) -> torch.Size:
1004
+ """
1005
+ Returns the shape of the shared dimensions of the rotation and
1006
+ the translation.
1007
+
1008
+ Returns:
1009
+ The shape of the transformation
1010
+ """
1011
+ s = self._trans.shape[:-1]
1012
+ return s
1013
+
1014
+ @property
1015
+ def device(self) -> torch.device:
1016
+ """
1017
+ Returns the device on which the Rigid's tensors are located.
1018
+
1019
+ Returns:
1020
+ The device on which the Rigid's tensors are located
1021
+ """
1022
+ return self._trans.device
1023
+
1024
+ def get_rots(self) -> Rotation:
1025
+ """
1026
+ Getter for the rotation.
1027
+
1028
+ Returns:
1029
+ The rotation object
1030
+ """
1031
+ return self._rots
1032
+
1033
+ def get_trans(self) -> torch.Tensor:
1034
+ """
1035
+ Getter for the translation.
1036
+
1037
+ Returns:
1038
+ The stored translation
1039
+ """
1040
+ return self._trans
1041
+
1042
+ def compose_q_update_vec(self,
1043
+ q_update_vec: torch.Tensor,
1044
+ update_mask: torch.Tensor=None,
1045
+ ):
1046
+ """
1047
+ Composes the transformation with a quaternion update vector of
1048
+ shape [*, 6], where the final 6 columns represent the x, y, and
1049
+ z values of a quaternion of form (1, x, y, z) followed by a 3D
1050
+ translation.
1051
+
1052
+ Args:
1053
+ q_vec: The quaternion update vector.
1054
+ Returns:
1055
+ The composed transformation.
1056
+ """
1057
+ q_vec, t_vec = q_update_vec[..., :3], q_update_vec[..., 3:]
1058
+ new_rots = self._rots.compose_q_update_vec(
1059
+ q_vec, update_mask=update_mask)
1060
+
1061
+ trans_update = self._rots.apply(t_vec)
1062
+ if update_mask is not None:
1063
+ trans_update = trans_update * update_mask
1064
+ new_translation = self._trans + trans_update
1065
+
1066
+ return Rigid(new_rots, new_translation)
1067
+
1068
+ def compose(self,
1069
+ r,
1070
+ ):
1071
+ """
1072
+ Composes the current rigid object with another.
1073
+
1074
+ Args:
1075
+ r:
1076
+ Another Rigid object
1077
+ Returns:
1078
+ The composition of the two transformations
1079
+ """
1080
+ new_rot = self._rots.compose_r(r._rots)
1081
+ new_trans = self._rots.apply(r._trans) + self._trans
1082
+ return Rigid(new_rot, new_trans)
1083
+
1084
+ def compose_r(self,
1085
+ rot,
1086
+ order='right'
1087
+ ):
1088
+ """
1089
+ Composes the current rigid object with another.
1090
+
1091
+ Args:
1092
+ r:
1093
+ Another Rigid object
1094
+ order:
1095
+ Order in which to perform rotation multiplication.
1096
+ Returns:
1097
+ The composition of the two transformations
1098
+ """
1099
+ if order == 'right':
1100
+ new_rot = self._rots.compose_r(rot)
1101
+ elif order == 'left':
1102
+ new_rot = rot.compose_r(self._rots)
1103
+ else:
1104
+ raise ValueError(f'Unrecognized multiplication order: {order}')
1105
+ return Rigid(new_rot, self._trans)
1106
+
1107
+ def apply(self,
1108
+ pts: torch.Tensor,
1109
+ ) -> torch.Tensor:
1110
+ """
1111
+ Applies the transformation to a coordinate tensor.
1112
+
1113
+ Args:
1114
+ pts: A [*, 3] coordinate tensor.
1115
+ Returns:
1116
+ The transformed points.
1117
+ """
1118
+ rotated = self._rots.apply(pts)
1119
+ return rotated + self._trans
1120
+
1121
+ def invert_apply(self,
1122
+ pts: torch.Tensor
1123
+ ) -> torch.Tensor:
1124
+ """
1125
+ Applies the inverse of the transformation to a coordinate tensor.
1126
+
1127
+ Args:
1128
+ pts: A [*, 3] coordinate tensor
1129
+ Returns:
1130
+ The transformed points.
1131
+ """
1132
+ pts = pts - self._trans
1133
+ return self._rots.invert_apply(pts)
1134
+
1135
+ def invert(self):
1136
+ """
1137
+ Inverts the transformation.
1138
+
1139
+ Returns:
1140
+ The inverse transformation.
1141
+ """
1142
+ rot_inv = self._rots.invert()
1143
+ trn_inv = rot_inv.apply(self._trans)
1144
+
1145
+ return Rigid(rot_inv, -1 * trn_inv)
1146
+
1147
+ def map_tensor_fn(self,
1148
+ fn
1149
+ ):
1150
+ """
1151
+ Apply a Tensor -> Tensor function to underlying translation and
1152
+ rotation tensors, mapping over the translation/rotation dimensions
1153
+ respectively.
1154
+
1155
+ Args:
1156
+ fn:
1157
+ A Tensor -> Tensor function to be mapped over the Rigid
1158
+ Returns:
1159
+ The transformed Rigid object
1160
+ """
1161
+ new_rots = self._rots.map_tensor_fn(fn)
1162
+ new_trans = torch.stack(
1163
+ list(map(fn, torch.unbind(self._trans, dim=-1))),
1164
+ dim=-1
1165
+ )
1166
+
1167
+ return Rigid(new_rots, new_trans)
1168
+
1169
+ def to_tensor_4x4(self) -> torch.Tensor:
1170
+ """
1171
+ Converts a transformation to a homogenous transformation tensor.
1172
+
1173
+ Returns:
1174
+ A [*, 4, 4] homogenous transformation tensor
1175
+ """
1176
+ tensor = self._trans.new_zeros((*self.shape, 4, 4))
1177
+ tensor[..., :3, :3] = self._rots.get_rot_mats()
1178
+ tensor[..., :3, 3] = self._trans
1179
+ tensor[..., 3, 3] = 1
1180
+ return tensor
1181
+
1182
+ @staticmethod
1183
+ def from_tensor_4x4(
1184
+ t: torch.Tensor
1185
+ ):
1186
+ """
1187
+ Constructs a transformation from a homogenous transformation
1188
+ tensor.
1189
+
1190
+ Args:
1191
+ t: [*, 4, 4] homogenous transformation tensor
1192
+ Returns:
1193
+ T object with shape [*]
1194
+ """
1195
+ if(t.shape[-2:] != (4, 4)):
1196
+ raise ValueError("Incorrectly shaped input tensor")
1197
+
1198
+ rots = Rotation(rot_mats=t[..., :3, :3], quats=None)
1199
+ trans = t[..., :3, 3]
1200
+
1201
+ return Rigid(rots, trans)
1202
+
1203
+ def to_tensor_7(self) -> torch.Tensor:
1204
+ """
1205
+ Converts a transformation to a tensor with 7 final columns, four
1206
+ for the quaternion followed by three for the translation.
1207
+
1208
+ Returns:
1209
+ A [*, 7] tensor representation of the transformation
1210
+ """
1211
+ tensor = self._trans.new_zeros((*self.shape, 7))
1212
+ tensor[..., :4] = self._rots.get_quats()
1213
+ tensor[..., 4:] = self._trans
1214
+
1215
+ return tensor
1216
+
1217
+ @staticmethod
1218
+ def from_tensor_7(
1219
+ t: torch.Tensor,
1220
+ normalize_quats: bool = False,
1221
+ ):
1222
+ if(t.shape[-1] != 7):
1223
+ raise ValueError("Incorrectly shaped input tensor")
1224
+
1225
+ quats, trans = t[..., :4], t[..., 4:]
1226
+
1227
+ rots = Rotation(
1228
+ rot_mats=None,
1229
+ quats=quats,
1230
+ normalize_quats=normalize_quats
1231
+ )
1232
+
1233
+ return Rigid(rots, trans)
1234
+
1235
+ @staticmethod
1236
+ def from_3_points(
1237
+ p_neg_x_axis: torch.Tensor,
1238
+ origin: torch.Tensor,
1239
+ p_xy_plane: torch.Tensor,
1240
+ eps: float = 1e-8
1241
+ ):
1242
+ """
1243
+ Implements algorithm 21. Constructs transformations from sets of 3
1244
+ points using the Gram-Schmidt algorithm.
1245
+
1246
+ Args:
1247
+ p_neg_x_axis: [*, 3] coordinates
1248
+ origin: [*, 3] coordinates used as frame origins
1249
+ p_xy_plane: [*, 3] coordinates
1250
+ eps: Small epsilon value
1251
+ Returns:
1252
+ A transformation object of shape [*]
1253
+ """
1254
+ p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1)
1255
+ origin = torch.unbind(origin, dim=-1)
1256
+ p_xy_plane = torch.unbind(p_xy_plane, dim=-1)
1257
+
1258
+ e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)]
1259
+ e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)]
1260
+
1261
+ denom = torch.sqrt(sum((c * c for c in e0)) + eps)
1262
+ e0 = [c / denom for c in e0]
1263
+ dot = sum((c1 * c2 for c1, c2 in zip(e0, e1)))
1264
+ e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)]
1265
+ denom = torch.sqrt(sum((c * c for c in e1)) + eps)
1266
+ e1 = [c / denom for c in e1]
1267
+ e2 = [
1268
+ e0[1] * e1[2] - e0[2] * e1[1],
1269
+ e0[2] * e1[0] - e0[0] * e1[2],
1270
+ e0[0] * e1[1] - e0[1] * e1[0],
1271
+ ]
1272
+
1273
+ rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1)
1274
+ rots = rots.reshape(rots.shape[:-1] + (3, 3))
1275
+
1276
+ rot_obj = Rotation(rot_mats=rots, quats=None)
1277
+
1278
+ return Rigid(rot_obj, torch.stack(origin, dim=-1))
1279
+
1280
+ def unsqueeze(self,
1281
+ dim: int,
1282
+ ):
1283
+ """
1284
+ Analogous to torch.unsqueeze. The dimension is relative to the
1285
+ shared dimensions of the rotation/translation.
1286
+
1287
+ Args:
1288
+ dim: A positive or negative dimension index.
1289
+ Returns:
1290
+ The unsqueezed transformation.
1291
+ """
1292
+ if dim >= len(self.shape):
1293
+ raise ValueError("Invalid dimension")
1294
+ rots = self._rots.unsqueeze(dim)
1295
+ trans = self._trans.unsqueeze(dim if dim >= 0 else dim - 1)
1296
+
1297
+ return Rigid(rots, trans)
1298
+
1299
+ @staticmethod
1300
+ def cat(
1301
+ ts,
1302
+ dim: int,
1303
+ ):
1304
+ """
1305
+ Concatenates transformations along a new dimension.
1306
+
1307
+ Args:
1308
+ ts:
1309
+ A list of T objects
1310
+ dim:
1311
+ The dimension along which the transformations should be
1312
+ concatenated
1313
+ Returns:
1314
+ A concatenated transformation object
1315
+ """
1316
+ rots = Rotation.cat([t._rots for t in ts], dim)
1317
+ trans = torch.cat(
1318
+ [t._trans for t in ts], dim=dim if dim >= 0 else dim - 1
1319
+ )
1320
+
1321
+ return Rigid(rots, trans)
1322
+
1323
+ def apply_rot_fn(self, fn):
1324
+ """
1325
+ Applies a Rotation -> Rotation function to the stored rotation
1326
+ object.
1327
+
1328
+ Args:
1329
+ fn: A function of type Rotation -> Rotation
1330
+ Returns:
1331
+ A transformation object with a transformed rotation.
1332
+ """
1333
+ return Rigid(fn(self._rots), self._trans)
1334
+
1335
+ def apply_trans_fn(self, fn):
1336
+ """
1337
+ Applies a Tensor -> Tensor function to the stored translation.
1338
+
1339
+ Args:
1340
+ fn:
1341
+ A function of type Tensor -> Tensor to be applied to the
1342
+ translation
1343
+ Returns:
1344
+ A transformation object with a transformed translation.
1345
+ """
1346
+ return Rigid(self._rots, fn(self._trans))
1347
+
1348
+ def scale_translation(self, trans_scale_factor: float):
1349
+ """
1350
+ Scales the translation by a constant factor.
1351
+
1352
+ Args:
1353
+ trans_scale_factor:
1354
+ The constant factor
1355
+ Returns:
1356
+ A transformation object with a scaled translation.
1357
+ """
1358
+ fn = lambda t: t * trans_scale_factor
1359
+ return self.apply_trans_fn(fn)
1360
+
1361
+ def stop_rot_gradient(self):
1362
+ """
1363
+ Detaches the underlying rotation object
1364
+
1365
+ Returns:
1366
+ A transformation object with detached rotations
1367
+ """
1368
+ fn = lambda r: r.detach()
1369
+ return self.apply_rot_fn(fn)
1370
+
1371
+ @staticmethod
1372
+ def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20):
1373
+ """
1374
+ Returns a transformation object from reference coordinates.
1375
+
1376
+ Note that this method does not take care of symmetries. If you
1377
+ provide the atom positions in the non-standard way, the N atom will
1378
+ end up not at [-0.527250, 1.359329, 0.0] but instead at
1379
+ [-0.527250, -1.359329, 0.0]. You need to take care of such cases in
1380
+ your code.
1381
+
1382
+ Args:
1383
+ n_xyz: A [*, 3] tensor of nitrogen xyz coordinates.
1384
+ ca_xyz: A [*, 3] tensor of carbon alpha xyz coordinates.
1385
+ c_xyz: A [*, 3] tensor of carbon xyz coordinates.
1386
+ Returns:
1387
+ A transformation object. After applying the translation and
1388
+ rotation to the reference backbone, the coordinates will
1389
+ approximately equal to the input coordinates.
1390
+ """
1391
+ translation = -1 * ca_xyz
1392
+ n_xyz = n_xyz + translation
1393
+ c_xyz = c_xyz + translation
1394
+
1395
+ c_x, c_y, c_z = [c_xyz[..., i] for i in range(3)]
1396
+ norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2)
1397
+ sin_c1 = -c_y / norm
1398
+ cos_c1 = c_x / norm
1399
+ zeros = sin_c1.new_zeros(sin_c1.shape)
1400
+ ones = sin_c1.new_ones(sin_c1.shape)
1401
+
1402
+ c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3))
1403
+ c1_rots[..., 0, 0] = cos_c1
1404
+ c1_rots[..., 0, 1] = -1 * sin_c1
1405
+ c1_rots[..., 1, 0] = sin_c1
1406
+ c1_rots[..., 1, 1] = cos_c1
1407
+ c1_rots[..., 2, 2] = 1
1408
+
1409
+ norm = torch.sqrt(eps + c_x ** 2 + c_y ** 2 + c_z ** 2)
1410
+ sin_c2 = c_z / norm
1411
+ cos_c2 = torch.sqrt(c_x ** 2 + c_y ** 2) / norm
1412
+
1413
+ c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
1414
+ c2_rots[..., 0, 0] = cos_c2
1415
+ c2_rots[..., 0, 2] = sin_c2
1416
+ c2_rots[..., 1, 1] = 1
1417
+ c1_rots[..., 2, 0] = -1 * sin_c2
1418
+ c1_rots[..., 2, 2] = cos_c2
1419
+
1420
+ c_rots = rot_matmul(c2_rots, c1_rots)
1421
+ n_xyz = rot_vec_mul(c_rots, n_xyz)
1422
+
1423
+ _, n_y, n_z = [n_xyz[..., i] for i in range(3)]
1424
+ norm = torch.sqrt(eps + n_y ** 2 + n_z ** 2)
1425
+ sin_n = -n_z / norm
1426
+ cos_n = n_y / norm
1427
+
1428
+ n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3))
1429
+ n_rots[..., 0, 0] = 1
1430
+ n_rots[..., 1, 1] = cos_n
1431
+ n_rots[..., 1, 2] = -1 * sin_n
1432
+ n_rots[..., 2, 1] = sin_n
1433
+ n_rots[..., 2, 2] = cos_n
1434
+
1435
+ rots = rot_matmul(n_rots, c_rots)
1436
+
1437
+ rots = rots.transpose(-1, -2)
1438
+ translation = -1 * translation
1439
+
1440
+ rot_obj = Rotation(rot_mats=rots, quats=None)
1441
+
1442
+ return Rigid(rot_obj, translation)
1443
+
1444
+ def cuda(self):
1445
+ """
1446
+ Moves the transformation object to GPU memory
1447
+
1448
+ Returns:
1449
+ A version of the transformation on GPU
1450
+ """
1451
+ return Rigid(self._rots.cuda(), self._trans.cuda())
analysis/src/common/rotation3d.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Union
8
+
9
+ import torch
10
+ from torch.nn import functional as F
11
+
12
+ Device = Union[str, torch.device]
13
+
14
+
15
+ """
16
+ The transformation matrices returned from the functions in this file assume
17
+ the points on which the transformation will be applied are column vectors.
18
+ i.e. the R matrix is structured as
19
+
20
+ R = [
21
+ [Rxx, Rxy, Rxz],
22
+ [Ryx, Ryy, Ryz],
23
+ [Rzx, Rzy, Rzz],
24
+ ] # (3, 3)
25
+
26
+ This matrix can be applied to column vectors by post multiplication
27
+ by the points e.g.
28
+
29
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
30
+ transformed_points = R * points
31
+
32
+ To apply the same matrix to points which are row vectors, the R matrix
33
+ can be transposed and pre multiplied by the points:
34
+
35
+ e.g.
36
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
37
+ transformed_points = points * R.transpose(1, 0)
38
+ """
39
+
40
+
41
+ def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
42
+ """
43
+ Convert rotations given as quaternions to rotation matrices.
44
+
45
+ Args:
46
+ quaternions: quaternions with real part first,
47
+ as tensor of shape (..., 4).
48
+
49
+ Returns:
50
+ Rotation matrices as tensor of shape (..., 3, 3).
51
+ """
52
+ r, i, j, k = torch.unbind(quaternions, -1)
53
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
54
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
55
+
56
+ o = torch.stack(
57
+ (
58
+ 1 - two_s * (j * j + k * k),
59
+ two_s * (i * j - k * r),
60
+ two_s * (i * k + j * r),
61
+ two_s * (i * j + k * r),
62
+ 1 - two_s * (i * i + k * k),
63
+ two_s * (j * k - i * r),
64
+ two_s * (i * k - j * r),
65
+ two_s * (j * k + i * r),
66
+ 1 - two_s * (i * i + j * j),
67
+ ),
68
+ -1,
69
+ )
70
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
71
+
72
+
73
+ def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
74
+ """
75
+ Return a tensor where each element has the absolute value taken from the,
76
+ corresponding element of a, with sign taken from the corresponding
77
+ element of b. This is like the standard copysign floating-point operation,
78
+ but is not careful about negative 0 and NaN.
79
+
80
+ Args:
81
+ a: source tensor.
82
+ b: tensor whose signs will be used, of the same shape as a.
83
+
84
+ Returns:
85
+ Tensor of the same shape as a with the signs of b.
86
+ """
87
+ signs_differ = (a < 0) != (b < 0)
88
+ return torch.where(signs_differ, -a, a)
89
+
90
+
91
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
92
+ """
93
+ Returns torch.sqrt(torch.max(0, x))
94
+ but with a zero subgradient where x is 0.
95
+ """
96
+ ret = torch.zeros_like(x)
97
+ positive_mask = x > 0
98
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
99
+ return ret
100
+
101
+
102
+ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
103
+ """
104
+ Convert rotations given as rotation matrices to quaternions.
105
+
106
+ Args:
107
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
108
+
109
+ Returns:
110
+ quaternions with real part first, as tensor of shape (..., 4).
111
+ """
112
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
113
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
114
+
115
+ batch_dim = matrix.shape[:-2]
116
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
117
+ matrix.reshape(batch_dim + (9,)), dim=-1
118
+ )
119
+
120
+ q_abs = _sqrt_positive_part(
121
+ torch.stack(
122
+ [
123
+ 1.0 + m00 + m11 + m22,
124
+ 1.0 + m00 - m11 - m22,
125
+ 1.0 - m00 + m11 - m22,
126
+ 1.0 - m00 - m11 + m22,
127
+ ],
128
+ dim=-1,
129
+ )
130
+ )
131
+
132
+ # we produce the desired quaternion multiplied by each of r, i, j, k
133
+ quat_by_rijk = torch.stack(
134
+ [
135
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
136
+ # `int`.
137
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
138
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
139
+ # `int`.
140
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
141
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
142
+ # `int`.
143
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
144
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
145
+ # `int`.
146
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
147
+ ],
148
+ dim=-2,
149
+ )
150
+
151
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
152
+ # the candidate won't be picked.
153
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
154
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
155
+
156
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
157
+ # forall i; we pick the best-conditioned one (with the largest denominator)
158
+
159
+ return quat_candidates[
160
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
161
+ ].reshape(batch_dim + (4,))
162
+
163
+
164
+ def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
165
+ """
166
+ Return the rotation matrices for one of the rotations about an axis
167
+ of which Euler angles describe, for each value of the angle given.
168
+
169
+ Args:
170
+ axis: Axis label "X" or "Y or "Z".
171
+ angle: any shape tensor of Euler angles in radians
172
+
173
+ Returns:
174
+ Rotation matrices as tensor of shape (..., 3, 3).
175
+ """
176
+
177
+ cos = torch.cos(angle)
178
+ sin = torch.sin(angle)
179
+ one = torch.ones_like(angle)
180
+ zero = torch.zeros_like(angle)
181
+
182
+ if axis == "X":
183
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
184
+ elif axis == "Y":
185
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
186
+ elif axis == "Z":
187
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
188
+ else:
189
+ raise ValueError("letter must be either X, Y or Z.")
190
+
191
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
192
+
193
+
194
+ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
195
+ """
196
+ Convert rotations given as Euler angles in radians to rotation matrices.
197
+
198
+ Args:
199
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
200
+ convention: Convention string of three uppercase letters from
201
+ {"X", "Y", and "Z"}.
202
+
203
+ Returns:
204
+ Rotation matrices as tensor of shape (..., 3, 3).
205
+ """
206
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
207
+ raise ValueError("Invalid input euler angles.")
208
+ if len(convention) != 3:
209
+ raise ValueError("Convention must have 3 letters.")
210
+ if convention[1] in (convention[0], convention[2]):
211
+ raise ValueError(f"Invalid convention {convention}.")
212
+ for letter in convention:
213
+ if letter not in ("X", "Y", "Z"):
214
+ raise ValueError(f"Invalid letter {letter} in convention string.")
215
+ matrices = [
216
+ _axis_angle_rotation(c, e)
217
+ for c, e in zip(convention, torch.unbind(euler_angles, -1))
218
+ ]
219
+ # return functools.reduce(torch.matmul, matrices)
220
+ return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
221
+
222
+
223
+ def _angle_from_tan(
224
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
225
+ ) -> torch.Tensor:
226
+ """
227
+ Extract the first or third Euler angle from the two members of
228
+ the matrix which are positive constant times its sine and cosine.
229
+
230
+ Args:
231
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
232
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
233
+ convention.
234
+ data: Rotation matrices as tensor of shape (..., 3, 3).
235
+ horizontal: Whether we are looking for the angle for the third axis,
236
+ which means the relevant entries are in the same row of the
237
+ rotation matrix. If not, they are in the same column.
238
+ tait_bryan: Whether the first and third axes in the convention differ.
239
+
240
+ Returns:
241
+ Euler Angles in radians for each matrix in data as a tensor
242
+ of shape (...).
243
+ """
244
+
245
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
246
+ if horizontal:
247
+ i2, i1 = i1, i2
248
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
249
+ if horizontal == even:
250
+ return torch.atan2(data[..., i1], data[..., i2])
251
+ if tait_bryan:
252
+ return torch.atan2(-data[..., i2], data[..., i1])
253
+ return torch.atan2(data[..., i2], -data[..., i1])
254
+
255
+
256
+ def _index_from_letter(letter: str) -> int:
257
+ if letter == "X":
258
+ return 0
259
+ if letter == "Y":
260
+ return 1
261
+ if letter == "Z":
262
+ return 2
263
+ raise ValueError("letter must be either X, Y or Z.")
264
+
265
+
266
+ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
267
+ """
268
+ Convert rotations given as rotation matrices to Euler angles in radians.
269
+
270
+ Args:
271
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
272
+ convention: Convention string of three uppercase letters.
273
+
274
+ Returns:
275
+ Euler angles in radians as tensor of shape (..., 3).
276
+ """
277
+ if len(convention) != 3:
278
+ raise ValueError("Convention must have 3 letters.")
279
+ if convention[1] in (convention[0], convention[2]):
280
+ raise ValueError(f"Invalid convention {convention}.")
281
+ for letter in convention:
282
+ if letter not in ("X", "Y", "Z"):
283
+ raise ValueError(f"Invalid letter {letter} in convention string.")
284
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
285
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
286
+ i0 = _index_from_letter(convention[0])
287
+ i2 = _index_from_letter(convention[2])
288
+ tait_bryan = i0 != i2
289
+ if tait_bryan:
290
+ central_angle = torch.asin(
291
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
292
+ )
293
+ else:
294
+ central_angle = torch.acos(matrix[..., i0, i0])
295
+
296
+ o = (
297
+ _angle_from_tan(
298
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
299
+ ),
300
+ central_angle,
301
+ _angle_from_tan(
302
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
303
+ ),
304
+ )
305
+ return torch.stack(o, -1)
306
+
307
+
308
+ def random_quaternions(
309
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
310
+ ) -> torch.Tensor:
311
+ """
312
+ Generate random quaternions representing rotations,
313
+ i.e. versors with nonnegative real part.
314
+
315
+ Args:
316
+ n: Number of quaternions in a batch to return.
317
+ dtype: Type to return.
318
+ device: Desired device of returned tensor. Default:
319
+ uses the current device for the default tensor type.
320
+
321
+ Returns:
322
+ Quaternions as tensor of shape (N, 4).
323
+ """
324
+ if isinstance(device, str):
325
+ device = torch.device(device)
326
+ o = torch.randn((n, 4), dtype=dtype, device=device)
327
+ s = (o * o).sum(1)
328
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
329
+ return o
330
+
331
+
332
+ def random_rotations(
333
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
334
+ ) -> torch.Tensor:
335
+ """
336
+ Generate random rotations as 3x3 rotation matrices.
337
+
338
+ Args:
339
+ n: Number of rotation matrices in a batch to return.
340
+ dtype: Type to return.
341
+ device: Device of returned tensor. Default: if None,
342
+ uses the current device for the default tensor type.
343
+
344
+ Returns:
345
+ Rotation matrices as tensor of shape (n, 3, 3).
346
+ """
347
+ quaternions = random_quaternions(n, dtype=dtype, device=device)
348
+ return quaternion_to_matrix(quaternions)
349
+
350
+
351
+ def random_rotation(
352
+ dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
353
+ ) -> torch.Tensor:
354
+ """
355
+ Generate a single random 3x3 rotation matrix.
356
+
357
+ Args:
358
+ dtype: Type to return
359
+ device: Device of returned tensor. Default: if None,
360
+ uses the current device for the default tensor type
361
+
362
+ Returns:
363
+ Rotation matrix as tensor of shape (3, 3).
364
+ """
365
+ return random_rotations(1, dtype, device)[0]
366
+
367
+
368
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
369
+ """
370
+ Convert a unit quaternion to a standard form: one in which the real
371
+ part is non negative.
372
+
373
+ Args:
374
+ quaternions: Quaternions with real part first,
375
+ as tensor of shape (..., 4).
376
+
377
+ Returns:
378
+ Standardized quaternions as tensor of shape (..., 4).
379
+ """
380
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
381
+
382
+
383
+ def quaternion_raw_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
384
+ """
385
+ Multiply two quaternions.
386
+ Usual torch rules for broadcasting apply.
387
+
388
+ Args:
389
+ a: Quaternions as tensor of shape (..., 4), real part first.
390
+ b: Quaternions as tensor of shape (..., 4), real part first.
391
+
392
+ Returns:
393
+ The product of a and b, a tensor of quaternions shape (..., 4).
394
+ """
395
+ aw, ax, ay, az = torch.unbind(a, -1)
396
+ bw, bx, by, bz = torch.unbind(b, -1)
397
+ ow = aw * bw - ax * bx - ay * by - az * bz
398
+ ox = aw * bx + ax * bw + ay * bz - az * by
399
+ oy = aw * by - ax * bz + ay * bw + az * bx
400
+ oz = aw * bz + ax * by - ay * bx + az * bw
401
+ return torch.stack((ow, ox, oy, oz), -1)
402
+
403
+
404
+ def quaternion_multiply(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
405
+ """
406
+ Multiply two quaternions representing rotations, returning the quaternion
407
+ representing their composition, i.e. the versor with nonnegative real part.
408
+ Usual torch rules for broadcasting apply.
409
+
410
+ Args:
411
+ a: Quaternions as tensor of shape (..., 4), real part first.
412
+ b: Quaternions as tensor of shape (..., 4), real part first.
413
+
414
+ Returns:
415
+ The product of a and b, a tensor of quaternions of shape (..., 4).
416
+ """
417
+ ab = quaternion_raw_multiply(a, b)
418
+ return standardize_quaternion(ab)
419
+
420
+
421
+ def quaternion_invert(quaternion: torch.Tensor) -> torch.Tensor:
422
+ """
423
+ Given a quaternion representing rotation, get the quaternion representing
424
+ its inverse.
425
+
426
+ Args:
427
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
428
+ first, which must be versors (unit quaternions).
429
+
430
+ Returns:
431
+ The inverse, a tensor of quaternions of shape (..., 4).
432
+ """
433
+
434
+ scaling = torch.tensor([1, -1, -1, -1], device=quaternion.device)
435
+ return quaternion * scaling
436
+
437
+
438
+ def quaternion_apply(quaternion: torch.Tensor, point: torch.Tensor) -> torch.Tensor:
439
+ """
440
+ Apply the rotation given by a quaternion to a 3D point.
441
+ Usual torch rules for broadcasting apply.
442
+
443
+ Args:
444
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
445
+ point: Tensor of 3D points of shape (..., 3).
446
+
447
+ Returns:
448
+ Tensor of rotated points of shape (..., 3).
449
+ """
450
+ if point.size(-1) != 3:
451
+ raise ValueError(f"Points are not in 3D, {point.shape}.")
452
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
453
+ point_as_quaternion = torch.cat((real_parts, point), -1)
454
+ out = quaternion_raw_multiply(
455
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
456
+ quaternion_invert(quaternion),
457
+ )
458
+ return out[..., 1:]
459
+
460
+
461
+ def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor:
462
+ """
463
+ Convert rotations given as axis/angle to rotation matrices.
464
+
465
+ Args:
466
+ axis_angle: Rotations given as a vector in axis angle form,
467
+ as a tensor of shape (..., 3), where the magnitude is
468
+ the angle turned anticlockwise in radians around the
469
+ vector's direction.
470
+
471
+ Returns:
472
+ Rotation matrices as tensor of shape (..., 3, 3).
473
+ """
474
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
475
+
476
+
477
+ def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
478
+ """
479
+ Convert rotations given as rotation matrices to axis/angle.
480
+
481
+ Args:
482
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
483
+
484
+ Returns:
485
+ Rotations given as a vector in axis angle form, as a tensor
486
+ of shape (..., 3), where the magnitude is the angle
487
+ turned anticlockwise in radians around the vector's
488
+ direction.
489
+ """
490
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
491
+
492
+
493
+ def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor:
494
+ """
495
+ Convert rotations given as axis/angle to quaternions.
496
+
497
+ Args:
498
+ axis_angle: Rotations given as a vector in axis angle form,
499
+ as a tensor of shape (..., 3), where the magnitude is
500
+ the angle turned anticlockwise in radians around the
501
+ vector's direction.
502
+
503
+ Returns:
504
+ quaternions with real part first, as tensor of shape (..., 4).
505
+ """
506
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
507
+ half_angles = angles * 0.5
508
+ eps = 1e-6
509
+ small_angles = angles.abs() < eps
510
+ sin_half_angles_over_angles = torch.empty_like(angles)
511
+ sin_half_angles_over_angles[~small_angles] = (
512
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
513
+ )
514
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
515
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
516
+ sin_half_angles_over_angles[small_angles] = (
517
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
518
+ )
519
+ quaternions = torch.cat(
520
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
521
+ )
522
+ return quaternions
523
+
524
+
525
+ def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
526
+ """
527
+ Convert rotations given as quaternions to axis/angle.
528
+
529
+ Args:
530
+ quaternions: quaternions with real part first,
531
+ as tensor of shape (..., 4).
532
+
533
+ Returns:
534
+ Rotations given as a vector in axis angle form, as a tensor
535
+ of shape (..., 3), where the magnitude is the angle
536
+ turned anticlockwise in radians around the vector's
537
+ direction.
538
+ """
539
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
540
+ half_angles = torch.atan2(norms, quaternions[..., :1])
541
+ angles = 2 * half_angles
542
+ eps = 1e-6
543
+ small_angles = angles.abs() < eps
544
+ sin_half_angles_over_angles = torch.empty_like(angles)
545
+ sin_half_angles_over_angles[~small_angles] = (
546
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
547
+ )
548
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
549
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
550
+ sin_half_angles_over_angles[small_angles] = (
551
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
552
+ )
553
+ return quaternions[..., 1:] / sin_half_angles_over_angles
554
+
555
+
556
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
557
+ """
558
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
559
+ using Gram--Schmidt orthogonalization per Section B of [1].
560
+ Args:
561
+ d6: 6D rotation representation, of size (*, 6)
562
+
563
+ Returns:
564
+ batch of rotation matrices of size (*, 3, 3)
565
+
566
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
567
+ On the Continuity of Rotation Representations in Neural Networks.
568
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
569
+ Retrieved from http://arxiv.org/abs/1812.07035
570
+ """
571
+
572
+ a1, a2 = d6[..., :3], d6[..., 3:]
573
+ b1 = F.normalize(a1, dim=-1)
574
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
575
+ b2 = F.normalize(b2, dim=-1)
576
+ b3 = torch.cross(b1, b2, dim=-1)
577
+ return torch.stack((b1, b2, b3), dim=-2)
578
+
579
+
580
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
581
+ """
582
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
583
+ by dropping the last row. Note that 6D representation is not unique.
584
+ Args:
585
+ matrix: batch of rotation matrices of size (*, 3, 3)
586
+
587
+ Returns:
588
+ 6D rotation representation, of size (*, 6)
589
+
590
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
591
+ On the Continuity of Rotation Representations in Neural Networks.
592
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
593
+ Retrieved from http://arxiv.org/abs/1812.07035
594
+ """
595
+ batch_dim = matrix.size()[:-2]
596
+ return matrix[..., :2, :].clone().reshape(batch_dim + (6,))
analysis/src/data/__init__.py ADDED
File without changes
analysis/src/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (141 Bytes). View file
 
analysis/src/data/__pycache__/protein_datamodule.cpython-39.pyc ADDED
Binary file (10.8 kB). View file
 
analysis/src/data/components/__init__.py ADDED
File without changes
analysis/src/data/components/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (152 Bytes). View file
 
analysis/src/data/components/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (10.1 kB). View file
 
analysis/src/data/components/dataset.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Protein dataset class."""
2
+ import os
3
+ import pickle
4
+ from pathlib import Path
5
+ from glob import glob
6
+ from typing import Optional, Sequence, List, Union
7
+ from functools import lru_cache
8
+ import tree
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+
15
+ from src.common import residue_constants, data_transforms, rigid_utils, protein
16
+
17
+
18
+ CA_IDX = residue_constants.atom_order['CA']
19
+ DTYPE_MAPPING = {
20
+ 'aatype': torch.long,
21
+ 'atom_positions': torch.double,
22
+ 'atom_mask': torch.double,
23
+ }
24
+
25
+
26
+ class ProteinFeatureTransform:
27
+ def __init__(self,
28
+ unit: Optional[str] = 'angstrom',
29
+ truncate_length: Optional[int] = None,
30
+ strip_missing_residues: bool = True,
31
+ recenter_and_scale: bool = True,
32
+ eps: float = 1e-8,
33
+ ):
34
+ if unit == 'angstrom':
35
+ self.coordinate_scale = 1.0
36
+ elif unit in ('nm', 'nanometer'):
37
+ self.coordiante_scale = 0.1
38
+ else:
39
+ raise ValueError(f"Invalid unit: {unit}")
40
+
41
+ if truncate_length is not None:
42
+ assert truncate_length > 0, f"Invalid truncate_length: {truncate_length}"
43
+ self.truncate_length = truncate_length
44
+
45
+ self.strip_missing_residues = strip_missing_residues
46
+ self.recenter_and_scale = recenter_and_scale
47
+ self.eps = eps
48
+
49
+ def __call__(self, chain_feats):
50
+ chain_feats = self.patch_feats(chain_feats)
51
+
52
+ if self.strip_missing_residues:
53
+ chain_feats = self.strip_ends(chain_feats)
54
+
55
+ if self.truncate_length is not None:
56
+ chain_feats = self.random_truncate(chain_feats, max_len=self.truncate_length)
57
+
58
+ # Recenter and scale atom positions
59
+ if self.recenter_and_scale:
60
+ chain_feats = self.recenter_and_scale_coords(chain_feats, coordinate_scale=self.coordinate_scale, eps=self.eps)
61
+
62
+ # Map to torch Tensor
63
+ chain_feats = self.map_to_tensors(chain_feats)
64
+ # Add extra features from AF2
65
+ chain_feats = self.protein_data_transform(chain_feats)
66
+
67
+ # ** refer to line 170 in pdb_data_loader.py **
68
+ return chain_feats
69
+
70
+ @staticmethod
71
+ def patch_feats(chain_feats):
72
+ seq_mask = chain_feats['atom_mask'][:, CA_IDX] # a little hack here
73
+ # residue_idx = np.arange(seq_mask.shape[0], dtype=np.int64)
74
+ residue_idx = chain_feats['residue_index'] - np.min(chain_feats['residue_index']) # start from 0, possibly has chain break
75
+ patch_feats = {
76
+ 'seq_mask': seq_mask,
77
+ 'residue_mask': seq_mask,
78
+ 'residue_idx': residue_idx,
79
+ 'fixed_mask': np.zeros_like(seq_mask),
80
+ 'sc_ca_t': np.zeros(seq_mask.shape + (3, )),
81
+ }
82
+ chain_feats.update(patch_feats)
83
+ return chain_feats
84
+
85
+ @staticmethod
86
+ def strip_ends(chain_feats):
87
+ # Strip missing residues on both ends
88
+ modeled_idx = np.where(chain_feats['aatype'] != 20)[0]
89
+ min_idx, max_idx = np.min(modeled_idx), np.max(modeled_idx)
90
+ chain_feats = tree.map_structure(
91
+ lambda x: x[min_idx : (max_idx+1)], chain_feats)
92
+ return chain_feats
93
+
94
+ @staticmethod
95
+ def random_truncate(chain_feats, max_len):
96
+ L = chain_feats['aatype'].shape[0]
97
+ if L > max_len:
98
+ # Randomly truncate
99
+ start = np.random.randint(0, L - max_len + 1)
100
+ end = start + max_len
101
+ chain_feats = tree.map_structure(
102
+ lambda x: x[start : end], chain_feats)
103
+ return chain_feats
104
+
105
+ @staticmethod
106
+ def map_to_tensors(chain_feats):
107
+ chain_feats = {k: torch.as_tensor(v) for k,v in chain_feats.items()}
108
+ # Alter dtype
109
+ for k, dtype in DTYPE_MAPPING.items():
110
+ if k in chain_feats:
111
+ chain_feats[k] = chain_feats[k].type(dtype)
112
+ return chain_feats
113
+
114
+ @staticmethod
115
+ def recenter_and_scale_coords(chain_feats, coordinate_scale, eps=1e-8):
116
+ # recenter and scale atom positions
117
+ bb_pos = chain_feats['atom_positions'][:, CA_IDX]
118
+ bb_center = np.sum(bb_pos, axis=0) / (np.sum(chain_feats['seq_mask']) + eps)
119
+ centered_pos = chain_feats['atom_positions'] - bb_center[None, None, :]
120
+ scaled_pos = centered_pos * coordinate_scale
121
+ chain_feats['atom_positions'] = scaled_pos * chain_feats['atom_mask'][..., None]
122
+ return chain_feats
123
+
124
+ @staticmethod
125
+ def protein_data_transform(chain_feats):
126
+ chain_feats.update(
127
+ {
128
+ "all_atom_positions": chain_feats["atom_positions"],
129
+ "all_atom_mask": chain_feats["atom_mask"],
130
+ }
131
+ )
132
+ chain_feats = data_transforms.atom37_to_frames(chain_feats)
133
+ chain_feats = data_transforms.atom37_to_torsion_angles("")(chain_feats)
134
+ chain_feats = data_transforms.get_backbone_frames(chain_feats)
135
+ chain_feats = data_transforms.get_chi_angles(chain_feats)
136
+ chain_feats = data_transforms.make_pseudo_beta("")(chain_feats)
137
+ chain_feats = data_transforms.make_atom14_masks(chain_feats)
138
+ chain_feats = data_transforms.make_atom14_positions(chain_feats)
139
+
140
+ # Add convenient key
141
+ chain_feats.pop("all_atom_positions")
142
+ chain_feats.pop("all_atom_mask")
143
+ return chain_feats
144
+
145
+
146
+ class MetadataFilter:
147
+ def __init__(self,
148
+ min_len: Optional[int] = None,
149
+ max_len: Optional[int] = None,
150
+ min_chains: Optional[int] = None,
151
+ max_chains: Optional[int] = None,
152
+ min_resolution: Optional[int] = None,
153
+ max_resolution: Optional[int] = None,
154
+ include_structure_method: Optional[List[str]] = None,
155
+ include_oligomeric_detail: Optional[List[str]] = None,
156
+ **kwargs,
157
+ ):
158
+ self.min_len = min_len
159
+ self.max_len = max_len
160
+ self.min_chains = min_chains
161
+ self.max_chains = max_chains
162
+ self.min_resolution = min_resolution
163
+ self.max_resolution = max_resolution
164
+ self.include_structure_method = include_structure_method
165
+ self.include_oligomeric_detail = include_oligomeric_detail
166
+
167
+ def __call__(self, df):
168
+ _pre_filter_len = len(df)
169
+ if self.min_len is not None:
170
+ df = df[df['raw_seq_len'] >= self.min_len]
171
+ if self.max_len is not None:
172
+ df = df[df['raw_seq_len'] <= self.max_len]
173
+ if self.min_chains is not None:
174
+ df = df[df['num_chains'] >= self.min_chains]
175
+ if self.max_chains is not None:
176
+ df = df[df['num_chains'] <= self.max_chains]
177
+ if self.min_resolution is not None:
178
+ df = df[df['resolution'] >= self.min_resolution]
179
+ if self.max_resolution is not None:
180
+ df = df[df['resolution'] <= self.max_resolution]
181
+ if self.include_structure_method is not None:
182
+ df = df[df['include_structure_method'].isin(self.include_structure_method)]
183
+ if self.include_oligomeric_detail is not None:
184
+ df = df[df['include_oligomeric_detail'].isin(self.include_oligomeric_detail)]
185
+
186
+ print(f">>> Filter out {len(df)} samples out of {_pre_filter_len} by the metadata filter")
187
+ return df
188
+
189
+
190
+ class RandomAccessProteinDataset(torch.utils.data.Dataset):
191
+ """Random access to pickle protein objects of dataset.
192
+
193
+ dict_keys(['atom_positions', 'aatype', 'atom_mask', 'residue_index', 'chain_index', 'b_factors'])
194
+
195
+ Note that each value is a ndarray in shape (L, *), for example:
196
+ 'atom_positions': (L, 37, 3)
197
+ """
198
+ def __init__(self,
199
+ path_to_dataset: Union[Path, str],
200
+ path_to_seq_embedding: Optional[Path] = None,
201
+ metadata_filter: Optional[MetadataFilter] = None,
202
+ training: bool = True,
203
+ transform: Optional[ProteinFeatureTransform] = None,
204
+ suffix: Optional[str] = '.pkl',
205
+ accession_code_fillter: Optional[Sequence[str]] = None,
206
+ **kwargs,
207
+ ):
208
+ super().__init__()
209
+ path_to_dataset = os.path.expanduser(path_to_dataset)
210
+ suffix = suffix if suffix.startswith('.') else '.' + suffix
211
+ assert suffix in ('.pkl', '.pdb'), f"Invalid suffix: {suffix}"
212
+
213
+ if os.path.isfile(path_to_dataset): # path to csv file
214
+ assert path_to_dataset.endswith('.csv'), f"Invalid file extension: {path_to_dataset} (have to be .csv)"
215
+ self._df = pd.read_csv(path_to_dataset)
216
+ self._df.sort_values('modeled_seq_len', ascending=False)
217
+ if metadata_filter:
218
+ self._df = metadata_filter(self._df)
219
+ self._data = self._df['processed_complex_path'].tolist()
220
+ elif os.path.isdir(path_to_dataset): # path to directory
221
+ self._data = sorted(glob(os.path.join(path_to_dataset, '*' + suffix)))
222
+ assert len(self._data) > 0, f"No {suffix} file found in '{path_to_dataset}'"
223
+ else: # path as glob pattern
224
+ _pattern = path_to_dataset
225
+ self._data = sorted(glob(_pattern))
226
+ assert len(self._data) > 0, f"No files found in '{_pattern}'"
227
+
228
+ if accession_code_fillter and len(accession_code_fillter) > 0:
229
+ self._data = [p for p in self._data
230
+ if np.isin(os.path.splitext(os.path.basename(p))[0], accession_code_fillter)
231
+ ]
232
+
233
+ self.data = np.asarray(self._data)
234
+ self.path_to_seq_embedding = os.path.expanduser(path_to_seq_embedding) \
235
+ if path_to_seq_embedding is not None else None
236
+ self.suffix = suffix
237
+ self.transform = transform
238
+ self.training = training # not implemented yet
239
+
240
+
241
+ @property
242
+ def num_samples(self):
243
+ return len(self.data)
244
+
245
+ def len(self):
246
+ return self.__len__()
247
+
248
+ def __len__(self):
249
+ return self.num_samples
250
+
251
+ def get(self, idx):
252
+ return self.__getitem__(idx)
253
+
254
+ @lru_cache(maxsize=100)
255
+ def __getitem__(self, idx):
256
+ """return single pyg.Data() instance
257
+ """
258
+ data_path = self.data[idx]
259
+ accession_code = os.path.splitext(os.path.basename(data_path))[0]
260
+
261
+ if self.suffix == '.pkl':
262
+ # Load pickled protein
263
+ with open(data_path, 'rb') as f:
264
+ data_object = pickle.load(f)
265
+ elif self.suffix == '.pdb':
266
+ # Load pdb file
267
+ with open(data_path, 'r') as f:
268
+ pdb_string = f.read()
269
+ data_object = protein.from_pdb_string(pdb_string).to_dict()
270
+
271
+ # Apply data transform
272
+ if self.transform is not None:
273
+ data_object = self.transform(data_object)
274
+
275
+ # Get sequence embedding if have
276
+ if self.path_to_seq_embedding is not None:
277
+ embed_dict = torch.load(
278
+ os.path.join(self.path_to_seq_embedding, f"{accession_code}.pt")
279
+ )
280
+ data_object.update(
281
+ {
282
+ 'seq_emb': embed_dict['representations'][33].float(),
283
+ } # 33 is for ESM650M
284
+ )
285
+
286
+ data_object['accession_code'] = accession_code
287
+ return data_object # dict of arrays
288
+
289
+
290
+
291
+ class PretrainPDBDataset(RandomAccessProteinDataset):
292
+ def __init__(self,
293
+ path_to_dataset: str,
294
+ metadata_filter: MetadataFilter,
295
+ transform: ProteinFeatureTransform,
296
+ **kwargs,
297
+ ):
298
+ super(PretrainPDBDataset, self).__init__(path_to_dataset=path_to_dataset,
299
+ metadata_filter=metadata_filter,
300
+ transform=transform,
301
+ **kwargs,
302
+ )
303
+
304
+
305
+ class SamplingPDBDataset(RandomAccessProteinDataset):
306
+ def __init__(self,
307
+ path_to_dataset: str,
308
+ training: bool = False,
309
+ suffix: str = '.pdb',
310
+ transform: Optional[ProteinFeatureTransform] = None,
311
+ accession_code_fillter: Optional[Sequence[str]] = None,
312
+ ):
313
+ assert os.path.isdir(path_to_dataset), f"Invalid path (expected to be directory): {path_to_dataset}"
314
+ super(SamplingPDBDataset, self).__init__(path_to_dataset=path_to_dataset,
315
+ training=training,
316
+ suffix=suffix,
317
+ transform=transform,
318
+ accession_code_fillter=accession_code_fillter,
319
+ metadata_filter=None,
320
+ )
321
+
analysis/src/data/protein_datamodule.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, List, Sequence
2
+
3
+ import torch
4
+ from torch.utils.data import ConcatDataset, DataLoader, Dataset, random_split
5
+ from lightning import LightningDataModule
6
+ from hydra.utils import instantiate
7
+
8
+
9
+ class BatchTensorConverter:
10
+ """Callable to convert an unprocessed (labels + strings) batch to a
11
+ processed (labels + tensor) batch.
12
+ """
13
+ def __init__(self, target_keys: Optional[List] = None):
14
+ self.target_keys = target_keys
15
+
16
+ def __call__(self, raw_batch: Sequence[Dict[str, object]]):
17
+ B = len(raw_batch)
18
+ # Only do for Tensor
19
+ target_keys = self.target_keys \
20
+ if self.target_keys is not None else [k for k,v in raw_batch[0].items() if torch.is_tensor(v)]
21
+ # Non-array, for example string, int
22
+ non_array_keys = [k for k in raw_batch[0] if k not in target_keys]
23
+ collated_batch = dict()
24
+ for k in target_keys:
25
+ collated_batch[k] = self.collate_dense_tensors([d[k] for d in raw_batch], pad_v=0.0)
26
+ for k in non_array_keys: # return non-array keys as is
27
+ collated_batch[k] = [d[k] for d in raw_batch]
28
+ return collated_batch
29
+
30
+ @staticmethod
31
+ def collate_dense_tensors(samples: Sequence, pad_v: float = 0.0):
32
+ """
33
+ Takes a list of tensors with the following dimensions:
34
+ [(d_11, ..., d_1K),
35
+ (d_21, ..., d_2K),
36
+ ...,
37
+ (d_N1, ..., d_NK)]
38
+ and stack + pads them into a single tensor of:
39
+ (N, max_i=1,N { d_i1 }, ..., max_i=1,N {diK})
40
+ """
41
+ if len(samples) == 0:
42
+ return torch.Tensor()
43
+ if len(set(x.dim() for x in samples)) != 1:
44
+ raise RuntimeError(
45
+ f"Samples has varying dimensions: {[x.dim() for x in samples]}"
46
+ )
47
+ (device,) = tuple(set(x.device for x in samples)) # assumes all on same device
48
+ max_shape = [max(lst) for lst in zip(*[x.shape for x in samples])]
49
+ result = torch.empty(
50
+ len(samples), *max_shape, dtype=samples[0].dtype, device=device
51
+ )
52
+ result.fill_(pad_v)
53
+ for i in range(len(samples)):
54
+ result_i = result[i]
55
+ t = samples[i]
56
+ result_i[tuple(slice(0, k) for k in t.shape)] = t
57
+ return result
58
+
59
+
60
+ class ProteinDataModule(LightningDataModule):
61
+ """`LightningDataModule` for a single protein dataset,
62
+ for pretrain or finetune purpose.
63
+
64
+ ### To be revised.###
65
+
66
+ The MNIST database of handwritten digits has a training set of 60,000 examples, and a test set of 10,000 examples.
67
+ It is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a
68
+ fixed-size image. The original black and white images from NIST were size normalized to fit in a 20x20 pixel box
69
+ while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing
70
+ technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of
71
+ mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.
72
+
73
+ A `LightningDataModule` implements 7 key methods:
74
+
75
+ ```python
76
+ def prepare_data(self):
77
+ # Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
78
+ # Download data, pre-process, split, save to disk, etc...
79
+
80
+ def setup(self, stage):
81
+ # Things to do on every process in DDP.
82
+ # Load data, set variables, etc...
83
+
84
+ def train_dataloader(self):
85
+ # return train dataloader
86
+
87
+ def val_dataloader(self):
88
+ # return validation dataloader
89
+
90
+ def test_dataloader(self):
91
+ # return test dataloader
92
+
93
+ def predict_dataloader(self):
94
+ # return predict dataloader
95
+
96
+ def teardown(self, stage):
97
+ # Called on every process in DDP.
98
+ # Clean up after fit or test.
99
+ ```
100
+
101
+ This allows you to share a full dataset without explaining how to download,
102
+ split, transform and process the data.
103
+
104
+ Read the docs:
105
+ https://lightning.ai/docs/pytorch/latest/data/datamodule.html
106
+ """
107
+
108
+ def __init__(
109
+ self,
110
+ dataset: torch.utils.data.Dataset,
111
+ batch_size: int = 64,
112
+ generator_seed: int = 42,
113
+ train_val_split: Tuple[float, float] = (0.95, 0.05),
114
+ num_workers: int = 0,
115
+ pin_memory: bool = False,
116
+ shuffle: bool = False,
117
+ ) -> None:
118
+ """Initialize a `MNISTDataModule`.
119
+
120
+ :param data_dir: The data directory. Defaults to `"data/"`.
121
+ :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`.
122
+ :param batch_size: The batch size. Defaults to `64`.
123
+ :param num_workers: The number of workers. Defaults to `0`.
124
+ :param pin_memory: Whether to pin memory. Defaults to `False`.
125
+ """
126
+ super().__init__()
127
+
128
+ # this line allows to access init params with 'self.hparams' attribute
129
+ # also ensures init params will be stored in ckpt
130
+ self.save_hyperparameters(logger=False)
131
+
132
+ self.dataset = dataset
133
+
134
+ self.data_train: Optional[Dataset] = None
135
+ self.data_val: Optional[Dataset] = None
136
+ self.data_test: Optional[Dataset] = None
137
+
138
+ self.batch_size_per_device = batch_size
139
+
140
+ def prepare_data(self) -> None:
141
+ """Download data if needed. Lightning ensures that `self.prepare_data()` is called only
142
+ within a single process on CPU, so you can safely add your downloading logic within. In
143
+ case of multi-node training, the execution of this hook depends upon
144
+ `self.prepare_data_per_node()`.
145
+
146
+ Do not use it to assign state (self.x = y).
147
+ """
148
+ pass
149
+
150
+ def setup(self, stage: Optional[str] = None) -> None:
151
+ """Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
152
+
153
+ This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
154
+ `trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
155
+ `self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
156
+ `self.setup()` once the data is prepared and available for use.
157
+
158
+ :param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
159
+ """
160
+ # Divide batch size by the number of devices.
161
+ if self.trainer is not None:
162
+ if self.hparams.batch_size % self.trainer.world_size != 0:
163
+ raise RuntimeError(
164
+ f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
165
+ )
166
+ self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
167
+
168
+ # load and split datasets only if not loaded already
169
+ if stage == 'fit' and not self.data_train and not self.data_val:
170
+ # dataset = ConcatDataset(datasets=[trainset, testset])
171
+ self.data_train, self.data_val = random_split(
172
+ dataset=self.dataset,
173
+ lengths=self.hparams.train_val_split,
174
+ generator=torch.Generator().manual_seed(self.hparams.generator_seed),
175
+ )
176
+ elif stage in ('predict', 'test'):
177
+ self.data_test = self.dataset
178
+ else:
179
+ raise NotImplementedError(f"Stage {stage} not implemented.")
180
+
181
+ def _dataloader_template(self, dataset: Dataset[Any]) -> DataLoader[Any]:
182
+ """Create a dataloader from a dataset.
183
+
184
+ :param dataset: The dataset.
185
+ :return: The dataloader.
186
+ """
187
+ batch_collator = BatchTensorConverter() # list of dicts -> dict of tensors
188
+ return DataLoader(
189
+ dataset=dataset,
190
+ collate_fn=batch_collator,
191
+ batch_size=self.batch_size_per_device,
192
+ num_workers=self.hparams.num_workers,
193
+ pin_memory=self.hparams.pin_memory,
194
+ shuffle=self.hparams.shuffle,
195
+ )
196
+
197
+ def train_dataloader(self) -> DataLoader[Any]:
198
+ """Create and return the train dataloader.
199
+
200
+ :return: The train dataloader.
201
+ """
202
+ return self._dataloader_template(self.data_train)
203
+
204
+
205
+ def val_dataloader(self) -> DataLoader[Any]:
206
+ """Create and return the validation dataloader.
207
+
208
+ :return: The validation dataloader.
209
+ """
210
+ return self._dataloader_template(self.data_val)
211
+
212
+ def test_dataloader(self) -> DataLoader[Any]:
213
+ """Create and return the test dataloader.
214
+
215
+ :return: The test dataloader.
216
+ """
217
+ return self._dataloader_template(self.data_test)
218
+
219
+ def teardown(self, stage: Optional[str] = None) -> None:
220
+ """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
221
+ `trainer.test()`, and `trainer.predict()`.
222
+
223
+ :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
224
+ Defaults to ``None``.
225
+ """
226
+ pass
227
+
228
+ def state_dict(self) -> Dict[Any, Any]:
229
+ """Called when saving a checkpoint. Implement to generate and save the datamodule state.
230
+
231
+ :return: A dictionary containing the datamodule state that you want to save.
232
+ """
233
+ return {}
234
+
235
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
236
+ """Called when loading a checkpoint. Implement to reload datamodule state given datamodule
237
+ `state_dict()`.
238
+
239
+ :param state_dict: The datamodule state returned by `self.state_dict()`.
240
+ """
241
+ pass
242
+
analysis/src/eval.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+ import os
3
+ from time import strftime
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ # import hydra
9
+ # import rootutils
10
+ # from lightning import LightningDataModule, LightningModule, Trainer
11
+ # from lightning.pytorch.loggers import Logger
12
+ from omegaconf import DictConfig
13
+
14
+ # rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
15
+
16
+ # ------------------------------------------------------------------------------------ #
17
+ # the setup_root above is equivalent to:
18
+ # - adding project root dir to PYTHONPATH
19
+ # (so you don't need to force user to install project as a package)
20
+ # (necessary before importing any local modules e.g. `from src import utils`)
21
+ # - setting up PROJECT_ROOT environment variable
22
+ # (which is used as a base for paths in "configs/paths/default.yaml")
23
+ # (this way all filepaths are the same no matter where you run the code)
24
+ # - loading environment variables from ".env" in root dir
25
+ #
26
+ # you can remove it if you:
27
+ # 1. either install project as a package or move entry files to project root dir
28
+ # 2. set `root_dir` to "." in "configs/paths/default.yaml"
29
+ #
30
+ # more info: https://github.com/ashleve/rootutils
31
+ # ------------------------------------------------------------------------------------ #
32
+
33
+ from src.utils import (
34
+ RankedLogger,
35
+ extras,
36
+ instantiate_loggers,
37
+ log_hyperparameters,
38
+ task_wrapper,
39
+ checkpoint_utils,
40
+ plot_utils,
41
+ )
42
+ from src.common.pdb_utils import extract_backbone_coords
43
+ from src.metrics import metrics
44
+ from src.common.geo_utils import _find_rigid_alignment
45
+
46
+ log = RankedLogger(__name__, rank_zero_only=True)
47
+
48
+
49
+ def evaluate_prediction(pred_dir: str, target_dir: str = None, crystal_dir: str = None, tag: str = None):
50
+ """Evaluate prediction results based on pdb files.
51
+ """
52
+ if target_dir is None or not os.path.isdir(target_dir):
53
+ log.warning(f"target_dir {target_dir} does not exist. Skip evaluation.")
54
+ return {}
55
+
56
+ assert os.path.isdir(pred_dir), f"pred_dir {pred_dir} is not a directory."
57
+
58
+ targets = [
59
+ d.replace(".pdb", "") for d in os.listdir(target_dir)
60
+ ]
61
+ # pred_bases = os.listdir(pred_dir)
62
+ output_dir = pred_dir
63
+ tag = tag if tag is not None else "dev"
64
+ timestamp = strftime("%m%d-%H-%M")
65
+
66
+ fns = {
67
+ 'val_clash': metrics.validity,
68
+ 'val_bond': metrics.bonding_validity,
69
+ 'js_pwd': metrics.js_pwd,
70
+ 'js_rg': metrics.js_rg,
71
+ # 'js_tica_pos': metrics.js_tica_pos,
72
+ 'w2_rmwd': metrics.w2_rmwd,
73
+ # 'div_rmsd': metrics.div_rmsd,
74
+ 'div_rmsf': metrics.div_rmsf,
75
+ 'pro_w_contacks': metrics.pro_w_contacts,
76
+ 'pro_t_contacks': metrics.pro_t_contacts,
77
+ # 'pro_c_contacks': metrics.pro_c_contacts,
78
+ }
79
+ eval_res = {k: {} for k in fns}
80
+
81
+
82
+ print(f"total_md_num = {len(targets)}")
83
+ count = 0
84
+ for target in targets:
85
+ count += 1
86
+ print("")
87
+ print(count, target)
88
+ pred_file = os.path.join(pred_dir, f"{target}.pdb")
89
+ # assert os.path.isfile(pred_file), f"pred_file {pred_file} does not exist."
90
+ if not os.path.isfile(pred_file):
91
+ continue
92
+
93
+ target_file = os.path.join(target_dir, f"{target}.pdb")
94
+ ca_coords = {
95
+ 'target': extract_backbone_coords(target_file),
96
+ 'pred': extract_backbone_coords(pred_file),
97
+ }
98
+ cry_target_file = os.path.join(crystal_dir, f"{target}.pdb")
99
+ cry_ca_coords = extract_backbone_coords(cry_target_file)[0]
100
+
101
+
102
+ for f_name, func in fns.items():
103
+ print(f_name)
104
+
105
+
106
+ if f_name == 'w2_rmwd':
107
+ v_ref = torch.as_tensor(ca_coords['target'][0])
108
+ for k, v in ca_coords.items():
109
+ v = torch.as_tensor(v) # (250,356,3)
110
+ for idx in range(v.shape[0]):
111
+ R, t = _find_rigid_alignment(v[idx], v_ref)
112
+ v[idx] = (torch.matmul(R, v[idx].transpose(-2, -1))).transpose(-2, -1) + t.unsqueeze(0)
113
+ ca_coords[k] = v.numpy()
114
+
115
+
116
+ if f_name.startswith('js_'):
117
+ res = func(ca_coords, ref_key='target')
118
+ elif f_name == 'pro_c_contacks':
119
+ res = func(target_file, pred_file, cry_target_file)
120
+ elif f_name.startswith('pro_'):
121
+ res = func(ca_coords, cry_ca_coords)
122
+ else:
123
+ res = func(ca_coords)
124
+
125
+ if f_name == 'js_tica' or f_name == 'js_tica_pos':
126
+ pass
127
+ # eval_res[f_name][target] = res[0]['pred']
128
+ # save_to = os.path.join(output_dir, f"tica_{target}_{tag}_{timestamp}.png")
129
+ # plot_utils.scatterplot_2d(res[1], save_to=save_to, ref_key='target')
130
+ else:
131
+ eval_res[f_name][target] = res['pred']
132
+
133
+ csv_save_to = os.path.join(output_dir, f"metrics_{tag}_{timestamp}.csv")
134
+ df = pd.DataFrame.from_dict(eval_res) # row = target, col = metric name
135
+ df.to_csv(csv_save_to)
136
+ print(f"metrics saved to {csv_save_to}")
137
+ mean_metrics = np.around(df.mean(), decimals=4)
138
+
139
+ return mean_metrics
140
+
141
+
142
+ # @task_wrapper
143
+ # def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
144
+ # """Sample on a test set and report evaluation metrics.
145
+
146
+ # This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
147
+ # failure. Useful for multiruns, saving info about the crash, etc.
148
+
149
+ # :param cfg: DictConfig configuration composed by Hydra.
150
+ # :return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
151
+ # """
152
+ # # assert cfg.ckpt_path
153
+ # pred_dir = cfg.get("pred_dir")
154
+ # if pred_dir and os.path.isdir(pred_dir):
155
+ # log.info(f"Found pre-computed prediction directory {pred_dir}.")
156
+ # metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir)
157
+ # return metric_dict, None
158
+
159
+ # log.info(f"Instantiating datamodule <{cfg.data._target_}>")
160
+ # datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
161
+
162
+ # log.info(f"Instantiating model <{cfg.model._target_}>")
163
+ # model: LightningModule = hydra.utils.instantiate(cfg.model)
164
+
165
+ # log.info("Instantiating loggers...")
166
+ # logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
167
+
168
+ # log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
169
+ # trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
170
+
171
+ # object_dict = {
172
+ # "cfg": cfg,
173
+ # "datamodule": datamodule,
174
+ # "model": model,
175
+ # "logger": logger,
176
+ # "trainer": trainer,
177
+ # }
178
+
179
+ # if logger:
180
+ # log.info("Logging hyperparameters!")
181
+ # log_hyperparameters(object_dict)
182
+
183
+ # # Load checkpoint manually.
184
+ # model, ckpt_path = checkpoint_utils.load_model_checkpoint(model, cfg.ckpt_path)
185
+
186
+ # # log.info("Starting testing!")
187
+ # # trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
188
+
189
+ # # Get dataloader for prediction.
190
+ # datamodule.setup(stage="predict")
191
+ # dataloaders = datamodule.test_dataloader()
192
+
193
+ # log.info("Starting predictions.")
194
+ # pred_dir = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=ckpt_path)[-1]
195
+
196
+ # # metric_dict = trainer.callback_metrics
197
+ # log.info("Starting evaluations.")
198
+ # metric_dict = evaluate_prediction(pred_dir, target_dir=cfg.target_dir)
199
+
200
+ # return metric_dict, object_dict
201
+
202
+
203
+ # @hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
204
+ # def main(cfg: DictConfig) -> None:
205
+ # """Main entry point for evaluation.
206
+
207
+ # :param cfg: DictConfig configuration composed by Hydra.
208
+ # """
209
+ # # apply extra utilities
210
+ # # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
211
+ # extras(cfg)
212
+
213
+ # evaluate(cfg)
214
+
215
+
216
+ # if __name__ == "__main__":
217
+ # main()