Commit a47416ec0d2f11e71a6e1b7f24c1dfe34bf48cf1

Authored by tangwang
1 parent 76e1f088

把融合逻辑改成乘法公式,并把 ES 命名子句分数回传链路补上了。

核心改动在 rerank_client.py (line 99):fuse_scores_and_resort 现在按
rerank * knn * text 的平滑乘法公式计算,优先从 hit["matched_queries"]
里取 base_query 和 knn_query,并把 _text_score / _knn_score
一并写回调试字段。为了让 KNN 也有名字,我给 top-level knn 加了 name:
"knn_query",见 es_query_builder.py (line 273)。搜索执行时会在 rerank
窗口内打开 include_named_queries_score,并在显式排序时加上
track_scores,见 searcher.py (line 400) 和 es_client.py (line 224)。
api/routes/indexer.py
... ... @@ -7,7 +7,7 @@
7 7 import asyncio
8 8 import re
9 9 from fastapi import APIRouter, HTTPException
10   -from typing import Any, Dict, List
  10 +from typing import Any, Dict, List, Optional
11 11 from pydantic import BaseModel, Field
12 12 import logging
13 13 from sqlalchemy import text
... ... @@ -78,9 +78,12 @@ class BuildDocsFromDbRequest(BaseModel):
78 78  
79 79  
80 80 class EnrichContentItem(BaseModel):
81   - """单条待生成内容理解字段的商品(仅需 spu_id + 标题)。"""
  81 + """单条待生成内容理解字段的商品。"""
82 82 spu_id: str = Field(..., description="SPU ID")
83 83 title: str = Field(..., description="商品标题,用于 LLM 分析生成 qanchors / tags 等")
  84 + image_url: Optional[str] = Field(None, description="商品主图 URL(预留给多模态/内容理解扩展)")
  85 + brief: Optional[str] = Field(None, description="商品简介/短描述")
  86 + description: Optional[str] = Field(None, description="商品详情/长描述")
