feat: new covalent bond input

This commit is contained in:
YaoYinYing 2024-05-16 20:15:09 +08:00
parent bf214835d6
commit a9f2512cda
2 changed files with 120 additions and 10 deletions

View file

@ -1,6 +1,7 @@
import re
import torch
from openbabel import openbabel
from typing import Optional
from typing import List, Literal, Optional, Tuple, Union
from dataclasses import dataclass
from tempfile import NamedTemporaryFile
from rf2aa.chemical import ChemicalData as ChemData
@ -15,8 +16,8 @@ class MoleculeToMoleculeBond:
absolute_atom_index_first: int
chain_index_second: int
absolute_atom_index_second: int
new_chirality_atom_first: Optional[str]
new_chirality_atom_second: Optional[str]
new_chirality_atom_first: Optional[Literal['null','CW','CCW']]
new_chirality_atom_second: Optional[Literal['null','CW','CCW']]
@dataclass
class AtomizedResidue:
@ -27,6 +28,110 @@ class AtomizedResidue:
original_chain: str
index_in_original_chain: int
@dataclass
class CovalentBond:
"""
A class representing a covalent bond between a protein and a small molecule.
Attributes:
- protein_chain: Identifier for the protein chain, of string type.
- protein_residue_number: Residue number in the protein, can be an integer or a string representation of a number.
- protein_atom_name: Name of the atom in the protein participating in the covalent bond, string type.
- small_molecule_chain: Identifier for the small molecule chain, string type.
- sm_atom_index: Index of the atom in the small molecule participating in the covalent bond, can be an integer or a string representation of a number.
- new_chirality_atom_1: Orientation of the new chiral center atom 1, can be 'null', 'CW' (clockwise), or 'CCW' (counterclockwise).
- new_chirality_atom_2: Orientation of the new chiral center atom 2, can be 'null', 'CW', or 'CCW'.
Returns:
- None. This initializer directly constructs the object.
"""
protein_chain: str
protein_residue_number: int
protein_atom_name: str
small_molecule_chain: str
sm_atom_index: int
new_chirality_atom_1: Literal['null','CW','CCW']
new_chirality_atom_2: Literal['null','CW','CCW']
def __post_init__(self):
"""
Special method called immediately after the object is initialized, used to validate and convert residue_number and atom_index values.
"""
# Validate and convert residue_number to ensure it's an integer
if isinstance(self.protein_residue_number, str):
if not self.protein_residue_number.isdigit():
raise ValueError("Residue number must be a number")
self.protein_residue_number = int(self.protein_residue_number)
# Validate and convert atom_index to ensure it's an integer
if isinstance(self.sm_atom_index, str):
if not self.sm_atom_index.isdigit():
raise ValueError("atom_index must be a number")
self.sm_atom_index = int(self.sm_atom_index)
@dataclass
class CovalentBondsParser:
"""
A class for parsing covalent bond information.
Attributes:
input (str): A string containing a series of covalent bond entries, separated by semicolons. Each entry follows the pattern: 'Chain,Residue,Atom:Chain,AtomIndex:NewChiralityAtom1,NewChiralityAtom2'.
sub_pattern (Optional[str]): A regular expression pattern for matching the input format. Defaults to `r'([A-Z]{1}),([\d]{1,4}),(\w+):([A-Z]{1}),([\d]{1,2}):(\w+),(\w+)'`.
Methods:
__post_init__: Validates the input format after initialization.
parse_one_covalent_bond(sub_input: str) -> CovalentBond: Parses a single covalent bond entry and returns a CovalentBond object.
formatted -> Tuple[CovalentBond]: Returns a tuple of parsed CovalentBond objects from the input string.
"""
input: str # A case: 'A,74,ND2:B,1:CW,null;...'
sub_pattern: Optional[str]=r'([A-Z]{1}),([\d]{1,4}),(\w+):([A-Z]{1}),([\d]{1,2}):(\w+),(\w+)'
def __post_init__(self):
"""
Validates the input format after initialization.
Raises:
ValueError: If the input format is invalid.
"""
for sub_input in self.input.split(';'):
if not (':' in sub_input and ',' in sub_input):
raise ValueError(f"Invalid input format: {sub_input=}.\nPlease follow comma-separated pattern like `r'{self.sub_pattern}'`")
def parse_one_covalent_bond(self, sub_input: str) -> CovalentBond:
"""
Parses a single covalent bond entry and returns a CovalentBond object.
Args:
sub_input (str): A covalent bond entry to be parsed.
Returns:
CovalentBond: The parsed covalent bond object.
Raises:
ValueError: If the input format is invalid.
"""
if not (is_match:= re.match(self.sub_pattern, sub_input)):
raise ValueError(f"Invalid input format: {sub_input=}")
return CovalentBond(
protein_chain=is_match.group(1),
protein_residue_number=is_match.group(2),
protein_atom_name=is_match.group(3),
small_molecule_chain=is_match.group(4),
sm_atom_index=is_match.group(5),
new_chirality_atom_1=is_match.group(6),
new_chirality_atom_2=is_match.group(7))
@property
def formatted(self) -> Tuple[CovalentBond]:
"""
Returns a tuple of parsed CovalentBond objects from the input string.
Returns:
Tuple[CovalentBond]: A tuple containing CovalentBond objects.
"""
return tuple(self.parse_one_covalent_bond(sub_input) for sub_input in self.input.split(';'))
def load_covalent_molecules(protein_inputs, config, model_runner):
if config.covale_inputs is None:
@ -35,7 +140,7 @@ def load_covalent_molecules(protein_inputs, config, model_runner):
if config.sm_inputs is None:
raise ValueError("If you provide covale_inputs, you must also provide small molecule inputs")
covalent_bonds = eval(config.covale_inputs)
covalent_bonds = CovalentBondsParser(config.covale_inputs).formatted
sm_inputs = delete_leaving_atoms(config.sm_inputs)
residues_to_atomize, combined_molecules, extra_bonds = find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner)
chainid_to_input = {}
@ -50,14 +155,18 @@ def load_covalent_molecules(protein_inputs, config, model_runner):
return chainid_to_input, residues_to_atomize
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner):
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds: Tuple[CovalentBond], model_runner):
residues_to_atomize = [] # hold on to delete wayward inputs
combined_molecules = {} # combined multiple molecules that are bonded
extra_bonds = {}
for bond in covalent_bonds:
prot_chid, prot_res_idx, atom_to_bond = bond[0]
sm_chid, sm_atom_num = bond[1]
chirality_first_atom, chirality_second_atom = bond[2]
prot_chid = bond.protein_chain
prot_res_idx=bond.protein_residue_number
atom_to_bond=bond.protein_atom_name
sm_chid=bond.small_molecule_chain
sm_atom_num=bond.sm_atom_index
chirality_first_atom=bond.new_chirality_atom_1
chirality_second_atom = bond.new_chirality_atom_2
if chirality_first_atom.strip() == "null":
chirality_first_atom = None
if chirality_second_atom.strip() == "null":
@ -231,7 +340,7 @@ def recompute_xyz_after_chirality(obmol):
ff.FastRotorSearch()
ff.GetCoordinates(obmol)
else:
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
raise ValueError(f"Failed to generate 3D coordinates for molecule {obmol}.")
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
return atom_coords

View file

@ -17,6 +17,7 @@ from rf2aa.training.recycling import recycle_step_legacy
from rf2aa.util import writepdb, is_atom, Ls_from_same_chain_2d
from rf2aa.util_module import XYZConverter
script_path=os.path.dirname(os.path.realpath(__file__))
class ModelRunner:
@ -200,7 +201,7 @@ class ModelRunner:
return err_dict
@hydra.main(version_base=None, config_path='config/inference')
@hydra.main(version_base=None, config_path=os.path.join(script_path,'config','inference'))
def main(config):
runner = ModelRunner(config)
runner.infer()