test_translation_converter_resolution.py
3.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from __future__ import annotations
import sys
import types
import pytest
import translation.ct2_conversion as ct2_conversion
class _FakeTransformersConverter:
def __init__(self, model_name_or_path):
self.model_name_or_path = model_name_or_path
self.load_calls = []
def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
self.load_calls.append(
{
"model_class": model_class,
"resolved_model_name_or_path": resolved_model_name_or_path,
"kwargs": dict(kwargs),
}
)
if "dtype" in kwargs or "torch_dtype" in kwargs:
raise TypeError("M2M100ForConditionalGeneration.__init__() got an unexpected keyword argument 'dtype'")
return {"loaded": True, "path": resolved_model_name_or_path}
def convert(self, output_dir, quantization=None, force=False):
loaded = self.load_model("FakeModel", self.model_name_or_path, dtype="float32")
return {
"loaded": loaded,
"output_dir": output_dir,
"quantization": quantization,
"force": force,
"load_calls": list(self.load_calls),
}
def _install_fake_ctranslate2(monkeypatch, base_converter):
converters_module = types.ModuleType("ctranslate2.converters")
converters_module.TransformersConverter = base_converter
ctranslate2_module = types.ModuleType("ctranslate2")
ctranslate2_module.converters = converters_module
monkeypatch.setitem(sys.modules, "ctranslate2", ctranslate2_module)
monkeypatch.setitem(sys.modules, "ctranslate2.converters", converters_module)
def test_convert_transformers_model_retries_without_torch_dtype(monkeypatch):
_install_fake_ctranslate2(monkeypatch, _FakeTransformersConverter)
fake_transformers = types.ModuleType("transformers")
fake_transformers.AutoConfig = types.SimpleNamespace(
from_pretrained=lambda path: types.SimpleNamespace(torch_dtype="float32", path=path)
)
monkeypatch.setitem(sys.modules, "transformers", fake_transformers)
result = ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16")
assert result["loaded"] == {"loaded": True, "path": "fake-model"}
assert result["output_dir"] == "/tmp/out"
assert result["quantization"] == "float16"
assert result["force"] is False
assert len(result["load_calls"]) == 2
assert result["load_calls"][0] == {
"model_class": "FakeModel",
"resolved_model_name_or_path": "fake-model",
"kwargs": {"dtype": "float32"},
}
assert result["load_calls"][1]["model_class"] == "FakeModel"
assert result["load_calls"][1]["resolved_model_name_or_path"] == "fake-model"
assert getattr(result["load_calls"][1]["kwargs"]["config"], "torch_dtype", "missing") is None
def test_convert_transformers_model_preserves_unrelated_type_errors(monkeypatch):
class _AlwaysFailingConverter(_FakeTransformersConverter):
def load_model(self, model_class, resolved_model_name_or_path, **kwargs):
raise TypeError("different constructor error")
_install_fake_ctranslate2(monkeypatch, _AlwaysFailingConverter)
fake_transformers = types.ModuleType("transformers")
fake_transformers.AutoConfig = types.SimpleNamespace(from_pretrained=lambda path: types.SimpleNamespace(path=path))
monkeypatch.setitem(sys.modules, "transformers", fake_transformers)
with pytest.raises(TypeError, match="different constructor error"):
ct2_conversion.convert_transformers_model("fake-model", "/tmp/out", "float16")