RoseTTAFold-All-Atom/rf2aa/data/covale.py
2024-03-04 22:38:17 -08:00

308 lines
12 KiB
Python

import torch
from openbabel import openbabel
from typing import Optional
from dataclasses import dataclass
from tempfile import NamedTemporaryFile
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.data.parsers import parse_mol
from rf2aa.data.small_molecule import compute_features_from_obmol
from rf2aa.util import get_bond_feats
@dataclass
class MoleculeToMoleculeBond:
chain_index_first: int
absolute_atom_index_first: int
chain_index_second: int
absolute_atom_index_second: int
new_chirality_atom_first: Optional[str]
new_chirality_atom_second: Optional[str]
@dataclass
class AtomizedResidue:
chain: str
chain_index_in_combined_chain: int
absolute_N_index_in_chain: int
absolute_C_index_in_chain: int
original_chain: str
index_in_original_chain: int
def load_covalent_molecules(protein_inputs, config, model_runner):
if config.covale_inputs is None:
return None
if config.sm_inputs is None:
raise ValueError("If you provide covale_inputs, you must also provide small molecule inputs")
covalent_bonds = eval(config.covale_inputs)
sm_inputs = delete_leaving_atoms(config.sm_inputs)
residues_to_atomize, combined_molecules, extra_bonds = find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner)
chainid_to_input = {}
for chain, combined_molecule in combined_molecules.items():
extra_bonds_for_chain = extra_bonds[chain]
msa, bond_feats, xyz, Ls = get_combined_atoms_bonds(combined_molecule)
residues_to_atomize = update_absolute_indices_after_combination(residues_to_atomize, chain, Ls)
mol = make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds_for_chain)
xyz = recompute_xyz_after_chirality(mol)
input = compute_features_from_obmol(mol, msa, xyz, model_runner)
chainid_to_input[chain] = input
return chainid_to_input, residues_to_atomize
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner):
residues_to_atomize = [] # hold on to delete wayward inputs
combined_molecules = {} # combined multiple molecules that are bonded
extra_bonds = {}
for bond in covalent_bonds:
prot_chid, prot_res_idx, atom_to_bond = bond[0]
sm_chid, sm_atom_num = bond[1]
chirality_first_atom, chirality_second_atom = bond[2]
if chirality_first_atom.strip() == "null":
chirality_first_atom = None
if chirality_second_atom.strip() == "null":
chirality_second_atom = None
sm_atom_num = int(sm_atom_num) - 1 # 0 index
try:
assert sm_chid in sm_inputs, f"must provide a small molecule chain {sm_chid} for covalent bond: {bond}"
except:
print(f"Skipping bond: {bond} since no sm chain {sm_chid} was provided")
continue
assert sm_inputs[sm_chid].input_type == "sdf", "only sdf inputs can be covalently linked to proteins"
try:
protein_input = protein_inputs[prot_chid]
except Exception as e:
raise ValueError(f"first atom in covale_input must be present in\
a protein chain. Given chain: {prot_chid} was not in \
given protein chains: {list(protein_inputs.keys())}")
residue = (prot_chid, prot_res_idx, atom_to_bond)
file, atom_index = convert_residue_to_molecule(protein_inputs, residue, model_runner)
if sm_chid not in combined_molecules:
combined_molecules[sm_chid] = [sm_inputs[sm_chid].input]
combined_molecules[sm_chid].insert(0, file) # this is a bug, revert
absolute_chain_index_first = combined_molecules[sm_chid].index(sm_inputs[sm_chid].input)
absolute_chain_index_second = combined_molecules[sm_chid].index(file)
if sm_chid not in extra_bonds:
extra_bonds[sm_chid] = []
extra_bonds[sm_chid].append(MoleculeToMoleculeBond(
absolute_chain_index_first,
sm_atom_num,
absolute_chain_index_second,
atom_index,
new_chirality_atom_first=chirality_first_atom,
new_chirality_atom_second=chirality_second_atom
))
residues_to_atomize.append(AtomizedResidue(
sm_chid,
absolute_chain_index_second,
0,
2,
prot_chid,
int(prot_res_idx) -1
))
return residues_to_atomize, combined_molecules, extra_bonds
def convert_residue_to_molecule(protein_inputs, residue, model_runner):
"""convert residue into sdf and record index for covalent bond"""
prot_chid, prot_res_idx, atom_to_bond = residue
protein_input = protein_inputs[prot_chid]
prot_res_abs_idx = int(prot_res_idx) -1
residue_identity_num = protein_input.query_sequence()[prot_res_abs_idx]
residue_identity = ChemData().num2aa[residue_identity_num]
molecule_info = model_runner.molecule_db[residue_identity]
sdf = molecule_info["sdf"]
temp_file = create_and_populate_temp_file(sdf)
is_heavy = [i for i, a in enumerate(molecule_info["atom_id"]) if a[0] != "H"]
is_leaving = [a for i,a in enumerate(molecule_info["leaving"]) if i in is_heavy]
sdf_string_no_leaving_atoms = delete_leaving_atoms_single_chain(temp_file, is_leaving )
temp_file = create_and_populate_temp_file(sdf_string_no_leaving_atoms)
atom_names = molecule_info["atom_id"]
atom_index = atom_names.index(atom_to_bond.strip())
return temp_file, atom_index
def get_combined_atoms_bonds(combined_molecule):
atom_list = []
bond_feats_list = []
xyzs = []
Ls = []
for molecule in combined_molecule:
obmol, msa, ins, xyz, mask = parse_mol(
molecule,
filetype="sdf",
string=False,
generate_conformer=True,
find_automorphs=False
)
bond_feats = get_bond_feats(obmol)
atom_list.append(msa)
bond_feats_list.append(bond_feats)
xyzs.append(xyz)
Ls.append(msa.shape[0])
atoms = torch.cat(atom_list)
L_total = sum(Ls)
bond_feats = torch.zeros((L_total, L_total)).long()
offset = 0
for bf in bond_feats_list:
L = bf.shape[0]
bond_feats[offset:offset+L, offset:offset+L] = bf
offset += L
xyz = torch.cat(xyzs, dim=1)[0]
return atoms, bond_feats, xyz, Ls
def make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds):
mol = openbabel.OBMol()
for i,k in enumerate(msa):
element = ChemData().num2aa[k]
atomnum = ChemData().atomtype2atomnum[element]
a = mol.NewAtom()
a.SetAtomicNum(atomnum)
a.SetVector(float(xyz[i,0]), float(xyz[i,1]), float(xyz[i,2]))
first_index, second_index = bond_feats.nonzero(as_tuple=True)
for i, j in zip(first_index, second_index):
order = bond_feats[i,j]
bond = make_openbabel_bond(mol, i.item(), j.item(), order.item())
mol.AddBond(bond)
for bond in extra_bonds:
absolute_index_first = get_absolute_index_from_relative_indices(
bond.chain_index_first,
bond.absolute_atom_index_first,
Ls
)
absolute_index_second = get_absolute_index_from_relative_indices(
bond.chain_index_second,
bond.absolute_atom_index_second,
Ls
)
order = 1 #all covale bonds are single bonds
openbabel_bond = make_openbabel_bond(mol, absolute_index_first, absolute_index_second, order)
mol.AddBond(openbabel_bond)
set_chirality(mol, absolute_index_first, bond.new_chirality_atom_first)
set_chirality(mol, absolute_index_second, bond.new_chirality_atom_second)
return mol
def make_openbabel_bond(mol, i, j, order):
obb = openbabel.OBBond()
obb.SetBegin(mol.GetAtom(i+1))
obb.SetEnd(mol.GetAtom(j+1))
if order == 4:
obb.SetBondOrder(2)
obb.SetAromatic()
else:
obb.SetBondOrder(order)
return obb
def set_chirality(mol, absolute_atom_index, new_chirality):
stereo = openbabel.OBStereoFacade(mol)
if stereo.HasTetrahedralStereo(absolute_atom_index+1):
tetstereo = stereo.GetTetrahedralStereo(mol.GetAtom(absolute_atom_index+1).GetId())
if tetstereo is None:
return
assert new_chirality is not None, "you have introduced a new stereocenter, \
so you must specify its chirality either as CW, or CCW"
config = tetstereo.GetConfig()
config.winding = chirality_options[new_chirality]
tetstereo.SetConfig(config)
print("Updating chirality...")
else:
assert new_chirality is None, "you have specified a chirality without creating a new chiral center"
chirality_options = {
"CW": openbabel.OBStereo.Clockwise,
"CCW": openbabel.OBStereo.AntiClockwise,
}
def recompute_xyz_after_chirality(obmol):
builder = openbabel.OBBuilder()
builder.Build(obmol)
ff = openbabel.OBForceField.FindForceField("mmff94")
did_setup = ff.Setup(obmol)
if did_setup:
ff.FastRotorSearch()
ff.GetCoordinates(obmol)
else:
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
return atom_coords
def delete_leaving_atoms(sm_inputs):
updated_sm_inputs = {}
for chain in sm_inputs:
if "is_leaving" not in sm_inputs[chain]:
continue
is_leaving = eval(sm_inputs[chain]["is_leaving"])
sdf_string = delete_leaving_atoms_single_chain(sm_inputs[chain]["input"], is_leaving)
updated_sm_inputs[chain] = {
"input": create_and_populate_temp_file(sdf_string),
"input_type": "sdf"
}
sm_inputs.update(updated_sm_inputs)
return sm_inputs
def delete_leaving_atoms_single_chain(filename, is_leaving):
obmol, msa, ins, xyz, mask = parse_mol(
filename,
filetype="sdf",
string=False,
generate_conformer=True
)
assert len(is_leaving) == obmol.NumAtoms()
leaving_indices = torch.tensor(is_leaving).nonzero()
for idx in leaving_indices:
obmol.DeleteAtom(obmol.GetAtom(idx.item()+1))
obConversion = openbabel.OBConversion()
obConversion.SetInAndOutFormats("sdf", "sdf")
sdf_string = obConversion.WriteString(obmol)
return sdf_string
def get_absolute_index_from_relative_indices(chain_index, absolute_index_in_chain, Ls):
offset = sum(Ls[:chain_index])
return offset + absolute_index_in_chain
def update_absolute_indices_after_combination(residues_to_atomize, chain, Ls):
updated_residues_to_atomize = []
for residue in residues_to_atomize:
if residue.chain == chain:
absolute_index_N = get_absolute_index_from_relative_indices(
residue.chain_index_in_combined_chain,
residue.absolute_N_index_in_chain,
Ls)
absolute_index_C = get_absolute_index_from_relative_indices(
residue.chain_index_in_combined_chain,
residue.absolute_C_index_in_chain,
Ls)
updated_residue = AtomizedResidue(
residue.chain,
None,
absolute_index_N,
absolute_index_C,
residue.original_chain,
residue.index_in_original_chain
)
updated_residues_to_atomize.append(updated_residue)
else:
updated_residues_to_atomize.append(residue)
return updated_residues_to_atomize
def create_and_populate_temp_file(data):
# Create a temporary file
with NamedTemporaryFile(mode='w+', delete=False) as temp_file:
# Write the string to the temporary file
temp_file.write(data)
# Get the filename
temp_file_name = temp_file.name
return temp_file_name