From 3eff49b7015264da642c0effb871c60dc5b68129 Mon Sep 17 00:00:00 2001 From: tangwang Date: Tue, 17 Mar 2026 21:29:18 +0800 Subject: [PATCH] trans nllb-200-distilled-600M性能提升 --- config/config.yaml | 5 +++-- docs/翻译模块说明.md | 42 +++--------------------------------------- perf_reports/20260317/translation_local_models/README.md | 167 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------------------------------------------------------------------------------- scripts/benchmark_translation_local_models.py | 15 +++++++++++++++ translation/README.md | 82 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------- translation/backends/local_seq2seq.py | 9 +++++++++ translation/service.py | 2 ++ 7 files changed, 162 insertions(+), 160 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index ef70ab7..b54acb4 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -141,10 +141,11 @@ services: model_dir: "./models/translation/facebook/nllb-200-distilled-600M" device: "cuda" torch_dtype: "float16" - batch_size: 8 + batch_size: 16 max_input_length: 256 - max_new_tokens: 256 + max_new_tokens: 64 num_beams: 1 + attn_implementation: "sdpa" opus-mt-zh-en: enabled: true backend: "local_marian" diff --git a/docs/翻译模块说明.md b/docs/翻译模块说明.md index 9734723..6ff372b 100644 --- a/docs/翻译模块说明.md +++ b/docs/翻译模块说明.md @@ -31,43 +31,7 @@ DEEPL_AUTH_KEY=xxx - `service_url`、`default_model`、`default_scene` 只从 `config/config.yaml` 读取,不再接受环境变量静默覆盖 - 外部接口通过 `model + scene` 指定本次使用哪种能力、哪个场景 -配置入口在 `config/config.yaml -> services.translation`,核心字段示例: - -```yaml -services: - translation: - service_url: "http://127.0.0.1:6006" - default_model: "llm" - default_scene: "general" - timeout_sec: 10.0 - capabilities: - qwen-mt: - enabled: true - backend: "qwen_mt" - model: "qwen-mt-flash" - base_url: "https://dashscope-us.aliyuncs.com/compatible-mode/v1" - llm: - enabled: true - backend: "llm" - model: "qwen-flash" - base_url: "https://dashscope-us.aliyuncs.com/compatible-mode/v1" - deepl: - enabled: false - backend: "deepl" - api_url: "https://api.deepl.com/v2/translate" - nllb-200-distilled-600m: - enabled: false - backend: "local_nllb" - model_id: "facebook/nllb-200-distilled-600M" - opus-mt-zh-en: - enabled: false - backend: "local_marian" - model_id: "Helsinki-NLP/opus-mt-zh-en" - opus-mt-en-zh: - enabled: false - backend: "local_marian" - model_id: "Helsinki-NLP/opus-mt-en-zh" -``` +配置入口在 `config/config.yaml -> services.translation` ## 本地模型部署 @@ -163,9 +127,9 @@ services: ## 开发者接口约定(代码调用) -代码侧(如 query/indexer)通过 `translation.create_translation_client()` 获取实例并调用 `translate()`;业务侧不再存在翻译 provider 选择逻辑。 +代码侧(如 query/indexer)通过 `translation.create_translation_client()` 获取实例并调用 `translate()`; -### 输入输出形状(Shape) +### 输入输出Shape - `translate(text=...)` 支持: - **单条**:`text: str` → 返回 `Optional[str]` diff --git a/perf_reports/20260317/translation_local_models/README.md b/perf_reports/20260317/translation_local_models/README.md index 77f19b8..bc1bf96 100644 --- a/perf_reports/20260317/translation_local_models/README.md +++ b/perf_reports/20260317/translation_local_models/README.md @@ -9,105 +9,72 @@ Environment: - Driver / CUDA: `570.158.01 / 12.8` - Python env: `.venv-translator` - Dataset: [`products_analyzed.csv`](/data/saas-search/products_analyzed.csv) -- Rows in dataset: `18,576` Method: -- `opus-mt-zh-en` and `opus-mt-en-zh` were benchmarked on the full dataset using their configured runtime settings from [`config/config.yaml`](/data/saas-search/config/config.yaml). -- `nllb-200-distilled-600m` could not complete GPU cold start in the current co-resident environment because GPU memory was already heavily occupied by other long-running services. -- For `nllb-200-distilled-600m`, I therefore ran CPU baselines on a `128`-row sample from the same CSV, using `device=cpu`, `torch_dtype=float32`, `batch_size=4`, and then estimated full-dataset runtime from measured throughput. -- Quality was intentionally not evaluated; this report is performance-only. - -Current GPU co-residency at benchmark time: -- `text-embeddings-router`: about `1.3 GiB` -- `clip_server`: about `2.0 GiB` -- `VLLM::EngineCore`: about `7.2 GiB` -- `api.translator_app` process: about `2.8 GiB` -- Total occupied before `nllb` cold start: about `13.4 / 16 GiB` - -Operational finding: -- `facebook/nllb-200-distilled-600M` cannot be reliably loaded on the current shared T4 node together with the existing long-running services above. -- This is not a model-quality issue; it is a deployment-capacity issue. - -## Summary - -| Model | Direction | Device | Rows | Load s | Translate s | Items/s | Avg item ms | Batch p50 ms | Batch p95 ms | Peak GPU GiB | Success | -|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:| -| `opus-mt-zh-en` | `zh -> en` | `cuda` | 18,576 | 3.1435 | 497.7513 | 37.32 | 26.795 | 301.99 | 1835.81 | 0.382 | 1.000000 | -| `opus-mt-en-zh` | `en -> zh` | `cuda` | 18,576 | 3.1867 | 987.3994 | 18.81 | 53.155 | 449.14 | 2012.12 | 0.379 | 0.999569 | -| `nllb-200-distilled-600m` | `zh -> en` | `cpu` | 128 | 4.4589 | 132.3088 | 0.97 | 1033.662 | 3853.39 | 6896.14 | 0.0 | 1.000000 | -| `nllb-200-distilled-600m` | `en -> zh` | `cpu` | 128 | 4.5039 | 317.8845 | 0.40 | 2483.473 | 6138.87 | 35134.11 | 0.0 | 1.000000 | - -## Detailed Findings - -### 1. `opus-mt-zh-en` - -- Full dataset, `title_cn -> en`, scene=`sku_name` -- Throughput: `37.32 items/s` -- Average per-item latency: `26.795 ms` -- Batch latency: `p50 301.99 ms`, `p95 1835.81 ms`, `max 2181.61 ms` -- Input throughput: `1179.47 chars/s` -- Peak GPU allocated: `0.382 GiB` -- Peak GPU reserved: `0.473 GiB` -- Max RSS: `1355.21 MB` -- Success count: `18576/18576` - -Interpretation: -- This was the fastest of the three new local models in this benchmark. -- It is a strong candidate for large-scale `zh -> en` title translation on the current machine. - -### 2. `opus-mt-en-zh` - -- Full dataset, `title -> zh`, scene=`sku_name` -- Throughput: `18.81 items/s` -- Average per-item latency: `53.155 ms` -- Batch latency: `p50 449.14 ms`, `p95 2012.12 ms`, `max 2210.03 ms` -- Input throughput: `2081.66 chars/s` -- Peak GPU allocated: `0.379 GiB` -- Peak GPU reserved: `0.473 GiB` -- Max RSS: `1376.72 MB` -- Success count: `18568/18576` -- Failure count: `8` - -Interpretation: -- Roughly half the item throughput of `opus-mt-zh-en`. -- Still practical on this T4 for offline bulk translation. -- The `8` failed items are a runtime-stability signal worth keeping an eye on for production batch jobs, even though quality was not checked here. - -### 3. `nllb-200-distilled-600m` - -GPU result in the current shared environment: -- Cold start failed with CUDA OOM before benchmark could begin. -- Root cause was insufficient free VRAM on the shared T4, not a script error. - -CPU baseline, `zh -> en`: -- Sample size: `128` -- Throughput: `0.97 items/s` -- Average per-item latency: `1033.662 ms` -- Batch latency: `p50 3853.39 ms`, `p95 6896.14 ms`, `max 8039.91 ms` -- Max RSS: `3481.75 MB` -- Estimated full-dataset runtime at this throughput: about `19,150.52 s` = `319.18 min` = `5.32 h` - -CPU baseline, `en -> zh`: -- Sample size: `128` -- Throughput: `0.40 items/s` -- Average per-item latency: `2483.473 ms` -- Batch latency: `p50 6138.87 ms`, `p95 35134.11 ms`, `max 37388.36 ms` -- Max RSS: `3483.60 MB` -- Estimated full-dataset runtime at this throughput: about `46,440 s` = `774 min` = `12.9 h` - -Interpretation: -- In the current node layout, `nllb` is not a good fit for shared-GPU online service. -- CPU fallback is functionally available but far slower than the Marian models. -- If `nllb` is still desired, it should be considered for isolated GPU deployment, dedicated batch nodes, or lower-frequency offline tasks. - -## Practical Ranking On This Machine - -By usable real-world performance on the current node: -1. `opus-mt-zh-en` -2. `opus-mt-en-zh` -3. `nllb-200-distilled-600m` - -By deployment friendliness on the current shared T4: -1. `opus-mt-zh-en` -2. `opus-mt-en-zh` -3. `nllb-200-distilled-600m` because it currently cannot cold-start on GPU alongside the existing resident services +- `opus-mt-zh-en` and `opus-mt-en-zh` were benchmarked on the full dataset using their configured production settings. +- `nllb-200-distilled-600m` was benchmarked on a `500`-row subset after optimization. +- This report only keeps the final optimized results and final deployment recommendation. +- Quality was intentionally not evaluated; this is a performance-only report. + +## Final Production-Like Config + +For `nllb-200-distilled-600m`, the final recommended config on `Tesla T4` is: + +```yaml +nllb-200-distilled-600m: + enabled: true + backend: "local_nllb" + model_id: "facebook/nllb-200-distilled-600M" + model_dir: "./models/translation/facebook/nllb-200-distilled-600M" + device: "cuda" + torch_dtype: "float16" + batch_size: 16 + max_input_length: 256 + max_new_tokens: 64 + num_beams: 1 + attn_implementation: "sdpa" +``` + +What actually helped: +- `cuda + float16` +- `batch_size=16` +- `max_new_tokens=64` +- `attn_implementation=sdpa` + +What did not become the final recommendation: +- `batch_size=32` + Throughput can improve further, but tail latency degrades too much for a balanced default. + +## Final Results + +| Model | Direction | Device | Rows | Load s | Translate s | Items/s | Avg item ms | Batch p50 ms | Batch p95 ms | +|---|---|---:|---:|---:|---:|---:|---:|---:|---:| +| `opus-mt-zh-en` | `zh -> en` | `cuda` | 18,576 | 3.1435 | 497.7513 | 37.32 | 26.795 | 301.99 | 1835.81 | +| `opus-mt-en-zh` | `en -> zh` | `cuda` | 18,576 | 3.1867 | 987.3994 | 18.81 | 53.155 | 449.14 | 2012.12 | +| `nllb-200-distilled-600m` | `zh -> en` | `cuda` | 500 | 7.3397 | 25.9577 | 19.26 | 51.915 | 832.64 | 1263.01 | +| `nllb-200-distilled-600m` | `en -> zh` | `cuda` | 500 | 7.4152 | 42.0405 | 11.89 | 84.081 | 1093.87 | 2107.44 | + +## NLLB Resource Reality + +The common online claim that this model uses only about `1.25GB` in `float16` is best understood as a rough weight-size level, not end-to-end runtime memory. + +Actual runtime on this machine: +- loaded on `cuda:0` +- actual parameter dtype verified as `torch.float16` +- steady GPU memory after load: about `2.6 GiB` +- benchmark peak GPU memory: about `2.8-3.0 GiB` + +The difference comes from: +- CUDA context +- allocator reserved memory +- runtime activations and temporary tensors +- batch size +- input length and generation length +- framework overhead + +## Final Takeaways + +1. `opus-mt-zh-en` remains the fastest model on this machine. +2. `opus-mt-en-zh` is slower but still very practical for bulk translation. +3. `nllb-200-distilled-600m` is now fully usable on T4 after optimization. +4. `nllb` is still slower than the two Marian models, but it is the better choice when broad multilingual coverage matters more than peak throughput. diff --git a/scripts/benchmark_translation_local_models.py b/scripts/benchmark_translation_local_models.py index 8e73b6f..911f3d0 100644 --- a/scripts/benchmark_translation_local_models.py +++ b/scripts/benchmark_translation_local_models.py @@ -80,6 +80,9 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--batch-size", type=int, default=0, help="Override configured batch size") parser.add_argument("--device-override", default="", help="Override configured device, for example cpu or cuda") parser.add_argument("--torch-dtype-override", default="", help="Override configured torch dtype, for example float32 or float16") + parser.add_argument("--max-new-tokens", type=int, default=0, help="Override configured max_new_tokens") + parser.add_argument("--num-beams", type=int, default=0, help="Override configured num_beams") + parser.add_argument("--attn-implementation", default="", help="Override attention implementation, for example sdpa") parser.add_argument("--warmup-batches", type=int, default=1, help="Warmup batches before measuring") return parser.parse_args() @@ -155,6 +158,12 @@ def benchmark_single_scenario(args: argparse.Namespace) -> Dict[str, Any]: capability["torch_dtype"] = args.torch_dtype_override if args.batch_size: capability["batch_size"] = args.batch_size + if args.max_new_tokens: + capability["max_new_tokens"] = args.max_new_tokens + if args.num_beams: + capability["num_beams"] = args.num_beams + if args.attn_implementation: + capability["attn_implementation"] = args.attn_implementation config["capabilities"][args.model] = capability configured_batch_size = int(capability.get("batch_size") or 1) batch_size = configured_batch_size @@ -296,6 +305,12 @@ def run_all_scenarios(args: argparse.Namespace) -> Dict[str, Any]: cmd.extend(["--device-override", args.device_override]) if args.torch_dtype_override: cmd.extend(["--torch-dtype-override", args.torch_dtype_override]) + if args.max_new_tokens: + cmd.extend(["--max-new-tokens", str(args.max_new_tokens)]) + if args.num_beams: + cmd.extend(["--num-beams", str(args.num_beams)]) + if args.attn_implementation: + cmd.extend(["--attn-implementation", args.attn_implementation]) completed = subprocess.run(cmd, capture_output=True, text=True, check=True) result_line = "" diff --git a/translation/README.md b/translation/README.md index 485b04a..275a8e7 100644 --- a/translation/README.md +++ b/translation/README.md @@ -17,14 +17,13 @@ ## 1. 设计目标 -翻译模块已经从旧的 provider 体系中独立出来,采用: +翻译模块采用: - 一个 translator service - 多个 capability backend - 一个统一外部接口:`model + scene` 这套设计的目标是: -- 业务侧不再关心具体翻译 provider 细节 - 翻译能力可以独立扩展、独立启停 - scene、语言码、prompt 模板、模型方向约束等翻译域知识集中在 `translation/` - 配置尽量集中在 [`config/config.yaml`](/data/saas-search/config/config.yaml) 的 `services.translation` @@ -108,10 +107,11 @@ services: model_dir: "./models/translation/facebook/nllb-200-distilled-600M" device: "cuda" torch_dtype: "float16" - batch_size: 8 + batch_size: 16 max_input_length: 256 - max_new_tokens: 256 + max_new_tokens: 64 num_beams: 1 + attn_implementation: "sdpa" opus-mt-zh-en: enabled: true backend: "local_marian" @@ -332,23 +332,31 @@ results = translator.translate( - 本地目录:`models/translation/facebook/nllb-200-distilled-600M` - 当前磁盘占用:约 `2.4G` - 模型类型:多语种 Seq2Seq 机器翻译模型 +- 来源:Meta NLLB(No Language Left Behind)系列的 600M 蒸馏版 +- 目标:用一个模型覆盖大规模多语言互译,而不是只服务某一个固定语言对 - 结构特点: - - encoder-decoder 架构 - - 面向多语种互译 - - 通过语言码控制源语言和目标语言 + - Transformer encoder-decoder 架构 + - 12 层 encoder + 12 层 decoder + - `d_model=1024` + - 多头注意力,适合多语统一建模 + - 通过 `source_lang + forced_bos_token_id` 控制翻译方向 + - 语言标识采用 `language_script` 形式,例如 `eng_Latn`、`zho_Hans` + +模型定位: +- 优势是多语覆盖面广,一个模型可以支撑很多语言方向 +- 劣势是相较于 Marian 这种双语专用模型,推理更重、延迟更高 +- 在我们当前业务里,它更适合“多语覆盖优先”的场景,不适合拿来和专用中英模型拼极致吞吐 + +显存占用情况: +- 600M模型半float16权重约1.25G,推理时会叠加 CUDA context、allocator reserve、激活张量、batch、输入长度、生成长度等开销 +- 当前这台 `Tesla T4` 上,优化后的实际运行峰值大约在 `2.8-3.0 GiB` 当前实现特点: - backend 类型:`local_nllb` - 支持多语 - 调用时必须显式传 `source_lang` - 语言码映射定义在 [`translation/languages.py`](/data/saas-search/translation/languages.py) - -适合场景: -- 需要多语覆盖 -- 需要一个模型处理多语言对 - -不太适合: -- 当前共享 GPU 环境下的常驻在线服务 +- 当前 T4 推荐配置:`device=cuda`、`torch_dtype=float16`、`batch_size=16`、`max_new_tokens=64`、`attn_implementation=sdpa` ### 8.5 `opus-mt-zh-en` @@ -424,6 +432,7 @@ models/translation/Helsinki-NLP/opus-mt-en-zh - 避免多 worker 重复加载模型 - GPU 机器上优先使用 `cuda + float16` - CPU 只建议用于功能验证或离线低频任务 +- 对 NLLB,T4 上优先采用 `batch_size=16 + max_new_tokens=64 + attn_implementation=sdpa` ### 9.5 验证 @@ -479,21 +488,56 @@ cd /data/saas-search - Python env:`.venv-translator` - 数据量:`18,576` 条商品标题 -性能结果摘要: +最终性能结果: | Model | Direction | Device | Rows | Load s | Translate s | Items/s | Avg item ms | Batch p50 ms | Batch p95 ms | |---|---|---:|---:|---:|---:|---:|---:|---:|---:| | `opus-mt-zh-en` | `zh -> en` | `cuda` | 18,576 | 3.1435 | 497.7513 | 37.32 | 26.795 | 301.99 | 1835.81 | | `opus-mt-en-zh` | `en -> zh` | `cuda` | 18,576 | 3.1867 | 987.3994 | 18.81 | 53.155 | 449.14 | 2012.12 | -| `nllb-200-distilled-600m` | `zh -> en` | `cpu` | 128 | 4.4589 | 132.3088 | 0.97 | 1033.662 | 3853.39 | 6896.14 | -| `nllb-200-distilled-600m` | `en -> zh` | `cpu` | 128 | 4.5039 | 317.8845 | 0.40 | 2483.473 | 6138.87 | 35134.11 | +| `nllb-200-distilled-600m` | `zh -> en` | `cuda` | 500 | 7.3397 | 25.9577 | 19.26 | 51.915 | 832.64 | 1263.01 | +| `nllb-200-distilled-600m` | `en -> zh` | `cuda` | 500 | 7.4152 | 42.0405 | 11.89 | 84.081 | 1093.87 | 2107.44 | + +NLLB 性能优化经验: + +- 起作用的优化点 1:`float16 + cuda` + - 模型确认以 `torch.float16` 实际加载到 `cuda:0` + - 优化后在 T4 上的峰值显存约 `2.8-3.0 GiB` +- 起作用的优化点 2:`batch_size=16` + - 相比 `batch_size=8`,吞吐提升明显 + - 继续提升到 `32` 虽然还能增吞吐,但 batch p95 和 batch max 会恶化很多 +- 起作用的优化点 3:`max_new_tokens=64` + - 商品标题翻译通常不需要 `256` 的生成上限 + - 收紧生成长度后,`zh->en` 与 `en->zh` 都有明显收益 +- 起作用的优化点 4:`attn_implementation=sdpa` + - 对当前 PyTorch + T4 环境有效 + - 配合半精度和较合理 batch size 后,整体延迟进一步下降 + +为什么最终没有采用其它方案: + +- 当前 HF 原生方案已经能在 T4 上稳定跑通 +- 在 `10G+` 可用显存下,原生 `float16` 已足够支撑 NLLB-600M +- 因此暂时不需要为这个模型额外引入 GGUF 或 CT2 的新运行栈 +- 如果未来目标变成“继续压缩显存”或“进一步追求更低延迟”,再评估 `ct2-int8` 会更合适 关键结论: - 当前机器上,`opus-mt-zh-en` 是三个新增本地模型里最快的 - `opus-mt-en-zh` 大约是 `opus-mt-zh-en` 吞吐的一半 -- `nllb-200-distilled-600M` 在当前共享 T4 环境下无法完成 GPU 冷启动,会 OOM -- `nllb` 的 CPU fallback 可用,但明显更慢,更适合隔离部署或离线任务 +- `nllb-200-distilled-600M` 在显存充足时可以用 `cuda + float16 + batch_size=16 + max_new_tokens=64 + sdpa` 正常运行 +- `nllb` 最终可用,但吞吐仍明显低于两个 Marian 模型,更适合多语覆盖或独立资源环境 + +最终推荐部署方案: + +- 模型:`facebook/nllb-200-distilled-600M` +- 设备:`cuda` +- 精度:`float16` +- 推荐卡型:至少 `Tesla T4 16GB` 这一级别 +- 推荐 batch:`16` +- 推荐 `max_input_length`:`256` +- 推荐 `max_new_tokens`:`64` +- 推荐 `num_beams`:`1` +- 推荐注意力实现:`sdpa` +- 运行方式:单 worker,避免重复加载 更详细的性能说明见: - [`perf_reports/20260317/translation_local_models/README.md`](/data/saas-search/perf_reports/20260317/translation_local_models/README.md) diff --git a/translation/backends/local_seq2seq.py b/translation/backends/local_seq2seq.py index 5ef9475..b5109cd 100644 --- a/translation/backends/local_seq2seq.py +++ b/translation/backends/local_seq2seq.py @@ -50,6 +50,7 @@ class LocalSeq2SeqTranslationBackend: max_input_length: int, max_new_tokens: int, num_beams: int, + attn_implementation: Optional[str] = None, ) -> None: self.model = name self.model_id = model_id @@ -60,6 +61,7 @@ class LocalSeq2SeqTranslationBackend: self.max_input_length = int(max_input_length) self.max_new_tokens = int(max_new_tokens) self.num_beams = int(num_beams) + self.attn_implementation = str(attn_implementation or "").strip() or None self._lock = threading.Lock() self._load_model() @@ -92,6 +94,9 @@ class LocalSeq2SeqTranslationBackend: kwargs: Dict[str, object] = {} if self.torch_dtype is not None: kwargs["dtype"] = self.torch_dtype + kwargs["low_cpu_mem_usage"] = True + if self.attn_implementation: + kwargs["attn_implementation"] = self.attn_implementation return kwargs def _normalize_texts(self, text: Union[str, Sequence[str]]) -> List[str]: @@ -178,6 +183,7 @@ class MarianMTTranslationBackend(LocalSeq2SeqTranslationBackend): num_beams: int, source_langs: Sequence[str], target_langs: Sequence[str], + attn_implementation: Optional[str] = None, ) -> None: self.source_langs = {str(lang).strip().lower() for lang in source_langs if str(lang).strip()} self.target_langs = {str(lang).strip().lower() for lang in target_langs if str(lang).strip()} @@ -191,6 +197,7 @@ class MarianMTTranslationBackend(LocalSeq2SeqTranslationBackend): max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, + attn_implementation=attn_implementation, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: @@ -222,6 +229,7 @@ class NLLBTranslationBackend(LocalSeq2SeqTranslationBackend): max_new_tokens: int, num_beams: int, language_codes: Optional[Dict[str, str]] = None, + attn_implementation: Optional[str] = None, ) -> None: overrides = language_codes or {} self.language_codes = { @@ -238,6 +246,7 @@ class NLLBTranslationBackend(LocalSeq2SeqTranslationBackend): max_input_length=max_input_length, max_new_tokens=max_new_tokens, num_beams=num_beams, + attn_implementation=attn_implementation, ) def _validate_languages(self, source_lang: Optional[str], target_lang: str) -> None: diff --git a/translation/service.py b/translation/service.py index 91ba0de..f0ed6a0 100644 --- a/translation/service.py +++ b/translation/service.py @@ -105,6 +105,7 @@ class TranslationService: max_input_length=int(cfg["max_input_length"]), max_new_tokens=int(cfg["max_new_tokens"]), num_beams=int(cfg["num_beams"]), + attn_implementation=cfg.get("attn_implementation"), ) def _create_local_marian_backend(self, *, name: str, cfg: Dict[str, object]) -> TranslationBackendProtocol: @@ -124,6 +125,7 @@ class TranslationService: num_beams=int(cfg["num_beams"]), source_langs=[source_lang], target_langs=[target_lang], + attn_implementation=cfg.get("attn_implementation"), ) @property -- libgit2 0.21.2