# Copyright 2023 The Inseq Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import abstractmethod
from typing import Union
import torch
from torch.linalg import vector_norm
from ..utils import Registry, available_classes
from ..utils.typing import (
ScoreTensor,
)
logger = logging.getLogger(__name__)
class AggregationFunction(Registry):
registry_attr = "aggregation_function_name"
def __init__(self):
self.takes_single_tensor: bool = True
self.takes_sequence_scores: bool = False
@abstractmethod
def __call__(
self,
scores: Union[torch.Tensor, tuple[torch.Tensor, ...]],
dim: int,
**kwargs,
) -> ScoreTensor:
pass
[docs]
class MeanAggregationFunction(AggregationFunction):
aggregation_function_name = "mean"
def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor:
return scores.mean(dim)
[docs]
class MaxAggregationFunction(AggregationFunction):
aggregation_function_name = "max"
def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor:
return scores.max(dim).values
[docs]
class MinAggregationFunction(AggregationFunction):
aggregation_function_name = "min"
def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor:
return scores.min(dim).values
[docs]
class SumAggregationFunction(AggregationFunction):
aggregation_function_name = "sum"
def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor:
return scores.sum(dim)
[docs]
class ProdAggregationFunction(AggregationFunction):
aggregation_function_name = "prod"
def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor:
return scores.prod(dim)
[docs]
class AbsMaxAggregationFunction(AggregationFunction):
aggregation_function_name = "absmax"
def __call__(self, scores: torch.Tensor, dim: int) -> ScoreTensor:
return scores.gather(dim, torch.nan_to_num(scores).abs().argmax(dim, keepdim=True)).squeeze(dim)
[docs]
class VectorNormAggregationFunction(AggregationFunction):
aggregation_function_name = "vnorm"
def __call__(self, scores: torch.Tensor, dim: int, vnorm_ord: int = 2) -> ScoreTensor:
return vector_norm(scores, ord=vnorm_ord, dim=dim)
DEFAULT_ATTRIBUTION_AGGREGATE_DICT = {
"source_attributions": {"spans": "absmax"},
"target_attributions": {"spans": "absmax"},
"step_scores": {
"spans": {
"probability": "prod",
"entropy": "sum",
"crossentropy": "sum",
"perplexity": "prod",
"contrast_prob_diff": "prod",
"contrast_prob": "prod",
"pcxmi": "sum",
"kl_divergence": "sum",
"mc_dropout_prob_avg": "prod",
}
},
}
[docs]
def list_aggregation_functions() -> list[str]:
"""Lists identifiers for all available aggregation functions."""
return available_classes(AggregationFunction)