mirror of
https://github.com/baker-laboratory/RoseTTAFold-All-Atom.git
synced 2024-11-04 22:25:42 +00:00
181 lines
6.2 KiB
Python
181 lines
6.2 KiB
Python
import torch
|
|
import numpy as np
|
|
import json
|
|
from deepdiff import DeepDiff
|
|
import pprint
|
|
import assertpy
|
|
import dataclasses
|
|
from collections import OrderedDict
|
|
import contextlib
|
|
from deepdiff.operator import BaseOperator
|
|
|
|
def assert_shape(t, s):
|
|
assertpy.assert_that(tuple(t.shape)).is_equal_to(s)
|
|
|
|
def assert_same_shape(t, s):
|
|
assertpy.assert_that(tuple(t.shape)).is_equal_to(tuple(s.shape))
|
|
|
|
class ExceptionLogger(contextlib.AbstractContextManager):
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
if exc_type:
|
|
print("***Logging exception {}***".format((exc_type, exc_value,
|
|
traceback)))
|
|
|
|
|
|
def assert_equal(got, want):
|
|
assertpy.assert_that(got.dtype).is_equal_to(want.dtype)
|
|
assertpy.assert_that(got.shape).is_equal_to(want.shape)
|
|
is_eq = got.nan_to_num()==want.nan_to_num()
|
|
unequal_idx = torch.nonzero(~is_eq)
|
|
unequal_got = got[~is_eq]
|
|
unequal_want = want[~is_eq]
|
|
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_want, unequal_got))[:3]
|
|
|
|
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
|
|
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
|
|
if torch.numel(got) < 10:
|
|
msg = f'got {got}, want: {want}'
|
|
assert len(unequal_idx) == 0, msg
|
|
|
|
|
|
def assert_close(got, want, atol=1e-4, rtol=1e-8):
|
|
got = torch.nan_to_num(got)
|
|
want = torch.nan_to_num(want)
|
|
if got.shape != want.shape:
|
|
raise ValueError(f'Wrong shapes: got shape {got.shape} want shape {want.shape}')
|
|
elif not torch.allclose(got, want, atol=atol, rtol=rtol):
|
|
maximum_difference = torch.abs(got - want).max().item()
|
|
indices_different = torch.nonzero(got != want)
|
|
raise ValueError(f'Maximum difference: {maximum_difference}, indices different: {indices_different}')
|
|
|
|
|
|
def cpu(e):
|
|
if isinstance(e, dict):
|
|
return {cpu(k): cpu(v) for k,v in e.items()}
|
|
if isinstance(e, list) or isinstance(e, tuple):
|
|
return tuple(cpu(i) for i in e)
|
|
if hasattr(e, 'cpu'):
|
|
return e.cpu().detach()
|
|
return e
|
|
|
|
# Dataclass functions
|
|
|
|
def to_ordered_dict(dc):
|
|
return OrderedDict((field.name, getattr(dc, field.name)) for field in dataclasses.fields(dc))
|
|
|
|
def to_device(dc, device):
|
|
d = to_ordered_dict(dc)
|
|
for k, v in d.items():
|
|
if v is not None:
|
|
setattr(dc, k, v.to(device))
|
|
|
|
def shapes(dc):
|
|
d = to_ordered_dict(dc)
|
|
return {k:v.shape if hasattr(v, 'shape') else None for k,v in d.items()}
|
|
|
|
|
|
def pprint_obj(obj):
|
|
pprint.pprint(obj.__dict__, indent=4)
|
|
|
|
def assert_squeeze(t):
|
|
assert t.shape[0] == 1, f'{t.shape}[0] != 1'
|
|
return t[0]
|
|
|
|
def apply_to_tensors(e, op):
|
|
# if isinstance(e, dataclasses.)
|
|
if dataclasses.is_dataclass(e):
|
|
# return to_ordered_dict
|
|
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
|
|
if isinstance(e, dict):
|
|
return {k: apply_to_tensors(v, op) for k,v in e.items()}
|
|
if isinstance(e, list) or isinstance(e, tuple):
|
|
return tuple(apply_to_tensors(i, op) for i in e)
|
|
if hasattr(e, 'cpu'):
|
|
return op(e)
|
|
return e
|
|
|
|
def apply_to_matching(e, op, filt):
|
|
# if isinstance(e, dataclasses.)
|
|
if filt(e):
|
|
return op(e)
|
|
if dataclasses.is_dataclass(e):
|
|
# return to_ordered_dict
|
|
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
|
|
if isinstance(e, dict):
|
|
return {k: apply_to_tensors(v, op) for k,v in e.items()}
|
|
if isinstance(e, list) or isinstance(e, tuple):
|
|
return tuple(apply_to_tensors(i, op) for i in e)
|
|
return e
|
|
|
|
def set_grad(t):
|
|
t.requires_grad = True
|
|
|
|
def require_grad(e):
|
|
apply_to_tensors(e, set_grad)
|
|
|
|
def get_grad(e):
|
|
return apply_to_tensors(e, lambda x: x.grad)
|
|
|
|
def info(e):
|
|
shap = apply_to_tensors(e, lambda x: x.shape)
|
|
shap = apply_to_matching(shap, str, dataclasses.is_dataclass)
|
|
return json.dumps(shap, indent=4)
|
|
|
|
def minmax(e):
|
|
return apply_to_tensors(e, lambda x: (torch.log10(torch.min(x)), torch.log10(torch.max(x))))
|
|
|
|
class TensorMatchOperator(BaseOperator):
|
|
|
|
|
|
def __init__(self, atol=1e-3, rtol=0, **kwargs):
|
|
super(TensorMatchOperator, self).__init__(**kwargs)
|
|
self.atol = atol
|
|
self.rtol = rtol
|
|
|
|
def _equal_msg(self, got, want):
|
|
if got.shape != want.shape:
|
|
return f'got shape {got.shape} want shape {want.shape}'
|
|
if got.dtype != want.dtype:
|
|
return f'got dtype {got.dtype} want dtype {want.dtype}'
|
|
if torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol).all():
|
|
return ''
|
|
is_eq = torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol)
|
|
unequal_idx = torch.nonzero(~is_eq)
|
|
unequal_got = got[~is_eq]
|
|
unequal_want = want[~is_eq]
|
|
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_got, unequal_want))[:3]
|
|
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
|
|
uneq_msg += f' fraction unequal:{unequal_got.numel()}/{got.numel()}'
|
|
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
|
|
if torch.numel(got) < 10:
|
|
msg = f'got {got}, want: {want}'
|
|
return msg
|
|
|
|
|
|
def give_up_diffing(self, level, diff_instance):
|
|
msg = self._equal_msg(level.t1, level.t2)
|
|
if msg:
|
|
print(f'got:\n{level.t1}\n\nwant:\n{level.t2}')
|
|
msg = self._equal_msg(level.t1, level.t2)
|
|
if msg:
|
|
diff_instance.custom_report_result('tensors unequal', level, {
|
|
"msg": msg
|
|
})
|
|
return True
|
|
|
|
class NumpyMatchOperator(TensorMatchOperator):
|
|
def give_up_diffing(self, level, diff_instance):
|
|
level.t1 = torch.Tensor(level.t1)
|
|
level.t2 = torch.Tensor(level.t2)
|
|
return super(NumpyMatchOperator, self).give_up_diffing(level, diff_instance)
|
|
|
|
|
|
def cmp(got, want, **kwargs):
|
|
|
|
dd = DeepDiff(got, want, custom_operators=[
|
|
NumpyMatchOperator(types=[np.ndarray], **kwargs),
|
|
TensorMatchOperator(types=[torch.Tensor], **kwargs)])
|
|
if dd:
|
|
return dd
|
|
return ''
|