mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-04 22:25:42 +00:00
1221 lines
48 KiB
Python
1221 lines
48 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from opt_einsum import contract as einsum
|
||
|
import torch.utils.checkpoint as checkpoint
|
||
|
from icecream import ic
|
||
|
|
||
|
from contextlib import ExitStack, nullcontext
|
||
|
|
||
|
from rf2aa.util_module import *
|
||
|
from rf2aa.model.layers.Attention_module import *
|
||
|
from rf2aa.model.layers.SE3_network import SE3TransformerWrapper
|
||
|
from rf2aa.util import is_atom, xyz_frame_from_rotation_mask
|
||
|
from rf2aa.loss.loss import (
|
||
|
calc_lj_grads, calc_chiral_grads
|
||
|
)
|
||
|
|
||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||
|
|
||
|
|
||
|
# Components for three-track blocks
|
||
|
# 1. MSA -> MSA update (biased attention. bias from pair & structure)
|
||
|
# 2. Pair -> Pair update (biased attention. bias from structure)
|
||
|
# 3. MSA -> Pair update (extract coevolution signal)
|
||
|
# 4. Str -> Str update (node from MSA, edge from Pair)
|
||
|
|
||
|
class PositionalEncoding2D(nn.Module):
|
||
|
# Add relative positional encoding to pair features
|
||
|
def __init__(self, d_pair, minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.15, use_same_chain=False, enable_same_chain=False):
|
||
|
super(PositionalEncoding2D, self).__init__()
|
||
|
self.minpos = minpos
|
||
|
self.maxpos = maxpos
|
||
|
self.maxpos_atom = maxpos_atom
|
||
|
self.nbin_res = abs(minpos)+maxpos+2 # include 0 and "unknown" value (maxpos+1)
|
||
|
self.nbin_atom = maxpos_atom+2 # include 0 and "unknown" token (maxpos_sm + 1)
|
||
|
self.emb_res = nn.Embedding(self.nbin_res, d_pair)
|
||
|
self.emb_atom = nn.Embedding(self.nbin_atom, d_pair)
|
||
|
|
||
|
self.use_same_chain = use_same_chain
|
||
|
if use_same_chain:
|
||
|
self.emb_chain = nn.Embedding(2, d_pair)
|
||
|
self.enable_same_chain = enable_same_chain
|
||
|
|
||
|
def forward(self, seq, idx, bond_feats, dist_matrix, same_chain=None):
|
||
|
sm_mask = is_atom(seq[0])
|
||
|
|
||
|
res_dist, atom_dist = get_res_atom_dist(idx, bond_feats, dist_matrix, sm_mask,
|
||
|
minpos_res=self.minpos, maxpos_res=self.maxpos, maxpos_atom=self.maxpos_atom)
|
||
|
|
||
|
bins = torch.arange(self.minpos, self.maxpos+1, device=seq.device)
|
||
|
ib_res = torch.bucketize(res_dist, bins).long() # (B, L, L)
|
||
|
emb_res = self.emb_res(ib_res) #(B, L, L, d_pair)
|
||
|
|
||
|
bins = torch.arange(0, self.maxpos_atom+1, device=seq.device)
|
||
|
ib_atom = torch.bucketize(atom_dist, bins).long() # (B, L, L)
|
||
|
emb_atom = self.emb_atom(ib_atom) #(B, L, L, d_pair)
|
||
|
|
||
|
out = emb_res + emb_atom
|
||
|
|
||
|
if self.use_same_chain and self.enable_same_chain == False and same_chain is not None:
|
||
|
emb_c = self.emb_chain(same_chain.long())
|
||
|
out += emb_c*0 # cursed but exists for backwards compatibility
|
||
|
elif self.enable_same_chain == True:
|
||
|
emb_c = self.emb_chain(same_chain.long())
|
||
|
out += emb_c
|
||
|
|
||
|
return out
|
||
|
|
||
|
|
||
|
# Update MSA with biased self-attention. bias from Pair & Str
|
||
|
class MSAPairStr2MSA(nn.Module):
|
||
|
def __init__(self, d_msa=256, d_pair=128, n_head=8, d_state=16, d_rbf=64,
|
||
|
d_hidden=32, p_drop=0.15, use_global_attn=False):
|
||
|
super(MSAPairStr2MSA, self).__init__()
|
||
|
self.norm_pair = nn.LayerNorm(d_pair)
|
||
|
self.emb_rbf = nn.Linear(d_rbf, d_pair)
|
||
|
self.norm_state = nn.LayerNorm(d_state)
|
||
|
self.proj_state = nn.Linear(d_state, d_msa)
|
||
|
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
|
||
|
self.row_attn = MSARowAttentionWithBias(d_msa=d_msa, d_pair=d_pair,
|
||
|
n_head=n_head, d_hidden=d_hidden)
|
||
|
if use_global_attn:
|
||
|
self.col_attn = MSAColGlobalAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
|
||
|
else:
|
||
|
self.col_attn = MSAColAttention(d_msa=d_msa, n_head=n_head, d_hidden=d_hidden)
|
||
|
self.ff = FeedForwardLayer(d_msa, 4, p_drop=p_drop)
|
||
|
|
||
|
# Do proper initialization
|
||
|
self.reset_parameter()
|
||
|
|
||
|
def reset_parameter(self):
|
||
|
# initialize weights to normal distrib
|
||
|
self.emb_rbf= init_lecun_normal(self.emb_rbf)
|
||
|
self.proj_state = init_lecun_normal(self.proj_state)
|
||
|
|
||
|
# initialize bias to zeros
|
||
|
nn.init.zeros_(self.emb_rbf.bias)
|
||
|
nn.init.zeros_(self.proj_state.bias)
|
||
|
|
||
|
def forward(self, msa, pair, rbf_feat, state):
|
||
|
'''
|
||
|
Inputs:
|
||
|
- msa: MSA feature (B, N, L, d_msa)
|
||
|
- pair: Pair feature (B, L, L, d_pair)
|
||
|
- rbf_feat: Ca-Ca distance feature calculated from xyz coordinates (B, L, L, 36)
|
||
|
- xyz: xyz coordinates (B, L, n_atom, 3)
|
||
|
- state: updated node features after SE(3)-Transformer layer (B, L, d_state)
|
||
|
Output:
|
||
|
- msa: Updated MSA feature (B, N, L, d_msa)
|
||
|
'''
|
||
|
B, N, L = msa.shape[:3]
|
||
|
|
||
|
# prepare input bias feature by combining pair & coordinate info
|
||
|
pair = self.norm_pair(pair)
|
||
|
pair = pair + self.emb_rbf(rbf_feat)
|
||
|
#
|
||
|
# update query sequence feature (first sequence in the MSA) with feedbacks (state) from SE3
|
||
|
state = self.norm_state(state)
|
||
|
state = self.proj_state(state).reshape(B, 1, L, -1)
|
||
|
msa = msa.type_as(state)
|
||
|
msa = msa.index_add(1, torch.tensor([0,], device=state.device), state)
|
||
|
#
|
||
|
# Apply row/column attention to msa & transform
|
||
|
msa = msa + self.drop_row(self.row_attn(msa, pair))
|
||
|
msa = msa + self.col_attn(msa)
|
||
|
msa = msa + self.ff(msa)
|
||
|
|
||
|
return msa
|
||
|
|
||
|
|
||
|
def find_symmsub(Ltot, Lasu, k):
|
||
|
"""
|
||
|
Creates a symmsub matrix
|
||
|
|
||
|
Parameters:
|
||
|
Ltot (int, required): Total length of all residues
|
||
|
|
||
|
Lasu (int, required): Length of asymmetric units
|
||
|
|
||
|
k (int, required): Number of off diagonals to include in symmetrization
|
||
|
|
||
|
|
||
|
"""
|
||
|
assert Ltot % Lasu == 0
|
||
|
nchunk = Ltot // Lasu
|
||
|
|
||
|
N = 2*k + 1 # total number of diagonals being accessed
|
||
|
symmsub = torch.ones((nchunk, nchunk))*-1
|
||
|
C = 0 # a marker for blocks of the same category
|
||
|
|
||
|
for i in range(N): # i = 0, 1,2, 3,4, 5,6...
|
||
|
offset = int(((i+1) // 2) * (math.pow(-1,i))) # offset = 0,-1,1,-2,2,-3,3...
|
||
|
|
||
|
row = torch.arange(nchunk)
|
||
|
col = torch.roll(row, offset)
|
||
|
|
||
|
if offset < 0:
|
||
|
row = row[:-abs(offset)]
|
||
|
col = col[:-abs(offset)]
|
||
|
elif offset > 0:
|
||
|
row = row[abs(offset):]
|
||
|
col = col[abs(offset):]
|
||
|
else:# i=0
|
||
|
pass
|
||
|
|
||
|
symmsub[row, col] = i
|
||
|
|
||
|
return symmsub.long()
|
||
|
|
||
|
|
||
|
def copy_block_activations(pair, symmsub, main_block):
|
||
|
"""
|
||
|
copies pair activations around in blocks according to
|
||
|
matrix S
|
||
|
"""
|
||
|
raise NotImplementedError
|
||
|
|
||
|
return False
|
||
|
|
||
|
|
||
|
def max_block_activations(pair, symmsub):
|
||
|
"""
|
||
|
copies pair activations around in blocks according to
|
||
|
matrix S
|
||
|
"""
|
||
|
B,L = pair.shape[:2]
|
||
|
|
||
|
Osub = symmsub.shape[0]
|
||
|
|
||
|
# average pairs/blocks together
|
||
|
Leff = L//Osub
|
||
|
|
||
|
# applies block averaging to the pair representation based on symmsub
|
||
|
# pairsymm = torch.zeros([Osub,Leff,Leff,pair.shape[-1]], device=pair.device, dtype=pair.dtype)
|
||
|
# Nsymm = torch.zeros([Osub], device=pair.device, dtype=torch.int)
|
||
|
|
||
|
stacks = {}
|
||
|
|
||
|
# find all of the activation blocks
|
||
|
for i in range(Osub):
|
||
|
for j in range(Osub):
|
||
|
sij = symmsub[i,j]
|
||
|
if (sij>=0):
|
||
|
if not stacks.get(int(sij), False):
|
||
|
stacks[int(sij)] = []
|
||
|
stacks[int(sij)].append( pair[0, i*Leff:(i+1)*Leff, j*Leff:(j+1)*Leff] )
|
||
|
|
||
|
# make tensors and find max activation in each tensor
|
||
|
# ic(list(stacks.keys()))
|
||
|
for key,val in stacks.items():
|
||
|
A = torch.stack(stacks[key]) # stacked block activations
|
||
|
B,max_idx = torch.max(A, dim=0) # find the max
|
||
|
stacks[key] = B # replace with the max
|
||
|
|
||
|
for i in range(Osub):
|
||
|
for j in range(Osub):
|
||
|
sij = symmsub[i,j]
|
||
|
if (sij>=0):
|
||
|
pair[0, i*Leff:(i+1)*Leff, j*Leff:(j+1)*Leff] = stacks[int(sij)] #pairsymm[sij]/Nsymm[sij]
|
||
|
return pair
|
||
|
|
||
|
|
||
|
def mean_block_activations(pair, symmsub):
|
||
|
"""
|
||
|
Applies block average symmetrization
|
||
|
"""
|
||
|
B,L = pair.shape[:2]
|
||
|
|
||
|
Osub = symmsub.shape[0]
|
||
|
|
||
|
# average pairs/blocks together
|
||
|
Leff = L//Osub
|
||
|
|
||
|
# applies block averaging to the pair representation based on symmsub
|
||
|
pairsymm = torch.zeros([Osub,Leff,Leff,pair.shape[-1]], device=pair.device, dtype=pair.dtype)
|
||
|
Nsymm = torch.zeros([Osub], device=pair.device, dtype=torch.int)
|
||
|
|
||
|
for i in range(Osub):
|
||
|
for j in range(Osub):
|
||
|
sij = symmsub[i,j]
|
||
|
if (sij>=0):
|
||
|
pairsymm[sij] += pair[0, i*Leff:(i+1)*Leff, j*Leff:(j+1)*Leff]
|
||
|
Nsymm[sij] += 1
|
||
|
|
||
|
for i in range(Osub):
|
||
|
for j in range(Osub):
|
||
|
sij = symmsub[i,j]
|
||
|
if (sij>=0):
|
||
|
pair[0, i*Leff:(i+1)*Leff, j*Leff:(j+1)*Leff] = pairsymm[sij]/Nsymm[sij]
|
||
|
|
||
|
return pair
|
||
|
|
||
|
|
||
|
def apply_pair_symmetry(pair, symmsub, method='mean', main_block=None):
|
||
|
"""
|
||
|
Applies pair symmetrizing operation
|
||
|
"""
|
||
|
assert method in ['mean','max','copy']
|
||
|
|
||
|
if method == 'mean':
|
||
|
pair = mean_block_activations(pair, symmsub)
|
||
|
|
||
|
elif method == 'copy':
|
||
|
assert not (main_block is None), "cant have None main block here"
|
||
|
pair = copy_block_activations(pair, symmsub, main_block=main_block)
|
||
|
|
||
|
elif method == 'max':
|
||
|
pair = max_block_activations(pair, symmsub)
|
||
|
|
||
|
return pair
|
||
|
|
||
|
|
||
|
class PairStr2Pair(nn.Module):
|
||
|
def __init__(self, d_pair=128, n_head=4, d_hidden=32, d_hidden_state=16, d_rbf=64, d_state=32, p_drop=0.15,
|
||
|
symmetrize_repeats=False, repeat_length=None, symmsub_k=1, sym_method='max',main_block=None):
|
||
|
"""
|
||
|
|
||
|
Parameters:
|
||
|
symmetrize_repeats (bool, optional): whether to symmetrize the repeats.
|
||
|
|
||
|
repeat_length (int, optional): length of the repeat unit in repeat protein
|
||
|
|
||
|
symmsub_k (int, optional): number of diagonals to use for symmetrization
|
||
|
|
||
|
sym_method (str, optional): method to use for symmetrization.
|
||
|
|
||
|
main_block (int, optional): main block to use for symmetrization (the one with the motif)
|
||
|
"""
|
||
|
super(PairStr2Pair, self).__init__()
|
||
|
|
||
|
self.symmetrize_repeats = symmetrize_repeats
|
||
|
self.repeat_length = repeat_length
|
||
|
self.symmsub_k = symmsub_k
|
||
|
self.sym_method = sym_method
|
||
|
self.main_block = main_block
|
||
|
|
||
|
self.norm_state = nn.LayerNorm(d_state)
|
||
|
self.proj_left = nn.Linear(d_state, d_hidden_state)
|
||
|
self.proj_right = nn.Linear(d_state, d_hidden_state)
|
||
|
self.to_gate = nn.Linear(d_hidden_state*d_hidden_state, d_pair)
|
||
|
|
||
|
self.emb_rbf = nn.Linear(d_rbf, d_pair)
|
||
|
|
||
|
self.drop_row = Dropout(broadcast_dim=1, p_drop=p_drop)
|
||
|
self.drop_col = Dropout(broadcast_dim=2, p_drop=p_drop)
|
||
|
|
||
|
self.tri_mul_out = TriangleMultiplication(d_pair, d_hidden=d_hidden)
|
||
|
self.tri_mul_in = TriangleMultiplication(d_pair, d_hidden, outgoing=False)
|
||
|
|
||
|
self.row_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=True)
|
||
|
self.col_attn = BiasedAxialAttention(d_pair, d_pair, n_head, d_hidden, p_drop=p_drop, is_row=False)
|
||
|
|
||
|
self.ff = FeedForwardLayer(d_pair, 2)
|
||
|
|
||
|
self.reset_parameter()
|
||
|
|
||
|
def reset_parameter(self):
|
||
|
self.emb_rbf = init_lecun_normal(self.emb_rbf)
|
||
|
nn.init.zeros_(self.emb_rbf.bias)
|
||
|
|
||
|
self.proj_left = init_lecun_normal(self.proj_left)
|
||
|
nn.init.zeros_(self.proj_left.bias)
|
||
|
self.proj_right = init_lecun_normal(self.proj_right)
|
||
|
nn.init.zeros_(self.proj_right.bias)
|
||
|
|
||
|
# gating: zero weights, one biases (mostly open gate at the begining)
|
||
|
nn.init.zeros_(self.to_gate.weight)
|
||
|
nn.init.ones_(self.to_gate.bias)
|
||
|
|
||
|
# perform a striped p2p op
|
||
|
def subblock(self, OP, pair, rbf_feat, crop):
|
||
|
N,L = pair.shape[:2]
|
||
|
|
||
|
nbox = (L-1)//(crop//2)+1
|
||
|
idx = torch.triu_indices(nbox,nbox,1, device=pair.device)
|
||
|
ncrops = idx.shape[1]
|
||
|
|
||
|
pairnew = torch.zeros((N,L*L,pair.shape[-1]), device=pair.device, dtype=pair.dtype)
|
||
|
countnew = torch.zeros((N,L*L), device=pair.device, dtype=torch.int)
|
||
|
|
||
|
for i in range(ncrops):
|
||
|
# reindex sub-blocks
|
||
|
offsetC = torch.clamp( (1+idx[1,i:(i+1)])*(crop//2)-L, min=0 ) # account for going past L
|
||
|
offsetN = torch.zeros_like(offsetC)
|
||
|
mask = (offsetC>0)*((idx[0,i]+1)==idx[1,i])
|
||
|
offsetN[mask] = offsetC[mask]
|
||
|
pairIdx = torch.zeros((1,crop), dtype=torch.long, device=pair.device)
|
||
|
pairIdx[:,:(crop//2)] = torch.arange(crop//2, dtype=torch.long, device=pair.device)+idx[0,i:(i+1),None]*(crop//2) - offsetN[:,None]
|
||
|
pairIdx[:,(crop//2):] = torch.arange(crop//2, dtype=torch.long, device=pair.device)+idx[1,i:(i+1),None]*(crop//2) - offsetC[:,None]
|
||
|
|
||
|
# do reindexing
|
||
|
iL,iU = pairIdx[:,:,None], pairIdx[:,None,:]
|
||
|
paircrop = pair[:,iL,iU,:].reshape(-1,crop,crop,pair.shape[-1])
|
||
|
rbfcrop = rbf_feat[:,iL,iU,:].reshape(-1,crop,crop,rbf_feat.shape[-1])
|
||
|
|
||
|
# row attn
|
||
|
paircrop = OP(paircrop, rbfcrop).to(pair.dtype)
|
||
|
|
||
|
# unindex
|
||
|
iUL = (iL*L+iU).flatten()
|
||
|
pairnew.index_add_(1,iUL, paircrop.reshape(N,iUL.shape[0],pair.shape[-1]))
|
||
|
countnew.index_add_(1,iUL, torch.ones((N,iUL.shape[0]), device=pair.device, dtype=torch.int))
|
||
|
|
||
|
return pair + (pairnew/countnew[...,None]).reshape(N,L,L,-1)
|
||
|
|
||
|
def forward(self, pair, rbf_feat, state, crop=-1):
|
||
|
B,L = pair.shape[:2]
|
||
|
|
||
|
rbf_feat = self.emb_rbf(rbf_feat) # B, L, L, d_pair
|
||
|
|
||
|
state = self.norm_state(state)
|
||
|
left = self.proj_left(state)
|
||
|
right = self.proj_right(state)
|
||
|
gate = einsum('bli,bmj->blmij', left, right).reshape(B,L,L,-1)
|
||
|
gate = torch.sigmoid(self.to_gate(gate))
|
||
|
rbf_feat = gate*rbf_feat
|
||
|
|
||
|
|
||
|
crop = 2*(crop//2) # make sure even
|
||
|
|
||
|
if (crop>0 and crop<=L):
|
||
|
pair = self.subblock(
|
||
|
lambda x,y:self.drop_row(self.tri_mul_out(x)),
|
||
|
pair, rbf_feat, crop
|
||
|
)
|
||
|
|
||
|
pair = self.subblock(
|
||
|
lambda x,y:self.drop_row(self.tri_mul_in(x)),
|
||
|
pair, rbf_feat, crop
|
||
|
)
|
||
|
|
||
|
pair = self.subblock(
|
||
|
lambda x,y:self.drop_row(self.row_attn(x,y)),
|
||
|
pair, rbf_feat, crop
|
||
|
)
|
||
|
|
||
|
pair = self.subblock(
|
||
|
lambda x,y:self.drop_col(self.col_attn(x,y)),
|
||
|
pair, rbf_feat, crop
|
||
|
)
|
||
|
|
||
|
# feed forward layer
|
||
|
RESSTRIDE = 16384//L
|
||
|
for i in range((L-1)//RESSTRIDE+1):
|
||
|
r_i,r_j = i*RESSTRIDE, min((i+1)*RESSTRIDE,L)
|
||
|
pair[:,r_i:r_j] = pair[:,r_i:r_j] + self.ff(pair[:,r_i:r_j])
|
||
|
|
||
|
else:
|
||
|
#_nc = lambda x:torch.sum(torch.isnan(x))
|
||
|
pair = pair + self.drop_row(self.tri_mul_out(pair))
|
||
|
pair = pair + self.drop_row(self.tri_mul_in(pair))
|
||
|
pair = pair + self.drop_row(self.row_attn(pair, rbf_feat))
|
||
|
pair = pair + self.drop_col(self.col_attn(pair, rbf_feat))
|
||
|
pair = pair + self.ff(pair)
|
||
|
|
||
|
# symmetry/repeat proteins (Diffusion inference only)
|
||
|
if self.symmetrize_repeats:
|
||
|
symmsub = find_symmsub(L, self.repeat_length, self.symmsub_k)
|
||
|
else:
|
||
|
symmsub = None
|
||
|
|
||
|
if symmsub is not None:
|
||
|
pair = apply_pair_symmetry(pair, symmsub, self.sym_method, self.main_block)
|
||
|
return pair
|
||
|
|
||
|
class MSA2Pair(nn.Module):
|
||
|
def __init__(self, d_msa=256, d_pair=128, d_hidden=16, p_drop=0.15):
|
||
|
super(MSA2Pair, self).__init__()
|
||
|
self.norm = nn.LayerNorm(d_msa)
|
||
|
self.proj_left = nn.Linear(d_msa, d_hidden)
|
||
|
self.proj_right = nn.Linear(d_msa, d_hidden)
|
||
|
self.proj_out = nn.Linear(d_hidden*d_hidden, d_pair)
|
||
|
|
||
|
self.reset_parameter()
|
||
|
|
||
|
def reset_parameter(self):
|
||
|
# normal initialization
|
||
|
self.proj_left = init_lecun_normal(self.proj_left)
|
||
|
self.proj_right = init_lecun_normal(self.proj_right)
|
||
|
nn.init.zeros_(self.proj_left.bias)
|
||
|
nn.init.zeros_(self.proj_right.bias)
|
||
|
|
||
|
# zero initialize output
|
||
|
nn.init.zeros_(self.proj_out.weight)
|
||
|
nn.init.zeros_(self.proj_out.bias)
|
||
|
|
||
|
def forward(self, msa, pair):
|
||
|
B, N, L = msa.shape[:3]
|
||
|
msa = self.norm(msa)
|
||
|
left = self.proj_left(msa)
|
||
|
right = self.proj_right(msa)
|
||
|
right = right / float(N)
|
||
|
out = einsum('bsli,bsmj->blmij', left, right).reshape(B, L, L, -1)
|
||
|
out = self.proj_out(out)
|
||
|
|
||
|
pair = pair + out
|
||
|
|
||
|
return pair
|
||
|
|
||
|
class Str2Str(nn.Module):
|
||
|
def __init__(self, d_msa=256, d_pair=128, d_state=16, d_rbf=64,
|
||
|
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
||
|
nextra_l0=0, nextra_l1=0, p_drop=0.1
|
||
|
):
|
||
|
super(Str2Str, self).__init__()
|
||
|
|
||
|
# initial node & pair feature process
|
||
|
self.norm_msa = nn.LayerNorm(d_msa)
|
||
|
self.norm_pair = nn.LayerNorm(d_pair)
|
||
|
self.norm_state = nn.LayerNorm(d_state)
|
||
|
|
||
|
self.embed_node = nn.Linear(d_msa+d_state, SE3_param['l0_in_features'])
|
||
|
self.ff_node = FeedForwardLayer(SE3_param['l0_in_features'], 2, p_drop=p_drop)
|
||
|
self.norm_node = nn.LayerNorm(SE3_param['l0_in_features'])
|
||
|
|
||
|
self.embed_edge = nn.Linear(d_pair+d_rbf+1, SE3_param['num_edge_features'])
|
||
|
self.ff_edge = FeedForwardLayer(SE3_param['num_edge_features'], 2, p_drop=p_drop)
|
||
|
self.norm_edge = nn.LayerNorm(SE3_param['num_edge_features'])
|
||
|
|
||
|
SE3_param_temp = SE3_param.copy()
|
||
|
SE3_param_temp['l0_in_features'] += nextra_l0
|
||
|
SE3_param_temp['l1_in_features'] += nextra_l1
|
||
|
|
||
|
self.se3 = SE3TransformerWrapper(**SE3_param_temp)
|
||
|
|
||
|
self.sc_predictor = SCPred(
|
||
|
d_msa=d_msa,
|
||
|
d_state=SE3_param['l0_out_features'],
|
||
|
p_drop=p_drop)
|
||
|
|
||
|
self.nextra_l0 = nextra_l0
|
||
|
self.nextra_l1 = nextra_l1
|
||
|
|
||
|
self.reset_parameter()
|
||
|
|
||
|
def reset_parameter(self):
|
||
|
# initialize weights to normal distribution
|
||
|
self.embed_node = init_lecun_normal(self.embed_node)
|
||
|
self.embed_edge = init_lecun_normal(self.embed_edge)
|
||
|
|
||
|
# initialize bias to zeros
|
||
|
nn.init.zeros_(self.embed_node.bias)
|
||
|
nn.init.zeros_(self.embed_edge.bias)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def forward(self, msa, pair, xyz, state, idx, rotation_mask, bond_feats, dist_matrix, atom_frames, is_motif, extra_l0=None, extra_l1=None, use_atom_frames=True, top_k=128, eps=1e-5):
|
||
|
# process msa & pair features
|
||
|
B, N, L = msa.shape[:3]
|
||
|
seq = self.norm_msa(msa[:,0])
|
||
|
pair = self.norm_pair(pair)
|
||
|
state = self.norm_state(state)
|
||
|
|
||
|
node = torch.cat((seq, state), dim=-1)
|
||
|
node = self.embed_node(node)
|
||
|
node = node + self.ff_node(node)
|
||
|
node = self.norm_node(node)
|
||
|
|
||
|
neighbor = get_seqsep_protein_sm(idx, bond_feats, dist_matrix, rotation_mask)
|
||
|
cas = xyz[:,:,1].contiguous()
|
||
|
rbf_feat = rbf(torch.cdist(cas, cas))
|
||
|
edge = torch.cat((pair, rbf_feat, neighbor), dim=-1)
|
||
|
edge = self.embed_edge(edge)
|
||
|
edge = edge + self.ff_edge(edge)
|
||
|
edge = self.norm_edge(edge)
|
||
|
|
||
|
# define graph
|
||
|
if top_k > 0:
|
||
|
G, edge_feats = make_topk_graph(xyz[:,:,1,:], edge, idx, top_k=top_k)
|
||
|
else:
|
||
|
G, edge_feats = make_full_graph(xyz[:,:,1,:], edge, idx)
|
||
|
|
||
|
if use_atom_frames: # ligand l1 features are vectors to neighboring atoms
|
||
|
xyz_frame = xyz_frame_from_rotation_mask(xyz, rotation_mask, atom_frames)
|
||
|
l1_feats = xyz_frame - xyz_frame[:,:,1,:].unsqueeze(2)
|
||
|
else: # old (incorrect) behavior: vectors to random initial coords of virtual N and C
|
||
|
l1_feats = xyz - xyz[:,:,1,:].unsqueeze(2)
|
||
|
l1_feats = l1_feats.reshape(B*L, -1, 3)
|
||
|
|
||
|
if extra_l1 is not None:
|
||
|
l1_feats = torch.cat( (l1_feats,extra_l1), dim=1 )
|
||
|
if extra_l0 is not None:
|
||
|
node = torch.cat( (node,extra_l0), dim=2 )
|
||
|
|
||
|
# apply SE(3) Transformer & update coordinates
|
||
|
shift = self.se3(G, node.reshape(B*L, -1, 1), l1_feats, edge_feats)
|
||
|
|
||
|
state = shift['0'].reshape(B, L, -1) # (B, L, C)
|
||
|
|
||
|
offset = shift['1'].reshape(B, L, 2, 3)
|
||
|
offset[:,is_motif,...] = 0 # NOTE: DJ - frozen motif!!
|
||
|
T = offset[:,:,0,:] / 10
|
||
|
R = offset[:,:,1,:] / 100.0
|
||
|
|
||
|
Qnorm = torch.sqrt( 1 + torch.sum(R*R, dim=-1) )
|
||
|
qA, qB, qC, qD = 1/Qnorm, R[:,:,0]/Qnorm, R[:,:,1]/Qnorm, R[:,:,2]/Qnorm
|
||
|
|
||
|
v = xyz - xyz[:,:,1:2,:]
|
||
|
Rout = torch.zeros((B,L,3,3), device=xyz.device)
|
||
|
Rout[:,:,0,0] = qA*qA+qB*qB-qC*qC-qD*qD
|
||
|
Rout[:,:,0,1] = 2*qB*qC - 2*qA*qD
|
||
|
Rout[:,:,0,2] = 2*qB*qD + 2*qA*qC
|
||
|
Rout[:,:,1,0] = 2*qB*qC + 2*qA*qD
|
||
|
Rout[:,:,1,1] = qA*qA-qB*qB+qC*qC-qD*qD
|
||
|
Rout[:,:,1,2] = 2*qC*qD - 2*qA*qB
|
||
|
Rout[:,:,2,0] = 2*qB*qD - 2*qA*qC
|
||
|
Rout[:,:,2,1] = 2*qC*qD + 2*qA*qB
|
||
|
Rout[:,:,2,2] = qA*qA-qB*qB-qC*qC+qD*qD
|
||
|
I = torch.eye(3, device=Rout.device).expand(B,L,3,3)
|
||
|
Rout = torch.where(rotation_mask.reshape(B, L, 1,1), I, Rout)
|
||
|
xyz = torch.einsum('blij,blaj->blai', Rout,v)+xyz[:,:,1:2,:]+T[:,:,None,:]
|
||
|
|
||
|
alpha = self.sc_predictor(msa[:,0], state)
|
||
|
return xyz, state, alpha, torch.stack([qA, qB, qC, qD], dim=2)
|
||
|
|
||
|
|
||
|
class Allatom2Allatom(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
SE3_param
|
||
|
):
|
||
|
super(Allatom2Allatom, self).__init__()
|
||
|
|
||
|
self.se3 = SE3TransformerWrapper(**SE3_param)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def forward(self, seq, xyz, aamask, num_bonds, state, grads, top_k=24, eps=1e-5):
|
||
|
raise Exception('not implemented for diffusion')
|
||
|
# seq (B,L)
|
||
|
# xyz (B,L,27,3)
|
||
|
# aamask (22,27) [per-amino-acid]
|
||
|
# num_bonds (22,27,27) [per-amino-acid]
|
||
|
# state (N,B,L,K) [K channels]
|
||
|
# grads (N,B,L,27,3) [N terms]
|
||
|
|
||
|
B, L = xyz.shape[:2]
|
||
|
|
||
|
mask = aamask[seq]
|
||
|
G, edge = make_atom_graph( xyz, mask, num_bonds[seq], top_k, maxbonds=4 )
|
||
|
node = state[mask]
|
||
|
node_l1 = grads[:,mask].permute(1,0,2)
|
||
|
|
||
|
# apply SE(3) Transformer & update coordinates
|
||
|
shift = self.se3(G, node[...,None], node_l1, edge)
|
||
|
|
||
|
state[mask] = shift['0'][...,0]
|
||
|
xyz[mask] = xyz[mask] + shift['1'].squeeze(1) / 100.0
|
||
|
|
||
|
return xyz, state
|
||
|
|
||
|
class AllatomEmbed(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
d_state_in=64,
|
||
|
d_state_out=32,
|
||
|
p_mask=0.15
|
||
|
):
|
||
|
super(AllatomEmbed, self).__init__()
|
||
|
|
||
|
self.p_mask = p_mask
|
||
|
|
||
|
# initial node & pair feature process
|
||
|
self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element
|
||
|
self.norm_state = nn.LayerNorm(d_state_out)
|
||
|
|
||
|
self.reset_parameter()
|
||
|
|
||
|
def reset_parameter(self):
|
||
|
# initialize weights to normal distribution
|
||
|
self.compress_embed = init_lecun_normal(self.compress_embed)
|
||
|
# initialize bias to zeros
|
||
|
nn.init.zeros_(self.compress_embed.bias)
|
||
|
|
||
|
def forward(self, state, seq, eltmap):
|
||
|
B,L = state.shape[:2]
|
||
|
mask = torch.rand(B,L) < self.p_mask
|
||
|
state = state.reshape(B,L,1,-1).repeat(1,1,27,1)
|
||
|
state[mask] = 0.0
|
||
|
elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element
|
||
|
state = self.compress_embed(
|
||
|
torch.cat( (state,elements), dim=-1 )
|
||
|
)
|
||
|
state = self.norm_state( state )
|
||
|
|
||
|
return state
|
||
|
|
||
|
# embed residue state + atomtype -> per-atom state
|
||
|
#
|
||
|
class AllatomEmbed(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
d_state_in=64,
|
||
|
d_state_out=32,
|
||
|
p_mask=0.15
|
||
|
):
|
||
|
super(AllatomEmbed, self).__init__()
|
||
|
|
||
|
self.p_mask = p_mask
|
||
|
|
||
|
# initial node & pair feature process
|
||
|
self.compress_embed = nn.Linear(d_state_in + 29, d_state_out) # 29->5 if using element
|
||
|
self.norm_state = nn.LayerNorm(d_state_out)
|
||
|
|
||
|
self.reset_parameter()
|
||
|
|
||
|
def reset_parameter(self):
|
||
|
# initialize weights to normal distribution
|
||
|
self.compress_embed = init_lecun_normal(self.compress_embed)
|
||
|
# initialize bias to zeros
|
||
|
nn.init.zeros_(self.compress_embed.bias)
|
||
|
|
||
|
def forward(self, state, seq, eltmap):
|
||
|
B,L = state.shape[:2]
|
||
|
mask = torch.rand(B,L) < self.p_mask
|
||
|
state = state.reshape(B,L,1,-1).repeat(1,1,27,1)
|
||
|
state[mask] = 0.0
|
||
|
elements = F.one_hot(eltmap[seq], num_classes=29) # 29->5 if using element
|
||
|
state = self.compress_embed(
|
||
|
torch.cat( (state,elements), dim=-1 )
|
||
|
)
|
||
|
state = self.norm_state( state )
|
||
|
|
||
|
return state
|
||
|
|
||
|
# embed per-atom state -> residue state
|
||
|
class ResidueEmbed(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
d_state_in=16,
|
||
|
d_state_out=64
|
||
|
):
|
||
|
super(ResidueEmbed, self).__init__()
|
||
|
|
||
|
self.compress_embed = nn.Linear(27*d_state_in, d_state_out)
|
||
|
self.norm_state = nn.LayerNorm(d_state_out)
|
||
|
|
||
|
self.reset_parameter()
|
||
|
|
||
|
def reset_parameter(self):
|
||
|
# initialize weights to normal distribution
|
||
|
self.compress_embed = init_lecun_normal(self.compress_embed)
|
||
|
# initialize bias to zeros
|
||
|
nn.init.zeros_(self.compress_embed.bias)
|
||
|
|
||
|
def forward(self, state):
|
||
|
B,L = state.shape[:2]
|
||
|
|
||
|
state = self.compress_embed( state.reshape(B,L,-1) )
|
||
|
state = self.norm_state( state )
|
||
|
|
||
|
return state
|
||
|
|
||
|
class SCPred(nn.Module):
|
||
|
def __init__(self, d_msa=256, d_state=32, d_hidden=128, p_drop=0.15):
|
||
|
super(SCPred, self).__init__()
|
||
|
self.norm_s0 = nn.LayerNorm(d_msa)
|
||
|
self.norm_si = nn.LayerNorm(d_state)
|
||
|
self.linear_s0 = nn.Linear(d_msa, d_hidden)
|
||
|
self.linear_si = nn.Linear(d_state, d_hidden)
|
||
|
|
||
|
# ResNet layers
|
||
|
self.linear_1 = nn.Linear(d_hidden, d_hidden)
|
||
|
self.linear_2 = nn.Linear(d_hidden, d_hidden)
|
||
|
self.linear_3 = nn.Linear(d_hidden, d_hidden)
|
||
|
self.linear_4 = nn.Linear(d_hidden, d_hidden)
|
||
|
|
||
|
# Final outputs
|
||
|
self.linear_out = nn.Linear(d_hidden, 2*ChemData().NTOTALDOFS)
|
||
|
|
||
|
self.reset_parameter()
|
||
|
|
||
|
def reset_parameter(self):
|
||
|
# normal initialization
|
||
|
self.linear_s0 = init_lecun_normal(self.linear_s0)
|
||
|
self.linear_si = init_lecun_normal(self.linear_si)
|
||
|
self.linear_out = init_lecun_normal(self.linear_out)
|
||
|
nn.init.zeros_(self.linear_s0.bias)
|
||
|
nn.init.zeros_(self.linear_si.bias)
|
||
|
nn.init.zeros_(self.linear_out.bias)
|
||
|
|
||
|
# right before relu activation: He initializer (kaiming normal)
|
||
|
nn.init.kaiming_normal_(self.linear_1.weight, nonlinearity='relu')
|
||
|
nn.init.zeros_(self.linear_1.bias)
|
||
|
nn.init.kaiming_normal_(self.linear_3.weight, nonlinearity='relu')
|
||
|
nn.init.zeros_(self.linear_3.bias)
|
||
|
|
||
|
# right before residual connection: zero initialize
|
||
|
nn.init.zeros_(self.linear_2.weight)
|
||
|
nn.init.zeros_(self.linear_2.bias)
|
||
|
nn.init.zeros_(self.linear_4.weight)
|
||
|
nn.init.zeros_(self.linear_4.bias)
|
||
|
|
||
|
def forward(self, seq, state):
|
||
|
'''
|
||
|
Predict side-chain torsion angles along with backbone torsions
|
||
|
Inputs:
|
||
|
- seq: hidden embeddings corresponding to query sequence (B, L, d_msa)
|
||
|
- state: state feature (output l0 feature) from previous SE3 layer (B, L, d_state)
|
||
|
Outputs:
|
||
|
- si: predicted torsion/pseudotorsion angles (phi, psi, omega, chi1~4 with cos/sin, theta) (B, L, NTOTALDOFS, 2)
|
||
|
'''
|
||
|
B, L = seq.shape[:2]
|
||
|
seq = self.norm_s0(seq)
|
||
|
state = self.norm_si(state)
|
||
|
si = self.linear_s0(seq) + self.linear_si(state)
|
||
|
|
||
|
si = si + self.linear_2(F.relu_(self.linear_1(F.relu_(si))))
|
||
|
si = si + self.linear_4(F.relu_(self.linear_3(F.relu_(si))))
|
||
|
|
||
|
si = self.linear_out(F.relu_(si))
|
||
|
return si.view(B, L, ChemData().NTOTALDOFS, 2)
|
||
|
|
||
|
def update_symm_Rs(xyz, Lasu, symmsub, symmRs, fit=False, tscale=1.0):
|
||
|
def dist_error_comp(R0,T0,xyz,tscale):
|
||
|
B = xyz.shape[0]
|
||
|
Tcom = xyz[:,:Lasu].mean(dim=1,keepdim=True)
|
||
|
Tcorr = torch.einsum('ij,brj->bri', R0, xyz[:,:Lasu]-Tcom) + Tcom + tscale*T0[None,None,:]
|
||
|
|
||
|
# distance map loss
|
||
|
Xsymm = torch.einsum('sij,brj->bsri', symmRs[symmsub], Tcorr).reshape(B,-1,3)
|
||
|
Xtrue = Ts
|
||
|
|
||
|
delsx = Xsymm[:,:Lasu,None]-Xsymm[:, None, Lasu:]
|
||
|
deltx = Xtrue[:,:Lasu,None]-Xtrue[:, None, Lasu:]
|
||
|
dsymm = torch.linalg.norm(delsx, dim=-1)
|
||
|
dtrue = torch.linalg.norm(deltx, dim=-1)
|
||
|
loss1 = torch.abs(dsymm-dtrue).mean()
|
||
|
|
||
|
return loss1,0.0 #loss2
|
||
|
|
||
|
def dist_error(R0,T0,xyz,tscale,w_clash=0.0):
|
||
|
l1,l2 = dist_error_comp(R0,T0,xyz,tscale)
|
||
|
return l1+w_clash*l2
|
||
|
|
||
|
def Q2R(Q):
|
||
|
Qs = torch.cat((torch.ones((1),device=Q.device),Q),dim=-1)
|
||
|
Qs = normQ(Qs)
|
||
|
return Qs2Rs(Qs[None,:]).squeeze(0)
|
||
|
|
||
|
|
||
|
B = xyz.shape[0]
|
||
|
|
||
|
# symmetry correction 1: don't let COM (of entire complex) move
|
||
|
#Tmean = xyz[:,:Lasu].reshape(-1,3).mean(dim=0)
|
||
|
#Tmean = torch.einsum('sij,j->si', symmRs, Tmean).mean(dim=0)
|
||
|
#xyz = xyz - Tmean
|
||
|
|
||
|
if fit:
|
||
|
# symmetry correction 2: use minimization to minimize drms
|
||
|
with torch.enable_grad():
|
||
|
T0 = torch.zeros(3,device=xyz.device).requires_grad_(True)
|
||
|
Q0 = torch.zeros(3,device=xyz.device).requires_grad_(True)
|
||
|
lbfgs = torch.optim.LBFGS([T0,Q0],
|
||
|
history_size=10,
|
||
|
max_iter=4,
|
||
|
line_search_fn="strong_wolfe")
|
||
|
def closure():
|
||
|
lbfgs.zero_grad()
|
||
|
loss = dist_error(Q2R(Q0),T0,xyz[:,:,1], tscale)
|
||
|
loss.backward() #retain_graph=True)
|
||
|
return loss
|
||
|
|
||
|
for e in range(4):
|
||
|
loss = lbfgs.step(closure)
|
||
|
|
||
|
Tcom = xyz[:,:Lasu].mean(dim=1,keepdim=True).detach()
|
||
|
Q0 = Q0.detach()
|
||
|
T0 = T0.detach()
|
||
|
xyz = torch.einsum('ij,braj->brai', Q2R(Q0), xyz[:,:Lasu]-Tcom) +Tcom + tscale*T0[None,None,:]
|
||
|
|
||
|
xyz = torch.einsum('sij,braj->bsrai', symmRs[symmsub], xyz[:,:Lasu])
|
||
|
xyz = xyz.reshape(B,-1,3,3) # (B,S,L,3,3)
|
||
|
|
||
|
return xyz
|
||
|
|
||
|
|
||
|
def update_symm_subs(xyz, pair, symmids, symmsub, symmRs, metasymm, fit=False, tscale=1.0):
|
||
|
B,Ls = xyz.shape[0:2]
|
||
|
Osub = symmsub.shape[0]
|
||
|
L = Ls//Osub
|
||
|
|
||
|
com = xyz[:,:L,1].mean(dim=-2)
|
||
|
rcoms = torch.einsum('sij,bj->si', symmRs, com)
|
||
|
subsymms, nneighs = metasymm
|
||
|
symmsub_new = []
|
||
|
for i in range(len(subsymms)):
|
||
|
drcoms = torch.linalg.norm(rcoms[0,:] - rcoms[subsymms[i],:], dim=-1)
|
||
|
_,subs_i = torch.topk(drcoms,nneighs[i],largest=False)
|
||
|
subs_i,_ = torch.sort( subsymms[i][subs_i] )
|
||
|
symmsub_new.append(subs_i)
|
||
|
|
||
|
symmsub_new = torch.cat(symmsub_new)
|
||
|
|
||
|
s_old = symmids[symmsub[:,None],symmsub[None,:]]
|
||
|
s_new = symmids[symmsub_new[:,None],symmsub_new[None,:]]
|
||
|
|
||
|
# remap old->new
|
||
|
# a) find highest-magnitude patches
|
||
|
pairsub = dict()
|
||
|
pairmag = dict()
|
||
|
for i in range(Osub):
|
||
|
for j in range(Osub):
|
||
|
idx_old = s_old[i,j].item()
|
||
|
sub_ij = pair[:,i*L:(i+1)*L,j*L:(j+1)*L,:].clone()
|
||
|
mag_ij = torch.max(sub_ij.flatten()) #torch.norm(sub_ij.flatten())
|
||
|
if idx_old not in pairsub or mag_ij > pairmag[idx_old]:
|
||
|
pairmag[idx_old] = mag_ij
|
||
|
pairsub[idx_old] = (i,j) #sub_ij
|
||
|
|
||
|
# b) reindex
|
||
|
idx = torch.zeros((Osub*L,Osub*L),dtype=torch.long,device=pair.device)
|
||
|
idx = (
|
||
|
torch.arange(Osub*L,device=pair.device)[:,None]*Osub*L
|
||
|
+ torch.arange(Osub*L,device=pair.device)[None,:]
|
||
|
)
|
||
|
for i in range(Osub):
|
||
|
for j in range(Osub):
|
||
|
idx_new = s_new[i,j].item()
|
||
|
if idx_new in pairsub:
|
||
|
inew,jnew = pairsub[idx_new]
|
||
|
idx[i*L:(i+1)*L,j*L:(j+1)*L] = (
|
||
|
Osub*L*torch.arange(inew*L,(inew+1)*L)[:,None]
|
||
|
+ torch.arange(jnew*L,(jnew+1)*L)[None,:]
|
||
|
)
|
||
|
pair = pair.view(1,-1,pair.shape[-1])[:,idx.flatten(),:].view(1,Osub*L,Osub*L,pair.shape[-1])
|
||
|
|
||
|
if symmsub is not None and symmsub.shape[0]>1:
|
||
|
xyz = update_symm_Rs(xyz, L, symmsub_new, symmRs, fit, tscale)
|
||
|
|
||
|
return xyz, pair, symmsub_new
|
||
|
|
||
|
|
||
|
class IterBlock(nn.Module):
|
||
|
def __init__(self, d_msa=256, d_pair=128, d_rbf=64,
|
||
|
n_head_msa=8, n_head_pair=4,
|
||
|
use_global_attn=False,
|
||
|
d_hidden=32, d_hidden_msa=None,
|
||
|
minpos=-32, maxpos=32, maxpos_atom=8, p_drop=0.15,
|
||
|
SE3_param={'l0_in_features':32, 'l0_out_features':16, 'num_edge_features':32},
|
||
|
nextra_l0=0, nextra_l1=0, use_same_chain=False, enable_same_chain=False,
|
||
|
symmetrize_repeats=None, repeat_length=None,symmsub_k=None, sym_method=None, main_block=None,
|
||
|
fit=False, tscale=1.0
|
||
|
):
|
||
|
|
||
|
super(IterBlock, self).__init__()
|
||
|
if d_hidden_msa == None:
|
||
|
d_hidden_msa = d_hidden
|
||
|
|
||
|
self.fit = fit
|
||
|
self.tscale = tscale
|
||
|
|
||
|
self.pos = PositionalEncoding2D(d_rbf, minpos=minpos, maxpos=maxpos,
|
||
|
maxpos_atom=maxpos_atom, p_drop=p_drop,
|
||
|
use_same_chain=use_same_chain,
|
||
|
enable_same_chain=enable_same_chain)
|
||
|
|
||
|
self.msa2msa = MSAPairStr2MSA(d_msa=d_msa, d_pair=d_pair, d_rbf=d_rbf,
|
||
|
n_head=n_head_msa,
|
||
|
d_state=SE3_param['l0_out_features'],
|
||
|
use_global_attn=use_global_attn,
|
||
|
d_hidden=d_hidden_msa, p_drop=p_drop)
|
||
|
|
||
|
self.msa2pair = MSA2Pair(d_msa=d_msa, d_pair=d_pair,
|
||
|
d_hidden=d_hidden//2, p_drop=p_drop)
|
||
|
|
||
|
self.pair2pair = PairStr2Pair(d_pair=d_pair, n_head=n_head_pair, d_rbf=d_rbf,
|
||
|
d_state=SE3_param['l0_out_features'],
|
||
|
d_hidden=d_hidden, p_drop=p_drop,
|
||
|
symmetrize_repeats=symmetrize_repeats, repeat_length=repeat_length,
|
||
|
symmsub_k=symmsub_k, sym_method=sym_method, main_block=main_block)
|
||
|
|
||
|
self.str2str = Str2Str(d_msa=d_msa, d_pair=d_pair, d_rbf=d_rbf,
|
||
|
d_state=SE3_param['l0_out_features'],
|
||
|
SE3_param=SE3_param,
|
||
|
p_drop=p_drop,
|
||
|
nextra_l0=nextra_l0,
|
||
|
nextra_l1=nextra_l1)
|
||
|
|
||
|
def forward(
|
||
|
self, msa, pair, xyz, state, seq_unmasked, idx,
|
||
|
symmids, symmsub, symmRs, symmmeta,
|
||
|
bond_feats, same_chain, is_motif, dist_matrix,
|
||
|
use_checkpoint=False, top_k=128, rotation_mask=None, atom_frames=None, extra_l0=None, extra_l1=None, use_atom_frames=True,
|
||
|
crop=-1
|
||
|
):
|
||
|
cas = xyz[:,:,1].contiguous()
|
||
|
rbf_feat = rbf(torch.cdist(cas, cas)) + self.pos(seq_unmasked, idx, bond_feats, dist_matrix, same_chain)
|
||
|
if use_checkpoint:
|
||
|
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state, use_reentrant=True)
|
||
|
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair, use_reentrant=True)
|
||
|
pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat, state, crop, use_reentrant=True)
|
||
|
|
||
|
xyz, state, alpha, quat = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=top_k),
|
||
|
msa.float(), pair.float(), xyz.detach().float(), state.float(), idx,
|
||
|
rotation_mask, bond_feats, dist_matrix, atom_frames, is_motif, extra_l0, extra_l1, use_atom_frames,
|
||
|
use_reentrant=True
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
msa = self.msa2msa(msa, pair, rbf_feat, state)
|
||
|
pair = self.msa2pair(msa, pair)
|
||
|
pair = self.pair2pair(pair, rbf_feat, state, crop)
|
||
|
|
||
|
xyz, state, alpha, quat = self.str2str(
|
||
|
msa.float(), pair.float(), xyz.detach().float(), state.float(),
|
||
|
idx, rotation_mask, bond_feats, dist_matrix, atom_frames, is_motif, extra_l0, extra_l1, use_atom_frames, top_k=top_k
|
||
|
)
|
||
|
|
||
|
# update contacting subunits
|
||
|
# symmetrize pair features
|
||
|
if symmsub is not None and symmsub.shape[0]>1:
|
||
|
xyz, pair, symmsub = update_symm_subs(xyz, pair, symmids, symmsub, symmRs, symmmeta, self.fit, self.tscale)
|
||
|
|
||
|
return msa, pair, xyz, state, alpha, symmsub, quat
|
||
|
|
||
|
class IterativeSimulator(nn.Module):
|
||
|
def __init__(self, n_extra_block=4, n_main_block=12, n_ref_block=4, n_finetune_block=0,
|
||
|
d_msa=256, d_msa_full=64, d_pair=128, d_hidden=32,
|
||
|
n_head_msa=8, n_head_pair=4,
|
||
|
SE3_param={}, SE3_ref_param={}, p_drop=0.15,
|
||
|
atom_type_index=None, aamask=None,
|
||
|
ljlk_parameters=None, lj_correction_parameters=None,
|
||
|
cb_len=None, cb_ang=None, cb_tor=None,
|
||
|
num_bonds=None, lj_lin=0.6, use_same_chain=False, enable_same_chain=False,
|
||
|
use_chiral_l1=True, use_lj_l1=False, refiner_topk=64,
|
||
|
symmetrize_repeats=None,
|
||
|
repeat_length=None,
|
||
|
symmsub_k=None,
|
||
|
sym_method=None,
|
||
|
main_block=None,
|
||
|
fit=False,
|
||
|
tscale=1.0
|
||
|
):
|
||
|
super(IterativeSimulator, self).__init__()
|
||
|
self.n_extra_block = n_extra_block
|
||
|
self.n_main_block = n_main_block
|
||
|
self.n_ref_block = n_ref_block
|
||
|
self.n_finetune_block = n_finetune_block
|
||
|
|
||
|
self.atom_type_index = atom_type_index
|
||
|
self.aamask = aamask
|
||
|
self.ljlk_parameters = ljlk_parameters
|
||
|
self.lj_correction_parameters = lj_correction_parameters
|
||
|
self.num_bonds = num_bonds
|
||
|
self.lj_lin = lj_lin
|
||
|
self.cb_len = cb_len
|
||
|
self.cb_ang = cb_ang
|
||
|
self.cb_tor = cb_tor
|
||
|
self.use_chiral_l1 = use_chiral_l1
|
||
|
self.use_lj_l1 = use_lj_l1
|
||
|
self.enable_same_chain = enable_same_chain
|
||
|
self.fit = fit
|
||
|
self.tscale = tscale
|
||
|
|
||
|
# Update with extra sequences
|
||
|
if n_extra_block > 0:
|
||
|
self.extra_block = nn.ModuleList([IterBlock(d_msa=d_msa_full, d_pair=d_pair,
|
||
|
n_head_msa=n_head_msa,
|
||
|
n_head_pair=n_head_pair,
|
||
|
d_hidden_msa=8,
|
||
|
d_hidden=d_hidden,
|
||
|
p_drop=p_drop,
|
||
|
use_global_attn=True,
|
||
|
SE3_param=SE3_param,
|
||
|
nextra_l1=3 if self.use_chiral_l1 else 0,
|
||
|
use_same_chain=use_same_chain,
|
||
|
enable_same_chain=enable_same_chain,
|
||
|
symmetrize_repeats=symmetrize_repeats,
|
||
|
repeat_length=repeat_length,
|
||
|
symmsub_k=symmsub_k,
|
||
|
sym_method=sym_method,
|
||
|
main_block=main_block,
|
||
|
fit=fit,
|
||
|
tscale=tscale)
|
||
|
for i in range(n_extra_block)])
|
||
|
|
||
|
# Update with seed sequences
|
||
|
if n_main_block > 0:
|
||
|
self.main_block = nn.ModuleList([IterBlock(d_msa=d_msa, d_pair=d_pair,
|
||
|
n_head_msa=n_head_msa,
|
||
|
n_head_pair=n_head_pair,
|
||
|
d_hidden=d_hidden,
|
||
|
p_drop=p_drop,
|
||
|
use_global_attn=False,
|
||
|
SE3_param=SE3_param,
|
||
|
nextra_l1=3 if self.use_chiral_l1 else 0,
|
||
|
use_same_chain=use_same_chain,
|
||
|
enable_same_chain=enable_same_chain,
|
||
|
symmetrize_repeats=symmetrize_repeats,
|
||
|
repeat_length=repeat_length,
|
||
|
symmsub_k=symmsub_k,
|
||
|
sym_method=sym_method,
|
||
|
main_block=main_block,
|
||
|
fit=fit,
|
||
|
tscale=tscale)
|
||
|
for i in range(n_main_block)])
|
||
|
|
||
|
# Final SE(3) refinement
|
||
|
if n_ref_block > 0:
|
||
|
n_extra_l0 = 0
|
||
|
n_extra_l1 = 0
|
||
|
if self.use_chiral_l1:
|
||
|
n_extra_l1 += 3
|
||
|
if self.use_lj_l1:
|
||
|
n_extra_l0 += 2*ChemData().NTOTALDOFS
|
||
|
n_extra_l1 += 3
|
||
|
self.str_refiner = Str2Str(d_msa=d_msa, d_pair=d_pair,
|
||
|
d_state=SE3_param['l0_out_features'],
|
||
|
SE3_param=SE3_ref_param,
|
||
|
p_drop=p_drop,
|
||
|
nextra_l0=n_extra_l0,
|
||
|
nextra_l1=n_extra_l1,
|
||
|
)
|
||
|
|
||
|
# To get all-atom coordinates
|
||
|
self.xyzconverter = XYZConverter()
|
||
|
|
||
|
self.refiner_topk = refiner_topk
|
||
|
|
||
|
def forward(
|
||
|
self, seq_unmasked, msa, msa_full, pair, xyz, state, idx,
|
||
|
symmids, symmsub, symmRs, symmmeta,
|
||
|
bond_feats, dist_matrix, same_chain, chirals, is_motif,
|
||
|
atom_frames=None, use_checkpoint=False, use_atom_frames=True,
|
||
|
p2p_crop=-1, topk_crop=0
|
||
|
):
|
||
|
# input:
|
||
|
# msa: initial MSA embeddings (N, L, d_msa)
|
||
|
# pair: initial residue pair embeddings (L, L, d_pair)
|
||
|
|
||
|
if self.enable_same_chain == False:
|
||
|
same_chain = None
|
||
|
|
||
|
B,_,L = msa.shape[:3]
|
||
|
if symmsub is not None:
|
||
|
Lasu = L//symmsub.shape[0]
|
||
|
else:
|
||
|
Lasu = L
|
||
|
rotation_mask = is_atom(seq_unmasked)
|
||
|
xyz_s = list()
|
||
|
alpha_s = list()
|
||
|
quat_s = list()
|
||
|
for i_m in range(self.n_extra_block):
|
||
|
extra_l0 = None
|
||
|
extra_l1 = None
|
||
|
if self.use_chiral_l1:
|
||
|
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||
|
extra_l1 = dchiraldxyz[0].detach()
|
||
|
|
||
|
msa_full, pair, xyz, state, alpha, symmsub, quat = self.extra_block[i_m](msa_full, pair,
|
||
|
xyz, state, seq_unmasked, idx,
|
||
|
symmids, symmsub, symmRs, symmmeta,
|
||
|
bond_feats,
|
||
|
same_chain,
|
||
|
use_checkpoint=use_checkpoint,
|
||
|
top_k=topk_crop, rotation_mask=rotation_mask,
|
||
|
atom_frames=atom_frames,
|
||
|
extra_l0=extra_l0,
|
||
|
extra_l1=extra_l1,
|
||
|
is_motif=is_motif,
|
||
|
dist_matrix=dist_matrix,
|
||
|
use_atom_frames=use_atom_frames, crop=p2p_crop)
|
||
|
xyz_s.append(xyz)
|
||
|
alpha_s.append(alpha)
|
||
|
quat_s.append(quat)
|
||
|
|
||
|
for i_m in range(self.n_main_block):
|
||
|
extra_l0 = None
|
||
|
extra_l1 = None
|
||
|
if self.use_chiral_l1:
|
||
|
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||
|
extra_l1 = dchiraldxyz[0].detach()
|
||
|
msa, pair, xyz, state, alpha, symmsub, quat = self.main_block[i_m](msa, pair,
|
||
|
xyz, state, seq_unmasked, idx,
|
||
|
symmids, symmsub, symmRs, symmmeta,
|
||
|
bond_feats,
|
||
|
same_chain,
|
||
|
use_checkpoint=use_checkpoint,
|
||
|
top_k=topk_crop, rotation_mask=rotation_mask,
|
||
|
atom_frames=atom_frames,
|
||
|
extra_l0=extra_l0,
|
||
|
extra_l1=extra_l1,
|
||
|
is_motif=is_motif,
|
||
|
dist_matrix=dist_matrix,
|
||
|
use_atom_frames=use_atom_frames, crop=p2p_crop)
|
||
|
xyz_s.append(xyz)
|
||
|
alpha_s.append(alpha)
|
||
|
quat_s.append(quat)
|
||
|
|
||
|
_, xyzallatom = self.xyzconverter.compute_all_atom(seq_unmasked, xyz, alpha) # think about detach here...
|
||
|
|
||
|
backprop = torch.arange(self.n_ref_block) # backprop through everything
|
||
|
for i_m in range(self.n_ref_block):
|
||
|
with ExitStack() as stack:
|
||
|
if (backprop != i_m).all():
|
||
|
stack.enter_context(torch.no_grad())
|
||
|
|
||
|
extra_l0 = None
|
||
|
extra_l1 = []
|
||
|
|
||
|
if self.use_lj_l1:
|
||
|
dljdxyz, dljdalpha = calc_lj_grads(
|
||
|
seq_unmasked, xyz.detach(), alpha.detach(),
|
||
|
self.xyzconverter.compute_all_atom,
|
||
|
bond_feats, dist_matrix,
|
||
|
self.aamask,
|
||
|
self.ljlk_parameters,
|
||
|
self.lj_correction_parameters,
|
||
|
self.num_bonds,
|
||
|
lj_lin=self.lj_lin)
|
||
|
extra_l0 = dljdalpha.reshape(1,-1,2*ChemData().NTOTALDOFS).detach()
|
||
|
extra_l1.append(dljdxyz[0].detach())
|
||
|
|
||
|
if self.use_chiral_l1:
|
||
|
dchiraldxyz, = calc_chiral_grads(xyz.detach(),chirals)
|
||
|
#extra_l1 = torch.cat((dljdxyz[0].detach(), dchiraldxyz[0].detach()), dim=1)
|
||
|
extra_l1.append(dchiraldxyz[0].detach())
|
||
|
extra_l1 = torch.cat(extra_l1, dim=1)
|
||
|
|
||
|
xyz, state, alpha, quat = self.str_refiner(
|
||
|
msa.float(), pair.float(), xyz.detach().float(), state.float(), idx,
|
||
|
rotation_mask, bond_feats, dist_matrix, atom_frames,
|
||
|
is_motif, extra_l0, extra_l1.float(), top_k=self.refiner_topk, use_atom_frames=use_atom_frames
|
||
|
)
|
||
|
|
||
|
|
||
|
if symmsub is not None and symmsub.shape[0]>1:
|
||
|
xyz = update_symm_Rs(xyz, Lasu, symmsub, symmRs, self.fit, self.tscale)
|
||
|
|
||
|
xyz_s.append(xyz)
|
||
|
alpha_s.append(alpha)
|
||
|
quat_s.append(quat)
|
||
|
|
||
|
_, xyzallatom = self.xyzconverter.compute_all_atom(seq_unmasked, xyz, alpha) # think about detach here...
|
||
|
xyzallatom_s = list()
|
||
|
xyzallatom_s.append(xyzallatom.clone())
|
||
|
# if (self.n_finetune_block>0):
|
||
|
# state = self.allatom_embed(state, seq_unmasked, self.atom_type_index)
|
||
|
#
|
||
|
# for i_m in range(self.n_finetune_block):
|
||
|
# extra_l1 = None
|
||
|
#
|
||
|
# xyzallatom, state = self.finetune_refiner(
|
||
|
# seq_unmasked,
|
||
|
# xyzallatom.detach().float(),
|
||
|
# self.aamask,
|
||
|
# self.num_bonds,
|
||
|
# state,
|
||
|
# extra_l1.float()
|
||
|
# )
|
||
|
#
|
||
|
# xyzallatom_s.append(xyzallatom.clone())
|
||
|
#
|
||
|
# state = self.residue_embed(state)
|
||
|
|
||
|
xyz = torch.stack(xyz_s, dim=0)
|
||
|
alpha_s = torch.stack(alpha_s, dim=0)
|
||
|
xyzallatom_s = torch.cat(xyzallatom_s, dim=0)
|
||
|
quat = torch.stack(quat_s, dim=1)
|
||
|
|
||
|
return msa, pair, xyz, alpha_s, xyzallatom_s, state, symmsub, quat
|