diff --git a/api/routes/search.py b/api/routes/search.py index 9cc10df..63aef84 100644 --- a/api/routes/search.py +++ b/api/routes/search.py @@ -18,6 +18,7 @@ from ..models import ( ErrorResponse ) from context.request_context import create_request_context, set_current_request_context, clear_current_request_context +from indexer.mapping_generator import get_tenant_index_name router = APIRouter(prefix="/search", tags=["search"]) backend_verbose_logger = logging.getLogger("backend.verbose") @@ -437,3 +438,57 @@ async def get_document(doc_id: str, http_request: Request): raise except Exception as e: raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/es-doc/{spu_id}") +async def get_es_raw_document(spu_id: str, http_request: Request): + """ + Get raw Elasticsearch document(s) for a given SPU ID. + + This is intended for debugging in the test frontend: + it queries the tenant-specific ES index with a term filter on spu_id + and returns the raw ES search response. + """ + # Extract tenant_id (required) + tenant_id = http_request.headers.get("X-Tenant-ID") + if not tenant_id: + from urllib.parse import parse_qs + query_string = http_request.url.query + if query_string: + params = parse_qs(query_string) + tenant_id = params.get("tenant_id", [None])[0] + + if not tenant_id: + raise HTTPException( + status_code=400, + detail="tenant_id is required. Provide it via header 'X-Tenant-ID' or query parameter 'tenant_id'", + ) + + try: + from api.app import get_searcher + + searcher = get_searcher() + es_client = searcher.es_client + index_name = get_tenant_index_name(tenant_id) + + body = { + "size": 5, + "query": { + "bool": { + "filter": [ + { + "term": { + "spu_id": spu_id, + } + } + ] + } + }, + } + + es_response = es_client.search(index_name=index_name, body=body, size=5, from_=0) + return es_response + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) diff --git a/docs/翻译模块说明.md b/docs/翻译模块说明.md index 03a70cf..91bd643 100644 --- a/docs/翻译模块说明.md +++ b/docs/翻译模块说明.md @@ -15,6 +15,12 @@ DEEPL_AUTH_KEY=xxx TRANSLATION_MODEL=qwen # 或 deepl ``` +> **重要限速说明(Qwen 机翻)** +> 当前默认的 Qwen 翻译后端使用 `qwen-mt-flash` 云端模型,**官方限速较低,约 RPM=60(每分钟约 60 请求)**。 +> - 推荐通过 Redis 翻译缓存复用结果,避免对相同文本重复打云端 +> - 高并发场景需要在调用端做限流 / 去抖,或改为离线批量翻译 +> - 如需更高吞吐,可考虑 DeepL 或自建翻译服务 + ## Provider 配置 Provider 与 URL 在 `config/config.yaml` 的 `services.translation`。详见 [QUICKSTART.md](./QUICKSTART.md) §3 与 [DEVELOPER_GUIDE.md](./DEVELOPER_GUIDE.md) §7.2。 diff --git a/frontend/static/css/style.css b/frontend/static/css/style.css index 87a3f11..73478aa 100644 --- a/frontend/static/css/style.css +++ b/frontend/static/css/style.css @@ -322,9 +322,9 @@ body { /* Product Grid */ .product-grid { - display: grid; - grid-template-columns: repeat(auto-fill, minmax(220px, 1fr)); - gap: 20px; + display: flex; + flex-direction: column; + gap: 16px; padding: 30px; background: #f8f8f8; min-height: 400px; @@ -336,9 +336,51 @@ body { border-radius: 8px; padding: 15px; transition: all 0.3s; - cursor: pointer; + cursor: default; + display: flex; + flex-direction: row; + align-items: flex-start; + gap: 16px; +} + +.product-main { display: flex; flex-direction: column; + width: 260px; + flex-shrink: 0; +} + +.product-debug { + flex: 1; + font-family: Menlo, Consolas, "Courier New", monospace; + font-size: 12px; + color: #555; + border-left: 1px dashed #eee; + padding-left: 12px; + max-height: 260px; + overflow: auto; +} + +.product-debug-title { + font-weight: 600; + margin-bottom: 6px; + color: #333; +} + +.product-debug-line { + margin-bottom: 2px; +} + +.product-debug-link { + display: inline-block; + margin-top: 6px; + font-size: 12px; + color: #e67e22; + text-decoration: none; +} + +.product-debug-link:hover { + text-decoration: underline; } .product-card:hover { @@ -347,7 +389,7 @@ body { } .product-image-wrapper { - width: 100%; + width: 220px; height: 180px; display: flex; align-items: center; diff --git a/frontend/static/js/app.js b/frontend/static/js/app.js index 29a0b91..6c17c74 100644 --- a/frontend/static/js/app.js +++ b/frontend/static/js/app.js @@ -287,7 +287,8 @@ async function performSearch(page = 1) { sort_by: state.sortBy || null, sort_order: state.sortOrder, sku_filter_dimension: skuFilterDimension, - debug: state.debug + // 测试前端始终开启后端调试信息 + debug: true }) }); @@ -336,7 +337,19 @@ function displayResults(data) { } let html = ''; - + + // Build per-SPU debug lookup from debug_info.per_result (if present) + let perResultDebugBySpu = {}; + if (state.debug && data.debug_info && Array.isArray(data.debug_info.per_result)) { + data.debug_info.per_result.forEach((item) => { + if (item && item.spu_id) { + perResultDebugBySpu[String(item.spu_id)] = item; + } + }); + } + + const tenantId = getTenantId(); + data.results.forEach((result) => { const product = result; const title = product.title || product.name || 'N/A'; @@ -344,32 +357,71 @@ function displayResults(data) { const imageUrl = product.image_url || product.imageUrl || ''; const category = product.category || product.categoryName || ''; const vendor = product.vendor || product.brandName || ''; + const spuId = product.spu_id || ''; + const debug = spuId ? perResultDebugBySpu[String(spuId)] : null; + + let debugHtml = ''; + if (debug) { + const esScore = typeof debug.es_score === 'number' ? debug.es_score.toFixed(4) : String(debug.es_score ?? ''); + const esNorm = typeof debug.es_score_normalized === 'number' + ? debug.es_score_normalized.toFixed(4) + : (debug.es_score_normalized == null ? '' : String(debug.es_score_normalized)); + + // Build multilingual title info + let titleLines = ''; + if (debug.title_multilingual && typeof debug.title_multilingual === 'object') { + Object.entries(debug.title_multilingual).forEach(([lang, val]) => { + if (val) { + titleLines += `
title.${escapeHtml(String(lang))}: ${escapeHtml(String(val))}
`; + } + }); + } + + const rawUrl = `${API_BASE_URL}/search/es-doc/${encodeURIComponent(spuId)}?tenant_id=${encodeURIComponent(tenantId)}`; + + debugHtml = ` +
+
Ranking Debug
+
spu_id: ${escapeHtml(String(spuId || ''))}
+
es_id: ${escapeHtml(String(debug.es_id || ''))}
+
ES score: ${esScore}
+
ES normalized: ${esNorm}
+ ${titleLines} + + 查看 ES 原始文档 + +
+ `; + } html += `
-
- ${imageUrl ? ` - ${escapeHtml(title)} - ` : ` -
No Image
- `} -
- -
- ${price !== 'N/A' ? `¥${price}` : 'N/A'} -
- -
- ${escapeHtml(title)} -
- -
- ${category ? escapeHtml(category) : ''} - ${vendor ? ' | ' + escapeHtml(vendor) : ''} +
+
+ ${imageUrl ? ` + ${escapeHtml(title)} + ` : ` +
No Image
+ `} +
+ +
+ ${price !== 'N/A' ? `¥${price}` : 'N/A'} +
+ +
+ ${escapeHtml(title)} +
+ +
+ ${category ? escapeHtml(category) : ''} + ${vendor ? ' | ' + escapeHtml(vendor) : ''} +
+ ${debugHtml}
`; }); diff --git a/providers/translation.py b/providers/translation.py index 10bdef5..e69de29 100644 --- a/providers/translation.py +++ b/providers/translation.py @@ -1,170 +0,0 @@ -""" -Translation provider - direct (in-process) or HTTP service. -""" - -from __future__ import annotations - -import logging -from typing import Any, Dict, List, Optional, Union - -from concurrent.futures import Future, ThreadPoolExecutor -import requests - -from config.services_config import get_translation_config, get_translation_base_url - -logger = logging.getLogger(__name__) - - -class HttpTranslationProvider: - """Translation via HTTP service.""" - - def __init__( - self, - base_url: str, - model: str = "qwen", - timeout_sec: float = 10.0, - translation_context: Optional[str] = None, - ): - self.base_url = (base_url or "").rstrip("/") - self.model = model or "qwen" - self.timeout_sec = float(timeout_sec or 10.0) - self.translation_context = translation_context or "e-commerce product search" - self.executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="http-translator") - - def _translate_once( - self, - text: str, - target_lang: str, - source_lang: Optional[str] = None, - ) -> Optional[str]: - if not text or not str(text).strip(): - return text - try: - url = f"{self.base_url}/translate" - payload = { - "text": text, - "target_lang": target_lang, - "source_lang": source_lang or "auto", - "model": self.model, - } - response = requests.post(url, json=payload, timeout=self.timeout_sec) - if response.status_code != 200: - logger.warning( - "HTTP translator failed: status=%s body=%s", - response.status_code, - (response.text or "")[:200], - ) - return None - data = response.json() - translated = data.get("translated_text") - return translated if translated is not None else None - except Exception as exc: - logger.warning("HTTP translator request failed: %s", exc, exc_info=True) - return None - - def translate( - self, - text: str, - target_lang: str, - source_lang: Optional[str] = None, - context: Optional[str] = None, - prompt: Optional[str] = None, - ) -> Optional[str]: - del context, prompt - result = self._translate_once(text=text, target_lang=target_lang, source_lang=source_lang) - return result if result is not None else text - - def translate_multi( - self, - text: str, - target_langs: List[str], - source_lang: Optional[str] = None, - context: Optional[str] = None, - async_mode: bool = True, - prompt: Optional[str] = None, - ) -> Dict[str, Optional[str]]: - del context, async_mode, prompt - out: Dict[str, Optional[str]] = {} - for lang in target_langs: - out[lang] = self.translate(text, lang, source_lang=source_lang) - return out - - def translate_multi_async( - self, - text: str, - target_langs: List[str], - source_lang: Optional[str] = None, - context: Optional[str] = None, - prompt: Optional[str] = None, - ) -> Dict[str, Union[str, Future]]: - del context, prompt - out: Dict[str, Union[str, Future]] = {} - for lang in target_langs: - out[lang] = self.executor.submit(self.translate, text, lang, source_lang) - return out - - def translate_for_indexing( - self, - text: str, - shop_language: str, - source_lang: Optional[str] = None, - context: Optional[str] = None, - prompt: Optional[str] = None, - index_languages: Optional[List[str]] = None, - ) -> Dict[str, Optional[str]]: - del context, prompt - langs = index_languages if index_languages else ["en", "zh"] - source = source_lang or shop_language or "auto" - out: Dict[str, Optional[str]] = {} - for lang in langs: - if lang == shop_language: - out[lang] = text - else: - out[lang] = self.translate(text, target_lang=lang, source_lang=source) - return out - - -def create_translation_provider(query_config: Any = None) -> Any: - """ - Create translation provider from services config. - - query_config: optional, for api_key/glossary_id/context (used by direct provider). - """ - cfg = get_translation_config() - provider = cfg.provider - pc = cfg.get_provider_cfg() - - if provider in ("direct", "local", "inprocess"): - from query.translator import Translator - model = pc.get("model") or "qwen" - qc = query_config or _empty_query_config() - return Translator( - model=model, - api_key=getattr(qc, "translation_api_key", None), - use_cache=True, - glossary_id=getattr(qc, "translation_glossary_id", None), - translation_context=getattr(qc, "translation_context", "e-commerce product search"), - ) - - if provider in ("http", "service"): - base_url = get_translation_base_url() - model = pc.get("model") or "qwen" - timeout = pc.get("timeout_sec", 10.0) - qc = query_config or _empty_query_config() - return HttpTranslationProvider( - base_url=base_url, - model=model, - timeout_sec=float(timeout), - translation_context=getattr(qc, "translation_context", "e-commerce product search"), - ) - - raise ValueError(f"Unsupported translation provider: {provider}") - - -def _empty_query_config() -> Any: - """Minimal object with default translation attrs.""" - class _QC: - translation_api_key = None - translation_glossary_id = None - translation_context = "e-commerce product search" - return _QC() diff --git a/query/llm_translate.py b/query/llm_translate.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/query/llm_translate.py diff --git a/query/translator.py b/query/translator.py index 1ad90e5..ee39071 100644 --- a/query/translator.py +++ b/query/translator.py @@ -5,6 +5,11 @@ Supports multiple translation models: - Qwen (default): Alibaba Cloud DashScope API using qwen-mt-flash model - DeepL: DeepL API for high-quality translations +重要说明(Qwen 机翻限速): +- 当前默认使用的 `qwen-mt-flash` 为云端机翻模型,**官方限速较低,约 RPM=60(每分钟约 60 请求)** +- 在高并发场景必须依赖 Redis 翻译缓存与批量预热,避免在用户实时请求路径上直接打满 DashScope 限流 +- 若业务侧存在大规模离线翻译或更高吞吐需求,建议评估 DeepL 或自建翻译后端 + 使用方法 (Usage): ```python diff --git a/reranker/backends/batching_utils.py b/reranker/backends/batching_utils.py deleted file mode 100644 index ed48f96..0000000 --- a/reranker/backends/batching_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Utilities for reranker batching and deduplication.""" - -from __future__ import annotations - -from typing import Iterable, List, Sequence, Tuple - - -def deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]: - """ - Deduplicate texts globally while preserving first-seen order. - - Returns: - unique_texts: deduplicated texts in first-seen order - position_to_unique: mapping from each original position to unique index - """ - unique_texts: List[str] = [] - position_to_unique: List[int] = [] - seen: dict[str, int] = {} - - for text in texts: - idx = seen.get(text) - if idx is None: - idx = len(unique_texts) - seen[text] = idx - unique_texts.append(text) - position_to_unique.append(idx) - - return unique_texts, position_to_unique - - -def sort_indices_by_length(lengths: Sequence[int]) -> List[int]: - """Return stable ascending indices by lengths.""" - return sorted(range(len(lengths)), key=lambda i: lengths[i]) - - -def iter_batches(indices: Sequence[int], batch_size: int) -> Iterable[List[int]]: - """Yield consecutive batches from indices.""" - if batch_size <= 0: - raise ValueError(f"batch_size must be > 0, got {batch_size}") - for i in range(0, len(indices), batch_size): - yield list(indices[i : i + batch_size]) diff --git a/reranker/backends/dashscope_rerank.py b/reranker/backends/dashscope_rerank.py index eb5ef24..fefa67c 100644 --- a/reranker/backends/dashscope_rerank.py +++ b/reranker/backends/dashscope_rerank.py @@ -21,11 +21,33 @@ from typing import Any, Dict, List, Tuple from urllib import error as urllib_error from urllib import request as urllib_request -from reranker.backends.batching_utils import deduplicate_with_positions, iter_batches logger = logging.getLogger("reranker.backends.dashscope_rerank") +def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]: + """ + Deduplicate texts globally while preserving first-seen order. + + Returns: + unique_texts: deduplicated texts in first-seen order + position_to_unique: mapping from each original position to unique index + """ + unique_texts: List[str] = [] + position_to_unique: List[int] = [] + seen: Dict[str, int] = {} + + for text in texts: + idx = seen.get(text) + if idx is None: + idx = len(unique_texts) + seen[text] = idx + unique_texts.append(text) + position_to_unique.append(idx) + + return unique_texts, position_to_unique + + class DashScopeRerankBackend: """ DashScope cloud reranker backend. @@ -206,7 +228,10 @@ class DashScopeRerankBackend: then apply global top_n/top_n_cap truncation after merge if needed. """ indices = list(range(len(unique_texts))) - batches = list(iter_batches(indices, batch_size=self._batchsize)) + batches = [ + indices[i : i + self._batchsize] + for i in range(0, len(indices), self._batchsize) + ] num_batches = len(batches) max_workers = min(8, num_batches) if num_batches > 0 else 1 unique_scores: List[float] = [0.0] * len(unique_texts) diff --git a/reranker/backends/qwen3_vllm.py b/reranker/backends/qwen3_vllm.py index 9ca830d..4b23f8d 100644 --- a/reranker/backends/qwen3_vllm.py +++ b/reranker/backends/qwen3_vllm.py @@ -14,12 +14,6 @@ import threading import time from typing import Any, Dict, List, Tuple -from reranker.backends.batching_utils import ( - deduplicate_with_positions, - iter_batches, - sort_indices_by_length, -) - logger = logging.getLogger("reranker.backends.qwen3_vllm") try: @@ -34,6 +28,29 @@ except ImportError as e: ) from e +def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]: + """ + Deduplicate texts globally while preserving first-seen order. + + Returns: + unique_texts: deduplicated texts in first-seen order + position_to_unique: mapping from each original position to unique index + """ + unique_texts: List[str] = [] + position_to_unique: List[int] = [] + seen: Dict[str, int] = {} + + for text in texts: + idx = seen.get(text) + if idx is None: + idx = len(unique_texts) + seen[text] = idx + unique_texts.append(text) + position_to_unique.append(idx) + + return unique_texts, position_to_unique + + def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: """Build chat messages for one (query, doc) pair.""" return [ @@ -71,19 +88,17 @@ class Qwen3VLLMRerankerBackend: sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH") if sort_by_doc_length is None: sort_by_doc_length = self._config.get("sort_by_doc_length", True) - length_sort_mode = os.getenv("RERANK_VLLM_LENGTH_SORT_MODE") or self._config.get("length_sort_mode", "char") self._infer_batch_size = int(infer_batch_size) self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {"1", "true", "yes", "y", "on"} - self._length_sort_mode = str(length_sort_mode).strip().lower() if not torch.cuda.is_available(): raise RuntimeError("qwen3_vllm backend requires CUDA GPU, but torch.cuda.is_available() is False") if dtype not in {"float16", "half", "auto"}: raise ValueError(f"Unsupported dtype for qwen3_vllm: {dtype!r}. Use float16/half/auto.") if self._infer_batch_size <= 0: - raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}") - if self._length_sort_mode not in {"char", "token"}: - raise ValueError(f"length_sort_mode must be 'char' or 'token', got {self._length_sort_mode!r}") + raise ValueError( + f"infer_batch_size must be > 0, got {self._infer_batch_size}" + ) logger.info( "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)", @@ -196,21 +211,7 @@ class Qwen3VLLMRerankerBackend: """ if not docs: return [] - if self._length_sort_mode == "char": - return [len(text) for text in docs] - try: - enc = self._tokenizer( - docs, - add_special_tokens=False, - truncation=True, - max_length=self._max_prompt_len, - return_length=True, - ) - lengths = enc.get("length") - if isinstance(lengths, list) and len(lengths) == len(docs): - return [int(x) for x in lengths] - except Exception as exc: - logger.debug("Length estimation fallback to char length: %s", exc) + # Use simple character length to approximate document length. return [len(text) for text in docs] def score_with_meta( @@ -247,7 +248,6 @@ class Qwen3VLLMRerankerBackend: "infer_batch_size": self._infer_batch_size, "inference_batches": 0, "sort_by_doc_length": self._sort_by_doc_length, - "length_sort_mode": self._length_sort_mode, } # Deduplicate globally by text, keep mapping to original indices. @@ -257,11 +257,12 @@ class Qwen3VLLMRerankerBackend: lengths = self._estimate_doc_lengths(unique_texts) order = list(range(len(unique_texts))) if self._sort_by_doc_length and len(unique_texts) > 1: - order = sort_indices_by_length(lengths) + order = sorted(order, key=lambda i: lengths[i]) unique_scores: List[float] = [0.0] * len(unique_texts) inference_batches = 0 - for batch_indices in iter_batches(order, self._infer_batch_size): + for start in range(0, len(order), self._infer_batch_size): + batch_indices = order[start : start + self._infer_batch_size] inference_batches += 1 pairs = [(query, unique_texts[i]) for i in batch_indices] prompts = self._process_inputs(pairs) diff --git a/scripts/service_ctl.sh b/scripts/service_ctl.sh index bdf7f06..5887483 100755 --- a/scripts/service_ctl.sh +++ b/scripts/service_ctl.sh @@ -684,6 +684,7 @@ status_one() { local pid_info="-" local health="down" local health_body="" + local curl_timeout_opts=(--connect-timeout 8 --max-time 8) if [ "${service}" = "tei" ]; then local cid @@ -696,7 +697,7 @@ status_one() { local path path="$(health_path_for_service "${service}")" if [ -n "${port}" ] && [ -n "${path}" ]; then - if health_body="$(curl -fsS "http://127.0.0.1:${port}${path}" 2>/dev/null)"; then + if health_body="$(curl -fsS "${curl_timeout_opts[@]}" "http://127.0.0.1:${port}${path}" 2>/dev/null)"; then health="ok" else health="fail" @@ -723,7 +724,7 @@ status_one() { local path path="$(health_path_for_service "${service}")" if [ -n "${port}" ] && [ -n "${path}" ]; then - if health_body="$(curl -fsS "http://127.0.0.1:${port}${path}" 2>/dev/null)"; then + if health_body="$(curl -fsS "${curl_timeout_opts[@]}" "http://127.0.0.1:${port}${path}" 2>/dev/null)"; then health="ok" else health="fail" diff --git a/search/searcher.py b/search/searcher.py index 6c4057e..e1641fb 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -598,7 +598,6 @@ class Searcher: es_hits = [] if 'hits' in es_response and 'hits' in es_response['hits']: es_hits = es_response['hits']['hits'] - # Extract total and max_score total = es_response.get('hits', {}).get('total', {}) if isinstance(total, dict): @@ -616,6 +615,37 @@ class Searcher: sku_filter_dimension=sku_filter_dimension ) + # Build per-result debug info (per SPU) when debug mode is enabled + per_result_debug = [] + if debug and es_hits and formatted_results: + for hit, spu in zip(es_hits, formatted_results): + source = hit.get("_source", {}) or {} + raw_score = hit.get("_score") + try: + es_score = float(raw_score) if raw_score is not None else 0.0 + except (TypeError, ValueError): + es_score = 0.0 + try: + normalized = float(es_score) / float(max_score) if max_score else None + except (TypeError, ValueError, ZeroDivisionError): + normalized = None + + title_multilingual = source.get("title") if isinstance(source.get("title"), dict) else None + brief_multilingual = source.get("brief") if isinstance(source.get("brief"), dict) else None + vendor_multilingual = source.get("vendor") if isinstance(source.get("vendor"), dict) else None + + per_result_debug.append( + { + "spu_id": spu.spu_id, + "es_id": hit.get("_id"), + "es_score": es_score, + "es_score_normalized": normalized, + "title_multilingual": title_multilingual, + "brief_multilingual": brief_multilingual, + "vendor_multilingual": vendor_multilingual, + } + ) + # Format facets standardized_facets = None if facets: @@ -676,6 +706,8 @@ class Searcher: }, "search_params": context.metadata.get('search_params', {}) } + if per_result_debug: + debug_info["per_result"] = per_result_debug # Build result result = SearchResult( diff --git a/tests/test_reranker_batching_utils.py b/tests/test_reranker_batching_utils.py deleted file mode 100644 index 709d327..0000000 --- a/tests/test_reranker_batching_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -import pytest - -from reranker.backends.batching_utils import ( - deduplicate_with_positions, - iter_batches, - sort_indices_by_length, -) - - -def test_deduplicate_with_positions_global_not_adjacent(): - texts = ["a", "b", "a", "c", "b", "a"] - unique, mapping = deduplicate_with_positions(texts) - assert unique == ["a", "b", "c"] - assert mapping == [0, 1, 0, 2, 1, 0] - - -def test_sort_indices_by_length_stable(): - lengths = [5, 2, 2, 9, 4] - order = sort_indices_by_length(lengths) - # Stable sort: index 1 remains ahead of index 2 when lengths are equal. - assert order == [1, 2, 4, 0, 3] - - -def test_iter_batches(): - indices = list(range(10)) - batches = list(iter_batches(indices, 4)) - assert batches == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]] - - -def test_iter_batches_invalid_batch_size(): - with pytest.raises(ValueError, match="batch_size must be > 0"): - list(iter_batches([0, 1], 0)) -- libgit2 0.21.2