from copy import deepcopy
from dataclasses import dataclass, fields
from typing import Any, TypeVar
import numpy as np
import torch
from jaxtyping import Int
from ..utils import pretty_dict
TensorClass = TypeVar("TensorClass", bound="TensorWrapper")
[docs]
@dataclass
class TensorWrapper:
"""Wrapper for tensors and lists of tensors to allow for easy access to their attributes."""
@staticmethod
def _getitem(attr, subscript):
if isinstance(attr, torch.Tensor):
if attr.ndim == 1:
return attr[subscript]
if attr.ndim >= 2:
return attr[:, subscript, ...]
elif isinstance(attr, TensorWrapper):
return attr[subscript]
elif isinstance(attr, list) and isinstance(attr[0], list):
return [seq[subscript] for seq in attr]
elif isinstance(attr, dict):
return {key: TensorWrapper._getitem(val, subscript) for key, val in attr.items()}
else:
return attr
@staticmethod
def _slice_batch(attr, subscript):
if isinstance(attr, torch.Tensor):
if attr.ndim == 1:
return attr[subscript]
if attr.ndim >= 2:
return attr[subscript, ...]
elif isinstance(attr, (TensorWrapper, list)):
return attr[subscript]
elif isinstance(attr, dict):
return {key: TensorWrapper._slice_batch(val, subscript) for key, val in attr.items()}
else:
return attr
@staticmethod
def _select_active(attr, mask):
if isinstance(attr, torch.Tensor):
if attr.ndim <= 1:
return attr
else:
curr_mask = mask.clone()
if curr_mask.dtype != torch.bool:
curr_mask = curr_mask.bool()
while curr_mask.ndim < attr.ndim:
curr_mask = curr_mask.unsqueeze(-1)
orig_shape = attr.shape[1:]
return attr.masked_select(curr_mask).reshape(-1, *orig_shape)
elif isinstance(attr, TensorWrapper):
return attr.select_active(mask)
elif isinstance(attr, list):
return [val for i, val in enumerate(attr) if mask.tolist()[i]]
elif isinstance(attr, dict):
return {key: TensorWrapper._select_active(val, mask) for key, val in attr.items()}
else:
return attr
@staticmethod
def _to(attr, device: str):
if isinstance(attr, (torch.Tensor, TensorWrapper)):
return attr.to(device)
elif isinstance(attr, dict):
return {key: TensorWrapper._to(val, device) for key, val in attr.items()}
else:
return attr
@staticmethod
def _detach(attr):
if isinstance(attr, (torch.Tensor, TensorWrapper)):
return attr.detach()
elif isinstance(attr, dict):
return {key: TensorWrapper._detach(val) for key, val in attr.items()}
else:
return attr
@staticmethod
def _numpy(attr):
if isinstance(attr, (torch.Tensor, TensorWrapper)):
np_array = attr.numpy()
if isinstance(np_array, np.ndarray):
return np.ascontiguousarray(np_array, dtype=np_array.dtype)
return np_array
elif isinstance(attr, dict):
return {key: TensorWrapper._numpy(val) for key, val in attr.items()}
else:
return attr
@staticmethod
def _torch(attr):
if isinstance(attr, np.ndarray):
return torch.tensor(attr)
elif isinstance(attr, TensorWrapper):
return attr.torch()
elif isinstance(attr, dict):
return {key: TensorWrapper._torch(val) for key, val in attr.items()}
else:
return attr
@staticmethod
def _eq(self_attr: TensorClass, other_attr: TensorClass) -> bool:
try:
if isinstance(self_attr, torch.Tensor):
return torch.allclose(self_attr, other_attr, equal_nan=True, atol=1e-5)
elif isinstance(self_attr, dict):
return all(TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys())
else:
return self_attr == other_attr
except: # noqa: E722
return False
[docs]
def __getitem__(self: TensorClass, subscript) -> TensorClass:
"""By default, idiomatic slicing is used for the sequence dimension across batches.
For batching use `slice_batch` instead.
"""
return self.__class__(
**{field.name: self._getitem(getattr(self, field.name), subscript) for field in fields(self.__class__)}
)
def slice_batch(self: TensorClass, subscript) -> TensorClass:
return self.__class__(
**{field.name: self._slice_batch(getattr(self, field.name), subscript) for field in fields(self.__class__)}
)
def select_active(self: TensorClass, mask: Int[torch.Tensor, "batch_size 1"]) -> TensorClass:
return self.__class__(
**{field.name: self._select_active(getattr(self, field.name), mask) for field in fields(self.__class__)}
)
def to(self: TensorClass, device: str) -> TensorClass:
for field in fields(self.__class__):
attr = getattr(self, field.name)
setattr(self, field.name, self._to(attr, device))
if device == "cpu" and torch.cuda.is_available():
torch.cuda.empty_cache()
return self
def detach(self: TensorClass) -> TensorClass:
for field in fields(self.__class__):
attr = getattr(self, field.name)
setattr(self, field.name, self._detach(attr))
return self
def numpy(self: TensorClass) -> TensorClass:
for field in fields(self.__class__):
attr = getattr(self, field.name)
setattr(self, field.name, self._numpy(attr))
return self
def torch(self: TensorClass) -> TensorClass:
for field, val in self.to_dict().items():
setattr(self, field, self._torch(val))
return self
def clone(self: TensorClass) -> TensorClass:
out_params = {}
for field in fields(self.__class__):
attr = getattr(self, field.name)
if isinstance(attr, (torch.Tensor, TensorWrapper)):
out_params[field.name] = attr.clone()
elif attr is not None:
out_params[field.name] = deepcopy(attr)
else:
out_params[field.name] = None
return self.__class__(**out_params)
def clone_empty(self: TensorClass) -> TensorClass:
out_params = {k: v for k, v in self.__dict__.items() if k.startswith("_") and v is not None}
return self.__class__(**out_params)
def to_dict(self: TensorClass) -> dict[str, Any]:
return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}
def __str__(self):
return f"{self.__class__.__name__}({pretty_dict(self.to_dict())})"
def __repr__(self):
return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})"
def __eq__(self, other):
equals = {field: self._eq(val, getattr(other, field)) for field, val in self.__dict__.items()}
return all(x for x in equals.values())
def __json_encode__(self):
return self.clone().detach().to("cpu").numpy().to_dict()
def __json_decode__(self, **attrs):
# Does not contemplate the usage of __slots__
self.__dict__ = attrs
self.__post_init__()
def __post_init__(self):
pass