mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-14 22:33:58 +00:00
202 lines
8.4 KiB
Python
202 lines
8.4 KiB
Python
import torch
|
|
from dataclasses import dataclass, fields
|
|
from typing import Optional, List
|
|
|
|
from rf2aa.chemical import ChemicalData as ChemData
|
|
from rf2aa.data.data_loader_utils import MSAFeaturize, get_bond_distances, generate_xyz_prev
|
|
from rf2aa.kinematics import xyz_to_t2d
|
|
from rf2aa.util import get_prot_sm_mask, xyz_t_to_frame_xyz, same_chain_from_bond_feats, \
|
|
Ls_from_same_chain_2d, idx_from_Ls, is_atom
|
|
|
|
|
|
|
|
@dataclass
|
|
class RawInputData:
|
|
msa: torch.Tensor
|
|
ins: torch.Tensor
|
|
bond_feats: torch.Tensor
|
|
xyz_t: torch.Tensor
|
|
mask_t: torch.Tensor
|
|
t1d: torch.Tensor
|
|
chirals: torch.Tensor
|
|
atom_frames: torch.Tensor
|
|
taxids: Optional[List[str]] = None
|
|
term_info: Optional[torch.Tensor] = None
|
|
chain_lengths: Optional[List] = None
|
|
idx: Optional[List] = None
|
|
|
|
def query_sequence(self):
|
|
return self.msa[0]
|
|
|
|
def sequence_string(self):
|
|
three_letter_sequence = [ChemData().num2aa[num] for num in self.query_sequence()]
|
|
return "".join([ChemData().aa_321[three] for three in three_letter_sequence])
|
|
|
|
def is_atom(self):
|
|
return is_atom(self.query_sequence())
|
|
|
|
def length(self):
|
|
return self.msa.shape[1]
|
|
|
|
def get_chain_bins_from_chain_lengths(self):
|
|
if self.chain_lengths is None:
|
|
raise ValueError("Cannot call get_chain_bins_from_chain_lengths without \
|
|
setting chain_lengths. Chain_lengths is set in merge_inputs")
|
|
chain_bins = {}
|
|
running_length = 0
|
|
for chain, length in self.chain_lengths:
|
|
chain_bins[chain] = (running_length, running_length+length)
|
|
running_length = running_length + length
|
|
return chain_bins
|
|
|
|
def update_protein_features_after_atomize(self, residues_to_atomize):
|
|
if self.chain_lengths is None:
|
|
raise("Cannot update protein features without chain_lengths. \
|
|
merge_inputs must be called before this function")
|
|
chain_bins = self.get_chain_bins_from_chain_lengths()
|
|
keep = torch.ones(self.length())
|
|
prev_absolute_index = None
|
|
prev_C = None
|
|
#need to atomize residues from N term to Cterm to handle atomizing neighbors
|
|
residues_to_atomize = sorted(residues_to_atomize, key= lambda x: x.original_chain +str(x.index_in_original_chain))
|
|
for residue in residues_to_atomize:
|
|
original_chain_start_index, original_chain_end_index = chain_bins[residue.original_chain]
|
|
absolute_index_in_combined_input = original_chain_start_index + residue.index_in_original_chain
|
|
|
|
atomized_chain_start_index, atomized_chain_end_index = chain_bins[residue.chain]
|
|
N_index = atomized_chain_start_index + residue.absolute_N_index_in_chain
|
|
C_index = atomized_chain_start_index + residue.absolute_C_index_in_chain
|
|
# if residue is first in the chain, no extra bond feats to following residue
|
|
if absolute_index_in_combined_input != original_chain_start_index:
|
|
self.bond_feats[absolute_index_in_combined_input-1, N_index] = ChemData().RESIDUE_ATOM_BOND
|
|
self.bond_feats[N_index, absolute_index_in_combined_input-1] = ChemData().RESIDUE_ATOM_BOND
|
|
|
|
# if residue is last in chain, no extra bonds feats to following residue
|
|
if absolute_index_in_combined_input != original_chain_end_index-1:
|
|
self.bond_feats[absolute_index_in_combined_input+1, C_index] = ChemData().RESIDUE_ATOM_BOND
|
|
self.bond_feats[C_index,absolute_index_in_combined_input+1] = ChemData().RESIDUE_ATOM_BOND
|
|
keep[absolute_index_in_combined_input] = 0
|
|
|
|
# find neighboring residues that were atomized
|
|
if prev_absolute_index is not None:
|
|
if prev_absolute_index + 1 == absolute_index_in_combined_input:
|
|
self.bond_feats[prev_C, N_index] = 1
|
|
self.bond_feats[N_index, prev_C] = 1
|
|
|
|
prev_absolute_index = absolute_index_in_combined_input
|
|
prev_C = C_index
|
|
# remove protein features
|
|
self.keep_features(keep.bool())
|
|
|
|
def keep_features(self, keep):
|
|
if not torch.all(keep[self.is_atom()]):
|
|
raise ValueError("cannot remove atoms")
|
|
self.msa = self.msa[:,keep]
|
|
self.ins = self.ins[:,keep]
|
|
self.bond_feats = self.bond_feats[keep][:,keep]
|
|
self.xyz_t = self.xyz_t[:,keep]
|
|
self.t1d = self.t1d[:,keep]
|
|
self.mask_t = self.mask_t[:,keep]
|
|
if self.term_info is not None:
|
|
self.term_info = self.term_info[keep]
|
|
if self.idx is not None:
|
|
self.idx = self.idx[keep]
|
|
# assumes all chirals are after all protein residues
|
|
self.chirals[...,:-1] = self.chirals[...,:-1] - torch.sum(~keep)
|
|
|
|
def construct_features(self, model_runner):
|
|
loader_params = model_runner.config.loader_params
|
|
B, L = 1, self.length()
|
|
seq, msa_clust, msa_seed, msa_extra, mask_pos = MSAFeaturize(
|
|
self.msa.long(),
|
|
self.ins.long(),
|
|
loader_params,
|
|
p_mask=loader_params.get("p_msa_mask", 0),
|
|
term_info=self.term_info,
|
|
deterministic=model_runner.deterministic,
|
|
)
|
|
dist_matrix = get_bond_distances(self.bond_feats)
|
|
|
|
# xyz_prev, mask_prev = generate_xyz_prev(self.xyz_t, self.mask_t, loader_params)
|
|
# xyz_prev = torch.nan_to_num(xyz_prev)
|
|
|
|
# NOTE: The above is the way things "should" be done, this is for compatability with training.
|
|
xyz_prev = ChemData().INIT_CRDS.reshape(1,ChemData().NTOTAL,3).repeat(L,1,1)
|
|
|
|
self.xyz_t = torch.nan_to_num(self.xyz_t)
|
|
|
|
mask_t_2d = get_prot_sm_mask(self.mask_t, seq[0])
|
|
mask_t_2d = mask_t_2d[:,None]*mask_t_2d[:,:,None] # (B, T, L, L)
|
|
|
|
xyz_t_frame = xyz_t_to_frame_xyz(self.xyz_t[None], self.msa[0], self.atom_frames)
|
|
t2d = xyz_to_t2d(xyz_t_frame, mask_t_2d[None])
|
|
t2d = t2d[0]
|
|
# get torsion angles from templates
|
|
seq_tmp = self.t1d[...,:-1].argmax(dim=-1)
|
|
alpha, _, alpha_mask, _ = model_runner.xyz_converter.get_torsions(self.xyz_t.reshape(-1,L,ChemData().NTOTAL,3),
|
|
seq_tmp, mask_in=self.mask_t.reshape(-1,L,ChemData().NTOTAL))
|
|
alpha = alpha.reshape(B,-1,L,ChemData().NTOTALDOFS,2)
|
|
alpha_mask = alpha_mask.reshape(B,-1,L,ChemData().NTOTALDOFS,1)
|
|
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*ChemData().NTOTALDOFS)
|
|
alpha_t = alpha_t[0]
|
|
alpha_prev = torch.zeros((L,ChemData().NTOTALDOFS,2))
|
|
|
|
same_chain = same_chain_from_bond_feats(self.bond_feats)
|
|
return RFInput(
|
|
msa_latent=msa_seed,
|
|
msa_full=msa_extra,
|
|
seq=seq,
|
|
seq_unmasked=self.query_sequence(),
|
|
bond_feats=self.bond_feats,
|
|
dist_matrix=dist_matrix,
|
|
chirals=self.chirals,
|
|
atom_frames=self.atom_frames.long(),
|
|
xyz_prev=xyz_prev,
|
|
alpha_prev=alpha_prev,
|
|
t1d=self.t1d,
|
|
t2d=t2d,
|
|
xyz_t=self.xyz_t[..., 1, :],
|
|
alpha_t=alpha_t.float(),
|
|
mask_t=mask_t_2d.float(),
|
|
same_chain=same_chain.long(),
|
|
idx=self.idx
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class RFInput:
|
|
msa_latent: torch.Tensor
|
|
msa_full: torch.Tensor
|
|
seq: torch.Tensor
|
|
seq_unmasked: torch.Tensor
|
|
idx: torch.Tensor
|
|
bond_feats: torch.Tensor
|
|
dist_matrix: torch.Tensor
|
|
chirals: torch.Tensor
|
|
atom_frames: torch.Tensor
|
|
xyz_prev: torch.Tensor
|
|
alpha_prev: torch.Tensor
|
|
t1d: torch.Tensor
|
|
t2d: torch.Tensor
|
|
xyz_t: torch.Tensor
|
|
alpha_t: torch.Tensor
|
|
mask_t: torch.Tensor
|
|
same_chain: torch.Tensor
|
|
msa_prev: Optional[torch.Tensor] = None
|
|
pair_prev: Optional[torch.Tensor] = None
|
|
state_prev: Optional[torch.Tensor] = None
|
|
mask_recycle: Optional[torch.Tensor] = None
|
|
|
|
def to(self, gpu):
|
|
for field in fields(self):
|
|
field_value = getattr(self, field.name)
|
|
if torch.is_tensor(field_value):
|
|
setattr(self, field.name, field_value.to(gpu))
|
|
|
|
def add_batch_dim(self):
|
|
""" mimic pytorch dataloader at inference time"""
|
|
for field in fields(self):
|
|
field_value = getattr(self, field.name)
|
|
if torch.is_tensor(field_value):
|
|
setattr(self, field.name, field_value[None])
|
|
|