84 87  
85 88  
86 89 class EnrichContentRequest(BaseModel):
... ... @@ -88,8 +91,8 @@ class EnrichContentRequest(BaseModel):
88 91 内容理解字段生成请求:根据商品标题批量生成 qanchors、semantic_attributes、tags。
89 92 供外部 indexer 在自行组织 doc 时调用,与翻译、向量化等微服务并列。
90 93 """
91   - tenant_id: str = Field(..., description="租户 ID,用于缓存隔离")
92   - items: List[EnrichContentItem] = Field(..., description="待分析的 SPU 列表(spu_id + title)")
  94 + tenant_id: str = Field(..., description="租户 ID,用于请求路由与结果归属,不参与缓存键")
  95 + items: List[EnrichContentItem] = Field(..., description="待分析的 SPU 列表(spu_id + title,可附带 brief/description/image_url)")
93 96 languages: List[str] = Field(
94 97 default_factory=lambda: ["zh", "en"],
95 98 description="目标语言列表,需在支持范围内(zh/en/de/ru/fr),默认 zh, en",
... ... @@ -450,7 +453,16 @@ def _run_enrich_content(tenant_id: str, items: List[Dict[str, str]], languages:
450 453  
451 454 llm_langs = list(dict.fromkeys(languages)) or ["en"]
452 455  
453   - products = [{"id": it["spu_id"], "title": (it.get("title") or "").strip()} for it in items]
  456 + products = [
  457 + {
  458 + "id": it["spu_id"],
  459 + "title": (it.get("title") or "").strip(),
  460 + "brief": (it.get("brief") or "").strip(),
  461 + "description": (it.get("description") or "").strip(),
  462 + "image_url": (it.get("image_url") or "").strip(),
  463 + }
  464 + for it in items
  465 + ]
454 466 dim_keys = [
455 467 "tags",
456 468 "target_audience",
... ... @@ -545,7 +557,13 @@ async def enrich_content(request: EnrichContentRequest):
545 557 )
546 558  
547 559 items_payload = [
548   - {"spu_id": it.spu_id, "title": it.title or ""}
  560 + {
  561 + "spu_id": it.spu_id,
  562 + "title": it.title or "",
  563 + "brief": it.brief or "",
  564 + "description": it.description or "",
  565 + "image_url": it.image_url or "",
  566 + }
549 567 for it in request.items
550 568 ]
551 569 loop = asyncio.get_event_loop()
... ...
docs/搜索API对接指南.md
... ... @@ -1511,7 +1511,7 @@ curl -X POST "http://127.0.0.1:6004/indexer/build-docs-from-db" \
1511 1511  
1512 1512 | 参数 | 类型 | 必填 | 默认值 | 说明 |
1513 1513 |------|------|------|--------|------|
1514   -| `tenant_id` | string | Y | - | 租户 ID,用于缓存隔离 |
  1514 +| `tenant_id` | string | Y | - | 租户 ID。目前仅用于记录日志,不产生实际作用|
1515 1515 | `items` | array | Y | - | 待分析列表;**单次最多 50 条** |
1516 1516 | `languages` | array[string] | N | `["zh", "en"]` | 目标语言,需在支持范围内:`zh`、`en`、`de`、`ru`、`fr` |
1517 1517  
... ... @@ -1519,11 +1519,17 @@ curl -X POST "http://127.0.0.1:6004/indexer/build-docs-from-db" \
1519 1519  
1520 1520 | 字段 | 类型 | 必填 | 说明 |
1521 1521 |------|------|------|------|
1522   -| `spu_id` | string | Y | SPU ID,用于回填结果与缓存键 |
  1522 +| `spu_id` | string | Y | SPU ID,用于回填结果;目前仅用于记录日志,不产生实际作用|
1523 1523 | `title` | string | Y | 商品标题 |
1524   -| `image_url` | string | N | 商品主图 URL(预留:后续可用于图像/多模态内容理解) |
1525   -| `brief` | string | N | 商品简介/短描述(预留) |
1526   -| `description` | string | N | 商品详情/长描述(预留) |
  1524 +| `image_url` | string | N | 商品主图 URL;当前会参与内容缓存键,后续可用于图像/多模态内容理解 |
  1525 +| `brief` | string | N | 商品简介/短描述;当前会参与内容缓存键 |
  1526 +| `description` | string | N | 商品详情/长描述;当前会参与内容缓存键 |
  1527 +
  1528 +缓存说明:
  1529 +
  1530 +- 内容缓存键仅由 `target_lang + items[]` 中会影响内容理解结果的输入文本构成,目前包括:`title`、`brief`、`description`、`image_url` 的规范化内容 hash。
  1531 +- `tenant_id`、`spu_id` 只用于请求归属与结果回填,不参与缓存键。
  1532 +- 因此,输入内容不变时可跨请求直接命中缓存;任一输入字段变化时,会自然落到新的缓存 key。
1527 1533  
1528 1534 批量请求建议:
1529 1535 - **全量**:强烈建议 尽可能 **20 个 SPU/doc** 攒成一个批次后再请求一次。
... ...
indexer/product_enrich.py
... ... @@ -9,6 +9,7 @@
9 9 import os
10 10 import json
11 11 import logging
  12 +import re
12 13 import time
13 14 import hashlib
14 15 from collections import OrderedDict
... ... @@ -40,6 +41,10 @@ MAX_RETRIES = 3
40 41 RETRY_DELAY = 5 # 秒
41 42 REQUEST_TIMEOUT = 180 # 秒
42 43 LOGGED_SHARED_CONTEXT_CACHE_SIZE = 256
  44 +PROMPT_INPUT_MIN_ZH_CHARS = 20
  45 +PROMPT_INPUT_MAX_ZH_CHARS = 100
  46 +PROMPT_INPUT_MIN_WORDS = 16
  47 +PROMPT_INPUT_MAX_WORDS = 80
43 48  
44 49 # 日志路径
45 50 OUTPUT_DIR = Path("output_logs")
... ... @@ -82,6 +87,8 @@ if not verbose_logger.handlers:
82 87 verbose_logger.addHandler(verbose_file_handler)
83 88 verbose_logger.propagate = False
84 89  
  90 +logger.info("Verbose LLM logs are written to: %s", verbose_log_file)
  91 +
85 92  
86 93 # Redis 缓存(用于 anchors / 语义属性)
87 94 ANCHOR_CACHE_PREFIX = REDIS_CONFIG.get("anchor_cache_prefix", "product_anchors")
... ... @@ -112,26 +119,86 @@ if _missing_prompt_langs:
112 119 )
113 120  
114 121  
  122 +def _normalize_space(text: str) -> str:
  123 + return re.sub(r"\s+", " ", (text or "").strip())
  124 +
  125 +
  126 +def _contains_cjk(text: str) -> bool:
  127 + return bool(re.search(r"[\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff]", text or ""))
  128 +
  129 +
  130 +def _truncate_by_chars(text: str, max_chars: int) -> str:
  131 + return text[:max_chars].strip()
  132 +
  133 +
  134 +def _truncate_by_words(text: str, max_words: int) -> str:
  135 + words = re.findall(r"\S+", text or "")
  136 + return " ".join(words[:max_words]).strip()
  137 +
  138 +
  139 +def _detect_prompt_input_lang(text: str) -> str:
  140 + # 简化处理:包含 CJK 时按中文类文本处理,否则统一按空格分词类语言处理。
  141 + return "zh" if _contains_cjk(text) else "en"
  142 +
  143 +
  144 +def _build_prompt_input_text(product: Dict[str, Any]) -> str:
  145 + """
  146 + 生成真正送入 prompt 的商品文本。
  147 +
  148 + 规则:
  149 + - 默认使用 title
  150 + - 若文本过短,则依次补 brief / description
  151 + - 若文本过长,则按语言粗粒度截断
  152 + """
  153 + fields = [
  154 + _normalize_space(str(product.get("title") or "")),
  155 + _normalize_space(str(product.get("brief") or "")),
  156 + _normalize_space(str(product.get("description") or "")),
  157 + ]
  158 + parts: List[str] = []
  159 +
  160 + def join_parts() -> str:
  161 + return " | ".join(part for part in parts if part).strip()
  162 +
  163 + for field in fields:
  164 + if not field:
  165 + continue
  166 + if field not in parts:
  167 + parts.append(field)
  168 + candidate = join_parts()
  169 + if _detect_prompt_input_lang(candidate) == "zh":
  170 + if len(candidate) >= PROMPT_INPUT_MIN_ZH_CHARS:
  171 + return _truncate_by_chars(candidate, PROMPT_INPUT_MAX_ZH_CHARS)
  172 + else:
  173 + if len(re.findall(r"\S+", candidate)) >= PROMPT_INPUT_MIN_WORDS:
  174 + return _truncate_by_words(candidate, PROMPT_INPUT_MAX_WORDS)
  175 +
  176 + candidate = join_parts()
  177 + if not candidate:
  178 + return ""
  179 + if _detect_prompt_input_lang(candidate) == "zh":
  180 + return _truncate_by_chars(candidate, PROMPT_INPUT_MAX_ZH_CHARS)
  181 + return _truncate_by_words(candidate, PROMPT_INPUT_MAX_WORDS)
  182 +
  183 +
115 184 def _make_anchor_cache_key(
116   - title: str,
  185 + product: Dict[str, Any],
117 186 target_lang: str,
118   - tenant_id: Optional[str] = None,
119 187 ) -> str:
120   - """构造 anchors/语义属性的缓存 key。"""
121   - base = (tenant_id or "global").strip()
122   - h = hashlib.md5(title.encode("utf-8")).hexdigest()
123   - return f"{ANCHOR_CACHE_PREFIX}:{base}:{target_lang}:{h}"
  188 + """构造缓存 key,仅由 prompt 实际输入文本内容 + 目标语言决定。"""
  189 + prompt_input = _build_prompt_input_text(product)
  190 + h = hashlib.md5(prompt_input.encode("utf-8")).hexdigest()
  191 + return f"{ANCHOR_CACHE_PREFIX}:{target_lang}:{prompt_input[:4]}{h}"
124 192  
125 193  
126 194 def _get_cached_anchor_result(
127   - title: str,
  195 + product: Dict[str, Any],
128 196 target_lang: str,
129   - tenant_id: Optional[str] = None,
130 197 ) -> Optional[Dict[str, Any]]:
131 198 if not _anchor_redis:
132 199 return None
133 200 try:
134   - key = _make_anchor_cache_key(title, target_lang, tenant_id)
  201 + key = _make_anchor_cache_key(product, target_lang)
135 202 raw = _anchor_redis.get(key)
136 203 if not raw:
137 204 return None
... ... @@ -142,15 +209,14 @@ def _get_cached_anchor_result(
142 209  
143 210  
144 211 def _set_cached_anchor_result(
145   - title: str,
  212 + product: Dict[str, Any],
146 213 target_lang: str,
147 214 result: Dict[str, Any],
148   - tenant_id: Optional[str] = None,
149 215 ) -> None:
150 216 if not _anchor_redis:
151 217 return
152 218 try:
153   - key = _make_anchor_cache_key(title, target_lang, tenant_id)
  219 + key = _make_anchor_cache_key(product, target_lang)
154 220 ttl = ANCHOR_CACHE_EXPIRE_DAYS * 24 * 3600
155 221 _anchor_redis.setex(key, ttl, json.dumps(result, ensure_ascii=False))
156 222 except Exception as e:
... ... @@ -166,7 +232,8 @@ def _build_assistant_prefix(headers: List[str]) -> str:
166 232 def _build_shared_context(products: List[Dict[str, str]]) -> str:
167 233 shared_context = SHARED_ANALYSIS_INSTRUCTION
168 234 for idx, product in enumerate(products, 1):
169   - shared_context += f'{idx}. {product["title"]}\n'
  235 + prompt_input = _build_prompt_input_text(product)
  236 + shared_context += f"{idx}. {prompt_input}\n"
170 237 return shared_context
171 238  
172 239  
... ... @@ -619,11 +686,11 @@ def analyze_products(
619 686 uncached_items.append((idx, product))
620 687 continue
621 688  
622   - cached = _get_cached_anchor_result(title, target_lang, tenant_id=tenant_id)
  689 + cached = _get_cached_anchor_result(product, target_lang)
623 690 if cached:
624 691 logger.info(
625 692 f"[analyze_products] Cache hit for title='{title[:50]}...', "
626   - f"lang={target_lang}, tenant_id={tenant_id or 'global'}"
  693 + f"lang={target_lang}"
627 694 )
628 695 results_by_index[idx] = cached
629 696 continue
... ... @@ -650,7 +717,7 @@ def analyze_products(
650 717 )
651 718 batch_results = process_batch(batch, batch_num=batch_num, target_lang=target_lang)
652 719  
653   - for (original_idx, _), item in zip(batch_slice, batch_results):
  720 + for (original_idx, product), item in zip(batch_slice, batch_results):
654 721 results_by_index[original_idx] = item
655 722 title_input = str(item.get("title_input") or "").strip()
656 723 if not title_input:
... ... @@ -659,7 +726,7 @@ def analyze_products(
659 726 # 不缓存错误结果,避免放大临时故障
660 727 continue
661 728 try:
662   - _set_cached_anchor_result(title_input, target_lang, item, tenant_id=tenant_id)
  729 + _set_cached_anchor_result(product, target_lang, item)
663 730 except Exception:
664 731 # 已在内部记录 warning
665 732 pass
... ...
indexer/product_enrich_prompts.py
... ... @@ -8,9 +8,9 @@ SYSTEM_MESSAGE = (
8 8 "Do not repeat or modify the prefix, and do not add explanations outside the table."
9 9 )
10 10  
11   -SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product title and fill these columns:
  11 +SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product text and fill these columns:
12 12  
13   -1. Product title: a natural localized product name derived from the input title
  13 +1. Product title: a natural localized product name derived from the input product text
14 14 2. Category path: broad to fine-grained category, separated by ">"
15 15 3. Fine-grained tags: style, features, functions, or notable attributes
16 16 4. Target audience: gender, age group, or suitable users
... ... @@ -23,7 +23,7 @@ SHARED_ANALYSIS_INSTRUCTION = """Analyze each input product title and fill these
23 23  
24 24 Rules:
25 25 - Keep the input order and row count exactly the same.
26   -- Infer from the title only; if uncertain, prefer concise and broadly correct ecommerce wording.
  26 +- Infer only from the provided input product text; if uncertain, prefer concise and broadly correct ecommerce wording.
