RoseTTAFold-All-Atom/rf2aa/setup_model.py
2024-03-04 22:38:17 -08:00

91 lines
3.1 KiB
Python

import torch
import numpy as np
import hydra
import os
from rf2aa.training.EMA import EMA
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
from rf2aa.util_module import XYZConverter
from rf2aa.chemical import ChemicalData as ChemData
#TODO: control environment variables from config
# limit thread counts
os.environ['OMP_NUM_THREADS'] = '4'
os.environ['OPENBLAS_NUM_THREADS'] = '4'
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
## To reproduce errors
import random
def seed_all(seed=0):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.set_num_threads(4)
#torch.autograd.set_detect_anomaly(True)
class Trainer:
def __init__(self, config) -> None:
self.config = config
assert self.config.ddp_params.batch_size == 1, "batch size is assumed to be 1"
if self.config.experiment.output_dir is not None:
self.output_dir = self.config.experiment.output_dir
else:
self.output_dir = "models/"
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
def move_constants_to_device(self, gpu):
self.fi_dev = ChemData().frame_indices.to(gpu)
self.xyz_converter = XYZConverter().to(gpu)
self.l2a = ChemData().long2alt.to(gpu)
self.aamask = ChemData().allatom_mask.to(gpu)
self.num_bonds = ChemData().num_bonds.to(gpu)
self.atom_type_index = ChemData().atom_type_index.to(gpu)
self.ljlk_parameters = ChemData().ljlk_parameters.to(gpu)
self.lj_correction_parameters = ChemData().lj_correction_parameters.to(gpu)
self.hbtypes = ChemData().hbtypes.to(gpu)
self.hbbaseatoms = ChemData().hbbaseatoms.to(gpu)
self.hbpolys = ChemData().hbpolys.to(gpu)
self.cb_len = ChemData().cb_length_t.to(gpu)
self.cb_ang = ChemData().cb_angle_t.to(gpu)
self.cb_tor = ChemData().cb_torsion_t.to(gpu)
class LegacyTrainer(Trainer):
def __init__(self, config) -> None:
super().__init__(config)
def construct_model(self, device="cpu"):
self.model = RoseTTAFoldModule(
**self.config.legacy_model_param,
aamask = ChemData().allatom_mask.to(device),
atom_type_index = ChemData().atom_type_index.to(device),
ljlk_parameters = ChemData().ljlk_parameters.to(device),
lj_correction_parameters = ChemData().lj_correction_parameters.to(device),
num_bonds = ChemData().num_bonds.to(device),
cb_len = ChemData().cb_length_t.to(device),
cb_ang = ChemData().cb_angle_t.to(device),
cb_tor = ChemData().cb_torsion_t.to(device),
).to(device)
if self.config.training_params.EMA is not None:
self.model = EMA(self.model, self.config.training_params.EMA)
@hydra.main(version_base=None, config_path='config/train')
def main(config):
seed_all()
trainer = trainer_factory[config.experiment.trainer](config=config)
trainer_factory = {
"legacy": LegacyTrainer,
}
if __name__ == "__main__":
main()