From efd435cfd693c23f8ee6eeb65adf4b63f08e54e3 Mon Sep 17 00:00:00 2001 From: tangwang Date: Wed, 11 Mar 2026 13:12:44 +0800 Subject: [PATCH] tei性能调优: ./scripts/start_tei_service.sh START_TEI=0 ./scripts/service_ctl.sh restart embedding --- config/config.yaml | 4 ++-- docs/TEI_SERVICE说明文档.md | 34 ++++++++++++++++++++++++++++------ docs/搜索API对接指南.md | 24 ++++++++++++++++++++++++ embeddings/qwen3_model.py | 37 ++++++++++++++++++++++++++----------- embeddings/server.py | 165 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--- reranker/backends/qwen3_vllm.py | 9 +++++++-- scripts/perf_api_benchmark.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++- scripts/start_embedding_service.sh | 28 +++++++++++++++++++++++----- scripts/start_tei_service.sh | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------ suggestion/service.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------------------ 10 files changed, 469 insertions(+), 96 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index a409140..4624e83 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -146,11 +146,11 @@ services: http: base_url: "http://127.0.0.1:6005" # 服务内文本后端(embedding 进程启动时读取) - backend: "local_st" # tei | local_st + backend: "tei" # tei | local_st backends: tei: base_url: "http://127.0.0.1:8080" - timeout_sec: 60 + timeout_sec: 20 model_id: "Qwen/Qwen3-Embedding-0.6B" local_st: model_id: "Qwen/Qwen3-Embedding-0.6B" diff --git a/docs/TEI_SERVICE说明文档.md b/docs/TEI_SERVICE说明文档.md index 23e2717..dda7dea 100644 --- a/docs/TEI_SERVICE说明文档.md +++ b/docs/TEI_SERVICE说明文档.md @@ -68,9 +68,13 @@ TEI_USE_GPU=1 ./scripts/start_tei_service.sh 预期输出包含: -- `Image: ghcr.io/huggingface/text-embeddings-inference:cuda-...` +- `Image: ghcr.io/huggingface/text-embeddings-inference:turing-...` 或 `cuda-...`(脚本按 GPU 架构自动选择) - `Mode: gpu` -- `TEI is ready: http://127.0.0.1:8080` +- `TEI is ready and output probe passed: http://127.0.0.1:8080` + +说明: +- T4(compute capability 7.5)会自动使用 `turing-*` 镜像。 +- Ampere 及更新架构(compute capability >= 8)会自动使用 `cuda-*` 镜像。 ### 5.2 CPU 模式启动(显式) @@ -82,7 +86,7 @@ TEI_USE_GPU=0 ./scripts/start_tei_service.sh - `Image: ghcr.io/huggingface/text-embeddings-inference:1.9`(非 `cuda-`) - `Mode: cpu` -- `TEI is ready: http://127.0.0.1:8080` +- `TEI is ready and output probe passed: http://127.0.0.1:8080` ### 5.3 停止服务 @@ -108,6 +112,8 @@ curl -sS http://127.0.0.1:8080/embed \ 返回应为二维数组(每条输入对应一个向量)。 +建议再连续请求一次,确认不是“首个请求正常,后续返回 null/NaN”。 + ### 6.3 与 embedding 服务联调 ```bash @@ -120,6 +126,11 @@ curl -sS -X POST "http://127.0.0.1:6005/embed/text" \ 返回应为 1024 维向量数组。 +### 6.4 运行建议(单服务兼顾在线与索引) + +- 在线 query(低延迟优先):客户端建议 `batch=1~4` +- 索引构建(吞吐优先):客户端建议 `batch=15~20` + ## 7. 配置项(环境变量) `scripts/start_tei_service.sh` 支持下列变量: @@ -130,10 +141,11 @@ curl -sS -X POST "http://127.0.0.1:6005/embed/text" \ - `TEI_MODEL_ID`:默认 `Qwen/Qwen3-Embedding-0.6B` - `TEI_VERSION`:镜像版本,默认 `1.9` - `TEI_DTYPE`:默认 `float16` -- `TEI_MAX_BATCH_TOKENS`:默认 `2048` -- `TEI_MAX_CLIENT_BATCH_SIZE`:默认 `8` +- `TEI_MAX_BATCH_TOKENS`:默认 `4096` +- `TEI_MAX_CLIENT_BATCH_SIZE`:默认 `24` - `HF_CACHE_DIR`:HF 缓存目录,默认 `$HOME/.cache/huggingface` - `HF_TOKEN`:可选,避免匿名限速 +- `TEI_IMAGE`:可选,手动指定镜像(通常不需要,建议使用脚本自动选择) ## 8. service_ctl 使用方式 @@ -184,10 +196,20 @@ curl -sS http://127.0.0.1:8080/health - `TEI_BASE_URL` - `services.embedding.backends.tei.base_url`(`config/config.yaml`) +### 9.4 `/embed/text` 第二次请求开始出现 NaN/null + +- 常见原因:在 T4 这类 pre-Ampere GPU 上误用了 `cuda-*` TEI 镜像。 +- 处理: + +```bash +./scripts/start_tei_service.sh +``` + +该脚本会自动按 GPU 架构选择镜像,并在启动后做两次输出探测;若发现 `null/NaN/Inf` 会直接失败并清理错误容器。 + ## 10. 相关文档 - 开发总览:`docs/QUICKSTART.md` - 体系规范:`docs/DEVELOPER_GUIDE.md` - embedding 模块:`embeddings/README.md` - CN-CLIP 专项:`docs/CNCLIP_SERVICE说明文档.md` - diff --git a/docs/搜索API对接指南.md b/docs/搜索API对接指南.md index bcc0601..10d9396 100644 --- a/docs/搜索API对接指南.md +++ b/docs/搜索API对接指南.md @@ -1586,6 +1586,28 @@ curl -X POST "http://localhost:6005/embed/image" \ curl "http://localhost:6005/health" ``` +#### 7.1.4 TEI 统一调优建议(主服务) + +使用单套主服务即可同时兼顾: +- 在线 query 向量化(低延迟,常见 `batch=1~4`) +- 索引构建向量化(高吞吐,常见 `batch=15~20`) + +统一启动(主链路): + +```bash +./scripts/start_tei_service.sh +START_TEI=0 ./scripts/service_ctl.sh restart embedding +``` + +默认端口: +- TEI: `http://127.0.0.1:8080` +- 向量服务(`/embed/text`): `http://127.0.0.1:6005` + +当前主 TEI 启动默认值(已按 T4/短文本场景调优): +- `TEI_MAX_BATCH_TOKENS=4096` +- `TEI_MAX_CLIENT_BATCH_SIZE=24` +- `TEI_DTYPE=float16` + ### 7.2 重排服务(Reranker) - **Base URL**: `http://localhost:6007`(可通过 `RERANKER_SERVICE_URL` 覆盖) @@ -2094,6 +2116,8 @@ curl "http://localhost:6006/health" - 翻译服务:`POST /translate` - 重排服务:`POST /rerank` +说明:脚本对 `embed_text` 场景会校验返回向量内容有效性(必须是有限数值,不允许 `null/NaN/Inf`),不是只看 HTTP 200。 + ### 10.1 快速示例 ```bash diff --git a/embeddings/qwen3_model.py b/embeddings/qwen3_model.py index 454e443..9e858f3 100644 --- a/embeddings/qwen3_model.py +++ b/embeddings/qwen3_model.py @@ -9,6 +9,7 @@ from typing import List, Union import numpy as np from sentence_transformers import SentenceTransformer +import torch class Qwen3TextModel(object): @@ -24,8 +25,21 @@ class Qwen3TextModel(object): if cls._instance is None: cls._instance = super(Qwen3TextModel, cls).__new__(cls) cls._instance.model = SentenceTransformer(model_id, trust_remote_code=True) + cls._instance._current_device = None + cls._instance._encode_lock = threading.Lock() return cls._instance + def _ensure_device(self, device: str) -> str: + target = (device or "cpu").strip().lower() + if target == "gpu": + target = "cuda" + if target == "cuda" and not torch.cuda.is_available(): + target = "cpu" + if target != self._current_device: + self.model = self.model.to(target) + self._current_device = target + return target + def encode( self, sentences: Union[str, List[str]], @@ -33,17 +47,18 @@ class Qwen3TextModel(object): device: str = "cuda", batch_size: int = 32, ) -> np.ndarray: - if device == "gpu": - device = "cuda" - self.model = self.model.to(device) - embeddings = self.model.encode( - sentences, - normalize_embeddings=normalize_embeddings, - device=device, - show_progress_bar=False, - batch_size=batch_size, - ) - return embeddings + # SentenceTransformer + CUDA inference is not thread-safe in our usage; + # keep one in-flight encode call while avoiding repeated .to(device) hops. + with self._encode_lock: + run_device = self._ensure_device(device) + embeddings = self.model.encode( + sentences, + normalize_embeddings=normalize_embeddings, + device=run_device, + show_progress_bar=False, + batch_size=batch_size, + ) + return embeddings def encode_batch( self, diff --git a/embeddings/server.py b/embeddings/server.py index ee16d04..ebcaec8 100644 --- a/embeddings/server.py +++ b/embeddings/server.py @@ -9,6 +9,9 @@ API (simple list-in, list-out; aligned by index): import logging import os import threading +import time +from collections import deque +from dataclasses import dataclass from typing import Any, Dict, List, Optional import numpy as np @@ -26,13 +29,139 @@ app = FastAPI(title="saas-search Embedding Service", version="1.0.0") _text_model: Optional[Any] = None _image_model: Optional[ImageEncoderProtocol] = None _text_backend_name: str = "" -open_text_model = True -open_image_model = True # Enable image embedding when using clip-as-service +open_text_model = os.getenv("EMBEDDING_ENABLE_TEXT_MODEL", "true").lower() in ("1", "true", "yes") +open_image_model = os.getenv("EMBEDDING_ENABLE_IMAGE_MODEL", "true").lower() in ("1", "true", "yes") _text_encode_lock = threading.Lock() _image_encode_lock = threading.Lock() +@dataclass +class _SingleTextTask: + text: str + normalize: bool + created_at: float + done: threading.Event + result: Optional[List[float]] = None + error: Optional[Exception] = None + + +_text_single_queue: "deque[_SingleTextTask]" = deque() +_text_single_queue_cv = threading.Condition() +_text_batch_worker: Optional[threading.Thread] = None +_text_batch_worker_stop = False +_TEXT_MICROBATCH_WINDOW_SEC = max( + 0.0, float(os.getenv("TEXT_MICROBATCH_WINDOW_MS", "4")) / 1000.0 +) +_TEXT_REQUEST_TIMEOUT_SEC = max( + 1.0, float(os.getenv("TEXT_REQUEST_TIMEOUT_SEC", "30")) +) + + +def _encode_local_st(texts: List[str], normalize_embeddings: bool) -> Any: + with _text_encode_lock: + return _text_model.encode_batch( + texts, + batch_size=int(CONFIG.TEXT_BATCH_SIZE), + device=CONFIG.TEXT_DEVICE, + normalize_embeddings=normalize_embeddings, + ) + + +def _start_text_batch_worker() -> None: + global _text_batch_worker, _text_batch_worker_stop + if _text_batch_worker is not None and _text_batch_worker.is_alive(): + return + _text_batch_worker_stop = False + _text_batch_worker = threading.Thread( + target=_text_batch_worker_loop, + name="embed-text-microbatch-worker", + daemon=True, + ) + _text_batch_worker.start() + logger.info( + "Started local_st text micro-batch worker | window_ms=%.1f max_batch=%d", + _TEXT_MICROBATCH_WINDOW_SEC * 1000.0, + int(CONFIG.TEXT_BATCH_SIZE), + ) + + +def _stop_text_batch_worker() -> None: + global _text_batch_worker_stop + with _text_single_queue_cv: + _text_batch_worker_stop = True + _text_single_queue_cv.notify_all() + + +def _text_batch_worker_loop() -> None: + max_batch = max(1, int(CONFIG.TEXT_BATCH_SIZE)) + while True: + with _text_single_queue_cv: + while not _text_single_queue and not _text_batch_worker_stop: + _text_single_queue_cv.wait() + if _text_batch_worker_stop: + return + + batch: List[_SingleTextTask] = [_text_single_queue.popleft()] + deadline = time.perf_counter() + _TEXT_MICROBATCH_WINDOW_SEC + + while len(batch) < max_batch: + remaining = deadline - time.perf_counter() + if remaining <= 0: + break + if not _text_single_queue: + _text_single_queue_cv.wait(timeout=remaining) + continue + while _text_single_queue and len(batch) < max_batch: + batch.append(_text_single_queue.popleft()) + + try: + embs = _encode_local_st([task.text for task in batch], normalize_embeddings=False) + if embs is None or len(embs) != len(batch): + raise RuntimeError( + f"Text model response length mismatch in micro-batch: " + f"expected {len(batch)}, got {0 if embs is None else len(embs)}" + ) + for task, emb in zip(batch, embs): + vec = _as_list(emb, normalize=task.normalize) + if vec is None: + raise RuntimeError("Text model returned empty embedding in micro-batch") + task.result = vec + except Exception as exc: + for task in batch: + task.error = exc + finally: + for task in batch: + task.done.set() + + +def _encode_single_text_with_microbatch(text: str, normalize: bool) -> List[float]: + task = _SingleTextTask( + text=text, + normalize=normalize, + created_at=time.perf_counter(), + done=threading.Event(), + ) + with _text_single_queue_cv: + _text_single_queue.append(task) + _text_single_queue_cv.notify() + + if not task.done.wait(timeout=_TEXT_REQUEST_TIMEOUT_SEC): + with _text_single_queue_cv: + try: + _text_single_queue.remove(task) + except ValueError: + pass + raise RuntimeError( + f"Timed out waiting for text micro-batch worker ({_TEXT_REQUEST_TIMEOUT_SEC:.1f}s)" + ) + if task.error is not None: + raise task.error + if task.result is None: + raise RuntimeError("Text micro-batch worker returned empty result") + return task.result + + @app.on_event("startup") def load_models(): """Load models at service startup to avoid first-request latency.""" @@ -73,6 +202,7 @@ def load_models(): ) logger.info("Loading text backend: local_st (model=%s)", model_id) _text_model = Qwen3TextModel(model_id=str(model_id)) + _start_text_batch_worker() else: raise ValueError( f"Unsupported embedding backend: {backend_name}. " @@ -112,6 +242,11 @@ def load_models(): logger.info("All embedding models loaded successfully, service ready") +@app.on_event("shutdown") +def stop_workers() -> None: + _stop_text_batch_worker() + + def _normalize_vector(vec: np.ndarray) -> np.ndarray: norm = float(np.linalg.norm(vec)) if not np.isfinite(norm) or norm <= 0.0: @@ -157,8 +292,24 @@ def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optio raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: empty string") normalized.append(s) + t0 = time.perf_counter() try: - with _text_encode_lock: + # local_st backend uses in-process torch model, keep serialized encode for safety; + # TEI backend is an HTTP client and supports concurrent requests. + if _text_backend_name == "local_st": + if len(normalized) == 1 and _text_batch_worker is not None: + out = [_encode_single_text_with_microbatch(normalized[0], normalize=effective_normalize)] + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "embed_text done | backend=%s mode=microbatch-single inputs=%d normalize=%s elapsed_ms=%.2f", + _text_backend_name, + len(normalized), + effective_normalize, + elapsed_ms, + ) + return out + embs = _encode_local_st(normalized, normalize_embeddings=False) + else: embs = _text_model.encode_batch( normalized, batch_size=int(CONFIG.TEXT_BATCH_SIZE), @@ -182,6 +333,14 @@ def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optio if vec is None: raise RuntimeError(f"Text model returned empty embedding for index {i}") out.append(vec) + elapsed_ms = (time.perf_counter() - t0) * 1000.0 + logger.info( + "embed_text done | backend=%s inputs=%d normalize=%s elapsed_ms=%.2f", + _text_backend_name, + len(normalized), + effective_normalize, + elapsed_ms, + ) return out diff --git a/reranker/backends/qwen3_vllm.py b/reranker/backends/qwen3_vllm.py index b26aefd..f6ebce4 100644 --- a/reranker/backends/qwen3_vllm.py +++ b/reranker/backends/qwen3_vllm.py @@ -9,6 +9,7 @@ from __future__ import annotations import logging import math +import threading import time from typing import Any, Dict, List, Optional, Tuple @@ -102,6 +103,9 @@ class Qwen3VLLMRerankerBackend: logprobs=20, allowed_token_ids=[self._true_token, self._false_token], ) + # vLLM generate path is unstable under concurrent calls in this process model. + # Serialize infer calls to avoid engine-core protocol corruption. + self._infer_lock = threading.Lock() self._model_name = model_name logger.info("[Qwen3_VLLM] Model ready | model=%s", model_name) @@ -209,8 +213,9 @@ class Qwen3VLLMRerankerBackend: position_to_unique.append(len(unique_texts) - 1) pairs = [(query, t) for t in unique_texts] - prompts = self._process_inputs(pairs) - unique_scores = self._compute_scores(prompts) + with self._infer_lock: + prompts = self._process_inputs(pairs) + unique_scores = self._compute_scores(prompts) for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): # Score is already P(yes) in [0,1] from yes/(yes+no) diff --git a/scripts/perf_api_benchmark.py b/scripts/perf_api_benchmark.py index f2dba7d..ceb4ef0 100755 --- a/scripts/perf_api_benchmark.py +++ b/scripts/perf_api_benchmark.py @@ -55,6 +55,43 @@ class RequestResult: error: str = "" +def _is_finite_number(v: Any) -> bool: + if isinstance(v, bool): + return False + if isinstance(v, (int, float)): + return math.isfinite(float(v)) + return False + + +def validate_response_payload( + scenario_name: str, + tpl: RequestTemplate, + payload: Any, +) -> Tuple[bool, str]: + """ + Lightweight payload validation for correctness-aware perf tests. + Currently strict for embed_text to catch NaN/null vector regressions. + """ + if scenario_name != "embed_text": + return True, "" + + expected_len = len(tpl.json_body) if isinstance(tpl.json_body, list) else None + if not isinstance(payload, list): + return False, "invalid_payload_non_list" + if expected_len is not None and len(payload) != expected_len: + return False, "invalid_payload_length" + if len(payload) == 0: + return False, "invalid_payload_empty" + + for i, vec in enumerate(payload): + if not isinstance(vec, list) or len(vec) == 0: + return False, f"invalid_vector_{i}_shape" + for x in vec: + if not _is_finite_number(x): + return False, f"invalid_vector_{i}_non_finite" + return True, "" + + def percentile(sorted_values: List[float], p: float) -> float: if not sorted_values: return 0.0 @@ -259,7 +296,22 @@ async def run_single_scenario( ) status = int(resp.status_code) ok = 200 <= status < 300 - if not ok: + if ok: + try: + payload = resp.json() + except Exception: + ok = False + err = "invalid_json_response" + else: + valid, reason = validate_response_payload( + scenario_name=scenario.name, + tpl=tpl, + payload=payload, + ) + if not valid: + ok = False + err = reason or "invalid_payload" + if not ok and not err: err = f"http_{status}" except Exception as e: err = type(e).__name__ diff --git a/scripts/start_embedding_service.sh b/scripts/start_embedding_service.sh index d0d8c62..56a3d34 100755 --- a/scripts/start_embedding_service.sh +++ b/scripts/start_embedding_service.sh @@ -54,6 +54,13 @@ USE_CLIP_AS_SERVICE=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; CLIP_AS_SERVICE_SERVER=$("${PYTHON_BIN}" -c "from embeddings.config import CONFIG; print(CONFIG.CLIP_AS_SERVICE_SERVER)") TEXT_BACKEND=$("${PYTHON_BIN}" -c "from config.services_config import get_embedding_backend_config; print(get_embedding_backend_config()[0])") TEI_BASE_URL=$("${PYTHON_BIN}" -c "import os; from config.services_config import get_embedding_backend_config; from embeddings.config import CONFIG; _, cfg = get_embedding_backend_config(); print(os.getenv('TEI_BASE_URL') or cfg.get('base_url') or CONFIG.TEI_BASE_URL)") +ENABLE_IMAGE_MODEL="${EMBEDDING_ENABLE_IMAGE_MODEL:-true}" +ENABLE_IMAGE_MODEL="$(echo "${ENABLE_IMAGE_MODEL}" | tr '[:upper:]' '[:lower:]')" +if [[ "${ENABLE_IMAGE_MODEL}" == "1" || "${ENABLE_IMAGE_MODEL}" == "true" || "${ENABLE_IMAGE_MODEL}" == "yes" ]]; then + IMAGE_MODEL_ENABLED=1 +else + IMAGE_MODEL_ENABLED=0 +fi EMBEDDING_SERVICE_HOST="${EMBEDDING_HOST:-${DEFAULT_EMBEDDING_SERVICE_HOST}}" EMBEDDING_SERVICE_PORT="${EMBEDDING_PORT:-${DEFAULT_EMBEDDING_SERVICE_PORT}}" @@ -66,7 +73,7 @@ if [[ "${TEXT_BACKEND}" == "tei" ]]; then fi fi -if [[ "${USE_CLIP_AS_SERVICE}" == "1" ]]; then +if [[ "${IMAGE_MODEL_ENABLED}" == "1" && "${USE_CLIP_AS_SERVICE}" == "1" ]]; then CLIP_SERVER="${CLIP_AS_SERVICE_SERVER#*://}" CLIP_HOST="${CLIP_SERVER%:*}" CLIP_PORT="${CLIP_SERVER##*:}" @@ -102,7 +109,9 @@ echo "Text backend: ${TEXT_BACKEND}" if [[ "${TEXT_BACKEND}" == "tei" ]]; then echo "TEI URL: ${TEI_BASE_URL}" fi -if [[ "${USE_CLIP_AS_SERVICE}" == "1" ]]; then +if [[ "${IMAGE_MODEL_ENABLED}" == "0" ]]; then + echo "Image backend: disabled" +elif [[ "${USE_CLIP_AS_SERVICE}" == "1" ]]; then echo "Image backend: clip-as-service (${CLIP_AS_SERVICE_SERVER})" fi echo @@ -111,7 +120,16 @@ echo " - Use a single worker (GPU models cannot be safely duplicated across wor echo " - Clients can set EMBEDDING_SERVICE_URL=http://localhost:${EMBEDDING_SERVICE_PORT}" echo -exec "${PYTHON_BIN}" -m uvicorn embeddings.server:app \ - --host "${EMBEDDING_SERVICE_HOST}" \ - --port "${EMBEDDING_SERVICE_PORT}" \ +UVICORN_LOG_LEVEL="${EMBEDDING_UVICORN_LOG_LEVEL:-info}" +UVICORN_ACCESS_LOG="${EMBEDDING_UVICORN_ACCESS_LOG:-true}" +UVICORN_ARGS=( + --host "${EMBEDDING_SERVICE_HOST}" + --port "${EMBEDDING_SERVICE_PORT}" --workers 1 + --log-level "${UVICORN_LOG_LEVEL}" +) +if [[ "${UVICORN_ACCESS_LOG}" == "0" || "${UVICORN_ACCESS_LOG}" == "false" || "${UVICORN_ACCESS_LOG}" == "no" ]]; then + UVICORN_ARGS+=(--no-access-log) +fi + +exec "${PYTHON_BIN}" -m uvicorn embeddings.server:app "${UVICORN_ARGS[@]}" diff --git a/scripts/start_tei_service.sh b/scripts/start_tei_service.sh index 049c8c8..ea3124e 100755 --- a/scripts/start_tei_service.sh +++ b/scripts/start_tei_service.sh @@ -43,8 +43,8 @@ TEI_CONTAINER_NAME="${TEI_CONTAINER_NAME:-saas-search-tei}" TEI_PORT="${TEI_PORT:-8080}" TEI_MODEL_ID="${TEI_MODEL_ID:-Qwen/Qwen3-Embedding-0.6B}" TEI_VERSION="${TEI_VERSION:-1.9}" -TEI_MAX_BATCH_TOKENS="${TEI_MAX_BATCH_TOKENS:-2048}" -TEI_MAX_CLIENT_BATCH_SIZE="${TEI_MAX_CLIENT_BATCH_SIZE:-8}" +TEI_MAX_BATCH_TOKENS="${TEI_MAX_BATCH_TOKENS:-4096}" +TEI_MAX_CLIENT_BATCH_SIZE="${TEI_MAX_CLIENT_BATCH_SIZE:-24}" TEI_DTYPE="${TEI_DTYPE:-float16}" HF_CACHE_DIR="${HF_CACHE_DIR:-$HOME/.cache/huggingface}" TEI_HEALTH_TIMEOUT_SEC="${TEI_HEALTH_TIMEOUT_SEC:-300}" @@ -60,6 +60,18 @@ else exit 1 fi +detect_gpu_tei_image() { + # Prefer turing image for pre-Ampere GPUs (e.g. Tesla T4, compute capability 7.5). + local compute_cap major + compute_cap="$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null | head -n1 || true)" + major="${compute_cap%%.*}" + if [[ -n "${major}" && "${major}" -lt 8 ]]; then + echo "ghcr.io/huggingface/text-embeddings-inference:turing-${TEI_VERSION}" + else + echo "ghcr.io/huggingface/text-embeddings-inference:cuda-${TEI_VERSION}" + fi +} + if [[ "${USE_GPU}" == "1" ]]; then if ! command -v nvidia-smi >/dev/null 2>&1 || ! nvidia-smi >/dev/null 2>&1; then echo "ERROR: TEI_USE_GPU=1 but NVIDIA GPU is not available. No CPU fallback." >&2 @@ -70,7 +82,7 @@ if [[ "${USE_GPU}" == "1" ]]; then echo "Install and configure nvidia-container-toolkit, then restart Docker." >&2 exit 1 fi - TEI_IMAGE="${TEI_IMAGE:-ghcr.io/huggingface/text-embeddings-inference:cuda-${TEI_VERSION}}" + TEI_IMAGE="${TEI_IMAGE:-$(detect_gpu_tei_image)}" GPU_ARGS=(--gpus all) TEI_MODE="gpu" else @@ -87,28 +99,33 @@ if [[ -n "${existing_id}" ]]; then if [[ -n "${running_id}" ]]; then current_image="$(docker inspect "${TEI_CONTAINER_NAME}" --format '{{.Config.Image}}' 2>/dev/null || true)" device_req="$(docker inspect "${TEI_CONTAINER_NAME}" --format '{{json .HostConfig.DeviceRequests}}' 2>/dev/null || true)" + current_is_gpu_image=0 + if [[ "${current_image}" == *":cuda-"* || "${current_image}" == *":turing-"* ]]; then + current_is_gpu_image=1 + fi if [[ "${USE_GPU}" == "1" ]]; then - if [[ "${current_image}" != *":cuda-"* ]] || [[ "${device_req}" == "null" ]]; then - echo "ERROR: existing TEI container mode mismatch (need GPU): ${TEI_CONTAINER_NAME}" >&2 - echo " image=${current_image:-unknown}" >&2 - echo " device_requests=${device_req:-unknown}" >&2 - echo "Stop it first: ./scripts/stop_tei_service.sh" >&2 - exit 1 + if [[ "${current_is_gpu_image}" -eq 1 ]] && [[ "${device_req}" != "null" ]] && [[ "${current_image}" == "${TEI_IMAGE}" ]]; then + echo "TEI already running (GPU): ${TEI_CONTAINER_NAME}" + exit 0 fi - echo "TEI already running (GPU): ${TEI_CONTAINER_NAME}" + echo "TEI running with different mode/image; recreating container ${TEI_CONTAINER_NAME}" + echo " current_image=${current_image:-unknown}" + echo " target_image=${TEI_IMAGE}" + docker rm -f "${TEI_CONTAINER_NAME}" >/dev/null 2>&1 || true else - if [[ "${current_image}" == *":cuda-"* ]] || [[ "${device_req}" != "null" ]]; then - echo "ERROR: existing TEI container mode mismatch (need CPU): ${TEI_CONTAINER_NAME}" >&2 - echo " image=${current_image:-unknown}" >&2 - echo " device_requests=${device_req:-unknown}" >&2 - echo "Stop it first: ./scripts/stop_tei_service.sh" >&2 - exit 1 + if [[ "${current_is_gpu_image}" -eq 0 ]] && [[ "${device_req}" == "null" ]] && [[ "${current_image}" == "${TEI_IMAGE}" ]]; then + echo "TEI already running (CPU): ${TEI_CONTAINER_NAME}" + exit 0 fi - echo "TEI already running (CPU): ${TEI_CONTAINER_NAME}" + echo "TEI running with different mode/image; recreating container ${TEI_CONTAINER_NAME}" + echo " current_image=${current_image:-unknown}" + echo " target_image=${TEI_IMAGE}" + docker rm -f "${TEI_CONTAINER_NAME}" >/dev/null 2>&1 || true fi - exit 0 fi - docker rm "${TEI_CONTAINER_NAME}" >/dev/null + if docker ps -aq -f name=^/${TEI_CONTAINER_NAME}$ | grep -q .; then + docker rm "${TEI_CONTAINER_NAME}" >/dev/null + fi fi echo "Starting TEI container: ${TEI_CONTAINER_NAME}" @@ -132,12 +149,37 @@ docker run -d \ echo "Waiting for TEI health..." for i in $(seq 1 "${TEI_HEALTH_TIMEOUT_SEC}"); do if curl -sf "http://127.0.0.1:${TEI_PORT}/health" >/dev/null 2>&1; then - echo "TEI is ready: http://127.0.0.1:${TEI_PORT}" - exit 0 + echo "TEI health is ready: http://127.0.0.1:${TEI_PORT}" + break fi sleep 1 + if [[ "${i}" == "${TEI_HEALTH_TIMEOUT_SEC}" ]]; then + echo "ERROR: TEI failed to become healthy in time." >&2 + docker logs --tail 100 "${TEI_CONTAINER_NAME}" >&2 || true + exit 1 + fi +done + +echo "Running TEI output probe..." +for probe_idx in 1 2; do + probe_resp="$(curl -sf -X POST "http://127.0.0.1:${TEI_PORT}/embed" \ + -H "Content-Type: application/json" \ + -d '{"inputs":["health check","芭比娃娃 儿童玩具"]}' || true)" + if [[ -z "${probe_resp}" ]]; then + echo "ERROR: TEI probe ${probe_idx} failed: empty response" >&2 + docker logs --tail 120 "${TEI_CONTAINER_NAME}" >&2 || true + docker rm -f "${TEI_CONTAINER_NAME}" >/dev/null 2>&1 || true + exit 1 + fi + # Detect non-finite-like payloads (observed as null/NaN on incompatible CUDA image + GPU). + if echo "${probe_resp}" | rg -qi '(null|nan|inf)'; then + echo "ERROR: TEI probe ${probe_idx} detected invalid embedding values (null/NaN/Inf)." >&2 + echo "Response preview: $(echo "${probe_resp}" | head -c 220)" >&2 + docker logs --tail 120 "${TEI_CONTAINER_NAME}" >&2 || true + docker rm -f "${TEI_CONTAINER_NAME}" >/dev/null 2>&1 || true + exit 1 + fi done -echo "ERROR: TEI failed to become healthy in time." >&2 -docker logs --tail 100 "${TEI_CONTAINER_NAME}" >&2 || true -exit 1 +echo "TEI is ready and output probe passed: http://127.0.0.1:${TEI_PORT}" +exit 0 diff --git a/suggestion/service.py b/suggestion/service.py index 3804cee..5e0900a 100644 --- a/suggestion/service.py +++ b/suggestion/service.py @@ -123,6 +123,7 @@ class SuggestionService: size: int = 10, ) -> Dict[str, Any]: start = time.time() + query_text = str(query or "").strip() resolved_lang = self._resolve_language(tenant_id, language) index_name = self._resolve_search_target(tenant_id) if not index_name: @@ -137,6 +138,55 @@ class SuggestionService: "took_ms": took_ms, } + # Recall path A: completion suggester (fast path, usually enough for short prefix typing) + t_completion_start = time.time() + completion_items = self._completion_suggest( + index_name=index_name, + query=query_text, + lang=resolved_lang, + size=size, + tenant_id=tenant_id, + ) + completion_ms = int((time.time() - t_completion_start) * 1000) + + suggestions: List[Dict[str, Any]] = [] + seen_text_norm: set = set() + + def _norm_text(v: Any) -> str: + return str(v or "").strip().lower() + + def _append_items(items: List[Dict[str, Any]]) -> None: + for item in items: + text_val = item.get("text") + norm = _norm_text(text_val) + if not norm or norm in seen_text_norm: + continue + seen_text_norm.add(norm) + suggestions.append(dict(item)) + + _append_items(completion_items) + + # Fast path: avoid a second ES query for short prefixes or when completion already full. + if len(query_text) <= 2 or len(suggestions) >= size: + took_ms = int((time.time() - start) * 1000) + logger.info( + "suggest completion-fast-return | tenant=%s lang=%s q=%s completion=%d took_ms=%d completion_ms=%d", + tenant_id, + resolved_lang, + query_text, + len(suggestions), + took_ms, + completion_ms, + ) + return { + "query": query, + "language": language, + "resolved_language": resolved_lang, + "suggestions": suggestions[:size], + "took_ms": took_ms, + } + + # Recall path B: bool_prefix on search_as_you_type (fallback/recall补全) sat_field = f"sat.{resolved_lang}" dsl = { "track_total_hits": False, @@ -151,7 +201,7 @@ class SuggestionService: "should": [ { "multi_match": { - "query": query, + "query": query_text, "type": "bool_prefix", "fields": [sat_field, f"{sat_field}._2gram", f"{sat_field}._3gram"], } @@ -180,7 +230,7 @@ class SuggestionService: "lang_conflict", ], } - # Recall path A: bool_prefix on search_as_you_type + t_sat_start = time.time() es_resp = self.es_client.search( index_name=index_name, body=dsl, @@ -188,52 +238,38 @@ class SuggestionService: from_=0, routing=str(tenant_id), ) + sat_ms = int((time.time() - t_sat_start) * 1000) hits = es_resp.get("hits", {}).get("hits", []) or [] - # Recall path B: completion suggester (optional optimization) - completion_items = self._completion_suggest( - index_name=index_name, - query=query, - lang=resolved_lang, - size=size, - tenant_id=tenant_id, - ) - - suggestions: List[Dict[str, Any]] = [] - seen_text_norm: set = set() - - def _norm_text(v: Any) -> str: - return str(v or "").strip().lower() - - # Put completion results first (usually better prefix UX), then fill with sat results. - for item in completion_items: - text_val = item.get("text") - norm = _norm_text(text_val) - if not norm or norm in seen_text_norm: - continue - seen_text_norm.add(norm) - suggestions.append(dict(item)) - + sat_items: List[Dict[str, Any]] = [] for hit in hits: src = hit.get("_source", {}) or {} - text_val = src.get("text") - norm = _norm_text(text_val) - if not norm or norm in seen_text_norm: - continue - seen_text_norm.add(norm) - item = { - "text": text_val, - "lang": src.get("lang"), - "score": hit.get("_score", 0.0), - "rank_score": src.get("rank_score"), - "sources": src.get("sources", []), - "lang_source": src.get("lang_source"), - "lang_confidence": src.get("lang_confidence"), - "lang_conflict": src.get("lang_conflict", False), - } - suggestions.append(item) + sat_items.append( + { + "text": src.get("text"), + "lang": src.get("lang"), + "score": hit.get("_score", 0.0), + "rank_score": src.get("rank_score"), + "sources": src.get("sources", []), + "lang_source": src.get("lang_source"), + "lang_confidence": src.get("lang_confidence"), + "lang_conflict": src.get("lang_conflict", False), + } + ) + _append_items(sat_items) took_ms = int((time.time() - start) * 1000) + logger.info( + "suggest completion+sat-return | tenant=%s lang=%s q=%s completion=%d sat_hits=%d took_ms=%d completion_ms=%d sat_ms=%d", + tenant_id, + resolved_lang, + query_text, + len(completion_items), + len(hits), + took_ms, + completion_ms, + sat_ms, + ) return { "query": query, "language": language, -- libgit2 0.21.2