Source code for inseq.models

import logging
from typing import Optional, Union

from rich.status import Status

from ..utils import isnotebook, optional
from ..utils.typing import ModelClass, ModelIdentifier
from .attribution_model import AttributionModel, InputFormatter
from .decoder_only import DecoderOnlyAttributionModel
from .encoder_decoder import EncoderDecoderAttributionModel
from .huggingface_model import HuggingfaceDecoderOnlyModel, HuggingfaceEncoderDecoderModel, HuggingfaceModel
from .model_config import ModelConfig, register_model_config

logger = logging.getLogger(__name__)

FRAMEWORKS_MAP = {
    "hf_transformers": HuggingfaceModel,
}


[docs] def load_model( model: Union[ModelIdentifier, ModelClass], attribution_method: Optional[str] = None, framework: str = "hf_transformers", **kwargs, ) -> AttributionModel: """Factory function to load a model with or without attribution methods. Args: model (`Union[ModelIdentifier, ModelClass]`): Either a model identifier (e.g. `gpt2` in HF transformers) or an instance of a model class supported by the selected modeling framework. attribution_method (`Optional[str]`, *optional*, defaults to None): Identifier for the attribution method to use. If `None`, the model will be loaded without any attribution methods, which can be added during attribution. framework (`str`, *optional*, defaults to "hf_transformers"): The framework to use for loading the model. Currently, only HF transformers is supported. Returns: `AttributionModel`: An instance of one of `AttributionModel` children classes matching the selected framework and model architecture. """ model_name = model if isinstance(model, str) else "model" method_desc = f"with {attribution_method} method..." if attribution_method else " without attribution methods..." load_msg = f"Loading {model_name} {method_desc}" with optional(not isnotebook(), Status(load_msg), logger.info, msg=load_msg): return FRAMEWORKS_MAP[framework].load(model, attribution_method, **kwargs)
def list_supported_frameworks() -> list[str]: """Lists identifiers for all available frameworks. These can be used to load models with the `framework` argument in the :meth:`~inseq.load_model` function. """ return list(FRAMEWORKS_MAP.keys()) __all__ = [ "AttributionModel", "InputFormatter", "HuggingfaceModel", "HuggingfaceEncoderDecoderModel", "HuggingfaceDecoderOnlyModel", "DecoderOnlyAttributionModel", "EncoderDecoderAttributionModel", "load_model", "list_supported_frameworks", "ModelConfig", "register_model_config", ]