"""Helpers for converting Hugging Face translation models to CTranslate2.""" from __future__ import annotations import copy import logging logger = logging.getLogger(__name__) def convert_transformers_model( model_name_or_path: str, output_dir: str, quantization: str, *, force: bool = False, ) -> str: from ctranslate2.converters import TransformersConverter from transformers import AutoConfig class _CompatibleTransformersConverter(TransformersConverter): def load_model(self, model_class, resolved_model_name_or_path, **kwargs): try: return super().load_model(model_class, resolved_model_name_or_path, **kwargs) except TypeError as exc: if "unexpected keyword argument 'dtype'" not in str(exc): raise if kwargs.get("dtype") is None and kwargs.get("torch_dtype") is None: raise logger.warning( "Retrying CTranslate2 model load without dtype hints | model=%s class=%s", resolved_model_name_or_path, getattr(model_class, "__name__", model_class), ) retry_kwargs = dict(kwargs) retry_kwargs.pop("dtype", None) retry_kwargs.pop("torch_dtype", None) config = retry_kwargs.get("config") if config is None: config = AutoConfig.from_pretrained(resolved_model_name_or_path) else: config = copy.deepcopy(config) if hasattr(config, "dtype"): config.dtype = None if hasattr(config, "torch_dtype"): config.torch_dtype = None retry_kwargs["config"] = config return super().load_model(model_class, resolved_model_name_or_path, **retry_kwargs) converter = _CompatibleTransformersConverter(model_name_or_path) return converter.convert(output_dir=output_dir, quantization=quantization, force=force)