RoseTTAFold-All-Atom/rf2aa/data/data_loader_utils.py
2024-03-05 16:45:27 -08:00

910 lines
39 KiB
Python

import torch
import warnings
import time
from icecream import ic
from torch.utils import data
import os, csv, random, pickle, gzip, itertools, time, ast, copy, sys
script_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(script_dir)
sys.path.append(script_dir+'/../')
import numpy as np
import scipy
import networkx as nx
from rf2aa.data.parsers import parse_a3m, parse_pdb
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.util import random_rot_trans, \
is_atom, is_protein, is_nucleic, is_atom
def MSABlockDeletion(msa, ins, nb=5):
'''
Input: MSA having shape (N, L)
output: new MSA with block deletion
'''
N, L = msa.shape
block_size = max(int(N*0.3), 1)
block_start = np.random.randint(low=1, high=N, size=nb) # (nb)
to_delete = block_start[:,None] + np.arange(block_size)[None,:]
to_delete = np.unique(np.clip(to_delete, 1, N-1))
#
mask = np.ones(N, bool)
mask[to_delete] = 0
return msa[mask], ins[mask]
def cluster_sum(data, assignment, N_seq, N_res):
csum = torch.zeros(N_seq, N_res, data.shape[-1], device=data.device).scatter_add(0, assignment.view(-1,1,1).expand(-1,N_res,data.shape[-1]), data.float())
return csum
def get_term_feats(Ls):
"""Creates N/C-terminus binary features"""
term_info = torch.zeros((sum(Ls),2)).float()
start = 0
for L_chain in Ls:
term_info[start, 0] = 1.0 # flag for N-term
term_info[start+L_chain-1,1] = 1.0 # flag for C-term
start += L_chain
return term_info
def MSAFeaturize(msa, ins, params, p_mask=0.15, eps=1e-6, nmer=1, L_s=[],
term_info=None, tocpu=False, fixbb=False, seed_msa_clus=None, deterministic=False):
'''
Input: full MSA information (after Block deletion if necessary) & full insertion information
Output: seed MSA features & extra sequences
Seed MSA features:
- aatype of seed sequence (20 regular aa + 1 gap/unknown + 1 mask)
- profile of clustered sequences (22)
- insertion statistics (2)
- N-term or C-term? (2)
extra sequence features:
- aatype of extra sequence (22)
- insertion info (1)
- N-term or C-term? (2)
'''
if deterministic:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
# TODO: delete me, just for testing purposes
msa = msa[:2]
if fixbb:
p_mask = 0
msa = msa[:1]
ins = ins[:1]
N, L = msa.shape
if term_info is None:
if len(L_s)==0:
L_s = [L]
term_info = get_term_feats(L_s)
term_info = term_info.to(msa.device)
#binding_site = torch.zeros((L,1), device=msa.device).float()
binding_site = torch.zeros((L,0), device=msa.device).float() # keeping this off for now (Jue 12/19)
# raw MSA profile
raw_profile = torch.nn.functional.one_hot(msa, num_classes=ChemData().NAATOKENS) # N x L x NAATOKENS
raw_profile = raw_profile.float().mean(dim=0) # L x NAATOKENS
# Select Nclust sequence randomly (seed MSA or latent MSA)
Nclust = (min(N, params['MAXLAT'])-1) // nmer
Nclust = Nclust*nmer + 1
if N > Nclust*2:
Nextra = N - Nclust
else:
Nextra = N
Nextra = min(Nextra, params['MAXSEQ']) // nmer
Nextra = max(1, Nextra * nmer)
#
b_seq = list()
b_msa_clust = list()
b_msa_seed = list()
b_msa_extra = list()
b_mask_pos = list()
for i_cycle in range(params['MAXCYCLE']):
sample_mono = torch.randperm((N-1)//nmer, device=msa.device)
sample = [sample_mono + imer*((N-1)//nmer) for imer in range(nmer)]
sample = torch.stack(sample, dim=-1)
sample = sample.reshape(-1)
# add MSA clusters pre-chosen before calling this function
if seed_msa_clus is not None:
sample_orig_shape = sample.shape
sample_seed = seed_msa_clus[i_cycle]
sample_more = torch.tensor([i for i in sample if i not in sample_seed])
N_sample_more = len(sample) - len(sample_seed)
if N_sample_more > 0:
sample_more = sample_more[torch.randperm(len(sample_more))[:N_sample_more]]
sample = torch.cat([sample_seed, sample_more])
else:
sample = sample_seed[:len(sample)] # take all clusters from pre-chosen ones
msa_clust = torch.cat((msa[:1,:], msa[1:,:][sample[:Nclust-1]]), dim=0)
ins_clust = torch.cat((ins[:1,:], ins[1:,:][sample[:Nclust-1]]), dim=0)
# 15% random masking
# - 10%: aa replaced with a uniformly sampled random amino acid
# - 10%: aa replaced with an amino acid sampled from the MSA profile
# - 10%: not replaced
# - 70%: replaced with a special token ("mask")
random_aa = torch.tensor([[0.05]*20 + [0.0]*(ChemData().NAATOKENS-20)], device=msa.device)
same_aa = torch.nn.functional.one_hot(msa_clust, num_classes=ChemData().NAATOKENS)
# explicitly remove probabilities from nucleic acids and atoms
#same_aa[..., ChemData().NPROTAAS:] = 0
#raw_profile[...,ChemData().NPROTAAS:] = 0
probs = 0.1*random_aa + 0.1*raw_profile + 0.1*same_aa
#probs = torch.nn.functional.pad(probs, (0, 1), "constant", 0.7)
# explicitly set the probability of masking for nucleic acids and atoms
#probs[...,is_protein(seq),ChemData().MASKINDEX]=0.7
#probs[...,~is_protein(seq), :] = 0 # probably overkill but set all none protein elements to 0
#probs[1:, ~is_protein(seq),20] = 1.0 # want to leave the gaps as gaps
#probs[0,is_nucleic(seq), ChemData().MASKINDEX] = 1.0
#probs[0,is_atom(seq), ChemData().aa2num["ATM"]] = 1.0
sampler = torch.distributions.categorical.Categorical(probs=probs)
mask_sample = sampler.sample()
mask_pos = torch.rand(msa_clust.shape, device=msa_clust.device) < p_mask
mask_pos[msa_clust>ChemData().MASKINDEX]=False # no masking on NAs
use_seq = msa_clust
msa_masked = torch.where(mask_pos, mask_sample, use_seq)
b_seq.append(msa_masked[0].clone())
## get extra sequenes
if N > Nclust*2: # there are enough extra sequences
msa_extra = msa[1:,:][sample[Nclust-1:]]
ins_extra = ins[1:,:][sample[Nclust-1:]]
extra_mask = torch.full(msa_extra.shape, False, device=msa_extra.device)
elif N - Nclust < 1:
msa_extra = msa_masked.clone()
ins_extra = ins_clust.clone()
extra_mask = mask_pos.clone()
else:
msa_add = msa[1:,:][sample[Nclust-1:]]
ins_add = ins[1:,:][sample[Nclust-1:]]
mask_add = torch.full(msa_add.shape, False, device=msa_add.device)
msa_extra = torch.cat((msa_masked, msa_add), dim=0)
ins_extra = torch.cat((ins_clust, ins_add), dim=0)
extra_mask = torch.cat((mask_pos, mask_add), dim=0)
N_extra = msa_extra.shape[0]
# clustering (assign remaining sequences to their closest cluster by Hamming distance
msa_clust_onehot = torch.nn.functional.one_hot(msa_masked, num_classes=ChemData().NAATOKENS)
msa_extra_onehot = torch.nn.functional.one_hot(msa_extra, num_classes=ChemData().NAATOKENS)
count_clust = torch.logical_and(~mask_pos, msa_clust != 20).float() # 20: index for gap, ignore both masked & gaps
count_extra = torch.logical_and(~extra_mask, msa_extra != 20).float()
agreement = torch.matmul((count_extra[:,:,None]*msa_extra_onehot).view(N_extra, -1), (count_clust[:,:,None]*msa_clust_onehot).view(Nclust, -1).T)
assignment = torch.argmax(agreement, dim=-1)
# seed MSA features
# 1. one_hot encoded aatype: msa_clust_onehot
# 2. cluster profile
count_extra = ~extra_mask
count_clust = ~mask_pos
msa_clust_profile = cluster_sum(count_extra[:,:,None]*msa_extra_onehot, assignment, Nclust, L)
msa_clust_profile += count_clust[:,:,None]*msa_clust_profile
count_profile = cluster_sum(count_extra[:,:,None], assignment, Nclust, L).view(Nclust, L)
count_profile += count_clust
count_profile += eps
msa_clust_profile /= count_profile[:,:,None]
# 3. insertion statistics
msa_clust_del = cluster_sum((count_extra*ins_extra)[:,:,None], assignment, Nclust, L).view(Nclust, L)
msa_clust_del += count_clust*ins_clust
msa_clust_del /= count_profile
ins_clust = (2.0/np.pi)*torch.arctan(ins_clust.float()/3.0) # (from 0 to 1)
msa_clust_del = (2.0/np.pi)*torch.arctan(msa_clust_del.float()/3.0) # (from 0 to 1)
ins_clust = torch.stack((ins_clust, msa_clust_del), dim=-1)
#
if fixbb:
assert params['MAXCYCLE'] == 1
msa_clust_profile = msa_clust_onehot
msa_extra_onehot = msa_clust_onehot
ins_clust[:] = 0
ins_extra[:] = 0
# This is how it is done in rfdiff, but really it seems like it should be all 0.
# Keeping as-is for now for consistency, as it may be used in downstream masking done
# by apply_masks.
mask_pos = torch.full_like(msa_clust, 1).bool()
msa_seed = torch.cat((msa_clust_onehot, msa_clust_profile, ins_clust, term_info[None].expand(Nclust,-1,-1)), dim=-1)
# extra MSA features
ins_extra = (2.0/np.pi)*torch.arctan(ins_extra[:Nextra].float()/3.0) # (from 0 to 1)
try:
msa_extra = torch.cat((msa_extra_onehot[:Nextra], ins_extra[:,:,None], term_info[None].expand(Nextra,-1,-1)), dim=-1)
except Exception as e:
print('msa_extra.shape',msa_extra.shape)
print('ins_extra.shape',ins_extra.shape)
if (tocpu):
b_msa_clust.append(msa_clust.cpu())
b_msa_seed.append(msa_seed.cpu())
b_msa_extra.append(msa_extra.cpu())
b_mask_pos.append(mask_pos.cpu())
else:
b_msa_clust.append(msa_clust)
b_msa_seed.append(msa_seed)
b_msa_extra.append(msa_extra)
b_mask_pos.append(mask_pos)
b_seq = torch.stack(b_seq)
b_msa_clust = torch.stack(b_msa_clust)
b_msa_seed = torch.stack(b_msa_seed)
b_msa_extra = torch.stack(b_msa_extra)
b_mask_pos = torch.stack(b_mask_pos)
return b_seq, b_msa_clust, b_msa_seed, b_msa_extra, b_mask_pos
def blank_template(n_tmpl, L, random_noise=5.0, deterministic: bool = False):
if deterministic:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(n_tmpl,L,1,1) \
+ torch.rand(n_tmpl,L,1,3)*random_noise - random_noise/2
t1d = torch.nn.functional.one_hot(torch.full((n_tmpl, L), 20).long(), num_classes=ChemData().NAATOKENS-1).float() # all gaps
conf = torch.zeros((n_tmpl, L, 1)).float()
t1d = torch.cat((t1d, conf), -1)
mask_t = torch.full((n_tmpl,L,ChemData().NTOTAL), False)
return xyz, t1d, mask_t, np.full((n_tmpl), "")
def TemplFeaturize(tplt, qlen, params, offset=0, npick=1, npick_global=None, pick_top=True, same_chain=None, random_noise=5, deterministic: bool = False):
if deterministic:
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
seqID_cut = params['SEQID']
if npick_global == None:
npick_global=max(npick, 1)
ntplt = len(tplt['ids'])
if (ntplt < 1) or (npick < 1): #no templates in hhsearch file or not want to use templ
return blank_template(npick_global, qlen, random_noise)
# ignore templates having too high seqID
if seqID_cut <= 100.0:
tplt_valid_idx = torch.where(tplt['f0d'][0,:,4] < seqID_cut)[0]
tplt['ids'] = np.array(tplt['ids'])[tplt_valid_idx]
else:
tplt_valid_idx = torch.arange(len(tplt['ids']))
# check again if there are templates having seqID < cutoff
ntplt = len(tplt['ids'])
npick = min(npick, ntplt)
if npick<1: # no templates
return blank_template(npick_global, qlen, random_noise)
if not pick_top: # select randomly among all possible templates
sample = torch.randperm(ntplt)[:npick]
else: # only consider top 50 templates
sample = torch.randperm(min(50,ntplt))[:npick]
xyz = ChemData().INIT_CRDS.reshape(1,1,ChemData().NTOTAL,3).repeat(npick_global,qlen,1,1) + torch.rand(1,qlen,1,3)*random_noise
mask_t = torch.full((npick_global,qlen,ChemData().NTOTAL),False) # True for valid atom, False for missing atom
t1d = torch.full((npick_global, qlen), 20).long()
t1d_val = torch.zeros((npick_global, qlen)).float()
for i,nt in enumerate(sample):
tplt_idx = tplt_valid_idx[nt]
sel = torch.where(tplt['qmap'][0,:,1]==tplt_idx)[0]
pos = tplt['qmap'][0,sel,0] + offset
ntmplatoms = tplt['xyz'].shape[2] # will be bigger for NA templates
xyz[i,pos,:ntmplatoms] = tplt['xyz'][0,sel]
mask_t[i,pos,:ntmplatoms] = tplt['mask'][0,sel].bool()
# 1-D features: alignment confidence
t1d[i,pos] = tplt['seq'][0,sel]
t1d_val[i,pos] = tplt['f1d'][0,sel,2] # alignment confidence
# xyz[i] = center_and_realign_missing(xyz[i], mask_t[i], same_chain=same_chain)
t1d = torch.nn.functional.one_hot(t1d, num_classes=ChemData().NAATOKENS-1).float() # (no mask token)
t1d = torch.cat((t1d, t1d_val[...,None]), dim=-1)
tplt_ids = np.array(tplt["ids"])[sample].flatten() # np.array of chain ids (ordered)
return xyz, t1d, mask_t, tplt_ids
def merge_hetero_templates(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids, Ls_prot, deterministic: bool = False):
"""Diagonally tiles template coordinates, 1d input features, and masks across
template and residue dimensions. 1st template is concatenated directly on residue
dimension after a random rotation & translation.
"""
N_tmpl_tot = sum([x.shape[0] for x in xyz_t_prot])
xyz_t_out, f1d_t_out, mask_t_out, _ = blank_template(N_tmpl_tot, sum(Ls_prot))
tplt_ids_out = np.full((N_tmpl_tot),"", dtype=object) # rk bad practice.. should fix
i_tmpl = 0
i_res = 0
for xyz_, f1d_, mask_, ids in zip(xyz_t_prot, f1d_t_prot, mask_t_prot, tplt_ids):
N_tmpl, L_tmpl = xyz_.shape[:2]
if i_tmpl == 0:
i1, i2 = 1, N_tmpl
else:
i1, i2 = i_tmpl, i_tmpl+N_tmpl - 1
# 1st template is concatenated directly, so that all atoms are set in xyz_prev
xyz_t_out[0, i_res:i_res+L_tmpl] = random_rot_trans(xyz_[0:1], deterministic=deterministic)
f1d_t_out[0, i_res:i_res+L_tmpl] = f1d_[0]
mask_t_out[0, i_res:i_res+L_tmpl] = mask_[0]
if not tplt_ids_out[0]: # only add first template
tplt_ids_out[0] = ids[0]
# remaining templates are diagonally tiled
xyz_t_out[i1:i2, i_res:i_res+L_tmpl] = xyz_[1:]
f1d_t_out[i1:i2, i_res:i_res+L_tmpl] = f1d_[1:]
mask_t_out[i1:i2, i_res:i_res+L_tmpl] = mask_[1:]
tplt_ids_out[i1:i2] = ids[1:]
if i_tmpl == 0:
i_tmpl += N_tmpl
else:
i_tmpl += N_tmpl-1
i_res += L_tmpl
return xyz_t_out, f1d_t_out, mask_t_out, tplt_ids_out
def generate_xyz_prev(xyz_t, mask_t, params):
"""
allows you to use different initializations for the coordinate track specified in params
"""
L = xyz_t.shape[1]
if params["BLACK_HOLE_INIT"]:
xyz_t, _, mask_t = blank_template(1, L)
return xyz_t[0].clone(), mask_t[0].clone()
### merge msa & insertion statistics of two proteins having different taxID
def merge_a3m_hetero(a3mA, a3mB, L_s):
# merge msa
query = torch.cat([a3mA['msa'][0], a3mB['msa'][0]]).unsqueeze(0) # (1, L)
msa = [query]
if a3mA['msa'].shape[0] > 1:
extra_A = torch.nn.functional.pad(a3mA['msa'][1:], (0,sum(L_s[1:])), "constant", 20) # pad gaps
msa.append(extra_A)
if a3mB['msa'].shape[0] > 1:
extra_B = torch.nn.functional.pad(a3mB['msa'][1:], (L_s[0],0), "constant", 20)
msa.append(extra_B)
msa = torch.cat(msa, dim=0)
# merge ins
query = torch.cat([a3mA['ins'][0], a3mB['ins'][0]]).unsqueeze(0) # (1, L)
ins = [query]
if a3mA['ins'].shape[0] > 1:
extra_A = torch.nn.functional.pad(a3mA['ins'][1:], (0,sum(L_s[1:])), "constant", 0) # pad gaps
ins.append(extra_A)
if a3mB['ins'].shape[0] > 1:
extra_B = torch.nn.functional.pad(a3mB['ins'][1:], (L_s[0],0), "constant", 0)
ins.append(extra_B)
ins = torch.cat(ins, dim=0)
a3m = {'msa': msa, 'ins': ins}
# merge taxids
if 'taxid' in a3mA and 'taxid' in a3mB:
a3m['taxid'] = np.concatenate([np.array(a3mA['taxid']), np.array(a3mB['taxid'])[1:]])
return a3m
# merge msa & insertion statistics of units in homo-oligomers
def merge_a3m_homo(msa_orig, ins_orig, nmer, mode="default"):
N, L = msa_orig.shape[:2]
if mode == "repeat":
# AAAAAA
# AAAAAA
msa = torch.tile(msa_orig,(1,nmer))
ins = torch.tile(ins_orig,(1,nmer))
elif mode == "diag":
# AAAAAA
# A-----
# -A----
# --A---
# ---A--
# ----A-
# -----A
N = N - 1
new_N = 1 + N * nmer
new_L = L * nmer
msa = torch.full((new_N, new_L), 20, dtype=msa_orig.dtype, device=msa_orig.device)
ins = torch.full((new_N, new_L), 0, dtype=ins_orig.dtype, device=msa_orig.device)
start_L = 0
start_N = 1
for i_c in range(nmer):
msa[0, start_L:start_L+L] = msa_orig[0]
msa[start_N:start_N+N, start_L:start_L+L] = msa_orig[1:]
ins[0, start_L:start_L+L] = ins_orig[0]
ins[start_N:start_N+N, start_L:start_L+L] = ins_orig[1:]
start_L += L
start_N += N
else:
# AAAAAA
# A-----
# -AAAAA
msa = torch.full((2*N-1, L*nmer), 20, dtype=msa_orig.dtype, device=msa_orig.device)
ins = torch.full((2*N-1, L*nmer), 0, dtype=ins_orig.dtype, device=msa_orig.device)
msa[:N, :L] = msa_orig
ins[:N, :L] = ins_orig
start = L
for i_c in range(1,nmer):
msa[0, start:start+L] = msa_orig[0]
msa[N:, start:start+L] = msa_orig[1:]
ins[0, start:start+L] = ins_orig[0]
ins[N:, start:start+L] = ins_orig[1:]
start += L
return msa, ins
def merge_msas(a3m_list, L_s):
"""
takes a list of a3m dictionaries with keys msa, ins and a list of protein lengths and creates a
combined MSA
"""
seen = set()
taxIDs = []
a3mA = a3m_list[0]
taxIDs.extend(a3mA["taxID"])
seen.update(a3mA["hash"])
msaA, insA = a3mA["msa"], a3mA["ins"]
for i in range(1, len(a3m_list)):
a3mB = a3m_list[i]
pair_taxIDs = set(taxIDs).intersection(set(a3mB["taxID"]))
if a3mB["hash"] in seen or len(pair_taxIDs) < 5: #homomer/not enough pairs
a3mA = {"msa": msaA, "ins": insA}
L_s_to_merge = [sum(L_s[:i]), L_s[i]]
a3mA = merge_a3m_hetero(a3mA, a3mB, L_s_to_merge)
msaA, insA = a3mA["msa"], a3mA["ins"]
taxIDs.extend(a3mB["taxID"])
else:
final_pairsA = []
final_pairsB = []
msaB, insB = a3mB["msa"], a3mB["ins"]
for pair in pair_taxIDs:
pair_a3mA = np.where(np.array(taxIDs)==pair)[0]
pair_a3mB = np.where(a3mB["taxID"]==pair)[0]
msaApair = torch.argmin(torch.sum(msaA[pair_a3mA, :] == msaA[0, :],axis=-1))
msaBpair = torch.argmin(torch.sum(msaB[pair_a3mB, :] == msaB[0, :],axis=-1))
final_pairsA.append(pair_a3mA[msaApair])
final_pairsB.append(pair_a3mB[msaBpair])
paired_msaB = torch.full((msaA.shape[0], L_s[i]), 20).long() # (N_seq_A, L_B)
paired_msaB[final_pairsA] = msaB[final_pairsB]
msaA = torch.cat([msaA, paired_msaB], dim=1)
insA = torch.zeros_like(msaA) # paired MSAs in our dataset dont have insertions
seen.update(a3mB["hash"])
return msaA, insA
def remove_all_gap_seqs(a3m):
"""Removes sequences that are all gaps from an MSA represented as `a3m` dictionary"""
idx_seq_keep = ~(a3m['msa']==ChemData().UNKINDEX).all(dim=1)
a3m['msa'] = a3m['msa'][idx_seq_keep]
a3m['ins'] = a3m['ins'][idx_seq_keep]
return a3m
def join_msas_by_taxid(a3mA, a3mB, idx_overlap=None):
"""Joins (or "pairs") 2 MSAs by matching sequences with the same
taxonomic ID. If more than 1 sequence exists in both MSAs with the same tax
ID, only the sequence with the highest sequence identity to the query (1st
sequence in MSA) will be paired.
Sequences that aren't paired will be padded and added to the bottom of the
joined MSA. If a subregion of the input MSAs overlap (represent the same
chain), the subregion residue indices can be given as `idx_overlap`, and
the overlap region of the unpaired sequences will be included in the joined
MSA.
Parameters
----------
a3mA : dict
First MSA to be joined, with keys `msa` (N_seq, L_seq), `ins` (N_seq,
L_seq), `taxid` (N_seq,), and optionally `is_paired` (N_seq,), a
boolean tensor indicating whether each sequence is fully paired. Can be
a multi-MSA (contain >2 sub-MSAs).
a3mB : dict
2nd MSA to be joined, with keys `msa`, `ins`, `taxid`, and optionally
`is_paired`. Can be a multi-MSA ONLY if not overlapping with 1st MSA.
idx_overlap : tuple or list (optional)
Start and end indices of overlap region in 1st MSA, followed by the
same in 2nd MSA.
Returns
-------
a3m : dict
Paired MSA, with keys `msa`, `ins`, `taxid` and `is_paired`.
"""
# preprocess overlap region
L_A, L_B = a3mA['msa'].shape[1], a3mB['msa'].shape[1]
if idx_overlap is not None:
i1A, i2A, i1B, i2B = idx_overlap
i1B_new, i2B_new = (0, i1B) if i2B==L_B else (i2B, L_B) # MSA B residues that don't overlap MSA A
assert((i1B==0) or (i2B==a3mB['msa'].shape[1])), \
"When overlapping with 1st MSA, 2nd MSA must comprise at most 2 sub-MSAs "\
"(i.e. residue range should include 0 or a3mB['msa'].shape[1])"
else:
i1B_new, i2B_new = (0, L_B)
# pair sequences
taxids_shared = a3mA['taxid'][np.isin(a3mA['taxid'],a3mB['taxid'])]
i_pairedA, i_pairedB = [], []
for taxid in taxids_shared:
i_match = np.where(a3mA['taxid']==taxid)[0]
i_match_best = torch.argmin(torch.sum(a3mA['msa'][i_match]==a3mA['msa'][0], axis=1))
i_pairedA.append(i_match[i_match_best])
i_match = np.where(a3mB['taxid']==taxid)[0]
i_match_best = torch.argmin(torch.sum(a3mB['msa'][i_match]==a3mB['msa'][0], axis=1))
i_pairedB.append(i_match[i_match_best])
# unpaired sequences
i_unpairedA = np.setdiff1d(np.arange(a3mA['msa'].shape[0]), i_pairedA)
i_unpairedB = np.setdiff1d(np.arange(a3mB['msa'].shape[0]), i_pairedB)
N_paired, N_unpairedA, N_unpairedB = len(i_pairedA), len(i_unpairedA), len(i_unpairedB)
# handle overlap region
# if msa A consists of sub-MSAs 1,2,3 and msa B of 2,4 (i.e overlap region is 2),
# this diagram shows how the variables below make up the final multi-MSA
# (* denotes nongaps, - denotes gaps)
# 1 2 3 4
# |*|*|*|*| msa_paired
# |*|*|*|-| msaA_unpaired
# |-|*|-|*| msaB_unpaired
if idx_overlap is not None:
assert((a3mA['msa'][i_pairedA, i1A:i2A]==a3mB['msa'][i_pairedB, i1B:i2B]) |
(a3mA['msa'][i_pairedA, i1A:i2A]==ChemData().UNKINDEX)).all(),\
'Paired MSAs should be identical (or 1st MSA should be all gaps) in overlap region'
# overlap region gets sequences from 2nd MSA bc sometimes 1st MSA will be all gaps here
msa_paired = torch.cat([a3mA['msa'][i_pairedA, :i1A],
a3mB['msa'][i_pairedB, i1B:i2B],
a3mA['msa'][i_pairedA, i2A:],
a3mB['msa'][i_pairedB, i1B_new:i2B_new] ], dim=1)
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
torch.full((N_unpairedA, i2B_new-i1B_new), ChemData().UNKINDEX) ], dim=1)
msaB_unpaired = torch.cat([torch.full((N_unpairedB, i1A), ChemData().UNKINDEX),
a3mB['msa'][i_unpairedB, i1B:i2B],
torch.full((N_unpairedB, L_A-i2A), ChemData().UNKINDEX),
a3mB['msa'][i_unpairedB, i1B_new:i2B_new] ], dim=1)
else:
# no overlap region, simple offset pad & stack
# this code is actually a special case of "if" block above, but writing
# this out explicitly here to make the logic more clear
msa_paired = torch.cat([a3mA['msa'][i_pairedA], a3mB['msa'][i_pairedB, i1B_new:i2B_new]], dim=1)
msaA_unpaired = torch.cat([a3mA['msa'][i_unpairedA],
torch.full((N_unpairedA, L_B), ChemData().UNKINDEX)], dim=1) # pad with gaps
msaB_unpaired = torch.cat([torch.full((N_unpairedB, L_A), ChemData().UNKINDEX),
a3mB['msa'][i_unpairedB]], dim=1) # pad with gaps
# stack paired & unpaired
msa = torch.cat([msa_paired, msaA_unpaired, msaB_unpaired], dim=0)
taxids = np.concatenate([a3mA['taxid'][i_pairedA], a3mA['taxid'][i_unpairedA], a3mB['taxid'][i_unpairedB]])
# label "fully paired" sequences (a row of MSA that was never padded with gaps)
# output seq is fully paired if seqs A & B both started out as paired and were paired to
# each other on tax ID.
# NOTE: there is a rare edge case that is ignored here for simplicity: if
# pMSA 0+1 and 1+2 are joined and then joined to 2+3, a seq that exists in
# 0+1 and 2+3 but NOT 1+2 will become fully paired on the last join but
# will not be labeled as such here
is_pairedA = a3mA['is_paired'] if 'is_paired' in a3mA else torch.ones((a3mA['msa'].shape[0],)).bool()
is_pairedB = a3mB['is_paired'] if 'is_paired' in a3mB else torch.ones((a3mB['msa'].shape[0],)).bool()
is_paired = torch.cat([is_pairedA[i_pairedA] & is_pairedB[i_pairedB],
torch.zeros((N_unpairedA + N_unpairedB,)).bool()])
# insertion features in paired MSAs are assumed to be zero
a3m = dict(msa=msa, ins=torch.zeros_like(msa), taxid=taxids, is_paired=is_paired)
return a3m
def load_minimal_multi_msa(hash_list, taxid_list, Ls, params):
"""Load a multi-MSA, which is a MSA that is paired across more than 2
chains. This loads the MSA for unique chains. Use 'expand_multi_msa` to
duplicate portions of the MSA for homo-oligomer repeated chains.
Given a list of unique MSA hashes, loads all MSAs (using paired MSAs where
it can) and pairs sequences across as many sub-MSAs as possible by matching
taxonomic ID. For details on how pairing is done, see
`join_msas_by_taxid()`
Parameters
----------
hash_list : list of str
Hashes of MSAs to load and join. Must not contain duplicates.
taxid_list : list of str
Taxonomic IDs of query sequences of each input MSA.
Ls : list of int
Lengths of the chains corresponding to the hashes.
Returns
-------
a3m_out : dict
Multi-MSA with all input MSAs. Keys: `msa`,`ins` [torch.Tensor (N_seq, L)],
`taxid` [np.array (Nseq,)], `is_paired` [torch.Tensor (N_seq,)]
hashes_out : list of str
Hashes of MSAs in the order that they are joined in `a3m_out`.
Contains the same elements as the input `hash_list` but may be in a
different order.
Ls_out : list of int
Lengths of each chain in `a3m_out`
"""
assert(len(hash_list)==len(set(hash_list))), 'Input MSA hashes must be unique'
# the lists below are constructed such that `a3m_list[i_a3m]` is a multi-MSA
# comprising sub-MSAs whose indices in the input lists are
# `i_in = idx_list_groups[i_a3m][i_submsa]`, i.e. the sub-MSA hashes are
# `hash_list[i_in]` and lengths are `Ls[i_in]`.
# Each sub-MSA spans a region of its multi-MSA `a3m_list[i_a3m][:,i_start:i_end]`,
# where `(i_start,i_end) = res_range_groups[i_a3m][i_submsa]`
a3m_list = [] # list of multi-MSAs
idx_list_groups = [] # list of lists of indices of input chains making up each multi-MSA
res_range_groups = [] # list of lists of start and end residues of each sub-MSA in multi-MSA
# iterate through all pairs of hashes and look for paired MSAs (pMSAs)
# NOTE: in the below, if pMSAs are loaded for hashes 0+1 and then 2+3, and
# later a pMSA is found for 0+2, the last MSA will not be loaded. The 0+1
# and 2+3 pMSAs will still be joined on taxID at the end, but sequences
# only present in the 0+2 pMSA pMSAs will be missed. this is probably very
# rare and so is ignored here for simplicity.
N = len(hash_list)
for i1, i2 in itertools.permutations(range(N),2):
idx_list = [x for group in idx_list_groups for x in group] # flattened list of loaded hashes
if i1 in idx_list and i2 in idx_list: continue # already loaded
if i1 == '' or i2 == '': continue # no taxID means no pMSA
# a paired MSA exists
if taxid_list[i1]==taxid_list[i2]:
h1, h2 = hash_list[i1], hash_list[i2]
fn = params['COMPL_DIR']+'/pMSA/'+h1[:3]+'/'+h2[:3]+'/'+h1+'_'+h2+'.a3m.gz'
if os.path.exists(fn):
msa, ins, taxid = parse_a3m(fn, paired=True)
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins), taxid=taxid,
is_paired=torch.ones(msa.shape[0]).bool())
res_range1 = (0,Ls[i1])
res_range2 = (Ls[i1],msa.shape[1])
# both hashes are new, add paired MSA to list
if i1 not in idx_list and i2 not in idx_list:
a3m_list.append(a3m_new)
idx_list_groups.append([i1,i2])
res_range_groups.append([res_range1, res_range2])
# one of the hashes is already in a multi-MSA
# find that multi-MSA and join the new pMSA to it
elif i1 in idx_list:
# which multi-MSA & sub-MSA has the hash with index `i1`?
i_a3m = np.where([i1 in group for group in idx_list_groups])[0][0]
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i1)[0][0]
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range1
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
idx_list_groups[i_a3m].append(i2)
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
L_new = res_range2[1] - res_range2[0]
res_range_groups[i_a3m].append((L, L+L_new))
elif i2 in idx_list:
# which multi-MSA & sub-MSA has the hash with index `i2`?
i_a3m = np.where([i2 in group for group in idx_list_groups])[0][0]
i_submsa = np.where(np.array(idx_list_groups[i_a3m])==i2)[0][0]
idx_overlap = res_range_groups[i_a3m][i_submsa] + res_range2
a3m_list[i_a3m] = join_msas_by_taxid(a3m_list[i_a3m], a3m_new, idx_overlap)
idx_list_groups[i_a3m].append(i1)
L = res_range_groups[i_a3m][-1][1] # length of current multi-MSA
L_new = res_range1[1] - res_range1[0]
res_range_groups[i_a3m].append((L, L+L_new))
# add unpaired MSAs
# ungroup hash indices now, since we're done making multi-MSAs
idx_list = [x for group in idx_list_groups for x in group]
for i in range(N):
if i not in idx_list:
fn = params['PDB_DIR'] + '/a3m/' + hash_list[i][:3] + '/' + hash_list[i] + '.a3m.gz'
msa, ins, taxid = parse_a3m(fn)
a3m_new = dict(msa=torch.tensor(msa), ins=torch.tensor(ins),
taxid=taxid, is_paired=torch.ones(msa.shape[0]).bool())
a3m_list.append(a3m_new)
idx_list.append(i)
Ls_out = [Ls[i] for i in idx_list]
hashes_out = [hash_list[i] for i in idx_list]
# join multi-MSAs & unpaired MSAs
a3m_out = a3m_list[0]
for i in range(1, len(a3m_list)):
a3m_out = join_msas_by_taxid(a3m_out, a3m_list[i])
return a3m_out, hashes_out, Ls_out
def expand_multi_msa(a3m, hashes_in, hashes_out, Ls_in, Ls_out):
"""Expands a multi-MSA of unique chains into an MSA of a
hetero-homo-oligomer in which some chains appear more than once. The query
sequences (1st sequence of MSA) are concatenated directly along the
residue dimention. The remaining sequences are offset-tiled (i.e. "padded &
stacked") so that exact repeat sequences aren't paired.
For example, if the original multi-MSA contains unique chains 1,2,3 but
the final chain order is 1,2,1,3,3,1, this function will output an MSA like
(where - denotes a block of gap characters):
1 2 - 3 - -
- - 1 - 3 -
- - - - - 1
Parameters
----------
a3m : dict
Contains torch.Tensors `msa` and `ins` (N_seq, L) and np.array `taxid` (Nseq,),
representing the multi-MSA of unique chains.
hashes_in : list of str
Unique MSA hashes used in `a3m`.
hashes_out : list of str
Non-unique MSA hashes desired in expanded MSA.
Ls_in : list of int
Lengths of each chain in `a3m`
Ls_out : list of int
Lengths of each chain desired in expanded MSA.
params : dict
Data loading parameters
Returns
-------
a3m : dict
Contains torch.Tensors `msa` and `ins` of expanded MSA. No
taxids because no further joining needs to be done.
"""
assert(len(hashes_out)==len(Ls_out))
assert(set(hashes_in)==set(hashes_out))
assert(a3m['msa'].shape[1]==sum(Ls_in))
# figure out which oligomeric repeat is represented by each hash in `hashes_out`
# each new repeat will be offset in sequence dimension of final MSA
counts = dict()
n_copy = [] # n-th copy of this hash in `hashes`
for h in hashes_out:
if h in counts:
counts[h] += 1
else:
counts[h] = 1
n_copy.append(counts[h])
# num sequences in source & destination MSAs
N_in = a3m['msa'].shape[0]
N_out = (N_in-1)*max(n_copy)+1 # concatenate query seqs, pad&stack the rest
# source MSA
msa_in, ins_in = a3m['msa'], a3m['ins']
# initialize destination MSA to gap characters
msa_out = torch.full((N_out, sum(Ls_out)), ChemData().UNKINDEX)
ins_out = torch.full((N_out, sum(Ls_out)), 0)
# for each destination chain
for i_out, h_out in enumerate(hashes_out):
# identify index of source chain
i_in = np.where(np.array(hashes_in)==h_out)[0][0]
# residue indexes
i1_res_in = sum(Ls_in[:i_in])
i2_res_in = sum(Ls_in[:i_in+1])
i1_res_out = sum(Ls_out[:i_out])
i2_res_out = sum(Ls_out[:i_out+1])
# copy over query sequence
# NOTE: There is a bug in these next two lines!
# The second line should be ins_out[0, i1_res_out:i2_res_out] = ins_in[0, i1_res_in:i2_res_in]
msa_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
ins_out[0, i1_res_out:i2_res_out] = msa_in[0, i1_res_in:i2_res_in]
# offset non-query sequences along sequence dimension based on repeat number of a given hash
i1_seq_out = 1+(n_copy[i_out]-1)*(N_in-1)
i2_seq_out = 1+n_copy[i_out]*(N_in-1)
# copy over non-query sequences
msa_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = msa_in[1:, i1_res_in:i2_res_in]
ins_out[i1_seq_out:i2_seq_out, i1_res_out:i2_res_out] = ins_in[1:, i1_res_in:i2_res_in]
# only 1st oligomeric repeat can be fully paired
is_paired_out = torch.cat([a3m['is_paired'], torch.zeros((N_out-N_in,)).bool()])
a3m_out = dict(msa=msa_out, ins=ins_out, is_paired=is_paired_out)
a3m_out = remove_all_gap_seqs(a3m_out)
return a3m_out
def load_multi_msa(chain_ids, Ls, chid2hash, chid2taxid, params):
"""Loads multi-MSA for an arbitrary number of protein chains. Tries to
locate paired MSAs and pair sequences across all chains by taxonomic ID.
Unpaired sequences are padded and stacked on the bottom.
"""
# get MSA hashes (used to locate a3m files) and taxonomic IDs (used to determine pairing)
hashes = []
hashes_unique = []
taxids_unique = []
Ls_unique = []
for chid,L_ in zip(chain_ids, Ls):
hashes.append(chid2hash[chid])
if chid2hash[chid] not in hashes_unique:
hashes_unique.append(chid2hash[chid])
taxids_unique.append(chid2taxid.get(chid))
Ls_unique.append(L_)
# loads multi-MSA for unique chains
a3m_prot, hashes_unique, Ls_unique = \
load_minimal_multi_msa(hashes_unique, taxids_unique, Ls_unique, params)
# expands multi-MSA to repeat chains of homo-oligomers
a3m_prot = expand_multi_msa(a3m_prot, hashes_unique, hashes, Ls_unique, Ls, params)
return a3m_prot
def choose_multimsa_clusters(msa_seq_is_paired, params):
"""Returns indices of fully-paired sequences in a multi-MSA to use as seed
clusters during MSA featurization.
"""
frac_paired = msa_seq_is_paired.float().mean()
if frac_paired > 0.25: # enough fully paired sequences, just let MSAFeaturize choose randomly
return None
else:
# ensure that half of the clusters are fully-paired sequences,
# and let the rest be chosen randomly
N_seed = params['MAXLAT']//2
msa_seed_clus = []
for i_cycle in range(params['MAXCYCLE']):
idx_paired = torch.where(msa_seq_is_paired)[0]
msa_seed_clus.append(idx_paired[torch.randperm(len(idx_paired))][:N_seed])
return msa_seed_clus
#fd
def get_bond_distances(bond_feats):
atom_bonds = (bond_feats > 0)*(bond_feats<5)
dist_matrix = scipy.sparse.csgraph.shortest_path(atom_bonds.long().numpy(), directed=False)
# dist_matrix = torch.tensor(np.nan_to_num(dist_matrix, posinf=4.0)) # protein portion is inf and you don't want to mask it out
return torch.from_numpy(dist_matrix).float()
def get_pdb(pdbfilename, plddtfilename, item, lddtcut, sccut):
xyz, mask, res_idx = parse_pdb(pdbfilename)
plddt = np.load(plddtfilename)
# update mask info with plddt (ignore sidechains if plddt < 90.0)
mask_lddt = np.full_like(mask, False)
mask_lddt[plddt > sccut] = True
mask_lddt[:,:5] = True
mask = np.logical_and(mask, mask_lddt)
mask = np.logical_and(mask, (plddt > lddtcut)[:,None])
return {'xyz':torch.tensor(xyz), 'mask':torch.tensor(mask), 'idx': torch.tensor(res_idx), 'label':item}
def get_msa(a3mfilename, item, maxseq=5000):
msa,ins, taxIDs = parse_a3m(a3mfilename, maxseq=5000)
return {'msa':torch.tensor(msa), 'ins':torch.tensor(ins), 'taxIDs':taxIDs, 'label':item}