RoseTTAFold-All-Atom/rf2aa/data/nucleic_acid.py
2024-03-04 22:38:17 -08:00

46 lines
No EOL
1.5 KiB
Python

import numpy as np
import torch
from rf2aa.data.parsers import parse_mixed_fasta, parse_multichain_fasta
from rf2aa.data.data_loader_utils import merge_a3m_hetero, merge_a3m_homo, blank_template
from rf2aa.data.data_loader import RawInputData
from rf2aa.util import get_protein_bond_feats
def load_nucleic_acid(fasta_fn, input_type, model_runner):
if input_type not in ["dna", "rna"]:
raise ValueError("Only DNA and RNA inputs allowed for nucleic acids")
if input_type == "dna":
dna_alphabet = True
rna_alphabet = False
elif input_type == "rna":
dna_alphabet = False
rna_alphabet = True
loader_params = model_runner.config.loader_params
msa, ins, L = parse_multichain_fasta(fasta_fn, rna_alphabet=rna_alphabet, dna_alphabet=dna_alphabet)
if (msa.shape[0] > loader_params["MAXSEQ"]):
idxs_tokeep = np.random.permutation(msa.shape[0])[:loader_params["MAXSEQ"]]
idxs_tokeep[0] = 0
msa = msa[idxs_tokeep]
ins = ins[idxs_tokeep]
if len(L) > 1:
raise ValueError("Please provide separate fasta files for each nucleic acid chain")
L = L[0]
xyz_t, t1d, mask_t, _ = blank_template(loader_params["n_templ"], L)
bond_feats = get_protein_bond_feats(L)
chirals = torch.zeros(0, 5)
atom_frames = torch.zeros(0, 3, 2)
return RawInputData(
torch.from_numpy(msa),
torch.from_numpy(ins),
bond_feats,
xyz_t,
mask_t,
t1d,
chirals,
atom_frames,
taxids=None,
)