Source code for inseq.models.model_config
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
import yaml
logger = logging.getLogger(__name__)
@dataclass
class ModelConfig:
"""Configuration used by the methods for which the attribute ``use_model_config=True``.
Args:
self_attention_module (:obj:`str`):
The name of the module performing the self-attention computation (e.g.``attn`` for the GPT-2 model in
transformers). Can be identified by looking at the name of the self-attention module attribute
in the model's transformer block class (e.g. :obj:`transformers.models.gpt2.GPT2Block` for GPT-2).
cross_attention_module (:obj:`str`):
The name of the module performing the cross-attention computation (e.g.``encoder_attn`` for MarianMT models
in transformers). Can be identified by looking at the name of the cross-attention module attribute
in the model's transformer block class (e.g. :obj:`transformers.models.marian.MarianDecoderLayer`).
value_vector (:obj:`str`):
The name of the variable in the forward pass of the attention module containing the value vector
(e.g. ``value`` for the GPT-2 model in transformers). Can be identified by looking at the forward pass of
the attention module (e.g. :obj:`transformers.models.gpt2.modeling_gpt2.GPT2Attention.forward` for GPT-2).
"""
self_attention_module: str
value_vector: str
cross_attention_module: Optional[str] = None
MODEL_CONFIGS = {
model_type: ModelConfig(**cfg)
for model_type, cfg in yaml.safe_load(open(Path(__file__).parent / "model_config.yaml", encoding="utf8")).items()
}
def get_model_config(model_type: str) -> ModelConfig:
if model_type not in MODEL_CONFIGS:
raise ValueError(
f"A configuration for the {model_type} model is not defined. "
"You can register a configuration with :meth:`~inseq.models.register_model_config`, "
"or request it to be added to the library by opening an issue on GitHub: "
"https://github.com/inseq-team/inseq/issues"
)
return MODEL_CONFIGS[model_type]
[docs]
def register_model_config(
model_type: str,
config: dict,
overwrite: bool = False,
allow_partial: bool = False,
) -> None:
"""Allows to register a model configuration for a given model type. The configuration is a dictionary containing
information required the methods for which the attribute ``use_model_config=True``.
Args:
model_type (`str`):
The class of the model for which the configuration is registered, used as key in the stored configuration.
E.g. GPT2LMHeadModel for the GPT-2 model in HuggingFace Transformers.
config (`dict`):
A dictionary containing the configuration for the model. The fields should match those of the
:class:`~inseq.models.ModelConfig` class.
overwrite (`bool`, *optional*, defaults to False):
If `True`, the configuration will be overwritten if it already exists.
allow_partial (`bool`, *optional*, defaults to False):
If `True`, the configuration can be partial, i.e. it can contain only a subset of the fields of the
:class:`~inseq.models.ModelConfig` class. The missing fields will be set to `None`.
Raises:
`ValueError`: If the model type is already registered and `overwrite=False`, or if the configuration is partial
and `allow_partial=False`.
"""
if model_type in MODEL_CONFIGS:
if not overwrite:
raise ValueError(
f"{model_type} is already registered in model configurations.Override with overwrite=True."
)
logger.warning(f"Overwriting {model_type} config.")
all_fields = set(ModelConfig.__dataclass_fields__.keys())
config_fields = set(config.keys())
diff = all_fields - config_fields
if diff and not allow_partial:
raise ValueError(
f"Missing fields {','.join(diff)} in model configuration for {model_type}."
"Set allow_partial=True to allow partial configuration."
)
if allow_partial:
config = {**{field: None for field in diff}, **config}
MODEL_CONFIGS[model_type] = ModelConfig(**config)