Source code for inseq.data.aggregator

import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union

import torch

from ..utils import (
    Registry,
    aggregate_contiguous,
    aggregate_token_pair,
    aggregate_token_sequence,
    available_classes,
    extract_signature_args,
    validate_indices,
)
from ..utils import normalize as normalize_fn
from ..utils.typing import IndexSpan, OneOrMoreIndices, TokenWithId
from .aggregation_functions import AggregationFunction
from .data_utils import TensorWrapper

if TYPE_CHECKING:
    from .attribution import FeatureAttributionSequenceOutput


logger = logging.getLogger(__name__)

AggregableMixinClass = TypeVar("AggregableMixinClass", bound="AggregableMixin")


class DictWithDefault(dict):
    """Used to pass specific values to field-specific calls of the aggregate function in Aggregator.

    DictWithDefault dictionary objects won't be passed as a whole to all field-specific functions called by
    Aggregator.aggregate, and instead only the values with the name of the corresponding field will be used.
    When these are missing, the default field of DictWithDefault will be used as fallback.
    """

    @staticmethod
    def _get_fn(name: str) -> Callable:
        if name not in available_classes(AggregationFunction):
            raise ValueError(
                f"Unknown aggregation function {name}. Choose from {','.join(available_classes(AggregationFunction))}."
            )
        return AggregationFunction.available_classes()[name]()

    def __init__(self, default: Union[str, Callable], **kwargs):
        super().__init__(**kwargs)
        self.default = self._get_fn(default) if isinstance(default, str) else default

    def __getitem__(self, key):
        try:
            value = super().__getitem__(key)
            if isinstance(value, str):
                return self._get_fn(value)
            elif isinstance(value, dict):
                return DictWithDefault(self.default, **value)
            return value
        except KeyError:
            return self.default


