import torch from hashlib import md5 from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, merge_hetero_templates, get_term_feats, join_msas_by_taxid, expand_multi_msa from rf2aa.data.data_loader import RawInputData from rf2aa.util import center_and_realign_missing, same_chain_from_bond_feats, random_rot_trans, idx_from_Ls def merge_protein_inputs(protein_inputs, deterministic: bool = False): if len(protein_inputs) == 0: return None,[] elif len(protein_inputs) == 1: chain = list(protein_inputs.keys())[0] input = list(protein_inputs.values())[0] xyz_t = input.xyz_t xyz_t[0:1] = random_rot_trans(xyz_t[0:1], deterministic=deterministic) input.xyz_t = xyz_t return input, [(chain, input.length())] # handle merging MSAs and such # first determine which sequence are identical, then which one have mergeable MSAs # then cat the templates, other feats else: a3m_list = [ {"msa": input.msa, "ins": input.ins, "taxid": input.taxids } for input in protein_inputs.values() ] hash_list = [md5(input.sequence_string().encode()).hexdigest() for input in protein_inputs.values()] lengths_list = [input.length() for input in protein_inputs.values()] seen = set() unique_indices = [] for idx, hash in enumerate(hash_list): if hash not in seen: unique_indices.append(idx) seen.add(hash) unique_a3m = [a3m for i, a3m in enumerate(a3m_list) if i in unique_indices ] unique_hashes = [value for index, value in enumerate(hash_list) if index in unique_indices] unique_lengths_list = [value for index, value in enumerate(lengths_list) if index in unique_indices] if len(unique_a3m) >1: a3m_out = unique_a3m[0] for i in range(1, len(unique_a3m)): a3m_out = join_msas_by_taxid(a3m_out, a3m_list[i]) a3m_out = expand_multi_msa(a3m_out, unique_hashes, hash_list, unique_lengths_list, lengths_list) else: a3m = unique_a3m[0] msa, ins = a3m["msa"], a3m["ins"] a3m_out = merge_a3m_homo(msa, ins, len(hash_list)) # merge templates max_template_dim = max([input.xyz_t.shape[0] for input in protein_inputs.values()]) xyz_t_list = [input.xyz_t for input in protein_inputs.values()] mask_t_list = [input.mask_t for input in protein_inputs.values()] t1d_list = [input.t1d for input in protein_inputs.values()] ids = ["inference"] * len(t1d_list) xyz_t, t1d, mask_t, _ = merge_hetero_templates(xyz_t_list, t1d_list, mask_t_list, ids, lengths_list, deterministic=deterministic) atom_frames = torch.zeros(0,3,2) chirals = torch.zeros(0,5) L_total = sum(lengths_list) bond_feats = torch.zeros((L_total, L_total)).long() offset = 0 for bf in [input.bond_feats for input in protein_inputs.values()]: L = bf.shape[0] bond_feats[offset:offset+L, offset:offset+L] = bf offset += L chain_lengths = list(zip(protein_inputs.keys(), lengths_list)) merged_input = RawInputData( a3m_out["msa"], a3m_out["ins"], bond_feats, xyz_t[:max_template_dim], mask_t[:max_template_dim], t1d[:max_template_dim], chirals, atom_frames, taxids=None ) return merged_input, chain_lengths def merge_na_inputs(na_inputs): # should just be trivially catting features running_inputs = None chain_lengths = [] for chid, input in na_inputs.items(): running_inputs = merge_two_inputs(running_inputs, input) chain_lengths.append((chid, input.length())) return running_inputs, chain_lengths def merge_sm_inputs(sm_inputs): # should be trivially catting features running_inputs = None chain_lengths = [] for chid, input in sm_inputs.items(): running_inputs = merge_two_inputs(running_inputs, input) chain_lengths.append((chid, input.length())) return running_inputs, chain_lengths def merge_two_inputs(first_input, second_input): # merges two arbitrary inputs of data types if first_input is None and second_input is None: return None elif first_input is None: return second_input elif second_input is None: return first_input Ls = [first_input.length(), second_input.length()] L_total = sum(Ls) # merge msas a3m_first = { "msa": first_input.msa, "ins": first_input.ins, } a3m_second = { "msa": second_input.msa, "ins": second_input.ins, } a3m = merge_a3m_hetero(a3m_first, a3m_second, Ls) # merge bond_feats bond_feats = torch.zeros((L_total, L_total)).long() offset = 0 for bf in [first_input.bond_feats, second_input.bond_feats]: L = bf.shape[0] bond_feats[offset:offset+L, offset:offset+L] = bf offset += L # merge templates xyz_t = torch.cat([first_input.xyz_t, second_input.xyz_t],dim=1) t1d = torch.cat([first_input.t1d, second_input.t1d],dim=1) mask_t = torch.cat([first_input.mask_t, second_input.mask_t],dim=1) # handle chirals (need to residue offset) if second_input.chirals.shape[0] > 0 : second_input.chirals[:, :-1] = second_input.chirals[:, :-1] + first_input.length() chirals = torch.cat([first_input.chirals, second_input.chirals]) # cat atom frames atom_frames = torch.cat([first_input.atom_frames, second_input.atom_frames]) # return new object return RawInputData( a3m["msa"], a3m["ins"], bond_feats, xyz_t, mask_t, t1d, chirals, atom_frames, taxids=None ) def merge_all( protein_inputs, na_inputs, sm_inputs, residues_to_atomize, deterministic: bool = False, ): protein_inputs, protein_chain_lengths = merge_protein_inputs(protein_inputs, deterministic=deterministic) na_inputs, na_chain_lengths = merge_na_inputs(na_inputs) sm_inputs, sm_chain_lengths = merge_sm_inputs(sm_inputs) if protein_inputs is None and na_inputs is None and sm_inputs is None: raise ValueError("No valid inputs were provided") running_inputs = merge_two_inputs(protein_inputs, na_inputs) #could handle pairing protein/NA MSAs here running_inputs = merge_two_inputs(running_inputs, sm_inputs) all_chain_lengths = protein_chain_lengths + na_chain_lengths + sm_chain_lengths running_inputs.chain_lengths = all_chain_lengths all_lengths = get_Ls_from_chain_lengths(running_inputs.chain_lengths) protein_lengths = get_Ls_from_chain_lengths(protein_chain_lengths) term_info = get_term_feats(all_lengths) term_info[sum(protein_lengths):, :] = 0 running_inputs.term_info = term_info xyz_t = running_inputs.xyz_t mask_t = running_inputs.mask_t same_chain = same_chain = same_chain_from_bond_feats(running_inputs.bond_feats) ntempl = xyz_t.shape[0] xyz_t = torch.stack( [center_and_realign_missing(xyz_t[i], mask_t[i], same_chain=same_chain) for i in range(ntempl)] ) xyz_t = torch.nan_to_num(xyz_t) running_inputs.xyz_t = xyz_t running_inputs.idx = idx_from_Ls(all_lengths) # after everything is merged need to add bond feats for covales # reindex protein feats function if residues_to_atomize: running_inputs.update_protein_features_after_atomize(residues_to_atomize) return running_inputs def get_Ls_from_chain_lengths(chain_lengths): return [val[1] for val in chain_lengths]