2024-03-05 06:38:17 +00:00
|
|
|
import torch
|
|
|
|
import warnings
|
|
|
|
import time
|
|
|
|
from icecream import ic
|
|
|
|
from torch.utils import data
|
|
|
|
import os, csv, random, pickle, gzip, itertools, time, ast, copy, sys
|
|
|
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
sys.path.append(script_dir)
|
|
|
|
sys.path.append(script_dir+'/../')
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import scipy
|
|
|
|
import networkx as nx
|
|
|
|
|
|
|
|
from rf2aa.data.parsers import parse_a3m, parse_pdb
|
|
|
|
from rf2aa.chemical import ChemicalData as ChemData
|
|
|
|
|
|
|
|
|
|
|
|
from rf2aa.util import random_rot_trans, \
|
|
|
|
is_atom, is_protein, is_nucleic, is_atom
|
|
|
|
|
|
|
|
|
|
|
|
def MSABlockDeletion(msa, ins, nb=5):
|
|
|
|
'''
|
|
|
|
Input: MSA having shape (N, L)
|
|
|
|
output: new MSA with block deletion
|
|
|
|
'''
|
|
|
|
N, L = msa.shape
|
|
|
|
block_size = max(int(N*0.3), 1)
|
|
|
|
block_start = np.random.randint(low=1, high=N, size=nb) # (nb)
|
|
|
|
to_delete = block_start[:,None] + np.arange(block_size)[None,:]
|
|
|
|
to_delete = np.unique(np.clip(to_delete, 1, N-1))
|
|
|
|
#
|
|
|
|
mask = np.ones(N, bool)
|
|
|
|
mask[to_delete] = 0
|
|
|
|
|
|
|
|
return msa[mask], ins[mask]
|
|
|
|
|
|
|
|
def cluster_sum(data, assignment, N_seq, N_res):
|
|
|
|
csum = torch.zeros(N_seq, N_res, data.shape[-1], device=data.device).scatter_add(0, assignment.view(-1,1,1).expand(-1,N_res,data.shape[-1]), data.float())
|
|
|
|
return csum
|
|
|
|
|
|
|
|
def get_term_feats(Ls):
|
|
|
|
"""Creates N/C-terminus binary features"""
|
|
|
|
term_info = torch.zeros((sum(Ls),2)).float()
|
|
|
|
start = 0
|
|
|
|
for L_chain in Ls:
|
|
|
|
term_info[start, 0] = 1.0 # flag for N-term
|
|
|
|
term_info[start+L_chain-1,1] = 1.0 # flag for C-term
|
|
|
|
start += L_chain
|
|
|
|
return term_info
|
|
|
|
|
|
|
|
|
|
|
|
def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
|
|
|
|
term_info=None, tocpu=False, fixbb=False, seed_msa_clus=None, deterministic=False):
|
|
|
|
'''
|
|
|
|
Input: full MSA information (after Block deletion if necessary) & full insertion information
|
|
|
|
Output: seed MSA features & extra sequences
|
|
|
|
|
|
|
|
Seed MSA features:
|
|
|
|
- aatype of seed sequence (20 regular aa + 1 gap/unknown + 1 mask)
|
|
|
|
- profile of clustered sequences (22)
|
|
|
|
- insertion statistics (2)
|
|
|
|
- N-term or C-term? (2)
|
|
|
|
extra sequence features:
|
|
|
|
- aatype of extra sequence (22)
|
|
|
|
- insertion info (1)
|
|
|
|
- N-term or C-term? (2)
|
|
|
|
'''
|
|
|
|
if deterministic:
|
|
|
|
random.seed(0)
|
|
|
|
np.random.seed(0)
|
|
|
|
torch.manual_seed(0)
|
|
|
|
torch.cuda.manual_seed(0)
|
|
|
|
# TODO: delete me, just for testing purposes
|
|
|
|
msa = msa[:2]
|
|
|
|
|
|
|
|
if fixbb:
|
|
|
|
p_mask = 0
|
|
|
|
msa = msa[:1]
|
|
|
|
ins = ins[:1]
|
|
|
|
N, L = msa.shape
|
|
|
|
|
|
|
|
if term_info is None:
|
|
|
|
if len(L_s)==0:
|
|
|
|
L_s = [L]
|
|
|
|
term_info = get_term_feats(L_s)
|
|
|
|
term_info = term_info.to(msa.device)
|
|
|
|
|
|
|
|
#binding_site = torch.zeros((L,1), device=msa.device).float()
|
|
|
|
binding_site = torch.zeros((L,0), device=msa.device).float() # keeping this off for now (Jue 12/19)
|
|
|
|
|
|
|
|
# raw MSA profile
|
|
|
|
raw_profile = torch.nn.functional.one_hot(msa, num_classes=ChemData().NAATOKENS) # N x L x NAATOKENS
|
|
|
|
raw_profile = raw_profile.float().mean(dim=0) # L x NAATOKENS
|
|
|
|
|
|
|
|
# Select Nclust sequence randomly (seed MSA or latent MSA)
|
|
|
|
Nclust = (min(N, params['MAXLAT'])-1) // nmer
|
|
|
|
Nclust = Nclust*nmer + 1
|
|
|
|
|
|
|
|
if N > Nclust*2:
|
|
|
|
Nextra = N - Nclust
|
|
|
|
else:
|
|
|
|
Nextra = N
|
|
|
|
Nextra = min(Nextra, params['MAXSEQ']) // nmer
|
|
|
|
Nextra = max(1, Nextra * nmer)
|
|
|
|
#
|
|
|
|
b_seq = list()
|
|
|
|
b_msa_clust = list()
|
|
|
|
b_msa_seed = list()
|
|
|
|
b_msa_extra = list()
|
|
|
|
b_mask_pos = list()
|
|
|
|
for i_cycle in range(params['MAXCYCLE']):
|
|
|
|
sample_mono = torch.randperm((N-1)//nmer, device=msa.device)
|
|
|
|
sample = [sample_mono + imer*((N-1)//nmer) for imer in range(nmer)]
|
|
|
|
sample = torch.stack(sample, dim=-1)
|
|
|
|
sample = sample.reshape(-1)
|
|
|
|
|
|
|
|
# add MSA clusters pre-chosen before calling this function
|
|
|
|
if seed_msa_clus is not None:
|
|
|
|
sample_orig_shape = sample.shape
|
|
|
|
sample_seed = seed_msa_clus[i_cycle]
|
|
|
|
sample_more = torch.tensor([i for i in sample if i not in sample_seed])
|
|
|
|
N_sample_more = len(sample) - len(sample_seed)
|
|
|
|
if N_sample_more > 0:
|
|
|
|
sample_more = sample_more[torch.randperm(len(sample_more))[:N_sample_more]]
|
|
|
|
sample = torch.cat([sample_seed, sample_more])
|
|
|
|
else:
|
|
|
|
sample = sample_seed[:len(sample)] # take all clusters from pre-chosen ones
|
|
|
|
|
|
|
|
msa_clust = torch.cat((msa[:1,:], msa[1:,:][sample[:Nclust-1]]), dim=0)
|
|
|
|
ins_clust = torch.cat((ins[:1,:], ins[1:,:][sample[:Nclust-1]]), dim=0)
|
|
|
|
|
|
|
|
# 15% random masking
|
|
|
|
# - 10%: aa replaced with a uniformly sampled random amino acid
|
|
|
|
# - 10%: aa replaced with an amino acid sampled from the MSA profile
|
|
|
|
# - 10%: not replaced
|
|
|
|
# - 70%: replaced with a special token ("mask")
|
|
|
|
random_aa = torch.tensor([[0.05]*20 + [0.0]*(ChemData().NAATOKENS-20)], device=msa.device)
|
|
|
|
same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=ChemData().NAATOKENS)
|
|
|
|
# explicitly remove probabilities from nucleic acids and atoms
|
|
|
|
#same_aa[..., ChemData().NPROTAAS:] = 0
|
|
|
|
#raw_profile[...,ChemData().NPROTAAS:] = 0
|
|
|
|
probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa
|
|
|
|
#probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7)
|
|
|
|
|
|
|
|
# explicitly set the probability of masking for nucleic acids and atoms
|
|
|
|
#probs[...,is_protein(seq),ChemData().MASKINDEX]=0.7
|
|
|
|
#probs[...,~is_protein(seq), :] = 0 # probably overkill but set all none protein elements to 0
|
|
|
|
#probs[1:, ~is_protein(seq),20] = 1.0 # want to leave the gaps as gaps
|
|
|
|
#probs[0,is_nucleic(seq), ChemData().MASKINDEX] = 1.0
|
|
|
|
#probs[0,is_atom(seq), ChemData().aa2num["ATM"]] = 1.0
|
|
|
|
|
|
|
|
sampler = torch.distributions.categorical.Categorical(probs=probs)
|
|
|
|
mask_sample = sampler.sample()
|
|
|
|
|
|
|
|
mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask
|
|
|
|
mask_pos[msa_clust>ChemData().MASKINDEX]=False # no masking on NAs
|
|
|
|
use_seq = msa_clust
|
|
|
|
msa_masked = torch.where(mask_pos, mask_sample, use_seq)
|
|
|
|
b_seq.append(msa_masked[0].clone())
|
|
|
|
|
|
|
|
## get extra sequenes
|
|
|
|
if N > Nclust*2: # there are enough extra sequences
|
|
|
|
msa_extra = msa[1:,:][sample[Nclust-1:]]
|
|
|
|
ins_extra = ins[1:,:][sample[Nclust-1:]]
|
|
|
|
extra_mask = torch.full(msa_extra.shape, False, device=msa_extra.device)
|
|
|
|
elif N - Nclust < 1:
|
|
|
|
msa_extra = msa_masked.clone()
|
|
|
|
ins_extra = ins_clust.clone()
|
|
|
|
extra_mask = mask_pos.clone()
|
|
|
|
else:
|
|
|
|
msa_add = msa[1:,:][sample[Nclust-1:]]
|
|
|
|
ins_add = ins[1:,:][sample[Nclust-1:]]
|
|
|
|
mask_add = torch.full(msa_add.shape, False, device=msa_add.device)
|
|
|
|
msa_extra = torch.cat((msa_masked, msa_add), dim=0)
|
|
|
|
ins_extra = torch.cat((ins_clust, ins_add), dim=0)
|
|
|
|
extra_mask = torch.cat((mask_pos, mask_add), dim=0)
|
|
|
|
N_extra = msa_extra.shape[0]
|
|
|
|
|
|
|
|
# clustering (assign remaining sequences to their closest cluster by Hamming distance
|
|
|
|
msa_clust_onehot = torch.nn.functional.one_hot(msa_masked, num_classes=ChemData().NAATOKENS)
|
|
|
|
msa_extra_onehot = torch.nn.functional.one_hot(msa_extra, num_classes=ChemData().NAATOKENS)
|
|
|
|
count_clust = torch.logical_and(~mask_pos, msa_clust != 20).float() # 20: index for gap, ignore both masked & gaps
|
|
|
|
count_extra = torch.logical_and(~extra_mask, msa_extra != 20).float()
|
|
|
|
agreement = torch.matmul((count_extra[:,:,None]*msa_extra_onehot).view(N_extra, -1), (count_clust[:,:,None]*msa_clust_onehot).view(Nclust, -1).T)
|
|
|
|
assignment = torch.argmax(agreement, dim=-1)
|
|
|
|
|
|
|
|
# seed MSA features
|
|
|
|
# 1. one_hot encoded aatype: msa_clust_onehot
|
|
|
|
# 2. cluster profile
|
|
|
|
count_extra = ~extra_mask
|
|
|
|
count_clust = ~mask_pos
|
|
|
|
msa_clust_profile = cluster_sum(count_extra[:,:,None]*msa_extra_onehot, assignment, Nclust, L)
|
|
|
|
msa_clust_profile += count_clust[:,:,None]*msa_clust_profile
|
|
|
|
count_profile = cluster_sum(count_extra[:,:,None], assignment, Nclust, L).view(Nclust, L)
|
|
|
|
count_profile += count_clust
|
|
|
|
count_profile += eps
|
|
|
|
msa_clust_profile /= count_profile[:,:,None]
|
|
|
|
# 3. insertion statistics
|
|
|
|
msa_clust_del = cluster_sum((count_extra*ins_extra)[:,:,None], assignment, Nclust, L).view(Nclust, L)
|
|
|
|
msa_clust_del += count_clust*ins_clust
|
|
|
|
msa_clust_del /= count_profile
|
|
|
|
ins_clust = (2.0/np.pi)*torch.arctan(ins_clust.float()/3.0) # (from 0 to 1)
|
|
|
|
msa_clust_del = (2.0/np.pi)*torch.arctan(msa_clust_del.float()/3.0) # (from 0 to 1)
|
|
|
|
ins_clust = torch.stack((ins_clust, msa_clust_del), dim=-1)
|
|
|
|
#
|
|
|
|
if fixbb:
|
|
|
|
assert params['MAXCYCLE'] == 1
|
|
|
|
msa_clust_profile = msa_clust_onehot
|
|
|
|
msa_extra_onehot = msa_clust_onehot
|
|
|
|
ins_clust[:] = 0
|
|
|
|
ins_extra[:] = 0
|
|
|
|
# This is how it is done in rfdiff, but really it seems like it should be all 0.
|
|
|
|
# Keeping as-is for now for consistency, as it may be used in downstream masking done
|
|
|
|
# by apply_masks.
|
|
|
|
mask_pos = torch.full_like(msa_clust, 1).bool()
|
|
|
|
msa_seed = torch.cat((msa_clust_onehot, msa_clust_profile, ins_clust, term_info[None].expand(Nclust,-1,-1)), dim=-1)
|
|
|
|
|
|
|
|
# extra MSA features
|
|
|
|
ins_extra = (2.0/np.pi)*torch.arctan(ins_extra[:Nextra].float()/3.0) # (from 0 to 1)
|
|
|
|
try:
|
|
|
|
msa_extra = torch.cat((msa_extra_onehot[:Nextra], ins_extra[:,:,None], term_info[None].expand(Nextra,-1,-1)), dim=-1)
|
|
|
|
except Exception as e:
|
|
|
|
print('msa_extra.shape',msa_extra.shape)
|
|
|
|
print('ins_extra.shape',ins_extra.shape)
|
|
|
|
|
|
|
|
if (tocpu):
|
|
|
|
b_msa_clust.append(msa_clust.cpu())
|
|
|
|
b_msa_seed.append(msa_seed.cpu())
|
|
|
|
b_msa_extra.append(msa_extra.cpu())
|
|
|
|
b_mask_pos.append(mask_pos.cpu())
|
|
|
|
else:
|
|
|
|
b_msa_clust.append(msa_clust)
|
|
|
|
b_msa_seed.append(msa_seed)
|
|
|
|
b_msa_extra.append(msa_extra)
|
|
|
|
b_mask_pos.append(mask_pos)
|
|
|
|
|
|
|
|
b_seq = torch.stack(b_seq)
|
|
|
|
b_msa_clust = torch.stack(b_msa_clust)
|
|
|
|
b_msa_seed = torch.stack(b_msa_seed)
|
|
|
|
b_msa_extra = torch.stack(b_msa_extra)
|
|
|
|
b_mask_pos = torch.stack(b_mask_pos)
|
|
|
|
|
|
|
|
return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos
|
|
|
|
|
|
|
|
def blank_template(n_tmpl, L, random_noise=5.0, deterministic: bool = False):
|
|
|
|
if deterministic:
|
|
|
|
random.seed(0)
|
|
|
|
np.random.seed(0)
|
|
|
|
torch.manual_seed(0)
|
|
|
|
torch.cuda.manual_seed(0)
|
|
|
|
|
|
|
|
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(n_tmpl,L,1,1) \
|
|
|
|
+ torch.rand(n_tmpl,L,1,3)*random_noise - random_noise/2
|
|
|
|
t1d = torch.nn.functional.one_hot(torch.full((n_tmpl, L), 20).long(), num_classes=ChemData().NAATOKENS-1).float() # all gaps
|
|
|
|
conf = torch.zeros((n_tmpl, L, 1)).float()
|
|
|
|
t1d = torch.cat((t1d, conf), -1)
|
|
|
|
mask_t = torch.full((n_tmpl,L,ChemData().NTOTAL), False)
|
|
|
|
return xyz, t1d, mask_t, np.full((n_tmpl), "")
|
|
|
|
|
|
|
|
|
|
|
|
def TemplFeaturize(tplt, qlen, params, offset=0, npick=1, npick_global=None, pick_top=True, same_chain=None, random_noise=5, deterministic: bool = False):
|
|
|
|
if deterministic:
|
|
|
|
random.seed(0)
|
|
|
|
np.random.seed(0)
|
|
|
|
torch.manual_seed(0)
|
|
|
|
torch.cuda.manual_seed(0)
|
|
|
|
|
|
|
|
seqID_cut = params['SEQID']
|
|
|
|
|
|
|
|
if npick_global == None:
|
|
|
|
npick_global=max(npick, 1)
|
|
|
|
|
|
|
|
ntplt = len(tplt['ids'])
|
|
|
|
if (ntplt < 1) or (npick < 1): #no templates in hhsearch file or not want to use templ
|
|
|
|
return blank_template(npick_global, qlen, random_noise)
|
|
|
|
|
|
|
|
# ignore templates having too high seqID
|
|
|
|
if seqID_cut <= 100.0:
|
|
|
|
tplt_valid_idx = torch.where(tplt['f0d'][0,:,4] < seqID_cut)[0]
|
|
|
|
tplt['ids'] = np.array(tplt['ids'])[tplt_valid_idx]
|
|
|
|
else:
|
|
|
|
tplt_valid_idx = torch.arange(len(tplt['ids']))
|
|
|
|
|
|
|
|
# check again if there are templates having seqID < cutoff
|
|
|
|
ntplt = len(tplt['ids'])
|
|
|
|
npick = min(npick, ntplt)
|
|
|
|
if npick<1: # no templates
|
|
|
|
return blank_template(npick_global, qlen, random_noise)
|
|
|
|
|
|
|
|
if not pick_top: # select randomly among all possible templates
|
|
|
|
sample = torch.randperm(ntplt)[:npick]
|
|
|
|
else: # only consider top 50 templates
|
|
|
|
sample = torch.randperm(min(50,ntplt))[:npick]
|
|
|
|
|
|
|
|
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(npick_global,qlen,1,1) + torch.rand(1,qlen,1,3)*random_noise
|
|
|
|
mask_t = torch.full((npick_global,qlen,ChemData().NTOTAL),False) # True for valid atom, False for missing atom
|
|
|
|
t1d = torch.full((npick_global, qlen), 20).long()
|
|
|
|
t1d_val = torch.zeros((npick_global, qlen)).float()
|
|
|
|
for i,nt in enumerate(sample):
|
|
|
|
tplt_idx = tplt_valid_idx[nt]
|
|
|
|
sel = torch.where(tplt['qmap'][0,:,1]==tplt_idx)[0]
|
|
|
|
pos = tplt['qmap'][0,sel,0] + offset
|
|
|
|
|
|
|
|
ntmplatoms = tplt['xyz'].shape[2] # will be bigger for NA templates
|
|
|
|
xyz[i,pos,:ntmplatoms] = tplt['xyz'][0,sel]
|
|
|
|
mask_t[i,pos,:ntmplatoms] = tplt['mask'][0,sel].bool()
|
|
|
|
|
|
|
|
# 1-D features: alignment confidence
|
|
|
|
t1d[i,pos] = tplt['seq'][0,sel]
|
|
|
|
t1d_val[i,pos] = tplt['f1d'][0,sel,2] # alignment confidence
|
|
|
|
# xyz[i] = center_and_realign_missing(xyz[i], mask_t[i], same_chain=same_chain)
|
|
|
|
|
|
|
|
t1d = torch.nn.functional.one_hot(t1d, num_classes=ChemData().NAATOKENS-1).float() # (no mask token)
|
|
|
|
t1d = torch.cat((t1d, t1d_val[...,None]), dim=-1)
|
|
|
|
|
|
|
|
tplt_ids = np.array(tplt["ids"])[sample].flatten() # np.array of chain ids (ordered)
|
|
|
|
return xyz, t1d, mask_t, tplt_ids
|
|
|
|
|
|
|
|
def merge_hetero_templates(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids, Ls_prot, deterministic: bool = False):
|
|
|
|
"""Diagonally tiles template coordinates, 1d input features, and masks across
|
|
|
|
template and residue dimensions. 1st template is concatenated directly on residue
|
|
|
|
dimension after a random rotation & translation.
|
|
|
|
"""
|
|
|
|
N_tmpl_tot = sum([x.shape[0] for x in xyz_t_prot])
|
|
|
|
|
|
|
|
xyz_t_out, f1d_t_out, mask_t_out, _ = blank_template(N_tmpl_tot, sum(Ls_prot))
|
|
|
|
tplt_ids_out = np.full((N_tmpl_tot),"", dtype=object) # rk bad practice.. should fix
|
|
|
|
i_tmpl = 0
|
|
|
|
i_res = 0
|
|
|
|
for xyz_, f1d_, mask_, ids in zip(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids):
|
|
|
|
N_tmpl, L_tmpl = xyz_.shape[:2]
|
|
|
|
if i_tmpl == 0:
|
|
|
|
i1, i2 = 1, N_tmpl
|
|
|
|
else:
|
|
|
|
i1, i2 = i_tmpl, i_tmpl+N_tmpl - 1
|
|
|
|
|
|
|
|
# 1st template is concatenated directly, so that all atoms are set in xyz_prev
|
|
|
|
xyz_t_out[0, i_res:i_res+L_tmpl] = random_rot_trans(xyz_[0:1], deterministic=deterministic)
|
|
|
|
f1d_t_out[0, i_res:i_res+L_tmpl] = f1d_[0]
|
|
|
|
mask_t_out[0, i_res:i_res+L_tmpl] = mask_[0]
|
|
|
|
|
|
|
|
if not tplt_ids_out[0]: # only add first template
|
|
|
|
tplt_ids_out[0] = ids[0]
|
|
|
|
# remaining templates are diagonally tiled
|
|
|
|
xyz_t_out[i1:i2, i_res:i_res+L_tmpl] = xyz_[1:]
|
|
|
|
f1d_t_out[i1:i2, i_res:i_res+L_tmpl] = f1d_[1:]
|
|
|
|
mask_t_out[i1:i2, i_res:i_res+L_tmpl] = mask_[1:]
|
|
|
|
tplt_ids_out[i1:i2] = ids[1:]
|
|
|
|
if i_tmpl == 0:
|
|
|
|
i_tmpl += N_tmpl
|
|
|
|
else:
|
|
|
|
i_tmpl += N_tmpl-1
|
|
|
|
i_res += L_tmpl
|
|
|
|
|
|
|
|
return xyz_t_out, f1d_t_out, mask_t_out, tplt_ids_out
|
|
|
|
|
|
|
|
def generate_xyz_prev(xyz_t, mask_t, params):
|
|
|
|
"""
|
|
|
|
allows you to use different initializations for the coordinate track specified in params
|
|
|
|
"""
|
|
|
|
L = xyz_t.shape[1]
|
|
|
|
if params["BLACK_HOLE_INIT"]:
|
|
|
|
xyz_t, _, mask_t = blank_template(1, L)
|
|
|
|
return xyz_t[0].clone(), mask_t[0].clone()
|
|
|
|
|
|
|
|
### merge msa & insertion statistics of two proteins having different taxID
|
|
|
|
def merge_a3m_hetero(a3mA, a3mB, L_s):
|
|
|
|
# merge msa
|
|
|
|
query = torch.cat([a3mA['msa'][0], a3mB['msa'][0]]).unsqueeze(0) # (1, L)
|
|
|
|
|
|
|
|
msa = [query]
|
|
|
|
if a3mA['msa'].shape[0] > 1:
|
|
|
|
extra_A = torch.nn.functional.pad(a3mA['msa'][1:], (0,sum(L_s[1:])), "constant", 20) # pad gaps
|
|
|
|
msa.append(extra_A)
|
|
|
|
if a3mB['msa'].shape[0] > 1:
|
|
|
|
extra_B = torch.nn.functional.pad(a3mB['msa'][1:], (L_s[0],0), "constant", 20)
|
|
|
|
msa.append(extra_B)
|
|
|
|
msa = torch.cat(msa, dim=0)
|
|
|
|
|
|
|
|
# merge ins
|
|
|
|
query = torch.cat([a3mA['ins'][0], a3mB['ins'][0]]).unsqueeze(0) # (1, L)
|
|
|
|
ins = [query]
|
|
|
|
if a3mA['ins'].shape[0] > 1:
|
|
|
|
extra_A = torch.nn.functional.pad(a3mA['ins'][1:], (0,sum(L_s[1:])), "constant", 0) # pad gaps
|
|
|
|
ins.append(extra_A)
|
|
|
|
if a3mB['ins'].shape[0] > 1:
|
|
|
|
extra_B = torch.nn.functional.pad(a3mB['ins'][1:], (L_s[0],0), "constant", 0)
|
|
|
|
ins.append(extra_B)
|
|
|
|
ins = torch.cat(ins, dim=0)
|
|
|
|
|
|
|
|
a3m = {'msa': msa, 'ins': ins}
|
|
|
|
|
|
|
|
# merge taxids
|
|
|
|
if 'taxid' in a3mA and 'taxid' in a3mB:
|
|
|
|
a3m['taxid'] = np.concatenate([np.array(a3mA['taxid']), np.array(a3mB['taxid'])[1:]])
|
|
|
|
|
|
|
|
return a3m
|
|
|
|
|
|
|
|
# merge msa & insertion statistics of units in homo-oligomers
|
|
|
|
def merge_a3m_homo(msa_orig, ins_orig, nmer, mode="default"):
|
|
|
|
N, L = msa_orig.shape[:2]
|
|
|
|
if mode == "repeat":
|
|
|
|
|
|
|
|
# AAAAAA
|
|
|
|
# AAAAAA
|
|
|
|
|
|
|
|
msa = torch.tile(msa_orig,(1,nmer))
|
|
|
|
ins = torch.tile(ins_orig,(1,nmer))
|
|
|
|
|
|
|
|
elif mode == "diag":
|
|
|
|
|
|
|
|
# AAAAAA
|
|
|
|
# A-----
|
|
|
|
# -A----
|
|
|
|
# --A---
|
|
|
|
# ---A--
|
|
|
|
# ----A-
|
|
|
|
# -----A
|
|
|
|
|
|
|
|
N = N - 1
|
|
|
|
new_N = 1 + N * nmer
|
|
|
|
new_L = L * nmer
|
|
|
|
msa = torch.full((new_N, new_L), 20, dtype=msa_orig.dtype, device=msa_orig.device)
|
|
|
|
ins = torch.full((new_N, new_L), 0, dtype=ins_orig.dtype, device=msa_orig.device)
|
|
|
|
|
|
|
|
start_L = 0
|
|
|
|
start_N = 1
|
|
|
|
for i_c in range(nmer):
|
|
|
|
msa[0, start_L:start_L+L] = msa_orig[0]
|
|
|
|
msa[start_N:start_N+N, start_L:start_L+L] = msa_orig[1:]
|
|
|
|
ins[0, start_L:start_L+L] = ins_orig[0]
|
|
|
|
ins[start_N:start_N+N, start_L:start_L+L] = ins_orig[1:]
|
|
|
|
start_L += L
|
|
|
|
start_N += N
|
|
|
|
else:
|
|
|
|
|
|
|
|
# AAAAAA
|
|
|
|
# A-----
|
|
|
|
# -AAAAA
|
|
|
|
|
|
|
|
msa = torch.full((2*N-1, L*nmer), 20, dtype=msa_orig.dtype, device=msa_orig.device)
|
|
|
|
ins = torch.full((2*N-1, L*nmer), 0, dtype=ins_orig.dtype, device=msa_orig.device)
|
|
|
|
|
|
|
|
msa[:N, :L] = msa_orig
|
|
|
|
ins[:N, :L] = ins_orig
|
|
|
|
start = L
|
|
|
|
|
|
|
|
for i_c in range(1,nmer):
|
|
|
|
msa[0, start:start+L] = msa_orig[0]
|
|
|
|
msa[N:, start:start+L] = msa_orig[1:]
|
|
|
|
ins[0, start:start+L] = ins_orig[0]
|
|
|
|
ins[N:, start:start+L] = ins_orig[1:]
|
|
|
|
start += L
|
|
|
|
|
|
|
|
return msa, ins
|
|
|
|
|
|
|
|
def merge_msas(a3m_list, L_s):
|
|
|
|
"""
|
|
|
|
takes a list of a3m dictionaries with keys msa, ins and a list of protein lengths and creates a
|
|
|
|
combined MSA
|
|
|
|
"""
|
|
|
|
seen = set()
|
|
|
|
taxIDs = []
|
|
|
|
a3mA = a3m_list[0]
|
|
|
|
taxIDs.extend(a3mA["taxID"])
|
|
|
|
seen.update(a3mA["hash"])
|
|
|
|
msaA, insA = a3mA["msa"], a3mA["ins"]
|
|
|
|
for i in range(1, len(a3m_list)):
|
|
|
|
a3mB = a3m_list[i]
|
|
|
|
pair_taxIDs = set(taxIDs).intersection(set(a3mB["taxID"]))
|
|
|
|
if a3mB["hash"] in seen or len(pair_taxIDs) < 5: #homomer/not enough pairs
|
|
|
|
a3mA = {"msa": msaA, "ins": insA}
|
|
|
|
L_s_to_merge = [sum(L_s[:i]), L_s[i]]
|
|
|
|
a3mA = merge_a3m_hetero(a3mA, a3mB, L_s_to_merge)
|
|
|
|
msaA, insA = a3mA["msa"], a3mA["ins"]
|
|
|
|
taxIDs.extend(a3mB["taxID"])
|
|
|
|
else:
|
|
|
|
final_pairsA = []
|
|
|
|
final_pairsB = []
|
|
|
|
msaB, insB = a3mB["msa"], a3mB["ins"]
|
|
|
|
for pair in pair_taxIDs:
|
|
|
|
pair_a3mA = np.where(np.array(taxIDs)==pair)[0]
|
|
|
|
pair_a3mB = np.where(a3mB["taxID"]==pair)[0]
|
|
|
|
msaApair = torch.argmin(torch.sum(msaA[pair_a3mA, :] == msaA[0, :],axis=-1))
|
|
|
|
msaBpair = torch.argmin(torch.sum(msaB[pair_a3mB, :] == msaB[0, :],axis=-1))
|
|
|
|
final_pairsA.append(pair_a3mA[msaApair])
|
|
|
|
final_pairsB.append(pair_a3mB[msaBpair])
|
|
|
|
paired_msaB = torch.full((msaA.shape[0], L_s[i]), 20).long() # (N_seq_A, L_B)
|
|
|
|
paired_msaB[final_pairsA] = msaB[final_pairsB]
|
|
|
|
msaA = torch.cat([msaA, paired_msaB], dim=1)
|
|
|
|
insA = torch.zeros_like(msaA) # paired MSAs in our dataset dont have insertions
|
|
|
|
seen.update(a3mB["hash"])
|
|
|
|
|
|
|
|
return msaA, insA
|
|
|
|
|
|
|
|
def remove_all_gap_seqs(a3m):
|
|
|
|
"""Removes sequences that are all gaps from an MSA represented as `a3m` dictionary"""
|
|
|
|
idx_seq_keep = ~(a3m['msa']==ChemData().UNKINDEX).all(dim=1)
|
|
|
|
a3m['msa'] = a3m['msa'][idx_seq_keep]
|
|
|
|
a3m['ins'] = a3m['ins'][idx_seq_keep]
|
|
|
|
return a3m
|
|
|
|
|
|
|
|
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
|
|
|
|
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
|
|
|
|
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
|
|
|
|
ID, only the sequence with the highest sequence identity to the query (1st
|
|
|
|
sequence in MSA) will be paired.
|
|
|
|
|
|
|
|
Sequences that aren't paired will be padded and added to the bottom of the
|
|
|
|
joined MSA. If a subregion of the input MSAs overlap (represent the same
|
|
|
|
chain), the subregion residue indices can be given as `idx_overlap`, and
|
|
|
|
the overlap region of the unpaired sequences will be included in the joined
|
|
|
|
MSA.
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
a3mA : dict
|
|
|
|
First MSA to be joined, with keys `msa` (N_seq, L_seq), `ins` (N_seq,
|
|
|
|
L_seq), `taxid` (N_seq,), and optionally `is_paired` (N_seq,), a
|
|
|
|
boolean tensor indicating whether each sequence is fully paired. Can be
|
|
|
|
a multi-MSA (contain >2 sub-MSAs).
|
|
|
|
a3mB : dict
|
|
|
|
2nd MSA to be joined, with keys `msa`, `ins`, `taxid`, and optionally
|
|
|
|
`is_paired`. Can be a multi-MSA ONLY if not overlapping with 1st MSA.
|
|
|
|
idx_overlap : tuple or list (optional)
|
|
|
|
Start and end indices of overlap region in 1st MSA, followed by the
|
|
|
|
same in 2nd MSA.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
a3m : dict
|
|
|
|
Paired MSA, with keys `msa`, `ins`, `taxid` and `is_paired`.
|
|
|
|
"""
|
|
|
|
# preprocess overlap region
|
|
|
|
L_A, L_B = a3mA['msa'].shape[1], a3mB['msa'].shape[1]
|
|
|
|
if idx_overlap is not None:
|
|
|
|
i1A, i2A, i1B, i2B = idx_overlap
|
|
|
|
i1B_new, i2B_new = (0, i1B) if i2B==L_B else (i2B, L_B) # MSA B residues that don't overlap MSA A
|
|
|
|
assert((i1B==0) or (i2B==a3mB['msa'].shape[1])), \
|
|
|
|
"When overlapping with 1st MSA, 2nd MSA must comprise at most 2 sub-MSAs "\
|
|
|
|
"(i.e. residue range should include 0 or a3mB['msa'].shape[1])"
|
|
|
|
else:
|
|
|
|
i1B_new, i2B_new = (0, L_B)
|
|
|
|
|
|
|
|
# pair sequences
|
|
|
|
taxids_shared = a3mA['taxid'][np.isin(a3mA['taxid'],a3mB['taxid'])]
|
|
|
|
i_pairedA, i_pairedB = [], []
|
2024-03-06 00:45:27 +00:00
|
|
|
|
2024-03-05 06:38:17 +00:00
|
|
|
for taxid in taxids_shared:
|
|
|
|
i_match = np.where(a3mA['taxid']==taxid)[0]
|
|
|
|
i_match_best = torch.argmin(torch.sum(a3mA['msa'][i_match]==a3mA['msa'][0], axis=1))
|
|
|
|
i_pairedA.append(i_match[i_match_best])
|
|
|
|
|
|
|
|
i_match = np.where(a3mB['taxid']==taxid)[0]
|
|
|
|
i_match_best = torch.argmin(torch.sum(a3mB['msa'][i_match]==a3mB['msa'][0], axis=1))
|
|
|
|
i_pairedB.append(i_match[i_match_best])
|
|
|
|
|
|
|
|
# unpaired sequences
|
|
|
|
i_unpairedA = np.setdiff1d(np.arange(a3mA['msa'].shape[0]), i_pairedA)
|
|
|
|
i_unpairedB = np.setdiff1d(np.arange(a3mB['msa'].shape[0]), i_pairedB)
|
|
|
|
N_paired, N_unpairedA, N_unpairedB = len(i_pairedA), len(i_unpairedA), len(i_unpairedB)
|
|
|
|
|
|
|
|
# handle overlap region
|
|
|
|
# if msa A consists of sub-MSAs 1,2,3 and msa B of 2,4 (i.e overlap region is 2),
|
|
|
|
# this diagram shows how the variables below make up the final multi-MSA
|
|
|
|
# (* denotes nongaps, - denotes gaps)
|
|
|
|
# 1 2 3 4
|
|
|
|
# |*|*|*|*| msa_paired
|
|
|
|
# |*|*|*|-| msaA_unpaired
|
|
|
|
# |-|*|-|*| msaB_unpaired
|
|
|
|
if idx_overlap is not None:
|
|
|
|
assert((a3mA['msa'][i_pairedA, i1A:i2A]==a3mB['msa'][i_pairedB, i1B:i2B]) |
|
|
|
|
(a3mA['msa'][i_pairedA, i1A:i2A]==ChemData().UNKINDEX)).all(),\
|
|
|
|
'Paired MSAs should be identical (or 1st MSA should be all gaps) in overlap region'
|
|
|
|
|
|
|
|
# overlap region gets sequences from 2nd MSA bc sometimes 1st MSA will be all gaps here
|
|
|
|
msa_paired = torch.cat([a3mA['msa'][i_pairedA, :i1A],
|
|
|
|
a3mB['msa'][i_pairedB, i1B:i2B],
|
|
|
|
a3mA['msa'][i_pairedA, i2A:],
|
|
|
|
a3mB['msa'][i_pairedB, i1B_new:i2B_new] ], dim=1)
|
|
|
|
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
|
|
|
|
torch.full((N_unpairedA, i2B_new-i1B_new), ChemData().UNKINDEX) ], dim=1)
|
|
|
|
msaB_unpaired = torch.cat([torch.full((N_unpairedB, i1A), ChemData().UNKINDEX),
|
|
|
|
a3mB['msa'][i_unpairedB, i1B:i2B],
|
|
|
|
torch.full((N_unpairedB, L_A-i2A), ChemData().UNKINDEX),
|
|
|
|
a3mB['msa'][i_unpairedB, i1B_new:i2B_new] ], dim=1)
|
|
|
|
else:
|
|
|
|
# no overlap region, simple offset pad & stack
|
|
|
|
# this code is actually a special case of "if" block above, but writing
|
|
|
|
# this out explicitly here to make the logic more clear
|
|
|
|
msa_paired = torch.cat([a3mA['msa'][i_pairedA], a3mB['msa'][i_pairedB, i1B_new:i2B_new]], dim=1)
|
|
|
|
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
|
|
|
|
torch.full((N_unpairedA, L_B), ChemData().UNKINDEX)], dim=1) # pad with gaps
|
|
|
|
msaB_unpaired = torch.cat([torch.full((N_unpairedB, L_A), ChemData().UNKINDEX),
|
|
|
|
a3mB['msa'][i_unpairedB]], dim=1) # pad with gaps
|
|
|
|
|
|
|
|
# stack paired & unpaired
|
|
|
|
msa = torch.cat([msa_paired, msaA_unpaired, msaB_unpaired], dim=0)
|
|
|
|
taxids = np.concatenate([a3mA['taxid'][i_pairedA], a3mA['taxid'][i_unpairedA], a3mB['taxid'][i_unpairedB]])
|
|
|
|
|
|
|
|
# label "fully paired" sequences (a row of MSA that was never padded with gaps)
|
|
|
|
# output seq is fully paired if seqs A & B both started out as paired and were paired to
|
|
|
|
# each other on tax ID.
|
|
|
|
# NOTE: there is a rare edge case that is ignored here for simplicity: if
|
|
|
|
# pMSA 0+1 and 1+2 are joined and then joined to 2+3, a seq that exists in
|
|
|
|
# 0+1 and 2+3 but NOT 1+2 will become fully paired on the last join but
|
|
|
|
# will not be labeled as such here
|
|
|
|
is_pairedA = a3mA['is_paired'] if 'is_paired' in a3mA else torch.ones((a3mA['msa'].shape[0],)).bool()
|
|
|
|
is_pairedB = a3mB['is_paired'] if 'is_paired' in a3mB else torch.ones((a3mB['msa'].shape[0],)).bool()
|
|
|
|
is_paired = torch.cat([is_pairedA[i_pairedA] & is_pairedB[i_pairedB],
|
|
|
|
torch.zeros((N_unpairedA + N_unpairedB,)).bool()])
|
|
|
|
|
|
|
|
# insertion features in paired MSAs are assumed to be zero
|
|
|
|
a3m = dict(msa=msa, ins=torch.zeros_like(msa), taxid=taxids, is_paired=is_paired)
|
|
|
|
return a3m
|
|
|
|
|
|
|
|
|
|
|
|
def load_minimal_multi_msa(hash_list, taxid_list, Ls, params):
|
|
|
|
"""Load a multi-MSA, which is a MSA that is paired across more than 2
|
|
|
|
chains. This loads the MSA for unique chains. Use 'expand_multi_msa` to
|
|
|
|
duplicate portions of the MSA for homo-oligomer repeated chains.
|
|
|
|
|
|
|
|
Given a list of unique MSA hashes, loads all MSAs (using paired MSAs where
|
|
|
|
it can) and pairs sequences across as many sub-MSAs as possible by matching
|
|
|
|
taxonomic ID. For details on how pairing is done, see
|
|
|
|
`join_msas_by_taxid()`
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
hash_list : list of str
|
|
|
|
Hashes of MSAs to load and join. Must not contain duplicates.
|
|
|
|
taxid_list : list of str
|
|
|
|
Taxonomic IDs of query sequences of each input MSA.
|
|
|
|
Ls : list of int
|
|
|
|
Lengths of the chains corresponding to the hashes.
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
a3m_out : dict
|
|
|
|
Multi-MSA with all input MSAs. Keys: `msa`,`ins` [torch.Tensor (N_seq, L)],
|
|
|
|
`taxid` [np.array (Nseq,)], `is_paired` [torch.Tensor (N_seq,)]
|
|
|
|
hashes_out : list of str
|
|
|
|
Hashes of MSAs in the order that they are joined in `a3m_out`.
|
|
|
|
Contains the same elements as the input `hash_list` but may be in a
|
|
|
|
different order.
|
|
|
|
Ls_out : list of int
|
|
|
|
Lengths of each chain in `a3m_out`
|
|
|
|
"""
|
|
|
|
assert(len(hash_list)==len(set(hash_list))), 'Input MSA hashes must be unique'
|
|
|
|
|
|
|
|
# the lists below are constructed such that `a3m_list[i_a3m]` is a multi-MSA
|
|
|
|
# comprising sub-MSAs whose indices in the input lists are
|
|
|
|
# `i_in = idx_list_groups[i_a3m][i_submsa]`, i.e. the sub-MSA hashes are
|
|
|
|
# `hash_list[i_in]` and lengths are `Ls[i_in]`.
|
|
|
|
# Each sub-MSA spans a region of its multi-MSA `a3m_list[i_a3m][:,i_start:i_end]`,
|
|
|
|
# where `(i_start,i_end) = res_range_groups[i_a3m][i_submsa]`
|
|
|
|
a3m_list = [] # list of multi-MSAs
|
|
|
|
idx_list_groups = [] # list of lists of indices of input chains making up each multi-MSA
|
|
|
|
res_range_groups = [] # list of lists of start and end residues of each sub-MSA in multi-MSA
|
|
|
|
|
|
|
|
# iterate through all pairs of hashes and look for paired MSAs (pMSAs)
|
|
|
|
# NOTE: in the below, if pMSAs are loaded for hashes 0+1 and then 2+3, and
|
|
|
|
# later a pMSA is found for 0+2, the last MSA will not be loaded. The 0+1
|
|
|
|
# and 2+3 pMSAs will still be joined on taxID at the end, but sequences
|
|
|
|
# only present in the 0+2 pMSA pMSAs will be missed. this is probably very
|
|
|
|
# rare and so is ignored here for simplicity.
|
|
|
|
N = len(hash_list)
|
|
|
|
for i1, i2 in itertools.permutations(range(N),2):
|
|
|
|
|
|
|
|
idx_list = [x for group in idx_list_groups for x in group] # flattened list of loaded hashes
|
|
|
|
if i1 in idx_list and i2 in idx_list: continue # already loaded
|
|
|
|
if i1 == '' or i2 == '': continue # no taxID means no pMSA
|
|
|
|
|
|
|
|
# a paired MSA exists
|
|
|
|
if taxid_list[i1]==taxid_list[i2]:
|
|
|
|
|
|
|
|
h1, h2 = hash_list[i1], hash_list[i2]
|
|
|
|
fn = params['COMPL_DIR']+'/pMSA/'+h1[:3]+'/'+h2[:3]+'/'+h1+'_'+h2+'.a3m.gz'
|
|
|
|
|
|
|
|
if os.path.exists(fn):
|
|
|
|
msa, ins, taxid = parse_a3m(fn, paired=True)
|
|
|
|
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins), taxid=taxid,
|
|
|
|
is_paired=torch.ones(msa.shape[0]).bool())
|
|
|
|
res_range1 = (0,Ls[i1])
|
|
|
|
res_range2 = (Ls[i1],msa.shape[1])
|
|
|
|
|
|
|
|
# both hashes are new, add paired MSA to list
|
|
|
|
if i1 not in idx_list and i2 not in idx_list:
|
|
|
|
a3m_list.append(a3m_new)
|
|
|
|
idx_list_groups.append([i1,i2])
|
|
|
|
res_range_groups.append([res_range1, res_range2])
|
|
|
|
|
|
|
|
# one of the hashes is already in a multi-MSA
|
|
|
|
# find that multi-MSA and join the new pMSA to it
|
|
|
|
elif i1 in idx_list:
|
|
|
|
# which multi-MSA & sub-MSA has the hash with index `i1`?
|
|
|
|
i_a3m = np.where([i1 in group for group in idx_list_groups])[0][0]
|
|
|
|
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i1)[0][0]
|
|
|
|
|
|
|
|
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range1
|
|
|
|
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
|
|
|
|
|
|
|
|
idx_list_groups[i_a3m].append(i2)
|
|
|
|
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
|
|
|
|
L_new = res_range2[1] - res_range2[0]
|
|
|
|
res_range_groups[i_a3m].append((L, L+L_new))
|
|
|
|
|
|
|
|
elif i2 in idx_list:
|
|
|
|
# which multi-MSA & sub-MSA has the hash with index `i2`?
|
|
|
|
i_a3m = np.where([i2 in group for group in idx_list_groups])[0][0]
|
|
|
|
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i2)[0][0]
|
|
|
|
|
|
|
|
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range2
|
|
|
|
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
|
|
|
|
|
|
|
|
idx_list_groups[i_a3m].append(i1)
|
|
|
|
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
|
|
|
|
L_new = res_range1[1] - res_range1[0]
|
|
|
|
res_range_groups[i_a3m].append((L, L+L_new))
|
|
|
|
|
|
|
|
# add unpaired MSAs
|
|
|
|
# ungroup hash indices now, since we're done making multi-MSAs
|
|
|
|
idx_list = [x for group in idx_list_groups for x in group]
|
|
|
|
for i in range(N):
|
|
|
|
if i not in idx_list:
|
|
|
|
fn = params['PDB_DIR'] + '/a3m/' + hash_list[i][:3] + '/' + hash_list[i] + '.a3m.gz'
|
|
|
|
msa, ins, taxid = parse_a3m(fn)
|
|
|
|
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins),
|
|
|
|
taxid=taxid, is_paired=torch.ones(msa.shape[0]).bool())
|
|
|
|
a3m_list.append(a3m_new)
|
|
|
|
idx_list.append(i)
|
|
|
|
|
|
|
|
Ls_out = [Ls[i] for i in idx_list]
|
|
|
|
hashes_out = [hash_list[i] for i in idx_list]
|
|
|
|
|
|
|
|
# join multi-MSAs & unpaired MSAs
|
|
|
|
a3m_out = a3m_list[0]
|
|
|
|
for i in range(1, len(a3m_list)):
|
|
|
|
a3m_out = join_msas_by_taxid(a3m_out, a3m_list[i])
|
|
|
|
|
|
|
|
return a3m_out, hashes_out, Ls_out
|
|
|
|
|
|
|
|
|
2024-03-06 00:45:27 +00:00
|
|
|
def expand_multi_msa(a3m, hashes_in, hashes_out, Ls_in, Ls_out):
|
2024-03-05 06:38:17 +00:00
|
|
|
"""Expands a multi-MSA of unique chains into an MSA of a
|
|
|
|
hetero-homo-oligomer in which some chains appear more than once. The query
|
|
|
|
sequences (1st sequence of MSA) are concatenated directly along the
|
|
|
|
residue dimention. The remaining sequences are offset-tiled (i.e. "padded &
|
|
|
|
stacked") so that exact repeat sequences aren't paired.
|
|
|
|
|
|
|
|
For example, if the original multi-MSA contains unique chains 1,2,3 but
|
|
|
|
the final chain order is 1,2,1,3,3,1, this function will output an MSA like
|
|
|
|
(where - denotes a block of gap characters):
|
|
|
|
|
|
|
|
1 2 - 3 - -
|
|
|
|
- - 1 - 3 -
|
|
|
|
- - - - - 1
|
|
|
|
|
|
|
|
Parameters
|
|
|
|
----------
|
|
|
|
a3m : dict
|
|
|
|
Contains torch.Tensors `msa` and `ins` (N_seq, L) and np.array `taxid` (Nseq,),
|
|
|
|
representing the multi-MSA of unique chains.
|
|
|
|
hashes_in : list of str
|
|
|
|
Unique MSA hashes used in `a3m`.
|
|
|
|
hashes_out : list of str
|
|
|
|
Non-unique MSA hashes desired in expanded MSA.
|
|
|
|
Ls_in : list of int
|
|
|
|
Lengths of each chain in `a3m`
|
|
|
|
Ls_out : list of int
|
|
|
|
Lengths of each chain desired in expanded MSA.
|
|
|
|
params : dict
|
|
|
|
Data loading parameters
|
|
|
|
|
|
|
|
Returns
|
|
|
|
-------
|
|
|
|
a3m : dict
|
|
|
|
Contains torch.Tensors `msa` and `ins` of expanded MSA. No
|
|
|
|
taxids because no further joining needs to be done.
|
|
|
|
"""
|
|
|
|
assert(len(hashes_out)==len(Ls_out))
|
|
|
|
assert(set(hashes_in)==set(hashes_out))
|
|
|
|
assert(a3m['msa'].shape[1]==sum(Ls_in))
|
|
|
|
|
|
|
|
# figure out which oligomeric repeat is represented by each hash in `hashes_out`
|
|
|
|
# each new repeat will be offset in sequence dimension of final MSA
|
|
|
|
counts = dict()
|
|
|
|
n_copy = [] # n-th copy of this hash in `hashes`
|
|
|
|
for h in hashes_out:
|
|
|
|
if h in counts:
|
|
|
|
counts[h] += 1
|
|
|
|
else:
|
|
|
|
counts[h] = 1
|
|
|
|
n_copy.append(counts[h])
|
|
|
|
|
|
|
|
# num sequences in source & destination MSAs
|
|
|
|
N_in = a3m['msa'].shape[0]
|
|
|
|
N_out = (N_in-1)*max(n_copy)+1 # concatenate query seqs, pad&stack the rest
|
|
|
|
|
|
|
|
# source MSA
|
|
|
|
msa_in, ins_in = a3m['msa'], a3m['ins']
|
|
|
|
|
|
|
|
# initialize destination MSA to gap characters
|
|
|
|
msa_out = torch.full((N_out, sum(Ls_out)), ChemData().UNKINDEX)
|
|
|
|
ins_out = torch.full((N_out, sum(Ls_out)), 0)
|
|
|
|
|
|
|
|
# for each destination chain
|
|
|
|
for i_out, h_out in enumerate(hashes_out):
|
|
|
|
# identify index of source chain
|
|
|
|
i_in = np.where(np.array(hashes_in)==h_out)[0][0]
|
|
|
|
|
|
|
|
# residue indexes
|
|
|
|
i1_res_in = sum(Ls_in[:i_in])
|
|
|
|
i2_res_in = sum(Ls_in[:i_in+1])
|
|
|
|
i1_res_out = sum(Ls_out[:i_out])
|
|
|
|
i2_res_out = sum(Ls_out[:i_out+1])
|
|
|
|
|
|
|
|
# copy over query sequence
|
|
|
|
# NOTE: There is a bug in these next two lines!
|
|
|
|
# The second line should be ins_out[0, i1_res_out:i2_res_out] = ins_in[0, i1_res_in:i2_res_in]
|
|
|
|
msa_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
|
|
|
|
ins_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
|
|
|
|
|
|
|
|
# offset non-query sequences along sequence dimension based on repeat number of a given hash
|
|
|
|
i1_seq_out = 1+(n_copy[i_out]-1)*(N_in-1)
|
|
|
|
i2_seq_out = 1+n_copy[i_out]*(N_in-1)
|
|
|
|
# copy over non-query sequences
|
|
|
|
msa_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = msa_in[1:, i1_res_in:i2_res_in]
|
|
|
|
ins_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = ins_in[1:, i1_res_in:i2_res_in]
|
|
|
|
|
|
|
|
# only 1st oligomeric repeat can be fully paired
|
|
|
|
is_paired_out = torch.cat([a3m['is_paired'], torch.zeros((N_out-N_in,)).bool()])
|
|
|
|
|
|
|
|
a3m_out = dict(msa=msa_out, ins=ins_out, is_paired=is_paired_out)
|
|
|
|
a3m_out = remove_all_gap_seqs(a3m_out)
|
|
|
|
|
|
|
|
return a3m_out
|
|
|
|
|
|
|
|
def load_multi_msa(chain_ids, Ls, chid2hash, chid2taxid, params):
|
|
|
|
"""Loads multi-MSA for an arbitrary number of protein chains. Tries to
|
|
|
|
locate paired MSAs and pair sequences across all chains by taxonomic ID.
|
|
|
|
Unpaired sequences are padded and stacked on the bottom.
|
|
|
|
"""
|
|
|
|
# get MSA hashes (used to locate a3m files) and taxonomic IDs (used to determine pairing)
|
|
|
|
hashes = []
|
|
|
|
hashes_unique = []
|
|
|
|
taxids_unique = []
|
|
|
|
Ls_unique = []
|
|
|
|
for chid,L_ in zip(chain_ids, Ls):
|
|
|
|
hashes.append(chid2hash[chid])
|
|
|
|
if chid2hash[chid] not in hashes_unique:
|
|
|
|
hashes_unique.append(chid2hash[chid])
|
|
|
|
taxids_unique.append(chid2taxid.get(chid))
|
|
|
|
Ls_unique.append(L_)
|
|
|
|
|
|
|
|
# loads multi-MSA for unique chains
|
|
|
|
a3m_prot, hashes_unique, Ls_unique = \
|
|
|
|
load_minimal_multi_msa(hashes_unique, taxids_unique, Ls_unique, params)
|
|
|
|
|
|
|
|
# expands multi-MSA to repeat chains of homo-oligomers
|
|
|
|
a3m_prot = expand_multi_msa(a3m_prot, hashes_unique, hashes, Ls_unique, Ls, params)
|
|
|
|
|
|
|
|
return a3m_prot
|
|
|
|
|
|
|
|
def choose_multimsa_clusters(msa_seq_is_paired, params):
|
|
|
|
"""Returns indices of fully-paired sequences in a multi-MSA to use as seed
|
|
|
|
clusters during MSA featurization.
|
|
|
|
"""
|
|
|
|
frac_paired = msa_seq_is_paired.float().mean()
|
|
|
|
if frac_paired > 0.25: # enough fully paired sequences, just let MSAFeaturize choose randomly
|
|
|
|
return None
|
|
|
|
else:
|
|
|
|
# ensure that half of the clusters are fully-paired sequences,
|
|
|
|
# and let the rest be chosen randomly
|
|
|
|
N_seed = params['MAXLAT']//2
|
|
|
|
msa_seed_clus = []
|
|
|
|
for i_cycle in range(params['MAXCYCLE']):
|
|
|
|
idx_paired = torch.where(msa_seq_is_paired)[0]
|
|
|
|
msa_seed_clus.append(idx_paired[torch.randperm(len(idx_paired))][:N_seed])
|
|
|
|
return msa_seed_clus
|
|
|
|
|
|
|
|
|
|
|
|
#fd
|
|
|
|
def get_bond_distances(bond_feats):
|
|
|
|
atom_bonds = (bond_feats > 0)*(bond_feats<5)
|
|
|
|
dist_matrix = scipy.sparse.csgraph.shortest_path(atom_bonds.long().numpy(), directed=False)
|
|
|
|
# dist_matrix = torch.tensor(np.nan_to_num(dist_matrix, posinf=4.0)) # protein portion is inf and you don't want to mask it out
|
|
|
|
return torch.from_numpy(dist_matrix).float()
|
|
|
|
|
|
|
|
|
|
|
|
def get_pdb(pdbfilename, plddtfilename, item, lddtcut, sccut):
|
|
|
|
xyz, mask, res_idx = parse_pdb(pdbfilename)
|
|
|
|
plddt = np.load(plddtfilename)
|
|
|
|
|
|
|
|
# update mask info with plddt (ignore sidechains if plddt < 90.0)
|
|
|
|
mask_lddt = np.full_like(mask, False)
|
|
|
|
mask_lddt[plddt > sccut] = True
|
|
|
|
mask_lddt[:,:5] = True
|
|
|
|
mask = np.logical_and(mask, mask_lddt)
|
|
|
|
mask = np.logical_and(mask, (plddt > lddtcut)[:,None])
|
|
|
|
|
|
|
|
return {'xyz':torch.tensor(xyz), 'mask':torch.tensor(mask), 'idx': torch.tensor(res_idx), 'label':item}
|
|
|
|
|
|
|
|
def get_msa(a3mfilename, item, maxseq=5000):
|
|
|
|
msa,ins, taxIDs = parse_a3m(a3mfilename, maxseq=5000)
|
|
|
|
return {'msa':torch.tensor(msa), 'ins':torch.tensor(ins), 'taxIDs':taxIDs, 'label':item}
|