mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-04 22:25:42 +00:00
1044 lines
40 KiB
Python
1044 lines
40 KiB
Python
import sys
|
|
import warnings
|
|
import assertpy
|
|
|
|
import numpy as np
|
|
import random
|
|
import torch
|
|
import warnings
|
|
from assertpy import assert_that
|
|
|
|
import networkx as nx
|
|
import itertools
|
|
from itertools import combinations
|
|
from collections import OrderedDict, Counter
|
|
from openbabel import openbabel
|
|
from scipy.spatial.transform import Rotation
|
|
from icecream import ic
|
|
|
|
from rf2aa.chemical import ChemicalData as ChemData
|
|
from rf2aa.kinematics import get_atomize_protein_chirals, generate_Cbeta
|
|
from rf2aa.scoring import *
|
|
|
|
|
|
def random_rot_trans(xyz, random_noise=20.0, deterministic: bool = False):
|
|
if deterministic:
|
|
random.seed(0)
|
|
np.random.seed(0)
|
|
torch.manual_seed(0)
|
|
torch.cuda.manual_seed(0)
|
|
|
|
# xyz: (N, L, 27, 3)
|
|
N, L = xyz.shape[:2]
|
|
|
|
# pick random rotation axis
|
|
R_mat = torch.tensor(Rotation.random(N).as_matrix(), dtype=xyz.dtype).to(xyz.device)
|
|
xyz = torch.einsum('nij,nlaj->nlai', R_mat, xyz) + torch.rand(N,1,1,3, device=xyz.device)*random_noise
|
|
return xyz
|
|
|
|
def get_prot_sm_mask(atom_mask, seq):
|
|
"""
|
|
Parameters
|
|
----------
|
|
atom_mask : (..., L, Natoms)
|
|
seq : (L)
|
|
|
|
Returns
|
|
-------
|
|
mask : (..., L)
|
|
"""
|
|
sm_mask = is_atom(seq).to(atom_mask.device) # (L)
|
|
# Asserting that atom_mask is full for masked regions of proteins [should be]
|
|
has_backbone = atom_mask[...,:3].all(dim=-1)
|
|
# has_backbone_prot = has_backbone[...,~sm_mask]
|
|
# n_protein_with_backbone = has_backbone.sum()
|
|
# n_protein = (~sm_mask).sum()
|
|
#assert_that((n_protein/n_protein_with_backbone).item()).is_greater_than(0.8)
|
|
mask_prot = has_backbone & ~sm_mask # valid protein/NA residues (L)
|
|
mask_ca_sm = atom_mask[...,1] & sm_mask # valid sm mol positions (L)
|
|
|
|
mask = mask_prot | mask_ca_sm # valid positions
|
|
return mask
|
|
|
|
def center_and_realign_missing(xyz, mask_t, seq=None, same_chain=None, should_center: bool = True):
|
|
"""
|
|
Moves center of mass of xyz to origin, then moves positions with missing
|
|
coordinates to nearest existing residue on same chain.
|
|
|
|
Parameters
|
|
----------
|
|
seq : (L)
|
|
xyz : (L, Natms, 3)
|
|
mask_t : (L, Natms)
|
|
same_chain : (L, L)
|
|
|
|
Returns
|
|
-------
|
|
xyz : (L, Natms, 3)
|
|
|
|
"""
|
|
L = xyz.shape[0]
|
|
|
|
if same_chain is None:
|
|
same_chain = torch.full((L,L), True)
|
|
|
|
# valid protein/NA/small mol. positions
|
|
if seq is None:
|
|
mask = torch.full((L,), True)
|
|
else:
|
|
mask = get_prot_sm_mask(mask_t, seq)
|
|
|
|
# center c.o.m of existing residues at the origin
|
|
if should_center:
|
|
center_CA = xyz[mask,1].mean(dim=0) # (3)
|
|
xyz = torch.where(mask.view(L,1,1), xyz - center_CA.view(1, 1, 3), xyz)
|
|
|
|
# move missing residues to the closest valid residues on same chain
|
|
exist_in_xyz = torch.where(mask)[0] # (L_sub)
|
|
same_chain_in_xyz = same_chain[:,mask].bool() # (L, L_sub)
|
|
seqmap = (torch.arange(L, device=xyz.device)[:,None] - exist_in_xyz[None,:]).abs() # (L, L_sub)
|
|
seqmap[~same_chain_in_xyz] += 99999
|
|
seqmap = torch.argmin(seqmap, dim=-1) # (L)
|
|
idx = torch.gather(exist_in_xyz, 0, seqmap) # (L)
|
|
offset_CA = torch.gather(xyz[:,1], 0, idx.reshape(L,1).expand(-1,3))
|
|
has_neighbor = same_chain_in_xyz.all(-1)
|
|
offset_CA[~has_neighbor] = 0 # stay at origin if nothing on same chain has coords
|
|
xyz = torch.where(mask.view(L, 1, 1), xyz, xyz + offset_CA.reshape(L,1,3))
|
|
|
|
return xyz
|
|
|
|
|
|
# note: needs consistency with chemical.py
|
|
def is_protein(seq):
|
|
return seq < ChemData().NPROTAAS
|
|
|
|
def is_nucleic(seq):
|
|
return (seq>=ChemData().NPROTAAS) * (seq <= ChemData().NNAPROTAAS)
|
|
|
|
# fd hacky
|
|
def is_DNA(seq):
|
|
return (seq>=ChemData().NPROTAAS) * (seq < ChemData().NPROTAAS+5)
|
|
|
|
# fd hacky
|
|
def is_RNA(seq):
|
|
return (seq>=ChemData().NPROTAAS+5) * (seq < ChemData().NNAPROTAAS)
|
|
|
|
def is_atom(seq):
|
|
return seq > ChemData().NNAPROTAAS
|
|
|
|
# build a frame from 3 points
|
|
#fd - more complicated version splits angle deviations between CA-N and CA-C (giving more accurate CB position)
|
|
#fd - makes no assumptions about input dims (other than last 1 is xyz)
|
|
def rigid_from_3_points(N, Ca, C, is_na=None, eps=1e-4):
|
|
dims = N.shape[:-1]
|
|
|
|
v1 = C-Ca
|
|
v2 = N-Ca
|
|
e1 = v1/(torch.norm(v1, dim=-1, keepdim=True)+eps)
|
|
u2 = v2-(torch.einsum('...li, ...li -> ...l', e1, v2)[...,None]*e1)
|
|
e2 = u2/(torch.norm(u2, dim=-1, keepdim=True)+eps)
|
|
e3 = torch.cross(e1, e2, dim=-1)
|
|
R = torch.cat([e1[...,None], e2[...,None], e3[...,None]], axis=-1) #[B,L,3,3] - rotation matrix
|
|
|
|
v2 = v2/(torch.norm(v2, dim=-1, keepdim=True)+eps)
|
|
cosref = torch.sum(e1*v2, dim=-1)
|
|
|
|
costgt = torch.full(dims, -0.3616, device=N.device)
|
|
if is_na is not None:
|
|
costgt[is_na] = ChemData().costgtNA
|
|
|
|
cos2del = torch.clamp( cosref*costgt + torch.sqrt((1-cosref*cosref)*(1-costgt*costgt)+eps), min=-1.0, max=1.0 )
|
|
|
|
cosdel = torch.sqrt(0.5*(1+cos2del)+eps)
|
|
|
|
sindel = torch.sign(costgt-cosref) * torch.sqrt(1-0.5*(1+cos2del)+eps)
|
|
|
|
Rp = torch.eye(3, device=N.device).repeat(*dims,1,1)
|
|
Rp[...,0,0] = cosdel
|
|
Rp[...,0,1] = -sindel
|
|
Rp[...,1,0] = sindel
|
|
Rp[...,1,1] = cosdel
|
|
R = torch.einsum('...ij,...jk->...ik', R,Rp)
|
|
|
|
return R, Ca
|
|
|
|
def idealize_reference_frame(seq, xyz_in):
|
|
xyz = xyz_in.clone()
|
|
|
|
namask = is_nucleic(seq)
|
|
Rs, Ts = rigid_from_3_points(xyz[...,0,:],xyz[...,1,:],xyz[...,2,:], namask)
|
|
|
|
protmask = ~namask
|
|
|
|
pmask_bs,pmask_rs = protmask.nonzero(as_tuple=True)
|
|
nmask_bs,nmask_rs = namask.nonzero(as_tuple=True)
|
|
xyz[pmask_bs,pmask_rs,0,:] = torch.einsum('...ij,j->...i', Rs[pmask_bs,pmask_rs], ChemData().init_N.to(device=xyz_in.device) ) + Ts[pmask_bs,pmask_rs]
|
|
xyz[pmask_bs,pmask_rs,2,:] = torch.einsum('...ij,j->...i', Rs[pmask_bs,pmask_rs], ChemData().init_C.to(device=xyz_in.device) ) + Ts[pmask_bs,pmask_rs]
|
|
xyz[nmask_bs,nmask_rs,0,:] = torch.einsum('...ij,j->...i', Rs[nmask_bs,nmask_rs], ChemData().init_O1.to(device=xyz_in.device) ) + Ts[nmask_bs,nmask_rs]
|
|
xyz[nmask_bs,nmask_rs,2,:] = torch.einsum('...ij,j->...i', Rs[nmask_bs,nmask_rs], ChemData().init_O2.to(device=xyz_in.device) ) + Ts[nmask_bs,nmask_rs]
|
|
|
|
return xyz
|
|
|
|
def xyz_to_frame_xyz(xyz, seq_unmasked, atom_frames):
|
|
"""
|
|
xyz (1, L, natoms, 3)
|
|
seq_unmasked (1, L)
|
|
atom_frames (1, L, 3, 2)
|
|
"""
|
|
xyz_frame = xyz.clone()
|
|
atoms = is_atom(seq_unmasked)
|
|
if torch.all(~atoms):
|
|
return xyz_frame
|
|
|
|
atom_crds = xyz_frame[atoms]
|
|
atom_L, natoms, _ = atom_crds.shape
|
|
frames_reindex = torch.zeros(atom_frames.shape[:-1])
|
|
|
|
for i in range(atom_L):
|
|
frames_reindex[:, i, :] = (i+atom_frames[..., i, :, 0])*natoms + atom_frames[..., i, :, 1]
|
|
frames_reindex = frames_reindex.long()
|
|
|
|
xyz_frame[atoms, :, :3] = atom_crds.reshape(atom_L*natoms, 3)[frames_reindex]
|
|
return xyz_frame
|
|
|
|
def xyz_frame_from_rotation_mask(xyz,rotation_mask, atom_frames):
|
|
"""
|
|
function to get xyz_frame for l1 feature in Structure module
|
|
xyz (1, L, natoms, 3)
|
|
rotation_mask (1, L)
|
|
atom_frames (1, L, 3, 2)
|
|
"""
|
|
xyz_frame = xyz.clone()
|
|
if torch.all(~rotation_mask):
|
|
return xyz_frame
|
|
|
|
atom_crds = xyz_frame[rotation_mask]
|
|
atom_L, natoms, _ = atom_crds.shape
|
|
frames_reindex = torch.zeros(atom_frames.shape[:-1])
|
|
|
|
for i in range(atom_L):
|
|
frames_reindex[:, i, :] = (i+atom_frames[..., i, :, 0])*natoms + atom_frames[..., i, :, 1]
|
|
frames_reindex = frames_reindex.long()
|
|
xyz_frame[rotation_mask, :, :3] = atom_crds.reshape(atom_L*natoms, 3)[frames_reindex]
|
|
return xyz_frame
|
|
|
|
def xyz_t_to_frame_xyz(xyz_t, seq_unmasked, atom_frames):
|
|
"""
|
|
Parameters:
|
|
xyz_t (1, T, L, natoms, 3)
|
|
seq_unmasked (B, L)
|
|
atom_frames (1, A, 3, 2)
|
|
Returns:
|
|
xyz_t_frame (B, T, L, natoms, 3)
|
|
"""
|
|
is_sm = is_atom(seq_unmasked[0])
|
|
return xyz_t_to_frame_xyz_sm_mask(xyz_t, is_sm, atom_frames)
|
|
|
|
def xyz_t_to_frame_xyz_sm_mask(xyz_t, is_sm, atom_frames):
|
|
"""
|
|
Parameters:
|
|
xyz_t (1, T, L, natoms, 3)
|
|
is_sm (L)
|
|
atom_frames (1, A, 3, 2)
|
|
Returns:
|
|
xyz_t_frame (B, T, L, natoms, 3)
|
|
"""
|
|
# ic(xyz_t.shape, is_sm.shape, atom_frames.shape)
|
|
# xyz_t.shape: torch.Size([1, 1, 194, 36, 3])
|
|
# is_sm.shape: torch.Size([194])
|
|
# atom_frames.shape: torch.Size([1, 29, 3, 2])
|
|
xyz_t_frame = xyz_t.clone()
|
|
atoms = is_sm
|
|
if torch.all(~atoms):
|
|
return xyz_t_frame
|
|
atom_crds_t = xyz_t_frame[:, :, atoms]
|
|
|
|
B, T, atom_L, natoms, _ = atom_crds_t.shape
|
|
frames_reindex = torch.zeros(atom_frames.shape[:-1])
|
|
for i in range(atom_L):
|
|
frames_reindex[:, i, :] = (i+atom_frames[..., i, :, 0])*natoms + atom_frames[..., i, :, 1]
|
|
frames_reindex = frames_reindex.long()
|
|
xyz_t_frame[:, :, atoms, :3] = atom_crds_t.reshape(T, atom_L*natoms, 3)[:, frames_reindex.squeeze(0)]
|
|
return xyz_t_frame
|
|
|
|
def get_frames(xyz_in, xyz_mask, seq, frame_indices, atom_frames=None):
|
|
#B,L,natoms = xyz_in.shape[:3]
|
|
frames = frame_indices[seq]
|
|
atoms = is_atom(seq)
|
|
if torch.any(atoms):
|
|
frames[:,atoms[0].nonzero().flatten(), 0] = atom_frames
|
|
|
|
frame_mask = ~torch.all(frames[...,0, :] == frames[...,1, :], axis=-1)
|
|
|
|
# frame_mask *= torch.all(
|
|
# torch.gather(xyz_mask,2,frames.reshape(B,L,-1)).reshape(B,L,-1,3),
|
|
# axis=-1)
|
|
|
|
return frames, frame_mask
|
|
|
|
def get_tips(xyz, seq):
|
|
B,L = xyz.shape[:2]
|
|
|
|
xyz_tips = torch.gather(xyz, 2, tip_indices.to(xyz.device)[seq][:,:,None,None].expand(-1,-1,-1,3)).reshape(B, L, 3)
|
|
if torch.isnan(xyz_tips).any(): # replace NaN tip atom with virtual Cb atom
|
|
# three anchor atoms
|
|
N = xyz[:,:,0]
|
|
Ca = xyz[:,:,1]
|
|
C = xyz[:,:,2]
|
|
|
|
# recreate Cb given N,Ca,C
|
|
b = Ca - N
|
|
c = C - Ca
|
|
a = torch.cross(b, c, dim=-1)
|
|
Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
|
|
|
|
xyz_tips = torch.where(torch.isnan(xyz_tips), Cb, xyz_tips)
|
|
return xyz_tips
|
|
|
|
def superimpose(pred, true, atom_mask):
|
|
|
|
def centroid(X):
|
|
return X.mean(dim=-2, keepdim=True)
|
|
|
|
B, L, natoms = pred.shape[:3]
|
|
|
|
# center to centroid
|
|
pred_allatom = pred[atom_mask][None]
|
|
true_allatom = true[atom_mask][None]
|
|
|
|
cp = centroid(pred_allatom)
|
|
ct = centroid(true_allatom)
|
|
|
|
pred_allatom_origin = pred_allatom - cp
|
|
true_allatom_origin = true_allatom - ct
|
|
|
|
# Computation of the covariance matrix
|
|
C = torch.matmul(pred_allatom_origin.permute(0,2,1), true_allatom_origin)
|
|
|
|
# Compute optimal rotation matrix using SVD
|
|
V, S, W = torch.svd(C)
|
|
|
|
# get sign to ensure right-handedness
|
|
d = torch.ones([B,3,3], device=pred.device)
|
|
d[:,:,-1] = torch.sign(torch.det(V)*torch.det(W)).unsqueeze(1)
|
|
|
|
# Rotation matrix U
|
|
U = torch.matmul(d*V, W.permute(0,2,1)) # (IB, 3, 3)
|
|
pred_rms = pred - cp
|
|
true_rms = true - ct
|
|
|
|
# Rotate pred
|
|
rP = torch.matmul(pred_rms, U) # (IB, L*3, 3)
|
|
|
|
return rP+ct
|
|
|
|
def writepdb(filename, *args, file_mode='w', **kwargs, ):
|
|
f = open(filename, file_mode)
|
|
writepdb_file(f, *args, **kwargs)
|
|
|
|
def writepdb_file(f, atoms, seq, modelnum=None, chain="A", idx_pdb=None, bfacts=None,
|
|
bond_feats=None, file_mode="w",atom_mask=None, atom_idx_offset=0, chain_Ls=None,
|
|
remap_atomtype=True, lig_name='LG1', atom_names=None):
|
|
|
|
def _get_atom_type(atom_name):
|
|
atype = ''
|
|
if atom_name[0].isalpha():
|
|
atype += atom_name[0]
|
|
atype += atom_name[1]
|
|
return atype
|
|
|
|
# if needed, correct mistake in atomic number assignment in RF2-allatom (fold&dock 3 & earlier)
|
|
atom_names_ = [
|
|
"F", "Cl", "Br", "I", "O", "S", "Se", "Te", "N", "P", "As", "Sb",
|
|
"C", "Si", "Ge", "Sn", "Pb", "B", "Al", "Zn", "Hg", "Cu", "Au", "Ni",
|
|
"Pd", "Pt", "Co", "Rh", "Ir", "Pr", "Fe", "Ru", "Os", "Mn", "Re", "Cr",
|
|
"Mo", "W", "V", "U", "Tb", "Y", "Be", "Mg", "Ca", "Li", "K", "ATM"]
|
|
atom_num = [
|
|
9, 17, 35, 53, 8, 16, 34, 52, 7, 15, 33, 51,
|
|
6, 14, 32, 50, 82, 5, 13, 30, 80, 29, 79, 28,
|
|
46, 78, 27, 45, 77, 59, 26, 44, 76, 25, 75, 24,
|
|
42, 74, 23, 92, 65, 39, 4, 12, 20, 3, 19, 0]
|
|
atomnum2atomtype_ = dict(zip(atom_num,atom_names_))
|
|
if remap_atomtype:
|
|
atomtype_map = {v:atomnum2atomtype_[k] for k,v in ChemData().atomnum2atomtype.items()}
|
|
else:
|
|
atomtype_map = {v:v for k,v in ChemData().atomnum2atomtype.items()} # no change
|
|
|
|
ctr = 1+atom_idx_offset
|
|
scpu = seq.cpu().squeeze(0)
|
|
atomscpu = atoms.cpu().squeeze(0)
|
|
|
|
if bfacts is None:
|
|
bfacts = torch.zeros(atomscpu.shape[0])
|
|
if idx_pdb is None:
|
|
idx_pdb = 1 + torch.arange(atomscpu.shape[0])
|
|
|
|
alphabet = list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789')
|
|
if chain_Ls is not None:
|
|
chain_letters = np.concatenate([np.full(L, alphabet[i]) for i,L in enumerate(chain_Ls)])
|
|
else:
|
|
chain_letters = [chain]*len(scpu)
|
|
|
|
if modelnum is not None:
|
|
f.write(f"MODEL {modelnum}\n")
|
|
|
|
Bfacts = torch.clamp( bfacts.cpu(), 0, 1)
|
|
atom_idxs = {}
|
|
i_res_lig = 0
|
|
for i_res,s,ch in zip(range(len(scpu)), scpu, chain_letters):
|
|
natoms = atomscpu.shape[-2]
|
|
#if (natoms!=NHEAVY and natoms!=NTOTAL and natoms!=3):
|
|
# print ('bad size!', natoms, NHEAVY, NTOTAL, atoms.shape)
|
|
# assert(False)
|
|
|
|
if s >= len(ChemData().aa2long):
|
|
atom_idxs[i_res] = ctr
|
|
|
|
# hack to make sure H's are output properly (they are not in RFAA alphabet)
|
|
if atom_names is not None:
|
|
atom_type = _get_atom_type(atom_names[i_res_lig])
|
|
atom_name = atom_names[i_res_lig]
|
|
else:
|
|
atom_type = atomtype_map[ChemData().num2aa[s]]
|
|
atom_name = atom_type
|
|
|
|
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f %+2s\n"%(
|
|
"HETATM", ctr, atom_name, lig_name,
|
|
ch, idx_pdb.max()+10, atomscpu[i_res,1,0], atomscpu[i_res,1,1], atomscpu[i_res,1,2],
|
|
1.0, Bfacts[i_res], atom_type) )
|
|
i_res_lig += 1
|
|
ctr += 1
|
|
continue
|
|
|
|
atms = ChemData().aa2long[s]
|
|
|
|
for i_atm,atm in enumerate(atms):
|
|
if atom_mask is not None and not atom_mask[i_res,i_atm]: continue # skip missing atoms
|
|
if (i_atm<natoms and atm is not None and not torch.isnan(atomscpu[i_res,i_atm,:]).any()):
|
|
f.write ("%-6s%5s %4s %3s %s%4d %8.3f%8.3f%8.3f%6.2f%6.2f\n"%(
|
|
"ATOM", ctr, atm, ChemData().num2aa[s],
|
|
ch, idx_pdb[i_res], atomscpu[i_res,i_atm,0], atomscpu[i_res,i_atm,1], atomscpu[i_res,i_atm,2],
|
|
1.0, Bfacts[i_res] ) )
|
|
ctr += 1
|
|
if bond_feats != None:
|
|
atom_bonds = (bond_feats > 0) * (bond_feats <5)
|
|
atom_bonds = atom_bonds.cpu()
|
|
b, i, j = atom_bonds.nonzero(as_tuple=True)
|
|
for start, end in zip(i,j):
|
|
#print (start,end,bond_feats)
|
|
f.write(f"CONECT{atom_idxs[int(start.cpu().numpy())]:5d}{atom_idxs[int(end.cpu().numpy())]:5d}\n")
|
|
if modelnum is not None:
|
|
f.write("ENDMDL\n")
|
|
|
|
|
|
### Create atom frames for FAPE loss calculation ###
|
|
def get_nxgraph(mol):
|
|
'''build NetworkX graph from openbabel's OBMol'''
|
|
|
|
N = mol.NumAtoms()
|
|
|
|
# pairs of bonded atoms, openbabel indexes from 1 so readjust to indexing from 0
|
|
bonds = [(bond.GetBeginAtomIdx()-1, bond.GetEndAtomIdx()-1) for bond in openbabel.OBMolBondIter(mol)]
|
|
|
|
# connectivity graph
|
|
G = nx.Graph()
|
|
G.add_nodes_from(range(N))
|
|
G.add_edges_from(bonds)
|
|
|
|
return G
|
|
|
|
def find_all_rigid_groups(bond_feats):
|
|
"""
|
|
remove all single bonds from the graph and find connected components
|
|
"""
|
|
rigid_atom_bonds = (bond_feats>1)*(bond_feats<5)
|
|
rigid_atom_bonds_np = rigid_atom_bonds[0].cpu().numpy()
|
|
G = nx.from_numpy_array(rigid_atom_bonds_np)
|
|
connected_components = nx.connected_components(G)
|
|
connected_components = [cc for cc in connected_components if len(cc)>2]
|
|
connected_components = [torch.tensor(list(combinations(cc,2))) for cc in connected_components]
|
|
if connected_components:
|
|
connected_components = torch.cat(connected_components, dim=0)
|
|
else:
|
|
connected_components = None
|
|
return connected_components
|
|
|
|
def find_all_paths_of_length_n(G : nx.Graph,
|
|
n : int,
|
|
**karg) -> torch.Tensor:
|
|
'''find all paths of length N in a networkx graph
|
|
https://stackoverflow.com/questions/28095646/finding-all-paths-walks-of-given-length-in-a-networkx-graph'''
|
|
|
|
def findPaths(G,u,n):
|
|
if n==0:
|
|
return [[u]]
|
|
paths = [[u]+path for neighbor in G.neighbors(u) for path in findPaths(G,neighbor,n-1) if u not in path]
|
|
return paths
|
|
|
|
# all paths of length n
|
|
allpaths = [tuple(p) if p[0]<p[-1] else tuple(reversed(p))
|
|
for node in G for p in findPaths(G,node,n)]
|
|
|
|
if 'omit_permutation' in karg.keys() and not karg['omit_permutation']:
|
|
allpaths = [tuple(p) for node in G for p in findPaths(G,node,n)]
|
|
|
|
# unique paths
|
|
allpaths = list(set(allpaths))
|
|
|
|
#return torch.tensor(allpaths)
|
|
return allpaths
|
|
|
|
def get_atom_frames(msa, G, **karg):
|
|
"""choose a frame of 3 bonded atoms for each atom in the molecule, rule based system that chooses frame based on atom priorities"""
|
|
query_seq = msa
|
|
frames = find_all_paths_of_length_n(G, 2, **karg)
|
|
selected_frames = []
|
|
for n in range(msa.shape[0]):
|
|
frames_with_n = [frame for frame in frames if n == frame[1]]
|
|
|
|
# some chemical groups don't have two bonded heavy atoms; so choose a frame with an atom 2 bonds away
|
|
if not frames_with_n:
|
|
frames_with_n = [frame for frame in frames if n in frame]
|
|
# if the atom isn't in a 3 atom frame, it should be ignored in loss calc, set all the atoms to n
|
|
if not frames_with_n:
|
|
selected_frames.append([(0,1),(0,1),(0, 1)])
|
|
continue
|
|
frame_priorities = []
|
|
for frame in frames_with_n:
|
|
# hacky but uses the "query_seq" to convert index of the atom into an "atom type" and converts that into a priority
|
|
indices = [index for index in frame if index!=n]
|
|
aas = [ChemData().num2aa[int(query_seq[index].numpy())] for index in indices]
|
|
if 'omit_permutation' in karg.keys() and not karg['omit_permutation']:
|
|
frame_priorities.append([ChemData().atom2frame_priority[aa] for aa in aas])
|
|
else:
|
|
frame_priorities.append(sorted([ChemData().atom2frame_priority[aa] for aa in aas]))
|
|
|
|
|
|
|
|
# np.argsort doesn't sort tuples correctly so just sort a list of indices using a key
|
|
sorted_indices = sorted(range(len(frame_priorities)), key=lambda i: frame_priorities[i])
|
|
# calculate residue offset for frame
|
|
frame = [(frame-n, 1) for frame in frames_with_n[sorted_indices[0]]]
|
|
selected_frames.append(frame)
|
|
assert msa.shape[0] == len(selected_frames)
|
|
return torch.tensor(selected_frames).long()
|
|
|
|
|
|
### Generate bond features for small molecules ###
|
|
def get_bond_feats(mol):
|
|
"""creates 2d bond graph for small molecules"""
|
|
N = mol.NumAtoms()
|
|
bond_feats = torch.zeros((N, N)).long()
|
|
|
|
for bond in openbabel.OBMolBondIter(mol):
|
|
i,j = (bond.GetBeginAtomIdx()-1, bond.GetEndAtomIdx()-1)
|
|
bond_feats[i,j] = bond.GetBondOrder() if not bond.IsAromatic() else 4
|
|
bond_feats[j,i] = bond_feats[i,j]
|
|
|
|
return bond_feats.long()
|
|
|
|
def get_protein_bond_feats(protein_L):
|
|
""" creates protein residue connectivity graphs """
|
|
bond_feats = torch.zeros((protein_L, protein_L))
|
|
residues = torch.arange(protein_L-1)
|
|
bond_feats[residues, residues+1] = 5
|
|
bond_feats[residues+1, residues] = 5
|
|
return bond_feats
|
|
|
|
def get_protein_bond_feats_from_idx(protein_L, idx_protein):
|
|
""" creates protein residue connectivity graphs """
|
|
bond_feats = torch.zeros((protein_L, protein_L))
|
|
residues = torch.arange(protein_L-1)
|
|
mask = idx_protein[:,None] == idx_protein[None,:]+1
|
|
bond_feats[mask] = 5
|
|
bond_feats[mask.T] = 5
|
|
return bond_feats
|
|
|
|
def get_atomize_protein_bond_feats(i_start, msa, ra, n_res_atomize=5):
|
|
"""
|
|
generate atom bond features for atomized residues
|
|
currently ignores long-range bonds like disulfides
|
|
"""
|
|
ra2ind = {}
|
|
for i, two_d in enumerate(ra):
|
|
ra2ind[tuple(two_d.numpy())] = i
|
|
N = len(ra2ind.keys())
|
|
bond_feats = torch.zeros((N, N))
|
|
for i, res in enumerate(msa[0, i_start:i_start+n_res_atomize]):
|
|
for j, bond in enumerate(ChemData().aabonds[res]):
|
|
start_idx = ChemData().aa2long[res].index(bond[0])
|
|
end_idx = ChemData().aa2long[res].index(bond[1])
|
|
if (i, start_idx) not in ra2ind or (i, end_idx) not in ra2ind:
|
|
#skip bonds with atoms that aren't observed in the structure
|
|
continue
|
|
start_idx = ra2ind[(i, start_idx)]
|
|
end_idx = ra2ind[(i, end_idx)]
|
|
|
|
# maps the 2d index of the start and end indices to btype
|
|
bond_feats[start_idx, end_idx] = ChemData().aabtypes[res][j]
|
|
bond_feats[end_idx, start_idx] = ChemData().aabtypes[res][j]
|
|
#accounting for peptide bonds
|
|
if i > 0:
|
|
if (i-1, 2) not in ra2ind or (i, 0) not in ra2ind:
|
|
#skip bonds with atoms that aren't observed in the structure
|
|
continue
|
|
start_idx = ra2ind[(i-1, 2)]
|
|
end_idx = ra2ind[(i, 0)]
|
|
bond_feats[start_idx, end_idx] = ChemData().SINGLE_BOND
|
|
bond_feats[end_idx, start_idx] = ChemData().SINGLE_BOND
|
|
return bond_feats
|
|
|
|
|
|
### Generate atom features for proteins ###
|
|
def atomize_protein(i_start, msa, xyz, mask, n_res_atomize=5):
|
|
""" given an index i_start, make the following flank residues into "atom" nodes """
|
|
residues_atomize = msa[0, i_start:i_start+n_res_atomize]
|
|
residues_atom_types = [ChemData().aa2elt[num][:14] for num in residues_atomize]
|
|
residue_atomize_mask = mask[i_start:i_start+n_res_atomize].float() # mask of resolved atoms in the sidechain
|
|
residue_atomize_allatom_mask = ChemData().allatom_mask[residues_atomize][:, :14] # the indices that have heavy atoms in that sidechain
|
|
xyz_atomize = xyz[i_start:i_start+n_res_atomize]
|
|
|
|
# handle symmetries
|
|
xyz_alt = torch.zeros_like(xyz.unsqueeze(0))
|
|
xyz_alt.scatter_(2, ChemData().long2alt[msa[0],:,None].repeat(1,1,1,3), xyz.unsqueeze(0))
|
|
xyz_alt_atomize = xyz_alt[0, i_start:i_start+n_res_atomize]
|
|
|
|
coords_stack = torch.stack((xyz_atomize, xyz_alt_atomize), dim=0)
|
|
swaps = (coords_stack[0] == coords_stack[1]).all(dim=1).all(dim=1).squeeze() #checks whether theres a swap at each position
|
|
swaps = torch.nonzero(~swaps).squeeze() # indices with a swap eg. [2,3]
|
|
if swaps.numel() != 0:
|
|
# if there are residues with alternate numbering scheme, create a stack of coordinate with each combo of swaps
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore",category=UserWarning)
|
|
combs = torch.combinations(torch.tensor([0,1]), r=swaps.numel(), with_replacement=True) #[[0,0], [0,1], [1,1]]
|
|
stack = torch.stack((combs, swaps.repeat(swaps.numel()+1,1)), dim=-1).squeeze()
|
|
coords_stack = coords_stack.repeat(swaps.numel()+1,1,1,1)
|
|
nat_symm = coords_stack[0].repeat(swaps.numel()+1,1,1,1) # (N_symm, num_atomize_residues, natoms, 3)
|
|
swapped_coords = coords_stack[stack[...,0], stack[...,1]].squeeze(1) #
|
|
nat_symm[:,swaps] = swapped_coords
|
|
else:
|
|
nat_symm = xyz_atomize.unsqueeze(0)
|
|
# every heavy atom that is in the sidechain is modelled but losses only applied to resolved atoms
|
|
ra = residue_atomize_allatom_mask.nonzero()
|
|
lig_seq = torch.tensor([ChemData().aa2num[residues_atom_types[r][a]] if residues_atom_types[r][a] in ChemData().aa2num else ChemData().aa2num["ATM"] for r,a in ra])
|
|
ins = torch.zeros_like(lig_seq)
|
|
|
|
r,a = ra.T
|
|
lig_xyz = torch.zeros((len(ra), 3))
|
|
lig_xyz = nat_symm[:, r, a]
|
|
lig_mask = residue_atomize_mask[r, a].repeat(nat_symm.shape[0], 1)
|
|
bond_feats = get_atomize_protein_bond_feats(i_start, msa, ra, n_res_atomize=n_res_atomize)
|
|
#HACK: use networkx graph to make the atom frames, correct implementation will include frames with "residue atoms"
|
|
G = nx.from_numpy_array(bond_feats.numpy())
|
|
|
|
frames = get_atom_frames(lig_seq, G)
|
|
chirals = get_atomize_protein_chirals(residues_atomize, lig_xyz[0], residue_atomize_allatom_mask, bond_feats)
|
|
return lig_seq, ins, lig_xyz, lig_mask, frames, bond_feats, ra, chirals
|
|
|
|
def atomize_discontiguous_residues(idxs, msa, xyz, mask, bond_feats, same_chain, dslfs=None):
|
|
"""
|
|
this atomizes multiple discontiguous residues at the same time, this is the default interface into atomizing residues
|
|
(using the non assembly dataset)
|
|
"""
|
|
protein_L = msa.shape[1]
|
|
seq_atomize_all = []
|
|
ins_atomize_all = []
|
|
xyz_atomize_all = []
|
|
mask_atomize_all = []
|
|
frames_atomize_all = []
|
|
chirals_atomize_all = []
|
|
prev_C_index = None
|
|
total_num_atoms = 0
|
|
sgs = {}
|
|
for idx in idxs:
|
|
seq_atomize, ins_atomize, xyz_atomize, mask_atomize, frames_atomize, bond_feats_atomize, resatom2idx, chirals_atomize = \
|
|
atomize_protein(idx, msa, xyz, mask, n_res_atomize=1)
|
|
r,_ = resatom2idx.T
|
|
#print ('atomize_discontiguous_residues', idx, resatom2idx)
|
|
last_C = torch.all(resatom2idx==torch.tensor([r[-1],2]),dim=1).nonzero()
|
|
sgs[idx.item()] = torch.all(resatom2idx==torch.tensor([r[-1],5]),dim=1).nonzero()
|
|
|
|
natoms = seq_atomize.shape[0]
|
|
L = bond_feats.shape[0]
|
|
|
|
sgs[idx.item()] = L+sgs[idx.item()]
|
|
|
|
# update the chirals to be after all the other atoms (still need to update to put it behind all the proteins)
|
|
chirals_atomize[:, :-1] += total_num_atoms
|
|
|
|
seq_atomize_all.append(seq_atomize)
|
|
ins_atomize_all.append(ins_atomize)
|
|
xyz_atomize_all.append(xyz_atomize)
|
|
mask_atomize_all.append(mask_atomize)
|
|
frames_atomize_all.append(frames_atomize)
|
|
chirals_atomize_all.append(chirals_atomize)
|
|
|
|
N_term = idx == 0
|
|
C_term = idx == protein_L-1
|
|
|
|
# update bond_feats every iteration, update all other features at the end
|
|
bond_feats_new = torch.zeros((L+natoms, L+natoms))
|
|
bond_feats_new[:L, :L] = bond_feats
|
|
bond_feats_new[L:, L:] = bond_feats_atomize
|
|
# add bond between protein and atomized N
|
|
if not N_term and idx-1 not in idxs:
|
|
bond_feats_new[idx-1, L] = 6 # protein (backbone)-atom bond
|
|
bond_feats_new[L, idx-1] = 6 # protein (backbone)-atom bond
|
|
# add bond between protein and C, assumes every residue is being atomized one at a time (eg n_res_atomize=1)
|
|
if not C_term and idx+1 not in idxs:
|
|
bond_feats_new[idx+1, L+int(last_C.numpy())] = 6 # protein (backbone)-atom bond
|
|
bond_feats_new[L+int(last_C.numpy()), idx+1] = 6 # protein (backbone)-atom bond
|
|
# handle drawing peptide bond between contiguous atomized residues
|
|
if idx-1 in idxs:
|
|
if prev_C_index is None:
|
|
raise ValueError("prev_C_index is None even though the previous residue has been atomized")
|
|
bond_feats_new[prev_C_index, L] = 1 # single bond
|
|
bond_feats_new[L, prev_C_index] = 1 # single bond
|
|
|
|
prev_C_index = L+int(last_C.numpy()) #update prev_C to draw bond to upcoming residue
|
|
# update same_chain every iteration
|
|
same_chain_new = torch.zeros((L+natoms, L+natoms))
|
|
same_chain_new[:L, :L] = same_chain
|
|
residues_in_prot_chain = same_chain[idx].squeeze().nonzero()
|
|
|
|
same_chain_new[L:, residues_in_prot_chain] = 1
|
|
same_chain_new[residues_in_prot_chain, L:] = 1
|
|
same_chain_new[L:, L:] = 1
|
|
|
|
bond_feats = bond_feats_new
|
|
same_chain = same_chain_new
|
|
total_num_atoms += natoms
|
|
|
|
# disulfides
|
|
if dslfs is not None:
|
|
for i,j in dslfs:
|
|
start_idx = sgs[i].item()
|
|
end_idx = sgs[j].item()
|
|
bond_feats[start_idx, end_idx] = 1
|
|
bond_feats[end_idx, start_idx] = 1
|
|
|
|
seq_atomize_all = torch.cat(seq_atomize_all)
|
|
ins_atomize_all = torch.cat(ins_atomize_all)
|
|
xyz_atomize_all = cartprodcat(xyz_atomize_all)
|
|
mask_atomize_all = cartprodcat(mask_atomize_all)
|
|
|
|
# frames were calculated per residue -- we want them over all residues in case there are contiguous residues
|
|
bond_feats_sm = bond_feats[protein_L:][:, protein_L:]
|
|
G = nx.from_numpy_array(bond_feats_sm.detach().cpu().numpy())
|
|
frames_atomize_all = get_atom_frames(seq_atomize_all, G)
|
|
|
|
# frames_atomize_all = torch.cat(frames_atomize_all)
|
|
chirals_atomize_all = torch.cat(chirals_atomize_all)
|
|
|
|
return seq_atomize_all, ins_atomize_all, xyz_atomize_all, mask_atomize_all, frames_atomize_all, chirals_atomize_all, \
|
|
bond_feats, same_chain
|
|
|
|
def reindex_protein_feats_after_atomize(
|
|
residues_to_atomize,
|
|
prot_partners,
|
|
msa,
|
|
ins,
|
|
xyz,
|
|
mask,
|
|
bond_feats,
|
|
idx,
|
|
xyz_t,
|
|
f1d_t,
|
|
mask_t,
|
|
same_chain,
|
|
ch_label,
|
|
Ls_prot,
|
|
Ls_sm,
|
|
akeys_sm,
|
|
remove_residue=True
|
|
):
|
|
"""
|
|
Removes residues that have been atomized from protein features.
|
|
"""
|
|
Ls = Ls_prot + Ls_sm
|
|
chain_bins = [sum(Ls[:i]) for i in range(len(Ls)+1)]
|
|
akeys_sm = list(itertools.chain.from_iterable(akeys_sm)) # list of list of tuples get flattened to a list of tuples
|
|
|
|
# get tensor indices of atomized residues
|
|
residue_chain_nums = []
|
|
residue_indices = []
|
|
for residue in residues_to_atomize:
|
|
# residue object is a list of tuples:
|
|
# ((chain_letter, res_number, res_name), (chain_letter, xform_index))
|
|
|
|
#### Need to identify what chain you're in to get correct res idx
|
|
residue_chid_xf = residue[1]
|
|
residue_chain_num = [p[:2] for p in prot_partners].index(residue_chid_xf)
|
|
residue_index = (int(residue[0][1]) - 1) + sum(Ls_prot[:residue_chain_num]) # residues are 1 indexed in the cif files
|
|
|
|
# skip residues with all backbone atoms masked
|
|
if torch.sum(mask[0, residue_index, :3]) <3: continue
|
|
|
|
residue_chain_nums.append(residue_chain_num)
|
|
residue_indices.append(residue_index)
|
|
atomize_N = residue[0] + ("N",)
|
|
atomize_C = residue[0] + ("C",)
|
|
|
|
N_index = akeys_sm.index(atomize_N) + sum(Ls_prot)
|
|
C_index = akeys_sm.index(atomize_C) + sum(Ls_prot)
|
|
|
|
# if first residue in chain, no extra bond feats to previous residue
|
|
if residue_index != 0 and residue_index not in Ls_prot:
|
|
bond_feats[residue_index-1, N_index] = 6
|
|
bond_feats[N_index, residue_index-1] = 6
|
|
|
|
# if residue is last in chain, no extra bonds feats to following residue
|
|
if residue_index not in [L-1 for L in Ls_prot]:
|
|
bond_feats[residue_index+1, C_index] = 6
|
|
bond_feats[C_index,residue_index+1] = 6
|
|
|
|
lig_chain_num = np.digitize([N_index], chain_bins)[0] -1 # np.digitize is 1 indexed
|
|
same_chain[chain_bins[lig_chain_num]:chain_bins[lig_chain_num+1], \
|
|
chain_bins[residue_chain_num]: chain_bins[residue_chain_num+1]] = 1
|
|
same_chain[chain_bins[residue_chain_num]: chain_bins[residue_chain_num+1], \
|
|
chain_bins[lig_chain_num]:chain_bins[lig_chain_num+1]] = 1
|
|
|
|
if remove_residue:
|
|
# remove atomized residues from feature tensors
|
|
i_res = torch.tensor([i for i in range(sum(Ls)) if i not in residue_indices])
|
|
msa = msa[:,i_res]
|
|
ins = ins[:,i_res]
|
|
xyz = xyz[:,i_res]
|
|
mask = mask[:,i_res]
|
|
bond_feats = bond_feats[i_res][:,i_res]
|
|
idx = idx[i_res]
|
|
xyz_t = xyz_t[:,i_res]
|
|
f1d_t = f1d_t[:,i_res]
|
|
mask_t = mask_t[:,i_res]
|
|
same_chain = same_chain[i_res][:,i_res]
|
|
ch_label = ch_label[i_res]
|
|
|
|
for i_ch in residue_chain_nums:
|
|
Ls_prot[i_ch] -= 1
|
|
|
|
return msa, ins, xyz, mask, bond_feats, idx, xyz_t, f1d_t, mask_t, same_chain, ch_label, Ls_prot, Ls_sm
|
|
|
|
|
|
def pop_protein_feats(residue_indices, msa, ins, xyz, mask, bond_feats, idx, xyz_t, f1d_t, mask_t, same_chain, ch_label, Ls):
|
|
"""
|
|
remove protein features for an arbitrary set of residue indices
|
|
"""
|
|
pop = torch.ones((sum(Ls)))
|
|
pop[residue_indices] = 0
|
|
pop = pop.bool()
|
|
|
|
msa = msa[:,pop]
|
|
ins = ins[:,pop]
|
|
xyz = xyz[:,pop]
|
|
mask = mask[:,pop]
|
|
bond_feats = bond_feats[pop][:,pop]
|
|
idx = idx[pop]
|
|
xyz_t = xyz_t[:,pop]
|
|
f1d_t = f1d_t[:,pop]
|
|
mask_t = mask_t[:,pop]
|
|
same_chain = same_chain[pop][:,pop]
|
|
ch_label = ch_label[pop]
|
|
|
|
return msa, ins, xyz, mask, bond_feats, idx, xyz_t, f1d_t, mask_t, same_chain, ch_label
|
|
|
|
def get_automorphs(mol, xyz_sm, mask_sm, max_symm=1000):
|
|
"""Enumerate atom symmetry permutations."""
|
|
try:
|
|
automorphs = openbabel.vvpairUIntUInt()
|
|
openbabel.FindAutomorphisms(mol, automorphs)
|
|
|
|
automorphs = torch.tensor(automorphs)
|
|
n_symmetry = automorphs.shape[0]
|
|
|
|
xyz_sm = xyz_sm[None].repeat(n_symmetry,1,1)
|
|
mask_sm = mask_sm[None].repeat(n_symmetry,1)
|
|
|
|
xyz_sm = torch.scatter(xyz_sm, 1, automorphs[:,:,0:1].repeat(1,1,3),
|
|
torch.gather(xyz_sm,1,automorphs[:,:,1:2].repeat(1,1,3)))
|
|
mask_sm = torch.scatter(mask_sm, 1, automorphs[:,:,0],
|
|
torch.gather(mask_sm, 1, automorphs[:,:,1]))
|
|
except Exception as e:
|
|
xyz_sm = xyz_sm[None]
|
|
mask_sm = mask_sm[None]
|
|
if xyz_sm.shape[0] > max_symm:
|
|
xyz_sm = xyz_sm[:max_symm]
|
|
mask_sm = mask_sm[:max_symm]
|
|
return xyz_sm, mask_sm
|
|
|
|
def expand_xyz_sm_to_ntotal(xyz_sm, mask_sm, N_symmetry=None):
|
|
"""
|
|
for small molecules, takes a 1d xyz tensor and converts to using N_total
|
|
"""
|
|
N_symm_sm, L = xyz_sm.shape[:2]
|
|
if N_symmetry is None:
|
|
N_symmetry = N_symm_sm
|
|
xyz = torch.full((N_symmetry, L, ChemData().NTOTAL, 3), np.nan).float()
|
|
xyz[:N_symm_sm, :, 1, :] = xyz_sm
|
|
|
|
mask = torch.full((N_symmetry, L, ChemData().NTOTAL), False).bool()
|
|
mask[:N_symm_sm, :, 1] = mask_sm
|
|
return xyz, mask
|
|
|
|
def same_chain_2d_from_Ls(Ls):
|
|
"""Given list of chain lengths, returns binary matrix with 1 if two residues are on the same chain."""
|
|
same_chain = torch.zeros((sum(Ls),sum(Ls))).long()
|
|
i_curr = 0
|
|
for L in Ls:
|
|
same_chain[i_curr:i_curr+L, i_curr:i_curr+L] = 1
|
|
i_curr += L
|
|
return same_chain
|
|
|
|
def Ls_from_same_chain_2d(same_chain):
|
|
"""Given binary matrix indicating whether two residues are on same chain, returns list of chain lengths"""
|
|
if len(same_chain.shape)==3: # remove batch dimension
|
|
same_chain = same_chain.squeeze(0)
|
|
Ls = []
|
|
i_curr = 0
|
|
while i_curr < len(same_chain):
|
|
idx = torch.where(same_chain[i_curr])[0]
|
|
Ls.append(int(idx[-1]-idx[0]+1))
|
|
i_curr = idx[-1]+1
|
|
return Ls
|
|
|
|
def get_prot_seqstring(ch, modres):
|
|
"""Return string representing amino acid sequence of a parsed CIF chain."""
|
|
idx = [int(k[1]) for k in ch.atoms]
|
|
i_min, i_max = np.min(idx), np.max(idx)
|
|
L = i_max - i_min + 1
|
|
seq = ["-"]*L
|
|
|
|
for k,v in ch.atoms.items():
|
|
i_res = int(k[1])-i_min
|
|
if k[2] in ChemData().to1letter: # standard AA
|
|
aa = ChemData().to1letter[k[2]]
|
|
elif k[2] in modres and modres[k[2]] in ChemData().to1letter: # nonstandard AA, map to standard
|
|
aa = ChemData().to1letter[modres[k[2]]]
|
|
else: # unknown AA, still try to store BB atoms
|
|
aa = 'X'
|
|
seq[i_res] = aa
|
|
return ''.join(seq)
|
|
|
|
def map_identical_prot_chains(partners, chains, modres):
|
|
"""Identifies which chain letters represent unique protein sequences,
|
|
assigns a number to each unique sequence, and returns dicts mapping sequence
|
|
numbers to chain letters and vice versa.
|
|
|
|
Parameters
|
|
----------
|
|
partners : list of tuples (partner, transform_index, num_contacts, partner_type)
|
|
Information about neighboring chains to the query ligand in an
|
|
assembly. This function will use the subset of these tuples that
|
|
represent protein chains, where `partner_type = 'polypeptide(L)'`
|
|
and `partner` contains the chain letter. `transform_index` is an
|
|
integer index of the coordinate transform for each partner chain.
|
|
chains : dict
|
|
Dictionary mapping chain letters to cifutils.Chain objects representing
|
|
the chains in a PDB entry.
|
|
modres : dict
|
|
Maps modified residue names to their canonical equivalents. Any
|
|
modified residue will be converted to its standard equivalent and
|
|
coordinates for atoms with matching names will be saved.
|
|
|
|
Returns
|
|
-------
|
|
chnum2chlet : dict
|
|
Dictionary mapping integers to lists of chain letters which represent
|
|
identical chains
|
|
"""
|
|
chlet2seq = OrderedDict()
|
|
for p in partners:
|
|
if p[-1] != 'polypeptide(L)': continue
|
|
if p[0] not in chlet2seq:
|
|
chlet2seq[p[0]] = get_prot_seqstring(chains[p[0]], modres)
|
|
|
|
seq2chlet = OrderedDict()
|
|
for chlet, seq in chlet2seq.items():
|
|
if seq not in seq2chlet:
|
|
seq2chlet[seq] = set()
|
|
seq2chlet[seq].add(chlet)
|
|
|
|
chnum2chlet = OrderedDict([(i,v) for i,(k,v) in enumerate(seq2chlet.items())])
|
|
#chlet2chnum = OrderedDict([(chlet,chnum) for chnum,chlet_s in chnum2chlet.items() for chlet in chlet_s])
|
|
|
|
return chnum2chlet
|
|
|
|
def cartprodcat(X_s):
|
|
"""Concatenate list of tensors on dimension 1 while taking their cartesian product
|
|
over dimension 0."""
|
|
X = X_s[0]
|
|
for X_ in X_s[1:]:
|
|
N, L = X.shape[:2]
|
|
N_, L_ = X_.shape[:2]
|
|
X_out = torch.full((N, N_, L+L_,)+X.shape[2:], np.nan)
|
|
for i in range(N):
|
|
for j in range(N_):
|
|
X_out[i,j] = torch.concat([X[i], X_[j]], dim=0)
|
|
dims = (N*N_,L+L_,)+X.shape[2:]
|
|
X = X_out.view(*dims)
|
|
return X
|
|
|
|
def idx_from_Ls(Ls):
|
|
"""Generate residue indexes from a list of chain lengths,
|
|
with a chain gap offset between indexes for each chain."""
|
|
idx = []
|
|
offset = 0
|
|
for L in Ls:
|
|
idx.append(torch.arange(L)+offset)
|
|
offset = offset+L+ChemData().CHAIN_GAP
|
|
return torch.cat(idx, dim=0)
|
|
|
|
|
|
def bond_feats_from_Ls(Ls):
|
|
"""Generate protein (or DNA/RNA) bond features from a list of chain
|
|
lengths"""
|
|
bond_feats = torch.zeros((sum(Ls), sum(Ls))).long()
|
|
offset = 0
|
|
for L_ in Ls:
|
|
bond_feats[offset:offset+L_, offset:offset+L_] = get_protein_bond_feats(L_)
|
|
offset += L_
|
|
return bond_feats
|
|
|
|
def same_chain_from_bond_feats(bond_feats):
|
|
"""Return binary matrix indicating if pairs of residues are on same chain,
|
|
given their bond features.
|
|
"""
|
|
assert(len(bond_feats.shape)==2) # assume no batch dimension
|
|
L = bond_feats.shape[0]
|
|
same_chain = torch.zeros((L,L))
|
|
G = nx.from_numpy_array(bond_feats.detach().cpu().numpy())
|
|
for idx in nx.connected_components(G):
|
|
idx = list(idx)
|
|
for i in idx:
|
|
same_chain[i,idx] = 1
|
|
return same_chain
|
|
|
|
|
|
def kabsch(xyz1, xyz2, eps=1e-6):
|
|
"""Superimposes `xyz2` coordinates onto `xyz1`, returns RMSD and rotation matrix."""
|
|
# center to CA centroid
|
|
xyz1 = xyz1 - xyz1.mean(0)
|
|
xyz2 = xyz2 - xyz2.mean(0)
|
|
|
|
# Computation of the covariance matrix
|
|
C = xyz2.T @ xyz1
|
|
|
|
# Compute optimal rotation matrix using SVD
|
|
V, S, W = torch.linalg.svd(C)
|
|
|
|
# get sign to ensure right-handedness
|
|
d = torch.ones([3,3])
|
|
d[:,-1] = torch.sign(torch.linalg.det(V)*torch.linalg.det(W))
|
|
|
|
# Rotation matrix U
|
|
U = (d*V) @ W
|
|
|
|
# Rotate xyz2
|
|
xyz2_ = xyz2 @ U
|
|
|
|
L = xyz2_.shape[0]
|
|
|
|
rmsd = torch.sqrt(torch.sum((xyz2_-xyz1)*(xyz2_-xyz1), axis=(0,1)) / L + eps)
|
|
|
|
return rmsd, U
|
|
|
|
|