RoseTTAFold-All-Atom/rf2aa/data/covale.py

309 lines
12 KiB
Python
Raw Normal View History

2024-03-05 06:38:17 +00:00
import torch
from openbabel import openbabel
from typing import Optional
from dataclasses import dataclass
from tempfile import NamedTemporaryFile
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.data.parsers import parse_mol
from rf2aa.data.small_molecule import compute_features_from_obmol
from rf2aa.util import get_bond_feats
@dataclass
class MoleculeToMoleculeBond:
chain_index_first: int
absolute_atom_index_first: int
chain_index_second: int
absolute_atom_index_second: int
new_chirality_atom_first: Optional[str]
new_chirality_atom_second: Optional[str]
@dataclass
class AtomizedResidue:
chain: str
chain_index_in_combined_chain: int
absolute_N_index_in_chain: int
absolute_C_index_in_chain: int
original_chain: str
index_in_original_chain: int
def load_covalent_molecules(protein_inputs, config, model_runner):
if config.covale_inputs is None:
return None
if config.sm_inputs is None:
raise ValueError("If you provide covale_inputs, you must also provide small molecule inputs")
covalent_bonds = eval(config.covale_inputs)
sm_inputs = delete_leaving_atoms(config.sm_inputs)
residues_to_atomize, combined_molecules, extra_bonds = find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner)
chainid_to_input = {}
for chain, combined_molecule in combined_molecules.items():
extra_bonds_for_chain = extra_bonds[chain]
msa, bond_feats, xyz, Ls = get_combined_atoms_bonds(combined_molecule)
residues_to_atomize = update_absolute_indices_after_combination(residues_to_atomize, chain, Ls)
mol = make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds_for_chain)
xyz = recompute_xyz_after_chirality(mol)
input = compute_features_from_obmol(mol, msa, xyz, model_runner)
chainid_to_input[chain] = input
return chainid_to_input, residues_to_atomize
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner):
residues_to_atomize = [] # hold on to delete wayward inputs
combined_molecules = {} # combined multiple molecules that are bonded
extra_bonds = {}
for bond in covalent_bonds:
prot_chid, prot_res_idx, atom_to_bond = bond[0]
sm_chid, sm_atom_num = bond[1]
chirality_first_atom, chirality_second_atom = bond[2]
if chirality_first_atom.strip() == "null":
chirality_first_atom = None
if chirality_second_atom.strip() == "null":
chirality_second_atom = None
sm_atom_num = int(sm_atom_num) - 1 # 0 index
try:
assert sm_chid in sm_inputs, f"must provide a small molecule chain {sm_chid} for covalent bond: {bond}"
except:
print(f"Skipping bond: {bond} since no sm chain {sm_chid} was provided")
continue
assert sm_inputs[sm_chid].input_type == "sdf", "only sdf inputs can be covalently linked to proteins"
try:
protein_input = protein_inputs[prot_chid]
except Exception as e:
raise ValueError(f"first atom in covale_input must be present in\
a protein chain. Given chain: {prot_chid} was not in \
given protein chains: {list(protein_inputs.keys())}")
residue = (prot_chid, prot_res_idx, atom_to_bond)
file, atom_index = convert_residue_to_molecule(protein_inputs, residue, model_runner)
if sm_chid not in combined_molecules:
combined_molecules[sm_chid] = [sm_inputs[sm_chid].input]
combined_molecules[sm_chid].insert(0, file) # this is a bug, revert
absolute_chain_index_first = combined_molecules[sm_chid].index(sm_inputs[sm_chid].input)
absolute_chain_index_second = combined_molecules[sm_chid].index(file)
if sm_chid not in extra_bonds:
extra_bonds[sm_chid] = []
extra_bonds[sm_chid].append(MoleculeToMoleculeBond(
absolute_chain_index_first,
sm_atom_num,
absolute_chain_index_second,
atom_index,
new_chirality_atom_first=chirality_first_atom,
new_chirality_atom_second=chirality_second_atom
))
residues_to_atomize.append(AtomizedResidue(
sm_chid,
absolute_chain_index_second,
0,
2,
prot_chid,
int(prot_res_idx) -1
))
return residues_to_atomize, combined_molecules, extra_bonds
def convert_residue_to_molecule(protein_inputs, residue, model_runner):
"""convert residue into sdf and record index for covalent bond"""
prot_chid, prot_res_idx, atom_to_bond = residue
protein_input = protein_inputs[prot_chid]
prot_res_abs_idx = int(prot_res_idx) -1
residue_identity_num = protein_input.query_sequence()[prot_res_abs_idx]
residue_identity = ChemData().num2aa[residue_identity_num]
molecule_info = model_runner.molecule_db[residue_identity]
sdf = molecule_info["sdf"]
temp_file = create_and_populate_temp_file(sdf)
is_heavy = [i for i, a in enumerate(molecule_info["atom_id"]) if a[0] != "H"]
is_leaving = [a for i,a in enumerate(molecule_info["leaving"]) if i in is_heavy]
sdf_string_no_leaving_atoms = delete_leaving_atoms_single_chain(temp_file, is_leaving )
temp_file = create_and_populate_temp_file(sdf_string_no_leaving_atoms)
atom_names = molecule_info["atom_id"]
atom_index = atom_names.index(atom_to_bond.strip())
return temp_file, atom_index
def get_combined_atoms_bonds(combined_molecule):
atom_list = []
bond_feats_list = []
xyzs = []
Ls = []
for molecule in combined_molecule:
obmol, msa, ins, xyz, mask = parse_mol(
molecule,
filetype="sdf",
string=False,
generate_conformer=True,
find_automorphs=False
)
bond_feats = get_bond_feats(obmol)
atom_list.append(msa)
bond_feats_list.append(bond_feats)
xyzs.append(xyz)
Ls.append(msa.shape[0])
atoms = torch.cat(atom_list)
L_total = sum(Ls)
bond_feats = torch.zeros((L_total, L_total)).long()
offset = 0
for bf in bond_feats_list:
L = bf.shape[0]
bond_feats[offset:offset+L, offset:offset+L] = bf
offset += L
xyz = torch.cat(xyzs, dim=1)[0]
return atoms, bond_feats, xyz, Ls
def make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds):
mol = openbabel.OBMol()
for i,k in enumerate(msa):
element = ChemData().num2aa[k]
atomnum = ChemData().atomtype2atomnum[element]
a = mol.NewAtom()
a.SetAtomicNum(atomnum)
a.SetVector(float(xyz[i,0]), float(xyz[i,1]), float(xyz[i,2]))
first_index, second_index = bond_feats.nonzero(as_tuple=True)
for i, j in zip(first_index, second_index):
order = bond_feats[i,j]
bond = make_openbabel_bond(mol, i.item(), j.item(), order.item())
mol.AddBond(bond)
for bond in extra_bonds:
absolute_index_first = get_absolute_index_from_relative_indices(
bond.chain_index_first,
bond.absolute_atom_index_first,
Ls
)
absolute_index_second = get_absolute_index_from_relative_indices(
bond.chain_index_second,
bond.absolute_atom_index_second,
Ls
)
order = 1 #all covale bonds are single bonds
openbabel_bond = make_openbabel_bond(mol, absolute_index_first, absolute_index_second, order)
mol.AddBond(openbabel_bond)
set_chirality(mol, absolute_index_first, bond.new_chirality_atom_first)
set_chirality(mol, absolute_index_second, bond.new_chirality_atom_second)
return mol
def make_openbabel_bond(mol, i, j, order):
obb = openbabel.OBBond()
obb.SetBegin(mol.GetAtom(i+1))
obb.SetEnd(mol.GetAtom(j+1))
if order == 4:
obb.SetBondOrder(2)
obb.SetAromatic()
else:
obb.SetBondOrder(order)
return obb
def set_chirality(mol, absolute_atom_index, new_chirality):
stereo = openbabel.OBStereoFacade(mol)
if stereo.HasTetrahedralStereo(absolute_atom_index+1):
tetstereo = stereo.GetTetrahedralStereo(mol.GetAtom(absolute_atom_index+1).GetId())
if tetstereo is None:
return
assert new_chirality is not None, "you have introduced a new stereocenter, \
so you must specify its chirality either as CW, or CCW"
config = tetstereo.GetConfig()
config.winding = chirality_options[new_chirality]
tetstereo.SetConfig(config)
print("Updating chirality...")
else:
assert new_chirality is None, "you have specified a chirality without creating a new chiral center"
chirality_options = {
"CW": openbabel.OBStereo.Clockwise,
"CCW": openbabel.OBStereo.AntiClockwise,
}
def recompute_xyz_after_chirality(obmol):
builder = openbabel.OBBuilder()
builder.Build(obmol)
ff = openbabel.OBForceField.FindForceField("mmff94")
did_setup = ff.Setup(obmol)
if did_setup:
ff.FastRotorSearch()
ff.GetCoordinates(obmol)
else:
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
return atom_coords
def delete_leaving_atoms(sm_inputs):
updated_sm_inputs = {}
for chain in sm_inputs:
if "is_leaving" not in sm_inputs[chain]:
continue
is_leaving = eval(sm_inputs[chain]["is_leaving"])
sdf_string = delete_leaving_atoms_single_chain(sm_inputs[chain]["input"], is_leaving)
updated_sm_inputs[chain] = {
"input": create_and_populate_temp_file(sdf_string),
"input_type": "sdf"
}
sm_inputs.update(updated_sm_inputs)
return sm_inputs
def delete_leaving_atoms_single_chain(filename, is_leaving):
obmol, msa, ins, xyz, mask = parse_mol(
filename,
filetype="sdf",
string=False,
generate_conformer=True
)
assert len(is_leaving) == obmol.NumAtoms()
leaving_indices = torch.tensor(is_leaving).nonzero()
for idx in leaving_indices:
obmol.DeleteAtom(obmol.GetAtom(idx.item()+1))
obConversion = openbabel.OBConversion()
obConversion.SetInAndOutFormats("sdf", "sdf")
sdf_string = obConversion.WriteString(obmol)
return sdf_string
def get_absolute_index_from_relative_indices(chain_index, absolute_index_in_chain, Ls):
offset = sum(Ls[:chain_index])
return offset + absolute_index_in_chain
def update_absolute_indices_after_combination(residues_to_atomize, chain, Ls):
updated_residues_to_atomize = []
for residue in residues_to_atomize:
if residue.chain == chain:
absolute_index_N = get_absolute_index_from_relative_indices(
residue.chain_index_in_combined_chain,
residue.absolute_N_index_in_chain,
Ls)
absolute_index_C = get_absolute_index_from_relative_indices(
residue.chain_index_in_combined_chain,
residue.absolute_C_index_in_chain,
Ls)
updated_residue = AtomizedResidue(
residue.chain,
None,
absolute_index_N,
absolute_index_C,
residue.original_chain,
residue.index_in_original_chain
)
updated_residues_to_atomize.append(updated_residue)
else:
updated_residues_to_atomize.append(residue)
return updated_residues_to_atomize
def create_and_populate_temp_file(data):
# Create a temporary file
with NamedTemporaryFile(mode='w+', delete=False) as temp_file:
# Write the string to the temporary file
temp_file.write(data)
# Get the filename
temp_file_name = temp_file.name
return temp_file_name