mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-04 22:25:42 +00:00
309 lines
12 KiB
Python
309 lines
12 KiB
Python
|
import torch
|
||
|
from openbabel import openbabel
|
||
|
from typing import Optional
|
||
|
from dataclasses import dataclass
|
||
|
from tempfile import NamedTemporaryFile
|
||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||
|
from rf2aa.data.parsers import parse_mol
|
||
|
from rf2aa.data.small_molecule import compute_features_from_obmol
|
||
|
from rf2aa.util import get_bond_feats
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class MoleculeToMoleculeBond:
|
||
|
chain_index_first: int
|
||
|
absolute_atom_index_first: int
|
||
|
chain_index_second: int
|
||
|
absolute_atom_index_second: int
|
||
|
new_chirality_atom_first: Optional[str]
|
||
|
new_chirality_atom_second: Optional[str]
|
||
|
|
||
|
@dataclass
|
||
|
class AtomizedResidue:
|
||
|
chain: str
|
||
|
chain_index_in_combined_chain: int
|
||
|
absolute_N_index_in_chain: int
|
||
|
absolute_C_index_in_chain: int
|
||
|
original_chain: str
|
||
|
index_in_original_chain: int
|
||
|
|
||
|
|
||
|
def load_covalent_molecules(protein_inputs, config, model_runner):
|
||
|
if config.covale_inputs is None:
|
||
|
return None
|
||
|
|
||
|
if config.sm_inputs is None:
|
||
|
raise ValueError("If you provide covale_inputs, you must also provide small molecule inputs")
|
||
|
|
||
|
covalent_bonds = eval(config.covale_inputs)
|
||
|
sm_inputs = delete_leaving_atoms(config.sm_inputs)
|
||
|
residues_to_atomize, combined_molecules, extra_bonds = find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner)
|
||
|
chainid_to_input = {}
|
||
|
for chain, combined_molecule in combined_molecules.items():
|
||
|
extra_bonds_for_chain = extra_bonds[chain]
|
||
|
msa, bond_feats, xyz, Ls = get_combined_atoms_bonds(combined_molecule)
|
||
|
residues_to_atomize = update_absolute_indices_after_combination(residues_to_atomize, chain, Ls)
|
||
|
mol = make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds_for_chain)
|
||
|
xyz = recompute_xyz_after_chirality(mol)
|
||
|
input = compute_features_from_obmol(mol, msa, xyz, model_runner)
|
||
|
chainid_to_input[chain] = input
|
||
|
|
||
|
return chainid_to_input, residues_to_atomize
|
||
|
|
||
|
def find_residues_to_atomize(protein_inputs, sm_inputs, covalent_bonds, model_runner):
|
||
|
residues_to_atomize = [] # hold on to delete wayward inputs
|
||
|
combined_molecules = {} # combined multiple molecules that are bonded
|
||
|
extra_bonds = {}
|
||
|
for bond in covalent_bonds:
|
||
|
prot_chid, prot_res_idx, atom_to_bond = bond[0]
|
||
|
sm_chid, sm_atom_num = bond[1]
|
||
|
chirality_first_atom, chirality_second_atom = bond[2]
|
||
|
if chirality_first_atom.strip() == "null":
|
||
|
chirality_first_atom = None
|
||
|
if chirality_second_atom.strip() == "null":
|
||
|
chirality_second_atom = None
|
||
|
|
||
|
sm_atom_num = int(sm_atom_num) - 1 # 0 index
|
||
|
try:
|
||
|
assert sm_chid in sm_inputs, f"must provide a small molecule chain {sm_chid} for covalent bond: {bond}"
|
||
|
except:
|
||
|
print(f"Skipping bond: {bond} since no sm chain {sm_chid} was provided")
|
||
|
continue
|
||
|
assert sm_inputs[sm_chid].input_type == "sdf", "only sdf inputs can be covalently linked to proteins"
|
||
|
try:
|
||
|
protein_input = protein_inputs[prot_chid]
|
||
|
except Exception as e:
|
||
|
raise ValueError(f"first atom in covale_input must be present in\
|
||
|
a protein chain. Given chain: {prot_chid} was not in \
|
||
|
given protein chains: {list(protein_inputs.keys())}")
|
||
|
|
||
|
residue = (prot_chid, prot_res_idx, atom_to_bond)
|
||
|
file, atom_index = convert_residue_to_molecule(protein_inputs, residue, model_runner)
|
||
|
if sm_chid not in combined_molecules:
|
||
|
combined_molecules[sm_chid] = [sm_inputs[sm_chid].input]
|
||
|
combined_molecules[sm_chid].insert(0, file) # this is a bug, revert
|
||
|
absolute_chain_index_first = combined_molecules[sm_chid].index(sm_inputs[sm_chid].input)
|
||
|
absolute_chain_index_second = combined_molecules[sm_chid].index(file)
|
||
|
|
||
|
if sm_chid not in extra_bonds:
|
||
|
extra_bonds[sm_chid] = []
|
||
|
extra_bonds[sm_chid].append(MoleculeToMoleculeBond(
|
||
|
absolute_chain_index_first,
|
||
|
sm_atom_num,
|
||
|
absolute_chain_index_second,
|
||
|
atom_index,
|
||
|
new_chirality_atom_first=chirality_first_atom,
|
||
|
new_chirality_atom_second=chirality_second_atom
|
||
|
))
|
||
|
residues_to_atomize.append(AtomizedResidue(
|
||
|
sm_chid,
|
||
|
absolute_chain_index_second,
|
||
|
0,
|
||
|
2,
|
||
|
prot_chid,
|
||
|
int(prot_res_idx) -1
|
||
|
))
|
||
|
|
||
|
return residues_to_atomize, combined_molecules, extra_bonds
|
||
|
|
||
|
def convert_residue_to_molecule(protein_inputs, residue, model_runner):
|
||
|
"""convert residue into sdf and record index for covalent bond"""
|
||
|
prot_chid, prot_res_idx, atom_to_bond = residue
|
||
|
protein_input = protein_inputs[prot_chid]
|
||
|
prot_res_abs_idx = int(prot_res_idx) -1
|
||
|
residue_identity_num = protein_input.query_sequence()[prot_res_abs_idx]
|
||
|
residue_identity = ChemData().num2aa[residue_identity_num]
|
||
|
molecule_info = model_runner.molecule_db[residue_identity]
|
||
|
sdf = molecule_info["sdf"]
|
||
|
temp_file = create_and_populate_temp_file(sdf)
|
||
|
is_heavy = [i for i, a in enumerate(molecule_info["atom_id"]) if a[0] != "H"]
|
||
|
is_leaving = [a for i,a in enumerate(molecule_info["leaving"]) if i in is_heavy]
|
||
|
|
||
|
sdf_string_no_leaving_atoms = delete_leaving_atoms_single_chain(temp_file, is_leaving )
|
||
|
temp_file = create_and_populate_temp_file(sdf_string_no_leaving_atoms)
|
||
|
atom_names = molecule_info["atom_id"]
|
||
|
atom_index = atom_names.index(atom_to_bond.strip())
|
||
|
return temp_file, atom_index
|
||
|
|
||
|
def get_combined_atoms_bonds(combined_molecule):
|
||
|
atom_list = []
|
||
|
bond_feats_list = []
|
||
|
xyzs = []
|
||
|
Ls = []
|
||
|
for molecule in combined_molecule:
|
||
|
obmol, msa, ins, xyz, mask = parse_mol(
|
||
|
molecule,
|
||
|
filetype="sdf",
|
||
|
string=False,
|
||
|
generate_conformer=True,
|
||
|
find_automorphs=False
|
||
|
)
|
||
|
bond_feats = get_bond_feats(obmol)
|
||
|
|
||
|
atom_list.append(msa)
|
||
|
bond_feats_list.append(bond_feats)
|
||
|
xyzs.append(xyz)
|
||
|
Ls.append(msa.shape[0])
|
||
|
|
||
|
atoms = torch.cat(atom_list)
|
||
|
L_total = sum(Ls)
|
||
|
bond_feats = torch.zeros((L_total, L_total)).long()
|
||
|
offset = 0
|
||
|
for bf in bond_feats_list:
|
||
|
L = bf.shape[0]
|
||
|
bond_feats[offset:offset+L, offset:offset+L] = bf
|
||
|
offset += L
|
||
|
xyz = torch.cat(xyzs, dim=1)[0]
|
||
|
return atoms, bond_feats, xyz, Ls
|
||
|
|
||
|
def make_obmol_from_atoms_bonds(msa, bond_feats, xyz, Ls, extra_bonds):
|
||
|
mol = openbabel.OBMol()
|
||
|
for i,k in enumerate(msa):
|
||
|
element = ChemData().num2aa[k]
|
||
|
atomnum = ChemData().atomtype2atomnum[element]
|
||
|
a = mol.NewAtom()
|
||
|
a.SetAtomicNum(atomnum)
|
||
|
a.SetVector(float(xyz[i,0]), float(xyz[i,1]), float(xyz[i,2]))
|
||
|
|
||
|
first_index, second_index = bond_feats.nonzero(as_tuple=True)
|
||
|
for i, j in zip(first_index, second_index):
|
||
|
order = bond_feats[i,j]
|
||
|
bond = make_openbabel_bond(mol, i.item(), j.item(), order.item())
|
||
|
mol.AddBond(bond)
|
||
|
|
||
|
for bond in extra_bonds:
|
||
|
absolute_index_first = get_absolute_index_from_relative_indices(
|
||
|
bond.chain_index_first,
|
||
|
bond.absolute_atom_index_first,
|
||
|
Ls
|
||
|
)
|
||
|
absolute_index_second = get_absolute_index_from_relative_indices(
|
||
|
bond.chain_index_second,
|
||
|
bond.absolute_atom_index_second,
|
||
|
Ls
|
||
|
)
|
||
|
order = 1 #all covale bonds are single bonds
|
||
|
openbabel_bond = make_openbabel_bond(mol, absolute_index_first, absolute_index_second, order)
|
||
|
mol.AddBond(openbabel_bond)
|
||
|
set_chirality(mol, absolute_index_first, bond.new_chirality_atom_first)
|
||
|
set_chirality(mol, absolute_index_second, bond.new_chirality_atom_second)
|
||
|
return mol
|
||
|
|
||
|
def make_openbabel_bond(mol, i, j, order):
|
||
|
obb = openbabel.OBBond()
|
||
|
obb.SetBegin(mol.GetAtom(i+1))
|
||
|
obb.SetEnd(mol.GetAtom(j+1))
|
||
|
if order == 4:
|
||
|
obb.SetBondOrder(2)
|
||
|
obb.SetAromatic()
|
||
|
else:
|
||
|
obb.SetBondOrder(order)
|
||
|
return obb
|
||
|
|
||
|
def set_chirality(mol, absolute_atom_index, new_chirality):
|
||
|
stereo = openbabel.OBStereoFacade(mol)
|
||
|
if stereo.HasTetrahedralStereo(absolute_atom_index+1):
|
||
|
tetstereo = stereo.GetTetrahedralStereo(mol.GetAtom(absolute_atom_index+1).GetId())
|
||
|
if tetstereo is None:
|
||
|
return
|
||
|
|
||
|
assert new_chirality is not None, "you have introduced a new stereocenter, \
|
||
|
so you must specify its chirality either as CW, or CCW"
|
||
|
|
||
|
config = tetstereo.GetConfig()
|
||
|
config.winding = chirality_options[new_chirality]
|
||
|
tetstereo.SetConfig(config)
|
||
|
print("Updating chirality...")
|
||
|
else:
|
||
|
assert new_chirality is None, "you have specified a chirality without creating a new chiral center"
|
||
|
|
||
|
chirality_options = {
|
||
|
"CW": openbabel.OBStereo.Clockwise,
|
||
|
"CCW": openbabel.OBStereo.AntiClockwise,
|
||
|
}
|
||
|
|
||
|
def recompute_xyz_after_chirality(obmol):
|
||
|
builder = openbabel.OBBuilder()
|
||
|
builder.Build(obmol)
|
||
|
ff = openbabel.OBForceField.FindForceField("mmff94")
|
||
|
did_setup = ff.Setup(obmol)
|
||
|
if did_setup:
|
||
|
ff.FastRotorSearch()
|
||
|
ff.GetCoordinates(obmol)
|
||
|
else:
|
||
|
raise ValueError(f"Failed to generate 3D coordinates for molecule {filename}.")
|
||
|
atom_coords = torch.tensor([[obmol.GetAtom(i).x(),obmol.GetAtom(i).y(), obmol.GetAtom(i).z()]
|
||
|
for i in range(1, obmol.NumAtoms()+1)]).unsqueeze(0) # (1, natoms, 3)
|
||
|
return atom_coords
|
||
|
|
||
|
def delete_leaving_atoms(sm_inputs):
|
||
|
updated_sm_inputs = {}
|
||
|
for chain in sm_inputs:
|
||
|
if "is_leaving" not in sm_inputs[chain]:
|
||
|
continue
|
||
|
is_leaving = eval(sm_inputs[chain]["is_leaving"])
|
||
|
sdf_string = delete_leaving_atoms_single_chain(sm_inputs[chain]["input"], is_leaving)
|
||
|
updated_sm_inputs[chain] = {
|
||
|
"input": create_and_populate_temp_file(sdf_string),
|
||
|
"input_type": "sdf"
|
||
|
}
|
||
|
|
||
|
sm_inputs.update(updated_sm_inputs)
|
||
|
return sm_inputs
|
||
|
|
||
|
def delete_leaving_atoms_single_chain(filename, is_leaving):
|
||
|
obmol, msa, ins, xyz, mask = parse_mol(
|
||
|
filename,
|
||
|
filetype="sdf",
|
||
|
string=False,
|
||
|
generate_conformer=True
|
||
|
)
|
||
|
assert len(is_leaving) == obmol.NumAtoms()
|
||
|
leaving_indices = torch.tensor(is_leaving).nonzero()
|
||
|
for idx in leaving_indices:
|
||
|
obmol.DeleteAtom(obmol.GetAtom(idx.item()+1))
|
||
|
obConversion = openbabel.OBConversion()
|
||
|
obConversion.SetInAndOutFormats("sdf", "sdf")
|
||
|
sdf_string = obConversion.WriteString(obmol)
|
||
|
return sdf_string
|
||
|
|
||
|
def get_absolute_index_from_relative_indices(chain_index, absolute_index_in_chain, Ls):
|
||
|
offset = sum(Ls[:chain_index])
|
||
|
return offset + absolute_index_in_chain
|
||
|
|
||
|
def update_absolute_indices_after_combination(residues_to_atomize, chain, Ls):
|
||
|
updated_residues_to_atomize = []
|
||
|
for residue in residues_to_atomize:
|
||
|
if residue.chain == chain:
|
||
|
absolute_index_N = get_absolute_index_from_relative_indices(
|
||
|
residue.chain_index_in_combined_chain,
|
||
|
residue.absolute_N_index_in_chain,
|
||
|
Ls)
|
||
|
absolute_index_C = get_absolute_index_from_relative_indices(
|
||
|
residue.chain_index_in_combined_chain,
|
||
|
residue.absolute_C_index_in_chain,
|
||
|
Ls)
|
||
|
updated_residue = AtomizedResidue(
|
||
|
residue.chain,
|
||
|
None,
|
||
|
absolute_index_N,
|
||
|
absolute_index_C,
|
||
|
residue.original_chain,
|
||
|
residue.index_in_original_chain
|
||
|
)
|
||
|
updated_residues_to_atomize.append(updated_residue)
|
||
|
else:
|
||
|
updated_residues_to_atomize.append(residue)
|
||
|
return updated_residues_to_atomize
|
||
|
|
||
|
def create_and_populate_temp_file(data):
|
||
|
# Create a temporary file
|
||
|
with NamedTemporaryFile(mode='w+', delete=False) as temp_file:
|
||
|
# Write the string to the temporary file
|
||
|
temp_file.write(data)
|
||
|
|
||
|
# Get the filename
|
||
|
temp_file_name = temp_file.name
|
||
|
|
||
|
return temp_file_name
|