Blame view

tests/test_translation_converter_resolution.py 3.54 KB
f07947a5   tangwang   Improve portabili...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
  from __future__ import annotations
  
  import sys
  import types
  
  import pytest
  
  import translation.ct2_conversion as ct2_conversion
  
  
  class _FakeTransformersConverter:
      def __init__(self, model_name_or_path):
          self.model_name_or_path = model_name_or_path
          self.load_calls = []
  
      def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
          self.load_calls.append(
              {
                  "model_class": model_class,
                  "resolved_model_name_or_path": resolved_model_name_or_path,
                  "kwargs": dict(kwargs),
              }
          )
          if "dtype" in kwargs or "torch_dtype" in kwargs:
              raise TypeError("M2M100ForConditionalGeneration.__init__() got an unexpected keyword argument 'dtype'")
          return {"loaded": True, "path": resolved_model_name_or_path}
  
      def convert(self, output_dir, quantization=None, force=False):
          loaded = self.load_model("FakeModel", self.model_name_or_path, dtype="float32")
          return {
              "loaded": loaded,
              "output_dir": output_dir,
              "quantization": quantization,
              "force": force,
              "load_calls": list(self.load_calls),
          }
  
  
  def _install_fake_ctranslate2(monkeypatch, base_converter):
      converters_module = types.ModuleType("ctranslate2.converters")
      converters_module.TransformersConverter = base_converter
      ctranslate2_module = types.ModuleType("ctranslate2")
      ctranslate2_module.converters = converters_module
  
      monkeypatch.setitem(sys.modules, "ctranslate2", ctranslate2_module)
      monkeypatch.setitem(sys.modules, "ctranslate2.converters", converters_module)
  
  
  def test_convert_transformers_model_retries_without_torch_dtype(monkeypatch):
      _install_fake_ctranslate2(monkeypatch, _FakeTransformersConverter)
      fake_transformers = types.ModuleType("transformers")
      fake_transformers.AutoConfig = types.SimpleNamespace(
          from_pretrained=lambda path: types.SimpleNamespace(torch_dtype="float32", path=path)
      )
      monkeypatch.setitem(sys.modules, "transformers", fake_transformers)
  
      result = ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16")
  
      assert result["loaded"] == {"loaded": True, "path": "fake-model"}
      assert result["output_dir"] == "/tmp/out"
      assert result["quantization"] == "float16"
      assert result["force"] is False
      assert len(result["load_calls"]) == 2
      assert result["load_calls"][0] == {
          "model_class": "FakeModel",
          "resolved_model_name_or_path": "fake-model",
          "kwargs": {"dtype": "float32"},
      }
      assert result["load_calls"][1]["model_class"] == "FakeModel"
      assert result["load_calls"][1]["resolved_model_name_or_path"] == "fake-model"
      assert getattr(result["load_calls"][1]["kwargs"]["config"], "torch_dtype", "missing") is None
  
  
  def test_convert_transformers_model_preserves_unrelated_type_errors(monkeypatch):
      class _AlwaysFailingConverter(_FakeTransformersConverter):
          def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
              raise TypeError("different constructor error")
  
      _install_fake_ctranslate2(monkeypatch, _AlwaysFailingConverter)
      fake_transformers = types.ModuleType("transformers")
      fake_transformers.AutoConfig = types.SimpleNamespace(from_pretrained=lambda path: types.SimpleNamespace(path=path))
      monkeypatch.setitem(sys.modules, "transformers", fake_transformers)
  
      with pytest.raises(TypeError, match="different constructor error"):
          ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16")