mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-24 22:37:20 +00:00
feat: new covalent bond input
This commit is contained in:
parent
bf214835d6
commit
a9f2512cda
2 changed files with 120 additions and 10 deletions
|
@ -1,6 +1,7 @@
|
|||
import re
|
||||
import torch
|
||||
from openbabel import openbabel
|
||||
from typing import Optional
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from tempfile import NamedTemporaryFile
|
||||
from rf2aa.chemical import ChemicalData as ChemData
|
||||
|
@ -15,8 +16,8 @@ class MoleculeToMoleculeBond:
|
|||
absolute_atom_index_first: int
|
||||
chain_index_second: int
|
||||
absolute_atom_index_second: int
|
||||
new_chirality_atom_first: Optional[str]
|
||||
new_chirality_atom_second: Optional[str]
|
||||
new_chirality_atom_first: Optional[Literal['null','CW','CCW']]
|
||||
new_chirality_atom_second: Optional[Literal['null','CW','CCW']]
|
||||
|
||||
@dataclass
|
||||
class AtomizedResidue:
|
||||
|
@ -27,6 +28,110 @@ class AtomizedResidue:
|
|||
original_chain: str
|
||||
index_in_original_chain: int
|
||||
|
||||
@dataclass
|
||||
class CovalentBond:
|
||||
"""
|
||||
A class representing a covalent bond between a protein and a small molecule.
|
||||
|
||||
Attributes:
|
||||
- protein_chain: Identifier for the protein chain, of string type.
|
||||
- protein_residue_number: Residue number in the protein, can be an integer or a string representation of a number.
|
||||
- protein_atom_name: Name of the atom in the protein participating in the covalent bond, string type.
|
||||
- small_molecule_chain: Identifier for the small molecule chain, string type.
|
||||
- sm_atom_index: Index of the atom in the small molecule participating in the covalent bond, can be an integer or a string representation of a number.
|
||||
- new_chirality_atom_1: Orientation of the new chiral center atom 1, can be 'null', 'CW' (clockwise), or 'CCW' (counterclockwise).
|
||||
- new_chirality_atom_2: Orientation of the new chiral center atom 2, can be 'null', 'CW', or 'CCW'.
|
||||
|
||||
Returns:
|
||||
- None. This initializer directly constructs the object.
|
||||
"""
|
||||
|
||||
protein_chain: str
|
||||
protein_residue_number: int
|
||||
protein_atom_name: str
|
||||
small_molecule_chain: str
|
||||
sm_atom_index: int
|
||||
|
||||
new_chirality_atom_1: Literal['null','CW','CCW']
|
||||
new_chirality_atom_2: Literal['null','CW','CCW']
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Special method called immediately after the object is initialized, used to validate and convert residue_number and atom_index values.
|
||||
"""
|
||||
# Validate and convert residue_number to ensure it's an integer
|
||||
if isinstance(self.protein_residue_number, str):
|
||||
if not self.protein_residue_number.isdigit():
|
||||
raise ValueError("Residue number must be a number")
|
||||
self.protein_residue_number = int(self.protein_residue_number)
|
||||
|
||||
# Validate and convert atom_index to ensure it's an integer
|
||||
if isinstance(self.sm_atom_index, str):
|
||||
if not self.sm_atom_index.isdigit():
|
||||
raise ValueError("atom_index must be a number")
|
||||
self.sm_atom_index = int(self.sm_atom_index)
|
||||
|
||||
@dataclass
|
||||
class CovalentBondsParser:
|
||||
"""
|
||||
A class for parsing covalent bond information.
|
||||
|
||||
Attributes:
|
||||
input (str): A string containing a series of covalent bond entries, separated by semicolons. Each entry follows the pattern: 'Chain,Residue,Atom:Chain,AtomIndex:NewChiralityAtom1,NewChiralityAtom2'.
|
||||
sub_pattern (Optional[str]): A regular expression pattern for matching the input format. Defaults to `r'([A-Z]{1}),([\d]{1,4}),(\w+):([A-Z]{1}),([\d]{1,2}):(\w+),(\w+)'`.
|
||||
|
||||
Methods:
|
||||
__post_init__: Validates the input format after initialization.
|
||||
parse_one_covalent_bond(sub_input: str) -> CovalentBond: Parses a single covalent bond entry and returns a CovalentBond object.
|
||||
formatted -> Tuple[CovalentBond]: Returns a tuple of parsed CovalentBond objects from the input string.
|
||||
"""
|
||||
|
||||
input: str # A case: 'A,74,ND2:B,1:CW,null;...'
|
||||
sub_pattern: Optional[str]=r'([A-Z]{1}),([\d]{1,4}),(\w+):([A-Z]{1}),([\d]{1,2}):(\w+),(\w+)'
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Validates the input format after initialization.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input format is invalid.
|
||||
"""
|
||||
for sub_input in self.input.split(';'):
|
||||
if not (':' in sub_input and ',' in sub_input):
|
||||
raise ValueError(f"Invalid input format: {sub_input=}.\nPlease follow comma-separated pattern like `r'{self.sub_pattern}'`")
|
||||
def parse_one_covalent_bond(self, sub_input: str) -> CovalentBond:
|
||||
"""
|
||||
Parses a single covalent bond entry and returns a CovalentBond object.
|
||||
|
||||
Args:
|
||||
sub_input (str): A covalent bond entry to be parsed.
|
||||
|
||||
Returns:
|
||||
CovalentBond: The parsed covalent bond object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the input format is invalid.
|
||||
"""
|
||||
if not (is_match:= re.match(self.sub_pattern, sub_input)):
|
||||
raise ValueError(f"Invalid input format: {sub_input=}")
|
||||
return CovalentBond(
|
||||
protein_chain=is_match.group(1),
|
||||
protein_residue_number=is_match.group(2),
|
||||
protein_atom_name=is_match.group(3),
|
||||
small_molecule_chain=is_match.group(4),
|
||||
sm_atom_index=is_match.group(5),
|
||||
new_chirality_atom_1=is_match.group(6),
|
||||
new_chirality_atom_2=is_match.group(7))
|
||||
|
||||
@property
|
||||
def formatted(self) -> Tuple[CovalentBond]:
|
||||
"""
|
||||
Returns a tuple of parsed CovalentBond objects from the input string.
|
||||
|
||||
Returns:
|
||||
Tuple[CovalentBond]: A tuple containing CovalentBond objects.
|
||||
"""
|
||||
return tuple(self.parse_one_covalent_bond(sub_input) for sub_input in self.input.split(';'))
|
||||
|
||||
def load_covalent_molecules(protein_inputs, config, model_runner):
|
||||
if config.covale_inputs is None:
|
||||
|
@ -35,7 +140,7 @@ def load_covalent_molecules(protein_inputs, config, model_runner):
|
|||
if config.sm_inputs is None:
|
||||
raise ValueError("If you provide covale_inputs, you must also provide small molecule inputs")
|
||||
|
||||
covalent_bonds = eval(config.covale_inputs)
|
||||
covalent_bonds = CovalentBondsParser(config.covale_inputs).formatted
|
||||
sm_inputs = delete_leaving_atoms(config.sm_inputs)
|
||||
residues_to_atomize, combined_molecules, extra_bonds = find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner)
|
||||
chainid_to_input = {}
|
||||
|
@ -50,14 +155,18 @@ def load_covalent_molecules(protein_inputs, config, model_runner):
|
|||
|
||||
return chainid_to_input, residues_to_atomize
|
||||
|
||||
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner):
|
||||
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds: Tuple[CovalentBond], model_runner):
|
||||
residues_to_atomize = [] # hold on to delete wayward inputs
|
||||
combined_molecules = {} # combined multiple molecules that are bonded
|
||||
extra_bonds = {}
|
||||
for bond in covalent_bonds:
|
||||
prot_chid, prot_res_idx, atom_to_bond = bond[0]
|
||||
sm_chid, sm_atom_num = bond[1]
|
||||
chirality_first_atom, chirality_second_atom = bond[2]
|
||||
prot_chid = bond.protein_chain
|
||||
prot_res_idx=bond.protein_residue_number
|
||||
atom_to_bond=bond.protein_atom_name
|
||||
sm_chid=bond.small_molecule_chain
|
||||
sm_atom_num=bond.sm_atom_index
|
||||
chirality_first_atom=bond.new_chirality_atom_1
|
||||
chirality_second_atom = bond.new_chirality_atom_2
|
||||
if chirality_first_atom.strip() == "null":
|
||||
chirality_first_atom = None
|
||||
if chirality_second_atom.strip() == "null":
|
||||
|
@ -231,7 +340,7 @@ def recompute_xyz_after_chirality(obmol):
|
|||
ff.FastRotorSearch()
|
||||
ff.GetCoordinates(obmol)
|
||||
else:
|
||||
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
|
||||
raise ValueError(f"Failed to generate 3D coordinates for molecule {obmol}.")
|
||||
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
|
||||
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
|
||||
return atom_coords
|
||||
|
|
|
@ -17,6 +17,7 @@ from rf2aa.training.recycling import recycle_step_legacy
|
|||
from rf2aa.util import writepdb, is_atom, Ls_from_same_chain_2d
|
||||
from rf2aa.util_module import XYZConverter
|
||||
|
||||
script_path=os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
|
@ -200,7 +201,7 @@ class ModelRunner:
|
|||
return err_dict
|
||||
|
||||
|
||||
@hydra.main(version_base=None, config_path='config/inference')
|
||||
@hydra.main(version_base=None, config_path=os.path.join(script_path,'config','inference'))
|
||||
def main(config):
|
||||
runner = ModelRunner(config)
|
||||
runner.infer()
|
||||
|
|
Loading…
Reference in a new issue