# Copyright 2021 The Inseq Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Gradient-based feature attribution methods."""
import logging
from typing import Any
from captum.attr import (
DeepLift,
GradientShap,
InputXGradient,
IntegratedGradients,
LayerDeepLift,
LayerGradientXActivation,
LayerIntegratedGradients,
Saliency,
)
from ...data import GranularFeatureAttributionStepOutput
from ...utils import Registry, extract_signature_args, rgetattr
from ..attribution_decorators import set_hook, unset_hook
from .attribution_utils import get_source_target_attributions
from .feature_attribution import FeatureAttribution
from .ops import DiscretetizedIntegratedGradients, SequentialIntegratedGradients
logger = logging.getLogger(__name__)
[docs]
class GradientAttributionRegistry(FeatureAttribution, Registry):
r"""Gradient-based attribution method registry."""
[docs]
@set_hook
def hook(self, **kwargs):
r"""Hooks the attribution method to the model by replacing normal :obj:`nn.Embedding` with Captum's
`InterpretableEmbeddingBase <https://captum.ai/api/utilities.html#captum.attr.InterpretableEmbeddingBase>`__.
"""
super().hook(**kwargs)
if self.attribute_batch_ids and not self.forward_batch_embeds:
self.target_layer = kwargs.pop("target_layer", self.attribution_model.get_embedding_layer())
logger.debug(f"target_layer={self.target_layer}")
if isinstance(self.target_layer, str):
self.target_layer = rgetattr(self.attribution_model.model, self.target_layer)
if not self.attribute_batch_ids:
self.attribution_model.configure_interpretable_embeddings()
[docs]
@unset_hook
def unhook(self, **kwargs):
r"""Unhook the attribution method by restoring the model's original embeddings."""
super().hook(**kwargs)
if self.attribute_batch_ids and not self.forward_batch_embeds:
self.target_layer = None
else:
self.attribution_model.remove_interpretable_embeddings()
[docs]
def attribute_step(
self,
attribute_fn_main_args: dict[str, Any],
attribution_args: dict[str, Any] = {},
) -> GranularFeatureAttributionStepOutput:
r"""Performs a single attribution step for the specified attribution arguments.
Args:
attribute_fn_main_args (:obj:`dict`): Main arguments used for the attribution method. These are built from
model inputs at the current step of the feature attribution process.
attribution_args (:obj:`dict`, `optional`): Additional arguments to pass to the attribution method.
These can be specified by the user while calling the top level `attribute` methods. Defaults to {}.
Returns:
:class:`~inseq.data.GranularFeatureAttributionStepOutput`: A dataclass containing a tensor of source
attributions of size `(batch_size, source_length)`, possibly a tensor of target attributions of size
`(batch_size, prefix length) if attribute_target=True and possibly a tensor of deltas of size
`(batch_size)` if the attribution step supports deltas and they are requested. At this point the batch
information is empty, and will later be filled by the enrich_step_output function.
"""
attr = self.method.attribute(**attribute_fn_main_args, **attribution_args)
deltas = None
if (
attribution_args.get("return_convergence_delta", False)
and hasattr(self.method, "has_convergence_delta")
and self.method.has_convergence_delta()
):
attr, deltas = attr
source_attributions, target_attributions = get_source_target_attributions(
attr, self.attribution_model.is_encoder_decoder
)
return GranularFeatureAttributionStepOutput(
source_attributions=source_attributions if source_attributions is not None else None,
target_attributions=target_attributions if target_attributions is not None else None,
step_scores={"deltas": deltas} if deltas is not None else None,
)
[docs]
class DeepLiftAttribution(GradientAttributionRegistry):
"""DeepLIFT attribution method.
Reference implementation:
`https://captum.ai/api/deep_lift.html <https://captum.ai/api/deep_lift.html>`__.
"""
method_name = "deeplift"
def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model)
self.method = DeepLift(self.attribution_model, multiply_by_inputs)
self.use_baselines = True
[docs]
class GradientShapAttribution(GradientAttributionRegistry):
"""GradientSHAP attribution method.
Reference implementation:
`https://captum.ai/api/gradient_shap.html <https://captum.ai/api/gradient_shap.html>`__.
"""
method_name = "gradient_shap"
def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model)
self.method = GradientShap(self.attribution_model, multiply_by_inputs)
self.use_baselines = True
[docs]
class DiscretizedIntegratedGradientsAttribution(GradientAttributionRegistry):
"""Discretized Integrated Gradients attribution method.
Reference: https://arxiv.org/abs/2108.13654
Original implementation: https://github.com/INK-USC/DIG
"""
method_name = "discretized_integrated_gradients"
def __init__(self, attribution_model, multiply_by_inputs: bool = False, **kwargs):
super().__init__(attribution_model, hook_to_model=False)
self.attribution_model = attribution_model
self.attribute_batch_ids = True
self.use_baselines = True
self.method = DiscretetizedIntegratedGradients(
self.attribution_model,
multiply_by_inputs,
)
self.hook(**kwargs)
[docs]
@set_hook
def hook(self, **kwargs):
load_kwargs, other_kwargs = extract_signature_args(
kwargs,
self.method.load_monotonic_path_builder,
return_remaining=True,
)
self.method.load_monotonic_path_builder(
self.attribution_model.model_name,
vocabulary_embeddings=self.attribution_model.vocabulary_embeddings.detach(),
special_tokens=self.attribution_model.special_tokens_ids,
embedding_scaling=self.attribution_model.embed_scale,
**load_kwargs,
)
super().hook(**other_kwargs)
[docs]
class IntegratedGradientsAttribution(GradientAttributionRegistry):
"""Integrated Gradients attribution method.
Reference implementation:
`https://captum.ai/api/integrated_gradients.html <https://captum.ai/api/integrated_gradients.html>`__.
"""
method_name = "integrated_gradients"
def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model)
self.method = IntegratedGradients(self.attribution_model, multiply_by_inputs)
self.use_baselines = True
[docs]
class SaliencyAttribution(GradientAttributionRegistry):
"""Saliency attribution method.
Reference implementation:
`https://captum.ai/api/saliency.html <https://captum.ai/api/saliency.html>`__.
"""
method_name = "saliency"
def __init__(self, attribution_model):
super().__init__(attribution_model)
self.method = Saliency(self.attribution_model)
[docs]
class SequentialIntegratedGradientsAttribution(GradientAttributionRegistry):
"""Sequential Integrated Gradients attribution method.
Reference: https://aclanthology.org/2023.findings-acl.477/
Original implementation: https://github.com/josephenguehard/time_interpret/blob/main/tint/attr/seq_ig.py
"""
method_name = "sequential_integrated_gradients"
def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model)
self.method = SequentialIntegratedGradients(self.attribution_model, multiply_by_inputs)
self.use_baselines = True
# Layer methods
[docs]
class LayerIntegratedGradientsAttribution(GradientAttributionRegistry):
"""Layer Integrated Gradients attribution method.
Reference implementation:
`https://captum.ai/api/layer.html#layer-integrated-gradients <https://captum.ai/api/layer.html#layer-integrated-gradients>`__.
""" # noqa E501
method_name = "layer_integrated_gradients"
def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model, hook_to_model=False)
self.attribute_batch_ids = True
self.forward_batch_embeds = False
self.use_baselines = True
self.hook(**kwargs)
self.method = LayerIntegratedGradients(
self.attribution_model,
self.target_layer,
multiply_by_inputs=multiply_by_inputs,
)
[docs]
class LayerGradientXActivationAttribution(GradientAttributionRegistry):
"""Layer Integrated Gradients attribution method.
Reference implementation:
`https://captum.ai/api/layer.html#layer-gradient-x-activation <https://captum.ai/api/layer.html#layer-gradient-x-activation>`__.
""" # noqa E501
method_name = "layer_gradient_x_activation"
def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model, hook_to_model=False)
self.attribute_batch_ids = True
self.forward_batch_embeds = False
self.use_baselines = False
self.hook(**kwargs)
self.method = LayerGradientXActivation(
self.attribution_model,
self.target_layer,
multiply_by_inputs=multiply_by_inputs,
)
[docs]
class LayerDeepLiftAttribution(GradientAttributionRegistry):
"""Layer DeepLIFT attribution method.
Reference implementation:
`https://captum.ai/api/layer.html#layer-deeplift <https://captum.ai/api/layer.html#layer-deeplift>`__.
"""
method_name = "layer_deeplift"
def __init__(self, attribution_model, multiply_by_inputs: bool = True, **kwargs):
super().__init__(attribution_model, hook_to_model=False)
self.attribute_batch_ids = True
self.forward_batch_embeds = False
self.use_baselines = True
self.hook(**kwargs)
self.method = LayerDeepLift(
self.attribution_model,
self.target_layer,
multiply_by_inputs=multiply_by_inputs,
)