RoseTTAFold-All-Atom/rf2aa/model/RoseTTAFoldModel.py
2024-03-04 22:38:17 -08:00

419 lines
17 KiB
Python

import torch
import torch.nn as nn
import assertpy
from assertpy import assert_that
from icecream import ic
from rf2aa.model.layers.Embeddings import MSA_emb, Extra_emb, Bond_emb, Templ_emb, recycling_factory
from rf2aa.model.Track_module import IterativeSimulator
from rf2aa.model.layers.AuxiliaryPredictor import (
DistanceNetwork,
MaskedTokenNetwork,
LDDTNetwork,
PAENetwork,
BinderNetwork,
)
from rf2aa.tensor_util import assert_shape, assert_equal
import rf2aa.util
from rf2aa.chemical import ChemicalData as ChemData
def get_shape(t):
if hasattr(t, "shape"):
return t.shape
if type(t) is tuple:
return [get_shape(e) for e in t]
else:
return type(t)
class RoseTTAFoldModule(nn.Module):
def __init__(
self,
symmetrize_repeats=None, # whether to symmetrize repeats in the pair track
repeat_length=None, # if symmetrizing repeats, what length are they?
symmsub_k=None, # if symmetrizing repeats, which diagonals?
sym_method=None, # if symmetrizing repeats, which block symmetrization method?
main_block=None, # if copying template blocks along main diag, which block is main block? (the one w/ motif)
copy_main_block_template=None, # whether or not to copy main block template along main diag
n_extra_block=4,
n_main_block=8,
n_ref_block=4,
n_finetune_block=0,
d_msa=256,
d_msa_full=64,
d_pair=128,
d_templ=64,
n_head_msa=8,
n_head_pair=4,
n_head_templ=4,
d_hidden=32,
d_hidden_templ=64,
d_t1d=0,
p_drop=0.15,
additional_dt1d=0,
recycling_type="msa_pair",
SE3_param={}, SE3_ref_param={},
atom_type_index=None,
aamask=None,
ljlk_parameters=None,
lj_correction_parameters=None,
cb_len=None,
cb_ang=None,
cb_tor=None,
num_bonds=None,
lj_lin=0.6,
use_chiral_l1=True,
use_lj_l1=False,
use_atom_frames=True,
use_same_chain=False,
enable_same_chain=False,
refiner_topk=64,
get_quaternion=False,
# New for diffusion
freeze_track_motif=False,
assert_single_sequence_input=False,
fit=False,
tscale=1.0
):
super(RoseTTAFoldModule, self).__init__()
self.freeze_track_motif = freeze_track_motif
self.assert_single_sequence_input = assert_single_sequence_input
self.recycling_type = recycling_type
#
# Input Embeddings
d_state = SE3_param["l0_out_features"]
self.latent_emb = MSA_emb(
d_msa=d_msa, d_pair=d_pair, d_state=d_state, p_drop=p_drop, use_same_chain=use_same_chain,
enable_same_chain=enable_same_chain
)
self.full_emb = Extra_emb(
d_msa=d_msa_full, d_init=ChemData().NAATOKENS - 1 + 4, p_drop=p_drop
)
self.bond_emb = Bond_emb(d_pair=d_pair, d_init=ChemData().NBTYPES)
self.templ_emb = Templ_emb(d_t1d=d_t1d,
d_pair=d_pair,
d_templ=d_templ,
d_state=d_state,
n_head=n_head_templ,
d_hidden=d_hidden_templ,
p_drop=0.25,
symmetrize_repeats=symmetrize_repeats, # repeat protein stuff
repeat_length=repeat_length,
symmsub_k=symmsub_k,
sym_method=sym_method,
main_block=main_block,
copy_main_block=copy_main_block_template,
additional_dt1d=additional_dt1d)
# Update inputs with outputs from previous round
self.recycle = recycling_factory[recycling_type](d_msa=d_msa, d_pair=d_pair, d_state=d_state)
#
self.simulator = IterativeSimulator(
n_extra_block=n_extra_block,
n_main_block=n_main_block,
n_ref_block=n_ref_block,
n_finetune_block=n_finetune_block,
d_msa=d_msa,
d_msa_full=d_msa_full,
d_pair=d_pair,
d_hidden=d_hidden,
n_head_msa=n_head_msa,
n_head_pair=n_head_pair,
SE3_param=SE3_param,
SE3_ref_param=SE3_ref_param,
p_drop=p_drop,
atom_type_index=atom_type_index, # change if encoding elements instead of atomtype
aamask=aamask,
ljlk_parameters=ljlk_parameters,
lj_correction_parameters=lj_correction_parameters,
num_bonds=num_bonds,
cb_len=cb_len,
cb_ang=cb_ang,
cb_tor=cb_tor,
lj_lin=lj_lin,
use_lj_l1=use_lj_l1,
use_chiral_l1=use_chiral_l1,
symmetrize_repeats=symmetrize_repeats,
repeat_length=repeat_length,
symmsub_k=symmsub_k,
sym_method=sym_method,
main_block=main_block,
use_same_chain=use_same_chain,
enable_same_chain=enable_same_chain,
refiner_topk=refiner_topk
)
##
self.c6d_pred = DistanceNetwork(d_pair, p_drop=p_drop)
self.aa_pred = MaskedTokenNetwork(d_msa, p_drop=p_drop)
self.lddt_pred = LDDTNetwork(d_state)
self.pae_pred = PAENetwork(d_pair)
self.pde_pred = PAENetwork(
d_pair
) # distance error, but use same architecture as aligned error
# binder predictions are made on top of the pair features, just like
# PAE predictions are. It's not clear if this is the best place to insert
# this prediction head.
# self.binder_network = BinderNetwork(d_pair, d_state)
self.bind_pred = BinderNetwork() #fd - expose n_hidden as variable?
self.use_atom_frames = use_atom_frames
self.enable_same_chain = enable_same_chain
self.get_quaternion = get_quaternion
self.verbose_checks = False
def forward(
self,
msa_latent,
msa_full,
seq,
seq_unmasked,
xyz,
sctors,
idx,
bond_feats,
dist_matrix,
chirals,
atom_frames=None, t1d=None, t2d=None, xyz_t=None, alpha_t=None, mask_t=None, same_chain=None,
msa_prev=None, pair_prev=None, state_prev=None, mask_recycle=None, is_motif=None,
return_raw=False,
use_checkpoint=False,
return_infer=False, #fd ?
p2p_crop=-1, topk_crop=-1, # striping
symmids=None, symmsub=None, symmRs=None, symmmeta=None, # symmetry
):
# ic(get_shape(msa_latent))
# ic(get_shape(msa_full))
# ic(get_shape(seq))
# ic(get_shape(seq_unmasked))
# ic(get_shape(xyz))
# ic(get_shape(sctors))
# ic(get_shape(idx))
# ic(get_shape(bond_feats))
# ic(get_shape(chirals))
# ic(get_shape(atom_frames))
# ic(get_shape(t1d))
# ic(get_shape(t2d))
# ic(get_shape(xyz_t))
# ic(get_shape(alpha_t))
# ic(get_shape(mask_t))
# ic(get_shape(same_chain))
# ic(get_shape(msa_prev))
# ic(get_shape(pair_prev))
# ic(get_shape(mask_recycle))
# ic()
# ic()
B, N, L = msa_latent.shape[:3]
A = atom_frames.shape[1]
dtype = msa_latent.dtype
if self.assert_single_sequence_input:
assert_shape(msa_latent, (1, 1, L, 164))
assert_shape(msa_full, (1, 1, L, 83))
assert_shape(seq, (1, L))
assert_shape(seq_unmasked, (1, L))
assert_shape(xyz, (1, L, ChemData().NTOTAL, 3))
assert_shape(sctors, (1, L, 20, 2))
assert_shape(idx, (1, L))
assert_shape(bond_feats, (1, L, L))
assert_shape(dist_matrix, (1, L, L))
# assert_shape(chirals, (1, 0))
# assert_shape(atom_frames, (1, 4, L)) # This is set to 4 for the recycle count, but that can't be right
assert_shape(atom_frames, (1, A, 3, 2)) # What is 4?
assert_shape(t1d, (1, 1, L, 80))
assert_shape(t2d, (1, 1, L, L, 68))
assert_shape(xyz_t, (1, 1, L, 3))
assert_shape(alpha_t, (1, 1, L, 60))
assert_shape(mask_t, (1, 1, L, L))
assert_shape(same_chain, (1, L, L))
device = msa_latent.device
assert_that(msa_full.device).is_equal_to(device)
assert_that(seq.device).is_equal_to(device)
assert_that(seq_unmasked.device).is_equal_to(device)
assert_that(xyz.device).is_equal_to(device)
assert_that(sctors.device).is_equal_to(device)
assert_that(idx.device).is_equal_to(device)
assert_that(bond_feats.device).is_equal_to(device)
assert_that(dist_matrix.device).is_equal_to(device)
assert_that(atom_frames.device).is_equal_to(device)
assert_that(t1d.device).is_equal_to(device)
assert_that(t2d.device).is_equal_to(device)
assert_that(xyz_t.device).is_equal_to(device)
assert_that(alpha_t.device).is_equal_to(device)
assert_that(mask_t.device).is_equal_to(device)
assert_that(same_chain.device).is_equal_to(device)
if self.verbose_checks:
#ic(is_motif.shape)
is_sm = rf2aa.util.is_atom(seq[0]) # (L)
#is_protein_motif = is_motif & ~is_sm
#if is_motif.any():
# motif_protein_i = torch.where(is_motif)[0][0]
#is_motif_sm = is_motif & is_sm
#if is_sm.any():
# motif_sm_i = torch.where(is_motif_sm)[0][0]
#diffused_protein_i = torch.where(~is_sm & ~is_motif)[0][0]
"""
msa_full: NSEQ,N_INDEL,N_TERMINUS,
msa_masked: NSEQ,NSEQ,N_INDEL,N_INDEL,N_TERMINUS
"""
import numpy as np
NINDEL = 1
NTERMINUS = 2
NMSAFULL = ChemData().NAATOKENS + NINDEL + NTERMINUS
NMSAMASKED = ChemData().NAATOKENS + ChemData().NAATOKENS + NINDEL + NINDEL + NTERMINUS
assert_that(msa_latent.shape[-1]).is_equal_to(NMSAMASKED)
assert_that(msa_full.shape[-1]).is_equal_to(NMSAFULL)
msa_full_seq = np.r_[0:ChemData().NAATOKENS]
msa_full_indel = np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
msa_full_term = np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]
msa_latent_seq1 = np.r_[0:ChemData().NAATOKENS]
msa_latent_seq2 = np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]
msa_latent_indel1 = np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
msa_latent_indel2 = np.r_[
2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL
]
msa_latent_terminus = np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
#i_name = [(diffused_protein_i, "diffused_protein")]
#if is_sm.any():
# i_name.insert(0, (motif_sm_i, "motif_sm"))
#if is_motif.any():
# i_name.insert(0, (motif_protein_i, "motif_protein"))
i_name = [(0, "tst")]
for i, name in i_name:
ic(f"------------------{name}:{i}----------------")
msa_full_seq = msa_full[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
msa_full_indel = msa_full[
0, 0, i, np.r_[ChemData().NAATOKENS : ChemData().NAATOKENS + NINDEL]
]
msa_full_term = msa_full[0, 0, i, np.r_[ChemData().NAATOKENS + NINDEL : NMSAFULL]]
msa_latent_seq1 = msa_latent[0, 0, i, np.r_[0:ChemData().NAATOKENS]]
msa_latent_seq2 = msa_latent[0, 0, i, np.r_[ChemData().NAATOKENS : 2 * ChemData().NAATOKENS]]
msa_latent_indel1 = msa_latent[
0, 0, i, np.r_[2 * ChemData().NAATOKENS : 2 * ChemData().NAATOKENS + NINDEL]
]
msa_latent_indel2 = msa_latent[
0,
0,
i,
np.r_[2 * ChemData().NAATOKENS + NINDEL : 2 * ChemData().NAATOKENS + NINDEL + NINDEL],
]
msa_latent_term = msa_latent[
0, 0, i, np.r_[2 * ChemData().NAATOKENS + 2 * NINDEL : NMSAMASKED]
]
assert_equal(msa_full_seq, msa_latent_seq1)
assert_equal(msa_full_seq, msa_latent_seq2)
assert_equal(msa_full_indel, msa_latent_indel1)
assert_equal(msa_full_indel, msa_latent_indel2)
assert_equal(msa_full_term, msa_latent_term)
# if 'motif' in name:
msa_cat = torch.where(msa_full_seq)[0]
ic(msa_cat, seq[0, i])
assert_equal(seq[0, i : i + 1], msa_cat)
assert_equal(seq[0, i], seq_unmasked[0, i])
ic(
name,
# torch.where(msa_latent[0,0,i,:80]),
# torch.where(msa_full[0,0,i]),
seq[0, i],
seq_unmasked[0, i],
torch.where(t1d[0, 0, i]),
xyz[0, i, :4, 0],
xyz_t[0, 0, i, 0],
)
# Get embeddings
#if self.enable_same_chain == False:
# same_chain = None
msa_latent, pair, state = self.latent_emb(
msa_latent, seq, idx, bond_feats, dist_matrix, same_chain=same_chain
)
msa_full = self.full_emb(msa_full, seq, idx)
pair = pair + self.bond_emb(bond_feats)
msa_latent, pair, state = msa_latent.to(dtype), pair.to(dtype), state.to(dtype)
msa_full = msa_full.to(dtype)
#
# Do recycling
if msa_prev is None:
msa_prev = torch.zeros_like(msa_latent[:,0])
if pair_prev is None:
pair_prev = torch.zeros_like(pair)
if state_prev is None or self.recycling_type == "msa_pair": #explicitly remove state features if only recycling msa and pair
state_prev = torch.zeros_like(state)
msa_recycle, pair_recycle, state_recycle = self.recycle(msa_prev, pair_prev, xyz, state_prev, sctors, mask_recycle)
msa_recycle, pair_recycle = msa_recycle.to(dtype), pair_recycle.to(dtype)
msa_latent[:,0] = msa_latent[:,0] + msa_recycle.reshape(B,L,-1)
pair = pair + pair_recycle
state = state + state_recycle # if state is not recycled these will be zeros
# add template embedding
pair, state = self.templ_emb(t1d, t2d, alpha_t, xyz_t, mask_t, pair, state, use_checkpoint=use_checkpoint, p2p_crop=p2p_crop)
# Predict coordinates from given inputs
is_motif = is_motif if self.freeze_track_motif else torch.zeros_like(seq).bool()[0]
msa, pair, xyz, alpha_s, xyz_allatom, state, symmsub, quat = self.simulator(
seq_unmasked, msa_latent, msa_full, pair, xyz[:,:,:3], state, idx,
symmids, symmsub, symmRs, symmmeta,
bond_feats, dist_matrix, same_chain, chirals, is_motif, atom_frames,
use_checkpoint=use_checkpoint, use_atom_frames=self.use_atom_frames,
p2p_crop=p2p_crop, topk_crop=topk_crop
)
if return_raw:
# get last structure
xyz_last = xyz_allatom[-1].unsqueeze(0)
return msa[:,0], pair, xyz_last, alpha_s[-1], None
# predict masked amino acids
logits_aa = self.aa_pred(msa)
# predict distogram & orientograms
logits = self.c6d_pred(pair)
# Predict LDDT
lddt = self.lddt_pred(state)
if self.verbose_checks:
pseq_0 = logits_aa.permute(0, 2, 1)
ic(pseq_0.shape)
pseq_0 = pseq_0[0]
ic(
f"motif sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[is_motif], dim=-1).tolist())}"
)
ic(
f"diffused sequence: { rf2aa.chemical.seq2chars(torch.argmax(pseq_0[~is_motif], dim=-1).tolist())}"
)
logits_pae = logits_pde = p_bind = None
# predict aligned error and distance error
logits_pae = self.pae_pred(pair)
logits_pde = self.pde_pred(pair + pair.permute(0,2,1,3)) # symmetrize pair features
#fd predict bind/no-bind
p_bind = self.bind_pred(logits_pae,same_chain)
if self.get_quaternion:
return (
logits, logits_aa, logits_pae, logits_pde, p_bind,
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state, quat
)
else:
return (
logits, logits_aa, logits_pae, logits_pde, p_bind,
xyz, alpha_s, xyz_allatom, lddt, msa[:,0], pair, state
)