Commit a47416ec0d2f11e71a6e1b7f24c1dfe34bf48cf1
1 parent
76e1f088
把融合逻辑改成乘法公式,并把 ES 命名子句分数回传链路补上了。
核心改动在 rerank_client.py (line 99):fuse_scores_and_resort 现在按 rerank * knn * text 的平滑乘法公式计算,优先从 hit["matched_queries"] 里取 base_query 和 knn_query,并把 _text_score / _knn_score 一并写回调试字段。为了让 KNN 也有名字,我给 top-level knn 加了 name: "knn_query",见 es_query_builder.py (line 273)。搜索执行时会在 rerank 窗口内打开 include_named_queries_score,并在显式排序时加上 track_scores,见 searcher.py (line 400) 和 es_client.py (line 224)。
Showing
12 changed files
with
322 additions
and
61 deletions
Show diff stats
api/routes/indexer.py
| ... | ... | @@ -7,7 +7,7 @@ |
| 7 | 7 | import asyncio |
| 8 | 8 | import re |
| 9 | 9 | from fastapi import APIRouter, HTTPException |
| 10 | -from typing import Any, Dict, List | |
| 10 | +from typing import Any, Dict, List, Optional | |
| 11 | 11 | from pydantic import BaseModel, Field |
| 12 | 12 | import logging |
| 13 | 13 | from sqlalchemy import text |
| ... | ... | @@ -78,9 +78,12 @@ class BuildDocsFromDbRequest(BaseModel): |
| 78 | 78 | |
| 79 | 79 | |
| 80 | 80 | class EnrichContentItem(BaseModel): |
| 81 | - """单条待生成内容理解字段的商品(仅需 spu_id + 标题)。""" | |
| 81 | + """单条待生成内容理解字段的商品。""" | |
| 82 | 82 | spu_id: str = Field(..., description="SPU ID") |
| 83 | 83 | title: str = Field(..., description="商品标题,用于 LLM 分析生成 qanchors / tags 等") |
| 84 | + image_url: Optional[str] = Field(None, description="商品主图 URL(预留给多模态/内容理解扩展)") | |
| 85 | + brief: Optional[str] = Field(None, description="商品简介/短描述") | |
| 86 | + description: Optional[str] = Field(None, description="商品详情/长描述") | |
| 84 | 87 | |
| 85 | 88 | |
| 86 | 89 | class EnrichContentRequest(BaseModel): |
| ... | ... | @@ -88,8 +91,8 @@ class EnrichContentRequest(BaseModel): |
| 88 | 91 | 内容理解字段生成请求:根据商品标题批量生成 qanchors、semantic_attributes、tags。 |
| 89 | 92 | 供外部 indexer 在自行组织 doc 时调用,与翻译、向量化等微服务并列。 |
| 90 | 93 | """ |
| 91 | - tenant_id: str = Field(..., description="租户 ID,用于缓存隔离") | |
| 92 | - items: List[EnrichContentItem] = Field(..., description="待分析的 SPU 列表(spu_id + title)") | |
| 94 | + tenant_id: str = Field(..., description="租户 ID,用于请求路由与结果归属,不参与缓存键") | |
| 95 | + items: List[EnrichContentItem] = Field(..., description="待分析的 SPU 列表(spu_id + title,可附带 brief/description/image_url)") | |
| 93 | 96 | languages: List[str] = Field( |
| 94 | 97 | default_factory=lambda: ["zh", "en"], |
| 95 | 98 | description="目标语言列表,需在支持范围内(zh/en/de/ru/fr),默认 zh, en", |
| ... | ... | @@ -450,7 +453,16 @@ def _run_enrich_content(tenant_id: str, items: List[Dict[str, str]], languages: |
| 450 | 453 | |
| 451 | 454 | llm_langs = list(dict.fromkeys(languages)) or ["en"] |
| 452 | 455 | |
| 453 | - products = [{"id": it["spu_id"], "title": (it.get("title") or "").strip()} for it in items] | |
| 456 | + products = [ | |
| 457 | + { | |
| 458 | + "id": it["spu_id"], | |
| 459 | + "title": (it.get("title") or "").strip(), | |
| 460 | + "brief": (it.get("brief") or "").strip(), | |
| 461 | + "description": (it.get("description") or "").strip(), | |
| 462 | + "image_url": (it.get("image_url") or "").strip(), | |
| 463 | + } | |
| 464 | + for it in items | |
| 465 | + ] | |
| 454 | 466 | dim_keys = [ |
| 455 | 467 | "tags", |
| 456 | 468 | "target_audience", |
| ... | ... | @@ -545,7 +557,13 @@ async def enrich_content(request: EnrichContentRequest): |
| 545 | 557 | ) |
| 546 | 558 | |
| 547 | 559 | items_payload = [ |
| 548 | - {"spu_id": it.spu_id, "title": it.title or ""} | |
| 560 | + { | |
| 561 | + "spu_id": it.spu_id, | |
| 562 | + "title": it.title or "", | |
| 563 | + "brief": it.brief or "", | |
| 564 | + "description": it.description or "", | |
| 565 | + "image_url": it.image_url or "", | |
| 566 | + } | |
| 549 | 567 | for it in request.items |
| 550 | 568 | ] |
| 551 | 569 | loop = asyncio.get_event_loop() | ... | ... |
docs/搜索API对接指南.md
| ... | ... | @@ -1511,7 +1511,7 @@ curl -X POST "http://127.0.0.1:6004/indexer/build-docs-from-db" \ |
| 1511 | 1511 | |
| 1512 | 1512 | | 参数 | 类型 | 必填 | 默认值 | 说明 | |
| 1513 | 1513 | |------|------|------|--------|------| |
| 1514 | -| `tenant_id` | string | Y | - | 租户 ID,用于缓存隔离 | | |
| 1514 | +| `tenant_id` | string | Y | - | 租户 ID。目前仅用于记录日志,不产生实际作用| | |
| 1515 | 1515 | | `items` | array | Y | - | 待分析列表;**单次最多 50 条** | |
| 1516 | 1516 | | `languages` | array[string] | N | `["zh", "en"]` | 目标语言,需在支持范围内:`zh`、`en`、`de`、`ru`、`fr` | |
| 1517 | 1517 | |
| ... | ... | @@ -1519,11 +1519,17 @@ curl -X POST "http://127.0.0.1:6004/indexer/build-docs-from-db" \ |
| 1519 | 1519 | |
| 1520 | 1520 | | 字段 | 类型 | 必填 | 说明 | |
| 1521 | 1521 | |------|------|------|------| |
| 1522 | -| `spu_id` | string | Y | SPU ID,用于回填结果与缓存键 | | |
| 1522 | +| `spu_id` | string | Y | SPU ID,用于回填结果;目前仅用于记录日志,不产生实际作用| | |
| 1523 | 1523 | | `title` | string | Y | 商品标题 | |
| 1524 | -| `image_url` | string | N | 商品主图 URL(预留:后续可用于图像/多模态内容理解) | | |
| 1525 | -| `brief` | string | N | 商品简介/短描述(预留) | | |
| 1526 | -| `description` | string | N | 商品详情/长描述(预留) | | |
| 1524 | +| `image_url` | string | N | 商品主图 URL;当前会参与内容缓存键,后续可用于图像/多模态内容理解 | | |
| 1525 | +| `brief` | string | N | 商品简介/短描述;当前会参与内容缓存键 | | |
| 1526 | +| `description` | string | N | 商品详情/长描述;当前会参与内容缓存键 | | |
| 1527 | + | |
| 1528 | +缓存说明: | |
| 1529 | + | |
| 1530 | +- 内容缓存键仅由 `target_lang + items[]` 中会影响内容理解结果的输入文本构成,目前包括:`title`、`brief`、`description`、`image_url` 的规范化内容 hash。 | |
| 1531 | +- `tenant_id`、`spu_id` 只用于请求归属与结果回填,不参与缓存键。 | |
| 1532 | +- 因此,输入内容不变时可跨请求直接命中缓存;任一输入字段变化时,会自然落到新的缓存 key。 | |
| 1527 | 1533 | |
| 1528 | 1534 | 批量请求建议: |
| 1529 | 1535 | - **全量**:强烈建议 尽可能 **20 个 SPU/doc** 攒成一个批次后再请求一次。 | ... | ... |
indexer/product_enrich.py
| ... | ... | @@ -9,6 +9,7 @@ |
| 9 | 9 | import os |
| 10 | 10 | import json |
| 11 | 11 | import logging |
| 12 | +import re | |
| 12 | 13 | import time |
| 13 | 14 | import hashlib |
| 14 | 15 | from collections import OrderedDict |
| ... | ... | @@ -40,6 +41,10 @@ MAX_RETRIES = 3 |
| 40 | 41 | RETRY_DELAY = 5 # 秒 |
| 41 | 42 | REQUEST_TIMEOUT = 180 # 秒 |
| 42 | 43 | LOGGED_SHARED_CONTEXT_CACHE_SIZE = 256 |
| 44 | +PROMPT_INPUT_MIN_ZH_CHARS = 20 | |
| 45 | +PROMPT_INPUT_MAX_ZH_CHARS = 100 | |
| 46 | +PROMPT_INPUT_MIN_WORDS = 16 | |
| 47 | +PROMPT_INPUT_MAX_WORDS = 80 | |
| 43 | 48 | |
| 44 | 49 | # 日志路径 |
| 45 | 50 | OUTPUT_DIR = Path("output_logs") |
| ... | ... | @@ -82,6 +87,8 @@ if not verbose_logger.handlers: |
| 82 | 87 | verbose_logger.addHandler(verbose_file_handler) |
| 83 | 88 | verbose_logger.propagate = False |
| 84 | 89 | |
| 90 | +logger.info("Verbose LLM logs are written to: %s", verbose_log_file) | |
| 91 | + | |
| 85 | 92 | |
| 86 | 93 | # Redis 缓存(用于 anchors / 语义属性) |
| 87 | 94 | ANCHOR_CACHE_PREFIX = REDIS_CONFIG.get("anchor_cache_prefix", "product_anchors") |
| ... | ... | @@ -112,26 +119,86 @@ if _missing_prompt_langs: |
| 112 | 119 | ) |
| 113 | 120 | |
| 114 | 121 | |
| 122 | +def _normalize_space(text: str) -> str: | |
| 123 | + return re.sub(r"\s+", " ", (text or "").strip()) | |
| 124 | + | |
| 125 | + | |
| 126 | +def _contains_cjk(text: str) -> bool: | |
| 127 | + return bool(re.search(r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]", text or "")) | |
| 128 | + | |
| 129 | + | |
| 130 | +def _truncate_by_chars(text: str, max_chars: int) -> str: | |
| 131 | + return text[:max_chars].strip() | |
| 132 | + | |
| 133 | + | |
| 134 | +def _truncate_by_words(text: str, max_words: int) -> str: | |
| 135 | + words = re.findall(r"\S+", text or "") | |
| 136 | + return " ".join(words[:max_words]).strip() | |
| 137 | + | |
| 138 | + | |
| 139 | +def _detect_prompt_input_lang(text: str) -> str: | |
| 140 | + # 简化处理:包含 CJK 时按中文类文本处理,否则统一按空格分词类语言处理。 | |
| 141 | + return "zh" if _contains_cjk(text) else "en" | |
| 142 | + | |
| 143 | + | |
| 144 | +def _build_prompt_input_text(product: Dict[str, Any]) -> str: | |
| 145 | + """ | |
| 146 | + 生成真正送入 prompt 的商品文本。 | |
| 147 | + | |
| 148 | + 规则: | |
| 149 | + - 默认使用 title | |
| 150 | + - 若文本过短,则依次补 brief / description | |
| 151 | + - 若文本过长,则按语言粗粒度截断 | |
| 152 | + """ | |
| 153 | + fields = [ | |
| 154 | + _normalize_space(str(product.get("title") or "")), | |
| 155 | + _normalize_space(str(product.get("brief") or "")), | |
| 156 | + _normalize_space(str(product.get("description") or "")), | |
| 157 | + ] | |
| 158 | + parts: List[str] = [] | |
| 159 | + | |
| 160 | + def join_parts() -> str: | |
| 161 | + return " | ".join(part for part in parts if part).strip() | |
| 162 | + | |
| 163 | + for field in fields: | |
| 164 | + if not field: | |
| 165 | + continue | |
| 166 | + if field not in parts: | |
| 167 | + parts.append(field) | |
| 168 | + candidate = join_parts() | |
| 169 | + if _detect_prompt_input_lang(candidate) == "zh": | |
| 170 | + if len(candidate) >= PROMPT_INPUT_MIN_ZH_CHARS: | |
| 171 | + return _truncate_by_chars(candidate, PROMPT_INPUT_MAX_ZH_CHARS) | |
| 172 | + else: | |
| 173 | + if len(re.findall(r"\S+", candidate)) >= PROMPT_INPUT_MIN_WORDS: | |
| 174 | + return _truncate_by_words(candidate, PROMPT_INPUT_MAX_WORDS) | |
| 175 | + | |
| 176 | + candidate = join_parts() | |
| 177 | + if not candidate: | |
| 178 | + return "" | |
| 179 | + if _detect_prompt_input_lang(candidate) == "zh": | |
| 180 | + return _truncate_by_chars(candidate, PROMPT_INPUT_MAX_ZH_CHARS) | |
| 181 | + return _truncate_by_words(candidate, PROMPT_INPUT_MAX_WORDS) | |
| 182 | + | |
| 183 | + | |
| 115 | 184 | def _make_anchor_cache_key( |
| 116 | - title: str, | |
| 185 | + product: Dict[str, Any], | |
| 117 | 186 | target_lang: str, |
| 118 | - tenant_id: Optional[str] = None, | |
| 119 | 187 | ) -> str: |
| 120 | - """构造 anchors/语义属性的缓存 key。""" | |
| 121 | - base = (tenant_id or "global").strip() | |
| 122 | - h = hashlib.md5(title.encode("utf-8")).hexdigest() | |
| 123 | - return f"{ANCHOR_CACHE_PREFIX}:{base}:{target_lang}:{h}" | |
| 188 | + """构造缓存 key,仅由 prompt 实际输入文本内容 + 目标语言决定。""" | |
| 189 | + prompt_input = _build_prompt_input_text(product) | |
| 190 | + h = hashlib.md5(prompt_input.encode("utf-8")).hexdigest() | |
| 191 | + return f"{ANCHOR_CACHE_PREFIX}:{target_lang}:{prompt_input[:4]}{h}" | |
| 124 | 192 | |
| 125 | 193 | |
| 126 | 194 | def _get_cached_anchor_result( |
| 127 | - title: str, | |
| 195 | + product: Dict[str, Any], | |
| 128 | 196 | target_lang: str, |
| 129 | - tenant_id: Optional[str] = None, | |
| 130 | 197 | ) -> Optional[Dict[str, Any]]: |
| 131 | 198 | if not _anchor_redis: |
| 132 | 199 | return None |
| 133 | 200 | try: |
| 134 | - key = _make_anchor_cache_key(title, target_lang, tenant_id) | |
| 201 | + key = _make_anchor_cache_key(product, target_lang) | |
| 135 | 202 | raw = _anchor_redis.get(key) |
| 136 | 203 | if not raw: |
| 137 | 204 | return None |
| ... | ... | @@ -142,15 +209,14 @@ def _get_cached_anchor_result( |
| 142 | 209 | |
| 143 | 210 | |
| 144 | 211 | def _set_cached_anchor_result( |
| 145 | - title: str, | |
| 212 | + product: Dict[str, Any], | |
| 146 | 213 | target_lang: str, |
| 147 | 214 | result: Dict[str, Any], |
| 148 | - tenant_id: Optional[str] = None, | |
| 149 | 215 | ) -> None: |
| 150 | 216 | if not _anchor_redis: |
| 151 | 217 | return |
| 152 | 218 | try: |
| 153 | - key = _make_anchor_cache_key(title, target_lang, tenant_id) | |
| 219 | + key = _make_anchor_cache_key(product, target_lang) | |
| 154 | 220 | ttl = ANCHOR_CACHE_EXPIRE_DAYS * 24 * 3600 |
| 155 | 221 | _anchor_redis.setex(key, ttl, json.dumps(result, ensure_ascii=False)) |
| 156 | 222 | except Exception as e: |
| ... | ... | @@ -166,7 +232,8 @@ def _build_assistant_prefix(headers: List[str]) -> str: |
| 166 | 232 | def _build_shared_context(products: List[Dict[str, str]]) -> str: |
| 167 | 233 | shared_context = SHARED_ANALYSIS_INSTRUCTION |
| 168 | 234 | for idx, product in enumerate(products, 1): |
| 169 | - shared_context += f'{idx}. {product["title"]}\n' | |
| 235 | + prompt_input = _build_prompt_input_text(product) | |
| 236 | + shared_context += f"{idx}. {prompt_input}\n" | |
| 170 | 237 | return shared_context |
| 171 | 238 | |
| 172 | 239 | |
| ... | ... | @@ -619,11 +686,11 @@ def analyze_products( |
| 619 | 686 | uncached_items.append((idx, product)) |
| 620 | 687 | continue |
| 621 | 688 | |
| 622 | - cached = _get_cached_anchor_result(title, target_lang, tenant_id=tenant_id) | |
| 689 | + cached = _get_cached_anchor_result(product, target_lang) | |
| 623 | 690 | if cached: |
| 624 | 691 | logger.info( |
| 625 | 692 | f"[analyze_products] Cache hit for title='{title[:50]}...', " |
| 626 | - f"lang={target_lang}, tenant_id={tenant_id or 'global'}" | |
| 693 | + f"lang={target_lang}" | |
| 627 | 694 | ) |
| 628 | 695 | results_by_index[idx] = cached |
| 629 | 696 | continue |
| ... | ... | @@ -650,7 +717,7 @@ def analyze_products( |
| 650 | 717 | ) |
| 651 | 718 | batch_results = process_batch(batch, batch_num=batch_num, target_lang=target_lang) |
| 652 | 719 | |
| 653 | - for (original_idx, _), item in zip(batch_slice, batch_results): | |
| 720 | + for (original_idx, product), item in zip(batch_slice, batch_results): | |
| 654 | 721 | results_by_index[original_idx] = item |
| 655 | 722 | title_input = str(item.get("title_input") or "").strip() |
| 656 | 723 | if not title_input: |
| ... | ... | @@ -659,7 +726,7 @@ def analyze_products( |
| 659 | 726 | # 不缓存错误结果,避免放大临时故障 |
| 660 | 727 | continue |
| 661 | 728 | try: |
| 662 | - _set_cached_anchor_result(title_input, target_lang, item, tenant_id=tenant_id) | |
| 729 | + _set_cached_anchor_result(product, target_lang, item) | |
| 663 | 730 | except Exception: |
| 664 | 731 | # 已在内部记录 warning |
| 665 | 732 | pass | ... | ... |
indexer/product_enrich_prompts.py
| ... | ... | @@ -8,9 +8,9 @@ SYSTEM_MESSAGE = ( |
| 8 | 8 | "Do not repeat or modify the prefix, and do not add explanations outside the table." |
| 9 | 9 | ) |
| 10 | 10 | |
| 11 | -SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product title and fill these columns: | |
| 11 | +SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product text and fill these columns: | |
| 12 | 12 | |
| 13 | -1. Product title: a natural localized product name derived from the input title | |
| 13 | +1. Product title: a natural localized product name derived from the input product text | |
| 14 | 14 | 2. Category path: broad to fine-grained category, separated by ">" |
| 15 | 15 | 3. Fine-grained tags: style, features, functions, or notable attributes |
| 16 | 16 | 4. Target audience: gender, age group, or suitable users |
| ... | ... | @@ -23,7 +23,7 @@ SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product title and fill these |
| 23 | 23 | |
| 24 | 24 | Rules: |
| 25 | 25 | - Keep the input order and row count exactly the same. |
| 26 | -- Infer from the title only; if uncertain, prefer concise and broadly correct ecommerce wording. | |
| 26 | +- Infer only from the provided input product text; if uncertain, prefer concise and broadly correct ecommerce wording. | |
| 27 | 27 | - Keep category paths concise and use ">" as the separator. |
| 28 | 28 | - For columns with multiple values, the localized output requirement will define the delimiter. |
| 29 | 29 | |
| ... | ... | @@ -515,4 +515,4 @@ LANGUAGE_MARKDOWN_TABLE_HEADERS: Dict[str, Dict[str, Any]] = { |
| 515 | 515 | "Характеристики", |
| 516 | 516 | "Анкор текст" |
| 517 | 517 | ] |
| 518 | -} | |
| 519 | 518 | \ No newline at end of file |
| 519 | +} | ... | ... |
search/es_query_builder.py
| ... | ... | @@ -275,7 +275,8 @@ class ESQueryBuilder: |
| 275 | 275 | "query_vector": query_vector.tolist(), |
| 276 | 276 | "k": knn_k, |
| 277 | 277 | "num_candidates": knn_num_candidates, |
| 278 | - "boost": knn_boost | |
| 278 | + "boost": knn_boost, | |
| 279 | + "name": "knn_query", | |
| 279 | 280 | } |
| 280 | 281 | # Top-level knn does not inherit query.bool.filter automatically. |
| 281 | 282 | # Apply conjunctive + range filters here so vector recall respects hard filters. | ... | ... |
search/rerank_client.py
| ... | ... | @@ -4,7 +4,7 @@ |
| 4 | 4 | 流程: |
| 5 | 5 | 1. 从 ES hits 构造用于重排的文档文本列表 |
| 6 | 6 | 2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数 |
| 7 | -3. 将 ES 分数(归一化)与重排分数线性融合,写回 hit["_score"] 并重排序 | |
| 7 | +3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序 | |
| 8 | 8 | """ |
| 9 | 9 | |
| 10 | 10 | from typing import Dict, Any, List, Optional, Tuple |
| ... | ... | @@ -14,7 +14,7 @@ from providers import create_rerank_provider |
| 14 | 14 | |
| 15 | 15 | logger = logging.getLogger(__name__) |
| 16 | 16 | |
| 17 | -# 默认融合权重:ES 归一化分数权重、重排分数权重(相加为 1) | |
| 17 | +# 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。 | |
| 18 | 18 | DEFAULT_WEIGHT_ES = 0.4 |
| 19 | 19 | DEFAULT_WEIGHT_AI = 0.6 |
| 20 | 20 | # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大) |
| ... | ... | @@ -103,18 +103,20 @@ def fuse_scores_and_resort( |
| 103 | 103 | weight_ai: float = DEFAULT_WEIGHT_AI, |
| 104 | 104 | ) -> List[Dict[str, Any]]: |
| 105 | 105 | """ |
| 106 | - 将 ES 分数与重排分数线性融合(不修改原始 _score),并按融合分数降序重排。 | |
| 106 | + 将 ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。 | |
| 107 | 107 | |
| 108 | 108 | 对每条 hit 会写入: |
| 109 | 109 | - _original_score: 原始 ES 分数 |
| 110 | 110 | - _rerank_score: 重排服务返回的分数 |
| 111 | 111 | - _fused_score: 融合分数 |
| 112 | + - _text_score: 文本相关性分数(优先取 named queries 的 base_query 分数) | |
| 113 | + - _knn_score: KNN 分数(优先取 named queries 的 knn_query 分数) | |
| 112 | 114 | |
| 113 | 115 | Args: |
| 114 | 116 | es_hits: ES hits 列表(会被原地修改) |
| 115 | 117 | rerank_scores: 与 es_hits 等长的重排分数列表 |
| 116 | - weight_es: ES 归一化分数权重 | |
| 117 | - weight_ai: 重排分数权重 | |
| 118 | + weight_es: 兼容保留,当前未使用 | |
| 119 | + weight_ai: 兼容保留,当前未使用 | |
| 118 | 120 | |
| 119 | 121 | Returns: |
| 120 | 122 | 每条文档的融合调试信息列表,用于 debug_info |
| ... | ... | @@ -123,38 +125,62 @@ def fuse_scores_and_resort( |
| 123 | 125 | if n == 0 or len(rerank_scores) != n: |
| 124 | 126 | return [] |
| 125 | 127 | |
| 126 | - # 收集 ES 原始分数 | |
| 127 | - es_scores: List[float] = [] | |
| 128 | - for hit in es_hits: | |
| 129 | - raw = hit.get("_score") | |
| 130 | - try: | |
| 131 | - es_scores.append(float(raw) if raw is not None else 0.0) | |
| 132 | - except (TypeError, ValueError): | |
| 133 | - es_scores.append(0.0) | |
| 134 | - | |
| 135 | - max_es = max(es_scores) if es_scores else 0.0 | |
| 136 | 128 | fused_debug: List[Dict[str, Any]] = [] |
| 137 | 129 | |
| 138 | 130 | for idx, hit in enumerate(es_hits): |
| 139 | - es_score = es_scores[idx] | |
| 131 | + raw_es_score = hit.get("_score") | |
| 132 | + try: | |
| 133 | + es_score = float(raw_es_score) if raw_es_score is not None else 0.0 | |
| 134 | + except (TypeError, ValueError): | |
| 135 | + es_score = 0.0 | |
| 136 | + | |
| 140 | 137 | ai_score_raw = rerank_scores[idx] |
| 141 | 138 | try: |
| 142 | 139 | rerank_score = float(ai_score_raw) |
| 143 | 140 | except (TypeError, ValueError): |
| 144 | 141 | rerank_score = 0.0 |
| 145 | 142 | |
| 146 | - es_norm = (es_score / max_es) if max_es > 0 else 0.0 | |
| 147 | - fused = weight_es * es_norm + weight_ai * rerank_score | |
| 143 | + matched_queries = hit.get("matched_queries") | |
| 144 | + text_score = 0.0 | |
| 145 | + knn_score = 0.0 | |
| 146 | + if isinstance(matched_queries, dict): | |
| 147 | + try: | |
| 148 | + text_score = float(matched_queries.get("base_query", 0.0) or 0.0) | |
| 149 | + except (TypeError, ValueError): | |
| 150 | + text_score = 0.0 | |
| 151 | + try: | |
| 152 | + knn_score = float(matched_queries.get("knn_query", 0.0) or 0.0) | |
| 153 | + except (TypeError, ValueError): | |
| 154 | + knn_score = 0.0 | |
| 155 | + elif isinstance(matched_queries, list): | |
| 156 | + text_score = 1.0 if "base_query" in matched_queries else 0.0 | |
| 157 | + knn_score = 1.0 if "knn_query" in matched_queries else 0.0 | |
| 158 | + | |
| 159 | + # 回退逻辑: | |
| 160 | + # - text_score 缺失时,退回原始 _score,避免纯文本召回被错误打成 0。 | |
| 161 | + # - knn_score 缺失时保持 0,由平滑项 0.6 兜底。 | |
| 162 | + if text_score <= 0.0: | |
| 163 | + text_score = es_score | |
| 164 | + | |
| 165 | + fused = ( | |
| 166 | + (rerank_score + 0.00001) ** 1.0 * | |
| 167 | + (knn_score + 0.6) ** 0.2 * | |
| 168 | + (text_score + 0.1) ** 0.75 | |
| 169 | + ) | |
| 148 | 170 | |
| 149 | 171 | hit["_original_score"] = hit.get("_score") |
| 150 | 172 | hit["_rerank_score"] = rerank_score |
| 173 | + hit["_text_score"] = text_score | |
| 174 | + hit["_knn_score"] = knn_score | |
| 151 | 175 | hit["_fused_score"] = fused |
| 152 | 176 | |
| 153 | 177 | fused_debug.append({ |
| 154 | 178 | "doc_id": hit.get("_id"), |
| 155 | 179 | "es_score": es_score, |
| 156 | - "es_score_norm": es_norm, | |
| 157 | 180 | "rerank_score": rerank_score, |
| 181 | + "text_score": text_score, | |
| 182 | + "knn_score": knn_score, | |
| 183 | + "matched_queries": matched_queries, | |
| 158 | 184 | "fused_score": fused, |
| 159 | 185 | }) |
| 160 | 186 | ... | ... |
search/searcher.py
| ... | ... | @@ -400,6 +400,7 @@ class Searcher: |
| 400 | 400 | # Add sorting if specified |
| 401 | 401 | if sort_by: |
| 402 | 402 | es_query = self.query_builder.add_sorting(es_query, sort_by, sort_order) |
| 403 | + es_query["track_scores"] = True | |
| 403 | 404 | |
| 404 | 405 | # Keep requested response _source semantics for the final response fill. |
| 405 | 406 | response_source_spec = es_query.get("_source") |
| ... | ... | @@ -467,7 +468,8 @@ class Searcher: |
| 467 | 468 | index_name=index_name, |
| 468 | 469 | body=body_for_es, |
| 469 | 470 | size=es_fetch_size, |
| 470 | - from_=es_fetch_from | |
| 471 | + from_=es_fetch_from, | |
| 472 | + include_named_queries_score=bool(do_rerank and in_rerank_window), | |
| 471 | 473 | ) |
| 472 | 474 | |
| 473 | 475 | # Store ES response in context | ... | ... |
tests/test_es_query_builder.py
tests/test_product_enrich_partial_mode.py
| ... | ... | @@ -62,7 +62,7 @@ def test_create_prompt_splits_shared_context_and_localized_tail(): |
| 62 | 62 | shared_en, user_en, prefix_en = product_enrich.create_prompt(products, target_lang="en") |
| 63 | 63 | |
| 64 | 64 | assert shared_zh == shared_en |
| 65 | - assert "Analyze each input product title" in shared_zh | |
| 65 | + assert "Analyze each input product text" in shared_zh | |
| 66 | 66 | assert "1. dress" in shared_zh |
| 67 | 67 | assert "2. linen shirt" in shared_zh |
| 68 | 68 | assert "Product list" not in user_zh |
| ... | ... | @@ -232,11 +232,20 @@ def test_analyze_products_uses_product_level_cache_across_batch_requests(): |
| 232 | 232 | cache_store = {} |
| 233 | 233 | process_calls = [] |
| 234 | 234 | |
| 235 | - def fake_get_cached_anchor_result(title, target_lang, tenant_id=None): | |
| 236 | - return cache_store.get((tenant_id, target_lang, title)) | |
| 235 | + def _cache_key(product, target_lang): | |
| 236 | + return ( | |
| 237 | + target_lang, | |
| 238 | + product.get("title", ""), | |
| 239 | + product.get("brief", ""), | |
| 240 | + product.get("description", ""), | |
| 241 | + product.get("image_url", ""), | |
| 242 | + ) | |
| 243 | + | |
| 244 | + def fake_get_cached_anchor_result(product, target_lang): | |
| 245 | + return cache_store.get(_cache_key(product, target_lang)) | |
| 237 | 246 | |
| 238 | - def fake_set_cached_anchor_result(title, target_lang, result, tenant_id=None): | |
| 239 | - cache_store[(tenant_id, target_lang, title)] = result | |
| 247 | + def fake_set_cached_anchor_result(product, target_lang, result): | |
| 248 | + cache_store[_cache_key(product, target_lang)] = result | |
| 240 | 249 | |
| 241 | 250 | def fake_process_batch(batch_data, batch_num, target_lang="zh"): |
| 242 | 251 | process_calls.append( |
| ... | ... | @@ -291,7 +300,7 @@ def test_analyze_products_uses_product_level_cache_across_batch_requests(): |
| 291 | 300 | second = product_enrich.analyze_products( |
| 292 | 301 | products, |
| 293 | 302 | target_lang="zh", |
| 294 | - tenant_id="170", | |
| 303 | + tenant_id="999", | |
| 295 | 304 | ) |
| 296 | 305 | third = product_enrich.analyze_products( |
| 297 | 306 | products, |
| ... | ... | @@ -311,3 +320,63 @@ def test_analyze_products_uses_product_level_cache_across_batch_requests(): |
| 311 | 320 | assert second[1]["anchor_text"] == "anchor:shirt" |
| 312 | 321 | assert third[0]["anchor_text"] == "anchor:dress" |
| 313 | 322 | assert third[1]["anchor_text"] == "anchor:shirt" |
| 323 | + | |
| 324 | + | |
| 325 | +def test_anchor_cache_key_depends_on_product_input_not_identifiers(): | |
| 326 | + product_a = { | |
| 327 | + "id": "1", | |
| 328 | + "spu_id": "1001", | |
| 329 | + "title": "dress", | |
| 330 | + "brief": "soft cotton", | |
| 331 | + "description": "summer dress", | |
| 332 | + "image_url": "https://img/a.jpg", | |
| 333 | + } | |
| 334 | + product_b = { | |
| 335 | + "id": "2", | |
| 336 | + "spu_id": "9999", | |
| 337 | + "title": "dress", | |
| 338 | + "brief": "soft cotton", | |
| 339 | + "description": "summer dress", | |
| 340 | + "image_url": "https://img/a.jpg", | |
| 341 | + } | |
| 342 | + product_c = { | |
| 343 | + "id": "1", | |
| 344 | + "spu_id": "1001", | |
| 345 | + "title": "dress", | |
| 346 | + "brief": "soft cotton updated", | |
| 347 | + "description": "summer dress", | |
| 348 | + "image_url": "https://img/a.jpg", | |
| 349 | + } | |
| 350 | + | |
| 351 | + key_a = product_enrich._make_anchor_cache_key(product_a, "zh") | |
| 352 | + key_b = product_enrich._make_anchor_cache_key(product_b, "zh") | |
| 353 | + key_c = product_enrich._make_anchor_cache_key(product_c, "zh") | |
| 354 | + | |
| 355 | + assert key_a == key_b | |
| 356 | + assert key_a != key_c | |
| 357 | + | |
| 358 | + | |
| 359 | +def test_build_prompt_input_text_appends_brief_and_description_for_short_title(): | |
| 360 | + product = { | |
| 361 | + "title": "T恤", | |
| 362 | + "brief": "夏季透气纯棉短袖,舒适亲肤", | |
| 363 | + "description": "100%棉,圆领版型,适合日常通勤与休闲穿搭。", | |
| 364 | + } | |
| 365 | + | |
| 366 | + text = product_enrich._build_prompt_input_text(product) | |
| 367 | + | |
| 368 | + assert text.startswith("T恤") | |
| 369 | + assert "夏季透气纯棉短袖" in text | |
| 370 | + assert "100%棉" in text | |
| 371 | + | |
| 372 | + | |
| 373 | +def test_build_prompt_input_text_truncates_non_cjk_by_words(): | |
| 374 | + product = { | |
| 375 | + "title": "dress", | |
| 376 | + "brief": " ".join(f"brief{i}" for i in range(50)), | |
| 377 | + "description": " ".join(f"desc{i}" for i in range(50)), | |
| 378 | + } | |
| 379 | + | |
| 380 | + text = product_enrich._build_prompt_input_text(product) | |
| 381 | + | |
| 382 | + assert len(text.split()) <= product_enrich.PROMPT_INPUT_MAX_WORDS | ... | ... |
| ... | ... | @@ -0,0 +1,53 @@ |
| 1 | +from math import isclose | |
| 2 | + | |
| 3 | +from search.rerank_client import fuse_scores_and_resort | |
| 4 | + | |
| 5 | + | |
| 6 | +def test_fuse_scores_and_resort_uses_multiplicative_formula_with_named_query_scores(): | |
| 7 | + hits = [ | |
| 8 | + { | |
| 9 | + "_id": "1", | |
| 10 | + "_score": 3.2, | |
| 11 | + "matched_queries": { | |
| 12 | + "base_query": 2.4, | |
| 13 | + "knn_query": 0.8, | |
| 14 | + }, | |
| 15 | + }, | |
| 16 | + { | |
| 17 | + "_id": "2", | |
| 18 | + "_score": 2.8, | |
| 19 | + "matched_queries": { | |
| 20 | + "base_query": 1.6, | |
| 21 | + "knn_query": 0.2, | |
| 22 | + }, | |
| 23 | + }, | |
| 24 | + ] | |
| 25 | + | |
| 26 | + debug = fuse_scores_and_resort(hits, [0.9, 0.7]) | |
| 27 | + | |
| 28 | + expected_1 = (0.9 + 0.00001) * ((0.8 + 0.6) ** 0.2) * ((2.4 + 0.1) ** 0.75) | |
| 29 | + expected_2 = (0.7 + 0.00001) * ((0.2 + 0.6) ** 0.2) * ((1.6 + 0.1) ** 0.75) | |
| 30 | + | |
| 31 | + assert isclose(hits[0]["_fused_score"], expected_1, rel_tol=1e-9) | |
| 32 | + assert isclose(hits[1]["_fused_score"], expected_2, rel_tol=1e-9) | |
| 33 | + assert debug[0]["text_score"] == 2.4 | |
| 34 | + assert debug[0]["knn_score"] == 0.8 | |
| 35 | + assert [hit["_id"] for hit in hits] == ["1", "2"] | |
| 36 | + | |
| 37 | + | |
| 38 | +def test_fuse_scores_and_resort_falls_back_when_matched_queries_missing(): | |
| 39 | + hits = [ | |
| 40 | + {"_id": "1", "_score": 0.5}, | |
| 41 | + {"_id": "2", "_score": 2.0}, | |
| 42 | + ] | |
| 43 | + | |
| 44 | + fuse_scores_and_resort(hits, [0.4, 0.3]) | |
| 45 | + | |
| 46 | + expected_1 = (0.4 + 0.00001) * ((0.0 + 0.6) ** 0.2) * ((0.5 + 0.1) ** 0.75) | |
| 47 | + expected_2 = (0.3 + 0.00001) * ((0.0 + 0.6) ** 0.2) * ((2.0 + 0.1) ** 0.75) | |
| 48 | + | |
| 49 | + assert isclose(hits[0]["_text_score"], 2.0, rel_tol=1e-9) | |
| 50 | + assert isclose(hits[0]["_fused_score"], expected_2, rel_tol=1e-9) | |
| 51 | + assert isclose(hits[1]["_text_score"], 0.5, rel_tol=1e-9) | |
| 52 | + assert isclose(hits[1]["_fused_score"], expected_1, rel_tol=1e-9) | |
| 53 | + assert [hit["_id"] for hit in hits] == ["2", "1"] | ... | ... |
tests/test_search_rerank_window.py
| ... | ... | @@ -97,9 +97,22 @@ class _FakeESClient: |
| 97 | 97 | "skus": [], |
| 98 | 98 | } |
| 99 | 99 | |
| 100 | - def search(self, index_name: str, body: Dict[str, Any], size: int, from_: int): | |
| 100 | + def search( | |
| 101 | + self, | |
| 102 | + index_name: str, | |
| 103 | + body: Dict[str, Any], | |
| 104 | + size: int, | |
| 105 | + from_: int, | |
| 106 | + include_named_queries_score: bool = False, | |
| 107 | + ): | |
| 101 | 108 | self.calls.append( |
| 102 | - {"index_name": index_name, "body": body, "size": size, "from_": from_} | |
| 109 | + { | |
| 110 | + "index_name": index_name, | |
| 111 | + "body": body, | |
| 112 | + "size": size, | |
| 113 | + "from_": from_, | |
| 114 | + "include_named_queries_score": include_named_queries_score, | |
| 115 | + } | |
| 103 | 116 | ) |
| 104 | 117 | ids_query = (((body or {}).get("query") or {}).get("ids") or {}).get("values") |
| 105 | 118 | source_spec = (body or {}).get("_source") |
| ... | ... | @@ -213,6 +226,7 @@ def test_searcher_reranks_top_window_by_default(monkeypatch): |
| 213 | 226 | assert called["docs"] == window |
| 214 | 227 | assert es_client.calls[0]["from_"] == 0 |
| 215 | 228 | assert es_client.calls[0]["size"] == window |
| 229 | + assert es_client.calls[0]["include_named_queries_score"] is True | |
| 216 | 230 | assert es_client.calls[0]["body"]["_source"] == {"includes": ["title"]} |
| 217 | 231 | assert len(es_client.calls) == 2 |
| 218 | 232 | assert es_client.calls[1]["size"] == 10 |
| ... | ... | @@ -277,6 +291,7 @@ def test_searcher_skips_rerank_when_request_explicitly_false(monkeypatch): |
| 277 | 291 | assert called["count"] == 0 |
| 278 | 292 | assert es_client.calls[0]["from_"] == 20 |
| 279 | 293 | assert es_client.calls[0]["size"] == 10 |
| 294 | + assert es_client.calls[0]["include_named_queries_score"] is False | |
| 280 | 295 | assert len(es_client.calls) == 1 |
| 281 | 296 | |
| 282 | 297 | |
| ... | ... | @@ -310,4 +325,5 @@ def test_searcher_skips_rerank_when_page_exceeds_window(monkeypatch): |
| 310 | 325 | assert called["count"] == 0 |
| 311 | 326 | assert es_client.calls[0]["from_"] == 995 |
| 312 | 327 | assert es_client.calls[0]["size"] == 10 |
| 328 | + assert es_client.calls[0]["include_named_queries_score"] is False | |
| 313 | 329 | assert len(es_client.calls) == 1 | ... | ... |
utils/es_client.py
| ... | ... | @@ -228,6 +228,7 @@ class ESClient: |
| 228 | 228 | size: int = 10, |
| 229 | 229 | from_: int = 0, |
| 230 | 230 | routing: Optional[str] = None, |
| 231 | + include_named_queries_score: bool = False, | |
| 231 | 232 | ) -> Dict[str, Any]: |
| 232 | 233 | """ |
| 233 | 234 | Execute search query. |
| ... | ... | @@ -260,6 +261,7 @@ class ESClient: |
| 260 | 261 | size=size, |
| 261 | 262 | from_=from_, |
| 262 | 263 | routing=routing, |
| 264 | + include_named_queries_score=include_named_queries_score, | |
| 263 | 265 | ) |
| 264 | 266 | # elasticsearch-py 8.x returns ObjectApiResponse; normalize to mutable dict |
| 265 | 267 | # so caller can safely patch hits/took during post-processing. | ... | ... |