From a47416ec0d2f11e71a6e1b7f24c1dfe34bf48cf1 Mon Sep 17 00:00:00 2001 From: tangwang Date: Wed, 18 Mar 2026 10:24:05 +0800 Subject: [PATCH] 把融合逻辑改成乘法公式,并把 ES 命名子句分数回传链路补上了。 --- api/routes/indexer.py | 30 ++++++++++++++++++++++++------ docs/搜索API对接指南.md | 16 +++++++++++----- indexer/product_enrich.py | 101 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------- indexer/product_enrich_prompts.py | 8 ++++---- search/es_query_builder.py | 3 ++- search/rerank_client.py | 64 +++++++++++++++++++++++++++++++++++++++++++++------------------- search/searcher.py | 4 +++- tests/test_es_query_builder.py | 1 + tests/test_product_enrich_partial_mode.py | 81 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------ tests/test_rerank_client.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_search_rerank_window.py | 20 ++++++++++++++++++-- utils/es_client.py | 2 ++ 12 files changed, 322 insertions(+), 61 deletions(-) create mode 100644 tests/test_rerank_client.py diff --git a/api/routes/indexer.py b/api/routes/indexer.py index d50f578..65bfb3a 100644 --- a/api/routes/indexer.py +++ b/api/routes/indexer.py @@ -7,7 +7,7 @@ import asyncio import re from fastapi import APIRouter, HTTPException -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field import logging from sqlalchemy import text @@ -78,9 +78,12 @@ class BuildDocsFromDbRequest(BaseModel): class EnrichContentItem(BaseModel): - """单条待生成内容理解字段的商品(仅需 spu_id + 标题)。""" + """单条待生成内容理解字段的商品。""" spu_id: str = Field(..., description="SPU ID") title: str = Field(..., description="商品标题,用于 LLM 分析生成 qanchors / tags 等") + image_url: Optional[str] = Field(None, description="商品主图 URL(预留给多模态/内容理解扩展)") + brief: Optional[str] = Field(None, description="商品简介/短描述") + description: Optional[str] = Field(None, description="商品详情/长描述") class EnrichContentRequest(BaseModel): @@ -88,8 +91,8 @@ class EnrichContentRequest(BaseModel): 内容理解字段生成请求:根据商品标题批量生成 qanchors、semantic_attributes、tags。 供外部 indexer 在自行组织 doc 时调用,与翻译、向量化等微服务并列。 """ - tenant_id: str = Field(..., description="租户 ID,用于缓存隔离") - items: List[EnrichContentItem] = Field(..., description="待分析的 SPU 列表(spu_id + title)") + tenant_id: str = Field(..., description="租户 ID,用于请求路由与结果归属,不参与缓存键") + items: List[EnrichContentItem] = Field(..., description="待分析的 SPU 列表(spu_id + title,可附带 brief/description/image_url)") languages: List[str] = Field( default_factory=lambda: ["zh", "en"], description="目标语言列表,需在支持范围内(zh/en/de/ru/fr),默认 zh, en", @@ -450,7 +453,16 @@ def _run_enrich_content(tenant_id: str, items: List[Dict[str, str]], languages: llm_langs = list(dict.fromkeys(languages)) or ["en"] - products = [{"id": it["spu_id"], "title": (it.get("title") or "").strip()} for it in items] + products = [ + { + "id": it["spu_id"], + "title": (it.get("title") or "").strip(), + "brief": (it.get("brief") or "").strip(), + "description": (it.get("description") or "").strip(), + "image_url": (it.get("image_url") or "").strip(), + } + for it in items + ] dim_keys = [ "tags", "target_audience", @@ -545,7 +557,13 @@ async def enrich_content(request: EnrichContentRequest): ) items_payload = [ - {"spu_id": it.spu_id, "title": it.title or ""} + { + "spu_id": it.spu_id, + "title": it.title or "", + "brief": it.brief or "", + "description": it.description or "", + "image_url": it.image_url or "", + } for it in request.items ] loop = asyncio.get_event_loop() diff --git a/docs/搜索API对接指南.md b/docs/搜索API对接指南.md index b72b9ac..ece64e4 100644 --- a/docs/搜索API对接指南.md +++ b/docs/搜索API对接指南.md @@ -1511,7 +1511,7 @@ curl -X POST "http://127.0.0.1:6004/indexer/build-docs-from-db" \ | 参数 | 类型 | 必填 | 默认值 | 说明 | |------|------|------|--------|------| -| `tenant_id` | string | Y | - | 租户 ID,用于缓存隔离 | +| `tenant_id` | string | Y | - | 租户 ID。目前仅用于记录日志,不产生实际作用| | `items` | array | Y | - | 待分析列表;**单次最多 50 条** | | `languages` | array[string] | N | `["zh", "en"]` | 目标语言,需在支持范围内:`zh`、`en`、`de`、`ru`、`fr` | @@ -1519,11 +1519,17 @@ curl -X POST "http://127.0.0.1:6004/indexer/build-docs-from-db" \ | 字段 | 类型 | 必填 | 说明 | |------|------|------|------| -| `spu_id` | string | Y | SPU ID,用于回填结果与缓存键 | +| `spu_id` | string | Y | SPU ID,用于回填结果;目前仅用于记录日志,不产生实际作用| | `title` | string | Y | 商品标题 | -| `image_url` | string | N | 商品主图 URL(预留:后续可用于图像/多模态内容理解) | -| `brief` | string | N | 商品简介/短描述(预留) | -| `description` | string | N | 商品详情/长描述(预留) | +| `image_url` | string | N | 商品主图 URL;当前会参与内容缓存键,后续可用于图像/多模态内容理解 | +| `brief` | string | N | 商品简介/短描述;当前会参与内容缓存键 | +| `description` | string | N | 商品详情/长描述;当前会参与内容缓存键 | + +缓存说明: + +- 内容缓存键仅由 `target_lang + items[]` 中会影响内容理解结果的输入文本构成,目前包括:`title`、`brief`、`description`、`image_url` 的规范化内容 hash。 +- `tenant_id`、`spu_id` 只用于请求归属与结果回填,不参与缓存键。 +- 因此,输入内容不变时可跨请求直接命中缓存;任一输入字段变化时,会自然落到新的缓存 key。 批量请求建议: - **全量**:强烈建议 尽可能 **20 个 SPU/doc** 攒成一个批次后再请求一次。 diff --git a/indexer/product_enrich.py b/indexer/product_enrich.py index 445cfeb..35e3567 100644 --- a/indexer/product_enrich.py +++ b/indexer/product_enrich.py @@ -9,6 +9,7 @@ import os import json import logging +import re import time import hashlib from collections import OrderedDict @@ -40,6 +41,10 @@ MAX_RETRIES = 3 RETRY_DELAY = 5 # 秒 REQUEST_TIMEOUT = 180 # 秒 LOGGED_SHARED_CONTEXT_CACHE_SIZE = 256 +PROMPT_INPUT_MIN_ZH_CHARS = 20 +PROMPT_INPUT_MAX_ZH_CHARS = 100 +PROMPT_INPUT_MIN_WORDS = 16 +PROMPT_INPUT_MAX_WORDS = 80 # 日志路径 OUTPUT_DIR = Path("output_logs") @@ -82,6 +87,8 @@ if not verbose_logger.handlers: verbose_logger.addHandler(verbose_file_handler) verbose_logger.propagate = False +logger.info("Verbose LLM logs are written to: %s", verbose_log_file) + # Redis 缓存(用于 anchors / 语义属性) ANCHOR_CACHE_PREFIX = REDIS_CONFIG.get("anchor_cache_prefix", "product_anchors") @@ -112,26 +119,86 @@ if _missing_prompt_langs: ) +def _normalize_space(text: str) -> str: + return re.sub(r"\s+", " ", (text or "").strip()) + + +def _contains_cjk(text: str) -> bool: + return bool(re.search(r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]", text or "")) + + +def _truncate_by_chars(text: str, max_chars: int) -> str: + return text[:max_chars].strip() + + +def _truncate_by_words(text: str, max_words: int) -> str: + words = re.findall(r"\S+", text or "") + return " ".join(words[:max_words]).strip() + + +def _detect_prompt_input_lang(text: str) -> str: + # 简化处理:包含 CJK 时按中文类文本处理,否则统一按空格分词类语言处理。 + return "zh" if _contains_cjk(text) else "en" + + +def _build_prompt_input_text(product: Dict[str, Any]) -> str: + """ + 生成真正送入 prompt 的商品文本。 + + 规则: + - 默认使用 title + - 若文本过短,则依次补 brief / description + - 若文本过长,则按语言粗粒度截断 + """ + fields = [ + _normalize_space(str(product.get("title") or "")), + _normalize_space(str(product.get("brief") or "")), + _normalize_space(str(product.get("description") or "")), + ] + parts: List[str] = [] + + def join_parts() -> str: + return " | ".join(part for part in parts if part).strip() + + for field in fields: + if not field: + continue + if field not in parts: + parts.append(field) + candidate = join_parts() + if _detect_prompt_input_lang(candidate) == "zh": + if len(candidate) >= PROMPT_INPUT_MIN_ZH_CHARS: + return _truncate_by_chars(candidate, PROMPT_INPUT_MAX_ZH_CHARS) + else: + if len(re.findall(r"\S+", candidate)) >= PROMPT_INPUT_MIN_WORDS: + return _truncate_by_words(candidate, PROMPT_INPUT_MAX_WORDS) + + candidate = join_parts() + if not candidate: + return "" + if _detect_prompt_input_lang(candidate) == "zh": + return _truncate_by_chars(candidate, PROMPT_INPUT_MAX_ZH_CHARS) + return _truncate_by_words(candidate, PROMPT_INPUT_MAX_WORDS) + + def _make_anchor_cache_key( - title: str, + product: Dict[str, Any], target_lang: str, - tenant_id: Optional[str] = None, ) -> str: - """构造 anchors/语义属性的缓存 key。""" - base = (tenant_id or "global").strip() - h = hashlib.md5(title.encode("utf-8")).hexdigest() - return f"{ANCHOR_CACHE_PREFIX}:{base}:{target_lang}:{h}" + """构造缓存 key,仅由 prompt 实际输入文本内容 + 目标语言决定。""" + prompt_input = _build_prompt_input_text(product) + h = hashlib.md5(prompt_input.encode("utf-8")).hexdigest() + return f"{ANCHOR_CACHE_PREFIX}:{target_lang}:{prompt_input[:4]}{h}" def _get_cached_anchor_result( - title: str, + product: Dict[str, Any], target_lang: str, - tenant_id: Optional[str] = None, ) -> Optional[Dict[str, Any]]: if not _anchor_redis: return None try: - key = _make_anchor_cache_key(title, target_lang, tenant_id) + key = _make_anchor_cache_key(product, target_lang) raw = _anchor_redis.get(key) if not raw: return None @@ -142,15 +209,14 @@ def _get_cached_anchor_result( def _set_cached_anchor_result( - title: str, + product: Dict[str, Any], target_lang: str, result: Dict[str, Any], - tenant_id: Optional[str] = None, ) -> None: if not _anchor_redis: return try: - key = _make_anchor_cache_key(title, target_lang, tenant_id) + key = _make_anchor_cache_key(product, target_lang) ttl = ANCHOR_CACHE_EXPIRE_DAYS * 24 * 3600 _anchor_redis.setex(key, ttl, json.dumps(result, ensure_ascii=False)) except Exception as e: @@ -166,7 +232,8 @@ def _build_assistant_prefix(headers: List[str]) -> str: def _build_shared_context(products: List[Dict[str, str]]) -> str: shared_context = SHARED_ANALYSIS_INSTRUCTION for idx, product in enumerate(products, 1): - shared_context += f'{idx}. {product["title"]}\n' + prompt_input = _build_prompt_input_text(product) + shared_context += f"{idx}. {prompt_input}\n" return shared_context @@ -619,11 +686,11 @@ def analyze_products( uncached_items.append((idx, product)) continue - cached = _get_cached_anchor_result(title, target_lang, tenant_id=tenant_id) + cached = _get_cached_anchor_result(product, target_lang) if cached: logger.info( f"[analyze_products] Cache hit for title='{title[:50]}...', " - f"lang={target_lang}, tenant_id={tenant_id or 'global'}" + f"lang={target_lang}" ) results_by_index[idx] = cached continue @@ -650,7 +717,7 @@ def analyze_products( ) batch_results = process_batch(batch, batch_num=batch_num, target_lang=target_lang) - for (original_idx, _), item in zip(batch_slice, batch_results): + for (original_idx, product), item in zip(batch_slice, batch_results): results_by_index[original_idx] = item title_input = str(item.get("title_input") or "").strip() if not title_input: @@ -659,7 +726,7 @@ def analyze_products( # 不缓存错误结果,避免放大临时故障 continue try: - _set_cached_anchor_result(title_input, target_lang, item, tenant_id=tenant_id) + _set_cached_anchor_result(product, target_lang, item) except Exception: # 已在内部记录 warning pass diff --git a/indexer/product_enrich_prompts.py b/indexer/product_enrich_prompts.py index 44ae3c4..451925d 100644 --- a/indexer/product_enrich_prompts.py +++ b/indexer/product_enrich_prompts.py @@ -8,9 +8,9 @@ SYSTEM_MESSAGE = ( "Do not repeat or modify the prefix, and do not add explanations outside the table." ) -SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product title and fill these columns: +SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product text and fill these columns: -1. Product title: a natural localized product name derived from the input title +1. Product title: a natural localized product name derived from the input product text 2. Category path: broad to fine-grained category, separated by ">" 3. Fine-grained tags: style, features, functions, or notable attributes 4. Target audience: gender, age group, or suitable users @@ -23,7 +23,7 @@ SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product title and fill these Rules: - Keep the input order and row count exactly the same. -- Infer from the title only; if uncertain, prefer concise and broadly correct ecommerce wording. +- Infer only from the provided input product text; if uncertain, prefer concise and broadly correct ecommerce wording. - Keep category paths concise and use ">" as the separator. - For columns with multiple values, the localized output requirement will define the delimiter. @@ -515,4 +515,4 @@ LANGUAGE_MARKDOWN_TABLE_HEADERS: Dict[str, Dict[str, Any]] = { "Характеристики", "Анкор текст" ] -} \ No newline at end of file +} diff --git a/search/es_query_builder.py b/search/es_query_builder.py index 8620a1f..854db86 100644 --- a/search/es_query_builder.py +++ b/search/es_query_builder.py @@ -275,7 +275,8 @@ class ESQueryBuilder: "query_vector": query_vector.tolist(), "k": knn_k, "num_candidates": knn_num_candidates, - "boost": knn_boost + "boost": knn_boost, + "name": "knn_query", } # Top-level knn does not inherit query.bool.filter automatically. # Apply conjunctive + range filters here so vector recall respects hard filters. diff --git a/search/rerank_client.py b/search/rerank_client.py index 24c3686..7b3b0ed 100644 --- a/search/rerank_client.py +++ b/search/rerank_client.py @@ -4,7 +4,7 @@ 流程: 1. 从 ES hits 构造用于重排的文档文本列表 2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数 -3. 将 ES 分数(归一化)与重排分数线性融合,写回 hit["_score"] 并重排序 +3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序 """ from typing import Dict, Any, List, Optional, Tuple @@ -14,7 +14,7 @@ from providers import create_rerank_provider logger = logging.getLogger(__name__) -# 默认融合权重:ES 归一化分数权重、重排分数权重(相加为 1) +# 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。 DEFAULT_WEIGHT_ES = 0.4 DEFAULT_WEIGHT_AI = 0.6 # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大) @@ -103,18 +103,20 @@ def fuse_scores_and_resort( weight_ai: float = DEFAULT_WEIGHT_AI, ) -> List[Dict[str, Any]]: """ - 将 ES 分数与重排分数线性融合(不修改原始 _score),并按融合分数降序重排。 + 将 ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。 对每条 hit 会写入: - _original_score: 原始 ES 分数 - _rerank_score: 重排服务返回的分数 - _fused_score: 融合分数 + - _text_score: 文本相关性分数(优先取 named queries 的 base_query 分数) + - _knn_score: KNN 分数(优先取 named queries 的 knn_query 分数) Args: es_hits: ES hits 列表(会被原地修改) rerank_scores: 与 es_hits 等长的重排分数列表 - weight_es: ES 归一化分数权重 - weight_ai: 重排分数权重 + weight_es: 兼容保留,当前未使用 + weight_ai: 兼容保留,当前未使用 Returns: 每条文档的融合调试信息列表,用于 debug_info @@ -123,38 +125,62 @@ def fuse_scores_and_resort( if n == 0 or len(rerank_scores) != n: return [] - # 收集 ES 原始分数 - es_scores: List[float] = [] - for hit in es_hits: - raw = hit.get("_score") - try: - es_scores.append(float(raw) if raw is not None else 0.0) - except (TypeError, ValueError): - es_scores.append(0.0) - - max_es = max(es_scores) if es_scores else 0.0 fused_debug: List[Dict[str, Any]] = [] for idx, hit in enumerate(es_hits): - es_score = es_scores[idx] + raw_es_score = hit.get("_score") + try: + es_score = float(raw_es_score) if raw_es_score is not None else 0.0 + except (TypeError, ValueError): + es_score = 0.0 + ai_score_raw = rerank_scores[idx] try: rerank_score = float(ai_score_raw) except (TypeError, ValueError): rerank_score = 0.0 - es_norm = (es_score / max_es) if max_es > 0 else 0.0 - fused = weight_es * es_norm + weight_ai * rerank_score + matched_queries = hit.get("matched_queries") + text_score = 0.0 + knn_score = 0.0 + if isinstance(matched_queries, dict): + try: + text_score = float(matched_queries.get("base_query", 0.0) or 0.0) + except (TypeError, ValueError): + text_score = 0.0 + try: + knn_score = float(matched_queries.get("knn_query", 0.0) or 0.0) + except (TypeError, ValueError): + knn_score = 0.0 + elif isinstance(matched_queries, list): + text_score = 1.0 if "base_query" in matched_queries else 0.0 + knn_score = 1.0 if "knn_query" in matched_queries else 0.0 + + # 回退逻辑: + # - text_score 缺失时,退回原始 _score,避免纯文本召回被错误打成 0。 + # - knn_score 缺失时保持 0,由平滑项 0.6 兜底。 + if text_score <= 0.0: + text_score = es_score + + fused = ( + (rerank_score + 0.00001) ** 1.0 * + (knn_score + 0.6) ** 0.2 * + (text_score + 0.1) ** 0.75 + ) hit["_original_score"] = hit.get("_score") hit["_rerank_score"] = rerank_score + hit["_text_score"] = text_score + hit["_knn_score"] = knn_score hit["_fused_score"] = fused fused_debug.append({ "doc_id": hit.get("_id"), "es_score": es_score, - "es_score_norm": es_norm, "rerank_score": rerank_score, + "text_score": text_score, + "knn_score": knn_score, + "matched_queries": matched_queries, "fused_score": fused, }) diff --git a/search/searcher.py b/search/searcher.py index 33e57f0..ddd2910 100644 --- a/search/searcher.py +++ b/search/searcher.py @@ -400,6 +400,7 @@ class Searcher: # Add sorting if specified if sort_by: es_query = self.query_builder.add_sorting(es_query, sort_by, sort_order) + es_query["track_scores"] = True # Keep requested response _source semantics for the final response fill. response_source_spec = es_query.get("_source") @@ -467,7 +468,8 @@ class Searcher: index_name=index_name, body=body_for_es, size=es_fetch_size, - from_=es_fetch_from + from_=es_fetch_from, + include_named_queries_score=bool(do_rerank and in_rerank_window), ) # Store ES response in context diff --git a/tests/test_es_query_builder.py b/tests/test_es_query_builder.py index 82d940b..5cac1a6 100644 --- a/tests/test_es_query_builder.py +++ b/tests/test_es_query_builder.py @@ -62,3 +62,4 @@ def test_knn_prefilter_not_added_without_filters(): assert "knn" in q assert "filter" not in q["knn"] + assert q["knn"]["name"] == "knn_query" diff --git a/tests/test_product_enrich_partial_mode.py b/tests/test_product_enrich_partial_mode.py index f7dbb3b..705cec5 100644 --- a/tests/test_product_enrich_partial_mode.py +++ b/tests/test_product_enrich_partial_mode.py @@ -62,7 +62,7 @@ def test_create_prompt_splits_shared_context_and_localized_tail(): shared_en, user_en, prefix_en = product_enrich.create_prompt(products, target_lang="en") assert shared_zh == shared_en - assert "Analyze each input product title" in shared_zh + assert "Analyze each input product text" in shared_zh assert "1. dress" in shared_zh assert "2. linen shirt" in shared_zh assert "Product list" not in user_zh @@ -232,11 +232,20 @@ def test_analyze_products_uses_product_level_cache_across_batch_requests(): cache_store = {} process_calls = [] - def fake_get_cached_anchor_result(title, target_lang, tenant_id=None): - return cache_store.get((tenant_id, target_lang, title)) + def _cache_key(product, target_lang): + return ( + target_lang, + product.get("title", ""), + product.get("brief", ""), + product.get("description", ""), + product.get("image_url", ""), + ) + + def fake_get_cached_anchor_result(product, target_lang): + return cache_store.get(_cache_key(product, target_lang)) - def fake_set_cached_anchor_result(title, target_lang, result, tenant_id=None): - cache_store[(tenant_id, target_lang, title)] = result + def fake_set_cached_anchor_result(product, target_lang, result): + cache_store[_cache_key(product, target_lang)] = result def fake_process_batch(batch_data, batch_num, target_lang="zh"): process_calls.append( @@ -291,7 +300,7 @@ def test_analyze_products_uses_product_level_cache_across_batch_requests(): second = product_enrich.analyze_products( products, target_lang="zh", - tenant_id="170", + tenant_id="999", ) third = product_enrich.analyze_products( products, @@ -311,3 +320,63 @@ def test_analyze_products_uses_product_level_cache_across_batch_requests(): assert second[1]["anchor_text"] == "anchor:shirt" assert third[0]["anchor_text"] == "anchor:dress" assert third[1]["anchor_text"] == "anchor:shirt" + + +def test_anchor_cache_key_depends_on_product_input_not_identifiers(): + product_a = { + "id": "1", + "spu_id": "1001", + "title": "dress", + "brief": "soft cotton", + "description": "summer dress", + "image_url": "https://img/a.jpg", + } + product_b = { + "id": "2", + "spu_id": "9999", + "title": "dress", + "brief": "soft cotton", + "description": "summer dress", + "image_url": "https://img/a.jpg", + } + product_c = { + "id": "1", + "spu_id": "1001", + "title": "dress", + "brief": "soft cotton updated", + "description": "summer dress", + "image_url": "https://img/a.jpg", + } + + key_a = product_enrich._make_anchor_cache_key(product_a, "zh") + key_b = product_enrich._make_anchor_cache_key(product_b, "zh") + key_c = product_enrich._make_anchor_cache_key(product_c, "zh") + + assert key_a == key_b + assert key_a != key_c + + +def test_build_prompt_input_text_appends_brief_and_description_for_short_title(): + product = { + "title": "T恤", + "brief": "夏季透气纯棉短袖,舒适亲肤", + "description": "100%棉,圆领版型,适合日常通勤与休闲穿搭。", + } + + text = product_enrich._build_prompt_input_text(product) + + assert text.startswith("T恤") + assert "夏季透气纯棉短袖" in text + assert "100%棉" in text + + +def test_build_prompt_input_text_truncates_non_cjk_by_words(): + product = { + "title": "dress", + "brief": " ".join(f"brief{i}" for i in range(50)), + "description": " ".join(f"desc{i}" for i in range(50)), + } + + text = product_enrich._build_prompt_input_text(product) + + assert len(text.split()) <= product_enrich.PROMPT_INPUT_MAX_WORDS diff --git a/tests/test_rerank_client.py b/tests/test_rerank_client.py new file mode 100644 index 0000000..c83cb79 --- /dev/null +++ b/tests/test_rerank_client.py @@ -0,0 +1,53 @@ +from math import isclose + +from search.rerank_client import fuse_scores_and_resort + + +def test_fuse_scores_and_resort_uses_multiplicative_formula_with_named_query_scores(): + hits = [ + { + "_id": "1", + "_score": 3.2, + "matched_queries": { + "base_query": 2.4, + "knn_query": 0.8, + }, + }, + { + "_id": "2", + "_score": 2.8, + "matched_queries": { + "base_query": 1.6, + "knn_query": 0.2, + }, + }, + ] + + debug = fuse_scores_and_resort(hits, [0.9, 0.7]) + + expected_1 = (0.9 + 0.00001) * ((0.8 + 0.6) ** 0.2) * ((2.4 + 0.1) ** 0.75) + expected_2 = (0.7 + 0.00001) * ((0.2 + 0.6) ** 0.2) * ((1.6 + 0.1) ** 0.75) + + assert isclose(hits[0]["_fused_score"], expected_1, rel_tol=1e-9) + assert isclose(hits[1]["_fused_score"], expected_2, rel_tol=1e-9) + assert debug[0]["text_score"] == 2.4 + assert debug[0]["knn_score"] == 0.8 + assert [hit["_id"] for hit in hits] == ["1", "2"] + + +def test_fuse_scores_and_resort_falls_back_when_matched_queries_missing(): + hits = [ + {"_id": "1", "_score": 0.5}, + {"_id": "2", "_score": 2.0}, + ] + + fuse_scores_and_resort(hits, [0.4, 0.3]) + + expected_1 = (0.4 + 0.00001) * ((0.0 + 0.6) ** 0.2) * ((0.5 + 0.1) ** 0.75) + expected_2 = (0.3 + 0.00001) * ((0.0 + 0.6) ** 0.2) * ((2.0 + 0.1) ** 0.75) + + assert isclose(hits[0]["_text_score"], 2.0, rel_tol=1e-9) + assert isclose(hits[0]["_fused_score"], expected_2, rel_tol=1e-9) + assert isclose(hits[1]["_text_score"], 0.5, rel_tol=1e-9) + assert isclose(hits[1]["_fused_score"], expected_1, rel_tol=1e-9) + assert [hit["_id"] for hit in hits] == ["2", "1"] diff --git a/tests/test_search_rerank_window.py b/tests/test_search_rerank_window.py index f7e5c3f..bf23c9d 100644 --- a/tests/test_search_rerank_window.py +++ b/tests/test_search_rerank_window.py @@ -97,9 +97,22 @@ class _FakeESClient: "skus": [], } - def search(self, index_name: str, body: Dict[str, Any], size: int, from_: int): + def search( + self, + index_name: str, + body: Dict[str, Any], + size: int, + from_: int, + include_named_queries_score: bool = False, + ): self.calls.append( - {"index_name": index_name, "body": body, "size": size, "from_": from_} + { + "index_name": index_name, + "body": body, + "size": size, + "from_": from_, + "include_named_queries_score": include_named_queries_score, + } ) ids_query = (((body or {}).get("query") or {}).get("ids") or {}).get("values") source_spec = (body or {}).get("_source") @@ -213,6 +226,7 @@ def test_searcher_reranks_top_window_by_default(monkeypatch): assert called["docs"] == window assert es_client.calls[0]["from_"] == 0 assert es_client.calls[0]["size"] == window + assert es_client.calls[0]["include_named_queries_score"] is True assert es_client.calls[0]["body"]["_source"] == {"includes": ["title"]} assert len(es_client.calls) == 2 assert es_client.calls[1]["size"] == 10 @@ -277,6 +291,7 @@ def test_searcher_skips_rerank_when_request_explicitly_false(monkeypatch): assert called["count"] == 0 assert es_client.calls[0]["from_"] == 20 assert es_client.calls[0]["size"] == 10 + assert es_client.calls[0]["include_named_queries_score"] is False assert len(es_client.calls) == 1 @@ -310,4 +325,5 @@ def test_searcher_skips_rerank_when_page_exceeds_window(monkeypatch): assert called["count"] == 0 assert es_client.calls[0]["from_"] == 995 assert es_client.calls[0]["size"] == 10 + assert es_client.calls[0]["include_named_queries_score"] is False assert len(es_client.calls) == 1 diff --git a/utils/es_client.py b/utils/es_client.py index 248fc36..81c4564 100644 --- a/utils/es_client.py +++ b/utils/es_client.py @@ -228,6 +228,7 @@ class ESClient: size: int = 10, from_: int = 0, routing: Optional[str] = None, + include_named_queries_score: bool = False, ) -> Dict[str, Any]: """ Execute search query. @@ -260,6 +261,7 @@ class ESClient: size=size, from_=from_, routing=routing, + include_named_queries_score=include_named_queries_score, ) # elasticsearch-py 8.x returns ObjectApiResponse; normalize to mutable dict # so caller can safely patch hits/took during post-processing. -- libgit2 0.21.2