[docs] class Aggregator(Registry): registry_attr = "aggregator_name"
[docs] @classmethod def start_aggregation_hook(cls, tensors: TensorWrapper, **kwargs): """Hook called at the start of the aggregation process. Use to ensure a prerequisite that is independent of previous aggregation steps and fundamental to the aggregation process (e.g. parameters are of the correct type). Will avoid performing aggregation steps before returning an error. """ pass
[docs] @classmethod def pre_aggregate_hook(cls, tensors: TensorWrapper, **kwargs): """Hook called right before the aggregation function is called. Use to ensure a prerequisite that is functional of previous aggregation steps and fundamental to the aggregation process (e.g. the aggregatable object produced by the previous step has correct shapes). """ pass
@classmethod @abstractmethod def _aggregate(cls, tensors: TensorWrapper, **kwargs): pass @classmethod def aggregate( cls, tensors: AggregableMixinClass, do_pre_aggregation_checks: bool = True, do_post_aggregation_checks: bool = True, **kwargs, ) -> AggregableMixinClass: if do_pre_aggregation_checks: cls.start_aggregation_hook(tensors, **kwargs) cls.pre_aggregate_hook(tensors, **kwargs) aggregated = cls._aggregate(tensors, **kwargs) cls.post_aggregate_hook(aggregated, **kwargs) if do_post_aggregation_checks: cls.end_aggregation_hook(aggregated, **kwargs) return aggregated
[docs] @classmethod def post_aggregate_hook(cls, tensors: TensorWrapper, **kwargs): """Hook called right after the aggregation function is called. Verifies that the aggregated object has the correct properties. """ pass
[docs] @classmethod def end_aggregation_hook(cls, tensors: TensorWrapper, **kwargs): """Hook called at the end of the aggregation process. Use to ensure that the final product of aggregation is compliant with the requirements of individual aggregators. """ pass
def _get_aggregators_from_id( aggregator: str, aggregate_fn: Optional[str] = None, ) -> tuple[type[Aggregator], Optional[AggregationFunction]]: if aggregator in available_classes(Aggregator): aggregator = Aggregator.available_classes()[aggregator] elif aggregator in available_classes(AggregationFunction): if aggregate_fn is not None: raise ValueError( "If aggregator is a string identifying an aggregation function, aggregate_fn should not be provided." ) aggregate_fn = aggregator aggregator = SequenceAttributionAggregator else: raise ValueError( f"Unknown aggregator {aggregator}. Choose from {', '.join(available_classes(Aggregator))}.\n" f"Alternatively, choose from the aggregate_fn options {', '.join(available_classes(AggregationFunction))} " "for scores aggregation with the chosen function." ) if aggregate_fn is None: return aggregator, aggregate_fn if aggregate_fn not in available_classes(AggregationFunction): raise ValueError( f"Unknown aggregation function {aggregate_fn}. " f"Choose from {', '.join(available_classes(AggregationFunction))}" ) aggregate_fn = AggregationFunction.available_classes()[aggregate_fn]() return aggregator, aggregate_fn
[docs] class AggregatorPipeline: def __init__( self, aggregators: list[Union[str, type[Aggregator]]], aggregate_fn: Optional[list[Union[str, Callable]]] = None, ): self.aggregators: list[type[Aggregator]] = [] self.aggregate_fn: list[Callable] = [] if aggregate_fn is not None: if len(aggregate_fn) != len(aggregators): raise ValueError( "If custom aggregate_fn are provided, their number should match the number of aggregators." ) for idx in range(len(aggregators)): curr_aggregator = aggregators[idx] curr_aggregate_fn = aggregate_fn[idx] if aggregate_fn is not None else None if isinstance(curr_aggregator, str): curr_aggregator, curr_aggregate_fn = _get_aggregators_from_id(curr_aggregator, curr_aggregate_fn) self.aggregators.append(curr_aggregator) self.aggregate_fn.append(curr_aggregate_fn) def aggregate( self, tensors: AggregableMixinClass, do_pre_aggregation_checks: bool = True, do_post_aggregation_checks: bool = True, **kwargs, ) -> AggregableMixinClass: if do_pre_aggregation_checks: for aggregator in self.aggregators: aggregator.start_aggregation_hook(tensors, **kwargs) for aggregator, aggregate_fn in zip(self.aggregators, self.aggregate_fn): curr_aggregation_kwargs = kwargs.copy() if aggregate_fn is not None: curr_aggregation_kwargs["aggregate_fn"] = aggregate_fn tensors = aggregator.aggregate( tensors, do_pre_aggregation_checks=False, do_post_aggregation_checks=False, **curr_aggregation_kwargs ) if do_post_aggregation_checks: for aggregator in self.aggregators: aggregator.end_aggregation_hook(tensors, **kwargs) return tensors
AggregatorInput = Union[AggregatorPipeline, type[Aggregator], str, Sequence[Union[str, type[Aggregator]]], None]
[docs] def list_aggregators() -> list[str]: """Lists identifiers for all available aggregators.""" return available_classes(Aggregator)
[docs] class AggregableMixin(ABC): _aggregator: Union[AggregatorPipeline, type[Aggregator]]
[docs] def aggregate( self: AggregableMixinClass, aggregator: AggregatorInput = None, aggregate_fn: Union[str, Sequence[str], None] = None, do_pre_aggregation_checks: bool = True, do_post_aggregation_checks: bool = True, **kwargs, ) -> AggregableMixinClass: """Aggregate outputs using the default or provided aggregator. Args: aggregator (:obj:`AggregatorPipeline` or :obj:`Type[Aggregator]` or :obj:`str` or , optional): Aggregator pipeline to use. If not provided, the default aggregator pipeline is used. Returns: :obj:`AggregableMixin`: The aggregated output class. """ if aggregator is None: aggregator = self._aggregator if isinstance(aggregator, str): if isinstance(aggregate_fn, (list, tuple)): raise ValueError( "If a single aggregator is used, aggregate_fn should also be a string identifier for the " "corresponding aggregation function if defined." ) aggregator, aggregate_fn = _get_aggregators_from_id(aggregator, aggregate_fn) if aggregate_fn is not None: kwargs["aggregate_fn"] = aggregate_fn elif isinstance(aggregator, (list, tuple)): if all(isinstance(a, (str, type)) for a in aggregator): aggregator = AggregatorPipeline(aggregator, aggregate_fn) elif all(isinstance(agg, tuple) for agg in aggregator): if all(isinstance(idx, (str, type)) for agg in aggregator for idx in agg): aggregator = AggregatorPipeline([a[0] for a in aggregator], [a[1] for a in aggregator]) else: raise ValueError( "If aggregator is a sequence, it should contain either strings/classes identifying aggregators" "or tuples of pairs of strings/classes identifying aggregators and aggregate functions." ) return aggregator.aggregate( self, do_pre_aggregation_checks=do_pre_aggregation_checks, do_post_aggregation_checks=do_post_aggregation_checks, **kwargs, )
@abstractmethod def __post_init__(self): pass
[docs] class SequenceAttributionAggregator(Aggregator): """Aggregates sequence attributions using a custom function. By default, the mean function is used. Enables aggregation for the FeatureAttributionSequenceOutput class using an aggregation function of choice. Args: attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The attribution object to aggregate. aggregate_fn (:obj:`Callable`, optional): Function used to aggregate sequence attributions. Defaults to summing over the last dimension and renormalizing by the norm of the source(+target) attributions for granular attributions, no aggregation for token-level attributions. """ aggregator_name = "scores" aggregator_family = "scores" default_fn = "mean" @classmethod def _aggregate( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: Union[str, Callable, None] = None, **kwargs ) -> "FeatureAttributionSequenceOutput": if aggregate_fn is None and isinstance(attr._dict_aggregate_fn, dict): aggregate_fn = DictWithDefault(default=cls.default_fn, **attr._dict_aggregate_fn) elif aggregate_fn is not None: aggregate_fn = DictWithDefault(default=aggregate_fn) # Dispatch kwargs to the corresponding field-specific functions. # E.g. aggregate_source_attributions will take care of the source_attributions field. aggregated_sequence_attribution_fields = {} for field in attr.to_dict().keys(): if aggregate_fn is not None: kwargs["aggregate_fn"] = aggregate_fn[field] # If the subclass is a dict, then we assume its fields represent variants depending on the aggregator # family that is being used (see e.g. step_scores in DEFAULT_ATTRIBUTION_AGGREGATE_DICT) if isinstance(kwargs["aggregate_fn"], dict): kwargs["aggregate_fn"] = kwargs["aggregate_fn"][cls.aggregator_family] field_func = getattr(cls, f"aggregate_{field}") aggregated_sequence_attribution_fields[field] = field_func(attr, **kwargs) return attr.__class__(**aggregated_sequence_attribution_fields) @classmethod def _process_attribution_scores( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, select_idx: Optional[OneOrMoreIndices] = None, normalize: bool = True, **kwargs, ): fn_kwargs = extract_signature_args(kwargs, aggregate_fn) # If select_idx is a single int, no aggregation is performed do_aggregate = not isinstance(select_idx, int) has_source = attr.source_attributions is not None has_target = attr.target_attributions is not None src_scores = None if has_source: src_scores = cls._filter_scores(attr.source_attributions, dim=-1, indices=select_idx) tgt_scores = None if has_target: tgt_scores = cls._filter_scores(attr.target_attributions, dim=-1, indices=select_idx) if has_source and has_target: scores = (src_scores, tgt_scores) else: scores = src_scores if src_scores is not None else tgt_scores if aggregate_fn.takes_sequence_scores: fn_kwargs["sequence_scores"] = attr.sequence_scores if do_aggregate: scores = cls._aggregate_scores(scores, aggregate_fn, dim=-1, **fn_kwargs) if normalize: scores = normalize_fn(scores) return scores
[docs] @classmethod def post_aggregate_hook(cls, attr: "FeatureAttributionSequenceOutput", **kwargs): super().post_aggregate_hook(attr, **kwargs) cls.is_compatible(attr)
[docs] @classmethod def end_aggregation_hook(cls, attr: "FeatureAttributionSequenceOutput", **kwargs): super().end_aggregation_hook(attr, **kwargs) # Needed to ensure the attribution can be visualized try: if attr.source_attributions is not None: assert attr.source_attributions.ndim == 2, attr.source_attributions.shape if attr.target_attributions is not None: assert attr.target_attributions.ndim == 2, attr.target_attributions.shape except AssertionError as e: raise RuntimeError( f"The aggregated attributions should be 2-dimensional to be visualized. Found dimensions: {e.args[0]}" "If you're performing intermediate aggregation and don't aim to visualize the output right away, use" "do_post_aggregation_checks=False in the aggregate method to bypass this check." ) from e
@staticmethod def aggregate_source(attr: "FeatureAttributionSequenceOutput", **kwargs): return attr.source @staticmethod def aggregate_target(attr: "FeatureAttributionSequenceOutput", **kwargs): return attr.target @classmethod def aggregate_source_attributions( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, select_idx: Optional[OneOrMoreIndices] = None, normalize: bool = True, **kwargs, ): if attr.source_attributions is None: return attr.source_attributions scores = cls._process_attribution_scores(attr, aggregate_fn, select_idx, normalize, **kwargs) return scores[0] if attr.target_attributions is not None else scores @classmethod def aggregate_target_attributions( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, select_idx: Optional[OneOrMoreIndices] = None, normalize: bool = True, **kwargs, ): if attr.target_attributions is None: return attr.target_attributions scores = cls._process_attribution_scores(attr, aggregate_fn, select_idx, normalize, **kwargs) return scores[1] if attr.source_attributions is not None else scores @staticmethod def aggregate_step_scores(attr: "FeatureAttributionSequenceOutput", **kwargs): return attr.step_scores @classmethod def aggregate_sequence_scores( cls, attr: "FeatureAttributionSequenceOutput", aggregate_fn: AggregationFunction, select_idx: Optional[OneOrMoreIndices] = None, **kwargs, ): if aggregate_fn.takes_sequence_scores: return attr.sequence_scores fn_kwargs = extract_signature_args(kwargs, aggregate_fn) new_sequence_scores = {} for scores_id, seq_scores in attr.sequence_scores.items(): filtered_scores = cls._filter_scores(seq_scores, dim=-1, indices=select_idx) if not isinstance(select_idx, int): filtered_scores = cls._aggregate_scores(filtered_scores, aggregate_fn, dim=-1, **fn_kwargs) new_sequence_scores[scores_id] = filtered_scores return new_sequence_scores @staticmethod def aggregate_attr_pos_start(attr: "FeatureAttributionSequenceOutput", **kwargs): return attr.attr_pos_start @staticmethod def aggregate_attr_pos_end(attr: "FeatureAttributionSequenceOutput", **kwargs): return attr.attr_pos_end @staticmethod def is_compatible(attr: "FeatureAttributionSequenceOutput"): from .attribution import FeatureAttributionSequenceOutput assert isinstance(attr, FeatureAttributionSequenceOutput) if attr.source_attributions is not None: assert attr.source_attributions.shape[0] == len(attr.source) assert attr.source_attributions.shape[1] == attr.attr_pos_end - attr.attr_pos_start if attr.target_attributions is not None: assert attr.target_attributions.shape[0] == min(len(attr.target), attr.attr_pos_end) assert attr.target_attributions.shape[1] == attr.attr_pos_end - attr.attr_pos_start if attr.step_scores is not None: for step_score in attr.step_scores.values(): assert len(step_score) == attr.attr_pos_end - attr.attr_pos_start @staticmethod def _filter_scores( scores: torch.Tensor, dim: int = -1, indices: Optional[OneOrMoreIndices] = None, ) -> torch.Tensor: indexed = scores.index_select(dim, validate_indices(scores, dim, indices).to(scores.device)) if isinstance(indices, int): return indexed.squeeze(dim) return indexed @staticmethod def _aggregate_scores( scores: Union[torch.Tensor, tuple[torch.Tensor, ...]], aggregate_fn: AggregationFunction, dim: int = -1, **kwargs, ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]: if isinstance(scores, tuple) and aggregate_fn.takes_single_tensor: return tuple(aggregate_fn(score, dim=dim, **kwargs) for score in scores) return aggregate_fn(scores, dim=dim, **kwargs)
[docs] class ContiguousSpanAggregator(SequenceAttributionAggregator): """Reduces sequence attributions across one or more contiguous spans. Args: attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The attribution object to aggregate. aggregate_fn (:obj:`Callable`, optional): Function used to aggregate sequence attributions. Defaults to the highest absolute value score across the aggregated span, with original sign preserved (e.g. [0.3, -0.7, 0.1] -> -0.7). source_spans (tuple of [int, int] or sequence of tuples of [int, int], optional): Spans to aggregate over for the source sequence. Defaults to no aggregation performed. target_spans (tuple of [int, int] or sequence of tuples of [int, int], optional): Spans to aggregate over for the target sequence. Defaults to no aggregation performed. """ aggregator_name = "spans" aggregator_family = "spans" default_fn = "absmax"
[docs] @classmethod def start_aggregation_hook( cls, attr: "FeatureAttributionSequenceOutput", source_spans: Optional[IndexSpan] = None, target_spans: Optional[IndexSpan] = None, **kwargs, ): super().start_aggregation_hook(attr, **kwargs) cls.validate_spans(attr.source, source_spans) cls.validate_spans(attr.target, target_spans)
[docs] @classmethod def end_aggregation_hook(cls, attr: "FeatureAttributionSequenceOutput", **kwargs): pass
[docs] @classmethod def aggregate( cls, attr: "FeatureAttributionSequenceOutput", source_spans: Optional[IndexSpan] = None, target_spans: Optional[IndexSpan] = None, **kwargs, ): """Spans can be: 1. A list of the form [pos_start, pos_end] including the contiguous positions of tokens that are to be aggregated, if all values are integers and len(span) < len(original_seq) 2. A list of the form [(pos_start_0, pos_end_0), (pos_start_1, pos_end_1)], same as above but for multiple contiguous spans. """ source_spans = cls.format_spans(source_spans) target_spans = cls.format_spans(target_spans) return super().aggregate(attr, source_spans=source_spans, target_spans=target_spans, **kwargs)
@staticmethod def format_spans(spans) -> list[tuple[int, int]]: if not spans: return spans return [spans] if isinstance(spans[0], int) else spans @classmethod def validate_spans(cls, span_sequence: "FeatureAttributionSequenceOutput", spans: Optional[IndexSpan] = None): if not spans: return allmatch = lambda l, type: all(isinstance(x, type) for x in l) assert allmatch(spans, int) or allmatch( spans, tuple ), f"All items must be either indices (int) or spans (tuple), got {spans}" spans = cls.format_spans(spans) prev_span_max = -1 for span in spans: assert len(span) == 2, f"Spans must contain two indexes, got {spans}" assert span[1] > span[0] + 1, f"Spans must be non-empty, got {spans}" assert ( span[0] >= prev_span_max ), f"Spans must be postive-valued, non-overlapping and in ascending order, got {spans}" assert span[1] < len(span_sequence), f"Span values must be indexes of the original span, got {spans}" prev_span_max = span[1] @staticmethod def _aggregate_sequential_scores(scores, x_spans, y_spans, aggregate_fn): # First aggregate alongside the y-axis scores_aggregated_y = aggregate_contiguous(scores, y_spans, aggregate_fn, aggregate_dim=1) # Then aggregate alonside the x-axis scores_aggregated_x = aggregate_contiguous(scores_aggregated_y, x_spans, aggregate_fn, aggregate_dim=0) return scores_aggregated_x @staticmethod def _relativize_target_spans(spans: list[tuple[int, int]], start: int): if start != 0 and spans: # Remove target spans referring to the unattributed prefix, rescale remaining spans to relative idxs # of the generated sequences and set 0 if the span starts before the generation begins. spans = [(s[0] - start if s[0] > start else 0, s[1] - start) for s in spans if s[1] > start] return spans @staticmethod def aggregate_source(attr, source_spans, **kwargs): return aggregate_token_sequence(attr.source, source_spans) @staticmethod def aggregate_target(attr, target_spans, **kwargs): return aggregate_token_sequence(attr.target, target_spans) @staticmethod def aggregate_source_attributions(attr, source_spans, target_spans, aggregate_fn, **kwargs): if attr.source_attributions is None: return attr.source_attributions # Handle the case in which generation starts from a prefix target_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start) # First aggregate along generated target sequence, then along attributed source return ContiguousSpanAggregator._aggregate_sequential_scores( attr.source_attributions, source_spans, target_spans, aggregate_fn ) @staticmethod def aggregate_target_attributions(attr, target_spans, aggregate_fn, **kwargs): if attr.target_attributions is None: return attr.target_attributions # Handle the case in which generation starts from a prefix gen_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start) # First aggregate along generated target sequence, then along attributed prefix return ContiguousSpanAggregator._aggregate_sequential_scores( attr.target_attributions, target_spans, gen_spans, aggregate_fn ) @staticmethod def aggregate_step_scores(attr, target_spans, aggregate_fn, **kwargs): if not attr.step_scores: return attr.step_scores out_dict = {} # Handle the case in which generation starts from a prefix target_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start) for name, step_scores in attr.step_scores.items(): agg_fn = aggregate_fn[name] if isinstance(aggregate_fn, dict) else aggregate_fn out_dict[name] = aggregate_contiguous(step_scores, target_spans, agg_fn, aggregate_dim=0) return out_dict @staticmethod def aggregate_sequence_scores(attr, source_spans, target_spans, aggregate_fn, **kwargs): # Assume sequence scores are shaped like source attributions if not attr.sequence_scores: return attr.sequence_scores out_dict = {} # Handle the case in which generation starts from a prefix target_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start) for name, step_scores in attr.sequence_scores.items(): aggregate_fn = aggregate_fn[name] if isinstance(aggregate_fn, dict) else aggregate_fn if name.startswith("decoder"): out_dict[name] = ContiguousSpanAggregator._aggregate_sequential_scores( step_scores, target_spans, target_spans, aggregate_fn ) elif name.startswith("encoder"): out_dict[name] = ContiguousSpanAggregator._aggregate_sequential_scores( step_scores, source_spans, source_spans, aggregate_fn ) else: out_dict[name] = ContiguousSpanAggregator._aggregate_sequential_scores( step_scores, source_spans, target_spans, aggregate_fn ) return out_dict @staticmethod def aggregate_attr_pos_start(attr, target_spans, **kwargs): if not target_spans: return attr.attr_pos_start tot_merged_prefix = sum([s[1] - s[0] - 1 for s in target_spans if s[1] <= attr.attr_pos_start]) new_pos_start = attr.attr_pos_start - tot_merged_prefix # Handle the case in which tokens before and after the starting position are merged # The resulting merged span will include the full merged token, but merged scores will reflect only the portion # that was actually attributed. E.g. if "Hello world" if the prefix, ", how are you?" is the generation and the # token "world," is formed during merging, the "world," token will be included in the attributed targets, but # only scores of "," will be used for aggregation (i.e. no aggregation since it's a single token). overlapping = [s for s in target_spans if s[0] < attr.attr_pos_start < s[1]] if overlapping and len(overlapping) == 1: new_pos_start -= attr.attr_pos_start - overlapping[0][0] elif len(overlapping) > 1: raise RuntimeError(f"Multiple overlapping spans detected for the starting position {attr.attr_pos_start}.") return new_pos_start @staticmethod def aggregate_attr_pos_end(attr, target_spans, **kwargs): if not target_spans: return attr.attr_pos_end new_start = ContiguousSpanAggregator.aggregate_attr_pos_start(attr, target_spans, **kwargs) target_spans = ContiguousSpanAggregator._relativize_target_spans(target_spans, attr.attr_pos_start) tot_merged_sequence = sum([s[1] - s[0] - 1 for s in target_spans]) return new_start + ((attr.attr_pos_end - attr.attr_pos_start) - tot_merged_sequence)
[docs] class SubwordAggregator(ContiguousSpanAggregator): """Aggregates over subwords by automatic detecting contiguous subword spans. Args: attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The attribution object to aggregate. aggregate_fn (:obj:`Callable`, optional): Function to aggregate over the subwords. Defaults to the highest absolute value score across the aggregated span, with original sign preserved (e.g. [0.3, -0.7, 0.1] -> -0.7). aggregate_source (bool, optional): Whether to aggregate over the source sequence. Defaults to True. aggregate_target (bool, optional): Whether to aggregate over the target sequence. Defaults to True. special_chars (str or tuple of str, optional): One or more characters used to identify subword boundaries. Defaults to '▁', used by SentencePiece. If is_suffix_symbol=True, then this symbol is used to identify parts to be aggregated (e.g. # in WordPiece, ['phen', '##omen', '##al']). Otherwise, it identifies the roots that should be preserved (e.g. ▁ in SentencePiece, ['▁phen', 'omen', 'al']). is_suffix_symbol (bool, optional): Whether the special symbol is used to identify suffixes or prefixes. Defaults to False. """ aggregator_name = "subwords"
[docs] @classmethod def aggregate( cls, attr: "FeatureAttributionSequenceOutput", aggregate_source: bool = True, aggregate_target: bool = True, special_chars: Union[str, tuple[str, ...]] = "▁", is_suffix_symbol: bool = False, **kwargs, ): source_spans = [] target_spans = [] if aggregate_source: source_spans = cls.get_spans(attr.source, special_chars, is_suffix_symbol) if aggregate_target: target_spans = cls.get_spans(attr.target, special_chars, is_suffix_symbol) return super().aggregate(attr, source_spans=source_spans, target_spans=target_spans, **kwargs)
@staticmethod def get_spans(tokens: list[TokenWithId], special_chars: Union[str, tuple[str, ...]], is_suffix_symbol: bool): spans = [] last_prefix_idx = 0 has_special_chars = any(sym in token.token for token in tokens for sym in special_chars) if not has_special_chars: logger.warning( f"The {special_chars} character is currently used for subword aggregation, but no instances " "have been detected in the sequence. Change the special symbols using e.g. special_chars=('Ġ', 'Ċ')" ", and set is_suffix_symbol=True if they are used as suffix word separators (e.g. Hello</w> world</w>)" ) return spans for curr_idx, token in enumerate(tokens): # Suffix if token start with special suffix symbol, or if it doesn't have the special prefix symbol. is_suffix = token.token.startswith(special_chars) == is_suffix_symbol if is_suffix: if curr_idx == len(tokens) - 1 and curr_idx - last_prefix_idx > 1: spans.append((last_prefix_idx, curr_idx)) continue if curr_idx - last_prefix_idx > 1: spans.append((last_prefix_idx, curr_idx)) last_prefix_idx = curr_idx return spans
[docs] class PairAggregator(SequenceAttributionAggregator): """Aggregates two FeatureAttributionSequenceOutput object into a single one containing the diff. Args: attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The starting attribution object. paired_attr (:class:`~inseq.data.FeatureAttributionSequenceOutput`): The attribution object with whom the diff is computed, representing a change from `attr_start` (e.g. minimal pair edit). aggregate_fn (:obj:`Callable`, optional): Function to aggregate elementwise values of the pair. Defaults to the difference between the two elements. """ aggregator_name = "pair" aggregator_family = "pair" default_fn = lambda x, y: y - x
[docs] @classmethod def pre_aggregate_hook( cls, attr: "FeatureAttributionSequenceOutput", paired_attr: "FeatureAttributionSequenceOutput", **kwargs ): super().pre_aggregate_hook(attr, **kwargs) cls.validate_pair(attr, paired_attr)
@classmethod def validate_pair(cls, attr, paired_attr): assert len(attr.source) == len(paired_attr.source), "Source sequences must be the same length." assert len(attr.target) == len(paired_attr.target), "Target sequences must be the same length." if attr.source_attributions is not None: assert ( attr.source_attributions.shape == paired_attr.source_attributions.shape ), "Source attributions must be the same shape." if attr.target_attributions is not None: assert ( attr.target_attributions.shape == paired_attr.target_attributions.shape ), "Target attributions must be the same shape." if attr.step_scores is not None: assert paired_attr.step_scores is not None, "Paired attribution must have step scores." for key, value in attr.step_scores.items(): assert key in paired_attr.step_scores, f"Step score {key} must be in paired attribution." assert value.shape == paired_attr.step_scores[key].shape, f"Step score {key} must be the same shape." if attr.sequence_scores is not None: assert paired_attr.sequence_scores is not None, "Paired attribution must have sequence scores." for key, value in attr.sequence_scores.items(): assert key in paired_attr.sequence_scores, f"Sequence score {key} must be in paired attribution." assert ( value.shape == paired_attr.sequence_scores[key].shape ), f"Sequence score {key} must be the same shape." @staticmethod def aggregate_source(attr, paired_attr, **kwargs): return aggregate_token_pair(attr.source, paired_attr.source) @staticmethod def aggregate_target(attr, paired_attr, **kwargs): return aggregate_token_pair(attr.target, paired_attr.target) @staticmethod def aggregate_source_attributions(attr, paired_attr, aggregate_fn, **kwargs): if attr.source_attributions is None: return attr.source_attributions return aggregate_fn(attr.source_attributions, paired_attr.source_attributions) @staticmethod def aggregate_target_attributions(attr, paired_attr, aggregate_fn, **kwargs): if attr.target_attributions is None: return attr.target_attributions return aggregate_fn(attr.target_attributions, paired_attr.target_attributions) @staticmethod def aggregate_step_scores(attr, paired_attr, aggregate_fn, **kwargs): if not attr.step_scores: return attr.step_scores out_dict = {} for name, step_scores in attr.step_scores.items(): agg_fn = aggregate_fn[name] if isinstance(aggregate_fn, dict) else aggregate_fn out_dict[name] = agg_fn(step_scores, paired_attr.step_scores[name]) return out_dict @staticmethod def aggregate_sequence_scores(attr, paired_attr, aggregate_fn, **kwargs): if not attr.sequence_scores: return attr.sequence_scores out_dict = {} for name, sequence_scores in attr.sequence_scores.items(): agg_fn = aggregate_fn[name] if isinstance(aggregate_fn, dict) else aggregate_fn out_dict[name] = agg_fn(sequence_scores, paired_attr.sequence_scores[name]) return out_dict