27 27 - Keep category paths concise and use ">" as the separator.
28 28 - For columns with multiple values, the localized output requirement will define the delimiter.
29 29  
... ... @@ -515,4 +515,4 @@ LANGUAGE_MARKDOWN_TABLE_HEADERS: Dict[str, Dict[str, Any]] = {
515 515 "Характеристики",
516 516 "Анкор текст"
517 517 ]
518   -}
519 518 \ No newline at end of file
  519 +}
... ...
search/es_query_builder.py
... ... @@ -275,7 +275,8 @@ class ESQueryBuilder:
275 275 "query_vector": query_vector.tolist(),
276 276 "k": knn_k,
277 277 "num_candidates": knn_num_candidates,
278   - "boost": knn_boost
  278 + "boost": knn_boost,
  279 + "name": "knn_query",
279 280 }
280 281 # Top-level knn does not inherit query.bool.filter automatically.
281 282 # Apply conjunctive + range filters here so vector recall respects hard filters.
... ...
search/rerank_client.py
... ... @@ -4,7 +4,7 @@
4 4 流程:
5 5 1. 从 ES hits 构造用于重排的文档文本列表
6 6 2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数
7   -3. 将 ES 分数(归一化)与重排分数线性融合,写回 hit["_score"] 并重排序
  7 +3. 提取 ES 文本/向量子句分数,与重排分数做乘法融合并重排序
