ct2_conversion.py 2.04 KB
"""Helpers for converting Hugging Face translation models to CTranslate2."""

from __future__ import annotations

import copy
import logging

logger = logging.getLogger(__name__)


def convert_transformers_model(
    model_name_or_path: str,
    output_dir: str,
    quantization: str,
    *,
    force: bool = False,
) -> str:
    from ctranslate2.converters import TransformersConverter
    from transformers import AutoConfig

    class _CompatibleTransformersConverter(TransformersConverter):
        def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
            try:
                return super().load_model(model_class, resolved_model_name_or_path, **kwargs)
            except TypeError as exc:
                if "unexpected keyword argument 'dtype'" not in str(exc):
                    raise
                if kwargs.get("dtype") is None and kwargs.get("torch_dtype") is None:
                    raise

                logger.warning(
                    "Retrying CTranslate2 model load without dtype hints | model=%s class=%s",
                    resolved_model_name_or_path,
                    getattr(model_class, "__name__", model_class),
                )
                retry_kwargs = dict(kwargs)
                retry_kwargs.pop("dtype", None)
                retry_kwargs.pop("torch_dtype", None)
                config = retry_kwargs.get("config")
                if config is None:
                    config = AutoConfig.from_pretrained(resolved_model_name_or_path)
                else:
                    config = copy.deepcopy(config)
                if hasattr(config, "dtype"):
                    config.dtype = None
                if hasattr(config, "torch_dtype"):
                    config.torch_dtype = None
                retry_kwargs["config"] = config
                return super().load_model(model_class, resolved_model_name_or_path, **retry_kwargs)

    converter = _CompatibleTransformersConverter(model_name_or_path)
    return converter.convert(output_dir=output_dir, quantization=quantization, force=force)