0fd2f875
tangwang
translate
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
|
import torch
from translation.backends.local_seq2seq import MarianMTTranslationBackend, NLLBTranslationBackend
from translation.service import TranslationService
class _FakeBatch(dict):
def to(self, device):
self["device"] = device
return self
class _FakeTokenizer:
def __init__(self):
self.src_lang = None
self.pad_token = "</s>"
self.eos_token = "</s>"
self.lang_code_to_id = {"eng_Latn": 101, "zho_Hans": 202}
self.last_call = None
def __call__(self, texts, **kwargs):
self.last_call = {"texts": list(texts), **kwargs}
return _FakeBatch({"input_ids": torch.tensor([[1, 2, 3]])})
def batch_decode(self, generated, skip_special_tokens=True):
del generated, skip_special_tokens
return ["translated" for _ in range(len(self.last_call["texts"]))]
def convert_tokens_to_ids(self, token):
return self.lang_code_to_id[token]
class _FakeModel:
def to(self, device):
self.device = device
return self
def eval(self):
return self
def generate(self, **kwargs):
self.last_generate_kwargs = kwargs
return [[42]]
def _stub_load_model(self):
self.tokenizer = _FakeTokenizer()
self.seq2seq_model = _FakeModel()
def test_marian_language_validation(monkeypatch):
monkeypatch.setattr(MarianMTTranslationBackend, "_load_model", _stub_load_model)
backend = MarianMTTranslationBackend(
name="opus-mt-zh-en",
model_id="Helsinki-NLP/opus-mt-zh-en",
model_dir="./models/translation/Helsinki-NLP/opus-mt-zh-en",
device="cpu",
torch_dtype="float32",
batch_size=1,
max_input_length=16,
max_new_tokens=16,
num_beams=1,
source_langs=["zh"],
target_langs=["en"],
)
result = backend.translate("测试", source_lang="zh", target_lang="en")
assert result == "translated"
try:
backend.translate("test", source_lang="en", target_lang="zh")
except ValueError as exc:
assert "source languages" in str(exc)
else:
raise AssertionError("Expected unsupported source language to raise")
def test_nllb_uses_src_lang_and_forced_bos(monkeypatch):
monkeypatch.setattr(NLLBTranslationBackend, "_load_model", _stub_load_model)
backend = NLLBTranslationBackend(
name="nllb-200-distilled-600m",
model_id="facebook/nllb-200-distilled-600M",
model_dir="./models/translation/facebook/nllb-200-distilled-600M",
device="cpu",
torch_dtype="float32",
batch_size=1,
max_input_length=16,
max_new_tokens=16,
num_beams=1,
)
result = backend.translate("test", source_lang="en", target_lang="zh")
assert result == "translated"
assert backend.tokenizer.src_lang == "eng_Latn"
assert backend.seq2seq_model.last_generate_kwargs["forced_bos_token_id"] == 202
|
0fd2f875
tangwang
translate
|
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
created = []
def _fake_create_backend(self, *, name, backend_type, cfg):
del self, cfg
created.append((name, backend_type))
class _Backend:
model = name
@property
def supports_batch(self):
return True
def translate(self, text, target_lang, source_lang=None, scene=None):
del target_lang, source_lang, scene
return text
return _Backend()
monkeypatch.setattr(TranslationService, "_create_backend", _fake_create_backend)
config = {
"service_url": "http://127.0.0.1:6006",
"timeout_sec": 10.0,
"default_model": "opus-mt-en-zh",
"default_scene": "general",
"capabilities": {
"opus-mt-en-zh": {
"enabled": True,
"backend": "local_marian",
|
0fd2f875
tangwang
translate
|
130
131
132
133
134
135
136
137
138
139
140
141
|
"model_id": "dummy",
"model_dir": "dummy",
"device": "cpu",
"torch_dtype": "float32",
"batch_size": 1,
"max_input_length": 8,
"max_new_tokens": 8,
"num_beams": 1,
},
"nllb-200-distilled-600m": {
"enabled": True,
"backend": "local_nllb",
|
0fd2f875
tangwang
translate
|
143
144
145
146
147
148
149
150
151
152
153
|
"model_id": "dummy",
"model_dir": "dummy",
"device": "cpu",
"torch_dtype": "float32",
"batch_size": 1,
"max_input_length": 8,
"max_new_tokens": 8,
"num_beams": 1,
},
},
"cache": {
|