8 8 """
9 9  
10 10 from typing import Dict, Any, List, Optional, Tuple
... ... @@ -14,7 +14,7 @@ from providers import create_rerank_provider
14 14  
15 15 logger = logging.getLogger(__name__)
16 16  
17   -# 默认融合权重:ES 归一化分数权重、重排分数权重(相加为 1)
  17 +# 历史配置项,保留签名兼容;当前乘法融合公式不再使用线性权重。
18 18 DEFAULT_WEIGHT_ES = 0.4
19 19 DEFAULT_WEIGHT_AI = 0.6
20 20 # 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大)
... ... @@ -103,18 +103,20 @@ def fuse_scores_and_resort(
103 103 weight_ai: float = DEFAULT_WEIGHT_AI,
104 104 ) -> List[Dict[str, Any]]:
105 105 """
106   - 将 ES 分数与重排分数线性融合(不修改原始 _score),并按融合分数降序重排。
  106 + 将 ES 分数与重排分数按乘法公式融合(不修改原始 _score),并按融合分数降序重排。
107 107  
108 108 对每条 hit 会写入:
109 109 - _original_score: 原始 ES 分数
110 110 - _rerank_score: 重排服务返回的分数
111 111 - _fused_score: 融合分数
  112 + - _text_score: 文本相关性分数(优先取 named queries 的 base_query 分数)
  113 + - _knn_score: KNN 分数(优先取 named queries 的 knn_query 分数)
