Source code for inseq.attr.feat.gradient_attribution

# Copyright 2021 The Inseq Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gradient-based feature attribution methods."""

import logging
from typing import Any

from captum.attr import (
    DeepLift,
    GradientShap,
    InputXGradient,
    IntegratedGradients,
    LayerDeepLift,
    LayerGradientXActivation,
    LayerIntegratedGradients,
    Saliency,
)

from ...data import GranularFeatureAttributionStepOutput
from ...utils import Registry, extract_signature_args, rgetattr
from ..attribution_decorators import set_hook, unset_hook
from .attribution_utils import get_source_target_attributions
from .feature_attribution import FeatureAttribution
from .ops import DiscretetizedIntegratedGradients, SequentialIntegratedGradients

logger = logging.getLogger(__name__)


[docs] class GradientAttributionRegistry(FeatureAttribution, Registry): r"""Gradient-based attribution method registry."""
[docs] @set_hook def hook(self, **kwargs): r"""Hooks the attribution method to the model by replacing normal :obj:`nn.Embedding` with Captum's `InterpretableEmbeddingBase <https://captum.ai/api/utilities.html#captum.attr.InterpretableEmbeddingBase>`__. """ super().hook(**kwargs) if self.attribute_batch_ids and not self.forward_batch_embeds: self.target_layer = kwargs.pop("target_layer", self.attribution_model.get_embedding_layer()) logger.debug(f"target_layer={self.target_layer}") if isinstance(self.target_layer, str): self.target_layer = rgetattr(self.attribution_model.model, self.target_layer) if not self.attribute_batch_ids: self.attribution_model.configure_interpretable_embeddings()
[docs] @unset_hook def unhook(self, **kwargs): r"""Unhook the attribution method by restoring the model's original embeddings.""" super().hook(**kwargs) if self.attribute_batch_ids and not self.forward_batch_embeds: self.target_layer = None else: self.attribution_model.remove_interpretable_embeddings()
[docs] def attribute_step( self, attribute_fn_main_args: dict[str, Any], attribution_args: dict[str, Any] = {}, ) -> GranularFeatureAttributionStepOutput: r"""Performs a single attribution step for the specified attribution arguments. Args: attribute_fn_main_args (:obj:`dict`): Main arguments used for the attribution method. These are built from model inputs at the current step of the feature attribution process. attribution_args (:obj:`dict`, `optional`): Additional arguments to pass to the attribution method. These can be specified by the user while calling the top level `attribute` methods. Defaults to {}. Returns: :class:`~inseq.data.GranularFeatureAttributionStepOutput`: A dataclass containing a tensor of source attributions of size `(batch_size, source_length)`, possibly a tensor of target attributions of size `(batch_size, prefix length) if attribute_target=True and possibly a tensor of deltas of size `(batch_size)` if the attribution step supports deltas and they are requested. At this point the batch information is empty, and will later be filled by the enrich_step_output function. """ attr = self.method.attribute(**attribute_fn_main_args, **attribution_args) deltas = None if ( attribution_args.get("return_convergence_delta", False) and hasattr(self.method, "has_convergence_delta") and self.method.has_convergence_delta() ): attr, deltas = attr source_attributions, target_attributions = get_source_target_attributions( attr, self.attribution_model.is_encoder_decoder ) return GranularFeatureAttributionStepOutput( source_attributions=source_attributions if source_attributions is not None else None, target_attributions=target_attributions if target_attributions is not None else None, step_scores={"deltas": deltas} if deltas is not None else None, )
[docs] class DeepLiftAttribution(GradientAttributionRegistry): """DeepLIFT attribution method. Reference implementation: `https://captum.ai/api/deep_lift.html <https://captum.ai/api/deep_lift.html>`__. """ method_name = "deeplift" def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs): super().__init__(attribution_model) self.method = DeepLift(self.attribution_model, multiply_by_inputs) self.use_baselines = True
[docs] class GradientShapAttribution(GradientAttributionRegistry): """GradientSHAP attribution method. Reference implementation: `https://captum.ai/api/gradient_shap.html <https://captum.ai/api/gradient_shap.html>`__. """ method_name = "gradient_shap" def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs): super().__init__(attribution_model) self.method = GradientShap(self.attribution_model, multiply_by_inputs) self.use_baselines = True
[docs] class DiscretizedIntegratedGradientsAttribution(GradientAttributionRegistry): """Discretized Integrated Gradients attribution method. Reference: https://arxiv.org/abs/2108.13654 Original implementation: https://github.com/INK-USC/DIG """ method_name = "discretized_integrated_gradients" def __init__(self, attribution_model, multiply_by_inputs: bool = False, **kwargs): super().__init__(attribution_model, hook_to_model=False) self.attribution_model = attribution_model self.attribute_batch_ids = True self.use_baselines = True self.method = DiscretetizedIntegratedGradients( self.attribution_model, multiply_by_inputs, ) self.hook(**kwargs)
[docs] @set_hook def hook(self, **kwargs): load_kwargs, other_kwargs = extract_signature_args( kwargs, self.method.load_monotonic_path_builder, return_remaining=True, ) self.method.load_monotonic_path_builder( self.attribution_model.model_name, vocabulary_embeddings=self.attribution_model.vocabulary_embeddings.detach(), special_tokens=self.attribution_model.special_tokens_ids, embedding_scaling=self.attribution_model.embed_scale, **load_kwargs, ) super().hook(**other_kwargs)
[docs] class IntegratedGradientsAttribution(GradientAttributionRegistry): """Integrated Gradients attribution method. Reference implementation: `https://captum.ai/api/integrated_gradients.html <https://captum.ai/api/integrated_gradients.html>`__. """ method_name = "integrated_gradients" def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs): super().__init__(attribution_model) self.method = IntegratedGradients(self.attribution_model, multiply_by_inputs) self.use_baselines = True
[docs] class InputXGradientAttribution(GradientAttributionRegistry): """Input x Gradient attribution method. Reference implementation: `https://captum.ai/api/input_x_gradient.html <https://captum.ai/api/input_x_gradient.html>`__. """ method_name = "input_x_gradient" def __init__(self, attribution_model): super().__init__(attribution_model) self.method = InputXGradient(self.attribution_model)
[docs] class SaliencyAttribution(GradientAttributionRegistry): """Saliency attribution method. Reference implementation: `https://captum.ai/api/saliency.html <https://captum.ai/api/saliency.html>`__. """ method_name = "saliency" def __init__(self, attribution_model): super().__init__(attribution_model) self.method = Saliency(self.attribution_model)
[docs] class SequentialIntegratedGradientsAttribution(GradientAttributionRegistry): """Sequential Integrated Gradients attribution method. Reference: https://aclanthology.org/2023.findings-acl.477/ Original implementation: https://github.com/josephenguehard/time_interpret/blob/main/tint/attr/seq_ig.py """ method_name = "sequential_integrated_gradients" def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs): super().__init__(attribution_model) self.method = SequentialIntegratedGradients(self.attribution_model, multiply_by_inputs) self.use_baselines = True
# Layer methods
[docs] class LayerIntegratedGradientsAttribution(GradientAttributionRegistry): """Layer Integrated Gradients attribution method. Reference implementation: `https://captum.ai/api/layer.html#layer-integrated-gradients <https://captum.ai/api/layer.html#layer-integrated-gradients>`__. """ # noqa E501 method_name = "layer_integrated_gradients" def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs): super().__init__(attribution_model, hook_to_model=False) self.attribute_batch_ids = True self.forward_batch_embeds = False self.use_baselines = True self.hook(**kwargs) self.method = LayerIntegratedGradients( self.attribution_model, self.target_layer, multiply_by_inputs=multiply_by_inputs, )
[docs] class LayerGradientXActivationAttribution(GradientAttributionRegistry): """Layer Integrated Gradients attribution method. Reference implementation: `https://captum.ai/api/layer.html#layer-gradient-x-activation <https://captum.ai/api/layer.html#layer-gradient-x-activation>`__. """ # noqa E501 method_name = "layer_gradient_x_activation" def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs): super().__init__(attribution_model, hook_to_model=False) self.attribute_batch_ids = True self.forward_batch_embeds = False self.use_baselines = False self.hook(**kwargs) self.method = LayerGradientXActivation( self.attribution_model, self.target_layer, multiply_by_inputs=multiply_by_inputs, )
[docs] class LayerDeepLiftAttribution(GradientAttributionRegistry): """Layer DeepLIFT attribution method. Reference implementation: `https://captum.ai/api/layer.html#layer-deeplift <https://captum.ai/api/layer.html#layer-deeplift>`__. """ method_name = "layer_deeplift" def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs): super().__init__(attribution_model, hook_to_model=False) self.attribute_batch_ids = True self.forward_batch_embeds = False self.use_baselines = True self.hook(**kwargs) self.method = LayerDeepLift( self.attribution_model, self.target_layer, multiply_by_inputs=multiply_by_inputs, )