RoseTTAFold-All-Atom/rf2aa/tensor_util.py
2024-03-04 22:38:17 -08:00

182 lines
6.2 KiB
Python

import torch
import numpy as np
import json
from deepdiff import DeepDiff
import pprint
import assertpy
import dataclasses
from collections import OrderedDict
import contextlib
from deepdiff.operator import BaseOperator
def assert_shape(t, s):
assertpy.assert_that(tuple(t.shape)).is_equal_to(s)
def assert_same_shape(t, s):
assertpy.assert_that(tuple(t.shape)).is_equal_to(tuple(s.shape))
class ExceptionLogger(contextlib.AbstractContextManager):
def __exit__(self, exc_type, exc_value, traceback):
if exc_type:
print("***Logging exception {}***".format((exc_type, exc_value,
traceback)))
def assert_equal(got, want):
assertpy.assert_that(got.dtype).is_equal_to(want.dtype)
assertpy.assert_that(got.shape).is_equal_to(want.shape)
is_eq = got.nan_to_num()==want.nan_to_num()
unequal_idx = torch.nonzero(~is_eq)
unequal_got = got[~is_eq]
unequal_want = want[~is_eq]
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_want, unequal_got))[:3]
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
if torch.numel(got) < 10:
msg = f'got {got}, want: {want}'
assert len(unequal_idx) == 0, msg
def assert_close(got, want, atol=1e-4, rtol=1e-8):
got = torch.nan_to_num(got)
want = torch.nan_to_num(want)
if got.shape != want.shape:
raise ValueError(f'Wrong shapes: got shape {got.shape} want shape {want.shape}')
elif not torch.allclose(got, want, atol=atol, rtol=rtol):
maximum_difference = torch.abs(got - want).max().item()
indices_different = torch.nonzero(got != want)
raise ValueError(f'Maximum difference: {maximum_difference}, indices different: {indices_different}')
def cpu(e):
if isinstance(e, dict):
return {cpu(k): cpu(v) for k,v in e.items()}
if isinstance(e, list) or isinstance(e, tuple):
return tuple(cpu(i) for i in e)
if hasattr(e, 'cpu'):
return e.cpu().detach()
return e
# Dataclass functions
def to_ordered_dict(dc):
return OrderedDict((field.name, getattr(dc, field.name)) for field in dataclasses.fields(dc))
def to_device(dc, device):
d = to_ordered_dict(dc)
for k, v in d.items():
if v is not None:
setattr(dc, k, v.to(device))
def shapes(dc):
d = to_ordered_dict(dc)
return {k:v.shape if hasattr(v, 'shape') else None for k,v in d.items()}
def pprint_obj(obj):
pprint.pprint(obj.__dict__, indent=4)
def assert_squeeze(t):
assert t.shape[0] == 1, f'{t.shape}[0] != 1'
return t[0]
def apply_to_tensors(e, op):
# if isinstance(e, dataclasses.)
if dataclasses.is_dataclass(e):
# return to_ordered_dict
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
if isinstance(e, dict):
return {k: apply_to_tensors(v, op) for k,v in e.items()}
if isinstance(e, list) or isinstance(e, tuple):
return tuple(apply_to_tensors(i, op) for i in e)
if hasattr(e, 'cpu'):
return op(e)
return e
def apply_to_matching(e, op, filt):
# if isinstance(e, dataclasses.)
if filt(e):
return op(e)
if dataclasses.is_dataclass(e):
# return to_ordered_dict
return type(e)(*(apply_to_tensors(getattr(e, field.name), op) for field in dataclasses.fields(e)))
if isinstance(e, dict):
return {k: apply_to_tensors(v, op) for k,v in e.items()}
if isinstance(e, list) or isinstance(e, tuple):
return tuple(apply_to_tensors(i, op) for i in e)
return e
def set_grad(t):
t.requires_grad = True
def require_grad(e):
apply_to_tensors(e, set_grad)
def get_grad(e):
return apply_to_tensors(e, lambda x: x.grad)
def info(e):
shap = apply_to_tensors(e, lambda x: x.shape)
shap = apply_to_matching(shap, str, dataclasses.is_dataclass)
return json.dumps(shap, indent=4)
def minmax(e):
return apply_to_tensors(e, lambda x: (torch.log10(torch.min(x)), torch.log10(torch.max(x))))
class TensorMatchOperator(BaseOperator):
def __init__(self, atol=1e-3, rtol=0, **kwargs):
super(TensorMatchOperator, self).__init__(**kwargs)
self.atol = atol
self.rtol = rtol
def _equal_msg(self, got, want):
if got.shape != want.shape:
return f'got shape {got.shape} want shape {want.shape}'
if got.dtype != want.dtype:
return f'got dtype {got.dtype} want dtype {want.dtype}'
if torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol).all():
return ''
is_eq = torch.isclose(got, want, equal_nan=True, atol=self.atol, rtol=self.rtol)
unequal_idx = torch.nonzero(~is_eq)
unequal_got = got[~is_eq]
unequal_want = want[~is_eq]
uneq_idx_got_want = list(zip(unequal_idx.tolist(), unequal_got, unequal_want))[:3]
uneq_msg = ' '.join(f'idx:{idx}, got:{got}, want:{want}' for idx, got, want in uneq_idx_got_want)
uneq_msg += f' fraction unequal:{unequal_got.numel()}/{got.numel()}'
msg = f'tensors with shape {got.shape}: first unequal indices: {uneq_msg}'
if torch.numel(got) < 10:
msg = f'got {got}, want: {want}'
return msg
def give_up_diffing(self, level, diff_instance):
msg = self._equal_msg(level.t1, level.t2)
if msg:
print(f'got:\n{level.t1}\n\nwant:\n{level.t2}')
msg = self._equal_msg(level.t1, level.t2)
if msg:
diff_instance.custom_report_result('tensors unequal', level, {
"msg": msg
})
return True
class NumpyMatchOperator(TensorMatchOperator):
def give_up_diffing(self, level, diff_instance):
level.t1 = torch.Tensor(level.t1)
level.t2 = torch.Tensor(level.t2)
return super(NumpyMatchOperator, self).give_up_diffing(level, diff_instance)
def cmp(got, want, **kwargs):
dd = DeepDiff(got, want, custom_operators=[
NumpyMatchOperator(types=[np.ndarray], **kwargs),
TensorMatchOperator(types=[torch.Tensor], **kwargs)])
if dd:
return dd
return ''