112 114  
113 115 Args:
114 116 es_hits: ES hits 列表(会被原地修改)
115 117 rerank_scores: 与 es_hits 等长的重排分数列表
116   - weight_es: ES 归一化分数权重
117   - weight_ai: 重排分数权重
  118 + weight_es: 兼容保留,当前未使用
  119 + weight_ai: 兼容保留,当前未使用
118 120  
119 121 Returns:
120 122 每条文档的融合调试信息列表,用于 debug_info
... ... @@ -123,38 +125,62 @@ def fuse_scores_and_resort(
123 125 if n == 0 or len(rerank_scores) != n:
124 126 return []
125 127  
126   - # 收集 ES 原始分数
127   - es_scores: List[float] = []
128   - for hit in es_hits:
129   - raw = hit.get("_score")
130   - try:
131   - es_scores.append(float(raw) if raw is not None else 0.0)
132   - except (TypeError, ValueError):
133   - es_scores.append(0.0)
134   -
135   - max_es = max(es_scores) if es_scores else 0.0
136 128 fused_debug: List[Dict[str, Any]] = []
137 129  
138 130 for idx, hit in enumerate(es_hits):
139   - es_score = es_scores[idx]
  131 + raw_es_score = hit.get("_score")
  132 + try:
  133 + es_score = float(raw_es_score) if raw_es_score is not None else 0.0
  134 + except (TypeError, ValueError):
  135 + es_score = 0.0
  136 +
140 137 ai_score_raw = rerank_scores[idx]
141 138 try:
142 139 rerank_score = float(ai_score_raw)
143 140 except (TypeError, ValueError):
144 141 rerank_score = 0.0
145 142  
146   - es_norm = (es_score / max_es) if max_es > 0 else 0.0
147   - fused = weight_es * es_norm + weight_ai * rerank_score
  143 + matched_queries = hit.get("matched_queries")
  144 + text_score = 0.0
  145 + knn_score = 0.0
  146 + if isinstance(matched_queries, dict):
  147 + try:
  148 + text_score = float(matched_queries.get("base_query", 0.0) or 0.0)
  149 + except (TypeError, ValueError):
  150 + text_score = 0.0
  151 + try:
  152 + knn_score = float(matched_queries.get("knn_query", 0.0) or 0.0)
  153 + except (TypeError, ValueError):
  154 + knn_score = 0.0
  155 + elif isinstance(matched_queries, list):
  156 + text_score = 1.0 if "base_query" in matched_queries else 0.0
  157 + knn_score = 1.0 if "knn_query" in matched_queries else 0.0
  158 +
  159 + # 回退逻辑:
  160 + # - text_score 缺失时,退回原始 _score,避免纯文本召回被错误打成 0。
  161 + # - knn_score 缺失时保持 0,由平滑项 0.6 兜底。
  162 + if text_score <= 0.0:
  163 + text_score = es_score
  164 +
  165 + fused = (
  166 + (rerank_score + 0.00001) ** 1.0 *
  167 + (knn_score + 0.6) ** 0.2 *
  168 + (text_score + 0.1) ** 0.75
  169 + )
148 170  
149 171 hit["_original_score"] = hit.get("_score")
150 172 hit["_rerank_score"] = rerank_score
  173 + hit["_text_score"] = text_score
  174 + hit["_knn_score"] = knn_score
151 175 hit["_fused_score"] = fused
152 176  
153 177 fused_debug.append({
154 178 "doc_id": hit.get("_id"),
155 179 "es_score": es_score,
156   - "es_score_norm": es_norm,
157 180 "rerank_score": rerank_score,
  181 + "text_score": text_score,
  182 + "knn_score": knn_score,
  183 + "matched_queries": matched_queries,
158 184 "fused_score": fused,
159 185 })
160 186  
... ...
search/searcher.py
... ... @@ -400,6 +400,7 @@ class Searcher:
400 400 # Add sorting if specified
401 401 if sort_by:
402 402 es_query = self.query_builder.add_sorting(es_query, sort_by, sort_order)
  403 + es_query["track_scores"] = True
403 404  
404 405 # Keep requested response _source semantics for the final response fill.
405 406 response_source_spec = es_query.get("_source")
... ... @@ -467,7 +468,8 @@ class Searcher:
467 468 index_name=index_name,
468 469 body=body_for_es,
469 470 size=es_fetch_size,
470   - from_=es_fetch_from
  471 + from_=es_fetch_from,
  472 + include_named_queries_score=bool(do_rerank and in_rerank_window),
471 473 )
472 474  
473 475 # Store ES response in context
... ...
tests/test_es_query_builder.py
... ... @@ -62,3 +62,4 @@ def test_knn_prefilter_not_added_without_filters():
62 62  
63 63 assert "knn" in q
64 64 assert "filter" not in q["knn"]
  65 + assert q["knn"]["name"] == "knn_query"
