Commit 4747e2f40cb1e479219bcc25f420216e04d20092
1 parent
14e67b71
embedding performance
The instability is very likely real overload, but `lsof -i :6005 | wc -l = 75` alone does not prove it. What does matter is the live shape of the service: it is a single `uvicorn` worker on port `6005`, and the code had one shared process handling both text and image requests, with image work serialized behind a single lock. Under bursty image traffic, requests could pile up and sit blocked with almost no useful tracing, which matches the “only blocking observed” symptom. now adds persistent log files, request IDs, per-request request/response/failure logs, text microbatch dispatch logs, health stats with active/rejected counts, and explicit overload admission control. New knobs are `TEXT_MAX_INFLIGHT`, `IMAGE_MAX_INFLIGHT`, and `EMBEDDING_OVERLOAD_STATUS_CODE`. Startup output now shows those limits and log paths in [scripts/start_embedding_service.sh](/data/saas-search/scripts/start_embedding_service.sh#L80). I also added focused tests in [tests/test_embedding_service_limits.py](/data/saas-search/tests/test_embedding_service_limits.py#L1). What this means operationally: - Text and image are still in one process, so this is not the final architecture. - But image spikes will now be rejected quickly once the image lane is full instead of sitting around and consuming the worker pool. - Logs will now show each request, each rejection, each microbatch dispatch, backend time, response time, and request ID. Verification: - Passed: `.venv/bin/python -m pytest -q tests/test_embedding_service_limits.py` - I also ran a wider test command, but 3 failures came from pre-existing drift in [tests/test_embedding_pipeline.py](/data/saas-search/tests/test_embedding_pipeline.py#L95), where the tests still monkeypatch `embeddings.text_encoder.redis.Redis` even though [embeddings/text_encoder.py](/data/saas-search/embeddings/text_encoder.py#L1) no longer imports `redis` that way. 已把 CLIP_AS_SERVICE 的默认模型切到 ViT-L-14,并把这套配置收口成可变更的统一入口了。现在默认值在 embeddings/config.py (line 29) 的 CLIP_AS_SERVICE_MODEL_NAME,当前为 CN-CLIP/ViT-L-14;scripts/start_cnclip_service.sh (line 37) 会自动读取这个配置,不再把默认模型写死在脚本里,同时支持 CNCLIP_MODEL_NAME 和 --model-name 临时覆盖。scripts/start_embedding_service.sh (line 29) 和 embeddings/server.py (line 425) 也补了模型信息输出,方便排查实际连接的配置。 文档也一起更新了,重点在 docs/CNCLIP_SERVICE说明文档.md (line 62) 和 embeddings/README.md (line 58):现在说明的是“以配置为准、可覆盖”的机制,而不是写死某个模型名;相关总结文档和内部说明也同步改成了配置驱动表述。
Showing
15 changed files
with
956 additions
and
106 deletions
Show diff stats
CLAUDE.md
| @@ -195,7 +195,7 @@ The system uses centralized configuration through `config/config.yaml`: | @@ -195,7 +195,7 @@ The system uses centralized configuration through `config/config.yaml`: | ||
| 195 | - Configurable caching to avoid recomputation | 195 | - Configurable caching to avoid recomputation |
| 196 | 196 | ||
| 197 | **Image Embedding** (`embeddings/clip_encoder.py`): | 197 | **Image Embedding** (`embeddings/clip_encoder.py`): |
| 198 | -- Uses CN-CLIP model (ViT-H-14) | 198 | +- Uses a configurable CN-CLIP model (default managed in `embeddings/config.py`) |
| 199 | - Downloads and preprocesses images from URLs | 199 | - Downloads and preprocesses images from URLs |
| 200 | - Supports both local and remote image processing | 200 | - Supports both local and remote image processing |
| 201 | - Generates 1024-dimensional vectors | 201 | - Generates 1024-dimensional vectors |
| @@ -563,7 +563,7 @@ GET /admin/stats # Index statistics | @@ -563,7 +563,7 @@ GET /admin/stats # Index statistics | ||
| 563 | - **Usage**: Semantic search combined with BM25 relevance | 563 | - **Usage**: Semantic search combined with BM25 relevance |
| 564 | 564 | ||
| 565 | **Image Search Pipeline**: | 565 | **Image Search Pipeline**: |
| 566 | -- **Model**: CN-CLIP (ViT-H-14) | 566 | +- **Model**: CN-CLIP (configured in `embeddings/config.py`) |
| 567 | - **Processing**: URL download → preprocessing → vectorization | 567 | - **Processing**: URL download → preprocessing → vectorization |
| 568 | - **Storage**: Nested structure with vector + original URL | 568 | - **Storage**: Nested structure with vector + original URL |
| 569 | - **Application**: Visual similarity search for products | 569 | - **Application**: Visual similarity search for products |
config/config.yaml
| @@ -148,7 +148,7 @@ services: | @@ -148,7 +148,7 @@ services: | ||
| 148 | ct2_decoding_length_min: 32 | 148 | ct2_decoding_length_min: 32 |
| 149 | device: "cuda" | 149 | device: "cuda" |
| 150 | torch_dtype: "float16" | 150 | torch_dtype: "float16" |
| 151 | - batch_size: 16 | 151 | + batch_size: 64 |
| 152 | max_input_length: 256 | 152 | max_input_length: 256 |
| 153 | max_new_tokens: 64 | 153 | max_new_tokens: 64 |
| 154 | num_beams: 1 | 154 | num_beams: 1 |
docs/CNCLIP_SERVICE说明文档.md
| @@ -59,7 +59,26 @@ cd /data/saas-search | @@ -59,7 +59,26 @@ cd /data/saas-search | ||
| 59 | ./scripts/start_cnclip_service.sh --device cpu | 59 | ./scripts/start_cnclip_service.sh --device cpu |
| 60 | ``` | 60 | ``` |
| 61 | 61 | ||
| 62 | -### 5.3 停止服务 | 62 | +### 5.3 模型配置与覆盖 |
| 63 | + | ||
| 64 | +- 仓库默认模型取自 `embeddings/config.py` 的 `CLIP_AS_SERVICE_MODEL_NAME`。 | ||
| 65 | +- `scripts/start_cnclip_service.sh` 会自动读取这个配置,因此修改默认模型时不需要再去脚本里找硬编码。 | ||
| 66 | +- 覆盖优先级:`--model-name` > `CNCLIP_MODEL_NAME` > `CLIP_AS_SERVICE_MODEL_NAME` / `embeddings/config.py`。 | ||
| 67 | + | ||
| 68 | +查看当前默认模型: | ||
| 69 | + | ||
| 70 | +```bash | ||
| 71 | +python3 -c "from embeddings.config import CONFIG; print(CONFIG.CLIP_AS_SERVICE_MODEL_NAME)" | ||
| 72 | +``` | ||
| 73 | + | ||
| 74 | +临时覆盖模型: | ||
| 75 | + | ||
| 76 | +```bash | ||
| 77 | +./scripts/start_cnclip_service.sh --model-name CN-CLIP/ViT-L-14 | ||
| 78 | +CNCLIP_MODEL_NAME=CN-CLIP/ViT-H-14 ./scripts/service_ctl.sh start cnclip | ||
| 79 | +``` | ||
| 80 | + | ||
| 81 | +### 5.4 停止服务 | ||
| 63 | 82 | ||
| 64 | ```bash | 83 | ```bash |
| 65 | ./scripts/stop_cnclip_service.sh | 84 | ./scripts/stop_cnclip_service.sh |
| @@ -110,6 +129,7 @@ cat third-party/clip-as-service/server/torch-flow-temp.yml | @@ -110,6 +129,7 @@ cat third-party/clip-as-service/server/torch-flow-temp.yml | ||
| 110 | 129 | ||
| 111 | - GPU 模式:`device: 'cuda'` | 130 | - GPU 模式:`device: 'cuda'` |
| 112 | - CPU 模式:`device: 'cpu'` | 131 | - CPU 模式:`device: 'cpu'` |
| 132 | +- 模型名:`name: '<当前实际模型名>'` | ||
| 113 | 133 | ||
| 114 | ### 7.2.1 日志与 PID 文件 | 134 | ### 7.2.1 日志与 PID 文件 |
| 115 | 135 |
docs/工作总结-微服务性能优化与架构.md
| @@ -67,13 +67,13 @@ instruction: "Given a shopping query, rank product titles by relevance" | @@ -67,13 +67,13 @@ instruction: "Given a shopping query, rank product titles by relevance" | ||
| 67 | 67 | ||
| 68 | ### 3. 图片向量(Image Embedding) | 68 | ### 3. 图片向量(Image Embedding) |
| 69 | 69 | ||
| 70 | -**方案**:**clip-as-service**(CN-CLIP,ViT-H-14),由独立服务提供图片向量化能力。 | 70 | +**方案**:**clip-as-service**(CN-CLIP,模型由配置控制),由独立服务提供图片向量化能力。 |
| 71 | 71 | ||
| 72 | **具体内容**: | 72 | **具体内容**: |
| 73 | - **端口**:clip-as-service 默认 **51000**(`CNCLIP_PORT`);文本走 TEI(8080),图片走 clip-as-service。 | 73 | - **端口**:clip-as-service 默认 **51000**(`CNCLIP_PORT`);文本走 TEI(8080),图片走 clip-as-service。 |
| 74 | - **API**:embedding 服务(6005)统一暴露 `POST /embed/text` 与 `POST /embed/image`;图片请求由 `embeddings/server.py` 按配置调用实现 `ImageEncoderProtocol` 的后端(clip-as-service 或本地 CN-CLIP)。 | 74 | - **API**:embedding 服务(6005)统一暴露 `POST /embed/text` 与 `POST /embed/image`;图片请求由 `embeddings/server.py` 按配置调用实现 `ImageEncoderProtocol` 的后端(clip-as-service 或本地 CN-CLIP)。 |
| 75 | - **环境与启停**:CN-CLIP 使用独立虚拟环境 `.venv-cnclip`;启动 `scripts/start_cnclip_service.sh`,或 `./scripts/service_ctl.sh start cnclip`;设备可通过 `CNCLIP_DEVICE=cuda`(默认)或 `cpu` 指定。 | 75 | - **环境与启停**:CN-CLIP 使用独立虚拟环境 `.venv-cnclip`;启动 `scripts/start_cnclip_service.sh`,或 `./scripts/service_ctl.sh start cnclip`;设备可通过 `CNCLIP_DEVICE=cuda`(默认)或 `cpu` 指定。 |
| 76 | -- **配置**:图片后端在 `config/config.yaml` 的 `services.embedding` 下配置(若存在 image 相关 backend);clip-as-service 的 flow 配置在 `third-party/clip-as-service/server/torch-flow-temp.yml`。 | 76 | +- **配置**:图片后端在 `config/config.yaml` 的 `services.embedding` 下配置(若存在 image 相关 backend);clip-as-service 默认模型由 `embeddings/config.py` 的 `CLIP_AS_SERVICE_MODEL_NAME` 控制,flow 配置在 `third-party/clip-as-service/server/torch-flow-temp.yml`。 |
| 77 | 77 | ||
| 78 | 详见:`docs/CNCLIP_SERVICE说明文档.md`、`embeddings/README.md`。 | 78 | 详见:`docs/CNCLIP_SERVICE说明文档.md`、`embeddings/README.md`。 |
| 79 | 79 |
embeddings/README.md
| @@ -58,6 +58,8 @@ | @@ -58,6 +58,8 @@ | ||
| 58 | 3. **配置**(`embeddings/config.py` 或环境变量): | 58 | 3. **配置**(`embeddings/config.py` 或环境变量): |
| 59 | - `USE_CLIP_AS_SERVICE=true`(默认) | 59 | - `USE_CLIP_AS_SERVICE=true`(默认) |
| 60 | - `CLIP_AS_SERVICE_SERVER=grpc://127.0.0.1:51000` | 60 | - `CLIP_AS_SERVICE_SERVER=grpc://127.0.0.1:51000` |
| 61 | + - `CLIP_AS_SERVICE_MODEL_NAME=CN-CLIP/ViT-L-14` | ||
| 62 | + - `scripts/start_cnclip_service.sh` 默认会读取同一个 `CLIP_AS_SERVICE_MODEL_NAME`,也可用 `CNCLIP_MODEL_NAME` 或 `--model-name` 临时覆盖 | ||
| 61 | 63 | ||
| 62 | ### 启动服务 | 64 | ### 启动服务 |
| 63 | 65 | ||
| @@ -80,6 +82,6 @@ TEI_DEVICE=cpu ./scripts/start_tei_service.sh | @@ -80,6 +82,6 @@ TEI_DEVICE=cpu ./scripts/start_tei_service.sh | ||
| 80 | - `PORT`: 服务端口(默认 6005) | 82 | - `PORT`: 服务端口(默认 6005) |
| 81 | - `TEXT_MODEL_ID`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`, `TEXT_NORMALIZE_EMBEDDINGS` | 83 | - `TEXT_MODEL_ID`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE`, `TEXT_NORMALIZE_EMBEDDINGS` |
| 82 | - `IMAGE_NORMALIZE_EMBEDDINGS`(默认 true) | 84 | - `IMAGE_NORMALIZE_EMBEDDINGS`(默认 true) |
| 83 | -- `USE_CLIP_AS_SERVICE`, `CLIP_AS_SERVICE_SERVER`:图片向量(clip-as-service) | 85 | +- `USE_CLIP_AS_SERVICE`, `CLIP_AS_SERVICE_SERVER`, `CLIP_AS_SERVICE_MODEL_NAME`:图片向量(clip-as-service) |
| 84 | - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`:本地 CN-CLIP(当 `USE_CLIP_AS_SERVICE=false` 时) | 86 | - `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`:本地 CN-CLIP(当 `USE_CLIP_AS_SERVICE=false` 时) |
| 85 | - TEI 相关:`TEI_DEVICE`、`TEI_VERSION`、`TEI_MAX_BATCH_TOKENS`、`TEI_MAX_CLIENT_BATCH_SIZE`、`TEI_HEALTH_TIMEOUT_SEC` | 87 | - TEI 相关:`TEI_DEVICE`、`TEI_VERSION`、`TEI_MAX_BATCH_TOKENS`、`TEI_MAX_CLIENT_BATCH_SIZE`、`TEI_HEALTH_TIMEOUT_SEC` |
embeddings/clip_model.py
| @@ -16,7 +16,7 @@ from cn_clip.clip import load_from_name | @@ -16,7 +16,7 @@ from cn_clip.clip import load_from_name | ||
| 16 | import cn_clip.clip as clip | 16 | import cn_clip.clip as clip |
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | -DEFAULT_MODEL_NAME = "ViT-H-14" | 19 | +DEFAULT_MODEL_NAME = "ViT-L-14" # "ViT-H-14", "ViT-L-14-336" |
| 20 | MODEL_DOWNLOAD_DIR = "/data/" | 20 | MODEL_DOWNLOAD_DIR = "/data/" |
| 21 | 21 | ||
| 22 | 22 |
embeddings/config.py
| @@ -30,9 +30,10 @@ class EmbeddingConfig(object): | @@ -30,9 +30,10 @@ class EmbeddingConfig(object): | ||
| 30 | # Option A: clip-as-service (Jina CLIP server, recommended) | 30 | # Option A: clip-as-service (Jina CLIP server, recommended) |
| 31 | USE_CLIP_AS_SERVICE = os.getenv("USE_CLIP_AS_SERVICE", "true").lower() in ("1", "true", "yes") | 31 | USE_CLIP_AS_SERVICE = os.getenv("USE_CLIP_AS_SERVICE", "true").lower() in ("1", "true", "yes") |
| 32 | CLIP_AS_SERVICE_SERVER = os.getenv("CLIP_AS_SERVICE_SERVER", "grpc://127.0.0.1:51000") | 32 | CLIP_AS_SERVICE_SERVER = os.getenv("CLIP_AS_SERVICE_SERVER", "grpc://127.0.0.1:51000") |
| 33 | + CLIP_AS_SERVICE_MODEL_NAME = os.getenv("CLIP_AS_SERVICE_MODEL_NAME", "CN-CLIP/ViT-L-14") | ||
| 33 | 34 | ||
| 34 | # Option B: local CN-CLIP (when USE_CLIP_AS_SERVICE=false) | 35 | # Option B: local CN-CLIP (when USE_CLIP_AS_SERVICE=false) |
| 35 | - IMAGE_MODEL_NAME = "ViT-H-14" | 36 | + IMAGE_MODEL_NAME = os.getenv("IMAGE_MODEL_NAME", "ViT-L-14") |
| 36 | IMAGE_DEVICE = None # type: Optional[str] # "cuda" / "cpu" / None(auto) | 37 | IMAGE_DEVICE = None # type: Optional[str] # "cuda" / "cpu" / None(auto) |
| 37 | 38 | ||
| 38 | # Service behavior | 39 | # Service behavior |
embeddings/server.py
| @@ -8,23 +8,100 @@ API (simple list-in, list-out; aligned by index): | @@ -8,23 +8,100 @@ API (simple list-in, list-out; aligned by index): | ||
| 8 | 8 | ||
| 9 | import logging | 9 | import logging |
| 10 | import os | 10 | import os |
| 11 | +import pathlib | ||
| 11 | import threading | 12 | import threading |
| 12 | import time | 13 | import time |
| 14 | +import uuid | ||
| 13 | from collections import deque | 15 | from collections import deque |
| 14 | from dataclasses import dataclass | 16 | from dataclasses import dataclass |
| 17 | +from logging.handlers import TimedRotatingFileHandler | ||
| 15 | from typing import Any, Dict, List, Optional | 18 | from typing import Any, Dict, List, Optional |
| 16 | 19 | ||
| 17 | import numpy as np | 20 | import numpy as np |
| 18 | -from fastapi import FastAPI, HTTPException | 21 | +from fastapi import FastAPI, HTTPException, Request, Response |
| 22 | +from fastapi.concurrency import run_in_threadpool | ||
| 19 | 23 | ||
| 24 | +from config.services_config import get_embedding_backend_config | ||
| 20 | from embeddings.config import CONFIG | 25 | from embeddings.config import CONFIG |
| 21 | from embeddings.protocols import ImageEncoderProtocol | 26 | from embeddings.protocols import ImageEncoderProtocol |
| 22 | -from config.services_config import get_embedding_backend_config | ||
| 23 | - | ||
| 24 | -logger = logging.getLogger(__name__) | ||
| 25 | 27 | ||
| 26 | app = FastAPI(title="saas-search Embedding Service", version="1.0.0") | 28 | app = FastAPI(title="saas-search Embedding Service", version="1.0.0") |
| 27 | 29 | ||
| 30 | + | ||
| 31 | +class _DefaultRequestIdFilter(logging.Filter): | ||
| 32 | + def filter(self, record: logging.LogRecord) -> bool: | ||
| 33 | + if not hasattr(record, "reqid"): | ||
| 34 | + record.reqid = "-1" | ||
| 35 | + return True | ||
| 36 | + | ||
| 37 | + | ||
| 38 | +def configure_embedding_logging() -> None: | ||
| 39 | + root_logger = logging.getLogger() | ||
| 40 | + if getattr(root_logger, "_embedding_logging_configured", False): | ||
| 41 | + return | ||
| 42 | + | ||
| 43 | + log_dir = pathlib.Path("logs") | ||
| 44 | + verbose_dir = log_dir / "verbose" | ||
| 45 | + log_dir.mkdir(exist_ok=True) | ||
| 46 | + verbose_dir.mkdir(parents=True, exist_ok=True) | ||
| 47 | + | ||
| 48 | + log_level = os.getenv("LOG_LEVEL", "INFO").upper() | ||
| 49 | + numeric_level = getattr(logging, log_level, logging.INFO) | ||
| 50 | + formatter = logging.Formatter( | ||
| 51 | + "%(asctime)s | reqid:%(reqid)s | %(name)s | %(levelname)s | %(message)s" | ||
| 52 | + ) | ||
| 53 | + request_filter = _DefaultRequestIdFilter() | ||
| 54 | + | ||
| 55 | + root_logger.setLevel(numeric_level) | ||
| 56 | + | ||
| 57 | + file_handler = TimedRotatingFileHandler( | ||
| 58 | + filename=log_dir / "embedding_api.log", | ||
| 59 | + when="midnight", | ||
| 60 | + interval=1, | ||
| 61 | + backupCount=30, | ||
| 62 | + encoding="utf-8", | ||
| 63 | + ) | ||
| 64 | + file_handler.setLevel(numeric_level) | ||
| 65 | + file_handler.setFormatter(formatter) | ||
| 66 | + file_handler.addFilter(request_filter) | ||
| 67 | + root_logger.addHandler(file_handler) | ||
| 68 | + | ||
| 69 | + error_handler = TimedRotatingFileHandler( | ||
| 70 | + filename=log_dir / "embedding_api_error.log", | ||
| 71 | + when="midnight", | ||
| 72 | + interval=1, | ||
| 73 | + backupCount=30, | ||
| 74 | + encoding="utf-8", | ||
| 75 | + ) | ||
| 76 | + error_handler.setLevel(logging.ERROR) | ||
| 77 | + error_handler.setFormatter(formatter) | ||
| 78 | + error_handler.addFilter(request_filter) | ||
| 79 | + root_logger.addHandler(error_handler) | ||
| 80 | + | ||
| 81 | + verbose_logger = logging.getLogger("embedding.verbose") | ||
| 82 | + verbose_logger.setLevel(numeric_level) | ||
| 83 | + verbose_logger.handlers.clear() | ||
| 84 | + verbose_logger.propagate = False | ||
| 85 | + | ||
| 86 | + verbose_handler = TimedRotatingFileHandler( | ||
| 87 | + filename=verbose_dir / "embedding_verbose.log", | ||
| 88 | + when="midnight", | ||
| 89 | + interval=1, | ||
| 90 | + backupCount=30, | ||
| 91 | + encoding="utf-8", | ||
| 92 | + ) | ||
| 93 | + verbose_handler.setLevel(numeric_level) | ||
| 94 | + verbose_handler.setFormatter(formatter) | ||
| 95 | + verbose_handler.addFilter(request_filter) | ||
| 96 | + verbose_logger.addHandler(verbose_handler) | ||
| 97 | + | ||
| 98 | + root_logger._embedding_logging_configured = True # type: ignore[attr-defined] | ||
| 99 | + | ||
| 100 | + | ||
| 101 | +configure_embedding_logging() | ||
| 102 | +logger = logging.getLogger(__name__) | ||
| 103 | +verbose_logger = logging.getLogger("embedding.verbose") | ||
| 104 | + | ||
| 28 | # Models are loaded at startup, not lazily | 105 | # Models are loaded at startup, not lazily |
| 29 | _text_model: Optional[Any] = None | 106 | _text_model: Optional[Any] = None |
| 30 | _image_model: Optional[ImageEncoderProtocol] = None | 107 | _image_model: Optional[ImageEncoderProtocol] = None |
| @@ -35,12 +112,78 @@ open_image_model = os.getenv("EMBEDDING_ENABLE_IMAGE_MODEL", "true").lower() in | @@ -35,12 +112,78 @@ open_image_model = os.getenv("EMBEDDING_ENABLE_IMAGE_MODEL", "true").lower() in | ||
| 35 | _text_encode_lock = threading.Lock() | 112 | _text_encode_lock = threading.Lock() |
| 36 | _image_encode_lock = threading.Lock() | 113 | _image_encode_lock = threading.Lock() |
| 37 | 114 | ||
| 115 | +_TEXT_MICROBATCH_WINDOW_SEC = max( | ||
| 116 | + 0.0, float(os.getenv("TEXT_MICROBATCH_WINDOW_MS", "4")) / 1000.0 | ||
| 117 | +) | ||
| 118 | +_TEXT_REQUEST_TIMEOUT_SEC = max( | ||
| 119 | + 1.0, float(os.getenv("TEXT_REQUEST_TIMEOUT_SEC", "30")) | ||
| 120 | +) | ||
| 121 | +_TEXT_MAX_INFLIGHT = max(1, int(os.getenv("TEXT_MAX_INFLIGHT", "32"))) | ||
| 122 | +_IMAGE_MAX_INFLIGHT = max(1, int(os.getenv("IMAGE_MAX_INFLIGHT", "1"))) | ||
| 123 | +_OVERLOAD_STATUS_CODE = int(os.getenv("EMBEDDING_OVERLOAD_STATUS_CODE", "503")) | ||
| 124 | +_LOG_PREVIEW_COUNT = max(1, int(os.getenv("EMBEDDING_LOG_PREVIEW_COUNT", "3"))) | ||
| 125 | +_LOG_TEXT_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_TEXT_PREVIEW_CHARS", "120"))) | ||
| 126 | +_LOG_IMAGE_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_IMAGE_PREVIEW_CHARS", "180"))) | ||
| 127 | +_VECTOR_PREVIEW_DIMS = max(1, int(os.getenv("EMBEDDING_VECTOR_PREVIEW_DIMS", "6"))) | ||
| 128 | + | ||
| 129 | + | ||
| 130 | +class _InflightLimiter: | ||
| 131 | + def __init__(self, name: str, limit: int): | ||
| 132 | + self.name = name | ||
| 133 | + self.limit = max(1, int(limit)) | ||
| 134 | + self._sem = threading.BoundedSemaphore(self.limit) | ||
| 135 | + self._lock = threading.Lock() | ||
| 136 | + self._active = 0 | ||
| 137 | + self._rejected = 0 | ||
| 138 | + self._completed = 0 | ||
| 139 | + self._failed = 0 | ||
| 140 | + self._max_active = 0 | ||
| 141 | + | ||
| 142 | + def try_acquire(self) -> tuple[bool, int]: | ||
| 143 | + if not self._sem.acquire(blocking=False): | ||
| 144 | + with self._lock: | ||
| 145 | + self._rejected += 1 | ||
| 146 | + active = self._active | ||
| 147 | + return False, active | ||
| 148 | + with self._lock: | ||
| 149 | + self._active += 1 | ||
| 150 | + self._max_active = max(self._max_active, self._active) | ||
| 151 | + active = self._active | ||
| 152 | + return True, active | ||
| 153 | + | ||
| 154 | + def release(self, *, success: bool) -> int: | ||
| 155 | + with self._lock: | ||
| 156 | + self._active = max(0, self._active - 1) | ||
| 157 | + if success: | ||
| 158 | + self._completed += 1 | ||
| 159 | + else: | ||
| 160 | + self._failed += 1 | ||
| 161 | + active = self._active | ||
| 162 | + self._sem.release() | ||
| 163 | + return active | ||
| 164 | + | ||
| 165 | + def snapshot(self) -> Dict[str, int]: | ||
| 166 | + with self._lock: | ||
| 167 | + return { | ||
| 168 | + "limit": self.limit, | ||
| 169 | + "active": self._active, | ||
| 170 | + "rejected_total": self._rejected, | ||
| 171 | + "completed_total": self._completed, | ||
| 172 | + "failed_total": self._failed, | ||
| 173 | + "max_active": self._max_active, | ||
| 174 | + } | ||
| 175 | + | ||
| 176 | + | ||
| 177 | +_text_request_limiter = _InflightLimiter(name="text", limit=_TEXT_MAX_INFLIGHT) | ||
| 178 | +_image_request_limiter = _InflightLimiter(name="image", limit=_IMAGE_MAX_INFLIGHT) | ||
| 179 | + | ||
| 38 | 180 | ||
| 39 | @dataclass | 181 | @dataclass |
| 40 | class _SingleTextTask: | 182 | class _SingleTextTask: |
| 41 | text: str | 183 | text: str |
| 42 | normalize: bool | 184 | normalize: bool |
| 43 | created_at: float | 185 | created_at: float |
| 186 | + request_id: str | ||
| 44 | done: threading.Event | 187 | done: threading.Event |
| 45 | result: Optional[List[float]] = None | 188 | result: Optional[List[float]] = None |
| 46 | error: Optional[Exception] = None | 189 | error: Optional[Exception] = None |
| @@ -50,15 +193,6 @@ _text_single_queue: "deque[_SingleTextTask]" = deque() | @@ -50,15 +193,6 @@ _text_single_queue: "deque[_SingleTextTask]" = deque() | ||
| 50 | _text_single_queue_cv = threading.Condition() | 193 | _text_single_queue_cv = threading.Condition() |
| 51 | _text_batch_worker: Optional[threading.Thread] = None | 194 | _text_batch_worker: Optional[threading.Thread] = None |
| 52 | _text_batch_worker_stop = False | 195 | _text_batch_worker_stop = False |
| 53 | -_TEXT_MICROBATCH_WINDOW_SEC = max( | ||
| 54 | - 0.0, float(os.getenv("TEXT_MICROBATCH_WINDOW_MS", "4")) / 1000.0 | ||
| 55 | -) | ||
| 56 | -_TEXT_REQUEST_TIMEOUT_SEC = max( | ||
| 57 | - 1.0, float(os.getenv("TEXT_REQUEST_TIMEOUT_SEC", "30")) | ||
| 58 | -) | ||
| 59 | -_LOG_PREVIEW_COUNT = max(1, int(os.getenv("EMBEDDING_LOG_PREVIEW_COUNT", "3"))) | ||
| 60 | -_LOG_TEXT_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_TEXT_PREVIEW_CHARS", "120"))) | ||
| 61 | -_LOG_IMAGE_PREVIEW_CHARS = max(32, int(os.getenv("EMBEDDING_LOG_IMAGE_PREVIEW_CHARS", "180"))) | ||
| 62 | 196 | ||
| 63 | 197 | ||
| 64 | def _compact_preview(text: str, max_chars: int) -> str: | 198 | def _compact_preview(text: str, max_chars: int) -> str: |
| @@ -81,6 +215,29 @@ def _preview_inputs(items: List[str], max_items: int, max_chars: int) -> List[Di | @@ -81,6 +215,29 @@ def _preview_inputs(items: List[str], max_items: int, max_chars: int) -> List[Di | ||
| 81 | return previews | 215 | return previews |
| 82 | 216 | ||
| 83 | 217 | ||
| 218 | +def _preview_vector(vec: Optional[List[float]], max_dims: int = _VECTOR_PREVIEW_DIMS) -> List[float]: | ||
| 219 | + if not vec: | ||
| 220 | + return [] | ||
| 221 | + return [round(float(v), 6) for v in vec[:max_dims]] | ||
| 222 | + | ||
| 223 | + | ||
| 224 | +def _request_log_extra(request_id: str) -> Dict[str, str]: | ||
| 225 | + return {"reqid": request_id} | ||
| 226 | + | ||
| 227 | + | ||
| 228 | +def _resolve_request_id(http_request: Request) -> str: | ||
| 229 | + header_value = http_request.headers.get("X-Request-ID") | ||
| 230 | + if header_value and header_value.strip(): | ||
| 231 | + return header_value.strip()[:32] | ||
| 232 | + return str(uuid.uuid4())[:8] | ||
| 233 | + | ||
| 234 | + | ||
| 235 | +def _request_client(http_request: Request) -> str: | ||
| 236 | + client = getattr(http_request, "client", None) | ||
| 237 | + host = getattr(client, "host", None) | ||
| 238 | + return str(host or "-") | ||
| 239 | + | ||
| 240 | + | ||
| 84 | def _encode_local_st(texts: List[str], normalize_embeddings: bool) -> Any: | 241 | def _encode_local_st(texts: List[str], normalize_embeddings: bool) -> Any: |
| 85 | with _text_encode_lock: | 242 | with _text_encode_lock: |
| 86 | return _text_model.encode( | 243 | return _text_model.encode( |
| @@ -139,6 +296,21 @@ def _text_batch_worker_loop() -> None: | @@ -139,6 +296,21 @@ def _text_batch_worker_loop() -> None: | ||
| 139 | batch.append(_text_single_queue.popleft()) | 296 | batch.append(_text_single_queue.popleft()) |
| 140 | 297 | ||
| 141 | try: | 298 | try: |
| 299 | + queue_wait_ms = [(time.perf_counter() - task.created_at) * 1000.0 for task in batch] | ||
| 300 | + reqids = [task.request_id for task in batch] | ||
| 301 | + logger.info( | ||
| 302 | + "text microbatch dispatch | size=%d queue_wait_ms_min=%.2f queue_wait_ms_max=%.2f reqids=%s preview=%s", | ||
| 303 | + len(batch), | ||
| 304 | + min(queue_wait_ms) if queue_wait_ms else 0.0, | ||
| 305 | + max(queue_wait_ms) if queue_wait_ms else 0.0, | ||
| 306 | + reqids, | ||
| 307 | + _preview_inputs( | ||
| 308 | + [task.text for task in batch], | ||
| 309 | + _LOG_PREVIEW_COUNT, | ||
| 310 | + _LOG_TEXT_PREVIEW_CHARS, | ||
| 311 | + ), | ||
| 312 | + ) | ||
| 313 | + batch_t0 = time.perf_counter() | ||
| 142 | embs = _encode_local_st([task.text for task in batch], normalize_embeddings=False) | 314 | embs = _encode_local_st([task.text for task in batch], normalize_embeddings=False) |
| 143 | if embs is None or len(embs) != len(batch): | 315 | if embs is None or len(embs) != len(batch): |
| 144 | raise RuntimeError( | 316 | raise RuntimeError( |
| @@ -150,7 +322,21 @@ def _text_batch_worker_loop() -> None: | @@ -150,7 +322,21 @@ def _text_batch_worker_loop() -> None: | ||
| 150 | if vec is None: | 322 | if vec is None: |
| 151 | raise RuntimeError("Text model returned empty embedding in micro-batch") | 323 | raise RuntimeError("Text model returned empty embedding in micro-batch") |
| 152 | task.result = vec | 324 | task.result = vec |
| 325 | + logger.info( | ||
| 326 | + "text microbatch done | size=%d reqids=%s dim=%d backend_elapsed_ms=%.2f", | ||
| 327 | + len(batch), | ||
| 328 | + reqids, | ||
| 329 | + len(batch[0].result) if batch and batch[0].result is not None else 0, | ||
| 330 | + (time.perf_counter() - batch_t0) * 1000.0, | ||
| 331 | + ) | ||
| 153 | except Exception as exc: | 332 | except Exception as exc: |
| 333 | + logger.error( | ||
| 334 | + "text microbatch failed | size=%d reqids=%s error=%s", | ||
| 335 | + len(batch), | ||
| 336 | + [task.request_id for task in batch], | ||
| 337 | + exc, | ||
| 338 | + exc_info=True, | ||
| 339 | + ) | ||
| 154 | for task in batch: | 340 | for task in batch: |
| 155 | task.error = exc | 341 | task.error = exc |
| 156 | finally: | 342 | finally: |
| @@ -158,11 +344,12 @@ def _text_batch_worker_loop() -> None: | @@ -158,11 +344,12 @@ def _text_batch_worker_loop() -> None: | ||
| 158 | task.done.set() | 344 | task.done.set() |
| 159 | 345 | ||
| 160 | 346 | ||
| 161 | -def _encode_single_text_with_microbatch(text: str, normalize: bool) -> List[float]: | 347 | +def _encode_single_text_with_microbatch(text: str, normalize: bool, request_id: str) -> List[float]: |
| 162 | task = _SingleTextTask( | 348 | task = _SingleTextTask( |
| 163 | text=text, | 349 | text=text, |
| 164 | normalize=normalize, | 350 | normalize=normalize, |
| 165 | created_at=time.perf_counter(), | 351 | created_at=time.perf_counter(), |
| 352 | + request_id=request_id, | ||
| 166 | done=threading.Event(), | 353 | done=threading.Event(), |
| 167 | ) | 354 | ) |
| 168 | with _text_single_queue_cv: | 355 | with _text_single_queue_cv: |
| @@ -192,7 +379,6 @@ def load_models(): | @@ -192,7 +379,6 @@ def load_models(): | ||
| 192 | 379 | ||
| 193 | logger.info("Loading embedding models at startup...") | 380 | logger.info("Loading embedding models at startup...") |
| 194 | 381 | ||
| 195 | - # Load text model | ||
| 196 | if open_text_model: | 382 | if open_text_model: |
| 197 | try: | 383 | try: |
| 198 | backend_name, backend_cfg = get_embedding_backend_config() | 384 | backend_name, backend_cfg = get_embedding_backend_config() |
| @@ -233,17 +419,19 @@ def load_models(): | @@ -233,17 +419,19 @@ def load_models(): | ||
| 233 | ) | 419 | ) |
| 234 | logger.info("Text backend loaded successfully: %s", _text_backend_name) | 420 | logger.info("Text backend loaded successfully: %s", _text_backend_name) |
| 235 | except Exception as e: | 421 | except Exception as e: |
| 236 | - logger.error(f"Failed to load text model: {e}", exc_info=True) | 422 | + logger.error("Failed to load text model: %s", e, exc_info=True) |
| 237 | raise | 423 | raise |
| 238 | - | ||
| 239 | 424 | ||
| 240 | - # Load image model: clip-as-service (recommended) or local CN-CLIP | ||
| 241 | if open_image_model: | 425 | if open_image_model: |
| 242 | try: | 426 | try: |
| 243 | if CONFIG.USE_CLIP_AS_SERVICE: | 427 | if CONFIG.USE_CLIP_AS_SERVICE: |
| 244 | from embeddings.clip_as_service_encoder import ClipAsServiceImageEncoder | 428 | from embeddings.clip_as_service_encoder import ClipAsServiceImageEncoder |
| 245 | 429 | ||
| 246 | - logger.info(f"Loading image encoder via clip-as-service: {CONFIG.CLIP_AS_SERVICE_SERVER}") | 430 | + logger.info( |
| 431 | + "Loading image encoder via clip-as-service: %s (configured model: %s)", | ||
| 432 | + CONFIG.CLIP_AS_SERVICE_SERVER, | ||
| 433 | + CONFIG.CLIP_AS_SERVICE_MODEL_NAME, | ||
| 434 | + ) | ||
| 247 | _image_model = ClipAsServiceImageEncoder( | 435 | _image_model = ClipAsServiceImageEncoder( |
| 248 | server=CONFIG.CLIP_AS_SERVICE_SERVER, | 436 | server=CONFIG.CLIP_AS_SERVICE_SERVER, |
| 249 | batch_size=CONFIG.IMAGE_BATCH_SIZE, | 437 | batch_size=CONFIG.IMAGE_BATCH_SIZE, |
| @@ -252,7 +440,11 @@ def load_models(): | @@ -252,7 +440,11 @@ def load_models(): | ||
| 252 | else: | 440 | else: |
| 253 | from embeddings.clip_model import ClipImageModel | 441 | from embeddings.clip_model import ClipImageModel |
| 254 | 442 | ||
| 255 | - logger.info(f"Loading local image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})") | 443 | + logger.info( |
| 444 | + "Loading local image model: %s (device: %s)", | ||
| 445 | + CONFIG.IMAGE_MODEL_NAME, | ||
| 446 | + CONFIG.IMAGE_DEVICE, | ||
| 447 | + ) | ||
| 256 | _image_model = ClipImageModel( | 448 | _image_model = ClipImageModel( |
| 257 | model_name=CONFIG.IMAGE_MODEL_NAME, | 449 | model_name=CONFIG.IMAGE_MODEL_NAME, |
| 258 | device=CONFIG.IMAGE_DEVICE, | 450 | device=CONFIG.IMAGE_DEVICE, |
| @@ -292,55 +484,56 @@ def _as_list(embedding: Optional[np.ndarray], normalize: bool = False) -> Option | @@ -292,55 +484,56 @@ def _as_list(embedding: Optional[np.ndarray], normalize: bool = False) -> Option | ||
| 292 | 484 | ||
| 293 | @app.get("/health") | 485 | @app.get("/health") |
| 294 | def health() -> Dict[str, Any]: | 486 | def health() -> Dict[str, Any]: |
| 295 | - """Health check endpoint. Returns status and model loading state.""" | 487 | + """Health check endpoint. Returns status and current throttling stats.""" |
| 296 | return { | 488 | return { |
| 297 | "status": "ok", | 489 | "status": "ok", |
| 298 | "text_model_loaded": _text_model is not None, | 490 | "text_model_loaded": _text_model is not None, |
| 299 | "text_backend": _text_backend_name, | 491 | "text_backend": _text_backend_name, |
| 300 | "image_model_loaded": _image_model is not None, | 492 | "image_model_loaded": _image_model is not None, |
| 493 | + "limits": { | ||
| 494 | + "text": _text_request_limiter.snapshot(), | ||
| 495 | + "image": _image_request_limiter.snapshot(), | ||
| 496 | + }, | ||
| 497 | + "text_microbatch": { | ||
| 498 | + "window_ms": round(_TEXT_MICROBATCH_WINDOW_SEC * 1000.0, 3), | ||
| 499 | + "queue_depth": len(_text_single_queue), | ||
| 500 | + "worker_alive": bool(_text_batch_worker is not None and _text_batch_worker.is_alive()), | ||
| 501 | + "request_timeout_sec": _TEXT_REQUEST_TIMEOUT_SEC, | ||
| 502 | + }, | ||
| 301 | } | 503 | } |
| 302 | 504 | ||
| 303 | 505 | ||
| 304 | -@app.post("/embed/text") | ||
| 305 | -def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]: | 506 | +def _embed_text_impl( |
| 507 | + normalized: List[str], | ||
| 508 | + effective_normalize: bool, | ||
| 509 | + request_id: str, | ||
| 510 | +) -> List[Optional[List[float]]]: | ||
| 306 | if _text_model is None: | 511 | if _text_model is None: |
| 307 | raise RuntimeError("Text model not loaded") | 512 | raise RuntimeError("Text model not loaded") |
| 308 | - effective_normalize = bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | ||
| 309 | - normalized: List[str] = [] | ||
| 310 | - for i, t in enumerate(texts): | ||
| 311 | - if not isinstance(t, str): | ||
| 312 | - raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: must be string") | ||
| 313 | - s = t.strip() | ||
| 314 | - if not s: | ||
| 315 | - raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: empty string") | ||
| 316 | - normalized.append(s) | ||
| 317 | - | ||
| 318 | - logger.info( | ||
| 319 | - "embed_text request | backend=%s inputs=%d normalize=%s preview=%s", | ||
| 320 | - _text_backend_name, | ||
| 321 | - len(normalized), | ||
| 322 | - effective_normalize, | ||
| 323 | - _preview_inputs(normalized, _LOG_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), | ||
| 324 | - ) | ||
| 325 | 513 | ||
| 326 | t0 = time.perf_counter() | 514 | t0 = time.perf_counter() |
| 327 | try: | 515 | try: |
| 328 | - # local_st backend uses in-process torch model, keep serialized encode for safety; | ||
| 329 | - # TEI backend is an HTTP client and supports concurrent requests. | ||
| 330 | if _text_backend_name == "local_st": | 516 | if _text_backend_name == "local_st": |
| 331 | if len(normalized) == 1 and _text_batch_worker is not None: | 517 | if len(normalized) == 1 and _text_batch_worker is not None: |
| 332 | - out = [_encode_single_text_with_microbatch(normalized[0], normalize=effective_normalize)] | ||
| 333 | - elapsed_ms = (time.perf_counter() - t0) * 1000.0 | 518 | + out = [ |
| 519 | + _encode_single_text_with_microbatch( | ||
| 520 | + normalized[0], | ||
| 521 | + normalize=effective_normalize, | ||
| 522 | + request_id=request_id, | ||
| 523 | + ) | ||
| 524 | + ] | ||
| 334 | logger.info( | 525 | logger.info( |
| 335 | - "embed_text done | backend=%s mode=microbatch-single inputs=%d normalize=%s dim=%d elapsed_ms=%.2f", | 526 | + "text backend done | backend=%s mode=microbatch-single inputs=%d normalize=%s dim=%d backend_elapsed_ms=%.2f", |
| 336 | _text_backend_name, | 527 | _text_backend_name, |
| 337 | len(normalized), | 528 | len(normalized), |
| 338 | effective_normalize, | 529 | effective_normalize, |
| 339 | len(out[0]) if out and out[0] is not None else 0, | 530 | len(out[0]) if out and out[0] is not None else 0, |
| 340 | - elapsed_ms, | 531 | + (time.perf_counter() - t0) * 1000.0, |
| 532 | + extra=_request_log_extra(request_id), | ||
| 341 | ) | 533 | ) |
| 342 | return out | 534 | return out |
| 343 | embs = _encode_local_st(normalized, normalize_embeddings=False) | 535 | embs = _encode_local_st(normalized, normalize_embeddings=False) |
| 536 | + mode = "direct-batch" | ||
| 344 | else: | 537 | else: |
| 345 | embs = _text_model.encode( | 538 | embs = _text_model.encode( |
| 346 | normalized, | 539 | normalized, |
| @@ -348,55 +541,154 @@ def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optio | @@ -348,55 +541,154 @@ def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optio | ||
| 348 | device=CONFIG.TEXT_DEVICE, | 541 | device=CONFIG.TEXT_DEVICE, |
| 349 | normalize_embeddings=effective_normalize, | 542 | normalize_embeddings=effective_normalize, |
| 350 | ) | 543 | ) |
| 544 | + mode = "backend-batch" | ||
| 351 | except Exception as e: | 545 | except Exception as e: |
| 352 | - logger.error("Text embedding backend failure: %s", e, exc_info=True) | ||
| 353 | - raise HTTPException( | ||
| 354 | - status_code=502, | ||
| 355 | - detail=f"Text embedding backend failure: {e}", | ||
| 356 | - ) from e | 546 | + logger.error( |
| 547 | + "Text embedding backend failure: %s", | ||
| 548 | + e, | ||
| 549 | + exc_info=True, | ||
| 550 | + extra=_request_log_extra(request_id), | ||
| 551 | + ) | ||
| 552 | + raise RuntimeError(f"Text embedding backend failure: {e}") from e | ||
| 553 | + | ||
| 357 | if embs is None or len(embs) != len(normalized): | 554 | if embs is None or len(embs) != len(normalized): |
| 358 | raise RuntimeError( | 555 | raise RuntimeError( |
| 359 | f"Text model response length mismatch: expected {len(normalized)}, " | 556 | f"Text model response length mismatch: expected {len(normalized)}, " |
| 360 | f"got {0 if embs is None else len(embs)}" | 557 | f"got {0 if embs is None else len(embs)}" |
| 361 | ) | 558 | ) |
| 559 | + | ||
| 362 | out: List[Optional[List[float]]] = [] | 560 | out: List[Optional[List[float]]] = [] |
| 363 | for i, emb in enumerate(embs): | 561 | for i, emb in enumerate(embs): |
| 364 | vec = _as_list(emb, normalize=effective_normalize) | 562 | vec = _as_list(emb, normalize=effective_normalize) |
| 365 | if vec is None: | 563 | if vec is None: |
| 366 | raise RuntimeError(f"Text model returned empty embedding for index {i}") | 564 | raise RuntimeError(f"Text model returned empty embedding for index {i}") |
| 367 | out.append(vec) | 565 | out.append(vec) |
| 368 | - elapsed_ms = (time.perf_counter() - t0) * 1000.0 | 566 | + |
| 369 | logger.info( | 567 | logger.info( |
| 370 | - "embed_text done | backend=%s inputs=%d normalize=%s dim=%d elapsed_ms=%.2f", | 568 | + "text backend done | backend=%s mode=%s inputs=%d normalize=%s dim=%d backend_elapsed_ms=%.2f", |
| 371 | _text_backend_name, | 569 | _text_backend_name, |
| 570 | + mode, | ||
| 372 | len(normalized), | 571 | len(normalized), |
| 373 | effective_normalize, | 572 | effective_normalize, |
| 374 | len(out[0]) if out and out[0] is not None else 0, | 573 | len(out[0]) if out and out[0] is not None else 0, |
| 375 | - elapsed_ms, | 574 | + (time.perf_counter() - t0) * 1000.0, |
| 575 | + extra=_request_log_extra(request_id), | ||
| 376 | ) | 576 | ) |
| 377 | return out | 577 | return out |
| 378 | 578 | ||
| 379 | 579 | ||
| 380 | -@app.post("/embed/image") | ||
| 381 | -def embed_image(images: List[str], normalize: Optional[bool] = None) -> List[Optional[List[float]]]: | ||
| 382 | - if _image_model is None: | ||
| 383 | - raise RuntimeError("Image model not loaded") | ||
| 384 | - effective_normalize = bool(CONFIG.IMAGE_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | ||
| 385 | - urls: List[str] = [] | ||
| 386 | - for i, url_or_path in enumerate(images): | ||
| 387 | - if not isinstance(url_or_path, str): | ||
| 388 | - raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: must be string URL/path") | ||
| 389 | - s = url_or_path.strip() | 580 | +@app.post("/embed/text") |
| 581 | +async def embed_text( | ||
| 582 | + texts: List[str], | ||
| 583 | + http_request: Request, | ||
| 584 | + response: Response, | ||
| 585 | + normalize: Optional[bool] = None, | ||
| 586 | +) -> List[Optional[List[float]]]: | ||
| 587 | + request_id = _resolve_request_id(http_request) | ||
| 588 | + response.headers["X-Request-ID"] = request_id | ||
| 589 | + | ||
| 590 | + effective_normalize = bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | ||
| 591 | + normalized: List[str] = [] | ||
| 592 | + for i, t in enumerate(texts): | ||
| 593 | + if not isinstance(t, str): | ||
| 594 | + raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: must be string") | ||
| 595 | + s = t.strip() | ||
| 390 | if not s: | 596 | if not s: |
| 391 | - raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: empty URL/path") | ||
| 392 | - urls.append(s) | 597 | + raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: empty string") |
| 598 | + normalized.append(s) | ||
| 393 | 599 | ||
| 394 | - logger.info( | ||
| 395 | - "embed_image request | inputs=%d normalize=%s preview=%s", | ||
| 396 | - len(urls), | ||
| 397 | - effective_normalize, | ||
| 398 | - _preview_inputs(urls, _LOG_PREVIEW_COUNT, _LOG_IMAGE_PREVIEW_CHARS), | ||
| 399 | - ) | 600 | + accepted, active = _text_request_limiter.try_acquire() |
| 601 | + if not accepted: | ||
| 602 | + logger.warning( | ||
| 603 | + "embed_text rejected | client=%s backend=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", | ||
| 604 | + _request_client(http_request), | ||
| 605 | + _text_backend_name, | ||
| 606 | + len(normalized), | ||
| 607 | + effective_normalize, | ||
| 608 | + active, | ||
| 609 | + _TEXT_MAX_INFLIGHT, | ||
| 610 | + _preview_inputs(normalized, _LOG_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), | ||
| 611 | + extra=_request_log_extra(request_id), | ||
| 612 | + ) | ||
| 613 | + raise HTTPException( | ||
| 614 | + status_code=_OVERLOAD_STATUS_CODE, | ||
| 615 | + detail=f"Text embedding service busy: active={active}, limit={_TEXT_MAX_INFLIGHT}", | ||
| 616 | + ) | ||
| 617 | + | ||
| 618 | + request_started = time.perf_counter() | ||
| 619 | + success = False | ||
| 620 | + try: | ||
| 621 | + logger.info( | ||
| 622 | + "embed_text request | client=%s backend=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", | ||
| 623 | + _request_client(http_request), | ||
| 624 | + _text_backend_name, | ||
| 625 | + len(normalized), | ||
| 626 | + effective_normalize, | ||
| 627 | + active, | ||
| 628 | + _TEXT_MAX_INFLIGHT, | ||
| 629 | + _preview_inputs(normalized, _LOG_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), | ||
| 630 | + extra=_request_log_extra(request_id), | ||
| 631 | + ) | ||
| 632 | + verbose_logger.info( | ||
| 633 | + "embed_text detail | payload=%s normalize=%s backend=%s", | ||
| 634 | + normalized, | ||
| 635 | + effective_normalize, | ||
| 636 | + _text_backend_name, | ||
| 637 | + extra=_request_log_extra(request_id), | ||
| 638 | + ) | ||
| 639 | + out = await run_in_threadpool(_embed_text_impl, normalized, effective_normalize, request_id) | ||
| 640 | + success = True | ||
| 641 | + latency_ms = (time.perf_counter() - request_started) * 1000.0 | ||
| 642 | + logger.info( | ||
| 643 | + "embed_text response | backend=%s inputs=%d normalize=%s dim=%d first_vector=%s latency_ms=%.2f", | ||
| 644 | + _text_backend_name, | ||
| 645 | + len(normalized), | ||
| 646 | + effective_normalize, | ||
| 647 | + len(out[0]) if out and out[0] is not None else 0, | ||
| 648 | + _preview_vector(out[0] if out else None), | ||
| 649 | + latency_ms, | ||
| 650 | + extra=_request_log_extra(request_id), | ||
| 651 | + ) | ||
| 652 | + verbose_logger.info( | ||
| 653 | + "embed_text result detail | count=%d first_vector=%s latency_ms=%.2f", | ||
| 654 | + len(out), | ||
| 655 | + out[0][: _VECTOR_PREVIEW_DIMS] if out and out[0] is not None else [], | ||
| 656 | + latency_ms, | ||
| 657 | + extra=_request_log_extra(request_id), | ||
| 658 | + ) | ||
| 659 | + return out | ||
| 660 | + except HTTPException: | ||
| 661 | + raise | ||
| 662 | + except Exception as e: | ||
| 663 | + latency_ms = (time.perf_counter() - request_started) * 1000.0 | ||
| 664 | + logger.error( | ||
| 665 | + "embed_text failed | backend=%s inputs=%d normalize=%s latency_ms=%.2f error=%s", | ||
| 666 | + _text_backend_name, | ||
| 667 | + len(normalized), | ||
| 668 | + effective_normalize, | ||
| 669 | + latency_ms, | ||
| 670 | + e, | ||
| 671 | + exc_info=True, | ||
| 672 | + extra=_request_log_extra(request_id), | ||
| 673 | + ) | ||
| 674 | + raise HTTPException(status_code=502, detail=str(e)) from e | ||
| 675 | + finally: | ||
| 676 | + remaining = _text_request_limiter.release(success=success) | ||
| 677 | + logger.info( | ||
| 678 | + "embed_text finalize | success=%s active_after=%d", | ||
| 679 | + success, | ||
| 680 | + remaining, | ||
| 681 | + extra=_request_log_extra(request_id), | ||
| 682 | + ) | ||
| 683 | + | ||
| 684 | + | ||
| 685 | +def _embed_image_impl( | ||
| 686 | + urls: List[str], | ||
| 687 | + effective_normalize: bool, | ||
| 688 | + request_id: str, | ||
| 689 | +) -> List[Optional[List[float]]]: | ||
| 690 | + if _image_model is None: | ||
| 691 | + raise RuntimeError("Image model not loaded") | ||
| 400 | 692 | ||
| 401 | t0 = time.perf_counter() | 693 | t0 = time.perf_counter() |
| 402 | with _image_encode_lock: | 694 | with _image_encode_lock: |
| @@ -410,18 +702,120 @@ def embed_image(images: List[str], normalize: Optional[bool] = None) -> List[Opt | @@ -410,18 +702,120 @@ def embed_image(images: List[str], normalize: Optional[bool] = None) -> List[Opt | ||
| 410 | f"Image model response length mismatch: expected {len(urls)}, " | 702 | f"Image model response length mismatch: expected {len(urls)}, " |
| 411 | f"got {0 if vectors is None else len(vectors)}" | 703 | f"got {0 if vectors is None else len(vectors)}" |
| 412 | ) | 704 | ) |
| 705 | + | ||
| 413 | out: List[Optional[List[float]]] = [] | 706 | out: List[Optional[List[float]]] = [] |
| 414 | for i, vec in enumerate(vectors): | 707 | for i, vec in enumerate(vectors): |
| 415 | out_vec = _as_list(vec, normalize=effective_normalize) | 708 | out_vec = _as_list(vec, normalize=effective_normalize) |
| 416 | if out_vec is None: | 709 | if out_vec is None: |
| 417 | raise RuntimeError(f"Image model returned empty embedding for index {i}") | 710 | raise RuntimeError(f"Image model returned empty embedding for index {i}") |
| 418 | out.append(out_vec) | 711 | out.append(out_vec) |
| 419 | - elapsed_ms = (time.perf_counter() - t0) * 1000.0 | 712 | + |
| 420 | logger.info( | 713 | logger.info( |
| 421 | - "embed_image done | inputs=%d normalize=%s dim=%d elapsed_ms=%.2f", | 714 | + "image backend done | inputs=%d normalize=%s dim=%d backend_elapsed_ms=%.2f", |
| 422 | len(urls), | 715 | len(urls), |
| 423 | effective_normalize, | 716 | effective_normalize, |
| 424 | len(out[0]) if out and out[0] is not None else 0, | 717 | len(out[0]) if out and out[0] is not None else 0, |
| 425 | - elapsed_ms, | 718 | + (time.perf_counter() - t0) * 1000.0, |
| 719 | + extra=_request_log_extra(request_id), | ||
| 426 | ) | 720 | ) |
| 427 | return out | 721 | return out |
| 722 | + | ||
| 723 | + | ||
| 724 | +@app.post("/embed/image") | ||
| 725 | +async def embed_image( | ||
| 726 | + images: List[str], | ||
| 727 | + http_request: Request, | ||
| 728 | + response: Response, | ||
| 729 | + normalize: Optional[bool] = None, | ||
| 730 | +) -> List[Optional[List[float]]]: | ||
| 731 | + request_id = _resolve_request_id(http_request) | ||
| 732 | + response.headers["X-Request-ID"] = request_id | ||
| 733 | + | ||
| 734 | + effective_normalize = bool(CONFIG.IMAGE_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | ||
| 735 | + urls: List[str] = [] | ||
| 736 | + for i, url_or_path in enumerate(images): | ||
| 737 | + if not isinstance(url_or_path, str): | ||
| 738 | + raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: must be string URL/path") | ||
| 739 | + s = url_or_path.strip() | ||
| 740 | + if not s: | ||
| 741 | + raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: empty URL/path") | ||
| 742 | + urls.append(s) | ||
| 743 | + | ||
| 744 | + accepted, active = _image_request_limiter.try_acquire() | ||
| 745 | + if not accepted: | ||
| 746 | + logger.warning( | ||
| 747 | + "embed_image rejected | client=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", | ||
| 748 | + _request_client(http_request), | ||
| 749 | + len(urls), | ||
| 750 | + effective_normalize, | ||
| 751 | + active, | ||
| 752 | + _IMAGE_MAX_INFLIGHT, | ||
| 753 | + _preview_inputs(urls, _LOG_PREVIEW_COUNT, _LOG_IMAGE_PREVIEW_CHARS), | ||
| 754 | + extra=_request_log_extra(request_id), | ||
| 755 | + ) | ||
| 756 | + raise HTTPException( | ||
| 757 | + status_code=_OVERLOAD_STATUS_CODE, | ||
| 758 | + detail=f"Image embedding service busy: active={active}, limit={_IMAGE_MAX_INFLIGHT}", | ||
| 759 | + ) | ||
| 760 | + | ||
| 761 | + request_started = time.perf_counter() | ||
| 762 | + success = False | ||
| 763 | + try: | ||
| 764 | + logger.info( | ||
| 765 | + "embed_image request | client=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", | ||
| 766 | + _request_client(http_request), | ||
| 767 | + len(urls), | ||
| 768 | + effective_normalize, | ||
| 769 | + active, | ||
| 770 | + _IMAGE_MAX_INFLIGHT, | ||
| 771 | + _preview_inputs(urls, _LOG_PREVIEW_COUNT, _LOG_IMAGE_PREVIEW_CHARS), | ||
| 772 | + extra=_request_log_extra(request_id), | ||
| 773 | + ) | ||
| 774 | + verbose_logger.info( | ||
| 775 | + "embed_image detail | payload=%s normalize=%s", | ||
| 776 | + urls, | ||
| 777 | + effective_normalize, | ||
| 778 | + extra=_request_log_extra(request_id), | ||
| 779 | + ) | ||
| 780 | + out = await run_in_threadpool(_embed_image_impl, urls, effective_normalize, request_id) | ||
| 781 | + success = True | ||
| 782 | + latency_ms = (time.perf_counter() - request_started) * 1000.0 | ||
| 783 | + logger.info( | ||
| 784 | + "embed_image response | inputs=%d normalize=%s dim=%d first_vector=%s latency_ms=%.2f", | ||
| 785 | + len(urls), | ||
| 786 | + effective_normalize, | ||
| 787 | + len(out[0]) if out and out[0] is not None else 0, | ||
| 788 | + _preview_vector(out[0] if out else None), | ||
| 789 | + latency_ms, | ||
| 790 | + extra=_request_log_extra(request_id), | ||
| 791 | + ) | ||
| 792 | + verbose_logger.info( | ||
| 793 | + "embed_image result detail | count=%d first_vector=%s latency_ms=%.2f", | ||
| 794 | + len(out), | ||
| 795 | + out[0][: _VECTOR_PREVIEW_DIMS] if out and out[0] is not None else [], | ||
| 796 | + latency_ms, | ||
| 797 | + extra=_request_log_extra(request_id), | ||
| 798 | + ) | ||
| 799 | + return out | ||
| 800 | + except HTTPException: | ||
| 801 | + raise | ||
| 802 | + except Exception as e: | ||
| 803 | + latency_ms = (time.perf_counter() - request_started) * 1000.0 | ||
| 804 | + logger.error( | ||
| 805 | + "embed_image failed | inputs=%d normalize=%s latency_ms=%.2f error=%s", | ||
| 806 | + len(urls), | ||
| 807 | + effective_normalize, | ||
| 808 | + latency_ms, | ||
| 809 | + e, | ||
| 810 | + exc_info=True, | ||
| 811 | + extra=_request_log_extra(request_id), | ||
| 812 | + ) | ||
| 813 | + raise HTTPException(status_code=502, detail=f"Image embedding backend failure: {e}") from e | ||
| 814 | + finally: | ||
| 815 | + remaining = _image_request_limiter.release(success=success) | ||
| 816 | + logger.info( | ||
| 817 | + "embed_image finalize | success=%s active_after=%d", | ||
| 818 | + success, | ||
| 819 | + remaining, | ||
| 820 | + extra=_request_log_extra(request_id), | ||
| 821 | + ) |
perf_reports/20260319/nllb_t4_longtext_reassessment.md
0 → 100644
| @@ -0,0 +1,97 @@ | @@ -0,0 +1,97 @@ | ||
| 1 | +# NLLB T4 Long-Text Reassessment | ||
| 2 | + | ||
| 3 | +Date: 2026-03-19 | ||
| 4 | +Model: `nllb-200-distilled-600m` | ||
| 5 | +Backend: `CTranslate2 + float16` | ||
| 6 | +Direction: `zh -> en` | ||
| 7 | + | ||
| 8 | +## Why This Reassessment Exists | ||
| 9 | + | ||
| 10 | +Earlier notes mixed two different ideas: | ||
| 11 | + | ||
| 12 | +- `batch_size=64` was the highest-throughput point in the original product-title sweeps. | ||
| 13 | +- `batch_size=16` was only a more conservative default candidate when trying to balance throughput with tail latency for online use. | ||
| 14 | + | ||
| 15 | +That distinction was not carried forward clearly enough. We re-checked the current long-text segmented workload instead of reusing the product-title conclusion mechanically. | ||
| 16 | + | ||
| 17 | +## Current Long-Text Workload Observed in Logs | ||
| 18 | + | ||
| 19 | +The clearest apples-to-apples evidence came from repeated uncached requests of the same long Chinese input: | ||
| 20 | + | ||
| 21 | +- input length: about `3944` to `3966` chars | ||
| 22 | +- segmented into `60` pieces | ||
| 23 | +- target language: `en` | ||
| 24 | +- source language: `zh` | ||
| 25 | + | ||
| 26 | +### Log-Derived Comparison | ||
| 27 | + | ||
| 28 | +`batch_size=16` samples from [`logs/translator-2026-03-19.log`](/data/saas-search/logs/translator-2026-03-19.log): | ||
| 29 | + | ||
| 30 | +- `reqid=181f00ae` -> `1586.87 ms` | ||
| 31 | +- `reqid=d6c1213f` -> `1732.95 ms` | ||
| 32 | +- `reqid=26f8acd1` -> `4745.32 ms` | ||
| 33 | + | ||
| 34 | +`batch_size=64` samples from the same log: | ||
| 35 | + | ||
| 36 | +- `reqid=28262f1e` -> `752.96 ms` | ||
| 37 | +- `reqid=737fc848` -> `815.66 ms` | ||
| 38 | +- `reqid=8d05fa20` -> `835.25 ms` | ||
| 39 | +- `reqid=e29d2629` -> `3927.87 ms` | ||
| 40 | +- `reqid=c2b1df14` -> `4049.31 ms` | ||
| 41 | + | ||
| 42 | +### Summary | ||
| 43 | + | ||
| 44 | +For this `~3950 char / 60 segment` workload: | ||
| 45 | + | ||
| 46 | +- `batch_size=16` | ||
| 47 | + - median end-to-end latency: `1732.95 ms` | ||
| 48 | + - median `segmentation_summary -> response`: `1672 ms` | ||
| 49 | +- `batch_size=64` | ||
| 50 | + - median end-to-end latency: `835.25 ms` | ||
| 51 | + - median `segmentation_summary -> response`: `782 ms` | ||
| 52 | + | ||
| 53 | +This means the steady-state inference portion was cut by about half after moving from `16` to `64`. | ||
| 54 | + | ||
| 55 | +## Important Environment Finding | ||
| 56 | + | ||
| 57 | +This machine was not in an isolated benchmark state while re-checking: | ||
| 58 | + | ||
| 59 | +- the single T4 was shared with translator, embedding, CN-CLIP, and reranker processes | ||
| 60 | +- `nvidia-smi` showed about `15157 / 16384 MiB` in use during the re-check | ||
| 61 | + | ||
| 62 | +That explains the multi-second outliers in both the `16` and `64` groups. The outliers mainly appeared before the segmentation summary log, so they should be treated as shared-GPU contention noise, not pure model execution time. | ||
| 63 | + | ||
| 64 | +## Current Config Drift | ||
| 65 | + | ||
| 66 | +During this review, the live config had already been moved again to `batch_size=256`. | ||
| 67 | + | ||
| 68 | +That larger value is not yet backed by the same quality of evidence: | ||
| 69 | + | ||
| 70 | +- for `60` segments, `256` cannot improve on `64` in any meaningful way because both already fit the whole request into one inference batch | ||
| 71 | +- for much larger requests such as `11847` chars and `180` segments, `256` may help, but we do not yet have a clean isolated comparison against `64` | ||
| 72 | +- on a shared T4, larger batches also reduce memory headroom and make benchmarking less stable | ||
| 73 | + | ||
| 74 | +## Recommendation | ||
| 75 | + | ||
| 76 | +For the current shared-T4 deployment, keep the general NLLB default at: | ||
| 77 | + | ||
| 78 | +- `batch_size=64` | ||
| 79 | +- `ct2_inter_threads=4` | ||
| 80 | +- `ct2_max_queued_batches=32` | ||
| 81 | +- `ct2_batch_type=examples` | ||
| 82 | +- `max_new_tokens=64` | ||
| 83 | +- `ct2_decoding_length_mode=source` | ||
| 84 | +- `ct2_decoding_length_extra=8` | ||
| 85 | +- `ct2_decoding_length_min=32` | ||
| 86 | + | ||
| 87 | +Treat `batch_size=128` or `256` as workload-specific experiments, not as the default baseline. | ||
| 88 | + | ||
| 89 | +## Best Practices Going Forward | ||
| 90 | + | ||
| 91 | +- Benchmark long-text segmented translation separately from product-title translation. | ||
| 92 | +- Use uncached repeated requests with the same long sample when checking single-request latency. | ||
| 93 | +- Split latency analysis into: | ||
| 94 | + - `request -> segmentation summary` | ||
| 95 | + - `segmentation summary -> response` | ||
| 96 | +- Do not treat shared-GPU results as a clean config ranking. | ||
| 97 | +- Before promoting a larger batch like `128` or `256` to default, re-run in a translator-only GPU window. |
| @@ -0,0 +1,186 @@ | @@ -0,0 +1,186 @@ | ||
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +"""Benchmark a single long-text translation request for local models.""" | ||
| 3 | + | ||
| 4 | +from __future__ import annotations | ||
| 5 | + | ||
| 6 | +import argparse | ||
| 7 | +import copy | ||
| 8 | +import json | ||
| 9 | +import logging | ||
| 10 | +import statistics | ||
| 11 | +import time | ||
| 12 | +from pathlib import Path | ||
| 13 | + | ||
| 14 | +import torch | ||
| 15 | + | ||
| 16 | +PROJECT_ROOT = Path(__file__).resolve().parent.parent | ||
| 17 | + | ||
| 18 | +import sys | ||
| 19 | + | ||
| 20 | +if str(PROJECT_ROOT) not in sys.path: | ||
| 21 | + sys.path.insert(0, str(PROJECT_ROOT)) | ||
| 22 | + | ||
| 23 | +from config.services_config import get_translation_config # noqa: E402 | ||
| 24 | +from translation.service import TranslationService # noqa: E402 | ||
| 25 | +from translation.text_splitter import compute_safe_input_token_limit # noqa: E402 | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +def parse_args() -> argparse.Namespace: | ||
| 29 | + parser = argparse.ArgumentParser(description="Benchmark a long-text translation request") | ||
| 30 | + parser.add_argument("--model", default="nllb-200-distilled-600m") | ||
| 31 | + parser.add_argument("--source-lang", default="zh") | ||
| 32 | + parser.add_argument("--target-lang", default="en") | ||
| 33 | + parser.add_argument("--scene", default="sku_name") | ||
| 34 | + parser.add_argument("--source-md", default="docs/搜索API对接指南.md") | ||
| 35 | + parser.add_argument("--paragraph-min-chars", type=int, default=250) | ||
| 36 | + parser.add_argument("--target-doc-chars", type=int, default=4500) | ||
| 37 | + parser.add_argument("--min-doc-chars", type=int, default=2400) | ||
| 38 | + parser.add_argument("--runs", type=int, default=3) | ||
| 39 | + parser.add_argument("--batch-size", type=int, default=64) | ||
| 40 | + parser.add_argument("--ct2-inter-threads", type=int, default=4) | ||
| 41 | + parser.add_argument("--ct2-max-queued-batches", type=int, default=32) | ||
| 42 | + parser.add_argument("--ct2-batch-type", default="examples") | ||
| 43 | + parser.add_argument("--max-new-tokens", type=int, default=64) | ||
| 44 | + parser.add_argument("--ct2-decoding-length-mode", default="source") | ||
| 45 | + parser.add_argument("--ct2-decoding-length-extra", type=int, default=8) | ||
| 46 | + parser.add_argument("--ct2-decoding-length-min", type=int, default=32) | ||
| 47 | + return parser.parse_args() | ||
| 48 | + | ||
| 49 | + | ||
| 50 | +def build_long_document(args: argparse.Namespace) -> str: | ||
| 51 | + source_path = (PROJECT_ROOT / args.source_md).resolve() | ||
| 52 | + text = source_path.read_text(encoding="utf-8") | ||
| 53 | + paragraphs = [] | ||
| 54 | + for raw in text.split("\n\n"): | ||
| 55 | + normalized = " ".join(line.strip() for line in raw.splitlines() if line.strip()) | ||
| 56 | + if len(normalized) >= args.paragraph_min_chars and not normalized.startswith("```"): | ||
| 57 | + paragraphs.append(normalized) | ||
| 58 | + | ||
| 59 | + parts = [] | ||
| 60 | + total = 0 | ||
| 61 | + for paragraph in paragraphs: | ||
| 62 | + parts.append(paragraph) | ||
| 63 | + total += len(paragraph) + 2 | ||
| 64 | + if total >= args.target_doc_chars: | ||
| 65 | + break | ||
| 66 | + document = "\n\n".join(parts) | ||
| 67 | + if len(document) < args.min_doc_chars: | ||
| 68 | + raise ValueError( | ||
| 69 | + f"Prepared long document is too short: {len(document)} chars < {args.min_doc_chars}" | ||
| 70 | + ) | ||
| 71 | + return document | ||
| 72 | + | ||
| 73 | + | ||
| 74 | +def build_service(args: argparse.Namespace) -> TranslationService: | ||
| 75 | + config = copy.deepcopy(get_translation_config()) | ||
| 76 | + for name, capability in config["capabilities"].items(): | ||
| 77 | + capability["enabled"] = name == args.model | ||
| 78 | + | ||
| 79 | + capability = config["capabilities"][args.model] | ||
| 80 | + capability["use_cache"] = False | ||
| 81 | + capability["batch_size"] = args.batch_size | ||
| 82 | + capability["ct2_inter_threads"] = args.ct2_inter_threads | ||
| 83 | + capability["ct2_max_queued_batches"] = args.ct2_max_queued_batches | ||
| 84 | + capability["ct2_batch_type"] = args.ct2_batch_type | ||
| 85 | + capability["max_new_tokens"] = args.max_new_tokens | ||
| 86 | + capability["ct2_decoding_length_mode"] = args.ct2_decoding_length_mode | ||
| 87 | + capability["ct2_decoding_length_extra"] = args.ct2_decoding_length_extra | ||
| 88 | + capability["ct2_decoding_length_min"] = args.ct2_decoding_length_min | ||
| 89 | + config["default_model"] = args.model | ||
| 90 | + return TranslationService(config) | ||
| 91 | + | ||
| 92 | + | ||
| 93 | +def percentile(values: list[float], p: float) -> float: | ||
| 94 | + if not values: | ||
| 95 | + return 0.0 | ||
| 96 | + ordered = sorted(values) | ||
| 97 | + if len(ordered) == 1: | ||
| 98 | + return float(ordered[0]) | ||
| 99 | + index = min(len(ordered) - 1, max(0, round((len(ordered) - 1) * p))) | ||
| 100 | + return float(ordered[index]) | ||
| 101 | + | ||
| 102 | + | ||
| 103 | +def main() -> None: | ||
| 104 | + args = parse_args() | ||
| 105 | + logging.getLogger().setLevel(logging.WARNING) | ||
| 106 | + | ||
| 107 | + document = build_long_document(args) | ||
| 108 | + load_started = time.perf_counter() | ||
| 109 | + service = build_service(args) | ||
| 110 | + backend = service.get_backend(args.model) | ||
| 111 | + load_seconds = time.perf_counter() - load_started | ||
| 112 | + | ||
| 113 | + safe_input_limit = compute_safe_input_token_limit( | ||
| 114 | + max_input_length=backend.max_input_length, | ||
| 115 | + max_new_tokens=backend.max_new_tokens, | ||
| 116 | + decoding_length_mode=backend.ct2_decoding_length_mode, | ||
| 117 | + decoding_length_extra=backend.ct2_decoding_length_extra, | ||
| 118 | + ) | ||
| 119 | + segments = backend._split_text_if_needed( | ||
| 120 | + document, | ||
| 121 | + target_lang=args.target_lang, | ||
| 122 | + source_lang=args.source_lang, | ||
| 123 | + ) | ||
| 124 | + | ||
| 125 | + # Warm up once before measurements. | ||
| 126 | + _ = service.translate( | ||
| 127 | + document, | ||
| 128 | + source_lang=args.source_lang, | ||
| 129 | + target_lang=args.target_lang, | ||
| 130 | + model=args.model, | ||
| 131 | + scene=args.scene, | ||
| 132 | + ) | ||
| 133 | + if torch.cuda.is_available(): | ||
| 134 | + torch.cuda.synchronize() | ||
| 135 | + | ||
| 136 | + latencies_ms: list[float] = [] | ||
| 137 | + output_chars = 0 | ||
| 138 | + for _ in range(args.runs): | ||
| 139 | + started = time.perf_counter() | ||
| 140 | + output = service.translate( | ||
| 141 | + document, | ||
| 142 | + source_lang=args.source_lang, | ||
| 143 | + target_lang=args.target_lang, | ||
| 144 | + model=args.model, | ||
| 145 | + scene=args.scene, | ||
| 146 | + ) | ||
| 147 | + if torch.cuda.is_available(): | ||
| 148 | + torch.cuda.synchronize() | ||
| 149 | + latencies_ms.append((time.perf_counter() - started) * 1000) | ||
| 150 | + output_chars += len(output or "") | ||
| 151 | + | ||
| 152 | + total_seconds = sum(latencies_ms) / 1000.0 | ||
| 153 | + payload = { | ||
| 154 | + "model": args.model, | ||
| 155 | + "source_lang": args.source_lang, | ||
| 156 | + "target_lang": args.target_lang, | ||
| 157 | + "doc_chars": len(document), | ||
| 158 | + "runs": args.runs, | ||
| 159 | + "load_seconds": round(load_seconds, 3), | ||
| 160 | + "batch_size": backend.batch_size, | ||
| 161 | + "ct2_inter_threads": backend.ct2_inter_threads, | ||
| 162 | + "ct2_max_queued_batches": backend.ct2_max_queued_batches, | ||
| 163 | + "ct2_batch_type": backend.ct2_batch_type, | ||
| 164 | + "max_new_tokens": backend.max_new_tokens, | ||
| 165 | + "ct2_decoding_length_mode": backend.ct2_decoding_length_mode, | ||
| 166 | + "ct2_decoding_length_extra": backend.ct2_decoding_length_extra, | ||
| 167 | + "ct2_decoding_length_min": backend.ct2_decoding_length_min, | ||
| 168 | + "safe_input_limit": safe_input_limit, | ||
| 169 | + "segment_count": len(segments), | ||
| 170 | + "segment_char_lengths": { | ||
| 171 | + "min": min(len(segment) for segment in segments), | ||
| 172 | + "max": max(len(segment) for segment in segments), | ||
| 173 | + "avg": round(statistics.fmean(len(segment) for segment in segments), 1), | ||
| 174 | + }, | ||
| 175 | + "latency_avg_ms": round(statistics.fmean(latencies_ms), 2), | ||
| 176 | + "latency_p50_ms": round(percentile(latencies_ms, 0.50), 2), | ||
| 177 | + "latency_p95_ms": round(percentile(latencies_ms, 0.95), 2), | ||
| 178 | + "latency_max_ms": round(max(latencies_ms), 2), | ||
| 179 | + "input_chars_per_second": round((len(document) * args.runs) / total_seconds, 2), | ||
| 180 | + "output_chars_per_second": round(output_chars / total_seconds, 2), | ||
| 181 | + } | ||
| 182 | + print(json.dumps(payload, ensure_ascii=False)) | ||
| 183 | + | ||
| 184 | + | ||
| 185 | +if __name__ == "__main__": | ||
| 186 | + main() |
scripts/start_cnclip_service.sh
| @@ -12,7 +12,7 @@ | @@ -12,7 +12,7 @@ | ||
| 12 | # 选项: | 12 | # 选项: |
| 13 | # --port PORT 服务端口(默认:51000) | 13 | # --port PORT 服务端口(默认:51000) |
| 14 | # --device DEVICE 设备类型:cuda 或 cpu(默认:cuda) | 14 | # --device DEVICE 设备类型:cuda 或 cpu(默认:cuda) |
| 15 | -# --model-name NAME 模型名称(默认:CN-CLIP/ViT-H-14) | 15 | +# --model-name NAME 模型名称(默认读取 embeddings/config.py) |
| 16 | # --replicas NUM 副本数(默认:1) | 16 | # --replicas NUM 副本数(默认:1) |
| 17 | # --help 显示帮助信息 | 17 | # --help 显示帮助信息 |
| 18 | # | 18 | # |
| @@ -31,15 +31,31 @@ YELLOW='\033[1;33m' | @@ -31,15 +31,31 @@ YELLOW='\033[1;33m' | ||
| 31 | BLUE='\033[0;34m' | 31 | BLUE='\033[0;34m' |
| 32 | NC='\033[0m' # No Color | 32 | NC='\033[0m' # No Color |
| 33 | 33 | ||
| 34 | +# 项目路径(以仓库实际路径为准,避免写死 /data/tw/...) | ||
| 35 | +PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" | ||
| 36 | + | ||
| 37 | +resolve_default_model_name() { | ||
| 38 | + local python_bin | ||
| 39 | + local resolved_model_name | ||
| 40 | + for python_bin in python3 python; do | ||
| 41 | + if command -v "${python_bin}" >/dev/null 2>&1; then | ||
| 42 | + if resolved_model_name="$(PYTHONPATH="${PROJECT_ROOT}${PYTHONPATH:+:${PYTHONPATH}}" "${python_bin}" -c "from embeddings.config import CONFIG; print(CONFIG.CLIP_AS_SERVICE_MODEL_NAME)" 2>/dev/null)"; then | ||
| 43 | + if [ -n "${resolved_model_name}" ]; then | ||
| 44 | + echo "${resolved_model_name}" | ||
| 45 | + return 0 | ||
| 46 | + fi | ||
| 47 | + fi | ||
| 48 | + fi | ||
| 49 | + done | ||
| 50 | + echo "CN-CLIP/ViT-L-14" | ||
| 51 | +} | ||
| 52 | + | ||
| 34 | # 默认配置 | 53 | # 默认配置 |
| 35 | DEFAULT_PORT=51000 | 54 | DEFAULT_PORT=51000 |
| 36 | DEFAULT_DEVICE="cuda" | 55 | DEFAULT_DEVICE="cuda" |
| 37 | -DEFAULT_MODEL_NAME="CN-CLIP/ViT-H-14" | ||
| 38 | -# DEFAULT_MODEL_NAME="CN-CLIP/ViT-L-14-336" | 56 | +DEFAULT_MODEL_NAME="$(resolve_default_model_name)" |
| 39 | DEFAULT_REPLICAS=1 # 副本数 | 57 | DEFAULT_REPLICAS=1 # 副本数 |
| 40 | 58 | ||
| 41 | -# 项目路径(以仓库实际路径为准,避免写死 /data/tw/...) | ||
| 42 | -PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)" | ||
| 43 | CLIP_SERVER_DIR="${PROJECT_ROOT}/third-party/clip-as-service/server" | 59 | CLIP_SERVER_DIR="${PROJECT_ROOT}/third-party/clip-as-service/server" |
| 44 | LOG_DIR="${PROJECT_ROOT}/logs" | 60 | LOG_DIR="${PROJECT_ROOT}/logs" |
| 45 | PID_FILE="${LOG_DIR}/cnclip.pid" | 61 | PID_FILE="${LOG_DIR}/cnclip.pid" |
| @@ -64,20 +80,37 @@ show_help() { | @@ -64,20 +80,37 @@ show_help() { | ||
| 64 | echo " $0 # 使用默认配置启动" | 80 | echo " $0 # 使用默认配置启动" |
| 65 | echo " $0 --port 52000 --device cuda # 指定 CUDA 模式,端口 52000" | 81 | echo " $0 --port 52000 --device cuda # 指定 CUDA 模式,端口 52000" |
| 66 | echo " $0 --port 52000 --device cpu # 显式使用 CPU 模式" | 82 | echo " $0 --port 52000 --device cpu # 显式使用 CPU 模式" |
| 83 | + echo " $0 --model-name CN-CLIP/ViT-L-14 # 临时覆盖模型" | ||
| 67 | echo " $0 --replicas 2 # 启动2个副本(需8-10GB显存)" | 84 | echo " $0 --replicas 2 # 启动2个副本(需8-10GB显存)" |
| 68 | echo "" | 85 | echo "" |
| 86 | + echo "说明:" | ||
| 87 | + echo " - 默认模型取自 embeddings/config.py 的 CLIP_AS_SERVICE_MODEL_NAME" | ||
| 88 | + echo " - 也可通过环境变量 CNCLIP_MODEL_NAME 覆盖,再由 --model-name 最终覆盖" | ||
| 89 | + echo "" | ||
| 69 | echo "支持的模型:" | 90 | echo "支持的模型:" |
| 70 | - echo " - CN-CLIP/ViT-B-16 基础版本,速度快" | ||
| 71 | - echo " - CN-CLIP/ViT-L-14 平衡版本" | ||
| 72 | - echo " - CN-CLIP/ViT-L-14-336 高分辨率版本" | ||
| 73 | - echo " - CN-CLIP/ViT-H-14 大型版本,精度高(默认)" | ||
| 74 | - echo " - CN-CLIP/RN50 ResNet-50 版本" | 91 | + local supported_models=( |
| 92 | + "CN-CLIP/ViT-B-16|基础版本,速度快" | ||
| 93 | + "CN-CLIP/ViT-L-14|平衡版本" | ||
| 94 | + "CN-CLIP/ViT-L-14-336|高分辨率版本" | ||
| 95 | + "CN-CLIP/ViT-H-14|大型版本,精度高" | ||
| 96 | + "CN-CLIP/RN50|ResNet-50 版本" | ||
| 97 | + ) | ||
| 98 | + local item model desc suffix | ||
| 99 | + for item in "${supported_models[@]}"; do | ||
| 100 | + model="${item%%|*}" | ||
| 101 | + desc="${item#*|}" | ||
| 102 | + suffix="" | ||
| 103 | + if [ "${model}" = "${DEFAULT_MODEL_NAME}" ]; then | ||
| 104 | + suffix="(当前默认)" | ||
| 105 | + fi | ||
| 106 | + echo " - ${model} ${desc}${suffix}" | ||
| 107 | + done | ||
| 75 | } | 108 | } |
| 76 | 109 | ||
| 77 | # 解析命令行参数 | 110 | # 解析命令行参数 |
| 78 | PORT="${CNCLIP_PORT:-${DEFAULT_PORT}}" | 111 | PORT="${CNCLIP_PORT:-${DEFAULT_PORT}}" |
| 79 | DEVICE=${DEFAULT_DEVICE} | 112 | DEVICE=${DEFAULT_DEVICE} |
| 80 | -MODEL_NAME=${DEFAULT_MODEL_NAME} | 113 | +MODEL_NAME="${CNCLIP_MODEL_NAME:-${DEFAULT_MODEL_NAME}}" |
| 81 | REPLICAS=${DEFAULT_REPLICAS} | 114 | REPLICAS=${DEFAULT_REPLICAS} |
| 82 | 115 | ||
| 83 | while [[ $# -gt 0 ]]; do | 116 | while [[ $# -gt 0 ]]; do |
scripts/start_embedding_service.sh
| @@ -30,6 +30,7 @@ DEFAULT_EMBEDDING_SERVICE_HOST=$("${PYTHON_BIN}" -c "from embeddings.config impo | @@ -30,6 +30,7 @@ DEFAULT_EMBEDDING_SERVICE_HOST=$("${PYTHON_BIN}" -c "from embeddings.config impo | ||
| 30 | DEFAULT_EMBEDDING_SERVICE_PORT=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print(CONFIG.PORT)") | 30 | DEFAULT_EMBEDDING_SERVICE_PORT=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print(CONFIG.PORT)") |
| 31 | USE_CLIP_AS_SERVICE=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print('1' if CONFIG.USE_CLIP_AS_SERVICE else '0')") | 31 | USE_CLIP_AS_SERVICE=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print('1' if CONFIG.USE_CLIP_AS_SERVICE else '0')") |
| 32 | CLIP_AS_SERVICE_SERVER=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print(CONFIG.CLIP_AS_SERVICE_SERVER)") | 32 | CLIP_AS_SERVICE_SERVER=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print(CONFIG.CLIP_AS_SERVICE_SERVER)") |
| 33 | +CLIP_AS_SERVICE_MODEL_NAME=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print(CONFIG.CLIP_AS_SERVICE_MODEL_NAME)") | ||
| 33 | TEXT_BACKEND=$("${PYTHON_BIN}" -c "from config.services_config import get_embedding_backend_config; print(get_embedding_backend_config()[0])") | 34 | TEXT_BACKEND=$("${PYTHON_BIN}" -c "from config.services_config import get_embedding_backend_config; print(get_embedding_backend_config()[0])") |
| 34 | TEI_BASE_URL=$("${PYTHON_BIN}" -c "import os; from config.services_config import get_embedding_backend_config; from embeddings.config import CONFIG; _, cfg = get_embedding_backend_config(); print(os.getenv('TEI_BASE_URL') or cfg.get('base_url') or CONFIG.TEI_BASE_URL)") | 35 | TEI_BASE_URL=$("${PYTHON_BIN}" -c "import os; from config.services_config import get_embedding_backend_config; from embeddings.config import CONFIG; _, cfg = get_embedding_backend_config(); print(os.getenv('TEI_BASE_URL') or cfg.get('base_url') or CONFIG.TEI_BASE_URL)") |
| 35 | ENABLE_IMAGE_MODEL="${EMBEDDING_ENABLE_IMAGE_MODEL:-true}" | 36 | ENABLE_IMAGE_MODEL="${EMBEDDING_ENABLE_IMAGE_MODEL:-true}" |
| @@ -84,14 +85,17 @@ echo "Python: ${PYTHON_BIN}" | @@ -84,14 +85,17 @@ echo "Python: ${PYTHON_BIN}" | ||
| 84 | echo "Host: ${EMBEDDING_SERVICE_HOST}" | 85 | echo "Host: ${EMBEDDING_SERVICE_HOST}" |
| 85 | echo "Port: ${EMBEDDING_SERVICE_PORT}" | 86 | echo "Port: ${EMBEDDING_SERVICE_PORT}" |
| 86 | echo "Text backend: ${TEXT_BACKEND}" | 87 | echo "Text backend: ${TEXT_BACKEND}" |
| 88 | +echo "Text max inflight: ${TEXT_MAX_INFLIGHT:-32}" | ||
| 87 | if [[ "${TEXT_BACKEND}" == "tei" ]]; then | 89 | if [[ "${TEXT_BACKEND}" == "tei" ]]; then |
| 88 | echo "TEI URL: ${TEI_BASE_URL}" | 90 | echo "TEI URL: ${TEI_BASE_URL}" |
| 89 | fi | 91 | fi |
| 90 | if [[ "${IMAGE_MODEL_ENABLED}" == "0" ]]; then | 92 | if [[ "${IMAGE_MODEL_ENABLED}" == "0" ]]; then |
| 91 | echo "Image backend: disabled" | 93 | echo "Image backend: disabled" |
| 92 | elif [[ "${USE_CLIP_AS_SERVICE}" == "1" ]]; then | 94 | elif [[ "${USE_CLIP_AS_SERVICE}" == "1" ]]; then |
| 93 | - echo "Image backend: clip-as-service (${CLIP_AS_SERVICE_SERVER})" | 95 | + echo "Image backend: clip-as-service (${CLIP_AS_SERVICE_SERVER}, model=${CLIP_AS_SERVICE_MODEL_NAME})" |
| 94 | fi | 96 | fi |
| 97 | +echo "Image max inflight: ${IMAGE_MAX_INFLIGHT:-1}" | ||
| 98 | +echo "Logs: logs/embedding_api.log, logs/embedding_api_error.log, logs/verbose/embedding_verbose.log" | ||
| 95 | echo | 99 | echo |
| 96 | echo "Tips:" | 100 | echo "Tips:" |
| 97 | echo " - Use a single worker (GPU models cannot be safely duplicated across workers)." | 101 | echo " - Use a single worker (GPU models cannot be safely duplicated across workers)." |
| @@ -0,0 +1,93 @@ | @@ -0,0 +1,93 @@ | ||
| 1 | +import asyncio | ||
| 2 | + | ||
| 3 | +import numpy as np | ||
| 4 | +import pytest | ||
| 5 | + | ||
| 6 | +import embeddings.server as embedding_server | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +class _DummyClient: | ||
| 10 | + host = "127.0.0.1" | ||
| 11 | + | ||
| 12 | + | ||
| 13 | +class _DummyRequest: | ||
| 14 | + def __init__(self, headers=None): | ||
| 15 | + self.headers = headers or {} | ||
| 16 | + self.client = _DummyClient() | ||
| 17 | + | ||
| 18 | + | ||
| 19 | +class _DummyResponse: | ||
| 20 | + def __init__(self): | ||
| 21 | + self.headers = {} | ||
| 22 | + | ||
| 23 | + | ||
| 24 | +class _FakeTextModel: | ||
| 25 | + def encode(self, texts, batch_size, device, normalize_embeddings): | ||
| 26 | + assert texts == ["hello world"] | ||
| 27 | + assert normalize_embeddings is False | ||
| 28 | + return [np.array([1.0, 2.0, 3.0], dtype=np.float32)] | ||
| 29 | + | ||
| 30 | + | ||
| 31 | +def test_health_exposes_limit_stats(monkeypatch): | ||
| 32 | + monkeypatch.setattr( | ||
| 33 | + embedding_server, | ||
| 34 | + "_text_request_limiter", | ||
| 35 | + embedding_server._InflightLimiter("text", 2), | ||
| 36 | + ) | ||
| 37 | + monkeypatch.setattr( | ||
| 38 | + embedding_server, | ||
| 39 | + "_image_request_limiter", | ||
| 40 | + embedding_server._InflightLimiter("image", 1), | ||
| 41 | + ) | ||
| 42 | + | ||
| 43 | + payload = embedding_server.health() | ||
| 44 | + | ||
| 45 | + assert payload["status"] == "ok" | ||
| 46 | + assert payload["limits"]["text"]["limit"] == 2 | ||
| 47 | + assert payload["limits"]["image"]["limit"] == 1 | ||
| 48 | + assert "queue_depth" in payload["text_microbatch"] | ||
| 49 | + | ||
| 50 | + | ||
| 51 | +def test_embed_image_rejects_when_image_lane_is_full(monkeypatch): | ||
| 52 | + limiter = embedding_server._InflightLimiter("image", 1) | ||
| 53 | + acquired, _ = limiter.try_acquire() | ||
| 54 | + assert acquired is True | ||
| 55 | + monkeypatch.setattr(embedding_server, "_image_request_limiter", limiter) | ||
| 56 | + | ||
| 57 | + response = _DummyResponse() | ||
| 58 | + with pytest.raises(embedding_server.HTTPException) as exc_info: | ||
| 59 | + asyncio.run( | ||
| 60 | + embedding_server.embed_image( | ||
| 61 | + ["https://example.com/a.jpg"], | ||
| 62 | + _DummyRequest(), | ||
| 63 | + response, | ||
| 64 | + ) | ||
| 65 | + ) | ||
| 66 | + | ||
| 67 | + assert exc_info.value.status_code == embedding_server._OVERLOAD_STATUS_CODE | ||
| 68 | + assert "busy" in exc_info.value.detail | ||
| 69 | + assert limiter.snapshot()["rejected_total"] == 1 | ||
| 70 | + | ||
| 71 | + | ||
| 72 | +def test_embed_text_returns_request_id_and_vector(monkeypatch): | ||
| 73 | + monkeypatch.setattr( | ||
| 74 | + embedding_server, | ||
| 75 | + "_text_request_limiter", | ||
| 76 | + embedding_server._InflightLimiter("text", 2), | ||
| 77 | + ) | ||
| 78 | + monkeypatch.setattr(embedding_server, "_text_model", _FakeTextModel()) | ||
| 79 | + monkeypatch.setattr(embedding_server, "_text_backend_name", "tei") | ||
| 80 | + | ||
| 81 | + request = _DummyRequest(headers={"X-Request-ID": "req-123456"}) | ||
| 82 | + response = _DummyResponse() | ||
| 83 | + result = asyncio.run( | ||
| 84 | + embedding_server.embed_text( | ||
| 85 | + ["hello world"], | ||
| 86 | + request, | ||
| 87 | + response, | ||
| 88 | + normalize=False, | ||
| 89 | + ) | ||
| 90 | + ) | ||
| 91 | + | ||
| 92 | + assert response.headers["X-Request-ID"] == "req-123456" | ||
| 93 | + assert result == [[1.0, 2.0, 3.0]] |
translation/backends/local_ctranslate2.py
| @@ -353,14 +353,24 @@ class LocalCTranslate2TranslationBackend: | @@ -353,14 +353,24 @@ class LocalCTranslate2TranslationBackend: | ||
| 353 | source_lang: Optional[str] = None, | 353 | source_lang: Optional[str] = None, |
| 354 | ) -> List[str]: | 354 | ) -> List[str]: |
| 355 | limit = self._effective_input_token_limit(target_lang, source_lang) | 355 | limit = self._effective_input_token_limit(target_lang, source_lang) |
| 356 | - return split_text_for_translation( | ||
| 357 | - text, | ||
| 358 | - max_tokens=limit, | ||
| 359 | - token_length_fn=lambda value: self._token_count( | 356 | + token_count_cache: Dict[str, int] = {} |
| 357 | + | ||
| 358 | + def _cached_token_count(value: str) -> int: | ||
| 359 | + cached = token_count_cache.get(value) | ||
| 360 | + if cached is not None: | ||
| 361 | + return cached | ||
| 362 | + count = self._token_count( | ||
| 360 | value, | 363 | value, |
| 361 | target_lang=target_lang, | 364 | target_lang=target_lang, |
| 362 | source_lang=source_lang, | 365 | source_lang=source_lang, |
| 363 | - ), | 366 | + ) |
| 367 | + token_count_cache[value] = count | ||
| 368 | + return count | ||
| 369 | + | ||
| 370 | + return split_text_for_translation( | ||
| 371 | + text, | ||
| 372 | + max_tokens=limit, | ||
| 373 | + token_length_fn=_cached_token_count, | ||
| 364 | ) | 374 | ) |
| 365 | 375 | ||
| 366 | def _log_segmentation_summary( | 376 | def _log_segmentation_summary( |
translation/backends/local_seq2seq.py
| @@ -203,14 +203,24 @@ class LocalSeq2SeqTranslationBackend: | @@ -203,14 +203,24 @@ class LocalSeq2SeqTranslationBackend: | ||
| 203 | source_lang: Optional[str] = None, | 203 | source_lang: Optional[str] = None, |
| 204 | ) -> List[str]: | 204 | ) -> List[str]: |
| 205 | limit = self._effective_input_token_limit(target_lang, source_lang) | 205 | limit = self._effective_input_token_limit(target_lang, source_lang) |
| 206 | - return split_text_for_translation( | ||
| 207 | - text, | ||
| 208 | - max_tokens=limit, | ||
| 209 | - token_length_fn=lambda value: self._token_count( | 206 | + token_count_cache: Dict[str, int] = {} |
| 207 | + | ||
| 208 | + def _cached_token_count(value: str) -> int: | ||
| 209 | + cached = token_count_cache.get(value) | ||
| 210 | + if cached is not None: | ||
| 211 | + return cached | ||
| 212 | + count = self._token_count( | ||
| 210 | value, | 213 | value, |
| 211 | target_lang=target_lang, | 214 | target_lang=target_lang, |
| 212 | source_lang=source_lang, | 215 | source_lang=source_lang, |
| 213 | - ), | 216 | + ) |
| 217 | + token_count_cache[value] = count | ||
| 218 | + return count | ||
| 219 | + | ||
| 220 | + return split_text_for_translation( | ||
| 221 | + text, | ||
| 222 | + max_tokens=limit, | ||
| 223 | + token_length_fn=_cached_token_count, | ||
| 214 | ) | 224 | ) |
| 215 | 225 | ||
| 216 | def _log_segmentation_summary( | 226 | def _log_segmentation_summary( |