Blame view

tests/test_translation_converter_resolution.py 3.58 KB
f07947a5   tangwang   Improve portabili...
1
2
3
4
5
6
7
8
9
  from __future__ import annotations
  
  import sys
  import types
  
  import pytest
  
  import translation.ct2_conversion as ct2_conversion
  
99b72698   tangwang   测试回归钩子梳理
10
11
  pytestmark = [pytest.mark.translation]
  
f07947a5   tangwang   Improve portabili...
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  
  class _FakeTransformersConverter:
      def __init__(self, model_name_or_path):
          self.model_name_or_path = model_name_or_path
          self.load_calls = []
  
      def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
          self.load_calls.append(
              {
                  "model_class": model_class,
                  "resolved_model_name_or_path": resolved_model_name_or_path,
                  "kwargs": dict(kwargs),
              }
          )
          if "dtype" in kwargs or "torch_dtype" in kwargs:
              raise TypeError("M2M100ForConditionalGeneration.__init__() got an unexpected keyword argument 'dtype'")
          return {"loaded": True, "path": resolved_model_name_or_path}
  
      def convert(self, output_dir, quantization=None, force=False):
          loaded = self.load_model("FakeModel", self.model_name_or_path, dtype="float32")
          return {
              "loaded": loaded,
              "output_dir": output_dir,
              "quantization": quantization,
              "force": force,
              "load_calls": list(self.load_calls),
          }
  
  
  def _install_fake_ctranslate2(monkeypatch, base_converter):
      converters_module = types.ModuleType("ctranslate2.converters")
      converters_module.TransformersConverter = base_converter
      ctranslate2_module = types.ModuleType("ctranslate2")
      ctranslate2_module.converters = converters_module
  
      monkeypatch.setitem(sys.modules, "ctranslate2", ctranslate2_module)
      monkeypatch.setitem(sys.modules, "ctranslate2.converters", converters_module)
  
  
  def test_convert_transformers_model_retries_without_torch_dtype(monkeypatch):
      _install_fake_ctranslate2(monkeypatch, _FakeTransformersConverter)
      fake_transformers = types.ModuleType("transformers")
      fake_transformers.AutoConfig = types.SimpleNamespace(
          from_pretrained=lambda path: types.SimpleNamespace(torch_dtype="float32", path=path)
      )
      monkeypatch.setitem(sys.modules, "transformers", fake_transformers)
  
      result = ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16")
  
      assert result["loaded"] == {"loaded": True, "path": "fake-model"}
      assert result["output_dir"] == "/tmp/out"
      assert result["quantization"] == "float16"
      assert result["force"] is False
      assert len(result["load_calls"]) == 2
      assert result["load_calls"][0] == {
          "model_class": "FakeModel",
          "resolved_model_name_or_path": "fake-model",
          "kwargs": {"dtype": "float32"},
      }
      assert result["load_calls"][1]["model_class"] == "FakeModel"
      assert result["load_calls"][1]["resolved_model_name_or_path"] == "fake-model"
      assert getattr(result["load_calls"][1]["kwargs"]["config"], "torch_dtype", "missing") is None
  
  
  def test_convert_transformers_model_preserves_unrelated_type_errors(monkeypatch):
      class _AlwaysFailingConverter(_FakeTransformersConverter):
          def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
              raise TypeError("different constructor error")
  
      _install_fake_ctranslate2(monkeypatch, _AlwaysFailingConverter)
      fake_transformers = types.ModuleType("transformers")
      fake_transformers.AutoConfig = types.SimpleNamespace(from_pretrained=lambda path: types.SimpleNamespace(path=path))
      monkeypatch.setitem(sys.modules, "transformers", fake_transformers)
  
      with pytest.raises(TypeError, match="different constructor error"):
          ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16")