RoseTTAFold-All-Atom/rf2aa/data/protein.py
2024-03-05 16:45:27 -08:00

94 lines
2.5 KiB
Python

import torch
from rf2aa.data.data_loader import RawInputData
from rf2aa.data.data_loader_utils import blank_template, TemplFeaturize
from rf2aa.data.parsers import parse_a3m, parse_templates_raw
from rf2aa.data.preprocessing import make_msa
from rf2aa.util import get_protein_bond_feats
def get_templates(
qlen,
ffdb,
hhr_fn,
atab_fn,
seqID_cut,
n_templ,
pick_top: bool = True,
offset: int = 0,
random_noise: float = 5.0,
deterministic: bool = False,
):
(
xyz_parsed,
mask_parsed,
qmap_parsed,
f0d_parsed,
f1d_parsed,
seq_parsed,
ids_parsed,
) = parse_templates_raw(ffdb, hhr_fn=hhr_fn, atab_fn=atab_fn)
tplt = {
"xyz": xyz_parsed.unsqueeze(0),
"mask": mask_parsed.unsqueeze(0),
"qmap": qmap_parsed.unsqueeze(0),
"f0d": f0d_parsed.unsqueeze(0),
"f1d": f1d_parsed.unsqueeze(0),
"seq": seq_parsed.unsqueeze(0),
"ids": ids_parsed,
}
params = {
"SEQID": seqID_cut,
}
return TemplFeaturize(
tplt,
qlen,
params,
offset=offset,
npick=n_templ,
pick_top=pick_top,
random_noise=random_noise,
deterministic=deterministic,
)
def load_protein(msa_file, hhr_fn, atab_fn, model_runner):
msa, ins, taxIDs = parse_a3m(msa_file)
# NOTE: this next line is a bug, but is the way that
# the code is written in the original implementation!
ins[0] = msa[0]
L = msa.shape[1]
if hhr_fn is None or atab_fn is None:
print("No templates provided")
xyz_t, t1d, mask_t, _ = blank_template(1, L)
else:
xyz_t, t1d, mask_t, _ = get_templates(
L,
model_runner.ffdb,
hhr_fn,
atab_fn,
seqID_cut=model_runner.config.loader_params.seqid,
n_templ=model_runner.config.loader_params.n_templ,
deterministic=model_runner.deterministic,
)
bond_feats = get_protein_bond_feats(L)
chirals = torch.zeros(0, 5)
atom_frames = torch.zeros(0, 3, 2)
return RawInputData(
torch.from_numpy(msa),
torch.from_numpy(ins),
bond_feats,
xyz_t,
mask_t,
t1d,
chirals,
atom_frames,
taxids=taxIDs,
)
def generate_msa_and_load_protein(fasta_file, chain, model_runner):
msa_file, hhr_file, atab_file = make_msa(fasta_file, chain, model_runner)
return load_protein(str(msa_file), str(hhr_file), str(atab_file), model_runner)