mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-14 22:33:58 +00:00
initial commit
This commit is contained in:
commit
f87f5b8cdf
98 changed files with 26515 additions and 0 deletions
15
.gitignore
vendored
Normal file
15
.gitignore
vendored
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
valid_remapped
|
||||||
|
lig_test
|
||||||
|
dataset.pkl
|
||||||
|
run_digs.sh
|
||||||
|
*.pdb
|
||||||
|
.vscode
|
||||||
|
slurm_logs/
|
||||||
|
*/output/
|
||||||
|
*/notebooks/
|
||||||
|
*/models/
|
||||||
|
__pycache__/
|
||||||
|
*/run_scripts/
|
||||||
|
unit_tests/
|
||||||
|
ruff.toml
|
||||||
|
*/scratch/
|
254
README.md
Normal file
254
README.md
Normal file
|
@ -0,0 +1,254 @@
|
||||||
|
Code for RoseTTAFold All-Atom
|
||||||
|
--------------------
|
||||||
|
<p align="right">
|
||||||
|
<img style="float: right" src="./img/RFAA.png" alt="alt text" width="600px" align="right"/>
|
||||||
|
</p>
|
||||||
|
RoseTTAFold All-Atom is a biomolecular structure prediction neural network that can predict a broad range of biomolecular assemblies including proteins, nucleic acids, small molecules, covalent modifications and metals as outlined in the RFAA paper.
|
||||||
|
|
||||||
|
RFAA is not accurate for all cases, but produces useful error estimates to allow users to identify accurate predictions. Below are the instructions for setting up and using the model.
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
- [Setup/Installation](#set-up)
|
||||||
|
- [Inference Configs Using Hydra](#inference-config)
|
||||||
|
- [Predicting protein structures](#protein-pred)
|
||||||
|
- [Predicting protein/nucleic acid complexes](#p-na-complex)
|
||||||
|
- [Predicting protein/small molecule complexes](#p-sm-complex)
|
||||||
|
- [Predicting higher order complexes](#higher-order)
|
||||||
|
- [Predicting covalently modified proteins](#covale)
|
||||||
|
- [Understanding model outputs](#outputs)
|
||||||
|
- [Conclusion](#conclusion)
|
||||||
|
|
||||||
|
<a id="set-up"></a>
|
||||||
|
### Setup/Installation
|
||||||
|
1. Clone the package
|
||||||
|
```
|
||||||
|
git clone https://github.com/baker-laboratory/RoseTTAFold-All-Atom
|
||||||
|
cd RF2-allatom
|
||||||
|
```
|
||||||
|
2. Download the container used to run RFAA.
|
||||||
|
```
|
||||||
|
wget http://files.ipd.uw.edu/pub/RF-All-Atom/containers/SE3nv-20240131.sif
|
||||||
|
```
|
||||||
|
3. Download the model weights.
|
||||||
|
```
|
||||||
|
wget http://files.ipd.uw.edu/pub/RF-All-Atom/weights/RFAA_paper_weights.pt
|
||||||
|
|
||||||
|
```
|
||||||
|
4. Download sequence databases for MSA and template generation.
|
||||||
|
```
|
||||||
|
# uniref30 [46G]
|
||||||
|
wget http://wwwuser.gwdg.de/~compbiol/uniclust/2020_06/UniRef30_2020_06_hhsuite.tar.gz
|
||||||
|
mkdir -p UniRef30_2020_06
|
||||||
|
tar xfz UniRef30_2020_06_hhsuite.tar.gz -C ./UniRef30_2020_06
|
||||||
|
|
||||||
|
# BFD [272G]
|
||||||
|
wget https://bfd.mmseqs.com/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz
|
||||||
|
mkdir -p bfd
|
||||||
|
tar xfz bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz -C ./bfd
|
||||||
|
|
||||||
|
# structure templates (including *_a3m.ffdata, *_a3m.ffindex)
|
||||||
|
wget https://files.ipd.uw.edu/pub/RoseTTAFold/pdb100_2021Mar03.tar.gz
|
||||||
|
tar xfz pdb100_2021Mar03.tar.gz
|
||||||
|
```
|
||||||
|
|
||||||
|
<a id="inference-config"></a>
|
||||||
|
### Inference Configs Using Hydra
|
||||||
|
|
||||||
|
We use a library called Hydra to compose config files for predictions. The actual script that runs the model is in `rf2aa/run_inference.py` and default parameters that were used to train the model are in `rf2aa/config/inference/base.yaml`. We highly suggest using the default parameters since those are closest to the training task for RFAA but we have found that increasing loader_params.MAXCYCLE=10 (default set to 4) gives better results for hard cases (as noted in the paper).
|
||||||
|
|
||||||
|
We use a container system called apptainers which have very simple syntax. Instead of developing a local conda environment, users can use the apptainer to run the model which has all the dependencies already packaged.
|
||||||
|
|
||||||
|
The general way to run the model is as follows:
|
||||||
|
```
|
||||||
|
SE3nv-20240131.sif -m rf2aa.run_inference --config-name {your inference config}
|
||||||
|
```
|
||||||
|
The main inputs into the model are split into:
|
||||||
|
- protein inputs (protein_inputs)
|
||||||
|
- nucleic acid inputs (na_inputs)
|
||||||
|
- small molecule inputs (sm_inputs)
|
||||||
|
- covalent bonds between protein chains and small molecule chains
|
||||||
|
- modified or unnatural amino acids (COMING SOON)
|
||||||
|
|
||||||
|
In the following sections, we will describe how to set up configs for different prediction tasks that we described in the paper.
|
||||||
|
|
||||||
|
<a id="protein-pred"></a>
|
||||||
|
### Predicting Protein Monomers
|
||||||
|
|
||||||
|
Predicting a protein monomer structure requires an input fasta file and an optional job_name which will be used to name your output files. Here is a sample config (also in `rf2aa/config/inference/protein.yaml`).
|
||||||
|
```
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
job_name: "7u7w_protein"
|
||||||
|
protein_inputs:
|
||||||
|
A:
|
||||||
|
fasta_file: examples/protein/7u7w_A.fasta
|
||||||
|
```
|
||||||
|
The first line indicates that this job inherits all the configurations from the base file (this should be true for all your inference jobs). Then you can optionally specify the job name (the default job_name is "structure_prediction" so we highly recommend specifying one).
|
||||||
|
|
||||||
|
When specifying the fasta file for your protein, you might notice that it is nested within a mysterious "A" parameter. This represents a chain letter and is absolutely **required**, this is important when users are specifying multiple chains.
|
||||||
|
|
||||||
|
Now to predict the sample monomer structure, run:
|
||||||
|
```
|
||||||
|
SE3nv-20240131.sif -m rf2aa.run_inference --config-name protein
|
||||||
|
```
|
||||||
|
|
||||||
|
<a id="p-na-complex"></a>
|
||||||
|
### Predicting Protein Nucleic Acid Complexes
|
||||||
|
Protein-nucleic acid complexes have very similar syntax to protein monomer prediction, except with additional chains for nucleic acids. Here is sample config (also in `rf2aa/config/inference/nucleic_acid.yaml`):
|
||||||
|
```
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
job_name: "7u7w_protein_nucleic"
|
||||||
|
protein_inputs:
|
||||||
|
A:
|
||||||
|
fasta_file: examples/protein/7u7w_A.fasta
|
||||||
|
na_inputs:
|
||||||
|
B:
|
||||||
|
fasta: examples/nucleic_acid/7u7w_B.fasta
|
||||||
|
input_type: "dna"
|
||||||
|
C:
|
||||||
|
fasta: examples/nucleic_acid/7u7w_C.fasta
|
||||||
|
input_type: "dna"
|
||||||
|
```
|
||||||
|
Once again this config inherits the base config, defines a job name and provides a protein fasta file for chain A. To add double stranded DNA, you must add two more chain inputs for each strand (shown here as chains B and C). In this case, the allowed input types are dna and rna.
|
||||||
|
|
||||||
|
This repo currently does not support making RNA MSAs or pairing protein MSAs with RNA MSAs but this is functionality that we are keen to add. For now, please use RF-NA for modeling cases requiring paired protein-RNA MSAs.
|
||||||
|
|
||||||
|
Now, predict the example protein/NA complex.
|
||||||
|
```
|
||||||
|
SE3nv-20240131.sif -m rf2aa.run_inference --config-name nucleic_acid
|
||||||
|
```
|
||||||
|
<a id="p-sm-complex"></a>
|
||||||
|
### Predicting Protein Small Molecule Complexes
|
||||||
|
To predict protein small molecule complexes, the syntax to input the protein remains the same. Adding in the small molecule works similarly to other inputs.
|
||||||
|
Here is an example (from `rf2aa/config/inference/protein_sm.yaml`):
|
||||||
|
```
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
job_name: 7qxr
|
||||||
|
|
||||||
|
protein_inputs:
|
||||||
|
A:
|
||||||
|
fasta_file: examples/protein/7qxr.fasta
|
||||||
|
|
||||||
|
sm_inputs:
|
||||||
|
B:
|
||||||
|
input: examples/small_molecule/NSW_ideal.sdf
|
||||||
|
input_type: "sdf"
|
||||||
|
```
|
||||||
|
Small molecule inputs are provided as sdf files or smiles strings and users are **required** to provide both an input and an input_type field for every small molecule that they want to provide. Metal ions can also be provided as sdf files or smiles strings.
|
||||||
|
|
||||||
|
To predict the example:
|
||||||
|
```
|
||||||
|
SE3nv-20240131.sif -m rf2aa.run_inference --config-name protein_sm
|
||||||
|
```
|
||||||
|
<a id="higher-order"></a>
|
||||||
|
### Predicting Higher Order Complexes
|
||||||
|
If you have been following thus-far, this is where we put all the previous sections together! To predict a protein-nucleic acid-small molecule complex, you can combine the schema for all the inputs we have seen so far!
|
||||||
|
|
||||||
|
Here is an example:
|
||||||
|
```
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
job_name: "7u7w_protein_nucleic_sm"
|
||||||
|
protein_inputs:
|
||||||
|
A:
|
||||||
|
fasta_file: examples/protein/7u7w_A.fasta
|
||||||
|
na_inputs:
|
||||||
|
B:
|
||||||
|
fasta: examples/nucleic_acid/7u7w_B.fasta
|
||||||
|
input_type: "dna"
|
||||||
|
C:
|
||||||
|
fasta: examples/nucleic_acid/7u7w_C.fasta
|
||||||
|
input_type: "dna"
|
||||||
|
sm_inputs:
|
||||||
|
D:
|
||||||
|
input: examples/small_molecule/XG4.sdf
|
||||||
|
input_type: "sdf"
|
||||||
|
```
|
||||||
|
And to run:
|
||||||
|
```
|
||||||
|
SE3nv-20240131.sif -m rf2aa.run_inference --config-name protein_na_sm
|
||||||
|
```
|
||||||
|
<a id="covale"></a>
|
||||||
|
### Predicting Covalently Modified Proteins
|
||||||
|
Specifying covalent modifications is slightly more complicated for the following reasons.
|
||||||
|
|
||||||
|
- Forming new covalent bonds can create or remove chiral centers. Since RFAA specifies chirality at input, the network needs to be provided with chirality information. Under the hood, chiral centers are identified by a package called Openbabel which does not always agree with chemical intuition.
|
||||||
|
- Covalent modifications often have "leaving groups", or chemical groups that leave both the protein and the modification upon modification.
|
||||||
|
|
||||||
|
The way you input covalent bonds to RFAA is as a list of bonds between an atom on the protein and an atom on one of the input small molecules. This is the syntax for those bonds:
|
||||||
|
```
|
||||||
|
(protein_chain, residue_number, atom_name), (small_molecule_chain, atom_index), (new_chirality_atom_1, new_chirality_atom_2)
|
||||||
|
```
|
||||||
|
**Both the protein residue number and the atom_index are 1 indexed** (as you would normally count, as opposed to 0 indexed like many programming languages).
|
||||||
|
|
||||||
|
In most cases, the chirality of the atoms will not change. This is what an input for a case where the chirality does not change looks like:
|
||||||
|
```
|
||||||
|
(protein_chain, residue_number, atom_name), (small_molecule_chain, atom_index), ("null", "null")
|
||||||
|
```
|
||||||
|
The options for chirality are `CCW` and `CW` for counterclockwise and clockwise. The code will raise an Exception is there is a chiral center that Openbabel found that the user did not specify. Even if you believe Openbabel is wrong, the network likely received chirality information for those cases during training, so we expect that you will get the best results by specifying chirality at those positions.
|
||||||
|
|
||||||
|
**You cannot define bonds between two small molecule chains**. In cases, where the PDB defines molecules in "multiple residues", you must merge the residues into a single sdf file first.
|
||||||
|
|
||||||
|
**You must remove any leaving groups from your input molecules before inputting them into the network, but the code will handle leaving groups on the sidechain that is being modified automatically.** There is code for providing leaving group dynamically from the hydra config, but that is experimental and we have not fully tested it.
|
||||||
|
|
||||||
|
Given all of that background, this is how you specify covalent modification structure prediction to RFAA.
|
||||||
|
|
||||||
|
```
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
job_name: 7s69_A
|
||||||
|
|
||||||
|
protein_inputs:
|
||||||
|
A:
|
||||||
|
fasta_file: examples/protein/7s69_A.fasta
|
||||||
|
|
||||||
|
sm_inputs:
|
||||||
|
B:
|
||||||
|
input: examples/small_molecule/7s69_glycan.sdf
|
||||||
|
input_type: sdf
|
||||||
|
|
||||||
|
covale_inputs: "[((\"A\", \"74\", \"ND2\"), (\"B\", \"1\"), (\"CW\", \"null\"))]"
|
||||||
|
|
||||||
|
loader_params:
|
||||||
|
MAXCYCLE: 10
|
||||||
|
```
|
||||||
|
**For covalently modified proteins, you must provide the input molecule as a sdf file**, since openbabel does not read smiles strings in a specific order. The syntax shown is identical to loading a protein and small molecule and then indicating a bond between them. In this case, hydra creates some problems because we have to escape the quotation marks using backslashes.
|
||||||
|
|
||||||
|
To clarify, this input:
|
||||||
|
```
|
||||||
|
[(("A", "74", "ND2"), ("B", "1"), ("CW", "null"))]
|
||||||
|
```
|
||||||
|
becomes this so it can be parsed correctly:
|
||||||
|
```
|
||||||
|
"[((\"A\", \"74\", \"ND2\"), (\"B\", \"1\"), (\"CW\", \"null\"))]"
|
||||||
|
```
|
||||||
|
|
||||||
|
We know this syntax is hard to work with and we are happy to review PRs if anyone in the community can figure out how to specify all the necessary requirements in a more user friendly way!
|
||||||
|
|
||||||
|
<a id="outputs"></a>
|
||||||
|
### Understanding model outputs
|
||||||
|
|
||||||
|
The model returns two files:
|
||||||
|
- PDB file with predicted structure (bfactors represent predicted lddt at each position)
|
||||||
|
- pytorch file with confidence metrics stored (can load with `torch.load(file, map_location="cpu")`)
|
||||||
|
|
||||||
|
Here are the confidence metrics:
|
||||||
|
|
||||||
|
1. plddts, tensor with node-wise plddt for each node in the prediction
|
||||||
|
2. pae, a LxL tensor where the model predicts the error of every j position if the ith position's frame is aligned (or atom frame for atom nodes)
|
||||||
|
3. pde, a LxL tensor where the model predicts the unsigned error of the each pairwise distance
|
||||||
|
4. mean_plddt, the mean over all the plddts
|
||||||
|
5. mean_pae, the mean over all pairwise predicted aligned errors
|
||||||
|
6. pae_prot, the mean over all pairwise protein residues
|
||||||
|
7. pae_inter, the mean over all the errors of protein residues with respect to atom frames and atom coordinates with respect to protein frames. **This was the primary confidence metric we used in the paper and expect cases with pae_inter <10 to have high quality docks.**
|
||||||
|
|
||||||
|
<a id="conclusion"></a>
|
||||||
|
### Conclusion
|
||||||
|
We expect that RFAA will continue to improve and will share new models as we create them. Additionally, we are excited to see how the community uses RFAA and RFdiffusionAA and would love to get feedback and review PRs as necessary.
|
0
__init__.py
Normal file
0
__init__.py
Normal file
2
examples/nucleic_acid/7u7w_B.fasta
Normal file
2
examples/nucleic_acid/7u7w_B.fasta
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
>7U7W_2|Chain B[auth T]|DNA (5'-D(*CP*AP*TP*TP*AP*TP*GP*AP*CP*GP*CP*T)-3')|synthetic construct (32630)
|
||||||
|
CATTATGACGCT
|
2
examples/nucleic_acid/7u7w_C.fasta
Normal file
2
examples/nucleic_acid/7u7w_C.fasta
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
>7U7W_3|Chain C[auth P]|DNA (5'-D(*AP*GP*CP*GP*TP*CP*AP*T)-3')|synthetic construct (32630)
|
||||||
|
AGCGTCAT
|
2
examples/protein/7qxr.fasta
Normal file
2
examples/protein/7qxr.fasta
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
>7QXR_1|Chains A, B, C|Fragment transplantation onto a hyperstable ancestor of haloalkane dehalogenases and Renilla luciferase (Anc-FT)|synthetic construct (32630)
|
||||||
|
TATGDEWWAKCKQVDVLDSEMSYYDSDPGKHKNTVIFLHGNPTSSYLWRNVIPHVEPLARCLAPDLIGMGKSGKLPNHSYRFVDHYRYLSAWFDSVNLPEKVTIVCHDWGSGLGFHWCNEHRDRVKGIVHMESVVDVIESWDEWPDIEEDIALIKSEAGEEMVLKKNFFIERLLPSSIMRKLSEEEMDAYREPFVEPGESRRPTLTWPREIPIKGDGPEDVIEIVKSYNKWLSTSKDIPKLFINADPGFFSNAIKKVTKNWPNQKTVTVKGLHFLQEDSPEEIGEAIADFLNELT
|
2
examples/protein/7s69_A.fasta
Normal file
2
examples/protein/7s69_A.fasta
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
>7S69_1|Chains A, B|N-acetylglucosamine-1-phosphotransferase gamma subunit|Xenopus laevis (8355)
|
||||||
|
DRHHHHHHKLGKMKIVEEPNSFGLNNPFLSQTNKLQPRVQPSPVSGPSHLFRLAGKCFNLVESTYKYELCPFHNVTQHEQTFRWNAYSGILGIWQEWDIENNTFSGMWMREGDSCGNKNRQTKVLLVCGKANKLSSVSEPSTCLYSLTFETPLVCHPHSLLVYPTLSEGLQEKWNEAEQALYDELITEQGHGKILKEIFREAGYLKTTKPDGEGKETQDKPKEFDSLEKCNKGYTELTSEIQRLKKMLNEHGISYVTNGTSRSEGQPAEVNTTFARGEDKVHLRGDTGIRDGQ
|
2
examples/protein/7u7w_A.fasta
Normal file
2
examples/protein/7u7w_A.fasta
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
>7U7W_1|Chain A|DNA polymerase eta|Homo sapiens (9606)
|
||||||
|
GPHMATGQDRVVALVDMDCFFVQVEQRQNPHLRNKPCAVVQYKSWKGGGIIAVSYEARAFGVTRSMWADDAKKLCPDLLLAQVRESRGKANLTKYREASVEVMEIMSRFAVIERASIDEAYVDLTSAVQERLQKLQGQPISADLLPSTYIEGLPQGPTTAEETVQKEGMRKQGLFQWLDSLQIDNLTSPDLQLTVGAVIVEEMRAAIERETGFQCSAGISHNKVLAKLACGLNKPNRQTLVSHGSVPQLFSQMPIRKIRSLGGKLGASVIEILGIEYMGELTQFTESQLQSHFGEKNGSWLYAMCRGIEHDPVKPRQLPKTIGCSKNFPGKTALATREQVQWWLLQLAQELEERLTKDRNDNDRVATQLVVSIRVQGDKRLSSLRRCCALTRYDAHKMSHDAFTVIKNCNTSGIQTEWSPPLTMLFLCATKFSAS
|
155
examples/small_molecule/7s69_glycan.sdf
Normal file
155
examples/small_molecule/7s69_glycan.sdf
Normal file
|
@ -0,0 +1,155 @@
|
||||||
|
|
||||||
|
OpenBabel03042416223D
|
||||||
|
|
||||||
|
72 77 0 0 1 0 0 0 0 0999 V2000
|
||||||
|
29.7340 3.2540 76.7430 C 0 0 0 0 0 2 0 0 0 0 0 0
|
||||||
|
29.8160 4.4760 77.6460 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
28.5260 5.2840 77.5530 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
28.1780 5.5830 76.1020 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
28.2350 4.3240 75.2420 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
28.1040 4.6170 73.7650 C 0 0 0 0 0 2 0 0 0 0 0 0
|
||||||
|
31.3020 3.8250 79.4830 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
31.3910 3.4410 80.9280 C 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
30.0760 4.0880 79.0210 N 0 0 0 0 0 2 0 0 0 0 0 0
|
||||||
|
28.6870 6.5050 78.2670 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
26.8490 6.0910 76.0350 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
29.4950 3.6650 75.4130 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
29.3670 4.5550 73.1150 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
32.2950 3.8940 78.7640 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
26.7420 7.4140 75.6950 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
25.2700 7.7830 75.6110 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
25.1290 9.2300 75.1610 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
25.9180 10.1440 76.0880 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
27.3630 9.6720 76.2210 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
28.1310 10.4360 77.2730 C 0 0 0 0 0 2 0 0 0 0 0 0
|
||||||
|
23.8820 5.8170 75.1400 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
23.1980 5.0100 74.0810 C 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
24.5530 6.8930 74.7160 N 0 0 0 0 0 2 0 0 0 0 0 0
|
||||||
|
23.7530 9.5950 75.1670 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
25.9170 11.4700 75.5730 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
27.4050 8.2900 76.6040 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
29.5300 10.4030 77.0280 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
23.8300 5.5110 76.3290 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
25.3940 12.4250 76.4090 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
25.9490 13.7680 75.9090 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
25.1320 14.9560 76.4900 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
23.6130 14.6900 76.6390 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
23.3700 13.3000 77.2280 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
21.9020 12.9360 77.3500 C 0 0 0 0 0 2 0 0 0 0 0 0
|
||||||
|
25.9010 13.8490 74.4810 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
25.3420 16.1410 75.7110 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
23.0420 15.6520 77.5170 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
23.9910 12.3690 76.3570 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
21.3660 12.8480 76.0500 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
20.8090 11.6500 75.6780 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
20.6800 11.6410 74.1740 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
19.5510 12.5850 73.8180 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
18.2370 12.0940 74.4540 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
18.4030 11.9240 75.9810 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
17.2710 11.1260 76.6120 C 0 0 0 0 0 2 0 0 0 0 0 0
|
||||||
|
20.2900 10.3510 73.7080 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
19.4280 12.7380 72.4110 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
17.2120 13.0460 74.2030 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
19.6260 11.2000 76.3010 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
16.0670 11.4490 75.9360 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
20.2190 13.6280 71.7260 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
19.6090 14.0000 70.3810 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
19.6360 12.7820 69.4880 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
21.0860 12.3100 69.3240 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
21.7030 12.0240 70.7120 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
23.1940 11.7460 70.6620 C 0 0 0 0 0 2 0 0 0 0 0 0
|
||||||
|
20.4080 14.9810 69.7000 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
19.0310 13.0500 68.2340 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
21.1060 11.1280 68.5380 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
21.5380 13.1700 71.5840 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
23.8240 12.5210 71.6820 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
26.0070 17.3020 76.0200 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
27.0750 17.5250 74.9350 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
28.3660 16.8320 75.3290 C 0 0 2 0 0 3 0 0 0 0 0 0
|
||||||
|
28.7820 17.2470 76.7510 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
27.6930 16.8120 77.7320 C 0 0 1 0 0 3 0 0 0 0 0 0
|
||||||
|
27.9770 17.2020 79.1710 C 0 0 0 0 0 2 0 0 0 0 0 0
|
||||||
|
27.3990 18.9140 74.8010 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
29.4060 17.0990 74.3950 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
30.0160 16.6410 77.0930 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
26.4610 17.4820 77.3520 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
27.3660 18.4620 79.4040 O 0 0 0 0 0 1 0 0 0 0 0 0
|
||||||
|
1 2 1 0 0 0 0
|
||||||
|
1 12 1 0 0 0 0
|
||||||
|
2 3 1 0 0 0 0
|
||||||
|
2 9 1 1 0 0 0
|
||||||
|
3 10 1 1 0 0 0
|
||||||
|
3 4 1 0 0 0 0
|
||||||
|
4 5 1 0 0 0 0
|
||||||
|
4 11 1 1 0 0 0
|
||||||
|
5 6 1 6 0 0 0
|
||||||
|
5 12 1 0 0 0 0
|
||||||
|
6 13 1 0 0 0 0
|
||||||
|
7 14 2 0 0 0 0
|
||||||
|
7 8 1 0 0 0 0
|
||||||
|
7 9 1 0 0 0 0
|
||||||
|
15 16 1 0 0 0 0
|
||||||
|
15 11 1 1 0 0 0
|
||||||
|
15 26 1 0 0 0 0
|
||||||
|
16 23 1 6 0 0 0
|
||||||
|
16 17 1 0 0 0 0
|
||||||
|
17 18 1 0 0 0 0
|
||||||
|
17 24 1 1 0 0 0
|
||||||
|
18 25 1 6 0 0 0
|
||||||
|
18 19 1 0 0 0 0
|
||||||
|
19 20 1 1 0 0 0
|
||||||
|
19 26 1 0 0 0 0
|
||||||
|
20 27 1 0 0 0 0
|
||||||
|
21 22 1 0 0 0 0
|
||||||
|
21 23 1 0 0 0 0
|
||||||
|
21 28 2 0 0 0 0
|
||||||
|
29 38 1 0 0 0 0
|
||||||
|
29 25 1 6 0 0 0
|
||||||
|
29 30 1 0 0 0 0
|
||||||
|
30 35 1 6 0 0 0
|
||||||
|
30 31 1 0 0 0 0
|
||||||
|
31 32 1 0 0 0 0
|
||||||
|
31 36 1 6 0 0 0
|
||||||
|
32 33 1 0 0 0 0
|
||||||
|
32 37 1 1 0 0 0
|
||||||
|
33 38 1 0 0 0 0
|
||||||
|
33 34 1 6 0 0 0
|
||||||
|
34 39 1 0 0 0 0
|
||||||
|
40 49 1 0 0 0 0
|
||||||
|
40 41 1 0 0 0 0
|
||||||
|
40 39 1 1 0 0 0
|
||||||
|
41 46 1 1 0 0 0
|
||||||
|
41 42 1 0 0 0 0
|
||||||
|
42 43 1 0 0 0 0
|
||||||
|
42 47 1 6 0 0 0
|
||||||
|
43 48 1 1 0 0 0
|
||||||
|
43 44 1 0 0 0 0
|
||||||
|
44 49 1 0 0 0 0
|
||||||
|
44 45 1 6 0 0 0
|
||||||
|
45 50 1 0 0 0 0
|
||||||
|
51 47 1 6 0 0 0
|
||||||
|
51 60 1 0 0 0 0
|
||||||
|
51 52 1 0 0 0 0
|
||||||
|
52 53 1 0 0 0 0
|
||||||
|
52 57 1 6 0 0 0
|
||||||
|
53 54 1 0 0 0 0
|
||||||
|
53 58 1 6 0 0 0
|
||||||
|
54 59 1 6 0 0 0
|
||||||
|
54 55 1 0 0 0 0
|
||||||
|
55 56 1 6 0 0 0
|
||||||
|
55 60 1 0 0 0 0
|
||||||
|
56 61 1 0 0 0 0
|
||||||
|
62 71 1 0 0 0 0
|
||||||
|
62 36 1 1 0 0 0
|
||||||
|
62 63 1 0 0 0 0
|
||||||
|
63 68 1 1 0 0 0
|
||||||
|
63 64 1 0 0 0 0
|
||||||
|
64 69 1 6 0 0 0
|
||||||
|
64 65 1 0 0 0 0
|
||||||
|
65 70 1 1 0 0 0
|
||||||
|
65 66 1 0 0 0 0
|
||||||
|
66 67 1 1 0 0 0
|
||||||
|
66 71 1 0 0 0 0
|
||||||
|
67 72 1 0 0 0 0
|
||||||
|
M END
|
||||||
|
$$$$
|
129
examples/small_molecule/NSW_ideal.sdf
Normal file
129
examples/small_molecule/NSW_ideal.sdf
Normal file
|
@ -0,0 +1,129 @@
|
||||||
|
NSW
|
||||||
|
-OEChem-02232415193D
|
||||||
|
|
||||||
|
53 57 0 0 0 0 0 0 0999 V2000
|
||||||
|
5.9220 -0.3020 1.2970 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
1.4470 1.1670 -0.1420 N 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
0.2000 0.5990 0.0490 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-0.8390 2.2820 1.6000 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-0.4190 4.1380 -0.0250 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-3.0970 4.8230 -0.0410 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.6580 3.7710 0.7410 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-3.2910 -1.6780 0.1420 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-5.5360 -2.0330 0.9170 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
1.3400 -0.7150 -1.3020 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
3.5200 0.5160 -1.4970 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
4.4840 -0.2350 -0.6170 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
4.8140 -1.5440 -0.9150 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
5.6890 -2.2400 -0.1020 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
6.2500 -1.6170 1.0030 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
5.0380 0.3850 0.4880 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-0.9010 1.0040 0.8040 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-1.3180 3.4280 0.7480 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-0.8580 5.1930 -0.8030 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.1970 5.5360 -0.8100 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.0560 -0.8580 0.1180 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-4.3880 -1.2680 0.8980 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-5.5990 -3.2100 0.1850 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-4.5100 -3.6190 -0.5740 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-3.3600 -2.8590 -0.5970 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-0.9650 -1.2700 -0.6360 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
2.1540 0.3180 -1.0070 N 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-1.9960 0.2710 0.8210 N 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
0.1460 -0.5340 -0.6580 N 0 3 0 0 0 0 0 0 0 0 0 0
|
||||||
|
1.6180 -1.6500 -2.0280 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
7.1200 -2.2960 1.7980 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-6.7310 -3.9620 0.2060 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
6.3530 0.1830 2.1600 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
1.7700 1.9970 0.2440 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-1.4750 2.1920 2.4810 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
0.1890 2.4660 1.9120 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
0.6270 3.8710 -0.0190 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-4.1440 5.0900 -0.0470 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-3.3610 3.2140 1.3420 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-6.3860 -1.7150 1.5030 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
3.7620 1.5790 -1.4770 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
3.5960 0.1460 -2.5200 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
4.3810 -2.0270 -1.7780 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
5.9460 -3.2620 -0.3350 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
4.7780 1.4080 0.7180 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-0.1550 5.7500 -1.4040 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.5410 6.3580 -1.4210 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-4.3390 -0.3520 1.4680 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-4.5640 -4.5340 -1.1460 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.5130 -3.1770 -1.1870 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-1.0160 -2.1870 -1.2050 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
6.6980 -2.7710 2.5270 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-7.3690 -3.7310 -0.4820 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
6 7 2 0 0 0 0
|
||||||
|
6 20 1 0 0 0 0
|
||||||
|
7 18 1 0 0 0 0
|
||||||
|
19 20 2 0 0 0 0
|
||||||
|
24 25 2 0 0 0 0
|
||||||
|
23 24 1 0 0 0 0
|
||||||
|
8 25 1 0 0 0 0
|
||||||
|
23 32 1 0 0 0 0
|
||||||
|
4 18 1 0 0 0 0
|
||||||
|
5 18 2 0 0 0 0
|
||||||
|
9 23 2 0 0 0 0
|
||||||
|
15 31 1 0 0 0 0
|
||||||
|
4 17 1 0 0 0 0
|
||||||
|
5 19 1 0 0 0 0
|
||||||
|
17 28 2 0 0 0 0
|
||||||
|
21 28 1 0 0 0 0
|
||||||
|
8 21 1 0 0 0 0
|
||||||
|
8 22 2 0 0 0 0
|
||||||
|
3 17 1 0 0 0 0
|
||||||
|
1 15 2 0 0 0 0
|
||||||
|
14 15 1 0 0 0 0
|
||||||
|
21 26 2 0 0 0 0
|
||||||
|
9 22 1 0 0 0 0
|
||||||
|
1 16 1 0 0 0 0
|
||||||
|
13 14 2 0 0 0 0
|
||||||
|
2 3 1 0 0 0 0
|
||||||
|
3 29 2 0 0 0 0
|
||||||
|
26 29 1 0 0 0 0
|
||||||
|
2 27 1 0 0 0 0
|
||||||
|
12 16 2 0 0 0 0
|
||||||
|
10 29 1 0 0 0 0
|
||||||
|
12 13 1 0 0 0 0
|
||||||
|
11 12 1 0 0 0 0
|
||||||
|
10 27 1 0 0 0 0
|
||||||
|
11 27 1 0 0 0 0
|
||||||
|
10 30 2 0 0 0 0
|
||||||
|
1 33 1 0 0 0 0
|
||||||
|
2 34 1 0 0 0 0
|
||||||
|
4 35 1 0 0 0 0
|
||||||
|
4 36 1 0 0 0 0
|
||||||
|
5 37 1 0 0 0 0
|
||||||
|
6 38 1 0 0 0 0
|
||||||
|
7 39 1 0 0 0 0
|
||||||
|
9 40 1 0 0 0 0
|
||||||
|
11 41 1 0 0 0 0
|
||||||
|
11 42 1 0 0 0 0
|
||||||
|
13 43 1 0 0 0 0
|
||||||
|
14 44 1 0 0 0 0
|
||||||
|
16 45 1 0 0 0 0
|
||||||
|
19 46 1 0 0 0 0
|
||||||
|
20 47 1 0 0 0 0
|
||||||
|
22 48 1 0 0 0 0
|
||||||
|
24 49 1 0 0 0 0
|
||||||
|
25 50 1 0 0 0 0
|
||||||
|
26 51 1 0 0 0 0
|
||||||
|
31 52 1 0 0 0 0
|
||||||
|
32 53 1 0 0 0 0
|
||||||
|
M CHG 1 29 1
|
||||||
|
M END
|
||||||
|
> <OPENEYE_ISO_SMILES>
|
||||||
|
c1ccc(cc1)Cc2c3[nH]n(c(=O)[n+]3cc(n2)c4ccc(cc4)O)Cc5ccc(cc5)O
|
||||||
|
|
||||||
|
> <OPENEYE_INCHI>
|
||||||
|
InChI=1S/C25H20N4O3/c30-20-10-6-18(7-11-20)15-29-25(32)28-16-23(19-8-12-21(31)13-9-19)26-22(24(28)27-29)14-17-4-2-1-3-5-17/h1-13,16,30-31H,14-15H2/p+1
|
||||||
|
|
||||||
|
> <OPENEYE_INCHIKEY>
|
||||||
|
FAFQJYQLYIUGAT-UHFFFAOYSA-O
|
||||||
|
|
||||||
|
> <FORMULA>
|
||||||
|
C25H21N4O3+
|
||||||
|
|
||||||
|
$$$$
|
116
examples/small_molecule/XG4.sdf
Normal file
116
examples/small_molecule/XG4.sdf
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
XG4
|
||||||
|
-OEChem-02232415213D
|
||||||
|
|
||||||
|
48 50 0 1 0 0 0 0 0999 V2000
|
||||||
|
7.5710 2.0160 -0.3960 N 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
7.4820 0.6990 -0.7370 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
8.5980 0.0580 -1.2130 N 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
6.3590 0.0260 -0.6220 N 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
5.2470 0.6100 -0.1630 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
5.2730 1.9610 0.2040 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
6.4880 2.6730 0.0750 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
6.5580 3.8500 0.3880 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
4.0270 2.2870 0.6270 N 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
3.2570 1.2420 0.5440 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
3.9670 0.1810 0.0620 N 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.3540 -1.9400 -0.6300 P 0 0 1 0 0 0 0 0 0 0 0 0
|
||||||
|
-4.4150 -0.1240 -0.2370 P 0 0 1 0 0 0 0 0 0 0 0 0
|
||||||
|
-6.1130 2.2430 0.3070 P 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
3.4520 -1.1710 -0.1750 C 0 0 1 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.9530 -2.8530 0.3690 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-4.9590 -0.4340 -1.5780 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-7.0830 1.4520 1.0960 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
3.7430 -2.0630 1.0480 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.9280 -2.2930 -2.0920 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-5.0840 -1.1020 0.8530 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-5.8010 3.6270 1.0690 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
2.3450 -2.5490 1.5040 C 0 0 2 0 0 0 0 0 0 0 0 0
|
||||||
|
2.3820 -3.9270 1.8790 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.7470 -0.3500 -0.2430 N 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-4.7480 1.4050 0.1380 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-6.7320 2.5590 -1.1450 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
1.4970 -2.3470 0.2220 C 0 0 2 0 0 0 0 0 0 0 0 0
|
||||||
|
2.0250 -1.1280 -0.3430 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
0.0180 -2.1810 0.5750 C 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-0.7540 -2.1140 -0.6250 O 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
8.4160 2.4830 -0.4890 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
9.4340 0.5410 -1.3030 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
8.5510 -0.8790 -1.4600 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
2.2120 1.2170 0.8170 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
3.9200 -1.5950 -1.0630 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
4.3670 -2.9100 0.7620 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
4.2220 -1.4840 1.8370 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.5780 -1.7320 -2.7980 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-4.7730 -0.9550 1.7570 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-6.5800 4.1840 1.2060 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
1.9680 -1.9350 2.3210 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
2.9720 -4.1170 2.6220 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-2.3430 -0.0830 0.6430 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-6.1490 3.0750 -1.7180 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
1.6350 -3.1820 -0.4650 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-0.3110 -3.0320 1.1710 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
-0.1180 -1.2630 1.1470 H 0 0 0 0 0 0 0 0 0 0 0 0
|
||||||
|
1 2 1 0 0 0 0
|
||||||
|
1 7 1 0 0 0 0
|
||||||
|
1 32 1 0 0 0 0
|
||||||
|
2 4 2 0 0 0 0
|
||||||
|
2 3 1 0 0 0 0
|
||||||
|
3 33 1 0 0 0 0
|
||||||
|
3 34 1 0 0 0 0
|
||||||
|
4 5 1 0 0 0 0
|
||||||
|
5 11 1 0 0 0 0
|
||||||
|
5 6 2 0 0 0 0
|
||||||
|
6 9 1 0 0 0 0
|
||||||
|
6 7 1 0 0 0 0
|
||||||
|
7 8 2 0 0 0 0
|
||||||
|
9 10 2 0 0 0 0
|
||||||
|
10 11 1 0 0 0 0
|
||||||
|
10 35 1 0 0 0 0
|
||||||
|
11 15 1 0 0 0 0
|
||||||
|
12 31 1 0 0 0 0
|
||||||
|
12 16 2 0 0 0 0
|
||||||
|
12 20 1 0 0 0 0
|
||||||
|
12 25 1 0 0 0 0
|
||||||
|
13 17 2 0 0 0 0
|
||||||
|
13 25 1 0 0 0 0
|
||||||
|
13 21 1 0 0 0 0
|
||||||
|
13 26 1 0 0 0 0
|
||||||
|
14 18 2 0 0 0 0
|
||||||
|
14 26 1 0 0 0 0
|
||||||
|
14 22 1 0 0 0 0
|
||||||
|
14 27 1 0 0 0 0
|
||||||
|
15 29 1 0 0 0 0
|
||||||
|
15 19 1 0 0 0 0
|
||||||
|
15 36 1 0 0 0 0
|
||||||
|
19 23 1 0 0 0 0
|
||||||
|
19 37 1 0 0 0 0
|
||||||
|
19 38 1 0 0 0 0
|
||||||
|
20 39 1 0 0 0 0
|
||||||
|
21 40 1 0 0 0 0
|
||||||
|
22 41 1 0 0 0 0
|
||||||
|
23 28 1 0 0 0 0
|
||||||
|
23 24 1 0 0 0 0
|
||||||
|
23 42 1 0 0 0 0
|
||||||
|
24 43 1 0 0 0 0
|
||||||
|
25 44 1 0 0 0 0
|
||||||
|
27 45 1 0 0 0 0
|
||||||
|
28 29 1 0 0 0 0
|
||||||
|
28 30 1 0 0 0 0
|
||||||
|
28 46 1 0 0 0 0
|
||||||
|
30 31 1 0 0 0 0
|
||||||
|
30 47 1 0 0 0 0
|
||||||
|
30 48 1 0 0 0 0
|
||||||
|
M END
|
||||||
|
> <OPENEYE_ISO_SMILES>
|
||||||
|
c1nc2c(=O)[nH]c(nc2n1[C@H]3C[C@@H]([C@H](O3)CO[P@@](=O)(N[P@@](=O)(O)OP(=O)(O)O)O)O)N
|
||||||
|
|
||||||
|
> <OPENEYE_INCHI>
|
||||||
|
InChI=1S/C10H17N6O12P3/c11-10-13-8-7(9(18)14-10)12-3-16(8)6-1-4(17)5(27-6)2-26-29(19,20)15-30(21,22)28-31(23,24)25/h3-6,17H,1-2H2,(H2,23,24,25)(H3,11,13,14,18)(H3,15,19,20,21,22)/t4-,5+,6+/m0/s1
|
||||||
|
|
||||||
|
> <OPENEYE_INCHIKEY>
|
||||||
|
DWGAAFQEGIMTIA-KVQBGUIXSA-N
|
||||||
|
|
||||||
|
> <FORMULA>
|
||||||
|
C10H17N6O12P3
|
||||||
|
|
||||||
|
$$$$
|
BIN
img/RFAA.png
Normal file
BIN
img/RFAA.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 722 KiB |
121
make_msa.sh
Executable file
121
make_msa.sh
Executable file
|
@ -0,0 +1,121 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# inputs
|
||||||
|
in_fasta="$1"
|
||||||
|
out_dir="$2"
|
||||||
|
|
||||||
|
# resources
|
||||||
|
CPU="$3"
|
||||||
|
MEM="$4"
|
||||||
|
|
||||||
|
# pipe_dir
|
||||||
|
PIPE_DIR="$5"
|
||||||
|
DB_TEMPL="$6"
|
||||||
|
|
||||||
|
# sequence databases
|
||||||
|
DB_UR30="$PIPE_DIR/uniclust/UniRef30_2021_06"
|
||||||
|
DB_BFD="$PIPE_DIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt"
|
||||||
|
|
||||||
|
# Running signalP 6.0
|
||||||
|
mkdir -p $out_dir/signalp
|
||||||
|
tmp_dir="$out_dir/signalp"
|
||||||
|
signalp6 --fastafile $in_fasta --organism other --output_dir $tmp_dir --format none --mode slow
|
||||||
|
trim_fasta="$tmp_dir/processed_entries.fasta"
|
||||||
|
if [ ! -s $trim_fasta ] # empty file -- no signal P
|
||||||
|
then
|
||||||
|
trim_fasta="$in_fasta"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# setup hhblits command
|
||||||
|
export HHLIB=/software/hhsuite/build/bin/
|
||||||
|
export PATH=$HHLIB:$PATH
|
||||||
|
HHBLITS_UR30="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_UR30"
|
||||||
|
HHBLITS_BFD="hhblits -o /dev/null -mact 0.35 -maxfilt 100000000 -neffmax 20 -cov 25 -cpu $CPU -nodiff -realign_max 100000000 -maxseq 1000000 -maxmem $MEM -n 4 -d $DB_BFD"
|
||||||
|
|
||||||
|
mkdir -p $out_dir/hhblits
|
||||||
|
tmp_dir="$out_dir/hhblits"
|
||||||
|
out_prefix="$out_dir/t000_"
|
||||||
|
|
||||||
|
# perform iterative searches against UniRef30
|
||||||
|
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||||
|
then
|
||||||
|
prev_a3m="$trim_fasta"
|
||||||
|
for e in 1e-10 1e-6 1e-3
|
||||||
|
do
|
||||||
|
echo "Running HHblits against UniRef30 with E-value cutoff $e"
|
||||||
|
if [ ! -s $tmp_dir/t000_.$e.a3m ]
|
||||||
|
then
|
||||||
|
$HHBLITS_UR30 -i $prev_a3m -oa3m $tmp_dir/t000_.$e.a3m -e $e -v 0
|
||||||
|
fi
|
||||||
|
hhfilter -maxseq 100000 -id 90 -cov 75 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov75.a3m
|
||||||
|
hhfilter -maxseq 100000 -id 90 -cov 50 -i $tmp_dir/t000_.$e.a3m -o $tmp_dir/t000_.$e.id90cov50.a3m
|
||||||
|
prev_a3m="$tmp_dir/t000_.$e.id90cov50.a3m"
|
||||||
|
n75=`grep -c "^>" $tmp_dir/t000_.$e.id90cov75.a3m`
|
||||||
|
n50=`grep -c "^>" $tmp_dir/t000_.$e.id90cov50.a3m`
|
||||||
|
|
||||||
|
if ((n75>2000))
|
||||||
|
then
|
||||||
|
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||||
|
then
|
||||||
|
cp $tmp_dir/t000_.$e.id90cov75.a3m ${out_prefix}.msa0.a3m
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
elif ((n50>4000))
|
||||||
|
then
|
||||||
|
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||||
|
then
|
||||||
|
cp $tmp_dir/t000_.$e.id90cov50.a3m ${out_prefix}.msa0.a3m
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# perform iterative searches against BFD if it failes to get enough sequences
|
||||||
|
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||||
|
then
|
||||||
|
e=1e-3
|
||||||
|
echo "Running HHblits against BFD with E-value cutoff $e"
|
||||||
|
if [ ! -s $tmp_dir/t000_.$e.bfd.a3m ]
|
||||||
|
then
|
||||||
|
$HHBLITS_BFD -i $prev_a3m -oa3m $tmp_dir/t000_.$e.bfd.a3m -e $e -v 0
|
||||||
|
fi
|
||||||
|
hhfilter -maxseq 100000 -id 90 -cov 75 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov75.a3m
|
||||||
|
hhfilter -maxseq 100000 -id 90 -cov 50 -i $tmp_dir/t000_.$e.bfd.a3m -o $tmp_dir/t000_.$e.bfd.id90cov50.a3m
|
||||||
|
prev_a3m="$tmp_dir/t000_.$e.bfd.id90cov50.a3m"
|
||||||
|
n75=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov75.a3m`
|
||||||
|
n50=`grep -c "^>" $tmp_dir/t000_.$e.bfd.id90cov50.a3m`
|
||||||
|
|
||||||
|
if ((n75>2000))
|
||||||
|
then
|
||||||
|
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||||
|
then
|
||||||
|
cp $tmp_dir/t000_.$e.bfd.id90cov75.a3m ${out_prefix}.msa0.a3m
|
||||||
|
fi
|
||||||
|
elif ((n50>4000))
|
||||||
|
then
|
||||||
|
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||||
|
then
|
||||||
|
cp $tmp_dir/t000_.$e.bfd.id90cov50.a3m ${out_prefix}.msa0.a3m
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -s ${out_prefix}.msa0.a3m ]
|
||||||
|
then
|
||||||
|
cp $prev_a3m ${out_prefix}.msa0.a3m
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Running PSIPRED"
|
||||||
|
$PIPE_DIR/input_prep/make_ss.sh $out_dir/t000_.msa0.a3m $out_dir/t000_.ss2 > $out_dir/log/make_ss.stdout 2> $out_dir/log/make_ss.stderr
|
||||||
|
|
||||||
|
if [ ! -s $out_dir/t000_.hhr ]
|
||||||
|
then
|
||||||
|
echo "Running hhsearch"
|
||||||
|
HH="hhsearch -b 50 -B 500 -z 50 -Z 500 -mact 0.05 -cpu $CPU -maxmem $MEM -aliw 100000 -e 100 -p 5.0 -d $DB_TEMPL"
|
||||||
|
|
||||||
|
cat $out_dir/t000_.ss2 $out_dir/t000_.msa0.a3m > $out_dir/t000_.msa0.ss2.a3m
|
||||||
|
$HH -i $out_dir/t000_.msa0.ss2.a3m -o $out_dir/t000_.hhr -atab $out_dir/t000_.atab -v 0
|
||||||
|
fi
|
123
rf2aa/SE3Transformer/.dockerignore
Normal file
123
rf2aa/SE3Transformer/.dockerignore
Normal file
|
@ -0,0 +1,123 @@
|
||||||
|
.Trash-0
|
||||||
|
.git
|
||||||
|
data/
|
||||||
|
.DS_Store
|
||||||
|
*wandb/
|
||||||
|
*.pt
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
# added by FAFU
|
||||||
|
.idea/
|
||||||
|
cache/
|
||||||
|
downloaded/
|
||||||
|
*.lprof
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
|
||||||
|
**/benchmark
|
||||||
|
**/results
|
||||||
|
*.pkl
|
||||||
|
*.log
|
121
rf2aa/SE3Transformer/.gitignore
vendored
Normal file
121
rf2aa/SE3Transformer/.gitignore
vendored
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
data/
|
||||||
|
.DS_Store
|
||||||
|
*wandb/
|
||||||
|
*.pt
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
# added by FAFU
|
||||||
|
.idea/
|
||||||
|
cache/
|
||||||
|
downloaded/
|
||||||
|
*.lprof
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
|
||||||
|
**/benchmark
|
||||||
|
**/results
|
||||||
|
*.pkl
|
||||||
|
*.log
|
58
rf2aa/SE3Transformer/Dockerfile
Normal file
58
rf2aa/SE3Transformer/Dockerfile
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
# run docker daemon with --default-runtime=nvidia for GPU detection during build
|
||||||
|
# multistage build for DGL with CUDA and FP16
|
||||||
|
|
||||||
|
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.07-py3
|
||||||
|
|
||||||
|
FROM ${FROM_IMAGE_NAME} AS dgl_builder
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y git build-essential python3-dev make cmake \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
WORKDIR /dgl
|
||||||
|
RUN git clone --branch v0.7.0 --recurse-submodules --depth 1 https://github.com/dmlc/dgl.git .
|
||||||
|
RUN sed -i 's/"35 50 60 70"/"60 70 80"/g' cmake/modules/CUDA.cmake
|
||||||
|
WORKDIR build
|
||||||
|
RUN cmake -DUSE_CUDA=ON -DUSE_FP16=ON ..
|
||||||
|
RUN make -j8
|
||||||
|
|
||||||
|
|
||||||
|
FROM ${FROM_IMAGE_NAME}
|
||||||
|
|
||||||
|
RUN rm -rf /workspace/*
|
||||||
|
WORKDIR /workspace/se3-transformer
|
||||||
|
|
||||||
|
# copy built DGL and install it
|
||||||
|
COPY --from=dgl_builder /dgl ./dgl
|
||||||
|
RUN cd dgl/python && python setup.py install && cd ../.. && rm -rf dgl
|
||||||
|
|
||||||
|
ADD requirements.txt .
|
||||||
|
RUN pip install --no-cache-dir --upgrade --pre pip
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
ADD . .
|
||||||
|
|
||||||
|
ENV DGLBACKEND=pytorch
|
||||||
|
ENV OMP_NUM_THREADS=1
|
7
rf2aa/SE3Transformer/LICENSE
Normal file
7
rf2aa/SE3Transformer/LICENSE
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
Copyright 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
7
rf2aa/SE3Transformer/NOTICE
Normal file
7
rf2aa/SE3Transformer/NOTICE
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
SE(3)-Transformer PyTorch
|
||||||
|
|
||||||
|
This repository includes software from https://github.com/FabianFuchsML/se3-transformer-public
|
||||||
|
licensed under the MIT License.
|
||||||
|
|
||||||
|
This repository includes software from https://github.com/lucidrains/se3-transformer-pytorch
|
||||||
|
licensed under the MIT License.
|
580
rf2aa/SE3Transformer/README.md
Normal file
580
rf2aa/SE3Transformer/README.md
Normal file
|
@ -0,0 +1,580 @@
|
||||||
|
# SE(3)-Transformers For PyTorch
|
||||||
|
|
||||||
|
This repository provides a script and recipe to train the SE(3)-Transformer model to achieve state-of-the-art accuracy. The content of this repository is tested and maintained by NVIDIA.
|
||||||
|
|
||||||
|
## Table Of Contents
|
||||||
|
- [Model overview](#model-overview)
|
||||||
|
* [Model architecture](#model-architecture)
|
||||||
|
* [Default configuration](#default-configuration)
|
||||||
|
* [Feature support matrix](#feature-support-matrix)
|
||||||
|
* [Features](#features)
|
||||||
|
* [Mixed precision training](#mixed-precision-training)
|
||||||
|
* [Enabling mixed precision](#enabling-mixed-precision)
|
||||||
|
* [Enabling TF32](#enabling-tf32)
|
||||||
|
* [Glossary](#glossary)
|
||||||
|
- [Setup](#setup)
|
||||||
|
* [Requirements](#requirements)
|
||||||
|
- [Quick Start Guide](#quick-start-guide)
|
||||||
|
- [Advanced](#advanced)
|
||||||
|
* [Scripts and sample code](#scripts-and-sample-code)
|
||||||
|
* [Parameters](#parameters)
|
||||||
|
* [Command-line options](#command-line-options)
|
||||||
|
* [Getting the data](#getting-the-data)
|
||||||
|
* [Dataset guidelines](#dataset-guidelines)
|
||||||
|
* [Multi-dataset](#multi-dataset)
|
||||||
|
* [Training process](#training-process)
|
||||||
|
* [Inference process](#inference-process)
|
||||||
|
- [Performance](#performance)
|
||||||
|
* [Benchmarking](#benchmarking)
|
||||||
|
* [Training performance benchmark](#training-performance-benchmark)
|
||||||
|
* [Inference performance benchmark](#inference-performance-benchmark)
|
||||||
|
* [Results](#results)
|
||||||
|
* [Training accuracy results](#training-accuracy-results)
|
||||||
|
* [Training accuracy: NVIDIA DGX A100 (8x A100 80GB)](#training-accuracy-nvidia-dgx-a100-8x-a100-80gb)
|
||||||
|
* [Training accuracy: NVIDIA DGX-1 (8x V100 16GB)](#training-accuracy-nvidia-dgx-1-8x-v100-16gb)
|
||||||
|
* [Training stability test](#training-stability-test)
|
||||||
|
* [Training performance results](#training-performance-results)
|
||||||
|
* [Training performance: NVIDIA DGX A100 (8x A100 80GB)](#training-performance-nvidia-dgx-a100-8x-a100-80gb)
|
||||||
|
* [Training performance: NVIDIA DGX-1 (8x V100 16GB)](#training-performance-nvidia-dgx-1-8x-v100-16gb)
|
||||||
|
* [Inference performance results](#inference-performance-results)
|
||||||
|
* [Inference performance: NVIDIA DGX A100 (1x A100 80GB)](#inference-performance-nvidia-dgx-a100-1x-a100-80gb)
|
||||||
|
* [Inference performance: NVIDIA DGX-1 (1x V100 16GB)](#inference-performance-nvidia-dgx-1-1x-v100-16gb)
|
||||||
|
- [Release notes](#release-notes)
|
||||||
|
* [Changelog](#changelog)
|
||||||
|
* [Known issues](#known-issues)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Model overview
|
||||||
|
|
||||||
|
|
||||||
|
The **SE(3)-Transformer** is a Graph Neural Network using a variant of [self-attention](https://arxiv.org/abs/1706.03762v5) for 3D points and graphs processing.
|
||||||
|
This model is [equivariant](https://en.wikipedia.org/wiki/Equivariant_map) under [continuous 3D roto-translations](https://en.wikipedia.org/wiki/Euclidean_group), meaning that when the inputs (graphs or sets of points) rotate in 3D space (or more generally experience a [proper rigid transformation](https://en.wikipedia.org/wiki/Rigid_transformation)), the model outputs either stay invariant or transform with the input.
|
||||||
|
A mathematical guarantee of equivariance is important to ensure stable and predictable performance in the presence of nuisance transformations of the data input and when the problem has some inherent symmetries we want to exploit.
|
||||||
|
|
||||||
|
|
||||||
|
The model is based on the following publications:
|
||||||
|
- [SE(3)-Transformers: 3D Roto-Translation Equivariant Attention Networks](https://arxiv.org/abs/2006.10503) (NeurIPS 2020) by Fabian B. Fuchs, Daniel E. Worrall, et al.
|
||||||
|
- [Tensor field networks: Rotation- and translation-equivariant neural networks for 3D point clouds](https://arxiv.org/abs/1802.08219) by Nathaniel Thomas, Tess Smidt, et al.
|
||||||
|
|
||||||
|
A follow-up paper explains how this model can be used iteratively, for example, to predict or refine protein structures:
|
||||||
|
|
||||||
|
- [Iterative SE(3)-Transformers](https://arxiv.org/abs/2102.13419) by Fabian B. Fuchs, Daniel E. Worrall, et al.
|
||||||
|
|
||||||
|
Just like [the official implementation](https://github.com/FabianFuchsML/se3-transformer-public), this implementation uses [PyTorch](https://pytorch.org/) and the [Deep Graph Library (DGL)](https://www.dgl.ai/).
|
||||||
|
|
||||||
|
The main differences between this implementation of SE(3)-Transformers and the official one are the following:
|
||||||
|
|
||||||
|
- Training and inference support for multiple GPUs
|
||||||
|
- Training and inference support for [Mixed Precision](https://arxiv.org/abs/1710.03740)
|
||||||
|
- The [QM9 dataset from DGL](https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset) is used and automatically downloaded
|
||||||
|
- Significantly increased throughput
|
||||||
|
- Significantly reduced memory consumption
|
||||||
|
- The use of layer normalization in the fully connected radial profile layers is an option (`--use_layer_norm`), off by default
|
||||||
|
- The use of equivariant normalization between attention layers is an option (`--norm`), off by default
|
||||||
|
- The [spherical harmonics](https://en.wikipedia.org/wiki/Spherical_harmonic) and [Clebsch–Gordan coefficients](https://en.wikipedia.org/wiki/Clebsch%E2%80%93Gordan_coefficients), used to compute bases matrices, are computed with the [e3nn library](https://e3nn.org/)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
This model enables you to predict quantum chemical properties of small organic molecules in the [QM9 dataset](https://www.nature.com/articles/sdata201422).
|
||||||
|
In this case, the exploited symmetry is that these properties do not depend on the orientation or position of the molecules in space.
|
||||||
|
|
||||||
|
|
||||||
|
This model is trained with mixed precision using Tensor Cores on NVIDIA Volta, NVIDIA Turing, and the NVIDIA Ampere GPU architectures. Therefore, researchers can get results up to 1.5x faster than training without Tensor Cores while experiencing the benefits of mixed precision training. This model is tested against each NGC monthly container release to ensure consistent accuracy and performance over time.
|
||||||
|
|
||||||
|
### Model architecture
|
||||||
|
|
||||||
|
The model consists of stacked layers of equivariant graph self-attention and equivariant normalization.
|
||||||
|
Lastly, a Tensor Field Network convolution is applied to obtain invariant features. Graph pooling (mean or max over the nodes) is applied to these features, and the result is fed to a final MLP to get scalar predictions.
|
||||||
|
|
||||||
|
In this setup, the model is a graph-to-scalar network. The pooling can be removed to obtain a graph-to-graph network, and the final TFN can be modified to output features of any type (invariant scalars, 3D vectors, ...).
|
||||||
|
|
||||||
|
|
||||||
|
![Model high-level architecture](./images/se3-transformer.png)
|
||||||
|
|
||||||
|
|
||||||
|
### Default configuration
|
||||||
|
|
||||||
|
|
||||||
|
SE(3)-Transformers introduce a self-attention layer for graphs that is equivariant to 3D roto-translations. It achieves this by leveraging Tensor Field Networks to build attention weights that are invariant and attention values that are equivariant.
|
||||||
|
Combining the equivariant values with the invariant weights gives rise to an equivariant output. This output is normalized while preserving equivariance thanks to equivariant normalization layers operating on feature norms.
|
||||||
|
|
||||||
|
|
||||||
|
The following features were implemented in this model:
|
||||||
|
|
||||||
|
- Support for edge features of any degree (1D, 3D, 5D, ...), whereas the official implementation only supports scalar invariant edge features (degree 0). Edge features with a degree greater than one are
|
||||||
|
concatenated to node features of the same degree. This is required in order to reproduce published results on point cloud processing.
|
||||||
|
- Data-parallel multi-GPU training (DDP)
|
||||||
|
- Mixed precision training (autocast, gradient scaling)
|
||||||
|
- Gradient accumulation
|
||||||
|
- Model checkpointing
|
||||||
|
|
||||||
|
|
||||||
|
The following performance optimizations were implemented in this model:
|
||||||
|
|
||||||
|
|
||||||
|
**General optimizations**
|
||||||
|
|
||||||
|
- The option is provided to precompute bases at the beginning of the training instead of computing them at the beginning of each forward pass (`--precompute_bases`)
|
||||||
|
- The bases computation is just-in-time (JIT) compiled with `torch.jit.script`
|
||||||
|
- The Clebsch-Gordon coefficients are cached in RAM
|
||||||
|
|
||||||
|
|
||||||
|
**Tensor Field Network optimizations**
|
||||||
|
|
||||||
|
- The last layer of each radial profile network does not add any bias in order to avoid large broadcasting operations
|
||||||
|
- The layout (order of dimensions) of the bases tensors is optimized to avoid copies to contiguous memory in the downstream TFN layers
|
||||||
|
- When Tensor Cores are available, and the output feature dimension of computed bases is odd, then it is padded with zeros to make more effective use of Tensor Cores (AMP and TF32 precisions)
|
||||||
|
- Multiple levels of fusion for TFN convolutions (and radial profiles) are provided and automatically used when conditions are met
|
||||||
|
- A low-memory mode is provided that will trade throughput for less memory use (`--low_memory`)
|
||||||
|
|
||||||
|
**Self-attention optimizations**
|
||||||
|
|
||||||
|
- Attention keys and values are computed by a single partial TFN graph convolution in each attention layer instead of two
|
||||||
|
- Graph operations for different output degrees may be fused together if conditions are met
|
||||||
|
|
||||||
|
|
||||||
|
**Normalization optimizations**
|
||||||
|
|
||||||
|
- The equivariant normalization layer is optimized from multiple layer normalizations to a group normalization on fused norms when certain conditions are met
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Competitive training results and analysis are provided for the following hyperparameters (identical to the ones in the original publication):
|
||||||
|
- Number of layers: 7
|
||||||
|
- Number of degrees: 4
|
||||||
|
- Number of channels: 32
|
||||||
|
- Number of attention heads: 8
|
||||||
|
- Channels division: 2
|
||||||
|
- Use of equivariant normalization: true
|
||||||
|
- Use of layer normalization: true
|
||||||
|
- Pooling: max
|
||||||
|
|
||||||
|
|
||||||
|
### Feature support matrix
|
||||||
|
|
||||||
|
This model supports the following features::
|
||||||
|
|
||||||
|
| Feature | SE(3)-Transformer
|
||||||
|
|-----------------------|--------------------------
|
||||||
|
|Automatic mixed precision (AMP) | Yes
|
||||||
|
|Distributed data parallel (DDP) | Yes
|
||||||
|
|
||||||
|
#### Features
|
||||||
|
|
||||||
|
|
||||||
|
**Distributed data parallel (DDP)**
|
||||||
|
|
||||||
|
[DistributedDataParallel (DDP)](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel) implements data parallelism at the module level that can run across multiple GPUs or machines.
|
||||||
|
|
||||||
|
**Automatic Mixed Precision (AMP)**
|
||||||
|
|
||||||
|
This implementation uses the native PyTorch AMP implementation of mixed precision training. It allows us to use FP16 training with FP32 master weights by modifying just a few lines of code. A detailed explanation of mixed precision can be found in the next section.
|
||||||
|
|
||||||
|
### Mixed precision training
|
||||||
|
|
||||||
|
Mixed precision is the combined use of different numerical precisions in a computational method. [Mixed precision](https://arxiv.org/abs/1710.03740) training offers significant computational speedup by performing operations in half-precision format while storing minimal information in single-precision to retain as much information as possible in critical parts of the network. Since the introduction of [Tensor Cores](https://developer.nvidia.com/tensor-cores) in NVIDIA Volta, and following with both the NVIDIA Turing and NVIDIA Ampere Architectures, significant training speedups are experienced by switching to mixed precision -- up to 3x overall speedup on the most arithmetically intense model architectures. Using [mixed precision training](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) previously required two steps:
|
||||||
|
1. Porting the model to use the FP16 data type where appropriate.
|
||||||
|
2. Adding loss scaling to preserve small gradient values.
|
||||||
|
|
||||||
|
AMP enables mixed precision training on NVIDIA Volta, NVIDIA Turing, and NVIDIA Ampere GPU architectures automatically. The PyTorch framework code makes all necessary model changes internally.
|
||||||
|
|
||||||
|
For information about:
|
||||||
|
- How to train using mixed precision, refer to the [Mixed Precision Training](https://arxiv.org/abs/1710.03740) paper and [Training With Mixed Precision](https://docs.nvidia.com/deeplearning/performance/mixed-precision-training/index.html) documentation.
|
||||||
|
- Techniques used for mixed precision training, refer to the [Mixed-Precision Training of Deep Neural Networks](https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/) blog.
|
||||||
|
- APEX tools for mixed precision training, refer to the [NVIDIA Apex: Tools for Easy Mixed-Precision Training in PyTorch](https://devblogs.nvidia.com/apex-pytorch-easy-mixed-precision-training/).
|
||||||
|
|
||||||
|
#### Enabling mixed precision
|
||||||
|
|
||||||
|
Mixed precision is enabled in PyTorch by using the native [Automatic Mixed Precision package](https://pytorch.org/docs/stable/amp.html), which casts variables to half-precision upon retrieval while storing variables in single-precision format. Furthermore, to preserve small gradient magnitudes in backpropagation, a [loss scaling](https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#lossscaling) step must be included when applying gradients. In PyTorch, loss scaling can be applied automatically using a `GradScaler`.
|
||||||
|
Automatic Mixed Precision makes all the adjustments internally in PyTorch, providing two benefits over manual operations. First, programmers need not modify network model code, reducing development and maintenance effort. Second, using AMP maintains forward and backward compatibility with all the APIs for defining and running PyTorch models.
|
||||||
|
|
||||||
|
To enable mixed precision, you can simply use the `--amp` flag when running the training or inference scripts.
|
||||||
|
|
||||||
|
#### Enabling TF32
|
||||||
|
|
||||||
|
TensorFloat-32 (TF32) is the new math mode in [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) GPUs for handling the matrix math, also called tensor operations. TF32 running on Tensor Cores in A100 GPUs can provide up to 10x speedups compared to single-precision floating-point math (FP32) on NVIDIA Volta GPUs.
|
||||||
|
|
||||||
|
TF32 Tensor Cores can speed up networks using FP32, typically with no loss of accuracy. It is more robust than FP16 for models that require a high dynamic range for weights or activations.
|
||||||
|
|
||||||
|
For more information, refer to the [TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x](https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/) blog post.
|
||||||
|
|
||||||
|
TF32 is supported in the NVIDIA Ampere GPU architecture and is enabled by default.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Glossary
|
||||||
|
|
||||||
|
**Degree (type)**
|
||||||
|
|
||||||
|
In the model, every feature (input, output and hidden) transforms in an equivariant way in relation to the input graph. When we define a feature, we need to choose, in addition to the number of channels, which transformation rule it obeys.
|
||||||
|
|
||||||
|
The degree or type of a feature is a positive integer that describes how this feature transforms when the input rotates in 3D.
|
||||||
|
|
||||||
|
This is related to [irreducible representations](https://en.wikipedia.org/wiki/Irreducible_representation) of different rotation orders.
|
||||||
|
|
||||||
|
The degree of a feature determines its dimensionality. A type-d feature has a dimensionality of 2d+1.
|
||||||
|
|
||||||
|
Some common examples include:
|
||||||
|
- Degree 0: 1D scalars invariant to rotation
|
||||||
|
- Degree 1: 3D vectors that rotate according to 3D rotation matrices
|
||||||
|
- Degree 2: 5D vectors that rotate according to 5D [Wigner-D matrices](https://en.wikipedia.org/wiki/Wigner_D-matrix). These can represent symmetric traceless 3x3 matrices.
|
||||||
|
|
||||||
|
**Fiber**
|
||||||
|
|
||||||
|
A fiber can be viewed as a representation of a set of features of different types or degrees (positive integers), where each feature type transforms according to its rule.
|
||||||
|
|
||||||
|
In this repository, a fiber can be seen as a dictionary with degrees as keys and numbers of channels as values.
|
||||||
|
|
||||||
|
**Multiplicity**
|
||||||
|
|
||||||
|
The multiplicity of a feature of a given type is the number of channels of this feature.
|
||||||
|
|
||||||
|
**Tensor Field Network**
|
||||||
|
|
||||||
|
A [Tensor Field Network](https://arxiv.org/abs/1802.08219) is a kind of equivariant graph convolution that can combine features of different degrees and produce new ones while preserving equivariance thanks to [tensor products](https://en.wikipedia.org/wiki/Tensor_product).
|
||||||
|
|
||||||
|
**Equivariance**
|
||||||
|
|
||||||
|
[Equivariance](https://en.wikipedia.org/wiki/Equivariant_map) is a property of a function of model stating that applying a symmetry transformation to the input and then computing the function produces the same result as computing the function and then applying the transformation to the output.
|
||||||
|
|
||||||
|
In the case of SE(3)-Transformer, the symmetry group is the group of continuous roto-translations (SE(3)).
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
The following section lists the requirements that you need to meet in order to start training the SE(3)-Transformer model.
|
||||||
|
|
||||||
|
### Requirements
|
||||||
|
|
||||||
|
This repository contains a Dockerfile which extends the PyTorch 21.07 NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
|
||||||
|
- [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
|
||||||
|
- PyTorch 21.07+ NGC container
|
||||||
|
- Supported GPUs:
|
||||||
|
- [NVIDIA Volta architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/)
|
||||||
|
- [NVIDIA Turing architecture](https://www.nvidia.com/en-us/design-visualization/technologies/turing-architecture/)
|
||||||
|
- [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
|
||||||
|
|
||||||
|
For more information about how to get started with NGC containers, refer to the following sections from the NVIDIA GPU Cloud Documentation and the Deep Learning Documentation:
|
||||||
|
- [Getting Started Using NVIDIA GPU Cloud](https://docs.nvidia.com/ngc/ngc-getting-started-guide/index.html)
|
||||||
|
- [Accessing And Pulling From The NGC Container Registry](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#accessing_registry)
|
||||||
|
- [Running PyTorch](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/running.html#running)
|
||||||
|
|
||||||
|
For those unable to use the PyTorch NGC container to set up the required environment or create your own container, refer to the versioned [NVIDIA Container Support Matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html).
|
||||||
|
|
||||||
|
## Quick Start Guide
|
||||||
|
|
||||||
|
To train your model using mixed or TF32 precision with Tensor Cores or FP32, perform the following steps using the default parameters of the SE(3)-Transformer model on the QM9 dataset. For the specifics concerning training and inference, refer to the [Advanced](#advanced) section.
|
||||||
|
|
||||||
|
1. Clone the repository.
|
||||||
|
```
|
||||||
|
git clone https://github.com/NVIDIA/DeepLearningExamples
|
||||||
|
cd DeepLearningExamples/PyTorch/DrugDiscovery/SE3Transformer
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Build the `se3-transformer` PyTorch NGC container.
|
||||||
|
```
|
||||||
|
docker build -t se3-transformer .
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Start an interactive session in the NGC container to run training/inference.
|
||||||
|
```
|
||||||
|
mkdir -p results
|
||||||
|
docker run -it --runtime=nvidia --shm-size=8g --ulimit memlock=-1 --ulimit stack=67108864 --rm -v ${PWD}/results:/results se3-transformer:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Start training.
|
||||||
|
```
|
||||||
|
bash scripts/train.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Start inference/predictions.
|
||||||
|
```
|
||||||
|
bash scripts/predict.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
Now that you have your model trained and evaluated, you can choose to compare your training results with our [Training accuracy results](#training-accuracy-results). You can also choose to benchmark your performance to [Training performance benchmark](#training-performance-results) or [Inference performance benchmark](#inference-performance-results). Following the steps in these sections will ensure that you achieve the same accuracy and performance results as stated in the [Results](#results) section.
|
||||||
|
|
||||||
|
## Advanced
|
||||||
|
|
||||||
|
The following sections provide greater details of the dataset, running training and inference, and the training results.
|
||||||
|
|
||||||
|
### Scripts and sample code
|
||||||
|
|
||||||
|
In the root directory, the most important files are:
|
||||||
|
- `Dockerfile`: container with the basic set of dependencies to run SE(3)-Transformers
|
||||||
|
- `requirements.txt`: set of extra requirements to run SE(3)-Transformers
|
||||||
|
- `se3_transformer/data_loading/qm9.py`: QM9 data loading and preprocessing, as well as bases precomputation
|
||||||
|
- `se3_transformer/model/layers/`: directory containing model architecture layers
|
||||||
|
- `se3_transformer/model/transformer.py`: main Transformer module
|
||||||
|
- `se3_transformer/model/basis.py`: logic for computing bases matrices
|
||||||
|
- `se3_transformer/runtime/training.py`: training script, to be run as a python module
|
||||||
|
- `se3_transformer/runtime/inference.py`: inference script, to be run as a python module
|
||||||
|
- `se3_transformer/runtime/metrics.py`: MAE metric with support for multi-GPU synchronization
|
||||||
|
- `se3_transformer/runtime/loggers.py`: [DLLogger](https://github.com/NVIDIA/dllogger) and [W&B](wandb.ai/) loggers
|
||||||
|
|
||||||
|
|
||||||
|
### Parameters
|
||||||
|
|
||||||
|
The complete list of the available parameters for the `training.py` script contains:
|
||||||
|
|
||||||
|
**General**
|
||||||
|
|
||||||
|
- `--epochs`: Number of training epochs (default: `100` for single-GPU)
|
||||||
|
- `--batch_size`: Batch size (default: `240`)
|
||||||
|
- `--seed`: Set a seed globally (default: `None`)
|
||||||
|
- `--num_workers`: Number of dataloading workers (default: `8`)
|
||||||
|
- `--amp`: Use Automatic Mixed Precision (default `false`)
|
||||||
|
- `--gradient_clip`: Clipping of the gradient norms (default: `None`)
|
||||||
|
- `--accumulate_grad_batches`: Gradient accumulation (default: `1`)
|
||||||
|
- `--ckpt_interval`: Save a checkpoint every N epochs (default: `-1`)
|
||||||
|
- `--eval_interval`: Do an evaluation round every N epochs (default: `1`)
|
||||||
|
- `--silent`: Minimize stdout output (default: `false`)
|
||||||
|
|
||||||
|
**Paths**
|
||||||
|
|
||||||
|
- `--data_dir`: Directory where the data is located or should be downloaded (default: `./data`)
|
||||||
|
- `--log_dir`: Directory where the results logs should be saved (default: `/results`)
|
||||||
|
- `--save_ckpt_path`: File where the checkpoint should be saved (default: `None`)
|
||||||
|
- `--load_ckpt_path`: File of the checkpoint to be loaded (default: `None`)
|
||||||
|
|
||||||
|
**Optimizer**
|
||||||
|
|
||||||
|
- `--optimizer`: Optimizer to use (default: `adam`)
|
||||||
|
- `--learning_rate`: Learning rate to use (default: `0.002` for single-GPU)
|
||||||
|
- `--momentum`: Momentum to use (default: `0.9`)
|
||||||
|
- `--weight_decay`: Weight decay to use (default: `0.1`)
|
||||||
|
|
||||||
|
**QM9 dataset**
|
||||||
|
|
||||||
|
- `--task`: Regression task to train on (default: `homo`)
|
||||||
|
- `--precompute_bases`: Precompute bases at the beginning of the script during dataset initialization, instead of computing them at the beginning of each forward pass (default: `false`)
|
||||||
|
|
||||||
|
**Model architecture**
|
||||||
|
|
||||||
|
- `--num_layers`: Number of stacked Transformer layers (default: `7`)
|
||||||
|
- `--num_heads`: Number of heads in self-attention (default: `8`)
|
||||||
|
- `--channels_div`: Channels division before feeding to attention layer (default: `2`)
|
||||||
|
- `--pooling`: Type of graph pooling (default: `max`)
|
||||||
|
- `--norm`: Apply a normalization layer after each attention block (default: `false`)
|
||||||
|
- `--use_layer_norm`: Apply layer normalization between MLP layers (default: `false`)
|
||||||
|
- `--low_memory`: If true, will use fused ops that are slower but use less memory (expect 25 percent less memory). Only has an effect if AMP is enabled on NVIDIA Volta GPUs or if running on Ampere GPUs (default: `false`)
|
||||||
|
- `--num_degrees`: Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1] (default: `4`)
|
||||||
|
- `--num_channels`: Number of channels for the hidden features (default: `32`)
|
||||||
|
|
||||||
|
|
||||||
|
### Command-line options
|
||||||
|
|
||||||
|
To show the full list of available options and their descriptions, use the `-h` or `--help` command-line option, for example: `python -m se3_transformer.runtime.training --help`.
|
||||||
|
|
||||||
|
|
||||||
|
### Dataset guidelines
|
||||||
|
|
||||||
|
#### Demo dataset
|
||||||
|
|
||||||
|
The SE(3)-Transformer was trained on the QM9 dataset.
|
||||||
|
|
||||||
|
The QM9 dataset is hosted on DGL servers and downloaded (38MB) automatically when needed. By default, it is stored in the `./data` directory, but this location can be changed with the `--data_dir` argument.
|
||||||
|
|
||||||
|
The dataset is saved as a `qm9_edge.npz` file and converted to DGL graphs at runtime.
|
||||||
|
|
||||||
|
As input features, we use:
|
||||||
|
- Node features (6D):
|
||||||
|
- One-hot-encoded atom type (5D) (atom types: H, C, N, O, F)
|
||||||
|
- Number of protons of each atom (1D)
|
||||||
|
- Edge features: one-hot-encoded bond type (4D) (bond types: single, double, triple, aromatic)
|
||||||
|
- The relative positions between adjacent nodes (atoms)
|
||||||
|
|
||||||
|
#### Custom datasets
|
||||||
|
|
||||||
|
To use this network on a new dataset, you can extend the `DataModule` class present in `se3_transformer/data_loading/data_module.py`.
|
||||||
|
|
||||||
|
Your custom collate function should return a tuple with:
|
||||||
|
|
||||||
|
- A (batched) DGLGraph object
|
||||||
|
- A dictionary of node features ({‘{degree}’: tensor})
|
||||||
|
- A dictionary of edge features ({‘{degree}’: tensor})
|
||||||
|
- (Optional) Precomputed bases as a dictionary
|
||||||
|
- Labels as a tensor
|
||||||
|
|
||||||
|
You can then modify the `training.py` and `inference.py` scripts to use your new data module.
|
||||||
|
|
||||||
|
### Training process
|
||||||
|
|
||||||
|
The training script is `se3_transformer/runtime/training.py`, to be run as a module: `python -m se3_transformer.runtime.training`.
|
||||||
|
|
||||||
|
**Logs**
|
||||||
|
|
||||||
|
By default, the resulting logs are stored in `/results/`. This can be changed with `--log_dir`.
|
||||||
|
|
||||||
|
You can connect your existing Weights & Biases account by setting the `WANDB_API_KEY` environment variable.
|
||||||
|
|
||||||
|
**Checkpoints**
|
||||||
|
|
||||||
|
The argument `--save_ckpt_path` can be set to the path of the file where the checkpoints should be saved.
|
||||||
|
`--ckpt_interval` can also be set to the interval (in the number of epochs) between checkpoints.
|
||||||
|
|
||||||
|
**Evaluation**
|
||||||
|
|
||||||
|
The evaluation metric is the Mean Absolute Error (MAE).
|
||||||
|
|
||||||
|
`--eval_interval` can be set to the interval (in the number of epochs) between evaluation rounds. By default, an evaluation round is performed after each epoch.
|
||||||
|
|
||||||
|
**Automatic Mixed Precision**
|
||||||
|
|
||||||
|
To enable Mixed Precision training, add the `--amp` flag.
|
||||||
|
|
||||||
|
**Multi-GPU and multi-node**
|
||||||
|
|
||||||
|
The training script supports the PyTorch elastic launcher to run on multiple GPUs or nodes. Refer to the [official documentation](https://pytorch.org/docs/1.9.0/elastic/run.html).
|
||||||
|
|
||||||
|
For example, to train on all available GPUs with AMP:
|
||||||
|
|
||||||
|
```
|
||||||
|
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --module se3_transformer.runtime.training --amp
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Inference process
|
||||||
|
|
||||||
|
Inference can be run by using the `se3_transformer.runtime.inference` python module.
|
||||||
|
|
||||||
|
The inference script is `se3_transformer/runtime/inference.py`, to be run as a module: `python -m se3_transformer.runtime.inference`. It requires a pre-trained model checkpoint (to be passed as `--load_ckpt_path`).
|
||||||
|
|
||||||
|
|
||||||
|
## Performance
|
||||||
|
|
||||||
|
The performance measurements in this document were conducted at the time of publication and may not reflect the performance achieved from NVIDIA’s latest software release. For the most up-to-date performance measurements, go to [NVIDIA Data Center Deep Learning Product Performance](https://developer.nvidia.com/deep-learning-performance-training-inference).
|
||||||
|
|
||||||
|
### Benchmarking
|
||||||
|
|
||||||
|
The following section shows how to run benchmarks measuring the model performance in training and inference modes.
|
||||||
|
|
||||||
|
#### Training performance benchmark
|
||||||
|
|
||||||
|
To benchmark the training performance on a specific batch size, run `bash scripts/benchmarck_train.sh {BATCH_SIZE}` for single GPU, and `bash scripts/benchmarck_train_multi_gpu.sh {BATCH_SIZE}` for multi-GPU.
|
||||||
|
|
||||||
|
#### Inference performance benchmark
|
||||||
|
|
||||||
|
To benchmark the inference performance on a specific batch size, run `bash scripts/benchmarck_inference.sh {BATCH_SIZE}`.
|
||||||
|
|
||||||
|
### Results
|
||||||
|
|
||||||
|
|
||||||
|
The following sections provide details on how we achieved our performance and accuracy in training and inference.
|
||||||
|
|
||||||
|
#### Training accuracy results
|
||||||
|
|
||||||
|
##### Training accuracy: NVIDIA DGX A100 (8x A100 80GB)
|
||||||
|
|
||||||
|
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 (8x A100 80GB) GPUs.
|
||||||
|
|
||||||
|
| GPUs | Batch size / GPU | Absolute error - TF32 | Absolute error - mixed precision | Time to train - TF32 | Time to train - mixed precision | Time to train speedup (mixed precision to TF32) |
|
||||||
|
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||||
|
| 1 | 240 | 0.03456 | 0.03460 | 1h23min | 1h03min | 1.32x |
|
||||||
|
| 8 | 240 | 0.03417 | 0.03424 | 15min | 12min | 1.25x |
|
||||||
|
|
||||||
|
|
||||||
|
##### Training accuracy: NVIDIA DGX-1 (8x V100 16GB)
|
||||||
|
|
||||||
|
Our results were obtained by running the `scripts/train.sh` training script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with (8x V100 16GB) GPUs.
|
||||||
|
|
||||||
|
| GPUs | Batch size / GPU | Absolute error - FP32 | Absolute error - mixed precision | Time to train - FP32 | Time to train - mixed precision | Time to train speedup (mixed precision to FP32) |
|
||||||
|
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||||
|
| 1 | 240 | 0.03432 | 0.03439 | 2h25min | 1h33min | 1.56x |
|
||||||
|
| 8 | 240 | 0.03380 | 0.03495 | 29min | 20min | 1.45x |
|
||||||
|
|
||||||
|
|
||||||
|
#### Training performance results
|
||||||
|
|
||||||
|
##### Training performance: NVIDIA DGX A100 (8x A100 80GB)
|
||||||
|
|
||||||
|
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 8x A100 80GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
|
||||||
|
|
||||||
|
| GPUs | Batch size / GPU | Throughput - TF32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (mixed precision - TF32) | Weak scaling - TF32 | Weak scaling - mixed precision |
|
||||||
|
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||||
|
| 1 | 240 | 2.21 | 2.92 | 1.32x | | |
|
||||||
|
| 1 | 120 | 1.81 | 2.04 | 1.13x | | |
|
||||||
|
| 8 | 240 | 17.15 | 22.95 | 1.34x | 7.76 | 7.86 |
|
||||||
|
| 8 | 120 | 13.89 | 15.62 | 1.12x | 7.67 | 7.66 |
|
||||||
|
|
||||||
|
|
||||||
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
|
||||||
|
|
||||||
|
##### Training performance: NVIDIA DGX-1 (8x V100 16GB)
|
||||||
|
|
||||||
|
Our results were obtained by running the `scripts/benchmark_train.sh` and `scripts/benchmark_train_multi_gpu.sh` benchmarking scripts in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 8x V100 16GB GPUs. Performance numbers (in molecules per millisecond) were averaged over five entire training epochs after a warmup epoch.
|
||||||
|
|
||||||
|
| GPUs | Batch size / GPU | Throughput - FP32 [mol/ms] | Throughput - mixed precision [mol/ms] | Throughput speedup (FP32 - mixed precision) | Weak scaling - FP32 | Weak scaling - mixed precision |
|
||||||
|
|:------------------:|:----------------------:|:--------------------:|:------------------------------------:|:---------------------------------:|:----------------------:|:----------------------------------------------:|
|
||||||
|
| 1 | 240 | 1.25 | 1.88 | 1.50x | | |
|
||||||
|
| 1 | 120 | 1.03 | 1.41 | 1.37x | | |
|
||||||
|
| 8 | 240 | 9.33 | 14.02 | 1.50x | 7.46 | 7.46 |
|
||||||
|
| 8 | 120 | 7.39 | 9.41 | 1.27x | 7.17 | 6.67 |
|
||||||
|
|
||||||
|
|
||||||
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
|
||||||
|
|
||||||
|
#### Inference performance results
|
||||||
|
|
||||||
|
|
||||||
|
##### Inference performance: NVIDIA DGX A100 (1x A100 80GB)
|
||||||
|
|
||||||
|
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX A100 with 1x A100 80GB GPU.
|
||||||
|
|
||||||
|
FP16
|
||||||
|
|
||||||
|
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
||||||
|
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
||||||
|
| 1600 | 11.60 | 140.94 | 138.29 | 140.12 | 386.40 |
|
||||||
|
| 800 | 10.74 | 75.69 | 75.74 | 76.50 | 79.77 |
|
||||||
|
| 400 | 8.86 | 45.57 | 46.11 | 46.60 | 49.97 |
|
||||||
|
|
||||||
|
TF32
|
||||||
|
|
||||||
|
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
||||||
|
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
||||||
|
| 1600 | 8.58 | 189.20 | 186.39 | 187.71 | 420.28 |
|
||||||
|
| 800 | 8.28 | 97.56 | 97.20 | 97.73 | 101.13 |
|
||||||
|
| 400 | 7.55 | 53.38 | 53.72 | 54.48 | 56.62 |
|
||||||
|
|
||||||
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
##### Inference performance: NVIDIA DGX-1 (1x V100 16GB)
|
||||||
|
|
||||||
|
Our results were obtained by running the `scripts/benchmark_inference.sh` inferencing benchmarking script in the PyTorch 21.07 NGC container on NVIDIA DGX-1 with 1x V100 16GB GPU.
|
||||||
|
|
||||||
|
FP16
|
||||||
|
|
||||||
|
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
||||||
|
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
||||||
|
| 1600 | 6.42 | 254.54 | 247.97 | 249.29 | 721.15 |
|
||||||
|
| 800 | 6.13 | 132.07 | 131.90 | 132.70 | 140.15 |
|
||||||
|
| 400 | 5.37 | 75.12 | 76.01 | 76.66 | 79.90 |
|
||||||
|
|
||||||
|
FP32
|
||||||
|
|
||||||
|
| Batch size | Throughput Avg [mol/ms] | Latency Avg [ms] | Latency 90% [ms] |Latency 95% [ms] |Latency 99% [ms] |
|
||||||
|
|:------------:|:------:|:-----:|:-----:|:-----:|:-----:|
|
||||||
|
| 1600 | 3.39 | 475.86 | 473.82 | 475.64 | 891.18 |
|
||||||
|
| 800 | 3.36 | 239.17 | 240.64 | 241.65 | 243.70 |
|
||||||
|
| 400 | 3.17 | 126.67 | 128.19 | 128.82 | 130.54 |
|
||||||
|
|
||||||
|
|
||||||
|
To achieve these same results, follow the steps in the [Quick Start Guide](#quick-start-guide).
|
||||||
|
|
||||||
|
|
||||||
|
## Release notes
|
||||||
|
|
||||||
|
### Changelog
|
||||||
|
|
||||||
|
August 2021
|
||||||
|
- Initial release
|
||||||
|
|
||||||
|
### Known issues
|
||||||
|
|
||||||
|
If you encounter `OSError: [Errno 12] Cannot allocate memory` during the Dataloader iterator creation (more precisely during the `fork()`, this is most likely due to the use of the `--precompute_bases` flag. If you cannot add more RAM or Swap to your machine, it is recommended to turn off bases precomputation by removing the `--precompute_bases` flag or using `--precompute_bases false`.
|
BIN
rf2aa/SE3Transformer/images/se3-transformer.png
Normal file
BIN
rf2aa/SE3Transformer/images/se3-transformer.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.1 MiB |
4
rf2aa/SE3Transformer/requirements.txt
Normal file
4
rf2aa/SE3Transformer/requirements.txt
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
e3nn==0.3.3
|
||||||
|
wandb==0.12.0
|
||||||
|
pynvml==11.0.0
|
||||||
|
git+https://github.com/NVIDIA/dllogger#egg=dllogger
|
15
rf2aa/SE3Transformer/scripts/benchmark_inference.sh
Executable file
15
rf2aa/SE3Transformer/scripts/benchmark_inference.sh
Executable file
|
@ -0,0 +1,15 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Script to benchmark inference performance, without bases precomputation
|
||||||
|
|
||||||
|
# CLI args with defaults
|
||||||
|
BATCH_SIZE=${1:-240}
|
||||||
|
AMP=${2:-true}
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.inference \
|
||||||
|
--amp "$AMP" \
|
||||||
|
--batch_size "$BATCH_SIZE" \
|
||||||
|
--use_layer_norm \
|
||||||
|
--norm \
|
||||||
|
--task homo \
|
||||||
|
--seed 42 \
|
||||||
|
--benchmark
|
18
rf2aa/SE3Transformer/scripts/benchmark_train.sh
Executable file
18
rf2aa/SE3Transformer/scripts/benchmark_train.sh
Executable file
|
@ -0,0 +1,18 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Script to benchmark single-GPU training performance, with bases precomputation
|
||||||
|
|
||||||
|
# CLI args with defaults
|
||||||
|
BATCH_SIZE=${1:-240}
|
||||||
|
AMP=${2:-true}
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python -m se3_transformer.runtime.training \
|
||||||
|
--amp "$AMP" \
|
||||||
|
--batch_size "$BATCH_SIZE" \
|
||||||
|
--epochs 6 \
|
||||||
|
--use_layer_norm \
|
||||||
|
--norm \
|
||||||
|
--save_ckpt_path model_qm9.pth \
|
||||||
|
--task homo \
|
||||||
|
--precompute_bases \
|
||||||
|
--seed 42 \
|
||||||
|
--benchmark
|
19
rf2aa/SE3Transformer/scripts/benchmark_train_multi_gpu.sh
Executable file
19
rf2aa/SE3Transformer/scripts/benchmark_train_multi_gpu.sh
Executable file
|
@ -0,0 +1,19 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Script to benchmark multi-GPU training performance, with bases precomputation
|
||||||
|
|
||||||
|
# CLI args with defaults
|
||||||
|
BATCH_SIZE=${1:-240}
|
||||||
|
AMP=${2:-true}
|
||||||
|
|
||||||
|
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
|
||||||
|
se3_transformer.runtime.training \
|
||||||
|
--amp "$AMP" \
|
||||||
|
--batch_size "$BATCH_SIZE" \
|
||||||
|
--epochs 6 \
|
||||||
|
--use_layer_norm \
|
||||||
|
--norm \
|
||||||
|
--save_ckpt_path model_qm9.pth \
|
||||||
|
--task homo \
|
||||||
|
--precompute_bases \
|
||||||
|
--seed 42 \
|
||||||
|
--benchmark
|
19
rf2aa/SE3Transformer/scripts/predict.sh
Executable file
19
rf2aa/SE3Transformer/scripts/predict.sh
Executable file
|
@ -0,0 +1,19 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# CLI args with defaults
|
||||||
|
BATCH_SIZE=${1:-240}
|
||||||
|
AMP=${2:-true}
|
||||||
|
|
||||||
|
|
||||||
|
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
||||||
|
# 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
|
||||||
|
TASK=homo
|
||||||
|
|
||||||
|
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
|
||||||
|
se3_transformer.runtime.inference \
|
||||||
|
--amp "$AMP" \
|
||||||
|
--batch_size "$BATCH_SIZE" \
|
||||||
|
--use_layer_norm \
|
||||||
|
--norm \
|
||||||
|
--load_ckpt_path model_qm9.pth \
|
||||||
|
--task "$TASK"
|
25
rf2aa/SE3Transformer/scripts/train.sh
Executable file
25
rf2aa/SE3Transformer/scripts/train.sh
Executable file
|
@ -0,0 +1,25 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# CLI args with defaults
|
||||||
|
BATCH_SIZE=${1:-240}
|
||||||
|
AMP=${2:-true}
|
||||||
|
NUM_EPOCHS=${3:-100}
|
||||||
|
LEARNING_RATE=${4:-0.002}
|
||||||
|
WEIGHT_DECAY=${5:-0.1}
|
||||||
|
|
||||||
|
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
||||||
|
# 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
|
||||||
|
TASK=homo
|
||||||
|
|
||||||
|
python -m se3_transformer.runtime.training \
|
||||||
|
--amp "$AMP" \
|
||||||
|
--batch_size "$BATCH_SIZE" \
|
||||||
|
--epochs "$NUM_EPOCHS" \
|
||||||
|
--lr "$LEARNING_RATE" \
|
||||||
|
--weight_decay "$WEIGHT_DECAY" \
|
||||||
|
--use_layer_norm \
|
||||||
|
--norm \
|
||||||
|
--save_ckpt_path model_qm9.pth \
|
||||||
|
--precompute_bases \
|
||||||
|
--seed 42 \
|
||||||
|
--task "$TASK"
|
27
rf2aa/SE3Transformer/scripts/train_multi_gpu.sh
Executable file
27
rf2aa/SE3Transformer/scripts/train_multi_gpu.sh
Executable file
|
@ -0,0 +1,27 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# CLI args with defaults
|
||||||
|
BATCH_SIZE=${1:-240}
|
||||||
|
AMP=${2:-true}
|
||||||
|
NUM_EPOCHS=${3:-130}
|
||||||
|
LEARNING_RATE=${4:-0.01}
|
||||||
|
WEIGHT_DECAY=${5:-0.1}
|
||||||
|
|
||||||
|
# choices: 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
||||||
|
# 'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'
|
||||||
|
TASK=homo
|
||||||
|
|
||||||
|
python -m torch.distributed.run --nnodes=1 --nproc_per_node=gpu --max_restarts 0 --module \
|
||||||
|
se3_transformer.runtime.training \
|
||||||
|
--amp "$AMP" \
|
||||||
|
--batch_size "$BATCH_SIZE" \
|
||||||
|
--epochs "$NUM_EPOCHS" \
|
||||||
|
--lr "$LEARNING_RATE" \
|
||||||
|
--min_lr 0.00001 \
|
||||||
|
--weight_decay "$WEIGHT_DECAY" \
|
||||||
|
--use_layer_norm \
|
||||||
|
--norm \
|
||||||
|
--save_ckpt_path model_qm9.pth \
|
||||||
|
--precompute_bases \
|
||||||
|
--seed 42 \
|
||||||
|
--task "$TASK"
|
0
rf2aa/SE3Transformer/se3_transformer/__init__.py
Normal file
0
rf2aa/SE3Transformer/se3_transformer/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
from .qm9 import QM9DataModule
|
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import torch.distributed as dist
|
||||||
|
from abc import ABC
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler, Dataset
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import get_local_rank
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dataloader(dataset: Dataset, shuffle: bool, **kwargs) -> DataLoader:
|
||||||
|
# Classic or distributed dataloader depending on the context
|
||||||
|
sampler = DistributedSampler(dataset, shuffle=shuffle) if dist.is_initialized() else None
|
||||||
|
return DataLoader(dataset, shuffle=(shuffle and sampler is None), sampler=sampler, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DataModule(ABC):
|
||||||
|
""" Abstract DataModule. Children must define self.ds_{train | val | test}. """
|
||||||
|
|
||||||
|
def __init__(self, **dataloader_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
if get_local_rank() == 0:
|
||||||
|
self.prepare_data()
|
||||||
|
|
||||||
|
# Wait until rank zero has prepared the data (download, preprocessing, ...)
|
||||||
|
if dist.is_initialized():
|
||||||
|
dist.barrier(device_ids=[get_local_rank()])
|
||||||
|
|
||||||
|
self.dataloader_kwargs = {'pin_memory': True, 'persistent_workers': True, **dataloader_kwargs}
|
||||||
|
self.ds_train, self.ds_val, self.ds_test = None, None, None
|
||||||
|
|
||||||
|
def prepare_data(self):
|
||||||
|
""" Method called only once per node. Put here any downloading or preprocessing """
|
||||||
|
pass
|
||||||
|
|
||||||
|
def train_dataloader(self) -> DataLoader:
|
||||||
|
return _get_dataloader(self.ds_train, shuffle=True, **self.dataloader_kwargs)
|
||||||
|
|
||||||
|
def val_dataloader(self) -> DataLoader:
|
||||||
|
return _get_dataloader(self.ds_val, shuffle=False, **self.dataloader_kwargs)
|
||||||
|
|
||||||
|
def test_dataloader(self) -> DataLoader:
|
||||||
|
return _get_dataloader(self.ds_test, shuffle=False, **self.dataloader_kwargs)
|
173
rf2aa/SE3Transformer/se3_transformer/data_loading/qm9.py
Normal file
173
rf2aa/SE3Transformer/se3_transformer/data_loading/qm9.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import dgl
|
||||||
|
import pathlib
|
||||||
|
import torch
|
||||||
|
from dgl.data import QM9EdgeDataset
|
||||||
|
from dgl import DGLGraph
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.utils.data import random_split, DataLoader, Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.data_loading.data_module import DataModule
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.basis import get_basis
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import get_local_rank, str2bool, using_tensor_cores
|
||||||
|
|
||||||
|
|
||||||
|
def _get_relative_pos(qm9_graph: DGLGraph) -> Tensor:
|
||||||
|
x = qm9_graph.ndata['pos']
|
||||||
|
src, dst = qm9_graph.edges()
|
||||||
|
rel_pos = x[dst] - x[src]
|
||||||
|
return rel_pos
|
||||||
|
|
||||||
|
|
||||||
|
def _get_split_sizes(full_dataset: Dataset) -> Tuple[int, int, int]:
|
||||||
|
len_full = len(full_dataset)
|
||||||
|
len_train = 100_000
|
||||||
|
len_test = int(0.1 * len_full)
|
||||||
|
len_val = len_full - len_train - len_test
|
||||||
|
return len_train, len_val, len_test
|
||||||
|
|
||||||
|
|
||||||
|
class QM9DataModule(DataModule):
|
||||||
|
"""
|
||||||
|
Datamodule wrapping https://docs.dgl.ai/en/latest/api/python/dgl.data.html#qm9edge-dataset
|
||||||
|
Training set is 100k molecules. Test set is 10% of the dataset. Validation set is the rest.
|
||||||
|
This includes all the molecules from QM9 except the ones that are uncharacterized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
NODE_FEATURE_DIM = 6
|
||||||
|
EDGE_FEATURE_DIM = 4
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
data_dir: pathlib.Path,
|
||||||
|
task: str = 'homo',
|
||||||
|
batch_size: int = 240,
|
||||||
|
num_workers: int = 8,
|
||||||
|
num_degrees: int = 4,
|
||||||
|
amp: bool = False,
|
||||||
|
precompute_bases: bool = False,
|
||||||
|
**kwargs):
|
||||||
|
self.data_dir = data_dir # This needs to be before __init__ so that prepare_data has access to it
|
||||||
|
super().__init__(batch_size=batch_size, num_workers=num_workers, collate_fn=self._collate)
|
||||||
|
self.amp = amp
|
||||||
|
self.task = task
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.num_degrees = num_degrees
|
||||||
|
|
||||||
|
qm9_kwargs = dict(label_keys=[self.task], verbose=False, raw_dir=str(data_dir))
|
||||||
|
if precompute_bases:
|
||||||
|
bases_kwargs = dict(max_degree=num_degrees - 1, use_pad_trick=using_tensor_cores(amp), amp=amp)
|
||||||
|
full_dataset = CachedBasesQM9EdgeDataset(bases_kwargs=bases_kwargs, batch_size=batch_size,
|
||||||
|
num_workers=num_workers, **qm9_kwargs)
|
||||||
|
else:
|
||||||
|
full_dataset = QM9EdgeDataset(**qm9_kwargs)
|
||||||
|
|
||||||
|
self.ds_train, self.ds_val, self.ds_test = random_split(full_dataset, _get_split_sizes(full_dataset),
|
||||||
|
generator=torch.Generator().manual_seed(0))
|
||||||
|
|
||||||
|
train_targets = full_dataset.targets[self.ds_train.indices, full_dataset.label_keys[0]]
|
||||||
|
self.targets_mean = train_targets.mean()
|
||||||
|
self.targets_std = train_targets.std()
|
||||||
|
|
||||||
|
def prepare_data(self):
|
||||||
|
# Download the QM9 preprocessed data
|
||||||
|
QM9EdgeDataset(verbose=True, raw_dir=str(self.data_dir))
|
||||||
|
|
||||||
|
def _collate(self, samples):
|
||||||
|
graphs, y, *bases = map(list, zip(*samples))
|
||||||
|
batched_graph = dgl.batch(graphs)
|
||||||
|
edge_feats = {'0': batched_graph.edata['edge_attr'][..., None]}
|
||||||
|
batched_graph.edata['rel_pos'] = _get_relative_pos(batched_graph)
|
||||||
|
# get node features
|
||||||
|
node_feats = {'0': batched_graph.ndata['attr'][:, :6, None]}
|
||||||
|
targets = (torch.cat(y) - self.targets_mean) / self.targets_std
|
||||||
|
|
||||||
|
if bases:
|
||||||
|
# collate bases
|
||||||
|
all_bases = {
|
||||||
|
key: torch.cat([b[key] for b in bases[0]], dim=0)
|
||||||
|
for key in bases[0][0].keys()
|
||||||
|
}
|
||||||
|
|
||||||
|
return batched_graph, node_feats, edge_feats, all_bases, targets
|
||||||
|
else:
|
||||||
|
return batched_graph, node_feats, edge_feats, targets
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_argparse_args(parent_parser):
|
||||||
|
parser = parent_parser.add_argument_group("QM9 dataset")
|
||||||
|
parser.add_argument('--task', type=str, default='homo', const='homo', nargs='?',
|
||||||
|
choices=['mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'U0', 'U', 'H', 'G', 'Cv',
|
||||||
|
'U0_atom', 'U_atom', 'H_atom', 'G_atom', 'A', 'B', 'C'],
|
||||||
|
help='Regression task to train on')
|
||||||
|
parser.add_argument('--precompute_bases', type=str2bool, nargs='?', const=True, default=False,
|
||||||
|
help='Precompute bases at the beginning of the script during dataset initialization,'
|
||||||
|
' instead of computing them at the beginning of each forward pass.')
|
||||||
|
return parent_parser
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'QM9({self.task})'
|
||||||
|
|
||||||
|
|
||||||
|
class CachedBasesQM9EdgeDataset(QM9EdgeDataset):
|
||||||
|
""" Dataset extending the QM9 dataset from DGL with precomputed (cached in RAM) pairwise bases """
|
||||||
|
|
||||||
|
def __init__(self, bases_kwargs: dict, batch_size: int, num_workers: int, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
:param bases_kwargs: Arguments to feed the bases computation function
|
||||||
|
:param batch_size: Batch size to use when iterating over the dataset for computing bases
|
||||||
|
"""
|
||||||
|
self.bases_kwargs = bases_kwargs
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.bases = None
|
||||||
|
self.num_workers = num_workers
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def load(self):
|
||||||
|
super().load()
|
||||||
|
# Iterate through the dataset and compute bases (pairwise only)
|
||||||
|
# Potential improvement: use multi-GPU and gather
|
||||||
|
dataloader = DataLoader(self, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers,
|
||||||
|
collate_fn=lambda samples: dgl.batch([sample[0] for sample in samples]))
|
||||||
|
bases = []
|
||||||
|
for i, graph in tqdm(enumerate(dataloader), total=len(dataloader), desc='Precomputing QM9 bases',
|
||||||
|
disable=get_local_rank() != 0):
|
||||||
|
rel_pos = _get_relative_pos(graph)
|
||||||
|
# Compute the bases with the GPU but convert the result to CPU to store in RAM
|
||||||
|
bases.append({k: v.cpu() for k, v in get_basis(rel_pos.cuda(), **self.bases_kwargs).items()})
|
||||||
|
self.bases = bases # Assign at the end so that __getitem__ isn't confused
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
graph, label = super().__getitem__(idx)
|
||||||
|
|
||||||
|
if self.bases:
|
||||||
|
bases_idx = idx // self.batch_size
|
||||||
|
bases_cumsum_idx = self.ne_cumsum[idx] - self.ne_cumsum[bases_idx * self.batch_size]
|
||||||
|
bases_cumsum_next_idx = self.ne_cumsum[idx + 1] - self.ne_cumsum[bases_idx * self.batch_size]
|
||||||
|
return graph, label, {key: basis[bases_cumsum_idx:bases_cumsum_next_idx] for key, basis in
|
||||||
|
self.bases[bases_idx].items()}
|
||||||
|
else:
|
||||||
|
return graph, label
|
2
rf2aa/SE3Transformer/se3_transformer/model/__init__.py
Normal file
2
rf2aa/SE3Transformer/se3_transformer/model/__init__.py
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
from .transformer import SE3Transformer, SE3TransformerPooled
|
||||||
|
from .fiber import Fiber
|
178
rf2aa/SE3Transformer/se3_transformer/model/basis.py
Normal file
178
rf2aa/SE3Transformer/se3_transformer/model/basis.py
Normal file
|
@ -0,0 +1,178 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import e3nn.o3 as o3
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_clebsch_gordon(J: int, d_in: int, d_out: int, device) -> Tensor:
|
||||||
|
""" Get the (cached) Q^{d_out,d_in}_J matrices from equation (8) """
|
||||||
|
return o3.wigner_3j(J, d_in, d_out, dtype=torch.float64, device=device).permute(2, 1, 0)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def get_all_clebsch_gordon(max_degree: int, device) -> List[List[Tensor]]:
|
||||||
|
all_cb = []
|
||||||
|
for d_in in range(max_degree + 1):
|
||||||
|
for d_out in range(max_degree + 1):
|
||||||
|
K_Js = []
|
||||||
|
for J in range(abs(d_in - d_out), d_in + d_out + 1):
|
||||||
|
K_Js.append(get_clebsch_gordon(J, d_in, d_out, device))
|
||||||
|
all_cb.append(K_Js)
|
||||||
|
return all_cb
|
||||||
|
|
||||||
|
|
||||||
|
def get_spherical_harmonics(relative_pos: Tensor, max_degree: int) -> List[Tensor]:
|
||||||
|
all_degrees = list(range(2 * max_degree + 1))
|
||||||
|
with nvtx_range('spherical harmonics'):
|
||||||
|
sh = o3.spherical_harmonics(all_degrees, relative_pos, normalize=True)
|
||||||
|
return torch.split(sh, [degree_to_dim(d) for d in all_degrees], dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def get_basis_script(max_degree: int,
|
||||||
|
use_pad_trick: bool,
|
||||||
|
spherical_harmonics: List[Tensor],
|
||||||
|
clebsch_gordon: List[List[Tensor]],
|
||||||
|
amp: bool) -> Dict[str, Tensor]:
|
||||||
|
"""
|
||||||
|
Compute pairwise bases matrices for degrees up to max_degree
|
||||||
|
:param max_degree: Maximum input or output degree
|
||||||
|
:param use_pad_trick: Pad some of the odd dimensions for a better use of Tensor Cores
|
||||||
|
:param spherical_harmonics: List of computed spherical harmonics
|
||||||
|
:param clebsch_gordon: List of computed CB-coefficients
|
||||||
|
:param amp: When true, return bases in FP16 precision
|
||||||
|
"""
|
||||||
|
basis = {}
|
||||||
|
idx = 0
|
||||||
|
# Double for loop instead of product() because of JIT script
|
||||||
|
for d_in in range(max_degree + 1):
|
||||||
|
for d_out in range(max_degree + 1):
|
||||||
|
key = f'{d_in},{d_out}'
|
||||||
|
K_Js = []
|
||||||
|
for freq_idx, J in enumerate(range(abs(d_in - d_out), d_in + d_out + 1)):
|
||||||
|
Q_J = clebsch_gordon[idx][freq_idx]
|
||||||
|
K_Js.append(torch.einsum('n f, k l f -> n l k', spherical_harmonics[J].float(), Q_J.float()))
|
||||||
|
|
||||||
|
basis[key] = torch.stack(K_Js, 2) # Stack on second dim so order is n l f k
|
||||||
|
if amp:
|
||||||
|
basis[key] = basis[key].half()
|
||||||
|
if use_pad_trick:
|
||||||
|
basis[key] = F.pad(basis[key], (0, 1)) # Pad the k dimension, that can be sliced later
|
||||||
|
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
return basis
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def update_basis_with_fused(basis: Dict[str, Tensor],
|
||||||
|
max_degree: int,
|
||||||
|
use_pad_trick: bool,
|
||||||
|
fully_fused: bool) -> Dict[str, Tensor]:
|
||||||
|
""" Update the basis dict with partially and optionally fully fused bases """
|
||||||
|
num_edges = basis['0,0'].shape[0]
|
||||||
|
device = basis['0,0'].device
|
||||||
|
dtype = basis['0,0'].dtype
|
||||||
|
sum_dim = sum([degree_to_dim(d) for d in range(max_degree + 1)])
|
||||||
|
|
||||||
|
# Fused per output degree
|
||||||
|
for d_out in range(max_degree + 1):
|
||||||
|
sum_freq = sum([degree_to_dim(min(d, d_out)) for d in range(max_degree + 1)])
|
||||||
|
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, degree_to_dim(d_out) + int(use_pad_trick),
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
acc_d, acc_f = 0, 0
|
||||||
|
for d_in in range(max_degree + 1):
|
||||||
|
basis_fused[:, acc_d:acc_d + degree_to_dim(d_in), acc_f:acc_f + degree_to_dim(min(d_out, d_in)),
|
||||||
|
:degree_to_dim(d_out)] = basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
||||||
|
|
||||||
|
acc_d += degree_to_dim(d_in)
|
||||||
|
acc_f += degree_to_dim(min(d_out, d_in))
|
||||||
|
|
||||||
|
basis[f'out{d_out}_fused'] = basis_fused
|
||||||
|
|
||||||
|
# Fused per input degree
|
||||||
|
for d_in in range(max_degree + 1):
|
||||||
|
sum_freq = sum([degree_to_dim(min(d, d_in)) for d in range(max_degree + 1)])
|
||||||
|
basis_fused = torch.zeros(num_edges, degree_to_dim(d_in), sum_freq, sum_dim,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
acc_d, acc_f = 0, 0
|
||||||
|
for d_out in range(max_degree + 1):
|
||||||
|
basis_fused[:, :, acc_f:acc_f + degree_to_dim(min(d_out, d_in)), acc_d:acc_d + degree_to_dim(d_out)] \
|
||||||
|
= basis[f'{d_in},{d_out}'][:, :, :, :degree_to_dim(d_out)]
|
||||||
|
|
||||||
|
acc_d += degree_to_dim(d_out)
|
||||||
|
acc_f += degree_to_dim(min(d_out, d_in))
|
||||||
|
|
||||||
|
basis[f'in{d_in}_fused'] = basis_fused
|
||||||
|
|
||||||
|
if fully_fused:
|
||||||
|
# Fully fused
|
||||||
|
# Double sum this way because of JIT script
|
||||||
|
sum_freq = sum([
|
||||||
|
sum([degree_to_dim(min(d_in, d_out)) for d_in in range(max_degree + 1)]) for d_out in range(max_degree + 1)
|
||||||
|
])
|
||||||
|
basis_fused = torch.zeros(num_edges, sum_dim, sum_freq, sum_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
acc_d, acc_f = 0, 0
|
||||||
|
for d_out in range(max_degree + 1):
|
||||||
|
b = basis[f'out{d_out}_fused']
|
||||||
|
basis_fused[:, :, acc_f:acc_f + b.shape[2], acc_d:acc_d + degree_to_dim(d_out)] = b[:, :, :,
|
||||||
|
:degree_to_dim(d_out)]
|
||||||
|
acc_f += b.shape[2]
|
||||||
|
acc_d += degree_to_dim(d_out)
|
||||||
|
|
||||||
|
basis['fully_fused'] = basis_fused
|
||||||
|
|
||||||
|
del basis['0,0'] # We know that the basis for l = k = 0 is filled with a constant
|
||||||
|
return basis
|
||||||
|
|
||||||
|
|
||||||
|
def get_basis(relative_pos: Tensor,
|
||||||
|
max_degree: int = 4,
|
||||||
|
compute_gradients: bool = False,
|
||||||
|
use_pad_trick: bool = False,
|
||||||
|
amp: bool = False) -> Dict[str, Tensor]:
|
||||||
|
with nvtx_range('spherical harmonics'):
|
||||||
|
spherical_harmonics = get_spherical_harmonics(relative_pos, max_degree)
|
||||||
|
with nvtx_range('CB coefficients'):
|
||||||
|
clebsch_gordon = get_all_clebsch_gordon(max_degree, relative_pos.device)
|
||||||
|
|
||||||
|
with torch.autograd.set_grad_enabled(compute_gradients):
|
||||||
|
with nvtx_range('bases'):
|
||||||
|
basis = get_basis_script(max_degree=max_degree,
|
||||||
|
use_pad_trick=use_pad_trick,
|
||||||
|
spherical_harmonics=spherical_harmonics,
|
||||||
|
clebsch_gordon=clebsch_gordon,
|
||||||
|
amp=amp)
|
||||||
|
return basis
|
144
rf2aa/SE3Transformer/se3_transformer/model/fiber.py
Normal file
144
rf2aa/SE3Transformer/se3_transformer/model/fiber.py
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
from collections import namedtuple
|
||||||
|
from itertools import product
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim
|
||||||
|
|
||||||
|
FiberEl = namedtuple('FiberEl', ['degree', 'channels'])
|
||||||
|
|
||||||
|
|
||||||
|
class Fiber(dict):
|
||||||
|
"""
|
||||||
|
Describes the structure of some set of features.
|
||||||
|
Features are split into types (0, 1, 2, 3, ...). A feature of type k has a dimension of 2k+1.
|
||||||
|
Type-0 features: invariant scalars
|
||||||
|
Type-1 features: equivariant 3D vectors
|
||||||
|
Type-2 features: equivariant symmetric traceless matrices
|
||||||
|
...
|
||||||
|
|
||||||
|
As inputs to a SE3 layer, there can be many features of the same types, and many features of different types.
|
||||||
|
The 'multiplicity' or 'number of channels' is the number of features of a given type.
|
||||||
|
This class puts together all the degrees and their multiplicities in order to describe
|
||||||
|
the inputs, outputs or hidden features of SE3 layers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, structure):
|
||||||
|
if isinstance(structure, dict):
|
||||||
|
structure = [FiberEl(int(d), int(m)) for d, m in sorted(structure.items(), key=lambda x: x[1])]
|
||||||
|
elif not isinstance(structure[0], FiberEl):
|
||||||
|
structure = list(map(lambda t: FiberEl(*t), sorted(structure, key=lambda x: x[1])))
|
||||||
|
self.structure = structure
|
||||||
|
super().__init__({d: m for d, m in self.structure})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def degrees(self):
|
||||||
|
return sorted([t.degree for t in self.structure])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def channels(self):
|
||||||
|
return [self[d] for d in self.degrees]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_features(self):
|
||||||
|
""" Size of the resulting tensor if all features were concatenated together """
|
||||||
|
return sum(t.channels * degree_to_dim(t.degree) for t in self.structure)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create(num_degrees: int, num_channels: int):
|
||||||
|
""" Create a Fiber with degrees 0..num_degrees-1, all with the same multiplicity """
|
||||||
|
return Fiber([(degree, num_channels) for degree in range(num_degrees)])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_features(feats: Dict[str, Tensor]):
|
||||||
|
""" Infer the Fiber structure from a feature dict """
|
||||||
|
structure = {}
|
||||||
|
for k, v in feats.items():
|
||||||
|
degree = int(k)
|
||||||
|
assert len(v.shape) == 3, 'Feature shape should be (N, C, 2D+1)'
|
||||||
|
assert v.shape[-1] == degree_to_dim(degree)
|
||||||
|
structure[degree] = v.shape[-2]
|
||||||
|
return Fiber(structure)
|
||||||
|
|
||||||
|
def __getitem__(self, degree: int):
|
||||||
|
""" fiber[degree] returns the multiplicity for this degree """
|
||||||
|
return dict(self.structure).get(degree, 0)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
""" Iterate over namedtuples (degree, channels) """
|
||||||
|
return iter(self.structure)
|
||||||
|
|
||||||
|
def __mul__(self, other):
|
||||||
|
"""
|
||||||
|
If other in an int, multiplies all the multiplicities by other.
|
||||||
|
If other is a fiber, returns the cartesian product.
|
||||||
|
"""
|
||||||
|
if isinstance(other, Fiber):
|
||||||
|
return product(self.structure, other.structure)
|
||||||
|
elif isinstance(other, int):
|
||||||
|
return Fiber({t.degree: t.channels * other for t in self.structure})
|
||||||
|
|
||||||
|
def __add__(self, other):
|
||||||
|
"""
|
||||||
|
If other in an int, add other to all the multiplicities.
|
||||||
|
If other is a fiber, add the multiplicities of the fibers together.
|
||||||
|
"""
|
||||||
|
if isinstance(other, Fiber):
|
||||||
|
return Fiber({t.degree: t.channels + other[t.degree] for t in self.structure})
|
||||||
|
elif isinstance(other, int):
|
||||||
|
return Fiber({t.degree: t.channels + other for t in self.structure})
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.structure)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_max(f1, f2):
|
||||||
|
""" Combine two fiber by taking the maximum multiplicity for each degree in both fibers """
|
||||||
|
new_dict = dict(f1.structure)
|
||||||
|
for k, m in f2.structure:
|
||||||
|
new_dict[k] = max(new_dict.get(k, 0), m)
|
||||||
|
|
||||||
|
return Fiber(list(new_dict.items()))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def combine_selectively(f1, f2):
|
||||||
|
""" Combine two fiber by taking the sum of multiplicities for each degree in the first fiber """
|
||||||
|
# only use orders which occur in fiber f1
|
||||||
|
new_dict = dict(f1.structure)
|
||||||
|
for k in f1.degrees:
|
||||||
|
if k in f2.degrees:
|
||||||
|
new_dict[k] += f2[k]
|
||||||
|
return Fiber(list(new_dict.items()))
|
||||||
|
|
||||||
|
def to_attention_heads(self, tensors: Dict[str, Tensor], num_heads: int):
|
||||||
|
# dict(N, num_channels, 2d+1) -> (N, num_heads, -1)
|
||||||
|
fibers = [tensors[str(degree)].reshape(*tensors[str(degree)].shape[:-2], num_heads, -1) for degree in
|
||||||
|
self.degrees]
|
||||||
|
fibers = torch.cat(fibers, -1)
|
||||||
|
return fibers
|
|
@ -0,0 +1,5 @@
|
||||||
|
from .linear import LinearSE3
|
||||||
|
from .norm import NormSE3
|
||||||
|
from .pooling import GPooling
|
||||||
|
from .convolution import ConvSE3
|
||||||
|
from .attention import AttentionBlockSE3
|
186
rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py
Normal file
186
rf2aa/SE3Transformer/se3_transformer/model/layers/attention.py
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import dgl
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from dgl import DGLGraph
|
||||||
|
from dgl.ops import edge_softmax
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, aggregate_residual, unfuse_features
|
||||||
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionSE3(nn.Module):
|
||||||
|
""" Multi-headed sparse graph self-attention (SE(3)-equivariant) """
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
key_fiber: Fiber,
|
||||||
|
value_fiber: Fiber
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param num_heads: Number of attention heads
|
||||||
|
:param key_fiber: Fiber for the keys (and also for the queries)
|
||||||
|
:param value_fiber: Fiber for the values
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.key_fiber = key_fiber
|
||||||
|
self.value_fiber = value_fiber
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
value: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
||||||
|
key: Union[Tensor, Dict[str, Tensor]], # edge features (may be fused)
|
||||||
|
query: Dict[str, Tensor], # node features
|
||||||
|
graph: DGLGraph
|
||||||
|
):
|
||||||
|
with nvtx_range('AttentionSE3'):
|
||||||
|
with nvtx_range('reshape keys and queries'):
|
||||||
|
if isinstance(key, Tensor):
|
||||||
|
# case where features of all types are fused
|
||||||
|
key = key.reshape(key.shape[0], self.num_heads, -1)
|
||||||
|
# need to reshape queries that way to keep the same layout as keys
|
||||||
|
out = torch.cat([query[str(d)] for d in self.key_fiber.degrees], dim=-1)
|
||||||
|
query = out.reshape(list(query.values())[0].shape[0], self.num_heads, -1)
|
||||||
|
else:
|
||||||
|
# features are not fused, need to fuse and reshape them
|
||||||
|
key = self.key_fiber.to_attention_heads(key, self.num_heads)
|
||||||
|
query = self.key_fiber.to_attention_heads(query, self.num_heads)
|
||||||
|
|
||||||
|
with nvtx_range('attention dot product + softmax'):
|
||||||
|
# Compute attention weights (softmax of inner product between key and query)
|
||||||
|
with torch.cuda.amp.autocast(False):
|
||||||
|
edge_weights = dgl.ops.e_dot_v(graph, key, query).squeeze(-1)
|
||||||
|
edge_weights /= np.sqrt(self.key_fiber.num_features)
|
||||||
|
edge_weights = edge_softmax(graph, edge_weights)
|
||||||
|
edge_weights = edge_weights[..., None, None]
|
||||||
|
|
||||||
|
with nvtx_range('weighted sum'):
|
||||||
|
if isinstance(value, Tensor):
|
||||||
|
# features of all types are fused
|
||||||
|
v = value.view(value.shape[0], self.num_heads, -1, value.shape[-1])
|
||||||
|
weights = edge_weights * v
|
||||||
|
feat_out = dgl.ops.copy_e_sum(graph, weights)
|
||||||
|
feat_out = feat_out.view(feat_out.shape[0], -1, feat_out.shape[-1]) # merge heads
|
||||||
|
out = unfuse_features(feat_out, self.value_fiber.degrees)
|
||||||
|
else:
|
||||||
|
out = {}
|
||||||
|
for degree, channels in self.value_fiber:
|
||||||
|
v = value[str(degree)].view(-1, self.num_heads, channels // self.num_heads,
|
||||||
|
degree_to_dim(degree))
|
||||||
|
weights = edge_weights * v
|
||||||
|
res = dgl.ops.copy_e_sum(graph, weights)
|
||||||
|
out[str(degree)] = res.view(-1, channels, degree_to_dim(degree)) # merge heads
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlockSE3(nn.Module):
|
||||||
|
""" Multi-headed sparse graph self-attention block with skip connection, linear projection (SE(3)-equivariant) """
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fiber_in: Fiber,
|
||||||
|
fiber_out: Fiber,
|
||||||
|
fiber_edge: Optional[Fiber] = None,
|
||||||
|
num_heads: int = 4,
|
||||||
|
channels_div: Optional[Dict[str,int]] = None,
|
||||||
|
use_layer_norm: bool = False,
|
||||||
|
max_degree: bool = 4,
|
||||||
|
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param fiber_in: Fiber describing the input features
|
||||||
|
:param fiber_out: Fiber describing the output features
|
||||||
|
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
||||||
|
:param num_heads: Number of attention heads
|
||||||
|
:param channels_div: Divide the channels by this integer for computing values
|
||||||
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
|
:param max_degree: Maximum degree used in the bases computation
|
||||||
|
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if fiber_edge is None:
|
||||||
|
fiber_edge = Fiber({})
|
||||||
|
self.fiber_in = fiber_in
|
||||||
|
# value_fiber has same structure as fiber_out but #channels divided by 'channels_div'
|
||||||
|
if channels_div is not None:
|
||||||
|
value_fiber = Fiber([(degree, channels // channels_div[str(degree)]) for degree, channels in fiber_out])
|
||||||
|
else:
|
||||||
|
value_fiber = Fiber([(degree, channels) for degree, channels in fiber_out])
|
||||||
|
|
||||||
|
# key_query_fiber has the same structure as fiber_out, but only degrees which are in in_fiber
|
||||||
|
# (queries are merely projected, hence degrees have to match input)
|
||||||
|
key_query_fiber = Fiber([(fe.degree, fe.channels) for fe in value_fiber if fe.degree in fiber_in.degrees])
|
||||||
|
|
||||||
|
self.to_key_value = ConvSE3(fiber_in, value_fiber + key_query_fiber, pool=False, fiber_edge=fiber_edge,
|
||||||
|
use_layer_norm=use_layer_norm, max_degree=max_degree, fuse_level=fuse_level,
|
||||||
|
allow_fused_output=True)
|
||||||
|
self.to_query = LinearSE3(fiber_in, key_query_fiber)
|
||||||
|
self.attention = AttentionSE3(num_heads, key_query_fiber, value_fiber)
|
||||||
|
self.project = LinearSE3(value_fiber + fiber_in, fiber_out)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
node_features: Dict[str, Tensor],
|
||||||
|
edge_features: Dict[str, Tensor],
|
||||||
|
graph: DGLGraph,
|
||||||
|
basis: Dict[str, Tensor]
|
||||||
|
):
|
||||||
|
with nvtx_range('AttentionBlockSE3'):
|
||||||
|
with nvtx_range('keys / values'):
|
||||||
|
fused_key_value = self.to_key_value(node_features, edge_features, graph, basis)
|
||||||
|
key, value = self._get_key_value_from_fused(fused_key_value)
|
||||||
|
|
||||||
|
with nvtx_range('queries'):
|
||||||
|
with torch.cuda.amp.autocast(False):
|
||||||
|
query = self.to_query(node_features)
|
||||||
|
|
||||||
|
z = self.attention(value, key, query, graph)
|
||||||
|
z_concat = aggregate_residual(node_features, z, 'cat')
|
||||||
|
return self.project(z_concat)
|
||||||
|
|
||||||
|
def _get_key_value_from_fused(self, fused_key_value):
|
||||||
|
# Extract keys and queries features from fused features
|
||||||
|
if isinstance(fused_key_value, Tensor):
|
||||||
|
# Previous layer was a fully fused convolution
|
||||||
|
value, key = torch.chunk(fused_key_value, chunks=2, dim=-2)
|
||||||
|
else:
|
||||||
|
key, value = {}, {}
|
||||||
|
for degree, feat in fused_key_value.items():
|
||||||
|
if int(degree) in self.fiber_in.degrees:
|
||||||
|
value[degree], key[degree] = torch.chunk(feat, chunks=2, dim=-2)
|
||||||
|
else:
|
||||||
|
value[degree] = feat
|
||||||
|
|
||||||
|
return key, value
|
381
rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py
Normal file
381
rf2aa/SE3Transformer/se3_transformer/model/layers/convolution.py
Normal file
|
@ -0,0 +1,381 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from itertools import product
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import dgl
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from dgl import DGLGraph
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import degree_to_dim, unfuse_features
|
||||||
|
|
||||||
|
|
||||||
|
class ConvSE3FuseLevel(Enum):
|
||||||
|
"""
|
||||||
|
Enum to select a maximum level of fusing optimizations that will be applied when certain conditions are met.
|
||||||
|
If a desired level L is picked and the level L cannot be applied to a level, other fused ops < L are considered.
|
||||||
|
A higher level means faster training, but also more memory usage.
|
||||||
|
If you are tight on memory and want to feed large inputs to the network, choose a low value.
|
||||||
|
If you want to train fast, choose a high value.
|
||||||
|
Recommended value is FULL with AMP.
|
||||||
|
|
||||||
|
Fully fused TFN convolutions requirements:
|
||||||
|
- all input channels are the same
|
||||||
|
- all output channels are the same
|
||||||
|
- input degrees span the range [0, ..., max_degree]
|
||||||
|
- output degrees span the range [0, ..., max_degree]
|
||||||
|
|
||||||
|
Partially fused TFN convolutions requirements:
|
||||||
|
* For fusing by output degree:
|
||||||
|
- all input channels are the same
|
||||||
|
- input degrees span the range [0, ..., max_degree]
|
||||||
|
* For fusing by input degree:
|
||||||
|
- all output channels are the same
|
||||||
|
- output degrees span the range [0, ..., max_degree]
|
||||||
|
|
||||||
|
Original TFN pairwise convolutions: no requirements
|
||||||
|
"""
|
||||||
|
|
||||||
|
FULL = 2
|
||||||
|
PARTIAL = 1
|
||||||
|
NONE = 0
|
||||||
|
|
||||||
|
|
||||||
|
class RadialProfile(nn.Module):
|
||||||
|
"""
|
||||||
|
Radial profile function.
|
||||||
|
Outputs weights used to weigh basis matrices in order to get convolution kernels.
|
||||||
|
In TFN notation: $R^{l,k}$
|
||||||
|
In SE(3)-Transformer notation: $\phi^{l,k}$
|
||||||
|
|
||||||
|
Note:
|
||||||
|
In the original papers, this function only depends on relative node distances ||x||.
|
||||||
|
Here, we allow this function to also take as input additional invariant edge features.
|
||||||
|
This does not break equivariance and adds expressive power to the model.
|
||||||
|
|
||||||
|
Diagram:
|
||||||
|
invariant edge features (node distances included) ───> MLP layer (shared across edges) ───> radial weights
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_freq: int,
|
||||||
|
channels_in: int,
|
||||||
|
channels_out: int,
|
||||||
|
edge_dim: int = 1,
|
||||||
|
mid_dim: int = 32,
|
||||||
|
use_layer_norm: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param num_freq: Number of frequencies
|
||||||
|
:param channels_in: Number of input channels
|
||||||
|
:param channels_out: Number of output channels
|
||||||
|
:param edge_dim: Number of invariant edge features (input to the radial function)
|
||||||
|
:param mid_dim: Size of the hidden MLP layers
|
||||||
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
modules = [
|
||||||
|
nn.Linear(edge_dim, mid_dim),
|
||||||
|
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(mid_dim, mid_dim),
|
||||||
|
nn.LayerNorm(mid_dim) if use_layer_norm else None,
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(mid_dim, num_freq * channels_in * channels_out, bias=False)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.net = nn.Sequential(*[m for m in modules if m is not None])
|
||||||
|
|
||||||
|
def forward(self, features: Tensor) -> Tensor:
|
||||||
|
return self.net(features)
|
||||||
|
|
||||||
|
|
||||||
|
class VersatileConvSE3(nn.Module):
|
||||||
|
"""
|
||||||
|
Building block for TFN convolutions.
|
||||||
|
This single module can be used for fully fused convolutions, partially fused convolutions, or pairwise convolutions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
freq_sum: int,
|
||||||
|
channels_in: int,
|
||||||
|
channels_out: int,
|
||||||
|
edge_dim: int,
|
||||||
|
use_layer_norm: bool,
|
||||||
|
fuse_level: ConvSE3FuseLevel):
|
||||||
|
super().__init__()
|
||||||
|
self.freq_sum = freq_sum
|
||||||
|
self.channels_out = channels_out
|
||||||
|
self.channels_in = channels_in
|
||||||
|
self.fuse_level = fuse_level
|
||||||
|
self.radial_func = RadialProfile(num_freq=freq_sum,
|
||||||
|
channels_in=channels_in,
|
||||||
|
channels_out=channels_out,
|
||||||
|
edge_dim=edge_dim,
|
||||||
|
use_layer_norm=use_layer_norm)
|
||||||
|
|
||||||
|
def forward(self, features: Tensor, invariant_edge_feats: Tensor, basis: Tensor):
|
||||||
|
with nvtx_range(f'VersatileConvSE3'):
|
||||||
|
num_edges = features.shape[0]
|
||||||
|
in_dim = features.shape[2]
|
||||||
|
if (self.training or num_edges<=4096):
|
||||||
|
with nvtx_range(f'RadialProfile'):
|
||||||
|
radial_weights = self.radial_func(invariant_edge_feats) \
|
||||||
|
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
||||||
|
|
||||||
|
if basis is not None:
|
||||||
|
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
||||||
|
out_dim = basis.shape[-1]
|
||||||
|
if self.fuse_level != ConvSE3FuseLevel.FULL:
|
||||||
|
out_dim += out_dim % 2 - 1 # Account for padded basis
|
||||||
|
basis_view = basis.view(num_edges, in_dim, -1)
|
||||||
|
tmp = (features @ basis_view).view(num_edges, -1, basis.shape[-1])
|
||||||
|
retval = (radial_weights @ tmp)[:, :, :out_dim]
|
||||||
|
return retval
|
||||||
|
else:
|
||||||
|
# k = l = 0 non-fused case
|
||||||
|
retval = radial_weights @ features
|
||||||
|
|
||||||
|
else:
|
||||||
|
#fd reduce memory in inference
|
||||||
|
EDGESTRIDE = 65536 #16384
|
||||||
|
if basis is not None:
|
||||||
|
out_dim = basis.shape[-1]
|
||||||
|
if self.fuse_level != ConvSE3FuseLevel.FULL:
|
||||||
|
out_dim += out_dim % 2 - 1 # Account for padded basis
|
||||||
|
else:
|
||||||
|
out_dim = features.shape[-1]
|
||||||
|
|
||||||
|
retval = torch.zeros(
|
||||||
|
(num_edges, self.channels_out, out_dim),
|
||||||
|
dtype=features.dtype,
|
||||||
|
device=features.device
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range((num_edges-1)//EDGESTRIDE+1):
|
||||||
|
e_i,e_j = i*EDGESTRIDE, min((i+1)*EDGESTRIDE,num_edges)
|
||||||
|
|
||||||
|
radial_weights = self.radial_func(invariant_edge_feats[e_i:e_j]) \
|
||||||
|
.view(-1, self.channels_out, self.channels_in * self.freq_sum)
|
||||||
|
|
||||||
|
if basis is not None:
|
||||||
|
# This block performs the einsum n i l, n o i f, n l f k -> n o k
|
||||||
|
basis_view = basis[e_i:e_j].view(e_j-e_i, in_dim, -1)
|
||||||
|
with torch.cuda.amp.autocast(False):
|
||||||
|
tmp = (features[e_i:e_j] @ basis_view.float()).view(e_j-e_i, -1, basis.shape[-1])
|
||||||
|
retslice = (radial_weights.float() @ tmp)[:, :, :out_dim]
|
||||||
|
retval[e_i:e_j] = retslice
|
||||||
|
|
||||||
|
else:
|
||||||
|
# k = l = 0 non-fused case
|
||||||
|
retval[e_i:e_j] = radial_weights @ features[e_i:e_j]
|
||||||
|
|
||||||
|
return retval
|
||||||
|
|
||||||
|
class ConvSE3(nn.Module):
|
||||||
|
"""
|
||||||
|
SE(3)-equivariant graph convolution (Tensor Field Network convolution).
|
||||||
|
This convolution can map an arbitrary input Fiber to an arbitrary output Fiber, while preserving equivariance.
|
||||||
|
Features of different degrees interact together to produce output features.
|
||||||
|
|
||||||
|
Note 1:
|
||||||
|
The option is given to not pool the output. This means that the convolution sum over neighbors will not be
|
||||||
|
done, and the returned features will be edge features instead of node features.
|
||||||
|
|
||||||
|
Note 2:
|
||||||
|
Unlike the original paper and implementation, this convolution can handle edge feature of degree greater than 0.
|
||||||
|
Input edge features are concatenated with input source node features before the kernel is applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
fiber_in: Fiber,
|
||||||
|
fiber_out: Fiber,
|
||||||
|
fiber_edge: Fiber,
|
||||||
|
pool: bool = True,
|
||||||
|
use_layer_norm: bool = False,
|
||||||
|
self_interaction: bool = False,
|
||||||
|
sum_over_edge: bool = True,
|
||||||
|
max_degree: int = 4,
|
||||||
|
fuse_level: ConvSE3FuseLevel = ConvSE3FuseLevel.FULL,
|
||||||
|
allow_fused_output: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param fiber_in: Fiber describing the input features
|
||||||
|
:param fiber_out: Fiber describing the output features
|
||||||
|
:param fiber_edge: Fiber describing the edge features (node distances excluded)
|
||||||
|
:param pool: If True, compute final node features by averaging incoming edge features
|
||||||
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
|
:param self_interaction: Apply self-interaction of nodes
|
||||||
|
:param max_degree: Maximum degree used in the bases computation
|
||||||
|
:param fuse_level: Maximum fuse level to use in TFN convolutions
|
||||||
|
:param allow_fused_output: Allow the module to output a fused representation of features
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.pool = pool
|
||||||
|
self.fiber_in = fiber_in
|
||||||
|
self.fiber_out = fiber_out
|
||||||
|
self.self_interaction = self_interaction
|
||||||
|
self.sum_over_edge = sum_over_edge
|
||||||
|
self.max_degree = max_degree
|
||||||
|
self.allow_fused_output = allow_fused_output
|
||||||
|
|
||||||
|
# channels_in: account for the concatenation of edge features
|
||||||
|
channels_in_set = set([f.channels + fiber_edge[f.degree] * (f.degree > 0) for f in self.fiber_in])
|
||||||
|
channels_out_set = set([f.channels for f in self.fiber_out])
|
||||||
|
unique_channels_in = (len(channels_in_set) == 1)
|
||||||
|
unique_channels_out = (len(channels_out_set) == 1)
|
||||||
|
degrees_up_to_max = list(range(max_degree + 1))
|
||||||
|
common_args = dict(edge_dim=fiber_edge[0] + 1, use_layer_norm=use_layer_norm)
|
||||||
|
|
||||||
|
if fuse_level.value >= ConvSE3FuseLevel.FULL.value and \
|
||||||
|
unique_channels_in and fiber_in.degrees == degrees_up_to_max and \
|
||||||
|
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
||||||
|
# Single fused convolution
|
||||||
|
self.used_fuse_level = ConvSE3FuseLevel.FULL
|
||||||
|
|
||||||
|
sum_freq = sum([
|
||||||
|
degree_to_dim(min(d_in, d_out))
|
||||||
|
for d_in, d_out in product(degrees_up_to_max, degrees_up_to_max)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.conv = VersatileConvSE3(sum_freq, list(channels_in_set)[0], list(channels_out_set)[0],
|
||||||
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
|
|
||||||
|
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
||||||
|
unique_channels_in and fiber_in.degrees == degrees_up_to_max:
|
||||||
|
# Convolutions fused per output degree
|
||||||
|
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
||||||
|
self.conv_out = nn.ModuleDict()
|
||||||
|
for d_out, c_out in fiber_out:
|
||||||
|
sum_freq = sum([degree_to_dim(min(d_out, d)) for d in fiber_in.degrees])
|
||||||
|
self.conv_out[str(d_out)] = VersatileConvSE3(sum_freq, list(channels_in_set)[0], c_out,
|
||||||
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
|
|
||||||
|
elif fuse_level.value >= ConvSE3FuseLevel.PARTIAL.value and \
|
||||||
|
unique_channels_out and fiber_out.degrees == degrees_up_to_max:
|
||||||
|
# Convolutions fused per input degree
|
||||||
|
self.used_fuse_level = ConvSE3FuseLevel.PARTIAL
|
||||||
|
self.conv_in = nn.ModuleDict()
|
||||||
|
for d_in, c_in in fiber_in:
|
||||||
|
sum_freq = sum([degree_to_dim(min(d_in, d)) for d in fiber_out.degrees])
|
||||||
|
self.conv_in[str(d_in)] = VersatileConvSE3(sum_freq, c_in, list(channels_out_set)[0],
|
||||||
|
fuse_level=ConvSE3FuseLevel.FULL, **common_args)
|
||||||
|
else:
|
||||||
|
# Use pairwise TFN convolutions
|
||||||
|
self.used_fuse_level = ConvSE3FuseLevel.NONE
|
||||||
|
self.conv = nn.ModuleDict()
|
||||||
|
for (degree_in, channels_in), (degree_out, channels_out) in (self.fiber_in * self.fiber_out):
|
||||||
|
dict_key = f'{degree_in},{degree_out}'
|
||||||
|
channels_in_new = channels_in + fiber_edge[degree_in] * (degree_in > 0)
|
||||||
|
sum_freq = degree_to_dim(min(degree_in, degree_out))
|
||||||
|
self.conv[dict_key] = VersatileConvSE3(sum_freq, channels_in_new, channels_out,
|
||||||
|
fuse_level=self.used_fuse_level, **common_args)
|
||||||
|
|
||||||
|
if self_interaction:
|
||||||
|
self.to_kernel_self = nn.ParameterDict()
|
||||||
|
for degree_out, channels_out in fiber_out:
|
||||||
|
if fiber_in[degree_out]:
|
||||||
|
self.to_kernel_self[str(degree_out)] = nn.Parameter(
|
||||||
|
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
node_feats: Dict[str, Tensor],
|
||||||
|
edge_feats: Dict[str, Tensor],
|
||||||
|
graph: DGLGraph,
|
||||||
|
basis: Dict[str, Tensor]
|
||||||
|
):
|
||||||
|
with nvtx_range(f'ConvSE3'):
|
||||||
|
invariant_edge_feats = edge_feats['0'].squeeze(-1)
|
||||||
|
src, dst = graph.edges()
|
||||||
|
out = {}
|
||||||
|
in_features = []
|
||||||
|
|
||||||
|
# Fetch all input features from edge and node features
|
||||||
|
for degree_in in self.fiber_in.degrees:
|
||||||
|
src_node_features = node_feats[str(degree_in)][src]
|
||||||
|
if degree_in > 0 and str(degree_in) in edge_feats:
|
||||||
|
# Handle edge features of any type by concatenating them to node features
|
||||||
|
src_node_features = torch.cat([src_node_features, edge_feats[str(degree_in)]], dim=1)
|
||||||
|
in_features.append(src_node_features)
|
||||||
|
|
||||||
|
if self.used_fuse_level == ConvSE3FuseLevel.FULL:
|
||||||
|
in_features_fused = torch.cat(in_features, dim=-1)
|
||||||
|
out = self.conv(in_features_fused, invariant_edge_feats, basis['fully_fused'])
|
||||||
|
|
||||||
|
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||||
|
out = unfuse_features(out, self.fiber_out.degrees)
|
||||||
|
|
||||||
|
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_out'):
|
||||||
|
in_features_fused = torch.cat(in_features, dim=-1)
|
||||||
|
for degree_out in self.fiber_out.degrees:
|
||||||
|
out[str(degree_out)] = self.conv_out[str(degree_out)](in_features_fused, invariant_edge_feats,
|
||||||
|
basis[f'out{degree_out}_fused'])
|
||||||
|
|
||||||
|
elif self.used_fuse_level == ConvSE3FuseLevel.PARTIAL and hasattr(self, 'conv_in'):
|
||||||
|
out = 0
|
||||||
|
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||||
|
out += self.conv_in[str(degree_in)](feature, invariant_edge_feats,
|
||||||
|
basis[f'in{degree_in}_fused'])
|
||||||
|
if not self.allow_fused_output or self.self_interaction or self.pool:
|
||||||
|
out = unfuse_features(out, self.fiber_out.degrees)
|
||||||
|
else:
|
||||||
|
# Fallback to pairwise TFN convolutions
|
||||||
|
for degree_out in self.fiber_out.degrees:
|
||||||
|
out_feature = 0
|
||||||
|
for degree_in, feature in zip(self.fiber_in.degrees, in_features):
|
||||||
|
dict_key = f'{degree_in},{degree_out}'
|
||||||
|
out_feature = out_feature + self.conv[dict_key](feature, invariant_edge_feats,
|
||||||
|
basis.get(dict_key, None))
|
||||||
|
out[str(degree_out)] = out_feature
|
||||||
|
|
||||||
|
for degree_out in self.fiber_out.degrees:
|
||||||
|
if self.self_interaction and str(degree_out) in self.to_kernel_self:
|
||||||
|
with nvtx_range(f'self interaction'):
|
||||||
|
dst_features = node_feats[str(degree_out)][dst]
|
||||||
|
kernel_self = self.to_kernel_self[str(degree_out)]
|
||||||
|
out[str(degree_out)] += kernel_self @ dst_features
|
||||||
|
|
||||||
|
if self.pool:
|
||||||
|
if self.sum_over_edge:
|
||||||
|
with nvtx_range(f'pooling'):
|
||||||
|
if isinstance(out, dict):
|
||||||
|
out[str(degree_out)] = dgl.ops.copy_e_sum(graph, out[str(degree_out)])
|
||||||
|
else:
|
||||||
|
out = dgl.ops.copy_e_sum(graph, out)
|
||||||
|
else:
|
||||||
|
with nvtx_range(f'pooling'):
|
||||||
|
if isinstance(out, dict):
|
||||||
|
out[str(degree_out)] = dgl.ops.copy_e_mean(graph, out[str(degree_out)])
|
||||||
|
else:
|
||||||
|
out = dgl.ops.copy_e_mean(graph, out)
|
||||||
|
return out
|
59
rf2aa/SE3Transformer/se3_transformer/model/layers/linear.py
Normal file
59
rf2aa/SE3Transformer/se3_transformer/model/layers/linear.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||||
|
|
||||||
|
|
||||||
|
class LinearSE3(nn.Module):
|
||||||
|
"""
|
||||||
|
Graph Linear SE(3)-equivariant layer, equivalent to a 1x1 convolution.
|
||||||
|
Maps a fiber to a fiber with the same degrees (channels may be different).
|
||||||
|
No interaction between degrees, but interaction between channels.
|
||||||
|
|
||||||
|
type-0 features (C_0 channels) ────> Linear(bias=False) ────> type-0 features (C'_0 channels)
|
||||||
|
type-1 features (C_1 channels) ────> Linear(bias=False) ────> type-1 features (C'_1 channels)
|
||||||
|
:
|
||||||
|
type-k features (C_k channels) ────> Linear(bias=False) ────> type-k features (C'_k channels)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, fiber_in: Fiber, fiber_out: Fiber):
|
||||||
|
super().__init__()
|
||||||
|
self.weights = nn.ParameterDict({
|
||||||
|
str(degree_out): nn.Parameter(
|
||||||
|
torch.randn(channels_out, fiber_in[degree_out]) / np.sqrt(fiber_in[degree_out]))
|
||||||
|
for degree_out, channels_out in fiber_out
|
||||||
|
})
|
||||||
|
|
||||||
|
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
|
return {
|
||||||
|
degree: self.weights[degree] @ features[degree]
|
||||||
|
for degree, weight in self.weights.items()
|
||||||
|
}
|
83
rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py
Normal file
83
rf2aa/SE3Transformer/se3_transformer/model/layers/norm.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.cuda.nvtx import range as nvtx_range
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||||
|
|
||||||
|
|
||||||
|
class NormSE3(nn.Module):
|
||||||
|
"""
|
||||||
|
Norm-based SE(3)-equivariant nonlinearity.
|
||||||
|
|
||||||
|
┌──> feature_norm ──> LayerNorm() ──> ReLU() ──┐
|
||||||
|
feature_in ──┤ * ──> feature_out
|
||||||
|
└──> feature_phase ────────────────────────────┘
|
||||||
|
"""
|
||||||
|
|
||||||
|
NORM_CLAMP = 2 ** -24 # Minimum positive subnormal for FP16
|
||||||
|
|
||||||
|
def __init__(self, fiber: Fiber, nonlinearity: nn.Module = nn.ReLU()):
|
||||||
|
super().__init__()
|
||||||
|
self.fiber = fiber
|
||||||
|
self.nonlinearity = nonlinearity
|
||||||
|
|
||||||
|
if len(set(fiber.channels)) == 1:
|
||||||
|
# Fuse all the layer normalizations into a group normalization
|
||||||
|
self.group_norm = nn.GroupNorm(num_groups=len(fiber.degrees), num_channels=sum(fiber.channels))
|
||||||
|
else:
|
||||||
|
# Use multiple layer normalizations
|
||||||
|
self.layer_norms = nn.ModuleDict({
|
||||||
|
str(degree): nn.LayerNorm(channels)
|
||||||
|
for degree, channels in fiber
|
||||||
|
})
|
||||||
|
|
||||||
|
def forward(self, features: Dict[str, Tensor], *args, **kwargs) -> Dict[str, Tensor]:
|
||||||
|
with nvtx_range('NormSE3'):
|
||||||
|
output = {}
|
||||||
|
if hasattr(self, 'group_norm'):
|
||||||
|
# Compute per-degree norms of features
|
||||||
|
norms = [features[str(d)].norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
||||||
|
for d in self.fiber.degrees]
|
||||||
|
fused_norms = torch.cat(norms, dim=-2)
|
||||||
|
|
||||||
|
# Transform the norms only
|
||||||
|
new_norms = self.nonlinearity(self.group_norm(fused_norms.squeeze(-1))).unsqueeze(-1)
|
||||||
|
new_norms = torch.chunk(new_norms, chunks=len(self.fiber.degrees), dim=-2)
|
||||||
|
|
||||||
|
# Scale features to the new norms
|
||||||
|
for norm, new_norm, d in zip(norms, new_norms, self.fiber.degrees):
|
||||||
|
output[str(d)] = features[str(d)] / norm * new_norm
|
||||||
|
else:
|
||||||
|
for degree, feat in features.items():
|
||||||
|
norm = feat.norm(dim=-1, keepdim=True).clamp(min=self.NORM_CLAMP)
|
||||||
|
new_norm = self.nonlinearity(self.layer_norms[degree](norm.squeeze(-1)).unsqueeze(-1))
|
||||||
|
output[degree] = new_norm * feat / norm
|
||||||
|
|
||||||
|
return output
|
53
rf2aa/SE3Transformer/se3_transformer/model/layers/pooling.py
Normal file
53
rf2aa/SE3Transformer/se3_transformer/model/layers/pooling.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from typing import Dict, Literal
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from dgl import DGLGraph
|
||||||
|
from dgl.nn.pytorch import AvgPooling, MaxPooling
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class GPooling(nn.Module):
|
||||||
|
"""
|
||||||
|
Graph max/average pooling on a given feature type.
|
||||||
|
The average can be taken for any feature type, and equivariance will be maintained.
|
||||||
|
The maximum can only be taken for invariant features (type 0).
|
||||||
|
If you want max-pooling for type > 0 features, look into Vector Neurons.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, feat_type: int = 0, pool: Literal['max', 'avg'] = 'max'):
|
||||||
|
"""
|
||||||
|
:param feat_type: Feature type to pool
|
||||||
|
:param pool: Type of pooling: max or avg
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert pool in ['max', 'avg'], f'Unknown pooling: {pool}'
|
||||||
|
assert feat_type == 0 or pool == 'avg', 'Max pooling on type > 0 features will break equivariance'
|
||||||
|
self.feat_type = feat_type
|
||||||
|
self.pool = MaxPooling() if pool == 'max' else AvgPooling()
|
||||||
|
|
||||||
|
def forward(self, features: Dict[str, Tensor], graph: DGLGraph, **kwargs) -> Tensor:
|
||||||
|
pooled = self.pool(graph, features[str(self.feat_type)])
|
||||||
|
return pooled.squeeze(dim=-1)
|
257
rf2aa/SE3Transformer/se3_transformer/model/transformer.py
Normal file
257
rf2aa/SE3Transformer/se3_transformer/model/transformer.py
Normal file
|
@ -0,0 +1,257 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Literal, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from dgl import DGLGraph
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.basis import get_basis, update_basis_with_fused
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.layers.attention import AttentionBlockSE3
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.layers.linear import LinearSE3
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.layers.norm import NormSE3
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.layers.pooling import GPooling
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||||
|
|
||||||
|
|
||||||
|
class Sequential(nn.Sequential):
|
||||||
|
""" Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
|
||||||
|
|
||||||
|
def forward(self, input, *args, **kwargs):
|
||||||
|
for module in self:
|
||||||
|
input = module(input, *args, **kwargs)
|
||||||
|
return input
|
||||||
|
|
||||||
|
|
||||||
|
def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
|
||||||
|
""" Add relative positions to existing edge features """
|
||||||
|
edge_features = edge_features.copy() if edge_features else {}
|
||||||
|
r = relative_pos.norm(dim=-1, keepdim=True)
|
||||||
|
if '0' in edge_features:
|
||||||
|
edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
|
||||||
|
else:
|
||||||
|
edge_features['0'] = r[..., None]
|
||||||
|
|
||||||
|
return edge_features
|
||||||
|
|
||||||
|
|
||||||
|
class SE3Transformer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
num_layers: int,
|
||||||
|
fiber_in: Fiber,
|
||||||
|
fiber_hidden: Fiber,
|
||||||
|
fiber_out: Fiber,
|
||||||
|
num_heads: int,
|
||||||
|
channels_div: int,
|
||||||
|
fiber_edge: Fiber = Fiber({}),
|
||||||
|
return_type: Optional[int] = None,
|
||||||
|
pooling: Optional[Literal['avg', 'max']] = None,
|
||||||
|
final_layer: Optional[Literal['conv', 'lin', 'att']] = 'conv',
|
||||||
|
norm: bool = True,
|
||||||
|
use_layer_norm: bool = True,
|
||||||
|
tensor_cores: bool = False,
|
||||||
|
low_memory: bool = False,
|
||||||
|
populate_edge: Optional[Literal['lin', 'arcsin', 'log', 'zero']] = 'lin',
|
||||||
|
sum_over_edge: bool = True,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
:param num_layers: Number of attention layers
|
||||||
|
:param fiber_in: Input fiber description
|
||||||
|
:param fiber_hidden: Hidden fiber description
|
||||||
|
:param fiber_out: Output fiber description
|
||||||
|
:param fiber_edge: Input edge fiber description
|
||||||
|
:param num_heads: Number of attention heads
|
||||||
|
:param channels_div: Channels division before feeding to attention layer
|
||||||
|
:param return_type: Return only features of this type
|
||||||
|
:param pooling: 'avg' or 'max' graph pooling before MLP layers
|
||||||
|
:param norm: Apply a normalization layer after each attention block
|
||||||
|
:param use_layer_norm: Apply layer normalization between MLP layers
|
||||||
|
:param tensor_cores: True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
|
||||||
|
:param low_memory: If True, will use slower ops that use less memory
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.fiber_edge = fiber_edge
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.channels_div = channels_div
|
||||||
|
self.return_type = return_type
|
||||||
|
self.pooling = pooling
|
||||||
|
self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
|
||||||
|
self.tensor_cores = tensor_cores
|
||||||
|
self.low_memory = low_memory
|
||||||
|
self.populate_edge = populate_edge
|
||||||
|
|
||||||
|
if low_memory and not tensor_cores:
|
||||||
|
logging.warning('Low memory mode will have no effect with no Tensor Cores')
|
||||||
|
|
||||||
|
# Fully fused convolutions when using Tensor Cores (and not low memory mode)
|
||||||
|
fuse_level = ConvSE3FuseLevel.FULL if tensor_cores and not low_memory else ConvSE3FuseLevel.PARTIAL
|
||||||
|
|
||||||
|
div = dict((str(degree), channels_div) for degree in range(self.max_degree+1))
|
||||||
|
div_fin = dict((str(degree), 1) for degree in range(self.max_degree+1))
|
||||||
|
div_fin['0'] = channels_div
|
||||||
|
|
||||||
|
graph_modules = []
|
||||||
|
for i in range(num_layers):
|
||||||
|
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
||||||
|
fiber_out=fiber_hidden,
|
||||||
|
fiber_edge=fiber_edge,
|
||||||
|
num_heads=num_heads,
|
||||||
|
channels_div=div,
|
||||||
|
use_layer_norm=use_layer_norm,
|
||||||
|
max_degree=self.max_degree,
|
||||||
|
fuse_level=fuse_level))
|
||||||
|
if norm:
|
||||||
|
graph_modules.append(NormSE3(fiber_hidden))
|
||||||
|
fiber_in = fiber_hidden
|
||||||
|
|
||||||
|
if final_layer == 'conv':
|
||||||
|
graph_modules.append(ConvSE3(fiber_in=fiber_in,
|
||||||
|
fiber_out=fiber_out,
|
||||||
|
fiber_edge=fiber_edge,
|
||||||
|
self_interaction=True,
|
||||||
|
sum_over_edge=sum_over_edge,
|
||||||
|
use_layer_norm=use_layer_norm,
|
||||||
|
max_degree=self.max_degree))
|
||||||
|
elif final_layer == "lin":
|
||||||
|
graph_modules.append(LinearSE3(fiber_in=fiber_in,
|
||||||
|
fiber_out=fiber_out))
|
||||||
|
else:
|
||||||
|
graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
|
||||||
|
fiber_out=fiber_out,
|
||||||
|
fiber_edge=fiber_edge,
|
||||||
|
num_heads=1,
|
||||||
|
channels_div=div_fin,
|
||||||
|
use_layer_norm=use_layer_norm,
|
||||||
|
max_degree=self.max_degree,
|
||||||
|
fuse_level=fuse_level))
|
||||||
|
self.graph_modules = Sequential(*graph_modules)
|
||||||
|
|
||||||
|
if pooling is not None:
|
||||||
|
assert return_type is not None, 'return_type must be specified when pooling'
|
||||||
|
self.pooling_module = GPooling(pool=pooling, feat_type=return_type)
|
||||||
|
|
||||||
|
def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
|
||||||
|
edge_feats: Optional[Dict[str, Tensor]] = None,
|
||||||
|
basis: Optional[Dict[str, Tensor]] = None):
|
||||||
|
# Compute bases in case they weren't precomputed as part of the data loading
|
||||||
|
basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
|
||||||
|
use_pad_trick=self.tensor_cores and not self.low_memory,
|
||||||
|
amp=torch.is_autocast_enabled())
|
||||||
|
|
||||||
|
# Add fused bases (per output degree, per input degree, and fully fused) to the dict
|
||||||
|
basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
|
||||||
|
fully_fused=self.tensor_cores and not self.low_memory)
|
||||||
|
|
||||||
|
if self.populate_edge=='lin':
|
||||||
|
edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
|
||||||
|
elif self.populate_edge=='arcsin':
|
||||||
|
r = graph.edata['rel_pos'].norm(dim=-1, keepdim=True)
|
||||||
|
r = torch.maximum(r, torch.zeros_like(r) + 4.0) - 4.0
|
||||||
|
r = torch.arcsinh(r)/3.0
|
||||||
|
edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1)
|
||||||
|
elif self.populate_edge=='log':
|
||||||
|
# fd - replace with log(1+x)
|
||||||
|
r = torch.log( 1 + graph.edata['rel_pos'].norm(dim=-1, keepdim=True) )
|
||||||
|
edge_feats['0'] = torch.cat([edge_feats['0'], r[..., None]], dim=1)
|
||||||
|
else:
|
||||||
|
edge_feats['0'] = torch.cat((edge_feats['0'], torch.zeros_like(edge_feats['0'][:,:1,:])), dim=1)
|
||||||
|
|
||||||
|
node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
|
||||||
|
|
||||||
|
if self.pooling is not None:
|
||||||
|
return self.pooling_module(node_feats, graph=graph)
|
||||||
|
|
||||||
|
if self.return_type is not None:
|
||||||
|
return node_feats[str(self.return_type)]
|
||||||
|
|
||||||
|
return node_feats
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_argparse_args(parser):
|
||||||
|
parser.add_argument('--num_layers', type=int, default=7,
|
||||||
|
help='Number of stacked Transformer layers')
|
||||||
|
parser.add_argument('--num_heads', type=int, default=8,
|
||||||
|
help='Number of heads in self-attention')
|
||||||
|
parser.add_argument('--channels_div', type=int, default=2,
|
||||||
|
help='Channels division before feeding to attention layer')
|
||||||
|
parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
|
||||||
|
help='Type of graph pooling')
|
||||||
|
parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
|
||||||
|
help='Apply a normalization layer after each attention block')
|
||||||
|
parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
|
||||||
|
help='Apply layer normalization between MLP layers')
|
||||||
|
parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
|
||||||
|
help='If true, will use fused ops that are slower but that use less memory '
|
||||||
|
'(expect 25 percent less memory). '
|
||||||
|
'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
class SE3TransformerPooled(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
fiber_in: Fiber,
|
||||||
|
fiber_out: Fiber,
|
||||||
|
fiber_edge: Fiber,
|
||||||
|
num_degrees: int,
|
||||||
|
num_channels: int,
|
||||||
|
output_dim: int,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__()
|
||||||
|
kwargs['pooling'] = kwargs['pooling'] or 'max'
|
||||||
|
self.transformer = SE3Transformer(
|
||||||
|
fiber_in=fiber_in,
|
||||||
|
fiber_hidden=Fiber.create(num_degrees, num_channels),
|
||||||
|
fiber_out=fiber_out,
|
||||||
|
fiber_edge=fiber_edge,
|
||||||
|
return_type=0,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
n_out_features = fiber_out.num_features
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(n_out_features, n_out_features),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(n_out_features, output_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, graph, node_feats, edge_feats, basis=None):
|
||||||
|
feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
|
||||||
|
y = self.mlp(feats).squeeze(-1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_argparse_args(parent_parser):
|
||||||
|
parser = parent_parser.add_argument_group("Model architecture")
|
||||||
|
SE3Transformer.add_argparse_args(parser)
|
||||||
|
parser.add_argument('--num_degrees',
|
||||||
|
help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
|
||||||
|
type=int, default=4)
|
||||||
|
parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
|
||||||
|
return parent_parser
|
0
rf2aa/SE3Transformer/se3_transformer/runtime/__init__.py
Normal file
0
rf2aa/SE3Transformer/se3_transformer/runtime/__init__.py
Normal file
70
rf2aa/SE3Transformer/se3_transformer/runtime/arguments.py
Normal file
70
rf2aa/SE3Transformer/se3_transformer/runtime/arguments.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import str2bool
|
||||||
|
|
||||||
|
PARSER = argparse.ArgumentParser(description='SE(3)-Transformer')
|
||||||
|
|
||||||
|
paths = PARSER.add_argument_group('Paths')
|
||||||
|
paths.add_argument('--data_dir', type=pathlib.Path, default=pathlib.Path('./data'),
|
||||||
|
help='Directory where the data is located or should be downloaded')
|
||||||
|
paths.add_argument('--log_dir', type=pathlib.Path, default=pathlib.Path('/results'),
|
||||||
|
help='Directory where the results logs should be saved')
|
||||||
|
paths.add_argument('--dllogger_name', type=str, default='dllogger_results.json',
|
||||||
|
help='Name for the resulting DLLogger JSON file')
|
||||||
|
paths.add_argument('--save_ckpt_path', type=pathlib.Path, default=None,
|
||||||
|
help='File where the checkpoint should be saved')
|
||||||
|
paths.add_argument('--load_ckpt_path', type=pathlib.Path, default=None,
|
||||||
|
help='File of the checkpoint to be loaded')
|
||||||
|
|
||||||
|
optimizer = PARSER.add_argument_group('Optimizer')
|
||||||
|
optimizer.add_argument('--optimizer', choices=['adam', 'sgd', 'lamb'], default='adam')
|
||||||
|
optimizer.add_argument('--learning_rate', '--lr', dest='learning_rate', type=float, default=0.002)
|
||||||
|
optimizer.add_argument('--min_learning_rate', '--min_lr', dest='min_learning_rate', type=float, default=None)
|
||||||
|
optimizer.add_argument('--momentum', type=float, default=0.9)
|
||||||
|
optimizer.add_argument('--weight_decay', type=float, default=0.1)
|
||||||
|
|
||||||
|
PARSER.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
|
||||||
|
PARSER.add_argument('--batch_size', type=int, default=240, help='Batch size')
|
||||||
|
PARSER.add_argument('--seed', type=int, default=None, help='Set a seed globally')
|
||||||
|
PARSER.add_argument('--num_workers', type=int, default=8, help='Number of dataloading workers')
|
||||||
|
|
||||||
|
PARSER.add_argument('--amp', type=str2bool, nargs='?', const=True, default=False, help='Use Automatic Mixed Precision')
|
||||||
|
PARSER.add_argument('--gradient_clip', type=float, default=None, help='Clipping of the gradient norms')
|
||||||
|
PARSER.add_argument('--accumulate_grad_batches', type=int, default=1, help='Gradient accumulation')
|
||||||
|
PARSER.add_argument('--ckpt_interval', type=int, default=-1, help='Save a checkpoint every N epochs')
|
||||||
|
PARSER.add_argument('--eval_interval', dest='eval_interval', type=int, default=1,
|
||||||
|
help='Do an evaluation round every N epochs')
|
||||||
|
PARSER.add_argument('--silent', type=str2bool, nargs='?', const=True, default=False,
|
||||||
|
help='Minimize stdout output')
|
||||||
|
|
||||||
|
PARSER.add_argument('--benchmark', type=str2bool, nargs='?', const=True, default=False,
|
||||||
|
help='Benchmark mode')
|
||||||
|
|
||||||
|
QM9DataModule.add_argparse_args(PARSER)
|
||||||
|
SE3TransformerPooled.add_argparse_args(PARSER)
|
160
rf2aa/SE3Transformer/se3_transformer/runtime/callbacks.py
Normal file
160
rf2aa/SE3Transformer/se3_transformer/runtime/callbacks.py
Normal file
|
@ -0,0 +1,160 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import Logger
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.metrics import MeanAbsoluteError
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCallback(ABC):
|
||||||
|
def on_fit_start(self, optimizer, args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_fit_end(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_epoch_end(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_batch_start(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_validation_step(self, input, target, pred):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_validation_end(self, epoch=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_checkpoint_load(self, checkpoint):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_checkpoint_save(self, checkpoint):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LRSchedulerCallback(BaseCallback):
|
||||||
|
def __init__(self, logger: Optional[Logger] = None):
|
||||||
|
self.logger = logger
|
||||||
|
self.scheduler = None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_scheduler(self, optimizer, args):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_fit_start(self, optimizer, args):
|
||||||
|
self.scheduler = self.get_scheduler(optimizer, args)
|
||||||
|
|
||||||
|
def on_checkpoint_load(self, checkpoint):
|
||||||
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
|
|
||||||
|
def on_checkpoint_save(self, checkpoint):
|
||||||
|
checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
|
||||||
|
|
||||||
|
def on_epoch_end(self):
|
||||||
|
if self.logger is not None:
|
||||||
|
self.logger.log_metrics({'learning rate': self.scheduler.get_last_lr()[0]}, step=self.scheduler.last_epoch)
|
||||||
|
self.scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
|
class QM9MetricCallback(BaseCallback):
|
||||||
|
""" Logs the rescaled mean absolute error for QM9 regression tasks """
|
||||||
|
|
||||||
|
def __init__(self, logger, targets_std, prefix=''):
|
||||||
|
self.mae = MeanAbsoluteError()
|
||||||
|
self.logger = logger
|
||||||
|
self.targets_std = targets_std
|
||||||
|
self.prefix = prefix
|
||||||
|
self.best_mae = float('inf')
|
||||||
|
|
||||||
|
def on_validation_step(self, input, target, pred):
|
||||||
|
self.mae(pred.detach(), target.detach())
|
||||||
|
|
||||||
|
def on_validation_end(self, epoch=None):
|
||||||
|
mae = self.mae.compute() * self.targets_std
|
||||||
|
logging.info(f'{self.prefix} MAE: {mae}')
|
||||||
|
self.logger.log_metrics({f'{self.prefix} MAE': mae}, epoch)
|
||||||
|
self.best_mae = min(self.best_mae, mae)
|
||||||
|
|
||||||
|
def on_fit_end(self):
|
||||||
|
if self.best_mae != float('inf'):
|
||||||
|
self.logger.log_metrics({f'{self.prefix} best MAE': self.best_mae})
|
||||||
|
|
||||||
|
|
||||||
|
class QM9LRSchedulerCallback(LRSchedulerCallback):
|
||||||
|
def __init__(self, logger, epochs):
|
||||||
|
super().__init__(logger)
|
||||||
|
self.epochs = epochs
|
||||||
|
|
||||||
|
def get_scheduler(self, optimizer, args):
|
||||||
|
min_lr = args.min_learning_rate if args.min_learning_rate else args.learning_rate / 10.0
|
||||||
|
return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, self.epochs, eta_min=min_lr)
|
||||||
|
|
||||||
|
|
||||||
|
class PerformanceCallback(BaseCallback):
|
||||||
|
def __init__(self, logger, batch_size: int, warmup_epochs: int = 1, mode: str = 'train'):
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.warmup_epochs = warmup_epochs
|
||||||
|
self.epoch = 0
|
||||||
|
self.timestamps = []
|
||||||
|
self.mode = mode
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
def on_batch_start(self):
|
||||||
|
if self.epoch >= self.warmup_epochs:
|
||||||
|
self.timestamps.append(time.time() * 1000.0)
|
||||||
|
|
||||||
|
def _log_perf(self):
|
||||||
|
stats = self.process_performance_stats()
|
||||||
|
for k, v in stats.items():
|
||||||
|
logging.info(f'performance {k}: {v}')
|
||||||
|
|
||||||
|
self.logger.log_metrics(stats)
|
||||||
|
|
||||||
|
def on_epoch_end(self):
|
||||||
|
self.epoch += 1
|
||||||
|
|
||||||
|
def on_fit_end(self):
|
||||||
|
if self.epoch > self.warmup_epochs:
|
||||||
|
self._log_perf()
|
||||||
|
self.timestamps = []
|
||||||
|
|
||||||
|
def process_performance_stats(self):
|
||||||
|
timestamps = np.asarray(self.timestamps)
|
||||||
|
deltas = np.diff(timestamps)
|
||||||
|
throughput = (self.batch_size / deltas).mean()
|
||||||
|
stats = {
|
||||||
|
f"throughput_{self.mode}": throughput,
|
||||||
|
f"latency_{self.mode}_mean": deltas.mean(),
|
||||||
|
f"total_time_{self.mode}": timestamps[-1] - timestamps[0],
|
||||||
|
}
|
||||||
|
for level in [90, 95, 99]:
|
||||||
|
stats.update({f"latency_{self.mode}_{level}": np.percentile(deltas, level)})
|
||||||
|
|
||||||
|
return stats
|
325
rf2aa/SE3Transformer/se3_transformer/runtime/gpu_affinity.py
Normal file
325
rf2aa/SE3Transformer/se3_transformer/runtime/gpu_affinity.py
Normal file
|
@ -0,0 +1,325 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import itertools
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import pathlib
|
||||||
|
import re
|
||||||
|
|
||||||
|
import pynvml
|
||||||
|
|
||||||
|
|
||||||
|
class Device:
|
||||||
|
# assumes nvml returns list of 64 bit ints
|
||||||
|
_nvml_affinity_elements = math.ceil(os.cpu_count() / 64)
|
||||||
|
|
||||||
|
def __init__(self, device_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.handle = pynvml.nvmlDeviceGetHandleByIndex(device_idx)
|
||||||
|
|
||||||
|
def get_name(self):
|
||||||
|
return pynvml.nvmlDeviceGetName(self.handle)
|
||||||
|
|
||||||
|
def get_uuid(self):
|
||||||
|
return pynvml.nvmlDeviceGetUUID(self.handle)
|
||||||
|
|
||||||
|
def get_cpu_affinity(self):
|
||||||
|
affinity_string = ""
|
||||||
|
for j in pynvml.nvmlDeviceGetCpuAffinity(self.handle, Device._nvml_affinity_elements):
|
||||||
|
# assume nvml returns list of 64 bit ints
|
||||||
|
affinity_string = "{:064b}".format(j) + affinity_string
|
||||||
|
|
||||||
|
affinity_list = [int(x) for x in affinity_string]
|
||||||
|
affinity_list.reverse() # so core 0 is in 0th element of list
|
||||||
|
|
||||||
|
ret = [i for i, e in enumerate(affinity_list) if e != 0]
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_siblings_list():
|
||||||
|
"""
|
||||||
|
Returns a list of 2-element integer tuples representing pairs of
|
||||||
|
hyperthreading cores.
|
||||||
|
"""
|
||||||
|
path = "/sys/devices/system/cpu/cpu*/topology/thread_siblings_list"
|
||||||
|
thread_siblings_list = []
|
||||||
|
pattern = re.compile(r"(\d+)\D(\d+)")
|
||||||
|
for fname in pathlib.Path(path[0]).glob(path[1:]):
|
||||||
|
with open(fname) as f:
|
||||||
|
content = f.read().strip()
|
||||||
|
res = pattern.findall(content)
|
||||||
|
if res:
|
||||||
|
pair = tuple(map(int, res[0]))
|
||||||
|
thread_siblings_list.append(pair)
|
||||||
|
return thread_siblings_list
|
||||||
|
|
||||||
|
|
||||||
|
def check_socket_affinities(socket_affinities):
|
||||||
|
# sets of cores should be either identical or disjoint
|
||||||
|
for i, j in itertools.product(socket_affinities, socket_affinities):
|
||||||
|
if not set(i) == set(j) and not set(i).isdisjoint(set(j)):
|
||||||
|
raise RuntimeError(f"Sets of cores should be either identical or disjoint, " f"but got {i} and {j}.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_socket_affinities(nproc_per_node, exclude_unavailable_cores=True):
|
||||||
|
devices = [Device(i) for i in range(nproc_per_node)]
|
||||||
|
socket_affinities = [dev.get_cpu_affinity() for dev in devices]
|
||||||
|
|
||||||
|
if exclude_unavailable_cores:
|
||||||
|
available_cores = os.sched_getaffinity(0)
|
||||||
|
socket_affinities = [list(set(affinity) & available_cores) for affinity in socket_affinities]
|
||||||
|
|
||||||
|
check_socket_affinities(socket_affinities)
|
||||||
|
|
||||||
|
return socket_affinities
|
||||||
|
|
||||||
|
|
||||||
|
def set_socket_affinity(gpu_id):
|
||||||
|
"""
|
||||||
|
The process is assigned with all available logical CPU cores from the CPU
|
||||||
|
socket connected to the GPU with a given id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpu_id: index of a GPU
|
||||||
|
"""
|
||||||
|
dev = Device(gpu_id)
|
||||||
|
affinity = dev.get_cpu_affinity()
|
||||||
|
os.sched_setaffinity(0, affinity)
|
||||||
|
|
||||||
|
|
||||||
|
def set_single_affinity(gpu_id):
|
||||||
|
"""
|
||||||
|
The process is assigned with the first available logical CPU core from the
|
||||||
|
list of all CPU cores from the CPU socket connected to the GPU with a given
|
||||||
|
id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpu_id: index of a GPU
|
||||||
|
"""
|
||||||
|
dev = Device(gpu_id)
|
||||||
|
affinity = dev.get_cpu_affinity()
|
||||||
|
|
||||||
|
# exclude unavailable cores
|
||||||
|
available_cores = os.sched_getaffinity(0)
|
||||||
|
affinity = list(set(affinity) & available_cores)
|
||||||
|
os.sched_setaffinity(0, affinity[:1])
|
||||||
|
|
||||||
|
|
||||||
|
def set_single_unique_affinity(gpu_id, nproc_per_node):
|
||||||
|
"""
|
||||||
|
The process is assigned with a single unique available physical CPU core
|
||||||
|
from the list of all CPU cores from the CPU socket connected to the GPU with
|
||||||
|
a given id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpu_id: index of a GPU
|
||||||
|
"""
|
||||||
|
socket_affinities = get_socket_affinities(nproc_per_node)
|
||||||
|
|
||||||
|
siblings_list = get_thread_siblings_list()
|
||||||
|
siblings_dict = dict(siblings_list)
|
||||||
|
|
||||||
|
# remove siblings
|
||||||
|
for idx, socket_affinity in enumerate(socket_affinities):
|
||||||
|
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
||||||
|
|
||||||
|
affinities = []
|
||||||
|
assigned = []
|
||||||
|
|
||||||
|
for socket_affinity in socket_affinities:
|
||||||
|
for core in socket_affinity:
|
||||||
|
if core not in assigned:
|
||||||
|
affinities.append([core])
|
||||||
|
assigned.append(core)
|
||||||
|
break
|
||||||
|
os.sched_setaffinity(0, affinities[gpu_id])
|
||||||
|
|
||||||
|
|
||||||
|
def set_socket_unique_affinity(gpu_id, nproc_per_node, mode, balanced=True):
|
||||||
|
"""
|
||||||
|
The process is assigned with an unique subset of available physical CPU
|
||||||
|
cores from the CPU socket connected to a GPU with a given id.
|
||||||
|
Assignment automatically includes hyperthreading siblings (if siblings are
|
||||||
|
available).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpu_id: index of a GPU
|
||||||
|
nproc_per_node: total number of processes per node
|
||||||
|
mode: mode
|
||||||
|
balanced: assign an equal number of physical cores to each process
|
||||||
|
"""
|
||||||
|
socket_affinities = get_socket_affinities(nproc_per_node)
|
||||||
|
|
||||||
|
siblings_list = get_thread_siblings_list()
|
||||||
|
siblings_dict = dict(siblings_list)
|
||||||
|
|
||||||
|
# remove hyperthreading siblings
|
||||||
|
for idx, socket_affinity in enumerate(socket_affinities):
|
||||||
|
socket_affinities[idx] = list(set(socket_affinity) - set(siblings_dict.values()))
|
||||||
|
|
||||||
|
socket_affinities_to_device_ids = collections.defaultdict(list)
|
||||||
|
|
||||||
|
for idx, socket_affinity in enumerate(socket_affinities):
|
||||||
|
socket_affinities_to_device_ids[tuple(socket_affinity)].append(idx)
|
||||||
|
|
||||||
|
# compute minimal number of physical cores per GPU across all GPUs and
|
||||||
|
# sockets, code assigns this number of cores per GPU if balanced == True
|
||||||
|
min_physical_cores_per_gpu = min(
|
||||||
|
[len(cores) // len(gpus) for cores, gpus in socket_affinities_to_device_ids.items()]
|
||||||
|
)
|
||||||
|
|
||||||
|
for socket_affinity, device_ids in socket_affinities_to_device_ids.items():
|
||||||
|
devices_per_group = len(device_ids)
|
||||||
|
if balanced:
|
||||||
|
cores_per_device = min_physical_cores_per_gpu
|
||||||
|
socket_affinity = socket_affinity[: devices_per_group * min_physical_cores_per_gpu]
|
||||||
|
else:
|
||||||
|
cores_per_device = len(socket_affinity) // devices_per_group
|
||||||
|
|
||||||
|
for group_id, device_id in enumerate(device_ids):
|
||||||
|
if device_id == gpu_id:
|
||||||
|
|
||||||
|
# In theory there should be no difference in performance between
|
||||||
|
# 'interleaved' and 'continuous' pattern on Intel-based DGX-1,
|
||||||
|
# but 'continuous' should be better for DGX A100 because on AMD
|
||||||
|
# Rome 4 consecutive cores are sharing L3 cache.
|
||||||
|
# TODO: code doesn't attempt to automatically detect layout of
|
||||||
|
# L3 cache, also external environment may already exclude some
|
||||||
|
# cores, this code makes no attempt to detect it and to align
|
||||||
|
# mapping to multiples of 4.
|
||||||
|
|
||||||
|
if mode == "interleaved":
|
||||||
|
affinity = list(socket_affinity[group_id::devices_per_group])
|
||||||
|
elif mode == "continuous":
|
||||||
|
affinity = list(socket_affinity[group_id * cores_per_device: (group_id + 1) * cores_per_device])
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown set_socket_unique_affinity mode")
|
||||||
|
|
||||||
|
# unconditionally reintroduce hyperthreading siblings, this step
|
||||||
|
# may result in a different numbers of logical cores assigned to
|
||||||
|
# each GPU even if balanced == True (if hyperthreading siblings
|
||||||
|
# aren't available for a subset of cores due to some external
|
||||||
|
# constraints, siblings are re-added unconditionally, in the
|
||||||
|
# worst case unavailable logical core will be ignored by
|
||||||
|
# os.sched_setaffinity().
|
||||||
|
affinity += [siblings_dict[aff] for aff in affinity if aff in siblings_dict]
|
||||||
|
os.sched_setaffinity(0, affinity)
|
||||||
|
|
||||||
|
|
||||||
|
def set_affinity(gpu_id, nproc_per_node, mode="socket_unique_continuous", balanced=True):
|
||||||
|
"""
|
||||||
|
The process is assigned with a proper CPU affinity which matches hardware
|
||||||
|
architecture on a given platform. Usually it improves and stabilizes
|
||||||
|
performance of deep learning training workloads.
|
||||||
|
|
||||||
|
This function assumes that the workload is running in multi-process
|
||||||
|
single-device mode (there are multiple training processes and each process
|
||||||
|
is running on a single GPU), which is typical for multi-GPU training
|
||||||
|
workloads using `torch.nn.parallel.DistributedDataParallel`.
|
||||||
|
|
||||||
|
Available affinity modes:
|
||||||
|
* 'socket' - the process is assigned with all available logical CPU cores
|
||||||
|
from the CPU socket connected to the GPU with a given id.
|
||||||
|
* 'single' - the process is assigned with the first available logical CPU
|
||||||
|
core from the list of all CPU cores from the CPU socket connected to the GPU
|
||||||
|
with a given id (multiple GPUs could be assigned with the same CPU core).
|
||||||
|
* 'single_unique' - the process is assigned with a single unique available
|
||||||
|
physical CPU core from the list of all CPU cores from the CPU socket
|
||||||
|
connected to the GPU with a given id.
|
||||||
|
* 'socket_unique_interleaved' - the process is assigned with an unique
|
||||||
|
subset of available physical CPU cores from the CPU socket connected to a
|
||||||
|
GPU with a given id, hyperthreading siblings are included automatically,
|
||||||
|
cores are assigned with interleaved indexing pattern
|
||||||
|
* 'socket_unique_continuous' - (the default) the process is assigned with an
|
||||||
|
unique subset of available physical CPU cores from the CPU socket connected
|
||||||
|
to a GPU with a given id, hyperthreading siblings are included
|
||||||
|
automatically, cores are assigned with continuous indexing pattern
|
||||||
|
|
||||||
|
'socket_unique_continuous' is the recommended mode for deep learning
|
||||||
|
training workloads on NVIDIA DGX machines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gpu_id: integer index of a GPU
|
||||||
|
nproc_per_node: number of processes per node
|
||||||
|
mode: affinity mode
|
||||||
|
balanced: assign an equal number of physical cores to each process,
|
||||||
|
affects only 'socket_unique_interleaved' and
|
||||||
|
'socket_unique_continuous' affinity modes
|
||||||
|
|
||||||
|
Returns a set of logical CPU cores on which the process is eligible to run.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import gpu_affinity
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--local_rank',
|
||||||
|
type=int,
|
||||||
|
default=os.getenv('LOCAL_RANK', 0),
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
nproc_per_node = torch.cuda.device_count()
|
||||||
|
|
||||||
|
affinity = gpu_affinity.set_affinity(args.local_rank, nproc_per_node)
|
||||||
|
print(f'{args.local_rank}: core affinity: {affinity}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
Launch the example with:
|
||||||
|
python -m torch.distributed.launch --nproc_per_node <#GPUs> example.py
|
||||||
|
|
||||||
|
|
||||||
|
WARNING: On DGX A100 only a half of CPU cores have direct access to GPUs.
|
||||||
|
This function restricts execution only to the CPU cores directly connected
|
||||||
|
to GPUs, so on DGX A100 it will limit the code to half of CPU cores and half
|
||||||
|
of CPU memory bandwidth (which may be fine for many DL models).
|
||||||
|
"""
|
||||||
|
pynvml.nvmlInit()
|
||||||
|
|
||||||
|
if mode == "socket":
|
||||||
|
set_socket_affinity(gpu_id)
|
||||||
|
elif mode == "single":
|
||||||
|
set_single_affinity(gpu_id)
|
||||||
|
elif mode == "single_unique":
|
||||||
|
set_single_unique_affinity(gpu_id, nproc_per_node)
|
||||||
|
elif mode == "socket_unique_interleaved":
|
||||||
|
set_socket_unique_affinity(gpu_id, nproc_per_node, "interleaved", balanced)
|
||||||
|
elif mode == "socket_unique_continuous":
|
||||||
|
set_socket_unique_affinity(gpu_id, nproc_per_node, "continuous", balanced)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown affinity mode")
|
||||||
|
|
||||||
|
affinity = os.sched_getaffinity(0)
|
||||||
|
return affinity
|
131
rf2aa/SE3Transformer/se3_transformer/runtime/inference.py
Normal file
131
rf2aa/SE3Transformer/se3_transformer/runtime/inference.py
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import BaseCallback
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import DLLogger
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def evaluate(model: nn.Module,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
callbacks: List[BaseCallback],
|
||||||
|
args):
|
||||||
|
model.eval()
|
||||||
|
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
|
||||||
|
leave=False, disable=(args.silent or get_local_rank() != 0)):
|
||||||
|
*input, target = to_cuda(batch)
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_batch_start()
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=args.amp):
|
||||||
|
pred = model(*input)
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_validation_step(input, target, pred)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import init_distributed, seed_everything
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled, Fiber
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
|
||||||
|
import torch.distributed as dist
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
is_distributed = init_distributed()
|
||||||
|
local_rank = get_local_rank()
|
||||||
|
args = PARSER.parse_args()
|
||||||
|
|
||||||
|
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
||||||
|
|
||||||
|
logging.info('====== SE(3)-Transformer ======')
|
||||||
|
logging.info('| Inference on the test set |')
|
||||||
|
logging.info('===============================')
|
||||||
|
|
||||||
|
if not args.benchmark and args.load_ckpt_path is None:
|
||||||
|
logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.benchmark:
|
||||||
|
logging.info('Running benchmark mode with one warmup pass')
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
|
seed_everything(args.seed)
|
||||||
|
|
||||||
|
major_cc, minor_cc = torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
logger = DLLogger(args.log_dir, filename=args.dllogger_name)
|
||||||
|
datamodule = QM9DataModule(**vars(args))
|
||||||
|
model = SE3TransformerPooled(
|
||||||
|
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
||||||
|
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
||||||
|
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
||||||
|
output_dim=1,
|
||||||
|
tensor_cores=(args.amp and major_cc >= 7) or major_cc >= 8, # use Tensor Cores more effectively
|
||||||
|
**vars(args)
|
||||||
|
)
|
||||||
|
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
|
||||||
|
|
||||||
|
model.to(device=torch.cuda.current_device())
|
||||||
|
if args.load_ckpt_path is not None:
|
||||||
|
checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
|
||||||
|
if is_distributed:
|
||||||
|
nproc_per_node = torch.cuda.device_count()
|
||||||
|
affinity = gpu_affinity.set_affinity(local_rank, nproc_per_node)
|
||||||
|
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
||||||
|
|
||||||
|
test_dataloader = datamodule.test_dataloader() if not args.benchmark else datamodule.train_dataloader()
|
||||||
|
evaluate(model,
|
||||||
|
test_dataloader,
|
||||||
|
callbacks,
|
||||||
|
args)
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_validation_end()
|
||||||
|
|
||||||
|
if args.benchmark:
|
||||||
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
|
callbacks = [PerformanceCallback(logger, args.batch_size * world_size, warmup_epochs=1, mode='inference')]
|
||||||
|
for _ in range(6):
|
||||||
|
evaluate(model,
|
||||||
|
test_dataloader,
|
||||||
|
callbacks,
|
||||||
|
args)
|
||||||
|
callbacks[0].on_epoch_end()
|
||||||
|
|
||||||
|
callbacks[0].on_fit_end()
|
134
rf2aa/SE3Transformer/se3_transformer/runtime/loggers.py
Normal file
134
rf2aa/SE3Transformer/se3_transformer/runtime/loggers.py
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Dict, Any, Callable, Optional
|
||||||
|
|
||||||
|
import dllogger
|
||||||
|
import torch.distributed as dist
|
||||||
|
import wandb
|
||||||
|
from dllogger import Verbosity
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import rank_zero_only
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(ABC):
|
||||||
|
@rank_zero_only
|
||||||
|
@abstractmethod
|
||||||
|
def log_hyperparams(self, params):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
@abstractmethod
|
||||||
|
def log_metrics(self, metrics, step=None):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sanitize_params(params):
|
||||||
|
def _sanitize(val):
|
||||||
|
if isinstance(val, Callable):
|
||||||
|
try:
|
||||||
|
_val = val()
|
||||||
|
if isinstance(_val, Callable):
|
||||||
|
return val.__name__
|
||||||
|
return _val
|
||||||
|
except Exception:
|
||||||
|
return getattr(val, "__name__", None)
|
||||||
|
elif isinstance(val, pathlib.Path) or isinstance(val, Enum):
|
||||||
|
return str(val)
|
||||||
|
return val
|
||||||
|
|
||||||
|
return {key: _sanitize(val) for key, val in params.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerCollection(Logger):
|
||||||
|
def __init__(self, loggers):
|
||||||
|
super().__init__()
|
||||||
|
self.loggers = loggers
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
return [logger for logger in self.loggers][index]
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def log_metrics(self, metrics, step=None):
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.log_metrics(metrics, step)
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def log_hyperparams(self, params):
|
||||||
|
for logger in self.loggers:
|
||||||
|
logger.log_hyperparams(params)
|
||||||
|
|
||||||
|
|
||||||
|
class DLLogger(Logger):
|
||||||
|
def __init__(self, save_dir: pathlib.Path, filename: str):
|
||||||
|
super().__init__()
|
||||||
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
dllogger.init(
|
||||||
|
backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))])
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def log_hyperparams(self, params):
|
||||||
|
params = self._sanitize_params(params)
|
||||||
|
dllogger.log(step="PARAMETER", data=params)
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def log_metrics(self, metrics, step=None):
|
||||||
|
if step is None:
|
||||||
|
step = tuple()
|
||||||
|
|
||||||
|
dllogger.log(step=step, data=metrics)
|
||||||
|
|
||||||
|
|
||||||
|
class WandbLogger(Logger):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
save_dir: pathlib.Path,
|
||||||
|
id: Optional[str] = None,
|
||||||
|
project: Optional[str] = None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.experiment = wandb.init(name=name,
|
||||||
|
project=project,
|
||||||
|
id=id,
|
||||||
|
dir=str(save_dir),
|
||||||
|
resume='allow',
|
||||||
|
anonymous='must')
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def log_hyperparams(self, params: Dict[str, Any]) -> None:
|
||||||
|
params = self._sanitize_params(params)
|
||||||
|
self.experiment.config.update(params, allow_val_change=True)
|
||||||
|
|
||||||
|
@rank_zero_only
|
||||||
|
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||||
|
if step is not None:
|
||||||
|
self.experiment.log({**metrics, 'epoch': step})
|
||||||
|
else:
|
||||||
|
self.experiment.log(metrics)
|
83
rf2aa/SE3Transformer/se3_transformer/runtime/metrics.py
Normal file
83
rf2aa/SE3Transformer/se3_transformer/runtime/metrics.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class Metric(ABC):
|
||||||
|
""" Metric class with synchronization capabilities similar to TorchMetrics """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.states = {}
|
||||||
|
|
||||||
|
def add_state(self, name: str, default: Tensor):
|
||||||
|
assert name not in self.states
|
||||||
|
self.states[name] = default.clone()
|
||||||
|
setattr(self, name, default)
|
||||||
|
|
||||||
|
def synchronize(self):
|
||||||
|
if dist.is_initialized():
|
||||||
|
for state in self.states:
|
||||||
|
dist.all_reduce(getattr(self, state), op=dist.ReduceOp.SUM, group=dist.group.WORLD)
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
self.update(*args, **kwargs)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
for name, default in self.states.items():
|
||||||
|
setattr(self, name, default.clone())
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
self.synchronize()
|
||||||
|
value = self._compute().item()
|
||||||
|
self.reset()
|
||||||
|
return value
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _compute(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, preds: Tensor, targets: Tensor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MeanAbsoluteError(Metric):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.add_state('error', torch.tensor(0, dtype=torch.float32, device='cuda'))
|
||||||
|
self.add_state('total', torch.tensor(0, dtype=torch.int32, device='cuda'))
|
||||||
|
|
||||||
|
def update(self, preds: Tensor, targets: Tensor):
|
||||||
|
preds = preds.detach()
|
||||||
|
n = preds.shape[0]
|
||||||
|
error = torch.abs(preds.view(n, -1) - targets.view(n, -1)).sum()
|
||||||
|
self.total += n
|
||||||
|
self.error += error
|
||||||
|
|
||||||
|
def _compute(self):
|
||||||
|
return self.error / self.total
|
238
rf2aa/SE3Transformer/se3_transformer/runtime/training.py
Normal file
238
rf2aa/SE3Transformer/se3_transformer/runtime/training.py
Normal file
|
@ -0,0 +1,238 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import pathlib
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
import torch.nn as nn
|
||||||
|
from apex.optimizers import FusedAdam, FusedLAMB
|
||||||
|
from torch.nn.modules.loss import _Loss
|
||||||
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.data_loading import QM9DataModule
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model import SE3TransformerPooled
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime import gpu_affinity
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.arguments import PARSER
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback, \
|
||||||
|
PerformanceCallback
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.inference import evaluate
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.runtime.utils import to_cuda, get_local_rank, init_distributed, seed_everything, \
|
||||||
|
using_tensor_cores, increase_l2_fetch_granularity
|
||||||
|
|
||||||
|
|
||||||
|
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, path: pathlib.Path, callbacks: List[BaseCallback]):
|
||||||
|
""" Saves model, optimizer and epoch states to path (only once per node) """
|
||||||
|
if get_local_rank() == 0:
|
||||||
|
state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
|
||||||
|
checkpoint = {
|
||||||
|
'state_dict': state_dict,
|
||||||
|
'optimizer_state_dict': optimizer.state_dict(),
|
||||||
|
'epoch': epoch
|
||||||
|
}
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_checkpoint_save(checkpoint)
|
||||||
|
|
||||||
|
torch.save(checkpoint, str(path))
|
||||||
|
logging.info(f'Saved checkpoint to {str(path)}')
|
||||||
|
|
||||||
|
|
||||||
|
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
|
||||||
|
""" Loads model, optimizer and epoch states from path """
|
||||||
|
checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
|
||||||
|
if isinstance(model, DistributedDataParallel):
|
||||||
|
model.module.load_state_dict(checkpoint['state_dict'])
|
||||||
|
else:
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_checkpoint_load(checkpoint)
|
||||||
|
|
||||||
|
logging.info(f'Loaded checkpoint from {str(path)}')
|
||||||
|
return checkpoint['epoch']
|
||||||
|
|
||||||
|
|
||||||
|
def train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args):
|
||||||
|
losses = []
|
||||||
|
for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
|
||||||
|
desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
|
||||||
|
*inputs, target = to_cuda(batch)
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_batch_start()
|
||||||
|
|
||||||
|
with torch.cuda.amp.autocast(enabled=args.amp):
|
||||||
|
pred = model(*inputs)
|
||||||
|
loss = loss_fn(pred, target) / args.accumulate_grad_batches
|
||||||
|
|
||||||
|
grad_scaler.scale(loss).backward()
|
||||||
|
|
||||||
|
# gradient accumulation
|
||||||
|
if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
|
||||||
|
if args.gradient_clip:
|
||||||
|
grad_scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
|
||||||
|
|
||||||
|
grad_scaler.step(optimizer)
|
||||||
|
grad_scaler.update()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
return np.mean(losses)
|
||||||
|
|
||||||
|
|
||||||
|
def train(model: nn.Module,
|
||||||
|
loss_fn: _Loss,
|
||||||
|
train_dataloader: DataLoader,
|
||||||
|
val_dataloader: DataLoader,
|
||||||
|
callbacks: List[BaseCallback],
|
||||||
|
logger: Logger,
|
||||||
|
args):
|
||||||
|
device = torch.cuda.current_device()
|
||||||
|
model.to(device=device)
|
||||||
|
local_rank = get_local_rank()
|
||||||
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
|
||||||
|
|
||||||
|
model.train()
|
||||||
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
||||||
|
if args.optimizer == 'adam':
|
||||||
|
optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
||||||
|
weight_decay=args.weight_decay)
|
||||||
|
elif args.optimizer == 'lamb':
|
||||||
|
optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
|
||||||
|
weight_decay=args.weight_decay)
|
||||||
|
else:
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
||||||
|
weight_decay=args.weight_decay)
|
||||||
|
|
||||||
|
epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_fit_start(optimizer, args)
|
||||||
|
|
||||||
|
for epoch_idx in range(epoch_start, args.epochs):
|
||||||
|
if isinstance(train_dataloader.sampler, DistributedSampler):
|
||||||
|
train_dataloader.sampler.set_epoch(epoch_idx)
|
||||||
|
|
||||||
|
loss = train_epoch(model, train_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks, args)
|
||||||
|
if dist.is_initialized():
|
||||||
|
loss = torch.tensor(loss, dtype=torch.float, device=device)
|
||||||
|
torch.distributed.all_reduce(loss)
|
||||||
|
loss = (loss / world_size).item()
|
||||||
|
|
||||||
|
logging.info(f'Train loss: {loss}')
|
||||||
|
logger.log_metrics({'train loss': loss}, epoch_idx)
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_epoch_end()
|
||||||
|
|
||||||
|
if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
|
||||||
|
and (epoch_idx + 1) % args.ckpt_interval == 0:
|
||||||
|
save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
|
||||||
|
|
||||||
|
if not args.benchmark and args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0:
|
||||||
|
evaluate(model, val_dataloader, callbacks, args)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_validation_end(epoch_idx)
|
||||||
|
|
||||||
|
if args.save_ckpt_path is not None and not args.benchmark:
|
||||||
|
save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)
|
||||||
|
|
||||||
|
for callback in callbacks:
|
||||||
|
callback.on_fit_end()
|
||||||
|
|
||||||
|
|
||||||
|
def print_parameters_count(model):
|
||||||
|
num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
logging.info(f'Number of trainable parameters: {num_params_trainable}')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
is_distributed = init_distributed()
|
||||||
|
local_rank = get_local_rank()
|
||||||
|
args = PARSER.parse_args()
|
||||||
|
|
||||||
|
logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)
|
||||||
|
|
||||||
|
logging.info('====== SE(3)-Transformer ======')
|
||||||
|
logging.info('| Training procedure |')
|
||||||
|
logging.info('===============================')
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
|
logging.info(f'Using seed {args.seed}')
|
||||||
|
seed_everything(args.seed)
|
||||||
|
|
||||||
|
logger = LoggerCollection([
|
||||||
|
DLLogger(save_dir=args.log_dir, filename=args.dllogger_name),
|
||||||
|
WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer')
|
||||||
|
])
|
||||||
|
|
||||||
|
datamodule = QM9DataModule(**vars(args))
|
||||||
|
model = SE3TransformerPooled(
|
||||||
|
fiber_in=Fiber({0: datamodule.NODE_FEATURE_DIM}),
|
||||||
|
fiber_out=Fiber({0: args.num_degrees * args.num_channels}),
|
||||||
|
fiber_edge=Fiber({0: datamodule.EDGE_FEATURE_DIM}),
|
||||||
|
output_dim=1,
|
||||||
|
tensor_cores=using_tensor_cores(args.amp), # use Tensor Cores more effectively
|
||||||
|
**vars(args)
|
||||||
|
)
|
||||||
|
loss_fn = nn.L1Loss()
|
||||||
|
|
||||||
|
if args.benchmark:
|
||||||
|
logging.info('Running benchmark mode')
|
||||||
|
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||||
|
callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
|
||||||
|
else:
|
||||||
|
callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
|
||||||
|
QM9LRSchedulerCallback(logger, epochs=args.epochs)]
|
||||||
|
|
||||||
|
if is_distributed:
|
||||||
|
gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
|
||||||
|
|
||||||
|
print_parameters_count(model)
|
||||||
|
logger.log_hyperparams(vars(args))
|
||||||
|
increase_l2_fetch_granularity()
|
||||||
|
train(model,
|
||||||
|
loss_fn,
|
||||||
|
datamodule.train_dataloader(),
|
||||||
|
datamodule.val_dataloader(),
|
||||||
|
callbacks,
|
||||||
|
logger,
|
||||||
|
args)
|
||||||
|
|
||||||
|
logging.info('Training finished successfully')
|
130
rf2aa/SE3Transformer/se3_transformer/runtime/utils.py
Normal file
130
rf2aa/SE3Transformer/se3_transformer/runtime/utils.py
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import ctypes
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Union, List, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_residual(feats1, feats2, method: str):
|
||||||
|
""" Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """
|
||||||
|
if method in ['add', 'sum']:
|
||||||
|
return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()}
|
||||||
|
elif method in ['cat', 'concat']:
|
||||||
|
return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()}
|
||||||
|
else:
|
||||||
|
raise ValueError('Method must be add/sum or cat/concat')
|
||||||
|
|
||||||
|
|
||||||
|
def degree_to_dim(degree: int) -> int:
|
||||||
|
return 2 * degree + 1
|
||||||
|
|
||||||
|
|
||||||
|
def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]:
|
||||||
|
return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1)))
|
||||||
|
|
||||||
|
|
||||||
|
def str2bool(v: Union[bool, str]) -> bool:
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return v
|
||||||
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||||
|
return True
|
||||||
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||||
|
|
||||||
|
|
||||||
|
def to_cuda(x):
|
||||||
|
""" Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
|
||||||
|
if isinstance(x, Tensor):
|
||||||
|
return x.cuda(non_blocking=True)
|
||||||
|
elif isinstance(x, tuple):
|
||||||
|
return (to_cuda(v) for v in x)
|
||||||
|
elif isinstance(x, list):
|
||||||
|
return [to_cuda(v) for v in x]
|
||||||
|
elif isinstance(x, dict):
|
||||||
|
return {k: to_cuda(v) for k, v in x.items()}
|
||||||
|
else:
|
||||||
|
# DGLGraph or other objects
|
||||||
|
return x.to(device=torch.cuda.current_device())
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_rank() -> int:
|
||||||
|
return int(os.environ.get('LOCAL_RANK', 0))
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed() -> bool:
|
||||||
|
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||||
|
distributed = world_size > 1
|
||||||
|
if distributed:
|
||||||
|
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
|
||||||
|
dist.init_process_group(backend=backend, init_method='env://')
|
||||||
|
if backend == 'nccl':
|
||||||
|
torch.cuda.set_device(get_local_rank())
|
||||||
|
else:
|
||||||
|
logging.warning('Running on CPU only!')
|
||||||
|
assert torch.distributed.is_initialized()
|
||||||
|
return distributed
|
||||||
|
|
||||||
|
|
||||||
|
def increase_l2_fetch_granularity():
|
||||||
|
# maximum fetch granularity of L2: 128 bytes
|
||||||
|
_libcudart = ctypes.CDLL('libcudart.so')
|
||||||
|
# set device limit on the current device
|
||||||
|
# cudaLimitMaxL2FetchGranularity = 0x05
|
||||||
|
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
||||||
|
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
||||||
|
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
|
||||||
|
assert pValue.contents.value == 128
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed):
|
||||||
|
seed = int(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def rank_zero_only(fn):
|
||||||
|
@wraps(fn)
|
||||||
|
def wrapped_fn(*args, **kwargs):
|
||||||
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapped_fn
|
||||||
|
|
||||||
|
|
||||||
|
def using_tensor_cores(amp: bool) -> bool:
|
||||||
|
major_cc, minor_cc = torch.cuda.get_device_capability()
|
||||||
|
return (amp and major_cc >= 7) or major_cc >= 8
|
11
rf2aa/SE3Transformer/setup.py
Normal file
11
rf2aa/SE3Transformer/setup.py
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='se3-transformer',
|
||||||
|
packages=find_packages(),
|
||||||
|
include_package_data=True,
|
||||||
|
version='1.0.0',
|
||||||
|
description='PyTorch + DGL implementation of SE(3)-Transformers',
|
||||||
|
author='Alexandre Milesi',
|
||||||
|
author_email='alexandrem@nvidia.com',
|
||||||
|
)
|
0
rf2aa/SE3Transformer/tests/__init__.py
Normal file
0
rf2aa/SE3Transformer/tests/__init__.py
Normal file
103
rf2aa/SE3Transformer/tests/test_equivariance.py
Normal file
103
rf2aa/SE3Transformer/tests/test_equivariance.py
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model import SE3Transformer
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||||
|
from tests.utils import get_random_graph, assign_relative_pos, get_max_diff, rot
|
||||||
|
|
||||||
|
# Tolerances for equivariance error abs( f(x) @ R - f(x @ R) )
|
||||||
|
TOL = 1e-3
|
||||||
|
CHANNELS, NODES = 32, 512
|
||||||
|
|
||||||
|
|
||||||
|
def _get_outputs(model, R):
|
||||||
|
feats0 = torch.randn(NODES, CHANNELS, 1)
|
||||||
|
feats1 = torch.randn(NODES, CHANNELS, 3)
|
||||||
|
|
||||||
|
coords = torch.randn(NODES, 3)
|
||||||
|
graph = get_random_graph(NODES)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
feats0 = feats0.cuda()
|
||||||
|
feats1 = feats1.cuda()
|
||||||
|
R = R.cuda()
|
||||||
|
coords = coords.cuda()
|
||||||
|
graph = graph.to('cuda')
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
graph1 = assign_relative_pos(graph, coords)
|
||||||
|
out1 = model(graph1, {'0': feats0, '1': feats1}, {})
|
||||||
|
graph2 = assign_relative_pos(graph, coords @ R)
|
||||||
|
out2 = model(graph2, {'0': feats0, '1': feats1 @ R}, {})
|
||||||
|
|
||||||
|
return out1, out2
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model(**kwargs):
|
||||||
|
return SE3Transformer(
|
||||||
|
num_layers=4,
|
||||||
|
fiber_in=Fiber.create(2, CHANNELS),
|
||||||
|
fiber_hidden=Fiber.create(3, CHANNELS),
|
||||||
|
fiber_out=Fiber.create(2, CHANNELS),
|
||||||
|
fiber_edge=Fiber({}),
|
||||||
|
num_heads=8,
|
||||||
|
channels_div=2,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_equivariance():
|
||||||
|
model = _get_model()
|
||||||
|
R = rot(*torch.rand(3))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
R = R.cuda()
|
||||||
|
out1, out2 = _get_outputs(model, R)
|
||||||
|
|
||||||
|
assert torch.allclose(out2['0'], out1['0'], atol=TOL), \
|
||||||
|
f'type-0 features should be invariant {get_max_diff(out1["0"], out2["0"])}'
|
||||||
|
assert torch.allclose(out2['1'], (out1['1'] @ R), atol=TOL), \
|
||||||
|
f'type-1 features should be equivariant {get_max_diff(out1["1"] @ R, out2["1"])}'
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_equivariance_pooled():
|
||||||
|
model = _get_model(pooling='avg', return_type=1)
|
||||||
|
R = rot(*torch.rand(3))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
R = R.cuda()
|
||||||
|
out1, out2 = _get_outputs(model, R)
|
||||||
|
|
||||||
|
assert torch.allclose(out2, (out1 @ R), atol=TOL), \
|
||||||
|
f'type-1 features should be equivariant {get_max_diff(out1 @ R, out2)}'
|
||||||
|
|
||||||
|
@pytest.mark.skip
|
||||||
|
def test_invariance_pooled():
|
||||||
|
model = _get_model(pooling='avg', return_type=0)
|
||||||
|
R = rot(*torch.rand(3))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
R = R.cuda()
|
||||||
|
out1, out2 = _get_outputs(model, R)
|
||||||
|
|
||||||
|
assert torch.allclose(out2, out1, atol=TOL), \
|
||||||
|
f'type-0 features should be invariant {get_max_diff(out1, out2)}'
|
60
rf2aa/SE3Transformer/tests/utils.py
Normal file
60
rf2aa/SE3Transformer/tests/utils.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||||
|
#
|
||||||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||||
|
# copy of this software and associated documentation files (the "Software"),
|
||||||
|
# to deal in the Software without restriction, including without limitation
|
||||||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||||||
|
# Software is furnished to do so, subject to the following conditions:
|
||||||
|
#
|
||||||
|
# The above copyright notice and this permission notice shall be included in
|
||||||
|
# all copies or substantial portions of the Software.
|
||||||
|
#
|
||||||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||||||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
|
# DEALINGS IN THE SOFTWARE.
|
||||||
|
#
|
||||||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
|
||||||
|
import dgl
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_graph(N, num_edges_factor=18):
|
||||||
|
graph = dgl.transforms.remove_self_loop(dgl.rand_graph(N, N * num_edges_factor))
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def assign_relative_pos(graph, coords):
|
||||||
|
src, dst = graph.edges()
|
||||||
|
graph.edata['rel_pos'] = coords[src] - coords[dst]
|
||||||
|
return graph
|
||||||
|
|
||||||
|
|
||||||
|
def get_max_diff(a, b):
|
||||||
|
return (a - b).abs().max().item()
|
||||||
|
|
||||||
|
|
||||||
|
def rot_z(gamma):
|
||||||
|
return torch.tensor([
|
||||||
|
[torch.cos(gamma), -torch.sin(gamma), 0],
|
||||||
|
[torch.sin(gamma), torch.cos(gamma), 0],
|
||||||
|
[0, 0, 1]
|
||||||
|
], dtype=gamma.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def rot_y(beta):
|
||||||
|
return torch.tensor([
|
||||||
|
[torch.cos(beta), 0, torch.sin(beta)],
|
||||||
|
[0, 1, 0],
|
||||||
|
[-torch.sin(beta), 0, torch.cos(beta)]
|
||||||
|
], dtype=beta.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def rot(alpha, beta, gamma):
|
||||||
|
return rot_z(alpha) @ rot_y(beta) @ rot_z(gamma)
|
0
rf2aa/__init__.py
Normal file
0
rf2aa/__init__.py
Normal file
BIN
rf2aa/atomized_protein_frames.pt
Normal file
BIN
rf2aa/atomized_protein_frames.pt
Normal file
Binary file not shown.
9052
rf2aa/cartbonded.json
Normal file
9052
rf2aa/cartbonded.json
Normal file
File diff suppressed because it is too large
Load diff
2862
rf2aa/chemical.py
Normal file
2862
rf2aa/chemical.py
Normal file
File diff suppressed because it is too large
Load diff
70
rf2aa/config/inference/base.yaml
Normal file
70
rf2aa/config/inference/base.yaml
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
job_name: "structure_prediction"
|
||||||
|
output_path: ""
|
||||||
|
checkpoint_path: RFAA_paper_weights.pt
|
||||||
|
database_params:
|
||||||
|
sequencedb: ""
|
||||||
|
hhdb: "pdb100_2022Apr19/pdb100_2022Apr19"
|
||||||
|
command: make_msa.sh
|
||||||
|
num_cpus: 4
|
||||||
|
mem: 64
|
||||||
|
protein_inputs: null
|
||||||
|
na_inputs: null
|
||||||
|
sm_inputs: null
|
||||||
|
covale_inputs: null
|
||||||
|
residue_replacement: null
|
||||||
|
|
||||||
|
chem_params:
|
||||||
|
use_phospate_frames_for_NA: True
|
||||||
|
use_cif_ordering_for_trp: True
|
||||||
|
|
||||||
|
loader_params:
|
||||||
|
n_templ: 4
|
||||||
|
MAXLAT: 128
|
||||||
|
MAXSEQ: 1024
|
||||||
|
MAXCYCLE: 4
|
||||||
|
BLACK_HOLE_INIT: False
|
||||||
|
seqid: 150.0
|
||||||
|
|
||||||
|
|
||||||
|
legacy_model_param:
|
||||||
|
n_extra_block: 4
|
||||||
|
n_main_block: 32
|
||||||
|
n_ref_block: 4
|
||||||
|
n_finetune_block: 0
|
||||||
|
d_msa: 256
|
||||||
|
d_msa_full: 64
|
||||||
|
d_pair: 192
|
||||||
|
d_templ: 64
|
||||||
|
n_head_msa: 8
|
||||||
|
n_head_pair: 6
|
||||||
|
n_head_templ: 4
|
||||||
|
d_hidden_templ: 64
|
||||||
|
p_drop: 0.0
|
||||||
|
use_chiral_l1: True
|
||||||
|
use_lj_l1: True
|
||||||
|
use_atom_frames: True
|
||||||
|
recycling_type: "all"
|
||||||
|
use_same_chain: True
|
||||||
|
lj_lin: 0.75
|
||||||
|
SE3_param:
|
||||||
|
num_layers: 1
|
||||||
|
num_channels: 32
|
||||||
|
num_degrees: 2
|
||||||
|
l0_in_features: 64
|
||||||
|
l0_out_features: 64
|
||||||
|
l1_in_features: 3
|
||||||
|
l1_out_features: 2
|
||||||
|
num_edge_features: 64
|
||||||
|
n_heads: 4
|
||||||
|
div: 4
|
||||||
|
SE3_ref_param:
|
||||||
|
num_layers: 2
|
||||||
|
num_channels: 32
|
||||||
|
num_degrees: 2
|
||||||
|
l0_in_features: 64
|
||||||
|
l0_out_features: 64
|
||||||
|
l1_in_features: 3
|
||||||
|
l1_out_features: 2
|
||||||
|
num_edge_features: 64
|
||||||
|
n_heads: 4
|
||||||
|
div: 4
|
18
rf2aa/config/inference/covalent.yaml
Normal file
18
rf2aa/config/inference/covalent.yaml
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
job_name: 7s69_A
|
||||||
|
|
||||||
|
protein_inputs:
|
||||||
|
A:
|
||||||
|
fasta_file: examples/protein/7s69_A.fasta
|
||||||
|
|
||||||
|
sm_inputs:
|
||||||
|
B:
|
||||||
|
input: examples/small_molecule/7s69_glycan.sdf
|
||||||
|
input_type: sdf
|
||||||
|
|
||||||
|
covale_inputs: "[((\"A\", \"74\", \"ND2\"), (\"B\", \"1\"), (\"CW\", \"null\"))]"
|
||||||
|
|
||||||
|
loader_params:
|
||||||
|
MAXCYCLE: 10
|
14
rf2aa/config/inference/nucleic_acid.yaml
Normal file
14
rf2aa/config/inference/nucleic_acid.yaml
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
job_name: "7u7w_protein_nucleic"
|
||||||
|
protein_inputs:
|
||||||
|
A:
|
||||||
|
fasta_file: examples/protein/7u7w_A.fasta
|
||||||
|
na_inputs:
|
||||||
|
B:
|
||||||
|
fasta: examples/nucleic_acid/7u7w_B.fasta
|
||||||
|
input_type: "dna"
|
||||||
|
C:
|
||||||
|
fasta: examples/nucleic_acid/7u7w_C.fasta
|
||||||
|
input_type: "dna"
|
7
rf2aa/config/inference/protein.yaml
Normal file
7
rf2aa/config/inference/protein.yaml
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
job_name: "7u7w_protein"
|
||||||
|
protein_inputs:
|
||||||
|
A:
|
||||||
|
fasta_file: examples/protein/7u7w_A.fasta
|
18
rf2aa/config/inference/protein_na_sm.yaml
Normal file
18
rf2aa/config/inference/protein_na_sm.yaml
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
job_name: "7u7w_protein_nucleic_sm"
|
||||||
|
protein_inputs:
|
||||||
|
A:
|
||||||
|
fasta_file: examples/protein/7u7w_A.fasta
|
||||||
|
na_inputs:
|
||||||
|
B:
|
||||||
|
fasta: examples/nucleic_acid/7u7w_B.fasta
|
||||||
|
input_type: "dna"
|
||||||
|
C:
|
||||||
|
fasta: examples/nucleic_acid/7u7w_C.fasta
|
||||||
|
input_type: "dna"
|
||||||
|
sm_inputs:
|
||||||
|
D:
|
||||||
|
input: examples/small_molecule/XG4.sdf
|
||||||
|
input_type: "sdf"
|
308
rf2aa/data/covale.py
Normal file
308
rf2aa/data/covale.py
Normal file
|
@ -0,0 +1,308 @@
|
||||||
|
import torch
|
||||||
|
from openbabel import openbabel
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
from rf2aa.data.parsers import parse_mol
|
||||||
|
from rf2aa.data.small_molecule import compute_features_from_obmol
|
||||||
|
from rf2aa.util import get_bond_feats
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MoleculeToMoleculeBond:
|
||||||
|
chain_index_first: int
|
||||||
|
absolute_atom_index_first: int
|
||||||
|
chain_index_second: int
|
||||||
|
absolute_atom_index_second: int
|
||||||
|
new_chirality_atom_first: Optional[str]
|
||||||
|
new_chirality_atom_second: Optional[str]
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AtomizedResidue:
|
||||||
|
chain: str
|
||||||
|
chain_index_in_combined_chain: int
|
||||||
|
absolute_N_index_in_chain: int
|
||||||
|
absolute_C_index_in_chain: int
|
||||||
|
original_chain: str
|
||||||
|
index_in_original_chain: int
|
||||||
|
|
||||||
|
|
||||||
|
def load_covalent_molecules(protein_inputs, config, model_runner):
|
||||||
|
if config.covale_inputs is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if config.sm_inputs is None:
|
||||||
|
raise ValueError("If you provide covale_inputs, you must also provide small molecule inputs")
|
||||||
|
|
||||||
|
covalent_bonds = eval(config.covale_inputs)
|
||||||
|
sm_inputs = delete_leaving_atoms(config.sm_inputs)
|
||||||
|
residues_to_atomize, combined_molecules, extra_bonds = find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner)
|
||||||
|
chainid_to_input = {}
|
||||||
|
for chain, combined_molecule in combined_molecules.items():
|
||||||
|
extra_bonds_for_chain = extra_bonds[chain]
|
||||||
|
msa, bond_feats, xyz, Ls = get_combined_atoms_bonds(combined_molecule)
|
||||||
|
residues_to_atomize = update_absolute_indices_after_combination(residues_to_atomize, chain, Ls)
|
||||||
|
mol = make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds_for_chain)
|
||||||
|
xyz = recompute_xyz_after_chirality(mol)
|
||||||
|
input = compute_features_from_obmol(mol, msa, xyz, model_runner)
|
||||||
|
chainid_to_input[chain] = input
|
||||||
|
|
||||||
|
return chainid_to_input, residues_to_atomize
|
||||||
|
|
||||||
|
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner):
|
||||||
|
residues_to_atomize = [] # hold on to delete wayward inputs
|
||||||
|
combined_molecules = {} # combined multiple molecules that are bonded
|
||||||
|
extra_bonds = {}
|
||||||
|
for bond in covalent_bonds:
|
||||||
|
prot_chid, prot_res_idx, atom_to_bond = bond[0]
|
||||||
|
sm_chid, sm_atom_num = bond[1]
|
||||||
|
chirality_first_atom, chirality_second_atom = bond[2]
|
||||||
|
if chirality_first_atom.strip() == "null":
|
||||||
|
chirality_first_atom = None
|
||||||
|
if chirality_second_atom.strip() == "null":
|
||||||
|
chirality_second_atom = None
|
||||||
|
|
||||||
|
sm_atom_num = int(sm_atom_num) - 1 # 0 index
|
||||||
|
try:
|
||||||
|
assert sm_chid in sm_inputs, f"must provide a small molecule chain {sm_chid} for covalent bond: {bond}"
|
||||||
|
except:
|
||||||
|
print(f"Skipping bond: {bond} since no sm chain {sm_chid} was provided")
|
||||||
|
continue
|
||||||
|
assert sm_inputs[sm_chid].input_type == "sdf", "only sdf inputs can be covalently linked to proteins"
|
||||||
|
try:
|
||||||
|
protein_input = protein_inputs[prot_chid]
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"first atom in covale_input must be present in\
|
||||||
|
a protein chain. Given chain: {prot_chid} was not in \
|
||||||
|
given protein chains: {list(protein_inputs.keys())}")
|
||||||
|
|
||||||
|
residue = (prot_chid, prot_res_idx, atom_to_bond)
|
||||||
|
file, atom_index = convert_residue_to_molecule(protein_inputs, residue, model_runner)
|
||||||
|
if sm_chid not in combined_molecules:
|
||||||
|
combined_molecules[sm_chid] = [sm_inputs[sm_chid].input]
|
||||||
|
combined_molecules[sm_chid].insert(0, file) # this is a bug, revert
|
||||||
|
absolute_chain_index_first = combined_molecules[sm_chid].index(sm_inputs[sm_chid].input)
|
||||||
|
absolute_chain_index_second = combined_molecules[sm_chid].index(file)
|
||||||
|
|
||||||
|
if sm_chid not in extra_bonds:
|
||||||
|
extra_bonds[sm_chid] = []
|
||||||
|
extra_bonds[sm_chid].append(MoleculeToMoleculeBond(
|
||||||
|
absolute_chain_index_first,
|
||||||
|
sm_atom_num,
|
||||||
|
absolute_chain_index_second,
|
||||||
|
atom_index,
|
||||||
|
new_chirality_atom_first=chirality_first_atom,
|
||||||
|
new_chirality_atom_second=chirality_second_atom
|
||||||
|
))
|
||||||
|
residues_to_atomize.append(AtomizedResidue(
|
||||||
|
sm_chid,
|
||||||
|
absolute_chain_index_second,
|
||||||
|
0,
|
||||||
|
2,
|
||||||
|
prot_chid,
|
||||||
|
int(prot_res_idx) -1
|
||||||
|
))
|
||||||
|
|
||||||
|
return residues_to_atomize, combined_molecules, extra_bonds
|
||||||
|
|
||||||
|
def convert_residue_to_molecule(protein_inputs, residue, model_runner):
|
||||||
|
"""convert residue into sdf and record index for covalent bond"""
|
||||||
|
prot_chid, prot_res_idx, atom_to_bond = residue
|
||||||
|
protein_input = protein_inputs[prot_chid]
|
||||||
|
prot_res_abs_idx = int(prot_res_idx) -1
|
||||||
|
residue_identity_num = protein_input.query_sequence()[prot_res_abs_idx]
|
||||||
|
residue_identity = ChemData().num2aa[residue_identity_num]
|
||||||
|
molecule_info = model_runner.molecule_db[residue_identity]
|
||||||
|
sdf = molecule_info["sdf"]
|
||||||
|
temp_file = create_and_populate_temp_file(sdf)
|
||||||
|
is_heavy = [i for i, a in enumerate(molecule_info["atom_id"]) if a[0] != "H"]
|
||||||
|
is_leaving = [a for i,a in enumerate(molecule_info["leaving"]) if i in is_heavy]
|
||||||
|
|
||||||
|
sdf_string_no_leaving_atoms = delete_leaving_atoms_single_chain(temp_file, is_leaving )
|
||||||
|
temp_file = create_and_populate_temp_file(sdf_string_no_leaving_atoms)
|
||||||
|
atom_names = molecule_info["atom_id"]
|
||||||
|
atom_index = atom_names.index(atom_to_bond.strip())
|
||||||
|
return temp_file, atom_index
|
||||||
|
|
||||||
|
def get_combined_atoms_bonds(combined_molecule):
|
||||||
|
atom_list = []
|
||||||
|
bond_feats_list = []
|
||||||
|
xyzs = []
|
||||||
|
Ls = []
|
||||||
|
for molecule in combined_molecule:
|
||||||
|
obmol, msa, ins, xyz, mask = parse_mol(
|
||||||
|
molecule,
|
||||||
|
filetype="sdf",
|
||||||
|
string=False,
|
||||||
|
generate_conformer=True,
|
||||||
|
find_automorphs=False
|
||||||
|
)
|
||||||
|
bond_feats = get_bond_feats(obmol)
|
||||||
|
|
||||||
|
atom_list.append(msa)
|
||||||
|
bond_feats_list.append(bond_feats)
|
||||||
|
xyzs.append(xyz)
|
||||||
|
Ls.append(msa.shape[0])
|
||||||
|
|
||||||
|
atoms = torch.cat(atom_list)
|
||||||
|
L_total = sum(Ls)
|
||||||
|
bond_feats = torch.zeros((L_total, L_total)).long()
|
||||||
|
offset = 0
|
||||||
|
for bf in bond_feats_list:
|
||||||
|
L = bf.shape[0]
|
||||||
|
bond_feats[offset:offset+L, offset:offset+L] = bf
|
||||||
|
offset += L
|
||||||
|
xyz = torch.cat(xyzs, dim=1)[0]
|
||||||
|
return atoms, bond_feats, xyz, Ls
|
||||||
|
|
||||||
|
def make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds):
|
||||||
|
mol = openbabel.OBMol()
|
||||||
|
for i,k in enumerate(msa):
|
||||||
|
element = ChemData().num2aa[k]
|
||||||
|
atomnum = ChemData().atomtype2atomnum[element]
|
||||||
|
a = mol.NewAtom()
|
||||||
|
a.SetAtomicNum(atomnum)
|
||||||
|
a.SetVector(float(xyz[i,0]), float(xyz[i,1]), float(xyz[i,2]))
|
||||||
|
|
||||||
|
first_index, second_index = bond_feats.nonzero(as_tuple=True)
|
||||||
|
for i, j in zip(first_index, second_index):
|
||||||
|
order = bond_feats[i,j]
|
||||||
|
bond = make_openbabel_bond(mol, i.item(), j.item(), order.item())
|
||||||
|
mol.AddBond(bond)
|
||||||
|
|
||||||
|
for bond in extra_bonds:
|
||||||
|
absolute_index_first = get_absolute_index_from_relative_indices(
|
||||||
|
bond.chain_index_first,
|
||||||
|
bond.absolute_atom_index_first,
|
||||||
|
Ls
|
||||||
|
)
|
||||||
|
absolute_index_second = get_absolute_index_from_relative_indices(
|
||||||
|
bond.chain_index_second,
|
||||||
|
bond.absolute_atom_index_second,
|
||||||
|
Ls
|
||||||
|
)
|
||||||
|
order = 1 #all covale bonds are single bonds
|
||||||
|
openbabel_bond = make_openbabel_bond(mol, absolute_index_first, absolute_index_second, order)
|
||||||
|
mol.AddBond(openbabel_bond)
|
||||||
|
set_chirality(mol, absolute_index_first, bond.new_chirality_atom_first)
|
||||||
|
set_chirality(mol, absolute_index_second, bond.new_chirality_atom_second)
|
||||||
|
return mol
|
||||||
|
|
||||||
|
def make_openbabel_bond(mol, i, j, order):
|
||||||
|
obb = openbabel.OBBond()
|
||||||
|
obb.SetBegin(mol.GetAtom(i+1))
|
||||||
|
obb.SetEnd(mol.GetAtom(j+1))
|
||||||
|
if order == 4:
|
||||||
|
obb.SetBondOrder(2)
|
||||||
|
obb.SetAromatic()
|
||||||
|
else:
|
||||||
|
obb.SetBondOrder(order)
|
||||||
|
return obb
|
||||||
|
|
||||||
|
def set_chirality(mol, absolute_atom_index, new_chirality):
|
||||||
|
stereo = openbabel.OBStereoFacade(mol)
|
||||||
|
if stereo.HasTetrahedralStereo(absolute_atom_index+1):
|
||||||
|
tetstereo = stereo.GetTetrahedralStereo(mol.GetAtom(absolute_atom_index+1).GetId())
|
||||||
|
if tetstereo is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
assert new_chirality is not None, "you have introduced a new stereocenter, \
|
||||||
|
so you must specify its chirality either as CW, or CCW"
|
||||||
|
|
||||||
|
config = tetstereo.GetConfig()
|
||||||
|
config.winding = chirality_options[new_chirality]
|
||||||
|
tetstereo.SetConfig(config)
|
||||||
|
print("Updating chirality...")
|
||||||
|
else:
|
||||||
|
assert new_chirality is None, "you have specified a chirality without creating a new chiral center"
|
||||||
|
|
||||||
|
chirality_options = {
|
||||||
|
"CW": openbabel.OBStereo.Clockwise,
|
||||||
|
"CCW": openbabel.OBStereo.AntiClockwise,
|
||||||
|
}
|
||||||
|
|
||||||
|
def recompute_xyz_after_chirality(obmol):
|
||||||
|
builder = openbabel.OBBuilder()
|
||||||
|
builder.Build(obmol)
|
||||||
|
ff = openbabel.OBForceField.FindForceField("mmff94")
|
||||||
|
did_setup = ff.Setup(obmol)
|
||||||
|
if did_setup:
|
||||||
|
ff.FastRotorSearch()
|
||||||
|
ff.GetCoordinates(obmol)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
|
||||||
|
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
|
||||||
|
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
|
||||||
|
return atom_coords
|
||||||
|
|
||||||
|
def delete_leaving_atoms(sm_inputs):
|
||||||
|
updated_sm_inputs = {}
|
||||||
|
for chain in sm_inputs:
|
||||||
|
if "is_leaving" not in sm_inputs[chain]:
|
||||||
|
continue
|
||||||
|
is_leaving = eval(sm_inputs[chain]["is_leaving"])
|
||||||
|
sdf_string = delete_leaving_atoms_single_chain(sm_inputs[chain]["input"], is_leaving)
|
||||||
|
updated_sm_inputs[chain] = {
|
||||||
|
"input": create_and_populate_temp_file(sdf_string),
|
||||||
|
"input_type": "sdf"
|
||||||
|
}
|
||||||
|
|
||||||
|
sm_inputs.update(updated_sm_inputs)
|
||||||
|
return sm_inputs
|
||||||
|
|
||||||
|
def delete_leaving_atoms_single_chain(filename, is_leaving):
|
||||||
|
obmol, msa, ins, xyz, mask = parse_mol(
|
||||||
|
filename,
|
||||||
|
filetype="sdf",
|
||||||
|
string=False,
|
||||||
|
generate_conformer=True
|
||||||
|
)
|
||||||
|
assert len(is_leaving) == obmol.NumAtoms()
|
||||||
|
leaving_indices = torch.tensor(is_leaving).nonzero()
|
||||||
|
for idx in leaving_indices:
|
||||||
|
obmol.DeleteAtom(obmol.GetAtom(idx.item()+1))
|
||||||
|
obConversion = openbabel.OBConversion()
|
||||||
|
obConversion.SetInAndOutFormats("sdf", "sdf")
|
||||||
|
sdf_string = obConversion.WriteString(obmol)
|
||||||
|
return sdf_string
|
||||||
|
|
||||||
|
def get_absolute_index_from_relative_indices(chain_index, absolute_index_in_chain, Ls):
|
||||||
|
offset = sum(Ls[:chain_index])
|
||||||
|
return offset + absolute_index_in_chain
|
||||||
|
|
||||||
|
def update_absolute_indices_after_combination(residues_to_atomize, chain, Ls):
|
||||||
|
updated_residues_to_atomize = []
|
||||||
|
for residue in residues_to_atomize:
|
||||||
|
if residue.chain == chain:
|
||||||
|
absolute_index_N = get_absolute_index_from_relative_indices(
|
||||||
|
residue.chain_index_in_combined_chain,
|
||||||
|
residue.absolute_N_index_in_chain,
|
||||||
|
Ls)
|
||||||
|
absolute_index_C = get_absolute_index_from_relative_indices(
|
||||||
|
residue.chain_index_in_combined_chain,
|
||||||
|
residue.absolute_C_index_in_chain,
|
||||||
|
Ls)
|
||||||
|
updated_residue = AtomizedResidue(
|
||||||
|
residue.chain,
|
||||||
|
None,
|
||||||
|
absolute_index_N,
|
||||||
|
absolute_index_C,
|
||||||
|
residue.original_chain,
|
||||||
|
residue.index_in_original_chain
|
||||||
|
)
|
||||||
|
updated_residues_to_atomize.append(updated_residue)
|
||||||
|
else:
|
||||||
|
updated_residues_to_atomize.append(residue)
|
||||||
|
return updated_residues_to_atomize
|
||||||
|
|
||||||
|
def create_and_populate_temp_file(data):
|
||||||
|
# Create a temporary file
|
||||||
|
with NamedTemporaryFile(mode='w+', delete=False) as temp_file:
|
||||||
|
# Write the string to the temporary file
|
||||||
|
temp_file.write(data)
|
||||||
|
|
||||||
|
# Get the filename
|
||||||
|
temp_file_name = temp_file.name
|
||||||
|
|
||||||
|
return temp_file_name
|
198
rf2aa/data/data_loader.py
Normal file
198
rf2aa/data/data_loader.py
Normal file
|
@ -0,0 +1,198 @@
|
||||||
|
import torch
|
||||||
|
from dataclasses import dataclass, fields
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
from rf2aa.data.data_loader_utils import MSAFeaturize, get_bond_distances, generate_xyz_prev
|
||||||
|
from rf2aa.kinematics import xyz_to_t2d
|
||||||
|
from rf2aa.util import get_prot_sm_mask, xyz_t_to_frame_xyz, same_chain_from_bond_feats, \
|
||||||
|
Ls_from_same_chain_2d, idx_from_Ls, is_atom
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RawInputData:
|
||||||
|
msa: torch.Tensor
|
||||||
|
ins: torch.Tensor
|
||||||
|
bond_feats: torch.Tensor
|
||||||
|
xyz_t: torch.Tensor
|
||||||
|
mask_t: torch.Tensor
|
||||||
|
t1d: torch.Tensor
|
||||||
|
chirals: torch.Tensor
|
||||||
|
atom_frames: torch.Tensor
|
||||||
|
taxids: Optional[List[str]] = None
|
||||||
|
term_info: Optional[torch.Tensor] = None
|
||||||
|
chain_lengths: Optional[List] = None
|
||||||
|
idx: Optional[List] = None
|
||||||
|
|
||||||
|
def query_sequence(self):
|
||||||
|
return self.msa[0]
|
||||||
|
|
||||||
|
def is_atom(self):
|
||||||
|
return is_atom(self.query_sequence())
|
||||||
|
|
||||||
|
def length(self):
|
||||||
|
return self.msa.shape[1]
|
||||||
|
|
||||||
|
def get_chain_bins_from_chain_lengths(self):
|
||||||
|
if self.chain_lengths is None:
|
||||||
|
raise ValueError("Cannot call get_chain_bins_from_chain_lengths without \
|
||||||
|
setting chain_lengths. Chain_lengths is set in merge_inputs")
|
||||||
|
chain_bins = {}
|
||||||
|
running_length = 0
|
||||||
|
for chain, length in self.chain_lengths:
|
||||||
|
chain_bins[chain] = (running_length, running_length+length)
|
||||||
|
running_length = running_length + length
|
||||||
|
return chain_bins
|
||||||
|
|
||||||
|
def update_protein_features_after_atomize(self, residues_to_atomize):
|
||||||
|
if self.chain_lengths is None:
|
||||||
|
raise("Cannot update protein features without chain_lengths. \
|
||||||
|
merge_inputs must be called before this function")
|
||||||
|
chain_bins = self.get_chain_bins_from_chain_lengths()
|
||||||
|
keep = torch.ones(self.length())
|
||||||
|
prev_absolute_index = None
|
||||||
|
prev_C = None
|
||||||
|
#need to atomize residues from N term to Cterm to handle atomizing neighbors
|
||||||
|
residues_to_atomize = sorted(residues_to_atomize, key= lambda x: x.original_chain +str(x.index_in_original_chain))
|
||||||
|
for residue in residues_to_atomize:
|
||||||
|
original_chain_start_index, original_chain_end_index = chain_bins[residue.original_chain]
|
||||||
|
absolute_index_in_combined_input = original_chain_start_index + residue.index_in_original_chain
|
||||||
|
|
||||||
|
atomized_chain_start_index, atomized_chain_end_index = chain_bins[residue.chain]
|
||||||
|
N_index = atomized_chain_start_index + residue.absolute_N_index_in_chain
|
||||||
|
C_index = atomized_chain_start_index + residue.absolute_C_index_in_chain
|
||||||
|
# if residue is first in the chain, no extra bond feats to following residue
|
||||||
|
if absolute_index_in_combined_input != original_chain_start_index:
|
||||||
|
self.bond_feats[absolute_index_in_combined_input-1, N_index] = ChemData().RESIDUE_ATOM_BOND
|
||||||
|
self.bond_feats[N_index, absolute_index_in_combined_input-1] = ChemData().RESIDUE_ATOM_BOND
|
||||||
|
|
||||||
|
# if residue is last in chain, no extra bonds feats to following residue
|
||||||
|
if absolute_index_in_combined_input != original_chain_end_index-1:
|
||||||
|
self.bond_feats[absolute_index_in_combined_input+1, C_index] = ChemData().RESIDUE_ATOM_BOND
|
||||||
|
self.bond_feats[C_index,absolute_index_in_combined_input+1] = ChemData().RESIDUE_ATOM_BOND
|
||||||
|
keep[absolute_index_in_combined_input] = 0
|
||||||
|
|
||||||
|
# find neighboring residues that were atomized
|
||||||
|
if prev_absolute_index is not None:
|
||||||
|
if prev_absolute_index + 1 == absolute_index_in_combined_input:
|
||||||
|
self.bond_feats[prev_C, N_index] = 1
|
||||||
|
self.bond_feats[N_index, prev_C] = 1
|
||||||
|
|
||||||
|
prev_absolute_index = absolute_index_in_combined_input
|
||||||
|
prev_C = C_index
|
||||||
|
# remove protein features
|
||||||
|
self.keep_features(keep.bool())
|
||||||
|
|
||||||
|
def keep_features(self, keep):
|
||||||
|
if not torch.all(keep[self.is_atom()]):
|
||||||
|
raise ValueError("cannot remove atoms")
|
||||||
|
self.msa = self.msa[:,keep]
|
||||||
|
self.ins = self.ins[:,keep]
|
||||||
|
self.bond_feats = self.bond_feats[keep][:,keep]
|
||||||
|
self.xyz_t = self.xyz_t[:,keep]
|
||||||
|
self.t1d = self.t1d[:,keep]
|
||||||
|
self.mask_t = self.mask_t[:,keep]
|
||||||
|
if self.term_info is not None:
|
||||||
|
self.term_info = self.term_info[keep]
|
||||||
|
if self.idx is not None:
|
||||||
|
self.idx = self.idx[keep]
|
||||||
|
# assumes all chirals are after all protein residues
|
||||||
|
self.chirals[...,:-1] = self.chirals[...,:-1] - torch.sum(~keep)
|
||||||
|
|
||||||
|
def construct_features(self, model_runner):
|
||||||
|
loader_params = model_runner.config.loader_params
|
||||||
|
B, L = 1, self.length()
|
||||||
|
seq, msa_clust, msa_seed, msa_extra, mask_pos = MSAFeaturize(
|
||||||
|
self.msa.long(),
|
||||||
|
self.ins.long(),
|
||||||
|
loader_params,
|
||||||
|
p_mask=loader_params.get("p_msa_mask", 0),
|
||||||
|
term_info=self.term_info,
|
||||||
|
deterministic=model_runner.deterministic,
|
||||||
|
)
|
||||||
|
dist_matrix = get_bond_distances(self.bond_feats)
|
||||||
|
|
||||||
|
# xyz_prev, mask_prev = generate_xyz_prev(self.xyz_t, self.mask_t, loader_params)
|
||||||
|
# xyz_prev = torch.nan_to_num(xyz_prev)
|
||||||
|
|
||||||
|
# NOTE: The above is the way things "should" be done, this is for compatability with training.
|
||||||
|
xyz_prev = ChemData().INIT_CRDS.reshape(1,ChemData().NTOTAL,3).repeat(L,1,1)
|
||||||
|
|
||||||
|
self.xyz_t = torch.nan_to_num(self.xyz_t)
|
||||||
|
|
||||||
|
mask_t_2d = get_prot_sm_mask(self.mask_t, seq[0])
|
||||||
|
mask_t_2d = mask_t_2d[:,None]*mask_t_2d[:,:,None] # (B, T, L, L)
|
||||||
|
|
||||||
|
xyz_t_frame = xyz_t_to_frame_xyz(self.xyz_t[None], self.msa[0], self.atom_frames)
|
||||||
|
t2d = xyz_to_t2d(xyz_t_frame, mask_t_2d[None])
|
||||||
|
t2d = t2d[0]
|
||||||
|
# get torsion angles from templates
|
||||||
|
seq_tmp = self.t1d[...,:-1].argmax(dim=-1)
|
||||||
|
alpha, _, alpha_mask, _ = model_runner.xyz_converter.get_torsions(self.xyz_t.reshape(-1,L,ChemData().NTOTAL,3),
|
||||||
|
seq_tmp, mask_in=self.mask_t.reshape(-1,L,ChemData().NTOTAL))
|
||||||
|
alpha = alpha.reshape(B,-1,L,ChemData().NTOTALDOFS,2)
|
||||||
|
alpha_mask = alpha_mask.reshape(B,-1,L,ChemData().NTOTALDOFS,1)
|
||||||
|
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*ChemData().NTOTALDOFS)
|
||||||
|
alpha_t = alpha_t[0]
|
||||||
|
alpha_prev = torch.zeros((L,ChemData().NTOTALDOFS,2))
|
||||||
|
|
||||||
|
same_chain = same_chain_from_bond_feats(self.bond_feats)
|
||||||
|
return RFInput(
|
||||||
|
msa_latent=msa_seed,
|
||||||
|
msa_full=msa_extra,
|
||||||
|
seq=seq,
|
||||||
|
seq_unmasked=self.query_sequence(),
|
||||||
|
bond_feats=self.bond_feats,
|
||||||
|
dist_matrix=dist_matrix,
|
||||||
|
chirals=self.chirals,
|
||||||
|
atom_frames=self.atom_frames.long(),
|
||||||
|
xyz_prev=xyz_prev,
|
||||||
|
alpha_prev=alpha_prev,
|
||||||
|
t1d=self.t1d,
|
||||||
|
t2d=t2d,
|
||||||
|
xyz_t=self.xyz_t[..., 1, :],
|
||||||
|
alpha_t=alpha_t.float(),
|
||||||
|
mask_t=mask_t_2d.float(),
|
||||||
|
same_chain=same_chain.long(),
|
||||||
|
idx=self.idx
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RFInput:
|
||||||
|
msa_latent: torch.Tensor
|
||||||
|
msa_full: torch.Tensor
|
||||||
|
seq: torch.Tensor
|
||||||
|
seq_unmasked: torch.Tensor
|
||||||
|
idx: torch.Tensor
|
||||||
|
bond_feats: torch.Tensor
|
||||||
|
dist_matrix: torch.Tensor
|
||||||
|
chirals: torch.Tensor
|
||||||
|
atom_frames: torch.Tensor
|
||||||
|
xyz_prev: torch.Tensor
|
||||||
|
alpha_prev: torch.Tensor
|
||||||
|
t1d: torch.Tensor
|
||||||
|
t2d: torch.Tensor
|
||||||
|
xyz_t: torch.Tensor
|
||||||
|
alpha_t: torch.Tensor
|
||||||
|
mask_t: torch.Tensor
|
||||||
|
same_chain: torch.Tensor
|
||||||
|
msa_prev: Optional[torch.Tensor] = None
|
||||||
|
pair_prev: Optional[torch.Tensor] = None
|
||||||
|
state_prev: Optional[torch.Tensor] = None
|
||||||
|
mask_recycle: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def to(self, gpu):
|
||||||
|
for field in fields(self):
|
||||||
|
field_value = getattr(self, field.name)
|
||||||
|
if torch.is_tensor(field_value):
|
||||||
|
setattr(self, field.name, field_value.to(gpu))
|
||||||
|
|
||||||
|
def add_batch_dim(self):
|
||||||
|
""" mimic pytorch dataloader at inference time"""
|
||||||
|
for field in fields(self):
|
||||||
|
field_value = getattr(self, field.name)
|
||||||
|
if torch.is_tensor(field_value):
|
||||||
|
setattr(self, field.name, field_value[None])
|
||||||
|
|
909
rf2aa/data/data_loader_utils.py
Normal file
909
rf2aa/data/data_loader_utils.py
Normal file
|
@ -0,0 +1,909 @@
|
||||||
|
import torch
|
||||||
|
import warnings
|
||||||
|
import time
|
||||||
|
from icecream import ic
|
||||||
|
from torch.utils import data
|
||||||
|
import os, csv, random, pickle, gzip, itertools, time, ast, copy, sys
|
||||||
|
|
||||||
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(script_dir)
|
||||||
|
sys.path.append(script_dir+'/../')
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
import networkx as nx
|
||||||
|
|
||||||
|
from rf2aa.data.parsers import parse_a3m, parse_pdb
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
|
||||||
|
from rf2aa.util import random_rot_trans, \
|
||||||
|
is_atom, is_protein, is_nucleic, is_atom
|
||||||
|
|
||||||
|
|
||||||
|
def MSABlockDeletion(msa, ins, nb=5):
|
||||||
|
'''
|
||||||
|
Input: MSA having shape (N, L)
|
||||||
|
output: new MSA with block deletion
|
||||||
|
'''
|
||||||
|
N, L = msa.shape
|
||||||
|
block_size = max(int(N*0.3), 1)
|
||||||
|
block_start = np.random.randint(low=1, high=N, size=nb) # (nb)
|
||||||
|
to_delete = block_start[:,None] + np.arange(block_size)[None,:]
|
||||||
|
to_delete = np.unique(np.clip(to_delete, 1, N-1))
|
||||||
|
#
|
||||||
|
mask = np.ones(N, bool)
|
||||||
|
mask[to_delete] = 0
|
||||||
|
|
||||||
|
return msa[mask], ins[mask]
|
||||||
|
|
||||||
|
def cluster_sum(data, assignment, N_seq, N_res):
|
||||||
|
csum = torch.zeros(N_seq, N_res, data.shape[-1], device=data.device).scatter_add(0, assignment.view(-1,1,1).expand(-1,N_res,data.shape[-1]), data.float())
|
||||||
|
return csum
|
||||||
|
|
||||||
|
def get_term_feats(Ls):
|
||||||
|
"""Creates N/C-terminus binary features"""
|
||||||
|
term_info = torch.zeros((sum(Ls),2)).float()
|
||||||
|
start = 0
|
||||||
|
for L_chain in Ls:
|
||||||
|
term_info[start, 0] = 1.0 # flag for N-term
|
||||||
|
term_info[start+L_chain-1,1] = 1.0 # flag for C-term
|
||||||
|
start += L_chain
|
||||||
|
return term_info
|
||||||
|
|
||||||
|
|
||||||
|
def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
|
||||||
|
term_info=None, tocpu=False, fixbb=False, seed_msa_clus=None, deterministic=False):
|
||||||
|
'''
|
||||||
|
Input: full MSA information (after Block deletion if necessary) & full insertion information
|
||||||
|
Output: seed MSA features & extra sequences
|
||||||
|
|
||||||
|
Seed MSA features:
|
||||||
|
- aatype of seed sequence (20 regular aa + 1 gap/unknown + 1 mask)
|
||||||
|
- profile of clustered sequences (22)
|
||||||
|
- insertion statistics (2)
|
||||||
|
- N-term or C-term? (2)
|
||||||
|
extra sequence features:
|
||||||
|
- aatype of extra sequence (22)
|
||||||
|
- insertion info (1)
|
||||||
|
- N-term or C-term? (2)
|
||||||
|
'''
|
||||||
|
if deterministic:
|
||||||
|
random.seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
# TODO: delete me, just for testing purposes
|
||||||
|
msa = msa[:2]
|
||||||
|
|
||||||
|
if fixbb:
|
||||||
|
p_mask = 0
|
||||||
|
msa = msa[:1]
|
||||||
|
ins = ins[:1]
|
||||||
|
N, L = msa.shape
|
||||||
|
|
||||||
|
if term_info is None:
|
||||||
|
if len(L_s)==0:
|
||||||
|
L_s = [L]
|
||||||
|
term_info = get_term_feats(L_s)
|
||||||
|
term_info = term_info.to(msa.device)
|
||||||
|
|
||||||
|
#binding_site = torch.zeros((L,1), device=msa.device).float()
|
||||||
|
binding_site = torch.zeros((L,0), device=msa.device).float() # keeping this off for now (Jue 12/19)
|
||||||
|
|
||||||
|
# raw MSA profile
|
||||||
|
raw_profile = torch.nn.functional.one_hot(msa, num_classes=ChemData().NAATOKENS) # N x L x NAATOKENS
|
||||||
|
raw_profile = raw_profile.float().mean(dim=0) # L x NAATOKENS
|
||||||
|
|
||||||
|
# Select Nclust sequence randomly (seed MSA or latent MSA)
|
||||||
|
Nclust = (min(N, params['MAXLAT'])-1) // nmer
|
||||||
|
Nclust = Nclust*nmer + 1
|
||||||
|
|
||||||
|
if N > Nclust*2:
|
||||||
|
Nextra = N - Nclust
|
||||||
|
else:
|
||||||
|
Nextra = N
|
||||||
|
Nextra = min(Nextra, params['MAXSEQ']) // nmer
|
||||||
|
Nextra = max(1, Nextra * nmer)
|
||||||
|
#
|
||||||
|
b_seq = list()
|
||||||
|
b_msa_clust = list()
|
||||||
|
b_msa_seed = list()
|
||||||
|
b_msa_extra = list()
|
||||||
|
b_mask_pos = list()
|
||||||
|
for i_cycle in range(params['MAXCYCLE']):
|
||||||
|
sample_mono = torch.randperm((N-1)//nmer, device=msa.device)
|
||||||
|
sample = [sample_mono + imer*((N-1)//nmer) for imer in range(nmer)]
|
||||||
|
sample = torch.stack(sample, dim=-1)
|
||||||
|
sample = sample.reshape(-1)
|
||||||
|
|
||||||
|
# add MSA clusters pre-chosen before calling this function
|
||||||
|
if seed_msa_clus is not None:
|
||||||
|
sample_orig_shape = sample.shape
|
||||||
|
sample_seed = seed_msa_clus[i_cycle]
|
||||||
|
sample_more = torch.tensor([i for i in sample if i not in sample_seed])
|
||||||
|
N_sample_more = len(sample) - len(sample_seed)
|
||||||
|
if N_sample_more > 0:
|
||||||
|
sample_more = sample_more[torch.randperm(len(sample_more))[:N_sample_more]]
|
||||||
|
sample = torch.cat([sample_seed, sample_more])
|
||||||
|
else:
|
||||||
|
sample = sample_seed[:len(sample)] # take all clusters from pre-chosen ones
|
||||||
|
|
||||||
|
msa_clust = torch.cat((msa[:1,:], msa[1:,:][sample[:Nclust-1]]), dim=0)
|
||||||
|
ins_clust = torch.cat((ins[:1,:], ins[1:,:][sample[:Nclust-1]]), dim=0)
|
||||||
|
|
||||||
|
# 15% random masking
|
||||||
|
# - 10%: aa replaced with a uniformly sampled random amino acid
|
||||||
|
# - 10%: aa replaced with an amino acid sampled from the MSA profile
|
||||||
|
# - 10%: not replaced
|
||||||
|
# - 70%: replaced with a special token ("mask")
|
||||||
|
random_aa = torch.tensor([[0.05]*20 + [0.0]*(ChemData().NAATOKENS-20)], device=msa.device)
|
||||||
|
same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=ChemData().NAATOKENS)
|
||||||
|
# explicitly remove probabilities from nucleic acids and atoms
|
||||||
|
#same_aa[..., ChemData().NPROTAAS:] = 0
|
||||||
|
#raw_profile[...,ChemData().NPROTAAS:] = 0
|
||||||
|
probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa
|
||||||
|
#probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7)
|
||||||
|
|
||||||
|
# explicitly set the probability of masking for nucleic acids and atoms
|
||||||
|
#probs[...,is_protein(seq),ChemData().MASKINDEX]=0.7
|
||||||
|
#probs[...,~is_protein(seq), :] = 0 # probably overkill but set all none protein elements to 0
|
||||||
|
#probs[1:, ~is_protein(seq),20] = 1.0 # want to leave the gaps as gaps
|
||||||
|
#probs[0,is_nucleic(seq), ChemData().MASKINDEX] = 1.0
|
||||||
|
#probs[0,is_atom(seq), ChemData().aa2num["ATM"]] = 1.0
|
||||||
|
|
||||||
|
sampler = torch.distributions.categorical.Categorical(probs=probs)
|
||||||
|
mask_sample = sampler.sample()
|
||||||
|
|
||||||
|
mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask
|
||||||
|
mask_pos[msa_clust>ChemData().MASKINDEX]=False # no masking on NAs
|
||||||
|
use_seq = msa_clust
|
||||||
|
msa_masked = torch.where(mask_pos, mask_sample, use_seq)
|
||||||
|
b_seq.append(msa_masked[0].clone())
|
||||||
|
|
||||||
|
## get extra sequenes
|
||||||
|
if N > Nclust*2: # there are enough extra sequences
|
||||||
|
msa_extra = msa[1:,:][sample[Nclust-1:]]
|
||||||
|
ins_extra = ins[1:,:][sample[Nclust-1:]]
|
||||||
|
extra_mask = torch.full(msa_extra.shape, False, device=msa_extra.device)
|
||||||
|
elif N - Nclust < 1:
|
||||||
|
msa_extra = msa_masked.clone()
|
||||||
|
ins_extra = ins_clust.clone()
|
||||||
|
extra_mask = mask_pos.clone()
|
||||||
|
else:
|
||||||
|
msa_add = msa[1:,:][sample[Nclust-1:]]
|
||||||
|
ins_add = ins[1:,:][sample[Nclust-1:]]
|
||||||
|
mask_add = torch.full(msa_add.shape, False, device=msa_add.device)
|
||||||
|
msa_extra = torch.cat((msa_masked, msa_add), dim=0)
|
||||||
|
ins_extra = torch.cat((ins_clust, ins_add), dim=0)
|
||||||
|
extra_mask = torch.cat((mask_pos, mask_add), dim=0)
|
||||||
|
N_extra = msa_extra.shape[0]
|
||||||
|
|
||||||
|
# clustering (assign remaining sequences to their closest cluster by Hamming distance
|
||||||
|
msa_clust_onehot = torch.nn.functional.one_hot(msa_masked, num_classes=ChemData().NAATOKENS)
|
||||||
|
msa_extra_onehot = torch.nn.functional.one_hot(msa_extra, num_classes=ChemData().NAATOKENS)
|
||||||
|
count_clust = torch.logical_and(~mask_pos, msa_clust != 20).float() # 20: index for gap, ignore both masked & gaps
|
||||||
|
count_extra = torch.logical_and(~extra_mask, msa_extra != 20).float()
|
||||||
|
agreement = torch.matmul((count_extra[:,:,None]*msa_extra_onehot).view(N_extra, -1), (count_clust[:,:,None]*msa_clust_onehot).view(Nclust, -1).T)
|
||||||
|
assignment = torch.argmax(agreement, dim=-1)
|
||||||
|
|
||||||
|
# seed MSA features
|
||||||
|
# 1. one_hot encoded aatype: msa_clust_onehot
|
||||||
|
# 2. cluster profile
|
||||||
|
count_extra = ~extra_mask
|
||||||
|
count_clust = ~mask_pos
|
||||||
|
msa_clust_profile = cluster_sum(count_extra[:,:,None]*msa_extra_onehot, assignment, Nclust, L)
|
||||||
|
msa_clust_profile += count_clust[:,:,None]*msa_clust_profile
|
||||||
|
count_profile = cluster_sum(count_extra[:,:,None], assignment, Nclust, L).view(Nclust, L)
|
||||||
|
count_profile += count_clust
|
||||||
|
count_profile += eps
|
||||||
|
msa_clust_profile /= count_profile[:,:,None]
|
||||||
|
# 3. insertion statistics
|
||||||
|
msa_clust_del = cluster_sum((count_extra*ins_extra)[:,:,None], assignment, Nclust, L).view(Nclust, L)
|
||||||
|
msa_clust_del += count_clust*ins_clust
|
||||||
|
msa_clust_del /= count_profile
|
||||||
|
ins_clust = (2.0/np.pi)*torch.arctan(ins_clust.float()/3.0) # (from 0 to 1)
|
||||||
|
msa_clust_del = (2.0/np.pi)*torch.arctan(msa_clust_del.float()/3.0) # (from 0 to 1)
|
||||||
|
ins_clust = torch.stack((ins_clust, msa_clust_del), dim=-1)
|
||||||
|
#
|
||||||
|
if fixbb:
|
||||||
|
assert params['MAXCYCLE'] == 1
|
||||||
|
msa_clust_profile = msa_clust_onehot
|
||||||
|
msa_extra_onehot = msa_clust_onehot
|
||||||
|
ins_clust[:] = 0
|
||||||
|
ins_extra[:] = 0
|
||||||
|
# This is how it is done in rfdiff, but really it seems like it should be all 0.
|
||||||
|
# Keeping as-is for now for consistency, as it may be used in downstream masking done
|
||||||
|
# by apply_masks.
|
||||||
|
mask_pos = torch.full_like(msa_clust, 1).bool()
|
||||||
|
msa_seed = torch.cat((msa_clust_onehot, msa_clust_profile, ins_clust, term_info[None].expand(Nclust,-1,-1)), dim=-1)
|
||||||
|
|
||||||
|
# extra MSA features
|
||||||
|
ins_extra = (2.0/np.pi)*torch.arctan(ins_extra[:Nextra].float()/3.0) # (from 0 to 1)
|
||||||
|
try:
|
||||||
|
msa_extra = torch.cat((msa_extra_onehot[:Nextra], ins_extra[:,:,None], term_info[None].expand(Nextra,-1,-1)), dim=-1)
|
||||||
|
except Exception as e:
|
||||||
|
print('msa_extra.shape',msa_extra.shape)
|
||||||
|
print('ins_extra.shape',ins_extra.shape)
|
||||||
|
|
||||||
|
if (tocpu):
|
||||||
|
b_msa_clust.append(msa_clust.cpu())
|
||||||
|
b_msa_seed.append(msa_seed.cpu())
|
||||||
|
b_msa_extra.append(msa_extra.cpu())
|
||||||
|
b_mask_pos.append(mask_pos.cpu())
|
||||||
|
else:
|
||||||
|
b_msa_clust.append(msa_clust)
|
||||||
|
b_msa_seed.append(msa_seed)
|
||||||
|
b_msa_extra.append(msa_extra)
|
||||||
|
b_mask_pos.append(mask_pos)
|
||||||
|
|
||||||
|
b_seq = torch.stack(b_seq)
|
||||||
|
b_msa_clust = torch.stack(b_msa_clust)
|
||||||
|
b_msa_seed = torch.stack(b_msa_seed)
|
||||||
|
b_msa_extra = torch.stack(b_msa_extra)
|
||||||
|
b_mask_pos = torch.stack(b_mask_pos)
|
||||||
|
|
||||||
|
return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos
|
||||||
|
|
||||||
|
def blank_template(n_tmpl, L, random_noise=5.0, deterministic: bool = False):
|
||||||
|
if deterministic:
|
||||||
|
random.seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
|
||||||
|
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(n_tmpl,L,1,1) \
|
||||||
|
+ torch.rand(n_tmpl,L,1,3)*random_noise - random_noise/2
|
||||||
|
t1d = torch.nn.functional.one_hot(torch.full((n_tmpl, L), 20).long(), num_classes=ChemData().NAATOKENS-1).float() # all gaps
|
||||||
|
conf = torch.zeros((n_tmpl, L, 1)).float()
|
||||||
|
t1d = torch.cat((t1d, conf), -1)
|
||||||
|
mask_t = torch.full((n_tmpl,L,ChemData().NTOTAL), False)
|
||||||
|
return xyz, t1d, mask_t, np.full((n_tmpl), "")
|
||||||
|
|
||||||
|
|
||||||
|
def TemplFeaturize(tplt, qlen, params, offset=0, npick=1, npick_global=None, pick_top=True, same_chain=None, random_noise=5, deterministic: bool = False):
|
||||||
|
if deterministic:
|
||||||
|
random.seed(0)
|
||||||
|
np.random.seed(0)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
torch.cuda.manual_seed(0)
|
||||||
|
|
||||||
|
seqID_cut = params['SEQID']
|
||||||
|
|
||||||
|
if npick_global == None:
|
||||||
|
npick_global=max(npick, 1)
|
||||||
|
|
||||||
|
ntplt = len(tplt['ids'])
|
||||||
|
if (ntplt < 1) or (npick < 1): #no templates in hhsearch file or not want to use templ
|
||||||
|
return blank_template(npick_global, qlen, random_noise)
|
||||||
|
|
||||||
|
# ignore templates having too high seqID
|
||||||
|
if seqID_cut <= 100.0:
|
||||||
|
tplt_valid_idx = torch.where(tplt['f0d'][0,:,4] < seqID_cut)[0]
|
||||||
|
tplt['ids'] = np.array(tplt['ids'])[tplt_valid_idx]
|
||||||
|
else:
|
||||||
|
tplt_valid_idx = torch.arange(len(tplt['ids']))
|
||||||
|
|
||||||
|
# check again if there are templates having seqID < cutoff
|
||||||
|
ntplt = len(tplt['ids'])
|
||||||
|
npick = min(npick, ntplt)
|
||||||
|
if npick<1: # no templates
|
||||||
|
return blank_template(npick_global, qlen, random_noise)
|
||||||
|
|
||||||
|
if not pick_top: # select randomly among all possible templates
|
||||||
|
sample = torch.randperm(ntplt)[:npick]
|
||||||
|
else: # only consider top 50 templates
|
||||||
|
sample = torch.randperm(min(50,ntplt))[:npick]
|
||||||
|
|
||||||
|
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(npick_global,qlen,1,1) + torch.rand(1,qlen,1,3)*random_noise
|
||||||
|
mask_t = torch.full((npick_global,qlen,ChemData().NTOTAL),False) # True for valid atom, False for missing atom
|
||||||
|
t1d = torch.full((npick_global, qlen), 20).long()
|
||||||
|
t1d_val = torch.zeros((npick_global, qlen)).float()
|
||||||
|
for i,nt in enumerate(sample):
|
||||||
|
tplt_idx = tplt_valid_idx[nt]
|
||||||
|
sel = torch.where(tplt['qmap'][0,:,1]==tplt_idx)[0]
|
||||||
|
pos = tplt['qmap'][0,sel,0] + offset
|
||||||
|
|
||||||
|
ntmplatoms = tplt['xyz'].shape[2] # will be bigger for NA templates
|
||||||
|
xyz[i,pos,:ntmplatoms] = tplt['xyz'][0,sel]
|
||||||
|
mask_t[i,pos,:ntmplatoms] = tplt['mask'][0,sel].bool()
|
||||||
|
|
||||||
|
# 1-D features: alignment confidence
|
||||||
|
t1d[i,pos] = tplt['seq'][0,sel]
|
||||||
|
t1d_val[i,pos] = tplt['f1d'][0,sel,2] # alignment confidence
|
||||||
|
# xyz[i] = center_and_realign_missing(xyz[i], mask_t[i], same_chain=same_chain)
|
||||||
|
|
||||||
|
t1d = torch.nn.functional.one_hot(t1d, num_classes=ChemData().NAATOKENS-1).float() # (no mask token)
|
||||||
|
t1d = torch.cat((t1d, t1d_val[...,None]), dim=-1)
|
||||||
|
|
||||||
|
tplt_ids = np.array(tplt["ids"])[sample].flatten() # np.array of chain ids (ordered)
|
||||||
|
return xyz, t1d, mask_t, tplt_ids
|
||||||
|
|
||||||
|
def merge_hetero_templates(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids, Ls_prot, deterministic: bool = False):
|
||||||
|
"""Diagonally tiles template coordinates, 1d input features, and masks across
|
||||||
|
template and residue dimensions. 1st template is concatenated directly on residue
|
||||||
|
dimension after a random rotation & translation.
|
||||||
|
"""
|
||||||
|
N_tmpl_tot = sum([x.shape[0] for x in xyz_t_prot])
|
||||||
|
|
||||||
|
xyz_t_out, f1d_t_out, mask_t_out, _ = blank_template(N_tmpl_tot, sum(Ls_prot))
|
||||||
|
tplt_ids_out = np.full((N_tmpl_tot),"", dtype=object) # rk bad practice.. should fix
|
||||||
|
i_tmpl = 0
|
||||||
|
i_res = 0
|
||||||
|
for xyz_, f1d_, mask_, ids in zip(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids):
|
||||||
|
N_tmpl, L_tmpl = xyz_.shape[:2]
|
||||||
|
if i_tmpl == 0:
|
||||||
|
i1, i2 = 1, N_tmpl
|
||||||
|
else:
|
||||||
|
i1, i2 = i_tmpl, i_tmpl+N_tmpl - 1
|
||||||
|
|
||||||
|
# 1st template is concatenated directly, so that all atoms are set in xyz_prev
|
||||||
|
xyz_t_out[0, i_res:i_res+L_tmpl] = random_rot_trans(xyz_[0:1], deterministic=deterministic)
|
||||||
|
f1d_t_out[0, i_res:i_res+L_tmpl] = f1d_[0]
|
||||||
|
mask_t_out[0, i_res:i_res+L_tmpl] = mask_[0]
|
||||||
|
|
||||||
|
if not tplt_ids_out[0]: # only add first template
|
||||||
|
tplt_ids_out[0] = ids[0]
|
||||||
|
# remaining templates are diagonally tiled
|
||||||
|
xyz_t_out[i1:i2, i_res:i_res+L_tmpl] = xyz_[1:]
|
||||||
|
f1d_t_out[i1:i2, i_res:i_res+L_tmpl] = f1d_[1:]
|
||||||
|
mask_t_out[i1:i2, i_res:i_res+L_tmpl] = mask_[1:]
|
||||||
|
tplt_ids_out[i1:i2] = ids[1:]
|
||||||
|
if i_tmpl == 0:
|
||||||
|
i_tmpl += N_tmpl
|
||||||
|
else:
|
||||||
|
i_tmpl += N_tmpl-1
|
||||||
|
i_res += L_tmpl
|
||||||
|
|
||||||
|
return xyz_t_out, f1d_t_out, mask_t_out, tplt_ids_out
|
||||||
|
|
||||||
|
def generate_xyz_prev(xyz_t, mask_t, params):
|
||||||
|
"""
|
||||||
|
allows you to use different initializations for the coordinate track specified in params
|
||||||
|
"""
|
||||||
|
L = xyz_t.shape[1]
|
||||||
|
if params["BLACK_HOLE_INIT"]:
|
||||||
|
xyz_t, _, mask_t = blank_template(1, L)
|
||||||
|
return xyz_t[0].clone(), mask_t[0].clone()
|
||||||
|
|
||||||
|
### merge msa & insertion statistics of two proteins having different taxID
|
||||||
|
def merge_a3m_hetero(a3mA, a3mB, L_s):
|
||||||
|
# merge msa
|
||||||
|
query = torch.cat([a3mA['msa'][0], a3mB['msa'][0]]).unsqueeze(0) # (1, L)
|
||||||
|
|
||||||
|
msa = [query]
|
||||||
|
if a3mA['msa'].shape[0] > 1:
|
||||||
|
extra_A = torch.nn.functional.pad(a3mA['msa'][1:], (0,sum(L_s[1:])), "constant", 20) # pad gaps
|
||||||
|
msa.append(extra_A)
|
||||||
|
if a3mB['msa'].shape[0] > 1:
|
||||||
|
extra_B = torch.nn.functional.pad(a3mB['msa'][1:], (L_s[0],0), "constant", 20)
|
||||||
|
msa.append(extra_B)
|
||||||
|
msa = torch.cat(msa, dim=0)
|
||||||
|
|
||||||
|
# merge ins
|
||||||
|
query = torch.cat([a3mA['ins'][0], a3mB['ins'][0]]).unsqueeze(0) # (1, L)
|
||||||
|
ins = [query]
|
||||||
|
if a3mA['ins'].shape[0] > 1:
|
||||||
|
extra_A = torch.nn.functional.pad(a3mA['ins'][1:], (0,sum(L_s[1:])), "constant", 0) # pad gaps
|
||||||
|
ins.append(extra_A)
|
||||||
|
if a3mB['ins'].shape[0] > 1:
|
||||||
|
extra_B = torch.nn.functional.pad(a3mB['ins'][1:], (L_s[0],0), "constant", 0)
|
||||||
|
ins.append(extra_B)
|
||||||
|
ins = torch.cat(ins, dim=0)
|
||||||
|
|
||||||
|
a3m = {'msa': msa, 'ins': ins}
|
||||||
|
|
||||||
|
# merge taxids
|
||||||
|
if 'taxid' in a3mA and 'taxid' in a3mB:
|
||||||
|
a3m['taxid'] = np.concatenate([np.array(a3mA['taxid']), np.array(a3mB['taxid'])[1:]])
|
||||||
|
|
||||||
|
return a3m
|
||||||
|
|
||||||
|
# merge msa & insertion statistics of units in homo-oligomers
|
||||||
|
def merge_a3m_homo(msa_orig, ins_orig, nmer, mode="default"):
|
||||||
|
N, L = msa_orig.shape[:2]
|
||||||
|
if mode == "repeat":
|
||||||
|
|
||||||
|
# AAAAAA
|
||||||
|
# AAAAAA
|
||||||
|
|
||||||
|
msa = torch.tile(msa_orig,(1,nmer))
|
||||||
|
ins = torch.tile(ins_orig,(1,nmer))
|
||||||
|
|
||||||
|
elif mode == "diag":
|
||||||
|
|
||||||
|
# AAAAAA
|
||||||
|
# A-----
|
||||||
|
# -A----
|
||||||
|
# --A---
|
||||||
|
# ---A--
|
||||||
|
# ----A-
|
||||||
|
# -----A
|
||||||
|
|
||||||
|
N = N - 1
|
||||||
|
new_N = 1 + N * nmer
|
||||||
|
new_L = L * nmer
|
||||||
|
msa = torch.full((new_N, new_L), 20, dtype=msa_orig.dtype, device=msa_orig.device)
|
||||||
|
ins = torch.full((new_N, new_L), 0, dtype=ins_orig.dtype, device=msa_orig.device)
|
||||||
|
|
||||||
|
start_L = 0
|
||||||
|
start_N = 1
|
||||||
|
for i_c in range(nmer):
|
||||||
|
msa[0, start_L:start_L+L] = msa_orig[0]
|
||||||
|
msa[start_N:start_N+N, start_L:start_L+L] = msa_orig[1:]
|
||||||
|
ins[0, start_L:start_L+L] = ins_orig[0]
|
||||||
|
ins[start_N:start_N+N, start_L:start_L+L] = ins_orig[1:]
|
||||||
|
start_L += L
|
||||||
|
start_N += N
|
||||||
|
else:
|
||||||
|
|
||||||
|
# AAAAAA
|
||||||
|
# A-----
|
||||||
|
# -AAAAA
|
||||||
|
|
||||||
|
msa = torch.full((2*N-1, L*nmer), 20, dtype=msa_orig.dtype, device=msa_orig.device)
|
||||||
|
ins = torch.full((2*N-1, L*nmer), 0, dtype=ins_orig.dtype, device=msa_orig.device)
|
||||||
|
|
||||||
|
msa[:N, :L] = msa_orig
|
||||||
|
ins[:N, :L] = ins_orig
|
||||||
|
start = L
|
||||||
|
|
||||||
|
for i_c in range(1,nmer):
|
||||||
|
msa[0, start:start+L] = msa_orig[0]
|
||||||
|
msa[N:, start:start+L] = msa_orig[1:]
|
||||||
|
ins[0, start:start+L] = ins_orig[0]
|
||||||
|
ins[N:, start:start+L] = ins_orig[1:]
|
||||||
|
start += L
|
||||||
|
|
||||||
|
return msa, ins
|
||||||
|
|
||||||
|
def merge_msas(a3m_list, L_s):
|
||||||
|
"""
|
||||||
|
takes a list of a3m dictionaries with keys msa, ins and a list of protein lengths and creates a
|
||||||
|
combined MSA
|
||||||
|
"""
|
||||||
|
seen = set()
|
||||||
|
taxIDs = []
|
||||||
|
a3mA = a3m_list[0]
|
||||||
|
taxIDs.extend(a3mA["taxID"])
|
||||||
|
seen.update(a3mA["hash"])
|
||||||
|
msaA, insA = a3mA["msa"], a3mA["ins"]
|
||||||
|
for i in range(1, len(a3m_list)):
|
||||||
|
a3mB = a3m_list[i]
|
||||||
|
pair_taxIDs = set(taxIDs).intersection(set(a3mB["taxID"]))
|
||||||
|
if a3mB["hash"] in seen or len(pair_taxIDs) < 5: #homomer/not enough pairs
|
||||||
|
a3mA = {"msa": msaA, "ins": insA}
|
||||||
|
L_s_to_merge = [sum(L_s[:i]), L_s[i]]
|
||||||
|
a3mA = merge_a3m_hetero(a3mA, a3mB, L_s_to_merge)
|
||||||
|
msaA, insA = a3mA["msa"], a3mA["ins"]
|
||||||
|
taxIDs.extend(a3mB["taxID"])
|
||||||
|
else:
|
||||||
|
final_pairsA = []
|
||||||
|
final_pairsB = []
|
||||||
|
msaB, insB = a3mB["msa"], a3mB["ins"]
|
||||||
|
for pair in pair_taxIDs:
|
||||||
|
pair_a3mA = np.where(np.array(taxIDs)==pair)[0]
|
||||||
|
pair_a3mB = np.where(a3mB["taxID"]==pair)[0]
|
||||||
|
msaApair = torch.argmin(torch.sum(msaA[pair_a3mA, :] == msaA[0, :],axis=-1))
|
||||||
|
msaBpair = torch.argmin(torch.sum(msaB[pair_a3mB, :] == msaB[0, :],axis=-1))
|
||||||
|
final_pairsA.append(pair_a3mA[msaApair])
|
||||||
|
final_pairsB.append(pair_a3mB[msaBpair])
|
||||||
|
paired_msaB = torch.full((msaA.shape[0], L_s[i]), 20).long() # (N_seq_A, L_B)
|
||||||
|
paired_msaB[final_pairsA] = msaB[final_pairsB]
|
||||||
|
msaA = torch.cat([msaA, paired_msaB], dim=1)
|
||||||
|
insA = torch.zeros_like(msaA) # paired MSAs in our dataset dont have insertions
|
||||||
|
seen.update(a3mB["hash"])
|
||||||
|
|
||||||
|
return msaA, insA
|
||||||
|
|
||||||
|
def remove_all_gap_seqs(a3m):
|
||||||
|
"""Removes sequences that are all gaps from an MSA represented as `a3m` dictionary"""
|
||||||
|
idx_seq_keep = ~(a3m['msa']==ChemData().UNKINDEX).all(dim=1)
|
||||||
|
a3m['msa'] = a3m['msa'][idx_seq_keep]
|
||||||
|
a3m['ins'] = a3m['ins'][idx_seq_keep]
|
||||||
|
return a3m
|
||||||
|
|
||||||
|
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
|
||||||
|
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
|
||||||
|
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
|
||||||
|
ID, only the sequence with the highest sequence identity to the query (1st
|
||||||
|
sequence in MSA) will be paired.
|
||||||
|
|
||||||
|
Sequences that aren't paired will be padded and added to the bottom of the
|
||||||
|
joined MSA. If a subregion of the input MSAs overlap (represent the same
|
||||||
|
chain), the subregion residue indices can be given as `idx_overlap`, and
|
||||||
|
the overlap region of the unpaired sequences will be included in the joined
|
||||||
|
MSA.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
a3mA : dict
|
||||||
|
First MSA to be joined, with keys `msa` (N_seq, L_seq), `ins` (N_seq,
|
||||||
|
L_seq), `taxid` (N_seq,), and optionally `is_paired` (N_seq,), a
|
||||||
|
boolean tensor indicating whether each sequence is fully paired. Can be
|
||||||
|
a multi-MSA (contain >2 sub-MSAs).
|
||||||
|
a3mB : dict
|
||||||
|
2nd MSA to be joined, with keys `msa`, `ins`, `taxid`, and optionally
|
||||||
|
`is_paired`. Can be a multi-MSA ONLY if not overlapping with 1st MSA.
|
||||||
|
idx_overlap : tuple or list (optional)
|
||||||
|
Start and end indices of overlap region in 1st MSA, followed by the
|
||||||
|
same in 2nd MSA.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
a3m : dict
|
||||||
|
Paired MSA, with keys `msa`, `ins`, `taxid` and `is_paired`.
|
||||||
|
"""
|
||||||
|
# preprocess overlap region
|
||||||
|
L_A, L_B = a3mA['msa'].shape[1], a3mB['msa'].shape[1]
|
||||||
|
if idx_overlap is not None:
|
||||||
|
i1A, i2A, i1B, i2B = idx_overlap
|
||||||
|
i1B_new, i2B_new = (0, i1B) if i2B==L_B else (i2B, L_B) # MSA B residues that don't overlap MSA A
|
||||||
|
assert((i1B==0) or (i2B==a3mB['msa'].shape[1])), \
|
||||||
|
"When overlapping with 1st MSA, 2nd MSA must comprise at most 2 sub-MSAs "\
|
||||||
|
"(i.e. residue range should include 0 or a3mB['msa'].shape[1])"
|
||||||
|
else:
|
||||||
|
i1B_new, i2B_new = (0, L_B)
|
||||||
|
|
||||||
|
# pair sequences
|
||||||
|
taxids_shared = a3mA['taxid'][np.isin(a3mA['taxid'],a3mB['taxid'])]
|
||||||
|
i_pairedA, i_pairedB = [], []
|
||||||
|
|
||||||
|
for taxid in taxids_shared:
|
||||||
|
i_match = np.where(a3mA['taxid']==taxid)[0]
|
||||||
|
i_match_best = torch.argmin(torch.sum(a3mA['msa'][i_match]==a3mA['msa'][0], axis=1))
|
||||||
|
i_pairedA.append(i_match[i_match_best])
|
||||||
|
|
||||||
|
i_match = np.where(a3mB['taxid']==taxid)[0]
|
||||||
|
i_match_best = torch.argmin(torch.sum(a3mB['msa'][i_match]==a3mB['msa'][0], axis=1))
|
||||||
|
i_pairedB.append(i_match[i_match_best])
|
||||||
|
|
||||||
|
# unpaired sequences
|
||||||
|
i_unpairedA = np.setdiff1d(np.arange(a3mA['msa'].shape[0]), i_pairedA)
|
||||||
|
i_unpairedB = np.setdiff1d(np.arange(a3mB['msa'].shape[0]), i_pairedB)
|
||||||
|
N_paired, N_unpairedA, N_unpairedB = len(i_pairedA), len(i_unpairedA), len(i_unpairedB)
|
||||||
|
|
||||||
|
# handle overlap region
|
||||||
|
# if msa A consists of sub-MSAs 1,2,3 and msa B of 2,4 (i.e overlap region is 2),
|
||||||
|
# this diagram shows how the variables below make up the final multi-MSA
|
||||||
|
# (* denotes nongaps, - denotes gaps)
|
||||||
|
# 1 2 3 4
|
||||||
|
# |*|*|*|*| msa_paired
|
||||||
|
# |*|*|*|-| msaA_unpaired
|
||||||
|
# |-|*|-|*| msaB_unpaired
|
||||||
|
if idx_overlap is not None:
|
||||||
|
assert((a3mA['msa'][i_pairedA, i1A:i2A]==a3mB['msa'][i_pairedB, i1B:i2B]) |
|
||||||
|
(a3mA['msa'][i_pairedA, i1A:i2A]==ChemData().UNKINDEX)).all(),\
|
||||||
|
'Paired MSAs should be identical (or 1st MSA should be all gaps) in overlap region'
|
||||||
|
|
||||||
|
# overlap region gets sequences from 2nd MSA bc sometimes 1st MSA will be all gaps here
|
||||||
|
msa_paired = torch.cat([a3mA['msa'][i_pairedA, :i1A],
|
||||||
|
a3mB['msa'][i_pairedB, i1B:i2B],
|
||||||
|
a3mA['msa'][i_pairedA, i2A:],
|
||||||
|
a3mB['msa'][i_pairedB, i1B_new:i2B_new] ], dim=1)
|
||||||
|
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
|
||||||
|
torch.full((N_unpairedA, i2B_new-i1B_new), ChemData().UNKINDEX) ], dim=1)
|
||||||
|
msaB_unpaired = torch.cat([torch.full((N_unpairedB, i1A), ChemData().UNKINDEX),
|
||||||
|
a3mB['msa'][i_unpairedB, i1B:i2B],
|
||||||
|
torch.full((N_unpairedB, L_A-i2A), ChemData().UNKINDEX),
|
||||||
|
a3mB['msa'][i_unpairedB, i1B_new:i2B_new] ], dim=1)
|
||||||
|
else:
|
||||||
|
# no overlap region, simple offset pad & stack
|
||||||
|
# this code is actually a special case of "if" block above, but writing
|
||||||
|
# this out explicitly here to make the logic more clear
|
||||||
|
msa_paired = torch.cat([a3mA['msa'][i_pairedA], a3mB['msa'][i_pairedB, i1B_new:i2B_new]], dim=1)
|
||||||
|
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
|
||||||
|
torch.full((N_unpairedA, L_B), ChemData().UNKINDEX)], dim=1) # pad with gaps
|
||||||
|
msaB_unpaired = torch.cat([torch.full((N_unpairedB, L_A), ChemData().UNKINDEX),
|
||||||
|
a3mB['msa'][i_unpairedB]], dim=1) # pad with gaps
|
||||||
|
|
||||||
|
# stack paired & unpaired
|
||||||
|
msa = torch.cat([msa_paired, msaA_unpaired, msaB_unpaired], dim=0)
|
||||||
|
taxids = np.concatenate([a3mA['taxid'][i_pairedA], a3mA['taxid'][i_unpairedA], a3mB['taxid'][i_unpairedB]])
|
||||||
|
|
||||||
|
# label "fully paired" sequences (a row of MSA that was never padded with gaps)
|
||||||
|
# output seq is fully paired if seqs A & B both started out as paired and were paired to
|
||||||
|
# each other on tax ID.
|
||||||
|
# NOTE: there is a rare edge case that is ignored here for simplicity: if
|
||||||
|
# pMSA 0+1 and 1+2 are joined and then joined to 2+3, a seq that exists in
|
||||||
|
# 0+1 and 2+3 but NOT 1+2 will become fully paired on the last join but
|
||||||
|
# will not be labeled as such here
|
||||||
|
is_pairedA = a3mA['is_paired'] if 'is_paired' in a3mA else torch.ones((a3mA['msa'].shape[0],)).bool()
|
||||||
|
is_pairedB = a3mB['is_paired'] if 'is_paired' in a3mB else torch.ones((a3mB['msa'].shape[0],)).bool()
|
||||||
|
is_paired = torch.cat([is_pairedA[i_pairedA] & is_pairedB[i_pairedB],
|
||||||
|
torch.zeros((N_unpairedA + N_unpairedB,)).bool()])
|
||||||
|
|
||||||
|
# insertion features in paired MSAs are assumed to be zero
|
||||||
|
a3m = dict(msa=msa, ins=torch.zeros_like(msa), taxid=taxids, is_paired=is_paired)
|
||||||
|
return a3m
|
||||||
|
|
||||||
|
|
||||||
|
def load_minimal_multi_msa(hash_list, taxid_list, Ls, params):
|
||||||
|
"""Load a multi-MSA, which is a MSA that is paired across more than 2
|
||||||
|
chains. This loads the MSA for unique chains. Use 'expand_multi_msa` to
|
||||||
|
duplicate portions of the MSA for homo-oligomer repeated chains.
|
||||||
|
|
||||||
|
Given a list of unique MSA hashes, loads all MSAs (using paired MSAs where
|
||||||
|
it can) and pairs sequences across as many sub-MSAs as possible by matching
|
||||||
|
taxonomic ID. For details on how pairing is done, see
|
||||||
|
`join_msas_by_taxid()`
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
hash_list : list of str
|
||||||
|
Hashes of MSAs to load and join. Must not contain duplicates.
|
||||||
|
taxid_list : list of str
|
||||||
|
Taxonomic IDs of query sequences of each input MSA.
|
||||||
|
Ls : list of int
|
||||||
|
Lengths of the chains corresponding to the hashes.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
a3m_out : dict
|
||||||
|
Multi-MSA with all input MSAs. Keys: `msa`,`ins` [torch.Tensor (N_seq, L)],
|
||||||
|
`taxid` [np.array (Nseq,)], `is_paired` [torch.Tensor (N_seq,)]
|
||||||
|
hashes_out : list of str
|
||||||
|
Hashes of MSAs in the order that they are joined in `a3m_out`.
|
||||||
|
Contains the same elements as the input `hash_list` but may be in a
|
||||||
|
different order.
|
||||||
|
Ls_out : list of int
|
||||||
|
Lengths of each chain in `a3m_out`
|
||||||
|
"""
|
||||||
|
assert(len(hash_list)==len(set(hash_list))), 'Input MSA hashes must be unique'
|
||||||
|
|
||||||
|
# the lists below are constructed such that `a3m_list[i_a3m]` is a multi-MSA
|
||||||
|
# comprising sub-MSAs whose indices in the input lists are
|
||||||
|
# `i_in = idx_list_groups[i_a3m][i_submsa]`, i.e. the sub-MSA hashes are
|
||||||
|
# `hash_list[i_in]` and lengths are `Ls[i_in]`.
|
||||||
|
# Each sub-MSA spans a region of its multi-MSA `a3m_list[i_a3m][:,i_start:i_end]`,
|
||||||
|
# where `(i_start,i_end) = res_range_groups[i_a3m][i_submsa]`
|
||||||
|
a3m_list = [] # list of multi-MSAs
|
||||||
|
idx_list_groups = [] # list of lists of indices of input chains making up each multi-MSA
|
||||||
|
res_range_groups = [] # list of lists of start and end residues of each sub-MSA in multi-MSA
|
||||||
|
|
||||||
|
# iterate through all pairs of hashes and look for paired MSAs (pMSAs)
|
||||||
|
# NOTE: in the below, if pMSAs are loaded for hashes 0+1 and then 2+3, and
|
||||||
|
# later a pMSA is found for 0+2, the last MSA will not be loaded. The 0+1
|
||||||
|
# and 2+3 pMSAs will still be joined on taxID at the end, but sequences
|
||||||
|
# only present in the 0+2 pMSA pMSAs will be missed. this is probably very
|
||||||
|
# rare and so is ignored here for simplicity.
|
||||||
|
N = len(hash_list)
|
||||||
|
for i1, i2 in itertools.permutations(range(N),2):
|
||||||
|
|
||||||
|
idx_list = [x for group in idx_list_groups for x in group] # flattened list of loaded hashes
|
||||||
|
if i1 in idx_list and i2 in idx_list: continue # already loaded
|
||||||
|
if i1 == '' or i2 == '': continue # no taxID means no pMSA
|
||||||
|
|
||||||
|
# a paired MSA exists
|
||||||
|
if taxid_list[i1]==taxid_list[i2]:
|
||||||
|
|
||||||
|
h1, h2 = hash_list[i1], hash_list[i2]
|
||||||
|
fn = params['COMPL_DIR']+'/pMSA/'+h1[:3]+'/'+h2[:3]+'/'+h1+'_'+h2+'.a3m.gz'
|
||||||
|
|
||||||
|
if os.path.exists(fn):
|
||||||
|
msa, ins, taxid = parse_a3m(fn, paired=True)
|
||||||
|
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins), taxid=taxid,
|
||||||
|
is_paired=torch.ones(msa.shape[0]).bool())
|
||||||
|
res_range1 = (0,Ls[i1])
|
||||||
|
res_range2 = (Ls[i1],msa.shape[1])
|
||||||
|
|
||||||
|
# both hashes are new, add paired MSA to list
|
||||||
|
if i1 not in idx_list and i2 not in idx_list:
|
||||||
|
a3m_list.append(a3m_new)
|
||||||
|
idx_list_groups.append([i1,i2])
|
||||||
|
res_range_groups.append([res_range1, res_range2])
|
||||||
|
|
||||||
|
# one of the hashes is already in a multi-MSA
|
||||||
|
# find that multi-MSA and join the new pMSA to it
|
||||||
|
elif i1 in idx_list:
|
||||||
|
# which multi-MSA & sub-MSA has the hash with index `i1`?
|
||||||
|
i_a3m = np.where([i1 in group for group in idx_list_groups])[0][0]
|
||||||
|
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i1)[0][0]
|
||||||
|
|
||||||
|
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range1
|
||||||
|
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
|
||||||
|
|
||||||
|
idx_list_groups[i_a3m].append(i2)
|
||||||
|
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
|
||||||
|
L_new = res_range2[1] - res_range2[0]
|
||||||
|
res_range_groups[i_a3m].append((L, L+L_new))
|
||||||
|
|
||||||
|
elif i2 in idx_list:
|
||||||
|
# which multi-MSA & sub-MSA has the hash with index `i2`?
|
||||||
|
i_a3m = np.where([i2 in group for group in idx_list_groups])[0][0]
|
||||||
|
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i2)[0][0]
|
||||||
|
|
||||||
|
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range2
|
||||||
|
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
|
||||||
|
|
||||||
|
idx_list_groups[i_a3m].append(i1)
|
||||||
|
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
|
||||||
|
L_new = res_range1[1] - res_range1[0]
|
||||||
|
res_range_groups[i_a3m].append((L, L+L_new))
|
||||||
|
|
||||||
|
# add unpaired MSAs
|
||||||
|
# ungroup hash indices now, since we're done making multi-MSAs
|
||||||
|
idx_list = [x for group in idx_list_groups for x in group]
|
||||||
|
for i in range(N):
|
||||||
|
if i not in idx_list:
|
||||||
|
fn = params['PDB_DIR'] + '/a3m/' + hash_list[i][:3] + '/' + hash_list[i] + '.a3m.gz'
|
||||||
|
msa, ins, taxid = parse_a3m(fn)
|
||||||
|
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins),
|
||||||
|
taxid=taxid, is_paired=torch.ones(msa.shape[0]).bool())
|
||||||
|
a3m_list.append(a3m_new)
|
||||||
|
idx_list.append(i)
|
||||||
|
|
||||||
|
Ls_out = [Ls[i] for i in idx_list]
|
||||||
|
hashes_out = [hash_list[i] for i in idx_list]
|
||||||
|
|
||||||
|
# join multi-MSAs & unpaired MSAs
|
||||||
|
a3m_out = a3m_list[0]
|
||||||
|
for i in range(1, len(a3m_list)):
|
||||||
|
a3m_out = join_msas_by_taxid(a3m_out, a3m_list[i])
|
||||||
|
|
||||||
|
return a3m_out, hashes_out, Ls_out
|
||||||
|
|
||||||
|
|
||||||
|
def expand_multi_msa(a3m, hashes_in, hashes_out, Ls_in, Ls_out, params):
|
||||||
|
"""Expands a multi-MSA of unique chains into an MSA of a
|
||||||
|
hetero-homo-oligomer in which some chains appear more than once. The query
|
||||||
|
sequences (1st sequence of MSA) are concatenated directly along the
|
||||||
|
residue dimention. The remaining sequences are offset-tiled (i.e. "padded &
|
||||||
|
stacked") so that exact repeat sequences aren't paired.
|
||||||
|
|
||||||
|
For example, if the original multi-MSA contains unique chains 1,2,3 but
|
||||||
|
the final chain order is 1,2,1,3,3,1, this function will output an MSA like
|
||||||
|
(where - denotes a block of gap characters):
|
||||||
|
|
||||||
|
1 2 - 3 - -
|
||||||
|
- - 1 - 3 -
|
||||||
|
- - - - - 1
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
a3m : dict
|
||||||
|
Contains torch.Tensors `msa` and `ins` (N_seq, L) and np.array `taxid` (Nseq,),
|
||||||
|
representing the multi-MSA of unique chains.
|
||||||
|
hashes_in : list of str
|
||||||
|
Unique MSA hashes used in `a3m`.
|
||||||
|
hashes_out : list of str
|
||||||
|
Non-unique MSA hashes desired in expanded MSA.
|
||||||
|
Ls_in : list of int
|
||||||
|
Lengths of each chain in `a3m`
|
||||||
|
Ls_out : list of int
|
||||||
|
Lengths of each chain desired in expanded MSA.
|
||||||
|
params : dict
|
||||||
|
Data loading parameters
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
a3m : dict
|
||||||
|
Contains torch.Tensors `msa` and `ins` of expanded MSA. No
|
||||||
|
taxids because no further joining needs to be done.
|
||||||
|
"""
|
||||||
|
assert(len(hashes_out)==len(Ls_out))
|
||||||
|
assert(set(hashes_in)==set(hashes_out))
|
||||||
|
assert(a3m['msa'].shape[1]==sum(Ls_in))
|
||||||
|
|
||||||
|
# figure out which oligomeric repeat is represented by each hash in `hashes_out`
|
||||||
|
# each new repeat will be offset in sequence dimension of final MSA
|
||||||
|
counts = dict()
|
||||||
|
n_copy = [] # n-th copy of this hash in `hashes`
|
||||||
|
for h in hashes_out:
|
||||||
|
if h in counts:
|
||||||
|
counts[h] += 1
|
||||||
|
else:
|
||||||
|
counts[h] = 1
|
||||||
|
n_copy.append(counts[h])
|
||||||
|
|
||||||
|
# num sequences in source & destination MSAs
|
||||||
|
N_in = a3m['msa'].shape[0]
|
||||||
|
N_out = (N_in-1)*max(n_copy)+1 # concatenate query seqs, pad&stack the rest
|
||||||
|
|
||||||
|
# source MSA
|
||||||
|
msa_in, ins_in = a3m['msa'], a3m['ins']
|
||||||
|
|
||||||
|
# initialize destination MSA to gap characters
|
||||||
|
msa_out = torch.full((N_out, sum(Ls_out)), ChemData().UNKINDEX)
|
||||||
|
ins_out = torch.full((N_out, sum(Ls_out)), 0)
|
||||||
|
|
||||||
|
# for each destination chain
|
||||||
|
for i_out, h_out in enumerate(hashes_out):
|
||||||
|
# identify index of source chain
|
||||||
|
i_in = np.where(np.array(hashes_in)==h_out)[0][0]
|
||||||
|
|
||||||
|
# residue indexes
|
||||||
|
i1_res_in = sum(Ls_in[:i_in])
|
||||||
|
i2_res_in = sum(Ls_in[:i_in+1])
|
||||||
|
i1_res_out = sum(Ls_out[:i_out])
|
||||||
|
i2_res_out = sum(Ls_out[:i_out+1])
|
||||||
|
|
||||||
|
# copy over query sequence
|
||||||
|
# NOTE: There is a bug in these next two lines!
|
||||||
|
# The second line should be ins_out[0, i1_res_out:i2_res_out] = ins_in[0, i1_res_in:i2_res_in]
|
||||||
|
msa_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
|
||||||
|
ins_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
|
||||||
|
|
||||||
|
# offset non-query sequences along sequence dimension based on repeat number of a given hash
|
||||||
|
i1_seq_out = 1+(n_copy[i_out]-1)*(N_in-1)
|
||||||
|
i2_seq_out = 1+n_copy[i_out]*(N_in-1)
|
||||||
|
# copy over non-query sequences
|
||||||
|
msa_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = msa_in[1:, i1_res_in:i2_res_in]
|
||||||
|
ins_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = ins_in[1:, i1_res_in:i2_res_in]
|
||||||
|
|
||||||
|
# only 1st oligomeric repeat can be fully paired
|
||||||
|
is_paired_out = torch.cat([a3m['is_paired'], torch.zeros((N_out-N_in,)).bool()])
|
||||||
|
|
||||||
|
a3m_out = dict(msa=msa_out, ins=ins_out, is_paired=is_paired_out)
|
||||||
|
a3m_out = remove_all_gap_seqs(a3m_out)
|
||||||
|
|
||||||
|
return a3m_out
|
||||||
|
|
||||||
|
def load_multi_msa(chain_ids, Ls, chid2hash, chid2taxid, params):
|
||||||
|
"""Loads multi-MSA for an arbitrary number of protein chains. Tries to
|
||||||
|
locate paired MSAs and pair sequences across all chains by taxonomic ID.
|
||||||
|
Unpaired sequences are padded and stacked on the bottom.
|
||||||
|
"""
|
||||||
|
# get MSA hashes (used to locate a3m files) and taxonomic IDs (used to determine pairing)
|
||||||
|
hashes = []
|
||||||
|
hashes_unique = []
|
||||||
|
taxids_unique = []
|
||||||
|
Ls_unique = []
|
||||||
|
for chid,L_ in zip(chain_ids, Ls):
|
||||||
|
hashes.append(chid2hash[chid])
|
||||||
|
if chid2hash[chid] not in hashes_unique:
|
||||||
|
hashes_unique.append(chid2hash[chid])
|
||||||
|
taxids_unique.append(chid2taxid.get(chid))
|
||||||
|
Ls_unique.append(L_)
|
||||||
|
|
||||||
|
# loads multi-MSA for unique chains
|
||||||
|
a3m_prot, hashes_unique, Ls_unique = \
|
||||||
|
load_minimal_multi_msa(hashes_unique, taxids_unique, Ls_unique, params)
|
||||||
|
|
||||||
|
# expands multi-MSA to repeat chains of homo-oligomers
|
||||||
|
a3m_prot = expand_multi_msa(a3m_prot, hashes_unique, hashes, Ls_unique, Ls, params)
|
||||||
|
|
||||||
|
return a3m_prot
|
||||||
|
|
||||||
|
def choose_multimsa_clusters(msa_seq_is_paired, params):
|
||||||
|
"""Returns indices of fully-paired sequences in a multi-MSA to use as seed
|
||||||
|
clusters during MSA featurization.
|
||||||
|
"""
|
||||||
|
frac_paired = msa_seq_is_paired.float().mean()
|
||||||
|
if frac_paired > 0.25: # enough fully paired sequences, just let MSAFeaturize choose randomly
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
# ensure that half of the clusters are fully-paired sequences,
|
||||||
|
# and let the rest be chosen randomly
|
||||||
|
N_seed = params['MAXLAT']//2
|
||||||
|
msa_seed_clus = []
|
||||||
|
for i_cycle in range(params['MAXCYCLE']):
|
||||||
|
idx_paired = torch.where(msa_seq_is_paired)[0]
|
||||||
|
msa_seed_clus.append(idx_paired[torch.randperm(len(idx_paired))][:N_seed])
|
||||||
|
return msa_seed_clus
|
||||||
|
|
||||||
|
|
||||||
|
#fd
|
||||||
|
def get_bond_distances(bond_feats):
|
||||||
|
atom_bonds = (bond_feats > 0)*(bond_feats<5)
|
||||||
|
dist_matrix = scipy.sparse.csgraph.shortest_path(atom_bonds.long().numpy(), directed=False)
|
||||||
|
# dist_matrix = torch.tensor(np.nan_to_num(dist_matrix, posinf=4.0)) # protein portion is inf and you don't want to mask it out
|
||||||
|
return torch.from_numpy(dist_matrix).float()
|
||||||
|
|
||||||
|
|
||||||
|
def get_pdb(pdbfilename, plddtfilename, item, lddtcut, sccut):
|
||||||
|
xyz, mask, res_idx = parse_pdb(pdbfilename)
|
||||||
|
plddt = np.load(plddtfilename)
|
||||||
|
|
||||||
|
# update mask info with plddt (ignore sidechains if plddt < 90.0)
|
||||||
|
mask_lddt = np.full_like(mask, False)
|
||||||
|
mask_lddt[plddt > sccut] = True
|
||||||
|
mask_lddt[:,:5] = True
|
||||||
|
mask = np.logical_and(mask, mask_lddt)
|
||||||
|
mask = np.logical_and(mask, (plddt > lddtcut)[:,None])
|
||||||
|
|
||||||
|
return {'xyz':torch.tensor(xyz), 'mask':torch.tensor(mask), 'idx': torch.tensor(res_idx), 'label':item}
|
||||||
|
|
||||||
|
def get_msa(a3mfilename, item, maxseq=5000):
|
||||||
|
msa,ins, taxIDs = parse_a3m(a3mfilename, maxseq=5000)
|
||||||
|
return {'msa':torch.tensor(msa), 'ins':torch.tensor(ins), 'taxIDs':taxIDs, 'label':item}
|
151
rf2aa/data/merge_inputs.py
Normal file
151
rf2aa/data/merge_inputs.py
Normal file
|
@ -0,0 +1,151 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, merge_hetero_templates, get_term_feats
|
||||||
|
from rf2aa.data.data_loader import RawInputData
|
||||||
|
from rf2aa.util import center_and_realign_missing, same_chain_from_bond_feats, random_rot_trans, idx_from_Ls
|
||||||
|
|
||||||
|
|
||||||
|
def merge_protein_inputs(protein_inputs, deterministic: bool = False):
|
||||||
|
if len(protein_inputs) == 0:
|
||||||
|
return None,[]
|
||||||
|
elif len(protein_inputs) == 1:
|
||||||
|
chain = list(protein_inputs.keys())[0]
|
||||||
|
input = list(protein_inputs.values())[0]
|
||||||
|
xyz_t = input.xyz_t
|
||||||
|
xyz_t[0:1] = random_rot_trans(xyz_t[0:1], deterministic=deterministic)
|
||||||
|
input.xyz_t = xyz_t
|
||||||
|
return input, [(chain, input.length())]
|
||||||
|
# handle merging MSAs and such
|
||||||
|
# first determine which sequence are identical, then which one have mergeable MSAs
|
||||||
|
# then cat the templates, other feats
|
||||||
|
pass
|
||||||
|
|
||||||
|
def merge_na_inputs(na_inputs):
|
||||||
|
# should just be trivially catting features
|
||||||
|
running_inputs = None
|
||||||
|
chain_lengths = []
|
||||||
|
for chid, input in na_inputs.items():
|
||||||
|
running_inputs = merge_two_inputs(running_inputs, input)
|
||||||
|
chain_lengths.append((chid, input.length()))
|
||||||
|
return running_inputs, chain_lengths
|
||||||
|
|
||||||
|
def merge_sm_inputs(sm_inputs):
|
||||||
|
# should be trivially catting features
|
||||||
|
running_inputs = None
|
||||||
|
chain_lengths = []
|
||||||
|
for chid, input in sm_inputs.items():
|
||||||
|
running_inputs = merge_two_inputs(running_inputs, input)
|
||||||
|
chain_lengths.append((chid, input.length()))
|
||||||
|
return running_inputs, chain_lengths
|
||||||
|
|
||||||
|
def merge_two_inputs(first_input, second_input):
|
||||||
|
# merges two arbitrary inputs of data types
|
||||||
|
if first_input is None and second_input is None:
|
||||||
|
return None
|
||||||
|
elif first_input is None:
|
||||||
|
return second_input
|
||||||
|
elif second_input is None:
|
||||||
|
return first_input
|
||||||
|
|
||||||
|
Ls = [first_input.length(), second_input.length()]
|
||||||
|
L_total = sum(Ls)
|
||||||
|
# merge msas
|
||||||
|
|
||||||
|
a3m_first = {
|
||||||
|
"msa": first_input.msa,
|
||||||
|
"ins": first_input.ins,
|
||||||
|
}
|
||||||
|
a3m_second = {
|
||||||
|
"msa": second_input.msa,
|
||||||
|
"ins": second_input.ins,
|
||||||
|
}
|
||||||
|
a3m = merge_a3m_hetero(a3m_first, a3m_second, Ls)
|
||||||
|
# merge bond_feats
|
||||||
|
bond_feats = torch.zeros((L_total, L_total)).long()
|
||||||
|
offset = 0
|
||||||
|
for bf in [first_input.bond_feats, second_input.bond_feats]:
|
||||||
|
L = bf.shape[0]
|
||||||
|
bond_feats[offset:offset+L, offset:offset+L] = bf
|
||||||
|
offset += L
|
||||||
|
|
||||||
|
# merge templates
|
||||||
|
xyz_t = torch.cat([first_input.xyz_t, second_input.xyz_t],dim=1)
|
||||||
|
t1d = torch.cat([first_input.t1d, second_input.t1d],dim=1)
|
||||||
|
mask_t = torch.cat([first_input.mask_t, second_input.mask_t],dim=1)
|
||||||
|
|
||||||
|
# handle chirals (need to residue offset)
|
||||||
|
if second_input.chirals.shape[0] > 0 :
|
||||||
|
second_input.chirals[:, :-1] = second_input.chirals[:, :-1] + first_input.length()
|
||||||
|
chirals = torch.cat([first_input.chirals, second_input.chirals])
|
||||||
|
|
||||||
|
# cat atom frames
|
||||||
|
atom_frames = torch.cat([first_input.atom_frames, second_input.atom_frames])
|
||||||
|
# return new object
|
||||||
|
return RawInputData(
|
||||||
|
a3m["msa"],
|
||||||
|
a3m["ins"],
|
||||||
|
bond_feats,
|
||||||
|
xyz_t,
|
||||||
|
mask_t,
|
||||||
|
t1d,
|
||||||
|
chirals,
|
||||||
|
atom_frames,
|
||||||
|
taxids=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def merge_all(
|
||||||
|
protein_inputs,
|
||||||
|
na_inputs,
|
||||||
|
sm_inputs,
|
||||||
|
residues_to_atomize,
|
||||||
|
deterministic: bool = False,
|
||||||
|
):
|
||||||
|
|
||||||
|
#protein_lengths = [protein_input.length() for protein_input in protein_inputs.values()]
|
||||||
|
#na_lengths = [na_input.length() for na_input in na_inputs.values()]
|
||||||
|
#sm_lengths = [sm_input.length() for sm_input in sm_inputs.values()]
|
||||||
|
#all_lengths = protein_lengths + na_lengths + sm_lengths
|
||||||
|
|
||||||
|
#term_info = get_term_feats(all_lengths)
|
||||||
|
#term_info[sum(protein_lengths):, :] = 0
|
||||||
|
|
||||||
|
protein_inputs, protein_chain_lengths = merge_protein_inputs(protein_inputs, deterministic=deterministic)
|
||||||
|
|
||||||
|
na_inputs, na_chain_lengths = merge_na_inputs(na_inputs)
|
||||||
|
sm_inputs, sm_chain_lengths = merge_sm_inputs(sm_inputs)
|
||||||
|
if protein_inputs is None and na_inputs is None and sm_inputs is None:
|
||||||
|
raise ValueError("No valid inputs were provided")
|
||||||
|
running_inputs = merge_two_inputs(protein_inputs, na_inputs) #could handle pairing protein/NA MSAs here
|
||||||
|
running_inputs = merge_two_inputs(running_inputs, sm_inputs)
|
||||||
|
|
||||||
|
all_chain_lengths = protein_chain_lengths + na_chain_lengths + sm_chain_lengths
|
||||||
|
running_inputs.chain_lengths = all_chain_lengths
|
||||||
|
|
||||||
|
all_lengths = get_Ls_from_chain_lengths(running_inputs.chain_lengths)
|
||||||
|
protein_lengths = get_Ls_from_chain_lengths(protein_chain_lengths)
|
||||||
|
term_info = get_term_feats(all_lengths)
|
||||||
|
term_info[sum(protein_lengths):, :] = 0
|
||||||
|
running_inputs.term_info = term_info
|
||||||
|
|
||||||
|
xyz_t = running_inputs.xyz_t
|
||||||
|
mask_t = running_inputs.mask_t
|
||||||
|
|
||||||
|
same_chain = same_chain = same_chain_from_bond_feats(running_inputs.bond_feats)
|
||||||
|
ntempl = xyz_t.shape[0]
|
||||||
|
xyz_t = torch.stack(
|
||||||
|
[center_and_realign_missing(xyz_t[i], mask_t[i], same_chain=same_chain) for i in range(ntempl)]
|
||||||
|
)
|
||||||
|
xyz_t = torch.nan_to_num(xyz_t)
|
||||||
|
running_inputs.xyz_t = xyz_t
|
||||||
|
running_inputs.idx = idx_from_Ls(all_lengths)
|
||||||
|
|
||||||
|
# after everything is merged need to add bond feats for covales
|
||||||
|
# reindex protein feats function
|
||||||
|
if residues_to_atomize:
|
||||||
|
running_inputs.update_protein_features_after_atomize(residues_to_atomize)
|
||||||
|
|
||||||
|
return running_inputs
|
||||||
|
|
||||||
|
def get_Ls_from_chain_lengths(chain_lengths):
|
||||||
|
return [val[1] for val in chain_lengths]
|
||||||
|
|
46
rf2aa/data/nucleic_acid.py
Normal file
46
rf2aa/data/nucleic_acid.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from rf2aa.data.parsers import parse_mixed_fasta, parse_multichain_fasta
|
||||||
|
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, blank_template
|
||||||
|
from rf2aa.data.data_loader import RawInputData
|
||||||
|
from rf2aa.util import get_protein_bond_feats
|
||||||
|
|
||||||
|
def load_nucleic_acid(fasta_fn, input_type, model_runner):
|
||||||
|
if input_type not in ["dna", "rna"]:
|
||||||
|
raise ValueError("Only DNA and RNA inputs allowed for nucleic acids")
|
||||||
|
if input_type == "dna":
|
||||||
|
dna_alphabet = True
|
||||||
|
rna_alphabet = False
|
||||||
|
elif input_type == "rna":
|
||||||
|
dna_alphabet = False
|
||||||
|
rna_alphabet = True
|
||||||
|
|
||||||
|
loader_params = model_runner.config.loader_params
|
||||||
|
msa, ins, L = parse_multichain_fasta(fasta_fn, rna_alphabet=rna_alphabet, dna_alphabet=dna_alphabet)
|
||||||
|
if (msa.shape[0] > loader_params["MAXSEQ"]):
|
||||||
|
idxs_tokeep = np.random.permutation(msa.shape[0])[:loader_params["MAXSEQ"]]
|
||||||
|
idxs_tokeep[0] = 0
|
||||||
|
msa = msa[idxs_tokeep]
|
||||||
|
ins = ins[idxs_tokeep]
|
||||||
|
if len(L) > 1:
|
||||||
|
raise ValueError("Please provide separate fasta files for each nucleic acid chain")
|
||||||
|
L = L[0]
|
||||||
|
xyz_t, t1d, mask_t, _ = blank_template(loader_params["n_templ"], L)
|
||||||
|
|
||||||
|
|
||||||
|
bond_feats = get_protein_bond_feats(L)
|
||||||
|
chirals = torch.zeros(0, 5)
|
||||||
|
atom_frames = torch.zeros(0, 3, 2)
|
||||||
|
|
||||||
|
return RawInputData(
|
||||||
|
torch.from_numpy(msa),
|
||||||
|
torch.from_numpy(ins),
|
||||||
|
bond_feats,
|
||||||
|
xyz_t,
|
||||||
|
mask_t,
|
||||||
|
t1d,
|
||||||
|
chirals,
|
||||||
|
atom_frames,
|
||||||
|
taxids=None,
|
||||||
|
)
|
809
rf2aa/data/parsers.py
Normal file
809
rf2aa/data/parsers.py
Normal file
|
@ -0,0 +1,809 @@
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
import scipy.spatial
|
||||||
|
import string
|
||||||
|
import os,re
|
||||||
|
from os.path import exists
|
||||||
|
import random
|
||||||
|
import rf2aa.util as util
|
||||||
|
import gzip
|
||||||
|
import rf2aa
|
||||||
|
from rf2aa.ffindex import *
|
||||||
|
import torch
|
||||||
|
from openbabel import openbabel
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
def get_dislf(seq, xyz, mask):
|
||||||
|
L = seq.shape[0]
|
||||||
|
resolved_cys_mask = ((seq==ChemData().aa2num['CYS']) * mask[:,5]).nonzero().squeeze(-1) # cys[5]=='sg'
|
||||||
|
sgs = xyz[resolved_cys_mask,5]
|
||||||
|
ii,jj = torch.triu_indices(sgs.shape[0],sgs.shape[0],1)
|
||||||
|
d_sg_sg = torch.linalg.norm(sgs[ii,:]-sgs[jj,:], dim=-1)
|
||||||
|
is_dslf = (d_sg_sg>1.7)*(d_sg_sg<2.3)
|
||||||
|
|
||||||
|
dslf = []
|
||||||
|
for i in is_dslf.nonzero():
|
||||||
|
dslf.append( (
|
||||||
|
resolved_cys_mask[ii[i]].item(),
|
||||||
|
resolved_cys_mask[jj[i]].item(),
|
||||||
|
) )
|
||||||
|
return dslf
|
||||||
|
|
||||||
|
def read_template_pdb(L, pdb_fn, target_chain=None):
|
||||||
|
# get full sequence from given PDB
|
||||||
|
seq_full = list()
|
||||||
|
prev_chain=''
|
||||||
|
with open(pdb_fn) as fp:
|
||||||
|
for line in fp:
|
||||||
|
if line[:4] != "ATOM":
|
||||||
|
continue
|
||||||
|
if line[12:16].strip() != "CA":
|
||||||
|
continue
|
||||||
|
if line[21] != prev_chain:
|
||||||
|
if len(seq_full) > 0:
|
||||||
|
L_s.append(len(seq_full)-offset)
|
||||||
|
offset = len(seq_full)
|
||||||
|
prev_chain = line[21]
|
||||||
|
aa = line[17:20]
|
||||||
|
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
|
||||||
|
|
||||||
|
seq_full = torch.tensor(seq_full).long()
|
||||||
|
|
||||||
|
xyz = torch.full((L, 36, 3), np.nan).float()
|
||||||
|
seq = torch.full((L,), 20).long()
|
||||||
|
conf = torch.zeros(L,1).float()
|
||||||
|
|
||||||
|
with open(pdb_fn) as fp:
|
||||||
|
for line in fp:
|
||||||
|
if line[:4] != "ATOM":
|
||||||
|
continue
|
||||||
|
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
|
||||||
|
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
|
||||||
|
#
|
||||||
|
idx = resNo - 1
|
||||||
|
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
|
||||||
|
if tgtatm == atom:
|
||||||
|
xyz[idx, i_atm, :] = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
|
||||||
|
break
|
||||||
|
seq[idx] = aa_idx
|
||||||
|
|
||||||
|
mask = torch.logical_not(torch.isnan(xyz[:,:3,0])) # (L, 3)
|
||||||
|
mask = mask.all(dim=-1)[:,None]
|
||||||
|
conf = torch.where(mask, torch.full((L,1),0.1), torch.zeros(L,1)).float()
|
||||||
|
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=32).float()
|
||||||
|
t1d = torch.cat((seq_1hot, conf), -1)
|
||||||
|
|
||||||
|
#return seq_full[None], ins[None], L_s, xyz[None], t1d[None]
|
||||||
|
return xyz[None], t1d[None]
|
||||||
|
|
||||||
|
def read_multichain_pdb(pdb_fn, tmpl_chain=None, tmpl_conf=0.1):
|
||||||
|
print ('read_multichain_pdb',tmpl_chain)
|
||||||
|
|
||||||
|
# get full sequence from PDB
|
||||||
|
seq_full = list()
|
||||||
|
L_s = list()
|
||||||
|
prev_chain=''
|
||||||
|
offset = 0
|
||||||
|
with open(pdb_fn) as fp:
|
||||||
|
for line in fp:
|
||||||
|
if line[:4] != "ATOM":
|
||||||
|
continue
|
||||||
|
if line[12:16].strip() != "CA":
|
||||||
|
continue
|
||||||
|
if line[21] != prev_chain:
|
||||||
|
if len(seq_full) > 0:
|
||||||
|
L_s.append(len(seq_full)-offset)
|
||||||
|
offset = len(seq_full)
|
||||||
|
prev_chain = line[21]
|
||||||
|
aa = line[17:20]
|
||||||
|
seq_full.append(ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20)
|
||||||
|
L_s.append(len(seq_full) - offset)
|
||||||
|
|
||||||
|
seq_full = torch.tensor(seq_full).long()
|
||||||
|
L = len(seq_full)
|
||||||
|
msa = torch.stack((seq_full,seq_full,seq_full), dim=0)
|
||||||
|
msa[1,:L_s[0]] = 20
|
||||||
|
msa[2,L_s[0]:] = 20
|
||||||
|
ins = torch.zeros_like(msa)
|
||||||
|
|
||||||
|
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
|
||||||
|
xyz_t = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1) + torch.rand(1,L,1,3)*5.0
|
||||||
|
|
||||||
|
mask = torch.full((1, L, ChemData().NTOTAL), False)
|
||||||
|
mask_t = torch.full((1, L, ChemData().NTOTAL), False)
|
||||||
|
seq = torch.full((1, L,), 20).long()
|
||||||
|
conf = torch.zeros(1, L,1).float()
|
||||||
|
|
||||||
|
with open(pdb_fn) as fp:
|
||||||
|
for line in fp:
|
||||||
|
if line[:4] != "ATOM":
|
||||||
|
continue
|
||||||
|
outbatch = 0
|
||||||
|
|
||||||
|
resNo, atom, aa = int(line[22:26]), line[12:16], line[17:20]
|
||||||
|
aa_idx = ChemData().aa2num[aa] if aa in ChemData().aa2num.keys() else 20
|
||||||
|
|
||||||
|
idx = resNo - 1
|
||||||
|
|
||||||
|
for i_atm, tgtatm in enumerate(ChemData().aa2long[aa_idx]):
|
||||||
|
if tgtatm == atom:
|
||||||
|
xyz_i = torch.tensor([float(line[30:38]), float(line[38:46]), float(line[46:54])])
|
||||||
|
xyz[0, idx, i_atm, :] = xyz_i
|
||||||
|
mask[0, idx, i_atm] = True
|
||||||
|
if line[21] == tmpl_chain:
|
||||||
|
xyz_t[0, idx, i_atm, :] = xyz_i
|
||||||
|
mask_t[0, idx, i_atm] = True
|
||||||
|
break
|
||||||
|
seq[0, idx] = aa_idx
|
||||||
|
|
||||||
|
if (mask_t.any()):
|
||||||
|
xyz_t[0] = rf2aa.util.center_and_realign_missing(xyz[0], mask[0])
|
||||||
|
|
||||||
|
dslf = get_dislf(seq[0], xyz[0], mask[0])
|
||||||
|
|
||||||
|
# assign confidence 'CONF' to all residues with backbone in template
|
||||||
|
conf = torch.where(mask_t[...,:3].all(dim=-1)[...,None], torch.full((1,L,1),tmpl_conf), torch.zeros(L,1)).float()
|
||||||
|
|
||||||
|
seq_1hot = torch.nn.functional.one_hot(seq, num_classes=ChemData().NAATOKENS-1).float()
|
||||||
|
t1d = torch.cat((seq_1hot, conf), -1)
|
||||||
|
|
||||||
|
return msa, ins, L_s, xyz_t, mask_t, t1d, dslf
|
||||||
|
|
||||||
|
def parse_fasta(filename, maxseq=10000, rmsa_alphabet=False):
|
||||||
|
msa = []
|
||||||
|
ins = []
|
||||||
|
|
||||||
|
fstream = open(filename,"r")
|
||||||
|
|
||||||
|
for line in fstream:
|
||||||
|
# skip labels
|
||||||
|
if line[0] == '>':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove right whitespaces
|
||||||
|
line = line.rstrip()
|
||||||
|
|
||||||
|
if len(line) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove lowercase letters and append to MSA
|
||||||
|
msa.append(line)
|
||||||
|
|
||||||
|
# sequence length
|
||||||
|
L = len(msa[-1])
|
||||||
|
|
||||||
|
i = np.zeros((L))
|
||||||
|
ins.append(i)
|
||||||
|
|
||||||
|
# convert letters into numbers
|
||||||
|
if rmsa_alphabet:
|
||||||
|
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
|
||||||
|
else:
|
||||||
|
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8)
|
||||||
|
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
|
||||||
|
for i in range(alphabet.shape[0]):
|
||||||
|
msa[msa == alphabet[i]] = i
|
||||||
|
|
||||||
|
ins = np.array(ins, dtype=np.uint8)
|
||||||
|
|
||||||
|
return msa,ins
|
||||||
|
|
||||||
|
# Parse a fasta file containing multiple chains separated by '/'
|
||||||
|
def parse_multichain_fasta(filename, maxseq=10000, rna_alphabet=False, dna_alphabet=False):
|
||||||
|
msa = []
|
||||||
|
ins = []
|
||||||
|
|
||||||
|
fstream = open(filename,"r")
|
||||||
|
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||||
|
|
||||||
|
L_s = []
|
||||||
|
for line in fstream:
|
||||||
|
# skip labels
|
||||||
|
if line[0] == '>':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove right whitespaces
|
||||||
|
line = line.rstrip()
|
||||||
|
|
||||||
|
if len(line) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove lowercase letters and append to MSA
|
||||||
|
msa_i = line.translate(table)
|
||||||
|
msa_i = msa_i.replace('B','D') # hacky...
|
||||||
|
if L_s == []:
|
||||||
|
L_s = [len(x) for x in msa_i.split('/')]
|
||||||
|
msa_i = msa_i.replace('/','')
|
||||||
|
msa.append(msa_i)
|
||||||
|
|
||||||
|
# sequence length
|
||||||
|
L = len(msa[-1])
|
||||||
|
|
||||||
|
i = np.zeros((L))
|
||||||
|
ins.append(i)
|
||||||
|
|
||||||
|
if (len(msa) >= maxseq):
|
||||||
|
break
|
||||||
|
|
||||||
|
# convert letters into numbers
|
||||||
|
if rna_alphabet:
|
||||||
|
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
|
||||||
|
elif dna_alphabet:
|
||||||
|
alphabet = np.array(list("00000000000000000000-0ACGTD00000"), dtype='|S1').view(np.uint8)
|
||||||
|
else:
|
||||||
|
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
|
||||||
|
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
|
||||||
|
|
||||||
|
for i in range(alphabet.shape[0]):
|
||||||
|
msa[msa == alphabet[i]] = i
|
||||||
|
|
||||||
|
ins = np.array(ins, dtype=np.uint8)
|
||||||
|
|
||||||
|
return msa,ins,L_s
|
||||||
|
|
||||||
|
#fd - parse protein/RNA coupled fastas
|
||||||
|
def parse_mixed_fasta(filename, maxseq=10000):
|
||||||
|
msa1,msa2 = [],[]
|
||||||
|
|
||||||
|
fstream = open(filename,"r")
|
||||||
|
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||||
|
|
||||||
|
unpaired_r, unpaired_p = 0, 0
|
||||||
|
|
||||||
|
for line in fstream:
|
||||||
|
# skip labels
|
||||||
|
if line[0] == '>':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove right whitespaces
|
||||||
|
line = line.rstrip()
|
||||||
|
|
||||||
|
if len(line) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove lowercase letters and append to MSA
|
||||||
|
msa_i = line.translate(table)
|
||||||
|
msa_i = msa_i.replace('B','D') # hacky...
|
||||||
|
|
||||||
|
msas_i = msa_i.split('/')
|
||||||
|
|
||||||
|
if (len(msas_i)==1):
|
||||||
|
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
|
||||||
|
|
||||||
|
if (len(msa1)==0 or (
|
||||||
|
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
|
||||||
|
)):
|
||||||
|
# skip if we've already found half of our limit in unpaired protein seqs
|
||||||
|
if sum([1 for x in msas_i[1] if x != '-']) == 0:
|
||||||
|
unpaired_p += 1
|
||||||
|
if unpaired_p > maxseq // 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip if we've already found half of our limit in unpaired rna seqs
|
||||||
|
if sum([1 for x in msas_i[0] if x != '-']) == 0:
|
||||||
|
unpaired_r += 1
|
||||||
|
if unpaired_r > maxseq // 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
msa1.append(msas_i[0])
|
||||||
|
msa2.append(msas_i[1])
|
||||||
|
else:
|
||||||
|
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
|
||||||
|
|
||||||
|
if (len(msa1) >= maxseq):
|
||||||
|
break
|
||||||
|
|
||||||
|
# convert letters into numbers
|
||||||
|
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
|
||||||
|
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
|
||||||
|
for i in range(alphabet.shape[0]):
|
||||||
|
msa1[msa1 == alphabet[i]] = i
|
||||||
|
msa1[msa1>=31] = 21 # anything unknown to 'X'
|
||||||
|
|
||||||
|
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
|
||||||
|
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
|
||||||
|
for i in range(alphabet.shape[0]):
|
||||||
|
msa2[msa2 == alphabet[i]] = i
|
||||||
|
msa2[msa2>=31] = 30 # anything unknown to 'N'
|
||||||
|
|
||||||
|
msa = np.concatenate((msa1,msa2),axis=-1)
|
||||||
|
|
||||||
|
ins = np.zeros(msa.shape, dtype=np.uint8)
|
||||||
|
|
||||||
|
return msa,ins
|
||||||
|
|
||||||
|
|
||||||
|
# parse a fasta alignment IF it exists
|
||||||
|
# otherwise return single-sequence msa
|
||||||
|
def parse_fasta_if_exists(seq, filename, maxseq=10000, rmsa_alphabet=False):
|
||||||
|
if (exists(filename)):
|
||||||
|
return parse_fasta(filename, maxseq, rmsa_alphabet)
|
||||||
|
else:
|
||||||
|
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-0acgtxbdhuy"), dtype='|S1').view(np.uint8) # -0 are UNK/mask
|
||||||
|
seq = np.array([list(seq)], dtype='|S1').view(np.uint8)
|
||||||
|
for i in range(alphabet.shape[0]):
|
||||||
|
seq[seq == alphabet[i]] = i
|
||||||
|
|
||||||
|
return (seq, np.zeros_like(seq))
|
||||||
|
|
||||||
|
|
||||||
|
#fd - parse protein/RNA coupled fastas
|
||||||
|
def parse_mixed_fasta(filename, maxseq=8000):
|
||||||
|
msa1,msa2 = [],[]
|
||||||
|
|
||||||
|
fstream = open(filename,"r")
|
||||||
|
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||||
|
|
||||||
|
unpaired_r, unpaired_p = 0, 0
|
||||||
|
|
||||||
|
for line in fstream:
|
||||||
|
# skip labels
|
||||||
|
if line[0] == '>':
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove right whitespaces
|
||||||
|
line = line.rstrip()
|
||||||
|
|
||||||
|
if len(line) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove lowercase letters and append to MSA
|
||||||
|
msa_i = line.translate(table)
|
||||||
|
msa_i = msa_i.replace('B','D') # hacky...
|
||||||
|
|
||||||
|
msas_i = msa_i.split('/')
|
||||||
|
|
||||||
|
if (len(msas_i)==1):
|
||||||
|
msas_i = [msas_i[0][:len(msa1[0])], msas_i[0][len(msa1[0]):]]
|
||||||
|
|
||||||
|
if (len(msa1)==0 or (
|
||||||
|
len(msas_i[0])==len(msa1[0]) and len(msas_i[1])==len(msa2[0])
|
||||||
|
)):
|
||||||
|
# skip if we've already found half of our limit in unpaired protein seqs
|
||||||
|
if sum([1 for x in msas_i[1] if x != '-']) == 0:
|
||||||
|
unpaired_p += 1
|
||||||
|
if unpaired_p > maxseq // 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# skip if we've already found half of our limit in unpaired rna seqs
|
||||||
|
if sum([1 for x in msas_i[0] if x != '-']) == 0:
|
||||||
|
unpaired_r += 1
|
||||||
|
if unpaired_r > maxseq // 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
msa1.append(msas_i[0])
|
||||||
|
msa2.append(msas_i[1])
|
||||||
|
else:
|
||||||
|
print ("Len error",filename, len(msas_i[0]),len(msa1[0]),len(msas_i[1]),len(msas_i[1]))
|
||||||
|
|
||||||
|
if (len(msa1) >= maxseq):
|
||||||
|
break
|
||||||
|
|
||||||
|
# convert letters into numbers
|
||||||
|
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-Xacgtxbdhuy"), dtype='|S1').view(np.uint8)
|
||||||
|
msa1 = np.array([list(s) for s in msa1], dtype='|S1').view(np.uint8)
|
||||||
|
for i in range(alphabet.shape[0]):
|
||||||
|
msa1[msa1 == alphabet[i]] = i
|
||||||
|
msa1[msa1>=31] = 21 # anything unknown to 'X'
|
||||||
|
|
||||||
|
alphabet = np.array(list("00000000000000000000-000000ACGTN"), dtype='|S1').view(np.uint8)
|
||||||
|
msa2 = np.array([list(s) for s in msa2], dtype='|S1').view(np.uint8)
|
||||||
|
for i in range(alphabet.shape[0]):
|
||||||
|
msa2[msa2 == alphabet[i]] = i
|
||||||
|
msa2[msa2>=31] = 30 # anything unknown to 'N'
|
||||||
|
|
||||||
|
msa = np.concatenate((msa1,msa2),axis=-1)
|
||||||
|
|
||||||
|
ins = np.zeros(msa.shape, dtype=np.uint8)
|
||||||
|
|
||||||
|
return msa,ins
|
||||||
|
|
||||||
|
|
||||||
|
# read A3M and convert letters into
|
||||||
|
# integers in the 0..20 range,
|
||||||
|
# also keep track of insertions
|
||||||
|
def parse_a3m(filename, maxseq=8000, paired=False):
|
||||||
|
msa = []
|
||||||
|
ins = []
|
||||||
|
taxIDs = []
|
||||||
|
table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
|
||||||
|
|
||||||
|
# read file line by line
|
||||||
|
if filename.split('.')[-1] == 'gz':
|
||||||
|
fstream = gzip.open(filename, 'rt')
|
||||||
|
else:
|
||||||
|
fstream = open(filename, 'r')
|
||||||
|
|
||||||
|
for line in fstream:
|
||||||
|
|
||||||
|
# skip labels
|
||||||
|
if line[0] == '>':
|
||||||
|
if paired: # paired MSAs only have a TAXID in the fasta header
|
||||||
|
taxIDs.append(line[1:].strip())
|
||||||
|
else: # unpaired MSAs have all the metadata so use regex to pull out TAXID
|
||||||
|
match = re.search( r'TaxID=(\d+)', line)
|
||||||
|
if match:
|
||||||
|
taxIDs.append(match.group(1))
|
||||||
|
else:
|
||||||
|
taxIDs.append("query") # query sequence
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove right whitespaces
|
||||||
|
line = line.rstrip()
|
||||||
|
|
||||||
|
if len(line) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# remove lowercase letters and append to MSA
|
||||||
|
msa.append(line.translate(table))
|
||||||
|
|
||||||
|
# sequence length
|
||||||
|
L = len(msa[-1])
|
||||||
|
|
||||||
|
# 0 - match or gap; 1 - insertion
|
||||||
|
a = np.array([0 if c.isupper() or c=='-' else 1 for c in line])
|
||||||
|
i = np.zeros((L))
|
||||||
|
|
||||||
|
if np.sum(a) > 0:
|
||||||
|
# positions of insertions
|
||||||
|
pos = np.where(a==1)[0]
|
||||||
|
|
||||||
|
# shift by occurrence
|
||||||
|
a = pos - np.arange(pos.shape[0])
|
||||||
|
|
||||||
|
# position of insertions in cleaned sequence
|
||||||
|
# and their length
|
||||||
|
pos,num = np.unique(a, return_counts=True)
|
||||||
|
|
||||||
|
# append to the matrix of insetions
|
||||||
|
i[pos] = num
|
||||||
|
|
||||||
|
ins.append(i)
|
||||||
|
|
||||||
|
if (len(msa) >= maxseq):
|
||||||
|
break
|
||||||
|
|
||||||
|
# convert letters into numbers
|
||||||
|
alphabet = np.array(list("ARNDCQEGHILKMFPSTWYV-"), dtype='|S1').view(np.uint8)
|
||||||
|
msa = np.array([list(s) for s in msa], dtype='|S1').view(np.uint8)
|
||||||
|
for i in range(alphabet.shape[0]):
|
||||||
|
msa[msa == alphabet[i]] = i
|
||||||
|
|
||||||
|
# treat all unknown characters as gaps
|
||||||
|
msa[msa > 20] = 20
|
||||||
|
|
||||||
|
ins = np.array(ins, dtype=np.uint8)
|
||||||
|
|
||||||
|
return msa,ins, np.array(taxIDs)
|
||||||
|
|
||||||
|
|
||||||
|
# read and extract xyz coords of N,Ca,C atoms
|
||||||
|
# from a PDB file
|
||||||
|
def parse_pdb(filename, seq=False, lddt_mask=False):
|
||||||
|
lines = open(filename,'r').readlines()
|
||||||
|
if seq:
|
||||||
|
return parse_pdb_lines_w_seq(lines, lddt_mask=lddt_mask)
|
||||||
|
return parse_pdb_lines(lines)
|
||||||
|
|
||||||
|
def parse_pdb_lines_w_seq(lines, lddt_mask=False):
|
||||||
|
|
||||||
|
# indices of residues observed in the structure
|
||||||
|
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
|
||||||
|
pdb_idx_s = [(r[0], int(r[1])) for r in res]
|
||||||
|
idx_s = [int(r[1]) for r in res]
|
||||||
|
plddt = [float(r[3]) for r in res]
|
||||||
|
seq = [ChemData().aa2num[r[2]] if r[2] in ChemData().aa2num.keys() else 20 for r in res]
|
||||||
|
|
||||||
|
# 4 BB + up to 10 SC atoms
|
||||||
|
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
|
||||||
|
for l in lines:
|
||||||
|
if l[:4] != "ATOM":
|
||||||
|
continue
|
||||||
|
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
|
||||||
|
idx = pdb_idx_s.index((chain,resNo))
|
||||||
|
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
|
||||||
|
if tgtatm == atom:
|
||||||
|
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||||
|
break
|
||||||
|
|
||||||
|
# save atom mask
|
||||||
|
mask = np.logical_not(np.isnan(xyz[...,0]))
|
||||||
|
xyz[np.isnan(xyz[...,0])] = 0.0
|
||||||
|
if lddt_mask == True:
|
||||||
|
plddt = np.array(plddt)
|
||||||
|
mask_lddt = np.full_like(mask, False)
|
||||||
|
mask_lddt[plddt > .85, 5:] = True
|
||||||
|
mask_lddt[plddt > .70, :5] = True
|
||||||
|
mask = np.logical_and(mask, mask_lddt)
|
||||||
|
|
||||||
|
return xyz,mask,np.array(idx_s), np.array(seq)
|
||||||
|
|
||||||
|
#'''
|
||||||
|
def parse_pdb_lines(lines):
|
||||||
|
|
||||||
|
# indices of residues observed in the structure
|
||||||
|
res = [(l[21:22].strip(), l[22:26],l[17:20], l[60:66].strip()) for l in lines if l[:4]=="ATOM" and l[12:16].strip() in ["CA", "P"]] # (chain letter, res num, aa)
|
||||||
|
pdb_idx_s = [(r[0], int(r[1])) for r in res]
|
||||||
|
idx_s = [int(r[1]) for r in res]
|
||||||
|
|
||||||
|
# 4 BB + up to 10 SC atoms
|
||||||
|
xyz = np.full((len(idx_s), ChemData().NTOTAL, 3), np.nan, dtype=np.float32)
|
||||||
|
for l in lines:
|
||||||
|
if l[:4] != "ATOM":
|
||||||
|
continue
|
||||||
|
chain, resNo, atom, aa = l[21:22].strip(), int(l[22:26]), l[12:16], l[17:20]
|
||||||
|
idx = pdb_idx_s.index((chain,resNo))
|
||||||
|
for i_atm, tgtatm in enumerate(ChemData().aa2long[ChemData().aa2num[aa]]):
|
||||||
|
if tgtatm == atom:
|
||||||
|
xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
|
||||||
|
break
|
||||||
|
|
||||||
|
# save atom mask
|
||||||
|
mask = np.logical_not(np.isnan(xyz[...,0]))
|
||||||
|
xyz[np.isnan(xyz[...,0])] = 0.0
|
||||||
|
|
||||||
|
return xyz,mask,np.array(idx_s)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_templates(item, params):
|
||||||
|
|
||||||
|
# init FFindexDB of templates
|
||||||
|
### and extract template IDs
|
||||||
|
### present in the DB
|
||||||
|
ffdb = FFindexDB(read_index(params['FFDB']+'_pdb.ffindex'),
|
||||||
|
read_data(params['FFDB']+'_pdb.ffdata'))
|
||||||
|
#ffids = set([i.name for i in ffdb.index])
|
||||||
|
|
||||||
|
# process tabulated hhsearch output to get
|
||||||
|
# matched positions and positional scores
|
||||||
|
infile = params['DIR']+'/hhr/'+item[-2:]+'/'+item+'.atab'
|
||||||
|
hits = []
|
||||||
|
for l in open(infile, "r").readlines():
|
||||||
|
if l[0]=='>':
|
||||||
|
key = l[1:].split()[0]
|
||||||
|
hits.append([key,[],[]])
|
||||||
|
elif "score" in l or "dssp" in l:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
hi = l.split()[:5]+[0.0,0.0,0.0]
|
||||||
|
hits[-1][1].append([int(hi[0]),int(hi[1])])
|
||||||
|
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
|
||||||
|
|
||||||
|
# get per-hit statistics from an .hhr file
|
||||||
|
# (!!! assume that .hhr and .atab have the same hits !!!)
|
||||||
|
# [Probab, E-value, Score, Aligned_cols,
|
||||||
|
# Identities, Similarity, Sum_probs, Template_Neff]
|
||||||
|
lines = open(infile[:-4]+'hhr', "r").readlines()
|
||||||
|
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
|
||||||
|
for i,posi in enumerate(pos):
|
||||||
|
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
|
||||||
|
|
||||||
|
# parse templates from FFDB
|
||||||
|
for hi in hits:
|
||||||
|
#if hi[0] not in ffids:
|
||||||
|
# continue
|
||||||
|
entry = get_entry_by_name(hi[0], ffdb.index)
|
||||||
|
if entry == None:
|
||||||
|
continue
|
||||||
|
data = read_entry_lines(entry, ffdb.data)
|
||||||
|
hi += list(parse_pdb_lines(data))
|
||||||
|
|
||||||
|
# process hits
|
||||||
|
counter = 0
|
||||||
|
xyz,qmap,mask,f0d,f1d,ids = [],[],[],[],[],[]
|
||||||
|
for data in hits:
|
||||||
|
if len(data)<7:
|
||||||
|
continue
|
||||||
|
|
||||||
|
qi,ti = np.array(data[1]).T
|
||||||
|
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
|
||||||
|
ncol = sel1.shape[0]
|
||||||
|
if ncol < 10:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ids.append(data[0])
|
||||||
|
f0d.append(data[3])
|
||||||
|
f1d.append(np.array(data[2])[sel1])
|
||||||
|
xyz.append(data[4][sel2])
|
||||||
|
mask.append(data[5][sel2])
|
||||||
|
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
xyz = np.vstack(xyz).astype(np.float32)
|
||||||
|
mask = np.vstack(mask).astype(bool)
|
||||||
|
qmap = np.vstack(qmap).astype(np.long)
|
||||||
|
f0d = np.vstack(f0d).astype(np.float32)
|
||||||
|
f1d = np.vstack(f1d).astype(np.float32)
|
||||||
|
ids = ids
|
||||||
|
|
||||||
|
return xyz,mask,qmap,f0d,f1d,ids
|
||||||
|
|
||||||
|
def parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=20):
|
||||||
|
# process tabulated hhsearch output to get
|
||||||
|
# matched positions and positional scores
|
||||||
|
hits = []
|
||||||
|
for l in open(atab_fn, "r").readlines():
|
||||||
|
if l[0]=='>':
|
||||||
|
if len(hits) == max_templ:
|
||||||
|
break
|
||||||
|
key = l[1:].split()[0]
|
||||||
|
hits.append([key,[],[]])
|
||||||
|
elif "score" in l or "dssp" in l:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
hi = l.split()[:5]+[0.0,0.0,0.0]
|
||||||
|
hits[-1][1].append([int(hi[0]),int(hi[1])])
|
||||||
|
hits[-1][2].append([float(hi[2]),float(hi[3]),float(hi[4])])
|
||||||
|
|
||||||
|
# get per-hit statistics from an .hhr file
|
||||||
|
# (!!! assume that .hhr and .atab have the same hits !!!)
|
||||||
|
# [Probab, E-value, Score, Aligned_cols,
|
||||||
|
# Identities, Similarity, Sum_probs, Template_Neff]
|
||||||
|
lines = open(hhr_fn, "r").readlines()
|
||||||
|
pos = [i+1 for i,l in enumerate(lines) if l[0]=='>']
|
||||||
|
for i,posi in enumerate(pos[:len(hits)]):
|
||||||
|
hits[i].append([float(s) for s in re.sub('[=%]',' ',lines[posi]).split()[1::2]])
|
||||||
|
|
||||||
|
# parse templates from FFDB
|
||||||
|
for hi in hits:
|
||||||
|
#if hi[0] not in ffids:
|
||||||
|
# continue
|
||||||
|
entry = get_entry_by_name(hi[0], ffdb.index)
|
||||||
|
if entry == None:
|
||||||
|
print ("Failed to find %s in *_pdb.ffindex"%hi[0])
|
||||||
|
continue
|
||||||
|
data = read_entry_lines(entry, ffdb.data)
|
||||||
|
hi += list(parse_pdb_lines_w_seq(data))
|
||||||
|
|
||||||
|
# process hits
|
||||||
|
counter = 0
|
||||||
|
xyz,qmap,mask,f0d,f1d,ids,seq = [],[],[],[],[],[],[]
|
||||||
|
for data in hits:
|
||||||
|
if len(data)<7:
|
||||||
|
continue
|
||||||
|
# print ("Process %s..."%data[0])
|
||||||
|
|
||||||
|
qi,ti = np.array(data[1]).T
|
||||||
|
_,sel1,sel2 = np.intersect1d(ti, data[6], return_indices=True)
|
||||||
|
ncol = sel1.shape[0]
|
||||||
|
if ncol < 10:
|
||||||
|
continue
|
||||||
|
|
||||||
|
ids.append(data[0])
|
||||||
|
f0d.append(data[3])
|
||||||
|
f1d.append(np.array(data[2])[sel1])
|
||||||
|
xyz.append(data[4][sel2])
|
||||||
|
mask.append(data[5][sel2])
|
||||||
|
seq.append(data[-1][sel2])
|
||||||
|
qmap.append(np.stack([qi[sel1]-1,[counter]*ncol],axis=-1))
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
xyz = np.vstack(xyz).astype(np.float32)
|
||||||
|
mask = np.vstack(mask).astype(bool)
|
||||||
|
qmap = np.vstack(qmap).astype(np.int64)
|
||||||
|
f0d = np.vstack(f0d).astype(np.float32)
|
||||||
|
f1d = np.vstack(f1d).astype(np.float32)
|
||||||
|
seq = np.hstack(seq).astype(np.int64)
|
||||||
|
ids = ids
|
||||||
|
|
||||||
|
return torch.from_numpy(xyz), torch.from_numpy(mask), torch.from_numpy(qmap), \
|
||||||
|
torch.from_numpy(f0d), torch.from_numpy(f1d), torch.from_numpy(seq), ids
|
||||||
|
|
||||||
|
def read_templates(qlen, ffdb, hhr_fn, atab_fn, n_templ=10):
|
||||||
|
xyz_t, mask_t, qmap, t1d, seq, ids = parse_templates_raw(ffdb, hhr_fn, atab_fn, max_templ=max(n_templ, 20))
|
||||||
|
ntmplatoms = xyz_t.shape[1]
|
||||||
|
|
||||||
|
npick = min(n_templ, len(ids))
|
||||||
|
if npick < 1: # no templates
|
||||||
|
xyz = torch.full((1,qlen,ChemData().NTOTAL,3),np.nan).float()
|
||||||
|
mask = torch.full((1,qlen,ChemData().NTOTAL),False)
|
||||||
|
t1d = torch.nn.functional.one_hot(torch.full((1, qlen), 20).long(), num_classes=21).float() # all gaps
|
||||||
|
t1d = torch.cat((t1d, torch.zeros((1,qlen,1)).float()), -1)
|
||||||
|
return xyz, mask, t1d
|
||||||
|
|
||||||
|
sample = torch.arange(npick)
|
||||||
|
#
|
||||||
|
xyz = torch.full((npick, qlen, ChemData().NTOTAL, 3), np.nan).float()
|
||||||
|
mask = torch.full((npick, qlen, ChemData().NTOTAL), False)
|
||||||
|
f1d = torch.full((npick, qlen), 20).long()
|
||||||
|
f1d_val = torch.zeros((npick, qlen, 1)).float()
|
||||||
|
#
|
||||||
|
for i, nt in enumerate(sample):
|
||||||
|
sel = torch.where(qmap[:,1] == nt)[0]
|
||||||
|
pos = qmap[sel, 0]
|
||||||
|
xyz[i, pos] = xyz_t[sel]
|
||||||
|
mask[i, pos, :ntmplatoms] = mask_t[sel].bool()
|
||||||
|
f1d[i, pos] = seq[sel]
|
||||||
|
f1d_val[i,pos] = t1d[sel, 2].unsqueeze(-1)
|
||||||
|
xyz[i] = util.center_and_realign_missing(xyz[i], mask[i], seq=f1d[i])
|
||||||
|
|
||||||
|
f1d = torch.nn.functional.one_hot(f1d, num_classes=ChemData().NAATOKENS-1).float()
|
||||||
|
f1d = torch.cat((f1d, f1d_val), dim=-1)
|
||||||
|
|
||||||
|
return xyz, mask, f1d
|
||||||
|
|
||||||
|
|
||||||
|
def clean_sdffile(filename):
|
||||||
|
# lowercase the 2nd letter of the element name (e.g. FE->Fe) so openbabel can parse it correctly
|
||||||
|
lines2 = []
|
||||||
|
with open(filename) as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
num_atoms = int(lines[3][:3])
|
||||||
|
for i in range(len(lines)):
|
||||||
|
if i>=4 and i<4+num_atoms:
|
||||||
|
lines2.append(lines[i][:32]+lines[i][32].lower()+lines[i][33:])
|
||||||
|
else:
|
||||||
|
lines2.append(lines[i])
|
||||||
|
molstring = ''.join(lines2)
|
||||||
|
|
||||||
|
return molstring
|
||||||
|
|
||||||
|
def parse_mol(filename, filetype="mol2", string=False, remove_H=True, find_automorphs=True, generate_conformer: bool = False):
|
||||||
|
"""Parse small molecule ligand.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filename : str
|
||||||
|
filetype : str
|
||||||
|
string : bool
|
||||||
|
If True, `filename` is a string containing the molecule data.
|
||||||
|
remove_H : bool
|
||||||
|
Whether to remove hydrogen atoms.
|
||||||
|
find_automorphs : bool
|
||||||
|
Whether to enumerate atom symmetry permutations.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
obmol : OBMol
|
||||||
|
openbabel molecule object representing the ligand
|
||||||
|
msa : torch.Tensor (N_atoms,) long
|
||||||
|
Integer-encoded "sequence" (atom types) of ligand
|
||||||
|
ins : torch.Tensor (N_atoms,) long
|
||||||
|
Insertion features (all zero) for RF input
|
||||||
|
atom_coords : torch.Tensor (N_symmetry, N_atoms, 3) float
|
||||||
|
Atom coordinates
|
||||||
|
mask : torch.Tensor (N_symmetry, N_atoms) bool
|
||||||
|
Boolean mask for whether atom exists
|
||||||
|
"""
|
||||||
|
obConversion = openbabel.OBConversion()
|
||||||
|
obConversion.SetInFormat(filetype)
|
||||||
|
obmol = openbabel.OBMol()
|
||||||
|
if string:
|
||||||
|
obConversion.ReadString(obmol,filename)
|
||||||
|
elif filetype=='sdf':
|
||||||
|
molstring = clean_sdffile(filename)
|
||||||
|
obConversion.ReadString(obmol,molstring)
|
||||||
|
else:
|
||||||
|
obConversion.ReadFile(obmol,filename)
|
||||||
|
if generate_conformer:
|
||||||
|
builder = openbabel.OBBuilder()
|
||||||
|
builder.Build(obmol)
|
||||||
|
ff = openbabel.OBForceField.FindForceField("mmff94")
|
||||||
|
did_setup = ff.Setup(obmol)
|
||||||
|
if did_setup:
|
||||||
|
ff.FastRotorSearch()
|
||||||
|
ff.GetCoordinates(obmol)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
|
||||||
|
if remove_H:
|
||||||
|
obmol.DeleteHydrogens()
|
||||||
|
# the above sometimes fails to get all the hydrogens
|
||||||
|
i = 1
|
||||||
|
while i < obmol.NumAtoms()+1:
|
||||||
|
if obmol.GetAtom(i).GetAtomicNum()==1:
|
||||||
|
obmol.DeleteAtom(obmol.GetAtom(i))
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
atomtypes = [ChemData().atomnum2atomtype.get(obmol.GetAtom(i).GetAtomicNum(), 'ATM')
|
||||||
|
for i in range(1, obmol.NumAtoms()+1)]
|
||||||
|
msa = torch.tensor([ChemData().aa2num[x] for x in atomtypes])
|
||||||
|
ins = torch.zeros_like(msa)
|
||||||
|
|
||||||
|
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
|
||||||
|
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
|
||||||
|
mask = torch.full(atom_coords.shape[:-1], True) # (1, natoms,)
|
||||||
|
|
||||||
|
if find_automorphs:
|
||||||
|
atom_coords, mask = util.get_automorphs(obmol, atom_coords[0], mask[0])
|
||||||
|
|
||||||
|
return obmol, msa, ins, atom_coords, mask
|
34
rf2aa/data/preprocessing.py
Normal file
34
rf2aa/data/preprocessing.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
import os
|
||||||
|
from hydra import initialize, compose
|
||||||
|
from pathlib import Path
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
#from rf2aa.run_inference import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
|
def make_msa(
|
||||||
|
fasta_file,
|
||||||
|
model_runner
|
||||||
|
):
|
||||||
|
out_dir_base = Path(model_runner.config.output_path)
|
||||||
|
hash = model_runner.config.job_name
|
||||||
|
out_dir = out_dir_base / hash
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
command = model_runner.config.database_params.command
|
||||||
|
search_base = model_runner.config.database_params.sequencedb
|
||||||
|
num_cpus = model_runner.config.database_params.num_cpus
|
||||||
|
ram_gb = model_runner.config.database_params.mem
|
||||||
|
template_database = model_runner.config.database_params.hhdb
|
||||||
|
|
||||||
|
out_a3m = out_dir / "t000_.msa0.a3m"
|
||||||
|
out_atab = out_dir / "t000_.atab"
|
||||||
|
out_hhr = out_dir / "t000_.hhr"
|
||||||
|
if out_a3m.exists() and out_atab.exists() and out_hhr.exists():
|
||||||
|
return out_a3m, out_hhr, out_atab
|
||||||
|
|
||||||
|
search_command = f"./{command} {fasta_file} {out_dir} {num_cpus} {ram_gb} {search_base} {template_database}"
|
||||||
|
print(search_command)
|
||||||
|
_ = subprocess.run(search_command, shell=True)
|
||||||
|
return out_a3m, out_hhr, out_atab
|
||||||
|
|
93
rf2aa/data/protein.py
Normal file
93
rf2aa/data/protein.py
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from rf2aa.data.data_loader import RawInputData
|
||||||
|
from rf2aa.data.data_loader_utils import blank_template, TemplFeaturize
|
||||||
|
from rf2aa.data.parsers import parse_a3m, parse_templates_raw
|
||||||
|
from rf2aa.data.preprocessing import make_msa
|
||||||
|
from rf2aa.util import get_protein_bond_feats
|
||||||
|
|
||||||
|
|
||||||
|
def get_templates(
|
||||||
|
qlen,
|
||||||
|
ffdb,
|
||||||
|
hhr_fn,
|
||||||
|
atab_fn,
|
||||||
|
seqID_cut,
|
||||||
|
n_templ,
|
||||||
|
pick_top: bool = True,
|
||||||
|
offset: int = 0,
|
||||||
|
random_noise: float = 5.0,
|
||||||
|
deterministic: bool = False,
|
||||||
|
):
|
||||||
|
(
|
||||||
|
xyz_parsed,
|
||||||
|
mask_parsed,
|
||||||
|
qmap_parsed,
|
||||||
|
f0d_parsed,
|
||||||
|
f1d_parsed,
|
||||||
|
seq_parsed,
|
||||||
|
ids_parsed,
|
||||||
|
) = parse_templates_raw(ffdb, hhr_fn=hhr_fn, atab_fn=atab_fn)
|
||||||
|
tplt = {
|
||||||
|
"xyz": xyz_parsed.unsqueeze(0),
|
||||||
|
"mask": mask_parsed.unsqueeze(0),
|
||||||
|
"qmap": qmap_parsed.unsqueeze(0),
|
||||||
|
"f0d": f0d_parsed.unsqueeze(0),
|
||||||
|
"f1d": f1d_parsed.unsqueeze(0),
|
||||||
|
"seq": seq_parsed.unsqueeze(0),
|
||||||
|
"ids": ids_parsed,
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
"SEQID": seqID_cut,
|
||||||
|
}
|
||||||
|
return TemplFeaturize(
|
||||||
|
tplt,
|
||||||
|
qlen,
|
||||||
|
params,
|
||||||
|
offset=offset,
|
||||||
|
npick=n_templ,
|
||||||
|
pick_top=pick_top,
|
||||||
|
random_noise=random_noise,
|
||||||
|
deterministic=deterministic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_protein(msa_file, hhr_fn, atab_fn, model_runner):
|
||||||
|
msa, ins, taxIDs = parse_a3m(msa_file)
|
||||||
|
# NOTE: this next line is a bug, but is the way that
|
||||||
|
# the code is written in the original implementation!
|
||||||
|
ins[0] = msa[0]
|
||||||
|
|
||||||
|
L = msa.shape[1]
|
||||||
|
if hhr_fn is None or atab_fn is None:
|
||||||
|
print("No templates provided")
|
||||||
|
xyz_t, t1d, mask_t, _ = blank_template(1, L)
|
||||||
|
else:
|
||||||
|
xyz_t, t1d, mask_t, _ = get_templates(
|
||||||
|
L,
|
||||||
|
model_runner.ffdb,
|
||||||
|
hhr_fn,
|
||||||
|
atab_fn,
|
||||||
|
seqID_cut=model_runner.config.loader_params.seqid,
|
||||||
|
n_templ=model_runner.config.loader_params.n_templ,
|
||||||
|
deterministic=model_runner.deterministic,
|
||||||
|
)
|
||||||
|
|
||||||
|
bond_feats = get_protein_bond_feats(L)
|
||||||
|
chirals = torch.zeros(0, 5)
|
||||||
|
atom_frames = torch.zeros(0, 3, 2)
|
||||||
|
return RawInputData(
|
||||||
|
torch.from_numpy(msa),
|
||||||
|
torch.from_numpy(ins),
|
||||||
|
bond_feats,
|
||||||
|
xyz_t,
|
||||||
|
mask_t,
|
||||||
|
t1d,
|
||||||
|
chirals,
|
||||||
|
atom_frames,
|
||||||
|
taxids=taxIDs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_msa_and_load_protein(fasta_file, model_runner):
|
||||||
|
msa_file, hhr_file, atab_file = make_msa(fasta_file, model_runner)
|
||||||
|
return load_protein(str(msa_file), str(hhr_file), str(atab_file), model_runner)
|
41
rf2aa/data/small_molecule.py
Normal file
41
rf2aa/data/small_molecule.py
Normal file
|
@ -0,0 +1,41 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from rf2aa.data.data_loader import RawInputData
|
||||||
|
from rf2aa.data.data_loader_utils import blank_template
|
||||||
|
from rf2aa.data.parsers import parse_mol
|
||||||
|
from rf2aa.kinematics import get_chirals
|
||||||
|
from rf2aa.util import get_bond_feats, get_nxgraph, get_atom_frames
|
||||||
|
|
||||||
|
|
||||||
|
def load_small_molecule(input_file, input_type, model_runner):
|
||||||
|
if input_type == "smiles":
|
||||||
|
is_string = True
|
||||||
|
else:
|
||||||
|
is_string = False
|
||||||
|
|
||||||
|
obmol, msa, ins, xyz, mask = parse_mol(
|
||||||
|
input_file, filetype=input_type, string=is_string, generate_conformer=True
|
||||||
|
)
|
||||||
|
return compute_features_from_obmol(obmol, msa, xyz, model_runner)
|
||||||
|
|
||||||
|
def compute_features_from_obmol(obmol, msa, xyz, model_runner):
|
||||||
|
L = msa.shape[0]
|
||||||
|
ins = torch.zeros_like(msa)
|
||||||
|
bond_feats = get_bond_feats(obmol)
|
||||||
|
|
||||||
|
xyz_t, t1d, mask_t, _ = blank_template(
|
||||||
|
model_runner.config.loader_params.n_templ,
|
||||||
|
L,
|
||||||
|
deterministic=model_runner.deterministic,
|
||||||
|
)
|
||||||
|
chirals = get_chirals(obmol, xyz[0])
|
||||||
|
G = get_nxgraph(obmol)
|
||||||
|
atom_frames = get_atom_frames(msa, G)
|
||||||
|
msa, ins = msa[None], ins[None]
|
||||||
|
return RawInputData(
|
||||||
|
msa, ins, bond_feats, xyz_t, mask_t, t1d, chirals, atom_frames, taxids=None
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_leaving_atoms(input, is_leaving):
|
||||||
|
keep = ~is_leaving
|
||||||
|
return input.keep_features(keep)
|
91
rf2aa/ffindex.py
Normal file
91
rf2aa/ffindex.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# https://raw.githubusercontent.com/ahcm/ffindex/master/python/ffindex.py
|
||||||
|
|
||||||
|
'''
|
||||||
|
Created on Apr 30, 2014
|
||||||
|
|
||||||
|
@author: meiermark
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import mmap
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
FFindexEntry = namedtuple("FFindexEntry", "name, offset, length")
|
||||||
|
|
||||||
|
|
||||||
|
def read_index(ffindex_filename):
|
||||||
|
entries = []
|
||||||
|
|
||||||
|
fh = open(ffindex_filename, "r")
|
||||||
|
for line in fh:
|
||||||
|
tokens = line.split("\t")
|
||||||
|
entries.append(FFindexEntry(tokens[0], int(tokens[1]), int(tokens[2])))
|
||||||
|
fh.close()
|
||||||
|
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
|
def read_data(ffdata_filename):
|
||||||
|
fh = open(ffdata_filename, "rb")
|
||||||
|
data = mmap.mmap(fh.fileno(), 0, access=mmap.ACCESS_READ)
|
||||||
|
fh.close()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_entry_by_name(name, index):
|
||||||
|
#TODO: bsearch
|
||||||
|
for entry in index:
|
||||||
|
if(name == entry.name):
|
||||||
|
return entry
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def read_entry_lines(entry, data):
|
||||||
|
lines = data[entry.offset:entry.offset + entry.length - 1].decode("utf-8").split("\n")
|
||||||
|
return lines
|
||||||
|
|
||||||
|
|
||||||
|
def read_entry_data(entry, data):
|
||||||
|
return data[entry.offset:entry.offset + entry.length - 1]
|
||||||
|
|
||||||
|
|
||||||
|
def write_entry(entries, data_fh, entry_name, offset, data):
|
||||||
|
data_fh.write(data[:-1])
|
||||||
|
data_fh.write(bytearray(1))
|
||||||
|
|
||||||
|
entry = FFindexEntry(entry_name, offset, len(data))
|
||||||
|
entries.append(entry)
|
||||||
|
|
||||||
|
return offset + len(data)
|
||||||
|
|
||||||
|
|
||||||
|
def write_entry_with_file(entries, data_fh, entry_name, offset, file_name):
|
||||||
|
with open(file_name, "rb") as fh:
|
||||||
|
data = bytearray(fh.read())
|
||||||
|
return write_entry(entries, data_fh, entry_name, offset, data)
|
||||||
|
|
||||||
|
|
||||||
|
def finish_db(entries, ffindex_filename, data_fh):
|
||||||
|
data_fh.close()
|
||||||
|
write_entries_to_db(entries, ffindex_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def write_entries_to_db(entries, ffindex_filename):
|
||||||
|
sorted(entries, key=lambda x: x.name)
|
||||||
|
index_fh = open(ffindex_filename, "w")
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
index_fh.write("{name:.64}\t{offset}\t{length}\n".format(name=entry.name, offset=entry.offset, length=entry.length))
|
||||||
|
|
||||||
|
index_fh.close()
|
||||||
|
|
||||||
|
|
||||||
|
def write_entry_to_file(entry, data, file):
|
||||||
|
lines = read_lines(entry, data)
|
||||||
|
|
||||||
|
fh = open(file, "w")
|
||||||
|
for line in lines:
|
||||||
|
fh.write(line+"\n")
|
||||||
|
fh.close()
|
311
rf2aa/kinematics.py
Normal file
311
rf2aa/kinematics.py
Normal file
|
@ -0,0 +1,311 @@
|
||||||
|
from itertools import permutations
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from icecream import ic
|
||||||
|
from openbabel import openbabel
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
PARAMS = {
|
||||||
|
'DMIN':1,
|
||||||
|
'DMID':4,
|
||||||
|
'DMAX':20.0,
|
||||||
|
'DBINS1':30,
|
||||||
|
'DBINS2':30,
|
||||||
|
'ABINS':36,
|
||||||
|
'USE_CB':False
|
||||||
|
}
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
def get_pair_dist(a, b):
|
||||||
|
"""calculate pair distances between two sets of points
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
a,b : pytorch tensors of shape [batch,nres,3]
|
||||||
|
store Cartesian coordinates of two sets of atoms
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dist : pytorch tensor of shape [batch,nres,nres]
|
||||||
|
stores paitwise distances between atoms in a and b
|
||||||
|
"""
|
||||||
|
|
||||||
|
dist = torch.cdist(a, b, p=2)
|
||||||
|
return dist
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
def get_ang(a, b, c, eps=1e-6):
|
||||||
|
"""calculate planar angles for all consecutive triples (a[i],b[i],c[i])
|
||||||
|
from Cartesian coordinates of three sets of atoms a,b,c
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
a,b,c : pytorch tensors of shape [batch,nres,3]
|
||||||
|
store Cartesian coordinates of three sets of atoms
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ang : pytorch tensor of shape [batch,nres]
|
||||||
|
stores resulting planar angles
|
||||||
|
"""
|
||||||
|
v = a - b
|
||||||
|
w = c - b
|
||||||
|
vn = v / (torch.norm(v, dim=-1, keepdim=True)+eps)
|
||||||
|
wn = w / (torch.norm(w, dim=-1, keepdim=True)+eps)
|
||||||
|
vw = torch.sum(vn*wn, dim=-1)
|
||||||
|
|
||||||
|
return torch.acos(torch.clamp(vw,-0.999,0.999))
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
def get_dih(a, b, c, d, eps=1e-6):
|
||||||
|
"""calculate dihedral angles for all consecutive quadruples (a[i],b[i],c[i],d[i])
|
||||||
|
given Cartesian coordinates of four sets of atoms a,b,c,d
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
a,b,c,d : pytorch tensors of shape [batch,nres,3]
|
||||||
|
store Cartesian coordinates of four sets of atoms
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dih : pytorch tensor of shape [batch,nres]
|
||||||
|
stores resulting dihedrals
|
||||||
|
"""
|
||||||
|
b0 = a - b
|
||||||
|
b1 = c - b
|
||||||
|
b2 = d - c
|
||||||
|
|
||||||
|
b1n = b1 / (torch.norm(b1, dim=-1, keepdim=True) + eps)
|
||||||
|
|
||||||
|
v = b0 - torch.sum(b0*b1n, dim=-1, keepdim=True)*b1n
|
||||||
|
w = b2 - torch.sum(b2*b1n, dim=-1, keepdim=True)*b1n
|
||||||
|
|
||||||
|
x = torch.sum(v*w, dim=-1)
|
||||||
|
y = torch.sum(torch.cross(b1n,v,dim=-1)*w, dim=-1)
|
||||||
|
|
||||||
|
return torch.atan2(y+eps, x+eps)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def generate_Cbeta(N,Ca,C):
|
||||||
|
# recreate Cb given N,Ca,C
|
||||||
|
b = Ca - N
|
||||||
|
c = C - Ca
|
||||||
|
a = torch.cross(b, c, dim=-1)
|
||||||
|
#Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
||||||
|
# fd: below matches sidechain generator (=Rosetta params)
|
||||||
|
Cb = -0.57910144*a + 0.5689693*b - 0.5441217*c + Ca
|
||||||
|
|
||||||
|
return Cb
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def xyz_to_c6d(xyz, params=PARAMS):
|
||||||
|
"""convert cartesian coordinates into 2d distance
|
||||||
|
and orientation maps
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
xyz : pytorch tensor of shape [batch,nres,3,3]
|
||||||
|
stores Cartesian coordinates of backbone N,Ca,C atoms
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
c6d : pytorch tensor of shape [batch,nres,nres,4]
|
||||||
|
stores stacked dist,omega,theta,phi 2D maps
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch = xyz.shape[0]
|
||||||
|
nres = xyz.shape[1]
|
||||||
|
|
||||||
|
# three anchor atoms
|
||||||
|
N = xyz[:,:,0]
|
||||||
|
Ca = xyz[:,:,1]
|
||||||
|
C = xyz[:,:,2]
|
||||||
|
|
||||||
|
# recreate Cb given N,Ca,C
|
||||||
|
Cb = generate_Cbeta(N,Ca,C)
|
||||||
|
|
||||||
|
# 6d coordinates order: (dist,omega,theta,phi)
|
||||||
|
c6d = torch.zeros([batch,nres,nres,4],dtype=xyz.dtype,device=xyz.device)
|
||||||
|
|
||||||
|
if params['USE_CB']:
|
||||||
|
dist = get_pair_dist(Cb,Cb)
|
||||||
|
else:
|
||||||
|
dist = get_pair_dist(Ca,Ca)
|
||||||
|
|
||||||
|
dist[torch.isnan(dist)] = 999.9
|
||||||
|
c6d[...,0] = dist + 999.9*torch.eye(nres,device=xyz.device)[None,...]
|
||||||
|
b,i,j = torch.where(c6d[...,0]<params['DMAX'])
|
||||||
|
|
||||||
|
c6d[b,i,j,torch.full_like(b,1)] = get_dih(Ca[b,i], Cb[b,i], Cb[b,j], Ca[b,j])
|
||||||
|
c6d[b,i,j,torch.full_like(b,2)] = get_dih(N[b,i], Ca[b,i], Cb[b,i], Cb[b,j])
|
||||||
|
c6d[b,i,j,torch.full_like(b,3)] = get_ang(Ca[b,i], Cb[b,i], Cb[b,j])
|
||||||
|
|
||||||
|
# fix long-range distances
|
||||||
|
c6d[...,0][c6d[...,0]>=params['DMAX']] = 999.9
|
||||||
|
c6d = torch.nan_to_num(c6d)
|
||||||
|
|
||||||
|
return c6d
|
||||||
|
|
||||||
|
def xyz_to_t2d(xyz_t, mask, params=PARAMS):
|
||||||
|
"""convert template cartesian coordinates into 2d distance
|
||||||
|
and orientation maps
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
xyz_t : pytorch tensor of shape [batch,templ,nres,3,3]
|
||||||
|
stores Cartesian coordinates of template backbone N,Ca,C atoms
|
||||||
|
mask : pytorch tensor [batch,templ,nres,nres]
|
||||||
|
indicates whether valid residue pairs or not
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
t2d : pytorch tensor of shape [batch,nres,nres,37+6+3]
|
||||||
|
stores stacked dist,omega,theta,phi 2D maps
|
||||||
|
"""
|
||||||
|
B, T, L = xyz_t.shape[:3]
|
||||||
|
c6d = xyz_to_c6d(xyz_t[:,:,:,:3].view(B*T,L,3,3), params=params)
|
||||||
|
c6d = c6d.view(B, T, L, L, 4)
|
||||||
|
|
||||||
|
# dist to one-hot encoded
|
||||||
|
mask = mask[...,None]
|
||||||
|
dist = dist_to_onehot(c6d[...,0], params)*mask
|
||||||
|
orien = torch.cat((torch.sin(c6d[...,1:]), torch.cos(c6d[...,1:])), dim=-1)*mask # (B, T, L, L, 6)
|
||||||
|
#
|
||||||
|
t2d = torch.cat((dist, orien, mask), dim=-1)
|
||||||
|
return t2d
|
||||||
|
|
||||||
|
def xyz_to_bbtor(xyz, params=PARAMS):
|
||||||
|
batch = xyz.shape[0]
|
||||||
|
nres = xyz.shape[1]
|
||||||
|
|
||||||
|
# three anchor atoms
|
||||||
|
N = xyz[:,:,0]
|
||||||
|
Ca = xyz[:,:,1]
|
||||||
|
C = xyz[:,:,2]
|
||||||
|
|
||||||
|
# recreate Cb given N,Ca,C
|
||||||
|
next_N = torch.roll(N, -1, dims=1)
|
||||||
|
prev_C = torch.roll(C, 1, dims=1)
|
||||||
|
phi = get_dih(prev_C, N, Ca, C)
|
||||||
|
psi = get_dih(N, Ca, C, next_N)
|
||||||
|
#
|
||||||
|
phi[:,0] = 0.0
|
||||||
|
psi[:,-1] = 0.0
|
||||||
|
#
|
||||||
|
astep = 2.0*np.pi / params['ABINS']
|
||||||
|
phi_bin = torch.round((phi+np.pi-astep/2)/astep)
|
||||||
|
psi_bin = torch.round((psi+np.pi-astep/2)/astep)
|
||||||
|
return torch.stack([phi_bin, psi_bin], axis=-1).long()
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
def dist_to_onehot(dist, params=PARAMS):
|
||||||
|
db = dist_to_bins(dist, params)
|
||||||
|
dist = torch.nn.functional.one_hot(db, num_classes=params['DBINS1'] + params['DBINS2']+1).float()
|
||||||
|
return dist
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
def dist_to_bins(dist,params=PARAMS):
|
||||||
|
"""bin 2d distance maps
|
||||||
|
"""
|
||||||
|
dist[torch.isnan(dist)] = 999.9
|
||||||
|
dstep1 = (params['DMID'] - params['DMIN']) / params['DBINS1']
|
||||||
|
dstep2 = (params['DMAX'] - params['DMID']) / params['DBINS2']
|
||||||
|
dbins = torch.cat([
|
||||||
|
torch.linspace(params['DMIN']+dstep1, params['DMID'], params['DBINS1'],
|
||||||
|
dtype=dist.dtype,device=dist.device),
|
||||||
|
torch.linspace(params['DMID']+dstep2, params['DMAX'], params['DBINS2'],
|
||||||
|
dtype=dist.dtype,device=dist.device),
|
||||||
|
])
|
||||||
|
db = torch.bucketize(dist.contiguous(),dbins).long()
|
||||||
|
|
||||||
|
return db
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS):
|
||||||
|
"""bin 2d distance and orientation maps
|
||||||
|
"""
|
||||||
|
|
||||||
|
db = dist_to_bins(c6d[...,0], params) # all dist < DMIN are in bin 0
|
||||||
|
|
||||||
|
astep = 2.0*np.pi / params['ABINS']
|
||||||
|
ob = torch.round((c6d[...,1]+np.pi-astep/2)/astep)
|
||||||
|
tb = torch.round((c6d[...,2]+np.pi-astep/2)/astep)
|
||||||
|
pb = torch.round((c6d[...,3]-astep/2)/astep)
|
||||||
|
|
||||||
|
# synchronize no-contact bins
|
||||||
|
params['DBINS'] = params['DBINS1'] + params['DBINS2']
|
||||||
|
ob[db==params['DBINS']] = params['ABINS']
|
||||||
|
tb[db==params['DBINS']] = params['ABINS']
|
||||||
|
pb[db==params['DBINS']] = params['ABINS']//2
|
||||||
|
|
||||||
|
if negative:
|
||||||
|
db = torch.where(same_chain.bool(), db.long(), params['DBINS'])
|
||||||
|
ob = torch.where(same_chain.bool(), ob.long(), params['ABINS'])
|
||||||
|
tb = torch.where(same_chain.bool(), tb.long(), params['ABINS'])
|
||||||
|
pb = torch.where(same_chain.bool(), pb.long(), params['ABINS']//2)
|
||||||
|
|
||||||
|
return torch.stack([db,ob,tb,pb],axis=-1).long()
|
||||||
|
|
||||||
|
def standardize_dihedral_retain_first(a,b,c,d):
|
||||||
|
isomorphisms = [(a,b,c,d), (a,c,b,d)]
|
||||||
|
return sorted(isomorphisms)[0]
|
||||||
|
|
||||||
|
def get_chirals(obmol, xyz):
|
||||||
|
'''
|
||||||
|
get all quadruples of atoms forming chiral centers and the expected ideal pseudodihedral between them
|
||||||
|
'''
|
||||||
|
stereo = openbabel.OBStereoFacade(obmol)
|
||||||
|
angle = np.arcsin(1/3**0.5)
|
||||||
|
chiral_idx_set = set()
|
||||||
|
for i in range(obmol.NumAtoms()):
|
||||||
|
if not stereo.HasTetrahedralStereo(i):
|
||||||
|
continue
|
||||||
|
si = stereo.GetTetrahedralStereo(i)
|
||||||
|
config = si.GetConfig()
|
||||||
|
|
||||||
|
o = config.center
|
||||||
|
c = config.from_or_towards
|
||||||
|
i,j,k = list(config.refs)
|
||||||
|
for a, b, c in permutations((c,i,j,k), 3):
|
||||||
|
chiral_idx_set.add(standardize_dihedral_retain_first(o,a,b,c))
|
||||||
|
|
||||||
|
chiral_idx = list(chiral_idx_set)
|
||||||
|
chiral_idx.sort()
|
||||||
|
chiral_idx = torch.tensor(chiral_idx, dtype=torch.float32)
|
||||||
|
chiral_idx = chiral_idx[(chiral_idx<obmol.NumAtoms()).all(dim=-1)]
|
||||||
|
|
||||||
|
if chiral_idx.numel() == 0:
|
||||||
|
return torch.zeros((0,5))
|
||||||
|
|
||||||
|
dih = get_dih(*xyz[chiral_idx.long()].split(split_size=1,dim=1))[:,0]
|
||||||
|
chirals = torch.nn.functional.pad(chiral_idx, (0, 1), mode='constant', value=angle)
|
||||||
|
chirals[dih<0.0,-1] *= -1
|
||||||
|
return chirals
|
||||||
|
|
||||||
|
def get_atomize_protein_chirals(residues_atomize, lig_xyz, residue_atomize_mask, bond_feats):
|
||||||
|
"""
|
||||||
|
Enumerate chiral centers in residues and provide features for chiral centers
|
||||||
|
"""
|
||||||
|
angle = np.arcsin(1/3**0.5) # perfect tetrahedral geometry
|
||||||
|
chiral_atoms = ChemData().aachirals[residues_atomize]
|
||||||
|
ra = residue_atomize_mask.nonzero()
|
||||||
|
r,a = ra.T
|
||||||
|
|
||||||
|
chiral_atoms = chiral_atoms[r,a].nonzero().squeeze(1) #num_chiral_centers
|
||||||
|
num_chiral_centers = chiral_atoms.shape[0]
|
||||||
|
chiral_bonds = bond_feats[chiral_atoms] # find bonds to each chiral atom
|
||||||
|
chiral_bonds_idx = chiral_bonds.nonzero() # find indices of each bonded neighbor to chiral atom
|
||||||
|
# in practice all chiral atoms in proteins have 3 heavy atom neighbors, so reshape to 3
|
||||||
|
chiral_bonds_idx = chiral_bonds_idx.reshape(num_chiral_centers, 3, 2)
|
||||||
|
|
||||||
|
chirals = torch.zeros((num_chiral_centers, 5))
|
||||||
|
chirals[:,0] = chiral_atoms.long()
|
||||||
|
chirals[:, 1:-1] = chiral_bonds_idx[...,-1].long()
|
||||||
|
chirals[:, -1] = angle
|
||||||
|
n = chirals.shape[0]
|
||||||
|
if n>0:
|
||||||
|
chirals = chirals.repeat(3,1).float()
|
||||||
|
chirals[n:2*n,1:-1] = torch.roll(chirals[n:2*n,1:-1],1,1)
|
||||||
|
chirals[2*n: ,1:-1] = torch.roll(chirals[2*n: ,1:-1],2,1)
|
||||||
|
dih = get_dih(*lig_xyz[chirals[:,:4].long()].split(split_size=1,dim=1))[:,0]
|
||||||
|
chirals[dih<0.0,-1] = -angle
|
||||||
|
else:
|
||||||
|
chirals = torch.zeros((0,5))
|
||||||
|
return chirals
|
BIN
rf2aa/ligands.json.gz
Normal file
BIN
rf2aa/ligands.json.gz
Normal file
Binary file not shown.
240
rf2aa/loss/loss.py
Normal file
240
rf2aa/loss/loss.py
Normal file
|
@ -0,0 +1,240 @@
|
||||||
|
import torch
|
||||||
|
from icecream import ic
|
||||||
|
import numpy as np
|
||||||
|
import scipy
|
||||||
|
import pandas as pd
|
||||||
|
import networkx as nx
|
||||||
|
|
||||||
|
from rf2aa.util import (
|
||||||
|
is_protein,
|
||||||
|
rigid_from_3_points,
|
||||||
|
is_nucleic,
|
||||||
|
find_all_paths_of_length_n,
|
||||||
|
find_all_rigid_groups
|
||||||
|
)
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
from rf2aa.kinematics import get_dih, get_ang
|
||||||
|
from rf2aa.scoring import HbHybType
|
||||||
|
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
#fd more efficient LJ loss
|
||||||
|
class LJLoss(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def ljVdV(deltas, sigma, epsilon, lj_lin, eps):
|
||||||
|
# deltas - (N,natompair,3)
|
||||||
|
N = deltas.shape[0]
|
||||||
|
|
||||||
|
dist = torch.sqrt( torch.sum ( torch.square( deltas ), dim=-1 ) + eps )
|
||||||
|
linpart = dist<lj_lin*sigma[None]
|
||||||
|
deff = dist.clone()
|
||||||
|
deff[linpart] = lj_lin*sigma.repeat(N,1)[linpart]
|
||||||
|
sd = sigma / deff
|
||||||
|
sd2 = sd*sd
|
||||||
|
sd6 = sd2 * sd2 * sd2
|
||||||
|
sd12 = sd6 * sd6
|
||||||
|
ljE = epsilon * (sd12 - 2 * sd6)
|
||||||
|
ljE[linpart] += epsilon.repeat(N,1)[linpart] * (
|
||||||
|
-12 * sd12[linpart]/deff[linpart] + 12 * sd6[linpart]/deff[linpart]
|
||||||
|
) * (dist[linpart]-deff[linpart])
|
||||||
|
|
||||||
|
# works for linpart too
|
||||||
|
dljEdd_over_r = epsilon * (-12 * sd12/deff + 12 * sd6/deff) / (dist)
|
||||||
|
|
||||||
|
return ljE.sum(dim=-1), dljEdd_over_r
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx, xs, seq, aamask, bond_feats, dist_matrix, ljparams, ljcorr, num_bonds,
|
||||||
|
lj_lin=0.75, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
|
||||||
|
eps=1e-8, training=True
|
||||||
|
):
|
||||||
|
N, L, A = xs.shape[:3]
|
||||||
|
assert (N==1) # see comment below
|
||||||
|
|
||||||
|
ds_res = torch.sqrt( torch.sum ( torch.square(
|
||||||
|
xs.detach()[:,:,None,1,:]-xs.detach()[:,None,:,1,:]), dim=-1 ))
|
||||||
|
rs = torch.triu_indices(L,L,0, device=xs.device)
|
||||||
|
ri,rj = rs[0],rs[1]
|
||||||
|
|
||||||
|
# batch during inference for huge systems
|
||||||
|
BATCHSIZE = 65536//N
|
||||||
|
|
||||||
|
ljval = 0
|
||||||
|
dljEdx = torch.zeros_like(xs, dtype=torch.float)
|
||||||
|
|
||||||
|
for i_batch in range((len(ri)-1)//BATCHSIZE + 1):
|
||||||
|
idx = torch.arange(
|
||||||
|
i_batch*BATCHSIZE,
|
||||||
|
min( (i_batch+1)*BATCHSIZE, len(ri)),
|
||||||
|
device=xs.device
|
||||||
|
)
|
||||||
|
rii,rjj = ri[idx],rj[idx] # residue pairs we consider
|
||||||
|
|
||||||
|
ridx,ai,aj = (
|
||||||
|
aamask[seq[rii]][:,:,None]*aamask[seq[rjj]][:,None,:]
|
||||||
|
).nonzero(as_tuple=True)
|
||||||
|
|
||||||
|
deltas = xs[:,rii,:,None,:]-xs[:,rjj,None,:,:] # N,BATCHSIZE,Natm,Natm,3
|
||||||
|
seqi,seqj = seq[rii[ridx]], seq[rjj[ridx]]
|
||||||
|
|
||||||
|
mask = torch.ones_like(ridx, dtype=torch.bool) # are atoms defined?
|
||||||
|
|
||||||
|
# mask out atom pairs from too-distant residues (C-alpha dist > 24A)
|
||||||
|
ca_dist = torch.linalg.norm(deltas[:,:,1,1],dim=-1)
|
||||||
|
mask *= (ca_dist[:,ridx]<24).any(dim=0) # will work for batch>1 but very inefficient
|
||||||
|
|
||||||
|
intrares = (rii[ridx]==rjj[ridx])
|
||||||
|
mask[intrares*(ai<aj)] = False # upper tri (atoms)
|
||||||
|
|
||||||
|
## count-pair
|
||||||
|
# a) intra-protein
|
||||||
|
mask[intrares] *= num_bonds[seqi[intrares],ai[intrares],aj[intrares]]>=4
|
||||||
|
pepbondres = ri[ridx]+1==rj[ridx]
|
||||||
|
mask[pepbondres] *= (
|
||||||
|
num_bonds[seqi[pepbondres],ai[pepbondres],2]
|
||||||
|
+ num_bonds[seqj[pepbondres],0,aj[pepbondres]]
|
||||||
|
+ 1) >=4
|
||||||
|
|
||||||
|
# b) intra-ligand
|
||||||
|
atommask = (ai==1)*(aj==1)
|
||||||
|
dist_matrix = torch.nan_to_num(dist_matrix, posinf=4.0) #NOTE: need to run nan_to_num to remove infinities
|
||||||
|
resmask = (dist_matrix[0,rii,rjj] >= 4) # * will only work for batch=1
|
||||||
|
mask[atommask] *= resmask[ ridx[atommask] ]
|
||||||
|
|
||||||
|
# c) protein/ligand
|
||||||
|
##fd NOTE1: changed 6->5 in masking (atom 5 is CG which should always be 4+ bonds away from connected atom)
|
||||||
|
##fd NOTE2: this does NOT work correctly for nucleic acids
|
||||||
|
##fd for NAs atoms 0-4 are masked, but also 5,7,8 and 9 should be masked!
|
||||||
|
bbatommask = (ai<5)*(aj<5)
|
||||||
|
resmask = (bond_feats[0,rii,rjj] != 6) # * will only work for batch=1
|
||||||
|
mask[bbatommask] *= resmask[ ridx[bbatommask] ]
|
||||||
|
|
||||||
|
# apply mask. only interactions to be scored remain
|
||||||
|
ai,aj,seqi,seqj,ridx = ai[mask],aj[mask],seqi[mask],seqj[mask],ridx[mask]
|
||||||
|
deltas = deltas[:,ridx,ai,aj]
|
||||||
|
|
||||||
|
# hbond correction
|
||||||
|
use_hb_dis = (
|
||||||
|
ljcorr[seqi,ai,0]*ljcorr[seqj,aj,1]
|
||||||
|
+ ljcorr[seqi,ai,1]*ljcorr[seqj,aj,0] ).nonzero()
|
||||||
|
use_ohdon_dis = ( # OH are both donors & acceptors
|
||||||
|
ljcorr[seqi,ai,0]*ljcorr[seqi,ai,1]*ljcorr[seqj,aj,0]
|
||||||
|
+ljcorr[seqi,ai,0]*ljcorr[seqj,aj,0]*ljcorr[seqj,aj,1]
|
||||||
|
).nonzero()
|
||||||
|
use_hb_hdis = (
|
||||||
|
ljcorr[seqi,ai,2]*ljcorr[seqj,aj,1]
|
||||||
|
+ljcorr[seqi,ai,1]*ljcorr[seqj,aj,2]
|
||||||
|
).nonzero()
|
||||||
|
|
||||||
|
# disulfide correction
|
||||||
|
potential_disulf = (ljcorr[seqi,ai,3]*ljcorr[seqj,aj,3] ).nonzero()
|
||||||
|
|
||||||
|
ljrs = ljparams[seqi,ai,0] + ljparams[seqj,aj,0]
|
||||||
|
ljrs[use_hb_dis] = lj_hb_dis
|
||||||
|
ljrs[use_ohdon_dis] = lj_OHdon_dis
|
||||||
|
ljrs[use_hb_hdis] = lj_hbond_hdis
|
||||||
|
|
||||||
|
ljss = torch.sqrt( ljparams[seqi,ai,1] * ljparams[seqj,aj,1] + eps )
|
||||||
|
ljss [potential_disulf] = 0.0
|
||||||
|
|
||||||
|
natoms = torch.sum(aamask[seq])
|
||||||
|
ljval_i,dljEdd_i = LJLoss.ljVdV(deltas,ljrs,ljss,lj_lin,eps)
|
||||||
|
|
||||||
|
ljval += ljval_i / natoms
|
||||||
|
|
||||||
|
# sum per-atom-pair grads into per-atom grads
|
||||||
|
# note this is stochastic op on GPU
|
||||||
|
idxI,idxJ = rii[ridx]*A + ai, rjj[ridx]*A + aj
|
||||||
|
|
||||||
|
dljEdx.view(N,-1,3).index_add_(1, idxI, dljEdd_i[...,None]*deltas, alpha=1.0/natoms)
|
||||||
|
dljEdx.view(N,-1,3).index_add_(1, idxJ, dljEdd_i[...,None]*deltas, alpha=-1.0/natoms)
|
||||||
|
|
||||||
|
ctx.save_for_backward(dljEdx)
|
||||||
|
|
||||||
|
return ljval
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
"""
|
||||||
|
In the backward pass we receive a Tensor containing the gradient of the loss
|
||||||
|
with respect to the output, and we need to compute the gradient of the loss
|
||||||
|
with respect to the input.
|
||||||
|
"""
|
||||||
|
dljEdx, = ctx.saved_tensors
|
||||||
|
return (
|
||||||
|
grad_output * dljEdx,
|
||||||
|
None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rosetta-like version of LJ (fa_atr+fa_rep)
|
||||||
|
# lj_lin is switch from linear to 12-6. Smaller values more sharply penalize clashes
|
||||||
|
def calc_lj(
|
||||||
|
seq, xs, aamask, bond_feats, dist_matrix, ljparams, ljcorr, num_bonds,
|
||||||
|
lj_lin=0.75, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
|
||||||
|
lj_maxrad=-1.0, eps=1e-8,
|
||||||
|
training=True
|
||||||
|
):
|
||||||
|
lj = LJLoss.apply
|
||||||
|
ljval = lj(
|
||||||
|
xs, seq, aamask, bond_feats, dist_matrix, ljparams, ljcorr, num_bonds,
|
||||||
|
lj_lin, lj_hb_dis, lj_OHdon_dis, lj_hbond_hdis, eps, training)
|
||||||
|
|
||||||
|
return ljval
|
||||||
|
|
||||||
|
|
||||||
|
def calc_chiral_loss(pred, chirals):
|
||||||
|
"""
|
||||||
|
calculate error in dihedral angles for chiral atoms
|
||||||
|
Input:
|
||||||
|
- pred: predicted coords (B, L, :, 3)
|
||||||
|
- chirals: True coords (B, nchiral, 5), skip if 0 chiral sites, 5 dimension are indices for 4 atoms that make dihedral and the ideal angle they should form
|
||||||
|
Output:
|
||||||
|
- mean squared error of chiral angles
|
||||||
|
"""
|
||||||
|
if chirals.shape[1] == 0:
|
||||||
|
return torch.tensor(0.0, device=pred.device)
|
||||||
|
chiral_dih = pred[:, chirals[..., :-1].long(), 1]
|
||||||
|
pred_dih = get_dih(chiral_dih[...,0, :], chiral_dih[...,1, :], chiral_dih[...,2, :], chiral_dih[...,3, :]) # n_symm, b, n, 36, 3
|
||||||
|
l = torch.square(pred_dih-chirals[...,-1]).mean()
|
||||||
|
return l
|
||||||
|
|
||||||
|
@torch.enable_grad()
|
||||||
|
def calc_lj_grads(
|
||||||
|
seq, xyz, alpha, toaa, bond_feats, dist_matrix,
|
||||||
|
aamask, ljparams, ljcorr, num_bonds,
|
||||||
|
lj_lin=0.85, lj_hb_dis=3.0, lj_OHdon_dis=2.6, lj_hbond_hdis=1.75,
|
||||||
|
lj_maxrad=-1.0, eps=1e-8
|
||||||
|
):
|
||||||
|
xyz.requires_grad_(True)
|
||||||
|
alpha.requires_grad_(True)
|
||||||
|
_, xyzaa = toaa(seq, xyz, alpha)
|
||||||
|
Elj = calc_lj(
|
||||||
|
seq[0],
|
||||||
|
xyzaa[...,:3],
|
||||||
|
aamask,
|
||||||
|
bond_feats,
|
||||||
|
dist_matrix,
|
||||||
|
ljparams,
|
||||||
|
ljcorr,
|
||||||
|
num_bonds,
|
||||||
|
lj_lin,
|
||||||
|
lj_hb_dis,
|
||||||
|
lj_OHdon_dis,
|
||||||
|
lj_hbond_hdis,
|
||||||
|
lj_maxrad,
|
||||||
|
eps
|
||||||
|
)
|
||||||
|
return torch.autograd.grad(Elj, (xyz,alpha))
|
||||||
|
|
||||||
|
@torch.enable_grad()
|
||||||
|
def calc_chiral_grads(xyz, chirals):
|
||||||
|
xyz.requires_grad_(True)
|
||||||
|
l = calc_chiral_loss(xyz, chirals)
|
||||||
|
if l.item() == 0.0:
|
||||||
|
return (torch.zeros(xyz.shape, device=xyz.device),) # autograd returns a tuple..
|
||||||
|
return torch.autograd.grad(l, xyz)
|
418
rf2aa/model/RoseTTAFoldModel.py
Normal file
418
rf2aa/model/RoseTTAFoldModel.py
Normal file
|
@ -0,0 +1,418 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import assertpy
|
||||||
|
from assertpy import assert_that
|
||||||
|
from icecream import ic
|
||||||
|
from rf2aa.model.layers.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, recycling_factory
|
||||||
|
from rf2aa.model.Track_module import IterativeSimulator
|
||||||
|
from rf2aa.model.layers.AuxiliaryPredictor import (
|
||||||
|
DistanceNetwork,
|
||||||
|
MaskedTokenNetwork,
|
||||||
|
LDDTNetwork,
|
||||||
|
PAENetwork,
|
||||||
|
BinderNetwork,
|
||||||
|
)
|
||||||
|
from rf2aa.tensor_util import assert_shape, assert_equal
|
||||||
|
import rf2aa.util
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
|
||||||
|
def get_shape(t):
|
||||||
|
if hasattr(t, "shape"):
|
||||||
|
return t.shape
|
||||||
|
if type(t) is tuple:
|
||||||
|
return [get_shape(e) for e in t]
|
||||||
|
else:
|
||||||
|
return type(t)
|
||||||
|
|
||||||
|
|
||||||
|
class RoseTTAFoldModule(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
symmetrize_repeats=None, # whether to symmetrize repeats in the pair track
|
||||||
|
repeat_length=None, # if symmetrizing repeats, what length are they?
|
||||||
|
symmsub_k=None, # if symmetrizing repeats, which diagonals?
|
||||||
|
sym_method=None, # if symmetrizing repeats, which block symmetrization method?
|
||||||
|
main_block=None, # if copying template blocks along main diag, which block is main block? (the one w/ motif)
|
||||||
|
copy_main_block_template=None, # whether or not to copy main block template along main diag
|
||||||
|
n_extra_block=4,
|
||||||
|
n_main_block=8,
|
||||||
|
n_ref_block=4,
|
||||||
|
n_finetune_block=0,
|
||||||
|
d_msa=256,
|
||||||
|
d_msa_full=64,
|
||||||
|
d_pair=128,
|
||||||
|
d_templ=64,
|
||||||
|
n_head_msa=8,
|
||||||
|
n_head_pair=4,
|
||||||
|
n_head_templ=4,
|
||||||
|
d_hidden=32,
|
||||||
|
d_hidden_templ=64,
|
||||||
|
d_t1d=0,
|
||||||
|
p_drop=0.15,
|
||||||
|
additional_dt1d=0,
|
||||||
|
recycling_type="msa_pair",
|
||||||
|
SE3_param={}, SE3_ref_param={},
|
||||||
|
atom_type_index=None,
|
||||||
|
aamask=None,
|
||||||
|
ljlk_parameters=None,
|
||||||
|
lj_correction_parameters=None,
|
||||||
|
cb_len=None,
|
||||||
|
cb_ang=None,
|
||||||
|
cb_tor=None,
|
||||||
|
num_bonds=None,
|
||||||
|
lj_lin=0.6,
|
||||||
|
use_chiral_l1=True,
|
||||||
|
use_lj_l1=False,
|
||||||
|
use_atom_frames=True,
|
||||||
|
use_same_chain=False,
|
||||||
|
enable_same_chain=False,
|
||||||
|
refiner_topk=64,
|
||||||
|
get_quaternion=False,
|
||||||
|
# New for diffusion
|
||||||
|
freeze_track_motif=False,
|
||||||
|
assert_single_sequence_input=False,
|
||||||
|
fit=False,
|
||||||
|
tscale=1.0
|
||||||
|
):
|
||||||
|
super(RoseTTAFoldModule, self).__init__()
|
||||||
|
self.freeze_track_motif = freeze_track_motif
|
||||||
|
self.assert_single_sequence_input = assert_single_sequence_input
|
||||||
|
self.recycling_type = recycling_type
|
||||||
|
#
|
||||||
|
# Input Embeddings
|
||||||
|
d_state = SE3_param["l0_out_features"]
|
||||||
|
self.latent_emb = MSA_emb(
|
||||||
|
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop, use_same_chain=use_same_chain,
|
||||||
|
enable_same_chain=enable_same_chain
|
||||||
|
)
|
||||||
|
self.full_emb = Extra_emb(
|
||||||
|
d_msa=d_msa_full, d_init=ChemData().NAATOKENS - 1 + 4, p_drop=p_drop
|
||||||
|
)
|
||||||
|
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=ChemData().NBTYPES)
|
||||||
|
|
||||||
|
self.templ_emb = Templ_emb(d_t1d=d_t1d,
|
||||||
|
d_pair=d_pair,
|
||||||
|
d_templ=d_templ,
|
||||||
|
d_state=d_state,
|
||||||
|
n_head=n_head_templ,
|
||||||
|
d_hidden=d_hidden_templ,
|
||||||
|
p_drop=0.25,
|
||||||
|
symmetrize_repeats=symmetrize_repeats, # repeat protein stuff
|
||||||
|
repeat_length=repeat_length,
|
||||||
|
symmsub_k=symmsub_k,
|
||||||
|
sym_method=sym_method,
|
||||||
|
main_block=main_block,
|
||||||
|
copy_main_block=copy_main_block_template,
|
||||||
|
additional_dt1d=additional_dt1d)
|
||||||
|
|
||||||
|
# Update inputs with outputs from previous round
|
||||||
|
|
||||||
|
self.recycle = recycling_factory[recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
||||||
|
#
|
||||||
|
self.simulator = IterativeSimulator(
|
||||||
|
n_extra_block=n_extra_block,
|
||||||
|
n_main_block=n_main_block,
|
||||||
|
n_ref_block=n_ref_block,
|
||||||
|
n_finetune_block=n_finetune_block,
|
||||||
|
d_msa=d_msa,
|
||||||
|
d_msa_full=d_msa_full,
|
||||||
|
d_pair=d_pair,
|
||||||
|
d_hidden=d_hidden,
|
||||||
|
n_head_msa=n_head_msa,
|
||||||
|
n_head_pair=n_head_pair,
|
||||||
|
SE3_param=SE3_param,
|
||||||
|
SE3_ref_param=SE3_ref_param,
|
||||||
|
p_drop=p_drop,
|
||||||
|
atom_type_index=atom_type_index, # change if encoding elements instead of atomtype
|
||||||
|
aamask=aamask,
|
||||||
|
ljlk_parameters=ljlk_parameters,
|
||||||
|
lj_correction_parameters=lj_correction_parameters,
|
||||||
|
num_bonds=num_bonds,
|
||||||
|
cb_len=cb_len,
|
||||||
|
cb_ang=cb_ang,
|
||||||
|
cb_tor=cb_tor,
|
||||||
|
lj_lin=lj_lin,
|
||||||
|
use_lj_l1=use_lj_l1,
|
||||||
|
use_chiral_l1=use_chiral_l1,
|
||||||
|
symmetrize_repeats=symmetrize_repeats,
|
||||||
|
repeat_length=repeat_length,
|
||||||
|
symmsub_k=symmsub_k,
|
||||||
|
sym_method=sym_method,
|
||||||
|
main_block=main_block,
|
||||||
|
use_same_chain=use_same_chain,
|
||||||
|
enable_same_chain=enable_same_chain,
|
||||||
|
refiner_topk=refiner_topk
|
||||||
|
)
|
||||||
|
|
||||||
|
##
|
||||||
|
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
|
||||||
|
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
|
||||||
|
self.lddt_pred = LDDTNetwork(d_state)
|
||||||
|
self.pae_pred = PAENetwork(d_pair)
|
||||||
|
self.pde_pred = PAENetwork(
|
||||||
|
d_pair
|
||||||
|
) # distance error, but use same architecture as aligned error
|
||||||
|
# binder predictions are made on top of the pair features, just like
|
||||||
|
# PAE predictions are. It's not clear if this is the best place to insert
|
||||||
|
# this prediction head.
|
||||||
|
# self.binder_network = BinderNetwork(d_pair, d_state)
|
||||||
|
|
||||||
|
self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable?
|
||||||
|
|
||||||
|
self.use_atom_frames = use_atom_frames
|
||||||
|
self.enable_same_chain = enable_same_chain
|
||||||
|
self.get_quaternion = get_quaternion
|
||||||
|
self.verbose_checks = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
msa_latent,
|
||||||
|
msa_full,
|
||||||
|
seq,
|
||||||
|
seq_unmasked,
|
||||||
|
xyz,
|
||||||
|
sctors,
|
||||||
|
idx,
|
||||||
|
bond_feats,
|
||||||
|
dist_matrix,
|
||||||
|
chirals,
|
||||||
|
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
|
||||||
|
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None, is_motif=None,
|
||||||
|
return_raw=False,
|
||||||
|
use_checkpoint=False,
|
||||||
|
return_infer=False, #fd ?
|
||||||
|
p2p_crop=-1, topk_crop=-1, # striping
|
||||||
|
symmids=None, symmsub=None, symmRs=None, symmmeta=None, # symmetry
|
||||||
|
):
|
||||||
|
# ic(get_shape(msa_latent))
|
||||||
|
# ic(get_shape(msa_full))
|
||||||
|
# ic(get_shape(seq))
|
||||||
|
# ic(get_shape(seq_unmasked))
|
||||||
|
# ic(get_shape(xyz))
|
||||||
|
# ic(get_shape(sctors))
|
||||||
|
# ic(get_shape(idx))
|
||||||
|
# ic(get_shape(bond_feats))
|
||||||
|
# ic(get_shape(chirals))
|
||||||
|
# ic(get_shape(atom_frames))
|
||||||
|
# ic(get_shape(t1d))
|
||||||
|
# ic(get_shape(t2d))
|
||||||
|
# ic(get_shape(xyz_t))
|
||||||
|
# ic(get_shape(alpha_t))
|
||||||
|
# ic(get_shape(mask_t))
|
||||||
|
# ic(get_shape(same_chain))
|
||||||
|
# ic(get_shape(msa_prev))
|
||||||
|
# ic(get_shape(pair_prev))
|
||||||
|
# ic(get_shape(mask_recycle))
|
||||||
|
# ic()
|
||||||
|
# ic()
|
||||||
|
B, N, L = msa_latent.shape[:3]
|
||||||
|
A = atom_frames.shape[1]
|
||||||
|
dtype = msa_latent.dtype
|
||||||
|
|
||||||
|
if self.assert_single_sequence_input:
|
||||||
|
assert_shape(msa_latent, (1, 1, L, 164))
|
||||||
|
assert_shape(msa_full, (1, 1, L, 83))
|
||||||
|
assert_shape(seq, (1, L))
|
||||||
|
assert_shape(seq_unmasked, (1, L))
|
||||||
|
assert_shape(xyz, (1, L, ChemData().NTOTAL, 3))
|
||||||
|
assert_shape(sctors, (1, L, 20, 2))
|
||||||
|
assert_shape(idx, (1, L))
|
||||||
|
assert_shape(bond_feats, (1, L, L))
|
||||||
|
assert_shape(dist_matrix, (1, L, L))
|
||||||
|
# assert_shape(chirals, (1, 0))
|
||||||
|
# assert_shape(atom_frames, (1, 4, L)) # This is set to 4 for the recycle count, but that can't be right
|
||||||
|
assert_shape(atom_frames, (1, A, 3, 2)) # What is 4?
|
||||||
|
assert_shape(t1d, (1, 1, L, 80))
|
||||||
|
assert_shape(t2d, (1, 1, L, L, 68))
|
||||||
|
assert_shape(xyz_t, (1, 1, L, 3))
|
||||||
|
assert_shape(alpha_t, (1, 1, L, 60))
|
||||||
|
assert_shape(mask_t, (1, 1, L, L))
|
||||||
|
assert_shape(same_chain, (1, L, L))
|
||||||
|
device = msa_latent.device
|
||||||
|
assert_that(msa_full.device).is_equal_to(device)
|
||||||
|
assert_that(seq.device).is_equal_to(device)
|
||||||
|
assert_that(seq_unmasked.device).is_equal_to(device)
|
||||||
|
assert_that(xyz.device).is_equal_to(device)
|
||||||
|
assert_that(sctors.device).is_equal_to(device)
|
||||||
|
assert_that(idx.device).is_equal_to(device)
|
||||||
|
assert_that(bond_feats.device).is_equal_to(device)
|
||||||
|
assert_that(dist_matrix.device).is_equal_to(device)
|
||||||
|
assert_that(atom_frames.device).is_equal_to(device)
|
||||||
|
assert_that(t1d.device).is_equal_to(device)
|
||||||
|
assert_that(t2d.device).is_equal_to(device)
|
||||||
|
assert_that(xyz_t.device).is_equal_to(device)
|
||||||
|
assert_that(alpha_t.device).is_equal_to(device)
|
||||||
|
assert_that(mask_t.device).is_equal_to(device)
|
||||||
|
assert_that(same_chain.device).is_equal_to(device)
|
||||||
|
|
||||||
|
if self.verbose_checks:
|
||||||
|
#ic(is_motif.shape)
|
||||||
|
is_sm = rf2aa.util.is_atom(seq[0]) # (L)
|
||||||
|
#is_protein_motif = is_motif & ~is_sm
|
||||||
|
#if is_motif.any():
|
||||||
|
# motif_protein_i = torch.where(is_motif)[0][0]
|
||||||
|
#is_motif_sm = is_motif & is_sm
|
||||||
|
#if is_sm.any():
|
||||||
|
# motif_sm_i = torch.where(is_motif_sm)[0][0]
|
||||||
|
#diffused_protein_i = torch.where(~is_sm & ~is_motif)[0][0]
|
||||||
|
|
||||||
|
"""
|
||||||
|
msa_full: NSEQ,N_INDEL,N_TERMINUS,
|
||||||
|
msa_masked: NSEQ,NSEQ,N_INDEL,N_INDEL,N_TERMINUS
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
NINDEL = 1
|
||||||
|
NTERMINUS = 2
|
||||||
|
NMSAFULL = ChemData().NAATOKENS + NINDEL + NTERMINUS
|
||||||
|
NMSAMASKED = ChemData().NAATOKENS + ChemData().NAATOKENS + NINDEL + NINDEL + NTERMINUS
|
||||||
|
assert_that(msa_latent.shape[-1]).is_equal_to(NMSAMASKED)
|
||||||
|
assert_that(msa_full.shape[-1]).is_equal_to(NMSAFULL)
|
||||||
|
|
||||||
|
msa_full_seq = np.r_[0:ChemData().NAATOKENS]
|
||||||
|
msa_full_indel = np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
|
||||||
|
msa_full_term = np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]
|
||||||
|
|
||||||
|
msa_latent_seq1 = np.r_[0:ChemData().NAATOKENS]
|
||||||
|
msa_latent_seq2 = np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]
|
||||||
|
msa_latent_indel1 = np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
|
||||||
|
msa_latent_indel2 = np.r_[
|
||||||
|
2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL
|
||||||
|
]
|
||||||
|
msa_latent_terminus = np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
|
||||||
|
|
||||||
|
#i_name = [(diffused_protein_i, "diffused_protein")]
|
||||||
|
#if is_sm.any():
|
||||||
|
# i_name.insert(0, (motif_sm_i, "motif_sm"))
|
||||||
|
#if is_motif.any():
|
||||||
|
# i_name.insert(0, (motif_protein_i, "motif_protein"))
|
||||||
|
i_name = [(0, "tst")]
|
||||||
|
for i, name in i_name:
|
||||||
|
ic(f"------------------{name}:{i}----------------")
|
||||||
|
msa_full_seq = msa_full[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
|
||||||
|
msa_full_indel = msa_full[
|
||||||
|
0, 0, i, np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
|
||||||
|
]
|
||||||
|
msa_full_term = msa_full[0, 0, i, np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]]
|
||||||
|
|
||||||
|
msa_latent_seq1 = msa_latent[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
|
||||||
|
msa_latent_seq2 = msa_latent[0, 0, i, np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]]
|
||||||
|
msa_latent_indel1 = msa_latent[
|
||||||
|
0, 0, i, np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
|
||||||
|
]
|
||||||
|
msa_latent_indel2 = msa_latent[
|
||||||
|
0,
|
||||||
|
0,
|
||||||
|
i,
|
||||||
|
np.r_[2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL],
|
||||||
|
]
|
||||||
|
msa_latent_term = msa_latent[
|
||||||
|
0, 0, i, np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
|
||||||
|
]
|
||||||
|
|
||||||
|
assert_equal(msa_full_seq, msa_latent_seq1)
|
||||||
|
assert_equal(msa_full_seq, msa_latent_seq2)
|
||||||
|
assert_equal(msa_full_indel, msa_latent_indel1)
|
||||||
|
assert_equal(msa_full_indel, msa_latent_indel2)
|
||||||
|
assert_equal(msa_full_term, msa_latent_term)
|
||||||
|
# if 'motif' in name:
|
||||||
|
msa_cat = torch.where(msa_full_seq)[0]
|
||||||
|
ic(msa_cat, seq[0, i])
|
||||||
|
assert_equal(seq[0, i : i + 1], msa_cat)
|
||||||
|
assert_equal(seq[0, i], seq_unmasked[0, i])
|
||||||
|
ic(
|
||||||
|
name,
|
||||||
|
# torch.where(msa_latent[0,0,i,:80]),
|
||||||
|
# torch.where(msa_full[0,0,i]),
|
||||||
|
seq[0, i],
|
||||||
|
seq_unmasked[0, i],
|
||||||
|
torch.where(t1d[0, 0, i]),
|
||||||
|
xyz[0, i, :4, 0],
|
||||||
|
xyz_t[0, 0, i, 0],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get embeddings
|
||||||
|
#if self.enable_same_chain == False:
|
||||||
|
# same_chain = None
|
||||||
|
msa_latent, pair, state = self.latent_emb(
|
||||||
|
msa_latent, seq, idx, bond_feats, dist_matrix, same_chain=same_chain
|
||||||
|
)
|
||||||
|
msa_full = self.full_emb(msa_full, seq, idx)
|
||||||
|
pair = pair + self.bond_emb(bond_feats)
|
||||||
|
|
||||||
|
msa_latent, pair, state = msa_latent.to(dtype), pair.to(dtype), state.to(dtype)
|
||||||
|
msa_full = msa_full.to(dtype)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Do recycling
|
||||||
|
if msa_prev is None:
|
||||||
|
msa_prev = torch.zeros_like(msa_latent[:,0])
|
||||||
|
if pair_prev is None:
|
||||||
|
pair_prev = torch.zeros_like(pair)
|
||||||
|
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
|
||||||
|
state_prev = torch.zeros_like(state)
|
||||||
|
|
||||||
|
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
|
||||||
|
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
|
||||||
|
|
||||||
|
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
|
||||||
|
pair = pair + pair_recycle
|
||||||
|
state = state + state_recycle # if state is not recycled these will be zeros
|
||||||
|
|
||||||
|
# add template embedding
|
||||||
|
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop)
|
||||||
|
|
||||||
|
# Predict coordinates from given inputs
|
||||||
|
is_motif = is_motif if self.freeze_track_motif else torch.zeros_like(seq).bool()[0]
|
||||||
|
msa, pair, xyz, alpha_s, xyz_allatom, state, symmsub, quat = self.simulator(
|
||||||
|
seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx,
|
||||||
|
symmids, symmsub, symmRs, symmmeta,
|
||||||
|
bond_feats, dist_matrix, same_chain, chirals, is_motif, atom_frames,
|
||||||
|
use_checkpoint=use_checkpoint, use_atom_frames=self.use_atom_frames,
|
||||||
|
p2p_crop=p2p_crop, topk_crop=topk_crop
|
||||||
|
)
|
||||||
|
|
||||||
|
if return_raw:
|
||||||
|
# get last structure
|
||||||
|
xyz_last = xyz_allatom[-1].unsqueeze(0)
|
||||||
|
return msa[:,0], pair, xyz_last, alpha_s[-1], None
|
||||||
|
|
||||||
|
# predict masked amino acids
|
||||||
|
logits_aa = self.aa_pred(msa)
|
||||||
|
|
||||||
|
# predict distogram & orientograms
|
||||||
|
logits = self.c6d_pred(pair)
|
||||||
|
|
||||||
|
# Predict LDDT
|
||||||
|
lddt = self.lddt_pred(state)
|
||||||
|
|
||||||
|
if self.verbose_checks:
|
||||||
|
pseq_0 = logits_aa.permute(0, 2, 1)
|
||||||
|
ic(pseq_0.shape)
|
||||||
|
pseq_0 = pseq_0[0]
|
||||||
|
ic(
|
||||||
|
f"motif sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[is_motif], dim=-1).tolist())}"
|
||||||
|
)
|
||||||
|
ic(
|
||||||
|
f"diffused sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[~is_motif], dim=-1).tolist())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logits_pae = logits_pde = p_bind = None
|
||||||
|
# predict aligned error and distance error
|
||||||
|
logits_pae = self.pae_pred(pair)
|
||||||
|
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
|
||||||
|
|
||||||
|
#fd predict bind/no-bind
|
||||||
|
p_bind = self.bind_pred(logits_pae,same_chain)
|
||||||
|
|
||||||
|
if self.get_quaternion:
|
||||||
|
return (
|
||||||
|
logits, logits_aa, logits_pae, logits_pde, p_bind,
|
||||||
|
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state, quat
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return (
|
||||||
|
logits, logits_aa, logits_pae, logits_pde, p_bind,
|
||||||
|
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state
|
||||||
|
)
|
1220
rf2aa/model/Track_module.py
Normal file
1220
rf2aa/model/Track_module.py
Normal file
File diff suppressed because it is too large
Load diff
475
rf2aa/model/layers/Attention_module.py
Normal file
475
rf2aa/model/layers/Attention_module.py
Normal file
|
@ -0,0 +1,475 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
|
from opt_einsum import contract as einsum
|
||||||
|
from rf2aa.util_module import init_lecun_normal
|
||||||
|
class FeedForwardLayer(nn.Module):
|
||||||
|
def __init__(self, d_model, r_ff, p_drop=0.1):
|
||||||
|
super(FeedForwardLayer, self).__init__()
|
||||||
|
self.norm = nn.LayerNorm(d_model)
|
||||||
|
self.linear1 = nn.Linear(d_model, d_model*r_ff)
|
||||||
|
self.dropout = nn.Dropout(p_drop)
|
||||||
|
self.linear2 = nn.Linear(d_model*r_ff, d_model)
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# initialize linear layer right before ReLu: He initializer (kaiming normal)
|
||||||
|
nn.init.kaiming_normal_(self.linear1.weight, nonlinearity='relu')
|
||||||
|
nn.init.zeros_(self.linear1.bias)
|
||||||
|
|
||||||
|
# initialize linear layer right before residual connection: zero initialize
|
||||||
|
nn.init.zeros_(self.linear2.weight)
|
||||||
|
nn.init.zeros_(self.linear2.bias)
|
||||||
|
|
||||||
|
def forward(self, src):
|
||||||
|
src = self.norm(src)
|
||||||
|
src = self.linear2(self.dropout(F.relu_(self.linear1(src))))
|
||||||
|
return src
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
# calculate multi-head attention
|
||||||
|
def __init__(self, d_query, d_key, n_head, d_hidden, d_out, p_drop=0.1):
|
||||||
|
super(Attention, self).__init__()
|
||||||
|
self.h = n_head
|
||||||
|
self.dim = d_hidden
|
||||||
|
#
|
||||||
|
self.to_q = nn.Linear(d_query, n_head*d_hidden, bias=False)
|
||||||
|
self.to_k = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
||||||
|
self.to_v = nn.Linear(d_key, n_head*d_hidden, bias=False)
|
||||||
|
#
|
||||||
|
self.to_out = nn.Linear(n_head*d_hidden, d_out)
|
||||||
|
self.scaling = 1/math.sqrt(d_hidden)
|
||||||
|
#
|
||||||
|
# initialize all parameters properly
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||||
|
nn.init.xavier_uniform_(self.to_q.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_k.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_v.weight)
|
||||||
|
|
||||||
|
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||||
|
nn.init.zeros_(self.to_out.weight)
|
||||||
|
nn.init.zeros_(self.to_out.bias)
|
||||||
|
|
||||||
|
def forward(self, query, key, value):
|
||||||
|
B, Q = query.shape[:2]
|
||||||
|
B, K = key.shape[:2]
|
||||||
|
#
|
||||||
|
query = self.to_q(query).reshape(B, Q, self.h, self.dim)
|
||||||
|
key = self.to_k(key).reshape(B, K, self.h, self.dim)
|
||||||
|
value = self.to_v(value).reshape(B, K, self.h, self.dim)
|
||||||
|
#
|
||||||
|
query = query * self.scaling
|
||||||
|
attn = einsum('bqhd,bkhd->bhqk', query, key)
|
||||||
|
attn = F.softmax(attn, dim=-1)
|
||||||
|
#
|
||||||
|
out = einsum('bhqk,bkhd->bqhd', attn, value)
|
||||||
|
out = out.reshape(B, Q, self.h*self.dim)
|
||||||
|
#
|
||||||
|
out = self.to_out(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
# MSA Attention (row/column) from AlphaFold architecture
|
||||||
|
class SequenceWeight(nn.Module):
|
||||||
|
def __init__(self, d_msa, n_head, d_hidden, p_drop=0.1):
|
||||||
|
super(SequenceWeight, self).__init__()
|
||||||
|
self.h = n_head
|
||||||
|
self.dim = d_hidden
|
||||||
|
self.scale = 1.0 / math.sqrt(self.dim)
|
||||||
|
|
||||||
|
self.to_query = nn.Linear(d_msa, n_head*d_hidden)
|
||||||
|
self.to_key = nn.Linear(d_msa, n_head*d_hidden)
|
||||||
|
self.dropout = nn.Dropout(p_drop)
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||||
|
nn.init.xavier_uniform_(self.to_query.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_key.weight)
|
||||||
|
|
||||||
|
def forward(self, msa):
|
||||||
|
B, N, L = msa.shape[:3]
|
||||||
|
|
||||||
|
tar_seq = msa[:,0]
|
||||||
|
|
||||||
|
q = self.to_query(tar_seq).view(B, 1, L, self.h, self.dim)
|
||||||
|
k = self.to_key(msa).view(B, N, L, self.h, self.dim)
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
attn = einsum('bqihd,bkihd->bkihq', q, k)
|
||||||
|
attn = F.softmax(attn, dim=1)
|
||||||
|
return self.dropout(attn)
|
||||||
|
|
||||||
|
class MSARowAttentionWithBias(nn.Module):
|
||||||
|
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_hidden=32):
|
||||||
|
super(MSARowAttentionWithBias, self).__init__()
|
||||||
|
self.norm_msa = nn.LayerNorm(d_msa)
|
||||||
|
self.norm_pair = nn.LayerNorm(d_pair)
|
||||||
|
#
|
||||||
|
self.seq_weight = SequenceWeight(d_msa, n_head, d_hidden, p_drop=0.1)
|
||||||
|
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||||
|
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||||
|
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||||
|
self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
||||||
|
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
||||||
|
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
||||||
|
|
||||||
|
self.scaling = 1/math.sqrt(d_hidden)
|
||||||
|
self.h = n_head
|
||||||
|
self.dim = d_hidden
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||||
|
nn.init.xavier_uniform_(self.to_q.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_k.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_v.weight)
|
||||||
|
|
||||||
|
# bias: normal distribution
|
||||||
|
self.to_b = init_lecun_normal(self.to_b)
|
||||||
|
|
||||||
|
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||||
|
nn.init.zeros_(self.to_g.weight)
|
||||||
|
nn.init.ones_(self.to_g.bias)
|
||||||
|
|
||||||
|
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||||
|
nn.init.zeros_(self.to_out.weight)
|
||||||
|
nn.init.zeros_(self.to_out.bias)
|
||||||
|
|
||||||
|
def forward(self, msa, pair): # TODO: make this as tied-attention
|
||||||
|
B, N, L = msa.shape[:3]
|
||||||
|
#
|
||||||
|
msa = self.norm_msa(msa)
|
||||||
|
pair = self.norm_pair(pair)
|
||||||
|
#
|
||||||
|
seq_weight = self.seq_weight(msa) # (B, N, L, h, 1)
|
||||||
|
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
||||||
|
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
|
||||||
|
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
|
||||||
|
bias = self.to_b(pair) # (B, L, L, h)
|
||||||
|
gate = torch.sigmoid(self.to_g(msa))
|
||||||
|
#
|
||||||
|
query = query * seq_weight.expand(-1, -1, -1, -1, self.dim)
|
||||||
|
key = key * self.scaling
|
||||||
|
attn = einsum('bsqhd,bskhd->bqkh', query, key)
|
||||||
|
attn = attn + bias
|
||||||
|
attn = F.softmax(attn, dim=-2)
|
||||||
|
#
|
||||||
|
out = einsum('bqkh,bskhd->bsqhd', attn, value).reshape(B, N, L, -1)
|
||||||
|
out = gate * out
|
||||||
|
#
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class MSAColAttention(nn.Module):
|
||||||
|
def __init__(self, d_msa=256, n_head=8, d_hidden=32):
|
||||||
|
super(MSAColAttention, self).__init__()
|
||||||
|
self.norm_msa = nn.LayerNorm(d_msa)
|
||||||
|
#
|
||||||
|
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||||
|
self.to_k = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||||
|
self.to_v = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||||
|
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
||||||
|
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
||||||
|
|
||||||
|
self.scaling = 1/math.sqrt(d_hidden)
|
||||||
|
self.h = n_head
|
||||||
|
self.dim = d_hidden
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||||
|
nn.init.xavier_uniform_(self.to_q.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_k.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_v.weight)
|
||||||
|
|
||||||
|
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||||
|
nn.init.zeros_(self.to_g.weight)
|
||||||
|
nn.init.ones_(self.to_g.bias)
|
||||||
|
|
||||||
|
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||||
|
nn.init.zeros_(self.to_out.weight)
|
||||||
|
nn.init.zeros_(self.to_out.bias)
|
||||||
|
|
||||||
|
def forward(self, msa):
|
||||||
|
B, N, L = msa.shape[:3]
|
||||||
|
#
|
||||||
|
msa = self.norm_msa(msa)
|
||||||
|
#
|
||||||
|
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
||||||
|
key = self.to_k(msa).reshape(B, N, L, self.h, self.dim)
|
||||||
|
value = self.to_v(msa).reshape(B, N, L, self.h, self.dim)
|
||||||
|
gate = torch.sigmoid(self.to_g(msa))
|
||||||
|
#
|
||||||
|
query = query * self.scaling
|
||||||
|
attn = einsum('bqihd,bkihd->bihqk', query, key)
|
||||||
|
attn = F.softmax(attn, dim=-1)
|
||||||
|
#
|
||||||
|
out = einsum('bihqk,bkihd->bqihd', attn, value).reshape(B, N, L, -1)
|
||||||
|
out = gate * out
|
||||||
|
#
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class MSAColGlobalAttention(nn.Module):
|
||||||
|
def __init__(self, d_msa=64, n_head=8, d_hidden=8):
|
||||||
|
super(MSAColGlobalAttention, self).__init__()
|
||||||
|
self.norm_msa = nn.LayerNorm(d_msa)
|
||||||
|
#
|
||||||
|
self.to_q = nn.Linear(d_msa, n_head*d_hidden, bias=False)
|
||||||
|
self.to_k = nn.Linear(d_msa, d_hidden, bias=False)
|
||||||
|
self.to_v = nn.Linear(d_msa, d_hidden, bias=False)
|
||||||
|
self.to_g = nn.Linear(d_msa, n_head*d_hidden)
|
||||||
|
self.to_out = nn.Linear(n_head*d_hidden, d_msa)
|
||||||
|
|
||||||
|
self.scaling = 1/math.sqrt(d_hidden)
|
||||||
|
self.h = n_head
|
||||||
|
self.dim = d_hidden
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||||
|
nn.init.xavier_uniform_(self.to_q.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_k.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_v.weight)
|
||||||
|
|
||||||
|
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||||
|
nn.init.zeros_(self.to_g.weight)
|
||||||
|
nn.init.ones_(self.to_g.bias)
|
||||||
|
|
||||||
|
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||||
|
nn.init.zeros_(self.to_out.weight)
|
||||||
|
nn.init.zeros_(self.to_out.bias)
|
||||||
|
|
||||||
|
def forward(self, msa):
|
||||||
|
B, N, L = msa.shape[:3]
|
||||||
|
#
|
||||||
|
msa = self.norm_msa(msa)
|
||||||
|
#
|
||||||
|
query = self.to_q(msa).reshape(B, N, L, self.h, self.dim)
|
||||||
|
query = query.mean(dim=1) # (B, L, h, dim)
|
||||||
|
key = self.to_k(msa) # (B, N, L, dim)
|
||||||
|
value = self.to_v(msa) # (B, N, L, dim)
|
||||||
|
gate = torch.sigmoid(self.to_g(msa)) # (B, N, L, h*dim)
|
||||||
|
#
|
||||||
|
query = query * self.scaling
|
||||||
|
attn = einsum('bihd,bkid->bihk', query, key) # (B, L, h, N)
|
||||||
|
attn = F.softmax(attn, dim=-1)
|
||||||
|
#
|
||||||
|
out = einsum('bihk,bkid->bihd', attn, value).reshape(B, 1, L, -1) # (B, 1, L, h*dim)
|
||||||
|
out = gate * out # (B, N, L, h*dim)
|
||||||
|
#
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# TriangleAttention & TriangleMultiplication from AlphaFold architecture
|
||||||
|
class TriangleAttention(nn.Module):
|
||||||
|
def __init__(self, d_pair, n_head=4, d_hidden=32, p_drop=0.1, start_node=True):
|
||||||
|
super(TriangleAttention, self).__init__()
|
||||||
|
self.norm = nn.LayerNorm(d_pair)
|
||||||
|
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||||
|
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||||
|
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||||
|
|
||||||
|
self.to_b = nn.Linear(d_pair, n_head, bias=False)
|
||||||
|
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
|
||||||
|
|
||||||
|
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
|
||||||
|
|
||||||
|
self.scaling = 1/math.sqrt(d_hidden)
|
||||||
|
|
||||||
|
self.h = n_head
|
||||||
|
self.dim = d_hidden
|
||||||
|
self.start_node=start_node
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||||
|
nn.init.xavier_uniform_(self.to_q.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_k.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_v.weight)
|
||||||
|
|
||||||
|
# bias: normal distribution
|
||||||
|
self.to_b = init_lecun_normal(self.to_b)
|
||||||
|
|
||||||
|
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||||
|
nn.init.zeros_(self.to_g.weight)
|
||||||
|
nn.init.ones_(self.to_g.bias)
|
||||||
|
|
||||||
|
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||||
|
nn.init.zeros_(self.to_out.weight)
|
||||||
|
nn.init.zeros_(self.to_out.bias)
|
||||||
|
|
||||||
|
def forward(self, pair):
|
||||||
|
B, L = pair.shape[:2]
|
||||||
|
|
||||||
|
pair = self.norm(pair)
|
||||||
|
|
||||||
|
# input projection
|
||||||
|
query = self.to_q(pair).reshape(B, L, L, self.h, -1)
|
||||||
|
key = self.to_k(pair).reshape(B, L, L, self.h, -1)
|
||||||
|
value = self.to_v(pair).reshape(B, L, L, self.h, -1)
|
||||||
|
bias = self.to_b(pair) # (B, L, L, h)
|
||||||
|
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
|
||||||
|
|
||||||
|
# attention
|
||||||
|
query = query * self.scaling
|
||||||
|
if self.start_node:
|
||||||
|
attn = einsum('bijhd,bikhd->bijkh', query, key)
|
||||||
|
else:
|
||||||
|
attn = einsum('bijhd,bkjhd->bijkh', query, key)
|
||||||
|
attn = attn + bias.unsqueeze(1).expand(-1,L,-1,-1,-1) # (bijkh)
|
||||||
|
attn = F.softmax(attn, dim=-2)
|
||||||
|
if self.start_node:
|
||||||
|
out = einsum('bijkh,bikhd->bijhd', attn, value).reshape(B, L, L, -1)
|
||||||
|
else:
|
||||||
|
out = einsum('bijkh,bkjhd->bijhd', attn, value).reshape(B, L, L, -1)
|
||||||
|
out = gate * out # gated attention
|
||||||
|
|
||||||
|
# output projection
|
||||||
|
out = self.to_out(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
class TriangleMultiplication(nn.Module):
|
||||||
|
def __init__(self, d_pair, d_hidden=128, outgoing=True):
|
||||||
|
super(TriangleMultiplication, self).__init__()
|
||||||
|
self.norm = nn.LayerNorm(d_pair)
|
||||||
|
self.left_proj = nn.Linear(d_pair, d_hidden)
|
||||||
|
self.right_proj = nn.Linear(d_pair, d_hidden)
|
||||||
|
self.left_gate = nn.Linear(d_pair, d_hidden)
|
||||||
|
self.right_gate = nn.Linear(d_pair, d_hidden)
|
||||||
|
#
|
||||||
|
self.gate = nn.Linear(d_pair, d_pair)
|
||||||
|
self.norm_out = nn.LayerNorm(d_hidden)
|
||||||
|
self.out_proj = nn.Linear(d_hidden, d_pair)
|
||||||
|
|
||||||
|
self.outgoing = outgoing
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# normal distribution for regular linear weights
|
||||||
|
self.left_proj = init_lecun_normal(self.left_proj)
|
||||||
|
self.right_proj = init_lecun_normal(self.right_proj)
|
||||||
|
|
||||||
|
# Set Bias of Linear layers to zeros
|
||||||
|
nn.init.zeros_(self.left_proj.bias)
|
||||||
|
nn.init.zeros_(self.right_proj.bias)
|
||||||
|
|
||||||
|
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||||
|
nn.init.zeros_(self.left_gate.weight)
|
||||||
|
nn.init.ones_(self.left_gate.bias)
|
||||||
|
|
||||||
|
nn.init.zeros_(self.right_gate.weight)
|
||||||
|
nn.init.ones_(self.right_gate.bias)
|
||||||
|
|
||||||
|
nn.init.zeros_(self.gate.weight)
|
||||||
|
nn.init.ones_(self.gate.bias)
|
||||||
|
|
||||||
|
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||||
|
nn.init.zeros_(self.out_proj.weight)
|
||||||
|
nn.init.zeros_(self.out_proj.bias)
|
||||||
|
|
||||||
|
def forward(self, pair):
|
||||||
|
B, L = pair.shape[:2]
|
||||||
|
pair = self.norm(pair)
|
||||||
|
|
||||||
|
left = self.left_proj(pair) # (B, L, L, d_h)
|
||||||
|
left_gate = torch.sigmoid(self.left_gate(pair))
|
||||||
|
left = left_gate * left
|
||||||
|
|
||||||
|
right = self.right_proj(pair) # (B, L, L, d_h)
|
||||||
|
right_gate = torch.sigmoid(self.right_gate(pair))
|
||||||
|
right = right_gate * right
|
||||||
|
|
||||||
|
if self.outgoing:
|
||||||
|
out = einsum('bikd,bjkd->bijd', left, right/float(L))
|
||||||
|
else:
|
||||||
|
out = einsum('bkid,bkjd->bijd', left, right/float(L))
|
||||||
|
out = self.norm_out(out)
|
||||||
|
out = self.out_proj(out)
|
||||||
|
|
||||||
|
gate = torch.sigmoid(self.gate(pair)) # (B, L, L, d_pair)
|
||||||
|
out = gate * out
|
||||||
|
return out
|
||||||
|
|
||||||
|
# Instead of triangle attention, use Tied axail attention with bias from coordinates..?
|
||||||
|
class BiasedAxialAttention(nn.Module):
|
||||||
|
def __init__(self, d_pair, d_bias, n_head, d_hidden, p_drop=0.1, is_row=True):
|
||||||
|
super(BiasedAxialAttention, self).__init__()
|
||||||
|
#
|
||||||
|
self.is_row = is_row
|
||||||
|
self.norm_pair = nn.LayerNorm(d_pair)
|
||||||
|
self.norm_bias = nn.LayerNorm(d_bias)
|
||||||
|
|
||||||
|
self.to_q = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||||
|
self.to_k = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||||
|
self.to_v = nn.Linear(d_pair, n_head*d_hidden, bias=False)
|
||||||
|
self.to_b = nn.Linear(d_bias, n_head, bias=False)
|
||||||
|
self.to_g = nn.Linear(d_pair, n_head*d_hidden)
|
||||||
|
self.to_out = nn.Linear(n_head*d_hidden, d_pair)
|
||||||
|
|
||||||
|
self.scaling = 1/math.sqrt(d_hidden)
|
||||||
|
self.h = n_head
|
||||||
|
self.dim = d_hidden
|
||||||
|
|
||||||
|
# initialize all parameters properly
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# query/key/value projection: Glorot uniform / Xavier uniform
|
||||||
|
nn.init.xavier_uniform_(self.to_q.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_k.weight)
|
||||||
|
nn.init.xavier_uniform_(self.to_v.weight)
|
||||||
|
|
||||||
|
# bias: normal distribution
|
||||||
|
self.to_b = init_lecun_normal(self.to_b)
|
||||||
|
|
||||||
|
# gating: zero weights, one biases (mostly open gate at the begining)
|
||||||
|
nn.init.zeros_(self.to_g.weight)
|
||||||
|
nn.init.ones_(self.to_g.bias)
|
||||||
|
|
||||||
|
# to_out: right before residual connection: zero initialize -- to make it sure residual operation is same to the Identity at the begining
|
||||||
|
nn.init.zeros_(self.to_out.weight)
|
||||||
|
nn.init.zeros_(self.to_out.bias)
|
||||||
|
|
||||||
|
def forward(self, pair, bias):
|
||||||
|
# pair: (B, L, L, d_pair)
|
||||||
|
B, L = pair.shape[:2]
|
||||||
|
|
||||||
|
if self.is_row:
|
||||||
|
pair = pair.permute(0,2,1,3)
|
||||||
|
bias = bias.permute(0,2,1,3)
|
||||||
|
|
||||||
|
pair = self.norm_pair(pair)
|
||||||
|
bias = self.norm_bias(bias)
|
||||||
|
|
||||||
|
query = self.to_q(pair).reshape(B, L, L, self.h, self.dim)
|
||||||
|
key = self.to_k(pair).reshape(B, L, L, self.h, self.dim)
|
||||||
|
value = self.to_v(pair).reshape(B, L, L, self.h, self.dim)
|
||||||
|
bias = self.to_b(bias) # (B, L, L, h)
|
||||||
|
gate = torch.sigmoid(self.to_g(pair)) # (B, L, L, h*dim)
|
||||||
|
|
||||||
|
query = query * self.scaling
|
||||||
|
key = key / L # normalize for tied attention
|
||||||
|
attn = einsum('bnihk,bnjhk->bijh', query, key) # tied attention
|
||||||
|
attn = attn + bias # apply bias
|
||||||
|
attn = F.softmax(attn, dim=-2) # (B, L, L, h)
|
||||||
|
|
||||||
|
out = einsum('bijh,bnjhd->bnihd', attn, value).reshape(B, L, L, -1)
|
||||||
|
out = gate * out
|
||||||
|
|
||||||
|
out = self.to_out(out)
|
||||||
|
if self.is_row:
|
||||||
|
out = out.permute(0,2,1,3)
|
||||||
|
return out
|
||||||
|
|
111
rf2aa/model/layers/AuxiliaryPredictor.py
Normal file
111
rf2aa/model/layers/AuxiliaryPredictor.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
class DistanceNetwork(nn.Module):
|
||||||
|
def __init__(self, n_feat, p_drop=0.0):
|
||||||
|
super(DistanceNetwork, self).__init__()
|
||||||
|
#HACK: dimensions are hard coded here
|
||||||
|
self.proj_symm = nn.Linear(n_feat, 61+37) # must match bin counts defined in kinematics.py
|
||||||
|
self.proj_asymm = nn.Linear(n_feat, 37+19)
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
# initialize linear layer for final logit prediction
|
||||||
|
nn.init.zeros_(self.proj_symm.weight)
|
||||||
|
nn.init.zeros_(self.proj_asymm.weight)
|
||||||
|
nn.init.zeros_(self.proj_symm.bias)
|
||||||
|
nn.init.zeros_(self.proj_asymm.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# input: pair info (B, L, L, C)
|
||||||
|
|
||||||
|
# predict theta, phi (non-symmetric)
|
||||||
|
logits_asymm = self.proj_asymm(x)
|
||||||
|
logits_theta = logits_asymm[:,:,:,:37].permute(0,3,1,2)
|
||||||
|
logits_phi = logits_asymm[:,:,:,37:].permute(0,3,1,2)
|
||||||
|
|
||||||
|
# predict dist, omega
|
||||||
|
logits_symm = self.proj_symm(x)
|
||||||
|
logits_symm = logits_symm + logits_symm.permute(0,2,1,3)
|
||||||
|
logits_dist = logits_symm[:,:,:,:61].permute(0,3,1,2)
|
||||||
|
logits_omega = logits_symm[:,:,:,61:].permute(0,3,1,2)
|
||||||
|
|
||||||
|
return logits_dist, logits_omega, logits_theta, logits_phi
|
||||||
|
|
||||||
|
class MaskedTokenNetwork(nn.Module):
|
||||||
|
def __init__(self, n_feat, p_drop=0.0):
|
||||||
|
super(MaskedTokenNetwork, self).__init__()
|
||||||
|
|
||||||
|
#fd note this predicts probability for the mask token (which is never in ground truth)
|
||||||
|
# it should be ok though(?)
|
||||||
|
self.proj = nn.Linear(n_feat, ChemData().NAATOKENS)
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
nn.init.zeros_(self.proj.weight)
|
||||||
|
nn.init.zeros_(self.proj.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, L = x.shape[:3]
|
||||||
|
logits = self.proj(x).permute(0,3,1,2).reshape(B, -1, N*L)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
class LDDTNetwork(nn.Module):
|
||||||
|
def __init__(self, n_feat, n_bin_lddt=50):
|
||||||
|
super(LDDTNetwork, self).__init__()
|
||||||
|
self.proj = nn.Linear(n_feat, n_bin_lddt)
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
nn.init.zeros_(self.proj.weight)
|
||||||
|
nn.init.zeros_(self.proj.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
logits = self.proj(x) # (B, L, 50)
|
||||||
|
|
||||||
|
return logits.permute(0,2,1)
|
||||||
|
|
||||||
|
class PAENetwork(nn.Module):
|
||||||
|
def __init__(self, n_feat, n_bin_pae=64):
|
||||||
|
super(PAENetwork, self).__init__()
|
||||||
|
self.proj = nn.Linear(n_feat, n_bin_pae)
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
nn.init.zeros_(self.proj.weight)
|
||||||
|
nn.init.zeros_(self.proj.bias)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
logits = self.proj(x) # (B, L, L, 64)
|
||||||
|
|
||||||
|
return logits.permute(0,3,1,2)
|
||||||
|
|
||||||
|
class BinderNetwork(nn.Module):
|
||||||
|
def __init__(self, n_bin_pae=64):
|
||||||
|
super(BinderNetwork, self).__init__()
|
||||||
|
self.classify = torch.nn.Linear(n_bin_pae, 1)
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
nn.init.zeros_(self.classify.weight)
|
||||||
|
nn.init.zeros_(self.classify.bias)
|
||||||
|
|
||||||
|
def forward(self, pae, same_chain):
|
||||||
|
logits = pae.permute(0,2,3,1)
|
||||||
|
logits_inter = torch.mean( logits[same_chain==0], dim=0 ).nan_to_num() # all zeros if single chain
|
||||||
|
prob = torch.sigmoid( self.classify( logits_inter ) )
|
||||||
|
return prob
|
||||||
|
|
||||||
|
aux_predictor_factory = {
|
||||||
|
"c6d": DistanceNetwork,
|
||||||
|
"mlm": MaskedTokenNetwork,
|
||||||
|
"plddt": LDDTNetwork,
|
||||||
|
"pae": PAENetwork,
|
||||||
|
"binder": BinderNetwork
|
||||||
|
}
|
458
rf2aa/model/layers/Embeddings.py
Normal file
458
rf2aa/model/layers/Embeddings.py
Normal file
|
@ -0,0 +1,458 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from opt_einsum import contract as einsum
|
||||||
|
import torch.utils.checkpoint as checkpoint
|
||||||
|
from rf2aa.util import *
|
||||||
|
from rf2aa.util_module import Dropout, get_clones, create_custom_forward, rbf, init_lecun_normal, get_res_atom_dist
|
||||||
|
from rf2aa.model.layers.Attention_module import Attention, TriangleMultiplication, TriangleAttention, FeedForwardLayer
|
||||||
|
from rf2aa.model.Track_module import PairStr2Pair, PositionalEncoding2D
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
# Module contains classes and functions to generate initial embeddings
|
||||||
|
|
||||||
|
class MSA_emb(nn.Module):
|
||||||
|
# Get initial seed MSA embedding
|
||||||
|
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=0,
|
||||||
|
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False, enable_same_chain=False):
|
||||||
|
if (d_init==0):
|
||||||
|
d_init = 2*ChemData().NAATOKENS+2+2
|
||||||
|
|
||||||
|
super(MSA_emb, self).__init__()
|
||||||
|
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
|
||||||
|
self.emb_q = nn.Embedding(ChemData().NAATOKENS, d_msa) # embedding for query sequence -- used for MSA embedding
|
||||||
|
self.emb_left = nn.Embedding(ChemData().NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
|
||||||
|
self.emb_right = nn.Embedding(ChemData().NAATOKENS, d_pair) # embedding for query sequence -- used for pair embedding
|
||||||
|
self.emb_state = nn.Embedding(ChemData().NAATOKENS, d_state)
|
||||||
|
self.pos = PositionalEncoding2D(d_pair, minpos=minpos, maxpos=maxpos,
|
||||||
|
maxpos_atom=maxpos_atom, p_drop=p_drop, use_same_chain=use_same_chain,
|
||||||
|
enable_same_chain=enable_same_chain)
|
||||||
|
self.enable_same_chain = enable_same_chain
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
self.emb = init_lecun_normal(self.emb)
|
||||||
|
self.emb_q = init_lecun_normal(self.emb_q)
|
||||||
|
self.emb_left = init_lecun_normal(self.emb_left)
|
||||||
|
self.emb_right = init_lecun_normal(self.emb_right)
|
||||||
|
self.emb_state = init_lecun_normal(self.emb_state)
|
||||||
|
|
||||||
|
nn.init.zeros_(self.emb.bias)
|
||||||
|
|
||||||
|
|
||||||
|
def _msa_emb(self, msa, seq):
|
||||||
|
N = msa.shape[1]
|
||||||
|
msa = self.emb(msa) # (B, N, L, d_pair) # MSA embedding
|
||||||
|
tmp = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_pair) -- query embedding
|
||||||
|
msa = msa + tmp.expand(-1, N, -1, -1) # adding query embedding to MSA
|
||||||
|
|
||||||
|
return msa
|
||||||
|
|
||||||
|
def _pair_emb(self, seq, idx, bond_feats, dist_matrix, same_chain=None):
|
||||||
|
left = self.emb_left(seq)[:,None] # (B, 1, L, d_pair)
|
||||||
|
right = self.emb_right(seq)[:,:,None] # (B, L, 1, d_pair)
|
||||||
|
pair = left + right # (B, L, L, d_pair)
|
||||||
|
pair = pair + self.pos(seq, idx, bond_feats, dist_matrix, same_chain=same_chain) # add relative position
|
||||||
|
|
||||||
|
return pair
|
||||||
|
|
||||||
|
def _state_emb(self, seq):
|
||||||
|
return self.emb_state(seq)
|
||||||
|
|
||||||
|
def forward(self, msa, seq, idx, bond_feats, dist_matrix, same_chain=None):
|
||||||
|
# Inputs:
|
||||||
|
# - msa: Input MSA (B, N, L, d_init)
|
||||||
|
# - seq: Input Sequence (B, L)
|
||||||
|
# - idx: Residue index
|
||||||
|
# - bond_feats: Bond features (B, L, L)
|
||||||
|
# Outputs:
|
||||||
|
# - msa: Initial MSA embedding (B, N, L, d_msa)
|
||||||
|
# - pair: Initial Pair embedding (B, L, L, d_pair)
|
||||||
|
|
||||||
|
if self.enable_same_chain == False:
|
||||||
|
same_chain = None
|
||||||
|
|
||||||
|
msa = self._msa_emb(msa, seq)
|
||||||
|
|
||||||
|
# pair embedding
|
||||||
|
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix, same_chain=same_chain)
|
||||||
|
# state embedding
|
||||||
|
state = self._state_emb(seq)
|
||||||
|
return msa, pair, state
|
||||||
|
|
||||||
|
class MSA_emb_nostate(MSA_emb):
|
||||||
|
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_init=0, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.1, use_same_chain=False):
|
||||||
|
super().__init__(d_msa, d_pair, d_state, d_init, minpos, maxpos, maxpos_atom, p_drop, use_same_chain)
|
||||||
|
if d_init==0:
|
||||||
|
d_init = 2*ChemData().NAATOKENS + 2 + 2
|
||||||
|
self.emb_state = None # emb state is just the identity
|
||||||
|
|
||||||
|
def forward(self, msa, seq, idx, bond_feats, dist_matrix):
|
||||||
|
msa = self._msa_emb(msa, seq)
|
||||||
|
pair = self._pair_emb(seq, idx, bond_feats, dist_matrix)
|
||||||
|
return msa, pair, None
|
||||||
|
|
||||||
|
class Extra_emb(nn.Module):
|
||||||
|
# Get initial seed MSA embedding
|
||||||
|
def __init__(self, d_msa=256, d_init=0, p_drop=0.1):
|
||||||
|
super(Extra_emb, self).__init__()
|
||||||
|
if d_init==0:
|
||||||
|
d_init=ChemData().NAATOKENS-1+4
|
||||||
|
self.emb = nn.Linear(d_init, d_msa) # embedding for general MSA
|
||||||
|
self.emb_q = nn.Embedding(ChemData().NAATOKENS, d_msa) # embedding for query sequence
|
||||||
|
#self.drop = nn.Dropout(p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
self.emb = init_lecun_normal(self.emb)
|
||||||
|
nn.init.zeros_(self.emb.bias)
|
||||||
|
|
||||||
|
def forward(self, msa, seq, idx):
|
||||||
|
# Inputs:
|
||||||
|
# - msa: Input MSA (B, N, L, d_init)
|
||||||
|
# - seq: Input Sequence (B, L)
|
||||||
|
# - idx: Residue index
|
||||||
|
# Outputs:
|
||||||
|
# - msa: Initial MSA embedding (B, N, L, d_msa)
|
||||||
|
N = msa.shape[1] # number of sequenes in MSA
|
||||||
|
msa = self.emb(msa) # (B, N, L, d_model) # MSA embedding
|
||||||
|
seq = self.emb_q(seq).unsqueeze(1) # (B, 1, L, d_model) -- query embedding
|
||||||
|
msa = msa + seq.expand(-1, N, -1, -1) # adding query embedding to MSA
|
||||||
|
#return self.drop(msa)
|
||||||
|
return (msa)
|
||||||
|
|
||||||
|
class Bond_emb(nn.Module):
|
||||||
|
def __init__(self, d_pair=128, d_init=0):
|
||||||
|
super(Bond_emb, self).__init__()
|
||||||
|
|
||||||
|
if d_init==0:
|
||||||
|
d_init = ChemData().NBTYPES
|
||||||
|
|
||||||
|
self.emb = nn.Linear(d_init, d_pair)
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
self.emb = init_lecun_normal(self.emb)
|
||||||
|
nn.init.zeros_(self.emb.bias)
|
||||||
|
|
||||||
|
def forward(self, bond_feats):
|
||||||
|
bond_feats = torch.nn.functional.one_hot(bond_feats, num_classes=ChemData().NBTYPES)
|
||||||
|
return self.emb(bond_feats.float())
|
||||||
|
|
||||||
|
class TemplatePairStack(nn.Module):
|
||||||
|
def __init__(self, n_block=2, d_templ=64, n_head=4, d_hidden=32, d_t1d=22, d_state=32, p_drop=0.25,
|
||||||
|
symmetrize_repeats=False, repeat_length=None, symmsub_k=1, sym_method=None):
|
||||||
|
|
||||||
|
super(TemplatePairStack, self).__init__()
|
||||||
|
self.n_block = n_block
|
||||||
|
self.proj_t1d = nn.Linear(d_t1d, d_state)
|
||||||
|
|
||||||
|
proc_s = [PairStr2Pair(d_pair=d_templ,
|
||||||
|
n_head=n_head,
|
||||||
|
d_hidden=d_hidden,
|
||||||
|
d_state=d_state,
|
||||||
|
p_drop=p_drop,
|
||||||
|
symmetrize_repeats=symmetrize_repeats,
|
||||||
|
repeat_length=repeat_length,
|
||||||
|
symmsub_k=symmsub_k,
|
||||||
|
sym_method=sym_method) for i in range(n_block)]
|
||||||
|
|
||||||
|
self.block = nn.ModuleList(proc_s)
|
||||||
|
self.norm = nn.LayerNorm(d_templ)
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
self.proj_t1d = init_lecun_normal(self.proj_t1d)
|
||||||
|
nn.init.zeros_(self.proj_t1d.bias)
|
||||||
|
|
||||||
|
def forward(self, templ, rbf_feat, t1d, use_checkpoint=False, p2p_crop=-1):
|
||||||
|
B, T, L = templ.shape[:3]
|
||||||
|
templ = templ.reshape(B*T, L, L, -1)
|
||||||
|
t1d = t1d.reshape(B*T, L, -1)
|
||||||
|
state = self.proj_t1d(t1d)
|
||||||
|
|
||||||
|
for i_block in range(self.n_block):
|
||||||
|
if use_checkpoint:
|
||||||
|
templ = checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.block[i_block]),
|
||||||
|
templ, rbf_feat, state, p2p_crop,
|
||||||
|
use_reentrant=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
templ = self.block[i_block](templ, rbf_feat, state)
|
||||||
|
return self.norm(templ).reshape(B, T, L, L, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def copy_main_2d(pair, Leff, idx):
|
||||||
|
"""
|
||||||
|
Copies the "main unit" of a block in generic 2D representation of shape (...,L,L,h)
|
||||||
|
along the main diagonal
|
||||||
|
"""
|
||||||
|
start = idx*Leff
|
||||||
|
end = (idx+1)*Leff
|
||||||
|
|
||||||
|
# grab the main block
|
||||||
|
main = torch.clone( pair[..., start:end, start:end, :] )
|
||||||
|
|
||||||
|
# copy it around the main diag
|
||||||
|
L = pair.shape[-2]
|
||||||
|
assert L%Leff == 0
|
||||||
|
N = L//Leff
|
||||||
|
|
||||||
|
for i_block in range(N):
|
||||||
|
start = i_block*Leff
|
||||||
|
stop = (i_block+1)*Leff
|
||||||
|
|
||||||
|
pair[...,start:stop, start:stop, :] = main
|
||||||
|
|
||||||
|
return pair
|
||||||
|
|
||||||
|
|
||||||
|
def copy_main_1d(single, Leff, idx):
|
||||||
|
"""
|
||||||
|
Copies the "main unit" of a block in generic 1D representation of shape (...,L,h)
|
||||||
|
to all other (non-main) blocks
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
single (torch.tensor, required): Shape [...,L,h] "1D" tensor
|
||||||
|
"""
|
||||||
|
main_start = idx*Leff
|
||||||
|
main_end = (idx+1)*Leff
|
||||||
|
|
||||||
|
# grab main block
|
||||||
|
main = torch.clone(single[..., main_start:main_end, :])
|
||||||
|
|
||||||
|
# copy it around
|
||||||
|
L = single.shape[-2]
|
||||||
|
assert L%Leff == 0
|
||||||
|
N = L//Leff
|
||||||
|
|
||||||
|
for i_block in range(N):
|
||||||
|
start = i_block*Leff
|
||||||
|
end = (i_block+1)*Leff
|
||||||
|
|
||||||
|
single[..., start:end, :] = main
|
||||||
|
|
||||||
|
return single
|
||||||
|
|
||||||
|
|
||||||
|
class Templ_emb(nn.Module):
|
||||||
|
# Get template embedding
|
||||||
|
# Features are
|
||||||
|
# t2d:
|
||||||
|
# - 61 distogram bins + 6 orientations (67)
|
||||||
|
# - Mask (missing/unaligned) (1)
|
||||||
|
# t1d:
|
||||||
|
# - tiled AA sequence (20 standard aa + gap)
|
||||||
|
# - confidence (1)
|
||||||
|
#
|
||||||
|
def __init__(self, d_t1d=0, d_t2d=67+1, d_tor=0, d_pair=128, d_state=32,
|
||||||
|
n_block=2, d_templ=64,
|
||||||
|
n_head=4, d_hidden=16, p_drop=0.25,
|
||||||
|
symmetrize_repeats=False, repeat_length=None, symmsub_k=1, sym_method='mean',
|
||||||
|
main_block=None, copy_main_block=None, additional_dt1d=0):
|
||||||
|
if d_t1d==0:
|
||||||
|
d_t1d=(ChemData().NAATOKENS-1)+1
|
||||||
|
if d_tor==0:
|
||||||
|
d_tor=3*ChemData().NTOTALDOFS
|
||||||
|
|
||||||
|
self.main_block = main_block
|
||||||
|
self.symmetrize_repeats = symmetrize_repeats
|
||||||
|
self.copy_main_block = copy_main_block
|
||||||
|
self.repeat_length = repeat_length
|
||||||
|
d_t1d += additional_dt1d
|
||||||
|
|
||||||
|
super(Templ_emb, self).__init__()
|
||||||
|
# process 2D features
|
||||||
|
self.emb = nn.Linear(d_t1d*2+d_t2d, d_templ)
|
||||||
|
|
||||||
|
self.templ_stack = TemplatePairStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
||||||
|
d_hidden=d_hidden, d_t1d=d_t1d, d_state=d_state, p_drop=p_drop,
|
||||||
|
symmetrize_repeats=symmetrize_repeats, repeat_length=repeat_length,
|
||||||
|
symmsub_k=symmsub_k, sym_method=sym_method)
|
||||||
|
|
||||||
|
self.attn = Attention(d_pair, d_templ, n_head, d_hidden, d_pair, p_drop=p_drop)
|
||||||
|
|
||||||
|
# process torsion angles
|
||||||
|
self.emb_t1d = nn.Linear(d_t1d+d_tor, d_templ)
|
||||||
|
self.proj_t1d = nn.Linear(d_templ, d_templ)
|
||||||
|
#self.tor_stack = TemplateTorsionStack(n_block=n_block, d_templ=d_templ, n_head=n_head,
|
||||||
|
# d_hidden=d_hidden, p_drop=p_drop)
|
||||||
|
self.attn_tor = Attention(d_state, d_templ, n_head, d_hidden, d_state, p_drop=p_drop)
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
self.emb = init_lecun_normal(self.emb)
|
||||||
|
nn.init.zeros_(self.emb.bias)
|
||||||
|
|
||||||
|
nn.init.kaiming_normal_(self.emb_t1d.weight, nonlinearity='relu')
|
||||||
|
nn.init.zeros_(self.emb_t1d.bias)
|
||||||
|
|
||||||
|
self.proj_t1d = init_lecun_normal(self.proj_t1d)
|
||||||
|
nn.init.zeros_(self.proj_t1d.bias)
|
||||||
|
|
||||||
|
def _get_templ_emb(self, t1d, t2d):
|
||||||
|
B, T, L, _ = t1d.shape
|
||||||
|
# Prepare 2D template features
|
||||||
|
left = t1d.unsqueeze(3).expand(-1,-1,-1,L,-1)
|
||||||
|
right = t1d.unsqueeze(2).expand(-1,-1,L,-1,-1)
|
||||||
|
#
|
||||||
|
templ = torch.cat((t2d, left, right), -1) # (B, T, L, L, 88)
|
||||||
|
return self.emb(templ) # Template templures (B, T, L, L, d_templ)
|
||||||
|
|
||||||
|
def _get_templ_rbf(self, xyz_t, mask_t):
|
||||||
|
B, T, L = xyz_t.shape[:3]
|
||||||
|
|
||||||
|
# process each template features
|
||||||
|
xyz_t = xyz_t.reshape(B*T, L, 3).contiguous()
|
||||||
|
mask_t = mask_t.reshape(B*T, L, L)
|
||||||
|
assert(xyz_t.is_contiguous())
|
||||||
|
rbf_feat = rbf(torch.cdist(xyz_t, xyz_t)) * mask_t[...,None] # (B*T, L, L, d_rbf)
|
||||||
|
return rbf_feat
|
||||||
|
|
||||||
|
def forward(self, t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=False, p2p_crop=-1):
|
||||||
|
# Input
|
||||||
|
# - t1d: 1D template info (B, T, L, 30)
|
||||||
|
# - t2d: 2D template info (B, T, L, L, 44)
|
||||||
|
# - alpha_t: torsion angle info (B, T, L, 30) - DOUBLE-CHECK
|
||||||
|
# - xyz_t: template CA coordinates (B, T, L, 3)
|
||||||
|
# - mask_t: is valid residue pair? (B, T, L, L)
|
||||||
|
# - pair: query pair features (B, L, L, d_pair)
|
||||||
|
# - state: query state features (B, L, d_state)
|
||||||
|
B, T, L, _ = t1d.shape
|
||||||
|
|
||||||
|
templ = self._get_templ_emb(t1d, t2d)
|
||||||
|
# this looks a lot like a bug but it is not
|
||||||
|
# mask_t has already been updated by same_chain in the train_EMA script so pairwise distances between
|
||||||
|
# protein chains are ignored
|
||||||
|
rbf_feat = self._get_templ_rbf(xyz_t, mask_t)
|
||||||
|
|
||||||
|
# process each template pair feature
|
||||||
|
templ = self.templ_stack(templ, rbf_feat, t1d, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop) # (B, T, L,L, d_templ)
|
||||||
|
|
||||||
|
# DJ - repeat protein symmetrization (2D)
|
||||||
|
if self.copy_main_block:
|
||||||
|
assert not (self.main_block is None)
|
||||||
|
assert self.symmetrize_repeats
|
||||||
|
# copy the main repeat unit internally down the pair representation diagonal
|
||||||
|
templ = copy_main_2d(templ, self.repeat_length, self.main_block)
|
||||||
|
|
||||||
|
# Prepare 1D template torsion angle features
|
||||||
|
t1d = torch.cat((t1d, alpha_t), dim=-1) # (B, T, L, 30+3*17)
|
||||||
|
# process each template features
|
||||||
|
t1d = self.proj_t1d(F.relu_(self.emb_t1d(t1d)))
|
||||||
|
|
||||||
|
# DJ - repeat protein symmetrization (1D)
|
||||||
|
if self.copy_main_block:
|
||||||
|
# already made assertions above
|
||||||
|
# copy main unit down single rep
|
||||||
|
t1d = copy_main_1d(t1d, self.repeat_length, self.main_block)
|
||||||
|
|
||||||
|
# mixing query state features to template state features
|
||||||
|
state = state.reshape(B*L, 1, -1)
|
||||||
|
t1d = t1d.permute(0,2,1,3).reshape(B*L, T, -1)
|
||||||
|
if use_checkpoint:
|
||||||
|
out = checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.attn_tor), state, t1d, t1d, use_reentrant=True
|
||||||
|
)
|
||||||
|
out = out.reshape(B, L, -1)
|
||||||
|
else:
|
||||||
|
out = self.attn_tor(state, t1d, t1d).reshape(B, L, -1)
|
||||||
|
state = state.reshape(B, L, -1)
|
||||||
|
state = state + out
|
||||||
|
|
||||||
|
# mixing query pair features to template information (Template pointwise attention)
|
||||||
|
pair = pair.reshape(B*L*L, 1, -1)
|
||||||
|
templ = templ.permute(0, 2, 3, 1, 4).reshape(B*L*L, T, -1)
|
||||||
|
if use_checkpoint:
|
||||||
|
out = checkpoint.checkpoint(
|
||||||
|
create_custom_forward(self.attn), pair, templ, templ, use_reentrant=True
|
||||||
|
)
|
||||||
|
out = out.reshape(B, L, L, -1)
|
||||||
|
else:
|
||||||
|
out = self.attn(pair, templ, templ).reshape(B, L, L, -1)
|
||||||
|
#
|
||||||
|
pair = pair.reshape(B, L, L, -1)
|
||||||
|
pair = pair + out
|
||||||
|
|
||||||
|
return pair, state
|
||||||
|
|
||||||
|
|
||||||
|
class Recycling(nn.Module):
|
||||||
|
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
|
||||||
|
super(Recycling, self).__init__()
|
||||||
|
self.proj_dist = nn.Linear(d_rbf, d_pair)
|
||||||
|
self.norm_pair = nn.LayerNorm(d_pair)
|
||||||
|
self.norm_msa = nn.LayerNorm(d_msa)
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
#self.emb_rbf = init_lecun_normal(self.emb_rbf)
|
||||||
|
#nn.init.zeros_(self.emb_rbf.bias)
|
||||||
|
self.proj_dist = init_lecun_normal(self.proj_dist)
|
||||||
|
nn.init.zeros_(self.proj_dist.bias)
|
||||||
|
|
||||||
|
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
|
||||||
|
B, L = msa.shape[:2]
|
||||||
|
msa = self.norm_msa(msa)
|
||||||
|
pair = self.norm_pair(pair)
|
||||||
|
|
||||||
|
Ca = xyz[:,:,1]
|
||||||
|
dist_CA = rbf(
|
||||||
|
torch.cdist(Ca, Ca)
|
||||||
|
).reshape(B,L,L,-1)
|
||||||
|
|
||||||
|
if mask_recycle != None:
|
||||||
|
dist_CA = mask_recycle[...,None].float()*dist_CA
|
||||||
|
|
||||||
|
pair = pair + self.proj_dist(dist_CA)
|
||||||
|
|
||||||
|
return msa, pair, state # state is just zeros
|
||||||
|
|
||||||
|
class RecyclingAllFeatures(nn.Module):
|
||||||
|
def __init__(self, d_msa=256, d_pair=128, d_state=32, d_rbf=64):
|
||||||
|
super(RecyclingAllFeatures, self).__init__()
|
||||||
|
self.proj_dist = nn.Linear(d_rbf+d_state*2, d_pair)
|
||||||
|
self.norm_pair = nn.LayerNorm(d_pair)
|
||||||
|
self.proj_sctors = nn.Linear(2*ChemData().NTOTALDOFS, d_msa)
|
||||||
|
self.norm_msa = nn.LayerNorm(d_msa)
|
||||||
|
self.norm_state = nn.LayerNorm(d_state)
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
self.proj_dist = init_lecun_normal(self.proj_dist)
|
||||||
|
nn.init.zeros_(self.proj_dist.bias)
|
||||||
|
self.proj_sctors = init_lecun_normal(self.proj_sctors)
|
||||||
|
nn.init.zeros_(self.proj_sctors.bias)
|
||||||
|
|
||||||
|
def forward(self, msa, pair, xyz, state, sctors, mask_recycle=None):
|
||||||
|
B, L = pair.shape[:2]
|
||||||
|
state = self.norm_state(state)
|
||||||
|
|
||||||
|
left = state.unsqueeze(2).expand(-1,-1,L,-1)
|
||||||
|
right = state.unsqueeze(1).expand(-1,L,-1,-1)
|
||||||
|
|
||||||
|
Ca_or_P = xyz[:,:,1].contiguous()
|
||||||
|
|
||||||
|
dist = rbf(torch.cdist(Ca_or_P, Ca_or_P))
|
||||||
|
if mask_recycle != None:
|
||||||
|
dist = mask_recycle[...,None].float()*dist
|
||||||
|
dist = torch.cat((dist, left, right), dim=-1)
|
||||||
|
dist = self.proj_dist(dist)
|
||||||
|
pair = dist + self.norm_pair(pair)
|
||||||
|
|
||||||
|
sctors = self.proj_sctors(sctors.reshape(B,-1,2*ChemData().NTOTALDOFS))
|
||||||
|
msa = sctors + self.norm_msa(msa)
|
||||||
|
|
||||||
|
return msa, pair, state
|
||||||
|
|
||||||
|
recycling_factory = {
|
||||||
|
"msa_pair": Recycling,
|
||||||
|
"all": RecyclingAllFeatures
|
||||||
|
}
|
100
rf2aa/model/layers/SE3_network.py
Normal file
100
rf2aa/model/layers/SE3_network.py
Normal file
|
@ -0,0 +1,100 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from icecream import ic
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
#script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||||
|
#sys.path.insert(0,script_dir+'SE3Transformer')
|
||||||
|
|
||||||
|
from rf2aa.util import xyz_frame_from_rotation_mask
|
||||||
|
from rf2aa.util_module import init_lecun_normal_param, \
|
||||||
|
make_full_graph, rbf, init_lecun_normal
|
||||||
|
from rf2aa.loss.loss import calc_chiral_grads
|
||||||
|
from rf2aa.model.layers.Attention_module import FeedForwardLayer
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model import SE3Transformer
|
||||||
|
from rf2aa.SE3Transformer.se3_transformer.model.fiber import Fiber
|
||||||
|
from rf2aa.util_module import get_seqsep_protein_sm
|
||||||
|
|
||||||
|
se3_transformer_path = inspect.getfile(SE3Transformer)
|
||||||
|
se3_fiber_path = inspect.getfile(Fiber)
|
||||||
|
assert 'rf2aa' in se3_transformer_path
|
||||||
|
|
||||||
|
class SE3TransformerWrapper(nn.Module):
|
||||||
|
"""SE(3) equivariant GCN with attention"""
|
||||||
|
def __init__(self, num_layers=2, num_channels=32, num_degrees=3, n_heads=4, div=4,
|
||||||
|
l0_in_features=32, l0_out_features=32,
|
||||||
|
l1_in_features=3, l1_out_features=2,
|
||||||
|
num_edge_features=32):
|
||||||
|
super().__init__()
|
||||||
|
# Build the network
|
||||||
|
self.l1_in = l1_in_features
|
||||||
|
self.l1_out = l1_out_features
|
||||||
|
#
|
||||||
|
fiber_edge = Fiber({0: num_edge_features})
|
||||||
|
if l1_out_features > 0:
|
||||||
|
if l1_in_features > 0:
|
||||||
|
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||||
|
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||||
|
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||||
|
else:
|
||||||
|
fiber_in = Fiber({0: l0_in_features})
|
||||||
|
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||||
|
fiber_out = Fiber({0: l0_out_features, 1: l1_out_features})
|
||||||
|
else:
|
||||||
|
if l1_in_features > 0:
|
||||||
|
fiber_in = Fiber({0: l0_in_features, 1: l1_in_features})
|
||||||
|
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||||
|
fiber_out = Fiber({0: l0_out_features})
|
||||||
|
else:
|
||||||
|
fiber_in = Fiber({0: l0_in_features})
|
||||||
|
fiber_hidden = Fiber.create(num_degrees, num_channels)
|
||||||
|
fiber_out = Fiber({0: l0_out_features})
|
||||||
|
|
||||||
|
self.se3 = SE3Transformer(num_layers=num_layers,
|
||||||
|
fiber_in=fiber_in,
|
||||||
|
fiber_hidden=fiber_hidden,
|
||||||
|
fiber_out = fiber_out,
|
||||||
|
num_heads=n_heads,
|
||||||
|
channels_div=div,
|
||||||
|
fiber_edge=fiber_edge,
|
||||||
|
populate_edge="arcsin",
|
||||||
|
final_layer="lin",
|
||||||
|
use_layer_norm=True)
|
||||||
|
|
||||||
|
self.reset_parameter()
|
||||||
|
|
||||||
|
def reset_parameter(self):
|
||||||
|
|
||||||
|
# make sure linear layer before ReLu are initialized with kaiming_normal_
|
||||||
|
for n, p in self.se3.named_parameters():
|
||||||
|
if "bias" in n:
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
elif len(p.shape) == 1:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if "radial_func" not in n:
|
||||||
|
p = init_lecun_normal_param(p)
|
||||||
|
else:
|
||||||
|
if "net.6" in n:
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
else:
|
||||||
|
nn.init.kaiming_normal_(p, nonlinearity='relu')
|
||||||
|
|
||||||
|
# make last layers to be zero-initialized
|
||||||
|
#self.se3.graph_modules[-1].to_kernel_self['0'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||||
|
#self.se3.graph_modules[-1].to_kernel_self['1'] = init_lecun_normal_param(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||||
|
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['0'])
|
||||||
|
#nn.init.zeros_(self.se3.graph_modules[-1].to_kernel_self['1'])
|
||||||
|
nn.init.zeros_(self.se3.graph_modules[-1].weights['0'])
|
||||||
|
if self.l1_out > 0:
|
||||||
|
nn.init.zeros_(self.se3.graph_modules[-1].weights['1'])
|
||||||
|
|
||||||
|
def forward(self, G, type_0_features, type_1_features=None, edge_features=None):
|
||||||
|
if self.l1_in > 0:
|
||||||
|
node_features = {'0': type_0_features, '1': type_1_features}
|
||||||
|
else:
|
||||||
|
node_features = {'0': type_0_features}
|
||||||
|
edge_features = {'0': edge_features}
|
||||||
|
return self.se3(G, node_features, edge_features)
|
||||||
|
|
208
rf2aa/run_inference.py
Normal file
208
rf2aa/run_inference.py
Normal file
|
@ -0,0 +1,208 @@
|
||||||
|
import os
|
||||||
|
import hydra
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
from rf2aa.data.merge_inputs import merge_all
|
||||||
|
from rf2aa.data.covale import load_covalent_molecules
|
||||||
|
from rf2aa.data.nucleic_acid import load_nucleic_acid
|
||||||
|
from rf2aa.data.protein import generate_msa_and_load_protein
|
||||||
|
from rf2aa.data.small_molecule import load_small_molecule
|
||||||
|
from rf2aa.ffindex import *
|
||||||
|
from rf2aa.chemical import initialize_chemdata, load_pdb_ideal_sdf_strings
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
|
||||||
|
from rf2aa.training.recycling import recycle_step_legacy
|
||||||
|
from rf2aa.util import writepdb, is_atom, Ls_from_same_chain_2d
|
||||||
|
from rf2aa.util_module import XYZConverter
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRunner:
|
||||||
|
|
||||||
|
def __init__(self, config) -> None:
|
||||||
|
self.config = config
|
||||||
|
initialize_chemdata(self.config.chem_params)
|
||||||
|
FFindexDB = namedtuple("FFindexDB", "index, data")
|
||||||
|
self.ffdb = FFindexDB(read_index(config.database_params.hhdb+'_pdb.ffindex'),
|
||||||
|
read_data(config.database_params.hhdb+'_pdb.ffdata'))
|
||||||
|
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.xyz_converter = XYZConverter()
|
||||||
|
self.deterministic = config.get("deterministic", False)
|
||||||
|
self.molecule_db = load_pdb_ideal_sdf_strings()
|
||||||
|
|
||||||
|
def parse_inference_config(self):
|
||||||
|
residues_to_atomize = [] # chain letter, residue number, residue name
|
||||||
|
chains = []
|
||||||
|
protein_inputs = {}
|
||||||
|
if self.config.protein_inputs is not None:
|
||||||
|
for chain in self.config.protein_inputs:
|
||||||
|
if chain in chains:
|
||||||
|
raise ValueError(f"Duplicate chain found with name: {chain}. Please specify unique chain names")
|
||||||
|
elif len(chain) > 1:
|
||||||
|
raise ValueError(f"Chain name must be a single character, found chain with name: {chain}")
|
||||||
|
else:
|
||||||
|
chains.append(chain)
|
||||||
|
protein_input = generate_msa_and_load_protein(
|
||||||
|
self.config.protein_inputs[chain]["fasta_file"],
|
||||||
|
self
|
||||||
|
)
|
||||||
|
protein_inputs[chain] = protein_input
|
||||||
|
|
||||||
|
na_inputs = {}
|
||||||
|
if self.config.na_inputs is not None:
|
||||||
|
for chain in self.config.na_inputs:
|
||||||
|
na_input = load_nucleic_acid(
|
||||||
|
self.config.na_inputs[chain]["fasta"],
|
||||||
|
self.config.na_inputs[chain]["input_type"],
|
||||||
|
self
|
||||||
|
)
|
||||||
|
na_inputs[chain] = na_input
|
||||||
|
|
||||||
|
sm_inputs = {}
|
||||||
|
# first if any of the small molecules are covalently bonded to the protein
|
||||||
|
# merge the small molecule with the residue and add it as a separate ligand
|
||||||
|
# also add it to residues_to_atomize for bookkeeping later on
|
||||||
|
# need to handle atomizing multiple consecutive residues here too
|
||||||
|
if self.config.covale_inputs is not None:
|
||||||
|
covalent_sm_inputs, residues_to_atomize_covale = load_covalent_molecules(protein_inputs, self.config, self)
|
||||||
|
sm_inputs.update(covalent_sm_inputs)
|
||||||
|
residues_to_atomize.extend(residues_to_atomize_covale)
|
||||||
|
|
||||||
|
if self.config.sm_inputs is not None:
|
||||||
|
for chain in self.config.sm_inputs:
|
||||||
|
if self.config.sm_inputs[chain]["input_type"] not in ["smiles", "sdf"]:
|
||||||
|
raise ValueError("Small molecule input type must be smiles or sdf")
|
||||||
|
if chain in sm_inputs: # chain already processed as covale
|
||||||
|
continue
|
||||||
|
if "is_leaving" in self.config.sm_inputs[chain]:
|
||||||
|
raise ValueError("Leaving atoms are not supported for non-covalently bonded molecules")
|
||||||
|
sm_input = load_small_molecule(
|
||||||
|
self.config.sm_inputs[chain]["input"],
|
||||||
|
self.config.sm_inputs[chain]["input_type"],
|
||||||
|
self
|
||||||
|
)
|
||||||
|
sm_inputs[chain] = sm_input
|
||||||
|
|
||||||
|
if self.config.residue_replacement is not None:
|
||||||
|
# add to the sm_inputs list
|
||||||
|
# add to residues to atomize
|
||||||
|
raise NotImplementedError("Modres inference is not implemented")
|
||||||
|
|
||||||
|
raw_data = merge_all(protein_inputs, na_inputs, sm_inputs, residues_to_atomize, deterministic=self.deterministic)
|
||||||
|
self.raw_data = raw_data
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
self.model = RoseTTAFoldModule(
|
||||||
|
**self.config.legacy_model_param,
|
||||||
|
aamask = ChemData().allatom_mask.to(self.device),
|
||||||
|
atom_type_index = ChemData().atom_type_index.to(self.device),
|
||||||
|
ljlk_parameters = ChemData().ljlk_parameters.to(self.device),
|
||||||
|
lj_correction_parameters = ChemData().lj_correction_parameters.to(self.device),
|
||||||
|
num_bonds = ChemData().num_bonds.to(self.device),
|
||||||
|
cb_len = ChemData().cb_length_t.to(self.device),
|
||||||
|
cb_ang = ChemData().cb_angle_t.to(self.device),
|
||||||
|
cb_tor = ChemData().cb_torsion_t.to(self.device),
|
||||||
|
|
||||||
|
).to(self.device)
|
||||||
|
checkpoint = torch.load(self.config.checkpoint_path, map_location=self.device)
|
||||||
|
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
|
||||||
|
def construct_features(self):
|
||||||
|
return self.raw_data.construct_features(self)
|
||||||
|
|
||||||
|
def run_model_forward(self, input_feats):
|
||||||
|
input_feats.add_batch_dim()
|
||||||
|
input_feats.to(self.device)
|
||||||
|
input_dict = asdict(input_feats)
|
||||||
|
input_dict["bond_feats"] = input_dict["bond_feats"].long()
|
||||||
|
input_dict["seq_unmasked"] = input_dict["seq_unmasked"].long()
|
||||||
|
outputs = recycle_step_legacy(self.model,
|
||||||
|
input_dict,
|
||||||
|
self.config.loader_params.MAXCYCLE,
|
||||||
|
use_amp=False,
|
||||||
|
nograds=True,
|
||||||
|
force_device=self.device)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def write_outputs(self, input_feats, outputs):
|
||||||
|
logits, logits_aa, logits_pae, logits_pde, p_bind, \
|
||||||
|
xyz, alpha_s, xyz_allatom, lddt, _, _, _ \
|
||||||
|
= outputs
|
||||||
|
seq_unmasked = input_feats.seq_unmasked
|
||||||
|
bond_feats = input_feats.bond_feats
|
||||||
|
err_dict = self.calc_pred_err(lddt, logits_pae, logits_pde, seq_unmasked)
|
||||||
|
err_dict["same_chain"] = input_feats.same_chain
|
||||||
|
plddts = err_dict["plddts"]
|
||||||
|
Ls = Ls_from_same_chain_2d(input_feats.same_chain)
|
||||||
|
plddts = plddts[0]
|
||||||
|
writepdb(os.path.join(f"{self.config.output_path}", f"{self.config.job_name}.pdb"),
|
||||||
|
xyz_allatom,
|
||||||
|
seq_unmasked,
|
||||||
|
bond_feats=bond_feats,
|
||||||
|
bfacts=plddts,
|
||||||
|
chain_Ls=Ls
|
||||||
|
)
|
||||||
|
torch.save(err_dict, os.path.join(f"{self.config.output_path}",
|
||||||
|
f"{self.config.job_name}_aux.pt"))
|
||||||
|
|
||||||
|
def infer(self):
|
||||||
|
self.load_model()
|
||||||
|
self.parse_inference_config()
|
||||||
|
input_feats = self.construct_features()
|
||||||
|
outputs = self.run_model_forward(input_feats)
|
||||||
|
self.write_outputs(input_feats, outputs)
|
||||||
|
|
||||||
|
def lddt_unbin(self, pred_lddt):
|
||||||
|
# calculate lddt prediction loss
|
||||||
|
nbin = pred_lddt.shape[1]
|
||||||
|
bin_step = 1.0 / nbin
|
||||||
|
lddt_bins = torch.linspace(bin_step, 1.0, nbin, dtype=pred_lddt.dtype, device=pred_lddt.device)
|
||||||
|
|
||||||
|
pred_lddt = nn.Softmax(dim=1)(pred_lddt)
|
||||||
|
return torch.sum(lddt_bins[None,:,None]*pred_lddt, dim=1)
|
||||||
|
|
||||||
|
def pae_unbin(self, logits_pae, bin_step=0.5):
|
||||||
|
nbin = logits_pae.shape[1]
|
||||||
|
bins = torch.linspace(bin_step*0.5, bin_step*nbin-bin_step*0.5, nbin,
|
||||||
|
dtype=logits_pae.dtype, device=logits_pae.device)
|
||||||
|
logits_pae = torch.nn.Softmax(dim=1)(logits_pae)
|
||||||
|
return torch.sum(bins[None,:,None,None]*logits_pae, dim=1)
|
||||||
|
|
||||||
|
def pde_unbin(self, logits_pde, bin_step=0.3):
|
||||||
|
nbin = logits_pde.shape[1]
|
||||||
|
bins = torch.linspace(bin_step*0.5, bin_step*nbin-bin_step*0.5, nbin,
|
||||||
|
dtype=logits_pde.dtype, device=logits_pde.device)
|
||||||
|
logits_pde = torch.nn.Softmax(dim=1)(logits_pde)
|
||||||
|
return torch.sum(bins[None,:,None,None]*logits_pde, dim=1)
|
||||||
|
|
||||||
|
def calc_pred_err(self, pred_lddts, logit_pae, logit_pde, seq):
|
||||||
|
"""Calculates summary metrics on predicted lDDT and distance errors"""
|
||||||
|
plddts = self.lddt_unbin(pred_lddts)
|
||||||
|
pae = self.pae_unbin(logit_pae) if logit_pae is not None else None
|
||||||
|
pde = self.pde_unbin(logit_pde) if logit_pde is not None else None
|
||||||
|
sm_mask = is_atom(seq)[0]
|
||||||
|
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
|
||||||
|
prot_mask_2d = (~sm_mask[None,:])*(~sm_mask[:,None])
|
||||||
|
inter_mask_2d = sm_mask[None,:]*(~sm_mask[:,None]) + (~sm_mask[None,:])*sm_mask[:,None]
|
||||||
|
# assumes B=1
|
||||||
|
err_dict = dict(
|
||||||
|
plddts = plddts.cpu(),
|
||||||
|
pae = pae.cpu(),
|
||||||
|
pde = pde.cpu(),
|
||||||
|
mean_plddt = float(plddts.mean()),
|
||||||
|
mean_pae = float(pae.mean()) if pae is not None else None,
|
||||||
|
pae_prot = float(pae[0,prot_mask_2d].mean()) if pae is not None else None,
|
||||||
|
pae_inter = float(pae[0,inter_mask_2d].mean()) if pae is not None else None,
|
||||||
|
)
|
||||||
|
return err_dict
|
||||||
|
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path='config/inference')
|
||||||
|
def main(config):
|
||||||
|
runner = ModelRunner(config)
|
||||||
|
runner.infer()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
371
rf2aa/scoring.py
Normal file
371
rf2aa/scoring.py
Normal file
|
@ -0,0 +1,371 @@
|
||||||
|
import json, os
|
||||||
|
|
||||||
|
script_dir = os.path.dirname(os.path.realpath(__file__))+'/'
|
||||||
|
|
||||||
|
##
|
||||||
|
## lk and lk term
|
||||||
|
#(LJ_RADIUS LJ_WDEPTH LK_DGFREE LK_LAMBDA LK_VOLUME)
|
||||||
|
type2ljlk = {
|
||||||
|
"CNH2":(1.968297,0.094638,3.077030,3.5000,13.500000),
|
||||||
|
"COO":(1.916661,0.141799,-3.332648,3.5000,14.653000),
|
||||||
|
"CH0":(2.011760,0.062642,1.409284,3.5000,8.998000),
|
||||||
|
"CH1":(2.011760,0.062642,-3.538387,3.5000,10.686000),
|
||||||
|
"CH2":(2.011760,0.062642,-1.854658,3.5000,18.331000),
|
||||||
|
"CH3":(2.011760,0.062642,7.292929,3.5000,25.855000),
|
||||||
|
"aroC":(2.016441,0.068775,1.797950,3.5000,16.704000),
|
||||||
|
"Ntrp":(1.802452,0.161725,-8.413116,3.5000,9.522100),
|
||||||
|
"Nhis":(1.802452,0.161725,-9.739606,3.5000,9.317700),
|
||||||
|
"NtrR":(1.802452,0.161725,-5.158080,3.5000,9.779200),
|
||||||
|
"NH2O":(1.802452,0.161725,-8.101638,3.5000,15.689000),
|
||||||
|
"Nlys":(1.802452,0.161725,-20.864641,3.5000,16.514000),
|
||||||
|
"Narg":(1.802452,0.161725,-8.968351,3.5000,15.717000),
|
||||||
|
"Npro":(1.802452,0.161725,-0.984585,3.5000,3.718100),
|
||||||
|
"OH":(1.542743,0.161947,-8.133520,3.5000,10.722000),
|
||||||
|
"OHY":(1.542743,0.161947,-8.133520,3.5000,10.722000),
|
||||||
|
"ONH2":(1.548662,0.182924,-6.591644,3.5000,10.102000),
|
||||||
|
"OOC":(1.492871,0.099873,-9.239832,3.5000,9.995600),
|
||||||
|
"S":(1.975967,0.455970,-1.707229,3.5000,17.640000),
|
||||||
|
"SH1":(1.975967,0.455970,3.291643,3.5000,23.240000),
|
||||||
|
"Nbb":(1.802452,0.161725,-9.969494,3.5000,15.992000),
|
||||||
|
"CAbb":(2.011760,0.062642,2.533791,3.5000,12.137000),
|
||||||
|
"CObb":(1.916661,0.141799,3.104248,3.5000,13.221000),
|
||||||
|
"OCbb":(1.540580,0.142417,-8.006829,3.5000,12.196000),
|
||||||
|
"Phos":(2.1500,0.5850,-4.1000,3.5000,14.7000), # phil
|
||||||
|
"Oet2":(1.5500,0.1591,-5.8500,3.5000,10.8000),
|
||||||
|
"Oet3":(1.5500,0.1591,-6.7000,3.5000,10.8000),
|
||||||
|
"HNbb":(0.901681,0.005000,0.0000,3.5000,0.0000),
|
||||||
|
"Hapo":(1.421272,0.021808,0.0000,3.5000,0.0000),
|
||||||
|
"Haro":(1.374914,0.015909,0.0000,3.5000,0.0000),
|
||||||
|
"Hpol":(0.901681,0.005000,0.0000,3.5000,0.0000),
|
||||||
|
"HS":(0.363887,0.050836,0.0000,3.5000,0.0000),
|
||||||
|
"genAl":(1,0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genAs":(1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genAu":(1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genB": (1,0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genBe": (1,0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genBr": (2.1971, 0.1090, 2.7951, 3.5000, 19.6876),
|
||||||
|
"genC": (2.0067, 0.0689, 2.2256, 3.5000, 10.6860), # params from CT
|
||||||
|
"genCa": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genCl": (2.0496, 0.1070, 2.3668, 3.5000, 17.5849),
|
||||||
|
"genCo": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genCr": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genCu": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genF": (1.6941, 0.0750, 1.6442, 3.5000, 12.2163),
|
||||||
|
"genFe": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genHg": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genI": (2.3600, 0.1110, 3.1361, 3.5000, 22.0891),
|
||||||
|
"genIr": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genK": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genLi": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genMg": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genMn": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genMo": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genN": (1.7854, 0.1497, -6.3760, 3.5000, 9.5221), # params from NG2
|
||||||
|
"genNi": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genO": (1.5492, 0.1576, -3.5363, 3.5000, 10.7220), # params for OG3
|
||||||
|
"genOs": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genP": (2.1290, 0.5838, -9.6272, 3.5000, 34.8000), # params for PG5
|
||||||
|
"genPb": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genPd": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||||
|
"genPr": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||||
|
"genPt": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||||
|
"genRe": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||||
|
"genRh": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||||
|
"genRu": (1, 0.1, 0.0000, 0.0000,0.0000),
|
||||||
|
"genS": (1.9893, 0.3634, -2.3560, 3.5000, 17.6400), # params for SG3
|
||||||
|
"genSb": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genSe": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genSi": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genSn": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genTb": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genTe": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genU": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genW": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genV": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genY": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genZn": (1, 0.1, 0.0000, 0.0000, 0.0000),
|
||||||
|
"genATM": (1, 0.0, 0.0000, 0.0000, 0.0000), # masked
|
||||||
|
}
|
||||||
|
|
||||||
|
# cartbonded
|
||||||
|
with open(script_dir+'cartbonded.json', 'r') as j:
|
||||||
|
cartbonded_data_raw = json.loads(j.read())
|
||||||
|
|
||||||
|
# hbond donor/acceptors
|
||||||
|
class HbAtom:
|
||||||
|
NO = 0
|
||||||
|
DO = 1 # donor
|
||||||
|
AC = 2 # acceptor
|
||||||
|
DA = 3 # donor & acceptor
|
||||||
|
HP = 4 # polar H
|
||||||
|
|
||||||
|
type2hb = {
|
||||||
|
"CNH2":HbAtom.NO, "COO":HbAtom.NO, "CH0":HbAtom.NO, "CH1":HbAtom.NO,
|
||||||
|
"CH2":HbAtom.NO, "CH3":HbAtom.NO, "aroC":HbAtom.NO, "Ntrp":HbAtom.DO,
|
||||||
|
"Nhis":HbAtom.AC, "NtrR":HbAtom.DO, "NH2O":HbAtom.DO, "Nlys":HbAtom.DO,
|
||||||
|
"Narg":HbAtom.DO, "Npro":HbAtom.NO, "OH":HbAtom.DA, "OHY":HbAtom.DA,
|
||||||
|
"ONH2":HbAtom.AC, "OOC":HbAtom.AC, "S":HbAtom.NO, "SH1":HbAtom.NO,
|
||||||
|
"Nbb":HbAtom.DO, "CAbb":HbAtom.NO, "CObb":HbAtom.NO, "OCbb":HbAtom.AC,
|
||||||
|
"HNbb":HbAtom.HP, "Hapo":HbAtom.NO, "Haro":HbAtom.NO, "Hpol":HbAtom.HP,
|
||||||
|
"HS":HbAtom.HP, # HP in rosetta(?)
|
||||||
|
"Phos":HbAtom.NO, "Oet2":HbAtom.AC, "Oet3":HbAtom.AC,
|
||||||
|
"genAl":HbAtom.NO, "genAs":HbAtom.NO, "genAu":HbAtom.NO, "genB": HbAtom.NO,
|
||||||
|
"genBe": HbAtom.NO, "genBr": HbAtom.NO, "genC": HbAtom.NO, "genCa": HbAtom.NO,
|
||||||
|
"genCl": HbAtom.NO, "genCo": HbAtom.NO, "genCr": HbAtom.NO, "genCu": HbAtom.NO,
|
||||||
|
"genF": HbAtom.DA, "genFe": HbAtom.NO, "genHg": HbAtom.NO, "genI": HbAtom.NO,
|
||||||
|
"genIr": HbAtom.NO, "genK": HbAtom.NO, "genLi": HbAtom.NO, "genMg": HbAtom.NO,
|
||||||
|
"genMn": HbAtom.NO, "genMo": HbAtom.NO, "genN": HbAtom.DA, "genNi": HbAtom.NO,
|
||||||
|
"genO": HbAtom.DA, "genOs": HbAtom.NO, "genP": HbAtom.NO, "genPb": HbAtom.NO,
|
||||||
|
"genPd": HbAtom.NO, "genPr": HbAtom.NO, "genPt": HbAtom.NO, "genRe": HbAtom.NO,
|
||||||
|
"genRh": HbAtom.NO, "genRu": HbAtom.NO, "genS": HbAtom.DA, "genSb": HbAtom.NO,
|
||||||
|
"genSe": HbAtom.NO, "genSi": HbAtom.NO,"genSn": HbAtom.NO,"genTb": HbAtom.NO,
|
||||||
|
"genTe": HbAtom.NO, "genU": HbAtom.NO, "genW": HbAtom.NO, "genV": HbAtom.NO,
|
||||||
|
"genY": HbAtom.NO, "genZn": HbAtom.NO, "genATM": HbAtom.NO, # masked
|
||||||
|
}
|
||||||
|
|
||||||
|
##
|
||||||
|
## hbond term
|
||||||
|
## TO DO: ADD DNA
|
||||||
|
class HbDonType:
|
||||||
|
PBA = 0
|
||||||
|
IND = 1
|
||||||
|
IME = 2
|
||||||
|
GDE = 3
|
||||||
|
CXA = 4
|
||||||
|
AMO = 5
|
||||||
|
HXL = 6
|
||||||
|
AHX = 7
|
||||||
|
NTYPES = 8
|
||||||
|
|
||||||
|
class HbAccType:
|
||||||
|
PBA = 0
|
||||||
|
CXA = 1
|
||||||
|
CXL = 2
|
||||||
|
HXL = 3
|
||||||
|
AHX = 4
|
||||||
|
IME = 5
|
||||||
|
NTYPES = 6
|
||||||
|
|
||||||
|
class HbHybType:
|
||||||
|
SP2 = 0
|
||||||
|
SP3 = 1
|
||||||
|
RING = 2
|
||||||
|
NTYPES = 3
|
||||||
|
|
||||||
|
type2dontype = {
|
||||||
|
"Nbb": HbDonType.PBA,
|
||||||
|
"Ntrp": HbDonType.IND,
|
||||||
|
"NtrR": HbDonType.GDE,
|
||||||
|
"Narg": HbDonType.GDE,
|
||||||
|
"NH2O": HbDonType.CXA,
|
||||||
|
"Nlys": HbDonType.AMO,
|
||||||
|
"OH": HbDonType.HXL,
|
||||||
|
"OHY": HbDonType.AHX,
|
||||||
|
}
|
||||||
|
|
||||||
|
type2acctype = {
|
||||||
|
"OCbb": HbAccType.PBA,
|
||||||
|
"ONH2": HbAccType.CXA,
|
||||||
|
"OOC": HbAccType.CXL,
|
||||||
|
"OH": HbAccType.HXL,
|
||||||
|
"OHY": HbAccType.AHX,
|
||||||
|
"Nhis": HbAccType.IME,
|
||||||
|
}
|
||||||
|
|
||||||
|
type2hybtype = {
|
||||||
|
"OCbb": HbHybType.SP2,
|
||||||
|
"ONH2": HbHybType.SP2,
|
||||||
|
"OOC": HbHybType.SP2,
|
||||||
|
"OHY": HbHybType.SP3,
|
||||||
|
"OH": HbHybType.SP3,
|
||||||
|
"Nhis": HbHybType.RING,
|
||||||
|
}
|
||||||
|
|
||||||
|
dontype2wt = {
|
||||||
|
HbDonType.PBA: 1.45,
|
||||||
|
HbDonType.IND: 1.15,
|
||||||
|
HbDonType.IME: 1.42,
|
||||||
|
HbDonType.GDE: 1.11,
|
||||||
|
HbDonType.CXA: 1.29,
|
||||||
|
HbDonType.AMO: 1.17,
|
||||||
|
HbDonType.HXL: 0.99,
|
||||||
|
HbDonType.AHX: 1.00,
|
||||||
|
}
|
||||||
|
|
||||||
|
acctype2wt = {
|
||||||
|
HbAccType.PBA: 1.19,
|
||||||
|
HbAccType.CXA: 1.21,
|
||||||
|
HbAccType.CXL: 1.10,
|
||||||
|
HbAccType.HXL: 1.15,
|
||||||
|
HbAccType.AHX: 1.15,
|
||||||
|
HbAccType.IME: 1.17,
|
||||||
|
}
|
||||||
|
|
||||||
|
class HbPolyType:
|
||||||
|
ahdist_aASN_dARG = 0
|
||||||
|
ahdist_aASN_dASN = 1
|
||||||
|
ahdist_aASN_dGLY = 2
|
||||||
|
ahdist_aASN_dHIS = 3
|
||||||
|
ahdist_aASN_dLYS = 4
|
||||||
|
ahdist_aASN_dSER = 5
|
||||||
|
ahdist_aASN_dTRP = 6
|
||||||
|
ahdist_aASN_dTYR = 7
|
||||||
|
ahdist_aASP_dARG = 8
|
||||||
|
ahdist_aASP_dASN = 9
|
||||||
|
ahdist_aASP_dGLY = 10
|
||||||
|
ahdist_aASP_dHIS = 11
|
||||||
|
ahdist_aASP_dLYS = 12
|
||||||
|
ahdist_aASP_dSER = 13
|
||||||
|
ahdist_aASP_dTRP = 14
|
||||||
|
ahdist_aASP_dTYR = 15
|
||||||
|
ahdist_aGLY_dARG = 16
|
||||||
|
ahdist_aGLY_dASN = 17
|
||||||
|
ahdist_aGLY_dGLY = 18
|
||||||
|
ahdist_aGLY_dHIS = 19
|
||||||
|
ahdist_aGLY_dLYS = 20
|
||||||
|
ahdist_aGLY_dSER = 21
|
||||||
|
ahdist_aGLY_dTRP = 22
|
||||||
|
ahdist_aGLY_dTYR = 23
|
||||||
|
ahdist_aHIS_dARG = 24
|
||||||
|
ahdist_aHIS_dASN = 25
|
||||||
|
ahdist_aHIS_dGLY = 26
|
||||||
|
ahdist_aHIS_dHIS = 27
|
||||||
|
ahdist_aHIS_dLYS = 28
|
||||||
|
ahdist_aHIS_dSER = 29
|
||||||
|
ahdist_aHIS_dTRP = 30
|
||||||
|
ahdist_aHIS_dTYR = 31
|
||||||
|
ahdist_aSER_dARG = 32
|
||||||
|
ahdist_aSER_dASN = 33
|
||||||
|
ahdist_aSER_dGLY = 34
|
||||||
|
ahdist_aSER_dHIS = 35
|
||||||
|
ahdist_aSER_dLYS = 36
|
||||||
|
ahdist_aSER_dSER = 37
|
||||||
|
ahdist_aSER_dTRP = 38
|
||||||
|
ahdist_aSER_dTYR = 39
|
||||||
|
ahdist_aTYR_dARG = 40
|
||||||
|
ahdist_aTYR_dASN = 41
|
||||||
|
ahdist_aTYR_dGLY = 42
|
||||||
|
ahdist_aTYR_dHIS = 43
|
||||||
|
ahdist_aTYR_dLYS = 44
|
||||||
|
ahdist_aTYR_dSER = 45
|
||||||
|
ahdist_aTYR_dTRP = 46
|
||||||
|
ahdist_aTYR_dTYR = 47
|
||||||
|
cosBAH_off = 48
|
||||||
|
cosBAH_7 = 49
|
||||||
|
cosBAH_6i = 50
|
||||||
|
AHD_1h = 51
|
||||||
|
AHD_1i = 52
|
||||||
|
AHD_1j = 53
|
||||||
|
AHD_1k = 54
|
||||||
|
|
||||||
|
# map donor:acceptor pairs to polynomials
|
||||||
|
hbtypepair2poly = {
|
||||||
|
(HbDonType.PBA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||||
|
(HbDonType.CXA,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||||
|
(HbDonType.IME,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||||
|
(HbDonType.IND,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||||
|
(HbDonType.AMO,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
|
||||||
|
(HbDonType.GDE,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1j),
|
||||||
|
(HbDonType.AHX,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.HXL,HbAccType.PBA): (HbPolyType.ahdist_aGLY_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.PBA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.CXA,HbAccType.CXA): (HbPolyType.ahdist_aASN_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.IME,HbAccType.CXA): (HbPolyType.ahdist_aASN_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.IND,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.AMO,HbAccType.CXA): (HbPolyType.ahdist_aASN_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
|
||||||
|
(HbDonType.GDE,HbAccType.CXA): (HbPolyType.ahdist_aASN_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.AHX,HbAccType.CXA): (HbPolyType.ahdist_aASN_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.HXL,HbAccType.CXA): (HbPolyType.ahdist_aASN_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.PBA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dGLY,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.CXA,HbAccType.CXL): (HbPolyType.ahdist_aASP_dASN,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.IME,HbAccType.CXL): (HbPolyType.ahdist_aASP_dHIS,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.IND,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTRP,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.AMO,HbAccType.CXL): (HbPolyType.ahdist_aASP_dLYS,HbPolyType.cosBAH_off,HbPolyType.AHD_1h),
|
||||||
|
(HbDonType.GDE,HbAccType.CXL): (HbPolyType.ahdist_aASP_dARG,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.AHX,HbAccType.CXL): (HbPolyType.ahdist_aASP_dTYR,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.HXL,HbAccType.CXL): (HbPolyType.ahdist_aASP_dSER,HbPolyType.cosBAH_off,HbPolyType.AHD_1k),
|
||||||
|
(HbDonType.PBA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dGLY,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.CXA,HbAccType.IME): (HbPolyType.ahdist_aHIS_dASN,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.IME,HbAccType.IME): (HbPolyType.ahdist_aHIS_dHIS,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
|
||||||
|
(HbDonType.IND,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTRP,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
|
||||||
|
(HbDonType.AMO,HbAccType.IME): (HbPolyType.ahdist_aHIS_dLYS,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.GDE,HbAccType.IME): (HbPolyType.ahdist_aHIS_dARG,HbPolyType.cosBAH_7,HbPolyType.AHD_1h),
|
||||||
|
(HbDonType.AHX,HbAccType.IME): (HbPolyType.ahdist_aHIS_dTYR,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.HXL,HbAccType.IME): (HbPolyType.ahdist_aHIS_dSER,HbPolyType.cosBAH_7,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.PBA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.CXA,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.IME,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.IND,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
|
||||||
|
(HbDonType.AMO,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.GDE,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.AHX,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.HXL,HbAccType.AHX): (HbPolyType.ahdist_aTYR_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.PBA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dGLY,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.CXA,HbAccType.HXL): (HbPolyType.ahdist_aSER_dASN,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.IME,HbAccType.HXL): (HbPolyType.ahdist_aSER_dHIS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.IND,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTRP,HbPolyType.cosBAH_6i,HbPolyType.AHD_1h),
|
||||||
|
(HbDonType.AMO,HbAccType.HXL): (HbPolyType.ahdist_aSER_dLYS,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.GDE,HbAccType.HXL): (HbPolyType.ahdist_aSER_dARG,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.AHX,HbAccType.HXL): (HbPolyType.ahdist_aSER_dTYR,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
(HbDonType.HXL,HbAccType.HXL): (HbPolyType.ahdist_aSER_dSER,HbPolyType.cosBAH_6i,HbPolyType.AHD_1i),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# polynomials are triplets, (x_min, x_max), (y[x<x_min],y[x>x_max]), (c_9,...,c_0)
|
||||||
|
hbpolytype2coeffs = { # Parameters imported from rosetta sp2_elec_params @v2017.48-dev59886
|
||||||
|
HbPolyType.ahdist_aASN_dARG: ((0.7019094761929999, 2.86820307153,),(1.1, 1.1,),( 0.58376113, -9.29345473, 64.86270904, -260.3946711, 661.43138077, -1098.01378958, 1183.58371466, -790.82929582, 291.33125475, -43.01629727,)),
|
||||||
|
HbPolyType.ahdist_aASN_dASN: ((0.625841094801, 2.75107708444,),(1.1, 1.1,),( -1.31243015, 18.6745072, -112.63858313, 373.32878091, -734.99145504, 861.38324861, -556.21026097, 143.5626977, 20.03238394, -11.52167705,)),
|
||||||
|
HbPolyType.ahdist_aASN_dGLY: ((0.7477341047139999, 2.6796350782799996,),(1.1, 1.1,),( -1.61294554, 23.3150793, -144.11313069, 496.13575, -1037.83809166, 1348.76826073, -1065.14368678, 473.89008925, -100.41142701, 7.44453515,)),
|
||||||
|
HbPolyType.ahdist_aASN_dHIS: ((0.344789524346, 2.8303582266000005,),(1.1, 1.1,),( -0.2657122, 4.1073775, -26.9099632, 97.10486507, -209.96002602, 277.33057268, -218.74766996, 97.42852213, -24.07382402, 3.73962807,)),
|
||||||
|
HbPolyType.ahdist_aASN_dLYS: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
|
||||||
|
HbPolyType.ahdist_aASN_dSER: ((1.0812774602500002, 2.6832123582599996,),(1.1, 1.1,),( -3.51524353, 47.54032873, -254.40168577, 617.84606386, -255.49935027, -2361.56230539, 6426.85797934, -7760.4403891, 4694.08106855, -1149.83549068,)),
|
||||||
|
HbPolyType.ahdist_aASN_dTRP: ((0.6689984999999999, 3.0704254,),(1.1, 1.1,),( -0.5284840422, 8.3510150838, -56.4100479414, 212.4884326254, -488.3178610608, 703.7762350506, -628.9936994633999, 331.4294356146, -93.265817571, 11.9691623698,)),
|
||||||
|
HbPolyType.ahdist_aASN_dTYR: ((1.08950268805, 2.6887046709400004,),(1.1, 1.1,),( -4.4488705, 63.27696281, -371.44187037, 1121.71921621, -1638.11394306, 142.99988401, 3436.65879147, -5496.07011787, 3709.30505237, -962.79669688,)),
|
||||||
|
HbPolyType.ahdist_aASP_dARG: ((0.8100404642229999, 2.9851230124799994,),(1.1, 1.1,),( -0.66430344, 10.41343145, -70.12656205, 265.12578414, -617.05849171, 911.39378582, -847.25013928, 472.09090981, -141.71513167, 18.57721132,)),
|
||||||
|
HbPolyType.ahdist_aASP_dASN: ((1.05401125073, 3.11129675908,),(1.1, 1.1,),( 0.02090728, -0.24144928, -0.19578075, 16.80904547, -117.70216251, 407.18551288, -809.95195924, 939.83137947, -593.94527692, 159.57610528,)),
|
||||||
|
HbPolyType.ahdist_aASP_dGLY: ((0.886260952629, 2.66843608743,),(1.1, 1.1,),( -7.00699267, 107.33021779, -713.45752385, 2694.43092298, -6353.05100287, 9667.94098394, -9461.9261027, 5721.0086877, -1933.97818198, 279.47763789,)),
|
||||||
|
HbPolyType.ahdist_aASP_dHIS: ((1.03597611139, 2.78208509117,),(1.1, 1.1,),( -1.34823406, 17.08925926, -78.75087193, 106.32795459, 400.18459698, -2041.04320193, 4033.83557387, -4239.60530204, 2324.00877252, -519.38410941,)),
|
||||||
|
HbPolyType.ahdist_aASP_dLYS: ((0.97789485082, 2.50496946108,),(1.1, 1.1,),( -0.41300315, 6.59243438, -44.44525308, 163.11796012, -351.2307798, 443.2463146, -297.84582856, 62.38600547, 33.77496227, -14.11652182,)),
|
||||||
|
HbPolyType.ahdist_aASP_dSER: ((0.542905671869, 2.45259389314,),(1.1, 1.1,),( 1.38531754, -18.48733797, 106.14444613, -344.70585054, 698.91577956, -917.0879402, 775.32787908, -403.09588787, 113.65054778, -11.66516403,)),
|
||||||
|
HbPolyType.ahdist_aASP_dTRP: ((0.419155746414, 3.0486938610500003,),(1.1, 1.1,),( -0.24563471, 3.85598551, -25.75176874, 95.36525025, -214.13175785, 299.76133553, -259.0691378, 132.06975835, -37.15612683, 5.60445773,)),
|
||||||
|
HbPolyType.ahdist_aASP_dTYR: ((1.01057521468, 2.7207545786900003,),(1.1, 1.1,),( -0.15808672, -10.21398871, 178.80080949, -1238.0583801, 4736.25248274, -11071.96777725, 16239.07550047, -14593.21092621, 7335.66765017, -1575.08145078,)),
|
||||||
|
HbPolyType.ahdist_aGLY_dARG: ((0.499016667857, 2.9377031027599996,),(1.1, 1.1,),( -0.15923533, 2.5526639, -17.38788803, 65.71046957, -151.13491186, 218.78048387, -199.15882919, 110.56568974, -35.95143745, 6.47580213,)),
|
||||||
|
HbPolyType.ahdist_aGLY_dASN: ((0.7194388032060001, 2.9303772333599998,),(1.1, 1.1,),( -1.40718342, 23.65929694, -172.97144348, 720.64417348, -1882.85420815, 3194.87197776, -3515.52467458, 2415.75238278, -941.47705161, 159.84784277,)),
|
||||||
|
HbPolyType.ahdist_aGLY_dGLY: ((1.38403812683, 2.9981039433,),(1.1, 1.1,),( -0.5307601, 6.47949946, -22.39522814, -55.14303544, 708.30945242, -2619.49318162, 5227.8805795, -6043.31211632, 3806.04676175, -1007.66024144,)),
|
||||||
|
HbPolyType.ahdist_aGLY_dHIS: ((0.47406840932899996, 2.9234200830400003,),(1.1, 1.1,),( -0.12881679, 1.933838, -12.03134888, 39.92691227, -75.41519959, 78.87968016, -37.82769801, -0.13178679, 4.50193019, 0.45408359,)),
|
||||||
|
HbPolyType.ahdist_aGLY_dLYS: ((0.545347533475, 2.42624380351,),(1.1, 1.1,),( -0.22921901, 2.07015714, -6.2947417, 0.66645697, 45.21805416, -130.26668981, 176.32401031, -126.68226346, 43.96744431, -4.40105281,)),
|
||||||
|
HbPolyType.ahdist_aGLY_dSER: ((1.2803349239700001, 2.2465996077400003,),(1.1, 1.1,),( 6.72508613, -86.98495585, 454.18518444, -1119.89141452, 715.624663, 3172.36852982, -9455.49113097, 11797.38766934, -7363.28302948, 1885.50119665,)),
|
||||||
|
HbPolyType.ahdist_aGLY_dTRP: ((0.686512740494, 3.02901351815,),(1.1, 1.1,),( -0.1051487, 1.41597708, -7.42149173, 17.31830704, -6.98293652, -54.76605063, 130.95272289, -132.77575305, 62.75460448, -9.89110842,)),
|
||||||
|
HbPolyType.ahdist_aGLY_dTYR: ((1.28894687639, 2.26335316892,),(1.1, 1.1,),( 13.84536925, -169.40579865, 893.79467505, -2670.60617561, 5016.46234701, -6293.79378818, 5585.1049063, -3683.50722701, 1709.48661405, -399.5712153,)),
|
||||||
|
HbPolyType.ahdist_aHIS_dARG: ((0.8967400957230001, 2.96809434226,),(1.1, 1.1,),( 0.43460495, -10.52727665, 103.16979807, -551.42887412, 1793.25378923, -3701.08304991, 4861.05155388, -3922.4285529, 1763.82137881, -335.43441944,)),
|
||||||
|
HbPolyType.ahdist_aHIS_dASN: ((0.887120931718, 2.59166903153,),(1.1, 1.1,),( -3.50289894, 54.42813924, -368.14395507, 1418.90186454, -3425.60485859, 5360.92334837, -5428.54462336, 3424.68800187, -1221.49631986, 189.27122436,)),
|
||||||
|
HbPolyType.ahdist_aHIS_dGLY: ((1.01629363411, 2.58523052904,),(1.1, 1.1,),( -1.68095217, 21.31894078, -107.72203494, 251.81021758, -134.07465831, -707.64527046, 1894.6282743, -2156.85951846, 1216.83585872, -275.48078944,)),
|
||||||
|
HbPolyType.ahdist_aHIS_dHIS: ((0.9773010778919999, 2.72533796329,),(1.1, 1.1,),( -2.33350626, 35.66072412, -233.98966111, 859.13714961, -1925.30958567, 2685.35293578, -2257.48067507, 1021.49796136, -169.36082523, -12.1348055,)),
|
||||||
|
HbPolyType.ahdist_aHIS_dLYS: ((0.7080936539849999, 2.47191718632,),(1.1, 1.1,),( -1.88479369, 28.38084382, -185.74039957, 690.81875917, -1605.11404391, 2414.83545623, -2355.9723201, 1442.24496229, -506.45880637, 79.47512505,)),
|
||||||
|
HbPolyType.ahdist_aHIS_dSER: ((0.90846809159, 2.5477956147,),(1.1, 1.1,),( -0.92004641, 15.91841533, -117.83979251, 488.22211296, -1244.13047376, 2017.43704053, -2076.04468019, 1302.42621488, -451.29138643, 67.15812575,)),
|
||||||
|
HbPolyType.ahdist_aHIS_dTRP: ((0.991999676806, 2.81296584506,),(1.1, 1.1,),( -1.29358587, 19.97152857, -131.89796017, 485.29199356, -1084.0466445, 1497.3352889, -1234.58042682, 535.8048197, -75.58951691, -9.91148332,)),
|
||||||
|
HbPolyType.ahdist_aHIS_dTYR: ((0.882661836357, 2.5469016429900004,),(1.1, 1.1,),( -6.94700143, 109.07997256, -747.64035726, 2929.83959536, -7220.15788571, 11583.34170519, -12078.443492, 7881.85479715, -2918.19482068, 468.23988622,)),
|
||||||
|
HbPolyType.ahdist_aSER_dARG: ((1.0204658147399999, 2.8899566041900004,),(1.1, 1.1,),( 0.33887327, -7.54511361, 70.87316645, -371.88263665, 1206.67454443, -2516.82084076, 3379.45432693, -2819.73384601, 1325.33307517, -265.54533008,)),
|
||||||
|
HbPolyType.ahdist_aSER_dASN: ((1.01393052233, 3.0024434159299997,),(1.1, 1.1,),( 0.37012361, -7.46486204, 64.85775924, -318.6047209, 974.66322243, -1924.37334018, 2451.63840629, -1943.1915675, 867.07870559, -163.83771761,)),
|
||||||
|
HbPolyType.ahdist_aSER_dGLY: ((1.3856562156299999, 2.74160605537,),(1.1, 1.1,),( -1.32847415, 22.67528654, -172.53450064, 770.79034865, -2233.48829652, 4354.38807288, -5697.35144236, 4803.38686157, -2361.48028857, 518.28202382,)),
|
||||||
|
HbPolyType.ahdist_aSER_dHIS: ((0.550992321207, 2.68549261999,),(1.1, 1.1,),( -1.98041793, 29.59668639, -190.36751773, 688.43324385, -1534.68894765, 2175.66568976, -1952.07622113, 1066.28943929, -324.23381388, 43.41006168,)),
|
||||||
|
HbPolyType.ahdist_aSER_dLYS: ((0.8603189393170001, 2.77729502744,),(1.1, 1.1,),( 0.90884741, -17.24690746, 141.78469099, -661.85989315, 1929.7674992, -3636.43392779, 4419.00727923, -3332.43482061, 1410.78913266, -253.53829424,)),
|
||||||
|
HbPolyType.ahdist_aSER_dSER: ((1.10866545921, 2.61727781204,),(1.1, 1.1,),( -0.38264308, 4.41779675, -10.7016645, -81.91314845, 668.91174735, -2187.50684758, 3983.56103269, -4213.32320546, 2418.41531442, -580.28918569,)),
|
||||||
|
HbPolyType.ahdist_aSER_dTRP: ((1.4092077245899999, 2.8066121197099996,),(1.1, 1.1,),( 0.73762477, -11.70741276, 73.05154232, -205.00144794, 89.58794368, 1082.94541375, -3343.98293188, 4601.70815729, -3178.53568678, 896.59487831,)),
|
||||||
|
HbPolyType.ahdist_aSER_dTYR: ((1.10773547919, 2.60403567341,),(1.1, 1.1,),( -1.13249925, 14.66643161, -69.01708791, 93.96846742, 380.56063898, -1984.56675689, 4074.08891127, -4492.76927139, 2613.13168054, -627.71933508,)),
|
||||||
|
HbPolyType.ahdist_aTYR_dARG: ((1.05581400627, 2.85499888099,),(1.1, 1.1,),( -0.30396592, 5.30288548, -39.75788579, 167.5416547, -435.15958911, 716.52357586, -735.95195083, 439.76284677, -130.00400085, 13.23827556,)),
|
||||||
|
HbPolyType.ahdist_aTYR_dASN: ((1.0994919065200002, 2.8400869077900004,),(1.1, 1.1,),( 0.33548259, -3.5890451, 8.97769025, 48.1492734, -400.5983616, 1269.89613211, -2238.03101675, 2298.33009115, -1290.42961162, 308.43185147,)),
|
||||||
|
HbPolyType.ahdist_aTYR_dGLY: ((1.36546155066, 2.7303075916400004,),(1.1, 1.1,),( -1.55312915, 18.62092487, -70.91365499, -41.83066505, 1248.88835245, -4719.81948329, 9186.09528168, -10266.11434548, 6266.21959533, -1622.19652457,)),
|
||||||
|
HbPolyType.ahdist_aTYR_dHIS: ((0.5955982461899999, 2.6643551317500003,),(1.1, 1.1,),( -0.47442788, 7.16629863, -46.71287553, 171.46128947, -388.17484011, 558.45202337, -506.35587481, 276.46237273, -83.52554392, 12.05709329,)),
|
||||||
|
HbPolyType.ahdist_aTYR_dLYS: ((0.7978598238760001, 2.7620933782,),(1.1, 1.1,),( -0.20201464, 1.69684984, 0.27677515, -55.05786347, 286.29918332, -725.92372531, 1054.771746, -889.33602341, 401.11342256, -73.02221189,)),
|
||||||
|
HbPolyType.ahdist_aTYR_dSER: ((0.7083554962559999, 2.7032011990599996,),(1.1, 1.1,),( -0.70764192, 11.67978065, -82.80447482, 329.83401367, -810.58976486, 1269.57613941, -1261.04047117, 761.72890446, -254.37526011, 37.24301861,)),
|
||||||
|
HbPolyType.ahdist_aTYR_dTRP: ((1.10934023051, 2.8819112108,),(1.1, 1.1,),( -11.58453967, 204.88308091, -1589.77384548, 7100.84791905, -20113.61354433, 37457.83646055, -45850.02969172, 35559.8805122, -15854.78726237, 3098.04931146,)),
|
||||||
|
HbPolyType.ahdist_aTYR_dTYR: ((1.1105954899400001, 2.60081798685,),(1.1, 1.1,),( -1.63120628, 19.48493187, -81.0332905, 56.80517706, 687.42717782, -2842.77799908, 5385.52231471, -5656.74159307, 3178.83470588, -744.70042777,)),
|
||||||
|
HbPolyType.AHD_1h: ((1.76555274367, 3.1416,),(1.1, 1.1,),( 0.62725838, -9.98558225, 59.39060071, -120.82930213, -333.26536028, 2603.13082592, -6895.51207142, 9651.25238056, -7127.13394872, 2194.77244026,)),
|
||||||
|
HbPolyType.AHD_1i: ((1.59914724347, 3.1416,),(1.1, 1.1,),( -0.18888801, 3.48241679, -25.65508662, 89.57085435, -95.91708218, -367.93452341, 1589.6904702, -2662.3582135, 2184.40194483, -723.28383545,)),
|
||||||
|
HbPolyType.AHD_1j: ((1.1435646388, 3.1416,),(1.1, 1.1,),( 0.47683259, -9.54524724, 83.62557693, -420.55867774, 1337.19354878, -2786.26265686, 3803.178227, -3278.62879901, 1619.04116204, -347.50157909,)),
|
||||||
|
HbPolyType.AHD_1k: ((1.15651981164, 3.1416,),(1.1, 1.1,),( -0.10757999, 2.0276542, -16.51949978, 75.83866839, -214.18025678, 380.55117567, -415.47847283, 255.66998474, -69.94662165, 3.21313428,)),
|
||||||
|
HbPolyType.cosBAH_off: ((-1234.0, 1.1,),(1.1, 1.1,),( 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,)),
|
||||||
|
HbPolyType.cosBAH_6i: ((-0.23538144897100002, 1.1,),(1.1, 1.1,),( -0.822093, -3.75364636, 46.88852157, -129.5440564, 146.69151428, -67.60598792, 2.91683129, 9.26673173, -3.84488178, 0.05706659,)),
|
||||||
|
HbPolyType.cosBAH_7: ((-0.019373850666900002, 1.1,),(1.1, 1.1,),( 0.0, -27.942923450028, 136.039920253368, -268.06959056747, 275.400462507919, -153.502076215949, 39.741591385461, 0.693861510121, -3.885952320499, 1.024765090788892)),
|
||||||
|
}
|
90
rf2aa/setup_model.py
Normal file
90
rf2aa/setup_model.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
import hydra
|
||||||
|
import os
|
||||||
|
|
||||||
|
from rf2aa.training.EMA import EMA
|
||||||
|
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
|
||||||
|
from rf2aa.util_module import XYZConverter
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
#TODO: control environment variables from config
|
||||||
|
# limit thread counts
|
||||||
|
os.environ['OMP_NUM_THREADS'] = '4'
|
||||||
|
os.environ['OPENBLAS_NUM_THREADS'] = '4'
|
||||||
|
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
|
||||||
|
|
||||||
|
## To reproduce errors
|
||||||
|
import random
|
||||||
|
|
||||||
|
def seed_all(seed=0):
|
||||||
|
random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
|
||||||
|
torch.set_num_threads(4)
|
||||||
|
#torch.autograd.set_detect_anomaly(True)
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(self, config) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
assert self.config.ddp_params.batch_size == 1, "batch size is assumed to be 1"
|
||||||
|
if self.config.experiment.output_dir is not None:
|
||||||
|
self.output_dir = self.config.experiment.output_dir
|
||||||
|
else:
|
||||||
|
self.output_dir = "models/"
|
||||||
|
if not os.path.exists(self.output_dir):
|
||||||
|
os.makedirs(self.output_dir)
|
||||||
|
|
||||||
|
def move_constants_to_device(self, gpu):
|
||||||
|
self.fi_dev = ChemData().frame_indices.to(gpu)
|
||||||
|
self.xyz_converter = XYZConverter().to(gpu)
|
||||||
|
|
||||||
|
self.l2a = ChemData().long2alt.to(gpu)
|
||||||
|
self.aamask = ChemData().allatom_mask.to(gpu)
|
||||||
|
self.num_bonds = ChemData().num_bonds.to(gpu)
|
||||||
|
self.atom_type_index = ChemData().atom_type_index.to(gpu)
|
||||||
|
self.ljlk_parameters = ChemData().ljlk_parameters.to(gpu)
|
||||||
|
self.lj_correction_parameters = ChemData().lj_correction_parameters.to(gpu)
|
||||||
|
self.hbtypes = ChemData().hbtypes.to(gpu)
|
||||||
|
self.hbbaseatoms = ChemData().hbbaseatoms.to(gpu)
|
||||||
|
self.hbpolys = ChemData().hbpolys.to(gpu)
|
||||||
|
self.cb_len = ChemData().cb_length_t.to(gpu)
|
||||||
|
self.cb_ang = ChemData().cb_angle_t.to(gpu)
|
||||||
|
self.cb_tor = ChemData().cb_torsion_t.to(gpu)
|
||||||
|
|
||||||
|
class LegacyTrainer(Trainer):
|
||||||
|
def __init__(self, config) -> None:
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
def construct_model(self, device="cpu"):
|
||||||
|
self.model = RoseTTAFoldModule(
|
||||||
|
**self.config.legacy_model_param,
|
||||||
|
aamask = ChemData().allatom_mask.to(device),
|
||||||
|
atom_type_index = ChemData().atom_type_index.to(device),
|
||||||
|
ljlk_parameters = ChemData().ljlk_parameters.to(device),
|
||||||
|
lj_correction_parameters = ChemData().lj_correction_parameters.to(device),
|
||||||
|
num_bonds = ChemData().num_bonds.to(device),
|
||||||
|
cb_len = ChemData().cb_length_t.to(device),
|
||||||
|
cb_ang = ChemData().cb_angle_t.to(device),
|
||||||
|
cb_tor = ChemData().cb_torsion_t.to(device),
|
||||||
|
|
||||||
|
).to(device)
|
||||||
|
if self.config.training_params.EMA is not None:
|
||||||
|
self.model = EMA(self.model, self.config.training_params.EMA)
|
||||||
|
|
||||||
|
@hydra.main(version_base=None, config_path='config/train')
|
||||||
|
def main(config):
|
||||||
|
seed_all()
|
||||||
|
trainer = trainer_factory[config.experiment.trainer](config=config)
|
||||||
|
|
||||||
|
trainer_factory = {
|
||||||
|
"legacy": LegacyTrainer,
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
716
rf2aa/symmetry.py
Normal file
716
rf2aa/symmetry.py
Normal file
|
@ -0,0 +1,716 @@
|
||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
SYMA = 1.0
|
||||||
|
|
||||||
|
def generateC(angs, eps=1e-6):
|
||||||
|
L = angs.shape[0]
|
||||||
|
Rs = torch.eye(3, device=angs.device).repeat(L,1,1)
|
||||||
|
Rs[:,1,1] = torch.cos(angs)
|
||||||
|
Rs[:,1,2] = -torch.sin(angs)
|
||||||
|
Rs[:,2,1] = torch.sin(angs)
|
||||||
|
Rs[:,2,2] = torch.cos(angs)
|
||||||
|
return Rs
|
||||||
|
|
||||||
|
def generateD(angs, eps=1e-6):
|
||||||
|
L = angs.shape[0]
|
||||||
|
Rs = torch.eye(3, device=angs.device).repeat(2*L,1,1)
|
||||||
|
Rs[:L,1,1] = torch.cos(angs)
|
||||||
|
Rs[:L,1,2] = -torch.sin(angs)
|
||||||
|
Rs[:L,2,1] = torch.sin(angs)
|
||||||
|
Rs[:L,2,2] = torch.cos(angs)
|
||||||
|
Rx = torch.tensor([[-1.,0,0],[0,-1,0],[0,0,1]],device=angs.device)
|
||||||
|
Rs[L:] = torch.einsum('ij,bjk->bik',Rx,Rs[:L])
|
||||||
|
return Rs
|
||||||
|
|
||||||
|
def find_symm_subs(xyz,Rs,metasymm):
|
||||||
|
com = xyz[:,:,1].mean(dim=-2)
|
||||||
|
rcoms = torch.einsum('sij,bj->si', Rs, com)
|
||||||
|
|
||||||
|
subsymms, nneighs = metasymm
|
||||||
|
|
||||||
|
subs = []
|
||||||
|
for i in range(len(subsymms)):
|
||||||
|
drcoms = torch.linalg.norm(rcoms[0,:] - rcoms[subsymms[i],:], dim=-1)
|
||||||
|
_,subs_i = torch.topk(drcoms,nneighs[i],largest=False)
|
||||||
|
subs_i,_ = torch.sort( subsymms[i][subs_i] )
|
||||||
|
subs.append(subs_i)
|
||||||
|
|
||||||
|
subs=torch.cat(subs)
|
||||||
|
xyz_new = torch.einsum('sij,braj->bsrai', Rs[subs], xyz).reshape(
|
||||||
|
xyz.shape[0],-1,xyz.shape[2],3)
|
||||||
|
return xyz_new, subs
|
||||||
|
|
||||||
|
def update_symm_subs(xyz,subs,Rs,metasymm):
|
||||||
|
xyz_new = torch.einsum('sij,braj->bsrai', Rs[subs], xyz).reshape(
|
||||||
|
xyz.shape[0],-1,xyz.shape[2],3)
|
||||||
|
return xyz_new
|
||||||
|
|
||||||
|
|
||||||
|
def get_symm_map(subs,O):
|
||||||
|
symmmask = torch.zeros(O,dtype=torch.long)
|
||||||
|
symmmask[subs] = torch.arange(1,subs.shape[0]+1)
|
||||||
|
return symmmask
|
||||||
|
|
||||||
|
|
||||||
|
def rotation_from_matrix(R, eps=1e-4):
|
||||||
|
w, W = torch.linalg.eig(R.T)
|
||||||
|
i = torch.where(abs(torch.real(w) - 1.0) < eps)[0]
|
||||||
|
if (len(i)==0):
|
||||||
|
i = torch.tensor([0])
|
||||||
|
print ('rotation_from_matrix w',torch.real(w))
|
||||||
|
print ('rotation_from_matrix R.T',R.T)
|
||||||
|
axis = torch.real(W[:, i[-1]]).squeeze()
|
||||||
|
|
||||||
|
cosa = (torch.trace(R) - 1.0) / 2.0
|
||||||
|
if abs(axis[2]) > eps:
|
||||||
|
sina = (R[1, 0] + (cosa-1.0)*axis[0]*axis[1]) / axis[2]
|
||||||
|
elif abs(axis[1]) > eps:
|
||||||
|
sina = (R[0, 2] + (cosa-1.0)*axis[0]*axis[2]) / axis[1]
|
||||||
|
else:
|
||||||
|
sina = (R[2, 1] + (cosa-1.0)*axis[1]*axis[2]) / axis[0]
|
||||||
|
angle = torch.atan2(sina, cosa)
|
||||||
|
|
||||||
|
return angle, axis
|
||||||
|
|
||||||
|
def kabsch(pred, true):
|
||||||
|
def rmsd(V, W, eps=1e-6):
|
||||||
|
L = V.shape[0]
|
||||||
|
return torch.sqrt(torch.sum((V-W)*(V-W)) / L + eps)
|
||||||
|
def centroid(X):
|
||||||
|
return X.mean(dim=-2, keepdim=True)
|
||||||
|
|
||||||
|
cP = centroid(pred)
|
||||||
|
cT = centroid(true)
|
||||||
|
pred = pred - cP
|
||||||
|
true = true - cT
|
||||||
|
C = torch.matmul(pred.permute(1,0), true)
|
||||||
|
V, S, W = torch.svd(C)
|
||||||
|
d = torch.ones([3,3], device=pred.device)
|
||||||
|
d[:,-1] = torch.sign(torch.det(V)*torch.det(W))
|
||||||
|
U = torch.matmul(d*V, W.permute(1,0)) # (IB, 3, 3)
|
||||||
|
rpred = torch.matmul(pred, U) # (IB, L*3, 3)
|
||||||
|
rms = rmsd(rpred, true)
|
||||||
|
return rms, U, cP, cT
|
||||||
|
|
||||||
|
# do lines X0->X and Y0->Y intersect?
|
||||||
|
def intersect(X0,X,Y0,Y,eps=0.1):
|
||||||
|
mtx = torch.cat(
|
||||||
|
(torch.stack((X0,X0+X,Y0,Y0+Y)), torch.ones((4,1))) , axis=1
|
||||||
|
)
|
||||||
|
det = torch.linalg.det( mtx )
|
||||||
|
return (torch.abs(det) <= eps)
|
||||||
|
|
||||||
|
def get_angle(X,Y):
|
||||||
|
angle = torch.acos( torch.clamp( torch.sum(X*Y), -1., 1. ) )
|
||||||
|
if (angle > np.pi/2):
|
||||||
|
angle = np.pi - angle
|
||||||
|
return angle
|
||||||
|
|
||||||
|
# given the coordinates of a subunit +
|
||||||
|
def get_symmetry(xyz, mask, rms_cut=2.5, nfold_cut=0.1, angle_cut=0.05, trans_cut=2.0):
|
||||||
|
nops = xyz.shape[0]
|
||||||
|
L = xyz.shape[1]//2
|
||||||
|
|
||||||
|
# PASS 1: find all symm axes
|
||||||
|
symmaxes = []
|
||||||
|
for i in range(nops):
|
||||||
|
# if there are multiple biomt records, this may occur.
|
||||||
|
# rather than try to rescue, we will take the 1st (typically author-assigned)
|
||||||
|
offset0 = torch.linalg.norm(xyz[i,:L,1]-xyz[0,:L,1], dim=-1)
|
||||||
|
if (torch.mean(offset0)>1e-4):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# get alignment
|
||||||
|
mask_i = mask[i,:L,1]*mask[i,L:,1]
|
||||||
|
xyz_i = xyz[i,:L,1][mask_i,:]
|
||||||
|
xyz_j = xyz[i,L:,1][mask_i,:]
|
||||||
|
rms_ij, Uij, cI, cJ = kabsch(xyz_i, xyz_j)
|
||||||
|
if (rms_ij > rms_cut):
|
||||||
|
#print (i,'rms',rms_ij)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# get axis and point symmetry about axis
|
||||||
|
angle, axis = rotation_from_matrix(Uij)
|
||||||
|
nfold = 2*np.pi/torch.abs(angle)
|
||||||
|
# a) ensure integer # of subunits per rotation
|
||||||
|
if (torch.abs( nfold - torch.round(nfold) ) > nfold_cut ):
|
||||||
|
#print ('nfold fail',nfold)
|
||||||
|
continue
|
||||||
|
nfold = torch.round(nfold).long()
|
||||||
|
# b) ensure rotation only (no translation)
|
||||||
|
delCOM = torch.mean(xyz_i, dim=-2) - torch.mean(xyz_j, dim=-2)
|
||||||
|
trans_dot_symaxis = nfold * torch.abs(torch.dot(delCOM, axis))
|
||||||
|
if (trans_dot_symaxis > trans_cut ):
|
||||||
|
#print ('trans fail',trans_dot_symaxis)
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
# 3) get a point on the symm axis from CoMs and angle
|
||||||
|
cIJ = torch.sign(angle) * (cJ-cI).squeeze(0)
|
||||||
|
dIJ = torch.linalg.norm(cIJ)
|
||||||
|
p_mid = (cI+cJ).squeeze(0) / 2
|
||||||
|
u = cIJ / dIJ # unit vector in plane of circle
|
||||||
|
v = torch.cross(axis, u) # unit vector from sym axis to p_mid
|
||||||
|
r = dIJ / (2*torch.sin(angle/2))
|
||||||
|
d = torch.sqrt( r*r - dIJ*dIJ/4 ) # distance from mid-chord to center
|
||||||
|
point = p_mid - (d)*v
|
||||||
|
|
||||||
|
# check if redundant
|
||||||
|
toadd = True
|
||||||
|
for j,(nf_j,ax_j,pt_j,err_j) in enumerate(symmaxes):
|
||||||
|
if (not intersect(pt_j,ax_j,point,axis)):
|
||||||
|
continue
|
||||||
|
angle_j = get_angle(ax_j,axis)
|
||||||
|
if (angle_j < angle_cut):
|
||||||
|
if (nf_j < nfold): # stored is a subsymmetry of complex, overwrite
|
||||||
|
symmaxes[j] = (nfold, axis, point, i)
|
||||||
|
toadd = False
|
||||||
|
|
||||||
|
if (toadd):
|
||||||
|
symmaxes.append( (nfold, axis, point, i) )
|
||||||
|
|
||||||
|
# PASS 2: combine
|
||||||
|
symmgroup = 'C1'
|
||||||
|
subsymm = []
|
||||||
|
if len(symmaxes)==1:
|
||||||
|
symmgroup = 'C%d'%(symmaxes[0][0])
|
||||||
|
subsymm = [symmaxes[0][3]]
|
||||||
|
elif len(symmaxes)>1:
|
||||||
|
symmaxes = sorted(symmaxes, key=lambda x: x[0], reverse=True)
|
||||||
|
angle = get_angle(symmaxes[0][1],symmaxes[1][1])
|
||||||
|
subsymm = [symmaxes[0][3],symmaxes[1][3]]
|
||||||
|
|
||||||
|
# 2-fold and n-fold intersect at 90 degress => Dn
|
||||||
|
if (symmaxes[1][0] == 2 and torch.abs(angle-np.pi/2) < angle_cut):
|
||||||
|
symmgroup = 'D%d'%(symmaxes[0][0])
|
||||||
|
else:
|
||||||
|
# polyhedral rules:
|
||||||
|
# 3-Fold + 2-fold intersecting at acos(-1/sqrt(3)) -> T
|
||||||
|
angle_tgt = np.arccos(-1/np.sqrt(3))
|
||||||
|
if (symmaxes[0][0] == 3 and symmaxes[1][0] == 2 and torch.abs(angle - angle_tgt) < angle_cut):
|
||||||
|
symmgroup = 'T'
|
||||||
|
|
||||||
|
# 3-Fold + 2-fold intersecting at asin(1/sqrt(3)) -> O
|
||||||
|
angle_tgt = np.arcsin(1/np.sqrt(3))
|
||||||
|
if (symmaxes[0][0] == 3 and symmaxes[1][0] == 2 and torch.abs(angle - angle_tgt) < angle_cut):
|
||||||
|
symmgroup = 'O'
|
||||||
|
|
||||||
|
# 4-Fold + 3-fold intersecting at acos(1/sqrt(3)) -> O
|
||||||
|
angle_tgt = np.arccos(1/np.sqrt(3))
|
||||||
|
if (symmaxes[0][0] == 4 and symmaxes[1][0] == 3 and torch.abs(angle - angle_tgt) < angle_cut):
|
||||||
|
symmgroup = 'O'
|
||||||
|
|
||||||
|
# 3-Fold + 2-fold intersecting at 0.5*acos(sqrt(5)/3) -> I
|
||||||
|
angle_tgt = 0.5*np.arccos(np.sqrt(5)/3)
|
||||||
|
if (symmaxes[0][0] == 3 and symmaxes[1][0] == 2 and torch.abs(angle - angle_tgt) < angle_cut):
|
||||||
|
symmgroup = 'I'
|
||||||
|
|
||||||
|
# 5-Fold + 2-fold intersecting at 0.5*acos(1/sqrt(5)) -> I
|
||||||
|
angle_tgt = 0.5*np.arccos(1/np.sqrt(5))
|
||||||
|
if (symmaxes[0][0] == 5 and symmaxes[1][0] == 2 and torch.abs(angle - angle_tgt) < angle_cut):
|
||||||
|
symmgroup = 'I'
|
||||||
|
|
||||||
|
# 5-Fold + 3-fold intersecting at 0.5*acos((4*sqrt(5)-5)/15) -> I
|
||||||
|
angle_tgt = 0.5*np.arccos((4*np.sqrt(5)-5)/15)
|
||||||
|
if (symmaxes[0][0] == 5 and symmaxes[1][0] == 3 and torch.abs(angle - angle_tgt) < angle_cut):
|
||||||
|
symmgroup = 'I'
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
#fd: we could use a single symmetry here instead.
|
||||||
|
# But these cases mostly are bad BIOUNIT annotations...
|
||||||
|
#print ('nomatch',angle, [(x,y) for x,_,_,y in symmaxes])
|
||||||
|
|
||||||
|
return symmgroup, subsymm
|
||||||
|
|
||||||
|
|
||||||
|
def symm_subunit_matrix(symmid):
|
||||||
|
if (symmid[0]=='C'):
|
||||||
|
nsub = int(symmid[1:])
|
||||||
|
symmatrix = (
|
||||||
|
torch.arange(nsub)[:,None]-torch.arange(nsub)[None,:]
|
||||||
|
)%nsub
|
||||||
|
angles = torch.linspace(0,2*np.pi,nsub+1)[:nsub]
|
||||||
|
Rs = generateC(angles)
|
||||||
|
|
||||||
|
metasymm = (
|
||||||
|
[torch.arange(nsub)],
|
||||||
|
[min(3,nsub)]
|
||||||
|
)
|
||||||
|
|
||||||
|
if (nsub==1):
|
||||||
|
D = 0.0
|
||||||
|
else:
|
||||||
|
est_radius = 2.0*SYMA
|
||||||
|
theta = 2.0*np.pi/nsub
|
||||||
|
D = est_radius/np.sin(theta/2)
|
||||||
|
|
||||||
|
offset = torch.tensor([ 0.0,0.0,float(D) ])
|
||||||
|
elif (symmid[0]=='D'):
|
||||||
|
nsub = int(symmid[1:])
|
||||||
|
cblk=(torch.arange(nsub)[:,None]-torch.arange(nsub)[None,:])%nsub
|
||||||
|
symmatrix=torch.zeros((2*nsub,2*nsub),dtype=torch.long)
|
||||||
|
symmatrix[:nsub,:nsub] = cblk
|
||||||
|
symmatrix[:nsub,nsub:] = cblk+nsub
|
||||||
|
symmatrix[nsub:,:nsub] = cblk+nsub
|
||||||
|
symmatrix[nsub:,nsub:] = cblk
|
||||||
|
angles = torch.linspace(0,2*np.pi,nsub+1)[:nsub]
|
||||||
|
Rs = generateD(angles)
|
||||||
|
|
||||||
|
metasymm = (
|
||||||
|
[torch.arange(nsub), nsub+torch.arange(nsub)],
|
||||||
|
[min(3,nsub),2]
|
||||||
|
)
|
||||||
|
#metasymm = (
|
||||||
|
# [torch.arange(2*nsub)],
|
||||||
|
# [min(2*nsub,5)]
|
||||||
|
#)
|
||||||
|
|
||||||
|
est_radius = 2.0*SYMA
|
||||||
|
theta1 = 2.0*np.pi/nsub
|
||||||
|
theta2 = np.pi
|
||||||
|
D1 = est_radius/np.sin(theta1/2)
|
||||||
|
D2 = est_radius/np.sin(theta2/2)
|
||||||
|
offset = torch.tensor([ float(D2),0.0,float(D1) ])
|
||||||
|
#offset = torch.tensor([ 0.0,0.0,0.0 ])
|
||||||
|
elif (symmid=='T'):
|
||||||
|
symmatrix=torch.tensor(
|
||||||
|
[[ 0, 1, 2, 3, 8, 11, 9, 10, 4, 6, 7, 5],
|
||||||
|
[ 1, 0, 3, 2, 9, 10, 8, 11, 5, 7, 6, 4],
|
||||||
|
[ 2, 3, 0, 1, 10, 9, 11, 8, 6, 4, 5, 7],
|
||||||
|
[ 3, 2, 1, 0, 11, 8, 10, 9, 7, 5, 4, 6],
|
||||||
|
[ 4, 6, 7, 5, 0, 1, 2, 3, 8, 11, 9, 10],
|
||||||
|
[ 5, 7, 6, 4, 1, 0, 3, 2, 9, 10, 8, 11],
|
||||||
|
[ 6, 4, 5, 7, 2, 3, 0, 1, 10, 9, 11, 8],
|
||||||
|
[ 7, 5, 4, 6, 3, 2, 1, 0, 11, 8, 10, 9],
|
||||||
|
[ 8, 11, 9, 10, 4, 6, 7, 5, 0, 1, 2, 3],
|
||||||
|
[ 9, 10, 8, 11, 5, 7, 6, 4, 1, 0, 3, 2],
|
||||||
|
[10, 9, 11, 8, 6, 4, 5, 7, 2, 3, 0, 1],
|
||||||
|
[11, 8, 10, 9, 7, 5, 4, 6, 3, 2, 1, 0]])
|
||||||
|
Rs = torch.zeros(12,3,3)
|
||||||
|
Rs[ 0]=torch.tensor([[1.000000,0.000000,0.000000],[0.000000,1.000000,0.000000],[0.000000,0.000000,1.000000]])
|
||||||
|
Rs[ 1]=torch.tensor([[-1.000000,0.000000,0.000000],[0.000000,-1.000000,0.000000],[0.000000,0.000000,1.000000]])
|
||||||
|
Rs[ 2]=torch.tensor([[-1.000000,0.000000,0.000000],[0.000000,1.000000,0.000000],[0.000000,0.000000,-1.000000]])
|
||||||
|
Rs[ 3]=torch.tensor([[1.000000,0.000000,0.000000],[0.000000,-1.000000,0.000000],[0.000000,0.000000,-1.000000]])
|
||||||
|
Rs[ 4]=torch.tensor([[0.000000,0.000000,1.000000],[1.000000,0.000000,0.000000],[0.000000,1.000000,0.000000]])
|
||||||
|
Rs[ 5]=torch.tensor([[0.000000,0.000000,1.000000],[-1.000000,0.000000,0.000000],[0.000000,-1.000000,0.000000]])
|
||||||
|
Rs[ 6]=torch.tensor([[0.000000,0.000000,-1.000000],[-1.000000,0.000000,0.000000],[0.000000,1.000000,0.000000]])
|
||||||
|
Rs[ 7]=torch.tensor([[0.000000,0.000000,-1.000000],[1.000000,0.000000,0.000000],[0.000000,-1.000000,0.000000]])
|
||||||
|
Rs[ 8]=torch.tensor([[0.000000,1.000000,0.000000],[0.000000,0.000000,1.000000],[1.000000,0.000000,0.000000]])
|
||||||
|
Rs[ 9]=torch.tensor([[0.000000,-1.000000,0.000000],[0.000000,0.000000,1.000000],[-1.000000,0.000000,0.000000]])
|
||||||
|
Rs[10]=torch.tensor([[0.000000,1.000000,0.000000],[0.000000,0.000000,-1.000000],[-1.000000,0.000000,0.000000]])
|
||||||
|
Rs[11]=torch.tensor([[0.000000,-1.000000,0.000000],[0.000000,0.000000,-1.000000],[1.000000,0.000000,0.000000]])
|
||||||
|
nneigh = 5
|
||||||
|
metasymm = (
|
||||||
|
[torch.arange(12)],
|
||||||
|
[6]
|
||||||
|
)
|
||||||
|
|
||||||
|
est_radius = 4.0*SYMA
|
||||||
|
offset = torch.tensor([ 1.0,0.0,0.0 ])
|
||||||
|
offset = est_radius * offset / torch.linalg.norm(offset)
|
||||||
|
metasymm = (
|
||||||
|
[torch.arange(12)],
|
||||||
|
[6]
|
||||||
|
)
|
||||||
|
|
||||||
|
elif (symmid=='O'):
|
||||||
|
symmatrix=torch.tensor(
|
||||||
|
[[ 0, 1, 2, 3, 8, 11, 9, 10, 4, 6, 7, 5, 12, 13, 15, 14, 19, 17,
|
||||||
|
18, 16, 22, 21, 20, 23],
|
||||||
|
[ 1, 0, 3, 2, 9, 10, 8, 11, 5, 7, 6, 4, 13, 12, 14, 15, 18, 16,
|
||||||
|
19, 17, 23, 20, 21, 22],
|
||||||
|
[ 2, 3, 0, 1, 10, 9, 11, 8, 6, 4, 5, 7, 14, 15, 13, 12, 17, 19,
|
||||||
|
16, 18, 20, 23, 22, 21],
|
||||||
|
[ 3, 2, 1, 0, 11, 8, 10, 9, 7, 5, 4, 6, 15, 14, 12, 13, 16, 18,
|
||||||
|
17, 19, 21, 22, 23, 20],
|
||||||
|
[ 4, 6, 7, 5, 0, 1, 2, 3, 8, 11, 9, 10, 16, 18, 17, 19, 21, 22,
|
||||||
|
23, 20, 15, 14, 12, 13],
|
||||||
|
[ 5, 7, 6, 4, 1, 0, 3, 2, 9, 10, 8, 11, 17, 19, 16, 18, 20, 23,
|
||||||
|
22, 21, 14, 15, 13, 12],
|
||||||
|
[ 6, 4, 5, 7, 2, 3, 0, 1, 10, 9, 11, 8, 18, 16, 19, 17, 23, 20,
|
||||||
|
21, 22, 13, 12, 14, 15],
|
||||||
|
[ 7, 5, 4, 6, 3, 2, 1, 0, 11, 8, 10, 9, 19, 17, 18, 16, 22, 21,
|
||||||
|
20, 23, 12, 13, 15, 14],
|
||||||
|
[ 8, 11, 9, 10, 4, 6, 7, 5, 0, 1, 2, 3, 20, 23, 22, 21, 14, 15,
|
||||||
|
13, 12, 17, 19, 16, 18],
|
||||||
|
[ 9, 10, 8, 11, 5, 7, 6, 4, 1, 0, 3, 2, 21, 22, 23, 20, 15, 14,
|
||||||
|
12, 13, 16, 18, 17, 19],
|
||||||
|
[10, 9, 11, 8, 6, 4, 5, 7, 2, 3, 0, 1, 22, 21, 20, 23, 12, 13,
|
||||||
|
15, 14, 19, 17, 18, 16],
|
||||||
|
[11, 8, 10, 9, 7, 5, 4, 6, 3, 2, 1, 0, 23, 20, 21, 22, 13, 12,
|
||||||
|
14, 15, 18, 16, 19, 17],
|
||||||
|
[12, 13, 15, 14, 19, 17, 18, 16, 22, 21, 20, 23, 0, 1, 2, 3, 8, 11,
|
||||||
|
9, 10, 4, 6, 7, 5],
|
||||||
|
[13, 12, 14, 15, 18, 16, 19, 17, 23, 20, 21, 22, 1, 0, 3, 2, 9, 10,
|
||||||
|
8, 11, 5, 7, 6, 4],
|
||||||
|
[14, 15, 13, 12, 17, 19, 16, 18, 20, 23, 22, 21, 2, 3, 0, 1, 10, 9,
|
||||||
|
11, 8, 6, 4, 5, 7],
|
||||||
|
[15, 14, 12, 13, 16, 18, 17, 19, 21, 22, 23, 20, 3, 2, 1, 0, 11, 8,
|
||||||
|
10, 9, 7, 5, 4, 6],
|
||||||
|
[16, 18, 17, 19, 21, 22, 23, 20, 15, 14, 12, 13, 4, 6, 7, 5, 0, 1,
|
||||||
|
2, 3, 8, 11, 9, 10],
|
||||||
|
[17, 19, 16, 18, 20, 23, 22, 21, 14, 15, 13, 12, 5, 7, 6, 4, 1, 0,
|
||||||
|
3, 2, 9, 10, 8, 11],
|
||||||
|
[18, 16, 19, 17, 23, 20, 21, 22, 13, 12, 14, 15, 6, 4, 5, 7, 2, 3,
|
||||||
|
0, 1, 10, 9, 11, 8],
|
||||||
|
[19, 17, 18, 16, 22, 21, 20, 23, 12, 13, 15, 14, 7, 5, 4, 6, 3, 2,
|
||||||
|
1, 0, 11, 8, 10, 9],
|
||||||
|
[20, 23, 22, 21, 14, 15, 13, 12, 17, 19, 16, 18, 8, 11, 9, 10, 4, 6,
|
||||||
|
7, 5, 0, 1, 2, 3],
|
||||||
|
[21, 22, 23, 20, 15, 14, 12, 13, 16, 18, 17, 19, 9, 10, 8, 11, 5, 7,
|
||||||
|
6, 4, 1, 0, 3, 2],
|
||||||
|
[22, 21, 20, 23, 12, 13, 15, 14, 19, 17, 18, 16, 10, 9, 11, 8, 6, 4,
|
||||||
|
5, 7, 2, 3, 0, 1],
|
||||||
|
[23, 20, 21, 22, 13, 12, 14, 15, 18, 16, 19, 17, 11, 8, 10, 9, 7, 5,
|
||||||
|
4, 6, 3, 2, 1, 0]])
|
||||||
|
Rs = torch.zeros(24,3,3)
|
||||||
|
Rs[0]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
|
||||||
|
Rs[1]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
|
||||||
|
Rs[2]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
|
||||||
|
Rs[3]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
|
||||||
|
Rs[4]=torch.tensor([[ 0.000000, 0.000000,1.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000]])
|
||||||
|
Rs[5]=torch.tensor([[ 0.000000, 0.000000,1.000000],[-1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000]])
|
||||||
|
Rs[6]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[-1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000]])
|
||||||
|
Rs[7]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000]])
|
||||||
|
Rs[8]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 0.000000, 0.000000,1.000000],[ 1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[9]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 0.000000, 0.000000,1.000000],[-1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[10]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[-1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[11]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[ 1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[12]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
|
||||||
|
Rs[13]=torch.tensor([[ 0.000000,-1.000000,0.000000],[-1.000000, 0.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
|
||||||
|
Rs[14]=torch.tensor([[ 0.000000, 1.000000,0.000000],[-1.000000, 0.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
|
||||||
|
Rs[15]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
|
||||||
|
Rs[16]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000, 0.000000,1.000000],[ 0.000000,-1.000000,0.000000]])
|
||||||
|
Rs[17]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000, 0.000000,1.000000],[ 0.000000, 1.000000,0.000000]])
|
||||||
|
Rs[18]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[ 0.000000,-1.000000,0.000000]])
|
||||||
|
Rs[19]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[ 0.000000, 1.000000,0.000000]])
|
||||||
|
Rs[20]=torch.tensor([[ 0.000000, 0.000000,1.000000],[ 0.000000, 1.000000,0.000000],[-1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[21]=torch.tensor([[ 0.000000, 0.000000,1.000000],[ 0.000000,-1.000000,0.000000],[ 1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[22]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[ 0.000000, 1.000000,0.000000],[ 1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[23]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[ 0.000000,-1.000000,0.000000],[-1.000000, 0.000000,0.000000]])
|
||||||
|
|
||||||
|
est_radius = 6.0*SYMA
|
||||||
|
offset = torch.tensor([ 1.0,0.0,0.0 ])
|
||||||
|
offset = est_radius * offset / torch.linalg.norm(offset)
|
||||||
|
metasymm = (
|
||||||
|
[torch.arange(24)],
|
||||||
|
[6]
|
||||||
|
)
|
||||||
|
elif (symmid=='I'):
|
||||||
|
symmatrix=torch.tensor(
|
||||||
|
[[ 0, 4, 3, 2, 1, 5, 33, 49, 41, 22, 10, 27, 51, 59, 38, 15, 16, 17,
|
||||||
|
18, 19, 40, 21, 9, 32, 48, 55, 39, 11, 28, 52, 45, 42, 23, 6, 34, 50,
|
||||||
|
58, 37, 14, 26, 20, 8, 31, 47, 44, 30, 46, 43, 24, 7, 35, 12, 29, 53,
|
||||||
|
56, 25, 54, 57, 36, 13],
|
||||||
|
[ 1, 0, 4, 3, 2, 6, 34, 45, 42, 23, 11, 28, 52, 55, 39, 16, 17, 18,
|
||||||
|
19, 15, 41, 22, 5, 33, 49, 56, 35, 12, 29, 53, 46, 43, 24, 7, 30, 51,
|
||||||
|
59, 38, 10, 27, 21, 9, 32, 48, 40, 31, 47, 44, 20, 8, 36, 13, 25, 54,
|
||||||
|
57, 26, 50, 58, 37, 14],
|
||||||
|
[ 2, 1, 0, 4, 3, 7, 30, 46, 43, 24, 12, 29, 53, 56, 35, 17, 18, 19,
|
||||||
|
15, 16, 42, 23, 6, 34, 45, 57, 36, 13, 25, 54, 47, 44, 20, 8, 31, 52,
|
||||||
|
55, 39, 11, 28, 22, 5, 33, 49, 41, 32, 48, 40, 21, 9, 37, 14, 26, 50,
|
||||||
|
58, 27, 51, 59, 38, 10],
|
||||||
|
[ 3, 2, 1, 0, 4, 8, 31, 47, 44, 20, 13, 25, 54, 57, 36, 18, 19, 15,
|
||||||
|
16, 17, 43, 24, 7, 30, 46, 58, 37, 14, 26, 50, 48, 40, 21, 9, 32, 53,
|
||||||
|
56, 35, 12, 29, 23, 6, 34, 45, 42, 33, 49, 41, 22, 5, 38, 10, 27, 51,
|
||||||
|
59, 28, 52, 55, 39, 11],
|
||||||
|
[ 4, 3, 2, 1, 0, 9, 32, 48, 40, 21, 14, 26, 50, 58, 37, 19, 15, 16,
|
||||||
|
17, 18, 44, 20, 8, 31, 47, 59, 38, 10, 27, 51, 49, 41, 22, 5, 33, 54,
|
||||||
|
57, 36, 13, 25, 24, 7, 30, 46, 43, 34, 45, 42, 23, 6, 39, 11, 28, 52,
|
||||||
|
55, 29, 53, 56, 35, 12],
|
||||||
|
[ 5, 33, 49, 41, 22, 0, 4, 3, 2, 1, 15, 16, 17, 18, 19, 10, 27, 51,
|
||||||
|
59, 38, 45, 42, 23, 6, 34, 50, 58, 37, 14, 26, 40, 21, 9, 32, 48, 55,
|
||||||
|
39, 11, 28, 52, 25, 54, 57, 36, 13, 35, 12, 29, 53, 56, 30, 46, 43, 24,
|
||||||
|
7, 20, 8, 31, 47, 44],
|
||||||
|
[ 6, 34, 45, 42, 23, 1, 0, 4, 3, 2, 16, 17, 18, 19, 15, 11, 28, 52,
|
||||||
|
55, 39, 46, 43, 24, 7, 30, 51, 59, 38, 10, 27, 41, 22, 5, 33, 49, 56,
|
||||||
|
35, 12, 29, 53, 26, 50, 58, 37, 14, 36, 13, 25, 54, 57, 31, 47, 44, 20,
|
||||||
|
8, 21, 9, 32, 48, 40],
|
||||||
|
[ 7, 30, 46, 43, 24, 2, 1, 0, 4, 3, 17, 18, 19, 15, 16, 12, 29, 53,
|
||||||
|
56, 35, 47, 44, 20, 8, 31, 52, 55, 39, 11, 28, 42, 23, 6, 34, 45, 57,
|
||||||
|
36, 13, 25, 54, 27, 51, 59, 38, 10, 37, 14, 26, 50, 58, 32, 48, 40, 21,
|
||||||
|
9, 22, 5, 33, 49, 41],
|
||||||
|
[ 8, 31, 47, 44, 20, 3, 2, 1, 0, 4, 18, 19, 15, 16, 17, 13, 25, 54,
|
||||||
|
57, 36, 48, 40, 21, 9, 32, 53, 56, 35, 12, 29, 43, 24, 7, 30, 46, 58,
|
||||||
|
37, 14, 26, 50, 28, 52, 55, 39, 11, 38, 10, 27, 51, 59, 33, 49, 41, 22,
|
||||||
|
5, 23, 6, 34, 45, 42],
|
||||||
|
[ 9, 32, 48, 40, 21, 4, 3, 2, 1, 0, 19, 15, 16, 17, 18, 14, 26, 50,
|
||||||
|
58, 37, 49, 41, 22, 5, 33, 54, 57, 36, 13, 25, 44, 20, 8, 31, 47, 59,
|
||||||
|
38, 10, 27, 51, 29, 53, 56, 35, 12, 39, 11, 28, 52, 55, 34, 45, 42, 23,
|
||||||
|
6, 24, 7, 30, 46, 43],
|
||||||
|
[10, 27, 51, 59, 38, 15, 16, 17, 18, 19, 0, 4, 3, 2, 1, 5, 33, 49,
|
||||||
|
41, 22, 50, 58, 37, 14, 26, 45, 42, 23, 6, 34, 55, 39, 11, 28, 52, 40,
|
||||||
|
21, 9, 32, 48, 30, 46, 43, 24, 7, 20, 8, 31, 47, 44, 25, 54, 57, 36,
|
||||||
|
13, 35, 12, 29, 53, 56],
|
||||||
|
[11, 28, 52, 55, 39, 16, 17, 18, 19, 15, 1, 0, 4, 3, 2, 6, 34, 45,
|
||||||
|
42, 23, 51, 59, 38, 10, 27, 46, 43, 24, 7, 30, 56, 35, 12, 29, 53, 41,
|
||||||
|
22, 5, 33, 49, 31, 47, 44, 20, 8, 21, 9, 32, 48, 40, 26, 50, 58, 37,
|
||||||
|
14, 36, 13, 25, 54, 57],
|
||||||
|
[12, 29, 53, 56, 35, 17, 18, 19, 15, 16, 2, 1, 0, 4, 3, 7, 30, 46,
|
||||||
|
43, 24, 52, 55, 39, 11, 28, 47, 44, 20, 8, 31, 57, 36, 13, 25, 54, 42,
|
||||||
|
23, 6, 34, 45, 32, 48, 40, 21, 9, 22, 5, 33, 49, 41, 27, 51, 59, 38,
|
||||||
|
10, 37, 14, 26, 50, 58],
|
||||||
|
[13, 25, 54, 57, 36, 18, 19, 15, 16, 17, 3, 2, 1, 0, 4, 8, 31, 47,
|
||||||
|
44, 20, 53, 56, 35, 12, 29, 48, 40, 21, 9, 32, 58, 37, 14, 26, 50, 43,
|
||||||
|
24, 7, 30, 46, 33, 49, 41, 22, 5, 23, 6, 34, 45, 42, 28, 52, 55, 39,
|
||||||
|
11, 38, 10, 27, 51, 59],
|
||||||
|
[14, 26, 50, 58, 37, 19, 15, 16, 17, 18, 4, 3, 2, 1, 0, 9, 32, 48,
|
||||||
|
40, 21, 54, 57, 36, 13, 25, 49, 41, 22, 5, 33, 59, 38, 10, 27, 51, 44,
|
||||||
|
20, 8, 31, 47, 34, 45, 42, 23, 6, 24, 7, 30, 46, 43, 29, 53, 56, 35,
|
||||||
|
12, 39, 11, 28, 52, 55],
|
||||||
|
[15, 16, 17, 18, 19, 10, 27, 51, 59, 38, 5, 33, 49, 41, 22, 0, 4, 3,
|
||||||
|
2, 1, 55, 39, 11, 28, 52, 40, 21, 9, 32, 48, 50, 58, 37, 14, 26, 45,
|
||||||
|
42, 23, 6, 34, 35, 12, 29, 53, 56, 25, 54, 57, 36, 13, 20, 8, 31, 47,
|
||||||
|
44, 30, 46, 43, 24, 7],
|
||||||
|
[16, 17, 18, 19, 15, 11, 28, 52, 55, 39, 6, 34, 45, 42, 23, 1, 0, 4,
|
||||||
|
3, 2, 56, 35, 12, 29, 53, 41, 22, 5, 33, 49, 51, 59, 38, 10, 27, 46,
|
||||||
|
43, 24, 7, 30, 36, 13, 25, 54, 57, 26, 50, 58, 37, 14, 21, 9, 32, 48,
|
||||||
|
40, 31, 47, 44, 20, 8],
|
||||||
|
[17, 18, 19, 15, 16, 12, 29, 53, 56, 35, 7, 30, 46, 43, 24, 2, 1, 0,
|
||||||
|
4, 3, 57, 36, 13, 25, 54, 42, 23, 6, 34, 45, 52, 55, 39, 11, 28, 47,
|
||||||
|
44, 20, 8, 31, 37, 14, 26, 50, 58, 27, 51, 59, 38, 10, 22, 5, 33, 49,
|
||||||
|
41, 32, 48, 40, 21, 9],
|
||||||
|
[18, 19, 15, 16, 17, 13, 25, 54, 57, 36, 8, 31, 47, 44, 20, 3, 2, 1,
|
||||||
|
0, 4, 58, 37, 14, 26, 50, 43, 24, 7, 30, 46, 53, 56, 35, 12, 29, 48,
|
||||||
|
40, 21, 9, 32, 38, 10, 27, 51, 59, 28, 52, 55, 39, 11, 23, 6, 34, 45,
|
||||||
|
42, 33, 49, 41, 22, 5],
|
||||||
|
[19, 15, 16, 17, 18, 14, 26, 50, 58, 37, 9, 32, 48, 40, 21, 4, 3, 2,
|
||||||
|
1, 0, 59, 38, 10, 27, 51, 44, 20, 8, 31, 47, 54, 57, 36, 13, 25, 49,
|
||||||
|
41, 22, 5, 33, 39, 11, 28, 52, 55, 29, 53, 56, 35, 12, 24, 7, 30, 46,
|
||||||
|
43, 34, 45, 42, 23, 6],
|
||||||
|
[20, 8, 31, 47, 44, 30, 46, 43, 24, 7, 35, 12, 29, 53, 56, 25, 54, 57,
|
||||||
|
36, 13, 0, 4, 3, 2, 1, 5, 33, 49, 41, 22, 10, 27, 51, 59, 38, 15,
|
||||||
|
16, 17, 18, 19, 40, 21, 9, 32, 48, 55, 39, 11, 28, 52, 45, 42, 23, 6,
|
||||||
|
34, 50, 58, 37, 14, 26],
|
||||||
|
[21, 9, 32, 48, 40, 31, 47, 44, 20, 8, 36, 13, 25, 54, 57, 26, 50, 58,
|
||||||
|
37, 14, 1, 0, 4, 3, 2, 6, 34, 45, 42, 23, 11, 28, 52, 55, 39, 16,
|
||||||
|
17, 18, 19, 15, 41, 22, 5, 33, 49, 56, 35, 12, 29, 53, 46, 43, 24, 7,
|
||||||
|
30, 51, 59, 38, 10, 27],
|
||||||
|
[22, 5, 33, 49, 41, 32, 48, 40, 21, 9, 37, 14, 26, 50, 58, 27, 51, 59,
|
||||||
|
38, 10, 2, 1, 0, 4, 3, 7, 30, 46, 43, 24, 12, 29, 53, 56, 35, 17,
|
||||||
|
18, 19, 15, 16, 42, 23, 6, 34, 45, 57, 36, 13, 25, 54, 47, 44, 20, 8,
|
||||||
|
31, 52, 55, 39, 11, 28],
|
||||||
|
[23, 6, 34, 45, 42, 33, 49, 41, 22, 5, 38, 10, 27, 51, 59, 28, 52, 55,
|
||||||
|
39, 11, 3, 2, 1, 0, 4, 8, 31, 47, 44, 20, 13, 25, 54, 57, 36, 18,
|
||||||
|
19, 15, 16, 17, 43, 24, 7, 30, 46, 58, 37, 14, 26, 50, 48, 40, 21, 9,
|
||||||
|
32, 53, 56, 35, 12, 29],
|
||||||
|
[24, 7, 30, 46, 43, 34, 45, 42, 23, 6, 39, 11, 28, 52, 55, 29, 53, 56,
|
||||||
|
35, 12, 4, 3, 2, 1, 0, 9, 32, 48, 40, 21, 14, 26, 50, 58, 37, 19,
|
||||||
|
15, 16, 17, 18, 44, 20, 8, 31, 47, 59, 38, 10, 27, 51, 49, 41, 22, 5,
|
||||||
|
33, 54, 57, 36, 13, 25],
|
||||||
|
[25, 54, 57, 36, 13, 35, 12, 29, 53, 56, 30, 46, 43, 24, 7, 20, 8, 31,
|
||||||
|
47, 44, 5, 33, 49, 41, 22, 0, 4, 3, 2, 1, 15, 16, 17, 18, 19, 10,
|
||||||
|
27, 51, 59, 38, 45, 42, 23, 6, 34, 50, 58, 37, 14, 26, 40, 21, 9, 32,
|
||||||
|
48, 55, 39, 11, 28, 52],
|
||||||
|
[26, 50, 58, 37, 14, 36, 13, 25, 54, 57, 31, 47, 44, 20, 8, 21, 9, 32,
|
||||||
|
48, 40, 6, 34, 45, 42, 23, 1, 0, 4, 3, 2, 16, 17, 18, 19, 15, 11,
|
||||||
|
28, 52, 55, 39, 46, 43, 24, 7, 30, 51, 59, 38, 10, 27, 41, 22, 5, 33,
|
||||||
|
49, 56, 35, 12, 29, 53],
|
||||||
|
[27, 51, 59, 38, 10, 37, 14, 26, 50, 58, 32, 48, 40, 21, 9, 22, 5, 33,
|
||||||
|
49, 41, 7, 30, 46, 43, 24, 2, 1, 0, 4, 3, 17, 18, 19, 15, 16, 12,
|
||||||
|
29, 53, 56, 35, 47, 44, 20, 8, 31, 52, 55, 39, 11, 28, 42, 23, 6, 34,
|
||||||
|
45, 57, 36, 13, 25, 54],
|
||||||
|
[28, 52, 55, 39, 11, 38, 10, 27, 51, 59, 33, 49, 41, 22, 5, 23, 6, 34,
|
||||||
|
45, 42, 8, 31, 47, 44, 20, 3, 2, 1, 0, 4, 18, 19, 15, 16, 17, 13,
|
||||||
|
25, 54, 57, 36, 48, 40, 21, 9, 32, 53, 56, 35, 12, 29, 43, 24, 7, 30,
|
||||||
|
46, 58, 37, 14, 26, 50],
|
||||||
|
[29, 53, 56, 35, 12, 39, 11, 28, 52, 55, 34, 45, 42, 23, 6, 24, 7, 30,
|
||||||
|
46, 43, 9, 32, 48, 40, 21, 4, 3, 2, 1, 0, 19, 15, 16, 17, 18, 14,
|
||||||
|
26, 50, 58, 37, 49, 41, 22, 5, 33, 54, 57, 36, 13, 25, 44, 20, 8, 31,
|
||||||
|
47, 59, 38, 10, 27, 51],
|
||||||
|
[30, 46, 43, 24, 7, 20, 8, 31, 47, 44, 25, 54, 57, 36, 13, 35, 12, 29,
|
||||||
|
53, 56, 10, 27, 51, 59, 38, 15, 16, 17, 18, 19, 0, 4, 3, 2, 1, 5,
|
||||||
|
33, 49, 41, 22, 50, 58, 37, 14, 26, 45, 42, 23, 6, 34, 55, 39, 11, 28,
|
||||||
|
52, 40, 21, 9, 32, 48],
|
||||||
|
[31, 47, 44, 20, 8, 21, 9, 32, 48, 40, 26, 50, 58, 37, 14, 36, 13, 25,
|
||||||
|
54, 57, 11, 28, 52, 55, 39, 16, 17, 18, 19, 15, 1, 0, 4, 3, 2, 6,
|
||||||
|
34, 45, 42, 23, 51, 59, 38, 10, 27, 46, 43, 24, 7, 30, 56, 35, 12, 29,
|
||||||
|
53, 41, 22, 5, 33, 49],
|
||||||
|
[32, 48, 40, 21, 9, 22, 5, 33, 49, 41, 27, 51, 59, 38, 10, 37, 14, 26,
|
||||||
|
50, 58, 12, 29, 53, 56, 35, 17, 18, 19, 15, 16, 2, 1, 0, 4, 3, 7,
|
||||||
|
30, 46, 43, 24, 52, 55, 39, 11, 28, 47, 44, 20, 8, 31, 57, 36, 13, 25,
|
||||||
|
54, 42, 23, 6, 34, 45],
|
||||||
|
[33, 49, 41, 22, 5, 23, 6, 34, 45, 42, 28, 52, 55, 39, 11, 38, 10, 27,
|
||||||
|
51, 59, 13, 25, 54, 57, 36, 18, 19, 15, 16, 17, 3, 2, 1, 0, 4, 8,
|
||||||
|
31, 47, 44, 20, 53, 56, 35, 12, 29, 48, 40, 21, 9, 32, 58, 37, 14, 26,
|
||||||
|
50, 43, 24, 7, 30, 46],
|
||||||
|
[34, 45, 42, 23, 6, 24, 7, 30, 46, 43, 29, 53, 56, 35, 12, 39, 11, 28,
|
||||||
|
52, 55, 14, 26, 50, 58, 37, 19, 15, 16, 17, 18, 4, 3, 2, 1, 0, 9,
|
||||||
|
32, 48, 40, 21, 54, 57, 36, 13, 25, 49, 41, 22, 5, 33, 59, 38, 10, 27,
|
||||||
|
51, 44, 20, 8, 31, 47],
|
||||||
|
[35, 12, 29, 53, 56, 25, 54, 57, 36, 13, 20, 8, 31, 47, 44, 30, 46, 43,
|
||||||
|
24, 7, 15, 16, 17, 18, 19, 10, 27, 51, 59, 38, 5, 33, 49, 41, 22, 0,
|
||||||
|
4, 3, 2, 1, 55, 39, 11, 28, 52, 40, 21, 9, 32, 48, 50, 58, 37, 14,
|
||||||
|
26, 45, 42, 23, 6, 34],
|
||||||
|
[36, 13, 25, 54, 57, 26, 50, 58, 37, 14, 21, 9, 32, 48, 40, 31, 47, 44,
|
||||||
|
20, 8, 16, 17, 18, 19, 15, 11, 28, 52, 55, 39, 6, 34, 45, 42, 23, 1,
|
||||||
|
0, 4, 3, 2, 56, 35, 12, 29, 53, 41, 22, 5, 33, 49, 51, 59, 38, 10,
|
||||||
|
27, 46, 43, 24, 7, 30],
|
||||||
|
[37, 14, 26, 50, 58, 27, 51, 59, 38, 10, 22, 5, 33, 49, 41, 32, 48, 40,
|
||||||
|
21, 9, 17, 18, 19, 15, 16, 12, 29, 53, 56, 35, 7, 30, 46, 43, 24, 2,
|
||||||
|
1, 0, 4, 3, 57, 36, 13, 25, 54, 42, 23, 6, 34, 45, 52, 55, 39, 11,
|
||||||
|
28, 47, 44, 20, 8, 31],
|
||||||
|
[38, 10, 27, 51, 59, 28, 52, 55, 39, 11, 23, 6, 34, 45, 42, 33, 49, 41,
|
||||||
|
22, 5, 18, 19, 15, 16, 17, 13, 25, 54, 57, 36, 8, 31, 47, 44, 20, 3,
|
||||||
|
2, 1, 0, 4, 58, 37, 14, 26, 50, 43, 24, 7, 30, 46, 53, 56, 35, 12,
|
||||||
|
29, 48, 40, 21, 9, 32],
|
||||||
|
[39, 11, 28, 52, 55, 29, 53, 56, 35, 12, 24, 7, 30, 46, 43, 34, 45, 42,
|
||||||
|
23, 6, 19, 15, 16, 17, 18, 14, 26, 50, 58, 37, 9, 32, 48, 40, 21, 4,
|
||||||
|
3, 2, 1, 0, 59, 38, 10, 27, 51, 44, 20, 8, 31, 47, 54, 57, 36, 13,
|
||||||
|
25, 49, 41, 22, 5, 33],
|
||||||
|
[40, 21, 9, 32, 48, 55, 39, 11, 28, 52, 45, 42, 23, 6, 34, 50, 58, 37,
|
||||||
|
14, 26, 20, 8, 31, 47, 44, 30, 46, 43, 24, 7, 35, 12, 29, 53, 56, 25,
|
||||||
|
54, 57, 36, 13, 0, 4, 3, 2, 1, 5, 33, 49, 41, 22, 10, 27, 51, 59,
|
||||||
|
38, 15, 16, 17, 18, 19],
|
||||||
|
[41, 22, 5, 33, 49, 56, 35, 12, 29, 53, 46, 43, 24, 7, 30, 51, 59, 38,
|
||||||
|
10, 27, 21, 9, 32, 48, 40, 31, 47, 44, 20, 8, 36, 13, 25, 54, 57, 26,
|
||||||
|
50, 58, 37, 14, 1, 0, 4, 3, 2, 6, 34, 45, 42, 23, 11, 28, 52, 55,
|
||||||
|
39, 16, 17, 18, 19, 15],
|
||||||
|
[42, 23, 6, 34, 45, 57, 36, 13, 25, 54, 47, 44, 20, 8, 31, 52, 55, 39,
|
||||||
|
11, 28, 22, 5, 33, 49, 41, 32, 48, 40, 21, 9, 37, 14, 26, 50, 58, 27,
|
||||||
|
51, 59, 38, 10, 2, 1, 0, 4, 3, 7, 30, 46, 43, 24, 12, 29, 53, 56,
|
||||||
|
35, 17, 18, 19, 15, 16],
|
||||||
|
[43, 24, 7, 30, 46, 58, 37, 14, 26, 50, 48, 40, 21, 9, 32, 53, 56, 35,
|
||||||
|
12, 29, 23, 6, 34, 45, 42, 33, 49, 41, 22, 5, 38, 10, 27, 51, 59, 28,
|
||||||
|
52, 55, 39, 11, 3, 2, 1, 0, 4, 8, 31, 47, 44, 20, 13, 25, 54, 57,
|
||||||
|
36, 18, 19, 15, 16, 17],
|
||||||
|
[44, 20, 8, 31, 47, 59, 38, 10, 27, 51, 49, 41, 22, 5, 33, 54, 57, 36,
|
||||||
|
13, 25, 24, 7, 30, 46, 43, 34, 45, 42, 23, 6, 39, 11, 28, 52, 55, 29,
|
||||||
|
53, 56, 35, 12, 4, 3, 2, 1, 0, 9, 32, 48, 40, 21, 14, 26, 50, 58,
|
||||||
|
37, 19, 15, 16, 17, 18],
|
||||||
|
[45, 42, 23, 6, 34, 50, 58, 37, 14, 26, 40, 21, 9, 32, 48, 55, 39, 11,
|
||||||
|
28, 52, 25, 54, 57, 36, 13, 35, 12, 29, 53, 56, 30, 46, 43, 24, 7, 20,
|
||||||
|
8, 31, 47, 44, 5, 33, 49, 41, 22, 0, 4, 3, 2, 1, 15, 16, 17, 18,
|
||||||
|
19, 10, 27, 51, 59, 38],
|
||||||
|
[46, 43, 24, 7, 30, 51, 59, 38, 10, 27, 41, 22, 5, 33, 49, 56, 35, 12,
|
||||||
|
29, 53, 26, 50, 58, 37, 14, 36, 13, 25, 54, 57, 31, 47, 44, 20, 8, 21,
|
||||||
|
9, 32, 48, 40, 6, 34, 45, 42, 23, 1, 0, 4, 3, 2, 16, 17, 18, 19,
|
||||||
|
15, 11, 28, 52, 55, 39],
|
||||||
|
[47, 44, 20, 8, 31, 52, 55, 39, 11, 28, 42, 23, 6, 34, 45, 57, 36, 13,
|
||||||
|
25, 54, 27, 51, 59, 38, 10, 37, 14, 26, 50, 58, 32, 48, 40, 21, 9, 22,
|
||||||
|
5, 33, 49, 41, 7, 30, 46, 43, 24, 2, 1, 0, 4, 3, 17, 18, 19, 15,
|
||||||
|
16, 12, 29, 53, 56, 35],
|
||||||
|
[48, 40, 21, 9, 32, 53, 56, 35, 12, 29, 43, 24, 7, 30, 46, 58, 37, 14,
|
||||||
|
26, 50, 28, 52, 55, 39, 11, 38, 10, 27, 51, 59, 33, 49, 41, 22, 5, 23,
|
||||||
|
6, 34, 45, 42, 8, 31, 47, 44, 20, 3, 2, 1, 0, 4, 18, 19, 15, 16,
|
||||||
|
17, 13, 25, 54, 57, 36],
|
||||||
|
[49, 41, 22, 5, 33, 54, 57, 36, 13, 25, 44, 20, 8, 31, 47, 59, 38, 10,
|
||||||
|
27, 51, 29, 53, 56, 35, 12, 39, 11, 28, 52, 55, 34, 45, 42, 23, 6, 24,
|
||||||
|
7, 30, 46, 43, 9, 32, 48, 40, 21, 4, 3, 2, 1, 0, 19, 15, 16, 17,
|
||||||
|
18, 14, 26, 50, 58, 37],
|
||||||
|
[50, 58, 37, 14, 26, 45, 42, 23, 6, 34, 55, 39, 11, 28, 52, 40, 21, 9,
|
||||||
|
32, 48, 30, 46, 43, 24, 7, 20, 8, 31, 47, 44, 25, 54, 57, 36, 13, 35,
|
||||||
|
12, 29, 53, 56, 10, 27, 51, 59, 38, 15, 16, 17, 18, 19, 0, 4, 3, 2,
|
||||||
|
1, 5, 33, 49, 41, 22],
|
||||||
|
[51, 59, 38, 10, 27, 46, 43, 24, 7, 30, 56, 35, 12, 29, 53, 41, 22, 5,
|
||||||
|
33, 49, 31, 47, 44, 20, 8, 21, 9, 32, 48, 40, 26, 50, 58, 37, 14, 36,
|
||||||
|
13, 25, 54, 57, 11, 28, 52, 55, 39, 16, 17, 18, 19, 15, 1, 0, 4, 3,
|
||||||
|
2, 6, 34, 45, 42, 23],
|
||||||
|
[52, 55, 39, 11, 28, 47, 44, 20, 8, 31, 57, 36, 13, 25, 54, 42, 23, 6,
|
||||||
|
34, 45, 32, 48, 40, 21, 9, 22, 5, 33, 49, 41, 27, 51, 59, 38, 10, 37,
|
||||||
|
14, 26, 50, 58, 12, 29, 53, 56, 35, 17, 18, 19, 15, 16, 2, 1, 0, 4,
|
||||||
|
3, 7, 30, 46, 43, 24],
|
||||||
|
[53, 56, 35, 12, 29, 48, 40, 21, 9, 32, 58, 37, 14, 26, 50, 43, 24, 7,
|
||||||
|
30, 46, 33, 49, 41, 22, 5, 23, 6, 34, 45, 42, 28, 52, 55, 39, 11, 38,
|
||||||
|
10, 27, 51, 59, 13, 25, 54, 57, 36, 18, 19, 15, 16, 17, 3, 2, 1, 0,
|
||||||
|
4, 8, 31, 47, 44, 20],
|
||||||
|
[54, 57, 36, 13, 25, 49, 41, 22, 5, 33, 59, 38, 10, 27, 51, 44, 20, 8,
|
||||||
|
31, 47, 34, 45, 42, 23, 6, 24, 7, 30, 46, 43, 29, 53, 56, 35, 12, 39,
|
||||||
|
11, 28, 52, 55, 14, 26, 50, 58, 37, 19, 15, 16, 17, 18, 4, 3, 2, 1,
|
||||||
|
0, 9, 32, 48, 40, 21],
|
||||||
|
[55, 39, 11, 28, 52, 40, 21, 9, 32, 48, 50, 58, 37, 14, 26, 45, 42, 23,
|
||||||
|
6, 34, 35, 12, 29, 53, 56, 25, 54, 57, 36, 13, 20, 8, 31, 47, 44, 30,
|
||||||
|
46, 43, 24, 7, 15, 16, 17, 18, 19, 10, 27, 51, 59, 38, 5, 33, 49, 41,
|
||||||
|
22, 0, 4, 3, 2, 1],
|
||||||
|
[56, 35, 12, 29, 53, 41, 22, 5, 33, 49, 51, 59, 38, 10, 27, 46, 43, 24,
|
||||||
|
7, 30, 36, 13, 25, 54, 57, 26, 50, 58, 37, 14, 21, 9, 32, 48, 40, 31,
|
||||||
|
47, 44, 20, 8, 16, 17, 18, 19, 15, 11, 28, 52, 55, 39, 6, 34, 45, 42,
|
||||||
|
23, 1, 0, 4, 3, 2],
|
||||||
|
[57, 36, 13, 25, 54, 42, 23, 6, 34, 45, 52, 55, 39, 11, 28, 47, 44, 20,
|
||||||
|
8, 31, 37, 14, 26, 50, 58, 27, 51, 59, 38, 10, 22, 5, 33, 49, 41, 32,
|
||||||
|
48, 40, 21, 9, 17, 18, 19, 15, 16, 12, 29, 53, 56, 35, 7, 30, 46, 43,
|
||||||
|
24, 2, 1, 0, 4, 3],
|
||||||
|
[58, 37, 14, 26, 50, 43, 24, 7, 30, 46, 53, 56, 35, 12, 29, 48, 40, 21,
|
||||||
|
9, 32, 38, 10, 27, 51, 59, 28, 52, 55, 39, 11, 23, 6, 34, 45, 42, 33,
|
||||||
|
49, 41, 22, 5, 18, 19, 15, 16, 17, 13, 25, 54, 57, 36, 8, 31, 47, 44,
|
||||||
|
20, 3, 2, 1, 0, 4],
|
||||||
|
[59, 38, 10, 27, 51, 44, 20, 8, 31, 47, 54, 57, 36, 13, 25, 49, 41, 22,
|
||||||
|
5, 33, 39, 11, 28, 52, 55, 29, 53, 56, 35, 12, 24, 7, 30, 46, 43, 34,
|
||||||
|
45, 42, 23, 6, 19, 15, 16, 17, 18, 14, 26, 50, 58, 37, 9, 32, 48, 40,
|
||||||
|
21, 4, 3, 2, 1, 0]])
|
||||||
|
Rs = torch.zeros(60,3,3)
|
||||||
|
Rs[0]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
|
||||||
|
Rs[1]=torch.tensor([[ 0.500000,-0.809017,0.309017],[ 0.809017 ,0.309017,-0.500000],[ 0.309017, 0.500000,0.809017]])
|
||||||
|
Rs[2]=torch.tensor([[-0.309017,-0.500000,0.809017],[ 0.500000,-0.809017,-0.309017],[ 0.809017, 0.309017,0.500000]])
|
||||||
|
Rs[3]=torch.tensor([[-0.309017, 0.500000,0.809017],[-0.500000,-0.809017,0.309017],[ 0.809017,-0.309017,0.500000]])
|
||||||
|
Rs[4]=torch.tensor([[ 0.500000, 0.809017,0.309017],[-0.809017, 0.309017,0.500000],[ 0.309017,-0.500000,0.809017]])
|
||||||
|
Rs[5]=torch.tensor([[-0.809017, 0.309017,0.500000],[ 0.309017,-0.500000,0.809017],[ 0.500000, 0.809017,0.309017]])
|
||||||
|
Rs[6]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 0.000000, 0.000000,1.000000],[ 1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[7]=torch.tensor([[ 0.809017 ,0.309017,-0.500000],[ 0.309017, 0.500000,0.809017],[ 0.500000,-0.809017,0.309017]])
|
||||||
|
Rs[8]=torch.tensor([[ 0.500000,-0.809017,-0.309017],[ 0.809017, 0.309017,0.500000],[-0.309017,-0.500000,0.809017]])
|
||||||
|
Rs[9]=torch.tensor([[-0.500000,-0.809017,0.309017],[ 0.809017,-0.309017,0.500000],[-0.309017, 0.500000,0.809017]])
|
||||||
|
Rs[10]=torch.tensor([[-0.500000,-0.809017,0.309017],[-0.809017 ,0.309017,-0.500000],[ 0.309017,-0.500000,-0.809017]])
|
||||||
|
Rs[11]=torch.tensor([[-0.809017, 0.309017,0.500000],[-0.309017 ,0.500000,-0.809017],[-0.500000,-0.809017,-0.309017]])
|
||||||
|
Rs[12]=torch.tensor([[ 0.000000, 1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[-1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[13]=torch.tensor([[ 0.809017 ,0.309017,-0.500000],[-0.309017,-0.500000,-0.809017],[-0.500000 ,0.809017,-0.309017]])
|
||||||
|
Rs[14]=torch.tensor([[ 0.500000,-0.809017,-0.309017],[-0.809017,-0.309017,-0.500000],[ 0.309017 ,0.500000,-0.809017]])
|
||||||
|
Rs[15]=torch.tensor([[ 0.309017 ,0.500000,-0.809017],[ 0.500000,-0.809017,-0.309017],[-0.809017,-0.309017,-0.500000]])
|
||||||
|
Rs[16]=torch.tensor([[ 0.309017,-0.500000,-0.809017],[-0.500000,-0.809017,0.309017],[-0.809017 ,0.309017,-0.500000]])
|
||||||
|
Rs[17]=torch.tensor([[-0.500000,-0.809017,-0.309017],[-0.809017, 0.309017,0.500000],[-0.309017 ,0.500000,-0.809017]])
|
||||||
|
Rs[18]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
|
||||||
|
Rs[19]=torch.tensor([[-0.500000 ,0.809017,-0.309017],[ 0.809017 ,0.309017,-0.500000],[-0.309017,-0.500000,-0.809017]])
|
||||||
|
Rs[20]=torch.tensor([[-0.500000,-0.809017,-0.309017],[ 0.809017,-0.309017,-0.500000],[ 0.309017,-0.500000,0.809017]])
|
||||||
|
Rs[21]=torch.tensor([[-1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000],[ 0.000000, 0.000000,1.000000]])
|
||||||
|
Rs[22]=torch.tensor([[-0.500000 ,0.809017,-0.309017],[-0.809017,-0.309017,0.500000],[ 0.309017, 0.500000,0.809017]])
|
||||||
|
Rs[23]=torch.tensor([[ 0.309017 ,0.500000,-0.809017],[-0.500000, 0.809017,0.309017],[ 0.809017, 0.309017,0.500000]])
|
||||||
|
Rs[24]=torch.tensor([[ 0.309017,-0.500000,-0.809017],[ 0.500000 ,0.809017,-0.309017],[ 0.809017,-0.309017,0.500000]])
|
||||||
|
Rs[25]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[-1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000]])
|
||||||
|
Rs[26]=torch.tensor([[-0.309017,-0.500000,-0.809017],[-0.500000 ,0.809017,-0.309017],[ 0.809017 ,0.309017,-0.500000]])
|
||||||
|
Rs[27]=torch.tensor([[-0.809017,-0.309017,-0.500000],[ 0.309017 ,0.500000,-0.809017],[ 0.500000,-0.809017,-0.309017]])
|
||||||
|
Rs[28]=torch.tensor([[-0.809017 ,0.309017,-0.500000],[ 0.309017,-0.500000,-0.809017],[-0.500000,-0.809017,0.309017]])
|
||||||
|
Rs[29]=torch.tensor([[-0.309017 ,0.500000,-0.809017],[-0.500000,-0.809017,-0.309017],[-0.809017, 0.309017,0.500000]])
|
||||||
|
Rs[30]=torch.tensor([[ 0.809017, 0.309017,0.500000],[-0.309017,-0.500000,0.809017],[ 0.500000,-0.809017,-0.309017]])
|
||||||
|
Rs[31]=torch.tensor([[ 0.809017,-0.309017,0.500000],[-0.309017, 0.500000,0.809017],[-0.500000,-0.809017,0.309017]])
|
||||||
|
Rs[32]=torch.tensor([[ 0.309017,-0.500000,0.809017],[ 0.500000, 0.809017,0.309017],[-0.809017, 0.309017,0.500000]])
|
||||||
|
Rs[33]=torch.tensor([[ 0.000000, 0.000000,1.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000, 1.000000,0.000000]])
|
||||||
|
Rs[34]=torch.tensor([[ 0.309017, 0.500000,0.809017],[ 0.500000,-0.809017,0.309017],[ 0.809017 ,0.309017,-0.500000]])
|
||||||
|
Rs[35]=torch.tensor([[-0.309017, 0.500000,0.809017],[ 0.500000 ,0.809017,-0.309017],[-0.809017 ,0.309017,-0.500000]])
|
||||||
|
Rs[36]=torch.tensor([[ 0.500000, 0.809017,0.309017],[ 0.809017,-0.309017,-0.500000],[-0.309017 ,0.500000,-0.809017]])
|
||||||
|
Rs[37]=torch.tensor([[ 1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000]])
|
||||||
|
Rs[38]=torch.tensor([[ 0.500000,-0.809017,0.309017],[-0.809017,-0.309017,0.500000],[-0.309017,-0.500000,-0.809017]])
|
||||||
|
Rs[39]=torch.tensor([[-0.309017,-0.500000,0.809017],[-0.500000, 0.809017,0.309017],[-0.809017,-0.309017,-0.500000]])
|
||||||
|
Rs[40]=torch.tensor([[-0.500000, 0.809017,0.309017],[-0.809017,-0.309017,-0.500000],[-0.309017,-0.500000,0.809017]])
|
||||||
|
Rs[41]=torch.tensor([[ 0.500000 ,0.809017,-0.309017],[-0.809017 ,0.309017,-0.500000],[-0.309017, 0.500000,0.809017]])
|
||||||
|
Rs[42]=torch.tensor([[ 0.809017,-0.309017,-0.500000],[-0.309017 ,0.500000,-0.809017],[ 0.500000, 0.809017,0.309017]])
|
||||||
|
Rs[43]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 0.000000 ,0.000000,-1.000000],[ 1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[44]=torch.tensor([[-0.809017,-0.309017,0.500000],[-0.309017,-0.500000,-0.809017],[ 0.500000,-0.809017,0.309017]])
|
||||||
|
Rs[45]=torch.tensor([[ 0.809017,-0.309017,0.500000],[ 0.309017,-0.500000,-0.809017],[ 0.500000 ,0.809017,-0.309017]])
|
||||||
|
Rs[46]=torch.tensor([[ 0.309017,-0.500000,0.809017],[-0.500000,-0.809017,-0.309017],[ 0.809017,-0.309017,-0.500000]])
|
||||||
|
Rs[47]=torch.tensor([[ 0.000000, 0.000000,1.000000],[-1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000]])
|
||||||
|
Rs[48]=torch.tensor([[ 0.309017, 0.500000,0.809017],[-0.500000 ,0.809017,-0.309017],[-0.809017,-0.309017,0.500000]])
|
||||||
|
Rs[49]=torch.tensor([[ 0.809017, 0.309017,0.500000],[ 0.309017 ,0.500000,-0.809017],[-0.500000, 0.809017,0.309017]])
|
||||||
|
Rs[50]=torch.tensor([[-0.309017 ,0.500000,-0.809017],[ 0.500000, 0.809017,0.309017],[ 0.809017,-0.309017,-0.500000]])
|
||||||
|
Rs[51]=torch.tensor([[ 0.000000 ,0.000000,-1.000000],[ 1.000000, 0.000000,0.000000],[ 0.000000,-1.000000,0.000000]])
|
||||||
|
Rs[52]=torch.tensor([[-0.309017,-0.500000,-0.809017],[ 0.500000,-0.809017,0.309017],[-0.809017,-0.309017,0.500000]])
|
||||||
|
Rs[53]=torch.tensor([[-0.809017,-0.309017,-0.500000],[-0.309017,-0.500000,0.809017],[-0.500000, 0.809017,0.309017]])
|
||||||
|
Rs[54]=torch.tensor([[-0.809017 ,0.309017,-0.500000],[-0.309017, 0.500000,0.809017],[ 0.500000 ,0.809017,-0.309017]])
|
||||||
|
Rs[55]=torch.tensor([[ 0.000000,-1.000000,0.000000],[ 0.000000, 0.000000,1.000000],[-1.000000, 0.000000,0.000000]])
|
||||||
|
Rs[56]=torch.tensor([[-0.809017,-0.309017,0.500000],[ 0.309017, 0.500000,0.809017],[-0.500000 ,0.809017,-0.309017]])
|
||||||
|
Rs[57]=torch.tensor([[-0.500000, 0.809017,0.309017],[ 0.809017, 0.309017,0.500000],[ 0.309017 ,0.500000,-0.809017]])
|
||||||
|
Rs[58]=torch.tensor([[ 0.500000 ,0.809017,-0.309017],[ 0.809017,-0.309017,0.500000],[ 0.309017,-0.500000,-0.809017]])
|
||||||
|
Rs[59]=torch.tensor([[ 0.809017,-0.309017,-0.500000],[ 0.309017,-0.500000,0.809017],[-0.500000,-0.809017,-0.309017]])
|
||||||
|
|
||||||
|
est_radius = 10.0*SYMA
|
||||||
|
offset = torch.tensor([ 1.0,0.0,0.0 ])
|
||||||
|
offset = est_radius * offset / torch.linalg.norm(offset)
|
||||||
|
metasymm = (
|
||||||
|
[torch.arange(60)],
|
||||||
|
[6]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print ("Unknown symmetry",symmid)
|
||||||
|
assert False
|
||||||
|
|
||||||
|
return symmatrix,Rs,metasymm,offset
|
181
rf2aa/tensor_util.py
Normal file
181
rf2aa/tensor_util.py
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
from deepdiff import DeepDiff
|
||||||
|
import pprint
|
||||||
|
import assertpy
|
||||||
|
import dataclasses
|
||||||
|
from collections import OrderedDict
|
||||||
|
import contextlib
|
||||||
|
from deepdiff.operator import BaseOperator
|
||||||
|
|
||||||
|
def assert_shape(t, s):
|
||||||
|
assertpy.assert_that(tuple(t.shape)).is_equal_to(s)
|
||||||
|
|
||||||
|
def assert_same_shape(t, s):
|
||||||
|
assertpy.assert_that(tuple(t.shape)).is_equal_to(tuple(s.shape))
|
||||||
|
|
||||||
|
class ExceptionLogger(contextlib.AbstractContextManager):
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback):
|
||||||
|
if exc_type:
|
||||||
|
print("***Logging exception {}***".format((exc_type, exc_value,
|
||||||
|
traceback)))
|
||||||
|
|
||||||
|
|
||||||
|
def assert_equal(got, want):
|
||||||
|
assertpy.assert_that(got.dtype).is_equal_to(want.dtype)
|
||||||
|
assertpy.assert_that(got.shape).is_equal_to(want.shape)
|
||||||
|
is_eq = got.nan_to_num()==want.nan_to_num()
|
||||||
|
unequal_idx = torch.nonzero(~is_eq)
|
||||||
|
unequal_got = got[~is_eq]
|
||||||
|
unequal_want = want[~is_eq]
|
||||||
|
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_want, unequal_got))[:3]
|
||||||
|
|
||||||
|
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
|
||||||
|
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
|
||||||
|
if torch.numel(got) < 10:
|
||||||
|
msg = f'got {got}, want: {want}'
|
||||||
|
assert len(unequal_idx) == 0, msg
|
||||||
|
|
||||||
|
|
||||||
|
def assert_close(got, want, atol=1e-4, rtol=1e-8):
|
||||||
|
got = torch.nan_to_num(got)
|
||||||
|
want = torch.nan_to_num(want)
|
||||||
|
if got.shape != want.shape:
|
||||||
|
raise ValueError(f'Wrong shapes: got shape {got.shape} want shape {want.shape}')
|
||||||
|
elif not torch.allclose(got, want, atol=atol, rtol=rtol):
|
||||||
|
maximum_difference = torch.abs(got - want).max().item()
|
||||||
|
indices_different = torch.nonzero(got != want)
|
||||||
|
raise ValueError(f'Maximum difference: {maximum_difference}, indices different: {indices_different}')
|
||||||
|
|
||||||
|
|
||||||
|
def cpu(e):
|
||||||
|
if isinstance(e, dict):
|
||||||
|
return {cpu(k): cpu(v) for k,v in e.items()}
|
||||||
|
if isinstance(e, list) or isinstance(e, tuple):
|
||||||
|
return tuple(cpu(i) for i in e)
|
||||||
|
if hasattr(e, 'cpu'):
|
||||||
|
return e.cpu().detach()
|
||||||
|
return e
|
||||||
|
|
||||||
|
# Dataclass functions
|
||||||
|
|
||||||
|
def to_ordered_dict(dc):
|
||||||
|
return OrderedDict((field.name, getattr(dc, field.name)) for field in dataclasses.fields(dc))
|
||||||
|
|
||||||
|
def to_device(dc, device):
|
||||||
|
d = to_ordered_dict(dc)
|
||||||
|
for k, v in d.items():
|
||||||
|
if v is not None:
|
||||||
|
setattr(dc, k, v.to(device))
|
||||||
|
|
||||||
|
def shapes(dc):
|
||||||
|
d = to_ordered_dict(dc)
|
||||||
|
return {k:v.shape if hasattr(v, 'shape') else None for k,v in d.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def pprint_obj(obj):
|
||||||
|
pprint.pprint(obj.__dict__, indent=4)
|
||||||
|
|
||||||
|
def assert_squeeze(t):
|
||||||
|
assert t.shape[0] == 1, f'{t.shape}[0] != 1'
|
||||||
|
return t[0]
|
||||||
|
|
||||||
|
def apply_to_tensors(e, op):
|
||||||
|
# if isinstance(e, dataclasses.)
|
||||||
|
if dataclasses.is_dataclass(e):
|
||||||
|
# return to_ordered_dict
|
||||||
|
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
|
||||||
|
if isinstance(e, dict):
|
||||||
|
return {k: apply_to_tensors(v, op) for k,v in e.items()}
|
||||||
|
if isinstance(e, list) or isinstance(e, tuple):
|
||||||
|
return tuple(apply_to_tensors(i, op) for i in e)
|
||||||
|
if hasattr(e, 'cpu'):
|
||||||
|
return op(e)
|
||||||
|
return e
|
||||||
|
|
||||||
|
def apply_to_matching(e, op, filt):
|
||||||
|
# if isinstance(e, dataclasses.)
|
||||||
|
if filt(e):
|
||||||
|
return op(e)
|
||||||
|
if dataclasses.is_dataclass(e):
|
||||||
|
# return to_ordered_dict
|
||||||
|
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
|
||||||
|
if isinstance(e, dict):
|
||||||
|
return {k: apply_to_tensors(v, op) for k,v in e.items()}
|
||||||
|
if isinstance(e, list) or isinstance(e, tuple):
|
||||||
|
return tuple(apply_to_tensors(i, op) for i in e)
|
||||||
|
return e
|
||||||
|
|
||||||
|
def set_grad(t):
|
||||||
|
t.requires_grad = True
|
||||||
|
|
||||||
|
def require_grad(e):
|
||||||
|
apply_to_tensors(e, set_grad)
|
||||||
|
|
||||||
|
def get_grad(e):
|
||||||
|
return apply_to_tensors(e, lambda x: x.grad)
|
||||||
|
|
||||||
|
def info(e):
|
||||||
|
shap = apply_to_tensors(e, lambda x: x.shape)
|
||||||
|
shap = apply_to_matching(shap, str, dataclasses.is_dataclass)
|
||||||
|
return json.dumps(shap, indent=4)
|
||||||
|
|
||||||
|
def minmax(e):
|
||||||
|
return apply_to_tensors(e, lambda x: (torch.log10(torch.min(x)), torch.log10(torch.max(x))))
|
||||||
|
|
||||||
|
class TensorMatchOperator(BaseOperator):
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, atol=1e-3, rtol=0, **kwargs):
|
||||||
|
super(TensorMatchOperator, self).__init__(**kwargs)
|
||||||
|
self.atol = atol
|
||||||
|
self.rtol = rtol
|
||||||
|
|
||||||
|
def _equal_msg(self, got, want):
|
||||||
|
if got.shape != want.shape:
|
||||||
|
return f'got shape {got.shape} want shape {want.shape}'
|
||||||
|
if got.dtype != want.dtype:
|
||||||
|
return f'got dtype {got.dtype} want dtype {want.dtype}'
|
||||||
|
if torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol).all():
|
||||||
|
return ''
|
||||||
|
is_eq = torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol)
|
||||||
|
unequal_idx = torch.nonzero(~is_eq)
|
||||||
|
unequal_got = got[~is_eq]
|
||||||
|
unequal_want = want[~is_eq]
|
||||||
|
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_got, unequal_want))[:3]
|
||||||
|
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
|
||||||
|
uneq_msg += f' fraction unequal:{unequal_got.numel()}/{got.numel()}'
|
||||||
|
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
|
||||||
|
if torch.numel(got) < 10:
|
||||||
|
msg = f'got {got}, want: {want}'
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def give_up_diffing(self, level, diff_instance):
|
||||||
|
msg = self._equal_msg(level.t1, level.t2)
|
||||||
|
if msg:
|
||||||
|
print(f'got:\n{level.t1}\n\nwant:\n{level.t2}')
|
||||||
|
msg = self._equal_msg(level.t1, level.t2)
|
||||||
|
if msg:
|
||||||
|
diff_instance.custom_report_result('tensors unequal', level, {
|
||||||
|
"msg": msg
|
||||||
|
})
|
||||||
|
return True
|
||||||
|
|
||||||
|
class NumpyMatchOperator(TensorMatchOperator):
|
||||||
|
def give_up_diffing(self, level, diff_instance):
|
||||||
|
level.t1 = torch.Tensor(level.t1)
|
||||||
|
level.t2 = torch.Tensor(level.t2)
|
||||||
|
return super(NumpyMatchOperator, self).give_up_diffing(level, diff_instance)
|
||||||
|
|
||||||
|
|
||||||
|
def cmp(got, want, **kwargs):
|
||||||
|
|
||||||
|
dd = DeepDiff(got, want, custom_operators=[
|
||||||
|
NumpyMatchOperator(types=[np.ndarray], **kwargs),
|
||||||
|
TensorMatchOperator(types=[torch.Tensor], **kwargs)])
|
||||||
|
if dd:
|
||||||
|
return dd
|
||||||
|
return ''
|
BIN
rf2aa/test_pickles/model/legacy_train_na_compl_regression.pt
Normal file
BIN
rf2aa/test_pickles/model/legacy_train_na_compl_regression.pt
Normal file
Binary file not shown.
BIN
rf2aa/test_pickles/model/legacy_train_rna_regression.pt
Normal file
BIN
rf2aa/test_pickles/model/legacy_train_rna_regression.pt
Normal file
Binary file not shown.
Binary file not shown.
BIN
rf2aa/test_pickles/model/legacy_train_sm_compl_regression.pt
Normal file
BIN
rf2aa/test_pickles/model/legacy_train_sm_compl_regression.pt
Normal file
Binary file not shown.
79
rf2aa/tests/test_conditions.py
Normal file
79
rf2aa/tests/test_conditions.py
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
import torch
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import itertools
|
||||||
|
from collections import OrderedDict
|
||||||
|
from hydra import initialize, compose
|
||||||
|
|
||||||
|
from rf2aa.setup_model import trainer_factory, seed_all
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
# configurations to test
|
||||||
|
configs = ["legacy_train"]
|
||||||
|
datasets = ["compl", "na_compl", "rna", "sm_compl", "sm_compl_covale", "sm_compl_asmb"]
|
||||||
|
|
||||||
|
cfg_overrides = [
|
||||||
|
"loader_params.p_msa_mask=0.0",
|
||||||
|
"loader_params.crop=100000",
|
||||||
|
"loader_params.mintplt=0",
|
||||||
|
"loader_params.maxtplt=2"
|
||||||
|
]
|
||||||
|
|
||||||
|
def make_deterministic(seed=0):
|
||||||
|
seed_all(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
|
def setup_dataset_names():
|
||||||
|
data = {}
|
||||||
|
for name in datasets:
|
||||||
|
data[name] = [name]
|
||||||
|
return data
|
||||||
|
|
||||||
|
# set up models for regression tests
|
||||||
|
def setup_models(device="cpu"):
|
||||||
|
models, chem_cfgs = [], []
|
||||||
|
for config in configs:
|
||||||
|
with initialize(version_base=None, config_path="../config/train"):
|
||||||
|
cfg = compose(config_name=config, overrides=cfg_overrides)
|
||||||
|
|
||||||
|
# initializing the model needs the chemical DB initialized. Force a reload
|
||||||
|
ChemData.reset()
|
||||||
|
ChemData(cfg.chem_params)
|
||||||
|
|
||||||
|
trainer = trainer_factory[cfg.experiment.trainer](cfg)
|
||||||
|
seed_all()
|
||||||
|
trainer.construct_model(device=device)
|
||||||
|
models.append(trainer.model)
|
||||||
|
chem_cfgs.append(cfg.chem_params)
|
||||||
|
trainer = None
|
||||||
|
|
||||||
|
return dict(zip(configs, (zip(configs, models, chem_cfgs))))
|
||||||
|
|
||||||
|
# set up job array for regression
|
||||||
|
def setup_array(datasets, models, device="cpu"):
|
||||||
|
test_data = setup_dataset_names()
|
||||||
|
test_models = setup_models(device=device)
|
||||||
|
test_data = [test_data[dataset] for dataset in datasets]
|
||||||
|
test_models = [test_models[model] for model in models]
|
||||||
|
return (list(itertools.product(test_data, test_models)))
|
||||||
|
|
||||||
|
def random_param_init(model):
|
||||||
|
seed_all()
|
||||||
|
with torch.no_grad():
|
||||||
|
fake_state_dict = OrderedDict()
|
||||||
|
for name, param in model.model.named_parameters():
|
||||||
|
fake_state_dict[name] = torch.randn_like(param)
|
||||||
|
model.model.load_state_dict(fake_state_dict)
|
||||||
|
model.shadow.load_state_dict(fake_state_dict)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def dataset_pickle_path(dataset_name):
|
||||||
|
return f"test_pickles/data/{dataset_name}_regression.pt"
|
||||||
|
|
||||||
|
def model_pickle_path(dataset_name, model_name):
|
||||||
|
return f"test_pickles/model/{model_name}_{dataset_name}_regression.pt"
|
||||||
|
|
||||||
|
def loss_pickle_path(dataset_name, model_name, loss_name):
|
||||||
|
return f"test_pickles/loss/{loss_name}_{model_name}_{dataset_name}_regression.pt"
|
73
rf2aa/tests/test_model.py
Normal file
73
rf2aa/tests/test_model.py
Normal file
|
@ -0,0 +1,73 @@
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
from rf2aa.data.dataloader_adaptor import prepare_input
|
||||||
|
from rf2aa.training.recycling import run_model_forward_legacy
|
||||||
|
from rf2aa.tensor_util import assert_equal
|
||||||
|
from rf2aa.tests.test_conditions import setup_array,\
|
||||||
|
make_deterministic, dataset_pickle_path, model_pickle_path
|
||||||
|
from rf2aa.util_module import XYZConverter
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
|
||||||
|
# goal is to test all the configs on a broad set of datasets
|
||||||
|
|
||||||
|
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
legacy_test_conditions = setup_array(["na_compl", "rna", "sm_compl", "sm_compl_covale"], ["legacy_train"], device=gpu)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("example,model", legacy_test_conditions)
|
||||||
|
def test_regression_legacy(example, model):
|
||||||
|
dataset_name, dataset_inputs, model_name, model = setup_test(example, model)
|
||||||
|
make_deterministic()
|
||||||
|
output_i = run_model_forward_legacy(model, dataset_inputs, gpu)
|
||||||
|
model_pickle = model_pickle_path(dataset_name, model_name)
|
||||||
|
output_names = ("logits_c6d", "logits_aa", "logits_pae", \
|
||||||
|
"logits_pde", "p_bind", "xyz", "alpha", "xyz_allatom", \
|
||||||
|
"lddt", "seq", "pair", "state")
|
||||||
|
|
||||||
|
if not os.path.exists(model_pickle):
|
||||||
|
torch.save(output_i, model_pickle)
|
||||||
|
else:
|
||||||
|
output_regression = torch.load(model_pickle, map_location=gpu)
|
||||||
|
for idx, output in enumerate(output_i):
|
||||||
|
got = output
|
||||||
|
want = output_regression[idx]
|
||||||
|
if output_names[idx] == "logits_c6d":
|
||||||
|
for i in range(len(want)):
|
||||||
|
|
||||||
|
got_i = got[i]
|
||||||
|
want_i = want[i]
|
||||||
|
try:
|
||||||
|
assert_equal(got_i, want_i)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"{output_names[idx]} not same for model: {model_name} on dataset: {dataset_name}") from e
|
||||||
|
elif output_names[idx] in ["alpha", "xyz_allatom", "seq", "pair", "state"]:
|
||||||
|
try:
|
||||||
|
assert torch.allclose(got, want, atol=1e-4)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"{output_names[idx]} not same for model: {model_name} on dataset: {dataset_name}") from e
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
assert_equal(got, want)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"{output_names[idx]} not same for model: {model_name} on dataset: {dataset_name}") from e
|
||||||
|
|
||||||
|
def setup_test(example, model):
|
||||||
|
model_name, model, config = model
|
||||||
|
|
||||||
|
# initialize chemical database
|
||||||
|
ChemData.reset() # force reload chemical data
|
||||||
|
ChemData(config)
|
||||||
|
|
||||||
|
model = model.to(gpu)
|
||||||
|
dataset_name = example[0]
|
||||||
|
dataloader_inputs = torch.load(dataset_pickle_path(dataset_name), map_location=gpu)
|
||||||
|
xyz_converter = XYZConverter().to(gpu)
|
||||||
|
task, item, network_input, true_crds, mask_crds, msa, mask_msa, unclamp, \
|
||||||
|
negative, symmRs, Lasu, ch_label = prepare_input(dataloader_inputs,xyz_converter, gpu)
|
||||||
|
return dataset_name, network_input, model_name, model
|
||||||
|
|
60
rf2aa/training/EMA.py
Normal file
60
rf2aa/training/EMA.py
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
|
||||||
|
class EMA(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, model, decay):
|
||||||
|
super().__init__()
|
||||||
|
self.decay = decay
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.shadow = deepcopy(self.model)
|
||||||
|
|
||||||
|
for param in self.shadow.parameters():
|
||||||
|
param.detach_()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update(self):
|
||||||
|
if not self.training:
|
||||||
|
print("EMA update should only be called during training", file=stderr, flush=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
model_params = OrderedDict(self.model.named_parameters())
|
||||||
|
shadow_params = OrderedDict(self.shadow.named_parameters())
|
||||||
|
|
||||||
|
# check if both model contains the same set of keys
|
||||||
|
assert model_params.keys() == shadow_params.keys()
|
||||||
|
|
||||||
|
for name, param in model_params.items():
|
||||||
|
# see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||||
|
# shadow_variable -= (1 - decay) * (shadow_variable - variable)
|
||||||
|
if param.requires_grad:
|
||||||
|
shadow_params[name].sub_((1. - self.decay) * (shadow_params[name] - param))
|
||||||
|
|
||||||
|
model_buffers = OrderedDict(self.model.named_buffers())
|
||||||
|
shadow_buffers = OrderedDict(self.shadow.named_buffers())
|
||||||
|
|
||||||
|
# check if both model contains the same set of keys
|
||||||
|
assert model_buffers.keys() == shadow_buffers.keys()
|
||||||
|
|
||||||
|
for name, buffer in model_buffers.items():
|
||||||
|
# buffers are copied
|
||||||
|
shadow_buffers[name].copy_(buffer)
|
||||||
|
|
||||||
|
#fd A hack to allow non-DDP models to be passed into the Trainer
|
||||||
|
def no_sync(self):
|
||||||
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
if self.training:
|
||||||
|
return self.model(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return self.shadow(*args, **kwargs)
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
5
rf2aa/training/checkpoint.py
Normal file
5
rf2aa/training/checkpoint.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# for gradient checkpointing
|
||||||
|
def create_custom_forward(module, **kwargs):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, **kwargs)
|
||||||
|
return custom_forward
|
71
rf2aa/training/recycling.py
Normal file
71
rf2aa/training/recycling.py
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from contextlib import ExitStack
|
||||||
|
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
|
||||||
|
def recycle_step_legacy(ddp_model, input, n_cycle, use_amp, nograds=False, force_device=None):
|
||||||
|
if force_device is not None:
|
||||||
|
gpu = force_device
|
||||||
|
else:
|
||||||
|
gpu = ddp_model.device
|
||||||
|
|
||||||
|
xyz_prev, alpha_prev, mask_recycle = \
|
||||||
|
input["xyz_prev"], input["alpha_prev"], input["mask_recycle"]
|
||||||
|
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||||
|
for i_cycle in range(n_cycle):
|
||||||
|
with ExitStack() as stack:
|
||||||
|
stack.enter_context(torch.cuda.amp.autocast(enabled=use_amp))
|
||||||
|
if i_cycle < n_cycle -1 or nograds is True:
|
||||||
|
stack.enter_context(torch.no_grad())
|
||||||
|
if force_device is None:
|
||||||
|
stack.enter_context(ddp_model.no_sync())
|
||||||
|
return_raw = (i_cycle < n_cycle -1)
|
||||||
|
use_checkpoint = not nograds and (i_cycle == n_cycle -1)
|
||||||
|
|
||||||
|
input_i = add_recycle_inputs(input, output_i, i_cycle, gpu, return_raw=return_raw, use_checkpoint=use_checkpoint)
|
||||||
|
output_i = ddp_model(**input_i)
|
||||||
|
return output_i
|
||||||
|
|
||||||
|
|
||||||
|
def run_model_forward_legacy(model, network_input, device="cpu"):
|
||||||
|
""" run model forward pass, no recycling or ddp with legacy model (for tests)"""
|
||||||
|
gpu = device
|
||||||
|
xyz_prev, alpha_prev, mask_recycle = \
|
||||||
|
network_input["xyz_prev"], network_input["alpha_prev"], network_input["mask_recycle"]
|
||||||
|
output_i = (None, None, xyz_prev, alpha_prev, mask_recycle)
|
||||||
|
input_i = add_recycle_inputs(network_input, output_i, 0, gpu, return_raw=False, use_checkpoint=False)
|
||||||
|
input_i["seq_unmasked"] = input_i["seq_unmasked"].to(gpu)
|
||||||
|
input_i["sctors"] = input_i["sctors"].to(gpu)
|
||||||
|
model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
output_i = model(**input_i)
|
||||||
|
|
||||||
|
return output_i
|
||||||
|
|
||||||
|
def add_recycle_inputs(network_input, output_i, i_cycle, gpu, return_raw=False, use_checkpoint=False):
|
||||||
|
input_i = {}
|
||||||
|
for key in network_input:
|
||||||
|
if key in ['msa_latent', 'msa_full', 'seq']:
|
||||||
|
input_i[key] = network_input[key][:,i_cycle].to(gpu, non_blocking=True)
|
||||||
|
else:
|
||||||
|
input_i[key] = network_input[key]
|
||||||
|
|
||||||
|
L = input_i["msa_latent"].shape[2]
|
||||||
|
msa_prev, pair_prev, _, alpha, mask_recycle = output_i
|
||||||
|
xyz_prev = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(1,L,1,1).to(gpu, non_blocking=True)
|
||||||
|
|
||||||
|
input_i['msa_prev'] = msa_prev
|
||||||
|
input_i['pair_prev'] = pair_prev
|
||||||
|
input_i['xyz'] = xyz_prev
|
||||||
|
input_i['mask_recycle'] = mask_recycle
|
||||||
|
input_i['sctors'] = alpha
|
||||||
|
input_i['return_raw'] = return_raw
|
||||||
|
input_i['use_checkpoint'] = use_checkpoint
|
||||||
|
|
||||||
|
input_i.pop('xyz_prev')
|
||||||
|
input_i.pop('alpha_prev')
|
||||||
|
return input_i
|
1044
rf2aa/util.py
Normal file
1044
rf2aa/util.py
Normal file
File diff suppressed because it is too large
Load diff
710
rf2aa/util_module.py
Normal file
710
rf2aa/util_module.py
Normal file
|
@ -0,0 +1,710 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from opt_einsum import contract as einsum
|
||||||
|
import copy
|
||||||
|
import dgl
|
||||||
|
import networkx as nx
|
||||||
|
from rf2aa.util import *
|
||||||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||||||
|
from rf2aa.chemical import th_dih, th_ang_v
|
||||||
|
|
||||||
|
def init_lecun_normal(module, scale=1.0):
|
||||||
|
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
|
||||||
|
normal = torch.distributions.normal.Normal(0, 1)
|
||||||
|
|
||||||
|
alpha = (a - mu) / sigma
|
||||||
|
beta = (b - mu) / sigma
|
||||||
|
|
||||||
|
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
|
||||||
|
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
|
||||||
|
|
||||||
|
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
|
||||||
|
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
|
||||||
|
x = torch.clamp(x, a, b)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sample_truncated_normal(shape, scale=1.0):
|
||||||
|
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
|
||||||
|
return stddev * truncated_normal(torch.rand(shape))
|
||||||
|
|
||||||
|
module.weight = torch.nn.Parameter( (sample_truncated_normal(module.weight.shape)) )
|
||||||
|
return module
|
||||||
|
|
||||||
|
def init_lecun_normal_param(weight, scale=1.0):
|
||||||
|
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
|
||||||
|
normal = torch.distributions.normal.Normal(0, 1)
|
||||||
|
|
||||||
|
alpha = (a - mu) / sigma
|
||||||
|
beta = (b - mu) / sigma
|
||||||
|
|
||||||
|
alpha_normal_cdf = normal.cdf(torch.tensor(alpha))
|
||||||
|
p = alpha_normal_cdf + (normal.cdf(torch.tensor(beta)) - alpha_normal_cdf) * uniform
|
||||||
|
|
||||||
|
v = torch.clamp(2 * p - 1, -1 + 1e-8, 1 - 1e-8)
|
||||||
|
x = mu + sigma * np.sqrt(2) * torch.erfinv(v)
|
||||||
|
x = torch.clamp(x, a, b)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def sample_truncated_normal(shape, scale=1.0):
|
||||||
|
stddev = np.sqrt(scale/shape[-1])/.87962566103423978 # shape[-1] = fan_in
|
||||||
|
return stddev * truncated_normal(torch.rand(shape))
|
||||||
|
|
||||||
|
weight = torch.nn.Parameter( (sample_truncated_normal(weight.shape)) )
|
||||||
|
return weight
|
||||||
|
|
||||||
|
# for gradient checkpointing
|
||||||
|
def create_custom_forward(module, **kwargs):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs, **kwargs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
def get_clones(module, N):
|
||||||
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||||
|
|
||||||
|
class Dropout(nn.Module):
|
||||||
|
# Dropout entire row or column
|
||||||
|
def __init__(self, broadcast_dim=None, p_drop=0.15):
|
||||||
|
super(Dropout, self).__init__()
|
||||||
|
# give ones with probability of 1-p_drop / zeros with p_drop
|
||||||
|
self.sampler = torch.distributions.bernoulli.Bernoulli(torch.tensor([1-p_drop]))
|
||||||
|
self.broadcast_dim=broadcast_dim
|
||||||
|
self.p_drop=p_drop
|
||||||
|
def forward(self, x):
|
||||||
|
if not self.training: # no drophead during evaluation mode
|
||||||
|
return x
|
||||||
|
shape = list(x.shape)
|
||||||
|
if not self.broadcast_dim == None:
|
||||||
|
shape[self.broadcast_dim] = 1
|
||||||
|
mask = self.sampler.sample(shape).to(x.device).view(shape)
|
||||||
|
|
||||||
|
x = mask * x / (1.0 - self.p_drop)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def rbf(D, D_min=0.0, D_count=64, D_sigma=0.5):
|
||||||
|
# Distance radial basis function
|
||||||
|
D_max = D_min + (D_count-1) * D_sigma
|
||||||
|
D_mu = torch.linspace(D_min, D_max, D_count).to(D.device)
|
||||||
|
D_mu = D_mu[None,:]
|
||||||
|
D_expand = torch.unsqueeze(D, -1)
|
||||||
|
RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)
|
||||||
|
return RBF
|
||||||
|
|
||||||
|
def get_seqsep(idx):
|
||||||
|
'''
|
||||||
|
Sequence separation feature for structure module. Protein-only.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- idx: residue indices of given sequence (B,L)
|
||||||
|
Output:
|
||||||
|
- seqsep: sequence separation feature with sign (B, L, L, 1)
|
||||||
|
Sergey found that having sign in seqsep features helps a little
|
||||||
|
'''
|
||||||
|
seqsep = idx[:,None,:] - idx[:,:,None]
|
||||||
|
sign = torch.sign(seqsep)
|
||||||
|
neigh = torch.abs(seqsep)
|
||||||
|
neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0
|
||||||
|
neigh = sign * neigh
|
||||||
|
return neigh.unsqueeze(-1)
|
||||||
|
|
||||||
|
def get_seqsep_protein_sm(idx, bond_feats, dist_matrix, sm_mask):
|
||||||
|
'''
|
||||||
|
Sequence separation features for protein-SM complex
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- idx: residue indices of given sequence (B,L)
|
||||||
|
- bond_feats: bond features (B, L, L)
|
||||||
|
- dist_matrix: precomputed bond distances (B, L, L) NOTE: need to run nan_to_num to remove infinities
|
||||||
|
- sm_mask: boolean feature True if a position represents atom, False if residue (B, L)
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- seqsep: sequence separation feature with sign (B, L, L, 1)
|
||||||
|
-1 or 1 for bonded protein residues
|
||||||
|
1 for bonded SM atoms or residue-atom bonds
|
||||||
|
0 elsewhere
|
||||||
|
'''
|
||||||
|
sm_mask = sm_mask[0] # assume batch = 1
|
||||||
|
res_dist, atom_dist = get_res_atom_dist(idx, bond_feats, dist_matrix, sm_mask)
|
||||||
|
|
||||||
|
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
|
||||||
|
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
|
||||||
|
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
|
||||||
|
|
||||||
|
res_dist[(res_dist > 1) | (res_dist < -1)] = 0.0
|
||||||
|
atom_dist[(atom_dist > 1)] = 0.0
|
||||||
|
|
||||||
|
seqsep = sm_mask_2d*atom_dist + prot_mask_2d*res_dist + inter_mask_2d*(bond_feats==6)
|
||||||
|
|
||||||
|
return seqsep.unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_res_atom_dist(idx, bond_feats, dist_matrix, sm_mask, minpos_res=-32, maxpos_res=32, maxpos_atom=8):
|
||||||
|
'''
|
||||||
|
Calculates residue and atom bond distances of protein/SM complex. Used for positional
|
||||||
|
embedding and structure module. 2nd version (2022-9-19); handles atomized proteins.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- idx: residue index (B, L)
|
||||||
|
- bond_feats: bond features (B, L, L)
|
||||||
|
- dist_matrix: precomputed bond distances (B, L, L) NOTE: need to run nan_to_num to remove infinities
|
||||||
|
- sm_mask: boolean feature (L). True if a position represents atom, False otherwise
|
||||||
|
- minpos_res: minimum value of residue distances
|
||||||
|
- maxpos_res: maximum value of residue distances
|
||||||
|
- maxpos_atom: maximum value of atom bond distances
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- res_dist: residue distance (B, L, L)
|
||||||
|
- atom_dist: atom bond distance (B, L, L)
|
||||||
|
'''
|
||||||
|
bond_feats = bond_feats[0] # assume batch = 1
|
||||||
|
L = bond_feats.shape[0]
|
||||||
|
device = bond_feats.device
|
||||||
|
|
||||||
|
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
|
||||||
|
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
|
||||||
|
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
|
||||||
|
|
||||||
|
# protein residue distances
|
||||||
|
res_dist_prot = torch.clamp(idx[0,None,:] - idx[0,:,None],
|
||||||
|
min=minpos_res, max=maxpos_res) # (L, L) intra-protein
|
||||||
|
res_dist_sm = torch.full((L,L), maxpos_res+1, device=device) # (L, L) with "unknown" res. dist. token
|
||||||
|
|
||||||
|
# small molecule atom bond graph
|
||||||
|
atom_dist_sm = torch.nan_to_num(dist_matrix, posinf=maxpos_atom)[0].long() # this comes through the dataloader so it is batched
|
||||||
|
atom_dist_prot = torch.full((L,L), maxpos_atom+1, device=device)
|
||||||
|
|
||||||
|
#fd new impl
|
||||||
|
i_s, j_s = torch.where(bond_feats==6)
|
||||||
|
i_sm = i_s[sm_mask[i_s]]
|
||||||
|
i_prot = j_s[sm_mask[i_s]]
|
||||||
|
res_dist_inter = torch.full((L,L), maxpos_res, device=device)
|
||||||
|
atom_dist_inter = torch.full((L,L), maxpos_atom, device=device)
|
||||||
|
if i_prot.shape[0] > 0:
|
||||||
|
closest_prot_res = i_prot[torch.argmin(atom_dist_sm[sm_mask,:][:,i_sm], dim=-1)]
|
||||||
|
res_dist_inter[sm_mask,:] = res_dist_prot[closest_prot_res,:]
|
||||||
|
res_dist_inter[:,sm_mask] = res_dist_prot[:,closest_prot_res]
|
||||||
|
|
||||||
|
closest_atom = i_sm[torch.argmin(torch.abs(res_dist_prot[~sm_mask,:][:,i_prot]), dim=-1)]
|
||||||
|
atom_dist_inter[~sm_mask,:] = atom_dist_sm[closest_atom,:] + 1
|
||||||
|
atom_dist_inter[:,~sm_mask] = atom_dist_sm[:,closest_atom] + 1
|
||||||
|
|
||||||
|
res_dist = res_dist_prot * prot_mask_2d + res_dist_inter * inter_mask_2d + res_dist_sm * sm_mask_2d
|
||||||
|
atom_dist = atom_dist_prot * prot_mask_2d + atom_dist_inter * inter_mask_2d + atom_dist_sm * sm_mask_2d
|
||||||
|
|
||||||
|
return res_dist[None], atom_dist[None] # add batch dim.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_relpos(idx, bond_feats, sm_mask, inter_pos=32, maxpath=32):
|
||||||
|
'''
|
||||||
|
Relative position matrix of protein/SM complex. Used for positional
|
||||||
|
embedding and structure module. Simple version from 9/2/2022 that doesn't
|
||||||
|
handle atomized proteins.
|
||||||
|
|
||||||
|
Input:
|
||||||
|
- idx: residue index (B, L)
|
||||||
|
- bond_feats: bond features (B, L, L)
|
||||||
|
- sm_mask: boolean feature True if a position represents atom, False if residue (B, L)
|
||||||
|
- inter_pos: value to assign as the protein-SM residue index differences
|
||||||
|
- maxpath: bond distances greater than this are clipped to this value
|
||||||
|
|
||||||
|
Output:
|
||||||
|
- relpos: relative position feature (B, L, L)
|
||||||
|
for intra-protein this is the residue index difference
|
||||||
|
for intra-SM this is the bond distance
|
||||||
|
for protein-SM this is user-defined value inter_pos
|
||||||
|
'''
|
||||||
|
bond_feats = bond_feats[0]
|
||||||
|
|
||||||
|
sm_mask_2d = sm_mask[None,:]*sm_mask[:,None]
|
||||||
|
prot_mask_2d = (~sm_mask[None,:]) * (~sm_mask[:,None])
|
||||||
|
inter_mask_2d = (~sm_mask[None,:]) * (sm_mask[:,None]) + (sm_mask[None,:]) * (~sm_mask[:,None])
|
||||||
|
|
||||||
|
# intra-protein: residue # differences
|
||||||
|
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
|
||||||
|
|
||||||
|
# intra-small molecule: bond distances
|
||||||
|
sm_bond_feats = torch.zeros_like(bond_feats) + sm_mask*bond_feats
|
||||||
|
G = nx.from_numpy_matrix(sm_bond_feats.detach().cpu().numpy())
|
||||||
|
paths = dict(nx.all_pairs_shortest_path_length(G,cutoff=maxpath))
|
||||||
|
paths = [(i,j,vij) for i,vi in paths.items() for j,vij in vi.items()]
|
||||||
|
i,j,v = torch.tensor(paths).T
|
||||||
|
|
||||||
|
bond_separation = torch.full_like(bond_feats, maxpath) \
|
||||||
|
- maxpath*torch.eye(bond_feats.shape[0]).to(bond_feats.device).long()
|
||||||
|
bond_separation[i,j] = v.to(bond_feats.device)
|
||||||
|
|
||||||
|
# combine: protein-s.m. are always positive maximum distance apart
|
||||||
|
# assumes one small molecule per example
|
||||||
|
relpos = prot_mask_2d * seqsep + sm_mask_2d * bond_separation + inter_mask_2d * inter_pos # (B, L, L)
|
||||||
|
relpos = relpos.to(bond_feats.device)
|
||||||
|
|
||||||
|
return relpos
|
||||||
|
|
||||||
|
def make_full_graph(xyz, pair, idx):
|
||||||
|
'''
|
||||||
|
Input:
|
||||||
|
- xyz: current backbone cooordinates (B, L, 3, 3)
|
||||||
|
- pair: pair features from Trunk (B, L, L, E)
|
||||||
|
- idx: residue index from ground truth pdb
|
||||||
|
Output:
|
||||||
|
- G: defined graph
|
||||||
|
'''
|
||||||
|
|
||||||
|
B, L = xyz.shape[:2]
|
||||||
|
device = xyz.device
|
||||||
|
|
||||||
|
# seq sep
|
||||||
|
sep = idx[:,None,:] - idx[:,:,None]
|
||||||
|
b,i,j = torch.where(sep.abs() > 0)
|
||||||
|
src = b*L+i
|
||||||
|
tgt = b*L+j
|
||||||
|
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
|
||||||
|
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]) #.detach() # no gradient through basis function
|
||||||
|
return G, pair[b,i,j][...,None]
|
||||||
|
|
||||||
|
def make_topk_graph(xyz, pair, idx, top_k=128, nlocal=33, topk_incl_local=True, eps=1e-6):
|
||||||
|
'''
|
||||||
|
Input:
|
||||||
|
- xyz: current backbone cooordinates (B, L, 3, 3)
|
||||||
|
- pair: pair features from Trunk (B, L, L, E)
|
||||||
|
- idx: residue index from ground truth pdb
|
||||||
|
Output:
|
||||||
|
- G: defined graph
|
||||||
|
'''
|
||||||
|
|
||||||
|
B, L = xyz.shape[:2]
|
||||||
|
device = xyz.device
|
||||||
|
|
||||||
|
# distance map from current CA coordinates
|
||||||
|
D = torch.cdist(xyz, xyz) + torch.eye(L, device=device).unsqueeze(0)*9999.9 # (B, L, L)
|
||||||
|
|
||||||
|
# seq sep
|
||||||
|
sep = idx[:,None,:] - idx[:,:,None]
|
||||||
|
sep = sep.abs() + torch.eye(L, device=device).unsqueeze(0)*9999.9
|
||||||
|
|
||||||
|
if (topk_incl_local):
|
||||||
|
D = D + sep*eps
|
||||||
|
D[sep<nlocal] = 0.0
|
||||||
|
|
||||||
|
# get top_k neighbors
|
||||||
|
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
|
||||||
|
topk_matrix = torch.zeros((B, L, L), device=device)
|
||||||
|
topk_matrix.scatter_(2, E_idx, 1.0)
|
||||||
|
cond = topk_matrix > 0.0
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
D = D + sep*eps
|
||||||
|
|
||||||
|
# get top_k neighbors
|
||||||
|
D_neigh, E_idx = torch.topk(D, min(top_k, L-1), largest=False) # shape of E_idx: (B, L, top_k)
|
||||||
|
topk_matrix = torch.zeros((B, L, L), device=device)
|
||||||
|
topk_matrix.scatter_(2, E_idx, 1.0)
|
||||||
|
|
||||||
|
# put an edge if any of the 3 conditions are met:
|
||||||
|
# 1) |i-j| <= kmin (connect sequentially adjacent residues)
|
||||||
|
# 2) top_k neighbors
|
||||||
|
cond = torch.logical_or(topk_matrix > 0.0, sep < nlocal)
|
||||||
|
b,i,j = torch.where(cond)
|
||||||
|
|
||||||
|
src = b*L+i
|
||||||
|
tgt = b*L+j
|
||||||
|
G = dgl.graph((src, tgt), num_nodes=B*L).to(device)
|
||||||
|
G.edata['rel_pos'] = (xyz[b,j,:] - xyz[b,i,:]).detach() # no gradient through basis function
|
||||||
|
|
||||||
|
return G, pair[b,i,j][...,None]
|
||||||
|
|
||||||
|
def make_atom_graph( xyz, mask, num_bonds, top_k=16, maxbonds=4 ):
|
||||||
|
B,L,A = xyz.shape[:3]
|
||||||
|
device = xyz.device
|
||||||
|
|
||||||
|
D = torch.norm(
|
||||||
|
xyz[:,None,None,:,:] - xyz[:,:,:,None,None], dim=-1
|
||||||
|
)
|
||||||
|
mask2d = mask[:,:,:,None,None]*mask[:,None,None,:,:]
|
||||||
|
D[~mask2d] = 9999.
|
||||||
|
D[D==0] = 9999.
|
||||||
|
|
||||||
|
# select top K neighbors for each atom
|
||||||
|
# keep indices as batch/res/atm indices
|
||||||
|
D_neigh, E_idx = torch.topk(D.reshape(B,L,A,-1), top_k, largest=False) # shape of E_idx: (B, L, top_k)
|
||||||
|
Eres, Eatm = torch.div(E_idx,A,rounding_mode='trunc'), E_idx%A
|
||||||
|
bi,ri,ai = mask.nonzero(as_tuple=True)
|
||||||
|
bi = bi[:,None].repeat(1,top_k).reshape(-1)
|
||||||
|
ri = ri[:,None].repeat(1,top_k).reshape(-1)
|
||||||
|
ai = ai[:,None].repeat(1,top_k).reshape(-1)
|
||||||
|
rj,aj = Eres[mask].reshape(-1), Eatm[mask].reshape(-1)
|
||||||
|
|
||||||
|
# on each edge, 1-hot encode the number of bonds (up to maxbonds) seperating each atom
|
||||||
|
edge = torch.full(ri.shape, maxbonds, device=device)
|
||||||
|
resmask = ri==rj
|
||||||
|
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],aj[resmask]]-1
|
||||||
|
resmask = ri+1==rj
|
||||||
|
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],2]+num_bonds[bi[resmask],rj[resmask],0,aj[resmask]]
|
||||||
|
resmask = ri-1==rj
|
||||||
|
edge[resmask] = num_bonds[bi[resmask],ri[resmask],ai[resmask],0]+num_bonds[bi[resmask],rj[resmask],2,aj[resmask]]
|
||||||
|
edge = edge.clamp(0,maxbonds-1)
|
||||||
|
edge = F.one_hot(edge)[...,None]
|
||||||
|
|
||||||
|
natm = torch.sum(mask)
|
||||||
|
index = torch.zeros_like(mask, dtype=torch.long, device=device)
|
||||||
|
index[mask] = torch.arange(natm, device=device)
|
||||||
|
src=index[bi,ri,ai]
|
||||||
|
tgt=index[bi,rj,aj]
|
||||||
|
|
||||||
|
G = dgl.graph((src, tgt), num_nodes=natm).to(device)
|
||||||
|
G.edata['rel_pos'] = (xyz[bi,ri,ai] - xyz[bi,rj,aj]).detach() # no gradient through basis function
|
||||||
|
|
||||||
|
return G, edge
|
||||||
|
|
||||||
|
|
||||||
|
# rotate about the x axis
|
||||||
|
def make_rotX(angs, eps=1e-6):
|
||||||
|
B,L = angs.shape[:2]
|
||||||
|
NORM = torch.linalg.norm(angs, dim=-1) + eps
|
||||||
|
|
||||||
|
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
|
||||||
|
|
||||||
|
RTs[:,:,1,1] = angs[:,:,0]/NORM
|
||||||
|
RTs[:,:,1,2] = -angs[:,:,1]/NORM
|
||||||
|
RTs[:,:,2,1] = angs[:,:,1]/NORM
|
||||||
|
RTs[:,:,2,2] = angs[:,:,0]/NORM
|
||||||
|
return RTs
|
||||||
|
|
||||||
|
# rotate about the x axis
|
||||||
|
def make_rotZ(angs, eps=1e-6):
|
||||||
|
B,L = angs.shape[:2]
|
||||||
|
NORM = torch.linalg.norm(angs, dim=-1) + eps
|
||||||
|
|
||||||
|
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
|
||||||
|
|
||||||
|
RTs[:,:,0,0] = angs[:,:,0]/NORM
|
||||||
|
RTs[:,:,0,1] = -angs[:,:,1]/NORM
|
||||||
|
RTs[:,:,1,0] = angs[:,:,1]/NORM
|
||||||
|
RTs[:,:,1,1] = angs[:,:,0]/NORM
|
||||||
|
return RTs
|
||||||
|
|
||||||
|
# rotate about an arbitrary axis
|
||||||
|
def make_rot_axis(angs, u, eps=1e-6):
|
||||||
|
B,L = angs.shape[:2]
|
||||||
|
NORM = torch.linalg.norm(angs, dim=-1) + eps
|
||||||
|
|
||||||
|
RTs = torch.eye(4, device=angs.device).repeat(B,L,1,1)
|
||||||
|
|
||||||
|
ct = angs[:,:,0]/NORM
|
||||||
|
st = angs[:,:,1]/NORM
|
||||||
|
u0 = u[:,:,0]
|
||||||
|
u1 = u[:,:,1]
|
||||||
|
u2 = u[:,:,2]
|
||||||
|
|
||||||
|
RTs[:,:,0,0] = ct+u0*u0*(1-ct)
|
||||||
|
RTs[:,:,0,1] = u0*u1*(1-ct)-u2*st
|
||||||
|
RTs[:,:,0,2] = u0*u2*(1-ct)+u1*st
|
||||||
|
RTs[:,:,1,0] = u0*u1*(1-ct)+u2*st
|
||||||
|
RTs[:,:,1,1] = ct+u1*u1*(1-ct)
|
||||||
|
RTs[:,:,1,2] = u1*u2*(1-ct)-u0*st
|
||||||
|
RTs[:,:,2,0] = u0*u2*(1-ct)-u1*st
|
||||||
|
RTs[:,:,2,1] = u1*u2*(1-ct)+u0*st
|
||||||
|
RTs[:,:,2,2] = ct+u2*u2*(1-ct)
|
||||||
|
return RTs
|
||||||
|
|
||||||
|
|
||||||
|
# compute allatom structure from backbone frames and torsions
|
||||||
|
#
|
||||||
|
# alphas:
|
||||||
|
# omega/phi/psi: 0-2
|
||||||
|
# chi_1-4(prot): 3-6
|
||||||
|
# cb/cg bend: 7-9
|
||||||
|
# eps(p)/zeta(p): 10-11
|
||||||
|
# alpha/beta/gamma/delta: 12-15
|
||||||
|
# nu2/nu1/nu0: 16-18
|
||||||
|
# chi_1(na): 19
|
||||||
|
#
|
||||||
|
# RTs_in_base_frame:
|
||||||
|
# omega/phi/psi: 0-2
|
||||||
|
# chi_1-4(prot): 3-6
|
||||||
|
# eps(p)/zeta(p): 7-8
|
||||||
|
# alpha/beta/gamma/delta: 9-12
|
||||||
|
# nu2/nu1/nu0: 13-15
|
||||||
|
# chi_1(na): 16
|
||||||
|
#
|
||||||
|
# RT frames (output):
|
||||||
|
# origin: 0
|
||||||
|
# omega/phi/psi: 1-3
|
||||||
|
# chi_1-4(prot): 4-7
|
||||||
|
# cb bend: 8
|
||||||
|
# alpha/beta/gamma/delta: 9-12
|
||||||
|
# nu2/nu1/nu0: 13-15
|
||||||
|
# chi_1(na): 16
|
||||||
|
#
|
||||||
|
class XYZConverter(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(XYZConverter, self).__init__()
|
||||||
|
|
||||||
|
self.register_buffer("torsion_indices", ChemData().torsion_indices, persistent=False)
|
||||||
|
self.register_buffer("torsion_can_flip", ChemData().torsion_can_flip.to(torch.int32), persistent=False)
|
||||||
|
self.register_buffer("ref_angles", ChemData().reference_angles, persistent=False)
|
||||||
|
self.register_buffer("base_indices", ChemData().base_indices, persistent=False)
|
||||||
|
self.register_buffer("RTs_in_base_frame", ChemData().RTs_by_torsion, persistent=False)
|
||||||
|
self.register_buffer("xyzs_in_base_frame", ChemData().xyzs_in_base_frame, persistent=False)
|
||||||
|
|
||||||
|
def compute_all_atom(self, seq, xyz, alphas):
|
||||||
|
B,L = xyz.shape[:2]
|
||||||
|
|
||||||
|
is_NA = is_nucleic(seq)
|
||||||
|
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], is_NA)
|
||||||
|
|
||||||
|
RTF0 = torch.eye(4).repeat(B,L,1,1).to(device=Rs.device)
|
||||||
|
|
||||||
|
# bb
|
||||||
|
RTF0[:,:,:3,:3] = Rs
|
||||||
|
RTF0[:,:,:3,3] = Ts
|
||||||
|
|
||||||
|
# omega
|
||||||
|
RTF1 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF0, self.RTs_in_base_frame[seq,0,:], make_rotX(alphas[:,:,0,:]))
|
||||||
|
|
||||||
|
# phi
|
||||||
|
RTF2 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF0, self.RTs_in_base_frame[seq,1,:], make_rotX(alphas[:,:,1,:]))
|
||||||
|
|
||||||
|
# psi
|
||||||
|
RTF3 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF0, self.RTs_in_base_frame[seq,2,:], make_rotX(alphas[:,:,2,:]))
|
||||||
|
|
||||||
|
# CB bend
|
||||||
|
basexyzs = self.xyzs_in_base_frame[seq]
|
||||||
|
NCr = 0.5*(basexyzs[:,:,2,:3]+basexyzs[:,:,0,:3])
|
||||||
|
CAr = (basexyzs[:,:,1,:3])
|
||||||
|
CBr = (basexyzs[:,:,4,:3])
|
||||||
|
CBrotaxis1 = (CBr-CAr).cross(NCr-CAr)
|
||||||
|
CBrotaxis1 /= torch.linalg.norm(CBrotaxis1, dim=-1, keepdim=True)+1e-8
|
||||||
|
|
||||||
|
# CB twist
|
||||||
|
NCp = basexyzs[:,:,2,:3] - basexyzs[:,:,0,:3]
|
||||||
|
NCpp = NCp - torch.sum(NCp*NCr, dim=-1, keepdim=True)/ torch.sum(NCr*NCr, dim=-1, keepdim=True) * NCr
|
||||||
|
CBrotaxis2 = (CBr-CAr).cross(NCpp)
|
||||||
|
CBrotaxis2 /= torch.linalg.norm(CBrotaxis2, dim=-1, keepdim=True)+1e-8
|
||||||
|
|
||||||
|
CBrot1 = make_rot_axis(alphas[:,:,7,:], CBrotaxis1 )
|
||||||
|
CBrot2 = make_rot_axis(alphas[:,:,8,:], CBrotaxis2 )
|
||||||
|
|
||||||
|
RTF8 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF0, CBrot1,CBrot2)
|
||||||
|
|
||||||
|
# chi1 + CG bend
|
||||||
|
RTF4 = torch.einsum(
|
||||||
|
'brij,brjk,brkl,brlm->brim',
|
||||||
|
RTF8,
|
||||||
|
self.RTs_in_base_frame[seq,3,:],
|
||||||
|
make_rotX(alphas[:,:,3,:]),
|
||||||
|
make_rotZ(alphas[:,:,9,:]))
|
||||||
|
|
||||||
|
# chi2
|
||||||
|
RTF5 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF4, self.RTs_in_base_frame[seq,4,:],make_rotX(alphas[:,:,4,:]))
|
||||||
|
|
||||||
|
# chi3
|
||||||
|
RTF6 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF5,self.RTs_in_base_frame[seq,5,:],make_rotX(alphas[:,:,5,:]))
|
||||||
|
|
||||||
|
# chi4
|
||||||
|
RTF7 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF6,self.RTs_in_base_frame[seq,6,:],make_rotX(alphas[:,:,6,:]))
|
||||||
|
|
||||||
|
# ignore RTs_in_base_frame[seq,7:9,:] and alphas[:,:,10:12,:]
|
||||||
|
|
||||||
|
# which mode are we running in
|
||||||
|
if (not ChemData().params.use_phospate_frames_for_NA):
|
||||||
|
# NA nu1 --> from base frame
|
||||||
|
RTF14 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF0, self.RTs_in_base_frame[seq,14,:], make_rotX(alphas[:,:,17,:]))
|
||||||
|
|
||||||
|
# NA nu0 --> from base frame
|
||||||
|
RTF15 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF0, self.RTs_in_base_frame[seq,15,:], make_rotX(alphas[:,:,18,:]))
|
||||||
|
|
||||||
|
# NA chi --> from base frame
|
||||||
|
RTF16= torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF0, self.RTs_in_base_frame[seq,16,:], make_rotX(alphas[:,:,19,:]))
|
||||||
|
|
||||||
|
# NA nu2 --> from nu1 frame
|
||||||
|
RTF13 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF14, self.RTs_in_base_frame[seq,13,:], make_rotX(alphas[:,:,16,:]))
|
||||||
|
|
||||||
|
# NA delta --> from nu2 frame
|
||||||
|
RTF12 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF13, self.RTs_in_base_frame[seq,12,:], make_rotX(alphas[:,:,15,:]))
|
||||||
|
|
||||||
|
# NA gamma --> from delta frame
|
||||||
|
RTF11 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF12, self.RTs_in_base_frame[seq,11,:], make_rotX(alphas[:,:,14,:]))
|
||||||
|
|
||||||
|
# NA beta --> from gamma frame
|
||||||
|
RTF10 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF11, self.RTs_in_base_frame[seq,10,:], make_rotX(alphas[:,:,13,:]))
|
||||||
|
|
||||||
|
# NA alpha --> from beta frame
|
||||||
|
RTF9 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF10, self.RTs_in_base_frame[seq,9,:], make_rotX(alphas[:,:,12,:]))
|
||||||
|
else:
|
||||||
|
# NA alpha
|
||||||
|
RTF9 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF0, self.RTs_in_base_frame[seq,9,:], make_rotX(alphas[:,:,12,:]))
|
||||||
|
|
||||||
|
# NA beta
|
||||||
|
RTF10 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF9, self.RTs_in_base_frame[seq,10,:], make_rotX(alphas[:,:,13,:]))
|
||||||
|
|
||||||
|
# NA gamma
|
||||||
|
RTF11 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF10, self.RTs_in_base_frame[seq,11,:], make_rotX(alphas[:,:,14,:]))
|
||||||
|
|
||||||
|
# NA delta
|
||||||
|
RTF12 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF11, self.RTs_in_base_frame[seq,12,:], make_rotX(alphas[:,:,15,:]))
|
||||||
|
|
||||||
|
# NA nu2 - from gamma frame
|
||||||
|
RTF13 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF11, self.RTs_in_base_frame[seq,13,:], make_rotX(alphas[:,:,16,:]))
|
||||||
|
|
||||||
|
# NA nu1
|
||||||
|
RTF14 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF13, self.RTs_in_base_frame[seq,14,:], make_rotX(alphas[:,:,17,:]))
|
||||||
|
|
||||||
|
# NA nu0
|
||||||
|
RTF15 = torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF14, self.RTs_in_base_frame[seq,15,:], make_rotX(alphas[:,:,18,:]))
|
||||||
|
|
||||||
|
# NA chi - from nu1 frame
|
||||||
|
RTF16= torch.einsum(
|
||||||
|
'brij,brjk,brkl->bril',
|
||||||
|
RTF14, self.RTs_in_base_frame[seq,16,:], make_rotX(alphas[:,:,19,:]))
|
||||||
|
|
||||||
|
|
||||||
|
RTframes = torch.stack((
|
||||||
|
RTF0,RTF1,RTF2,RTF3,RTF4,RTF5,RTF6,RTF7,RTF8,
|
||||||
|
RTF9,RTF10,RTF11,RTF12,RTF13,RTF14,RTF15,RTF16
|
||||||
|
),dim=2)
|
||||||
|
|
||||||
|
xyzs = torch.einsum(
|
||||||
|
'brtij,brtj->brti',
|
||||||
|
RTframes.gather(2,self.base_indices[seq][...,None,None].repeat(1,1,1,4,4)), basexyzs
|
||||||
|
)
|
||||||
|
|
||||||
|
return RTframes, xyzs[...,:3]
|
||||||
|
|
||||||
|
|
||||||
|
def get_tor_mask(self, seq, mask_in=None):
|
||||||
|
B,L = seq.shape[:2]
|
||||||
|
dna_mask = is_nucleic(seq)
|
||||||
|
prot_mask = ~dna_mask
|
||||||
|
|
||||||
|
tors_mask = self.torsion_indices[seq,:,-1] > 0
|
||||||
|
|
||||||
|
if mask_in != None:
|
||||||
|
N = mask_in.shape[2]
|
||||||
|
ts = self.torsion_indices[seq]
|
||||||
|
bs = torch.arange(B, device=seq.device)[:,None,None,None]
|
||||||
|
rs = torch.arange(L, device=seq.device)[None,:,None,None] - (ts<0)*1 # ts<-1 ==> prev res
|
||||||
|
ts = torch.abs(ts)
|
||||||
|
tors_mask *= mask_in[bs,rs,ts].all(dim=-1)
|
||||||
|
|
||||||
|
return tors_mask
|
||||||
|
|
||||||
|
def get_torsions(self, xyz_in, seq, mask_in=None):
|
||||||
|
B,L = xyz_in.shape[:2]
|
||||||
|
|
||||||
|
tors_mask = self.get_tor_mask(seq, mask_in)
|
||||||
|
# idealize given xyz coordinates before computing torsion angles
|
||||||
|
xyz = idealize_reference_frame(seq, xyz_in)
|
||||||
|
|
||||||
|
ts = self.torsion_indices[seq]
|
||||||
|
bs = torch.arange(B, device=xyz_in.device)[:,None,None,None]
|
||||||
|
xs = torch.arange(L, device=xyz_in.device)[None,:,None,None] - (ts<0)*1 # ts<-1 ==> prev res
|
||||||
|
ys = torch.abs(ts)
|
||||||
|
xyzs_bytor = xyz[bs,xs,ys,:]
|
||||||
|
|
||||||
|
torsions = torch.zeros( (B,L,ChemData().NTOTALDOFS,2), device=xyz_in.device )
|
||||||
|
|
||||||
|
# protein torsion
|
||||||
|
torsions[...,:7,:] = th_dih(
|
||||||
|
xyzs_bytor[...,:7,0,:],xyzs_bytor[...,:7,1,:],xyzs_bytor[...,:7,2,:],xyzs_bytor[...,:7,3,:]
|
||||||
|
)
|
||||||
|
torsions[:,:,2,:] = -1 * torsions[:,:,2,:] # shift psi by pi
|
||||||
|
|
||||||
|
# NA
|
||||||
|
torsions[...,10:,:] = th_dih(
|
||||||
|
xyzs_bytor[...,10:,0,:],xyzs_bytor[...,10:,1,:],xyzs_bytor[...,10:,2,:],xyzs_bytor[...,10:,3,:]
|
||||||
|
)
|
||||||
|
|
||||||
|
# protein angles
|
||||||
|
# CB bend
|
||||||
|
NC = 0.5*( xyz[:,:,0,:3] + xyz[:,:,2,:3] )
|
||||||
|
CA = xyz[:,:,1,:3]
|
||||||
|
CB = xyz[:,:,4,:3]
|
||||||
|
t = th_ang_v(CB-CA,NC-CA)
|
||||||
|
t0 = self.ref_angles[seq][...,0,:]
|
||||||
|
torsions[:,:,7,:] = torch.stack(
|
||||||
|
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
|
||||||
|
dim=-1 )
|
||||||
|
|
||||||
|
# CB twist
|
||||||
|
NCCA = NC-CA
|
||||||
|
NCp = xyz[:,:,2,:3] - xyz[:,:,0,:3]
|
||||||
|
NCpp = NCp - torch.sum(NCp*NCCA, dim=-1, keepdim=True)/ torch.sum(NCCA*NCCA, dim=-1, keepdim=True) * NCCA
|
||||||
|
t = th_ang_v(CB-CA,NCpp)
|
||||||
|
t0 = self.ref_angles[seq][...,1,:]
|
||||||
|
torsions[:,:,8,:] = torch.stack(
|
||||||
|
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
|
||||||
|
dim=-1 )
|
||||||
|
|
||||||
|
# CG bend
|
||||||
|
CG = xyz[:,:,5,:3]
|
||||||
|
t = th_ang_v(CG-CB,CA-CB)
|
||||||
|
t0 = self.ref_angles[seq][...,2,:]
|
||||||
|
torsions[:,:,9,:] = torch.stack(
|
||||||
|
(torch.sum(t*t0,dim=-1),t[...,0]*t0[...,1]-t[...,1]*t0[...,0]),
|
||||||
|
dim=-1 )
|
||||||
|
|
||||||
|
mask0 = (torch.isnan(torsions[...,0])).nonzero()
|
||||||
|
mask1 = (torch.isnan(torsions[...,1])).nonzero()
|
||||||
|
torsions[mask0[:,0],mask0[:,1],mask0[:,2],0] = 1.0
|
||||||
|
torsions[mask1[:,0],mask1[:,1],mask1[:,2],1] = 0.0
|
||||||
|
|
||||||
|
# alt chis
|
||||||
|
torsions_alt = torsions.clone()
|
||||||
|
torsions_alt[self.torsion_can_flip[seq,:].to(torch.bool)] *= -1
|
||||||
|
|
||||||
|
# torsions to restrain to 0 or 180 degree
|
||||||
|
# (this should be specified in chemical?)
|
||||||
|
tors_planar = torch.zeros((B, L, ChemData().NTOTALDOFS), dtype=torch.bool, device=xyz_in.device)
|
||||||
|
tors_planar[:,:,5] = seq == ChemData().aa2num['TYR'] # TYR chi 3 should be planar
|
||||||
|
|
||||||
|
return torsions, torsions_alt, tors_mask, tors_planar
|
Loading…
Reference in a new issue