Source code for

from copy import deepcopy
from dataclasses import dataclass, fields
from typing import Any, TypeVar

import numpy as np
import torch
from jaxtyping import Int

from ..utils import pretty_dict

TensorClass = TypeVar("TensorClass", bound="TensorWrapper")

[docs] @dataclass class TensorWrapper: """Wrapper for tensors and lists of tensors to allow for easy access to their attributes.""" @staticmethod def _getitem(attr, subscript): if isinstance(attr, torch.Tensor): if attr.ndim == 1: return attr[subscript] if attr.ndim >= 2: return attr[:, subscript, ...] elif isinstance(attr, TensorWrapper): return attr[subscript] elif isinstance(attr, list) and isinstance(attr[0], list): return [seq[subscript] for seq in attr] elif isinstance(attr, dict): return {key: TensorWrapper._getitem(val, subscript) for key, val in attr.items()} else: return attr @staticmethod def _slice_batch(attr, subscript): if isinstance(attr, torch.Tensor): if attr.ndim == 1: return attr[subscript] if attr.ndim >= 2: return attr[subscript, ...] elif isinstance(attr, (TensorWrapper, list)): return attr[subscript] elif isinstance(attr, dict): return {key: TensorWrapper._slice_batch(val, subscript) for key, val in attr.items()} else: return attr @staticmethod def _select_active(attr, mask): if isinstance(attr, torch.Tensor): if attr.ndim <= 1: return attr else: curr_mask = mask.clone() if curr_mask.dtype != torch.bool: curr_mask = curr_mask.bool() while curr_mask.ndim < attr.ndim: curr_mask = curr_mask.unsqueeze(-1) orig_shape = attr.shape[1:] return attr.masked_select(curr_mask).reshape(-1, *orig_shape) elif isinstance(attr, TensorWrapper): return attr.select_active(mask) elif isinstance(attr, list): return [val for i, val in enumerate(attr) if mask.tolist()[i]] elif isinstance(attr, dict): return {key: TensorWrapper._select_active(val, mask) for key, val in attr.items()} else: return attr @staticmethod def _to(attr, device: str): if isinstance(attr, (torch.Tensor, TensorWrapper)): return elif isinstance(attr, dict): return {key: TensorWrapper._to(val, device) for key, val in attr.items()} else: return attr @staticmethod def _detach(attr): if isinstance(attr, (torch.Tensor, TensorWrapper)): return attr.detach() elif isinstance(attr, dict): return {key: TensorWrapper._detach(val) for key, val in attr.items()} else: return attr @staticmethod def _numpy(attr): if isinstance(attr, (torch.Tensor, TensorWrapper)): np_array = attr.numpy() if isinstance(np_array, np.ndarray): return np.ascontiguousarray(np_array, dtype=np_array.dtype) return np_array elif isinstance(attr, dict): return {key: TensorWrapper._numpy(val) for key, val in attr.items()} else: return attr @staticmethod def _torch(attr): if isinstance(attr, np.ndarray): return torch.tensor(attr) elif isinstance(attr, TensorWrapper): return attr.torch() elif isinstance(attr, dict): return {key: TensorWrapper._torch(val) for key, val in attr.items()} else: return attr @staticmethod def _eq(self_attr: TensorClass, other_attr: TensorClass) -> bool: try: if isinstance(self_attr, torch.Tensor): return torch.allclose(self_attr, other_attr, equal_nan=True, atol=1e-5) elif isinstance(self_attr, dict): return all(TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys()) else: return self_attr == other_attr except: # noqa: E722 return False
[docs] def __getitem__(self: TensorClass, subscript) -> TensorClass: """By default, idiomatic slicing is used for the sequence dimension across batches. For batching use `slice_batch` instead. """ return self.__class__( **{ self._getitem(getattr(self,, subscript) for field in fields(self.__class__)} )
def slice_batch(self: TensorClass, subscript) -> TensorClass: return self.__class__( **{ self._slice_batch(getattr(self,, subscript) for field in fields(self.__class__)} ) def select_active(self: TensorClass, mask: Int[torch.Tensor, "batch_size 1"]) -> TensorClass: return self.__class__( **{ self._select_active(getattr(self,, mask) for field in fields(self.__class__)} ) def to(self: TensorClass, device: str) -> TensorClass: for field in fields(self.__class__): attr = getattr(self, setattr(self,, self._to(attr, device)) if device == "cpu" and torch.cuda.is_available(): torch.cuda.empty_cache() return self def detach(self: TensorClass) -> TensorClass: for field in fields(self.__class__): attr = getattr(self, setattr(self,, self._detach(attr)) return self def numpy(self: TensorClass) -> TensorClass: for field in fields(self.__class__): attr = getattr(self, setattr(self,, self._numpy(attr)) return self def torch(self: TensorClass) -> TensorClass: for field, val in self.to_dict().items(): setattr(self, field, self._torch(val)) return self def clone(self: TensorClass) -> TensorClass: out_params = {} for field in fields(self.__class__): attr = getattr(self, if isinstance(attr, (torch.Tensor, TensorWrapper)): out_params[] = attr.clone() elif attr is not None: out_params[] = deepcopy(attr) else: out_params[] = None return self.__class__(**out_params) def clone_empty(self: TensorClass) -> TensorClass: out_params = {k: v for k, v in self.__dict__.items() if k.startswith("_") and v is not None} return self.__class__(**out_params) def to_dict(self: TensorClass) -> dict[str, Any]: return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} def __str__(self): return f"{self.__class__.__name__}({pretty_dict(self.to_dict())})" def __repr__(self): return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})" def __eq__(self, other): equals = {field: self._eq(val, getattr(other, field)) for field, val in self.__dict__.items()} return all(x for x in equals.values()) def __json_encode__(self): return self.clone().detach().to("cpu").numpy().to_dict() def __json_decode__(self, **attrs): # Does not contemplate the usage of __slots__ self.__dict__ = attrs self.__post_init__() def __post_init__(self): pass