mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-04 22:25:42 +00:00
91 lines
3.1 KiB
Python
91 lines
3.1 KiB
Python
|
import torch
|
||
|
import numpy as np
|
||
|
|
||
|
|
||
|
import hydra
|
||
|
import os
|
||
|
|
||
|
from rf2aa.training.EMA import EMA
|
||
|
from rf2aa.model.RoseTTAFoldModel import RoseTTAFoldModule
|
||
|
from rf2aa.util_module import XYZConverter
|
||
|
from rf2aa.chemical import ChemicalData as ChemData
|
||
|
|
||
|
#TODO: control environment variables from config
|
||
|
# limit thread counts
|
||
|
os.environ['OMP_NUM_THREADS'] = '4'
|
||
|
os.environ['OPENBLAS_NUM_THREADS'] = '4'
|
||
|
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = "max_split_size_mb:512"
|
||
|
|
||
|
## To reproduce errors
|
||
|
import random
|
||
|
|
||
|
def seed_all(seed=0):
|
||
|
random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
|
||
|
torch.set_num_threads(4)
|
||
|
#torch.autograd.set_detect_anomaly(True)
|
||
|
|
||
|
class Trainer:
|
||
|
def __init__(self, config) -> None:
|
||
|
self.config = config
|
||
|
|
||
|
assert self.config.ddp_params.batch_size == 1, "batch size is assumed to be 1"
|
||
|
if self.config.experiment.output_dir is not None:
|
||
|
self.output_dir = self.config.experiment.output_dir
|
||
|
else:
|
||
|
self.output_dir = "models/"
|
||
|
if not os.path.exists(self.output_dir):
|
||
|
os.makedirs(self.output_dir)
|
||
|
|
||
|
def move_constants_to_device(self, gpu):
|
||
|
self.fi_dev = ChemData().frame_indices.to(gpu)
|
||
|
self.xyz_converter = XYZConverter().to(gpu)
|
||
|
|
||
|
self.l2a = ChemData().long2alt.to(gpu)
|
||
|
self.aamask = ChemData().allatom_mask.to(gpu)
|
||
|
self.num_bonds = ChemData().num_bonds.to(gpu)
|
||
|
self.atom_type_index = ChemData().atom_type_index.to(gpu)
|
||
|
self.ljlk_parameters = ChemData().ljlk_parameters.to(gpu)
|
||
|
self.lj_correction_parameters = ChemData().lj_correction_parameters.to(gpu)
|
||
|
self.hbtypes = ChemData().hbtypes.to(gpu)
|
||
|
self.hbbaseatoms = ChemData().hbbaseatoms.to(gpu)
|
||
|
self.hbpolys = ChemData().hbpolys.to(gpu)
|
||
|
self.cb_len = ChemData().cb_length_t.to(gpu)
|
||
|
self.cb_ang = ChemData().cb_angle_t.to(gpu)
|
||
|
self.cb_tor = ChemData().cb_torsion_t.to(gpu)
|
||
|
|
||
|
class LegacyTrainer(Trainer):
|
||
|
def __init__(self, config) -> None:
|
||
|
super().__init__(config)
|
||
|
|
||
|
def construct_model(self, device="cpu"):
|
||
|
self.model = RoseTTAFoldModule(
|
||
|
**self.config.legacy_model_param,
|
||
|
aamask = ChemData().allatom_mask.to(device),
|
||
|
atom_type_index = ChemData().atom_type_index.to(device),
|
||
|
ljlk_parameters = ChemData().ljlk_parameters.to(device),
|
||
|
lj_correction_parameters = ChemData().lj_correction_parameters.to(device),
|
||
|
num_bonds = ChemData().num_bonds.to(device),
|
||
|
cb_len = ChemData().cb_length_t.to(device),
|
||
|
cb_ang = ChemData().cb_angle_t.to(device),
|
||
|
cb_tor = ChemData().cb_torsion_t.to(device),
|
||
|
|
||
|
).to(device)
|
||
|
if self.config.training_params.EMA is not None:
|
||
|
self.model = EMA(self.model, self.config.training_params.EMA)
|
||
|
|
||
|
@hydra.main(version_base=None, config_path='config/train')
|
||
|
def main(config):
|
||
|
seed_all()
|
||
|
trainer = trainer_factory[config.experiment.trainer](config=config)
|
||
|
|
||
|
trainer_factory = {
|
||
|
"legacy": LegacyTrainer,
|
||
|
}
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|