mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-24 22:37:20 +00:00
131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
|
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
#
|
||
|
# Permission is hereby granted, free of charge, to any person obtaining a
|
||
|
# copy of this software and associated documentation files (the "Software"),
|
||
|
# to deal in the Software without restriction, including without limitation
|
||
|
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||
|
# and/or sell copies of the Software, and to permit persons to whom the
|
||
|
# Software is furnished to do so, subject to the following conditions:
|
||
|
#
|
||
|
# The above copyright notice and this permission notice shall be included in
|
||
|
# all copies or substantial portions of the Software.
|
||
|
#
|
||
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
||
|
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||
|
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||
|
# DEALINGS IN THE SOFTWARE.
|
||
|
#
|
||
|
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
|
||
|
# SPDX-License-Identifier: MIT
|
||
|
|
||
|
import argparse
|
||
|
import ctypes
|
||
|
import logging
|
||
|
import os
|
||
|
import random
|
||
|
from functools import wraps
|
||
|
from typing import Union, List, Dict
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.distributed as dist
|
||
|
from torch import Tensor
|
||
|
|
||
|
|
||
|
def aggregate_residual(feats1, feats2, method: str):
|
||
|
""" Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """
|
||
|
if method in ['add', 'sum']:
|
||
|
return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()}
|
||
|
elif method in ['cat', 'concat']:
|
||
|
return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()}
|
||
|
else:
|
||
|
raise ValueError('Method must be add/sum or cat/concat')
|
||
|
|
||
|
|
||
|
def degree_to_dim(degree: int) -> int:
|
||
|
return 2 * degree + 1
|
||
|
|
||
|
|
||
|
def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]:
|
||
|
return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1)))
|
||
|
|
||
|
|
||
|
def str2bool(v: Union[bool, str]) -> bool:
|
||
|
if isinstance(v, bool):
|
||
|
return v
|
||
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||
|
return True
|
||
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||
|
return False
|
||
|
else:
|
||
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||
|
|
||
|
|
||
|
def to_cuda(x):
|
||
|
""" Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
|
||
|
if isinstance(x, Tensor):
|
||
|
return x.cuda(non_blocking=True)
|
||
|
elif isinstance(x, tuple):
|
||
|
return (to_cuda(v) for v in x)
|
||
|
elif isinstance(x, list):
|
||
|
return [to_cuda(v) for v in x]
|
||
|
elif isinstance(x, dict):
|
||
|
return {k: to_cuda(v) for k, v in x.items()}
|
||
|
else:
|
||
|
# DGLGraph or other objects
|
||
|
return x.to(device=torch.cuda.current_device())
|
||
|
|
||
|
|
||
|
def get_local_rank() -> int:
|
||
|
return int(os.environ.get('LOCAL_RANK', 0))
|
||
|
|
||
|
|
||
|
def init_distributed() -> bool:
|
||
|
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||
|
distributed = world_size > 1
|
||
|
if distributed:
|
||
|
backend = 'nccl' if torch.cuda.is_available() else 'gloo'
|
||
|
dist.init_process_group(backend=backend, init_method='env://')
|
||
|
if backend == 'nccl':
|
||
|
torch.cuda.set_device(get_local_rank())
|
||
|
else:
|
||
|
logging.warning('Running on CPU only!')
|
||
|
assert torch.distributed.is_initialized()
|
||
|
return distributed
|
||
|
|
||
|
|
||
|
def increase_l2_fetch_granularity():
|
||
|
# maximum fetch granularity of L2: 128 bytes
|
||
|
_libcudart = ctypes.CDLL('libcudart.so')
|
||
|
# set device limit on the current device
|
||
|
# cudaLimitMaxL2FetchGranularity = 0x05
|
||
|
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
|
||
|
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
|
||
|
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
|
||
|
assert pValue.contents.value == 128
|
||
|
|
||
|
|
||
|
def seed_everything(seed):
|
||
|
seed = int(seed)
|
||
|
random.seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
|
||
|
|
||
|
def rank_zero_only(fn):
|
||
|
@wraps(fn)
|
||
|
def wrapped_fn(*args, **kwargs):
|
||
|
if not dist.is_initialized() or dist.get_rank() == 0:
|
||
|
return fn(*args, **kwargs)
|
||
|
|
||
|
return wrapped_fn
|
||
|
|
||
|
|
||
|
def using_tensor_cores(amp: bool) -> bool:
|
||
|
major_cc, minor_cc = torch.cuda.get_device_capability()
|
||
|
return (amp and major_cc >= 7) or major_cc >= 8
|