mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-04 22:25:42 +00:00
418 lines
17 KiB
Python
418 lines
17 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import assertpy
|
|
from assertpy import assert_that
|
|
from icecream import ic
|
|
from rf2aa.model.layers.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, recycling_factory
|
|
from rf2aa.model.Track_module import IterativeSimulator
|
|
from rf2aa.model.layers.AuxiliaryPredictor import (
|
|
DistanceNetwork,
|
|
MaskedTokenNetwork,
|
|
LDDTNetwork,
|
|
PAENetwork,
|
|
BinderNetwork,
|
|
)
|
|
from rf2aa.tensor_util import assert_shape, assert_equal
|
|
import rf2aa.util
|
|
from rf2aa.chemical import ChemicalData as ChemData
|
|
|
|
|
|
def get_shape(t):
|
|
if hasattr(t, "shape"):
|
|
return t.shape
|
|
if type(t) is tuple:
|
|
return [get_shape(e) for e in t]
|
|
else:
|
|
return type(t)
|
|
|
|
|
|
class RoseTTAFoldModule(nn.Module):
|
|
def __init__(
|
|
self,
|
|
symmetrize_repeats=None, # whether to symmetrize repeats in the pair track
|
|
repeat_length=None, # if symmetrizing repeats, what length are they?
|
|
symmsub_k=None, # if symmetrizing repeats, which diagonals?
|
|
sym_method=None, # if symmetrizing repeats, which block symmetrization method?
|
|
main_block=None, # if copying template blocks along main diag, which block is main block? (the one w/ motif)
|
|
copy_main_block_template=None, # whether or not to copy main block template along main diag
|
|
n_extra_block=4,
|
|
n_main_block=8,
|
|
n_ref_block=4,
|
|
n_finetune_block=0,
|
|
d_msa=256,
|
|
d_msa_full=64,
|
|
d_pair=128,
|
|
d_templ=64,
|
|
n_head_msa=8,
|
|
n_head_pair=4,
|
|
n_head_templ=4,
|
|
d_hidden=32,
|
|
d_hidden_templ=64,
|
|
d_t1d=0,
|
|
p_drop=0.15,
|
|
additional_dt1d=0,
|
|
recycling_type="msa_pair",
|
|
SE3_param={}, SE3_ref_param={},
|
|
atom_type_index=None,
|
|
aamask=None,
|
|
ljlk_parameters=None,
|
|
lj_correction_parameters=None,
|
|
cb_len=None,
|
|
cb_ang=None,
|
|
cb_tor=None,
|
|
num_bonds=None,
|
|
lj_lin=0.6,
|
|
use_chiral_l1=True,
|
|
use_lj_l1=False,
|
|
use_atom_frames=True,
|
|
use_same_chain=False,
|
|
enable_same_chain=False,
|
|
refiner_topk=64,
|
|
get_quaternion=False,
|
|
# New for diffusion
|
|
freeze_track_motif=False,
|
|
assert_single_sequence_input=False,
|
|
fit=False,
|
|
tscale=1.0
|
|
):
|
|
super(RoseTTAFoldModule, self).__init__()
|
|
self.freeze_track_motif = freeze_track_motif
|
|
self.assert_single_sequence_input = assert_single_sequence_input
|
|
self.recycling_type = recycling_type
|
|
#
|
|
# Input Embeddings
|
|
d_state = SE3_param["l0_out_features"]
|
|
self.latent_emb = MSA_emb(
|
|
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop, use_same_chain=use_same_chain,
|
|
enable_same_chain=enable_same_chain
|
|
)
|
|
self.full_emb = Extra_emb(
|
|
d_msa=d_msa_full, d_init=ChemData().NAATOKENS - 1 + 4, p_drop=p_drop
|
|
)
|
|
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=ChemData().NBTYPES)
|
|
|
|
self.templ_emb = Templ_emb(d_t1d=d_t1d,
|
|
d_pair=d_pair,
|
|
d_templ=d_templ,
|
|
d_state=d_state,
|
|
n_head=n_head_templ,
|
|
d_hidden=d_hidden_templ,
|
|
p_drop=0.25,
|
|
symmetrize_repeats=symmetrize_repeats, # repeat protein stuff
|
|
repeat_length=repeat_length,
|
|
symmsub_k=symmsub_k,
|
|
sym_method=sym_method,
|
|
main_block=main_block,
|
|
copy_main_block=copy_main_block_template,
|
|
additional_dt1d=additional_dt1d)
|
|
|
|
# Update inputs with outputs from previous round
|
|
|
|
self.recycle = recycling_factory[recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
|
|
#
|
|
self.simulator = IterativeSimulator(
|
|
n_extra_block=n_extra_block,
|
|
n_main_block=n_main_block,
|
|
n_ref_block=n_ref_block,
|
|
n_finetune_block=n_finetune_block,
|
|
d_msa=d_msa,
|
|
d_msa_full=d_msa_full,
|
|
d_pair=d_pair,
|
|
d_hidden=d_hidden,
|
|
n_head_msa=n_head_msa,
|
|
n_head_pair=n_head_pair,
|
|
SE3_param=SE3_param,
|
|
SE3_ref_param=SE3_ref_param,
|
|
p_drop=p_drop,
|
|
atom_type_index=atom_type_index, # change if encoding elements instead of atomtype
|
|
aamask=aamask,
|
|
ljlk_parameters=ljlk_parameters,
|
|
lj_correction_parameters=lj_correction_parameters,
|
|
num_bonds=num_bonds,
|
|
cb_len=cb_len,
|
|
cb_ang=cb_ang,
|
|
cb_tor=cb_tor,
|
|
lj_lin=lj_lin,
|
|
use_lj_l1=use_lj_l1,
|
|
use_chiral_l1=use_chiral_l1,
|
|
symmetrize_repeats=symmetrize_repeats,
|
|
repeat_length=repeat_length,
|
|
symmsub_k=symmsub_k,
|
|
sym_method=sym_method,
|
|
main_block=main_block,
|
|
use_same_chain=use_same_chain,
|
|
enable_same_chain=enable_same_chain,
|
|
refiner_topk=refiner_topk
|
|
)
|
|
|
|
##
|
|
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
|
|
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
|
|
self.lddt_pred = LDDTNetwork(d_state)
|
|
self.pae_pred = PAENetwork(d_pair)
|
|
self.pde_pred = PAENetwork(
|
|
d_pair
|
|
) # distance error, but use same architecture as aligned error
|
|
# binder predictions are made on top of the pair features, just like
|
|
# PAE predictions are. It's not clear if this is the best place to insert
|
|
# this prediction head.
|
|
# self.binder_network = BinderNetwork(d_pair, d_state)
|
|
|
|
self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable?
|
|
|
|
self.use_atom_frames = use_atom_frames
|
|
self.enable_same_chain = enable_same_chain
|
|
self.get_quaternion = get_quaternion
|
|
self.verbose_checks = False
|
|
|
|
def forward(
|
|
self,
|
|
msa_latent,
|
|
msa_full,
|
|
seq,
|
|
seq_unmasked,
|
|
xyz,
|
|
sctors,
|
|
idx,
|
|
bond_feats,
|
|
dist_matrix,
|
|
chirals,
|
|
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
|
|
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None, is_motif=None,
|
|
return_raw=False,
|
|
use_checkpoint=False,
|
|
return_infer=False, #fd ?
|
|
p2p_crop=-1, topk_crop=-1, # striping
|
|
symmids=None, symmsub=None, symmRs=None, symmmeta=None, # symmetry
|
|
):
|
|
# ic(get_shape(msa_latent))
|
|
# ic(get_shape(msa_full))
|
|
# ic(get_shape(seq))
|
|
# ic(get_shape(seq_unmasked))
|
|
# ic(get_shape(xyz))
|
|
# ic(get_shape(sctors))
|
|
# ic(get_shape(idx))
|
|
# ic(get_shape(bond_feats))
|
|
# ic(get_shape(chirals))
|
|
# ic(get_shape(atom_frames))
|
|
# ic(get_shape(t1d))
|
|
# ic(get_shape(t2d))
|
|
# ic(get_shape(xyz_t))
|
|
# ic(get_shape(alpha_t))
|
|
# ic(get_shape(mask_t))
|
|
# ic(get_shape(same_chain))
|
|
# ic(get_shape(msa_prev))
|
|
# ic(get_shape(pair_prev))
|
|
# ic(get_shape(mask_recycle))
|
|
# ic()
|
|
# ic()
|
|
B, N, L = msa_latent.shape[:3]
|
|
A = atom_frames.shape[1]
|
|
dtype = msa_latent.dtype
|
|
|
|
if self.assert_single_sequence_input:
|
|
assert_shape(msa_latent, (1, 1, L, 164))
|
|
assert_shape(msa_full, (1, 1, L, 83))
|
|
assert_shape(seq, (1, L))
|
|
assert_shape(seq_unmasked, (1, L))
|
|
assert_shape(xyz, (1, L, ChemData().NTOTAL, 3))
|
|
assert_shape(sctors, (1, L, 20, 2))
|
|
assert_shape(idx, (1, L))
|
|
assert_shape(bond_feats, (1, L, L))
|
|
assert_shape(dist_matrix, (1, L, L))
|
|
# assert_shape(chirals, (1, 0))
|
|
# assert_shape(atom_frames, (1, 4, L)) # This is set to 4 for the recycle count, but that can't be right
|
|
assert_shape(atom_frames, (1, A, 3, 2)) # What is 4?
|
|
assert_shape(t1d, (1, 1, L, 80))
|
|
assert_shape(t2d, (1, 1, L, L, 68))
|
|
assert_shape(xyz_t, (1, 1, L, 3))
|
|
assert_shape(alpha_t, (1, 1, L, 60))
|
|
assert_shape(mask_t, (1, 1, L, L))
|
|
assert_shape(same_chain, (1, L, L))
|
|
device = msa_latent.device
|
|
assert_that(msa_full.device).is_equal_to(device)
|
|
assert_that(seq.device).is_equal_to(device)
|
|
assert_that(seq_unmasked.device).is_equal_to(device)
|
|
assert_that(xyz.device).is_equal_to(device)
|
|
assert_that(sctors.device).is_equal_to(device)
|
|
assert_that(idx.device).is_equal_to(device)
|
|
assert_that(bond_feats.device).is_equal_to(device)
|
|
assert_that(dist_matrix.device).is_equal_to(device)
|
|
assert_that(atom_frames.device).is_equal_to(device)
|
|
assert_that(t1d.device).is_equal_to(device)
|
|
assert_that(t2d.device).is_equal_to(device)
|
|
assert_that(xyz_t.device).is_equal_to(device)
|
|
assert_that(alpha_t.device).is_equal_to(device)
|
|
assert_that(mask_t.device).is_equal_to(device)
|
|
assert_that(same_chain.device).is_equal_to(device)
|
|
|
|
if self.verbose_checks:
|
|
#ic(is_motif.shape)
|
|
is_sm = rf2aa.util.is_atom(seq[0]) # (L)
|
|
#is_protein_motif = is_motif & ~is_sm
|
|
#if is_motif.any():
|
|
# motif_protein_i = torch.where(is_motif)[0][0]
|
|
#is_motif_sm = is_motif & is_sm
|
|
#if is_sm.any():
|
|
# motif_sm_i = torch.where(is_motif_sm)[0][0]
|
|
#diffused_protein_i = torch.where(~is_sm & ~is_motif)[0][0]
|
|
|
|
"""
|
|
msa_full: NSEQ,N_INDEL,N_TERMINUS,
|
|
msa_masked: NSEQ,NSEQ,N_INDEL,N_INDEL,N_TERMINUS
|
|
"""
|
|
import numpy as np
|
|
|
|
NINDEL = 1
|
|
NTERMINUS = 2
|
|
NMSAFULL = ChemData().NAATOKENS + NINDEL + NTERMINUS
|
|
NMSAMASKED = ChemData().NAATOKENS + ChemData().NAATOKENS + NINDEL + NINDEL + NTERMINUS
|
|
assert_that(msa_latent.shape[-1]).is_equal_to(NMSAMASKED)
|
|
assert_that(msa_full.shape[-1]).is_equal_to(NMSAFULL)
|
|
|
|
msa_full_seq = np.r_[0:ChemData().NAATOKENS]
|
|
msa_full_indel = np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
|
|
msa_full_term = np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]
|
|
|
|
msa_latent_seq1 = np.r_[0:ChemData().NAATOKENS]
|
|
msa_latent_seq2 = np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]
|
|
msa_latent_indel1 = np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
|
|
msa_latent_indel2 = np.r_[
|
|
2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL
|
|
]
|
|
msa_latent_terminus = np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
|
|
|
|
#i_name = [(diffused_protein_i, "diffused_protein")]
|
|
#if is_sm.any():
|
|
# i_name.insert(0, (motif_sm_i, "motif_sm"))
|
|
#if is_motif.any():
|
|
# i_name.insert(0, (motif_protein_i, "motif_protein"))
|
|
i_name = [(0, "tst")]
|
|
for i, name in i_name:
|
|
ic(f"------------------{name}:{i}----------------")
|
|
msa_full_seq = msa_full[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
|
|
msa_full_indel = msa_full[
|
|
0, 0, i, np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
|
|
]
|
|
msa_full_term = msa_full[0, 0, i, np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]]
|
|
|
|
msa_latent_seq1 = msa_latent[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
|
|
msa_latent_seq2 = msa_latent[0, 0, i, np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]]
|
|
msa_latent_indel1 = msa_latent[
|
|
0, 0, i, np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
|
|
]
|
|
msa_latent_indel2 = msa_latent[
|
|
0,
|
|
0,
|
|
i,
|
|
np.r_[2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL],
|
|
]
|
|
msa_latent_term = msa_latent[
|
|
0, 0, i, np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
|
|
]
|
|
|
|
assert_equal(msa_full_seq, msa_latent_seq1)
|
|
assert_equal(msa_full_seq, msa_latent_seq2)
|
|
assert_equal(msa_full_indel, msa_latent_indel1)
|
|
assert_equal(msa_full_indel, msa_latent_indel2)
|
|
assert_equal(msa_full_term, msa_latent_term)
|
|
# if 'motif' in name:
|
|
msa_cat = torch.where(msa_full_seq)[0]
|
|
ic(msa_cat, seq[0, i])
|
|
assert_equal(seq[0, i : i + 1], msa_cat)
|
|
assert_equal(seq[0, i], seq_unmasked[0, i])
|
|
ic(
|
|
name,
|
|
# torch.where(msa_latent[0,0,i,:80]),
|
|
# torch.where(msa_full[0,0,i]),
|
|
seq[0, i],
|
|
seq_unmasked[0, i],
|
|
torch.where(t1d[0, 0, i]),
|
|
xyz[0, i, :4, 0],
|
|
xyz_t[0, 0, i, 0],
|
|
)
|
|
|
|
# Get embeddings
|
|
#if self.enable_same_chain == False:
|
|
# same_chain = None
|
|
msa_latent, pair, state = self.latent_emb(
|
|
msa_latent, seq, idx, bond_feats, dist_matrix, same_chain=same_chain
|
|
)
|
|
msa_full = self.full_emb(msa_full, seq, idx)
|
|
pair = pair + self.bond_emb(bond_feats)
|
|
|
|
msa_latent, pair, state = msa_latent.to(dtype), pair.to(dtype), state.to(dtype)
|
|
msa_full = msa_full.to(dtype)
|
|
|
|
#
|
|
# Do recycling
|
|
if msa_prev is None:
|
|
msa_prev = torch.zeros_like(msa_latent[:,0])
|
|
if pair_prev is None:
|
|
pair_prev = torch.zeros_like(pair)
|
|
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
|
|
state_prev = torch.zeros_like(state)
|
|
|
|
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
|
|
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
|
|
|
|
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
|
|
pair = pair + pair_recycle
|
|
state = state + state_recycle # if state is not recycled these will be zeros
|
|
|
|
# add template embedding
|
|
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop)
|
|
|
|
# Predict coordinates from given inputs
|
|
is_motif = is_motif if self.freeze_track_motif else torch.zeros_like(seq).bool()[0]
|
|
msa, pair, xyz, alpha_s, xyz_allatom, state, symmsub, quat = self.simulator(
|
|
seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx,
|
|
symmids, symmsub, symmRs, symmmeta,
|
|
bond_feats, dist_matrix, same_chain, chirals, is_motif, atom_frames,
|
|
use_checkpoint=use_checkpoint, use_atom_frames=self.use_atom_frames,
|
|
p2p_crop=p2p_crop, topk_crop=topk_crop
|
|
)
|
|
|
|
if return_raw:
|
|
# get last structure
|
|
xyz_last = xyz_allatom[-1].unsqueeze(0)
|
|
return msa[:,0], pair, xyz_last, alpha_s[-1], None
|
|
|
|
# predict masked amino acids
|
|
logits_aa = self.aa_pred(msa)
|
|
|
|
# predict distogram & orientograms
|
|
logits = self.c6d_pred(pair)
|
|
|
|
# Predict LDDT
|
|
lddt = self.lddt_pred(state)
|
|
|
|
if self.verbose_checks:
|
|
pseq_0 = logits_aa.permute(0, 2, 1)
|
|
ic(pseq_0.shape)
|
|
pseq_0 = pseq_0[0]
|
|
ic(
|
|
f"motif sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[is_motif], dim=-1).tolist())}"
|
|
)
|
|
ic(
|
|
f"diffused sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[~is_motif], dim=-1).tolist())}"
|
|
)
|
|
|
|
logits_pae = logits_pde = p_bind = None
|
|
# predict aligned error and distance error
|
|
logits_pae = self.pae_pred(pair)
|
|
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
|
|
|
|
#fd predict bind/no-bind
|
|
p_bind = self.bind_pred(logits_pae,same_chain)
|
|
|
|
if self.get_quaternion:
|
|
return (
|
|
logits, logits_aa, logits_pae, logits_pde, p_bind,
|
|
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state, quat
|
|
)
|
|
else:
|
|
return (
|
|
logits, logits_aa, logits_pae, logits_pde, p_bind,
|
|
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state
|
|
)
|