ct2_conversion.py
2.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""Helpers for converting Hugging Face translation models to CTranslate2."""
from __future__ import annotations
import copy
import logging
logger = logging.getLogger(__name__)
def convert_transformers_model(
model_name_or_path: str,
output_dir: str,
quantization: str,
*,
force: bool = False,
) -> str:
from ctranslate2.converters import TransformersConverter
from transformers import AutoConfig
class _CompatibleTransformersConverter(TransformersConverter):
def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
try:
return super().load_model(model_class, resolved_model_name_or_path, **kwargs)
except TypeError as exc:
if "unexpected keyword argument 'dtype'" not in str(exc):
raise
if kwargs.get("dtype") is None and kwargs.get("torch_dtype") is None:
raise
logger.warning(
"Retrying CTranslate2 model load without dtype hints | model=%s class=%s",
resolved_model_name_or_path,
getattr(model_class, "__name__", model_class),
)
retry_kwargs = dict(kwargs)
retry_kwargs.pop("dtype", None)
retry_kwargs.pop("torch_dtype", None)
config = retry_kwargs.get("config")
if config is None:
config = AutoConfig.from_pretrained(resolved_model_name_or_path)
else:
config = copy.deepcopy(config)
if hasattr(config, "dtype"):
config.dtype = None
if hasattr(config, "torch_dtype"):
config.torch_dtype = None
retry_kwargs["config"] = config
return super().load_model(model_class, resolved_model_name_or_path, **retry_kwargs)
converter = _CompatibleTransformersConverter(model_name_or_path)
return converter.convert(output_dir=output_dir, quantization=quantization, force=force)