... ...
tests/test_product_enrich_partial_mode.py
... ... @@ -62,7 +62,7 @@ def test_create_prompt_splits_shared_context_and_localized_tail():
62 62 shared_en, user_en, prefix_en = product_enrich.create_prompt(products, target_lang="en")
63 63  
64 64 assert shared_zh == shared_en
65   - assert "Analyze each input product title" in shared_zh
  65 + assert "Analyze each input product text" in shared_zh
66 66 assert "1. dress" in shared_zh
67 67 assert "2. linen shirt" in shared_zh
68 68 assert "Product list" not in user_zh
... ... @@ -232,11 +232,20 @@ def test_analyze_products_uses_product_level_cache_across_batch_requests():
232 232 cache_store = {}
233 233 process_calls = []
234 234  
235   - def fake_get_cached_anchor_result(title, target_lang, tenant_id=None):
236   - return cache_store.get((tenant_id, target_lang, title))
  235 + def _cache_key(product, target_lang):
  236 + return (
  237 + target_lang,
  238 + product.get("title", ""),
  239 + product.get("brief", ""),
  240 + product.get("description", ""),
  241 + product.get("image_url", ""),
  242 + )
  243 +
  244 + def fake_get_cached_anchor_result(product, target_lang):
  245 + return cache_store.get(_cache_key(product, target_lang))
237 246  
238   - def fake_set_cached_anchor_result(title, target_lang, result, tenant_id=None):
239   - cache_store[(tenant_id, target_lang, title)] = result
  247 + def fake_set_cached_anchor_result(product, target_lang, result):
  248 + cache_store[_cache_key(product, target_lang)] = result
