Source code for inseq.data.aggregation_functions

# Copyright 2023 The Inseq Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from abc import abstractmethod
from typing import Union

import torch
from torch.linalg import vector_norm

from ..utils import Registry, available_classes
from ..utils.typing import (
    ScoreTensor,
)

logger = logging.getLogger(__name__)


class AggregationFunction(Registry):
    registry_attr = "aggregation_function_name"

    def __init__(self):
        self.takes_single_tensor: bool = True
        self.takes_sequence_scores: bool = False

    @abstractmethod
    def __call__(
        self,
        scores: Union[torch.Tensor, tuple[torch.Tensor, ...]],
        dim: int,
        **kwargs,
    ) -> ScoreTensor:
        pass


[docs] class MeanAggregationFunction(AggregationFunction): aggregation_function_name = "mean" def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor: return scores.mean(dim)
[docs] class MaxAggregationFunction(AggregationFunction): aggregation_function_name = "max" def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor: return scores.max(dim).values
[docs] class MinAggregationFunction(AggregationFunction): aggregation_function_name = "min" def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor: return scores.min(dim).values
[docs] class SumAggregationFunction(AggregationFunction): aggregation_function_name = "sum" def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor: return scores.sum(dim)
[docs] class ProdAggregationFunction(AggregationFunction): aggregation_function_name = "prod" def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor: return scores.prod(dim)
[docs] class AbsMaxAggregationFunction(AggregationFunction): aggregation_function_name = "absmax" def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor: return scores.gather(dim, torch.nan_to_num(scores).abs().argmax(dim, keepdim=True)).squeeze(dim)
[docs] class VectorNormAggregationFunction(AggregationFunction): aggregation_function_name = "vnorm" def __call__(self, scores: torch.Tensor, dim: int, vnorm_ord: int = 2) -> ScoreTensor: return vector_norm(scores, ord=vnorm_ord, dim=dim)
DEFAULT_ATTRIBUTION_AGGREGATE_DICT = { "source_attributions": {"spans": "absmax"}, "target_attributions": {"spans": "absmax"}, "step_scores": { "spans": { "probability": "prod", "entropy": "sum", "crossentropy": "sum", "perplexity": "prod", "contrast_prob_diff": "prod", "contrast_prob": "prod", "pcxmi": "sum", "kl_divergence": "sum", "mc_dropout_prob_avg": "prod", } }, }
[docs] def list_aggregation_functions() -> list[str]: """Lists identifiers for all available aggregation functions.""" return available_classes(AggregationFunction)