Blame view

translation/ct2_conversion.py 2.04 KB
f07947a5   tangwang   Improve portabili...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  """Helpers for converting Hugging Face translation models to CTranslate2."""
  
  from __future__ import annotations
  
  import copy
  import logging
  
  logger = logging.getLogger(__name__)
  
  
  def convert_transformers_model(
      model_name_or_path: str,
      output_dir: str,
      quantization: str,
      *,
      force: bool = False,
  ) -> str:
      from ctranslate2.converters import TransformersConverter
      from transformers import AutoConfig
  
      class _CompatibleTransformersConverter(TransformersConverter):
          def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
              try:
                  return super().load_model(model_class, resolved_model_name_or_path, **kwargs)
              except TypeError as exc:
                  if "unexpected keyword argument 'dtype'" not in str(exc):
                      raise
                  if kwargs.get("dtype") is None and kwargs.get("torch_dtype") is None:
                      raise
  
                  logger.warning(
                      "Retrying CTranslate2 model load without dtype hints | model=%s class=%s",
                      resolved_model_name_or_path,
                      getattr(model_class, "__name__", model_class),
                  )
                  retry_kwargs = dict(kwargs)
                  retry_kwargs.pop("dtype", None)
                  retry_kwargs.pop("torch_dtype", None)
                  config = retry_kwargs.get("config")
                  if config is None:
                      config = AutoConfig.from_pretrained(resolved_model_name_or_path)
                  else:
                      config = copy.deepcopy(config)
                  if hasattr(config, "dtype"):
                      config.dtype = None
                  if hasattr(config, "torch_dtype"):
                      config.torch_dtype = None
                  retry_kwargs["config"] = config
                  return super().load_model(model_class, resolved_model_name_or_path, **retry_kwargs)
  
      converter = _CompatibleTransformersConverter(model_name_or_path)
      return converter.convert(output_dir=output_dir, quantization=quantization, force=force)