initial commit

This commit is contained in:
Rohith Krishna 2024-03-04 22:38:17 -08:00
commit f87f5b8cdf
98 changed files with 26515 additions and 0 deletions

15
.gitignore vendored Normal file
View file

@ -0,0 +1,15 @@
valid_remapped
lig_test
dataset.pkl
run_digs.sh
*.pdb
.vscode
slurm_logs/
*/output/
*/notebooks/
*/models/
__pycache__/
*/run_scripts/
unit_tests/
ruff.toml
*/scratch/

254
README.md Normal file
View file

@ -0,0 +1,254 @@
Code for RoseTTAFold All-Atom
--------------------
<p align="right">
<img style="float: right" src="./img/RFAA.png" alt="alt text" width="600px" align="right"/>
</p>
RoseTTAFold All-Atom is a biomolecular structure prediction neural network that can predict a broad range of biomolecular assemblies including proteins, nucleic acids, small molecules, covalent modifications and metals as outlined in the RFAA paper.
RFAA is not accurate for all cases, but produces useful error estimates to allow users to identify accurate predictions. Below are the instructions for setting up and using the model.
## Table of Contents
- [Setup/Installation](#set-up)
- [Inference Configs Using Hydra](#inference-config)
- [Predicting protein structures](#protein-pred)
- [Predicting protein/nucleic acid complexes](#p-na-complex)
- [Predicting protein/small molecule complexes](#p-sm-complex)
- [Predicting higher order complexes](#higher-order)
- [Predicting covalently modified proteins](#covale)
- [Understanding model outputs](#outputs)
- [Conclusion](#conclusion)
<a id="set-up"></a>
### Setup/Installation
1. Clone the package
```
git clone https://github.com/baker-laboratory/RoseTTAFold-All-Atom
cd RF2-allatom
```
2. Download the container used to run RFAA.
```
wget http://files.ipd.uw.edu/pub/RF-All-Atom/containers/SE3nv-20240131.sif
```
3. Download the model weights.
```
wget http://files.ipd.uw.edu/pub/RF-All-Atom/weights/RFAA_paper_weights.pt
```
4. Download sequence databases for MSA and template generation.
```
# uniref30 [46G]
wget http://wwwuser.gwdg.de/~compbiol/uniclust/2020_06/UniRef30_2020_06_hhsuite.tar.gz
mkdir -p UniRef30_2020_06
tar xfz UniRef30_2020_06_hhsuite.tar.gz -C ./UniRef30_2020_06
# BFD [272G]
wget https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz
mkdir -p bfd
tar xfz bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz -C ./bfd
# structure templates (including *_a3m.ffdata, *_a3m.ffindex)
wget https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2021Mar03.tar.gz
tar xfz pdb100_2021Mar03.tar.gz
```
<a id="inference-config"></a>
### Inference Configs Using Hydra
We use a library called Hydra to compose config files for predictions. The actual script that runs the model is in `rf2aa/run_inference.py` and default parameters that were used to train the model are in `rf2aa/config/inference/base.yaml`. We highly suggest using the default parameters since those are closest to the training task for RFAA but we have found that increasing loader_params.MAXCYCLE=10 (default set to 4) gives better results for hard cases (as noted in the paper).
We use a container system called apptainers which have very simple syntax. Instead of developing a local conda environment, users can use the apptainer to run the model which has all the dependencies already packaged.
The general way to run the model is as follows:
```
SE3nv-20240131.sif -m rf2aa.run_inference --config-name {your inference config}
```
The main inputs into the model are split into:
- protein inputs (protein_inputs)
- nucleic acid inputs (na_inputs)
- small molecule inputs (sm_inputs)
- covalent bonds between protein chains and small molecule chains
- modified or unnatural amino acids (COMING SOON)
In the following sections, we will describe how to set up configs for different prediction tasks that we described in the paper.
<a id="protein-pred"></a>
### Predicting Protein Monomers
Predicting a protein monomer structure requires an input fasta file and an optional job_name which will be used to name your output files. Here is a sample config (also in `rf2aa/config/inference/protein.yaml`).
```
defaults:
- base
job_name: "7u7w_protein"
protein_inputs:
A:
fasta_file: examples/protein/7u7w_A.fasta
```
The first line indicates that this job inherits all the configurations from the base file (this should be true for all your inference jobs). Then you can optionally specify the job name (the default job_name is "structure_prediction" so we highly recommend specifying one).
When specifying the fasta file for your protein, you might notice that it is nested within a mysterious "A" parameter. This represents a chain letter and is absolutely **required**, this is important when users are specifying multiple chains.
Now to predict the sample monomer structure, run:
```
SE3nv-20240131.sif -m rf2aa.run_inference --config-name protein
```
<a id="p-na-complex"></a>
### Predicting Protein Nucleic Acid Complexes
Protein-nucleic acid complexes have very similar syntax to protein monomer prediction, except with additional chains for nucleic acids. Here is sample config (also in `rf2aa/config/inference/nucleic_acid.yaml`):
```
defaults:
- base
job_name: "7u7w_protein_nucleic"
protein_inputs:
A:
fasta_file: examples/protein/7u7w_A.fasta
na_inputs:
B:
fasta: examples/nucleic_acid/7u7w_B.fasta
input_type: "dna"
C:
fasta: examples/nucleic_acid/7u7w_C.fasta
input_type: "dna"
```
Once again this config inherits the base config, defines a job name and provides a protein fasta file for chain A. To add double stranded DNA, you must add two more chain inputs for each strand (shown here as chains B and C). In this case, the allowed input types are dna and rna.
This repo currently does not support making RNA MSAs or pairing protein MSAs with RNA MSAs but this is functionality that we are keen to add. For now, please use RF-NA for modeling cases requiring paired protein-RNA MSAs.
Now, predict the example protein/NA complex.
```
SE3nv-20240131.sif -m rf2aa.run_inference --config-name nucleic_acid
```
<a id="p-sm-complex"></a>
### Predicting Protein Small Molecule Complexes
To predict protein small molecule complexes, the syntax to input the protein remains the same. Adding in the small molecule works similarly to other inputs.
Here is an example (from `rf2aa/config/inference/protein_sm.yaml`):
```
defaults:
- base
job_name: 7qxr
protein_inputs:
A:
fasta_file: examples/protein/7qxr.fasta
sm_inputs:
B:
input: examples/small_molecule/NSW_ideal.sdf
input_type: "sdf"
```
Small molecule inputs are provided as sdf files or smiles strings and users are **required** to provide both an input and an input_type field for every small molecule that they want to provide. Metal ions can also be provided as sdf files or smiles strings.
To predict the example:
```
SE3nv-20240131.sif -m rf2aa.run_inference --config-name protein_sm
```
<a id="higher-order"></a>
### Predicting Higher Order Complexes
If you have been following thus-far, this is where we put all the previous sections together! To predict a protein-nucleic acid-small molecule complex, you can combine the schema for all the inputs we have seen so far!
Here is an example:
```
defaults:
- base
job_name: "7u7w_protein_nucleic_sm"
protein_inputs:
A:
fasta_file: examples/protein/7u7w_A.fasta
na_inputs:
B:
fasta: examples/nucleic_acid/7u7w_B.fasta
input_type: "dna"
C:
fasta: examples/nucleic_acid/7u7w_C.fasta
input_type: "dna"
sm_inputs:
D:
input: examples/small_molecule/XG4.sdf
input_type: "sdf"
```
And to run:
```
SE3nv-20240131.sif -m rf2aa.run_inference --config-name protein_na_sm
```
<a id="covale"></a>
### Predicting Covalently Modified Proteins
Specifying covalent modifications is slightly more complicated for the following reasons.
- Forming new covalent bonds can create or remove chiral centers. Since RFAA specifies chirality at input, the network needs to be provided with chirality information. Under the hood, chiral centers are identified by a package called Openbabel which does not always agree with chemical intuition.
- Covalent modifications often have "leaving groups", or chemical groups that leave both the protein and the modification upon modification.
The way you input covalent bonds to RFAA is as a list of bonds between an atom on the protein and an atom on one of the input small molecules. This is the syntax for those bonds:
```
(protein_chain, residue_number, atom_name), (small_molecule_chain, atom_index), (new_chirality_atom_1, new_chirality_atom_2)
```
**Both the protein residue number and the atom_index are 1 indexed** (as you would normally count, as opposed to 0 indexed like many programming languages).
In most cases, the chirality of the atoms will not change. This is what an input for a case where the chirality does not change looks like:
```
(protein_chain, residue_number, atom_name), (small_molecule_chain, atom_index), ("null", "null")
```
The options for chirality are `CCW` and `CW` for counterclockwise and clockwise. The code will raise an Exception is there is a chiral center that Openbabel found that the user did not specify. Even if you believe Openbabel is wrong, the network likely received chirality information for those cases during training, so we expect that you will get the best results by specifying chirality at those positions.
**You cannot define bonds between two small molecule chains**. In cases, where the PDB defines molecules in "multiple residues", you must merge the residues into a single sdf file first.
**You must remove any leaving groups from your input molecules before inputting them into the network, but the code will handle leaving groups on the sidechain that is being modified automatically.** There is code for providing leaving group dynamically from the hydra config, but that is experimental and we have not fully tested it.
Given all of that background, this is how you specify covalent modification structure prediction to RFAA.
```
defaults:
- base
job_name: 7s69_A
protein_inputs:
A:
fasta_file: examples/protein/7s69_A.fasta
sm_inputs:
B:
input: examples/small_molecule/7s69_glycan.sdf
input_type: sdf
covale_inputs: "[((\"A\", \"74\", \"ND2\"), (\"B\", \"1\"), (\"CW\", \"null\"))]"
loader_params:
MAXCYCLE: 10
```
**For covalently modified proteins, you must provide the input molecule as a sdf file**, since openbabel does not read smiles strings in a specific order. The syntax shown is identical to loading a protein and small molecule and then indicating a bond between them. In this case, hydra creates some problems because we have to escape the quotation marks using backslashes.
To clarify, this input:
```
[(("A", "74", "ND2"), ("B", "1"), ("CW", "null"))]
```
becomes this so it can be parsed correctly:
```
"[((\"A\", \"74\", \"ND2\"), (\"B\", \"1\"), (\"CW\", \"null\"))]"
```
We know this syntax is hard to work with and we are happy to review PRs if anyone in the community can figure out how to specify all the necessary requirements in a more user friendly way!
<a id="outputs"></a>
### Understanding model outputs
The model returns two files:
- PDB file with predicted structure (bfactors represent predicted lddt at each position)
- pytorch file with confidence metrics stored (can load with `torch.load(file, map_location="cpu")`)
Here are the confidence metrics:
1. plddts, tensor with node-wise plddt for each node in the prediction
2. pae, a LxL tensor where the model predicts the error of every j position if the ith position's frame is aligned (or atom frame for atom nodes)
3. pde, a LxL tensor where the model predicts the unsigned error of the each pairwise distance
4. mean_plddt, the mean over all the plddts
5. mean_pae, the mean over all pairwise predicted aligned errors
6. pae_prot, the mean over all pairwise protein residues
7. pae_inter, the mean over all the errors of protein residues with respect to atom frames and atom coordinates with respect to protein frames. **This was the primary confidence metric we used in the paper and expect cases with pae_inter <10 to have high quality docks.**
<a id="conclusion"></a>
### Conclusion
We expect that RFAA will continue to improve and will share new models as we create them. Additionally, we are excited to see how the community uses RFAA and RFdiffusionAA and would love to get feedback and review PRs as necessary.

0
__init__.py Normal file
View file

View file

@ -0,0 +1,2 @@
>7U7W_2|Chain B[auth T]|DNA (5'-D(*CP*AP*TP*TP*AP*TP*GP*AP*CP*GP*CP*T)-3')|synthetic construct (32630)
CATTATGACGCT

View file

@ -0,0 +1,2 @@
>7U7W_3|Chain C[auth P]|DNA (5'-D(*AP*GP*CP*GP*TP*CP*AP*T)-3')|synthetic construct (32630)
AGCGTCAT

View file

@ -0,0 +1,2 @@
>7QXR_1|Chains A, B, C|Fragment transplantation onto a hyperstable ancestor of haloalkane dehalogenases and Renilla luciferase (Anc-FT)|synthetic construct (32630)
TATGDEWWAKCKQVDVLDSEMSYYDSDPGKHKNTVIFLHGNPTSSYLWRNVIPHVEPLARCLAPDLIGMGKSGKLPNHSYRFVDHYRYLSAWFDSVNLPEKVTIVCHDWGSGLGFHWCNEHRDRVKGIVHMESVVDVIESWDEWPDIEEDIALIKSEAGEEMVLKKNFFIERLLPSSIMRKLSEEEMDAYREPFVEPGESRRPTLTWPREIPIKGDGPEDVIEIVKSYNKWLSTSKDIPKLFINADPGFFSNAIKKVTKNWPNQKTVTVKGLHFLQEDSPEEIGEAIADFLNELT

View file

@ -0,0 +1,2 @@
>7S69_1|Chains A, B|N-acetylglucosamine-1-phosphotransferase gamma subunit|Xenopus laevis (8355)
DRHHHHHHKLGKMKIVEEPNSFGLNNPFLSQTNKLQPRVQPSPVSGPSHLFRLAGKCFNLVESTYKYELCPFHNVTQHEQTFRWNAYSGILGIWQEWDIENNTFSGMWMREGDSCGNKNRQTKVLLVCGKANKLSSVSEPSTCLYSLTFETPLVCHPHSLLVYPTLSEGLQEKWNEAEQALYDELITEQGHGKILKEIFREAGYLKTTKPDGEGKETQDKPKEFDSLEKCNKGYTELTSEIQRLKKMLNEHGISYVTNGTSRSEGQPAEVNTTFARGEDKVHLRGDTGIRDGQ

View file

@ -0,0 +1,2 @@
>7U7W_1|Chain A|DNA polymerase eta|Homo sapiens (9606)
GPHMATGQDRVVALVDMDCFFVQVEQRQNPHLRNKPCAVVQYKSWKGGGIIAVSYEARAFGVTRSMWADDAKKLCPDLLLAQVRESRGKANLTKYREASVEVMEIMSRFAVIERASIDEAYVDLTSAVQERLQKLQGQPISADLLPSTYIEGLPQGPTTAEETVQKEGMRKQGLFQWLDSLQIDNLTSPDLQLTVGAVIVEEMRAAIERETGFQCSAGISHNKVLAKLACGLNKPNRQTLVSHGSVPQLFSQMPIRKIRSLGGKLGASVIEILGIEYMGELTQFTESQLQSHFGEKNGSWLYAMCRGIEHDPVKPRQLPKTIGCSKNFPGKTALATREQVQWWLLQLAQELEERLTKDRNDNDRVATQLVVSIRVQGDKRLSSLRRCCALTRYDAHKMSHDAFTVIKNCNTSGIQTEWSPPLTMLFLCATKFSAS

View file

@ -0,0 +1,155 @@
OpenBabel03042416223D
72 77 0 0 1 0 0 0 0 0999 V2000
29.7340 3.2540 76.7430 C 0 0 0 0 0 2 0 0 0 0 0 0
29.8160 4.4760 77.6460 C 0 0 1 0 0 3 0 0 0 0 0 0
28.5260 5.2840 77.5530 C 0 0 2 0 0 3 0 0 0 0 0 0
28.1780 5.5830 76.1020 C 0 0 1 0 0 3 0 0 0 0 0 0
28.2350 4.3240 75.2420 C 0 0 1 0 0 3 0 0 0 0 0 0
28.1040 4.6170 73.7650 C 0 0 0 0 0 2 0 0 0 0 0 0
31.3020 3.8250 79.4830 C 0 0 0 0 0 0 0 0 0 0 0 0
31.3910 3.4410 80.9280 C 0 0 0 0 0 1 0 0 0 0 0 0
30.0760 4.0880 79.0210 N 0 0 0 0 0 2 0 0 0 0 0 0
28.6870 6.5050 78.2670 O 0 0 0 0 0 1 0 0 0 0 0 0
26.8490 6.0910 76.0350 O 0 0 0 0 0 0 0 0 0 0 0 0
29.4950 3.6650 75.4130 O 0 0 0 0 0 0 0 0 0 0 0 0
29.3670 4.5550 73.1150 O 0 0 0 0 0 1 0 0 0 0 0 0
32.2950 3.8940 78.7640 O 0 0 0 0 0 0 0 0 0 0 0 0
26.7420 7.4140 75.6950 C 0 0 1 0 0 3 0 0 0 0 0 0
25.2700 7.7830 75.6110 C 0 0 1 0 0 3 0 0 0 0 0 0
25.1290 9.2300 75.1610 C 0 0 2 0 0 3 0 0 0 0 0 0
25.9180 10.1440 76.0880 C 0 0 1 0 0 3 0 0 0 0 0 0
27.3630 9.6720 76.2210 C 0 0 1 0 0 3 0 0 0 0 0 0
28.1310 10.4360 77.2730 C 0 0 0 0 0 2 0 0 0 0 0 0
23.8820 5.8170 75.1400 C 0 0 0 0 0 0 0 0 0 0 0 0
23.1980 5.0100 74.0810 C 0 0 0 0 0 1 0 0 0 0 0 0
24.5530 6.8930 74.7160 N 0 0 0 0 0 2 0 0 0 0 0 0
23.7530 9.5950 75.1670 O 0 0 0 0 0 1 0 0 0 0 0 0
25.9170 11.4700 75.5730 O 0 0 0 0 0 0 0 0 0 0 0 0
27.4050 8.2900 76.6040 O 0 0 0 0 0 0 0 0 0 0 0 0
29.5300 10.4030 77.0280 O 0 0 0 0 0 1 0 0 0 0 0 0
23.8300 5.5110 76.3290 O 0 0 0 0 0 0 0 0 0 0 0 0
25.3940 12.4250 76.4090 C 0 0 1 0 0 3 0 0 0 0 0 0
25.9490 13.7680 75.9090 C 0 0 2 0 0 3 0 0 0 0 0 0
25.1320 14.9560 76.4900 C 0 0 2 0 0 3 0 0 0 0 0 0
23.6130 14.6900 76.6390 C 0 0 1 0 0 3 0 0 0 0 0 0
23.3700 13.3000 77.2280 C 0 0 1 0 0 3 0 0 0 0 0 0
21.9020 12.9360 77.3500 C 0 0 0 0 0 2 0 0 0 0 0 0
25.9010 13.8490 74.4810 O 0 0 0 0 0 1 0 0 0 0 0 0
25.3420 16.1410 75.7110 O 0 0 0 0 0 0 0 0 0 0 0 0
23.0420 15.6520 77.5170 O 0 0 0 0 0 1 0 0 0 0 0 0
23.9910 12.3690 76.3570 O 0 0 0 0 0 0 0 0 0 0 0 0
21.3660 12.8480 76.0500 O 0 0 0 0 0 0 0 0 0 0 0 0
20.8090 11.6500 75.6780 C 0 0 2 0 0 3 0 0 0 0 0 0
20.6800 11.6410 74.1740 C 0 0 2 0 0 3 0 0 0 0 0 0
19.5510 12.5850 73.8180 C 0 0 2 0 0 3 0 0 0 0 0 0
18.2370 12.0940 74.4540 C 0 0 1 0 0 3 0 0 0 0 0 0
18.4030 11.9240 75.9810 C 0 0 1 0 0 3 0 0 0 0 0 0
17.2710 11.1260 76.6120 C 0 0 0 0 0 2 0 0 0 0 0 0
20.2900 10.3510 73.7080 O 0 0 0 0 0 1 0 0 0 0 0 0
19.4280 12.7380 72.4110 O 0 0 0 0 0 0 0 0 0 0 0 0
17.2120 13.0460 74.2030 O 0 0 0 0 0 1 0 0 0 0 0 0
19.6260 11.2000 76.3010 O 0 0 0 0 0 0 0 0 0 0 0 0
16.0670 11.4490 75.9360 O 0 0 0 0 0 1 0 0 0 0 0 0
20.2190 13.6280 71.7260 C 0 0 2 0 0 3 0 0 0 0 0 0
19.6090 14.0000 70.3810 C 0 0 2 0 0 3 0 0 0 0 0 0
19.6360 12.7820 69.4880 C 0 0 2 0 0 3 0 0 0 0 0 0
21.0860 12.3100 69.3240 C 0 0 1 0 0 3 0 0 0 0 0 0
21.7030 12.0240 70.7120 C 0 0 1 0 0 3 0 0 0 0 0 0
23.1940 11.7460 70.6620 C 0 0 0 0 0 2 0 0 0 0 0 0
20.4080 14.9810 69.7000 O 0 0 0 0 0 1 0 0 0 0 0 0
19.0310 13.0500 68.2340 O 0 0 0 0 0 1 0 0 0 0 0 0
21.1060 11.1280 68.5380 O 0 0 0 0 0 1 0 0 0 0 0 0
21.5380 13.1700 71.5840 O 0 0 0 0 0 0 0 0 0 0 0 0
23.8240 12.5210 71.6820 O 0 0 0 0 0 1 0 0 0 0 0 0
26.0070 17.3020 76.0200 C 0 0 2 0 0 3 0 0 0 0 0 0
27.0750 17.5250 74.9350 C 0 0 2 0 0 3 0 0 0 0 0 0
28.3660 16.8320 75.3290 C 0 0 2 0 0 3 0 0 0 0 0 0
28.7820 17.2470 76.7510 C 0 0 1 0 0 3 0 0 0 0 0 0
27.6930 16.8120 77.7320 C 0 0 1 0 0 3 0 0 0 0 0 0
27.9770 17.2020 79.1710 C 0 0 0 0 0 2 0 0 0 0 0 0
27.3990 18.9140 74.8010 O 0 0 0 0 0 1 0 0 0 0 0 0
29.4060 17.0990 74.3950 O 0 0 0 0 0 1 0 0 0 0 0 0
30.0160 16.6410 77.0930 O 0 0 0 0 0 1 0 0 0 0 0 0
26.4610 17.4820 77.3520 O 0 0 0 0 0 0 0 0 0 0 0 0
27.3660 18.4620 79.4040 O 0 0 0 0 0 1 0 0 0 0 0 0
1 2 1 0 0 0 0
1 12 1 0 0 0 0
2 3 1 0 0 0 0
2 9 1 1 0 0 0
3 10 1 1 0 0 0
3 4 1 0 0 0 0
4 5 1 0 0 0 0
4 11 1 1 0 0 0
5 6 1 6 0 0 0
5 12 1 0 0 0 0
6 13 1 0 0 0 0
7 14 2 0 0 0 0
7 8 1 0 0 0 0
7 9 1 0 0 0 0
15 16 1 0 0 0 0
15 11 1 1 0 0 0
15 26 1 0 0 0 0
16 23 1 6 0 0 0
16 17 1 0 0 0 0
17 18 1 0 0 0 0
17 24 1 1 0 0 0
18 25 1 6 0 0 0
18 19 1 0 0 0 0
19 20 1 1 0 0 0
19 26 1 0 0 0 0
20 27 1 0 0 0 0
21 22 1 0 0 0 0
21 23 1 0 0 0 0
21 28 2 0 0 0 0
29 38 1 0 0 0 0
29 25 1 6 0 0 0
29 30 1 0 0 0 0
30 35 1 6 0 0 0
30 31 1 0 0 0 0
31 32 1 0 0 0 0
31 36 1 6 0 0 0
32 33 1 0 0 0 0
32 37 1 1 0 0 0
33 38 1 0 0 0 0
33 34 1 6 0 0 0
34 39 1 0 0 0 0
40 49 1 0 0 0 0
40 41 1 0 0 0 0
40 39 1 1 0 0 0
41 46 1 1 0 0 0
41 42 1 0 0 0 0
42 43 1 0 0 0 0
42 47 1 6 0 0 0
43 48 1 1 0 0 0
43 44 1 0 0 0 0
44 49 1 0 0 0 0
44 45 1 6 0 0 0
45 50 1 0 0 0 0
51 47 1 6 0 0 0
51 60 1 0 0 0 0
51 52 1 0 0 0 0
52 53 1 0 0 0 0
52 57 1 6 0 0 0
53 54 1 0 0 0 0
53 58 1 6 0 0 0
54 59 1 6 0 0 0
54 55 1 0 0 0 0
55 56 1 6 0 0 0
55 60 1 0 0 0 0
56 61 1 0 0 0 0
62 71 1 0 0 0 0
62 36 1 1 0 0 0
62 63 1 0 0 0 0
63 68 1 1 0 0 0
63 64 1 0 0 0 0
64 69 1 6 0 0 0
64 65 1 0 0 0 0
65 70 1 1 0 0 0
65 66 1 0 0 0 0
66 67 1 1 0 0 0
66 71 1 0 0 0 0
67 72 1 0 0 0 0
M END
$$$$

View file

@ -0,0 +1,129 @@
NSW
-OEChem-02232415193D
53 57 0 0 0 0 0 0 0999 V2000
5.9220 -0.3020 1.2970 C 0 0 0 0 0 0 0 0 0 0 0 0
1.4470 1.1670 -0.1420 N 0 0 0 0 0 0 0 0 0 0 0 0
0.2000 0.5990 0.0490 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.8390 2.2820 1.6000 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.4190 4.1380 -0.0250 C 0 0 0 0 0 0 0 0 0 0 0 0
-3.0970 4.8230 -0.0410 C 0 0 0 0 0 0 0 0 0 0 0 0
-2.6580 3.7710 0.7410 C 0 0 0 0 0 0 0 0 0 0 0 0
-3.2910 -1.6780 0.1420 C 0 0 0 0 0 0 0 0 0 0 0 0
-5.5360 -2.0330 0.9170 C 0 0 0 0 0 0 0 0 0 0 0 0
1.3400 -0.7150 -1.3020 C 0 0 0 0 0 0 0 0 0 0 0 0
3.5200 0.5160 -1.4970 C 0 0 0 0 0 0 0 0 0 0 0 0
4.4840 -0.2350 -0.6170 C 0 0 0 0 0 0 0 0 0 0 0 0
4.8140 -1.5440 -0.9150 C 0 0 0 0 0 0 0 0 0 0 0 0
5.6890 -2.2400 -0.1020 C 0 0 0 0 0 0 0 0 0 0 0 0
6.2500 -1.6170 1.0030 C 0 0 0 0 0 0 0 0 0 0 0 0
5.0380 0.3850 0.4880 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.9010 1.0040 0.8040 C 0 0 0 0 0 0 0 0 0 0 0 0
-1.3180 3.4280 0.7480 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.8580 5.1930 -0.8030 C 0 0 0 0 0 0 0 0 0 0 0 0
-2.1970 5.5360 -0.8100 C 0 0 0 0 0 0 0 0 0 0 0 0
-2.0560 -0.8580 0.1180 C 0 0 0 0 0 0 0 0 0 0 0 0
-4.3880 -1.2680 0.8980 C 0 0 0 0 0 0 0 0 0 0 0 0
-5.5990 -3.2100 0.1850 C 0 0 0 0 0 0 0 0 0 0 0 0
-4.5100 -3.6190 -0.5740 C 0 0 0 0 0 0 0 0 0 0 0 0
-3.3600 -2.8590 -0.5970 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.9650 -1.2700 -0.6360 C 0 0 0 0 0 0 0 0 0 0 0 0
2.1540 0.3180 -1.0070 N 0 0 0 0 0 0 0 0 0 0 0 0
-1.9960 0.2710 0.8210 N 0 0 0 0 0 0 0 0 0 0 0 0
0.1460 -0.5340 -0.6580 N 0 3 0 0 0 0 0 0 0 0 0 0
1.6180 -1.6500 -2.0280 O 0 0 0 0 0 0 0 0 0 0 0 0
7.1200 -2.2960 1.7980 O 0 0 0 0 0 0 0 0 0 0 0 0
-6.7310 -3.9620 0.2060 O 0 0 0 0 0 0 0 0 0 0 0 0
6.3530 0.1830 2.1600 H 0 0 0 0 0 0 0 0 0 0 0 0
1.7700 1.9970 0.2440 H 0 0 0 0 0 0 0 0 0 0 0 0
-1.4750 2.1920 2.4810 H 0 0 0 0 0 0 0 0 0 0 0 0
0.1890 2.4660 1.9120 H 0 0 0 0 0 0 0 0 0 0 0 0
0.6270 3.8710 -0.0190 H 0 0 0 0 0 0 0 0 0 0 0 0
-4.1440 5.0900 -0.0470 H 0 0 0 0 0 0 0 0 0 0 0 0
-3.3610 3.2140 1.3420 H 0 0 0 0 0 0 0 0 0 0 0 0
-6.3860 -1.7150 1.5030 H 0 0 0 0 0 0 0 0 0 0 0 0
3.7620 1.5790 -1.4770 H 0 0 0 0 0 0 0 0 0 0 0 0
3.5960 0.1460 -2.5200 H 0 0 0 0 0 0 0 0 0 0 0 0
4.3810 -2.0270 -1.7780 H 0 0 0 0 0 0 0 0 0 0 0 0
5.9460 -3.2620 -0.3350 H 0 0 0 0 0 0 0 0 0 0 0 0
4.7780 1.4080 0.7180 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.1550 5.7500 -1.4040 H 0 0 0 0 0 0 0 0 0 0 0 0
-2.5410 6.3580 -1.4210 H 0 0 0 0 0 0 0 0 0 0 0 0
-4.3390 -0.3520 1.4680 H 0 0 0 0 0 0 0 0 0 0 0 0
-4.5640 -4.5340 -1.1460 H 0 0 0 0 0 0 0 0 0 0 0 0
-2.5130 -3.1770 -1.1870 H 0 0 0 0 0 0 0 0 0 0 0 0
-1.0160 -2.1870 -1.2050 H 0 0 0 0 0 0 0 0 0 0 0 0
6.6980 -2.7710 2.5270 H 0 0 0 0 0 0 0 0 0 0 0 0
-7.3690 -3.7310 -0.4820 H 0 0 0 0 0 0 0 0 0 0 0 0
6 7 2 0 0 0 0
6 20 1 0 0 0 0
7 18 1 0 0 0 0
19 20 2 0 0 0 0
24 25 2 0 0 0 0
23 24 1 0 0 0 0
8 25 1 0 0 0 0
23 32 1 0 0 0 0
4 18 1 0 0 0 0
5 18 2 0 0 0 0
9 23 2 0 0 0 0
15 31 1 0 0 0 0
4 17 1 0 0 0 0
5 19 1 0 0 0 0
17 28 2 0 0 0 0
21 28 1 0 0 0 0
8 21 1 0 0 0 0
8 22 2 0 0 0 0
3 17 1 0 0 0 0
1 15 2 0 0 0 0
14 15 1 0 0 0 0
21 26 2 0 0 0 0
9 22 1 0 0 0 0
1 16 1 0 0 0 0
13 14 2 0 0 0 0
2 3 1 0 0 0 0
3 29 2 0 0 0 0
26 29 1 0 0 0 0
2 27 1 0 0 0 0
12 16 2 0 0 0 0
10 29 1 0 0 0 0
12 13 1 0 0 0 0
11 12 1 0 0 0 0
10 27 1 0 0 0 0
11 27 1 0 0 0 0
10 30 2 0 0 0 0
1 33 1 0 0 0 0
2 34 1 0 0 0 0
4 35 1 0 0 0 0
4 36 1 0 0 0 0
5 37 1 0 0 0 0
6 38 1 0 0 0 0
7 39 1 0 0 0 0
9 40 1 0 0 0 0
11 41 1 0 0 0 0
11 42 1 0 0 0 0
13 43 1 0 0 0 0
14 44 1 0 0 0 0
16 45 1 0 0 0 0
19 46 1 0 0 0 0
20 47 1 0 0 0 0
22 48 1 0 0 0 0
24 49 1 0 0 0 0
25 50 1 0 0 0 0
26 51 1 0 0 0 0
31 52 1 0 0 0 0
32 53 1 0 0 0 0
M CHG 1 29 1
M END
> <OPENEYE_ISO_SMILES>
c1ccc(cc1)Cc2c3[nH]n(c(=O)[n+]3cc(n2)c4ccc(cc4)O)Cc5ccc(cc5)O
> <OPENEYE_INCHI>
InChI=1S/C25H20N4O3/c30-20-10-6-18(7-11-20)15-29-25(32)28-16-23(19-8-12-21(31)13-9-19)26-22(24(28)27-29)14-17-4-2-1-3-5-17/h1-13,16,30-31H,14-15H2/p+1
> <OPENEYE_INCHIKEY>
FAFQJYQLYIUGAT-UHFFFAOYSA-O
> <FORMULA>
C25H21N4O3+
$$$$

View file

@ -0,0 +1,116 @@
XG4
-OEChem-02232415213D
48 50 0 1 0 0 0 0 0999 V2000
7.5710 2.0160 -0.3960 N 0 0 0 0 0 0 0 0 0 0 0 0
7.4820 0.6990 -0.7370 C 0 0 0 0 0 0 0 0 0 0 0 0
8.5980 0.0580 -1.2130 N 0 0 0 0 0 0 0 0 0 0 0 0
6.3590 0.0260 -0.6220 N 0 0 0 0 0 0 0 0 0 0 0 0
5.2470 0.6100 -0.1630 C 0 0 0 0 0 0 0 0 0 0 0 0
5.2730 1.9610 0.2040 C 0 0 0 0 0 0 0 0 0 0 0 0
6.4880 2.6730 0.0750 C 0 0 0 0 0 0 0 0 0 0 0 0
6.5580 3.8500 0.3880 O 0 0 0 0 0 0 0 0 0 0 0 0
4.0270 2.2870 0.6270 N 0 0 0 0 0 0 0 0 0 0 0 0
3.2570 1.2420 0.5440 C 0 0 0 0 0 0 0 0 0 0 0 0
3.9670 0.1810 0.0620 N 0 0 0 0 0 0 0 0 0 0 0 0
-2.3540 -1.9400 -0.6300 P 0 0 1 0 0 0 0 0 0 0 0 0
-4.4150 -0.1240 -0.2370 P 0 0 1 0 0 0 0 0 0 0 0 0
-6.1130 2.2430 0.3070 P 0 0 0 0 0 0 0 0 0 0 0 0
3.4520 -1.1710 -0.1750 C 0 0 1 0 0 0 0 0 0 0 0 0
-2.9530 -2.8530 0.3690 O 0 0 0 0 0 0 0 0 0 0 0 0
-4.9590 -0.4340 -1.5780 O 0 0 0 0 0 0 0 0 0 0 0 0
-7.0830 1.4520 1.0960 O 0 0 0 0 0 0 0 0 0 0 0 0
3.7430 -2.0630 1.0480 C 0 0 0 0 0 0 0 0 0 0 0 0
-2.9280 -2.2930 -2.0920 O 0 0 0 0 0 0 0 0 0 0 0 0
-5.0840 -1.1020 0.8530 O 0 0 0 0 0 0 0 0 0 0 0 0
-5.8010 3.6270 1.0690 O 0 0 0 0 0 0 0 0 0 0 0 0
2.3450 -2.5490 1.5040 C 0 0 2 0 0 0 0 0 0 0 0 0
2.3820 -3.9270 1.8790 O 0 0 0 0 0 0 0 0 0 0 0 0
-2.7470 -0.3500 -0.2430 N 0 0 0 0 0 0 0 0 0 0 0 0
-4.7480 1.4050 0.1380 O 0 0 0 0 0 0 0 0 0 0 0 0
-6.7320 2.5590 -1.1450 O 0 0 0 0 0 0 0 0 0 0 0 0
1.4970 -2.3470 0.2220 C 0 0 2 0 0 0 0 0 0 0 0 0
2.0250 -1.1280 -0.3430 O 0 0 0 0 0 0 0 0 0 0 0 0
0.0180 -2.1810 0.5750 C 0 0 0 0 0 0 0 0 0 0 0 0
-0.7540 -2.1140 -0.6250 O 0 0 0 0 0 0 0 0 0 0 0 0
8.4160 2.4830 -0.4890 H 0 0 0 0 0 0 0 0 0 0 0 0
9.4340 0.5410 -1.3030 H 0 0 0 0 0 0 0 0 0 0 0 0
8.5510 -0.8790 -1.4600 H 0 0 0 0 0 0 0 0 0 0 0 0
2.2120 1.2170 0.8170 H 0 0 0 0 0 0 0 0 0 0 0 0
3.9200 -1.5950 -1.0630 H 0 0 0 0 0 0 0 0 0 0 0 0
4.3670 -2.9100 0.7620 H 0 0 0 0 0 0 0 0 0 0 0 0
4.2220 -1.4840 1.8370 H 0 0 0 0 0 0 0 0 0 0 0 0
-2.5780 -1.7320 -2.7980 H 0 0 0 0 0 0 0 0 0 0 0 0
-4.7730 -0.9550 1.7570 H 0 0 0 0 0 0 0 0 0 0 0 0
-6.5800 4.1840 1.2060 H 0 0 0 0 0 0 0 0 0 0 0 0
1.9680 -1.9350 2.3210 H 0 0 0 0 0 0 0 0 0 0 0 0
2.9720 -4.1170 2.6220 H 0 0 0 0 0 0 0 0 0 0 0 0
-2.3430 -0.0830 0.6430 H 0 0 0 0 0 0 0 0 0 0 0 0
-6.1490 3.0750 -1.7180 H 0 0 0 0 0 0 0 0 0 0 0 0
1.6350 -3.1820 -0.4650 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.3110 -3.0320 1.1710 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.1180 -1.2630 1.1470 H 0 0 0 0 0 0 0 0 0 0 0 0
1 2 1 0 0 0 0
1 7 1 0 0 0 0
1 32 1 0 0 0 0
2 4 2 0 0 0 0
2 3 1 0 0 0 0
3 33 1 0 0 0 0
3 34 1 0 0 0 0
4 5 1 0 0 0 0
5 11 1 0 0 0 0
5 6 2 0 0 0 0
6 9 1 0 0 0 0
6 7 1 0 0 0 0
7 8 2 0 0 0 0
9 10 2 0 0 0 0
10 11 1 0 0 0 0
10 35 1 0 0 0 0
11 15 1 0 0 0 0
12 31 1 0 0 0 0
12 16 2 0 0 0 0
12 20 1 0 0 0 0
12 25 1 0 0 0 0
13 17 2 0 0 0 0
13 25 1 0 0 0 0
13 21 1 0 0 0 0
13 26 1 0 0 0 0
14 18 2 0 0 0 0
14 26 1 0 0 0 0
14 22 1 0 0 0 0
14 27 1 0 0 0 0
15 29 1 0 0 0 0
15 19 1 0 0 0 0
15 36 1 0 0 0 0
19 23 1 0 0 0 0
19 37 1 0 0 0 0
19 38 1 0 0 0 0
20 39 1 0 0 0 0
21 40 1 0 0 0 0
22 41 1 0 0 0 0
23 28 1 0 0 0 0
23 24 1 0 0 0 0
23 42 1 0 0 0 0
24 43 1 0 0 0 0
25 44 1 0 0 0 0
27 45 1 0 0 0 0
28 29 1 0 0 0 0
28 30 1 0 0 0 0
28 46 1 0 0 0 0
30 31 1 0 0 0 0
30 47 1 0 0 0 0
30 48 1 0 0 0 0
M END
> <OPENEYE_ISO_SMILES>
c1nc2c(=O)[nH]c(nc2n1[C@H]3C[C@@H]([C@H](O3)CO[P@@](=O)(N[P@@](=O)(O)OP(=O)(O)O)O)O)N
> <OPENEYE_INCHI>
InChI=1S/C10H17N6O12P3/c11-10-13-8-7(9(18)14-10)12-3-16(8)6-1-4(17)5(27-6)2-26-29(19,20)15-30(21,22)28-31(23,24)25/h3-6,17H,1-2H2,(H2,23,24,25)(H3,11,13,14,18)(H3,15,19,20,21,22)/t4-,5+,6+/m0/s1
> <OPENEYE_INCHIKEY>
DWGAAFQEGIMTIA-KVQBGUIXSA-N
> <FORMULA>
C10H17N6O12P3
$$$$

BIN
img/RFAA.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 722 KiB

121
make_msa.sh Executable file
View file

@ -0,0 +1,121 @@
#!/bin/bash
# inputs
in_fasta="$1"
out_dir="$2"
# resources
CPU="$3"
MEM="$4"
# pipe_dir
PIPE_DIR="$5"
DB_TEMPL="$6"
# sequence databases
DB_UR30="$PIPE_DIR/uniclust/UniRef30_2021_06"
DB_BFD="$PIPE_DIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
# Running signalP 6.0
mkdir -p $out_dir/signalp
tmp_dir="$out_dir/signalp"
signalp6 --fastafile $in_fasta --organism other --output_dir $tmp_dir --format none --mode slow
trim_fasta="$tmp_dir/processed_entries.fasta"
if [ ! -s $trim_fasta ] # empty file -- no signal P
then
trim_fasta="$in_fasta"
fi
# setup hhblits command
export HHLIB=/software/hhsuite/build/bin/
export PATH=$HHLIB:$PATH
HHBLITS_UR30="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_UR30"
HHBLITS_BFD="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_BFD"
mkdir -p $out_dir/hhblits
tmp_dir="$out_dir/hhblits"
out_prefix="$out_dir/t000_"
# perform iterative searches against UniRef30
if [ ! -s ${out_prefix}.msa0.a3m ]
then
prev_a3m="$trim_fasta"
for e in 1e-10 1e-6 1e-3
do
echo "Running HHblits against UniRef30 with E-value cutoff $e"
if [ ! -s $tmp_dir/t000_.$e.a3m ]
then
$HHBLITS_UR30 -i $prev_a3m -oa3m $tmp_dir/t000_.$e.a3m -e $e -v 0
fi
hhfilter -maxseq 100000 -id 90 -cov 75 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov75.a3m
hhfilter -maxseq 100000 -id 90 -cov 50 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov50.a3m
prev_a3m="$tmp_dir/t000_.$e.id90cov50.a3m"
n75=`grep -c "^>" $tmp_dir/t000_.$e.id90cov75.a3m`
n50=`grep -c "^>" $tmp_dir/t000_.$e.id90cov50.a3m`
if ((n75>2000))
then
if [ ! -s ${out_prefix}.msa0.a3m ]
then
cp $tmp_dir/t000_.$e.id90cov75.a3m ${out_prefix}.msa0.a3m
break
fi
elif ((n50>4000))
then
if [ ! -s ${out_prefix}.msa0.a3m ]
then
cp $tmp_dir/t000_.$e.id90cov50.a3m ${out_prefix}.msa0.a3m
break
fi
else
continue
fi
done
# perform iterative searches against BFD if it failes to get enough sequences
if [ ! -s ${out_prefix}.msa0.a3m ]
then
e=1e-3
echo "Running HHblits against BFD with E-value cutoff $e"
if [ ! -s $tmp_dir/t000_.$e.bfd.a3m ]
then
$HHBLITS_BFD -i $prev_a3m -oa3m $tmp_dir/t000_.$e.bfd.a3m -e $e -v 0
fi
hhfilter -maxseq 100000 -id 90 -cov 75 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov75.a3m
hhfilter -maxseq 100000 -id 90 -cov 50 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov50.a3m
prev_a3m="$tmp_dir/t000_.$e.bfd.id90cov50.a3m"
n75=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov75.a3m`
n50=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov50.a3m`
if ((n75>2000))
then
if [ ! -s ${out_prefix}.msa0.a3m ]
then
cp $tmp_dir/t000_.$e.bfd.id90cov75.a3m ${out_prefix}.msa0.a3m
fi
elif ((n50>4000))
then
if [ ! -s ${out_prefix}.msa0.a3m ]
then
cp $tmp_dir/t000_.$e.bfd.id90cov50.a3m ${out_prefix}.msa0.a3m
fi
fi
fi
if [ ! -s ${out_prefix}.msa0.a3m ]
then
cp $prev_a3m ${out_prefix}.msa0.a3m
fi
fi
echo "Running PSIPRED"
$PIPE_DIR/input_prep/make_ss.sh $out_dir/t000_.msa0.a3m $out_dir/t000_.ss2 > $out_dir/log/make_ss.stdout 2> $out_dir/log/make_ss.stderr
if [ ! -s $out_dir/t000_.hhr ]
then
echo "Running hhsearch"
HH="hhsearch -b 50 -B 500 -z 50 -Z 500 -mact 0.05 -cpu $CPU -maxmem $MEM -aliw 100000 -e 100 -p 5.0 -d $DB_TEMPL"
cat $out_dir/t000_.ss2 $out_dir/t000_.msa0.a3m > $out_dir/t000_.msa0.ss2.a3m
$HH -i $out_dir/t000_.msa0.ss2.a3m -o $out_dir/t000_.hhr -atab $out_dir/t000_.atab -v 0
fi

View file

@ -0,0 +1,123 @@
.Trash-0
.git
data/
.DS_Store
*wandb/
*.pt
*.swp
# added by FAFU
.idea/
cache/
downloaded/
*.lprof
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
**/benchmark
**/results
*.pkl
*.log

121
rf2aa/SE3Transformer/.gitignore vendored Normal file
View file

@ -0,0 +1,121 @@
data/
.DS_Store
*wandb/
*.pt
*.swp
# added by FAFU
.idea/
cache/
downloaded/
*.lprof
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
**/benchmark
**/results
*.pkl
*.log

View file

@ -0,0 +1,58 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
# run docker daemon with --default-runtime=nvidia for GPU detection during build
# multistage build for DGL with CUDA and FP16
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.07-py3
FROM ${FROM_IMAGE_NAME} AS dgl_builder
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update \
&& apt-get install -y git build-essential python3-dev make cmake \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /dgl
RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake
WORKDIR build
RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
RUN make -j8
FROM ${FROM_IMAGE_NAME}
RUN rm -rf /workspace/*
WORKDIR /workspace/se3-transformer
# copy built DGL and install it
COPY --from=dgl_builder /dgl ./dgl
RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl
ADD requirements.txt .
RUN pip install --no-cache-dir --upgrade --pre pip
RUN pip install --no-cache-dir -r requirements.txt
ADD . .
ENV DGLBACKEND=pytorch
ENV OMP_NUM_THREADS=1

View file

@ -0,0 +1,7 @@
Copyright 2021 NVIDIA CORPORATION & AFFILIATES
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,7 @@
SE(3)-Transformer PyTorch
This repository includes software from https://github.com/FabianFuchsML/se3-transformer-public
licensed under the MIT License.
This repository includes software from https://github.com/lucidrains/se3-transformer-pytorch
licensed under the MIT License.

View file

@ -0,0 +1,580 @@
# SE(3)-Transformers For PyTorch
This repository provides a script and recipe to train the SE(3)-Transformer model to achieve state-of-the-art accuracy. The content of this repository is tested and maintained by NVIDIA.
## Table Of Contents
- [Model overview](#model-overview)
* [Model architecture](#model-architecture)
* [Default configuration](#default-configuration)
* [Feature support matrix](#feature-support-matrix)
* [Features](#features)
* [Mixed precision training](#mixed-precision-training)
* [Enabling mixed precision](#enabling-mixed-precision)
* [Enabling TF32](#enabling-tf32)
* [Glossary](#glossary)
- [Setup](#setup)
* [Requirements](#requirements)
- [Quick Start Guide](#quick-start-guide)
- [Advanced](#advanced)
* [Scripts and sample code](#scripts-and-sample-code)
* [Parameters](#parameters)
* [Command-line options](#command-line-options)
* [Getting the data](#getting-the-data)
* [Dataset guidelines](#dataset-guidelines)
* [Multi-dataset](#multi-dataset)
* [Training process](#training-process)
* [Inference process](#inference-process)
- [Performance](#performance)
* [Benchmarking](#benchmarking)
* [Training performance benchmark](#training-performance-benchmark)
* [Inference performance benchmark](#inference-performance-benchmark)
* [Results](#results)
* [Training accuracy results](#training-accuracy-results)
* [Training accuracy: NVIDIA DGX A100 (8x A100 80GB)](#training-accuracy-nvidia-dgx-a100-8x-a100-80gb)
* [Training accuracy: NVIDIA DGX-1 (8x V100 16GB)](#training-accuracy-nvidia-dgx-1-8x-v100-16gb)
* [Training stability test](#training-stability-test)
* [Training performance results](#training-performance-results)
* [Training performance: NVIDIA DGX A100 (8x A100 80GB)](#training-performance-nvidia-dgx-a100-8x-a100-80gb)
* [Training performance: NVIDIA DGX-1 (8x V100 16GB)](#training-performance-nvidia-dgx-1-8x-v100-16gb)
* [Inference performance results](#inference-performance-results)
* [Inference performance: NVIDIA DGX A100 (1x A100 80GB)](#inference-performance-nvidia-dgx-a100-1x-a100-80gb)
* [Inference performance: NVIDIA DGX-1 (1x V100 16GB)](#inference-performance-nvidia-dgx-1-1x-v100-16gb)
- [Release notes](#release-notes)
* [Changelog](#changelog)
* [Known issues](#known-issues)
## Model overview
The **SE(3)-Transformer** is a Graph Neural Network using a variant of [self-attention](https://arxiv.org/abs/1706.03762v5) for 3D points and graphs processing.
This model is [equivariant](https://en.wikipedia.org/wiki/Equivariant_map) under [continuous 3D roto-translations](https://en.wikipedia.org/wiki/Euclidean_group), meaning that when the inputs (graphs or sets of points) rotate in 3D space (or more generally experience a [proper rigid transformation](https://en.wikipedia.org/wiki/Rigid_transformation)), the model outputs either stay invariant or transform with the input.
A mathematical guarantee of equivariance is important to ensure stable and predictable performance in the presence of nuisance transformations of the data input and when the problem has some inherent symmetries we want to exploit.
The model is based on the following publications:
- [SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks](https://arxiv.org/abs/2006.10503) (NeurIPS 2020) by Fabian B. Fuchs, Daniel E. Worrall, et al.
- [Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds](https://arxiv.org/abs/1802.08219) by Nathaniel Thomas, Tess Smidt, et al.
A follow-up paper explains how this model can be used iteratively, for example, to predict or refine protein structures:
- [Iterative SE(3)-Transformers](https://arxiv.org/abs/2102.13419) by Fabian B. Fuchs, Daniel E. Worrall, et al.
Just like [the official implementation](https://github.com/FabianFuchsML/se3-transformer-public), this implementation uses [PyTorch](https://pytorch.org/) and the [Deep Graph Library (DGL)](https://www.dgl.ai/).
The main differences between this implementation of SE(3)-Transformers and the official one are the following:
- Training and inference support for multiple GPUs
- Training and inference support for [Mixed Precision](https://arxiv.org/abs/1710.03740)
- The [QM9 dataset from DGL](https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset) is used and automatically downloaded
- Significantly increased throughput
- Significantly reduced memory consumption
- The use of layer normalization in the fully connected radial profile layers is an option (`--use_layer_norm`), off by default
- The use of equivariant normalization between attention layers is an option (`--norm`), off by default
- The [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonic) and [ClebschGordan coefficients](https://en.wikipedia.org/wiki/Clebsch%E2%80%93Gordan_coefficients), used to compute bases matrices, are computed with the [e3nn library](https://e3nn.org/)
This model enables you to predict quantum chemical properties of small organic molecules in the [QM9 dataset](https://www.nature.com/articles/sdata201422).
In this case, the exploited symmetry is that these properties do not depend on the orientation or position of the molecules in space.
This model is trained with mixed precision using Tensor Cores on NVIDIA Volta, NVIDIA Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results up to 1.5x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
### Model architecture
The model consists of stacked layers of equivariant graph self-attention and equivariant normalization.
Lastly, a Tensor Field Network convolution is applied to obtain invariant features. Graph pooling (mean or max over the nodes) is applied to these features, and the result is fed to a final MLP to get scalar predictions.
In this setup, the model is a graph-to-scalar network. The pooling can be removed to obtain a graph-to-graph network, and the final TFN can be modified to output features of any type (invariant scalars, 3D vectors, ...).
![Model high-level architecture](./images/se3-transformer.png)
### Default configuration
SE(3)-Transformers introduce a self-attention layer for graphs that is equivariant to 3D roto-translations. It achieves this by leveraging Tensor Field Networks to build attention weights that are invariant and attention values that are equivariant.
Combining the equivariant values with the invariant weights gives rise to an equivariant output. This output is normalized while preserving equivariance thanks to equivariant normalization layers operating on feature norms.
The following features were implemented in this model:
- Support for edge features of any degree (1D, 3D, 5D, ...), whereas the official implementation only supports scalar invariant edge features (degree 0). Edge features with a degree greater than one are
concatenated to node features of the same degree. This is required in order to reproduce published results on point cloud processing.
- Data-parallel multi-GPU training (DDP)
- Mixed precision training (autocast, gradient scaling)
- Gradient accumulation
- Model checkpointing
The following performance optimizations were implemented in this model:
**General optimizations**
- The option is provided to precompute bases at the beginning of the training instead of computing them at the beginning of each forward pass (`--precompute_bases`)
- The bases computation is just-in-time (JIT) compiled with `torch.jit.script`
- The Clebsch-Gordon coefficients are cached in RAM
**Tensor Field Network optimizations**
- The last layer of each radial profile network does not add any bias in order to avoid large broadcasting operations
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
**Self-attention optimizations**
- Attention keys and values are computed by a single partial TFN graph convolution in each attention layer instead of two
- Graph operations for different output degrees may be fused together if conditions are met
**Normalization optimizations**
- The equivariant normalization layer is optimized from multiple layer normalizations to a group normalization on fused norms when certain conditions are met
Competitive training results and analysis are provided for the following hyperparameters (identical to the ones in the original publication):
- Number of layers: 7
- Number of degrees: 4
- Number of channels: 32
- Number of attention heads: 8
- Channels division: 2
- Use of equivariant normalization: true
- Use of layer normalization: true
- Pooling: max
### Feature support matrix
This model supports the following features::
| Feature | SE(3)-Transformer
|-----------------------|--------------------------
|Automatic mixed precision (AMP) | Yes
|Distributed data parallel (DDP) | Yes
#### Features
**Distributed data parallel (DDP)**
[DistributedDataParallel (DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implements data parallelism at the module level that can run across multiple GPUs or machines.
**Automatic Mixed Precision (AMP)**
This implementation uses the native PyTorch AMP implementation of mixed precision training. It allows us to use FP16 training with FP32 master weights by modifying just a few lines of code. A detailed explanation of mixed precision can be found in the next section.
### Mixed precision training
Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in NVIDIA Volta, and following with both the NVIDIA Turing and NVIDIA Ampere Architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using [mixed precision training](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) previously required two steps:
1. Porting the model to use the FP16 data type where appropriate.
2. Adding loss scaling to preserve small gradient values.
AMP enables mixed precision training on NVIDIA Volta, NVIDIA Turing, and NVIDIA Ampere GPU architectures automatically. The PyTorch framework code makes all necessary model changes internally.
For information about:
- How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) documentation.
- Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
- APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
#### Enabling mixed precision
Mixed precision is enabled in PyTorch by using the native [Automatic Mixed Precision package](https://pytorch.org/docs/stable/amp.html), which casts variables to half-precision upon retrieval while storing variables in single-precision format. Furthermore, to preserve small gradient magnitudes in backpropagation, a [loss scaling](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#lossscaling) step must be included when applying gradients. In PyTorch, loss scaling can be applied automatically using a `GradScaler`.
Automatic Mixed Precision makes all the adjustments internally in PyTorch, providing two benefits over manual operations. First, programmers need not modify network model code, reducing development and maintenance effort. Second, using AMP maintains forward and backward compatibility with all the APIs for defining and running PyTorch models.
To enable mixed precision, you can simply use the `--amp` flag when running the training or inference scripts.
#### Enabling TF32
TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math, also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on NVIDIA Volta GPUs.
TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models that require a high dynamic range for weights or activations.
For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
### Glossary
**Degree (type)**
In the model, every feature (input, output and hidden) transforms in an equivariant way in relation to the input graph. When we define a feature, we need to choose, in addition to the number of channels, which transformation rule it obeys.
The degree or type of a feature is a positive integer that describes how this feature transforms when the input rotates in 3D.
This is related to [irreducible representations](https://en.wikipedia.org/wiki/Irreducible_representation) of different rotation orders.
The degree of a feature determines its dimensionality. A type-d feature has a dimensionality of 2d+1.
Some common examples include:
- Degree 0: 1D scalars invariant to rotation
- Degree 1: 3D vectors that rotate according to 3D rotation matrices
- Degree 2: 5D vectors that rotate according to 5D [Wigner-D matrices](https://en.wikipedia.org/wiki/Wigner_D-matrix). These can represent symmetric traceless 3x3 matrices.
**Fiber**
A fiber can be viewed as a representation of a set of features of different types or degrees (positive integers), where each feature type transforms according to its rule.
In this repository, a fiber can be seen as a dictionary with degrees as keys and numbers of channels as values.
**Multiplicity**
The multiplicity of a feature of a given type is the number of channels of this feature.
**Tensor Field Network**
A [Tensor Field Network](https://arxiv.org/abs/1802.08219) is a kind of equivariant graph convolution that can combine features of different degrees and produce new ones while preserving equivariance thanks to [tensor products](https://en.wikipedia.org/wiki/Tensor_product).
**Equivariance**
[Equivariance](https://en.wikipedia.org/wiki/Equivariant_map) is a property of a function of model stating that applying a symmetry transformation to the input and then computing the function produces the same result as computing the function and then applying the transformation to the output.
In the case of SE(3)-Transformer, the symmetry group is the group of continuous roto-translations (SE(3)).
## Setup
The following section lists the requirements that you need to meet in order to start training the SE(3)-Transformer model.
### Requirements
This repository contains a Dockerfile which extends the PyTorch 21.07 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
- PyTorch 21.07+ NGC container
- Supported GPUs:
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
- [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
For more information about how to get started with NGC containers, refer to the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
- [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
For those unable to use the PyTorch NGC container to set up the required environment or create your own container, refer to the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
## Quick Start Guide
To train your model using mixed or TF32 precision with Tensor Cores or FP32, perform the following steps using the default parameters of the SE(3)-Transformer model on the QM9 dataset. For the specifics concerning training and inference, refer to the [Advanced](#advanced) section.
1. Clone the repository.
```
git clone https://github.com/NVIDIA/DeepLearningExamples
cd DeepLearningExamples/PyTorch/DrugDiscovery/SE3Transformer
```
2. Build the `se3-transformer` PyTorch NGC container.
```
docker build -t se3-transformer .
```
3. Start an interactive session in the NGC container to run training/inference.
```
mkdir -p results
docker run -it --runtime=nvidia --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 --rm -v ${PWD}/results:/results se3-transformer:latest
```
4. Start training.
```
bash scripts/train.sh
```
5. Start inference/predictions.
```
bash scripts/predict.sh
```
Now that you have your model trained and evaluated, you can choose to compare your training results with our [Training accuracy results](#training-accuracy-results). You can also choose to benchmark your performance to [Training performance benchmark](#training-performance-results) or [Inference performance benchmark](#inference-performance-results). Following the steps in these sections will ensure that you achieve the same accuracy and performance results as stated in the [Results](#results) section.
## Advanced
The following sections provide greater details of the dataset, running training and inference, and the training results.
### Scripts and sample code
In the root directory, the most important files are:
- `Dockerfile`: container with the basic set of dependencies to run SE(3)-Transformers
- `requirements.txt`: set of extra requirements to run SE(3)-Transformers
- `se3_transformer/data_loading/qm9.py`: QM9 data loading and preprocessing, as well as bases precomputation
- `se3_transformer/model/layers/`: directory containing model architecture layers
- `se3_transformer/model/transformer.py`: main Transformer module
- `se3_transformer/model/basis.py`: logic for computing bases matrices
- `se3_transformer/runtime/training.py`: training script, to be run as a python module
- `se3_transformer/runtime/inference.py`: inference script, to be run as a python module
- `se3_transformer/runtime/metrics.py`: MAE metric with support for multi-GPU synchronization
- `se3_transformer/runtime/loggers.py`: [DLLogger](https://github.com/NVIDIA/dllogger) and [W&B](wandb.ai/) loggers
### Parameters
The complete list of the available parameters for the `training.py` script contains:
**General**
- `--epochs`: Number of training epochs (default: `100` for single-GPU)
- `--batch_size`: Batch size (default: `240`)
- `--seed`: Set a seed globally (default: `None`)
- `--num_workers`: Number of dataloading workers (default: `8`)
- `--amp`: Use Automatic Mixed Precision (default `false`)
- `--gradient_clip`: Clipping of the gradient norms (default: `None`)
- `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
- `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
- `--eval_interval`: Do an evaluation round every N epochs (default: `1`)
- `--silent`: Minimize stdout output (default: `false`)
**Paths**
- `--data_dir`: Directory where the data is located or should be downloaded (default: `./data`)
- `--log_dir`: Directory where the results logs should be saved (default: `/results`)
- `--save_ckpt_path`: File where the checkpoint should be saved (default: `None`)
- `--load_ckpt_path`: File of the checkpoint to be loaded (default: `None`)
**Optimizer**
- `--optimizer`: Optimizer to use (default: `adam`)
- `--learning_rate`: Learning rate to use (default: `0.002` for single-GPU)
- `--momentum`: Momentum to use (default: `0.9`)
- `--weight_decay`: Weight decay to use (default: `0.1`)
**QM9 dataset**
- `--task`: Regression task to train on (default: `homo`)
- `--precompute_bases`: Precompute bases at the beginning of the script during dataset initialization, instead of computing them at the beginning of each forward pass (default: `false`)
**Model architecture**
- `--num_layers`: Number of stacked Transformer layers (default: `7`)
- `--num_heads`: Number of heads in self-attention (default: `8`)
- `--channels_div`: Channels division before feeding to attention layer (default: `2`)
- `--pooling`: Type of graph pooling (default: `max`)
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
- `--num_channels`: Number of channels for the hidden features (default: `32`)
### Command-line options
To show the full list of available options and their descriptions, use the `-h` or `--help` command-line option, for example: `python -m se3_transformer.runtime.training --help`.
### Dataset guidelines
#### Demo dataset
The SE(3)-Transformer was trained on the QM9 dataset.
The QM9 dataset is hosted on DGL servers and downloaded (38MB) automatically when needed. By default, it is stored in the `./data` directory, but this location can be changed with the `--data_dir` argument.
The dataset is saved as a `qm9_edge.npz` file and converted to DGL graphs at runtime.
As input features, we use:
- Node features (6D):
- One-hot-encoded atom type (5D) (atom types: H, C, N, O, F)
- Number of protons of each atom (1D)
- Edge features: one-hot-encoded bond type (4D) (bond types: single, double, triple, aromatic)
- The relative positions between adjacent nodes (atoms)
#### Custom datasets
To use this network on a new dataset, you can extend the `DataModule` class present in `se3_transformer/data_loading/data_module.py`.
Your custom collate function should return a tuple with:
- A (batched) DGLGraph object
- A dictionary of node features ({{degree}: tensor})
- A dictionary of edge features ({{degree}: tensor})
- (Optional) Precomputed bases as a dictionary
- Labels as a tensor
You can then modify the `training.py` and `inference.py` scripts to use your new data module.
### Training process
The training script is `se3_transformer/runtime/training.py`, to be run as a module: `python -m se3_transformer.runtime.training`.
**Logs**
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
**Checkpoints**
The argument `--save_ckpt_path` can be set to the path of the file where the checkpoints should be saved.
`--ckpt_interval` can also be set to the interval (in the number of epochs) between checkpoints.
**Evaluation**
The evaluation metric is the Mean Absolute Error (MAE).
`--eval_interval` can be set to the interval (in the number of epochs) between evaluation rounds. By default, an evaluation round is performed after each epoch.
**Automatic Mixed Precision**
To enable Mixed Precision training, add the `--amp` flag.
**Multi-GPU and multi-node**
The training script supports the PyTorch elastic launcher to run on multiple GPUs or nodes. Refer to the [official documentation](https://pytorch.org/docs/1.9.0/elastic/run.html).
For example, to train on all available GPUs with AMP:
```
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --module se3_transformer.runtime.training --amp
```
### Inference process
Inference can be run by using the `se3_transformer.runtime.inference` python module.
The inference script is `se3_transformer/runtime/inference.py`, to be run as a module: `python -m se3_transformer.runtime.inference`. It requires a pre-trained model checkpoint (to be passed as `--load_ckpt_path`).
## Performance
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIAs latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
### Benchmarking
The following section shows how to run benchmarks measuring the model performance in training and inference modes.
#### Training performance benchmark
To benchmark the training performance on a specific batch size, run `bash scripts/benchmarck_train.sh {BATCH_SIZE}` for single GPU, and `bash scripts/benchmarck_train_multi_gpu.sh {BATCH_SIZE}` for multi-GPU.
#### Inference performance benchmark
To benchmark the inference performance on a specific batch size, run `bash scripts/benchmarck_inference.sh {BATCH_SIZE}`.
### Results
The following sections provide details on how we achieved our performance and accuracy in training and inference.
#### Training accuracy results
##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB)
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 (8x A100 80GB) GPUs.
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
| 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x |
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
#### Training performance results
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 8x A100 80GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
| GPUs | Batch size / GPU | Throughput - TF32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (mixed precision - TF32) | Weak scaling - TF32 | Weak scaling - mixed precision |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 2.21 | 2.92 | 1.32x | | |
| 1 | 120 | 1.81 | 2.04 | 1.13x | | |
| 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 |
| 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 8x V100 16GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
| GPUs | Batch size / GPU | Throughput - FP32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (FP32 - mixed precision) | Weak scaling - FP32 | Weak scaling - mixed precision |
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
| 1 | 240 | 1.25 | 1.88 | 1.50x | | |
| 1 | 120 | 1.03 | 1.41 | 1.37x | | |
| 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 |
| 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
#### Inference performance results
##### Inference performance: NVIDIA DGX A100 (1x A100 80GB)
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 1x A100 80GB GPU.
FP16
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
| 1600 | 11.60 | 140.94 | 138.29 | 140.12 | 386.40 |
| 800 | 10.74 | 75.69 | 75.74 | 76.50 | 79.77 |
| 400 | 8.86 | 45.57 | 46.11 | 46.60 | 49.97 |
TF32
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
| 1600 | 8.58 | 189.20 | 186.39 | 187.71 | 420.28 |
| 800 | 8.28 | 97.56 | 97.20 | 97.73 | 101.13 |
| 400 | 7.55 | 53.38 | 53.72 | 54.48 | 56.62 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 1x V100 16GB GPU.
FP16
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
| 1600 | 6.42 | 254.54 | 247.97 | 249.29 | 721.15 |
| 800 | 6.13 | 132.07 | 131.90 | 132.70 | 140.15 |
| 400 | 5.37 | 75.12 | 76.01 | 76.66 | 79.90 |
FP32
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
| 1600 | 3.39 | 475.86 | 473.82 | 475.64 | 891.18 |
| 800 | 3.36 | 239.17 | 240.64 | 241.65 | 243.70 |
| 400 | 3.17 | 126.67 | 128.19 | 128.82 | 130.54 |
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
## Release notes
### Changelog
August 2021
- Initial release
### Known issues
If you encounter `OSError: [Errno 12] Cannot allocate memory` during the Dataloader iterator creation (more precisely during the `fork()`, this is most likely due to the use of the `--precompute_bases` flag. If you cannot add more RAM or Swap to your machine, it is recommended to turn off bases precomputation by removing the `--precompute_bases` flag or using `--precompute_bases false`.

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

View file

@ -0,0 +1,4 @@
e3nn==0.3.3
wandb==0.12.0
pynvml==11.0.0
git+https://github.com/NVIDIA/dllogger#egg=dllogger

View file

@ -0,0 +1,15 @@
#!/usr/bin/env bash
# Script to benchmark inference performance, without bases precomputation
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.inference \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--use_layer_norm \
--norm \
--task homo \
--seed 42 \
--benchmark

View file

@ -0,0 +1,18 @@
#!/usr/bin/env bash
# Script to benchmark single-GPU training performance, with bases precomputation
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.training \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--epochs 6 \
--use_layer_norm \
--norm \
--save_ckpt_path model_qm9.pth \
--task homo \
--precompute_bases \
--seed 42 \
--benchmark

View file

@ -0,0 +1,19 @@
#!/usr/bin/env bash
# Script to benchmark multi-GPU training performance, with bases precomputation
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
se3_transformer.runtime.training \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--epochs 6 \
--use_layer_norm \
--norm \
--save_ckpt_path model_qm9.pth \
--task homo \
--precompute_bases \
--seed 42 \
--benchmark

View file

@ -0,0 +1,19 @@
#!/usr/bin/env bash
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
# 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
TASK=homo
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
se3_transformer.runtime.inference \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--use_layer_norm \
--norm \
--load_ckpt_path model_qm9.pth \
--task "$TASK"

View file

@ -0,0 +1,25 @@
#!/usr/bin/env bash
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
NUM_EPOCHS=${3:-100}
LEARNING_RATE=${4:-0.002}
WEIGHT_DECAY=${5:-0.1}
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
# 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
TASK=homo
python -m se3_transformer.runtime.training \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--epochs "$NUM_EPOCHS" \
--lr "$LEARNING_RATE" \
--weight_decay "$WEIGHT_DECAY" \
--use_layer_norm \
--norm \
--save_ckpt_path model_qm9.pth \
--precompute_bases \
--seed 42 \
--task "$TASK"

View file

@ -0,0 +1,27 @@
#!/usr/bin/env bash
# CLI args with defaults
BATCH_SIZE=${1:-240}
AMP=${2:-true}
NUM_EPOCHS=${3:-130}
LEARNING_RATE=${4:-0.01}
WEIGHT_DECAY=${5:-0.1}
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
# 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
TASK=homo
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
se3_transformer.runtime.training \
--amp "$AMP" \
--batch_size "$BATCH_SIZE" \
--epochs "$NUM_EPOCHS" \
--lr "$LEARNING_RATE" \
--min_lr 0.00001 \
--weight_decay "$WEIGHT_DECAY" \
--use_layer_norm \
--norm \
--save_ckpt_path model_qm9.pth \
--precompute_bases \
--seed 42 \
--task "$TASK"

View file

@ -0,0 +1 @@
from .qm9 import QM9DataModule

View file

@ -0,0 +1,63 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import torch.distributed as dist
from abc import ABC
from torch.utils.data import DataLoader, DistributedSampler, Dataset
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import get_local_rank
def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
# Classic or distributed dataloader depending on the context
sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None
return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs)
class DataModule(ABC):
""" Abstract DataModule. Children must define self.ds_{train | val | test}. """
def __init__(self, **dataloader_kwargs):
super().__init__()
if get_local_rank() == 0:
self.prepare_data()
# Wait until rank zero has prepared the data (download, preprocessing, ...)
if dist.is_initialized():
dist.barrier(device_ids=[get_local_rank()])
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
self.ds_train, self.ds_val, self.ds_test = None, None, None
def prepare_data(self):
""" Method called only once per node. Put here any downloading or preprocessing """
pass
def train_dataloader(self) -> DataLoader:
return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs)
def val_dataloader(self) -> DataLoader:
return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs)
def test_dataloader(self) -> DataLoader:
return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs)

View file

@ -0,0 +1,173 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Tuple
import dgl
import pathlib
import torch
from dgl.data import QM9EdgeDataset
from dgl import DGLGraph
from torch import Tensor
from torch.utils.data import random_split, DataLoader, Dataset
from tqdm import tqdm
from rf2aa.SE3Transformer.se3_transformer.data_loading.data_module import DataModule
from rf2aa.SE3Transformer.se3_transformer.model.basis import get_basis
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
x = qm9_graph.ndata['pos']
src, dst = qm9_graph.edges()
rel_pos = x[dst] - x[src]
return rel_pos
def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]:
len_full = len(full_dataset)
len_train = 100_000
len_test = int(0.1 * len_full)
len_val = len_full - len_train - len_test
return len_train, len_val, len_test
class QM9DataModule(DataModule):
"""
Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset
Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest.
This includes all the molecules from QM9 except the ones that are uncharacterized.
"""
NODE_FEATURE_DIM = 6
EDGE_FEATURE_DIM = 4
def __init__(self,
data_dir: pathlib.Path,
task: str = 'homo',
batch_size: int = 240,
num_workers: int = 8,
num_degrees: int = 4,
amp: bool = False,
precompute_bases: bool = False,
**kwargs):
self.data_dir = data_dir # This needs to be before __init__ so that prepare_data has access to it
super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate)
self.amp = amp
self.task = task
self.batch_size = batch_size
self.num_degrees = num_degrees
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
if precompute_bases:
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
num_workers=num_workers, **qm9_kwargs)
else:
full_dataset = QM9EdgeDataset(**qm9_kwargs)
self.ds_train, self.ds_val, self.ds_test = random_split(full_dataset, _get_split_sizes(full_dataset),
generator=torch.Generator().manual_seed(0))
train_targets = full_dataset.targets[self.ds_train.indices, full_dataset.label_keys[0]]
self.targets_mean = train_targets.mean()
self.targets_std = train_targets.std()
def prepare_data(self):
# Download the QM9 preprocessed data
QM9EdgeDataset(verbose=True, raw_dir=str(self.data_dir))
def _collate(self, samples):
graphs, y, *bases = map(list, zip(*samples))
batched_graph = dgl.batch(graphs)
edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
# get node features
node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
targets = (torch.cat(y) - self.targets_mean) / self.targets_std
if bases:
# collate bases
all_bases = {
key: torch.cat([b[key] for b in bases[0]], dim=0)
for key in bases[0][0].keys()
}
return batched_graph, node_feats, edge_feats, all_bases, targets
else:
return batched_graph, node_feats, edge_feats, targets
@staticmethod
def add_argparse_args(parent_parser):
parser = parent_parser.add_argument_group("QM9 dataset")
parser.add_argument('--task', type=str, default='homo', const='homo', nargs='?',
choices=['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'],
help='Regression task to train on')
parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False,
help='Precompute bases at the beginning of the script during dataset initialization,'
' instead of computing them at the beginning of each forward pass.')
return parent_parser
def __repr__(self):
return f'QM9({self.task})'
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
"""
:param bases_kwargs: Arguments to feed the bases computation function
:param batch_size: Batch size to use when iterating over the dataset for computing bases
"""
self.bases_kwargs = bases_kwargs
self.batch_size = batch_size
self.bases = None
self.num_workers = num_workers
super().__init__(*args, **kwargs)
def load(self):
super().load()
# Iterate through the dataset and compute bases (pairwise only)
# Potential improvement: use multi-GPU and gather
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
bases = []
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
disable=get_local_rank() != 0):
rel_pos = _get_relative_pos(graph)
# Compute the bases with the GPU but convert the result to CPU to store in RAM
bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()})
self.bases = bases # Assign at the end so that __getitem__ isn't confused
def __getitem__(self, idx: int):
graph, label = super().__getitem__(idx)
if self.bases:
bases_idx = idx // self.batch_size
bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size]
bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size]
return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in
self.bases[bases_idx].items()}
else:
return graph, label

View file

@ -0,0 +1,2 @@
from .transformer import SE3Transformer, SE3TransformerPooled
from .fiber import Fiber

View file

@ -0,0 +1,178 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from functools import lru_cache
from typing import Dict, List
import e3nn.o3 as o3
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim
@lru_cache(maxsize=None)
def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
""" Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)
@lru_cache(maxsize=None)
def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
all_cb = []
for d_in in range(max_degree + 1):
for d_out in range(max_degree + 1):
K_Js = []
for J in range(abs(d_in - d_out), d_in + d_out + 1):
K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
all_cb.append(K_Js)
return all_cb
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
all_degrees = list(range(2 * max_degree + 1))
with nvtx_range('spherical harmonics'):
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
@torch.jit.script
def get_basis_script(max_degree: int,
use_pad_trick: bool,
spherical_harmonics: List[Tensor],
clebsch_gordon: List[List[Tensor]],
amp: bool) -> Dict[str, Tensor]:
"""
Compute pairwise bases matrices for degrees up to max_degree
:param max_degree: Maximum input or output degree
:param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores
:param spherical_harmonics: List of computed spherical harmonics
:param clebsch_gordon: List of computed CB-coefficients
:param amp: When true, return bases in FP16 precision
"""
basis = {}
idx = 0
# Double for loop instead of product() because of JIT script
for d_in in range(max_degree + 1):
for d_out in range(max_degree + 1):
key = f'{d_in},{d_out}'
K_Js = []
for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
Q_J = clebsch_gordon[idx][freq_idx]
K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))
basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k
if amp:
basis[key] = basis[key].half()
if use_pad_trick:
basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later
idx += 1
return basis
@torch.jit.script
def update_basis_with_fused(basis: Dict[str, Tensor],
max_degree: int,
use_pad_trick: bool,
fully_fused: bool) -> Dict[str, Tensor]:
""" Update the basis dict with partially and optionally fully fused bases """
num_edges = basis['0,0'].shape[0]
device = basis['0,0'].device
dtype = basis['0,0'].dtype
sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])
# Fused per output degree
for d_out in range(max_degree + 1):
sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
device=device, dtype=dtype)
acc_d, acc_f = 0, 0
for d_in in range(max_degree + 1):
basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
:degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
acc_d += degree_to_dim(d_in)
acc_f += degree_to_dim(min(d_out, d_in))
basis[f'out{d_out}_fused'] = basis_fused
# Fused per input degree
for d_in in range(max_degree + 1):
sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
device=device, dtype=dtype)
acc_d, acc_f = 0, 0
for d_out in range(max_degree + 1):
basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
= basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
acc_d += degree_to_dim(d_out)
acc_f += degree_to_dim(min(d_out, d_in))
basis[f'in{d_in}_fused'] = basis_fused
if fully_fused:
# Fully fused
# Double sum this way because of JIT script
sum_freq = sum([
sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
])
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)
acc_d, acc_f = 0, 0
for d_out in range(max_degree + 1):
b = basis[f'out{d_out}_fused']
basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
:degree_to_dim(d_out)]
acc_f += b.shape[2]
acc_d += degree_to_dim(d_out)
basis['fully_fused'] = basis_fused
del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant
return basis
def get_basis(relative_pos: Tensor,
max_degree: int = 4,
compute_gradients: bool = False,
use_pad_trick: bool = False,
amp: bool = False) -> Dict[str, Tensor]:
with nvtx_range('spherical harmonics'):
spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
with nvtx_range('CB coefficients'):
clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
with torch.autograd.set_grad_enabled(compute_gradients):
with nvtx_range('bases'):
basis = get_basis_script(max_degree=max_degree,
use_pad_trick=use_pad_trick,
spherical_harmonics=spherical_harmonics,
clebsch_gordon=clebsch_gordon,
amp=amp)
return basis

View file

@ -0,0 +1,144 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from collections import namedtuple
from itertools import product
from typing import Dict
import torch
from torch import Tensor
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
class Fiber(dict):
"""
Describes the structure of some set of features.
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
Type-0 features: invariant scalars
Type-1 features: equivariant 3D vectors
Type-2 features: equivariant symmetric traceless matrices
...
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
The 'multiplicity' or 'number of channels' is the number of features of a given type.
This class puts together all the degrees and their multiplicities in order to describe
the inputs, outputs or hidden features of SE3 layers.
"""
def __init__(self, structure):
if isinstance(structure, dict):
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
elif not isinstance(structure[0], FiberEl):
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
self.structure = structure
super().__init__({d: m for d, m in self.structure})
@property
def degrees(self):
return sorted([t.degree for t in self.structure])
@property
def channels(self):
return [self[d] for d in self.degrees]
@property
def num_features(self):
""" Size of the resulting tensor if all features were concatenated together """
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
@staticmethod
def create(num_degrees: int, num_channels: int):
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
@staticmethod
def from_features(feats: Dict[str, Tensor]):
""" Infer the Fiber structure from a feature dict """
structure = {}
for k, v in feats.items():
degree = int(k)
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
assert v.shape[-1] == degree_to_dim(degree)
structure[degree] = v.shape[-2]
return Fiber(structure)
def __getitem__(self, degree: int):
""" fiber[degree] returns the multiplicity for this degree """
return dict(self.structure).get(degree, 0)
def __iter__(self):
""" Iterate over namedtuples (degree, channels) """
return iter(self.structure)
def __mul__(self, other):
"""
If other in an int, multiplies all the multiplicities by other.
If other is a fiber, returns the cartesian product.
"""
if isinstance(other, Fiber):
return product(self.structure, other.structure)
elif isinstance(other, int):
return Fiber({t.degree: t.channels * other for t in self.structure})
def __add__(self, other):
"""
If other in an int, add other to all the multiplicities.
If other is a fiber, add the multiplicities of the fibers together.
"""
if isinstance(other, Fiber):
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
elif isinstance(other, int):
return Fiber({t.degree: t.channels + other for t in self.structure})
def __repr__(self):
return str(self.structure)
@staticmethod
def combine_max(f1, f2):
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
new_dict = dict(f1.structure)
for k, m in f2.structure:
new_dict[k] = max(new_dict.get(k, 0), m)
return Fiber(list(new_dict.items()))
@staticmethod
def combine_selectively(f1, f2):
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
# only use orders which occur in fiber f1
new_dict = dict(f1.structure)
for k in f1.degrees:
if k in f2.degrees:
new_dict[k] += f2[k]
return Fiber(list(new_dict.items()))
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
self.degrees]
fibers = torch.cat(fibers, -1)
return fibers

View file

@ -0,0 +1,5 @@
from .linear import LinearSE3
from .norm import NormSE3
from .pooling import GPooling
from .convolution import ConvSE3
from .attention import AttentionBlockSE3

View file

@ -0,0 +1,186 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import dgl
import numpy as np
import torch
import torch.nn as nn
from dgl import DGLGraph
from dgl.ops import edge_softmax
from torch import Tensor
from typing import Dict, Optional, Union
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
from torch.cuda.nvtx import range as nvtx_range
class AttentionSE3(nn.Module):
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
def __init__(
self,
num_heads: int,
key_fiber: Fiber,
value_fiber: Fiber
):
"""
:param num_heads: Number of attention heads
:param key_fiber: Fiber for the keys (and also for the queries)
:param value_fiber: Fiber for the values
"""
super().__init__()
self.num_heads = num_heads
self.key_fiber = key_fiber
self.value_fiber = value_fiber
def forward(
self,
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
query: Dict[str, Tensor], # node features
graph: DGLGraph
):
with nvtx_range('AttentionSE3'):
with nvtx_range('reshape keys and queries'):
if isinstance(key, Tensor):
# case where features of all types are fused
key = key.reshape(key.shape[0], self.num_heads, -1)
# need to reshape queries that way to keep the same layout as keys
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
else:
# features are not fused, need to fuse and reshape them
key = self.key_fiber.to_attention_heads(key, self.num_heads)
query = self.key_fiber.to_attention_heads(query, self.num_heads)
with nvtx_range('attention dot product + softmax'):
# Compute attention weights (softmax of inner product between key and query)
with torch.cuda.amp.autocast(False):
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
edge_weights /= np.sqrt(self.key_fiber.num_features)
edge_weights = edge_softmax(graph, edge_weights)
edge_weights = edge_weights[..., None, None]
with nvtx_range('weighted sum'):
if isinstance(value, Tensor):
# features of all types are fused
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
weights = edge_weights * v
feat_out = dgl.ops.copy_e_sum(graph, weights)
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
out = unfuse_features(feat_out, self.value_fiber.degrees)
else:
out = {}
for degree, channels in self.value_fiber:
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
degree_to_dim(degree))
weights = edge_weights * v
res = dgl.ops.copy_e_sum(graph, weights)
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
return out
class AttentionBlockSE3(nn.Module):
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
def __init__(
self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Optional[Fiber] = None,
num_heads: int = 4,
channels_div: Optional[Dict[str,int]] = None,
use_layer_norm: bool = False,
max_degree: bool = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
**kwargs
):
"""
:param fiber_in: Fiber describing the input features
:param fiber_out: Fiber describing the output features
:param fiber_edge: Fiber describing the edge features (node distances excluded)
:param num_heads: Number of attention heads
:param channels_div: Divide the channels by this integer for computing values
:param use_layer_norm: Apply layer normalization between MLP layers
:param max_degree: Maximum degree used in the bases computation
:param fuse_level: Maximum fuse level to use in TFN convolutions
"""
super().__init__()
if fiber_edge is None:
fiber_edge = Fiber({})
self.fiber_in = fiber_in
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
if channels_div is not None:
value_fiber = Fiber([(degree, channels // channels_div[str(degree)]) for degree, channels in fiber_out])
else:
value_fiber = Fiber([(degree, channels) for degree, channels in fiber_out])
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
# (queries are merely projected, hence degrees have to match input)
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
allow_fused_output=True)
self.to_query = LinearSE3(fiber_in, key_query_fiber)
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
def forward(
self,
node_features: Dict[str, Tensor],
edge_features: Dict[str, Tensor],
graph: DGLGraph,
basis: Dict[str, Tensor]
):
with nvtx_range('AttentionBlockSE3'):
with nvtx_range('keys / values'):
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
key, value = self._get_key_value_from_fused(fused_key_value)
with nvtx_range('queries'):
with torch.cuda.amp.autocast(False):
query = self.to_query(node_features)
z = self.attention(value, key, query, graph)
z_concat = aggregate_residual(node_features, z, 'cat')
return self.project(z_concat)
def _get_key_value_from_fused(self, fused_key_value):
# Extract keys and queries features from fused features
if isinstance(fused_key_value, Tensor):
# Previous layer was a fully fused convolution
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
else:
key, value = {}, {}
for degree, feat in fused_key_value.items():
if int(degree) in self.fiber_in.degrees:
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
else:
value[degree] = feat
return key, value

View file

@ -0,0 +1,381 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from enum import Enum
from itertools import product
from typing import Dict
import dgl
import numpy as np
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, unfuse_features
class ConvSE3FuseLevel(Enum):
"""
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
A higher level means faster training, but also more memory usage.
If you are tight on memory and want to feed large inputs to the network, choose a low value.
If you want to train fast, choose a high value.
Recommended value is FULL with AMP.
Fully fused TFN convolutions requirements:
- all input channels are the same
- all output channels are the same
- input degrees span the range [0, ..., max_degree]
- output degrees span the range [0, ..., max_degree]
Partially fused TFN convolutions requirements:
* For fusing by output degree:
- all input channels are the same
- input degrees span the range [0, ..., max_degree]
* For fusing by input degree:
- all output channels are the same
- output degrees span the range [0, ..., max_degree]
Original TFN pairwise convolutions: no requirements
"""
FULL = 2
PARTIAL = 1
NONE = 0
class RadialProfile(nn.Module):
"""
Radial profile function.
Outputs weights used to weigh basis matrices in order to get convolution kernels.
In TFN notation: $R^{l,k}$
In SE(3)-Transformer notation: $\phi^{l,k}$
Note:
In the original papers, this function only depends on relative node distances ||x||.
Here, we allow this function to also take as input additional invariant edge features.
This does not break equivariance and adds expressive power to the model.
Diagram:
invariant edge features (node distances included) > MLP layer (shared across edges) > radial weights
"""
def __init__(
self,
num_freq: int,
channels_in: int,
channels_out: int,
edge_dim: int = 1,
mid_dim: int = 32,
use_layer_norm: bool = False
):
"""
:param num_freq: Number of frequencies
:param channels_in: Number of input channels
:param channels_out: Number of output channels
:param edge_dim: Number of invariant edge features (input to the radial function)
:param mid_dim: Size of the hidden MLP layers
:param use_layer_norm: Apply layer normalization between MLP layers
"""
super().__init__()
modules = [
nn.Linear(edge_dim, mid_dim),
nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.ReLU(),
nn.Linear(mid_dim, mid_dim),
nn.LayerNorm(mid_dim) if use_layer_norm else None,
nn.ReLU(),
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
]
self.net = nn.Sequential(*[m for m in modules if m is not None])
def forward(self, features: Tensor) -> Tensor:
return self.net(features)
class VersatileConvSE3(nn.Module):
"""
Building block for TFN convolutions.
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
"""
def __init__(self,
freq_sum: int,
channels_in: int,
channels_out: int,
edge_dim: int,
use_layer_norm: bool,
fuse_level: ConvSE3FuseLevel):
super().__init__()
self.freq_sum = freq_sum
self.channels_out = channels_out
self.channels_in = channels_in
self.fuse_level = fuse_level
self.radial_func = RadialProfile(num_freq=freq_sum,
channels_in=channels_in,
channels_out=channels_out,
edge_dim=edge_dim,
use_layer_norm=use_layer_norm)
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
with nvtx_range(f'VersatileConvSE3'):
num_edges = features.shape[0]
in_dim = features.shape[2]
if (self.training or num_edges<=4096):
with nvtx_range(f'RadialProfile'):
radial_weights = self.radial_func(invariant_edge_feats) \
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
if basis is not None:
# This block performs the einsum n i l, n o i f, n l f k -> n o k
out_dim = basis.shape[-1]
if self.fuse_level != ConvSE3FuseLevel.FULL:
out_dim += out_dim % 2 - 1 # Account for padded basis
basis_view = basis.view(num_edges, in_dim, -1)
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
retval = (radial_weights @ tmp)[:, :, :out_dim]
return retval
else:
# k = l = 0 non-fused case
retval = radial_weights @ features
else:
#fd reduce memory in inference
EDGESTRIDE = 65536 #16384
if basis is not None:
out_dim = basis.shape[-1]
if self.fuse_level != ConvSE3FuseLevel.FULL:
out_dim += out_dim % 2 - 1 # Account for padded basis
else:
out_dim = features.shape[-1]
retval = torch.zeros(
(num_edges, self.channels_out, out_dim),
dtype=features.dtype,
device=features.device
)
for i in range((num_edges-1)//EDGESTRIDE+1):
e_i,e_j = i*EDGESTRIDE, min((i+1)*EDGESTRIDE,num_edges)
radial_weights = self.radial_func(invariant_edge_feats[e_i:e_j]) \
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
if basis is not None:
# This block performs the einsum n i l, n o i f, n l f k -> n o k
basis_view = basis[e_i:e_j].view(e_j-e_i, in_dim, -1)
with torch.cuda.amp.autocast(False):
tmp = (features[e_i:e_j] @ basis_view.float()).view(e_j-e_i, -1, basis.shape[-1])
retslice = (radial_weights.float() @ tmp)[:, :, :out_dim]
retval[e_i:e_j] = retslice
else:
# k = l = 0 non-fused case
retval[e_i:e_j] = radial_weights @ features[e_i:e_j]
return retval
class ConvSE3(nn.Module):
"""
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
Features of different degrees interact together to produce output features.
Note 1:
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
done, and the returned features will be edge features instead of node features.
Note 2:
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
Input edge features are concatenated with input source node features before the kernel is applied.
"""
def __init__(
self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Fiber,
pool: bool = True,
use_layer_norm: bool = False,
self_interaction: bool = False,
sum_over_edge: bool = True,
max_degree: int = 4,
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
allow_fused_output: bool = False
):
"""
:param fiber_in: Fiber describing the input features
:param fiber_out: Fiber describing the output features
:param fiber_edge: Fiber describing the edge features (node distances excluded)
:param pool: If True, compute final node features by averaging incoming edge features
:param use_layer_norm: Apply layer normalization between MLP layers
:param self_interaction: Apply self-interaction of nodes
:param max_degree: Maximum degree used in the bases computation
:param fuse_level: Maximum fuse level to use in TFN convolutions
:param allow_fused_output: Allow the module to output a fused representation of features
"""
super().__init__()
self.pool = pool
self.fiber_in = fiber_in
self.fiber_out = fiber_out
self.self_interaction = self_interaction
self.sum_over_edge = sum_over_edge
self.max_degree = max_degree
self.allow_fused_output = allow_fused_output
# channels_in: account for the concatenation of edge features
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
channels_out_set = set([f.channels for f in self.fiber_out])
unique_channels_in = (len(channels_in_set) == 1)
unique_channels_out = (len(channels_out_set) == 1)
degrees_up_to_max = list(range(max_degree + 1))
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
# Single fused convolution
self.used_fuse_level = ConvSE3FuseLevel.FULL
sum_freq = sum([
degree_to_dim(min(d_in, d_out))
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
])
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
fuse_level=self.used_fuse_level, **common_args)
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
# Convolutions fused per output degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_out = nn.ModuleDict()
for d_out, c_out in fiber_out:
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
fuse_level=self.used_fuse_level, **common_args)
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
# Convolutions fused per input degree
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
self.conv_in = nn.ModuleDict()
for d_in, c_in in fiber_in:
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
fuse_level=ConvSE3FuseLevel.FULL, **common_args)
else:
# Use pairwise TFN convolutions
self.used_fuse_level = ConvSE3FuseLevel.NONE
self.conv = nn.ModuleDict()
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
dict_key = f'{degree_in},{degree_out}'
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
sum_freq = degree_to_dim(min(degree_in, degree_out))
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
fuse_level=self.used_fuse_level, **common_args)
if self_interaction:
self.to_kernel_self = nn.ParameterDict()
for degree_out, channels_out in fiber_out:
if fiber_in[degree_out]:
self.to_kernel_self[str(degree_out)] = nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
def forward(
self,
node_feats: Dict[str, Tensor],
edge_feats: Dict[str, Tensor],
graph: DGLGraph,
basis: Dict[str, Tensor]
):
with nvtx_range(f'ConvSE3'):
invariant_edge_feats = edge_feats['0'].squeeze(-1)
src, dst = graph.edges()
out = {}
in_features = []
# Fetch all input features from edge and node features
for degree_in in self.fiber_in.degrees:
src_node_features = node_feats[str(degree_in)][src]
if degree_in > 0 and str(degree_in) in edge_feats:
# Handle edge features of any type by concatenating them to node features
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
in_features.append(src_node_features)
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
in_features_fused = torch.cat(in_features, dim=-1)
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
in_features_fused = torch.cat(in_features, dim=-1)
for degree_out in self.fiber_out.degrees:
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats,
basis[f'out{degree_out}_fused'])
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
out = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
out += self.conv_in[str(degree_in)](feature, invariant_edge_feats,
basis[f'in{degree_in}_fused'])
if not self.allow_fused_output or self.self_interaction or self.pool:
out = unfuse_features(out, self.fiber_out.degrees)
else:
# Fallback to pairwise TFN convolutions
for degree_out in self.fiber_out.degrees:
out_feature = 0
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
dict_key = f'{degree_in},{degree_out}'
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats,
basis.get(dict_key, None))
out[str(degree_out)] = out_feature
for degree_out in self.fiber_out.degrees:
if self.self_interaction and str(degree_out) in self.to_kernel_self:
with nvtx_range(f'self interaction'):
dst_features = node_feats[str(degree_out)][dst]
kernel_self = self.to_kernel_self[str(degree_out)]
out[str(degree_out)] += kernel_self @ dst_features
if self.pool:
if self.sum_over_edge:
with nvtx_range(f'pooling'):
if isinstance(out, dict):
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
else:
out = dgl.ops.copy_e_sum(graph, out)
else:
with nvtx_range(f'pooling'):
if isinstance(out, dict):
out[str(degree_out)] = dgl.ops.copy_e_mean(graph, out[str(degree_out)])
else:
out = dgl.ops.copy_e_mean(graph, out)
return out

View file

@ -0,0 +1,59 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
class LinearSE3(nn.Module):
"""
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
Maps a fiber to a fiber with the same degrees (channels may be different).
No interaction between degrees, but interaction between channels.
type-0 features (C_0 channels) > Linear(bias=False) > type-0 features (C'_0 channels)
type-1 features (C_1 channels) > Linear(bias=False) > type-1 features (C'_1 channels)
:
type-k features (C_k channels) > Linear(bias=False) > type-k features (C'_k channels)
"""
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
super().__init__()
self.weights = nn.ParameterDict({
str(degree_out): nn.Parameter(
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
for degree_out, channels_out in fiber_out
})
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
return {
degree: self.weights[degree] @ features[degree]
for degree, weight in self.weights.items()
}

View file

@ -0,0 +1,83 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict
import torch
import torch.nn as nn
from torch import Tensor
from torch.cuda.nvtx import range as nvtx_range
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
class NormSE3(nn.Module):
"""
Norm-based SE(3)-equivariant nonlinearity.
> feature_norm > LayerNorm() > ReLU()
feature_in * > feature_out
> feature_phase
"""
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
super().__init__()
self.fiber = fiber
self.nonlinearity = nonlinearity
if len(set(fiber.channels)) == 1:
# Fuse all the layer normalizations into a group normalization
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
else:
# Use multiple layer normalizations
self.layer_norms = nn.ModuleDict({
str(degree): nn.LayerNorm(channels)
for degree, channels in fiber
})
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
with nvtx_range('NormSE3'):
output = {}
if hasattr(self, 'group_norm'):
# Compute per-degree norms of features
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
for d in self.fiber.degrees]
fused_norms = torch.cat(norms, dim=-2)
# Transform the norms only
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
# Scale features to the new norms
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
output[str(d)] = features[str(d)] / norm * new_norm
else:
for degree, feat in features.items():
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
output[degree] = new_norm * feat / norm
return output

View file

@ -0,0 +1,53 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import Dict, Literal
import torch.nn as nn
from dgl import DGLGraph
from dgl.nn.pytorch import AvgPooling, MaxPooling
from torch import Tensor
class GPooling(nn.Module):
"""
Graph max/average pooling on a given feature type.
The average can be taken for any feature type, and equivariance will be maintained.
The maximum can only be taken for invariant features (type 0).
If you want max-pooling for type > 0 features, look into Vector Neurons.
"""
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
"""
:param feat_type: Feature type to pool
:param pool: Type of pooling: max or avg
"""
super().__init__()
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
self.feat_type = feat_type
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
pooled = self.pool(graph, features[str(self.feat_type)])
return pooled.squeeze(dim=-1)

View file

@ -0,0 +1,257 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
from typing import Optional, Literal, Dict
import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor
from rf2aa.SE3Transformer.se3_transformer.model.basis import get_basis, update_basis_with_fused
from rf2aa.SE3Transformer.se3_transformer.model.layers.attention import AttentionBlockSE3
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
from rf2aa.SE3Transformer.se3_transformer.model.layers.norm import NormSE3
from rf2aa.SE3Transformer.se3_transformer.model.layers.pooling import GPooling
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
class Sequential(nn.Sequential):
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
def forward(self, input, *args, **kwargs):
for module in self:
input = module(input, *args, **kwargs)
return input
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
""" Add relative positions to existing edge features """
edge_features = edge_features.copy() if edge_features else {}
r = relative_pos.norm(dim=-1, keepdim=True)
if '0' in edge_features:
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
else:
edge_features['0'] = r[..., None]
return edge_features
class SE3Transformer(nn.Module):
def __init__(self,
num_layers: int,
fiber_in: Fiber,
fiber_hidden: Fiber,
fiber_out: Fiber,
num_heads: int,
channels_div: int,
fiber_edge: Fiber = Fiber({}),
return_type: Optional[int] = None,
pooling: Optional[Literal['avg', 'max']] = None,
final_layer: Optional[Literal['conv', 'lin', 'att']] = 'conv',
norm: bool = True,
use_layer_norm: bool = True,
tensor_cores: bool = False,
low_memory: bool = False,
populate_edge: Optional[Literal['lin', 'arcsin', 'log', 'zero']] = 'lin',
sum_over_edge: bool = True,
**kwargs):
"""
:param num_layers: Number of attention layers
:param fiber_in: Input fiber description
:param fiber_hidden: Hidden fiber description
:param fiber_out: Output fiber description
:param fiber_edge: Input edge fiber description
:param num_heads: Number of attention heads
:param channels_div: Channels division before feeding to attention layer
:param return_type: Return only features of this type
:param pooling: 'avg' or 'max' graph pooling before MLP layers
:param norm: Apply a normalization layer after each attention block
:param use_layer_norm: Apply layer normalization between MLP layers
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
:param low_memory: If True, will use slower ops that use less memory
"""
super().__init__()
self.num_layers = num_layers
self.fiber_edge = fiber_edge
self.num_heads = num_heads
self.channels_div = channels_div
self.return_type = return_type
self.pooling = pooling
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
self.tensor_cores = tensor_cores
self.low_memory = low_memory
self.populate_edge = populate_edge
if low_memory and not tensor_cores:
logging.warning('Low memory mode will have no effect with no Tensor Cores')
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
div = dict((str(degree), channels_div) for degree in range(self.max_degree+1))
div_fin = dict((str(degree), 1) for degree in range(self.max_degree+1))
div_fin['0'] = channels_div
graph_modules = []
for i in range(num_layers):
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
fiber_out=fiber_hidden,
fiber_edge=fiber_edge,
num_heads=num_heads,
channels_div=div,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree,
fuse_level=fuse_level))
if norm:
graph_modules.append(NormSE3(fiber_hidden))
fiber_in = fiber_hidden
if final_layer == 'conv':
graph_modules.append(ConvSE3(fiber_in=fiber_in,
fiber_out=fiber_out,
fiber_edge=fiber_edge,
self_interaction=True,
sum_over_edge=sum_over_edge,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree))
elif final_layer == "lin":
graph_modules.append(LinearSE3(fiber_in=fiber_in,
fiber_out=fiber_out))
else:
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
fiber_out=fiber_out,
fiber_edge=fiber_edge,
num_heads=1,
channels_div=div_fin,
use_layer_norm=use_layer_norm,
max_degree=self.max_degree,
fuse_level=fuse_level))
self.graph_modules = Sequential(*graph_modules)
if pooling is not None:
assert return_type is not None, 'return_type must be specified when pooling'
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
edge_feats: Optional[Dict[str, Tensor]] = None,
basis: Optional[Dict[str, Tensor]] = None):
# Compute bases in case they weren't precomputed as part of the data loading
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
use_pad_trick=self.tensor_cores and not self.low_memory,
amp=torch.is_autocast_enabled())
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
fully_fused=self.tensor_cores and not self.low_memory)
if self.populate_edge=='lin':
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
elif self.populate_edge=='arcsin':
r = graph.edata['rel_pos'].norm(dim=-1, keepdim=True)
r = torch.maximum(r, torch.zeros_like(r) + 4.0) - 4.0
r = torch.arcsinh(r)/3.0
edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1)
elif self.populate_edge=='log':
# fd - replace with log(1+x)
r = torch.log( 1 + graph.edata['rel_pos'].norm(dim=-1, keepdim=True) )
edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1)
else:
edge_feats['0'] = torch.cat((edge_feats['0'], torch.zeros_like(edge_feats['0'][:,:1,:])), dim=1)
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
if self.pooling is not None:
return self.pooling_module(node_feats, graph=graph)
if self.return_type is not None:
return node_feats[str(self.return_type)]
return node_feats
@staticmethod
def add_argparse_args(parser):
parser.add_argument('--num_layers', type=int, default=7,
help='Number of stacked Transformer layers')
parser.add_argument('--num_heads', type=int, default=8,
help='Number of heads in self-attention')
parser.add_argument('--channels_div', type=int, default=2,
help='Channels division before feeding to attention layer')
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
help='Type of graph pooling')
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
help='Apply a normalization layer after each attention block')
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
help='Apply layer normalization between MLP layers')
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
help='If true, will use fused ops that are slower but that use less memory '
'(expect 25 percent less memory). '
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
return parser
class SE3TransformerPooled(nn.Module):
def __init__(self,
fiber_in: Fiber,
fiber_out: Fiber,
fiber_edge: Fiber,
num_degrees: int,
num_channels: int,
output_dim: int,
**kwargs):
super().__init__()
kwargs['pooling'] = kwargs['pooling'] or 'max'
self.transformer = SE3Transformer(
fiber_in=fiber_in,
fiber_hidden=Fiber.create(num_degrees, num_channels),
fiber_out=fiber_out,
fiber_edge=fiber_edge,
return_type=0,
**kwargs
)
n_out_features = fiber_out.num_features
self.mlp = nn.Sequential(
nn.Linear(n_out_features, n_out_features),
nn.ReLU(),
nn.Linear(n_out_features, output_dim)
)
def forward(self, graph, node_feats, edge_feats, basis=None):
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
y = self.mlp(feats).squeeze(-1)
return y
@staticmethod
def add_argparse_args(parent_parser):
parser = parent_parser.add_argument_group("Model architecture")
SE3Transformer.add_argparse_args(parser)
parser.add_argument('--num_degrees',
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
type=int, default=4)
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
return parent_parser

View file

@ -0,0 +1,70 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import argparse
import pathlib
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
paths = PARSER.add_argument_group('Paths')
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
help='Directory where the data is located or should be downloaded')
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
help='Directory where the results logs should be saved')
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
help='Name for the resulting DLLogger JSON file')
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
help='File where the checkpoint should be saved')
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
help='File of the checkpoint to be loaded')
optimizer = PARSER.add_argument_group('Optimizer')
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
optimizer.add_argument('--momentum', type=float, default=0.9)
optimizer.add_argument('--weight_decay', type=float, default=0.1)
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
help='Do an evaluation round every N epochs')
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
help='Minimize stdout output')
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
help='Benchmark mode')
QM9DataModule.add_argparse_args(PARSER)
SE3TransformerPooled.add_argparse_args(PARSER)

View file

@ -0,0 +1,160 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
import time
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
import torch
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import Logger
from rf2aa.SE3Transformer.se3_transformer.runtime.metrics import MeanAbsoluteError
class BaseCallback(ABC):
def on_fit_start(self, optimizer, args):
pass
def on_fit_end(self):
pass
def on_epoch_end(self):
pass
def on_batch_start(self):
pass
def on_validation_step(self, input, target, pred):
pass
def on_validation_end(self, epoch=None):
pass
def on_checkpoint_load(self, checkpoint):
pass
def on_checkpoint_save(self, checkpoint):
pass
class LRSchedulerCallback(BaseCallback):
def __init__(self, logger: Optional[Logger] = None):
self.logger = logger
self.scheduler = None
@abstractmethod
def get_scheduler(self, optimizer, args):
pass
def on_fit_start(self, optimizer, args):
self.scheduler = self.get_scheduler(optimizer, args)
def on_checkpoint_load(self, checkpoint):
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
def on_checkpoint_save(self, checkpoint):
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
def on_epoch_end(self):
if self.logger is not None:
self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch)
self.scheduler.step()
class QM9MetricCallback(BaseCallback):
""" Logs the rescaled mean absolute error for QM9 regression tasks """
def __init__(self, logger, targets_std, prefix=''):
self.mae = MeanAbsoluteError()
self.logger = logger
self.targets_std = targets_std
self.prefix = prefix
self.best_mae = float('inf')
def on_validation_step(self, input, target, pred):
self.mae(pred.detach(), target.detach())
def on_validation_end(self, epoch=None):
mae = self.mae.compute() * self.targets_std
logging.info(f'{self.prefix} MAE: {mae}')
self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch)
self.best_mae = min(self.best_mae, mae)
def on_fit_end(self):
if self.best_mae != float('inf'):
self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae})
class QM9LRSchedulerCallback(LRSchedulerCallback):
def __init__(self, logger, epochs):
super().__init__(logger)
self.epochs = epochs
def get_scheduler(self, optimizer, args):
min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr)
class PerformanceCallback(BaseCallback):
def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'):
self.batch_size = batch_size
self.warmup_epochs = warmup_epochs
self.epoch = 0
self.timestamps = []
self.mode = mode
self.logger = logger
def on_batch_start(self):
if self.epoch >= self.warmup_epochs:
self.timestamps.append(time.time() * 1000.0)
def _log_perf(self):
stats = self.process_performance_stats()
for k, v in stats.items():
logging.info(f'performance {k}: {v}')
self.logger.log_metrics(stats)
def on_epoch_end(self):
self.epoch += 1
def on_fit_end(self):
if self.epoch > self.warmup_epochs:
self._log_perf()
self.timestamps = []
def process_performance_stats(self):
timestamps = np.asarray(self.timestamps)
deltas = np.diff(timestamps)
throughput = (self.batch_size / deltas).mean()
stats = {
f"throughput_{self.mode}": throughput,
f"latency_{self.mode}_mean": deltas.mean(),
f"total_time_{self.mode}": timestamps[-1] - timestamps[0],
}
for level in [90, 95, 99]:
stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)})
return stats

View file

@ -0,0 +1,325 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import collections
import itertools
import math
import os
import pathlib
import re
import pynvml
class Device:
# assumes nvml returns list of 64 bit ints
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
def __init__(self, device_idx):
super().__init__()
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
def get_name(self):
return pynvml.nvmlDeviceGetName(self.handle)
def get_uuid(self):
return pynvml.nvmlDeviceGetUUID(self.handle)
def get_cpu_affinity(self):
affinity_string = ""
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
# assume nvml returns list of 64 bit ints
affinity_string = "{:064b}".format(j) + affinity_string
affinity_list = [int(x) for x in affinity_string]
affinity_list.reverse() # so core 0 is in 0th element of list
ret = [i for i, e in enumerate(affinity_list) if e != 0]
return ret
def get_thread_siblings_list():
"""
Returns a list of 2-element integer tuples representing pairs of
hyperthreading cores.
"""
path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list"
thread_siblings_list = []
pattern = re.compile(r"(\d+)\D(\d+)")
for fname in pathlib.Path(path[0]).glob(path[1:]):
with open(fname) as f:
content = f.read().strip()
res = pattern.findall(content)
if res:
pair = tuple(map(int, res[0]))
thread_siblings_list.append(pair)
return thread_siblings_list
def check_socket_affinities(socket_affinities):
# sets of cores should be either identical or disjoint
for i, j in itertools.product(socket_affinities, socket_affinities):
if not set(i) == set(j) and not set(i).isdisjoint(set(j)):
raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.")
def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True):
devices = [Device(i) for i in range(nproc_per_node)]
socket_affinities = [dev.get_cpu_affinity() for dev in devices]
if exclude_unavailable_cores:
available_cores = os.sched_getaffinity(0)
socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities]
check_socket_affinities(socket_affinities)
return socket_affinities
def set_socket_affinity(gpu_id):
"""
The process is assigned with all available logical CPU cores from the CPU
socket connected to the GPU with a given id.
Args:
gpu_id: index of a GPU
"""
dev = Device(gpu_id)
affinity = dev.get_cpu_affinity()
os.sched_setaffinity(0, affinity)
def set_single_affinity(gpu_id):
"""
The process is assigned with the first available logical CPU core from the
list of all CPU cores from the CPU socket connected to the GPU with a given
id.
Args:
gpu_id: index of a GPU
"""
dev = Device(gpu_id)
affinity = dev.get_cpu_affinity()
# exclude unavailable cores
available_cores = os.sched_getaffinity(0)
affinity = list(set(affinity) & available_cores)
os.sched_setaffinity(0, affinity[:1])
def set_single_unique_affinity(gpu_id, nproc_per_node):
"""
The process is assigned with a single unique available physical CPU core
from the list of all CPU cores from the CPU socket connected to the GPU with
a given id.
Args:
gpu_id: index of a GPU
"""
socket_affinities = get_socket_affinities(nproc_per_node)
siblings_list = get_thread_siblings_list()
siblings_dict = dict(siblings_list)
# remove siblings
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
affinities = []
assigned = []
for socket_affinity in socket_affinities:
for core in socket_affinity:
if core not in assigned:
affinities.append([core])
assigned.append(core)
break
os.sched_setaffinity(0, affinities[gpu_id])
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True):
"""
The process is assigned with an unique subset of available physical CPU
cores from the CPU socket connected to a GPU with a given id.
Assignment automatically includes hyperthreading siblings (if siblings are
available).
Args:
gpu_id: index of a GPU
nproc_per_node: total number of processes per node
mode: mode
balanced: assign an equal number of physical cores to each process
"""
socket_affinities = get_socket_affinities(nproc_per_node)
siblings_list = get_thread_siblings_list()
siblings_dict = dict(siblings_list)
# remove hyperthreading siblings
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
socket_affinities_to_device_ids = collections.defaultdict(list)
for idx, socket_affinity in enumerate(socket_affinities):
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
# compute minimal number of physical cores per GPU across all GPUs and
# sockets, code assigns this number of cores per GPU if balanced == True
min_physical_cores_per_gpu = min(
[len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()]
)
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
devices_per_group = len(device_ids)
if balanced:
cores_per_device = min_physical_cores_per_gpu
socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu]
else:
cores_per_device = len(socket_affinity) // devices_per_group
for group_id, device_id in enumerate(device_ids):
if device_id == gpu_id:
# In theory there should be no difference in performance between
# 'interleaved' and 'continuous' pattern on Intel-based DGX-1,
# but 'continuous' should be better for DGX A100 because on AMD
# Rome 4 consecutive cores are sharing L3 cache.
# TODO: code doesn't attempt to automatically detect layout of
# L3 cache, also external environment may already exclude some
# cores, this code makes no attempt to detect it and to align
# mapping to multiples of 4.
if mode == "interleaved":
affinity = list(socket_affinity[group_id::devices_per_group])
elif mode == "continuous":
affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device])
else:
raise RuntimeError("Unknown set_socket_unique_affinity mode")
# unconditionally reintroduce hyperthreading siblings, this step
# may result in a different numbers of logical cores assigned to
# each GPU even if balanced == True (if hyperthreading siblings
# aren't available for a subset of cores due to some external
# constraints, siblings are re-added unconditionally, in the
# worst case unavailable logical core will be ignored by
# os.sched_setaffinity().
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
os.sched_setaffinity(0, affinity)
def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True):
"""
The process is assigned with a proper CPU affinity which matches hardware
architecture on a given platform. Usually it improves and stabilizes
performance of deep learning training workloads.
This function assumes that the workload is running in multi-process
single-device mode (there are multiple training processes and each process
is running on a single GPU), which is typical for multi-GPU training
workloads using `torch.nn.parallel.DistributedDataParallel`.
Available affinity modes:
* 'socket' - the process is assigned with all available logical CPU cores
from the CPU socket connected to the GPU with a given id.
* 'single' - the process is assigned with the first available logical CPU
core from the list of all CPU cores from the CPU socket connected to the GPU
with a given id (multiple GPUs could be assigned with the same CPU core).
* 'single_unique' - the process is assigned with a single unique available
physical CPU core from the list of all CPU cores from the CPU socket
connected to the GPU with a given id.
* 'socket_unique_interleaved' - the process is assigned with an unique
subset of available physical CPU cores from the CPU socket connected to a
GPU with a given id, hyperthreading siblings are included automatically,
cores are assigned with interleaved indexing pattern
* 'socket_unique_continuous' - (the default) the process is assigned with an
unique subset of available physical CPU cores from the CPU socket connected
to a GPU with a given id, hyperthreading siblings are included
automatically, cores are assigned with continuous indexing pattern
'socket_unique_continuous' is the recommended mode for deep learning
training workloads on NVIDIA DGX machines.
Args:
gpu_id: integer index of a GPU
nproc_per_node: number of processes per node
mode: affinity mode
balanced: assign an equal number of physical cores to each process,
affects only 'socket_unique_interleaved' and
'socket_unique_continuous' affinity modes
Returns a set of logical CPU cores on which the process is eligible to run.
Example:
import argparse
import os
import gpu_affinity
import torch
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--local_rank',
type=int,
default=os.getenv('LOCAL_RANK', 0),
)
args = parser.parse_args()
nproc_per_node = torch.cuda.device_count()
affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node)
print(f'{args.local_rank}: core affinity: {affinity}')
if __name__ == "__main__":
main()
Launch the example with:
python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py
WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs.
This function restricts execution only to the CPU cores directly connected
to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half
of CPU memory bandwidth (which may be fine for many DL models).
"""
pynvml.nvmlInit()
if mode == "socket":
set_socket_affinity(gpu_id)
elif mode == "single":
set_single_affinity(gpu_id)
elif mode == "single_unique":
set_single_unique_affinity(gpu_id, nproc_per_node)
elif mode == "socket_unique_interleaved":
set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced)
elif mode == "socket_unique_continuous":
set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced)
else:
raise RuntimeError("Unknown affinity mode")
affinity = os.sched_getaffinity(0)
return affinity

View file

@ -0,0 +1,131 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from typing import List
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import BaseCallback
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import DLLogger
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank
@torch.inference_mode()
def evaluate(model: nn.Module,
dataloader: DataLoader,
callbacks: List[BaseCallback],
args):
model.eval()
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
leave=False, disable=(args.silent or get_local_rank() != 0)):
*input, target = to_cuda(batch)
for callback in callbacks:
callback.on_batch_start()
with torch.cuda.amp.autocast(enabled=args.amp):
pred = model(*input)
for callback in callbacks:
callback.on_validation_step(input, target, pred)
if __name__ == '__main__':
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import init_distributed, seed_everything
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled, Fiber
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
import torch.distributed as dist
import logging
import sys
is_distributed = init_distributed()
local_rank = get_local_rank()
args = PARSER.parse_args()
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
logging.info('====== SE(3)-Transformer ======')
logging.info('| Inference on the test set |')
logging.info('===============================')
if not args.benchmark and args.load_ckpt_path is None:
logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
sys.exit(1)
if args.benchmark:
logging.info('Running benchmark mode with one warmup pass')
if args.seed is not None:
seed_everything(args.seed)
major_cc, minor_cc = torch.cuda.get_device_capability()
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
output_dim=1,
tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively
**vars(args)
)
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
model.to(device=torch.cuda.current_device())
if args.load_ckpt_path is not None:
checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
model.load_state_dict(checkpoint['state_dict'])
if is_distributed:
nproc_per_node = torch.cuda.device_count()
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
evaluate(model,
test_dataloader,
callbacks,
args)
for callback in callbacks:
callback.on_validation_end()
if args.benchmark:
world_size = dist.get_world_size() if dist.is_initialized() else 1
callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
for _ in range(6):
evaluate(model,
test_dataloader,
callbacks,
args)
callbacks[0].on_epoch_end()
callbacks[0].on_fit_end()

View file

@ -0,0 +1,134 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import pathlib
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Any, Callable, Optional
import dllogger
import torch.distributed as dist
import wandb
from dllogger import Verbosity
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import rank_zero_only
class Logger(ABC):
@rank_zero_only
@abstractmethod
def log_hyperparams(self, params):
pass
@rank_zero_only
@abstractmethod
def log_metrics(self, metrics, step=None):
pass
@staticmethod
def _sanitize_params(params):
def _sanitize(val):
if isinstance(val, Callable):
try:
_val = val()
if isinstance(_val, Callable):
return val.__name__
return _val
except Exception:
return getattr(val, "__name__", None)
elif isinstance(val, pathlib.Path) or isinstance(val, Enum):
return str(val)
return val
return {key: _sanitize(val) for key, val in params.items()}
class LoggerCollection(Logger):
def __init__(self, loggers):
super().__init__()
self.loggers = loggers
def __getitem__(self, index):
return [logger for logger in self.loggers][index]
@rank_zero_only
def log_metrics(self, metrics, step=None):
for logger in self.loggers:
logger.log_metrics(metrics, step)
@rank_zero_only
def log_hyperparams(self, params):
for logger in self.loggers:
logger.log_hyperparams(params)
class DLLogger(Logger):
def __init__(self, save_dir: pathlib.Path, filename: str):
super().__init__()
if not dist.is_initialized() or dist.get_rank() == 0:
save_dir.mkdir(parents=True, exist_ok=True)
dllogger.init(
backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))])
@rank_zero_only
def log_hyperparams(self, params):
params = self._sanitize_params(params)
dllogger.log(step="PARAMETER", data=params)
@rank_zero_only
def log_metrics(self, metrics, step=None):
if step is None:
step = tuple()
dllogger.log(step=step, data=metrics)
class WandbLogger(Logger):
def __init__(
self,
name: str,
save_dir: pathlib.Path,
id: Optional[str] = None,
project: Optional[str] = None
):
super().__init__()
if not dist.is_initialized() or dist.get_rank() == 0:
save_dir.mkdir(parents=True, exist_ok=True)
self.experiment = wandb.init(name=name,
project=project,
id=id,
dir=str(save_dir),
resume='allow',
anonymous='must')
@rank_zero_only
def log_hyperparams(self, params: Dict[str, Any]) -> None:
params = self._sanitize_params(params)
self.experiment.config.update(params, allow_val_change=True)
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
if step is not None:
self.experiment.log({**metrics, 'epoch': step})
else:
self.experiment.log(metrics)

View file

@ -0,0 +1,83 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
from abc import ABC, abstractmethod
import torch
import torch.distributed as dist
from torch import Tensor
class Metric(ABC):
""" Metric class with synchronization capabilities similar to TorchMetrics """
def __init__(self):
self.states = {}
def add_state(self, name: str, default: Tensor):
assert name not in self.states
self.states[name] = default.clone()
setattr(self, name, default)
def synchronize(self):
if dist.is_initialized():
for state in self.states:
dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
def __call__(self, *args, **kwargs):
self.update(*args, **kwargs)
def reset(self):
for name, default in self.states.items():
setattr(self, name, default.clone())
def compute(self):
self.synchronize()
value = self._compute().item()
self.reset()
return value
@abstractmethod
def _compute(self):
pass
@abstractmethod
def update(self, preds: Tensor, targets: Tensor):
pass
class MeanAbsoluteError(Metric):
def __init__(self):
super().__init__()
self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
def update(self, preds: Tensor, targets: Tensor):
preds = preds.detach()
n = preds.shape[0]
error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
self.total += n
self.error += error
def _compute(self):
return self.error / self.total

View file

@ -0,0 +1,238 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import logging
import pathlib
from typing import List
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from apex.optimizers import FusedAdam, FusedLAMB
from torch.nn.modules.loss import _Loss
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
PerformanceCallback
from rf2aa.SE3Transformer.se3_transformer.runtime.inference import evaluate
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
using_tensor_cores, increase_l2_fetch_granularity
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Saves model, optimizer and epoch states to path (only once per node) """
if get_local_rank() == 0:
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
checkpoint = {
'state_dict': state_dict,
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch
}
for callback in callbacks:
callback.on_checkpoint_save(checkpoint)
torch.save(checkpoint, str(path))
logging.info(f'Saved checkpoint to {str(path)}')
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
""" Loads model, optimizer and epoch states from path """
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
if isinstance(model, DistributedDataParallel):
model.module.load_state_dict(checkpoint['state_dict'])
else:
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for callback in callbacks:
callback.on_checkpoint_load(checkpoint)
logging.info(f'Loaded checkpoint from {str(path)}')
return checkpoint['epoch']
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
losses = []
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
*inputs, target = to_cuda(batch)
for callback in callbacks:
callback.on_batch_start()
with torch.cuda.amp.autocast(enabled=args.amp):
pred = model(*inputs)
loss = loss_fn(pred, target) / args.accumulate_grad_batches
grad_scaler.scale(loss).backward()
# gradient accumulation
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
if args.gradient_clip:
grad_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
grad_scaler.step(optimizer)
grad_scaler.update()
optimizer.zero_grad()
losses.append(loss.item())
return np.mean(losses)
def train(model: nn.Module,
loss_fn: _Loss,
train_dataloader: DataLoader,
val_dataloader: DataLoader,
callbacks: List[BaseCallback],
logger: Logger,
args):
device = torch.cuda.current_device()
model.to(device=device)
local_rank = get_local_rank()
world_size = dist.get_world_size() if dist.is_initialized() else 1
if dist.is_initialized():
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
model.train()
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
if args.optimizer == 'adam':
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
elif args.optimizer == 'lamb':
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
weight_decay=args.weight_decay)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
for callback in callbacks:
callback.on_fit_start(optimizer, args)
for epoch_idx in range(epoch_start, args.epochs):
if isinstance(train_dataloader.sampler, DistributedSampler):
train_dataloader.sampler.set_epoch(epoch_idx)
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
if dist.is_initialized():
loss = torch.tensor(loss, dtype=torch.float, device=device)
torch.distributed.all_reduce(loss)
loss = (loss / world_size).item()
logging.info(f'Train loss: {loss}')
logger.log_metrics({'train loss': loss}, epoch_idx)
for callback in callbacks:
callback.on_epoch_end()
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
and (epoch_idx + 1) % args.ckpt_interval == 0:
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0:
evaluate(model, val_dataloader, callbacks, args)
model.train()
for callback in callbacks:
callback.on_validation_end(epoch_idx)
if args.save_ckpt_path is not None and not args.benchmark:
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
for callback in callbacks:
callback.on_fit_end()
def print_parameters_count(model):
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
logging.info(f'Number of trainable parameters: {num_params_trainable}')
if __name__ == '__main__':
is_distributed = init_distributed()
local_rank = get_local_rank()
args = PARSER.parse_args()
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
logging.info('====== SE(3)-Transformer ======')
logging.info('| Training procedure |')
logging.info('===============================')
if args.seed is not None:
logging.info(f'Using seed {args.seed}')
seed_everything(args.seed)
logger = LoggerCollection([
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
])
datamodule = QM9DataModule(**vars(args))
model = SE3TransformerPooled(
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
output_dim=1,
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
**vars(args)
)
loss_fn = nn.L1Loss()
if args.benchmark:
logging.info('Running benchmark mode')
world_size = dist.get_world_size() if dist.is_initialized() else 1
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
else:
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
if is_distributed:
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
print_parameters_count(model)
logger.log_hyperparams(vars(args))
increase_l2_fetch_granularity()
train(model,
loss_fn,
datamodule.train_dataloader(),
datamodule.val_dataloader(),
callbacks,
logger,
args)
logging.info('Training finished successfully')

View file

@ -0,0 +1,130 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import argparse
import ctypes
import logging
import os
import random
from functools import wraps
from typing import Union, List, Dict
import numpy as np
import torch
import torch.distributed as dist
from torch import Tensor
def aggregate_residual(feats1, feats2, method: str):
""" Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """
if method in ['add', 'sum']:
return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()}
elif method in ['cat', 'concat']:
return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()}
else:
raise ValueError('Method must be add/sum or cat/concat')
def degree_to_dim(degree: int) -> int:
return 2 * degree + 1
def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]:
return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1)))
def str2bool(v: Union[bool, str]) -> bool:
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def to_cuda(x):
""" Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
if isinstance(x, Tensor):
return x.cuda(non_blocking=True)
elif isinstance(x, tuple):
return (to_cuda(v) for v in x)
elif isinstance(x, list):
return [to_cuda(v) for v in x]
elif isinstance(x, dict):
return {k: to_cuda(v) for k, v in x.items()}
else:
# DGLGraph or other objects
return x.to(device=torch.cuda.current_device())
def get_local_rank() -> int:
return int(os.environ.get('LOCAL_RANK', 0))
def init_distributed() -> bool:
world_size = int(os.environ.get('WORLD_SIZE', 1))
distributed = world_size > 1
if distributed:
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
dist.init_process_group(backend=backend, init_method='env://')
if backend == 'nccl':
torch.cuda.set_device(get_local_rank())
else:
logging.warning('Running on CPU only!')
assert torch.distributed.is_initialized()
return distributed
def increase_l2_fetch_granularity():
# maximum fetch granularity of L2: 128 bytes
_libcudart = ctypes.CDLL('libcudart.so')
# set device limit on the current device
# cudaLimitMaxL2FetchGranularity = 0x05
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
assert pValue.contents.value == 128
def seed_everything(seed):
seed = int(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def rank_zero_only(fn):
@wraps(fn)
def wrapped_fn(*args, **kwargs):
if not dist.is_initialized() or dist.get_rank() == 0:
return fn(*args, **kwargs)
return wrapped_fn
def using_tensor_cores(amp: bool) -> bool:
major_cc, minor_cc = torch.cuda.get_device_capability()
return (amp and major_cc >= 7) or major_cc >= 8

View file

@ -0,0 +1,11 @@
from setuptools import setup, find_packages
setup(
name='se3-transformer',
packages=find_packages(),
include_package_data=True,
version='1.0.0',
description='PyTorch + DGL implementation of SE(3)-Transformers',
author='Alexandre Milesi',
author_email='alexandrem@nvidia.com',
)

View file

View file

@ -0,0 +1,103 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import torch
import pytest
from rf2aa.SE3Transformer.se3_transformer.model import SE3Transformer
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
TOL = 1e-3
CHANNELS, NODES = 32, 512
def _get_outputs(model, R):
feats0 = torch.randn(NODES, CHANNELS, 1)
feats1 = torch.randn(NODES, CHANNELS, 3)
coords = torch.randn(NODES, 3)
graph = get_random_graph(NODES)
if torch.cuda.is_available():
feats0 = feats0.cuda()
feats1 = feats1.cuda()
R = R.cuda()
coords = coords.cuda()
graph = graph.to('cuda')
model.cuda()
graph1 = assign_relative_pos(graph, coords)
out1 = model(graph1, {'0': feats0, '1': feats1}, {})
graph2 = assign_relative_pos(graph, coords @ R)
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
return out1, out2
def _get_model(**kwargs):
return SE3Transformer(
num_layers=4,
fiber_in=Fiber.create(2, CHANNELS),
fiber_hidden=Fiber.create(3, CHANNELS),
fiber_out=Fiber.create(2, CHANNELS),
fiber_edge=Fiber({}),
num_heads=8,
channels_div=2,
**kwargs
)
@pytest.mark.skip
def test_equivariance():
model = _get_model()
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
@pytest.mark.skip
def test_equivariance_pooled():
model = _get_model(pooling='avg', return_type=1)
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2, (out1 @ R), atol=TOL), \
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
@pytest.mark.skip
def test_invariance_pooled():
model = _get_model(pooling='avg', return_type=0)
R = rot(*torch.rand(3))
if torch.cuda.is_available():
R = R.cuda()
out1, out2 = _get_outputs(model, R)
assert torch.allclose(out2, out1, atol=TOL), \
f'type-0 features should be invariant {get_max_diff(out1, out2)}'

View file

@ -0,0 +1,60 @@
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT
import dgl
import torch
def get_random_graph(N, num_edges_factor=18):
graph = dgl.transforms.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
return graph
def assign_relative_pos(graph, coords):
src, dst = graph.edges()
graph.edata['rel_pos'] = coords[src] - coords[dst]
return graph
def get_max_diff(a, b):
return (a - b).abs().max().item()
def rot_z(gamma):
return torch.tensor([
[torch.cos(gamma), -torch.sin(gamma), 0],
[torch.sin(gamma), torch.cos(gamma), 0],
[0, 0, 1]
], dtype=gamma.dtype)
def rot_y(beta):
return torch.tensor([
[torch.cos(beta), 0, torch.sin(beta)],
[0, 1, 0],
[-torch.sin(beta), 0, torch.cos(beta)]
], dtype=beta.dtype)
def rot(alpha, beta, gamma):
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)

0
rf2aa/__init__.py Normal file
View file

Binary file not shown.

9052
rf2aa/cartbonded.json Normal file

File diff suppressed because it is too large Load diff

2862
rf2aa/chemical.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,70 @@
job_name: "structure_prediction"
output_path: ""
checkpoint_path: RFAA_paper_weights.pt
database_params:
sequencedb: ""
hhdb: "pdb100_2022Apr19/pdb100_2022Apr19"
command: make_msa.sh
num_cpus: 4
mem: 64
protein_inputs: null
na_inputs: null
sm_inputs: null
covale_inputs: null
residue_replacement: null
chem_params:
use_phospate_frames_for_NA: True
use_cif_ordering_for_trp: True
loader_params:
n_templ: 4
MAXLAT: 128
MAXSEQ: 1024
MAXCYCLE: 4
BLACK_HOLE_INIT: False
seqid: 150.0
legacy_model_param:
n_extra_block: 4
n_main_block: 32
n_ref_block: 4
n_finetune_block: 0
d_msa: 256
d_msa_full: 64
d_pair: 192
d_templ: 64
n_head_msa: 8
n_head_pair: 6
n_head_templ: 4
d_hidden_templ: 64
p_drop: 0.0
use_chiral_l1: True
use_lj_l1: True
use_atom_frames: True
recycling_type: "all"
use_same_chain: True
lj_lin: 0.75
SE3_param:
num_layers: 1
num_channels: 32
num_degrees: 2
l0_in_features: 64
l0_out_features: 64
l1_in_features: 3
l1_out_features: 2
num_edge_features: 64
n_heads: 4
div: 4
SE3_ref_param:
num_layers: 2
num_channels: 32
num_degrees: 2
l0_in_features: 64
l0_out_features: 64
l1_in_features: 3
l1_out_features: 2
num_edge_features: 64
n_heads: 4
div: 4

View file

@ -0,0 +1,18 @@
defaults:
- base
job_name: 7s69_A
protein_inputs:
A:
fasta_file: examples/protein/7s69_A.fasta
sm_inputs:
B:
input: examples/small_molecule/7s69_glycan.sdf
input_type: sdf
covale_inputs: "[((\"A\", \"74\", \"ND2\"), (\"B\", \"1\"), (\"CW\", \"null\"))]"
loader_params:
MAXCYCLE: 10

View file

@ -0,0 +1,14 @@
defaults:
- base
job_name: "7u7w_protein_nucleic"
protein_inputs:
A:
fasta_file: examples/protein/7u7w_A.fasta
na_inputs:
B:
fasta: examples/nucleic_acid/7u7w_B.fasta
input_type: "dna"
C:
fasta: examples/nucleic_acid/7u7w_C.fasta
input_type: "dna"

View file

@ -0,0 +1,7 @@
defaults:
- base
job_name: "7u7w_protein"
protein_inputs:
A:
fasta_file: examples/protein/7u7w_A.fasta

View file

@ -0,0 +1,18 @@
defaults:
- base
job_name: "7u7w_protein_nucleic_sm"
protein_inputs:
A:
fasta_file: examples/protein/7u7w_A.fasta
na_inputs:
B:
fasta: examples/nucleic_acid/7u7w_B.fasta
input_type: "dna"
C:
fasta: examples/nucleic_acid/7u7w_C.fasta
input_type: "dna"
sm_inputs:
D:
input: examples/small_molecule/XG4.sdf
input_type: "sdf"

308
rf2aa/data/covale.py Normal file
View file

@ -0,0 +1,308 @@
import torch
from openbabel import openbabel
from typing import Optional
from dataclasses import dataclass
from tempfile import NamedTemporaryFile
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.data.parsers import parse_mol
from rf2aa.data.small_molecule import compute_features_from_obmol
from rf2aa.util import get_bond_feats
@dataclass
class MoleculeToMoleculeBond:
chain_index_first: int
absolute_atom_index_first: int
chain_index_second: int
absolute_atom_index_second: int
new_chirality_atom_first: Optional[str]
new_chirality_atom_second: Optional[str]
@dataclass
class AtomizedResidue:
chain: str
chain_index_in_combined_chain: int
absolute_N_index_in_chain: int
absolute_C_index_in_chain: int
original_chain: str
index_in_original_chain: int
def load_covalent_molecules(protein_inputs, config, model_runner):
if config.covale_inputs is None:
return None
if config.sm_inputs is None:
raise ValueError("If you provide covale_inputs, you must also provide small molecule inputs")
covalent_bonds = eval(config.covale_inputs)
sm_inputs = delete_leaving_atoms(config.sm_inputs)
residues_to_atomize, combined_molecules, extra_bonds = find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner)
chainid_to_input = {}
for chain, combined_molecule in combined_molecules.items():
extra_bonds_for_chain = extra_bonds[chain]
msa, bond_feats, xyz, Ls = get_combined_atoms_bonds(combined_molecule)
residues_to_atomize = update_absolute_indices_after_combination(residues_to_atomize, chain, Ls)
mol = make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds_for_chain)
xyz = recompute_xyz_after_chirality(mol)
input = compute_features_from_obmol(mol, msa, xyz, model_runner)
chainid_to_input[chain] = input
return chainid_to_input, residues_to_atomize
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner):
residues_to_atomize = [] # hold on to delete wayward inputs
combined_molecules = {} # combined multiple molecules that are bonded
extra_bonds = {}
for bond in covalent_bonds:
prot_chid, prot_res_idx, atom_to_bond = bond[0]
sm_chid, sm_atom_num = bond[1]
chirality_first_atom, chirality_second_atom = bond[2]
if chirality_first_atom.strip() == "null":
chirality_first_atom = None
if chirality_second_atom.strip() == "null":
chirality_second_atom = None
sm_atom_num = int(sm_atom_num) - 1 # 0 index
try:
assert sm_chid in sm_inputs, f"must provide a small molecule chain {sm_chid} for covalent bond: {bond}"
except:
print(f"Skipping bond: {bond} since no sm chain {sm_chid} was provided")
continue
assert sm_inputs[sm_chid].input_type == "sdf", "only sdf inputs can be covalently linked to proteins"
try:
protein_input = protein_inputs[prot_chid]
except Exception as e:
raise ValueError(f"first atom in covale_input must be present in\
a protein chain. Given chain: {prot_chid} was not in \
given protein chains: {list(protein_inputs.keys())}")
residue = (prot_chid, prot_res_idx, atom_to_bond)
file, atom_index = convert_residue_to_molecule(protein_inputs, residue, model_runner)
if sm_chid not in combined_molecules:
combined_molecules[sm_chid] = [sm_inputs[sm_chid].input]
combined_molecules[sm_chid].insert(0, file) # this is a bug, revert
absolute_chain_index_first = combined_molecules[sm_chid].index(sm_inputs[sm_chid].input)
absolute_chain_index_second = combined_molecules[sm_chid].index(file)
if sm_chid not in extra_bonds:
extra_bonds[sm_chid] = []
extra_bonds[sm_chid].append(MoleculeToMoleculeBond(
absolute_chain_index_first,
sm_atom_num,
absolute_chain_index_second,
atom_index,
new_chirality_atom_first=chirality_first_atom,
new_chirality_atom_second=chirality_second_atom
))
residues_to_atomize.append(AtomizedResidue(
sm_chid,
absolute_chain_index_second,
0,
2,
prot_chid,
int(prot_res_idx) -1
))
return residues_to_atomize, combined_molecules, extra_bonds
def convert_residue_to_molecule(protein_inputs, residue, model_runner):
"""convert residue into sdf and record index for covalent bond"""
prot_chid, prot_res_idx, atom_to_bond = residue
protein_input = protein_inputs[prot_chid]
prot_res_abs_idx = int(prot_res_idx) -1
residue_identity_num = protein_input.query_sequence()[prot_res_abs_idx]
residue_identity = ChemData().num2aa[residue_identity_num]
molecule_info = model_runner.molecule_db[residue_identity]
sdf = molecule_info["sdf"]
temp_file = create_and_populate_temp_file(sdf)
is_heavy = [i for i, a in enumerate(molecule_info["atom_id"]) if a[0] != "H"]
is_leaving = [a for i,a in enumerate(molecule_info["leaving"]) if i in is_heavy]
sdf_string_no_leaving_atoms = delete_leaving_atoms_single_chain(temp_file, is_leaving )
temp_file = create_and_populate_temp_file(sdf_string_no_leaving_atoms)
atom_names = molecule_info["atom_id"]
atom_index = atom_names.index(atom_to_bond.strip())
return temp_file, atom_index
def get_combined_atoms_bonds(combined_molecule):
atom_list = []
bond_feats_list = []
xyzs = []
Ls = []
for molecule in combined_molecule:
obmol, msa, ins, xyz, mask = parse_mol(
molecule,
filetype="sdf",
string=False,
generate_conformer=True,
find_automorphs=False
)
bond_feats = get_bond_feats(obmol)
atom_list.append(msa)
bond_feats_list.append(bond_feats)
xyzs.append(xyz)
Ls.append(msa.shape[0])
atoms = torch.cat(atom_list)
L_total = sum(Ls)
bond_feats = torch.zeros((L_total, L_total)).long()
offset = 0
for bf in bond_feats_list:
L = bf.shape[0]
bond_feats[offset:offset+L, offset:offset+L] = bf
offset += L
xyz = torch.cat(xyzs, dim=1)[0]
return atoms, bond_feats, xyz, Ls
def make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds):
mol = openbabel.OBMol()
for i,k in enumerate(msa):
element = ChemData().num2aa[k]
atomnum = ChemData().atomtype2atomnum[element]
a = mol.NewAtom()
a.SetAtomicNum(atomnum)
a.SetVector(float(xyz[i,0]), float(xyz[i,1]), float(xyz[i,2]))
first_index, second_index = bond_feats.nonzero(as_tuple=True)
for i, j in zip(first_index, second_index):
order = bond_feats[i,j]
bond = make_openbabel_bond(mol, i.item(), j.item(), order.item())
mol.AddBond(bond)
for bond in extra_bonds:
absolute_index_first = get_absolute_index_from_relative_indices(
bond.chain_index_first,
bond.absolute_atom_index_first,
Ls
)
absolute_index_second = get_absolute_index_from_relative_indices(
bond.chain_index_second,
bond.absolute_atom_index_second,
Ls
)
order = 1 #all covale bonds are single bonds
openbabel_bond = make_openbabel_bond(mol, absolute_index_first, absolute_index_second, order)
mol.AddBond(openbabel_bond)
set_chirality(mol, absolute_index_first, bond.new_chirality_atom_first)
set_chirality(mol, absolute_index_second, bond.new_chirality_atom_second)
return mol
def make_openbabel_bond(mol, i, j, order):
obb = openbabel.OBBond()
obb.SetBegin(mol.GetAtom(i+1))
obb.SetEnd(mol.GetAtom(j+1))
if order == 4:
obb.SetBondOrder(2)
obb.SetAromatic()
else:
obb.SetBondOrder(order)
return obb
def set_chirality(mol, absolute_atom_index, new_chirality):
stereo = openbabel.OBStereoFacade(mol)
if stereo.HasTetrahedralStereo(absolute_atom_index+1):
tetstereo = stereo.GetTetrahedralStereo(mol.GetAtom(absolute_atom_index+1).GetId())
if tetstereo is None:
return
assert new_chirality is not None, "you have introduced a new stereocenter, \
so you must specify its chirality either as CW, or CCW"
config = tetstereo.GetConfig()
config.winding = chirality_options[new_chirality]
tetstereo.SetConfig(config)
print("Updating chirality...")
else:
assert new_chirality is None, "you have specified a chirality without creating a new chiral center"
chirality_options = {
"CW": openbabel.OBStereo.Clockwise,
"CCW": openbabel.OBStereo.AntiClockwise,
}
def recompute_xyz_after_chirality(obmol):
builder = openbabel.OBBuilder()
builder.Build(obmol)
ff = openbabel.OBForceField.FindForceField("mmff94")
did_setup = ff.Setup(obmol)
if did_setup:
ff.FastRotorSearch()
ff.GetCoordinates(obmol)
else:
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
return atom_coords
def delete_leaving_atoms(sm_inputs):
updated_sm_inputs = {}
for chain in sm_inputs:
if "is_leaving" not in sm_inputs[chain]:
continue
is_leaving = eval(sm_inputs[chain]["is_leaving"])
sdf_string = delete_leaving_atoms_single_chain(sm_inputs[chain]["input"], is_leaving)
updated_sm_inputs[chain] = {
"input": create_and_populate_temp_file(sdf_string),
"input_type": "sdf"
}
sm_inputs.update(updated_sm_inputs)
return sm_inputs
def delete_leaving_atoms_single_chain(filename, is_leaving):
obmol, msa, ins, xyz, mask = parse_mol(
filename,
filetype="sdf",
string=False,
generate_conformer=True
)
assert len(is_leaving) == obmol.NumAtoms()
leaving_indices = torch.tensor(is_leaving).nonzero()
for idx in leaving_indices:
obmol.DeleteAtom(obmol.GetAtom(idx.item()+1))
obConversion = openbabel.OBConversion()
obConversion.SetInAndOutFormats("sdf", "sdf")
sdf_string = obConversion.WriteString(obmol)
return sdf_string
def get_absolute_index_from_relative_indices(chain_index, absolute_index_in_chain, Ls):
offset = sum(Ls[:chain_index])
return offset + absolute_index_in_chain
def update_absolute_indices_after_combination(residues_to_atomize, chain, Ls):
updated_residues_to_atomize = []
for residue in residues_to_atomize:
if residue.chain == chain:
absolute_index_N = get_absolute_index_from_relative_indices(
residue.chain_index_in_combined_chain,
residue.absolute_N_index_in_chain,
Ls)
absolute_index_C = get_absolute_index_from_relative_indices(
residue.chain_index_in_combined_chain,
residue.absolute_C_index_in_chain,
Ls)
updated_residue = AtomizedResidue(
residue.chain,
None,
absolute_index_N,
absolute_index_C,
residue.original_chain,
residue.index_in_original_chain
)
updated_residues_to_atomize.append(updated_residue)
else:
updated_residues_to_atomize.append(residue)
return updated_residues_to_atomize
def create_and_populate_temp_file(data):
# Create a temporary file
with NamedTemporaryFile(mode='w+', delete=False) as temp_file:
# Write the string to the temporary file
temp_file.write(data)
# Get the filename
temp_file_name = temp_file.name
return temp_file_name

198
rf2aa/data/data_loader.py Normal file
View file

@ -0,0 +1,198 @@
import torch
from dataclasses import dataclass, fields
from typing import Optional, List
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.data.data_loader_utils import MSAFeaturize, get_bond_distances, generate_xyz_prev
from rf2aa.kinematics import xyz_to_t2d
from rf2aa.util import get_prot_sm_mask, xyz_t_to_frame_xyz, same_chain_from_bond_feats, \
Ls_from_same_chain_2d, idx_from_Ls, is_atom
@dataclass
class RawInputData:
msa: torch.Tensor
ins: torch.Tensor
bond_feats: torch.Tensor
xyz_t: torch.Tensor
mask_t: torch.Tensor
t1d: torch.Tensor
chirals: torch.Tensor
atom_frames: torch.Tensor
taxids: Optional[List[str]] = None
term_info: Optional[torch.Tensor] = None
chain_lengths: Optional[List] = None
idx: Optional[List] = None
def query_sequence(self):
return self.msa[0]
def is_atom(self):
return is_atom(self.query_sequence())
def length(self):
return self.msa.shape[1]
def get_chain_bins_from_chain_lengths(self):
if self.chain_lengths is None:
raise ValueError("Cannot call get_chain_bins_from_chain_lengths without \
setting chain_lengths. Chain_lengths is set in merge_inputs")
chain_bins = {}
running_length = 0
for chain, length in self.chain_lengths:
chain_bins[chain] = (running_length, running_length+length)
running_length = running_length + length
return chain_bins
def update_protein_features_after_atomize(self, residues_to_atomize):
if self.chain_lengths is None:
raise("Cannot update protein features without chain_lengths. \
merge_inputs must be called before this function")
chain_bins = self.get_chain_bins_from_chain_lengths()
keep = torch.ones(self.length())
prev_absolute_index = None
prev_C = None
#need to atomize residues from N term to Cterm to handle atomizing neighbors
residues_to_atomize = sorted(residues_to_atomize, key= lambda x: x.original_chain +str(x.index_in_original_chain))
for residue in residues_to_atomize:
original_chain_start_index, original_chain_end_index = chain_bins[residue.original_chain]
absolute_index_in_combined_input = original_chain_start_index + residue.index_in_original_chain
atomized_chain_start_index, atomized_chain_end_index = chain_bins[residue.chain]
N_index = atomized_chain_start_index + residue.absolute_N_index_in_chain
C_index = atomized_chain_start_index + residue.absolute_C_index_in_chain
# if residue is first in the chain, no extra bond feats to following residue
if absolute_index_in_combined_input != original_chain_start_index:
self.bond_feats[absolute_index_in_combined_input-1, N_index] = ChemData().RESIDUE_ATOM_BOND
self.bond_feats[N_index, absolute_index_in_combined_input-1] = ChemData().RESIDUE_ATOM_BOND
# if residue is last in chain, no extra bonds feats to following residue
if absolute_index_in_combined_input != original_chain_end_index-1:
self.bond_feats[absolute_index_in_combined_input+1, C_index] = ChemData().RESIDUE_ATOM_BOND
self.bond_feats[C_index,absolute_index_in_combined_input+1] = ChemData().RESIDUE_ATOM_BOND
keep[absolute_index_in_combined_input] = 0
# find neighboring residues that were atomized
if prev_absolute_index is not None:
if prev_absolute_index + 1 == absolute_index_in_combined_input:
self.bond_feats[prev_C, N_index] = 1
self.bond_feats[N_index, prev_C] = 1
prev_absolute_index = absolute_index_in_combined_input
prev_C = C_index
# remove protein features
self.keep_features(keep.bool())
def keep_features(self, keep):
if not torch.all(keep[self.is_atom()]):
raise ValueError("cannot remove atoms")
self.msa = self.msa[:,keep]
self.ins = self.ins[:,keep]
self.bond_feats = self.bond_feats[keep][:,keep]
self.xyz_t = self.xyz_t[:,keep]
self.t1d = self.t1d[:,keep]
self.mask_t = self.mask_t[:,keep]
if self.term_info is not None:
self.term_info = self.term_info[keep]
if self.idx is not None:
self.idx = self.idx[keep]
# assumes all chirals are after all protein residues
self.chirals[...,:-1] = self.chirals[...,:-1] - torch.sum(~keep)
def construct_features(self, model_runner):
loader_params = model_runner.config.loader_params
B, L = 1, self.length()
seq, msa_clust, msa_seed, msa_extra, mask_pos = MSAFeaturize(
self.msa.long(),
self.ins.long(),
loader_params,
p_mask=loader_params.get("p_msa_mask", 0),
term_info=self.term_info,
deterministic=model_runner.deterministic,
)
dist_matrix = get_bond_distances(self.bond_feats)
# xyz_prev, mask_prev = generate_xyz_prev(self.xyz_t, self.mask_t, loader_params)
# xyz_prev = torch.nan_to_num(xyz_prev)
# NOTE: The above is the way things "should" be done, this is for compatability with training.
xyz_prev = ChemData().INIT_CRDS.reshape(1,ChemData().NTOTAL,3).repeat(L,1,1)
self.xyz_t = torch.nan_to_num(self.xyz_t)
mask_t_2d = get_prot_sm_mask(self.mask_t, seq[0])
mask_t_2d = mask_t_2d[:,None]*mask_t_2d[:,:,None] # (B, T, L, L)
xyz_t_frame = xyz_t_to_frame_xyz(self.xyz_t[None], self.msa[0], self.atom_frames)
t2d = xyz_to_t2d(xyz_t_frame, mask_t_2d[None])
t2d = t2d[0]
# get torsion angles from templates
seq_tmp = self.t1d[...,:-1].argmax(dim=-1)
alpha, _, alpha_mask, _ = model_runner.xyz_converter.get_torsions(self.xyz_t.reshape(-1,L,ChemData().NTOTAL,3),
seq_tmp, mask_in=self.mask_t.reshape(-1,L,ChemData().NTOTAL))
alpha = alpha.reshape(B,-1,L,ChemData().NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(B,-1,L,ChemData().NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*ChemData().NTOTALDOFS)
alpha_t = alpha_t[0]
alpha_prev = torch.zeros((L,ChemData().NTOTALDOFS,2))
same_chain = same_chain_from_bond_feats(self.bond_feats)
return RFInput(
msa_latent=msa_seed,
msa_full=msa_extra,
seq=seq,
seq_unmasked=self.query_sequence(),
bond_feats=self.bond_feats,
dist_matrix=dist_matrix,
chirals=self.chirals,
atom_frames=self.atom_frames.long(),
xyz_prev=xyz_prev,
alpha_prev=alpha_prev,
t1d=self.t1d,
t2d=t2d,
xyz_t=self.xyz_t[..., 1, :],
alpha_t=alpha_t.float(),
mask_t=mask_t_2d.float(),
same_chain=same_chain.long(),
idx=self.idx
)
@dataclass
class RFInput:
msa_latent: torch.Tensor
msa_full: torch.Tensor
seq: torch.Tensor
seq_unmasked: torch.Tensor
idx: torch.Tensor
bond_feats: torch.Tensor
dist_matrix: torch.Tensor
chirals: torch.Tensor
atom_frames: torch.Tensor
xyz_prev: torch.Tensor
alpha_prev: torch.Tensor
t1d: torch.Tensor
t2d: torch.Tensor
xyz_t: torch.Tensor
alpha_t: torch.Tensor
mask_t: torch.Tensor
same_chain: torch.Tensor
msa_prev: Optional[torch.Tensor] = None
pair_prev: Optional[torch.Tensor] = None
state_prev: Optional[torch.Tensor] = None
mask_recycle: Optional[torch.Tensor] = None
def to(self, gpu):
for field in fields(self):
field_value = getattr(self, field.name)
if torch.is_tensor(field_value):
setattr(self, field.name, field_value.to(gpu))
def add_batch_dim(self):
""" mimic pytorch dataloader at inference time"""
for field in fields(self):
field_value = getattr(self, field.name)
if torch.is_tensor(field_value):
setattr(self, field.name, field_value[None])

View file

@ -0,0 +1,909 @@
import torch
import warnings
import time
from icecream import ic
from torch.utils import data
import os, csv, random, pickle, gzip, itertools, time, ast, copy, sys
script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(script_dir)
sys.path.append(script_dir+'/../')
import numpy as np
import scipy
import networkx as nx
from rf2aa.data.parsers import parse_a3m, parse_pdb
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.util import random_rot_trans, \
is_atom, is_protein, is_nucleic, is_atom
def MSABlockDeletion(msa, ins, nb=5):
'''
Input: MSA having shape (N, L)
output: new MSA with block deletion
'''
N, L = msa.shape
block_size = max(int(N*0.3), 1)
block_start = np.random.randint(low=1, high=N, size=nb) # (nb)
to_delete = block_start[:,None] + np.arange(block_size)[None,:]
to_delete = np.unique(np.clip(to_delete, 1, N-1))
#
mask = np.ones(N, bool)
mask[to_delete] = 0
return msa[mask], ins[mask]
def cluster_sum(data, assignment, N_seq, N_res):
csum = torch.zeros(N_seq, N_res, data.shape[-1], device=data.device).scatter_add(0, assignment.view(-1,1,1).expand(-1,N_res,data.shape[-1]), data.float())
return csum
def get_term_feats(Ls):
"""Creates N/C-terminus binary features"""
term_info = torch.zeros((sum(Ls),2)).float()
start = 0
for L_chain in Ls:
term_info[start, 0] = 1.0 # flag for N-term
term_info[start+L_chain-1,1] = 1.0 # flag for C-term
start += L_chain
return term_info
def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
term_info=None, tocpu=False, fixbb=False, seed_msa_clus=None, deterministic=False):
'''
Input: full MSA information (after Block deletion if necessary) & full insertion information
Output: seed MSA features & extra sequences
Seed MSA features:
- aatype of seed sequence (20 regular aa + 1 gap/unknown + 1 mask)
- profile of clustered sequences (22)
- insertion statistics (2)
- N-term or C-term? (2)
extra sequence features:
- aatype of extra sequence (22)
- insertion info (1)
- N-term or C-term? (2)
'''
if deterministic:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# TODO: delete me, just for testing purposes
msa = msa[:2]
if fixbb:
p_mask = 0
msa = msa[:1]
ins = ins[:1]
N, L = msa.shape
if term_info is None:
if len(L_s)==0:
L_s = [L]
term_info = get_term_feats(L_s)
term_info = term_info.to(msa.device)
#binding_site = torch.zeros((L,1), device=msa.device).float()
binding_site = torch.zeros((L,0), device=msa.device).float() # keeping this off for now (Jue 12/19)
# raw MSA profile
raw_profile = torch.nn.functional.one_hot(msa, num_classes=ChemData().NAATOKENS) # N x L x NAATOKENS
raw_profile = raw_profile.float().mean(dim=0) # L x NAATOKENS
# Select Nclust sequence randomly (seed MSA or latent MSA)
Nclust = (min(N, params['MAXLAT'])-1) // nmer
Nclust = Nclust*nmer + 1
if N > Nclust*2:
Nextra = N - Nclust
else:
Nextra = N
Nextra = min(Nextra, params['MAXSEQ']) // nmer
Nextra = max(1, Nextra * nmer)
#
b_seq = list()
b_msa_clust = list()
b_msa_seed = list()
b_msa_extra = list()
b_mask_pos = list()
for i_cycle in range(params['MAXCYCLE']):
sample_mono = torch.randperm((N-1)//nmer, device=msa.device)
sample = [sample_mono + imer*((N-1)//nmer) for imer in range(nmer)]
sample = torch.stack(sample, dim=-1)
sample = sample.reshape(-1)
# add MSA clusters pre-chosen before calling this function
if seed_msa_clus is not None:
sample_orig_shape = sample.shape
sample_seed = seed_msa_clus[i_cycle]
sample_more = torch.tensor([i for i in sample if i not in sample_seed])
N_sample_more = len(sample) - len(sample_seed)
if N_sample_more > 0:
sample_more = sample_more[torch.randperm(len(sample_more))[:N_sample_more]]
sample = torch.cat([sample_seed, sample_more])
else:
sample = sample_seed[:len(sample)] # take all clusters from pre-chosen ones
msa_clust = torch.cat((msa[:1,:], msa[1:,:][sample[:Nclust-1]]), dim=0)
ins_clust = torch.cat((ins[:1,:], ins[1:,:][sample[:Nclust-1]]), dim=0)
# 15% random masking
# - 10%: aa replaced with a uniformly sampled random amino acid
# - 10%: aa replaced with an amino acid sampled from the MSA profile
# - 10%: not replaced
# - 70%: replaced with a special token ("mask")
random_aa = torch.tensor([[0.05]*20 + [0.0]*(ChemData().NAATOKENS-20)], device=msa.device)
same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=ChemData().NAATOKENS)
# explicitly remove probabilities from nucleic acids and atoms
#same_aa[..., ChemData().NPROTAAS:] = 0
#raw_profile[...,ChemData().NPROTAAS:] = 0
probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa
#probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7)
# explicitly set the probability of masking for nucleic acids and atoms
#probs[...,is_protein(seq),ChemData().MASKINDEX]=0.7
#probs[...,~is_protein(seq), :] = 0 # probably overkill but set all none protein elements to 0
#probs[1:, ~is_protein(seq),20] = 1.0 # want to leave the gaps as gaps
#probs[0,is_nucleic(seq), ChemData().MASKINDEX] = 1.0
#probs[0,is_atom(seq), ChemData().aa2num["ATM"]] = 1.0
sampler = torch.distributions.categorical.Categorical(probs=probs)
mask_sample = sampler.sample()
mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask
mask_pos[msa_clust>ChemData().MASKINDEX]=False # no masking on NAs
use_seq = msa_clust
msa_masked = torch.where(mask_pos, mask_sample, use_seq)
b_seq.append(msa_masked[0].clone())
## get extra sequenes
if N > Nclust*2: # there are enough extra sequences
msa_extra = msa[1:,:][sample[Nclust-1:]]
ins_extra = ins[1:,:][sample[Nclust-1:]]
extra_mask = torch.full(msa_extra.shape, False, device=msa_extra.device)
elif N - Nclust < 1:
msa_extra = msa_masked.clone()
ins_extra = ins_clust.clone()
extra_mask = mask_pos.clone()
else:
msa_add = msa[1:,:][sample[Nclust-1:]]
ins_add = ins[1:,:][sample[Nclust-1:]]
mask_add = torch.full(msa_add.shape, False, device=msa_add.device)
msa_extra = torch.cat((msa_masked, msa_add), dim=0)
ins_extra = torch.cat((ins_clust, ins_add), dim=0)
extra_mask = torch.cat((mask_pos, mask_add), dim=0)
N_extra = msa_extra.shape[0]
# clustering (assign remaining sequences to their closest cluster by Hamming distance
msa_clust_onehot = torch.nn.functional.one_hot(msa_masked, num_classes=ChemData().NAATOKENS)
msa_extra_onehot = torch.nn.functional.one_hot(msa_extra, num_classes=ChemData().NAATOKENS)
count_clust = torch.logical_and(~mask_pos, msa_clust != 20).float() # 20: index for gap, ignore both masked & gaps
count_extra = torch.logical_and(~extra_mask, msa_extra != 20).float()
agreement = torch.matmul((count_extra[:,:,None]*msa_extra_onehot).view(N_extra, -1), (count_clust[:,:,None]*msa_clust_onehot).view(Nclust, -1).T)
assignment = torch.argmax(agreement, dim=-1)
# seed MSA features
# 1. one_hot encoded aatype: msa_clust_onehot
# 2. cluster profile
count_extra = ~extra_mask
count_clust = ~mask_pos
msa_clust_profile = cluster_sum(count_extra[:,:,None]*msa_extra_onehot, assignment, Nclust, L)
msa_clust_profile += count_clust[:,:,None]*msa_clust_profile
count_profile = cluster_sum(count_extra[:,:,None], assignment, Nclust, L).view(Nclust, L)
count_profile += count_clust
count_profile += eps
msa_clust_profile /= count_profile[:,:,None]
# 3. insertion statistics
msa_clust_del = cluster_sum((count_extra*ins_extra)[:,:,None], assignment, Nclust, L).view(Nclust, L)
msa_clust_del += count_clust*ins_clust
msa_clust_del /= count_profile
ins_clust = (2.0/np.pi)*torch.arctan(ins_clust.float()/3.0) # (from 0 to 1)
msa_clust_del = (2.0/np.pi)*torch.arctan(msa_clust_del.float()/3.0) # (from 0 to 1)
ins_clust = torch.stack((ins_clust, msa_clust_del), dim=-1)
#
if fixbb:
assert params['MAXCYCLE'] == 1
msa_clust_profile = msa_clust_onehot
msa_extra_onehot = msa_clust_onehot
ins_clust[:] = 0
ins_extra[:] = 0
# This is how it is done in rfdiff, but really it seems like it should be all 0.
# Keeping as-is for now for consistency, as it may be used in downstream masking done
# by apply_masks.
mask_pos = torch.full_like(msa_clust, 1).bool()
msa_seed = torch.cat((msa_clust_onehot, msa_clust_profile, ins_clust, term_info[None].expand(Nclust,-1,-1)), dim=-1)
# extra MSA features
ins_extra = (2.0/np.pi)*torch.arctan(ins_extra[:Nextra].float()/3.0) # (from 0 to 1)
try:
msa_extra = torch.cat((msa_extra_onehot[:Nextra], ins_extra[:,:,None], term_info[None].expand(Nextra,-1,-1)), dim=-1)
except Exception as e:
print('msa_extra.shape',msa_extra.shape)
print('ins_extra.shape',ins_extra.shape)
if (tocpu):
b_msa_clust.append(msa_clust.cpu())
b_msa_seed.append(msa_seed.cpu())
b_msa_extra.append(msa_extra.cpu())
b_mask_pos.append(mask_pos.cpu())
else:
b_msa_clust.append(msa_clust)
b_msa_seed.append(msa_seed)
b_msa_extra.append(msa_extra)
b_mask_pos.append(mask_pos)
b_seq = torch.stack(b_seq)
b_msa_clust = torch.stack(b_msa_clust)
b_msa_seed = torch.stack(b_msa_seed)
b_msa_extra = torch.stack(b_msa_extra)
b_mask_pos = torch.stack(b_mask_pos)
return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos
def blank_template(n_tmpl, L, random_noise=5.0, deterministic: bool = False):
if deterministic:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(n_tmpl,L,1,1) \
+ torch.rand(n_tmpl,L,1,3)*random_noise - random_noise/2
t1d = torch.nn.functional.one_hot(torch.full((n_tmpl, L), 20).long(), num_classes=ChemData().NAATOKENS-1).float() # all gaps
conf = torch.zeros((n_tmpl, L, 1)).float()
t1d = torch.cat((t1d, conf), -1)
mask_t = torch.full((n_tmpl,L,ChemData().NTOTAL), False)
return xyz, t1d, mask_t, np.full((n_tmpl), "")
def TemplFeaturize(tplt, qlen, params, offset=0, npick=1, npick_global=None, pick_top=True, same_chain=None, random_noise=5, deterministic: bool = False):
if deterministic:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
seqID_cut = params['SEQID']
if npick_global == None:
npick_global=max(npick, 1)
ntplt = len(tplt['ids'])
if (ntplt < 1) or (npick < 1): #no templates in hhsearch file or not want to use templ
return blank_template(npick_global, qlen, random_noise)
# ignore templates having too high seqID
if seqID_cut <= 100.0:
tplt_valid_idx = torch.where(tplt['f0d'][0,:,4] < seqID_cut)[0]
tplt['ids'] = np.array(tplt['ids'])[tplt_valid_idx]
else:
tplt_valid_idx = torch.arange(len(tplt['ids']))
# check again if there are templates having seqID < cutoff
ntplt = len(tplt['ids'])
npick = min(npick, ntplt)
if npick<1: # no templates
return blank_template(npick_global, qlen, random_noise)
if not pick_top: # select randomly among all possible templates
sample = torch.randperm(ntplt)[:npick]
else: # only consider top 50 templates
sample = torch.randperm(min(50,ntplt))[:npick]
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(npick_global,qlen,1,1) + torch.rand(1,qlen,1,3)*random_noise
mask_t = torch.full((npick_global,qlen,ChemData().NTOTAL),False) # True for valid atom, False for missing atom
t1d = torch.full((npick_global, qlen), 20).long()
t1d_val = torch.zeros((npick_global, qlen)).float()
for i,nt in enumerate(sample):
tplt_idx = tplt_valid_idx[nt]
sel = torch.where(tplt['qmap'][0,:,1]==tplt_idx)[0]
pos = tplt['qmap'][0,sel,0] + offset
ntmplatoms = tplt['xyz'].shape[2] # will be bigger for NA templates
xyz[i,pos,:ntmplatoms] = tplt['xyz'][0,sel]
mask_t[i,pos,:ntmplatoms] = tplt['mask'][0,sel].bool()
# 1-D features: alignment confidence
t1d[i,pos] = tplt['seq'][0,sel]
t1d_val[i,pos] = tplt['f1d'][0,sel,2] # alignment confidence
# xyz[i] = center_and_realign_missing(xyz[i], mask_t[i], same_chain=same_chain)
t1d = torch.nn.functional.one_hot(t1d, num_classes=ChemData().NAATOKENS-1).float() # (no mask token)
t1d = torch.cat((t1d, t1d_val[...,None]), dim=-1)
tplt_ids = np.array(tplt["ids"])[sample].flatten() # np.array of chain ids (ordered)
return xyz, t1d, mask_t, tplt_ids
def merge_hetero_templates(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids, Ls_prot, deterministic: bool = False):
"""Diagonally tiles template coordinates, 1d input features, and masks across
template and residue dimensions. 1st template is concatenated directly on residue
dimension after a random rotation & translation.
"""
N_tmpl_tot = sum([x.shape[0] for x in xyz_t_prot])
xyz_t_out, f1d_t_out, mask_t_out, _ = blank_template(N_tmpl_tot, sum(Ls_prot))
tplt_ids_out = np.full((N_tmpl_tot),"", dtype=object) # rk bad practice.. should fix
i_tmpl = 0
i_res = 0
for xyz_, f1d_, mask_, ids in zip(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids):
N_tmpl, L_tmpl = xyz_.shape[:2]
if i_tmpl == 0:
i1, i2 = 1, N_tmpl
else:
i1, i2 = i_tmpl, i_tmpl+N_tmpl - 1
# 1st template is concatenated directly, so that all atoms are set in xyz_prev
xyz_t_out[0, i_res:i_res+L_tmpl] = random_rot_trans(xyz_[0:1], deterministic=deterministic)
f1d_t_out[0, i_res:i_res+L_tmpl] = f1d_[0]
mask_t_out[0, i_res:i_res+L_tmpl] = mask_[0]
if not tplt_ids_out[0]: # only add first template
tplt_ids_out[0] = ids[0]
# remaining templates are diagonally tiled
xyz_t_out[i1:i2, i_res:i_res+L_tmpl] = xyz_[1:]
f1d_t_out[i1:i2, i_res:i_res+L_tmpl] = f1d_[1:]
mask_t_out[i1:i2, i_res:i_res+L_tmpl] = mask_[1:]
tplt_ids_out[i1:i2] = ids[1:]
if i_tmpl == 0:
i_tmpl += N_tmpl
else:
i_tmpl += N_tmpl-1
i_res += L_tmpl
return xyz_t_out, f1d_t_out, mask_t_out, tplt_ids_out
def generate_xyz_prev(xyz_t, mask_t, params):
"""
allows you to use different initializations for the coordinate track specified in params
"""
L = xyz_t.shape[1]
if params["BLACK_HOLE_INIT"]:
xyz_t, _, mask_t = blank_template(1, L)
return xyz_t[0].clone(), mask_t[0].clone()
### merge msa & insertion statistics of two proteins having different taxID
def merge_a3m_hetero(a3mA, a3mB, L_s):
# merge msa
query = torch.cat([a3mA['msa'][0], a3mB['msa'][0]]).unsqueeze(0) # (1, L)
msa = [query]
if a3mA['msa'].shape[0] > 1:
extra_A = torch.nn.functional.pad(a3mA['msa'][1:], (0,sum(L_s[1:])), "constant", 20) # pad gaps
msa.append(extra_A)
if a3mB['msa'].shape[0] > 1:
extra_B = torch.nn.functional.pad(a3mB['msa'][1:], (L_s[0],0), "constant", 20)
msa.append(extra_B)
msa = torch.cat(msa, dim=0)
# merge ins
query = torch.cat([a3mA['ins'][0], a3mB['ins'][0]]).unsqueeze(0) # (1, L)
ins = [query]
if a3mA['ins'].shape[0] > 1:
extra_A = torch.nn.functional.pad(a3mA['ins'][1:], (0,sum(L_s[1:])), "constant", 0) # pad gaps
ins.append(extra_A)
if a3mB['ins'].shape[0] > 1:
extra_B = torch.nn.functional.pad(a3mB['ins'][1:], (L_s[0],0), "constant", 0)
ins.append(extra_B)
ins = torch.cat(ins, dim=0)
a3m = {'msa': msa, 'ins': ins}
# merge taxids
if 'taxid' in a3mA and 'taxid' in a3mB:
a3m['taxid'] = np.concatenate([np.array(a3mA['taxid']), np.array(a3mB['taxid'])[1:]])
return a3m
# merge msa & insertion statistics of units in homo-oligomers
def merge_a3m_homo(msa_orig, ins_orig, nmer, mode="default"):
N, L = msa_orig.shape[:2]
if mode == "repeat":
# AAAAAA
# AAAAAA
msa = torch.tile(msa_orig,(1,nmer))
ins = torch.tile(ins_orig,(1,nmer))
elif mode == "diag":
# AAAAAA
# A-----
# -A----
# --A---
# ---A--
# ----A-
# -----A
N = N - 1
new_N = 1 + N * nmer
new_L = L * nmer
msa = torch.full((new_N, new_L), 20, dtype=msa_orig.dtype, device=msa_orig.device)
ins = torch.full((new_N, new_L), 0, dtype=ins_orig.dtype, device=msa_orig.device)
start_L = 0
start_N = 1
for i_c in range(nmer):
msa[0, start_L:start_L+L] = msa_orig[0]
msa[start_N:start_N+N, start_L:start_L+L] = msa_orig[1:]
ins[0, start_L:start_L+L] = ins_orig[0]
ins[start_N:start_N+N, start_L:start_L+L] = ins_orig[1:]
start_L += L
start_N += N
else:
# AAAAAA
# A-----
# -AAAAA
msa = torch.full((2*N-1, L*nmer), 20, dtype=msa_orig.dtype, device=msa_orig.device)
ins = torch.full((2*N-1, L*nmer), 0, dtype=ins_orig.dtype, device=msa_orig.device)
msa[:N, :L] = msa_orig
ins[:N, :L] = ins_orig
start = L
for i_c in range(1,nmer):
msa[0, start:start+L] = msa_orig[0]
msa[N:, start:start+L] = msa_orig[1:]
ins[0, start:start+L] = ins_orig[0]
ins[N:, start:start+L] = ins_orig[1:]
start += L
return msa, ins
def merge_msas(a3m_list, L_s):
"""
takes a list of a3m dictionaries with keys msa, ins and a list of protein lengths and creates a
combined MSA
"""
seen = set()
taxIDs = []
a3mA = a3m_list[0]
taxIDs.extend(a3mA["taxID"])
seen.update(a3mA["hash"])
msaA, insA = a3mA["msa"], a3mA["ins"]
for i in range(1, len(a3m_list)):
a3mB = a3m_list[i]
pair_taxIDs = set(taxIDs).intersection(set(a3mB["taxID"]))
if a3mB["hash"] in seen or len(pair_taxIDs) < 5: #homomer/not enough pairs
a3mA = {"msa": msaA, "ins": insA}
L_s_to_merge = [sum(L_s[:i]), L_s[i]]
a3mA = merge_a3m_hetero(a3mA, a3mB, L_s_to_merge)
msaA, insA = a3mA["msa"], a3mA["ins"]
taxIDs.extend(a3mB["taxID"])
else:
final_pairsA = []
final_pairsB = []
msaB, insB = a3mB["msa"], a3mB["ins"]
for pair in pair_taxIDs:
pair_a3mA = np.where(np.array(taxIDs)==pair)[0]
pair_a3mB = np.where(a3mB["taxID"]==pair)[0]
msaApair = torch.argmin(torch.sum(msaA[pair_a3mA, :] == msaA[0, :],axis=-1))
msaBpair = torch.argmin(torch.sum(msaB[pair_a3mB, :] == msaB[0, :],axis=-1))
final_pairsA.append(pair_a3mA[msaApair])
final_pairsB.append(pair_a3mB[msaBpair])
paired_msaB = torch.full((msaA.shape[0], L_s[i]), 20).long() # (N_seq_A, L_B)
paired_msaB[final_pairsA] = msaB[final_pairsB]
msaA = torch.cat([msaA, paired_msaB], dim=1)
insA = torch.zeros_like(msaA) # paired MSAs in our dataset dont have insertions
seen.update(a3mB["hash"])
return msaA, insA
def remove_all_gap_seqs(a3m):
"""Removes sequences that are all gaps from an MSA represented as `a3m` dictionary"""
idx_seq_keep = ~(a3m['msa']==ChemData().UNKINDEX).all(dim=1)
a3m['msa'] = a3m['msa'][idx_seq_keep]
a3m['ins'] = a3m['ins'][idx_seq_keep]
return a3m
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
ID, only the sequence with the highest sequence identity to the query (1st
sequence in MSA) will be paired.
Sequences that aren't paired will be padded and added to the bottom of the
joined MSA. If a subregion of the input MSAs overlap (represent the same
chain), the subregion residue indices can be given as `idx_overlap`, and
the overlap region of the unpaired sequences will be included in the joined
MSA.
Parameters
----------
a3mA : dict
First MSA to be joined, with keys `msa` (N_seq, L_seq), `ins` (N_seq,
L_seq), `taxid` (N_seq,), and optionally `is_paired` (N_seq,), a
boolean tensor indicating whether each sequence is fully paired. Can be
a multi-MSA (contain >2 sub-MSAs).
a3mB : dict
2nd MSA to be joined, with keys `msa`, `ins`, `taxid`, and optionally
`is_paired`. Can be a multi-MSA ONLY if not overlapping with 1st MSA.
idx_overlap : tuple or list (optional)
Start and end indices of overlap region in 1st MSA, followed by the
same in 2nd MSA.
Returns
-------
a3m : dict
Paired MSA, with keys `msa`, `ins`, `taxid` and `is_paired`.
"""
# preprocess overlap region
L_A, L_B = a3mA['msa'].shape[1], a3mB['msa'].shape[1]
if idx_overlap is not None:
i1A, i2A, i1B, i2B = idx_overlap
i1B_new, i2B_new = (0, i1B) if i2B==L_B else (i2B, L_B) # MSA B residues that don't overlap MSA A
assert((i1B==0) or (i2B==a3mB['msa'].shape[1])), \
"When overlapping with 1st MSA, 2nd MSA must comprise at most 2 sub-MSAs "\
"(i.e. residue range should include 0 or a3mB['msa'].shape[1])"
else:
i1B_new, i2B_new = (0, L_B)
# pair sequences
taxids_shared = a3mA['taxid'][np.isin(a3mA['taxid'],a3mB['taxid'])]
i_pairedA, i_pairedB = [], []
for taxid in taxids_shared:
i_match = np.where(a3mA['taxid']==taxid)[0]
i_match_best = torch.argmin(torch.sum(a3mA['msa'][i_match]==a3mA['msa'][0], axis=1))
i_pairedA.append(i_match[i_match_best])
i_match = np.where(a3mB['taxid']==taxid)[0]
i_match_best = torch.argmin(torch.sum(a3mB['msa'][i_match]==a3mB['msa'][0], axis=1))
i_pairedB.append(i_match[i_match_best])
# unpaired sequences
i_unpairedA = np.setdiff1d(np.arange(a3mA['msa'].shape[0]), i_pairedA)
i_unpairedB = np.setdiff1d(np.arange(a3mB['msa'].shape[0]), i_pairedB)
N_paired, N_unpairedA, N_unpairedB = len(i_pairedA), len(i_unpairedA), len(i_unpairedB)
# handle overlap region
# if msa A consists of sub-MSAs 1,2,3 and msa B of 2,4 (i.e overlap region is 2),
# this diagram shows how the variables below make up the final multi-MSA
# (* denotes nongaps, - denotes gaps)
# 1 2 3 4
# |*|*|*|*| msa_paired
# |*|*|*|-| msaA_unpaired
# |-|*|-|*| msaB_unpaired
if idx_overlap is not None:
assert((a3mA['msa'][i_pairedA, i1A:i2A]==a3mB['msa'][i_pairedB, i1B:i2B]) |
(a3mA['msa'][i_pairedA, i1A:i2A]==ChemData().UNKINDEX)).all(),\
'Paired MSAs should be identical (or 1st MSA should be all gaps) in overlap region'
# overlap region gets sequences from 2nd MSA bc sometimes 1st MSA will be all gaps here
msa_paired = torch.cat([a3mA['msa'][i_pairedA, :i1A],
a3mB['msa'][i_pairedB, i1B:i2B],
a3mA['msa'][i_pairedA, i2A:],
a3mB['msa'][i_pairedB, i1B_new:i2B_new] ], dim=1)
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
torch.full((N_unpairedA, i2B_new-i1B_new), ChemData().UNKINDEX) ], dim=1)
msaB_unpaired = torch.cat([torch.full((N_unpairedB, i1A), ChemData().UNKINDEX),
a3mB['msa'][i_unpairedB, i1B:i2B],
torch.full((N_unpairedB, L_A-i2A), ChemData().UNKINDEX),
a3mB['msa'][i_unpairedB, i1B_new:i2B_new] ], dim=1)
else:
# no overlap region, simple offset pad & stack
# this code is actually a special case of "if" block above, but writing
# this out explicitly here to make the logic more clear
msa_paired = torch.cat([a3mA['msa'][i_pairedA], a3mB['msa'][i_pairedB, i1B_new:i2B_new]], dim=1)
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
torch.full((N_unpairedA, L_B), ChemData().UNKINDEX)], dim=1) # pad with gaps
msaB_unpaired = torch.cat([torch.full((N_unpairedB, L_A), ChemData().UNKINDEX),
a3mB['msa'][i_unpairedB]], dim=1) # pad with gaps
# stack paired & unpaired
msa = torch.cat([msa_paired, msaA_unpaired, msaB_unpaired], dim=0)
taxids = np.concatenate([a3mA['taxid'][i_pairedA], a3mA['taxid'][i_unpairedA], a3mB['taxid'][i_unpairedB]])
# label "fully paired" sequences (a row of MSA that was never padded with gaps)
# output seq is fully paired if seqs A & B both started out as paired and were paired to
# each other on tax ID.
# NOTE: there is a rare edge case that is ignored here for simplicity: if
# pMSA 0+1 and 1+2 are joined and then joined to 2+3, a seq that exists in
# 0+1 and 2+3 but NOT 1+2 will become fully paired on the last join but
# will not be labeled as such here
is_pairedA = a3mA['is_paired'] if 'is_paired' in a3mA else torch.ones((a3mA['msa'].shape[0],)).bool()
is_pairedB = a3mB['is_paired'] if 'is_paired' in a3mB else torch.ones((a3mB['msa'].shape[0],)).bool()
is_paired = torch.cat([is_pairedA[i_pairedA] & is_pairedB[i_pairedB],
torch.zeros((N_unpairedA + N_unpairedB,)).bool()])
# insertion features in paired MSAs are assumed to be zero
a3m = dict(msa=msa, ins=torch.zeros_like(msa), taxid=taxids, is_paired=is_paired)
return a3m
def load_minimal_multi_msa(hash_list, taxid_list, Ls, params):
"""Load a multi-MSA, which is a MSA that is paired across more than 2
chains. This loads the MSA for unique chains. Use 'expand_multi_msa` to
duplicate portions of the MSA for homo-oligomer repeated chains.
Given a list of unique MSA hashes, loads all MSAs (using paired MSAs where
it can) and pairs sequences across as many sub-MSAs as possible by matching
taxonomic ID. For details on how pairing is done, see
`join_msas_by_taxid()`
Parameters
----------
hash_list : list of str
Hashes of MSAs to load and join. Must not contain duplicates.
taxid_list : list of str
Taxonomic IDs of query sequences of each input MSA.
Ls : list of int
Lengths of the chains corresponding to the hashes.
Returns
-------
a3m_out : dict
Multi-MSA with all input MSAs. Keys: `msa`,`ins` [torch.Tensor (N_seq, L)],
`taxid` [np.array (Nseq,)], `is_paired` [torch.Tensor (N_seq,)]
hashes_out : list of str
Hashes of MSAs in the order that they are joined in `a3m_out`.
Contains the same elements as the input `hash_list` but may be in a
different order.
Ls_out : list of int
Lengths of each chain in `a3m_out`
"""
assert(len(hash_list)==len(set(hash_list))), 'Input MSA hashes must be unique'
# the lists below are constructed such that `a3m_list[i_a3m]` is a multi-MSA
# comprising sub-MSAs whose indices in the input lists are
# `i_in = idx_list_groups[i_a3m][i_submsa]`, i.e. the sub-MSA hashes are
# `hash_list[i_in]` and lengths are `Ls[i_in]`.
# Each sub-MSA spans a region of its multi-MSA `a3m_list[i_a3m][:,i_start:i_end]`,
# where `(i_start,i_end) = res_range_groups[i_a3m][i_submsa]`
a3m_list = [] # list of multi-MSAs
idx_list_groups = [] # list of lists of indices of input chains making up each multi-MSA
res_range_groups = [] # list of lists of start and end residues of each sub-MSA in multi-MSA
# iterate through all pairs of hashes and look for paired MSAs (pMSAs)
# NOTE: in the below, if pMSAs are loaded for hashes 0+1 and then 2+3, and
# later a pMSA is found for 0+2, the last MSA will not be loaded. The 0+1
# and 2+3 pMSAs will still be joined on taxID at the end, but sequences
# only present in the 0+2 pMSA pMSAs will be missed. this is probably very
# rare and so is ignored here for simplicity.
N = len(hash_list)
for i1, i2 in itertools.permutations(range(N),2):
idx_list = [x for group in idx_list_groups for x in group] # flattened list of loaded hashes
if i1 in idx_list and i2 in idx_list: continue # already loaded
if i1 == '' or i2 == '': continue # no taxID means no pMSA
# a paired MSA exists
if taxid_list[i1]==taxid_list[i2]:
h1, h2 = hash_list[i1], hash_list[i2]
fn = params['COMPL_DIR']+'/pMSA/'+h1[:3]+'/'+h2[:3]+'/'+h1+'_'+h2+'.a3m.gz'
if os.path.exists(fn):
msa, ins, taxid = parse_a3m(fn, paired=True)
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins), taxid=taxid,
is_paired=torch.ones(msa.shape[0]).bool())
res_range1 = (0,Ls[i1])
res_range2 = (Ls[i1],msa.shape[1])
# both hashes are new, add paired MSA to list
if i1 not in idx_list and i2 not in idx_list:
a3m_list.append(a3m_new)
idx_list_groups.append([i1,i2])
res_range_groups.append([res_range1, res_range2])
# one of the hashes is already in a multi-MSA
# find that multi-MSA and join the new pMSA to it
elif i1 in idx_list:
# which multi-MSA & sub-MSA has the hash with index `i1`?
i_a3m = np.where([i1 in group for group in idx_list_groups])[0][0]
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i1)[0][0]
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range1
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
idx_list_groups[i_a3m].append(i2)
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
L_new = res_range2[1] - res_range2[0]
res_range_groups[i_a3m].append((L, L+L_new))
elif i2 in idx_list:
# which multi-MSA & sub-MSA has the hash with index `i2`?
i_a3m = np.where([i2 in group for group in idx_list_groups])[0][0]
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i2)[0][0]
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range2
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
idx_list_groups[i_a3m].append(i1)
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
L_new = res_range1[1] - res_range1[0]
res_range_groups[i_a3m].append((L, L+L_new))
# add unpaired MSAs
# ungroup hash indices now, since we're done making multi-MSAs
idx_list = [x for group in idx_list_groups for x in group]
for i in range(N):
if i not in idx_list:
fn = params['PDB_DIR'] + '/a3m/' + hash_list[i][:3] + '/' + hash_list[i] + '.a3m.gz'
msa, ins, taxid = parse_a3m(fn)
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins),
taxid=taxid, is_paired=torch.ones(msa.shape[0]).bool())
a3m_list.append(a3m_new)
idx_list.append(i)
Ls_out = [Ls[i] for i in idx_list]
hashes_out = [hash_list[i] for i in idx_list]
# join multi-MSAs & unpaired MSAs
a3m_out = a3m_list[0]
for i in range(1, len(a3m_list)):
a3m_out = join_msas_by_taxid(a3m_out, a3m_list[i])
return a3m_out, hashes_out, Ls_out
def expand_multi_msa(a3m, hashes_in, hashes_out, Ls_in, Ls_out, params):
"""Expands a multi-MSA of unique chains into an MSA of a
hetero-homo-oligomer in which some chains appear more than once. The query
sequences (1st sequence of MSA) are concatenated directly along the
residue dimention. The remaining sequences are offset-tiled (i.e. "padded &
stacked") so that exact repeat sequences aren't paired.
For example, if the original multi-MSA contains unique chains 1,2,3 but
the final chain order is 1,2,1,3,3,1, this function will output an MSA like
(where - denotes a block of gap characters):
1 2 - 3 - -
- - 1 - 3 -
- - - - - 1
Parameters
----------
a3m : dict
Contains torch.Tensors `msa` and `ins` (N_seq, L) and np.array `taxid` (Nseq,),
representing the multi-MSA of unique chains.
hashes_in : list of str
Unique MSA hashes used in `a3m`.
hashes_out : list of str
Non-unique MSA hashes desired in expanded MSA.
Ls_in : list of int
Lengths of each chain in `a3m`
Ls_out : list of int
Lengths of each chain desired in expanded MSA.
params : dict
Data loading parameters
Returns
-------
a3m : dict
Contains torch.Tensors `msa` and `ins` of expanded MSA. No
taxids because no further joining needs to be done.
"""
assert(len(hashes_out)==len(Ls_out))
assert(set(hashes_in)==set(hashes_out))
assert(a3m['msa'].shape[1]==sum(Ls_in))
# figure out which oligomeric repeat is represented by each hash in `hashes_out`
# each new repeat will be offset in sequence dimension of final MSA
counts = dict()
n_copy = [] # n-th copy of this hash in `hashes`
for h in hashes_out:
if h in counts:
counts[h] += 1
else:
counts[h] = 1
n_copy.append(counts[h])
# num sequences in source & destination MSAs
N_in = a3m['msa'].shape[0]
N_out = (N_in-1)*max(n_copy)+1 # concatenate query seqs, pad&stack the rest
# source MSA
msa_in, ins_in = a3m['msa'], a3m['ins']
# initialize destination MSA to gap characters
msa_out = torch.full((N_out, sum(Ls_out)), ChemData().UNKINDEX)
ins_out = torch.full((N_out, sum(Ls_out)), 0)
# for each destination chain
for i_out, h_out in enumerate(hashes_out):
# identify index of source chain
i_in = np.where(np.array(hashes_in)==h_out)[0][0]
# residue indexes
i1_res_in = sum(Ls_in[:i_in])
i2_res_in = sum(Ls_in[:i_in+1])
i1_res_out = sum(Ls_out[:i_out])
i2_res_out = sum(Ls_out[:i_out+1])
# copy over query sequence
# NOTE: There is a bug in these next two lines!
# The second line should be ins_out[0, i1_res_out:i2_res_out] = ins_in[0, i1_res_in:i2_res_in]
msa_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
ins_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
# offset non-query sequences along sequence dimension based on repeat number of a given hash
i1_seq_out = 1+(n_copy[i_out]-1)*(N_in-1)
i2_seq_out = 1+n_copy[i_out]*(N_in-1)
# copy over non-query sequences
msa_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = msa_in[1:, i1_res_in:i2_res_in]
ins_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = ins_in[1:, i1_res_in:i2_res_in]
# only 1st oligomeric repeat can be fully paired
is_paired_out = torch.cat([a3m['is_paired'], torch.zeros((N_out-N_in,)).bool()])
a3m_out = dict(msa=msa_out, ins=ins_out, is_paired=is_paired_out)
a3m_out = remove_all_gap_seqs(a3m_out)
return a3m_out
def load_multi_msa(chain_ids, Ls, chid2hash, chid2taxid, params):
"""Loads multi-MSA for an arbitrary number of protein chains. Tries to
locate paired MSAs and pair sequences across all chains by taxonomic ID.
Unpaired sequences are padded and stacked on the bottom.
"""
# get MSA hashes (used to locate a3m files) and taxonomic IDs (used to determine pairing)
hashes = []
hashes_unique = []
taxids_unique = []
Ls_unique = []
for chid,L_ in zip(chain_ids, Ls):
hashes.append(chid2hash[chid])
if chid2hash[chid] not in hashes_unique:
hashes_unique.append(chid2hash[chid])
taxids_unique.append(chid2taxid.get(chid))
Ls_unique.append(L_)
# loads multi-MSA for unique chains
a3m_prot, hashes_unique, Ls_unique = \
load_minimal_multi_msa(hashes_unique, taxids_unique, Ls_unique, params)
# expands multi-MSA to repeat chains of homo-oligomers
a3m_prot = expand_multi_msa(a3m_prot, hashes_unique, hashes, Ls_unique, Ls, params)
return a3m_prot
def choose_multimsa_clusters(msa_seq_is_paired, params):
"""Returns indices of fully-paired sequences in a multi-MSA to use as seed
clusters during MSA featurization.
"""
frac_paired = msa_seq_is_paired.float().mean()
if frac_paired > 0.25: # enough fully paired sequences, just let MSAFeaturize choose randomly
return None
else:
# ensure that half of the clusters are fully-paired sequences,
# and let the rest be chosen randomly
N_seed = params['MAXLAT']//2
msa_seed_clus = []
for i_cycle in range(params['MAXCYCLE']):
idx_paired = torch.where(msa_seq_is_paired)[0]
msa_seed_clus.append(idx_paired[torch.randperm(len(idx_paired))][:N_seed])
return msa_seed_clus
#fd
def get_bond_distances(bond_feats):
atom_bonds = (bond_feats > 0)*(bond_feats<5)
dist_matrix = scipy.sparse.csgraph.shortest_path(atom_bonds.long().numpy(), directed=False)
# dist_matrix = torch.tensor(np.nan_to_num(dist_matrix, posinf=4.0)) # protein portion is inf and you don't want to mask it out
return torch.from_numpy(dist_matrix).float()
def get_pdb(pdbfilename, plddtfilename, item, lddtcut, sccut):
xyz, mask, res_idx = parse_pdb(pdbfilename)
plddt = np.load(plddtfilename)
# update mask info with plddt (ignore sidechains if plddt < 90.0)
mask_lddt = np.full_like(mask, False)
mask_lddt[plddt > sccut] = True
mask_lddt[:,:5] = True
mask = np.logical_and(mask, mask_lddt)
mask = np.logical_and(mask, (plddt > lddtcut)[:,None])
return {'xyz':torch.tensor(xyz), 'mask':torch.tensor(mask), 'idx': torch.tensor(res_idx), 'label':item}
def get_msa(a3mfilename, item, maxseq=5000):
msa,ins, taxIDs = parse_a3m(a3mfilename, maxseq=5000)
return {'msa':torch.tensor(msa), 'ins':torch.tensor(ins), 'taxIDs':taxIDs, 'label':item}

151
rf2aa/data/merge_inputs.py Normal file
View file

@ -0,0 +1,151 @@
import torch
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, merge_hetero_templates, get_term_feats
from rf2aa.data.data_loader import RawInputData
from rf2aa.util import center_and_realign_missing, same_chain_from_bond_feats, random_rot_trans, idx_from_Ls
def merge_protein_inputs(protein_inputs, deterministic: bool = False):
if len(protein_inputs) == 0:
return None,[]
elif len(protein_inputs) == 1:
chain = list(protein_inputs.keys())[0]
input = list(protein_inputs.values())[0]
xyz_t = input.xyz_t
xyz_t[0:1] = random_rot_trans(xyz_t[0:1], deterministic=deterministic)
input.xyz_t = xyz_t
return input, [(chain, input.length())]
# handle merging MSAs and such
# first determine which sequence are identical, then which one have mergeable MSAs
# then cat the templates, other feats
pass
def merge_na_inputs(na_inputs):
# should just be trivially catting features
running_inputs = None
chain_lengths = []
for chid, input in na_inputs.items():
running_inputs = merge_two_inputs(running_inputs, input)
chain_lengths.append((chid, input.length()))
return running_inputs, chain_lengths
def merge_sm_inputs(sm_inputs):
# should be trivially catting features
running_inputs = None
chain_lengths = []
for chid, input in sm_inputs.items():
running_inputs = merge_two_inputs(running_inputs, input)
chain_lengths.append((chid, input.length()))
return running_inputs, chain_lengths
def merge_two_inputs(first_input, second_input):
# merges two arbitrary inputs of data types
if first_input is None and second_input is None:
return None
elif first_input is None:
return second_input
elif second_input is None:
return first_input
Ls = [first_input.length(), second_input.length()]
L_total = sum(Ls)
# merge msas
a3m_first = {
"msa": first_input.msa,
"ins": first_input.ins,
}
a3m_second = {
"msa": second_input.msa,
"ins": second_input.ins,
}
a3m = merge_a3m_hetero(a3m_first, a3m_second, Ls)
# merge bond_feats
bond_feats = torch.zeros((L_total, L_total)).long()
offset = 0
for bf in [first_input.bond_feats, second_input.bond_feats]:
L = bf.shape[0]
bond_feats[offset:offset+L, offset:offset+L] = bf
offset += L
# merge templates
xyz_t = torch.cat([first_input.xyz_t, second_input.xyz_t],dim=1)
t1d = torch.cat([first_input.t1d, second_input.t1d],dim=1)
mask_t = torch.cat([first_input.mask_t, second_input.mask_t],dim=1)
# handle chirals (need to residue offset)
if second_input.chirals.shape[0] > 0 :
second_input.chirals[:, :-1] = second_input.chirals[:, :-1] + first_input.length()
chirals = torch.cat([first_input.chirals, second_input.chirals])
# cat atom frames
atom_frames = torch.cat([first_input.atom_frames, second_input.atom_frames])
# return new object
return RawInputData(
a3m["msa"],
a3m["ins"],
bond_feats,
xyz_t,
mask_t,
t1d,
chirals,
atom_frames,
taxids=None
)
def merge_all(
protein_inputs,
na_inputs,
sm_inputs,
residues_to_atomize,
deterministic: bool = False,
):
#protein_lengths = [protein_input.length() for protein_input in protein_inputs.values()]
#na_lengths = [na_input.length() for na_input in na_inputs.values()]
#sm_lengths = [sm_input.length() for sm_input in sm_inputs.values()]
#all_lengths = protein_lengths + na_lengths + sm_lengths
#term_info = get_term_feats(all_lengths)
#term_info[sum(protein_lengths):, :] = 0
protein_inputs, protein_chain_lengths = merge_protein_inputs(protein_inputs, deterministic=deterministic)
na_inputs, na_chain_lengths = merge_na_inputs(na_inputs)
sm_inputs, sm_chain_lengths = merge_sm_inputs(sm_inputs)
if protein_inputs is None and na_inputs is None and sm_inputs is None:
raise ValueError("No valid inputs were provided")
running_inputs = merge_two_inputs(protein_inputs, na_inputs) #could handle pairing protein/NA MSAs here
running_inputs = merge_two_inputs(running_inputs, sm_inputs)
all_chain_lengths = protein_chain_lengths + na_chain_lengths + sm_chain_lengths
running_inputs.chain_lengths = all_chain_lengths
all_lengths = get_Ls_from_chain_lengths(running_inputs.chain_lengths)
protein_lengths = get_Ls_from_chain_lengths(protein_chain_lengths)
term_info = get_term_feats(all_lengths)
term_info[sum(protein_lengths):, :] = 0
running_inputs.term_info = term_info
xyz_t = running_inputs.xyz_t
mask_t = running_inputs.mask_t
same_chain = same_chain = same_chain_from_bond_feats(running_inputs.bond_feats)
ntempl = xyz_t.shape[0]
xyz_t = torch.stack(
[center_and_realign_missing(xyz_t[i], mask_t[i], same_chain=same_chain) for i in range(ntempl)]
)
xyz_t = torch.nan_to_num(xyz_t)
running_inputs.xyz_t = xyz_t
running_inputs.idx = idx_from_Ls(all_lengths)
# after everything is merged need to add bond feats for covales
# reindex protein feats function
if residues_to_atomize:
running_inputs.update_protein_features_after_atomize(residues_to_atomize)
return running_inputs
def get_Ls_from_chain_lengths(chain_lengths):
return [val[1] for val in chain_lengths]

View file

@ -0,0 +1,46 @@
import numpy as np
import torch
from rf2aa.data.parsers import parse_mixed_fasta, parse_multichain_fasta
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, blank_template
from rf2aa.data.data_loader import RawInputData
from rf2aa.util import get_protein_bond_feats
def load_nucleic_acid(fasta_fn, input_type, model_runner):
if input_type not in ["dna", "rna"]:
raise ValueError("Only DNA and RNA inputs allowed for nucleic acids")
if input_type == "dna":
dna_alphabet = True
rna_alphabet = False
elif input_type == "rna":
dna_alphabet = False
rna_alphabet = True
loader_params = model_runner.config.loader_params
msa, ins, L = parse_multichain_fasta(fasta_fn, rna_alphabet=rna_alphabet, dna_alphabet=dna_alphabet)
if (msa.shape[0] > loader_params["MAXSEQ"]):
idxs_tokeep = np.random.permutation(msa.shape[0])[:loader_params["MAXSEQ"]]
idxs_tokeep[0] = 0
msa = msa[idxs_tokeep]
ins = ins[idxs_tokeep]
if len(L) > 1:
raise ValueError("Please provide separate fasta files for each nucleic acid chain")
L = L[0]
xyz_t, t1d, mask_t, _ = blank_template(loader_params["n_templ"], L)
bond_feats = get_protein_bond_feats(L)
chirals = torch.zeros(0, 5)
atom_frames = torch.zeros(0, 3, 2)
return RawInputData(
torch.from_numpy(msa),
torch.from_numpy(ins),
bond_feats,
xyz_t,
mask_t,
t1d,
chirals,
atom_frames,
taxids=None,
)

809
rf2aa/data/parsers.py Normal file
View file

@ -0,0 +1,809 @@
import numpy as np
import scipy
import scipy.spatial
import string
import os,re
from os.path import exists
import random
import rf2aa.util as util
import gzip
import rf2aa
from rf2aa.ffindex import *
import torch
from openbabel import openbabel
from rf2aa.chemical import ChemicalData as ChemData
def get_dislf(seq, xyz, mask):
L = seq.shape[0]
resolved_cys_mask = ((seq==ChemData().aa2num['CYS']) * mask[:,5]).nonzero().squeeze(-1) # cys[5]=='sg'
sgs = xyz[resolved_cys_mask,5]
ii,jj = torch.triu_indices(sgs.shape[0],sgs.shape[0],1)
d_sg_sg = torch.linalg.norm(sgs[ii,:]-sgs[jj,:], dim=-1)
is_dslf = (d_sg_sg>1.7)*(d_sg_sg<2.3)
dslf = []
for i in is_dslf.nonzero():
dslf.append( (
resolved_cys_mask[ii[i]].item(),
resolved_cys_mask[jj[i]].item(),
) )
return dslf
def read_template_pdb(L, pdb_fn, target_chain=None):
# get full sequence from given PDB
seq_full = list()
prev_chain=''
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
seq_full = torch.tensor(seq_full).long()
xyz = torch.full((L, 36, 3), np.nan).float()
seq = torch.full((L,), 20).long()
conf = torch.zeros(L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
#
idx = resNo - 1
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
if tgtatm == atom:
xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
break
seq[idx] = aa_idx
mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3)
mask = mask.all(dim=-1)[:,None]
conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float()
t1d = torch.cat((seq_1hot, conf), -1)
#return seq_full[None], ins[None], L_s, xyz[None], t1d[None]
return xyz[None], t1d[None]
def read_multichain_pdb(pdb_fn, tmpl_chain=None, tmpl_conf=0.1):
print ('read_multichain_pdb',tmpl_chain)
# get full sequence from PDB
seq_full = list()
L_s = list()
prev_chain=''
offset = 0
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
if line[12:16].strip() != "CA":
continue
if line[21] != prev_chain:
if len(seq_full) > 0:
L_s.append(len(seq_full)-offset)
offset = len(seq_full)
prev_chain = line[21]
aa = line[17:20]
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
L_s.append(len(seq_full) - offset)
seq_full = torch.tensor(seq_full).long()
L = len(seq_full)
msa = torch.stack((seq_full,seq_full,seq_full), dim=0)
msa[1,:L_s[0]] = 20
msa[2,L_s[0]:] = 20
ins = torch.zeros_like(msa)
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
xyz_t = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
mask = torch.full((1, L, ChemData().NTOTAL), False)
mask_t = torch.full((1, L, ChemData().NTOTAL), False)
seq = torch.full((1, L,), 20).long()
conf = torch.zeros(1, L,1).float()
with open(pdb_fn) as fp:
for line in fp:
if line[:4] != "ATOM":
continue
outbatch = 0
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
idx = resNo - 1
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
if tgtatm == atom:
xyz_i = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
xyz[0, idx, i_atm, :] = xyz_i
mask[0, idx, i_atm] = True
if line[21] == tmpl_chain:
xyz_t[0, idx, i_atm, :] = xyz_i
mask_t[0, idx, i_atm] = True
break
seq[0, idx] = aa_idx
if (mask_t.any()):
xyz_t[0] = rf2aa.util.center_and_realign_missing(xyz[0], mask[0])
dslf = get_dislf(seq[0], xyz[0], mask[0])
# assign confidence 'CONF' to all residues with backbone in template
conf = torch.where(mask_t[...,:3].all(dim=-1)[...,None], torch.full((1,L,1),tmpl_conf), torch.zeros(L,1)).float()
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=ChemData().NAATOKENS-1).float()
t1d = torch.cat((seq_1hot, conf), -1)
return msa, ins, L_s, xyz_t, mask_t, t1d, dslf
def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
# convert letters into numbers
if rmsa_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins
# Parse a fasta file containing multiple chains separated by '/'
def parse_multichain_fasta(filename, maxseq=10000, rna_alphabet=False, dna_alphabet=False):
msa = []
ins = []
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
L_s = []
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
if L_s == []:
L_s = [len(x) for x in msa_i.split('/')]
msa_i = msa_i.replace('/','')
msa.append(msa_i)
# sequence length
L = len(msa[-1])
i = np.zeros((L))
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
if rna_alphabet:
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
elif dna_alphabet:
alphabet = np.array(list("00000000000000000000-0ACGTD00000"), dtype='|S1').view(np.uint8)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
ins = np.array(ins, dtype=np.uint8)
return msa,ins,L_s
#fd - parse protein/RNA coupled fastas
def parse_mixed_fasta(filename, maxseq=10000):
msa1,msa2 = [],[]
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
unpaired_r, unpaired_p = 0, 0
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
msas_i = msa_i.split('/')
if (len(msas_i)==1):
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
if (len(msa1)==0 or (
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
)):
# skip if we've already found half of our limit in unpaired protein seqs
if sum([1 for x in msas_i[1] if x != '-']) == 0:
unpaired_p += 1
if unpaired_p > maxseq // 2:
continue
# skip if we've already found half of our limit in unpaired rna seqs
if sum([1 for x in msas_i[0] if x != '-']) == 0:
unpaired_r += 1
if unpaired_r > maxseq // 2:
continue
msa1.append(msas_i[0])
msa2.append(msas_i[1])
else:
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
if (len(msa1) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa1[msa1 == alphabet[i]] = i
msa1[msa1>=31] = 21 # anything unknown to 'X'
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa2[msa2 == alphabet[i]] = i
msa2[msa2>=31] = 30 # anything unknown to 'N'
msa = np.concatenate((msa1,msa2),axis=-1)
ins = np.zeros(msa.shape, dtype=np.uint8)
return msa,ins
# parse a fasta alignment IF it exists
# otherwise return single-sequence msa
def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False):
if (exists(filename)):
return parse_fasta(filename, maxseq, rmsa_alphabet)
else:
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask
seq = np.array([list(seq)], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
seq[seq == alphabet[i]] = i
return (seq, np.zeros_like(seq))
#fd - parse protein/RNA coupled fastas
def parse_mixed_fasta(filename, maxseq=8000):
msa1,msa2 = [],[]
fstream = open(filename,"r")
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
unpaired_r, unpaired_p = 0, 0
for line in fstream:
# skip labels
if line[0] == '>':
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa_i = line.translate(table)
msa_i = msa_i.replace('B','D') # hacky...
msas_i = msa_i.split('/')
if (len(msas_i)==1):
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
if (len(msa1)==0 or (
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
)):
# skip if we've already found half of our limit in unpaired protein seqs
if sum([1 for x in msas_i[1] if x != '-']) == 0:
unpaired_p += 1
if unpaired_p > maxseq // 2:
continue
# skip if we've already found half of our limit in unpaired rna seqs
if sum([1 for x in msas_i[0] if x != '-']) == 0:
unpaired_r += 1
if unpaired_r > maxseq // 2:
continue
msa1.append(msas_i[0])
msa2.append(msas_i[1])
else:
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
if (len(msa1) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa1[msa1 == alphabet[i]] = i
msa1[msa1>=31] = 21 # anything unknown to 'X'
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa2[msa2 == alphabet[i]] = i
msa2[msa2>=31] = 30 # anything unknown to 'N'
msa = np.concatenate((msa1,msa2),axis=-1)
ins = np.zeros(msa.shape, dtype=np.uint8)
return msa,ins
# read A3M and convert letters into
# integers in the 0..20 range,
# also keep track of insertions
def parse_a3m(filename, maxseq=8000, paired=False):
msa = []
ins = []
taxIDs = []
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
# read file line by line
if filename.split('.')[-1] == 'gz':
fstream = gzip.open(filename, 'rt')
else:
fstream = open(filename, 'r')
for line in fstream:
# skip labels
if line[0] == '>':
if paired: # paired MSAs only have a TAXID in the fasta header
taxIDs.append(line[1:].strip())
else: # unpaired MSAs have all the metadata so use regex to pull out TAXID
match = re.search( r'TaxID=(\d+)', line)
if match:
taxIDs.append(match.group(1))
else:
taxIDs.append("query") # query sequence
continue
# remove right whitespaces
line = line.rstrip()
if len(line) == 0:
continue
# remove lowercase letters and append to MSA
msa.append(line.translate(table))
# sequence length
L = len(msa[-1])
# 0 - match or gap; 1 - insertion
a = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
i = np.zeros((L))
if np.sum(a) > 0:
# positions of insertions
pos = np.where(a==1)[0]
# shift by occurrence
a = pos - np.arange(pos.shape[0])
# position of insertions in cleaned sequence
# and their length
pos,num = np.unique(a, return_counts=True)
# append to the matrix of insetions
i[pos] = num
ins.append(i)
if (len(msa) >= maxseq):
break
# convert letters into numbers
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
for i in range(alphabet.shape[0]):
msa[msa == alphabet[i]] = i
# treat all unknown characters as gaps
msa[msa > 20] = 20
ins = np.array(ins, dtype=np.uint8)
return msa,ins, np.array(taxIDs)
# read and extract xyz coords of N,Ca,C atoms
# from a PDB file
def parse_pdb(filename, seq=False, lddt_mask=False):
lines = open(filename,'r').readlines()
if seq:
return parse_pdb_lines_w_seq(lines, lddt_mask=lddt_mask)
return parse_pdb_lines(lines)
def parse_pdb_lines_w_seq(lines, lddt_mask=False):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]
plddt = [float(r[3]) for r in res]
seq = [ChemData().aa2num[r[2]] if r[2] in ChemData().aa2num.keys() else 20 for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
idx = pdb_idx_s.index((chain,resNo))
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
if lddt_mask == True:
plddt = np.array(plddt)
mask_lddt = np.full_like(mask, False)
mask_lddt[plddt > .85, 5:] = True
mask_lddt[plddt > .70, :5] = True
mask = np.logical_and(mask, mask_lddt)
return xyz,mask,np.array(idx_s), np.array(seq)
#'''
def parse_pdb_lines(lines):
# indices of residues observed in the structure
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
pdb_idx_s = [(r[0], int(r[1])) for r in res]
idx_s = [int(r[1]) for r in res]
# 4 BB + up to 10 SC atoms
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
for l in lines:
if l[:4] != "ATOM":
continue
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
idx = pdb_idx_s.index((chain,resNo))
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
if tgtatm == atom:
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
break
# save atom mask
mask = np.logical_not(np.isnan(xyz[...,0]))
xyz[np.isnan(xyz[...,0])] = 0.0
return xyz,mask,np.array(idx_s)
def parse_templates(item, params):
# init FFindexDB of templates
### and extract template IDs
### present in the DB
ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'),
read_data(params['FFDB']+'_pdb.ffdata'))
#ffids = set([i.name for i in ffdb.index])
# process tabulated hhsearch output to get
# matched positions and positional scores
infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab'
hits = []
for l in open(infile, "r").readlines():
if l[0]=='>':
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(infile[:-4]+'hhr', "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(bool)
qmap = np.vstack(qmap).astype(np.long)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
ids = ids
return xyz,mask,qmap,f0d,f1d,ids
def parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=20):
# process tabulated hhsearch output to get
# matched positions and positional scores
hits = []
for l in open(atab_fn, "r").readlines():
if l[0]=='>':
if len(hits) == max_templ:
break
key = l[1:].split()[0]
hits.append([key,[],[]])
elif "score" in l or "dssp" in l:
continue
else:
hi = l.split()[:5]+[0.0,0.0,0.0]
hits[-1][1].append([int(hi[0]),int(hi[1])])
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
# get per-hit statistics from an .hhr file
# (!!! assume that .hhr and .atab have the same hits !!!)
# [Probab, E-value, Score, Aligned_cols,
# Identities, Similarity, Sum_probs, Template_Neff]
lines = open(hhr_fn, "r").readlines()
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
for i,posi in enumerate(pos[:len(hits)]):
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
# parse templates from FFDB
for hi in hits:
#if hi[0] not in ffids:
# continue
entry = get_entry_by_name(hi[0], ffdb.index)
if entry == None:
print ("Failed to find %s in *_pdb.ffindex"%hi[0])
continue
data = read_entry_lines(entry, ffdb.data)
hi += list(parse_pdb_lines_w_seq(data))
# process hits
counter = 0
xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[]
for data in hits:
if len(data)<7:
continue
# print ("Process %s..."%data[0])
qi,ti = np.array(data[1]).T
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
ncol = sel1.shape[0]
if ncol < 10:
continue
ids.append(data[0])
f0d.append(data[3])
f1d.append(np.array(data[2])[sel1])
xyz.append(data[4][sel2])
mask.append(data[5][sel2])
seq.append(data[-1][sel2])
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
counter += 1
xyz = np.vstack(xyz).astype(np.float32)
mask = np.vstack(mask).astype(bool)
qmap = np.vstack(qmap).astype(np.int64)
f0d = np.vstack(f0d).astype(np.float32)
f1d = np.vstack(f1d).astype(np.float32)
seq = np.hstack(seq).astype(np.int64)
ids = ids
return torch.from_numpy(xyz), torch.from_numpy(mask), torch.from_numpy(qmap), \
torch.from_numpy(f0d), torch.from_numpy(f1d), torch.from_numpy(seq), ids
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
xyz_t, mask_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=max(n_templ, 20))
ntmplatoms = xyz_t.shape[1]
npick = min(n_templ, len(ids))
if npick < 1: # no templates
xyz = torch.full((1,qlen,ChemData().NTOTAL,3),np.nan).float()
mask = torch.full((1,qlen,ChemData().NTOTAL),False)
t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps
t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1)
return xyz, mask, t1d
sample = torch.arange(npick)
#
xyz = torch.full((npick, qlen, ChemData().NTOTAL, 3), np.nan).float()
mask = torch.full((npick, qlen, ChemData().NTOTAL), False)
f1d = torch.full((npick, qlen), 20).long()
f1d_val = torch.zeros((npick, qlen, 1)).float()
#
for i, nt in enumerate(sample):
sel = torch.where(qmap[:,1] == nt)[0]
pos = qmap[sel, 0]
xyz[i, pos] = xyz_t[sel]
mask[i, pos, :ntmplatoms] = mask_t[sel].bool()
f1d[i, pos] = seq[sel]
f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1)
xyz[i] = util.center_and_realign_missing(xyz[i], mask[i], seq=f1d[i])
f1d = torch.nn.functional.one_hot(f1d, num_classes=ChemData().NAATOKENS-1).float()
f1d = torch.cat((f1d, f1d_val), dim=-1)
return xyz, mask, f1d
def clean_sdffile(filename):
# lowercase the 2nd letter of the element name (e.g. FE->Fe) so openbabel can parse it correctly
lines2 = []
with open(filename) as f:
lines = f.readlines()
num_atoms = int(lines[3][:3])
for i in range(len(lines)):
if i>=4 and i<4+num_atoms:
lines2.append(lines[i][:32]+lines[i][32].lower()+lines[i][33:])
else:
lines2.append(lines[i])
molstring = ''.join(lines2)
return molstring
def parse_mol(filename, filetype="mol2", string=False, remove_H=True, find_automorphs=True, generate_conformer: bool = False):
"""Parse small molecule ligand.
Parameters
----------
filename : str
filetype : str
string : bool
If True, `filename` is a string containing the molecule data.
remove_H : bool
Whether to remove hydrogen atoms.
find_automorphs : bool
Whether to enumerate atom symmetry permutations.
Returns
-------
obmol : OBMol
openbabel molecule object representing the ligand
msa : torch.Tensor (N_atoms,) long
Integer-encoded "sequence" (atom types) of ligand
ins : torch.Tensor (N_atoms,) long
Insertion features (all zero) for RF input
atom_coords : torch.Tensor (N_symmetry, N_atoms, 3) float
Atom coordinates
mask : torch.Tensor (N_symmetry, N_atoms) bool
Boolean mask for whether atom exists
"""
obConversion = openbabel.OBConversion()
obConversion.SetInFormat(filetype)
obmol = openbabel.OBMol()
if string:
obConversion.ReadString(obmol,filename)
elif filetype=='sdf':
molstring = clean_sdffile(filename)
obConversion.ReadString(obmol,molstring)
else:
obConversion.ReadFile(obmol,filename)
if generate_conformer:
builder = openbabel.OBBuilder()
builder.Build(obmol)
ff = openbabel.OBForceField.FindForceField("mmff94")
did_setup = ff.Setup(obmol)
if did_setup:
ff.FastRotorSearch()
ff.GetCoordinates(obmol)
else:
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
if remove_H:
obmol.DeleteHydrogens()
# the above sometimes fails to get all the hydrogens
i = 1
while i < obmol.NumAtoms()+1:
if obmol.GetAtom(i).GetAtomicNum()==1:
obmol.DeleteAtom(obmol.GetAtom(i))
else:
i += 1
atomtypes = [ChemData().atomnum2atomtype.get(obmol.GetAtom(i).GetAtomicNum(), 'ATM')
for i in range(1, obmol.NumAtoms()+1)]
msa = torch.tensor([ChemData().aa2num[x] for x in atomtypes])
ins = torch.zeros_like(msa)
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
mask = torch.full(atom_coords.shape[:-1], True) # (1, natoms,)
if find_automorphs:
atom_coords, mask = util.get_automorphs(obmol, atom_coords[0], mask[0])
return obmol, msa, ins, atom_coords, mask

View file

@ -0,0 +1,34 @@
import os
from hydra import initialize, compose
from pathlib import Path
import subprocess
#from rf2aa.run_inference import ModelRunner
def make_msa(
fasta_file,
model_runner
):
out_dir_base = Path(model_runner.config.output_path)
hash = model_runner.config.job_name
out_dir = out_dir_base / hash
out_dir.mkdir(parents=True, exist_ok=True)
command = model_runner.config.database_params.command
search_base = model_runner.config.database_params.sequencedb
num_cpus = model_runner.config.database_params.num_cpus
ram_gb = model_runner.config.database_params.mem
template_database = model_runner.config.database_params.hhdb
out_a3m = out_dir / "t000_.msa0.a3m"
out_atab = out_dir / "t000_.atab"
out_hhr = out_dir / "t000_.hhr"
if out_a3m.exists() and out_atab.exists() and out_hhr.exists():
return out_a3m, out_hhr, out_atab
search_command = f"./{command} {fasta_file} {out_dir} {num_cpus} {ram_gb} {search_base} {template_database}"
print(search_command)
_ = subprocess.run(search_command, shell=True)
return out_a3m, out_hhr, out_atab

93
rf2aa/data/protein.py Normal file
View file

@ -0,0 +1,93 @@
import torch
from rf2aa.data.data_loader import RawInputData
from rf2aa.data.data_loader_utils import blank_template, TemplFeaturize
from rf2aa.data.parsers import parse_a3m, parse_templates_raw
from rf2aa.data.preprocessing import make_msa
from rf2aa.util import get_protein_bond_feats
def get_templates(
qlen,
ffdb,
hhr_fn,
atab_fn,
seqID_cut,
n_templ,
pick_top: bool = True,
offset: int = 0,
random_noise: float = 5.0,
deterministic: bool = False,
):
(
xyz_parsed,
mask_parsed,
qmap_parsed,
f0d_parsed,
f1d_parsed,
seq_parsed,
ids_parsed,
) = parse_templates_raw(ffdb, hhr_fn=hhr_fn, atab_fn=atab_fn)
tplt = {
"xyz": xyz_parsed.unsqueeze(0),
"mask": mask_parsed.unsqueeze(0),
"qmap": qmap_parsed.unsqueeze(0),
"f0d": f0d_parsed.unsqueeze(0),
"f1d": f1d_parsed.unsqueeze(0),
"seq": seq_parsed.unsqueeze(0),
"ids": ids_parsed,
}
params = {
"SEQID": seqID_cut,
}
return TemplFeaturize(
tplt,
qlen,
params,
offset=offset,
npick=n_templ,
pick_top=pick_top,
random_noise=random_noise,
deterministic=deterministic,
)
def load_protein(msa_file, hhr_fn, atab_fn, model_runner):
msa, ins, taxIDs = parse_a3m(msa_file)
# NOTE: this next line is a bug, but is the way that
# the code is written in the original implementation!
ins[0] = msa[0]
L = msa.shape[1]
if hhr_fn is None or atab_fn is None:
print("No templates provided")
xyz_t, t1d, mask_t, _ = blank_template(1, L)
else:
xyz_t, t1d, mask_t, _ = get_templates(
L,
model_runner.ffdb,
hhr_fn,
atab_fn,
seqID_cut=model_runner.config.loader_params.seqid,
n_templ=model_runner.config.loader_params.n_templ,
deterministic=model_runner.deterministic,
)
bond_feats = get_protein_bond_feats(L)
chirals = torch.zeros(0, 5)
atom_frames = torch.zeros(0, 3, 2)
return RawInputData(
torch.from_numpy(msa),
torch.from_numpy(ins),
bond_feats,
xyz_t,
mask_t,
t1d,
chirals,
atom_frames,
taxids=taxIDs,
)
def generate_msa_and_load_protein(fasta_file, model_runner):
msa_file, hhr_file, atab_file = make_msa(fasta_file, model_runner)
return load_protein(str(msa_file), str(hhr_file), str(atab_file), model_runner)

View file

@ -0,0 +1,41 @@
import torch
from rf2aa.data.data_loader import RawInputData
from rf2aa.data.data_loader_utils import blank_template
from rf2aa.data.parsers import parse_mol
from rf2aa.kinematics import get_chirals
from rf2aa.util import get_bond_feats, get_nxgraph, get_atom_frames
def load_small_molecule(input_file, input_type, model_runner):
if input_type == "smiles":
is_string = True
else:
is_string = False
obmol, msa, ins, xyz, mask = parse_mol(
input_file, filetype=input_type, string=is_string, generate_conformer=True
)
return compute_features_from_obmol(obmol, msa, xyz, model_runner)
def compute_features_from_obmol(obmol, msa, xyz, model_runner):
L = msa.shape[0]
ins = torch.zeros_like(msa)
bond_feats = get_bond_feats(obmol)
xyz_t, t1d, mask_t, _ = blank_template(
model_runner.config.loader_params.n_templ,
L,
deterministic=model_runner.deterministic,
)
chirals = get_chirals(obmol, xyz[0])
G = get_nxgraph(obmol)
atom_frames = get_atom_frames(msa, G)
msa, ins = msa[None], ins[None]
return RawInputData(
msa, ins, bond_feats, xyz_t, mask_t, t1d, chirals, atom_frames, taxids=None
)
def remove_leaving_atoms(input, is_leaving):
keep = ~is_leaving
return input.keep_features(keep)

91
rf2aa/ffindex.py Normal file
View file

@ -0,0 +1,91 @@
#!/usr/bin/env python
# https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py
'''
Created on Apr 30, 2014
@author: meiermark
'''
import sys
import mmap
from collections import namedtuple
FFindexEntry = namedtuple("FFindexEntry", "name, offset, length")
def read_index(ffindex_filename):
entries = []
fh = open(ffindex_filename, "r")
for line in fh:
tokens = line.split("\t")
entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2])))
fh.close()
return entries
def read_data(ffdata_filename):
fh = open(ffdata_filename, "rb")
data = mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ)
fh.close()
return data
def get_entry_by_name(name, index):
#TODO: bsearch
for entry in index:
if(name == entry.name):
return entry
return None
def read_entry_lines(entry, data):
lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n")
return lines
def read_entry_data(entry, data):
return data[entry.offset:entry.offset + entry.length - 1]
def write_entry(entries, data_fh, entry_name, offset, data):
data_fh.write(data[:-1])
data_fh.write(bytearray(1))
entry = FFindexEntry(entry_name, offset, len(data))
entries.append(entry)
return offset + len(data)
def write_entry_with_file(entries, data_fh, entry_name, offset, file_name):
with open(file_name, "rb") as fh:
data = bytearray(fh.read())
return write_entry(entries, data_fh, entry_name, offset, data)
def finish_db(entries, ffindex_filename, data_fh):
data_fh.close()
write_entries_to_db(entries, ffindex_filename)
def write_entries_to_db(entries, ffindex_filename):
sorted(entries, key=lambda x: x.name)
index_fh = open(ffindex_filename, "w")
for entry in entries:
index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length))
index_fh.close()
def write_entry_to_file(entry, data, file):
lines = read_lines(entry, data)
fh = open(file, "w")
for line in lines:
fh.write(line+"\n")
fh.close()

311
rf2aa/kinematics.py Normal file
View file

@ -0,0 +1,311 @@
from itertools import permutations
import numpy as np
import torch
from icecream import ic
from openbabel import openbabel
from rf2aa.chemical import ChemicalData as ChemData
PARAMS = {
'DMIN':1,
'DMID':4,
'DMAX':20.0,
'DBINS1':30,
'DBINS2':30,
'ABINS':36,
'USE_CB':False
}
# ============================================================
def get_pair_dist(a, b):
"""calculate pair distances between two sets of points
Parameters
----------
a,b : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of two sets of atoms
Returns
-------
dist : pytorch tensor of shape [batch,nres,nres]
stores paitwise distances between atoms in a and b
"""
dist = torch.cdist(a, b, p=2)
return dist
# ============================================================
def get_ang(a, b, c, eps=1e-6):
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
from Cartesian coordinates of three sets of atoms a,b,c
Parameters
----------
a,b,c : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of three sets of atoms
Returns
-------
ang : pytorch tensor of shape [batch,nres]
stores resulting planar angles
"""
v = a - b
w = c - b
vn = v / (torch.norm(v, dim=-1, keepdim=True)+eps)
wn = w / (torch.norm(w, dim=-1, keepdim=True)+eps)
vw = torch.sum(vn*wn, dim=-1)
return torch.acos(torch.clamp(vw,-0.999,0.999))
# ============================================================
def get_dih(a, b, c, d, eps=1e-6):
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
given Cartesian coordinates of four sets of atoms a,b,c,d
Parameters
----------
a,b,c,d : pytorch tensors of shape [batch,nres,3]
store Cartesian coordinates of four sets of atoms
Returns
-------
dih : pytorch tensor of shape [batch,nres]
stores resulting dihedrals
"""
b0 = a - b
b1 = c - b
b2 = d - c
b1n = b1 / (torch.norm(b1, dim=-1, keepdim=True) + eps)
v = b0 - torch.sum(b0*b1n, dim=-1, keepdim=True)*b1n
w = b2 - torch.sum(b2*b1n, dim=-1, keepdim=True)*b1n
x = torch.sum(v*w, dim=-1)
y = torch.sum(torch.cross(b1n,v,dim=-1)*w, dim=-1)
return torch.atan2(y+eps, x+eps)
# ============================================================
def generate_Cbeta(N,Ca,C):
# recreate Cb given N,Ca,C
b = Ca - N
c = C - Ca
a = torch.cross(b, c, dim=-1)
#Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
# fd: below matches sidechain generator (=Rosetta params)
Cb = -0.57910144*a + 0.5689693*b - 0.5441217*c + Ca
return Cb
# ============================================================
def xyz_to_c6d(xyz, params=PARAMS):
"""convert cartesian coordinates into 2d distance
and orientation maps
Parameters
----------
xyz : pytorch tensor of shape [batch,nres,3,3]
stores Cartesian coordinates of backbone N,Ca,C atoms
Returns
-------
c6d : pytorch tensor of shape [batch,nres,nres,4]
stores stacked dist,omega,theta,phi 2D maps
"""
batch = xyz.shape[0]
nres = xyz.shape[1]
# three anchor atoms
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
# recreate Cb given N,Ca,C
Cb = generate_Cbeta(N,Ca,C)
# 6d coordinates order: (dist,omega,theta,phi)
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)
if params['USE_CB']:
dist = get_pair_dist(Cb,Cb)
else:
dist = get_pair_dist(Ca,Ca)
dist[torch.isnan(dist)] = 999.9
c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...]
b,i,j = torch.where(c6d[...,0]<params['DMAX'])
c6d[b,i,j,torch.full_like(b,1)] = get_dih(Ca[b,i], Cb[b,i], Cb[b,j], Ca[b,j])
c6d[b,i,j,torch.full_like(b,2)] = get_dih(N[b,i], Ca[b,i], Cb[b,i], Cb[b,j])
c6d[b,i,j,torch.full_like(b,3)] = get_ang(Ca[b,i], Cb[b,i], Cb[b,j])
# fix long-range distances
c6d[...,0][c6d[...,0]>=params['DMAX']] = 999.9
c6d = torch.nan_to_num(c6d)
return c6d
def xyz_to_t2d(xyz_t, mask, params=PARAMS):
"""convert template cartesian coordinates into 2d distance
and orientation maps
Parameters
----------
xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
stores Cartesian coordinates of template backbone N,Ca,C atoms
mask : pytorch tensor [batch,templ,nres,nres]
indicates whether valid residue pairs or not
Returns
-------
t2d : pytorch tensor of shape [batch,nres,nres,37+6+3]
stores stacked dist,omega,theta,phi 2D maps
"""
B, T, L = xyz_t.shape[:3]
c6d = xyz_to_c6d(xyz_t[:,:,:,:3].view(B*T,L,3,3), params=params)
c6d = c6d.view(B, T, L, L, 4)
# dist to one-hot encoded
mask = mask[...,None]
dist = dist_to_onehot(c6d[...,0], params)*mask
orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6)
#
t2d = torch.cat((dist, orien, mask), dim=-1)
return t2d
def xyz_to_bbtor(xyz, params=PARAMS):
batch = xyz.shape[0]
nres = xyz.shape[1]
# three anchor atoms
N = xyz[:,:,0]
Ca = xyz[:,:,1]
C = xyz[:,:,2]
# recreate Cb given N,Ca,C
next_N = torch.roll(N, -1, dims=1)
prev_C = torch.roll(C, 1, dims=1)
phi = get_dih(prev_C, N, Ca, C)
psi = get_dih(N, Ca, C, next_N)
#
phi[:,0] = 0.0
psi[:,-1] = 0.0
#
astep = 2.0*np.pi / params['ABINS']
phi_bin = torch.round((phi+np.pi-astep/2)/astep)
psi_bin = torch.round((psi+np.pi-astep/2)/astep)
return torch.stack([phi_bin, psi_bin], axis=-1).long()
# ============================================================
def dist_to_onehot(dist, params=PARAMS):
db = dist_to_bins(dist, params)
dist = torch.nn.functional.one_hot(db, num_classes=params['DBINS1'] + params['DBINS2']+1).float()
return dist
# ============================================================
def dist_to_bins(dist,params=PARAMS):
"""bin 2d distance maps
"""
dist[torch.isnan(dist)] = 999.9
dstep1 = (params['DMID'] - params['DMIN']) / params['DBINS1']
dstep2 = (params['DMAX'] - params['DMID']) / params['DBINS2']
dbins = torch.cat([
torch.linspace(params['DMIN']+dstep1, params['DMID'], params['DBINS1'],
dtype=dist.dtype,device=dist.device),
torch.linspace(params['DMID']+dstep2, params['DMAX'], params['DBINS2'],
dtype=dist.dtype,device=dist.device),
])
db = torch.bucketize(dist.contiguous(),dbins).long()
return db
# ============================================================
def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS):
"""bin 2d distance and orientation maps
"""
db = dist_to_bins(c6d[...,0], params) # all dist < DMIN are in bin 0
astep = 2.0*np.pi / params['ABINS']
ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep)
tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep)
pb = torch.round((c6d[...,3]-astep/2)/astep)
# synchronize no-contact bins
params['DBINS'] = params['DBINS1'] + params['DBINS2']
ob[db==params['DBINS']] = params['ABINS']
tb[db==params['DBINS']] = params['ABINS']
pb[db==params['DBINS']] = params['ABINS']//2
if negative:
db = torch.where(same_chain.bool(), db.long(), params['DBINS'])
ob = torch.where(same_chain.bool(), ob.long(), params['ABINS'])
tb = torch.where(same_chain.bool(), tb.long(), params['ABINS'])
pb = torch.where(same_chain.bool(), pb.long(), params['ABINS']//2)
return torch.stack([db,ob,tb,pb],axis=-1).long()
def standardize_dihedral_retain_first(a,b,c,d):
isomorphisms = [(a,b,c,d), (a,c,b,d)]
return sorted(isomorphisms)[0]
def get_chirals(obmol, xyz):
'''
get all quadruples of atoms forming chiral centers and the expected ideal pseudodihedral between them
'''
stereo = openbabel.OBStereoFacade(obmol)
angle = np.arcsin(1/3**0.5)
chiral_idx_set = set()
for i in range(obmol.NumAtoms()):
if not stereo.HasTetrahedralStereo(i):
continue
si = stereo.GetTetrahedralStereo(i)
config = si.GetConfig()
o = config.center
c = config.from_or_towards
i,j,k = list(config.refs)
for a, b, c in permutations((c,i,j,k), 3):
chiral_idx_set.add(standardize_dihedral_retain_first(o,a,b,c))
chiral_idx = list(chiral_idx_set)
chiral_idx.sort()
chiral_idx = torch.tensor(chiral_idx, dtype=torch.float32)
chiral_idx = chiral_idx[(chiral_idx<obmol.NumAtoms()).all(dim=-1)]
if chiral_idx.numel() == 0:
return torch.zeros((0,5))
dih = get_dih(*xyz[chiral_idx.long()].split(split_size=1,dim=1))[:,0]
chirals = torch.nn.functional.pad(chiral_idx, (0, 1), mode='constant', value=angle)
chirals[dih<0.0,-1] *= -1
return chirals
def get_atomize_protein_chirals(residues_atomize, lig_xyz, residue_atomize_mask, bond_feats):
"""
Enumerate chiral centers in residues and provide features for chiral centers
"""
angle = np.arcsin(1/3**0.5) # perfect tetrahedral geometry
chiral_atoms = ChemData().aachirals[residues_atomize]
ra = residue_atomize_mask.nonzero()
r,a = ra.T
chiral_atoms = chiral_atoms[r,a].nonzero().squeeze(1) #num_chiral_centers
num_chiral_centers = chiral_atoms.shape[0]
chiral_bonds = bond_feats[chiral_atoms] # find bonds to each chiral atom
chiral_bonds_idx = chiral_bonds.nonzero() # find indices of each bonded neighbor to chiral atom
# in practice all chiral atoms in proteins have 3 heavy atom neighbors, so reshape to 3
chiral_bonds_idx = chiral_bonds_idx.reshape(num_chiral_centers, 3, 2)
chirals = torch.zeros((num_chiral_centers, 5))
chirals[:,0] = chiral_atoms.long()
chirals[:, 1:-1] = chiral_bonds_idx[...,-1].long()
chirals[:, -1] = angle
n = chirals.shape[0]
if n>0:
chirals = chirals.repeat(3,1).float()
chirals[n:2*n,1:-1] = torch.roll(chirals[n:2*n,1:-1],1,1)
chirals[2*n: ,1:-1] = torch.roll(chirals[2*n: ,1:-1],2,1)
dih = get_dih(*lig_xyz[chirals[:,:4].long()].split(split_size=1,dim=1))[:,0]
chirals[dih<0.0,-1] = -angle
else:
chirals = torch.zeros((0,5))
return chirals

BIN
rf2aa/ligands.json.gz Normal file

Binary file not shown.

240
rf2aa/loss/loss.py Normal file
View file

@ -0,0 +1,240 @@
import torch
from icecream import ic
import numpy as np
import scipy
import pandas as pd
import networkx as nx
from rf2aa.util import (
is_protein,
rigid_from_3_points,
is_nucleic,
find_all_paths_of_length_n,
find_all_rigid_groups
)
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.kinematics import get_dih, get_ang
from rf2aa.scoring import HbHybType
from typing import List, Dict, Optional
#fd more efficient LJ loss
class LJLoss(torch.autograd.Function):
@staticmethod
def ljVdV(deltas, sigma, epsilon, lj_lin, eps):
# deltas - (N,natompair,3)
N = deltas.shape[0]
dist = torch.sqrt( torch.sum ( torch.square( deltas ), dim=-1 ) + eps )
linpart = dist<lj_lin*sigma[None]
deff = dist.clone()
deff[linpart] = lj_lin*sigma.repeat(N,1)[linpart]
sd = sigma / deff
sd2 = sd*sd
sd6 = sd2 * sd2 * sd2
sd12 = sd6 * sd6
ljE = epsilon * (sd12 - 2 * sd6)
ljE[linpart] += epsilon.repeat(N,1)[linpart] * (
-12 * sd12[linpart]/deff[linpart] + 12 * sd6[linpart]/deff[linpart]
) * (dist[linpart]-deff[linpart])
# works for linpart too
dljEdd_over_r = epsilon * (-12 * sd12/deff + 12 * sd6/deff) / (dist)
return ljE.sum(dim=-1), dljEdd_over_r
@staticmethod
def forward(
ctx, xs, seq, aamask, bond_feats, dist_matrix, ljparams, ljcorr, num_bonds,
lj_lin=0.75, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
eps=1e-8, training=True
):
N, L, A = xs.shape[:3]
assert (N==1) # see comment below
ds_res = torch.sqrt( torch.sum ( torch.square(
xs.detach()[:,:,None,1,:]-xs.detach()[:,None,:,1,:]), dim=-1 ))
rs = torch.triu_indices(L,L,0, device=xs.device)
ri,rj = rs[0],rs[1]
# batch during inference for huge systems
BATCHSIZE = 65536//N
ljval = 0
dljEdx = torch.zeros_like(xs, dtype=torch.float)
for i_batch in range((len(ri)-1)//BATCHSIZE + 1):
idx = torch.arange(
i_batch*BATCHSIZE,
min( (i_batch+1)*BATCHSIZE, len(ri)),
device=xs.device
)
rii,rjj = ri[idx],rj[idx] # residue pairs we consider
ridx,ai,aj = (
aamask[seq[rii]][:,:,None]*aamask[seq[rjj]][:,None,:]
).nonzero(as_tuple=True)
deltas = xs[:,rii,:,None,:]-xs[:,rjj,None,:,:] # N,BATCHSIZE,Natm,Natm,3
seqi,seqj = seq[rii[ridx]], seq[rjj[ridx]]
mask = torch.ones_like(ridx, dtype=torch.bool) # are atoms defined?
# mask out atom pairs from too-distant residues (C-alpha dist > 24A)
ca_dist = torch.linalg.norm(deltas[:,:,1,1],dim=-1)
mask *= (ca_dist[:,ridx]<24).any(dim=0) # will work for batch>1 but very inefficient
intrares = (rii[ridx]==rjj[ridx])
mask[intrares*(ai<aj)] = False # upper tri (atoms)
## count-pair
# a) intra-protein
mask[intrares] *= num_bonds[seqi[intrares],ai[intrares],aj[intrares]]>=4
pepbondres = ri[ridx]+1==rj[ridx]
mask[pepbondres] *= (
num_bonds[seqi[pepbondres],ai[pepbondres],2]
+ num_bonds[seqj[pepbondres],0,aj[pepbondres]]
+ 1) >=4
# b) intra-ligand
atommask = (ai==1)*(aj==1)
dist_matrix = torch.nan_to_num(dist_matrix, posinf=4.0) #NOTE: need to run nan_to_num to remove infinities
resmask = (dist_matrix[0,rii,rjj] >= 4) # * will only work for batch=1
mask[atommask] *= resmask[ ridx[atommask] ]
# c) protein/ligand
##fd NOTE1: changed 6->5 in masking (atom 5 is CG which should always be 4+ bonds away from connected atom)
##fd NOTE2: this does NOT work correctly for nucleic acids
##fd for NAs atoms 0-4 are masked, but also 5,7,8 and 9 should be masked!
bbatommask = (ai<5)*(aj<5)
resmask = (bond_feats[0,rii,rjj] != 6) # * will only work for batch=1
mask[bbatommask] *= resmask[ ridx[bbatommask] ]
# apply mask. only interactions to be scored remain
ai,aj,seqi,seqj,ridx = ai[mask],aj[mask],seqi[mask],seqj[mask],ridx[mask]
deltas = deltas[:,ridx,ai,aj]
# hbond correction
use_hb_dis = (
ljcorr[seqi,ai,0]*ljcorr[seqj,aj,1]
+ ljcorr[seqi,ai,1]*ljcorr[seqj,aj,0] ).nonzero()
use_ohdon_dis = ( # OH are both donors & acceptors
ljcorr[seqi,ai,0]*ljcorr[seqi,ai,1]*ljcorr[seqj,aj,0]
+ljcorr[seqi,ai,0]*ljcorr[seqj,aj,0]*ljcorr[seqj,aj,1]
).nonzero()
use_hb_hdis = (
ljcorr[seqi,ai,2]*ljcorr[seqj,aj,1]
+ljcorr[seqi,ai,1]*ljcorr[seqj,aj,2]
).nonzero()
# disulfide correction
potential_disulf = (ljcorr[seqi,ai,3]*ljcorr[seqj,aj,3] ).nonzero()
ljrs = ljparams[seqi,ai,0] + ljparams[seqj,aj,0]
ljrs[use_hb_dis] = lj_hb_dis
ljrs[use_ohdon_dis] = lj_OHdon_dis
ljrs[use_hb_hdis] = lj_hbond_hdis
ljss = torch.sqrt( ljparams[seqi,ai,1] * ljparams[seqj,aj,1] + eps )
ljss [potential_disulf] = 0.0
natoms = torch.sum(aamask[seq])
ljval_i,dljEdd_i = LJLoss.ljVdV(deltas,ljrs,ljss,lj_lin,eps)
ljval += ljval_i / natoms
# sum per-atom-pair grads into per-atom grads
# note this is stochastic op on GPU
idxI,idxJ = rii[ridx]*A + ai, rjj[ridx]*A + aj
dljEdx.view(N,-1,3).index_add_(1, idxI, dljEdd_i[...,None]*deltas, alpha=1.0/natoms)
dljEdx.view(N,-1,3).index_add_(1, idxJ, dljEdd_i[...,None]*deltas, alpha=-1.0/natoms)
ctx.save_for_backward(dljEdx)
return ljval
@staticmethod
def backward(ctx, grad_output):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
dljEdx, = ctx.saved_tensors
return (
grad_output * dljEdx,
None, None, None, None, None, None, None, None, None, None, None, None, None
)
# Rosetta-like version of LJ (fa_atr+fa_rep)
# lj_lin is switch from linear to 12-6. Smaller values more sharply penalize clashes
def calc_lj(
seq, xs, aamask, bond_feats, dist_matrix, ljparams, ljcorr, num_bonds,
lj_lin=0.75, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
lj_maxrad=-1.0, eps=1e-8,
training=True
):
lj = LJLoss.apply
ljval = lj(
xs, seq, aamask, bond_feats, dist_matrix, ljparams, ljcorr, num_bonds,
lj_lin, lj_hb_dis, lj_OHdon_dis, lj_hbond_hdis, eps, training)
return ljval
def calc_chiral_loss(pred, chirals):
"""
calculate error in dihedral angles for chiral atoms
Input:
- pred: predicted coords (B, L, :, 3)
- chirals: True coords (B, nchiral, 5), skip if 0 chiral sites, 5 dimension are indices for 4 atoms that make dihedral and the ideal angle they should form
Output:
- mean squared error of chiral angles
"""
if chirals.shape[1] == 0:
return torch.tensor(0.0, device=pred.device)
chiral_dih = pred[:, chirals[..., :-1].long(), 1]
pred_dih = get_dih(chiral_dih[...,0, :], chiral_dih[...,1, :], chiral_dih[...,2, :], chiral_dih[...,3, :]) # n_symm, b, n, 36, 3
l = torch.square(pred_dih-chirals[...,-1]).mean()
return l
@torch.enable_grad()
def calc_lj_grads(
seq, xyz, alpha, toaa, bond_feats, dist_matrix,
aamask, ljparams, ljcorr, num_bonds,
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
lj_maxrad=-1.0, eps=1e-8
):
xyz.requires_grad_(True)
alpha.requires_grad_(True)
_, xyzaa = toaa(seq, xyz, alpha)
Elj = calc_lj(
seq[0],
xyzaa[...,:3],
aamask,
bond_feats,
dist_matrix,
ljparams,
ljcorr,
num_bonds,
lj_lin,
lj_hb_dis,
lj_OHdon_dis,
lj_hbond_hdis,
lj_maxrad,
eps
)
return torch.autograd.grad(Elj, (xyz,alpha))
@torch.enable_grad()
def calc_chiral_grads(xyz, chirals):
xyz.requires_grad_(True)
l = calc_chiral_loss(xyz, chirals)
if l.item() == 0.0:
return (torch.zeros(xyz.shape, device=xyz.device),) # autograd returns a tuple..
return torch.autograd.grad(l, xyz)

View file

@ -0,0 +1,418 @@
import torch
import torch.nn as nn
import assertpy
from assertpy import assert_that
from icecream import ic
from rf2aa.model.layers.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, recycling_factory
from rf2aa.model.Track_module import IterativeSimulator
from rf2aa.model.layers.AuxiliaryPredictor import (
DistanceNetwork,
MaskedTokenNetwork,
LDDTNetwork,
PAENetwork,
BinderNetwork,
)
from rf2aa.tensor_util import assert_shape, assert_equal
import rf2aa.util
from rf2aa.chemical import ChemicalData as ChemData
def get_shape(t):
if hasattr(t, "shape"):
return t.shape
if type(t) is tuple:
return [get_shape(e) for e in t]
else:
return type(t)
class RoseTTAFoldModule(nn.Module):
def __init__(
self,
symmetrize_repeats=None, # whether to symmetrize repeats in the pair track
repeat_length=None, # if symmetrizing repeats, what length are they?
symmsub_k=None, # if symmetrizing repeats, which diagonals?
sym_method=None, # if symmetrizing repeats, which block symmetrization method?
main_block=None, # if copying template blocks along main diag, which block is main block? (the one w/ motif)
copy_main_block_template=None, # whether or not to copy main block template along main diag
n_extra_block=4,
n_main_block=8,
n_ref_block=4,
n_finetune_block=0,
d_msa=256,
d_msa_full=64,
d_pair=128,
d_templ=64,
n_head_msa=8,
n_head_pair=4,
n_head_templ=4,
d_hidden=32,
d_hidden_templ=64,
d_t1d=0,
p_drop=0.15,
additional_dt1d=0,
recycling_type="msa_pair",
SE3_param={}, SE3_ref_param={},
atom_type_index=None,
aamask=None,
ljlk_parameters=None,
lj_correction_parameters=None,
cb_len=None,
cb_ang=None,
cb_tor=None,
num_bonds=None,
lj_lin=0.6,
use_chiral_l1=True,
use_lj_l1=False,
use_atom_frames=True,
use_same_chain=False,
enable_same_chain=False,
refiner_topk=64,
get_quaternion=False,
# New for diffusion
freeze_track_motif=False,
assert_single_sequence_input=False,
fit=False,
tscale=1.0
):
super(RoseTTAFoldModule, self).__init__()
self.freeze_track_motif = freeze_track_motif
self.assert_single_sequence_input = assert_single_sequence_input
self.recycling_type = recycling_type
#
# Input Embeddings
d_state = SE3_param["l0_out_features"]
self.latent_emb = MSA_emb(
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop, use_same_chain=use_same_chain,
enable_same_chain=enable_same_chain
)
self.full_emb = Extra_emb(
d_msa=d_msa_full, d_init=ChemData().NAATOKENS - 1 + 4, p_drop=p_drop
)
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=ChemData().NBTYPES)
self.templ_emb = Templ_emb(d_t1d=d_t1d,
d_pair=d_pair,
d_templ=d_templ,
d_state=d_state,
n_head=n_head_templ,
d_hidden=d_hidden_templ,
p_drop=0.25,
symmetrize_repeats=symmetrize_repeats, # repeat protein stuff
repeat_length=repeat_length,
symmsub_k=symmsub_k,
sym_method=sym_method,
main_block=main_block,
copy_main_block=copy_main_block_template,
additional_dt1d=additional_dt1d)
# Update inputs with outputs from previous round
self.recycle = recycling_factory[recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
#
self.simulator = IterativeSimulator(
n_extra_block=n_extra_block,
n_main_block=n_main_block,
n_ref_block=n_ref_block,
n_finetune_block=n_finetune_block,
d_msa=d_msa,
d_msa_full=d_msa_full,
d_pair=d_pair,
d_hidden=d_hidden,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
SE3_param=SE3_param,
SE3_ref_param=SE3_ref_param,
p_drop=p_drop,
atom_type_index=atom_type_index, # change if encoding elements instead of atomtype
aamask=aamask,
ljlk_parameters=ljlk_parameters,
lj_correction_parameters=lj_correction_parameters,
num_bonds=num_bonds,
cb_len=cb_len,
cb_ang=cb_ang,
cb_tor=cb_tor,
lj_lin=lj_lin,
use_lj_l1=use_lj_l1,
use_chiral_l1=use_chiral_l1,
symmetrize_repeats=symmetrize_repeats,
repeat_length=repeat_length,
symmsub_k=symmsub_k,
sym_method=sym_method,
main_block=main_block,
use_same_chain=use_same_chain,
enable_same_chain=enable_same_chain,
refiner_topk=refiner_topk
)
##
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
self.lddt_pred = LDDTNetwork(d_state)
self.pae_pred = PAENetwork(d_pair)
self.pde_pred = PAENetwork(
d_pair
) # distance error, but use same architecture as aligned error
# binder predictions are made on top of the pair features, just like
# PAE predictions are. It's not clear if this is the best place to insert
# this prediction head.
# self.binder_network = BinderNetwork(d_pair, d_state)
self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable?
self.use_atom_frames = use_atom_frames
self.enable_same_chain = enable_same_chain
self.get_quaternion = get_quaternion
self.verbose_checks = False
def forward(
self,
msa_latent,
msa_full,
seq,
seq_unmasked,
xyz,
sctors,
idx,
bond_feats,
dist_matrix,
chirals,
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None, is_motif=None,
return_raw=False,
use_checkpoint=False,
return_infer=False, #fd ?
p2p_crop=-1, topk_crop=-1, # striping
symmids=None, symmsub=None, symmRs=None, symmmeta=None, # symmetry
):
# ic(get_shape(msa_latent))
# ic(get_shape(msa_full))
# ic(get_shape(seq))
# ic(get_shape(seq_unmasked))
# ic(get_shape(xyz))
# ic(get_shape(sctors))
# ic(get_shape(idx))
# ic(get_shape(bond_feats))
# ic(get_shape(chirals))
# ic(get_shape(atom_frames))
# ic(get_shape(t1d))
# ic(get_shape(t2d))
# ic(get_shape(xyz_t))
# ic(get_shape(alpha_t))
# ic(get_shape(mask_t))
# ic(get_shape(same_chain))
# ic(get_shape(msa_prev))
# ic(get_shape(pair_prev))
# ic(get_shape(mask_recycle))
# ic()
# ic()
B, N, L = msa_latent.shape[:3]
A = atom_frames.shape[1]
dtype = msa_latent.dtype
if self.assert_single_sequence_input:
assert_shape(msa_latent, (1, 1, L, 164))
assert_shape(msa_full, (1, 1, L, 83))
assert_shape(seq, (1, L))
assert_shape(seq_unmasked, (1, L))
assert_shape(xyz, (1, L, ChemData().NTOTAL, 3))
assert_shape(sctors, (1, L, 20, 2))
assert_shape(idx, (1, L))
assert_shape(bond_feats, (1, L, L))
assert_shape(dist_matrix, (1, L, L))
# assert_shape(chirals, (1, 0))
# assert_shape(atom_frames, (1, 4, L)) # This is set to 4 for the recycle count, but that can't be right
assert_shape(atom_frames, (1, A, 3, 2)) # What is 4?
assert_shape(t1d, (1, 1, L, 80))
assert_shape(t2d, (1, 1, L, L, 68))
assert_shape(xyz_t, (1, 1, L, 3))
assert_shape(alpha_t, (1, 1, L, 60))
assert_shape(mask_t, (1, 1, L, L))
assert_shape(same_chain, (1, L, L))
device = msa_latent.device
assert_that(msa_full.device).is_equal_to(device)
assert_that(seq.device).is_equal_to(device)
assert_that(seq_unmasked.device).is_equal_to(device)
assert_that(xyz.device).is_equal_to(device)
assert_that(sctors.device).is_equal_to(device)
assert_that(idx.device).is_equal_to(device)
assert_that(bond_feats.device).is_equal_to(device)
assert_that(dist_matrix.device).is_equal_to(device)
assert_that(atom_frames.device).is_equal_to(device)
assert_that(t1d.device).is_equal_to(device)
assert_that(t2d.device).is_equal_to(device)
assert_that(xyz_t.device).is_equal_to(device)
assert_that(alpha_t.device).is_equal_to(device)
assert_that(mask_t.device).is_equal_to(device)
assert_that(same_chain.device).is_equal_to(device)
if self.verbose_checks:
#ic(is_motif.shape)
is_sm = rf2aa.util.is_atom(seq[0]) # (L)
#is_protein_motif = is_motif & ~is_sm
#if is_motif.any():
# motif_protein_i = torch.where(is_motif)[0][0]
#is_motif_sm = is_motif & is_sm
#if is_sm.any():
# motif_sm_i = torch.where(is_motif_sm)[0][0]
#diffused_protein_i = torch.where(~is_sm & ~is_motif)[0][0]
"""
msa_full: NSEQ,N_INDEL,N_TERMINUS,
msa_masked: NSEQ,NSEQ,N_INDEL,N_INDEL,N_TERMINUS
"""
import numpy as np
NINDEL = 1
NTERMINUS = 2
NMSAFULL = ChemData().NAATOKENS + NINDEL + NTERMINUS
NMSAMASKED = ChemData().NAATOKENS + ChemData().NAATOKENS + NINDEL + NINDEL + NTERMINUS
assert_that(msa_latent.shape[-1]).is_equal_to(NMSAMASKED)
assert_that(msa_full.shape[-1]).is_equal_to(NMSAFULL)
msa_full_seq = np.r_[0:ChemData().NAATOKENS]
msa_full_indel = np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
msa_full_term = np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]
msa_latent_seq1 = np.r_[0:ChemData().NAATOKENS]
msa_latent_seq2 = np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]
msa_latent_indel1 = np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
msa_latent_indel2 = np.r_[
2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL
]
msa_latent_terminus = np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
#i_name = [(diffused_protein_i, "diffused_protein")]
#if is_sm.any():
# i_name.insert(0, (motif_sm_i, "motif_sm"))
#if is_motif.any():
# i_name.insert(0, (motif_protein_i, "motif_protein"))
i_name = [(0, "tst")]
for i, name in i_name:
ic(f"------------------{name}:{i}----------------")
msa_full_seq = msa_full[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
msa_full_indel = msa_full[
0, 0, i, np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
]
msa_full_term = msa_full[0, 0, i, np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]]
msa_latent_seq1 = msa_latent[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
msa_latent_seq2 = msa_latent[0, 0, i, np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]]
msa_latent_indel1 = msa_latent[
0, 0, i, np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
]
msa_latent_indel2 = msa_latent[
0,
0,
i,
np.r_[2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL],
]
msa_latent_term = msa_latent[
0, 0, i, np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
]
assert_equal(msa_full_seq, msa_latent_seq1)
assert_equal(msa_full_seq, msa_latent_seq2)
assert_equal(msa_full_indel, msa_latent_indel1)
assert_equal(msa_full_indel, msa_latent_indel2)
assert_equal(msa_full_term, msa_latent_term)
# if 'motif' in name:
msa_cat = torch.where(msa_full_seq)[0]
ic(msa_cat, seq[0, i])
assert_equal(seq[0, i : i + 1], msa_cat)
assert_equal(seq[0, i], seq_unmasked[0, i])
ic(
name,
# torch.where(msa_latent[0,0,i,:80]),
# torch.where(msa_full[0,0,i]),
seq[0, i],
seq_unmasked[0, i],
torch.where(t1d[0, 0, i]),
xyz[0, i, :4, 0],
xyz_t[0, 0, i, 0],
)
# Get embeddings
#if self.enable_same_chain == False:
# same_chain = None
msa_latent, pair, state = self.latent_emb(
msa_latent, seq, idx, bond_feats, dist_matrix, same_chain=same_chain
)
msa_full = self.full_emb(msa_full, seq, idx)
pair = pair + self.bond_emb(bond_feats)
msa_latent, pair, state = msa_latent.to(dtype), pair.to(dtype), state.to(dtype)
msa_full = msa_full.to(dtype)
#
# Do recycling
if msa_prev is None:
msa_prev = torch.zeros_like(msa_latent[:,0])
if pair_prev is None:
pair_prev = torch.zeros_like(pair)
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
state_prev = torch.zeros_like(state)
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
pair = pair + pair_recycle
state = state + state_recycle # if state is not recycled these will be zeros
# add template embedding
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop)
# Predict coordinates from given inputs
is_motif = is_motif if self.freeze_track_motif else torch.zeros_like(seq).bool()[0]
msa, pair, xyz, alpha_s, xyz_allatom, state, symmsub, quat = self.simulator(
seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx,
symmids, symmsub, symmRs, symmmeta,
bond_feats, dist_matrix, same_chain, chirals, is_motif, atom_frames,
use_checkpoint=use_checkpoint, use_atom_frames=self.use_atom_frames,
p2p_crop=p2p_crop, topk_crop=topk_crop
)
if return_raw:
# get last structure
xyz_last = xyz_allatom[-1].unsqueeze(0)
return msa[:,0], pair, xyz_last, alpha_s[-1], None
# predict masked amino acids
logits_aa = self.aa_pred(msa)
# predict distogram & orientograms
logits = self.c6d_pred(pair)
# Predict LDDT
lddt = self.lddt_pred(state)
if self.verbose_checks:
pseq_0 = logits_aa.permute(0, 2, 1)
ic(pseq_0.shape)
pseq_0 = pseq_0[0]
ic(
f"motif sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[is_motif], dim=-1).tolist())}"
)
ic(
f"diffused sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[~is_motif], dim=-1).tolist())}"
)
logits_pae = logits_pde = p_bind = None
# predict aligned error and distance error
logits_pae = self.pae_pred(pair)
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
#fd predict bind/no-bind
p_bind = self.bind_pred(logits_pae,same_chain)
if self.get_quaternion:
return (
logits, logits_aa, logits_pae, logits_pde, p_bind,
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state, quat
)
else:
return (
logits, logits_aa, logits_pae, logits_pde, p_bind,
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state
)

1220
rf2aa/model/Track_module.py Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,475 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from opt_einsum import contract as einsum
from rf2aa.util_module import init_lecun_normal
class FeedForwardLayer(nn.Module):
def __init__(self, d_model, r_ff, p_drop=0.1):
super(FeedForwardLayer, self).__init__()
self.norm = nn.LayerNorm(d_model)
self.linear1 = nn.Linear(d_model, d_model*r_ff)
self.dropout = nn.Dropout(p_drop)
self.linear2 = nn.Linear(d_model*r_ff, d_model)
self.reset_parameter()
def reset_parameter(self):
# initialize linear layer right before ReLu: He initializer (kaiming normal)
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
nn.init.zeros_(self.linear1.bias)
# initialize linear layer right before residual connection: zero initialize
nn.init.zeros_(self.linear2.weight)
nn.init.zeros_(self.linear2.bias)
def forward(self, src):
src = self.norm(src)
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
return src
class Attention(nn.Module):
# calculate multi-head attention
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
super(Attention, self).__init__()
self.h = n_head
self.dim = d_hidden
#
self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
#
self.to_out = nn.Linear(n_head*d_hidden, d_out)
self.scaling = 1/math.sqrt(d_hidden)
#
# initialize all parameters properly
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, query, key, value):
B, Q = query.shape[:2]
B, K = key.shape[:2]
#
query = self.to_q(query).reshape(B, Q, self.h, self.dim)
key = self.to_k(key).reshape(B, K, self.h, self.dim)
value = self.to_v(value).reshape(B, K, self.h, self.dim)
#
query = query * self.scaling
attn = einsum('bqhd,bkhd->bhqk', query, key)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bhqk,bkhd->bqhd', attn, value)
out = out.reshape(B, Q, self.h*self.dim)
#
out = self.to_out(out)
return out
# MSA Attention (row/column) from AlphaFold architecture
class SequenceWeight(nn.Module):
def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
super(SequenceWeight, self).__init__()
self.h = n_head
self.dim = d_hidden
self.scale = 1.0 / math.sqrt(self.dim)
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
self.dropout = nn.Dropout(p_drop)
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_query.weight)
nn.init.xavier_uniform_(self.to_key.weight)
def forward(self, msa):
B, N, L = msa.shape[:3]
tar_seq = msa[:,0]
q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
k = self.to_key(msa).view(B, N, L, self.h, self.dim)
q = q * self.scale
attn = einsum('bqihd,bkihd->bkihq', q, k)
attn = F.softmax(attn, dim=1)
return self.dropout(attn)
class MSARowAttentionWithBias(nn.Module):
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
super(MSARowAttentionWithBias, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_pair = nn.LayerNorm(d_pair)
#
self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_pair, n_head, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, msa, pair): # TODO: make this as tied-attention
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
pair = self.norm_pair(pair)
#
seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(msa))
#
query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
key = key * self.scaling
attn = einsum('bsqhd,bskhd->bqkh', query, key)
attn = attn + bias
attn = F.softmax(attn, dim=-2)
#
out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
out = gate * out
#
out = self.to_out(out)
return out
class MSAColAttention(nn.Module):
def __init__(self, d_msa=256, n_head=8, d_hidden=32):
super(MSAColAttention, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
#
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, msa):
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
#
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
gate = torch.sigmoid(self.to_g(msa))
#
query = query * self.scaling
attn = einsum('bqihd,bkihd->bihqk', query, key)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
out = gate * out
#
out = self.to_out(out)
return out
class MSAColGlobalAttention(nn.Module):
def __init__(self, d_msa=64, n_head=8, d_hidden=8):
super(MSAColGlobalAttention, self).__init__()
self.norm_msa = nn.LayerNorm(d_msa)
#
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, msa):
B, N, L = msa.shape[:3]
#
msa = self.norm_msa(msa)
#
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
query = query.mean(dim=1) # (B, L, h, dim)
key = self.to_k(msa) # (B, N, L, dim)
value = self.to_v(msa) # (B, N, L, dim)
gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
#
query = query * self.scaling
attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
attn = F.softmax(attn, dim=-1)
#
out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
out = gate * out # (B, N, L, h*dim)
#
out = self.to_out(out)
return out
# TriangleAttention & TriangleMultiplication from AlphaFold architecture
class TriangleAttention(nn.Module):
def __init__(self, d_pair, n_head=4, d_hidden=32, p_drop=0.1, start_node=True):
super(TriangleAttention, self).__init__()
self.norm = nn.LayerNorm(d_pair)
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_pair, n_head, bias=False)
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
self.start_node=start_node
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, pair):
B, L = pair.shape[:2]
pair = self.norm(pair)
# input projection
query = self.to_q(pair).reshape(B, L, L, self.h, -1)
key = self.to_k(pair).reshape(B, L, L, self.h, -1)
value = self.to_v(pair).reshape(B, L, L, self.h, -1)
bias = self.to_b(pair) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
# attention
query = query * self.scaling
if self.start_node:
attn = einsum('bijhd,bikhd->bijkh', query, key)
else:
attn = einsum('bijhd,bkjhd->bijkh', query, key)
attn = attn + bias.unsqueeze(1).expand(-1,L,-1,-1,-1) # (bijkh)
attn = F.softmax(attn, dim=-2)
if self.start_node:
out = einsum('bijkh,bikhd->bijhd', attn, value).reshape(B, L, L, -1)
else:
out = einsum('bijkh,bkjhd->bijhd', attn, value).reshape(B, L, L, -1)
out = gate * out # gated attention
# output projection
out = self.to_out(out)
return out
class TriangleMultiplication(nn.Module):
def __init__(self, d_pair, d_hidden=128, outgoing=True):
super(TriangleMultiplication, self).__init__()
self.norm = nn.LayerNorm(d_pair)
self.left_proj = nn.Linear(d_pair, d_hidden)
self.right_proj = nn.Linear(d_pair, d_hidden)
self.left_gate = nn.Linear(d_pair, d_hidden)
self.right_gate = nn.Linear(d_pair, d_hidden)
#
self.gate = nn.Linear(d_pair, d_pair)
self.norm_out = nn.LayerNorm(d_hidden)
self.out_proj = nn.Linear(d_hidden, d_pair)
self.outgoing = outgoing
self.reset_parameter()
def reset_parameter(self):
# normal distribution for regular linear weights
self.left_proj = init_lecun_normal(self.left_proj)
self.right_proj = init_lecun_normal(self.right_proj)
# Set Bias of Linear layers to zeros
nn.init.zeros_(self.left_proj.bias)
nn.init.zeros_(self.right_proj.bias)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.left_gate.weight)
nn.init.ones_(self.left_gate.bias)
nn.init.zeros_(self.right_gate.weight)
nn.init.ones_(self.right_gate.bias)
nn.init.zeros_(self.gate.weight)
nn.init.ones_(self.gate.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
def forward(self, pair):
B, L = pair.shape[:2]
pair = self.norm(pair)
left = self.left_proj(pair) # (B, L, L, d_h)
left_gate = torch.sigmoid(self.left_gate(pair))
left = left_gate * left
right = self.right_proj(pair) # (B, L, L, d_h)
right_gate = torch.sigmoid(self.right_gate(pair))
right = right_gate * right
if self.outgoing:
out = einsum('bikd,bjkd->bijd', left, right/float(L))
else:
out = einsum('bkid,bkjd->bijd', left, right/float(L))
out = self.norm_out(out)
out = self.out_proj(out)
gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair)
out = gate * out
return out
# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
class BiasedAxialAttention(nn.Module):
def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
super(BiasedAxialAttention, self).__init__()
#
self.is_row = is_row
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_bias = nn.LayerNorm(d_bias)
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
self.to_b = nn.Linear(d_bias, n_head, bias=False)
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
self.scaling = 1/math.sqrt(d_hidden)
self.h = n_head
self.dim = d_hidden
# initialize all parameters properly
self.reset_parameter()
def reset_parameter(self):
# query/key/value projection: Glorot uniform / Xavier uniform
nn.init.xavier_uniform_(self.to_q.weight)
nn.init.xavier_uniform_(self.to_k.weight)
nn.init.xavier_uniform_(self.to_v.weight)
# bias: normal distribution
self.to_b = init_lecun_normal(self.to_b)
# gating: zero weights, one biases (mostly open gate at the begining)
nn.init.zeros_(self.to_g.weight)
nn.init.ones_(self.to_g.bias)
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
nn.init.zeros_(self.to_out.weight)
nn.init.zeros_(self.to_out.bias)
def forward(self, pair, bias):
# pair: (B, L, L, d_pair)
B, L = pair.shape[:2]
if self.is_row:
pair = pair.permute(0,2,1,3)
bias = bias.permute(0,2,1,3)
pair = self.norm_pair(pair)
bias = self.norm_bias(bias)
query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
bias = self.to_b(bias) # (B, L, L, h)
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
query = query * self.scaling
key = key / L # normalize for tied attention
attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
attn = attn + bias # apply bias
attn = F.softmax(attn, dim=-2) # (B, L, L, h)
out = einsum('bijh,bnjhd->bnihd', attn, value).reshape(B, L, L, -1)
out = gate * out
out = self.to_out(out)
if self.is_row:
out = out.permute(0,2,1,3)
return out

View file

@ -0,0 +1,111 @@
import torch
import torch.nn as nn
from rf2aa.chemical import ChemicalData as ChemData
class DistanceNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.0):
super(DistanceNetwork, self).__init__()
#HACK: dimensions are hard coded here
self.proj_symm = nn.Linear(n_feat, 61+37) # must match bin counts defined in kinematics.py
self.proj_asymm = nn.Linear(n_feat, 37+19)
self.reset_parameter()
def reset_parameter(self):
# initialize linear layer for final logit prediction
nn.init.zeros_(self.proj_symm.weight)
nn.init.zeros_(self.proj_asymm.weight)
nn.init.zeros_(self.proj_symm.bias)
nn.init.zeros_(self.proj_asymm.bias)
def forward(self, x):
# input: pair info (B, L, L, C)
# predict theta, phi (non-symmetric)
logits_asymm = self.proj_asymm(x)
logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
# predict dist, omega
logits_symm = self.proj_symm(x)
logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
logits_dist = logits_symm[:,:,:,:61].permute(0,3,1,2)
logits_omega = logits_symm[:,:,:,61:].permute(0,3,1,2)
return logits_dist, logits_omega, logits_theta, logits_phi
class MaskedTokenNetwork(nn.Module):
def __init__(self, n_feat, p_drop=0.0):
super(MaskedTokenNetwork, self).__init__()
#fd note this predicts probability for the mask token (which is never in ground truth)
# it should be ok though(?)
self.proj = nn.Linear(n_feat, ChemData().NAATOKENS)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x):
B, N, L = x.shape[:3]
logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
return logits
class LDDTNetwork(nn.Module):
def __init__(self, n_feat, n_bin_lddt=50):
super(LDDTNetwork, self).__init__()
self.proj = nn.Linear(n_feat, n_bin_lddt)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x):
logits = self.proj(x) # (B, L, 50)
return logits.permute(0,2,1)
class PAENetwork(nn.Module):
def __init__(self, n_feat, n_bin_pae=64):
super(PAENetwork, self).__init__()
self.proj = nn.Linear(n_feat, n_bin_pae)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, x):
logits = self.proj(x) # (B, L, L, 64)
return logits.permute(0,3,1,2)
class BinderNetwork(nn.Module):
def __init__(self, n_bin_pae=64):
super(BinderNetwork, self).__init__()
self.classify = torch.nn.Linear(n_bin_pae, 1)
self.reset_parameter()
def reset_parameter(self):
nn.init.zeros_(self.classify.weight)
nn.init.zeros_(self.classify.bias)
def forward(self, pae, same_chain):
logits = pae.permute(0,2,3,1)
logits_inter = torch.mean( logits[same_chain==0], dim=0 ).nan_to_num() # all zeros if single chain
prob = torch.sigmoid( self.classify( logits_inter ) )
return prob
aux_predictor_factory = {
"c6d": DistanceNetwork,
"mlm": MaskedTokenNetwork,
"plddt": LDDTNetwork,
"pae": PAENetwork,
"binder": BinderNetwork
}

View file

@ -0,0 +1,458 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
import torch.utils.checkpoint as checkpoint
from rf2aa.util import *
from rf2aa.util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal, get_res_atom_dist
from rf2aa.model.layers.Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
from rf2aa.model.Track_module import PairStr2Pair, PositionalEncoding2D
from rf2aa.chemical import ChemicalData as ChemData
# Module contains classes and functions to generate initial embeddings
class MSA_emb(nn.Module):
# Get initial seed MSA embedding
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=0,
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False, enable_same_chain=False):
if (d_init==0):
d_init = 2*ChemData().NAATOKENS+2+2
super(MSA_emb, self).__init__()
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
self.emb_q = nn.Embedding(ChemData().NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding
self.emb_left = nn.Embedding(ChemData().NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
self.emb_right = nn.Embedding(ChemData().NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
self.emb_state = nn.Embedding(ChemData().NAATOKENS, d_state)
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos,
maxpos_atom=maxpos_atom, p_drop=p_drop, use_same_chain=use_same_chain,
enable_same_chain=enable_same_chain)
self.enable_same_chain = enable_same_chain
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
self.emb_q = init_lecun_normal(self.emb_q)
self.emb_left = init_lecun_normal(self.emb_left)
self.emb_right = init_lecun_normal(self.emb_right)
self.emb_state = init_lecun_normal(self.emb_state)
nn.init.zeros_(self.emb.bias)
def _msa_emb(self, msa, seq):
N = msa.shape[1]
msa = self.emb(msa) # (B, N, L, d_pair) # MSA embedding
tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_pair) -- query embedding
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
return msa
def _pair_emb(self, seq, idx, bond_feats, dist_matrix, same_chain=None):
left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair)
right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair)
pair = left + right # (B, L, L, d_pair)
pair = pair + self.pos(seq, idx, bond_feats, dist_matrix, same_chain=same_chain) # add relative position
return pair
def _state_emb(self, seq):
return self.emb_state(seq)
def forward(self, msa, seq, idx, bond_feats, dist_matrix, same_chain=None):
# Inputs:
# - msa: Input MSA (B, N, L, d_init)
# - seq: Input Sequence (B, L)
# - idx: Residue index
# - bond_feats: Bond features (B, L, L)
# Outputs:
# - msa: Initial MSA embedding (B, N, L, d_msa)
# - pair: Initial Pair embedding (B, L, L, d_pair)
if self.enable_same_chain == False:
same_chain = None
msa = self._msa_emb(msa, seq)
# pair embedding
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix, same_chain=same_chain)
# state embedding
state = self._state_emb(seq)
return msa, pair, state
class MSA_emb_nostate(MSA_emb):
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=0, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False):
super().__init__(d_msa, d_pair, d_state, d_init, minpos, maxpos, maxpos_atom, p_drop, use_same_chain)
if d_init==0:
d_init = 2*ChemData().NAATOKENS + 2 + 2
self.emb_state = None # emb state is just the identity
def forward(self, msa, seq, idx, bond_feats, dist_matrix):
msa = self._msa_emb(msa, seq)
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix)
return msa, pair, None
class Extra_emb(nn.Module):
# Get initial seed MSA embedding
def __init__(self, d_msa=256, d_init=0, p_drop=0.1):
super(Extra_emb, self).__init__()
if d_init==0:
d_init=ChemData().NAATOKENS-1+4
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
self.emb_q = nn.Embedding(ChemData().NAATOKENS, d_msa) # embedding for query sequence
#self.drop = nn.Dropout(p_drop)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
def forward(self, msa, seq, idx):
# Inputs:
# - msa: Input MSA (B, N, L, d_init)
# - seq: Input Sequence (B, L)
# - idx: Residue index
# Outputs:
# - msa: Initial MSA embedding (B, N, L, d_msa)
N = msa.shape[1] # number of sequenes in MSA
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
#return self.drop(msa)
return (msa)
class Bond_emb(nn.Module):
def __init__(self, d_pair=128, d_init=0):
super(Bond_emb, self).__init__()
if d_init==0:
d_init = ChemData().NBTYPES
self.emb = nn.Linear(d_init, d_pair)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
def forward(self, bond_feats):
bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=ChemData().NBTYPES)
return self.emb(bond_feats.float())
class TemplatePairStack(nn.Module):
def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=32, d_t1d=22, d_state=32, p_drop=0.25,
symmetrize_repeats=False, repeat_length=None, symmsub_k=1, sym_method=None):
super(TemplatePairStack, self).__init__()
self.n_block = n_block
self.proj_t1d = nn.Linear(d_t1d, d_state)
proc_s = [PairStr2Pair(d_pair=d_templ,
n_head=n_head,
d_hidden=d_hidden,
d_state=d_state,
p_drop=p_drop,
symmetrize_repeats=symmetrize_repeats,
repeat_length=repeat_length,
symmsub_k=symmsub_k,
sym_method=sym_method) for i in range(n_block)]
self.block = nn.ModuleList(proc_s)
self.norm = nn.LayerNorm(d_templ)
self.reset_parameter()
def reset_parameter(self):
self.proj_t1d = init_lecun_normal(self.proj_t1d)
nn.init.zeros_(self.proj_t1d.bias)
def forward(self, templ, rbf_feat, t1d, use_checkpoint=False, p2p_crop=-1):
B, T, L = templ.shape[:3]
templ = templ.reshape(B*T, L, L, -1)
t1d = t1d.reshape(B*T, L, -1)
state = self.proj_t1d(t1d)
for i_block in range(self.n_block):
if use_checkpoint:
templ = checkpoint.checkpoint(
create_custom_forward(self.block[i_block]),
templ, rbf_feat, state, p2p_crop,
use_reentrant=True
)
else:
templ = self.block[i_block](templ, rbf_feat, state)
return self.norm(templ).reshape(B, T, L, L, -1)
def copy_main_2d(pair, Leff, idx):
"""
Copies the "main unit" of a block in generic 2D representation of shape (...,L,L,h)
along the main diagonal
"""
start = idx*Leff
end = (idx+1)*Leff
# grab the main block
main = torch.clone( pair[..., start:end, start:end, :] )
# copy it around the main diag
L = pair.shape[-2]
assert L%Leff == 0
N = L//Leff
for i_block in range(N):
start = i_block*Leff
stop = (i_block+1)*Leff
pair[...,start:stop, start:stop, :] = main
return pair
def copy_main_1d(single, Leff, idx):
"""
Copies the "main unit" of a block in generic 1D representation of shape (...,L,h)
to all other (non-main) blocks
Parameters:
single (torch.tensor, required): Shape [...,L,h] "1D" tensor
"""
main_start = idx*Leff
main_end = (idx+1)*Leff
# grab main block
main = torch.clone(single[..., main_start:main_end, :])
# copy it around
L = single.shape[-2]
assert L%Leff == 0
N = L//Leff
for i_block in range(N):
start = i_block*Leff
end = (i_block+1)*Leff
single[..., start:end, :] = main
return single
class Templ_emb(nn.Module):
# Get template embedding
# Features are
# t2d:
# - 61 distogram bins + 6 orientations (67)
# - Mask (missing/unaligned) (1)
# t1d:
# - tiled AA sequence (20 standard aa + gap)
# - confidence (1)
#
def __init__(self, d_t1d=0, d_t2d=67+1, d_tor=0, d_pair=128, d_state=32,
n_block=2, d_templ=64,
n_head=4, d_hidden=16, p_drop=0.25,
symmetrize_repeats=False, repeat_length=None, symmsub_k=1, sym_method='mean',
main_block=None, copy_main_block=None, additional_dt1d=0):
if d_t1d==0:
d_t1d=(ChemData().NAATOKENS-1)+1
if d_tor==0:
d_tor=3*ChemData().NTOTALDOFS
self.main_block = main_block
self.symmetrize_repeats = symmetrize_repeats
self.copy_main_block = copy_main_block
self.repeat_length = repeat_length
d_t1d += additional_dt1d
super(Templ_emb, self).__init__()
# process 2D features
self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
d_hidden=d_hidden, d_t1d=d_t1d, d_state=d_state, p_drop=p_drop,
symmetrize_repeats=symmetrize_repeats, repeat_length=repeat_length,
symmsub_k=symmsub_k, sym_method=sym_method)
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
# process torsion angles
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
self.proj_t1d = nn.Linear(d_templ, d_templ)
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
# d_hidden=d_hidden, p_drop=p_drop)
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
self.reset_parameter()
def reset_parameter(self):
self.emb = init_lecun_normal(self.emb)
nn.init.zeros_(self.emb.bias)
nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
nn.init.zeros_(self.emb_t1d.bias)
self.proj_t1d = init_lecun_normal(self.proj_t1d)
nn.init.zeros_(self.proj_t1d.bias)
def _get_templ_emb(self, t1d, t2d):
B, T, L, _ = t1d.shape
# Prepare 2D template features
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
#
templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 88)
return self.emb(templ) # Template templures (B, T, L, L, d_templ)
def _get_templ_rbf(self, xyz_t, mask_t):
B, T, L = xyz_t.shape[:3]
# process each template features
xyz_t = xyz_t.reshape(B*T, L, 3).contiguous()
mask_t = mask_t.reshape(B*T, L, L)
assert(xyz_t.is_contiguous())
rbf_feat = rbf(torch.cdist(xyz_t, xyz_t)) * mask_t[...,None] # (B*T, L, L, d_rbf)
return rbf_feat
def forward(self, t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=False, p2p_crop=-1):
# Input
# - t1d: 1D template info (B, T, L, 30)
# - t2d: 2D template info (B, T, L, L, 44)
# - alpha_t: torsion angle info (B, T, L, 30) - DOUBLE-CHECK
# - xyz_t: template CA coordinates (B, T, L, 3)
# - mask_t: is valid residue pair? (B, T, L, L)
# - pair: query pair features (B, L, L, d_pair)
# - state: query state features (B, L, d_state)
B, T, L, _ = t1d.shape
templ = self._get_templ_emb(t1d, t2d)
# this looks a lot like a bug but it is not
# mask_t has already been updated by same_chain in the train_EMA script so pairwise distances between
# protein chains are ignored
rbf_feat = self._get_templ_rbf(xyz_t, mask_t)
# process each template pair feature
templ = self.templ_stack(templ, rbf_feat, t1d, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop) # (B, T, L,L, d_templ)
# DJ - repeat protein symmetrization (2D)
if self.copy_main_block:
assert not (self.main_block is None)
assert self.symmetrize_repeats
# copy the main repeat unit internally down the pair representation diagonal
templ = copy_main_2d(templ, self.repeat_length, self.main_block)
# Prepare 1D template torsion angle features
t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 30+3*17)
# process each template features
t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))
# DJ - repeat protein symmetrization (1D)
if self.copy_main_block:
# already made assertions above
# copy main unit down single rep
t1d = copy_main_1d(t1d, self.repeat_length, self.main_block)
# mixing query state features to template state features
state = state.reshape(B*L, 1, -1)
t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
if use_checkpoint:
out = checkpoint.checkpoint(
create_custom_forward(self.attn_tor), state, t1d, t1d, use_reentrant=True
)
out = out.reshape(B, L, -1)
else:
out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1)
state = state.reshape(B, L, -1)
state = state + out
# mixing query pair features to template information (Template pointwise attention)
pair = pair.reshape(B*L*L, 1, -1)
templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1)
if use_checkpoint:
out = checkpoint.checkpoint(
create_custom_forward(self.attn), pair, templ, templ, use_reentrant=True
)
out = out.reshape(B, L, L, -1)
else:
out = self.attn(pair, templ, templ).reshape(B, L, L, -1)
#
pair = pair.reshape(B, L, L, -1)
pair = pair + out
return pair, state
class Recycling(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
super(Recycling, self).__init__()
self.proj_dist = nn.Linear(d_rbf, d_pair)
self.norm_pair = nn.LayerNorm(d_pair)
self.norm_msa = nn.LayerNorm(d_msa)
self.reset_parameter()
def reset_parameter(self):
#self.emb_rbf = init_lecun_normal(self.emb_rbf)
#nn.init.zeros_(self.emb_rbf.bias)
self.proj_dist = init_lecun_normal(self.proj_dist)
nn.init.zeros_(self.proj_dist.bias)
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
B, L = msa.shape[:2]
msa = self.norm_msa(msa)
pair = self.norm_pair(pair)
Ca = xyz[:,:,1]
dist_CA = rbf(
torch.cdist(Ca, Ca)
).reshape(B,L,L,-1)
if mask_recycle != None:
dist_CA = mask_recycle[...,None].float()*dist_CA
pair = pair + self.proj_dist(dist_CA)
return msa, pair, state # state is just zeros
class RecyclingAllFeatures(nn.Module):
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
super(RecyclingAllFeatures, self).__init__()
self.proj_dist = nn.Linear(d_rbf+d_state*2, d_pair)
self.norm_pair = nn.LayerNorm(d_pair)
self.proj_sctors = nn.Linear(2*ChemData().NTOTALDOFS, d_msa)
self.norm_msa = nn.LayerNorm(d_msa)
self.norm_state = nn.LayerNorm(d_state)
self.reset_parameter()
def reset_parameter(self):
self.proj_dist = init_lecun_normal(self.proj_dist)
nn.init.zeros_(self.proj_dist.bias)
self.proj_sctors = init_lecun_normal(self.proj_sctors)
nn.init.zeros_(self.proj_sctors.bias)
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
B, L = pair.shape[:2]
state = self.norm_state(state)
left = state.unsqueeze(2).expand(-1,-1,L,-1)
right = state.unsqueeze(1).expand(-1,L,-1,-1)
Ca_or_P = xyz[:,:,1].contiguous()
dist = rbf(torch.cdist(Ca_or_P, Ca_or_P))
if mask_recycle != None:
dist = mask_recycle[...,None].float()*dist
dist = torch.cat((dist, left, right), dim=-1)
dist = self.proj_dist(dist)
pair = dist + self.norm_pair(pair)
sctors = self.proj_sctors(sctors.reshape(B,-1,2*ChemData().NTOTALDOFS))
msa = sctors + self.norm_msa(msa)
return msa, pair, state
recycling_factory = {
"msa_pair": Recycling,
"all": RecyclingAllFeatures
}

View file

@ -0,0 +1,100 @@
import torch
import torch.nn as nn
from icecream import ic
import inspect
import sys, os
#script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
#sys.path.insert(0,script_dir+'SE3Transformer')
from rf2aa.util import xyz_frame_from_rotation_mask
from rf2aa.util_module import init_lecun_normal_param, \
make_full_graph, rbf, init_lecun_normal
from rf2aa.loss.loss import calc_chiral_grads
from rf2aa.model.layers.Attention_module import FeedForwardLayer
from rf2aa.SE3Transformer.se3_transformer.model import SE3Transformer
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
from rf2aa.util_module import get_seqsep_protein_sm
se3_transformer_path = inspect.getfile(SE3Transformer)
se3_fiber_path = inspect.getfile(Fiber)
assert 'rf2aa' in se3_transformer_path
class SE3TransformerWrapper(nn.Module):
"""SE(3) equivariant GCN with attention"""
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
l0_in_features=32, l0_out_features=32,
l1_in_features=3, l1_out_features=2,
num_edge_features=32):
super().__init__()
# Build the network
self.l1_in = l1_in_features
self.l1_out = l1_out_features
#
fiber_edge = Fiber({0: num_edge_features})
if l1_out_features > 0:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
else:
if l1_in_features > 0:
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
else:
fiber_in = Fiber({0: l0_in_features})
fiber_hidden = Fiber.create(num_degrees, num_channels)
fiber_out = Fiber({0: l0_out_features})
self.se3 = SE3Transformer(num_layers=num_layers,
fiber_in=fiber_in,
fiber_hidden=fiber_hidden,
fiber_out = fiber_out,
num_heads=n_heads,
channels_div=div,
fiber_edge=fiber_edge,
populate_edge="arcsin",
final_layer="lin",
use_layer_norm=True)
self.reset_parameter()
def reset_parameter(self):
# make sure linear layer before ReLu are initialized with kaiming_normal_
for n, p in self.se3.named_parameters():
if "bias" in n:
nn.init.zeros_(p)
elif len(p.shape) == 1:
continue
else:
if "radial_func" not in n:
p = init_lecun_normal_param(p)
else:
if "net.6" in n:
nn.init.zeros_(p)
else:
nn.init.kaiming_normal_(p, nonlinearity='relu')
# make last layers to be zero-initialized
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
if self.l1_out > 0:
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
if self.l1_in > 0:
node_features = {'0': type_0_features, '1': type_1_features}
else:
node_features = {'0': type_0_features}
edge_features = {'0': edge_features}
return self.se3(G, node_features, edge_features)

208
rf2aa/run_inference.py Normal file
View file

@ -0,0 +1,208 @@
import os
import hydra
import torch
import torch.nn as nn
from dataclasses import asdict
from rf2aa.data.merge_inputs import merge_all
from rf2aa.data.covale import load_covalent_molecules
from rf2aa.data.nucleic_acid import load_nucleic_acid
from rf2aa.data.protein import generate_msa_and_load_protein
from rf2aa.data.small_molecule import load_small_molecule
from rf2aa.ffindex import *
from rf2aa.chemical import initialize_chemdata, load_pdb_ideal_sdf_strings
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
from rf2aa.training.recycling import recycle_step_legacy
from rf2aa.util import writepdb, is_atom, Ls_from_same_chain_2d
from rf2aa.util_module import XYZConverter
class ModelRunner:
def __init__(self, config) -> None:
self.config = config
initialize_chemdata(self.config.chem_params)
FFindexDB = namedtuple("FFindexDB", "index, data")
self.ffdb = FFindexDB(read_index(config.database_params.hhdb+'_pdb.ffindex'),
read_data(config.database_params.hhdb+'_pdb.ffdata'))
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.xyz_converter = XYZConverter()
self.deterministic = config.get("deterministic", False)
self.molecule_db = load_pdb_ideal_sdf_strings()
def parse_inference_config(self):
residues_to_atomize = [] # chain letter, residue number, residue name
chains = []
protein_inputs = {}
if self.config.protein_inputs is not None:
for chain in self.config.protein_inputs:
if chain in chains:
raise ValueError(f"Duplicate chain found with name: {chain}. Please specify unique chain names")
elif len(chain) > 1:
raise ValueError(f"Chain name must be a single character, found chain with name: {chain}")
else:
chains.append(chain)
protein_input = generate_msa_and_load_protein(
self.config.protein_inputs[chain]["fasta_file"],
self
)
protein_inputs[chain] = protein_input
na_inputs = {}
if self.config.na_inputs is not None:
for chain in self.config.na_inputs:
na_input = load_nucleic_acid(
self.config.na_inputs[chain]["fasta"],
self.config.na_inputs[chain]["input_type"],
self
)
na_inputs[chain] = na_input
sm_inputs = {}
# first if any of the small molecules are covalently bonded to the protein
# merge the small molecule with the residue and add it as a separate ligand
# also add it to residues_to_atomize for bookkeeping later on
# need to handle atomizing multiple consecutive residues here too
if self.config.covale_inputs is not None:
covalent_sm_inputs, residues_to_atomize_covale = load_covalent_molecules(protein_inputs, self.config, self)
sm_inputs.update(covalent_sm_inputs)
residues_to_atomize.extend(residues_to_atomize_covale)
if self.config.sm_inputs is not None:
for chain in self.config.sm_inputs:
if self.config.sm_inputs[chain]["input_type"] not in ["smiles", "sdf"]:
raise ValueError("Small molecule input type must be smiles or sdf")
if chain in sm_inputs: # chain already processed as covale
continue
if "is_leaving" in self.config.sm_inputs[chain]:
raise ValueError("Leaving atoms are not supported for non-covalently bonded molecules")
sm_input = load_small_molecule(
self.config.sm_inputs[chain]["input"],
self.config.sm_inputs[chain]["input_type"],
self
)
sm_inputs[chain] = sm_input
if self.config.residue_replacement is not None:
# add to the sm_inputs list
# add to residues to atomize
raise NotImplementedError("Modres inference is not implemented")
raw_data = merge_all(protein_inputs, na_inputs, sm_inputs, residues_to_atomize, deterministic=self.deterministic)
self.raw_data = raw_data
def load_model(self):
self.model = RoseTTAFoldModule(
**self.config.legacy_model_param,
aamask = ChemData().allatom_mask.to(self.device),
atom_type_index = ChemData().atom_type_index.to(self.device),
ljlk_parameters = ChemData().ljlk_parameters.to(self.device),
lj_correction_parameters = ChemData().lj_correction_parameters.to(self.device),
num_bonds = ChemData().num_bonds.to(self.device),
cb_len = ChemData().cb_length_t.to(self.device),
cb_ang = ChemData().cb_angle_t.to(self.device),
cb_tor = ChemData().cb_torsion_t.to(self.device),
).to(self.device)
checkpoint = torch.load(self.config.checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
def construct_features(self):
return self.raw_data.construct_features(self)
def run_model_forward(self, input_feats):
input_feats.add_batch_dim()
input_feats.to(self.device)
input_dict = asdict(input_feats)
input_dict["bond_feats"] = input_dict["bond_feats"].long()
input_dict["seq_unmasked"] = input_dict["seq_unmasked"].long()
outputs = recycle_step_legacy(self.model,
input_dict,
self.config.loader_params.MAXCYCLE,
use_amp=False,
nograds=True,
force_device=self.device)
return outputs
def write_outputs(self, input_feats, outputs):
logits, logits_aa, logits_pae, logits_pde, p_bind, \
xyz, alpha_s, xyz_allatom, lddt, _, _, _ \
= outputs
seq_unmasked = input_feats.seq_unmasked
bond_feats = input_feats.bond_feats
err_dict = self.calc_pred_err(lddt, logits_pae, logits_pde, seq_unmasked)
err_dict["same_chain"] = input_feats.same_chain
plddts = err_dict["plddts"]
Ls = Ls_from_same_chain_2d(input_feats.same_chain)
plddts = plddts[0]
writepdb(os.path.join(f"{self.config.output_path}", f"{self.config.job_name}.pdb"),
xyz_allatom,
seq_unmasked,
bond_feats=bond_feats,
bfacts=plddts,
chain_Ls=Ls
)
torch.save(err_dict, os.path.join(f"{self.config.output_path}",
f"{self.config.job_name}_aux.pt"))
def infer(self):
self.load_model()
self.parse_inference_config()
input_feats = self.construct_features()
outputs = self.run_model_forward(input_feats)
self.write_outputs(input_feats, outputs)
def lddt_unbin(self, pred_lddt):
# calculate lddt prediction loss
nbin = pred_lddt.shape[1]
bin_step = 1.0 / nbin
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
pred_lddt = nn.Softmax(dim=1)(pred_lddt)
return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
def pae_unbin(self, logits_pae, bin_step=0.5):
nbin = logits_pae.shape[1]
bins = torch.linspace(bin_step*0.5, bin_step*nbin-bin_step*0.5, nbin,
dtype=logits_pae.dtype, device=logits_pae.device)
logits_pae = torch.nn.Softmax(dim=1)(logits_pae)
return torch.sum(bins[None,:,None,None]*logits_pae, dim=1)
def pde_unbin(self, logits_pde, bin_step=0.3):
nbin = logits_pde.shape[1]
bins = torch.linspace(bin_step*0.5, bin_step*nbin-bin_step*0.5, nbin,
dtype=logits_pde.dtype, device=logits_pde.device)
logits_pde = torch.nn.Softmax(dim=1)(logits_pde)
return torch.sum(bins[None,:,None,None]*logits_pde, dim=1)
def calc_pred_err(self, pred_lddts, logit_pae, logit_pde, seq):
"""Calculates summary metrics on predicted lDDT and distance errors"""
plddts = self.lddt_unbin(pred_lddts)
pae = self.pae_unbin(logit_pae) if logit_pae is not None else None
pde = self.pde_unbin(logit_pde) if logit_pde is not None else None
sm_mask = is_atom(seq)[0]
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
prot_mask_2d = (~sm_mask[None,:])*(~sm_mask[:,None])
inter_mask_2d = sm_mask[None,:]*(~sm_mask[:,None]) + (~sm_mask[None,:])*sm_mask[:,None]
# assumes B=1
err_dict = dict(
plddts = plddts.cpu(),
pae = pae.cpu(),
pde = pde.cpu(),
mean_plddt = float(plddts.mean()),
mean_pae = float(pae.mean()) if pae is not None else None,
pae_prot = float(pae[0,prot_mask_2d].mean()) if pae is not None else None,
pae_inter = float(pae[0,inter_mask_2d].mean()) if pae is not None else None,
)
return err_dict
@hydra.main(version_base=None, config_path='config/inference')
def main(config):
runner = ModelRunner(config)
runner.infer()
if __name__ == "__main__":
main()

371
rf2aa/scoring.py Normal file
View file

@ -0,0 +1,371 @@
import json, os
script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
##
## lk and lk term
#(LJ_RADIUS LJ_WDEPTH LK_DGFREE LK_LAMBDA LK_VOLUME)
type2ljlk = {
"CNH2":(1.968297,0.094638,3.077030,3.5000,13.500000),
"COO":(1.916661,0.141799,-3.332648,3.5000,14.653000),
"CH0":(2.011760,0.062642,1.409284,3.5000,8.998000),
"CH1":(2.011760,0.062642,-3.538387,3.5000,10.686000),
"CH2":(2.011760,0.062642,-1.854658,3.5000,18.331000),
"CH3":(2.011760,0.062642,7.292929,3.5000,25.855000),
"aroC":(2.016441,0.068775,1.797950,3.5000,16.704000),
"Ntrp":(1.802452,0.161725,-8.413116,3.5000,9.522100),
"Nhis":(1.802452,0.161725,-9.739606,3.5000,9.317700),
"NtrR":(1.802452,0.161725,-5.158080,3.5000,9.779200),
"NH2O":(1.802452,0.161725,-8.101638,3.5000,15.689000),
"Nlys":(1.802452,0.161725,-20.864641,3.5000,16.514000),
"Narg":(1.802452,0.161725,-8.968351,3.5000,15.717000),
"Npro":(1.802452,0.161725,-0.984585,3.5000,3.718100),
"OH":(1.542743,0.161947,-8.133520,3.5000,10.722000),
"OHY":(1.542743,0.161947,-8.133520,3.5000,10.722000),
"ONH2":(1.548662,0.182924,-6.591644,3.5000,10.102000),
"OOC":(1.492871,0.099873,-9.239832,3.5000,9.995600),
"S":(1.975967,0.455970,-1.707229,3.5000,17.640000),
"SH1":(1.975967,0.455970,3.291643,3.5000,23.240000),
"Nbb":(1.802452,0.161725,-9.969494,3.5000,15.992000),
"CAbb":(2.011760,0.062642,2.533791,3.5000,12.137000),
"CObb":(1.916661,0.141799,3.104248,3.5000,13.221000),
"OCbb":(1.540580,0.142417,-8.006829,3.5000,12.196000),
"Phos":(2.1500,0.5850,-4.1000,3.5000,14.7000), # phil
"Oet2":(1.5500,0.1591,-5.8500,3.5000,10.8000),
"Oet3":(1.5500,0.1591,-6.7000,3.5000,10.8000),
"HNbb":(0.901681,0.005000,0.0000,3.5000,0.0000),
"Hapo":(1.421272,0.021808,0.0000,3.5000,0.0000),
"Haro":(1.374914,0.015909,0.0000,3.5000,0.0000),
"Hpol":(0.901681,0.005000,0.0000,3.5000,0.0000),
"HS":(0.363887,0.050836,0.0000,3.5000,0.0000),
"genAl":(1,0.1, 0.0000, 0.0000, 0.0000),
"genAs":(1, 0.1, 0.0000, 0.0000, 0.0000),
"genAu":(1, 0.1, 0.0000, 0.0000, 0.0000),
"genB": (1,0.1, 0.0000, 0.0000, 0.0000),
"genBe": (1,0.1, 0.0000, 0.0000, 0.0000),
"genBr": (2.1971, 0.1090, 2.7951, 3.5000, 19.6876),
"genC": (2.0067, 0.0689, 2.2256, 3.5000, 10.6860), # params from CT
"genCa": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genCl": (2.0496, 0.1070, 2.3668, 3.5000, 17.5849),
"genCo": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genCr": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genCu": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genF": (1.6941, 0.0750, 1.6442, 3.5000, 12.2163),
"genFe": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genHg": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genI": (2.3600, 0.1110, 3.1361, 3.5000, 22.0891),
"genIr": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genK": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genLi": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genMg": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genMn": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genMo": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genN": (1.7854, 0.1497, -6.3760, 3.5000, 9.5221), # params from NG2
"genNi": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genO": (1.5492, 0.1576, -3.5363, 3.5000, 10.7220), # params for OG3
"genOs": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genP": (2.1290, 0.5838, -9.6272, 3.5000, 34.8000), # params for PG5
"genPb": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genPd": (1, 0.1, 0.0000, 0.0000,0.0000),
"genPr": (1, 0.1, 0.0000, 0.0000,0.0000),
"genPt": (1, 0.1, 0.0000, 0.0000,0.0000),
"genRe": (1, 0.1, 0.0000, 0.0000,0.0000),
"genRh": (1, 0.1, 0.0000, 0.0000,0.0000),
"genRu": (1, 0.1, 0.0000, 0.0000,0.0000),
"genS": (1.9893, 0.3634, -2.3560, 3.5000, 17.6400), # params for SG3
"genSb": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genSe": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genSi": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genSn": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genTb": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genTe": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genU": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genW": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genV": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genY": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genZn": (1, 0.1, 0.0000, 0.0000, 0.0000),
"genATM": (1, 0.0, 0.0000, 0.0000, 0.0000), # masked
}
# cartbonded
with open(script_dir+'cartbonded.json', 'r') as j:
cartbonded_data_raw = json.loads(j.read())
# hbond donor/acceptors
class HbAtom:
NO = 0
DO = 1 # donor
AC = 2 # acceptor
DA = 3 # donor & acceptor
HP = 4 # polar H
type2hb = {
"CNH2":HbAtom.NO, "COO":HbAtom.NO, "CH0":HbAtom.NO, "CH1":HbAtom.NO,
"CH2":HbAtom.NO, "CH3":HbAtom.NO, "aroC":HbAtom.NO, "Ntrp":HbAtom.DO,
"Nhis":HbAtom.AC, "NtrR":HbAtom.DO, "NH2O":HbAtom.DO, "Nlys":HbAtom.DO,
"Narg":HbAtom.DO, "Npro":HbAtom.NO, "OH":HbAtom.DA, "OHY":HbAtom.DA,
"ONH2":HbAtom.AC, "OOC":HbAtom.AC, "S":HbAtom.NO, "SH1":HbAtom.NO,
"Nbb":HbAtom.DO, "CAbb":HbAtom.NO, "CObb":HbAtom.NO, "OCbb":HbAtom.AC,
"HNbb":HbAtom.HP, "Hapo":HbAtom.NO, "Haro":HbAtom.NO, "Hpol":HbAtom.HP,
"HS":HbAtom.HP, # HP in rosetta(?)
"Phos":HbAtom.NO, "Oet2":HbAtom.AC, "Oet3":HbAtom.AC,
"genAl":HbAtom.NO, "genAs":HbAtom.NO, "genAu":HbAtom.NO, "genB": HbAtom.NO,
"genBe": HbAtom.NO, "genBr": HbAtom.NO, "genC": HbAtom.NO, "genCa": HbAtom.NO,
"genCl": HbAtom.NO, "genCo": HbAtom.NO, "genCr": HbAtom.NO, "genCu": HbAtom.NO,
"genF": HbAtom.DA, "genFe": HbAtom.NO, "genHg": HbAtom.NO, "genI": HbAtom.NO,
"genIr": HbAtom.NO, "genK": HbAtom.NO, "genLi": HbAtom.NO, "genMg": HbAtom.NO,
"genMn": HbAtom.NO, "genMo": HbAtom.NO, "genN": HbAtom.DA, "genNi": HbAtom.NO,
"genO": HbAtom.DA, "genOs": HbAtom.NO, "genP": HbAtom.NO, "genPb": HbAtom.NO,
"genPd": HbAtom.NO, "genPr": HbAtom.NO, "genPt": HbAtom.NO, "genRe": HbAtom.NO,
"genRh": HbAtom.NO, "genRu": HbAtom.NO, "genS": HbAtom.DA, "genSb": HbAtom.NO,
"genSe": HbAtom.NO, "genSi": HbAtom.NO,"genSn": HbAtom.NO,"genTb": HbAtom.NO,
"genTe": HbAtom.NO, "genU": HbAtom.NO, "genW": HbAtom.NO, "genV": HbAtom.NO,
"genY": HbAtom.NO, "genZn": HbAtom.NO, "genATM": HbAtom.NO, # masked
}
##
## hbond term
## TO DO: ADD DNA
class HbDonType:
PBA = 0
IND = 1
IME = 2
GDE = 3
CXA = 4
AMO = 5
HXL = 6
AHX = 7
NTYPES = 8
class HbAccType:
PBA = 0
CXA = 1
CXL = 2
HXL = 3
AHX = 4
IME = 5
NTYPES = 6
class HbHybType:
SP2 = 0
SP3 = 1
RING = 2
NTYPES = 3
type2dontype = {
"Nbb": HbDonType.PBA,
"Ntrp": HbDonType.IND,
"NtrR": HbDonType.GDE,
"Narg": HbDonType.GDE,
"NH2O": HbDonType.CXA,
"Nlys": HbDonType.AMO,
"OH": HbDonType.HXL,
"OHY": HbDonType.AHX,
}
type2acctype = {
"OCbb": HbAccType.PBA,
"ONH2": HbAccType.CXA,
"OOC": HbAccType.CXL,
"OH": HbAccType.HXL,
"OHY": HbAccType.AHX,
"Nhis": HbAccType.IME,
}
type2hybtype = {
"OCbb": HbHybType.SP2,
"ONH2": HbHybType.SP2,
"OOC": HbHybType.SP2,
"OHY": HbHybType.SP3,
"OH": HbHybType.SP3,
"Nhis": HbHybType.RING,
}
dontype2wt = {
HbDonType.PBA: 1.45,
HbDonType.IND: 1.15,
HbDonType.IME: 1.42,
HbDonType.GDE: 1.11,
HbDonType.CXA: 1.29,
HbDonType.AMO: 1.17,
HbDonType.HXL: 0.99,
HbDonType.AHX: 1.00,
}
acctype2wt = {
HbAccType.PBA: 1.19,
HbAccType.CXA: 1.21,
HbAccType.CXL: 1.10,
HbAccType.HXL: 1.15,
HbAccType.AHX: 1.15,
HbAccType.IME: 1.17,
}
class HbPolyType:
ahdist_aASN_dARG = 0
ahdist_aASN_dASN = 1
ahdist_aASN_dGLY = 2
ahdist_aASN_dHIS = 3
ahdist_aASN_dLYS = 4
ahdist_aASN_dSER = 5
ahdist_aASN_dTRP = 6
ahdist_aASN_dTYR = 7
ahdist_aASP_dARG = 8
ahdist_aASP_dASN = 9
ahdist_aASP_dGLY = 10
ahdist_aASP_dHIS = 11
ahdist_aASP_dLYS = 12
ahdist_aASP_dSER = 13
ahdist_aASP_dTRP = 14
ahdist_aASP_dTYR = 15
ahdist_aGLY_dARG = 16
ahdist_aGLY_dASN = 17
ahdist_aGLY_dGLY = 18
ahdist_aGLY_dHIS = 19
ahdist_aGLY_dLYS = 20
ahdist_aGLY_dSER = 21
ahdist_aGLY_dTRP = 22
ahdist_aGLY_dTYR = 23
ahdist_aHIS_dARG = 24
ahdist_aHIS_dASN = 25
ahdist_aHIS_dGLY = 26
ahdist_aHIS_dHIS = 27
ahdist_aHIS_dLYS = 28
ahdist_aHIS_dSER = 29
ahdist_aHIS_dTRP = 30
ahdist_aHIS_dTYR = 31
ahdist_aSER_dARG = 32
ahdist_aSER_dASN = 33
ahdist_aSER_dGLY = 34
ahdist_aSER_dHIS = 35
ahdist_aSER_dLYS = 36
ahdist_aSER_dSER = 37
ahdist_aSER_dTRP = 38
ahdist_aSER_dTYR = 39
ahdist_aTYR_dARG = 40
ahdist_aTYR_dASN = 41
ahdist_aTYR_dGLY = 42
ahdist_aTYR_dHIS = 43
ahdist_aTYR_dLYS = 44
ahdist_aTYR_dSER = 45
ahdist_aTYR_dTRP = 46
ahdist_aTYR_dTYR = 47
cosBAH_off = 48
cosBAH_7 = 49
cosBAH_6i = 50
AHD_1h = 51
AHD_1i = 52
AHD_1j = 53
AHD_1k = 54
# map donor:acceptor pairs to polynomials
hbtypepair2poly = {
(HbDonType.PBA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.CXA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.IME,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.IND,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.AMO,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
(HbDonType.GDE,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
(HbDonType.AHX,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.HXL,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.PBA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.CXA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IME,HbAccType.CXA): (HbPolyType.ahdist_aASN_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IND,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AMO,HbAccType.CXA): (HbPolyType.ahdist_aASN_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
(HbDonType.GDE,HbAccType.CXA): (HbPolyType.ahdist_aASN_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AHX,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.HXL,HbAccType.CXA): (HbPolyType.ahdist_aASN_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.PBA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.CXA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IME,HbAccType.CXL): (HbPolyType.ahdist_aASP_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.IND,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AMO,HbAccType.CXL): (HbPolyType.ahdist_aASP_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
(HbDonType.GDE,HbAccType.CXL): (HbPolyType.ahdist_aASP_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.AHX,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.HXL,HbAccType.CXL): (HbPolyType.ahdist_aASP_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
(HbDonType.PBA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dGLY,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.CXA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dASN,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.IME,HbAccType.IME): (HbPolyType.ahdist_aHIS_dHIS,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
(HbDonType.IND,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTRP,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
(HbDonType.AMO,HbAccType.IME): (HbPolyType.ahdist_aHIS_dLYS,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.GDE,HbAccType.IME): (HbPolyType.ahdist_aHIS_dARG,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
(HbDonType.AHX,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTYR,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.HXL,HbAccType.IME): (HbPolyType.ahdist_aHIS_dSER,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
(HbDonType.PBA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.CXA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IME,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IND,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
(HbDonType.AMO,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.GDE,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.AHX,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.HXL,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.PBA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.CXA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IME,HbAccType.HXL): (HbPolyType.ahdist_aSER_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.IND,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
(HbDonType.AMO,HbAccType.HXL): (HbPolyType.ahdist_aSER_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.GDE,HbAccType.HXL): (HbPolyType.ahdist_aSER_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.AHX,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
(HbDonType.HXL,HbAccType.HXL): (HbPolyType.ahdist_aSER_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
}
# polynomials are triplets, (x_min, x_max), (y[x<x_min],y[x>x_max]), (c_9,...,c_0)
hbpolytype2coeffs = { # Parameters imported from rosetta sp2_elec_params @v2017.48-dev59886
HbPolyType.ahdist_aASN_dARG: ((0.7019094761929999, 2.86820307153,),(1.1, 1.1,),( 0.58376113, -9.29345473, 64.86270904, -260.3946711, 661.43138077, -1098.01378958, 1183.58371466, -790.82929582, 291.33125475, -43.01629727,)),
HbPolyType.ahdist_aASN_dASN: ((0.625841094801, 2.75107708444,),(1.1, 1.1,),( -1.31243015, 18.6745072, -112.63858313, 373.32878091, -734.99145504, 861.38324861, -556.21026097, 143.5626977, 20.03238394, -11.52167705,)),
HbPolyType.ahdist_aASN_dGLY: ((0.7477341047139999, 2.6796350782799996,),(1.1, 1.1,),( -1.61294554, 23.3150793, -144.11313069, 496.13575, -1037.83809166, 1348.76826073, -1065.14368678, 473.89008925, -100.41142701, 7.44453515,)),
HbPolyType.ahdist_aASN_dHIS: ((0.344789524346, 2.8303582266000005,),(1.1, 1.1,),( -0.2657122, 4.1073775, -26.9099632, 97.10486507, -209.96002602, 277.33057268, -218.74766996, 97.42852213, -24.07382402, 3.73962807,)),
HbPolyType.ahdist_aASN_dLYS: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
HbPolyType.ahdist_aASN_dSER: ((1.0812774602500002, 2.6832123582599996,),(1.1, 1.1,),( -3.51524353, 47.54032873, -254.40168577, 617.84606386, -255.49935027, -2361.56230539, 6426.85797934, -7760.4403891, 4694.08106855, -1149.83549068,)),
HbPolyType.ahdist_aASN_dTRP: ((0.6689984999999999, 3.0704254,),(1.1, 1.1,),( -0.5284840422, 8.3510150838, -56.4100479414, 212.4884326254, -488.3178610608, 703.7762350506, -628.9936994633999, 331.4294356146, -93.265817571, 11.9691623698,)),
HbPolyType.ahdist_aASN_dTYR: ((1.08950268805, 2.6887046709400004,),(1.1, 1.1,),( -4.4488705, 63.27696281, -371.44187037, 1121.71921621, -1638.11394306, 142.99988401, 3436.65879147, -5496.07011787, 3709.30505237, -962.79669688,)),
HbPolyType.ahdist_aASP_dARG: ((0.8100404642229999, 2.9851230124799994,),(1.1, 1.1,),( -0.66430344, 10.41343145, -70.12656205, 265.12578414, -617.05849171, 911.39378582, -847.25013928, 472.09090981, -141.71513167, 18.57721132,)),
HbPolyType.ahdist_aASP_dASN: ((1.05401125073, 3.11129675908,),(1.1, 1.1,),( 0.02090728, -0.24144928, -0.19578075, 16.80904547, -117.70216251, 407.18551288, -809.95195924, 939.83137947, -593.94527692, 159.57610528,)),
HbPolyType.ahdist_aASP_dGLY: ((0.886260952629, 2.66843608743,),(1.1, 1.1,),( -7.00699267, 107.33021779, -713.45752385, 2694.43092298, -6353.05100287, 9667.94098394, -9461.9261027, 5721.0086877, -1933.97818198, 279.47763789,)),
HbPolyType.ahdist_aASP_dHIS: ((1.03597611139, 2.78208509117,),(1.1, 1.1,),( -1.34823406, 17.08925926, -78.75087193, 106.32795459, 400.18459698, -2041.04320193, 4033.83557387, -4239.60530204, 2324.00877252, -519.38410941,)),
HbPolyType.ahdist_aASP_dLYS: ((0.97789485082, 2.50496946108,),(1.1, 1.1,),( -0.41300315, 6.59243438, -44.44525308, 163.11796012, -351.2307798, 443.2463146, -297.84582856, 62.38600547, 33.77496227, -14.11652182,)),
HbPolyType.ahdist_aASP_dSER: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
HbPolyType.ahdist_aASP_dTRP: ((0.419155746414, 3.0486938610500003,),(1.1, 1.1,),( -0.24563471, 3.85598551, -25.75176874, 95.36525025, -214.13175785, 299.76133553, -259.0691378, 132.06975835, -37.15612683, 5.60445773,)),
HbPolyType.ahdist_aASP_dTYR: ((1.01057521468, 2.7207545786900003,),(1.1, 1.1,),( -0.15808672, -10.21398871, 178.80080949, -1238.0583801, 4736.25248274, -11071.96777725, 16239.07550047, -14593.21092621, 7335.66765017, -1575.08145078,)),
HbPolyType.ahdist_aGLY_dARG: ((0.499016667857, 2.9377031027599996,),(1.1, 1.1,),( -0.15923533, 2.5526639, -17.38788803, 65.71046957, -151.13491186, 218.78048387, -199.15882919, 110.56568974, -35.95143745, 6.47580213,)),
HbPolyType.ahdist_aGLY_dASN: ((0.7194388032060001, 2.9303772333599998,),(1.1, 1.1,),( -1.40718342, 23.65929694, -172.97144348, 720.64417348, -1882.85420815, 3194.87197776, -3515.52467458, 2415.75238278, -941.47705161, 159.84784277,)),
HbPolyType.ahdist_aGLY_dGLY: ((1.38403812683, 2.9981039433,),(1.1, 1.1,),( -0.5307601, 6.47949946, -22.39522814, -55.14303544, 708.30945242, -2619.49318162, 5227.8805795, -6043.31211632, 3806.04676175, -1007.66024144,)),
HbPolyType.ahdist_aGLY_dHIS: ((0.47406840932899996, 2.9234200830400003,),(1.1, 1.1,),( -0.12881679, 1.933838, -12.03134888, 39.92691227, -75.41519959, 78.87968016, -37.82769801, -0.13178679, 4.50193019, 0.45408359,)),
HbPolyType.ahdist_aGLY_dLYS: ((0.545347533475, 2.42624380351,),(1.1, 1.1,),( -0.22921901, 2.07015714, -6.2947417, 0.66645697, 45.21805416, -130.26668981, 176.32401031, -126.68226346, 43.96744431, -4.40105281,)),
HbPolyType.ahdist_aGLY_dSER: ((1.2803349239700001, 2.2465996077400003,),(1.1, 1.1,),( 6.72508613, -86.98495585, 454.18518444, -1119.89141452, 715.624663, 3172.36852982, -9455.49113097, 11797.38766934, -7363.28302948, 1885.50119665,)),
HbPolyType.ahdist_aGLY_dTRP: ((0.686512740494, 3.02901351815,),(1.1, 1.1,),( -0.1051487, 1.41597708, -7.42149173, 17.31830704, -6.98293652, -54.76605063, 130.95272289, -132.77575305, 62.75460448, -9.89110842,)),
HbPolyType.ahdist_aGLY_dTYR: ((1.28894687639, 2.26335316892,),(1.1, 1.1,),( 13.84536925, -169.40579865, 893.79467505, -2670.60617561, 5016.46234701, -6293.79378818, 5585.1049063, -3683.50722701, 1709.48661405, -399.5712153,)),
HbPolyType.ahdist_aHIS_dARG: ((0.8967400957230001, 2.96809434226,),(1.1, 1.1,),( 0.43460495, -10.52727665, 103.16979807, -551.42887412, 1793.25378923, -3701.08304991, 4861.05155388, -3922.4285529, 1763.82137881, -335.43441944,)),
HbPolyType.ahdist_aHIS_dASN: ((0.887120931718, 2.59166903153,),(1.1, 1.1,),( -3.50289894, 54.42813924, -368.14395507, 1418.90186454, -3425.60485859, 5360.92334837, -5428.54462336, 3424.68800187, -1221.49631986, 189.27122436,)),
HbPolyType.ahdist_aHIS_dGLY: ((1.01629363411, 2.58523052904,),(1.1, 1.1,),( -1.68095217, 21.31894078, -107.72203494, 251.81021758, -134.07465831, -707.64527046, 1894.6282743, -2156.85951846, 1216.83585872, -275.48078944,)),
HbPolyType.ahdist_aHIS_dHIS: ((0.9773010778919999, 2.72533796329,),(1.1, 1.1,),( -2.33350626, 35.66072412, -233.98966111, 859.13714961, -1925.30958567, 2685.35293578, -2257.48067507, 1021.49796136, -169.36082523, -12.1348055,)),
HbPolyType.ahdist_aHIS_dLYS: ((0.7080936539849999, 2.47191718632,),(1.1, 1.1,),( -1.88479369, 28.38084382, -185.74039957, 690.81875917, -1605.11404391, 2414.83545623, -2355.9723201, 1442.24496229, -506.45880637, 79.47512505,)),
HbPolyType.ahdist_aHIS_dSER: ((0.90846809159, 2.5477956147,),(1.1, 1.1,),( -0.92004641, 15.91841533, -117.83979251, 488.22211296, -1244.13047376, 2017.43704053, -2076.04468019, 1302.42621488, -451.29138643, 67.15812575,)),
HbPolyType.ahdist_aHIS_dTRP: ((0.991999676806, 2.81296584506,),(1.1, 1.1,),( -1.29358587, 19.97152857, -131.89796017, 485.29199356, -1084.0466445, 1497.3352889, -1234.58042682, 535.8048197, -75.58951691, -9.91148332,)),
HbPolyType.ahdist_aHIS_dTYR: ((0.882661836357, 2.5469016429900004,),(1.1, 1.1,),( -6.94700143, 109.07997256, -747.64035726, 2929.83959536, -7220.15788571, 11583.34170519, -12078.443492, 7881.85479715, -2918.19482068, 468.23988622,)),
HbPolyType.ahdist_aSER_dARG: ((1.0204658147399999, 2.8899566041900004,),(1.1, 1.1,),( 0.33887327, -7.54511361, 70.87316645, -371.88263665, 1206.67454443, -2516.82084076, 3379.45432693, -2819.73384601, 1325.33307517, -265.54533008,)),
HbPolyType.ahdist_aSER_dASN: ((1.01393052233, 3.0024434159299997,),(1.1, 1.1,),( 0.37012361, -7.46486204, 64.85775924, -318.6047209, 974.66322243, -1924.37334018, 2451.63840629, -1943.1915675, 867.07870559, -163.83771761,)),
HbPolyType.ahdist_aSER_dGLY: ((1.3856562156299999, 2.74160605537,),(1.1, 1.1,),( -1.32847415, 22.67528654, -172.53450064, 770.79034865, -2233.48829652, 4354.38807288, -5697.35144236, 4803.38686157, -2361.48028857, 518.28202382,)),
HbPolyType.ahdist_aSER_dHIS: ((0.550992321207, 2.68549261999,),(1.1, 1.1,),( -1.98041793, 29.59668639, -190.36751773, 688.43324385, -1534.68894765, 2175.66568976, -1952.07622113, 1066.28943929, -324.23381388, 43.41006168,)),
HbPolyType.ahdist_aSER_dLYS: ((0.8603189393170001, 2.77729502744,),(1.1, 1.1,),( 0.90884741, -17.24690746, 141.78469099, -661.85989315, 1929.7674992, -3636.43392779, 4419.00727923, -3332.43482061, 1410.78913266, -253.53829424,)),
HbPolyType.ahdist_aSER_dSER: ((1.10866545921, 2.61727781204,),(1.1, 1.1,),( -0.38264308, 4.41779675, -10.7016645, -81.91314845, 668.91174735, -2187.50684758, 3983.56103269, -4213.32320546, 2418.41531442, -580.28918569,)),
HbPolyType.ahdist_aSER_dTRP: ((1.4092077245899999, 2.8066121197099996,),(1.1, 1.1,),( 0.73762477, -11.70741276, 73.05154232, -205.00144794, 89.58794368, 1082.94541375, -3343.98293188, 4601.70815729, -3178.53568678, 896.59487831,)),
HbPolyType.ahdist_aSER_dTYR: ((1.10773547919, 2.60403567341,),(1.1, 1.1,),( -1.13249925, 14.66643161, -69.01708791, 93.96846742, 380.56063898, -1984.56675689, 4074.08891127, -4492.76927139, 2613.13168054, -627.71933508,)),
HbPolyType.ahdist_aTYR_dARG: ((1.05581400627, 2.85499888099,),(1.1, 1.1,),( -0.30396592, 5.30288548, -39.75788579, 167.5416547, -435.15958911, 716.52357586, -735.95195083, 439.76284677, -130.00400085, 13.23827556,)),
HbPolyType.ahdist_aTYR_dASN: ((1.0994919065200002, 2.8400869077900004,),(1.1, 1.1,),( 0.33548259, -3.5890451, 8.97769025, 48.1492734, -400.5983616, 1269.89613211, -2238.03101675, 2298.33009115, -1290.42961162, 308.43185147,)),
HbPolyType.ahdist_aTYR_dGLY: ((1.36546155066, 2.7303075916400004,),(1.1, 1.1,),( -1.55312915, 18.62092487, -70.91365499, -41.83066505, 1248.88835245, -4719.81948329, 9186.09528168, -10266.11434548, 6266.21959533, -1622.19652457,)),
HbPolyType.ahdist_aTYR_dHIS: ((0.5955982461899999, 2.6643551317500003,),(1.1, 1.1,),( -0.47442788, 7.16629863, -46.71287553, 171.46128947, -388.17484011, 558.45202337, -506.35587481, 276.46237273, -83.52554392, 12.05709329,)),
HbPolyType.ahdist_aTYR_dLYS: ((0.7978598238760001, 2.7620933782,),(1.1, 1.1,),( -0.20201464, 1.69684984, 0.27677515, -55.05786347, 286.29918332, -725.92372531, 1054.771746, -889.33602341, 401.11342256, -73.02221189,)),
HbPolyType.ahdist_aTYR_dSER: ((0.7083554962559999, 2.7032011990599996,),(1.1, 1.1,),( -0.70764192, 11.67978065, -82.80447482, 329.83401367, -810.58976486, 1269.57613941, -1261.04047117, 761.72890446, -254.37526011, 37.24301861,)),
HbPolyType.ahdist_aTYR_dTRP: ((1.10934023051, 2.8819112108,),(1.1, 1.1,),( -11.58453967, 204.88308091, -1589.77384548, 7100.84791905, -20113.61354433, 37457.83646055, -45850.02969172, 35559.8805122, -15854.78726237, 3098.04931146,)),
HbPolyType.ahdist_aTYR_dTYR: ((1.1105954899400001, 2.60081798685,),(1.1, 1.1,),( -1.63120628, 19.48493187, -81.0332905, 56.80517706, 687.42717782, -2842.77799908, 5385.52231471, -5656.74159307, 3178.83470588, -744.70042777,)),
HbPolyType.AHD_1h: ((1.76555274367, 3.1416,),(1.1, 1.1,),( 0.62725838, -9.98558225, 59.39060071, -120.82930213, -333.26536028, 2603.13082592, -6895.51207142, 9651.25238056, -7127.13394872, 2194.77244026,)),
HbPolyType.AHD_1i: ((1.59914724347, 3.1416,),(1.1, 1.1,),( -0.18888801, 3.48241679, -25.65508662, 89.57085435, -95.91708218, -367.93452341, 1589.6904702, -2662.3582135, 2184.40194483, -723.28383545,)),
HbPolyType.AHD_1j: ((1.1435646388, 3.1416,),(1.1, 1.1,),( 0.47683259, -9.54524724, 83.62557693, -420.55867774, 1337.19354878, -2786.26265686, 3803.178227, -3278.62879901, 1619.04116204, -347.50157909,)),
HbPolyType.AHD_1k: ((1.15651981164, 3.1416,),(1.1, 1.1,),( -0.10757999, 2.0276542, -16.51949978, 75.83866839, -214.18025678, 380.55117567, -415.47847283, 255.66998474, -69.94662165, 3.21313428,)),
HbPolyType.cosBAH_off: ((-1234.0, 1.1,),(1.1, 1.1,),( 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,)),
HbPolyType.cosBAH_6i: ((-0.23538144897100002, 1.1,),(1.1, 1.1,),( -0.822093, -3.75364636, 46.88852157, -129.5440564, 146.69151428, -67.60598792, 2.91683129, 9.26673173, -3.84488178, 0.05706659,)),
HbPolyType.cosBAH_7: ((-0.019373850666900002, 1.1,),(1.1, 1.1,),( 0.0, -27.942923450028, 136.039920253368, -268.06959056747, 275.400462507919, -153.502076215949, 39.741591385461, 0.693861510121, -3.885952320499, 1.024765090788892)),
}

90
rf2aa/setup_model.py Normal file
View file

@ -0,0 +1,90 @@
import torch
import numpy as np
import hydra
import os
from rf2aa.training.EMA import EMA
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
from rf2aa.util_module import XYZConverter
from rf2aa.chemical import ChemicalData as ChemData
#TODO: control environment variables from config
# limit thread counts
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['OPENBLAS_NUM_THREADS'] = '4'
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
## To reproduce errors
import random
def seed_all(seed=0):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.set_num_threads(4)
#torch.autograd.set_detect_anomaly(True)
class Trainer:
def __init__(self, config) -> None:
self.config = config
assert self.config.ddp_params.batch_size == 1, "batch size is assumed to be 1"
if self.config.experiment.output_dir is not None:
self.output_dir = self.config.experiment.output_dir
else:
self.output_dir = "models/"
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
def move_constants_to_device(self, gpu):
self.fi_dev = ChemData().frame_indices.to(gpu)
self.xyz_converter = XYZConverter().to(gpu)
self.l2a = ChemData().long2alt.to(gpu)
self.aamask = ChemData().allatom_mask.to(gpu)
self.num_bonds = ChemData().num_bonds.to(gpu)
self.atom_type_index = ChemData().atom_type_index.to(gpu)
self.ljlk_parameters = ChemData().ljlk_parameters.to(gpu)
self.lj_correction_parameters = ChemData().lj_correction_parameters.to(gpu)
self.hbtypes = ChemData().hbtypes.to(gpu)
self.hbbaseatoms = ChemData().hbbaseatoms.to(gpu)
self.hbpolys = ChemData().hbpolys.to(gpu)
self.cb_len = ChemData().cb_length_t.to(gpu)
self.cb_ang = ChemData().cb_angle_t.to(gpu)
self.cb_tor = ChemData().cb_torsion_t.to(gpu)
class LegacyTrainer(Trainer):
def __init__(self, config) -> None:
super().__init__(config)
def construct_model(self, device="cpu"):
self.model = RoseTTAFoldModule(
**self.config.legacy_model_param,
aamask = ChemData().allatom_mask.to(device),
atom_type_index = ChemData().atom_type_index.to(device),
ljlk_parameters = ChemData().ljlk_parameters.to(device),
lj_correction_parameters = ChemData().lj_correction_parameters.to(device),
num_bonds = ChemData().num_bonds.to(device),
cb_len = ChemData().cb_length_t.to(device),
cb_ang = ChemData().cb_angle_t.to(device),
cb_tor = ChemData().cb_torsion_t.to(device),
).to(device)
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
@hydra.main(version_base=None, config_path='config/train')
def main(config):
seed_all()
trainer = trainer_factory[config.experiment.trainer](config=config)
trainer_factory = {
"legacy": LegacyTrainer,
}
if __name__ == "__main__":
main()

716
rf2aa/symmetry.py Normal file
View file

@ -0,0 +1,716 @@
import sys
import numpy as np
import torch
SYMA = 1.0
def generateC(angs, eps=1e-6):
L = angs.shape[0]
Rs = torch.eye(3, device=angs.device).repeat(L,1,1)
Rs[:,1,1] = torch.cos(angs)
Rs[:,1,2] = -torch.sin(angs)
Rs[:,2,1] = torch.sin(angs)
Rs[:,2,2] = torch.cos(angs)
return Rs
def generateD(angs, eps=1e-6):
L = angs.shape[0]
Rs = torch.eye(3, device=angs.device).repeat(2*L,1,1)
Rs[:L,1,1] = torch.cos(angs)
Rs[:L,1,2] = -torch.sin(angs)
Rs[:L,2,1] = torch.sin(angs)
Rs[:L,2,2] = torch.cos(angs)
Rx = torch.tensor([[-1.,0,0],[0,-1,0],[0,0,1]],device=angs.device)
Rs[L:] = torch.einsum('ij,bjk->bik',Rx,Rs[:L])
return Rs
def find_symm_subs(xyz,Rs,metasymm):
com = xyz[:,:,1].mean(dim=-2)
rcoms = torch.einsum('sij,bj->si', Rs, com)
subsymms, nneighs = metasymm
subs = []
for i in range(len(subsymms)):
drcoms = torch.linalg.norm(rcoms[0,:] - rcoms[subsymms[i],:], dim=-1)
_,subs_i = torch.topk(drcoms,nneighs[i],largest=False)
subs_i,_ = torch.sort( subsymms[i][subs_i] )
subs.append(subs_i)
subs=torch.cat(subs)
xyz_new = torch.einsum('sij,braj->bsrai', Rs[subs], xyz).reshape(
xyz.shape[0],-1,xyz.shape[2],3)
return xyz_new, subs
def update_symm_subs(xyz,subs,Rs,metasymm):
xyz_new = torch.einsum('sij,braj->bsrai', Rs[subs], xyz).reshape(
xyz.shape[0],-1,xyz.shape[2],3)
return xyz_new
def get_symm_map(subs,O):
symmmask = torch.zeros(O,dtype=torch.long)
symmmask[subs] = torch.arange(1,subs.shape[0]+1)
return symmmask
def rotation_from_matrix(R, eps=1e-4):
w, W = torch.linalg.eig(R.T)
i = torch.where(abs(torch.real(w) - 1.0) < eps)[0]
if (len(i)==0):
i = torch.tensor([0])
print ('rotation_from_matrix w',torch.real(w))
print ('rotation_from_matrix R.T',R.T)
axis = torch.real(W[:, i[-1]]).squeeze()
cosa = (torch.trace(R) - 1.0) / 2.0
if abs(axis[2]) > eps:
sina = (R[1, 0] + (cosa-1.0)*axis[0]*axis[1]) / axis[2]
elif abs(axis[1]) > eps:
sina = (R[0, 2] + (cosa-1.0)*axis[0]*axis[2]) / axis[1]
else:
sina = (R[2, 1] + (cosa-1.0)*axis[1]*axis[2]) / axis[0]
angle = torch.atan2(sina, cosa)
return angle, axis
def kabsch(pred, true):
def rmsd(V, W, eps=1e-6):
L = V.shape[0]
return torch.sqrt(torch.sum((V-W)*(V-W)) / L + eps)
def centroid(X):
return X.mean(dim=-2, keepdim=True)
cP = centroid(pred)
cT = centroid(true)
pred = pred - cP
true = true - cT
C = torch.matmul(pred.permute(1,0), true)
V, S, W = torch.svd(C)
d = torch.ones([3,3], device=pred.device)
d[:,-1] = torch.sign(torch.det(V)*torch.det(W))
U = torch.matmul(d*V, W.permute(1,0)) # (IB, 3, 3)
rpred = torch.matmul(pred, U) # (IB, L*3, 3)
rms = rmsd(rpred, true)
return rms, U, cP, cT
# do lines X0->X and Y0->Y intersect?
def intersect(X0,X,Y0,Y,eps=0.1):
mtx = torch.cat(
(torch.stack((X0,X0+X,Y0,Y0+Y)), torch.ones((4,1))) , axis=1
)
det = torch.linalg.det( mtx )
return (torch.abs(det) <= eps)
def get_angle(X,Y):
angle = torch.acos( torch.clamp( torch.sum(X*Y), -1., 1. ) )
if (angle > np.pi/2):
angle = np.pi - angle
return angle
# given the coordinates of a subunit +
def get_symmetry(xyz, mask, rms_cut=2.5, nfold_cut=0.1, angle_cut=0.05, trans_cut=2.0):
nops = xyz.shape[0]
L = xyz.shape[1]//2
# PASS 1: find all symm axes
symmaxes = []
for i in range(nops):
# if there are multiple biomt records, this may occur.
# rather than try to rescue, we will take the 1st (typically author-assigned)
offset0 = torch.linalg.norm(xyz[i,:L,1]-xyz[0,:L,1], dim=-1)
if (torch.mean(offset0)>1e-4):
continue
# get alignment
mask_i = mask[i,:L,1]*mask[i,L:,1]
xyz_i = xyz[i,:L,1][mask_i,:]
xyz_j = xyz[i,L:,1][mask_i,:]
rms_ij, Uij, cI, cJ = kabsch(xyz_i, xyz_j)
if (rms_ij > rms_cut):
#print (i,'rms',rms_ij)
continue
# get axis and point symmetry about axis
angle, axis = rotation_from_matrix(Uij)
nfold = 2*np.pi/torch.abs(angle)
# a) ensure integer # of subunits per rotation
if (torch.abs( nfold - torch.round(nfold) ) > nfold_cut ):
#print ('nfold fail',nfold)
continue
nfold = torch.round(nfold).long()
# b) ensure rotation only (no translation)
delCOM = torch.mean(xyz_i, dim=-2) - torch.mean(xyz_j, dim=-2)
trans_dot_symaxis = nfold * torch.abs(torch.dot(delCOM, axis))
if (trans_dot_symaxis > trans_cut ):
#print ('trans fail',trans_dot_symaxis)
continue
# 3) get a point on the symm axis from CoMs and angle
cIJ = torch.sign(angle) * (cJ-cI).squeeze(0)
dIJ = torch.linalg.norm(cIJ)
p_mid = (cI+cJ).squeeze(0) / 2
u = cIJ / dIJ # unit vector in plane of circle
v = torch.cross(axis, u) # unit vector from sym axis to p_mid
r = dIJ / (2*torch.sin(angle/2))
d = torch.sqrt( r*r - dIJ*dIJ/4 ) # distance from mid-chord to center
point = p_mid - (d)*v
# check if redundant
toadd = True
for j,(nf_j,ax_j,pt_j,err_j) in enumerate(symmaxes):
if (not intersect(pt_j,ax_j,point,axis)):
continue
angle_j = get_angle(ax_j,axis)
if (angle_j < angle_cut):
if (nf_j < nfold): # stored is a subsymmetry of complex, overwrite
symmaxes[j] = (nfold, axis, point, i)
toadd = False
if (toadd):
symmaxes.append( (nfold, axis, point, i) )
# PASS 2: combine
symmgroup = 'C1'
subsymm = []
if len(symmaxes)==1:
symmgroup = 'C%d'%(symmaxes[0][0])
subsymm = [symmaxes[0][3]]
elif len(symmaxes)>1:
symmaxes = sorted(symmaxes, key=lambda x: x[0], reverse=True)
angle = get_angle(symmaxes[0][1],symmaxes[1][1])
subsymm = [symmaxes[0][3],symmaxes[1][3]]
# 2-fold and n-fold intersect at 90 degress => Dn
if (symmaxes[1][0] == 2 and torch.abs(angle-np.pi/2) < angle_cut):
symmgroup = 'D%d'%(symmaxes[0][0])
else:
# polyhedral rules:
# 3-Fold + 2-fold intersecting at acos(-1/sqrt(3)) -> T
angle_tgt = np.arccos(-1/np.sqrt(3))
if (symmaxes[0][0] == 3 and symmaxes[1][0] == 2 and torch.abs(angle - angle_tgt) < angle_cut):
symmgroup = 'T'
# 3-Fold + 2-fold intersecting at asin(1/sqrt(3)) -> O
angle_tgt = np.arcsin(1/np.sqrt(3))
if (symmaxes[0][0] == 3 and symmaxes[1][0] == 2 and torch.abs(angle - angle_tgt) < angle_cut):
symmgroup = 'O'
# 4-Fold + 3-fold intersecting at acos(1/sqrt(3)) -> O
angle_tgt = np.arccos(1/np.sqrt(3))
if (symmaxes[0][0] == 4 and symmaxes[1][0] == 3 and torch.abs(angle - angle_tgt) < angle_cut):
symmgroup = 'O'
# 3-Fold + 2-fold intersecting at 0.5*acos(sqrt(5)/3) -> I
angle_tgt = 0.5*np.arccos(np.sqrt(5)/3)
if (symmaxes[0][0] == 3 and symmaxes[1][0] == 2 and torch.abs(angle - angle_tgt) < angle_cut):
symmgroup = 'I'
# 5-Fold + 2-fold intersecting at 0.5*acos(1/sqrt(5)) -> I
angle_tgt = 0.5*np.arccos(1/np.sqrt(5))
if (symmaxes[0][0] == 5 and symmaxes[1][0] == 2 and torch.abs(angle - angle_tgt) < angle_cut):
symmgroup = 'I'
# 5-Fold + 3-fold intersecting at 0.5*acos((4*sqrt(5)-5)/15) -> I
angle_tgt = 0.5*np.arccos((4*np.sqrt(5)-5)/15)
if (symmaxes[0][0] == 5 and symmaxes[1][0] == 3 and torch.abs(angle - angle_tgt) < angle_cut):
symmgroup = 'I'
else:
pass
#fd: we could use a single symmetry here instead.
# But these cases mostly are bad BIOUNIT annotations...
#print ('nomatch',angle, [(x,y) for x,_,_,y in symmaxes])
return symmgroup, subsymm
def symm_subunit_matrix(symmid):
if (symmid[0]=='C'):
nsub = int(symmid[1:])
symmatrix = (
torch.arange(nsub)[:,None]-torch.arange(nsub)[None,:]
)%nsub
angles = torch.linspace(0,2*np.pi,nsub+1)[:nsub]
Rs = generateC(angles)
metasymm = (
[torch.arange(nsub)],
[min(3,nsub)]
)
if (nsub==1):
D = 0.0
else:
est_radius = 2.0*SYMA
theta = 2.0*np.pi/nsub
D = est_radius/np.sin(theta/2)
offset = torch.tensor([ 0.0,0.0,float(D) ])
elif (symmid[0]=='D'):
nsub = int(symmid[1:])
cblk=(torch.arange(nsub)[:,None]-torch.arange(nsub)[None,:])%nsub
symmatrix=torch.zeros((2*nsub,2*nsub),dtype=torch.long)
symmatrix[:nsub,:nsub] = cblk
symmatrix[:nsub,nsub:] = cblk+nsub
symmatrix[nsub:,:nsub] = cblk+nsub
symmatrix[nsub:,nsub:] = cblk
angles = torch.linspace(0,2*np.pi,nsub+1)[:nsub]
Rs = generateD(angles)
metasymm = (
[torch.arange(nsub), nsub+torch.arange(nsub)],
[min(3,nsub),2]
)
#metasymm = (
# [torch.arange(2*nsub)],
# [min(2*nsub,5)]
#)
est_radius = 2.0*SYMA
theta1 = 2.0*np.pi/nsub
theta2 = np.pi
D1 = est_radius/np.sin(theta1/2)
D2 = est_radius/np.sin(theta2/2)
offset = torch.tensor([ float(D2),0.0,float(D1) ])
#offset = torch.tensor([ 0.0,0.0,0.0 ])
elif (symmid=='T'):
symmatrix=torch.tensor(
[[ 0, 1, 2, 3, 8, 11, 9, 10, 4, 6, 7, 5],
[ 1, 0, 3, 2, 9, 10, 8, 11, 5, 7, 6, 4],
[ 2, 3, 0, 1, 10, 9, 11, 8, 6, 4, 5, 7],
[ 3, 2, 1, 0, 11, 8, 10, 9, 7, 5, 4, 6],
[ 4, 6, 7, 5, 0, 1, 2, 3, 8, 11, 9, 10],
[ 5, 7, 6, 4, 1, 0, 3, 2, 9, 10, 8, 11],
[ 6, 4, 5, 7, 2, 3, 0, 1, 10, 9, 11, 8],
[ 7, 5, 4, 6, 3, 2, 1, 0, 11, 8, 10, 9],
[ 8, 11, 9, 10, 4, 6, 7, 5, 0, 1, 2, 3],
[ 9, 10, 8, 11, 5, 7, 6, 4, 1, 0, 3, 2],
[10, 9, 11, 8, 6, 4, 5, 7, 2, 3, 0, 1],
[11, 8, 10, 9, 7, 5, 4, 6, 3, 2, 1, 0]])
Rs = torch.zeros(12,3,3)
Rs[ 0]=torch.tensor([[1.000000,0.000000,0.000000],[0.000000,1.000000,0.000000],[0.000000,0.000000,1.000000]])
Rs[ 1]=torch.tensor([[-1.000000,0.000000,0.000000],[0.000000,-1.000000,0.000000],[0.000000,0.000000,1.000000]])
Rs[ 2]=torch.tensor([[-1.000000,0.000000,0.000000],[0.000000,1.000000,0.000000],[0.000000,0.000000,-1.000000]])
Rs[ 3]=torch.tensor([[1.000000,0.000000,0.000000],[0.000000,-1.000000,0.000000],[0.000000,0.000000,-1.000000]])
Rs[ 4]=torch.tensor([[0.000000,0.000000,1.000000],[1.000000,0.000000,0.000000],[0.000000,1.000000,0.000000]])
Rs[ 5]=torch.tensor([[0.000000,0.000000,1.000000],[-1.000000,0.000000,0.000000],[0.000000,-1.000000,0.000000]])
Rs[ 6]=torch.tensor([[0.000000,0.000000,-1.000000],[-1.000000,0.000000,0.000000],[0.000000,1.000000,0.000000]])
Rs[ 7]=torch.tensor([[0.000000,0.000000,-1.000000],[1.000000,0.000000,0.000000],[0.000000,-1.000000,0.000000]])
Rs[ 8]=torch.tensor([[0.000000,1.000000,0.000000],[0.000000,0.000000,1.000000],[1.000000,0.000000,0.000000]])
Rs[ 9]=torch.tensor([[0.000000,-1.000000,0.000000],[0.000000,0.000000,1.000000],[-1.000000,0.000000,0.000000]])
Rs[10]=torch.tensor([[0.000000,1.000000,0.000000],[0.000000,0.000000,-1.000000],[-1.000000,0.000000,0.000000]])
Rs[11]=torch.tensor([[0.000000,-1.000000,0.000000],[0.000000,0.000000,-1.000000],[1.000000,0.000000,0.000000]])
nneigh = 5
metasymm = (
[torch.arange(12)],
[6]
)
est_radius = 4.0*SYMA
offset = torch.tensor([ 1.0,0.0,0.0 ])
offset = est_radius * offset / torch.linalg.norm(offset)
metasymm = (
[torch.arange(12)],
[6]
)
elif (symmid=='O'):
symmatrix=torch.tensor(
[[ 0, 1, 2, 3, 8, 11, 9, 10, 4, 6, 7, 5, 12, 13, 15, 14, 19, 17,
18, 16, 22, 21, 20, 23],
[ 1, 0, 3, 2, 9, 10, 8, 11, 5, 7, 6, 4, 13, 12, 14, 15, 18, 16,
19, 17, 23, 20, 21, 22],
[ 2, 3, 0, 1, 10, 9, 11, 8, 6, 4, 5, 7, 14, 15, 13, 12, 17, 19,
16, 18, 20, 23, 22, 21],
[ 3, 2, 1, 0, 11, 8, 10, 9, 7, 5, 4, 6, 15, 14, 12, 13, 16, 18,
17, 19, 21, 22, 23, 20],
[ 4, 6, 7, 5, 0, 1, 2, 3, 8, 11, 9, 10, 16, 18, 17, 19, 21, 22,
23, 20, 15, 14, 12, 13],
[ 5, 7, 6, 4, 1, 0, 3, 2, 9, 10, 8, 11, 17, 19, 16, 18, 20, 23,
22, 21, 14, 15, 13, 12],
[ 6, 4, 5, 7, 2, 3, 0, 1, 10, 9, 11, 8, 18, 16, 19, 17, 23, 20,
21, 22, 13, 12, 14, 15],
[ 7, 5, 4, 6, 3, 2, 1, 0, 11, 8, 10, 9, 19, 17, 18, 16, 22, 21,
20, 23, 12, 13, 15, 14],
[ 8, 11, 9, 10, 4, 6, 7, 5, 0, 1, 2, 3, 20, 23, 22, 21, 14, 15,
13, 12, 17, 19, 16, 18],
[ 9, 10, 8, 11, 5, 7, 6, 4, 1, 0, 3, 2, 21, 22, 23, 20, 15, 14,
12, 13, 16, 18, 17, 19],
[10, 9, 11, 8, 6, 4, 5, 7, 2, 3, 0, 1, 22, 21, 20, 23, 12, 13,
15, 14, 19, 17, 18, 16],
[11, 8, 10, 9, 7, 5, 4, 6, 3, 2, 1, 0, 23, 20, 21, 22, 13, 12,
14, 15, 18, 16, 19, 17],
[12, 13, 15, 14, 19, 17, 18, 16, 22, 21, 20, 23, 0, 1, 2, 3, 8, 11,
9, 10, 4, 6, 7, 5],
[13, 12, 14, 15, 18, 16, 19, 17, 23, 20, 21, 22, 1, 0, 3, 2, 9, 10,
8, 11, 5, 7, 6, 4],
[14, 15, 13, 12, 17, 19, 16, 18, 20, 23, 22, 21, 2, 3, 0, 1, 10, 9,
11, 8, 6, 4, 5, 7],
[15, 14, 12, 13, 16, 18, 17, 19, 21, 22, 23, 20, 3, 2, 1, 0, 11, 8,
10, 9, 7, 5, 4, 6],
[16, 18, 17, 19, 21, 22, 23, 20, 15, 14, 12, 13, 4, 6, 7, 5, 0, 1,
2, 3, 8, 11, 9, 10],
[17, 19, 16, 18, 20, 23, 22, 21, 14, 15, 13, 12, 5, 7, 6, 4, 1, 0,
3, 2, 9, 10, 8, 11],
[18, 16, 19, 17, 23, 20, 21, 22, 13, 12, 14, 15, 6, 4, 5, 7, 2, 3,
0, 1, 10, 9, 11, 8],
[19, 17, 18, 16, 22, 21, 20, 23, 12, 13, 15, 14, 7, 5, 4, 6, 3, 2,
1, 0, 11, 8, 10, 9],
[20, 23, 22, 21, 14, 15, 13, 12, 17, 19, 16, 18, 8, 11, 9, 10, 4, 6,
7, 5, 0, 1, 2, 3],
[21, 22, 23, 20, 15, 14, 12, 13, 16, 18, 17, 19, 9, 10, 8, 11, 5, 7,
6, 4, 1, 0, 3, 2],
[22, 21, 20, 23, 12, 13, 15, 14, 19, 17, 18, 16, 10, 9, 11, 8, 6, 4,
5, 7, 2, 3, 0, 1],
[23, 20, 21, 22, 13, 12, 14, 15, 18, 16, 19, 17, 11, 8, 10, 9, 7, 5,
4, 6, 3, 2, 1, 0]])
Rs = torch.zeros(24,3,3)
Rs[0]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
Rs[1]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
Rs[2]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
Rs[3]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
Rs[4]=torch.tensor([[ 0.000000, 0.000000,1.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000]])
Rs[5]=torch.tensor([[ 0.000000, 0.000000,1.000000],[-1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000]])
Rs[6]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[-1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000]])
Rs[7]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000]])
Rs[8]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 0.000000, 0.000000,1.000000],[ 1.000000, 0.000000,0.000000]])
Rs[9]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 0.000000, 0.000000,1.000000],[-1.000000, 0.000000,0.000000]])
Rs[10]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[-1.000000, 0.000000,0.000000]])
Rs[11]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[ 1.000000, 0.000000,0.000000]])
Rs[12]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
Rs[13]=torch.tensor([[ 0.000000,-1.000000,0.000000],[-1.000000, 0.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
Rs[14]=torch.tensor([[ 0.000000, 1.000000,0.000000],[-1.000000, 0.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
Rs[15]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
Rs[16]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000, 0.000000,1.000000],[ 0.000000,-1.000000,0.000000]])
Rs[17]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000, 0.000000,1.000000],[ 0.000000, 1.000000,0.000000]])
Rs[18]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[ 0.000000,-1.000000,0.000000]])
Rs[19]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[ 0.000000, 1.000000,0.000000]])
Rs[20]=torch.tensor([[ 0.000000, 0.000000,1.000000],[ 0.000000, 1.000000,0.000000],[-1.000000, 0.000000,0.000000]])
Rs[21]=torch.tensor([[ 0.000000, 0.000000,1.000000],[ 0.000000,-1.000000,0.000000],[ 1.000000, 0.000000,0.000000]])
Rs[22]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[ 0.000000, 1.000000,0.000000],[ 1.000000, 0.000000,0.000000]])
Rs[23]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[ 0.000000,-1.000000,0.000000],[-1.000000, 0.000000,0.000000]])
est_radius = 6.0*SYMA
offset = torch.tensor([ 1.0,0.0,0.0 ])
offset = est_radius * offset / torch.linalg.norm(offset)
metasymm = (
[torch.arange(24)],
[6]
)
elif (symmid=='I'):
symmatrix=torch.tensor(
[[ 0, 4, 3, 2, 1, 5, 33, 49, 41, 22, 10, 27, 51, 59, 38, 15, 16, 17,
18, 19, 40, 21, 9, 32, 48, 55, 39, 11, 28, 52, 45, 42, 23, 6, 34, 50,
58, 37, 14, 26, 20, 8, 31, 47, 44, 30, 46, 43, 24, 7, 35, 12, 29, 53,
56, 25, 54, 57, 36, 13],
[ 1, 0, 4, 3, 2, 6, 34, 45, 42, 23, 11, 28, 52, 55, 39, 16, 17, 18,
19, 15, 41, 22, 5, 33, 49, 56, 35, 12, 29, 53, 46, 43, 24, 7, 30, 51,
59, 38, 10, 27, 21, 9, 32, 48, 40, 31, 47, 44, 20, 8, 36, 13, 25, 54,
57, 26, 50, 58, 37, 14],
[ 2, 1, 0, 4, 3, 7, 30, 46, 43, 24, 12, 29, 53, 56, 35, 17, 18, 19,
15, 16, 42, 23, 6, 34, 45, 57, 36, 13, 25, 54, 47, 44, 20, 8, 31, 52,
55, 39, 11, 28, 22, 5, 33, 49, 41, 32, 48, 40, 21, 9, 37, 14, 26, 50,
58, 27, 51, 59, 38, 10],
[ 3, 2, 1, 0, 4, 8, 31, 47, 44, 20, 13, 25, 54, 57, 36, 18, 19, 15,
16, 17, 43, 24, 7, 30, 46, 58, 37, 14, 26, 50, 48, 40, 21, 9, 32, 53,
56, 35, 12, 29, 23, 6, 34, 45, 42, 33, 49, 41, 22, 5, 38, 10, 27, 51,
59, 28, 52, 55, 39, 11],
[ 4, 3, 2, 1, 0, 9, 32, 48, 40, 21, 14, 26, 50, 58, 37, 19, 15, 16,
17, 18, 44, 20, 8, 31, 47, 59, 38, 10, 27, 51, 49, 41, 22, 5, 33, 54,
57, 36, 13, 25, 24, 7, 30, 46, 43, 34, 45, 42, 23, 6, 39, 11, 28, 52,
55, 29, 53, 56, 35, 12],
[ 5, 33, 49, 41, 22, 0, 4, 3, 2, 1, 15, 16, 17, 18, 19, 10, 27, 51,
59, 38, 45, 42, 23, 6, 34, 50, 58, 37, 14, 26, 40, 21, 9, 32, 48, 55,
39, 11, 28, 52, 25, 54, 57, 36, 13, 35, 12, 29, 53, 56, 30, 46, 43, 24,
7, 20, 8, 31, 47, 44],
[ 6, 34, 45, 42, 23, 1, 0, 4, 3, 2, 16, 17, 18, 19, 15, 11, 28, 52,
55, 39, 46, 43, 24, 7, 30, 51, 59, 38, 10, 27, 41, 22, 5, 33, 49, 56,
35, 12, 29, 53, 26, 50, 58, 37, 14, 36, 13, 25, 54, 57, 31, 47, 44, 20,
8, 21, 9, 32, 48, 40],
[ 7, 30, 46, 43, 24, 2, 1, 0, 4, 3, 17, 18, 19, 15, 16, 12, 29, 53,
56, 35, 47, 44, 20, 8, 31, 52, 55, 39, 11, 28, 42, 23, 6, 34, 45, 57,
36, 13, 25, 54, 27, 51, 59, 38, 10, 37, 14, 26, 50, 58, 32, 48, 40, 21,
9, 22, 5, 33, 49, 41],
[ 8, 31, 47, 44, 20, 3, 2, 1, 0, 4, 18, 19, 15, 16, 17, 13, 25, 54,
57, 36, 48, 40, 21, 9, 32, 53, 56, 35, 12, 29, 43, 24, 7, 30, 46, 58,
37, 14, 26, 50, 28, 52, 55, 39, 11, 38, 10, 27, 51, 59, 33, 49, 41, 22,
5, 23, 6, 34, 45, 42],
[ 9, 32, 48, 40, 21, 4, 3, 2, 1, 0, 19, 15, 16, 17, 18, 14, 26, 50,
58, 37, 49, 41, 22, 5, 33, 54, 57, 36, 13, 25, 44, 20, 8, 31, 47, 59,
38, 10, 27, 51, 29, 53, 56, 35, 12, 39, 11, 28, 52, 55, 34, 45, 42, 23,
6, 24, 7, 30, 46, 43],
[10, 27, 51, 59, 38, 15, 16, 17, 18, 19, 0, 4, 3, 2, 1, 5, 33, 49,
41, 22, 50, 58, 37, 14, 26, 45, 42, 23, 6, 34, 55, 39, 11, 28, 52, 40,
21, 9, 32, 48, 30, 46, 43, 24, 7, 20, 8, 31, 47, 44, 25, 54, 57, 36,
13, 35, 12, 29, 53, 56],
[11, 28, 52, 55, 39, 16, 17, 18, 19, 15, 1, 0, 4, 3, 2, 6, 34, 45,
42, 23, 51, 59, 38, 10, 27, 46, 43, 24, 7, 30, 56, 35, 12, 29, 53, 41,
22, 5, 33, 49, 31, 47, 44, 20, 8, 21, 9, 32, 48, 40, 26, 50, 58, 37,
14, 36, 13, 25, 54, 57],
[12, 29, 53, 56, 35, 17, 18, 19, 15, 16, 2, 1, 0, 4, 3, 7, 30, 46,
43, 24, 52, 55, 39, 11, 28, 47, 44, 20, 8, 31, 57, 36, 13, 25, 54, 42,
23, 6, 34, 45, 32, 48, 40, 21, 9, 22, 5, 33, 49, 41, 27, 51, 59, 38,
10, 37, 14, 26, 50, 58],
[13, 25, 54, 57, 36, 18, 19, 15, 16, 17, 3, 2, 1, 0, 4, 8, 31, 47,
44, 20, 53, 56, 35, 12, 29, 48, 40, 21, 9, 32, 58, 37, 14, 26, 50, 43,
24, 7, 30, 46, 33, 49, 41, 22, 5, 23, 6, 34, 45, 42, 28, 52, 55, 39,
11, 38, 10, 27, 51, 59],
[14, 26, 50, 58, 37, 19, 15, 16, 17, 18, 4, 3, 2, 1, 0, 9, 32, 48,
40, 21, 54, 57, 36, 13, 25, 49, 41, 22, 5, 33, 59, 38, 10, 27, 51, 44,
20, 8, 31, 47, 34, 45, 42, 23, 6, 24, 7, 30, 46, 43, 29, 53, 56, 35,
12, 39, 11, 28, 52, 55],
[15, 16, 17, 18, 19, 10, 27, 51, 59, 38, 5, 33, 49, 41, 22, 0, 4, 3,
2, 1, 55, 39, 11, 28, 52, 40, 21, 9, 32, 48, 50, 58, 37, 14, 26, 45,
42, 23, 6, 34, 35, 12, 29, 53, 56, 25, 54, 57, 36, 13, 20, 8, 31, 47,
44, 30, 46, 43, 24, 7],
[16, 17, 18, 19, 15, 11, 28, 52, 55, 39, 6, 34, 45, 42, 23, 1, 0, 4,
3, 2, 56, 35, 12, 29, 53, 41, 22, 5, 33, 49, 51, 59, 38, 10, 27, 46,
43, 24, 7, 30, 36, 13, 25, 54, 57, 26, 50, 58, 37, 14, 21, 9, 32, 48,
40, 31, 47, 44, 20, 8],
[17, 18, 19, 15, 16, 12, 29, 53, 56, 35, 7, 30, 46, 43, 24, 2, 1, 0,
4, 3, 57, 36, 13, 25, 54, 42, 23, 6, 34, 45, 52, 55, 39, 11, 28, 47,
44, 20, 8, 31, 37, 14, 26, 50, 58, 27, 51, 59, 38, 10, 22, 5, 33, 49,
41, 32, 48, 40, 21, 9],
[18, 19, 15, 16, 17, 13, 25, 54, 57, 36, 8, 31, 47, 44, 20, 3, 2, 1,
0, 4, 58, 37, 14, 26, 50, 43, 24, 7, 30, 46, 53, 56, 35, 12, 29, 48,
40, 21, 9, 32, 38, 10, 27, 51, 59, 28, 52, 55, 39, 11, 23, 6, 34, 45,
42, 33, 49, 41, 22, 5],
[19, 15, 16, 17, 18, 14, 26, 50, 58, 37, 9, 32, 48, 40, 21, 4, 3, 2,
1, 0, 59, 38, 10, 27, 51, 44, 20, 8, 31, 47, 54, 57, 36, 13, 25, 49,
41, 22, 5, 33, 39, 11, 28, 52, 55, 29, 53, 56, 35, 12, 24, 7, 30, 46,
43, 34, 45, 42, 23, 6],
[20, 8, 31, 47, 44, 30, 46, 43, 24, 7, 35, 12, 29, 53, 56, 25, 54, 57,
36, 13, 0, 4, 3, 2, 1, 5, 33, 49, 41, 22, 10, 27, 51, 59, 38, 15,
16, 17, 18, 19, 40, 21, 9, 32, 48, 55, 39, 11, 28, 52, 45, 42, 23, 6,
34, 50, 58, 37, 14, 26],
[21, 9, 32, 48, 40, 31, 47, 44, 20, 8, 36, 13, 25, 54, 57, 26, 50, 58,
37, 14, 1, 0, 4, 3, 2, 6, 34, 45, 42, 23, 11, 28, 52, 55, 39, 16,
17, 18, 19, 15, 41, 22, 5, 33, 49, 56, 35, 12, 29, 53, 46, 43, 24, 7,
30, 51, 59, 38, 10, 27],
[22, 5, 33, 49, 41, 32, 48, 40, 21, 9, 37, 14, 26, 50, 58, 27, 51, 59,
38, 10, 2, 1, 0, 4, 3, 7, 30, 46, 43, 24, 12, 29, 53, 56, 35, 17,
18, 19, 15, 16, 42, 23, 6, 34, 45, 57, 36, 13, 25, 54, 47, 44, 20, 8,
31, 52, 55, 39, 11, 28],
[23, 6, 34, 45, 42, 33, 49, 41, 22, 5, 38, 10, 27, 51, 59, 28, 52, 55,
39, 11, 3, 2, 1, 0, 4, 8, 31, 47, 44, 20, 13, 25, 54, 57, 36, 18,
19, 15, 16, 17, 43, 24, 7, 30, 46, 58, 37, 14, 26, 50, 48, 40, 21, 9,
32, 53, 56, 35, 12, 29],
[24, 7, 30, 46, 43, 34, 45, 42, 23, 6, 39, 11, 28, 52, 55, 29, 53, 56,
35, 12, 4, 3, 2, 1, 0, 9, 32, 48, 40, 21, 14, 26, 50, 58, 37, 19,
15, 16, 17, 18, 44, 20, 8, 31, 47, 59, 38, 10, 27, 51, 49, 41, 22, 5,
33, 54, 57, 36, 13, 25],
[25, 54, 57, 36, 13, 35, 12, 29, 53, 56, 30, 46, 43, 24, 7, 20, 8, 31,
47, 44, 5, 33, 49, 41, 22, 0, 4, 3, 2, 1, 15, 16, 17, 18, 19, 10,
27, 51, 59, 38, 45, 42, 23, 6, 34, 50, 58, 37, 14, 26, 40, 21, 9, 32,
48, 55, 39, 11, 28, 52],
[26, 50, 58, 37, 14, 36, 13, 25, 54, 57, 31, 47, 44, 20, 8, 21, 9, 32,
48, 40, 6, 34, 45, 42, 23, 1, 0, 4, 3, 2, 16, 17, 18, 19, 15, 11,
28, 52, 55, 39, 46, 43, 24, 7, 30, 51, 59, 38, 10, 27, 41, 22, 5, 33,
49, 56, 35, 12, 29, 53],
[27, 51, 59, 38, 10, 37, 14, 26, 50, 58, 32, 48, 40, 21, 9, 22, 5, 33,
49, 41, 7, 30, 46, 43, 24, 2, 1, 0, 4, 3, 17, 18, 19, 15, 16, 12,
29, 53, 56, 35, 47, 44, 20, 8, 31, 52, 55, 39, 11, 28, 42, 23, 6, 34,
45, 57, 36, 13, 25, 54],
[28, 52, 55, 39, 11, 38, 10, 27, 51, 59, 33, 49, 41, 22, 5, 23, 6, 34,
45, 42, 8, 31, 47, 44, 20, 3, 2, 1, 0, 4, 18, 19, 15, 16, 17, 13,
25, 54, 57, 36, 48, 40, 21, 9, 32, 53, 56, 35, 12, 29, 43, 24, 7, 30,
46, 58, 37, 14, 26, 50],
[29, 53, 56, 35, 12, 39, 11, 28, 52, 55, 34, 45, 42, 23, 6, 24, 7, 30,
46, 43, 9, 32, 48, 40, 21, 4, 3, 2, 1, 0, 19, 15, 16, 17, 18, 14,
26, 50, 58, 37, 49, 41, 22, 5, 33, 54, 57, 36, 13, 25, 44, 20, 8, 31,
47, 59, 38, 10, 27, 51],
[30, 46, 43, 24, 7, 20, 8, 31, 47, 44, 25, 54, 57, 36, 13, 35, 12, 29,
53, 56, 10, 27, 51, 59, 38, 15, 16, 17, 18, 19, 0, 4, 3, 2, 1, 5,
33, 49, 41, 22, 50, 58, 37, 14, 26, 45, 42, 23, 6, 34, 55, 39, 11, 28,
52, 40, 21, 9, 32, 48],
[31, 47, 44, 20, 8, 21, 9, 32, 48, 40, 26, 50, 58, 37, 14, 36, 13, 25,
54, 57, 11, 28, 52, 55, 39, 16, 17, 18, 19, 15, 1, 0, 4, 3, 2, 6,
34, 45, 42, 23, 51, 59, 38, 10, 27, 46, 43, 24, 7, 30, 56, 35, 12, 29,
53, 41, 22, 5, 33, 49],
[32, 48, 40, 21, 9, 22, 5, 33, 49, 41, 27, 51, 59, 38, 10, 37, 14, 26,
50, 58, 12, 29, 53, 56, 35, 17, 18, 19, 15, 16, 2, 1, 0, 4, 3, 7,
30, 46, 43, 24, 52, 55, 39, 11, 28, 47, 44, 20, 8, 31, 57, 36, 13, 25,
54, 42, 23, 6, 34, 45],
[33, 49, 41, 22, 5, 23, 6, 34, 45, 42, 28, 52, 55, 39, 11, 38, 10, 27,
51, 59, 13, 25, 54, 57, 36, 18, 19, 15, 16, 17, 3, 2, 1, 0, 4, 8,
31, 47, 44, 20, 53, 56, 35, 12, 29, 48, 40, 21, 9, 32, 58, 37, 14, 26,
50, 43, 24, 7, 30, 46],
[34, 45, 42, 23, 6, 24, 7, 30, 46, 43, 29, 53, 56, 35, 12, 39, 11, 28,
52, 55, 14, 26, 50, 58, 37, 19, 15, 16, 17, 18, 4, 3, 2, 1, 0, 9,
32, 48, 40, 21, 54, 57, 36, 13, 25, 49, 41, 22, 5, 33, 59, 38, 10, 27,
51, 44, 20, 8, 31, 47],
[35, 12, 29, 53, 56, 25, 54, 57, 36, 13, 20, 8, 31, 47, 44, 30, 46, 43,
24, 7, 15, 16, 17, 18, 19, 10, 27, 51, 59, 38, 5, 33, 49, 41, 22, 0,
4, 3, 2, 1, 55, 39, 11, 28, 52, 40, 21, 9, 32, 48, 50, 58, 37, 14,
26, 45, 42, 23, 6, 34],
[36, 13, 25, 54, 57, 26, 50, 58, 37, 14, 21, 9, 32, 48, 40, 31, 47, 44,
20, 8, 16, 17, 18, 19, 15, 11, 28, 52, 55, 39, 6, 34, 45, 42, 23, 1,
0, 4, 3, 2, 56, 35, 12, 29, 53, 41, 22, 5, 33, 49, 51, 59, 38, 10,
27, 46, 43, 24, 7, 30],
[37, 14, 26, 50, 58, 27, 51, 59, 38, 10, 22, 5, 33, 49, 41, 32, 48, 40,
21, 9, 17, 18, 19, 15, 16, 12, 29, 53, 56, 35, 7, 30, 46, 43, 24, 2,
1, 0, 4, 3, 57, 36, 13, 25, 54, 42, 23, 6, 34, 45, 52, 55, 39, 11,
28, 47, 44, 20, 8, 31],
[38, 10, 27, 51, 59, 28, 52, 55, 39, 11, 23, 6, 34, 45, 42, 33, 49, 41,
22, 5, 18, 19, 15, 16, 17, 13, 25, 54, 57, 36, 8, 31, 47, 44, 20, 3,
2, 1, 0, 4, 58, 37, 14, 26, 50, 43, 24, 7, 30, 46, 53, 56, 35, 12,
29, 48, 40, 21, 9, 32],
[39, 11, 28, 52, 55, 29, 53, 56, 35, 12, 24, 7, 30, 46, 43, 34, 45, 42,
23, 6, 19, 15, 16, 17, 18, 14, 26, 50, 58, 37, 9, 32, 48, 40, 21, 4,
3, 2, 1, 0, 59, 38, 10, 27, 51, 44, 20, 8, 31, 47, 54, 57, 36, 13,
25, 49, 41, 22, 5, 33],
[40, 21, 9, 32, 48, 55, 39, 11, 28, 52, 45, 42, 23, 6, 34, 50, 58, 37,
14, 26, 20, 8, 31, 47, 44, 30, 46, 43, 24, 7, 35, 12, 29, 53, 56, 25,
54, 57, 36, 13, 0, 4, 3, 2, 1, 5, 33, 49, 41, 22, 10, 27, 51, 59,
38, 15, 16, 17, 18, 19],
[41, 22, 5, 33, 49, 56, 35, 12, 29, 53, 46, 43, 24, 7, 30, 51, 59, 38,
10, 27, 21, 9, 32, 48, 40, 31, 47, 44, 20, 8, 36, 13, 25, 54, 57, 26,
50, 58, 37, 14, 1, 0, 4, 3, 2, 6, 34, 45, 42, 23, 11, 28, 52, 55,
39, 16, 17, 18, 19, 15],
[42, 23, 6, 34, 45, 57, 36, 13, 25, 54, 47, 44, 20, 8, 31, 52, 55, 39,
11, 28, 22, 5, 33, 49, 41, 32, 48, 40, 21, 9, 37, 14, 26, 50, 58, 27,
51, 59, 38, 10, 2, 1, 0, 4, 3, 7, 30, 46, 43, 24, 12, 29, 53, 56,
35, 17, 18, 19, 15, 16],
[43, 24, 7, 30, 46, 58, 37, 14, 26, 50, 48, 40, 21, 9, 32, 53, 56, 35,
12, 29, 23, 6, 34, 45, 42, 33, 49, 41, 22, 5, 38, 10, 27, 51, 59, 28,
52, 55, 39, 11, 3, 2, 1, 0, 4, 8, 31, 47, 44, 20, 13, 25, 54, 57,
36, 18, 19, 15, 16, 17],
[44, 20, 8, 31, 47, 59, 38, 10, 27, 51, 49, 41, 22, 5, 33, 54, 57, 36,
13, 25, 24, 7, 30, 46, 43, 34, 45, 42, 23, 6, 39, 11, 28, 52, 55, 29,
53, 56, 35, 12, 4, 3, 2, 1, 0, 9, 32, 48, 40, 21, 14, 26, 50, 58,
37, 19, 15, 16, 17, 18],
[45, 42, 23, 6, 34, 50, 58, 37, 14, 26, 40, 21, 9, 32, 48, 55, 39, 11,
28, 52, 25, 54, 57, 36, 13, 35, 12, 29, 53, 56, 30, 46, 43, 24, 7, 20,
8, 31, 47, 44, 5, 33, 49, 41, 22, 0, 4, 3, 2, 1, 15, 16, 17, 18,
19, 10, 27, 51, 59, 38],
[46, 43, 24, 7, 30, 51, 59, 38, 10, 27, 41, 22, 5, 33, 49, 56, 35, 12,
29, 53, 26, 50, 58, 37, 14, 36, 13, 25, 54, 57, 31, 47, 44, 20, 8, 21,
9, 32, 48, 40, 6, 34, 45, 42, 23, 1, 0, 4, 3, 2, 16, 17, 18, 19,
15, 11, 28, 52, 55, 39],
[47, 44, 20, 8, 31, 52, 55, 39, 11, 28, 42, 23, 6, 34, 45, 57, 36, 13,
25, 54, 27, 51, 59, 38, 10, 37, 14, 26, 50, 58, 32, 48, 40, 21, 9, 22,
5, 33, 49, 41, 7, 30, 46, 43, 24, 2, 1, 0, 4, 3, 17, 18, 19, 15,
16, 12, 29, 53, 56, 35],
[48, 40, 21, 9, 32, 53, 56, 35, 12, 29, 43, 24, 7, 30, 46, 58, 37, 14,
26, 50, 28, 52, 55, 39, 11, 38, 10, 27, 51, 59, 33, 49, 41, 22, 5, 23,
6, 34, 45, 42, 8, 31, 47, 44, 20, 3, 2, 1, 0, 4, 18, 19, 15, 16,
17, 13, 25, 54, 57, 36],
[49, 41, 22, 5, 33, 54, 57, 36, 13, 25, 44, 20, 8, 31, 47, 59, 38, 10,
27, 51, 29, 53, 56, 35, 12, 39, 11, 28, 52, 55, 34, 45, 42, 23, 6, 24,
7, 30, 46, 43, 9, 32, 48, 40, 21, 4, 3, 2, 1, 0, 19, 15, 16, 17,
18, 14, 26, 50, 58, 37],
[50, 58, 37, 14, 26, 45, 42, 23, 6, 34, 55, 39, 11, 28, 52, 40, 21, 9,
32, 48, 30, 46, 43, 24, 7, 20, 8, 31, 47, 44, 25, 54, 57, 36, 13, 35,
12, 29, 53, 56, 10, 27, 51, 59, 38, 15, 16, 17, 18, 19, 0, 4, 3, 2,
1, 5, 33, 49, 41, 22],
[51, 59, 38, 10, 27, 46, 43, 24, 7, 30, 56, 35, 12, 29, 53, 41, 22, 5,
33, 49, 31, 47, 44, 20, 8, 21, 9, 32, 48, 40, 26, 50, 58, 37, 14, 36,
13, 25, 54, 57, 11, 28, 52, 55, 39, 16, 17, 18, 19, 15, 1, 0, 4, 3,
2, 6, 34, 45, 42, 23],
[52, 55, 39, 11, 28, 47, 44, 20, 8, 31, 57, 36, 13, 25, 54, 42, 23, 6,
34, 45, 32, 48, 40, 21, 9, 22, 5, 33, 49, 41, 27, 51, 59, 38, 10, 37,
14, 26, 50, 58, 12, 29, 53, 56, 35, 17, 18, 19, 15, 16, 2, 1, 0, 4,
3, 7, 30, 46, 43, 24],
[53, 56, 35, 12, 29, 48, 40, 21, 9, 32, 58, 37, 14, 26, 50, 43, 24, 7,
30, 46, 33, 49, 41, 22, 5, 23, 6, 34, 45, 42, 28, 52, 55, 39, 11, 38,
10, 27, 51, 59, 13, 25, 54, 57, 36, 18, 19, 15, 16, 17, 3, 2, 1, 0,
4, 8, 31, 47, 44, 20],
[54, 57, 36, 13, 25, 49, 41, 22, 5, 33, 59, 38, 10, 27, 51, 44, 20, 8,
31, 47, 34, 45, 42, 23, 6, 24, 7, 30, 46, 43, 29, 53, 56, 35, 12, 39,
11, 28, 52, 55, 14, 26, 50, 58, 37, 19, 15, 16, 17, 18, 4, 3, 2, 1,
0, 9, 32, 48, 40, 21],
[55, 39, 11, 28, 52, 40, 21, 9, 32, 48, 50, 58, 37, 14, 26, 45, 42, 23,
6, 34, 35, 12, 29, 53, 56, 25, 54, 57, 36, 13, 20, 8, 31, 47, 44, 30,
46, 43, 24, 7, 15, 16, 17, 18, 19, 10, 27, 51, 59, 38, 5, 33, 49, 41,
22, 0, 4, 3, 2, 1],
[56, 35, 12, 29, 53, 41, 22, 5, 33, 49, 51, 59, 38, 10, 27, 46, 43, 24,
7, 30, 36, 13, 25, 54, 57, 26, 50, 58, 37, 14, 21, 9, 32, 48, 40, 31,
47, 44, 20, 8, 16, 17, 18, 19, 15, 11, 28, 52, 55, 39, 6, 34, 45, 42,
23, 1, 0, 4, 3, 2],
[57, 36, 13, 25, 54, 42, 23, 6, 34, 45, 52, 55, 39, 11, 28, 47, 44, 20,
8, 31, 37, 14, 26, 50, 58, 27, 51, 59, 38, 10, 22, 5, 33, 49, 41, 32,
48, 40, 21, 9, 17, 18, 19, 15, 16, 12, 29, 53, 56, 35, 7, 30, 46, 43,
24, 2, 1, 0, 4, 3],
[58, 37, 14, 26, 50, 43, 24, 7, 30, 46, 53, 56, 35, 12, 29, 48, 40, 21,
9, 32, 38, 10, 27, 51, 59, 28, 52, 55, 39, 11, 23, 6, 34, 45, 42, 33,
49, 41, 22, 5, 18, 19, 15, 16, 17, 13, 25, 54, 57, 36, 8, 31, 47, 44,
20, 3, 2, 1, 0, 4],
[59, 38, 10, 27, 51, 44, 20, 8, 31, 47, 54, 57, 36, 13, 25, 49, 41, 22,
5, 33, 39, 11, 28, 52, 55, 29, 53, 56, 35, 12, 24, 7, 30, 46, 43, 34,
45, 42, 23, 6, 19, 15, 16, 17, 18, 14, 26, 50, 58, 37, 9, 32, 48, 40,
21, 4, 3, 2, 1, 0]])
Rs = torch.zeros(60,3,3)
Rs[0]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
Rs[1]=torch.tensor([[ 0.500000,-0.809017,0.309017],[ 0.809017 ,0.309017,-0.500000],[ 0.309017, 0.500000,0.809017]])
Rs[2]=torch.tensor([[-0.309017,-0.500000,0.809017],[ 0.500000,-0.809017,-0.309017],[ 0.809017, 0.309017,0.500000]])
Rs[3]=torch.tensor([[-0.309017, 0.500000,0.809017],[-0.500000,-0.809017,0.309017],[ 0.809017,-0.309017,0.500000]])
Rs[4]=torch.tensor([[ 0.500000, 0.809017,0.309017],[-0.809017, 0.309017,0.500000],[ 0.309017,-0.500000,0.809017]])
Rs[5]=torch.tensor([[-0.809017, 0.309017,0.500000],[ 0.309017,-0.500000,0.809017],[ 0.500000, 0.809017,0.309017]])
Rs[6]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 0.000000, 0.000000,1.000000],[ 1.000000, 0.000000,0.000000]])
Rs[7]=torch.tensor([[ 0.809017 ,0.309017,-0.500000],[ 0.309017, 0.500000,0.809017],[ 0.500000,-0.809017,0.309017]])
Rs[8]=torch.tensor([[ 0.500000,-0.809017,-0.309017],[ 0.809017, 0.309017,0.500000],[-0.309017,-0.500000,0.809017]])
Rs[9]=torch.tensor([[-0.500000,-0.809017,0.309017],[ 0.809017,-0.309017,0.500000],[-0.309017, 0.500000,0.809017]])
Rs[10]=torch.tensor([[-0.500000,-0.809017,0.309017],[-0.809017 ,0.309017,-0.500000],[ 0.309017,-0.500000,-0.809017]])
Rs[11]=torch.tensor([[-0.809017, 0.309017,0.500000],[-0.309017 ,0.500000,-0.809017],[-0.500000,-0.809017,-0.309017]])
Rs[12]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[-1.000000, 0.000000,0.000000]])
Rs[13]=torch.tensor([[ 0.809017 ,0.309017,-0.500000],[-0.309017,-0.500000,-0.809017],[-0.500000 ,0.809017,-0.309017]])
Rs[14]=torch.tensor([[ 0.500000,-0.809017,-0.309017],[-0.809017,-0.309017,-0.500000],[ 0.309017 ,0.500000,-0.809017]])
Rs[15]=torch.tensor([[ 0.309017 ,0.500000,-0.809017],[ 0.500000,-0.809017,-0.309017],[-0.809017,-0.309017,-0.500000]])
Rs[16]=torch.tensor([[ 0.309017,-0.500000,-0.809017],[-0.500000,-0.809017,0.309017],[-0.809017 ,0.309017,-0.500000]])
Rs[17]=torch.tensor([[-0.500000,-0.809017,-0.309017],[-0.809017, 0.309017,0.500000],[-0.309017 ,0.500000,-0.809017]])
Rs[18]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
Rs[19]=torch.tensor([[-0.500000 ,0.809017,-0.309017],[ 0.809017 ,0.309017,-0.500000],[-0.309017,-0.500000,-0.809017]])
Rs[20]=torch.tensor([[-0.500000,-0.809017,-0.309017],[ 0.809017,-0.309017,-0.500000],[ 0.309017,-0.500000,0.809017]])
Rs[21]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
Rs[22]=torch.tensor([[-0.500000 ,0.809017,-0.309017],[-0.809017,-0.309017,0.500000],[ 0.309017, 0.500000,0.809017]])
Rs[23]=torch.tensor([[ 0.309017 ,0.500000,-0.809017],[-0.500000, 0.809017,0.309017],[ 0.809017, 0.309017,0.500000]])
Rs[24]=torch.tensor([[ 0.309017,-0.500000,-0.809017],[ 0.500000 ,0.809017,-0.309017],[ 0.809017,-0.309017,0.500000]])
Rs[25]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[-1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000]])
Rs[26]=torch.tensor([[-0.309017,-0.500000,-0.809017],[-0.500000 ,0.809017,-0.309017],[ 0.809017 ,0.309017,-0.500000]])
Rs[27]=torch.tensor([[-0.809017,-0.309017,-0.500000],[ 0.309017 ,0.500000,-0.809017],[ 0.500000,-0.809017,-0.309017]])
Rs[28]=torch.tensor([[-0.809017 ,0.309017,-0.500000],[ 0.309017,-0.500000,-0.809017],[-0.500000,-0.809017,0.309017]])
Rs[29]=torch.tensor([[-0.309017 ,0.500000,-0.809017],[-0.500000,-0.809017,-0.309017],[-0.809017, 0.309017,0.500000]])
Rs[30]=torch.tensor([[ 0.809017, 0.309017,0.500000],[-0.309017,-0.500000,0.809017],[ 0.500000,-0.809017,-0.309017]])
Rs[31]=torch.tensor([[ 0.809017,-0.309017,0.500000],[-0.309017, 0.500000,0.809017],[-0.500000,-0.809017,0.309017]])
Rs[32]=torch.tensor([[ 0.309017,-0.500000,0.809017],[ 0.500000, 0.809017,0.309017],[-0.809017, 0.309017,0.500000]])
Rs[33]=torch.tensor([[ 0.000000, 0.000000,1.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000]])
Rs[34]=torch.tensor([[ 0.309017, 0.500000,0.809017],[ 0.500000,-0.809017,0.309017],[ 0.809017 ,0.309017,-0.500000]])
Rs[35]=torch.tensor([[-0.309017, 0.500000,0.809017],[ 0.500000 ,0.809017,-0.309017],[-0.809017 ,0.309017,-0.500000]])
Rs[36]=torch.tensor([[ 0.500000, 0.809017,0.309017],[ 0.809017,-0.309017,-0.500000],[-0.309017 ,0.500000,-0.809017]])
Rs[37]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
Rs[38]=torch.tensor([[ 0.500000,-0.809017,0.309017],[-0.809017,-0.309017,0.500000],[-0.309017,-0.500000,-0.809017]])
Rs[39]=torch.tensor([[-0.309017,-0.500000,0.809017],[-0.500000, 0.809017,0.309017],[-0.809017,-0.309017,-0.500000]])
Rs[40]=torch.tensor([[-0.500000, 0.809017,0.309017],[-0.809017,-0.309017,-0.500000],[-0.309017,-0.500000,0.809017]])
Rs[41]=torch.tensor([[ 0.500000 ,0.809017,-0.309017],[-0.809017 ,0.309017,-0.500000],[-0.309017, 0.500000,0.809017]])
Rs[42]=torch.tensor([[ 0.809017,-0.309017,-0.500000],[-0.309017 ,0.500000,-0.809017],[ 0.500000, 0.809017,0.309017]])
Rs[43]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[ 1.000000, 0.000000,0.000000]])
Rs[44]=torch.tensor([[-0.809017,-0.309017,0.500000],[-0.309017,-0.500000,-0.809017],[ 0.500000,-0.809017,0.309017]])
Rs[45]=torch.tensor([[ 0.809017,-0.309017,0.500000],[ 0.309017,-0.500000,-0.809017],[ 0.500000 ,0.809017,-0.309017]])
Rs[46]=torch.tensor([[ 0.309017,-0.500000,0.809017],[-0.500000,-0.809017,-0.309017],[ 0.809017,-0.309017,-0.500000]])
Rs[47]=torch.tensor([[ 0.000000, 0.000000,1.000000],[-1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000]])
Rs[48]=torch.tensor([[ 0.309017, 0.500000,0.809017],[-0.500000 ,0.809017,-0.309017],[-0.809017,-0.309017,0.500000]])
Rs[49]=torch.tensor([[ 0.809017, 0.309017,0.500000],[ 0.309017 ,0.500000,-0.809017],[-0.500000, 0.809017,0.309017]])
Rs[50]=torch.tensor([[-0.309017 ,0.500000,-0.809017],[ 0.500000, 0.809017,0.309017],[ 0.809017,-0.309017,-0.500000]])
Rs[51]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000]])
Rs[52]=torch.tensor([[-0.309017,-0.500000,-0.809017],[ 0.500000,-0.809017,0.309017],[-0.809017,-0.309017,0.500000]])
Rs[53]=torch.tensor([[-0.809017,-0.309017,-0.500000],[-0.309017,-0.500000,0.809017],[-0.500000, 0.809017,0.309017]])
Rs[54]=torch.tensor([[-0.809017 ,0.309017,-0.500000],[-0.309017, 0.500000,0.809017],[ 0.500000 ,0.809017,-0.309017]])
Rs[55]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 0.000000, 0.000000,1.000000],[-1.000000, 0.000000,0.000000]])
Rs[56]=torch.tensor([[-0.809017,-0.309017,0.500000],[ 0.309017, 0.500000,0.809017],[-0.500000 ,0.809017,-0.309017]])
Rs[57]=torch.tensor([[-0.500000, 0.809017,0.309017],[ 0.809017, 0.309017,0.500000],[ 0.309017 ,0.500000,-0.809017]])
Rs[58]=torch.tensor([[ 0.500000 ,0.809017,-0.309017],[ 0.809017,-0.309017,0.500000],[ 0.309017,-0.500000,-0.809017]])
Rs[59]=torch.tensor([[ 0.809017,-0.309017,-0.500000],[ 0.309017,-0.500000,0.809017],[-0.500000,-0.809017,-0.309017]])
est_radius = 10.0*SYMA
offset = torch.tensor([ 1.0,0.0,0.0 ])
offset = est_radius * offset / torch.linalg.norm(offset)
metasymm = (
[torch.arange(60)],
[6]
)
else:
print ("Unknown symmetry",symmid)
assert False
return symmatrix,Rs,metasymm,offset

181
rf2aa/tensor_util.py Normal file
View file

@ -0,0 +1,181 @@
import torch
import numpy as np
import json
from deepdiff import DeepDiff
import pprint
import assertpy
import dataclasses
from collections import OrderedDict
import contextlib
from deepdiff.operator import BaseOperator
def assert_shape(t, s):
assertpy.assert_that(tuple(t.shape)).is_equal_to(s)
def assert_same_shape(t, s):
assertpy.assert_that(tuple(t.shape)).is_equal_to(tuple(s.shape))
class ExceptionLogger(contextlib.AbstractContextManager):
def __exit__(self, exc_type, exc_value, traceback):
if exc_type:
print("***Logging exception {}***".format((exc_type, exc_value,
traceback)))
def assert_equal(got, want):
assertpy.assert_that(got.dtype).is_equal_to(want.dtype)
assertpy.assert_that(got.shape).is_equal_to(want.shape)
is_eq = got.nan_to_num()==want.nan_to_num()
unequal_idx = torch.nonzero(~is_eq)
unequal_got = got[~is_eq]
unequal_want = want[~is_eq]
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_want, unequal_got))[:3]
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
if torch.numel(got) < 10:
msg = f'got {got}, want: {want}'
assert len(unequal_idx) == 0, msg
def assert_close(got, want, atol=1e-4, rtol=1e-8):
got = torch.nan_to_num(got)
want = torch.nan_to_num(want)
if got.shape != want.shape:
raise ValueError(f'Wrong shapes: got shape {got.shape} want shape {want.shape}')
elif not torch.allclose(got, want, atol=atol, rtol=rtol):
maximum_difference = torch.abs(got - want).max().item()
indices_different = torch.nonzero(got != want)
raise ValueError(f'Maximum difference: {maximum_difference}, indices different: {indices_different}')
def cpu(e):
if isinstance(e, dict):
return {cpu(k): cpu(v) for k,v in e.items()}
if isinstance(e, list) or isinstance(e, tuple):
return tuple(cpu(i) for i in e)
if hasattr(e, 'cpu'):
return e.cpu().detach()
return e
# Dataclass functions
def to_ordered_dict(dc):
return OrderedDict((field.name, getattr(dc, field.name)) for field in dataclasses.fields(dc))
def to_device(dc, device):
d = to_ordered_dict(dc)
for k, v in d.items():
if v is not None:
setattr(dc, k, v.to(device))
def shapes(dc):
d = to_ordered_dict(dc)
return {k:v.shape if hasattr(v, 'shape') else None for k,v in d.items()}
def pprint_obj(obj):
pprint.pprint(obj.__dict__, indent=4)
def assert_squeeze(t):
assert t.shape[0] == 1, f'{t.shape}[0] != 1'
return t[0]
def apply_to_tensors(e, op):
# if isinstance(e, dataclasses.)
if dataclasses.is_dataclass(e):
# return to_ordered_dict
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
if isinstance(e, dict):
return {k: apply_to_tensors(v, op) for k,v in e.items()}
if isinstance(e, list) or isinstance(e, tuple):
return tuple(apply_to_tensors(i, op) for i in e)
if hasattr(e, 'cpu'):
return op(e)
return e
def apply_to_matching(e, op, filt):
# if isinstance(e, dataclasses.)
if filt(e):
return op(e)
if dataclasses.is_dataclass(e):
# return to_ordered_dict
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
if isinstance(e, dict):
return {k: apply_to_tensors(v, op) for k,v in e.items()}
if isinstance(e, list) or isinstance(e, tuple):
return tuple(apply_to_tensors(i, op) for i in e)
return e
def set_grad(t):
t.requires_grad = True
def require_grad(e):
apply_to_tensors(e, set_grad)
def get_grad(e):
return apply_to_tensors(e, lambda x: x.grad)
def info(e):
shap = apply_to_tensors(e, lambda x: x.shape)
shap = apply_to_matching(shap, str, dataclasses.is_dataclass)
return json.dumps(shap, indent=4)
def minmax(e):
return apply_to_tensors(e, lambda x: (torch.log10(torch.min(x)), torch.log10(torch.max(x))))
class TensorMatchOperator(BaseOperator):
def __init__(self, atol=1e-3, rtol=0, **kwargs):
super(TensorMatchOperator, self).__init__(**kwargs)
self.atol = atol
self.rtol = rtol
def _equal_msg(self, got, want):
if got.shape != want.shape:
return f'got shape {got.shape} want shape {want.shape}'
if got.dtype != want.dtype:
return f'got dtype {got.dtype} want dtype {want.dtype}'
if torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol).all():
return ''
is_eq = torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol)
unequal_idx = torch.nonzero(~is_eq)
unequal_got = got[~is_eq]
unequal_want = want[~is_eq]
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_got, unequal_want))[:3]
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
uneq_msg += f' fraction unequal:{unequal_got.numel()}/{got.numel()}'
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
if torch.numel(got) < 10:
msg = f'got {got}, want: {want}'
return msg
def give_up_diffing(self, level, diff_instance):
msg = self._equal_msg(level.t1, level.t2)
if msg:
print(f'got:\n{level.t1}\n\nwant:\n{level.t2}')
msg = self._equal_msg(level.t1, level.t2)
if msg:
diff_instance.custom_report_result('tensors unequal', level, {
"msg": msg
})
return True
class NumpyMatchOperator(TensorMatchOperator):
def give_up_diffing(self, level, diff_instance):
level.t1 = torch.Tensor(level.t1)
level.t2 = torch.Tensor(level.t2)
return super(NumpyMatchOperator, self).give_up_diffing(level, diff_instance)
def cmp(got, want, **kwargs):
dd = DeepDiff(got, want, custom_operators=[
NumpyMatchOperator(types=[np.ndarray], **kwargs),
TensorMatchOperator(types=[torch.Tensor], **kwargs)])
if dd:
return dd
return ''

View file

@ -0,0 +1,79 @@
import torch
import pandas as pd
import numpy as np
import itertools
from collections import OrderedDict
from hydra import initialize, compose
from rf2aa.setup_model import trainer_factory, seed_all
from rf2aa.chemical import ChemicalData as ChemData
# configurations to test
configs = ["legacy_train"]
datasets = ["compl", "na_compl", "rna", "sm_compl", "sm_compl_covale", "sm_compl_asmb"]
cfg_overrides = [
"loader_params.p_msa_mask=0.0",
"loader_params.crop=100000",
"loader_params.mintplt=0",
"loader_params.maxtplt=2"
]
def make_deterministic(seed=0):
seed_all(seed)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def setup_dataset_names():
data = {}
for name in datasets:
data[name] = [name]
return data
# set up models for regression tests
def setup_models(device="cpu"):
models, chem_cfgs = [], []
for config in configs:
with initialize(version_base=None, config_path="../config/train"):
cfg = compose(config_name=config, overrides=cfg_overrides)
# initializing the model needs the chemical DB initialized. Force a reload
ChemData.reset()
ChemData(cfg.chem_params)
trainer = trainer_factory[cfg.experiment.trainer](cfg)
seed_all()
trainer.construct_model(device=device)
models.append(trainer.model)
chem_cfgs.append(cfg.chem_params)
trainer = None
return dict(zip(configs, (zip(configs, models, chem_cfgs))))
# set up job array for regression
def setup_array(datasets, models, device="cpu"):
test_data = setup_dataset_names()
test_models = setup_models(device=device)
test_data = [test_data[dataset] for dataset in datasets]
test_models = [test_models[model] for model in models]
return (list(itertools.product(test_data, test_models)))
def random_param_init(model):
seed_all()
with torch.no_grad():
fake_state_dict = OrderedDict()
for name, param in model.model.named_parameters():
fake_state_dict[name] = torch.randn_like(param)
model.model.load_state_dict(fake_state_dict)
model.shadow.load_state_dict(fake_state_dict)
return model
def dataset_pickle_path(dataset_name):
return f"test_pickles/data/{dataset_name}_regression.pt"
def model_pickle_path(dataset_name, model_name):
return f"test_pickles/model/{model_name}_{dataset_name}_regression.pt"
def loss_pickle_path(dataset_name, model_name, loss_name):
return f"test_pickles/loss/{loss_name}_{model_name}_{dataset_name}_regression.pt"

73
rf2aa/tests/test_model.py Normal file
View file

@ -0,0 +1,73 @@
import os
import torch
import pytest
import warnings
warnings.filterwarnings("ignore")
from rf2aa.data.dataloader_adaptor import prepare_input
from rf2aa.training.recycling import run_model_forward_legacy
from rf2aa.tensor_util import assert_equal
from rf2aa.tests.test_conditions import setup_array,\
make_deterministic, dataset_pickle_path, model_pickle_path
from rf2aa.util_module import XYZConverter
from rf2aa.chemical import ChemicalData as ChemData
# goal is to test all the configs on a broad set of datasets
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
legacy_test_conditions = setup_array(["na_compl", "rna", "sm_compl", "sm_compl_covale"], ["legacy_train"], device=gpu)
@pytest.mark.parametrize("example,model", legacy_test_conditions)
def test_regression_legacy(example, model):
dataset_name, dataset_inputs, model_name, model = setup_test(example, model)
make_deterministic()
output_i = run_model_forward_legacy(model, dataset_inputs, gpu)
model_pickle = model_pickle_path(dataset_name, model_name)
output_names = ("logits_c6d", "logits_aa", "logits_pae", \
"logits_pde", "p_bind", "xyz", "alpha", "xyz_allatom", \
"lddt", "seq", "pair", "state")
if not os.path.exists(model_pickle):
torch.save(output_i, model_pickle)
else:
output_regression = torch.load(model_pickle, map_location=gpu)
for idx, output in enumerate(output_i):
got = output
want = output_regression[idx]
if output_names[idx] == "logits_c6d":
for i in range(len(want)):
got_i = got[i]
want_i = want[i]
try:
assert_equal(got_i, want_i)
except Exception as e:
raise ValueError(f"{output_names[idx]} not same for model: {model_name} on dataset: {dataset_name}") from e
elif output_names[idx] in ["alpha", "xyz_allatom", "seq", "pair", "state"]:
try:
assert torch.allclose(got, want, atol=1e-4)
except Exception as e:
raise ValueError(f"{output_names[idx]} not same for model: {model_name} on dataset: {dataset_name}") from e
else:
try:
assert_equal(got, want)
except Exception as e:
raise ValueError(f"{output_names[idx]} not same for model: {model_name} on dataset: {dataset_name}") from e
def setup_test(example, model):
model_name, model, config = model
# initialize chemical database
ChemData.reset() # force reload chemical data
ChemData(config)
model = model.to(gpu)
dataset_name = example[0]
dataloader_inputs = torch.load(dataset_pickle_path(dataset_name), map_location=gpu)
xyz_converter = XYZConverter().to(gpu)
task, item, network_input, true_crds, mask_crds, msa, mask_msa, unclamp, \
negative, symmRs, Lasu, ch_label = prepare_input(dataloader_inputs,xyz_converter, gpu)
return dataset_name, network_input, model_name, model

60
rf2aa/training/EMA.py Normal file
View file

@ -0,0 +1,60 @@
import torch
import torch.nn as nn
from collections import OrderedDict
from copy import deepcopy
import contextlib
class EMA(nn.Module):
def __init__(self, model, decay):
super().__init__()
self.decay = decay
self.model = model
self.shadow = deepcopy(self.model)
for param in self.shadow.parameters():
param.detach_()
@torch.no_grad()
def update(self):
if not self.training:
print("EMA update should only be called during training", file=stderr, flush=True)
return
model_params = OrderedDict(self.model.named_parameters())
shadow_params = OrderedDict(self.shadow.named_parameters())
# check if both model contains the same set of keys
assert model_params.keys() == shadow_params.keys()
for name, param in model_params.items():
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
if param.requires_grad:
shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))
model_buffers = OrderedDict(self.model.named_buffers())
shadow_buffers = OrderedDict(self.shadow.named_buffers())
# check if both model contains the same set of keys
assert model_buffers.keys() == shadow_buffers.keys()
for name, buffer in model_buffers.items():
# buffers are copied
shadow_buffers[name].copy_(buffer)
#fd A hack to allow non-DDP models to be passed into the Trainer
def no_sync(self):
return contextlib.nullcontext()
def forward(self, *args, **kwargs):
if self.training:
return self.model(*args, **kwargs)
else:
return self.shadow(*args, **kwargs)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

View file

@ -0,0 +1,5 @@
# for gradient checkpointing
def create_custom_forward(module, **kwargs):
def custom_forward(*inputs):
return module(*inputs, **kwargs)
return custom_forward

View file

@ -0,0 +1,71 @@
import torch
import torch.nn as nn
import numpy as np
from contextlib import ExitStack
from rf2aa.chemical import ChemicalData as ChemData
def recycle_step_legacy(ddp_model, input, n_cycle, use_amp, nograds=False, force_device=None):
if force_device is not None:
gpu = force_device
else:
gpu = ddp_model.device
xyz_prev, alpha_prev, mask_recycle = \
input["xyz_prev"], input["alpha_prev"], input["mask_recycle"]
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
for i_cycle in range(n_cycle):
with ExitStack() as stack:
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
if i_cycle < n_cycle -1 or nograds is True:
stack.enter_context(torch.no_grad())
if force_device is None:
stack.enter_context(ddp_model.no_sync())
return_raw = (i_cycle < n_cycle -1)
use_checkpoint = not nograds and (i_cycle == n_cycle -1)
input_i = add_recycle_inputs(input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
output_i = ddp_model(**input_i)
return output_i
def run_model_forward_legacy(model, network_input, device="cpu"):
""" run model forward pass, no recycling or ddp with legacy model (for tests)"""
gpu = device
xyz_prev, alpha_prev, mask_recycle = \
network_input["xyz_prev"], network_input["alpha_prev"], network_input["mask_recycle"]
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
input_i = add_recycle_inputs(network_input, output_i, 0, gpu, return_raw=False, use_checkpoint=False)
input_i["seq_unmasked"] = input_i["seq_unmasked"].to(gpu)
input_i["sctors"] = input_i["sctors"].to(gpu)
model.eval()
with torch.no_grad():
output_i = model(**input_i)
return output_i
def add_recycle_inputs(network_input, output_i, i_cycle, gpu, return_raw=False, use_checkpoint=False):
input_i = {}
for key in network_input:
if key in ['msa_latent', 'msa_full', 'seq']:
input_i[key] = network_input[key][:,i_cycle].to(gpu, non_blocking=True)
else:
input_i[key] = network_input[key]
L = input_i["msa_latent"].shape[2]
msa_prev, pair_prev, _, alpha, mask_recycle = output_i
xyz_prev = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)
input_i['msa_prev'] = msa_prev
input_i['pair_prev'] = pair_prev
input_i['xyz'] = xyz_prev
input_i['mask_recycle'] = mask_recycle
input_i['sctors'] = alpha
input_i['return_raw'] = return_raw
input_i['use_checkpoint'] = use_checkpoint
input_i.pop('xyz_prev')
input_i.pop('alpha_prev')
return input_i

1044
rf2aa/util.py Normal file

File diff suppressed because it is too large Load diff

710
rf2aa/util_module.py Normal file
View file

@ -0,0 +1,710 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from opt_einsum import contract as einsum
import copy
import dgl
import networkx as nx
from rf2aa.util import *
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.chemical import th_dih, th_ang_v
def init_lecun_normal(module, scale=1.0):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
x = torch.clamp(x, a, b)
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
return module
def init_lecun_normal_param(weight, scale=1.0):
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
normal = torch.distributions.normal.Normal(0, 1)
alpha = (a - mu) / sigma
beta = (b - mu) / sigma
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
x = torch.clamp(x, a, b)
return x
def sample_truncated_normal(shape, scale=1.0):
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
return stddev * truncated_normal(torch.rand(shape))
weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
return weight
# for gradient checkpointing
def create_custom_forward(module, **kwargs):
def custom_forward(*inputs):
return module(*inputs, **kwargs)
return custom_forward
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class Dropout(nn.Module):
# Dropout entire row or column
def __init__(self, broadcast_dim=None, p_drop=0.15):
super(Dropout, self).__init__()
# give ones with probability of 1-p_drop / zeros with p_drop
self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop]))
self.broadcast_dim=broadcast_dim
self.p_drop=p_drop
def forward(self, x):
if not self.training: # no drophead during evaluation mode
return x
shape = list(x.shape)
if not self.broadcast_dim == None:
shape[self.broadcast_dim] = 1
mask = self.sampler.sample(shape).to(x.device).view(shape)
x = mask * x / (1.0 - self.p_drop)
return x
def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5):
# Distance radial basis function
D_max = D_min + (D_count-1) * D_sigma
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
D_mu = D_mu[None,:]
D_expand = torch.unsqueeze(D, -1)
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
return RBF
def get_seqsep(idx):
'''
Sequence separation feature for structure module. Protein-only.
Input:
- idx: residue indices of given sequence (B,L)
Output:
- seqsep: sequence separation feature with sign (B, L, L, 1)
Sergey found that having sign in seqsep features helps a little
'''
seqsep = idx[:,None,:] - idx[:,:,None]
sign = torch.sign(seqsep)
neigh = torch.abs(seqsep)
neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0
neigh = sign * neigh
return neigh.unsqueeze(-1)
def get_seqsep_protein_sm(idx, bond_feats, dist_matrix, sm_mask):
'''
Sequence separation features for protein-SM complex
Input:
- idx: residue indices of given sequence (B,L)
- bond_feats: bond features (B, L, L)
- dist_matrix: precomputed bond distances (B, L, L) NOTE: need to run nan_to_num to remove infinities
- sm_mask: boolean feature True if a position represents atom, False if residue (B, L)
Output:
- seqsep: sequence separation feature with sign (B, L, L, 1)
-1 or 1 for bonded protein residues
1 for bonded SM atoms or residue-atom bonds
0 elsewhere
'''
sm_mask = sm_mask[0] # assume batch = 1
res_dist, atom_dist = get_res_atom_dist(idx, bond_feats, dist_matrix, sm_mask)
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
res_dist[(res_dist > 1) | (res_dist < -1)] = 0.0
atom_dist[(atom_dist > 1)] = 0.0
seqsep = sm_mask_2d*atom_dist + prot_mask_2d*res_dist + inter_mask_2d*(bond_feats==6)
return seqsep.unsqueeze(-1)
def get_res_atom_dist(idx, bond_feats, dist_matrix, sm_mask, minpos_res=-32, maxpos_res=32, maxpos_atom=8):
'''
Calculates residue and atom bond distances of protein/SM complex. Used for positional
embedding and structure module. 2nd version (2022-9-19); handles atomized proteins.
Input:
- idx: residue index (B, L)
- bond_feats: bond features (B, L, L)
- dist_matrix: precomputed bond distances (B, L, L) NOTE: need to run nan_to_num to remove infinities
- sm_mask: boolean feature (L). True if a position represents atom, False otherwise
- minpos_res: minimum value of residue distances
- maxpos_res: maximum value of residue distances
- maxpos_atom: maximum value of atom bond distances
Output:
- res_dist: residue distance (B, L, L)
- atom_dist: atom bond distance (B, L, L)
'''
bond_feats = bond_feats[0] # assume batch = 1
L = bond_feats.shape[0]
device = bond_feats.device
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
# protein residue distances
res_dist_prot = torch.clamp(idx[0,None,:] - idx[0,:,None],
min=minpos_res, max=maxpos_res) # (L, L) intra-protein
res_dist_sm = torch.full((L,L), maxpos_res+1, device=device) # (L, L) with "unknown" res. dist. token
# small molecule atom bond graph
atom_dist_sm = torch.nan_to_num(dist_matrix, posinf=maxpos_atom)[0].long() # this comes through the dataloader so it is batched
atom_dist_prot = torch.full((L,L), maxpos_atom+1, device=device)
#fd new impl
i_s, j_s = torch.where(bond_feats==6)
i_sm = i_s[sm_mask[i_s]]
i_prot = j_s[sm_mask[i_s]]
res_dist_inter = torch.full((L,L), maxpos_res, device=device)
atom_dist_inter = torch.full((L,L), maxpos_atom, device=device)
if i_prot.shape[0] > 0:
closest_prot_res = i_prot[torch.argmin(atom_dist_sm[sm_mask,:][:,i_sm], dim=-1)]
res_dist_inter[sm_mask,:] = res_dist_prot[closest_prot_res,:]
res_dist_inter[:,sm_mask] = res_dist_prot[:,closest_prot_res]
closest_atom = i_sm[torch.argmin(torch.abs(res_dist_prot[~sm_mask,:][:,i_prot]), dim=-1)]
atom_dist_inter[~sm_mask,:] = atom_dist_sm[closest_atom,:] + 1
atom_dist_inter[:,~sm_mask] = atom_dist_sm[:,closest_atom] + 1
res_dist = res_dist_prot * prot_mask_2d + res_dist_inter * inter_mask_2d + res_dist_sm * sm_mask_2d
atom_dist = atom_dist_prot * prot_mask_2d + atom_dist_inter * inter_mask_2d + atom_dist_sm * sm_mask_2d
return res_dist[None], atom_dist[None] # add batch dim.
def get_relpos(idx, bond_feats, sm_mask, inter_pos=32, maxpath=32):
'''
Relative position matrix of protein/SM complex. Used for positional
embedding and structure module. Simple version from 9/2/2022 that doesn't
handle atomized proteins.
Input:
- idx: residue index (B, L)
- bond_feats: bond features (B, L, L)
- sm_mask: boolean feature True if a position represents atom, False if residue (B, L)
- inter_pos: value to assign as the protein-SM residue index differences
- maxpath: bond distances greater than this are clipped to this value
Output:
- relpos: relative position feature (B, L, L)
for intra-protein this is the residue index difference
for intra-SM this is the bond distance
for protein-SM this is user-defined value inter_pos
'''
bond_feats = bond_feats[0]
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
# intra-protein: residue # differences
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
# intra-small molecule: bond distances
sm_bond_feats = torch.zeros_like(bond_feats) + sm_mask*bond_feats
G = nx.from_numpy_matrix(sm_bond_feats.detach().cpu().numpy())
paths = dict(nx.all_pairs_shortest_path_length(G,cutoff=maxpath))
paths = [(i,j,vij) for i,vi in paths.items() for j,vij in vi.items()]
i,j,v = torch.tensor(paths).T
bond_separation = torch.full_like(bond_feats, maxpath) \
- maxpath*torch.eye(bond_feats.shape[0]).to(bond_feats.device).long()
bond_separation[i,j] = v.to(bond_feats.device)
# combine: protein-s.m. are always positive maximum distance apart
# assumes one small molecule per example
relpos = prot_mask_2d * seqsep + sm_mask_2d * bond_separation + inter_mask_2d * inter_pos # (B, L, L)
relpos = relpos.to(bond_feats.device)
return relpos
def make_full_graph(xyz, pair, idx):
'''
Input:
- xyz: current backbone cooordinates (B, L, 3, 3)
- pair: pair features from Trunk (B, L, L, E)
- idx: residue index from ground truth pdb
Output:
- G: defined graph
'''
B, L = xyz.shape[:2]
device = xyz.device
# seq sep
sep = idx[:,None,:] - idx[:,:,None]
b,i,j = torch.where(sep.abs() > 0)
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]) #.detach() # no gradient through basis function
return G, pair[b,i,j][...,None]
def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-6):
'''
Input:
- xyz: current backbone cooordinates (B, L, 3, 3)
- pair: pair features from Trunk (B, L, L, E)
- idx: residue index from ground truth pdb
Output:
- G: defined graph
'''
B, L = xyz.shape[:2]
device = xyz.device
# distance map from current CA coordinates
D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)*9999.9 # (B, L, L)
# seq sep
sep = idx[:,None,:] - idx[:,:,None]
sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*9999.9
if (topk_incl_local):
D = D + sep*eps
D[sep<nlocal] = 0.0
# get top_k neighbors
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
topk_matrix = torch.zeros((B, L, L), device=device)
topk_matrix.scatter_(2, E_idx, 1.0)
cond = topk_matrix > 0.0
else:
D = D + sep*eps
# get top_k neighbors
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
topk_matrix = torch.zeros((B, L, L), device=device)
topk_matrix.scatter_(2, E_idx, 1.0)
# put an edge if any of the 3 conditions are met:
# 1) |i-j| <= kmin (connect sequentially adjacent residues)
# 2) top_k neighbors
cond = torch.logical_or(topk_matrix > 0.0, sep < nlocal)
b,i,j = torch.where(cond)
src = b*L+i
tgt = b*L+j
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
return G, pair[b,i,j][...,None]
def make_atom_graph( xyz, mask, num_bonds, top_k=16, maxbonds=4 ):
B,L,A = xyz.shape[:3]
device = xyz.device
D = torch.norm(
xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1
)
mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:]
D[~mask2d] = 9999.
D[D==0] = 9999.
# select top K neighbors for each atom
# keep indices as batch/res/atm indices
D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, top_k)
Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A
bi,ri,ai = mask.nonzero(as_tuple=True)
bi = bi[:,None].repeat(1,top_k).reshape(-1)
ri = ri[:,None].repeat(1,top_k).reshape(-1)
ai = ai[:,None].repeat(1,top_k).reshape(-1)
rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1)
# on each edge, 1-hot encode the number of bonds (up to maxbonds) seperating each atom
edge = torch.full(ri.shape, maxbonds, device=device)
resmask = ri==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],aj[resmask]]-1
resmask = ri+1==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],2]+num_bonds[bi[resmask],rj[resmask],0,aj[resmask]]
resmask = ri-1==rj
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],0]+num_bonds[bi[resmask],rj[resmask],2,aj[resmask]]
edge = edge.clamp(0,maxbonds-1)
edge = F.one_hot(edge)[...,None]
natm = torch.sum(mask)
index = torch.zeros_like(mask, dtype=torch.long, device=device)
index[mask] = torch.arange(natm, device=device)
src=index[bi,ri,ai]
tgt=index[bi,rj,aj]
G = dgl.graph((src, tgt), num_nodes=natm).to(device)
G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj]).detach() # no gradient through basis function
return G, edge
# rotate about the x axis
def make_rotX(angs, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
RTs[:,:,1,1] = angs[:,:,0]/NORM
RTs[:,:,1,2] = -angs[:,:,1]/NORM
RTs[:,:,2,1] = angs[:,:,1]/NORM
RTs[:,:,2,2] = angs[:,:,0]/NORM
return RTs
# rotate about the x axis
def make_rotZ(angs, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
RTs[:,:,0,0] = angs[:,:,0]/NORM
RTs[:,:,0,1] = -angs[:,:,1]/NORM
RTs[:,:,1,0] = angs[:,:,1]/NORM
RTs[:,:,1,1] = angs[:,:,0]/NORM
return RTs
# rotate about an arbitrary axis
def make_rot_axis(angs, u, eps=1e-6):
B,L = angs.shape[:2]
NORM = torch.linalg.norm(angs, dim=-1) + eps
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
ct = angs[:,:,0]/NORM
st = angs[:,:,1]/NORM
u0 = u[:,:,0]
u1 = u[:,:,1]
u2 = u[:,:,2]
RTs[:,:,0,0] = ct+u0*u0*(1-ct)
RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st
RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st
RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st
RTs[:,:,1,1] = ct+u1*u1*(1-ct)
RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st
RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st
RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st
RTs[:,:,2,2] = ct+u2*u2*(1-ct)
return RTs
# compute allatom structure from backbone frames and torsions
#
# alphas:
# omega/phi/psi: 0-2
# chi_1-4(prot): 3-6
# cb/cg bend: 7-9
# eps(p)/zeta(p): 10-11
# alpha/beta/gamma/delta: 12-15
# nu2/nu1/nu0: 16-18
# chi_1(na): 19
#
# RTs_in_base_frame:
# omega/phi/psi: 0-2
# chi_1-4(prot): 3-6
# eps(p)/zeta(p): 7-8
# alpha/beta/gamma/delta: 9-12
# nu2/nu1/nu0: 13-15
# chi_1(na): 16
#
# RT frames (output):
# origin: 0
# omega/phi/psi: 1-3
# chi_1-4(prot): 4-7
# cb bend: 8
# alpha/beta/gamma/delta: 9-12
# nu2/nu1/nu0: 13-15
# chi_1(na): 16
#
class XYZConverter(nn.Module):
def __init__(self):
super(XYZConverter, self).__init__()
self.register_buffer("torsion_indices", ChemData().torsion_indices, persistent=False)
self.register_buffer("torsion_can_flip", ChemData().torsion_can_flip.to(torch.int32), persistent=False)
self.register_buffer("ref_angles", ChemData().reference_angles, persistent=False)
self.register_buffer("base_indices", ChemData().base_indices, persistent=False)
self.register_buffer("RTs_in_base_frame", ChemData().RTs_by_torsion, persistent=False)
self.register_buffer("xyzs_in_base_frame", ChemData().xyzs_in_base_frame, persistent=False)
def compute_all_atom(self, seq, xyz, alphas):
B,L = xyz.shape[:2]
is_NA = is_nucleic(seq)
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], is_NA)
RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device)
# bb
RTF0[:,:,:3,:3] = Rs
RTF0[:,:,:3,3] = Ts
# omega
RTF1 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:]))
# phi
RTF2 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:]))
# psi
RTF3 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:]))
# CB bend
basexyzs = self.xyzs_in_base_frame[seq]
NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3])
CAr = (basexyzs[:,:,1,:3])
CBr = (basexyzs[:,:,4,:3])
CBrotaxis1 = (CBr-CAr).cross(NCr-CAr)
CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8
# CB twist
NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3]
NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr
CBrotaxis2 = (CBr-CAr).cross(NCpp)
CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8
CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 )
CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 )
RTF8 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, CBrot1,CBrot2)
# chi1 + CG bend
RTF4 = torch.einsum(
'brij,brjk,brkl,brlm->brim',
RTF8,
self.RTs_in_base_frame[seq,3,:],
make_rotX(alphas[:,:,3,:]),
make_rotZ(alphas[:,:,9,:]))
# chi2
RTF5 = torch.einsum(
'brij,brjk,brkl->bril',
RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:]))
# chi3
RTF6 = torch.einsum(
'brij,brjk,brkl->bril',
RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:]))
# chi4
RTF7 = torch.einsum(
'brij,brjk,brkl->bril',
RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:]))
# ignore RTs_in_base_frame[seq,7:9,:] and alphas[:,:,10:12,:]
# which mode are we running in
if (not ChemData().params.use_phospate_frames_for_NA):
# NA nu1 --> from base frame
RTF14 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,14,:], make_rotX(alphas[:,:,17,:]))
# NA nu0 --> from base frame
RTF15 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,15,:], make_rotX(alphas[:,:,18,:]))
# NA chi --> from base frame
RTF16= torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,16,:], make_rotX(alphas[:,:,19,:]))
# NA nu2 --> from nu1 frame
RTF13 = torch.einsum(
'brij,brjk,brkl->bril',
RTF14, self.RTs_in_base_frame[seq,13,:], make_rotX(alphas[:,:,16,:]))
# NA delta --> from nu2 frame
RTF12 = torch.einsum(
'brij,brjk,brkl->bril',
RTF13, self.RTs_in_base_frame[seq,12,:], make_rotX(alphas[:,:,15,:]))
# NA gamma --> from delta frame
RTF11 = torch.einsum(
'brij,brjk,brkl->bril',
RTF12, self.RTs_in_base_frame[seq,11,:], make_rotX(alphas[:,:,14,:]))
# NA beta --> from gamma frame
RTF10 = torch.einsum(
'brij,brjk,brkl->bril',
RTF11, self.RTs_in_base_frame[seq,10,:], make_rotX(alphas[:,:,13,:]))
# NA alpha --> from beta frame
RTF9 = torch.einsum(
'brij,brjk,brkl->bril',
RTF10, self.RTs_in_base_frame[seq,9,:], make_rotX(alphas[:,:,12,:]))
else:
# NA alpha
RTF9 = torch.einsum(
'brij,brjk,brkl->bril',
RTF0, self.RTs_in_base_frame[seq,9,:], make_rotX(alphas[:,:,12,:]))
# NA beta
RTF10 = torch.einsum(
'brij,brjk,brkl->bril',
RTF9, self.RTs_in_base_frame[seq,10,:], make_rotX(alphas[:,:,13,:]))
# NA gamma
RTF11 = torch.einsum(
'brij,brjk,brkl->bril',
RTF10, self.RTs_in_base_frame[seq,11,:], make_rotX(alphas[:,:,14,:]))
# NA delta
RTF12 = torch.einsum(
'brij,brjk,brkl->bril',
RTF11, self.RTs_in_base_frame[seq,12,:], make_rotX(alphas[:,:,15,:]))
# NA nu2 - from gamma frame
RTF13 = torch.einsum(
'brij,brjk,brkl->bril',
RTF11, self.RTs_in_base_frame[seq,13,:], make_rotX(alphas[:,:,16,:]))
# NA nu1
RTF14 = torch.einsum(
'brij,brjk,brkl->bril',
RTF13, self.RTs_in_base_frame[seq,14,:], make_rotX(alphas[:,:,17,:]))
# NA nu0
RTF15 = torch.einsum(
'brij,brjk,brkl->bril',
RTF14, self.RTs_in_base_frame[seq,15,:], make_rotX(alphas[:,:,18,:]))
# NA chi - from nu1 frame
RTF16= torch.einsum(
'brij,brjk,brkl->bril',
RTF14, self.RTs_in_base_frame[seq,16,:], make_rotX(alphas[:,:,19,:]))
RTframes = torch.stack((
RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8,
RTF9,RTF10,RTF11,RTF12,RTF13,RTF14,RTF15,RTF16
),dim=2)
xyzs = torch.einsum(
'brtij,brtj->brti',
RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs
)
return RTframes, xyzs[...,:3]
def get_tor_mask(self, seq, mask_in=None):
B,L = seq.shape[:2]
dna_mask = is_nucleic(seq)
prot_mask = ~dna_mask
tors_mask = self.torsion_indices[seq,:,-1] > 0
if mask_in != None:
N = mask_in.shape[2]
ts = self.torsion_indices[seq]
bs = torch.arange(B, device=seq.device)[:,None,None,None]
rs = torch.arange(L, device=seq.device)[None,:,None,None] - (ts<0)*1 # ts<-1 ==> prev res
ts = torch.abs(ts)
tors_mask *= mask_in[bs,rs,ts].all(dim=-1)
return tors_mask
def get_torsions(self, xyz_in, seq, mask_in=None):
B,L = xyz_in.shape[:2]
tors_mask = self.get_tor_mask(seq, mask_in)
# idealize given xyz coordinates before computing torsion angles
xyz = idealize_reference_frame(seq, xyz_in)
ts = self.torsion_indices[seq]
bs = torch.arange(B, device=xyz_in.device)[:,None,None,None]
xs = torch.arange(L, device=xyz_in.device)[None,:,None,None] - (ts<0)*1 # ts<-1 ==> prev res
ys = torch.abs(ts)
xyzs_bytor = xyz[bs,xs,ys,:]
torsions = torch.zeros( (B,L,ChemData().NTOTALDOFS,2), device=xyz_in.device )
# protein torsion
torsions[...,:7,:] = th_dih(
xyzs_bytor[...,:7,0,:],xyzs_bytor[...,:7,1,:],xyzs_bytor[...,:7,2,:],xyzs_bytor[...,:7,3,:]
)
torsions[:,:,2,:] = -1 * torsions[:,:,2,:] # shift psi by pi
# NA
torsions[...,10:,:] = th_dih(
xyzs_bytor[...,10:,0,:],xyzs_bytor[...,10:,1,:],xyzs_bytor[...,10:,2,:],xyzs_bytor[...,10:,3,:]
)
# protein angles
# CB bend
NC = 0.5*( xyz[:,:,0,:3] + xyz[:,:,2,:3] )
CA = xyz[:,:,1,:3]
CB = xyz[:,:,4,:3]
t = th_ang_v(CB-CA,NC-CA)
t0 = self.ref_angles[seq][...,0,:]
torsions[:,:,7,:] = torch.stack(
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
dim=-1 )
# CB twist
NCCA = NC-CA
NCp = xyz[:,:,2,:3] - xyz[:,:,0,:3]
NCpp = NCp - torch.sum(NCp*NCCA, dim=-1, keepdim=True)/ torch.sum(NCCA*NCCA, dim=-1, keepdim=True) * NCCA
t = th_ang_v(CB-CA,NCpp)
t0 = self.ref_angles[seq][...,1,:]
torsions[:,:,8,:] = torch.stack(
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
dim=-1 )
# CG bend
CG = xyz[:,:,5,:3]
t = th_ang_v(CG-CB,CA-CB)
t0 = self.ref_angles[seq][...,2,:]
torsions[:,:,9,:] = torch.stack(
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
dim=-1 )
mask0 = (torch.isnan(torsions[...,0])).nonzero()
mask1 = (torch.isnan(torsions[...,1])).nonzero()
torsions[mask0[:,0],mask0[:,1],mask0[:,2],0] = 1.0
torsions[mask1[:,0],mask1[:,1],mask1[:,2],1] = 0.0
# alt chis
torsions_alt = torsions.clone()
torsions_alt[self.torsion_can_flip[seq,:].to(torch.bool)] *= -1
# torsions to restrain to 0 or 180 degree
# (this should be specified in chemical?)
tors_planar = torch.zeros((B, L, ChemData().NTOTALDOFS), dtype=torch.bool, device=xyz_in.device)
tors_planar[:,:,5] = seq == ChemData().aa2num['TYR'] # TYR chi 3 should be planar
return torsions, torsions_alt, tors_mask, tors_planar