Commit 506c39b751130ec541c23b7fdabb1f5d373a97f8
1 parent
d90e7428
feat(search): 统一重排逻辑,仅由 ai_search 控制并调用外部 BGE 重排服务
- API:新增请求参数 ai_search,开启时在窗口内走重排流程 - 配置:RerankConfig 移除 enabled/expression/description,仅保留 rerank_window 及 service_url/timeout_sec/weight_es/weight_ai;默认超时 15s - 重排流程:ai_search 且 from+size<=rerank_window 时,ES 取前 rerank_window 条, 调用外部 /rerank 服务,融合 ES 与重排分数后按 from/size 分页;否则不重排 - search/rerank_client:新增模块,封装 build_docs、call_rerank_service、 fuse_scores_and_resort、run_rerank;超时单独捕获并简短日志 - search/searcher:移除 RerankEngine,enable_rerank=ai_search,使用 config.rerank 参数 - 删除 search/rerank_engine.py(本地表达式重排),统一为外部服务一种实现 - 文档:搜索 API 对接指南补充 ai_search 与 relevance_score 说明 - 测试:conftest 中 rerank 配置改为新结构 Co-authored-by: Cursor <cursoragent@cursor.com>
Showing
10 changed files
with
362 additions
and
203 deletions
Show diff stats
api/models.py
| @@ -151,6 +151,10 @@ class SearchRequest(BaseModel): | @@ -151,6 +151,10 @@ class SearchRequest(BaseModel): | ||
| 151 | min_score: Optional[float] = Field(None, ge=0, description="最小相关性分数阈值") | 151 | min_score: Optional[float] = Field(None, ge=0, description="最小相关性分数阈值") |
| 152 | highlight: bool = Field(False, description="是否高亮搜索关键词(暂不实现)") | 152 | highlight: bool = Field(False, description="是否高亮搜索关键词(暂不实现)") |
| 153 | debug: bool = Field(False, description="是否返回调试信息") | 153 | debug: bool = Field(False, description="是否返回调试信息") |
| 154 | + ai_search: bool = Field( | ||
| 155 | + False, | ||
| 156 | + description="是否开启 AI 搜索(调用本地重排服务对 ES 结果进行二次排序)" | ||
| 157 | + ) | ||
| 154 | 158 | ||
| 155 | # SKU筛选参数 | 159 | # SKU筛选参数 |
| 156 | sku_filter_dimension: Optional[List[str]] = Field( | 160 | sku_filter_dimension: Optional[List[str]] = Field( |
api/routes/search.py
| @@ -84,6 +84,7 @@ async def search(request: SearchRequest, http_request: Request): | @@ -84,6 +84,7 @@ async def search(request: SearchRequest, http_request: Request): | ||
| 84 | f"min_score: {request.min_score} | " | 84 | f"min_score: {request.min_score} | " |
| 85 | f"language: {request.language} | " | 85 | f"language: {request.language} | " |
| 86 | f"debug: {request.debug} | " | 86 | f"debug: {request.debug} | " |
| 87 | + f"ai_search: {request.ai_search} | " | ||
| 87 | f"sku_filter_dimension: {request.sku_filter_dimension} | " | 88 | f"sku_filter_dimension: {request.sku_filter_dimension} | " |
| 88 | f"filters: {request.filters} | " | 89 | f"filters: {request.filters} | " |
| 89 | f"range_filters: {request.range_filters} | " | 90 | f"range_filters: {request.range_filters} | " |
| @@ -111,6 +112,7 @@ async def search(request: SearchRequest, http_request: Request): | @@ -111,6 +112,7 @@ async def search(request: SearchRequest, http_request: Request): | ||
| 111 | debug=request.debug, | 112 | debug=request.debug, |
| 112 | language=request.language, | 113 | language=request.language, |
| 113 | sku_filter_dimension=request.sku_filter_dimension, | 114 | sku_filter_dimension=request.sku_filter_dimension, |
| 115 | + ai_search=request.ai_search, | ||
| 114 | ) | 116 | ) |
| 115 | 117 | ||
| 116 | # Include performance summary in response | 118 | # Include performance summary in response |
config/config.yaml
| @@ -133,11 +133,14 @@ function_score: | @@ -133,11 +133,14 @@ function_score: | ||
| 133 | boost_mode: "multiply" | 133 | boost_mode: "multiply" |
| 134 | functions: [] | 134 | functions: [] |
| 135 | 135 | ||
| 136 | -# Rerank配置(本地重排,当前禁用) | 136 | +# 重排配置(唯一实现:外部 BGE 重排服务,由请求参数 ai_search 控制是否执行) |
| 137 | +# ai_search 且 from+size<=rerank_window 时:从 ES 取前 rerank_window 条、重排后再按 from/size 分页 | ||
| 137 | rerank: | 138 | rerank: |
| 138 | - enabled: false | ||
| 139 | - expression: "" | ||
| 140 | - description: "Local reranking (disabled, use ES function_score instead)" | 139 | + rerank_window: 1000 |
| 140 | + # service_url: "http://127.0.0.1:6007/rerank" # 可选,不填则用默认端口 6007 | ||
| 141 | + timeout_sec: 15.0 # 文档多时重排耗时长,可按需调大 | ||
| 142 | + weight_es: 0.4 | ||
| 143 | + weight_ai: 0.6 | ||
| 141 | 144 | ||
| 142 | # SPU配置(已启用,使用嵌套skus) | 145 | # SPU配置(已启用,使用嵌套skus) |
| 143 | spu_config: | 146 | spu_config: |
config/config_loader.py
| @@ -88,10 +88,14 @@ class RankingConfig: | @@ -88,10 +88,14 @@ class RankingConfig: | ||
| 88 | 88 | ||
| 89 | @dataclass | 89 | @dataclass |
| 90 | class RerankConfig: | 90 | class RerankConfig: |
| 91 | - """本地重排配置(当前禁用)""" | ||
| 92 | - enabled: bool = False | ||
| 93 | - expression: str = "" | ||
| 94 | - description: str = "" | 91 | + """重排配置(唯一实现:调用外部 BGE 重排服务,由请求参数 ai_search 控制是否执行)""" |
| 92 | + # 重排窗口:ai_search 且 from+size<=rerank_window 时,从 ES 取前 rerank_window 条重排后再分页 | ||
| 93 | + rerank_window: int = 1000 | ||
| 94 | + # 可选:重排服务 URL,为空时使用 reranker 模块默认端口 6007 | ||
| 95 | + service_url: Optional[str] = None | ||
| 96 | + timeout_sec: float = 15.0 | ||
| 97 | + weight_es: float = 0.4 | ||
| 98 | + weight_ai: float = 0.6 | ||
| 95 | 99 | ||
| 96 | 100 | ||
| 97 | @dataclass | 101 | @dataclass |
| @@ -263,12 +267,14 @@ class ConfigLoader: | @@ -263,12 +267,14 @@ class ConfigLoader: | ||
| 263 | functions=fs_data.get("functions") or [] | 267 | functions=fs_data.get("functions") or [] |
| 264 | ) | 268 | ) |
| 265 | 269 | ||
| 266 | - # Parse Rerank configuration | 270 | + # Parse Rerank configuration(唯一实现:外部重排服务,由 ai_search 控制) |
| 267 | rerank_data = config_data.get("rerank", {}) | 271 | rerank_data = config_data.get("rerank", {}) |
| 268 | rerank = RerankConfig( | 272 | rerank = RerankConfig( |
| 269 | - enabled=rerank_data.get("enabled", False), | ||
| 270 | - expression=rerank_data.get("expression") or "", | ||
| 271 | - description=rerank_data.get("description") or "" | 273 | + rerank_window=int(rerank_data.get("rerank_window", 1000)), |
| 274 | + service_url=rerank_data.get("service_url") or None, | ||
| 275 | + timeout_sec=float(rerank_data.get("timeout_sec", 15.0)), | ||
| 276 | + weight_es=float(rerank_data.get("weight_es", 0.4)), | ||
| 277 | + weight_ai=float(rerank_data.get("weight_ai", 0.6)), | ||
| 272 | ) | 278 | ) |
| 273 | 279 | ||
| 274 | # Parse SPU config | 280 | # Parse SPU config |
| @@ -399,9 +405,11 @@ class ConfigLoader: | @@ -399,9 +405,11 @@ class ConfigLoader: | ||
| 399 | "functions": config.function_score.functions | 405 | "functions": config.function_score.functions |
| 400 | }, | 406 | }, |
| 401 | "rerank": { | 407 | "rerank": { |
| 402 | - "enabled": config.rerank.enabled, | ||
| 403 | - "expression": config.rerank.expression, | ||
| 404 | - "description": config.rerank.description | 408 | + "rerank_window": config.rerank.rerank_window, |
| 409 | + "service_url": config.rerank.service_url, | ||
| 410 | + "timeout_sec": config.rerank.timeout_sec, | ||
| 411 | + "weight_es": config.rerank.weight_es, | ||
| 412 | + "weight_ai": config.rerank.weight_ai, | ||
| 405 | }, | 413 | }, |
| 406 | "spu_config": { | 414 | "spu_config": { |
| 407 | "enabled": config.spu_config.enabled, | 415 | "enabled": config.spu_config.enabled, |
docs/搜索API对接指南.md
| @@ -167,6 +167,7 @@ curl -X POST "http://120.76.41.98:6002/search/" \ | @@ -167,6 +167,7 @@ curl -X POST "http://120.76.41.98:6002/search/" \ | ||
| 167 | "min_score": 0.0, | 167 | "min_score": 0.0, |
| 168 | "sku_filter_dimension": ["string"], | 168 | "sku_filter_dimension": ["string"], |
| 169 | "debug": false, | 169 | "debug": false, |
| 170 | + "ai_search": false, | ||
| 170 | "user_id": "string", | 171 | "user_id": "string", |
| 171 | "session_id": "string" | 172 | "session_id": "string" |
| 172 | } | 173 | } |
| @@ -188,6 +189,7 @@ curl -X POST "http://120.76.41.98:6002/search/" \ | @@ -188,6 +189,7 @@ curl -X POST "http://120.76.41.98:6002/search/" \ | ||
| 188 | | `min_score` | float | N | null | 最小相关性分数阈值 | | 189 | | `min_score` | float | N | null | 最小相关性分数阈值 | |
| 189 | | `sku_filter_dimension` | array[string] | N | null | 子SKU筛选维度列表(见[SKU筛选维度](#35-sku筛选维度)) | | 190 | | `sku_filter_dimension` | array[string] | N | null | 子SKU筛选维度列表(见[SKU筛选维度](#35-sku筛选维度)) | |
| 190 | | `debug` | boolean | N | false | 是否返回调试信息 | | 191 | | `debug` | boolean | N | false | 是否返回调试信息 | |
| 192 | +| `ai_search` | boolean | N | false | 是否开启 AI 搜索(调用本地重排服务对 ES 结果进行二次排序) | | ||
| 191 | | `user_id` | string | N | null | 用户ID(用于个性化,预留) | | 193 | | `user_id` | string | N | null | 用户ID(用于个性化,预留) | |
| 192 | | `session_id` | string | N | null | 会话ID(用于分析,预留) | | 194 | | `session_id` | string | N | null | 会话ID(用于分析,预留) | |
| 193 | 195 | ||
| @@ -787,7 +789,7 @@ curl "http://localhost:6002/search/12345" | @@ -787,7 +789,7 @@ curl "http://localhost:6002/search/12345" | ||
| 787 | | `option3_name` | string | 选项3名称 | | 789 | | `option3_name` | string | 选项3名称 | |
| 788 | | `specifications` | array[object] | 规格列表(与ES specifications字段对应) | | 790 | | `specifications` | array[object] | 规格列表(与ES specifications字段对应) | |
| 789 | | `skus` | array | SKU 列表 | | 791 | | `skus` | array | SKU 列表 | |
| 790 | -| `relevance_score` | float | 相关性分数 | | 792 | +| `relevance_score` | float | 相关性分数(默认为 ES 原始分数;当开启 AI 搜索时为融合后的最终分数) | |
| 791 | 793 | ||
| 792 | ### 4.4 SkuResult字段说明 | 794 | ### 4.4 SkuResult字段说明 |
| 793 | 795 |
search/__init__.py
| @@ -2,14 +2,12 @@ | @@ -2,14 +2,12 @@ | ||
| 2 | 2 | ||
| 3 | from .boolean_parser import BooleanParser, QueryNode | 3 | from .boolean_parser import BooleanParser, QueryNode |
| 4 | from .es_query_builder import ESQueryBuilder | 4 | from .es_query_builder import ESQueryBuilder |
| 5 | -from .rerank_engine import RerankEngine | ||
| 6 | from .searcher import Searcher, SearchResult | 5 | from .searcher import Searcher, SearchResult |
| 7 | 6 | ||
| 8 | __all__ = [ | 7 | __all__ = [ |
| 9 | 'BooleanParser', | 8 | 'BooleanParser', |
| 10 | 'QueryNode', | 9 | 'QueryNode', |
| 11 | 'ESQueryBuilder', | 10 | 'ESQueryBuilder', |
| 12 | - 'RerankEngine', | ||
| 13 | 'Searcher', | 11 | 'Searcher', |
| 14 | 'SearchResult', | 12 | 'SearchResult', |
| 15 | ] | 13 | ] |
| @@ -0,0 +1,244 @@ | @@ -0,0 +1,244 @@ | ||
| 1 | +""" | ||
| 2 | +重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。 | ||
| 3 | + | ||
| 4 | +流程: | ||
| 5 | +1. 从 ES hits 构造用于重排的文档文本列表 | ||
| 6 | +2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数 | ||
| 7 | +3. 将 ES 分数(归一化)与重排分数线性融合,写回 hit["_score"] 并重排序 | ||
| 8 | +""" | ||
| 9 | + | ||
| 10 | +from typing import Dict, Any, List, Optional, Tuple | ||
| 11 | +import logging | ||
| 12 | + | ||
| 13 | +logger = logging.getLogger(__name__) | ||
| 14 | + | ||
| 15 | +# 默认融合权重:ES 归一化分数权重、重排分数权重(相加为 1) | ||
| 16 | +DEFAULT_WEIGHT_ES = 0.4 | ||
| 17 | +DEFAULT_WEIGHT_AI = 0.6 | ||
| 18 | +# 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大) | ||
| 19 | +DEFAULT_TIMEOUT_SEC = 15.0 | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +def build_docs_from_hits( | ||
| 23 | + es_hits: List[Dict[str, Any]], | ||
| 24 | + language: str = "zh", | ||
| 25 | +) -> List[str]: | ||
| 26 | + """ | ||
| 27 | + 从 ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。 | ||
| 28 | + | ||
| 29 | + 文本由 title、brief、description、vendor、category_path 等多语言字段拼接, | ||
| 30 | + 按 language 优先选取对应语言;若无内容则用 spu_id 兜底。 | ||
| 31 | + | ||
| 32 | + Args: | ||
| 33 | + es_hits: ES 返回的 hits 列表,每项含 _source | ||
| 34 | + language: 语言代码,如 "zh"、"en" | ||
| 35 | + | ||
| 36 | + Returns: | ||
| 37 | + 与 es_hits 等长的字符串列表,用于 POST /rerank 的 docs | ||
| 38 | + """ | ||
| 39 | + lang = (language or "zh").strip().lower() | ||
| 40 | + if lang not in ("zh", "en"): | ||
| 41 | + lang = "zh" | ||
| 42 | + | ||
| 43 | + def pick_lang_text(obj: Any) -> str: | ||
| 44 | + if obj is None: | ||
| 45 | + return "" | ||
| 46 | + if isinstance(obj, dict): | ||
| 47 | + return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip() | ||
| 48 | + return str(obj).strip() | ||
| 49 | + | ||
| 50 | + docs: List[str] = [] | ||
| 51 | + for hit in es_hits: | ||
| 52 | + src = hit.get("_source") or {} | ||
| 53 | + parts: List[str] = [] | ||
| 54 | + for key in ("title", "brief", "description", "vendor", "category_path"): | ||
| 55 | + parts.append(pick_lang_text(src.get(key))) | ||
| 56 | + text = " ".join(p for p in parts if p).strip() | ||
| 57 | + if not text: | ||
| 58 | + text = str(src.get("spu_id", "")) | ||
| 59 | + docs.append(text) | ||
| 60 | + return docs | ||
| 61 | + | ||
| 62 | + | ||
| 63 | +def call_rerank_service( | ||
| 64 | + query: str, | ||
| 65 | + docs: List[str], | ||
| 66 | + service_url: str, | ||
| 67 | + timeout_sec: float = DEFAULT_TIMEOUT_SEC, | ||
| 68 | +) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: | ||
| 69 | + """ | ||
| 70 | + 调用重排服务 POST /rerank,返回分数列表与 meta。 | ||
| 71 | + | ||
| 72 | + Args: | ||
| 73 | + query: 搜索查询字符串 | ||
| 74 | + docs: 文档文本列表(与 ES hits 顺序一致) | ||
| 75 | + service_url: 完整 URL,如 http://127.0.0.1:6007/rerank | ||
| 76 | + timeout_sec: 请求超时秒数 | ||
| 77 | + | ||
| 78 | + Returns: | ||
| 79 | + (scores, meta):成功时 scores 与 docs 等长,meta 为服务返回的 meta; | ||
| 80 | + 失败时返回 (None, None) | ||
| 81 | + """ | ||
| 82 | + if not docs: | ||
| 83 | + return [], {} | ||
| 84 | + try: | ||
| 85 | + import requests | ||
| 86 | + payload = {"query": (query or "").strip(), "docs": docs} | ||
| 87 | + response = requests.post(service_url, json=payload, timeout=timeout_sec) | ||
| 88 | + if response.status_code != 200: | ||
| 89 | + logger.warning( | ||
| 90 | + "Rerank service HTTP %s: %s", | ||
| 91 | + response.status_code, | ||
| 92 | + (response.text or "")[:200], | ||
| 93 | + ) | ||
| 94 | + return None, None | ||
| 95 | + data = response.json() | ||
| 96 | + scores = data.get("scores") | ||
| 97 | + if not isinstance(scores, list): | ||
| 98 | + return None, None | ||
| 99 | + return scores, data.get("meta") or {} | ||
| 100 | + except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as e: | ||
| 101 | + logger.warning( | ||
| 102 | + "Rerank request timed out after %.1fs (docs=%d); returning ES order. %s", | ||
| 103 | + timeout_sec, len(docs), e, | ||
| 104 | + ) | ||
| 105 | + return None, None | ||
| 106 | + except Exception as e: | ||
| 107 | + logger.warning("Rerank request failed: %s", e, exc_info=True) | ||
| 108 | + return None, None | ||
| 109 | + | ||
| 110 | + | ||
| 111 | +def fuse_scores_and_resort( | ||
| 112 | + es_hits: List[Dict[str, Any]], | ||
| 113 | + rerank_scores: List[float], | ||
| 114 | + weight_es: float = DEFAULT_WEIGHT_ES, | ||
| 115 | + weight_ai: float = DEFAULT_WEIGHT_AI, | ||
| 116 | +) -> List[Dict[str, Any]]: | ||
| 117 | + """ | ||
| 118 | + 将 ES 分数与重排分数线性融合,写回每条 hit 的 _score,并按融合分数降序重排。 | ||
| 119 | + | ||
| 120 | + 对每条 hit 会写入: | ||
| 121 | + - _original_score: 原始 ES 分数 | ||
| 122 | + - _ai_rerank_score: 重排服务返回的分数 | ||
| 123 | + - _fused_score: 融合分数 | ||
| 124 | + - _score: 置为融合分数(供后续 ResultFormatter 使用) | ||
| 125 | + | ||
| 126 | + Args: | ||
| 127 | + es_hits: ES hits 列表(会被原地修改) | ||
| 128 | + rerank_scores: 与 es_hits 等长的重排分数列表 | ||
| 129 | + weight_es: ES 归一化分数权重 | ||
| 130 | + weight_ai: 重排分数权重 | ||
| 131 | + | ||
| 132 | + Returns: | ||
| 133 | + 每条文档的融合调试信息列表,用于 debug_info | ||
| 134 | + """ | ||
| 135 | + n = len(es_hits) | ||
| 136 | + if n == 0 or len(rerank_scores) != n: | ||
| 137 | + return [] | ||
| 138 | + | ||
| 139 | + # 收集 ES 原始分数 | ||
| 140 | + es_scores: List[float] = [] | ||
| 141 | + for hit in es_hits: | ||
| 142 | + raw = hit.get("_score") | ||
| 143 | + try: | ||
| 144 | + es_scores.append(float(raw) if raw is not None else 0.0) | ||
| 145 | + except (TypeError, ValueError): | ||
| 146 | + es_scores.append(0.0) | ||
| 147 | + | ||
| 148 | + max_es = max(es_scores) if es_scores else 0.0 | ||
| 149 | + fused_debug: List[Dict[str, Any]] = [] | ||
| 150 | + | ||
| 151 | + for idx, hit in enumerate(es_hits): | ||
| 152 | + es_score = es_scores[idx] | ||
| 153 | + ai_score_raw = rerank_scores[idx] | ||
| 154 | + try: | ||
| 155 | + ai_score = float(ai_score_raw) | ||
| 156 | + except (TypeError, ValueError): | ||
| 157 | + ai_score = 0.0 | ||
| 158 | + | ||
| 159 | + es_norm = (es_score / max_es) if max_es > 0 else 0.0 | ||
| 160 | + fused = weight_es * es_norm + weight_ai * ai_score | ||
| 161 | + | ||
| 162 | + hit["_original_score"] = hit.get("_score") | ||
| 163 | + hit["_ai_rerank_score"] = ai_score | ||
| 164 | + hit["_fused_score"] = fused | ||
| 165 | + hit["_score"] = fused | ||
| 166 | + | ||
| 167 | + fused_debug.append({ | ||
| 168 | + "doc_id": hit.get("_id"), | ||
| 169 | + "es_score": es_score, | ||
| 170 | + "es_score_norm": es_norm, | ||
| 171 | + "ai_rerank_score": ai_score, | ||
| 172 | + "fused_score": fused, | ||
| 173 | + }) | ||
| 174 | + | ||
| 175 | + # 按融合分数降序重排 | ||
| 176 | + es_hits.sort( | ||
| 177 | + key=lambda h: h.get("_fused_score", h.get("_score", 0.0)), | ||
| 178 | + reverse=True, | ||
| 179 | + ) | ||
| 180 | + return fused_debug | ||
| 181 | + | ||
| 182 | + | ||
| 183 | +def run_rerank( | ||
| 184 | + query: str, | ||
| 185 | + es_response: Dict[str, Any], | ||
| 186 | + language: str = "zh", | ||
| 187 | + service_url: Optional[str] = None, | ||
| 188 | + timeout_sec: float = DEFAULT_TIMEOUT_SEC, | ||
| 189 | + weight_es: float = DEFAULT_WEIGHT_ES, | ||
| 190 | + weight_ai: float = DEFAULT_WEIGHT_AI, | ||
| 191 | +) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: | ||
| 192 | + """ | ||
| 193 | + 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 | ||
| 194 | + | ||
| 195 | + Args: | ||
| 196 | + query: 搜索查询 | ||
| 197 | + es_response: ES 原始响应(其中的 hits["hits"] 会被原地修改) | ||
| 198 | + language: 文档文本使用的语言 | ||
| 199 | + service_url: 重排服务 URL,为 None 时使用默认 127.0.0.1:6007 | ||
| 200 | + timeout_sec: 请求超时 | ||
| 201 | + weight_es: ES 分数权重 | ||
| 202 | + weight_ai: 重排分数权重 | ||
| 203 | + | ||
| 204 | + Returns: | ||
| 205 | + (es_response, rerank_meta, fused_debug): | ||
| 206 | + - es_response: 已更新 hits 与 max_score 的响应(同一引用) | ||
| 207 | + - rerank_meta: 重排服务返回的 meta,失败时为 None | ||
| 208 | + - fused_debug: 每条文档的融合信息,供 debug 使用 | ||
| 209 | + """ | ||
| 210 | + try: | ||
| 211 | + from reranker.config import CONFIG as RERANKER_CONFIG | ||
| 212 | + except Exception: | ||
| 213 | + RERANKER_CONFIG = None | ||
| 214 | + | ||
| 215 | + url = service_url | ||
| 216 | + if not url and RERANKER_CONFIG is not None: | ||
| 217 | + url = f"http://127.0.0.1:{RERANKER_CONFIG.PORT}/rerank" | ||
| 218 | + if not url: | ||
| 219 | + url = "http://127.0.0.1:6007/rerank" | ||
| 220 | + | ||
| 221 | + hits = es_response.get("hits", {}).get("hits") or [] | ||
| 222 | + if not hits: | ||
| 223 | + return es_response, None, [] | ||
| 224 | + | ||
| 225 | + docs = build_docs_from_hits(hits, language=language) | ||
| 226 | + scores, meta = call_rerank_service(query, docs, url, timeout_sec=timeout_sec) | ||
| 227 | + | ||
| 228 | + if scores is None or len(scores) != len(hits): | ||
| 229 | + return es_response, None, [] | ||
| 230 | + | ||
| 231 | + fused_debug = fuse_scores_and_resort( | ||
| 232 | + hits, | ||
| 233 | + scores, | ||
| 234 | + weight_es=weight_es, | ||
| 235 | + weight_ai=weight_ai, | ||
| 236 | + ) | ||
| 237 | + | ||
| 238 | + # 更新 max_score 为融合后的最高分 | ||
| 239 | + if hits: | ||
| 240 | + top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0 | ||
| 241 | + if "hits" in es_response: | ||
| 242 | + es_response["hits"]["max_score"] = top | ||
| 243 | + | ||
| 244 | + return es_response, meta, fused_debug |
search/rerank_engine.py deleted
| @@ -1,171 +0,0 @@ | @@ -1,171 +0,0 @@ | ||
| 1 | -""" | ||
| 2 | -Reranking engine for post-processing search result scoring. | ||
| 3 | - | ||
| 4 | -本地重排引擎,用于ES返回结果后的二次排序。 | ||
| 5 | -当前状态:已禁用,优先使用ES的function_score。 | ||
| 6 | - | ||
| 7 | -Supports expression-based ranking with functions like: | ||
| 8 | -- bm25(): Base BM25 text relevance score | ||
| 9 | -- text_embedding_relevance(): KNN embedding similarity | ||
| 10 | -- field_value(field): Use field value in scoring | ||
| 11 | -- timeliness(date_field): Time decay function | ||
| 12 | -""" | ||
| 13 | - | ||
| 14 | -import re | ||
| 15 | -from typing import Dict, Any, List, Optional | ||
| 16 | -import math | ||
| 17 | - | ||
| 18 | - | ||
| 19 | -class RerankEngine: | ||
| 20 | - """ | ||
| 21 | - 本地重排引擎(当前禁用) | ||
| 22 | - | ||
| 23 | - 功能:对ES返回的结果进行二次打分和排序 | ||
| 24 | - 用途:复杂的自定义排序逻辑、实时个性化等 | ||
| 25 | - """ | ||
| 26 | - | ||
| 27 | - def __init__(self, ranking_expression: str, enabled: bool = False): | ||
| 28 | - """ | ||
| 29 | - Initialize rerank engine. | ||
| 30 | - | ||
| 31 | - Args: | ||
| 32 | - ranking_expression: Ranking expression string | ||
| 33 | - Example: "bm25() + 0.2*text_embedding_relevance() + general_score*2" | ||
| 34 | - enabled: Whether local reranking is enabled (default: False) | ||
| 35 | - """ | ||
| 36 | - self.enabled = enabled | ||
| 37 | - self.expression = ranking_expression | ||
| 38 | - self.parsed_terms = [] | ||
| 39 | - | ||
| 40 | - if enabled: | ||
| 41 | - self.parsed_terms = self._parse_expression(ranking_expression) | ||
| 42 | - | ||
| 43 | - def _parse_expression(self, expression: str) -> List[Dict[str, Any]]: | ||
| 44 | - """ | ||
| 45 | - Parse ranking expression into terms. | ||
| 46 | - | ||
| 47 | - Args: | ||
| 48 | - expression: Ranking expression | ||
| 49 | - | ||
| 50 | - Returns: | ||
| 51 | - List of term dictionaries | ||
| 52 | - """ | ||
| 53 | - terms = [] | ||
| 54 | - | ||
| 55 | - # Pattern to match: coefficient * function() or field_name | ||
| 56 | - # Example: "0.2*text_embedding_relevance()" or "general_score*2" | ||
| 57 | - pattern = r'([+-]?\s*\d*\.?\d*)\s*\*?\s*([a-zA-Z_]\w*(?:\([^)]*\))?)' | ||
| 58 | - | ||
| 59 | - for match in re.finditer(pattern, expression): | ||
| 60 | - coef_str = match.group(1).strip() | ||
| 61 | - func_str = match.group(2).strip() | ||
| 62 | - | ||
| 63 | - # Parse coefficient | ||
| 64 | - if coef_str in ['', '+']: | ||
| 65 | - coefficient = 1.0 | ||
| 66 | - elif coef_str == '-': | ||
| 67 | - coefficient = -1.0 | ||
| 68 | - else: | ||
| 69 | - try: | ||
| 70 | - coefficient = float(coef_str) | ||
| 71 | - except ValueError: | ||
| 72 | - coefficient = 1.0 | ||
| 73 | - | ||
| 74 | - # Check if function or field | ||
| 75 | - if '(' in func_str: | ||
| 76 | - # Function call | ||
| 77 | - func_name = func_str[:func_str.index('(')] | ||
| 78 | - args_str = func_str[func_str.index('(') + 1:func_str.rindex(')')] | ||
| 79 | - args = [arg.strip() for arg in args_str.split(',')] if args_str else [] | ||
| 80 | - | ||
| 81 | - terms.append({ | ||
| 82 | - 'type': 'function', | ||
| 83 | - 'name': func_name, | ||
| 84 | - 'args': args, | ||
| 85 | - 'coefficient': coefficient | ||
| 86 | - }) | ||
| 87 | - else: | ||
| 88 | - # Field reference | ||
| 89 | - terms.append({ | ||
| 90 | - 'type': 'field', | ||
| 91 | - 'name': func_str, | ||
| 92 | - 'coefficient': coefficient | ||
| 93 | - }) | ||
| 94 | - | ||
| 95 | - return terms | ||
| 96 | - | ||
| 97 | - def calculate_score( | ||
| 98 | - self, | ||
| 99 | - hit: Dict[str, Any], | ||
| 100 | - base_score: float, | ||
| 101 | - knn_score: Optional[float] = None | ||
| 102 | - ) -> float: | ||
| 103 | - """ | ||
| 104 | - Calculate final score for a search result. | ||
| 105 | - | ||
| 106 | - Args: | ||
| 107 | - hit: ES hit document | ||
| 108 | - base_score: Base BM25 score | ||
| 109 | - knn_score: KNN similarity score (if available) | ||
| 110 | - | ||
| 111 | - Returns: | ||
| 112 | - Final calculated score | ||
| 113 | - """ | ||
| 114 | - if not self.enabled: | ||
| 115 | - return base_score | ||
| 116 | - | ||
| 117 | - score = 0.0 | ||
| 118 | - source = hit.get('_source', {}) | ||
| 119 | - | ||
| 120 | - for term in self.parsed_terms: | ||
| 121 | - term_value = 0.0 | ||
| 122 | - | ||
| 123 | - if term['type'] == 'function': | ||
| 124 | - func_name = term['name'] | ||
| 125 | - | ||
| 126 | - if func_name == 'bm25': | ||
| 127 | - term_value = base_score | ||
| 128 | - | ||
| 129 | - elif func_name == 'text_embedding_relevance': | ||
| 130 | - term_value = knn_score if knn_score is not None else 0.0 | ||
| 131 | - | ||
| 132 | - elif func_name == 'timeliness': | ||
| 133 | - # Time decay function | ||
| 134 | - if term['args']: | ||
| 135 | - date_field = term['args'][0] | ||
| 136 | - if date_field in source: | ||
| 137 | - # Simple time decay (would need actual implementation) | ||
| 138 | - term_value = 1.0 | ||
| 139 | - else: | ||
| 140 | - term_value = 1.0 | ||
| 141 | - | ||
| 142 | - elif func_name == 'field_value': | ||
| 143 | - # Get field value | ||
| 144 | - if term['args'] and term['args'][0] in source: | ||
| 145 | - field_value = source[term['args'][0]] | ||
| 146 | - try: | ||
| 147 | - term_value = float(field_value) | ||
| 148 | - except (ValueError, TypeError): | ||
| 149 | - term_value = 0.0 | ||
| 150 | - | ||
| 151 | - elif term['type'] == 'field': | ||
| 152 | - # Direct field reference | ||
| 153 | - field_name = term['name'] | ||
| 154 | - if field_name in source: | ||
| 155 | - try: | ||
| 156 | - term_value = float(source[field_name]) | ||
| 157 | - except (ValueError, TypeError): | ||
| 158 | - term_value = 0.0 | ||
| 159 | - | ||
| 160 | - score += term['coefficient'] * term_value | ||
| 161 | - | ||
| 162 | - return score | ||
| 163 | - | ||
| 164 | - def get_expression(self) -> str: | ||
| 165 | - """Get ranking expression.""" | ||
| 166 | - return self.expression | ||
| 167 | - | ||
| 168 | - def get_terms(self) -> List[Dict[str, Any]]: | ||
| 169 | - """Get parsed expression terms.""" | ||
| 170 | - return self.parsed_terms | ||
| 171 | - |
search/searcher.py
| @@ -13,7 +13,6 @@ from query import QueryParser, ParsedQuery | @@ -13,7 +13,6 @@ from query import QueryParser, ParsedQuery | ||
| 13 | from embeddings import CLIPImageEncoder | 13 | from embeddings import CLIPImageEncoder |
| 14 | from .boolean_parser import BooleanParser, QueryNode | 14 | from .boolean_parser import BooleanParser, QueryNode |
| 15 | from .es_query_builder import ESQueryBuilder | 15 | from .es_query_builder import ESQueryBuilder |
| 16 | -from .rerank_engine import RerankEngine | ||
| 17 | from config import SearchConfig | 16 | from config import SearchConfig |
| 18 | from config.tenant_config_loader import get_tenant_config_loader | 17 | from config.tenant_config_loader import get_tenant_config_loader |
| 19 | from config.utils import get_match_fields_for_index | 18 | from config.utils import get_match_fields_for_index |
| @@ -99,7 +98,6 @@ class Searcher: | @@ -99,7 +98,6 @@ class Searcher: | ||
| 99 | 98 | ||
| 100 | # Initialize components | 99 | # Initialize components |
| 101 | self.boolean_parser = BooleanParser() | 100 | self.boolean_parser = BooleanParser() |
| 102 | - self.rerank_engine = RerankEngine(config.ranking.expression, enabled=False) | ||
| 103 | 101 | ||
| 104 | # Get match fields from config | 102 | # Get match fields from config |
| 105 | self.match_fields = get_match_fields_for_index(config, "default") | 103 | self.match_fields = get_match_fields_for_index(config, "default") |
| @@ -137,6 +135,7 @@ class Searcher: | @@ -137,6 +135,7 @@ class Searcher: | ||
| 137 | debug: bool = False, | 135 | debug: bool = False, |
| 138 | language: str = "en", | 136 | language: str = "en", |
| 139 | sku_filter_dimension: Optional[List[str]] = None, | 137 | sku_filter_dimension: Optional[List[str]] = None, |
| 138 | + ai_search: bool = False, | ||
| 140 | ) -> SearchResult: | 139 | ) -> SearchResult: |
| 141 | """ | 140 | """ |
| 142 | Execute search query (外部友好格式). | 141 | Execute search query (外部友好格式). |
| @@ -168,15 +167,21 @@ class Searcher: | @@ -168,15 +167,21 @@ class Searcher: | ||
| 168 | index_langs = tenant_cfg.get("index_languages") or [] | 167 | index_langs = tenant_cfg.get("index_languages") or [] |
| 169 | enable_translation = len(index_langs) > 0 | 168 | enable_translation = len(index_langs) > 0 |
| 170 | enable_embedding = self.config.query_config.enable_text_embedding | 169 | enable_embedding = self.config.query_config.enable_text_embedding |
| 171 | - enable_rerank = False # Temporarily disabled | 170 | + # 重排仅由请求参数 ai_search 控制,唯一实现为调用外部 BGE 重排服务 |
| 171 | + enable_rerank = bool(ai_search) | ||
| 172 | + rerank_window = self.config.rerank.rerank_window or 1000 | ||
| 173 | + # 若开启重排且请求范围在窗口内:从 ES 取前 rerank_window 条、重排后再按 from/size 分页;否则不重排,按原 from/size 查 ES | ||
| 174 | + in_rerank_window = enable_rerank and (from_ + size) <= rerank_window | ||
| 175 | + es_fetch_from = 0 if in_rerank_window else from_ | ||
| 176 | + es_fetch_size = rerank_window if in_rerank_window else size | ||
| 172 | 177 | ||
| 173 | # Start timing | 178 | # Start timing |
| 174 | context.start_stage(RequestContextStage.TOTAL) | 179 | context.start_stage(RequestContextStage.TOTAL) |
| 175 | 180 | ||
| 176 | context.logger.info( | 181 | context.logger.info( |
| 177 | f"开始搜索请求 | 查询: '{query}' | 参数: size={size}, from_={from_}, " | 182 | f"开始搜索请求 | 查询: '{query}' | 参数: size={size}, from_={from_}, " |
| 178 | - f"enable_translation={enable_translation}, enable_embedding={enable_embedding}, " | ||
| 179 | - f"enable_rerank={enable_rerank}, min_score={min_score}", | 183 | + f"enable_rerank={enable_rerank}, in_rerank_window={in_rerank_window}, es_fetch=({es_fetch_from},{es_fetch_size}) | " |
| 184 | + f"enable_translation={enable_translation}, enable_embedding={enable_embedding}, min_score={min_score}", | ||
| 180 | extra={'reqid': context.reqid, 'uid': context.uid} | 185 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 181 | ) | 186 | ) |
| 182 | 187 | ||
| @@ -184,6 +189,9 @@ class Searcher: | @@ -184,6 +189,9 @@ class Searcher: | ||
| 184 | context.metadata['search_params'] = { | 189 | context.metadata['search_params'] = { |
| 185 | 'size': size, | 190 | 'size': size, |
| 186 | 'from_': from_, | 191 | 'from_': from_, |
| 192 | + 'es_fetch_from': es_fetch_from, | ||
| 193 | + 'es_fetch_size': es_fetch_size, | ||
| 194 | + 'in_rerank_window': in_rerank_window, | ||
| 187 | 'filters': filters, | 195 | 'filters': filters, |
| 188 | 'range_filters': range_filters, | 196 | 'range_filters': range_filters, |
| 189 | 'facets': facets, | 197 | 'facets': facets, |
| @@ -287,8 +295,8 @@ class Searcher: | @@ -287,8 +295,8 @@ class Searcher: | ||
| 287 | filters=filters, | 295 | filters=filters, |
| 288 | range_filters=range_filters, | 296 | range_filters=range_filters, |
| 289 | facet_configs=facets, | 297 | facet_configs=facets, |
| 290 | - size=size, | ||
| 291 | - from_=from_, | 298 | + size=es_fetch_size, |
| 299 | + from_=es_fetch_from, | ||
| 292 | enable_knn=enable_embedding and parsed_query.query_vector is not None, | 300 | enable_knn=enable_embedding and parsed_query.query_vector is not None, |
| 293 | min_score=min_score, | 301 | min_score=min_score, |
| 294 | parsed_query=parsed_query | 302 | parsed_query=parsed_query |
| @@ -336,12 +344,12 @@ class Searcher: | @@ -336,12 +344,12 @@ class Searcher: | ||
| 336 | # Step 4: Elasticsearch search | 344 | # Step 4: Elasticsearch search |
| 337 | context.start_stage(RequestContextStage.ELASTICSEARCH_SEARCH) | 345 | context.start_stage(RequestContextStage.ELASTICSEARCH_SEARCH) |
| 338 | try: | 346 | try: |
| 339 | - # Use tenant-specific index name | 347 | + # Use tenant-specific index name(开启重排且在窗口内时已用 es_fetch_size/es_fetch_from) |
| 340 | es_response = self.es_client.search( | 348 | es_response = self.es_client.search( |
| 341 | index_name=index_name, | 349 | index_name=index_name, |
| 342 | body=body_for_es, | 350 | body=body_for_es, |
| 343 | - size=size, | ||
| 344 | - from_=from_ | 351 | + size=es_fetch_size, |
| 352 | + from_=es_fetch_from | ||
| 345 | ) | 353 | ) |
| 346 | 354 | ||
| 347 | # Store ES response in context | 355 | # Store ES response in context |
| @@ -365,6 +373,69 @@ class Searcher: | @@ -365,6 +373,69 @@ class Searcher: | ||
| 365 | finally: | 373 | finally: |
| 366 | context.end_stage(RequestContextStage.ELASTICSEARCH_SEARCH) | 374 | context.end_stage(RequestContextStage.ELASTICSEARCH_SEARCH) |
| 367 | 375 | ||
| 376 | + # Optional Step 4.5: AI reranking(仅当请求范围在重排窗口内时执行) | ||
| 377 | + if enable_rerank and in_rerank_window: | ||
| 378 | + context.start_stage(RequestContextStage.RERANKING) | ||
| 379 | + try: | ||
| 380 | + from .rerank_client import run_rerank | ||
| 381 | + | ||
| 382 | + rerank_query = parsed_query.original_query if parsed_query else query | ||
| 383 | + rc = self.config.rerank | ||
| 384 | + es_response, rerank_meta, fused_debug = run_rerank( | ||
| 385 | + query=rerank_query, | ||
| 386 | + es_response=es_response, | ||
| 387 | + language=language, | ||
| 388 | + service_url=rc.service_url, | ||
| 389 | + timeout_sec=rc.timeout_sec, | ||
| 390 | + weight_es=rc.weight_es, | ||
| 391 | + weight_ai=rc.weight_ai, | ||
| 392 | + ) | ||
| 393 | + | ||
| 394 | + if rerank_meta is not None: | ||
| 395 | + try: | ||
| 396 | + from reranker.config import CONFIG as RERANKER_CONFIG | ||
| 397 | + rerank_url = f"http://127.0.0.1:{RERANKER_CONFIG.PORT}/rerank" | ||
| 398 | + except Exception: | ||
| 399 | + rerank_url = "http://127.0.0.1:6007/rerank" | ||
| 400 | + context.metadata.setdefault("rerank_info", {}) | ||
| 401 | + context.metadata["rerank_info"].update({ | ||
| 402 | + "service_url": rerank_url, | ||
| 403 | + "docs": len(es_response.get("hits", {}).get("hits") or []), | ||
| 404 | + "meta": rerank_meta, | ||
| 405 | + }) | ||
| 406 | + context.store_intermediate_result("rerank_scores", fused_debug) | ||
| 407 | + context.logger.info( | ||
| 408 | + f"重排完成 | docs={len(fused_debug)} | meta={rerank_meta}", | ||
| 409 | + extra={'reqid': context.reqid, 'uid': context.uid} | ||
| 410 | + ) | ||
| 411 | + except Exception as e: | ||
| 412 | + context.add_warning(f"Rerank failed: {e}") | ||
| 413 | + context.logger.warning( | ||
| 414 | + f"调用重排服务失败 | error: {e}", | ||
| 415 | + extra={'reqid': context.reqid, 'uid': context.uid}, | ||
| 416 | + exc_info=True, | ||
| 417 | + ) | ||
| 418 | + finally: | ||
| 419 | + context.end_stage(RequestContextStage.RERANKING) | ||
| 420 | + | ||
| 421 | + # 当本次请求在重排窗口内时:已从 ES 取了 rerank_window 条并可能已重排,需按请求的 from/size 做分页切片 | ||
| 422 | + if in_rerank_window: | ||
| 423 | + hits = es_response.get("hits", {}).get("hits") or [] | ||
| 424 | + sliced = hits[from_ : from_ + size] | ||
| 425 | + es_response.setdefault("hits", {})["hits"] = sliced | ||
| 426 | + if sliced: | ||
| 427 | + slice_max = max((h.get("_score") for h in sliced), default=0.0) | ||
| 428 | + try: | ||
| 429 | + es_response["hits"]["max_score"] = float(slice_max) | ||
| 430 | + except (TypeError, ValueError): | ||
| 431 | + es_response["hits"]["max_score"] = 0.0 | ||
| 432 | + else: | ||
| 433 | + es_response["hits"]["max_score"] = 0.0 | ||
| 434 | + context.logger.info( | ||
| 435 | + f"重排分页切片 | from={from_}, size={size}, 返回={len(sliced)}条", | ||
| 436 | + extra={'reqid': context.reqid, 'uid': context.uid} | ||
| 437 | + ) | ||
| 438 | + | ||
| 368 | # Step 5: Result processing | 439 | # Step 5: Result processing |
| 369 | context.start_stage(RequestContextStage.RESULT_PROCESSING) | 440 | context.start_stage(RequestContextStage.RESULT_PROCESSING) |
| 370 | try: | 441 | try: |
| @@ -379,7 +450,7 @@ class Searcher: | @@ -379,7 +450,7 @@ class Searcher: | ||
| 379 | total_value = total.get('value', 0) | 450 | total_value = total.get('value', 0) |
| 380 | else: | 451 | else: |
| 381 | total_value = total | 452 | total_value = total |
| 382 | - | 453 | + # max_score 会在启用 AI 搜索时被更新为融合分数的最大值 |
| 383 | max_score = es_response.get('hits', {}).get('max_score') or 0.0 | 454 | max_score = es_response.get('hits', {}).get('max_score') or 0.0 |
| 384 | 455 | ||
| 385 | # Format results using ResultFormatter | 456 | # Format results using ResultFormatter |
tests/conftest.py
| @@ -191,9 +191,7 @@ def temp_config_file() -> Generator[str, None, None]: | @@ -191,9 +191,7 @@ def temp_config_file() -> Generator[str, None, None]: | ||
| 191 | "functions": [] | 191 | "functions": [] |
| 192 | }, | 192 | }, |
| 193 | "rerank": { | 193 | "rerank": { |
| 194 | - "enabled": False, | ||
| 195 | - "expression": "", | ||
| 196 | - "description": "" | 194 | + "rerank_window": 1000 |
| 197 | } | 195 | } |
| 198 | } | 196 | } |
| 199 | 197 |