mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-24 22:37:20 +00:00
46 lines
No EOL
1.5 KiB
Python
46 lines
No EOL
1.5 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
from rf2aa.data.parsers import parse_mixed_fasta, parse_multichain_fasta
|
|
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, blank_template
|
|
from rf2aa.data.data_loader import RawInputData
|
|
from rf2aa.util import get_protein_bond_feats
|
|
|
|
def load_nucleic_acid(fasta_fn, input_type, model_runner):
|
|
if input_type not in ["dna", "rna"]:
|
|
raise ValueError("Only DNA and RNA inputs allowed for nucleic acids")
|
|
if input_type == "dna":
|
|
dna_alphabet = True
|
|
rna_alphabet = False
|
|
elif input_type == "rna":
|
|
dna_alphabet = False
|
|
rna_alphabet = True
|
|
|
|
loader_params = model_runner.config.loader_params
|
|
msa, ins, L = parse_multichain_fasta(fasta_fn, rna_alphabet=rna_alphabet, dna_alphabet=dna_alphabet)
|
|
if (msa.shape[0] > loader_params["MAXSEQ"]):
|
|
idxs_tokeep = np.random.permutation(msa.shape[0])[:loader_params["MAXSEQ"]]
|
|
idxs_tokeep[0] = 0
|
|
msa = msa[idxs_tokeep]
|
|
ins = ins[idxs_tokeep]
|
|
if len(L) > 1:
|
|
raise ValueError("Please provide separate fasta files for each nucleic acid chain")
|
|
L = L[0]
|
|
xyz_t, t1d, mask_t, _ = blank_template(loader_params["n_templ"], L)
|
|
|
|
|
|
bond_feats = get_protein_bond_feats(L)
|
|
chirals = torch.zeros(0, 5)
|
|
atom_frames = torch.zeros(0, 3, 2)
|
|
|
|
return RawInputData(
|
|
torch.from_numpy(msa),
|
|
torch.from_numpy(ins),
|
|
bond_feats,
|
|
xyz_t,
|
|
mask_t,
|
|
t1d,
|
|
chirals,
|
|
atom_frames,
|
|
taxids=None,
|
|
) |