from dataclasses import dataclass
from typing import Optional, Union
from ..utils import get_aligned_idx
from ..utils.typing import EmbeddingsTensor, ExpandedTargetIdsTensor, IdsTensor, OneOrMoreTokenSequences
from .data_utils import TensorWrapper
[docs]
@dataclass(eq=False, repr=False)
class BatchEncoding(TensorWrapper):
"""Output produced by the tokenization process using :meth:`~inseq.models.AttributionModel.encode`.
Attributes:
input_ids (:obj:`torch.Tensor`): Batch of token ids with shape ``[batch_size, longest_seq_length]``.
Extra tokens for each sentence are padded, and truncation to ``max_seq_length`` is performed.
input_tokens (:obj:`list(list(str))`): List of lists containing tokens for each sentence in the batch.
attention_mask (:obj:`torch.Tensor`): Batch of attention masks with shape ``[batch_size, longest_seq_length]``.
1 for positions that are valid, 0 for padded positions.
baseline_ids (torch.Tensor, optional): Batch of reference token ids with shape
``[batch_size, longest_seq_length]``. Used for attribution methods requiring a baseline input (e.g. IG).
"""
input_ids: IdsTensor
attention_mask: IdsTensor
input_tokens: Optional[OneOrMoreTokenSequences] = None
baseline_ids: Optional[IdsTensor] = None
def __len__(self) -> int:
return len(self.input_tokens)
[docs]
@dataclass(eq=False, repr=False)
class BatchEmbedding(TensorWrapper):
"""Embeddings produced by the embedding process using :meth:`~inseq.models.AttributionModel.embed`.
Attributes:
input_embeds (:obj:`torch.Tensor`): Batch of token embeddings with shape
``[batch_size, longest_seq_length, embedding_size]`` for each sentence in the batch.
baseline_embeds (:obj:`torch.Tensor`, optional): Batch of reference token embeddings with shape
``[batch_size, longest_seq_length, embedding_size]`` for each sentence in the batch.
"""
input_embeds: Optional[EmbeddingsTensor] = None
baseline_embeds: Optional[EmbeddingsTensor] = None
def __len__(self) -> Optional[int]:
if self.input_embeds is not None:
return self.input_embeds.shape[0]
return None
[docs]
@dataclass(eq=False, repr=False)
class Batch(TensorWrapper):
"""Batch of input data for the attribution model.
Attributes:
encoding (:class:`~inseq.data.BatchEncoding`): Output produced by the tokenization process using
:meth:`~inseq.models.AttributionModel.encode`.
embedding (:class:`~inseq.data.BatchEmbedding`): Embeddings produced by the embedding process using
:meth:`~inseq.models.AttributionModel.embed`.
All attribute fields are accessible as properties (e.g. ``batch.input_ids`` corresponds to
``batch.encoding.input_ids``)
"""
encoding: BatchEncoding
embedding: BatchEmbedding
@property
def input_ids(self) -> IdsTensor:
return self.encoding.input_ids
@property
def input_tokens(self) -> OneOrMoreTokenSequences:
return self.encoding.input_tokens
@property
def attention_mask(self) -> IdsTensor:
return self.encoding.attention_mask
@property
def baseline_ids(self) -> Optional[IdsTensor]:
return self.encoding.baseline_ids
@property
def input_embeds(self) -> Optional[EmbeddingsTensor]:
return self.embedding.input_embeds
@property
def baseline_embeds(self) -> Optional[EmbeddingsTensor]:
return self.embedding.baseline_embeds
@input_ids.setter
def input_ids(self, value: IdsTensor):
self.encoding.input_ids = value
@input_tokens.setter
def input_tokens(self, value: list[list[str]]):
self.encoding.input_tokens = value
@attention_mask.setter
def attention_mask(self, value: IdsTensor):
self.encoding.attention_mask = value
@baseline_ids.setter
def baseline_ids(self, value: Optional[IdsTensor]):
self.encoding.baseline_ids = value
@input_embeds.setter
def input_embeds(self, value: Optional[EmbeddingsTensor]):
self.embedding.input_embeds = value
@baseline_embeds.setter
def baseline_embeds(self, value: Optional[EmbeddingsTensor]):
self.embedding.baseline_embeds = value
[docs]
@dataclass(eq=False, repr=False)
class EncoderDecoderBatch(TensorWrapper):
"""Batch of input data for the encoder-decoder attribution model, including information for the source text and the
target prefix.
Attributes:
sources (:class:`~inseq.data.Batch`): Batch of input data for the source text.
targets (:class:`~inseq.data.Batch`): Batch of input data for the target prefix.
"""
sources: Batch
targets: Batch
def __getitem__(self, subscript: Union[slice, int]) -> "EncoderDecoderBatch":
return EncoderDecoderBatch(sources=self.sources, targets=self.targets[subscript])
@property
def max_generation_length(self) -> int:
return self.targets.input_ids.shape[1]
@property
def source_tokens(self) -> OneOrMoreTokenSequences:
return self.sources.input_tokens
@property
def target_tokens(self) -> OneOrMoreTokenSequences:
return self.targets.input_tokens
@property
def source_ids(self) -> IdsTensor:
return self.sources.input_ids
@property
def target_ids(self) -> IdsTensor:
return self.targets.input_ids
@property
def source_embeds(self) -> EmbeddingsTensor:
return self.sources.input_embeds
@property
def target_embeds(self) -> EmbeddingsTensor:
return self.targets.input_embeds
@property
def source_mask(self) -> IdsTensor:
return self.sources.attention_mask
@property
def target_mask(self) -> IdsTensor:
return self.targets.attention_mask
def get_step_target(
self, step: int, with_attention: bool = False
) -> Union[ExpandedTargetIdsTensor, tuple[ExpandedTargetIdsTensor, ExpandedTargetIdsTensor]]:
tgt = self.targets.input_ids[:, step]
if with_attention:
return tgt, self.targets.attention_mask[:, step]
return tgt
[docs]
@dataclass(eq=False, repr=False)
class DecoderOnlyBatch(Batch):
"""Input batch adapted for decoder-only attribution models, including information for the target prefix."""
@property
def max_generation_length(self) -> int:
return self.input_ids.shape[1]
@property
def source_tokens(self) -> OneOrMoreTokenSequences:
return None
@property
def target_tokens(self) -> OneOrMoreTokenSequences:
return self.input_tokens
@property
def source_ids(self) -> IdsTensor:
return None
@property
def target_ids(self) -> IdsTensor:
return self.input_ids
@property
def source_embeds(self) -> EmbeddingsTensor:
return None
@property
def target_embeds(self) -> EmbeddingsTensor:
return self.input_embeds
@property
def source_mask(self) -> IdsTensor:
return None
@property
def target_mask(self) -> IdsTensor:
return self.attention_mask
def get_step_target(
self, step: int, with_attention: bool = False
) -> Union[ExpandedTargetIdsTensor, tuple[ExpandedTargetIdsTensor, ExpandedTargetIdsTensor]]:
tgt = self.input_ids[:, step]
if with_attention:
return tgt, self.attention_mask[:, step]
return tgt
@classmethod
def from_batch(self, batch: Batch) -> "DecoderOnlyBatch":
return DecoderOnlyBatch(
encoding=batch.encoding,
embedding=batch.embedding,
)
def slice_batch_from_position(
batch: DecoderOnlyBatch, curr_idx: int, alignments: Optional[list[tuple[int, int]]] = None
) -> tuple[DecoderOnlyBatch, IdsTensor]:
if len(alignments) > 0 and isinstance(alignments[0], list):
alignments = alignments[0]
truncate_idx = get_aligned_idx(curr_idx, alignments)
tgt_ids = batch.target_ids[:, truncate_idx]
return batch[:truncate_idx], tgt_ids