RoseTTAFold-All-Atom/rf2aa/tensor_util.py

182 lines
6.2 KiB
Python
Raw Permalink Normal View History

2024-03-05 06:38:17 +00:00
import torch
import numpy as np
import json
from deepdiff import DeepDiff
import pprint
import assertpy
import dataclasses
from collections import OrderedDict
import contextlib
from deepdiff.operator import BaseOperator
def assert_shape(t, s):
assertpy.assert_that(tuple(t.shape)).is_equal_to(s)
def assert_same_shape(t, s):
assertpy.assert_that(tuple(t.shape)).is_equal_to(tuple(s.shape))
class ExceptionLogger(contextlib.AbstractContextManager):
def __exit__(self, exc_type, exc_value, traceback):
if exc_type:
print("***Logging exception {}***".format((exc_type, exc_value,
traceback)))
def assert_equal(got, want):
assertpy.assert_that(got.dtype).is_equal_to(want.dtype)
assertpy.assert_that(got.shape).is_equal_to(want.shape)
is_eq = got.nan_to_num()==want.nan_to_num()
unequal_idx = torch.nonzero(~is_eq)
unequal_got = got[~is_eq]
unequal_want = want[~is_eq]
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_want, unequal_got))[:3]
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
if torch.numel(got) < 10:
msg = f'got {got}, want: {want}'
assert len(unequal_idx) == 0, msg
def assert_close(got, want, atol=1e-4, rtol=1e-8):
got = torch.nan_to_num(got)
want = torch.nan_to_num(want)
if got.shape != want.shape:
raise ValueError(f'Wrong shapes: got shape {got.shape} want shape {want.shape}')
elif not torch.allclose(got, want, atol=atol, rtol=rtol):
maximum_difference = torch.abs(got - want).max().item()
indices_different = torch.nonzero(got != want)
raise ValueError(f'Maximum difference: {maximum_difference}, indices different: {indices_different}')
def cpu(e):
if isinstance(e, dict):
return {cpu(k): cpu(v) for k,v in e.items()}
if isinstance(e, list) or isinstance(e, tuple):
return tuple(cpu(i) for i in e)
if hasattr(e, 'cpu'):
return e.cpu().detach()
return e
# Dataclass functions
def to_ordered_dict(dc):
return OrderedDict((field.name, getattr(dc, field.name)) for field in dataclasses.fields(dc))
def to_device(dc, device):
d = to_ordered_dict(dc)
for k, v in d.items():
if v is not None:
setattr(dc, k, v.to(device))
def shapes(dc):
d = to_ordered_dict(dc)
return {k:v.shape if hasattr(v, 'shape') else None for k,v in d.items()}
def pprint_obj(obj):
pprint.pprint(obj.__dict__, indent=4)
def assert_squeeze(t):
assert t.shape[0] == 1, f'{t.shape}[0] != 1'
return t[0]
def apply_to_tensors(e, op):
# if isinstance(e, dataclasses.)
if dataclasses.is_dataclass(e):
# return to_ordered_dict
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
if isinstance(e, dict):
return {k: apply_to_tensors(v, op) for k,v in e.items()}
if isinstance(e, list) or isinstance(e, tuple):
return tuple(apply_to_tensors(i, op) for i in e)
if hasattr(e, 'cpu'):
return op(e)
return e
def apply_to_matching(e, op, filt):
# if isinstance(e, dataclasses.)
if filt(e):
return op(e)
if dataclasses.is_dataclass(e):
# return to_ordered_dict
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
if isinstance(e, dict):
return {k: apply_to_tensors(v, op) for k,v in e.items()}
if isinstance(e, list) or isinstance(e, tuple):
return tuple(apply_to_tensors(i, op) for i in e)
return e
def set_grad(t):
t.requires_grad = True
def require_grad(e):
apply_to_tensors(e, set_grad)
def get_grad(e):
return apply_to_tensors(e, lambda x: x.grad)
def info(e):
shap = apply_to_tensors(e, lambda x: x.shape)
shap = apply_to_matching(shap, str, dataclasses.is_dataclass)
return json.dumps(shap, indent=4)
def minmax(e):
return apply_to_tensors(e, lambda x: (torch.log10(torch.min(x)), torch.log10(torch.max(x))))
class TensorMatchOperator(BaseOperator):
def __init__(self, atol=1e-3, rtol=0, **kwargs):
super(TensorMatchOperator, self).__init__(**kwargs)
self.atol = atol
self.rtol = rtol
def _equal_msg(self, got, want):
if got.shape != want.shape:
return f'got shape {got.shape} want shape {want.shape}'
if got.dtype != want.dtype:
return f'got dtype {got.dtype} want dtype {want.dtype}'
if torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol).all():
return ''
is_eq = torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol)
unequal_idx = torch.nonzero(~is_eq)
unequal_got = got[~is_eq]
unequal_want = want[~is_eq]
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_got, unequal_want))[:3]
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
uneq_msg += f' fraction unequal:{unequal_got.numel()}/{got.numel()}'
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
if torch.numel(got) < 10:
msg = f'got {got}, want: {want}'
return msg
def give_up_diffing(self, level, diff_instance):
msg = self._equal_msg(level.t1, level.t2)
if msg:
print(f'got:\n{level.t1}\n\nwant:\n{level.t2}')
msg = self._equal_msg(level.t1, level.t2)
if msg:
diff_instance.custom_report_result('tensors unequal', level, {
"msg": msg
})
return True
class NumpyMatchOperator(TensorMatchOperator):
def give_up_diffing(self, level, diff_instance):
level.t1 = torch.Tensor(level.t1)
level.t2 = torch.Tensor(level.t2)
return super(NumpyMatchOperator, self).give_up_diffing(level, diff_instance)
def cmp(got, want, **kwargs):
dd = DeepDiff(got, want, custom_operators=[
NumpyMatchOperator(types=[np.ndarray], **kwargs),
TensorMatchOperator(types=[torch.Tensor], **kwargs)])
if dd:
return dd
return ''