RoseTTAFold-All-Atom/rf2aa/data/data_loader.py
2024-03-05 16:45:27 -08:00

202 lines
8.4 KiB
Python

import torch
from dataclasses import dataclass, fields
from typing import Optional, List
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.data.data_loader_utils import MSAFeaturize, get_bond_distances, generate_xyz_prev
from rf2aa.kinematics import xyz_to_t2d
from rf2aa.util import get_prot_sm_mask, xyz_t_to_frame_xyz, same_chain_from_bond_feats, \
Ls_from_same_chain_2d, idx_from_Ls, is_atom
@dataclass
class RawInputData:
msa: torch.Tensor
ins: torch.Tensor
bond_feats: torch.Tensor
xyz_t: torch.Tensor
mask_t: torch.Tensor
t1d: torch.Tensor
chirals: torch.Tensor
atom_frames: torch.Tensor
taxids: Optional[List[str]] = None
term_info: Optional[torch.Tensor] = None
chain_lengths: Optional[List] = None
idx: Optional[List] = None
def query_sequence(self):
return self.msa[0]
def sequence_string(self):
three_letter_sequence = [ChemData().num2aa[num] for num in self.query_sequence()]
return "".join([ChemData().aa_321[three] for three in three_letter_sequence])
def is_atom(self):
return is_atom(self.query_sequence())
def length(self):
return self.msa.shape[1]
def get_chain_bins_from_chain_lengths(self):
if self.chain_lengths is None:
raise ValueError("Cannot call get_chain_bins_from_chain_lengths without \
setting chain_lengths. Chain_lengths is set in merge_inputs")
chain_bins = {}
running_length = 0
for chain, length in self.chain_lengths:
chain_bins[chain] = (running_length, running_length+length)
running_length = running_length + length
return chain_bins
def update_protein_features_after_atomize(self, residues_to_atomize):
if self.chain_lengths is None:
raise("Cannot update protein features without chain_lengths. \
merge_inputs must be called before this function")
chain_bins = self.get_chain_bins_from_chain_lengths()
keep = torch.ones(self.length())
prev_absolute_index = None
prev_C = None
#need to atomize residues from N term to Cterm to handle atomizing neighbors
residues_to_atomize = sorted(residues_to_atomize, key= lambda x: x.original_chain +str(x.index_in_original_chain))
for residue in residues_to_atomize:
original_chain_start_index, original_chain_end_index = chain_bins[residue.original_chain]
absolute_index_in_combined_input = original_chain_start_index + residue.index_in_original_chain
atomized_chain_start_index, atomized_chain_end_index = chain_bins[residue.chain]
N_index = atomized_chain_start_index + residue.absolute_N_index_in_chain
C_index = atomized_chain_start_index + residue.absolute_C_index_in_chain
# if residue is first in the chain, no extra bond feats to following residue
if absolute_index_in_combined_input != original_chain_start_index:
self.bond_feats[absolute_index_in_combined_input-1, N_index] = ChemData().RESIDUE_ATOM_BOND
self.bond_feats[N_index, absolute_index_in_combined_input-1] = ChemData().RESIDUE_ATOM_BOND
# if residue is last in chain, no extra bonds feats to following residue
if absolute_index_in_combined_input != original_chain_end_index-1:
self.bond_feats[absolute_index_in_combined_input+1, C_index] = ChemData().RESIDUE_ATOM_BOND
self.bond_feats[C_index,absolute_index_in_combined_input+1] = ChemData().RESIDUE_ATOM_BOND
keep[absolute_index_in_combined_input] = 0
# find neighboring residues that were atomized
if prev_absolute_index is not None:
if prev_absolute_index + 1 == absolute_index_in_combined_input:
self.bond_feats[prev_C, N_index] = 1
self.bond_feats[N_index, prev_C] = 1
prev_absolute_index = absolute_index_in_combined_input
prev_C = C_index
# remove protein features
self.keep_features(keep.bool())
def keep_features(self, keep):
if not torch.all(keep[self.is_atom()]):
raise ValueError("cannot remove atoms")
self.msa = self.msa[:,keep]
self.ins = self.ins[:,keep]
self.bond_feats = self.bond_feats[keep][:,keep]
self.xyz_t = self.xyz_t[:,keep]
self.t1d = self.t1d[:,keep]
self.mask_t = self.mask_t[:,keep]
if self.term_info is not None:
self.term_info = self.term_info[keep]
if self.idx is not None:
self.idx = self.idx[keep]
# assumes all chirals are after all protein residues
self.chirals[...,:-1] = self.chirals[...,:-1] - torch.sum(~keep)
def construct_features(self, model_runner):
loader_params = model_runner.config.loader_params
B, L = 1, self.length()
seq, msa_clust, msa_seed, msa_extra, mask_pos = MSAFeaturize(
self.msa.long(),
self.ins.long(),
loader_params,
p_mask=loader_params.get("p_msa_mask", 0),
term_info=self.term_info,
deterministic=model_runner.deterministic,
)
dist_matrix = get_bond_distances(self.bond_feats)
# xyz_prev, mask_prev = generate_xyz_prev(self.xyz_t, self.mask_t, loader_params)
# xyz_prev = torch.nan_to_num(xyz_prev)
# NOTE: The above is the way things "should" be done, this is for compatability with training.
xyz_prev = ChemData().INIT_CRDS.reshape(1,ChemData().NTOTAL,3).repeat(L,1,1)
self.xyz_t = torch.nan_to_num(self.xyz_t)
mask_t_2d = get_prot_sm_mask(self.mask_t, seq[0])
mask_t_2d = mask_t_2d[:,None]*mask_t_2d[:,:,None] # (B, T, L, L)
xyz_t_frame = xyz_t_to_frame_xyz(self.xyz_t[None], self.msa[0], self.atom_frames)
t2d = xyz_to_t2d(xyz_t_frame, mask_t_2d[None])
t2d = t2d[0]
# get torsion angles from templates
seq_tmp = self.t1d[...,:-1].argmax(dim=-1)
alpha, _, alpha_mask, _ = model_runner.xyz_converter.get_torsions(self.xyz_t.reshape(-1,L,ChemData().NTOTAL,3),
seq_tmp, mask_in=self.mask_t.reshape(-1,L,ChemData().NTOTAL))
alpha = alpha.reshape(B,-1,L,ChemData().NTOTALDOFS,2)
alpha_mask = alpha_mask.reshape(B,-1,L,ChemData().NTOTALDOFS,1)
alpha_t = torch.cat((alpha, alpha_mask), dim=-1).reshape(B, -1, L, 3*ChemData().NTOTALDOFS)
alpha_t = alpha_t[0]
alpha_prev = torch.zeros((L,ChemData().NTOTALDOFS,2))
same_chain = same_chain_from_bond_feats(self.bond_feats)
return RFInput(
msa_latent=msa_seed,
msa_full=msa_extra,
seq=seq,
seq_unmasked=self.query_sequence(),
bond_feats=self.bond_feats,
dist_matrix=dist_matrix,
chirals=self.chirals,
atom_frames=self.atom_frames.long(),
xyz_prev=xyz_prev,
alpha_prev=alpha_prev,
t1d=self.t1d,
t2d=t2d,
xyz_t=self.xyz_t[..., 1, :],
alpha_t=alpha_t.float(),
mask_t=mask_t_2d.float(),
same_chain=same_chain.long(),
idx=self.idx
)
@dataclass
class RFInput:
msa_latent: torch.Tensor
msa_full: torch.Tensor
seq: torch.Tensor
seq_unmasked: torch.Tensor
idx: torch.Tensor
bond_feats: torch.Tensor
dist_matrix: torch.Tensor
chirals: torch.Tensor
atom_frames: torch.Tensor
xyz_prev: torch.Tensor
alpha_prev: torch.Tensor
t1d: torch.Tensor
t2d: torch.Tensor
xyz_t: torch.Tensor
alpha_t: torch.Tensor
mask_t: torch.Tensor
same_chain: torch.Tensor
msa_prev: Optional[torch.Tensor] = None
pair_prev: Optional[torch.Tensor] = None
state_prev: Optional[torch.Tensor] = None
mask_recycle: Optional[torch.Tensor] = None
def to(self, gpu):
for field in fields(self):
field_value = getattr(self, field.name)
if torch.is_tensor(field_value):
setattr(self, field.name, field_value.to(gpu))
def add_batch_dim(self):
""" mimic pytorch dataloader at inference time"""
for field in fields(self):
field_value = getattr(self, field.name)
if torch.is_tensor(field_value):
setattr(self, field.name, field_value[None])