240 249  
241 250 def fake_process_batch(batch_data, batch_num, target_lang="zh"):
242 251 process_calls.append(
... ... @@ -291,7 +300,7 @@ def test_analyze_products_uses_product_level_cache_across_batch_requests():
291 300 second = product_enrich.analyze_products(
292 301 products,
293 302 target_lang="zh",
294   - tenant_id="170",
  303 + tenant_id="999",
295 304 )
296 305 third = product_enrich.analyze_products(
297 306 products,
... ... @@ -311,3 +320,63 @@ def test_analyze_products_uses_product_level_cache_across_batch_requests():
311 320 assert second[1]["anchor_text"] == "anchor:shirt"
312 321 assert third[0]["anchor_text"] == "anchor:dress"
313 322 assert third[1]["anchor_text"] == "anchor:shirt"
  323 +
  324 +
  325 +def test_anchor_cache_key_depends_on_product_input_not_identifiers():
  326 + product_a = {
  327 + "id": "1",
  328 + "spu_id": "1001",
  329 + "title": "dress",
  330 + "brief": "soft cotton",
  331 + "description": "summer dress",
  332 + "image_url": "https://img/a.jpg",
  333 + }
  334 + product_b = {
  335 + "id": "2",
  336 + "spu_id": "9999",
  337 + "title": "dress",
  338 + "brief": "soft cotton",
  339 + "description": "summer dress",
  340 + "image_url": "https://img/a.jpg",
  341 + }
  342 + product_c = {
  343 + "id": "1",
  344 + "spu_id": "1001",
  345 + "title": "dress",
  346 + "brief": "soft cotton updated",
  347 + "description": "summer dress",
  348 + "image_url": "https://img/a.jpg",
  349 + }
  350 +
  351 + key_a = product_enrich._make_anchor_cache_key(product_a, "zh")
  352 + key_b = product_enrich._make_anchor_cache_key(product_b, "zh")
  353 + key_c = product_enrich._make_anchor_cache_key(product_c, "zh")
  354 +
  355 + assert key_a == key_b
  356 + assert key_a != key_c
  357 +
  358 +
  359 +def test_build_prompt_input_text_appends_brief_and_description_for_short_title():
  360 + product = {
  361 + "title": "T恤",
  362 + "brief": "夏季透气纯棉短袖,舒适亲肤",
  363 + "description": "100%棉,圆领版型,适合日常通勤与休闲穿搭。",
  364 + }
  365 +
  366 + text = product_enrich._build_prompt_input_text(product)
  367 +
  368 + assert text.startswith("T恤")
  369 + assert "夏季透气纯棉短袖" in text
  370 + assert "100%棉" in text
  371 +
  372 +
  373 +def test_build_prompt_input_text_truncates_non_cjk_by_words():
  374 + product = {
  375 + "title": "dress",
  376 + "brief": " ".join(f"brief{i}" for i in range(50)),
  377 + "description": " ".join(f"desc{i}" for i in range(50)),
  378 + }
  379 +
  380 + text = product_enrich._build_prompt_input_text(product)
  381 +
  382 + assert len(text.split()) <= product_enrich.PROMPT_INPUT_MAX_WORDS
... ...
tests/test_rerank_client.py 0 → 100644
... ... @@ -0,0 +1,53 @@
  1 +from math import isclose
  2 +
  3 +from search.rerank_client import fuse_scores_and_resort
  4 +
  5 +
  6 +def test_fuse_scores_and_resort_uses_multiplicative_formula_with_named_query_scores():
  7 + hits = [
  8 + {
  9 + "_id": "1",
  10 + "_score": 3.2,
  11 + "matched_queries": {
  12 + "base_query": 2.4,
  13 + "knn_query": 0.8,
  14 + },
  15 + },
  16 + {
  17 + "_id": "2",
  18 + "_score": 2.8,
  19 + "matched_queries": {
  20 + "base_query": 1.6,
  21 + "knn_query": 0.2,
  22 + },
  23 + },
  24 + ]
  25 +
  26 + debug = fuse_scores_and_resort(hits, [0.9, 0.7])
  27 +
  28 + expected_1 = (0.9 + 0.00001) * ((0.8 + 0.6) ** 0.2) * ((2.4 + 0.1) ** 0.75)
  29 + expected_2 = (0.7 + 0.00001) * ((0.2 + 0.6) ** 0.2) * ((1.6 + 0.1) ** 0.75)
  30 +
  31 + assert isclose(hits[0]["_fused_score"], expected_1, rel_tol=1e-9)
  32 + assert isclose(hits[1]["_fused_score"], expected_2, rel_tol=1e-9)
  33 + assert debug[0]["text_score"] == 2.4
  34 + assert debug[0]["knn_score"] == 0.8
  35 + assert [hit["_id"] for hit in hits] == ["1", "2"]
  36 +
  37 +
  38 +def test_fuse_scores_and_resort_falls_back_when_matched_queries_missing():
  39 + hits = [
  40 + {"_id": "1", "_score": 0.5},
  41 + {"_id": "2", "_score": 2.0},
  42 + ]
  43 +
  44 + fuse_scores_and_resort(hits, [0.4, 0.3])
  45 +
  46 + expected_1 = (0.4 + 0.00001) * ((0.0 + 0.6) ** 0.2) * ((0.5 + 0.1) ** 0.75)
  47 + expected_2 = (0.3 + 0.00001) * ((0.0 + 0.6) ** 0.2) * ((2.0 + 0.1) ** 0.75)
  48 +
  49 + assert isclose(hits[0]["_text_score"], 2.0, rel_tol=1e-9)
  50 + assert isclose(hits[0]["_fused_score"], expected_2, rel_tol=1e-9)
  51 + assert isclose(hits[1]["_text_score"], 0.5, rel_tol=1e-9)
  52 + assert isclose(hits[1]["_fused_score"], expected_1, rel_tol=1e-9)
  53 + assert [hit["_id"] for hit in hits] == ["2", "1"]
... ...
tests/test_search_rerank_window.py
... ... @@ -97,9 +97,22 @@ class _FakeESClient:
97 97 "skus": [],
98 98 }
99 99  
100   - def search(self, index_name: str, body: Dict[str, Any], size: int, from_: int):
  100 + def search(
  101 + self,
  102 + index_name: str,
  103 + body: Dict[str, Any],
  104 + size: int,
  105 + from_: int,
  106 + include_named_queries_score: bool = False,
  107 + ):
101 108 self.calls.append(
102   - {"index_name": index_name, "body": body, "size": size, "from_": from_}
  109 + {
  110 + "index_name": index_name,
  111 + "body": body,
  112 + "size": size,
  113 + "from_": from_,
  114 + "include_named_queries_score": include_named_queries_score,
  115 + }
103 116 )
104 117 ids_query = (((body or {}).get("query") or {}).get("ids") or {}).get("values")
105 118 source_spec = (body or {}).get("_source")
... ... @@ -213,6 +226,7 @@ def test_searcher_reranks_top_window_by_default(monkeypatch):
213 226 assert called["docs"] == window
214 227 assert es_client.calls[0]["from_"] == 0
215 228 assert es_client.calls[0]["size"] == window
  229 + assert es_client.calls[0]["include_named_queries_score"] is True
216 230 assert es_client.calls[0]["body"]["_source"] == {"includes": ["title"]}
217 231 assert len(es_client.calls) == 2
218 232 assert es_client.calls[1]["size"] == 10
... ... @@ -277,6 +291,7 @@ def test_searcher_skips_rerank_when_request_explicitly_false(monkeypatch):
277 291 assert called["count"] == 0
278 292 assert es_client.calls[0]["from_"] == 20
279 293 assert es_client.calls[0]["size"] == 10
  294 + assert es_client.calls[0]["include_named_queries_score"] is False
280 295 assert len(es_client.calls) == 1
281 296  
282 297  
... ... @@ -310,4 +325,5 @@ def test_searcher_skips_rerank_when_page_exceeds_window(monkeypatch):
310 325 assert called["count"] == 0
311 326 assert es_client.calls[0]["from_"] == 995
312 327 assert es_client.calls[0]["size"] == 10
  328 + assert es_client.calls[0]["include_named_queries_score"] is False
313 329 assert len(es_client.calls) == 1
... ...
utils/es_client.py
... ... @@ -228,6 +228,7 @@ class ESClient:
228 228 size: int = 10,
229 229 from_: int = 0,
230 230 routing: Optional[str] = None,
  231 + include_named_queries_score: bool = False,
231 232 ) -> Dict[str, Any]:
232 233 """
233 234 Execute search query.
... ... @@ -260,6 +261,7 @@ class ESClient:
260 261 size=size,
261 262 from_=from_,
262 263 routing=routing,
  264 + include_named_queries_score=include_named_queries_score,
263 265 )
264 266 # elasticsearch-py 8.x returns ObjectApiResponse; normalize to mutable dict
265 267 # so caller can safely patch hits/took during post-processing.
... ...