mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-04 22:25:42 +00:00
remove pickle files
This commit is contained in:
parent
097ad85d4e
commit
bd290cca68
6 changed files with 0 additions and 152 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -1,79 +0,0 @@
|
||||||
import torch
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
import itertools
|
|
||||||
from collections import OrderedDict
|
|
||||||
from hydra import initialize, compose
|
|
||||||
|
|
||||||
from rf2aa.setup_model import trainer_factory, seed_all
|
|
||||||
from rf2aa.chemical import ChemicalData as ChemData
|
|
||||||
|
|
||||||
# configurations to test
|
|
||||||
configs = ["legacy_train"]
|
|
||||||
datasets = ["compl", "na_compl", "rna", "sm_compl", "sm_compl_covale", "sm_compl_asmb"]
|
|
||||||
|
|
||||||
cfg_overrides = [
|
|
||||||
"loader_params.p_msa_mask=0.0",
|
|
||||||
"loader_params.crop=100000",
|
|
||||||
"loader_params.mintplt=0",
|
|
||||||
"loader_params.maxtplt=2"
|
|
||||||
]
|
|
||||||
|
|
||||||
def make_deterministic(seed=0):
|
|
||||||
seed_all(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
torch.backends.cudnn.benchmark = False
|
|
||||||
|
|
||||||
def setup_dataset_names():
|
|
||||||
data = {}
|
|
||||||
for name in datasets:
|
|
||||||
data[name] = [name]
|
|
||||||
return data
|
|
||||||
|
|
||||||
# set up models for regression tests
|
|
||||||
def setup_models(device="cpu"):
|
|
||||||
models, chem_cfgs = [], []
|
|
||||||
for config in configs:
|
|
||||||
with initialize(version_base=None, config_path="../config/train"):
|
|
||||||
cfg = compose(config_name=config, overrides=cfg_overrides)
|
|
||||||
|
|
||||||
# initializing the model needs the chemical DB initialized. Force a reload
|
|
||||||
ChemData.reset()
|
|
||||||
ChemData(cfg.chem_params)
|
|
||||||
|
|
||||||
trainer = trainer_factory[cfg.experiment.trainer](cfg)
|
|
||||||
seed_all()
|
|
||||||
trainer.construct_model(device=device)
|
|
||||||
models.append(trainer.model)
|
|
||||||
chem_cfgs.append(cfg.chem_params)
|
|
||||||
trainer = None
|
|
||||||
|
|
||||||
return dict(zip(configs, (zip(configs, models, chem_cfgs))))
|
|
||||||
|
|
||||||
# set up job array for regression
|
|
||||||
def setup_array(datasets, models, device="cpu"):
|
|
||||||
test_data = setup_dataset_names()
|
|
||||||
test_models = setup_models(device=device)
|
|
||||||
test_data = [test_data[dataset] for dataset in datasets]
|
|
||||||
test_models = [test_models[model] for model in models]
|
|
||||||
return (list(itertools.product(test_data, test_models)))
|
|
||||||
|
|
||||||
def random_param_init(model):
|
|
||||||
seed_all()
|
|
||||||
with torch.no_grad():
|
|
||||||
fake_state_dict = OrderedDict()
|
|
||||||
for name, param in model.model.named_parameters():
|
|
||||||
fake_state_dict[name] = torch.randn_like(param)
|
|
||||||
model.model.load_state_dict(fake_state_dict)
|
|
||||||
model.shadow.load_state_dict(fake_state_dict)
|
|
||||||
return model
|
|
||||||
|
|
||||||
def dataset_pickle_path(dataset_name):
|
|
||||||
return f"test_pickles/data/{dataset_name}_regression.pt"
|
|
||||||
|
|
||||||
def model_pickle_path(dataset_name, model_name):
|
|
||||||
return f"test_pickles/model/{model_name}_{dataset_name}_regression.pt"
|
|
||||||
|
|
||||||
def loss_pickle_path(dataset_name, model_name, loss_name):
|
|
||||||
return f"test_pickles/loss/{loss_name}_{model_name}_{dataset_name}_regression.pt"
|
|
|
@ -1,73 +0,0 @@
|
||||||
import os
|
|
||||||
import torch
|
|
||||||
import pytest
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings("ignore")
|
|
||||||
|
|
||||||
from rf2aa.data.dataloader_adaptor import prepare_input
|
|
||||||
from rf2aa.training.recycling import run_model_forward_legacy
|
|
||||||
from rf2aa.tensor_util import assert_equal
|
|
||||||
from rf2aa.tests.test_conditions import setup_array,\
|
|
||||||
make_deterministic, dataset_pickle_path, model_pickle_path
|
|
||||||
from rf2aa.util_module import XYZConverter
|
|
||||||
from rf2aa.chemical import ChemicalData as ChemData
|
|
||||||
|
|
||||||
|
|
||||||
# goal is to test all the configs on a broad set of datasets
|
|
||||||
|
|
||||||
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
||||||
|
|
||||||
legacy_test_conditions = setup_array(["na_compl", "rna", "sm_compl", "sm_compl_covale"], ["legacy_train"], device=gpu)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("example,model", legacy_test_conditions)
|
|
||||||
def test_regression_legacy(example, model):
|
|
||||||
dataset_name, dataset_inputs, model_name, model = setup_test(example, model)
|
|
||||||
make_deterministic()
|
|
||||||
output_i = run_model_forward_legacy(model, dataset_inputs, gpu)
|
|
||||||
model_pickle = model_pickle_path(dataset_name, model_name)
|
|
||||||
output_names = ("logits_c6d", "logits_aa", "logits_pae", \
|
|
||||||
"logits_pde", "p_bind", "xyz", "alpha", "xyz_allatom", \
|
|
||||||
"lddt", "seq", "pair", "state")
|
|
||||||
|
|
||||||
if not os.path.exists(model_pickle):
|
|
||||||
torch.save(output_i, model_pickle)
|
|
||||||
else:
|
|
||||||
output_regression = torch.load(model_pickle, map_location=gpu)
|
|
||||||
for idx, output in enumerate(output_i):
|
|
||||||
got = output
|
|
||||||
want = output_regression[idx]
|
|
||||||
if output_names[idx] == "logits_c6d":
|
|
||||||
for i in range(len(want)):
|
|
||||||
|
|
||||||
got_i = got[i]
|
|
||||||
want_i = want[i]
|
|
||||||
try:
|
|
||||||
assert_equal(got_i, want_i)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"{output_names[idx]} not same for model: {model_name} on dataset: {dataset_name}") from e
|
|
||||||
elif output_names[idx] in ["alpha", "xyz_allatom", "seq", "pair", "state"]:
|
|
||||||
try:
|
|
||||||
assert torch.allclose(got, want, atol=1e-4)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"{output_names[idx]} not same for model: {model_name} on dataset: {dataset_name}") from e
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
assert_equal(got, want)
|
|
||||||
except Exception as e:
|
|
||||||
raise ValueError(f"{output_names[idx]} not same for model: {model_name} on dataset: {dataset_name}") from e
|
|
||||||
|
|
||||||
def setup_test(example, model):
|
|
||||||
model_name, model, config = model
|
|
||||||
|
|
||||||
# initialize chemical database
|
|
||||||
ChemData.reset() # force reload chemical data
|
|
||||||
ChemData(config)
|
|
||||||
|
|
||||||
model = model.to(gpu)
|
|
||||||
dataset_name = example[0]
|
|
||||||
dataloader_inputs = torch.load(dataset_pickle_path(dataset_name), map_location=gpu)
|
|
||||||
xyz_converter = XYZConverter().to(gpu)
|
|
||||||
task, item, network_input, true_crds, mask_crds, msa, mask_msa, unclamp, \
|
|
||||||
negative, symmRs, Lasu, ch_label = prepare_input(dataloader_inputs,xyz_converter, gpu)
|
|
||||||
return dataset_name, network_input, model_name, model
|
|
||||||
|
|
Loading…
Reference in a new issue