Source code for inseq.data.data_utils

from copy import deepcopy
from dataclasses import dataclass, fields
from typing import Any, TypeVar

import numpy as np
import torch
from jaxtyping import Int

from ..utils import pretty_dict

TensorClass = TypeVar("TensorClass", bound="TensorWrapper")


[docs] @dataclass class TensorWrapper: """Wrapper for tensors and lists of tensors to allow for easy access to their attributes.""" @staticmethod def _getitem(attr, subscript): if isinstance(attr, torch.Tensor): if attr.ndim == 1: return attr[subscript] if attr.ndim >= 2: return attr[:, subscript, ...] elif isinstance(attr, TensorWrapper): return attr[subscript] elif isinstance(attr, list) and isinstance(attr[0], list): return [seq[subscript] for seq in attr] elif isinstance(attr, dict): return {key: TensorWrapper._getitem(val, subscript) for key, val in attr.items()} else: return attr @staticmethod def _slice_batch(attr, subscript): if isinstance(attr, torch.Tensor): if attr.ndim == 1: return attr[subscript] if attr.ndim >= 2: return attr[subscript, ...] elif isinstance(attr, (TensorWrapper, list)): return attr[subscript] elif isinstance(attr, dict): return {key: TensorWrapper._slice_batch(val, subscript) for key, val in attr.items()} else: return attr @staticmethod def _select_active(attr, mask): if isinstance(attr, torch.Tensor): if attr.ndim <= 1: return attr else: curr_mask = mask.clone() if curr_mask.dtype != torch.bool: curr_mask = curr_mask.bool() while curr_mask.ndim < attr.ndim: curr_mask = curr_mask.unsqueeze(-1) orig_shape = attr.shape[1:] return attr.masked_select(curr_mask).reshape(-1, *orig_shape) elif isinstance(attr, TensorWrapper): return attr.select_active(mask) elif isinstance(attr, list): return [val for i, val in enumerate(attr) if mask.tolist()[i]] elif isinstance(attr, dict): return {key: TensorWrapper._select_active(val, mask) for key, val in attr.items()} else: return attr @staticmethod def _to(attr, device: str): if isinstance(attr, (torch.Tensor, TensorWrapper)): return attr.to(device) elif isinstance(attr, dict): return {key: TensorWrapper._to(val, device) for key, val in attr.items()} else: return attr @staticmethod def _detach(attr): if isinstance(attr, (torch.Tensor, TensorWrapper)): return attr.detach() elif isinstance(attr, dict): return {key: TensorWrapper._detach(val) for key, val in attr.items()} else: return attr @staticmethod def _numpy(attr): if isinstance(attr, (torch.Tensor, TensorWrapper)): np_array = attr.numpy() if isinstance(np_array, np.ndarray): return np.ascontiguousarray(np_array, dtype=np_array.dtype) return np_array elif isinstance(attr, dict): return {key: TensorWrapper._numpy(val) for key, val in attr.items()} else: return attr @staticmethod def _torch(attr): if isinstance(attr, np.ndarray): return torch.tensor(attr) elif isinstance(attr, TensorWrapper): return attr.torch() elif isinstance(attr, dict): return {key: TensorWrapper._torch(val) for key, val in attr.items()} else: return attr @staticmethod def _eq(self_attr: TensorClass, other_attr: TensorClass) -> bool: try: if isinstance(self_attr, torch.Tensor): return torch.allclose(self_attr, other_attr, equal_nan=True, atol=1e-5) elif isinstance(self_attr, dict): return all(TensorWrapper._eq(self_attr[k], other_attr[k]) for k in self_attr.keys()) else: return self_attr == other_attr except: # noqa: E722 return False
[docs] def __getitem__(self: TensorClass, subscript) -> TensorClass: """By default, idiomatic slicing is used for the sequence dimension across batches. For batching use `slice_batch` instead. """ return self.__class__( **{field.name: self._getitem(getattr(self, field.name), subscript) for field in fields(self.__class__)} )
def slice_batch(self: TensorClass, subscript) -> TensorClass: return self.__class__( **{field.name: self._slice_batch(getattr(self, field.name), subscript) for field in fields(self.__class__)} ) def select_active(self: TensorClass, mask: Int[torch.Tensor, "batch_size 1"]) -> TensorClass: return self.__class__( **{field.name: self._select_active(getattr(self, field.name), mask) for field in fields(self.__class__)} ) def to(self: TensorClass, device: str) -> TensorClass: for field in fields(self.__class__): attr = getattr(self, field.name) setattr(self, field.name, self._to(attr, device)) if device == "cpu" and torch.cuda.is_available(): torch.cuda.empty_cache() return self def detach(self: TensorClass) -> TensorClass: for field in fields(self.__class__): attr = getattr(self, field.name) setattr(self, field.name, self._detach(attr)) return self def numpy(self: TensorClass) -> TensorClass: for field in fields(self.__class__): attr = getattr(self, field.name) setattr(self, field.name, self._numpy(attr)) return self def torch(self: TensorClass) -> TensorClass: for field, val in self.to_dict().items(): setattr(self, field, self._torch(val)) return self def clone(self: TensorClass) -> TensorClass: out_params = {} for field in fields(self.__class__): attr = getattr(self, field.name) if isinstance(attr, (torch.Tensor, TensorWrapper)): out_params[field.name] = attr.clone() elif attr is not None: out_params[field.name] = deepcopy(attr) else: out_params[field.name] = None return self.__class__(**out_params) def clone_empty(self: TensorClass) -> TensorClass: out_params = {k: v for k, v in self.__dict__.items() if k.startswith("_") and v is not None} return self.__class__(**out_params) def to_dict(self: TensorClass) -> dict[str, Any]: return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} def __str__(self): return f"{self.__class__.__name__}({pretty_dict(self.to_dict())})" def __repr__(self): return f"{self.__class__.__name__}({pretty_dict(self.__dict__)})" def __eq__(self, other): equals = {field: self._eq(val, getattr(other, field)) for field, val in self.__dict__.items()} return all(x for x in equals.values()) def __json_encode__(self): return self.clone().detach().to("cpu").numpy().to_dict() def __json_decode__(self, **attrs): # Does not contemplate the usage of __slots__ self.__dict__ = attrs self.__post_init__() def __post_init__(self): pass