Commit 506c39b751130ec541c23b7fdabb1f5d373a97f8
1 parent
d90e7428
feat(search): 统一重排逻辑,仅由 ai_search 控制并调用外部 BGE 重排服务
- API:新增请求参数 ai_search,开启时在窗口内走重排流程 - 配置:RerankConfig 移除 enabled/expression/description,仅保留 rerank_window 及 service_url/timeout_sec/weight_es/weight_ai;默认超时 15s - 重排流程:ai_search 且 from+size<=rerank_window 时,ES 取前 rerank_window 条, 调用外部 /rerank 服务,融合 ES 与重排分数后按 from/size 分页;否则不重排 - search/rerank_client:新增模块,封装 build_docs、call_rerank_service、 fuse_scores_and_resort、run_rerank;超时单独捕获并简短日志 - search/searcher:移除 RerankEngine,enable_rerank=ai_search,使用 config.rerank 参数 - 删除 search/rerank_engine.py(本地表达式重排),统一为外部服务一种实现 - 文档:搜索 API 对接指南补充 ai_search 与 relevance_score 说明 - 测试:conftest 中 rerank 配置改为新结构 Co-authored-by: Cursor <cursoragent@cursor.com>
Showing
10 changed files
with
362 additions
and
203 deletions
Show diff stats
api/models.py
| ... | ... | @@ -151,6 +151,10 @@ class SearchRequest(BaseModel): |
| 151 | 151 | min_score: Optional[float] = Field(None, ge=0, description="最小相关性分数阈值") |
| 152 | 152 | highlight: bool = Field(False, description="是否高亮搜索关键词(暂不实现)") |
| 153 | 153 | debug: bool = Field(False, description="是否返回调试信息") |
| 154 | + ai_search: bool = Field( | |
| 155 | + False, | |
| 156 | + description="是否开启 AI 搜索(调用本地重排服务对 ES 结果进行二次排序)" | |
| 157 | + ) | |
| 154 | 158 | |
| 155 | 159 | # SKU筛选参数 |
| 156 | 160 | sku_filter_dimension: Optional[List[str]] = Field( | ... | ... |
api/routes/search.py
| ... | ... | @@ -84,6 +84,7 @@ async def search(request: SearchRequest, http_request: Request): |
| 84 | 84 | f"min_score: {request.min_score} | " |
| 85 | 85 | f"language: {request.language} | " |
| 86 | 86 | f"debug: {request.debug} | " |
| 87 | + f"ai_search: {request.ai_search} | " | |
| 87 | 88 | f"sku_filter_dimension: {request.sku_filter_dimension} | " |
| 88 | 89 | f"filters: {request.filters} | " |
| 89 | 90 | f"range_filters: {request.range_filters} | " |
| ... | ... | @@ -111,6 +112,7 @@ async def search(request: SearchRequest, http_request: Request): |
| 111 | 112 | debug=request.debug, |
| 112 | 113 | language=request.language, |
| 113 | 114 | sku_filter_dimension=request.sku_filter_dimension, |
| 115 | + ai_search=request.ai_search, | |
| 114 | 116 | ) |
| 115 | 117 | |
| 116 | 118 | # Include performance summary in response | ... | ... |
config/config.yaml
| ... | ... | @@ -133,11 +133,14 @@ function_score: |
| 133 | 133 | boost_mode: "multiply" |
| 134 | 134 | functions: [] |
| 135 | 135 | |
| 136 | -# Rerank配置(本地重排,当前禁用) | |
| 136 | +# 重排配置(唯一实现:外部 BGE 重排服务,由请求参数 ai_search 控制是否执行) | |
| 137 | +# ai_search 且 from+size<=rerank_window 时:从 ES 取前 rerank_window 条、重排后再按 from/size 分页 | |
| 137 | 138 | rerank: |
| 138 | - enabled: false | |
| 139 | - expression: "" | |
| 140 | - description: "Local reranking (disabled, use ES function_score instead)" | |
| 139 | + rerank_window: 1000 | |
| 140 | + # service_url: "http://127.0.0.1:6007/rerank" # 可选,不填则用默认端口 6007 | |
| 141 | + timeout_sec: 15.0 # 文档多时重排耗时长,可按需调大 | |
| 142 | + weight_es: 0.4 | |
| 143 | + weight_ai: 0.6 | |
| 141 | 144 | |
| 142 | 145 | # SPU配置(已启用,使用嵌套skus) |
| 143 | 146 | spu_config: | ... | ... |
config/config_loader.py
| ... | ... | @@ -88,10 +88,14 @@ class RankingConfig: |
| 88 | 88 | |
| 89 | 89 | @dataclass |
| 90 | 90 | class RerankConfig: |
| 91 | - """本地重排配置(当前禁用)""" | |
| 92 | - enabled: bool = False | |
| 93 | - expression: str = "" | |
| 94 | - description: str = "" | |
| 91 | + """重排配置(唯一实现:调用外部 BGE 重排服务,由请求参数 ai_search 控制是否执行)""" | |
| 92 | + # 重排窗口:ai_search 且 from+size<=rerank_window 时,从 ES 取前 rerank_window 条重排后再分页 | |
| 93 | + rerank_window: int = 1000 | |
| 94 | + # 可选:重排服务 URL,为空时使用 reranker 模块默认端口 6007 | |
| 95 | + service_url: Optional[str] = None | |
| 96 | + timeout_sec: float = 15.0 | |
| 97 | + weight_es: float = 0.4 | |
| 98 | + weight_ai: float = 0.6 | |
| 95 | 99 | |
| 96 | 100 | |
| 97 | 101 | @dataclass |
| ... | ... | @@ -263,12 +267,14 @@ class ConfigLoader: |
| 263 | 267 | functions=fs_data.get("functions") or [] |
| 264 | 268 | ) |
| 265 | 269 | |
| 266 | - # Parse Rerank configuration | |
| 270 | + # Parse Rerank configuration(唯一实现:外部重排服务,由 ai_search 控制) | |
| 267 | 271 | rerank_data = config_data.get("rerank", {}) |
| 268 | 272 | rerank = RerankConfig( |
| 269 | - enabled=rerank_data.get("enabled", False), | |
| 270 | - expression=rerank_data.get("expression") or "", | |
| 271 | - description=rerank_data.get("description") or "" | |
| 273 | + rerank_window=int(rerank_data.get("rerank_window", 1000)), | |
| 274 | + service_url=rerank_data.get("service_url") or None, | |
| 275 | + timeout_sec=float(rerank_data.get("timeout_sec", 15.0)), | |
| 276 | + weight_es=float(rerank_data.get("weight_es", 0.4)), | |
| 277 | + weight_ai=float(rerank_data.get("weight_ai", 0.6)), | |
| 272 | 278 | ) |
| 273 | 279 | |
| 274 | 280 | # Parse SPU config |
| ... | ... | @@ -399,9 +405,11 @@ class ConfigLoader: |
| 399 | 405 | "functions": config.function_score.functions |
| 400 | 406 | }, |
| 401 | 407 | "rerank": { |
| 402 | - "enabled": config.rerank.enabled, | |
| 403 | - "expression": config.rerank.expression, | |
| 404 | - "description": config.rerank.description | |
| 408 | + "rerank_window": config.rerank.rerank_window, | |
| 409 | + "service_url": config.rerank.service_url, | |
| 410 | + "timeout_sec": config.rerank.timeout_sec, | |
| 411 | + "weight_es": config.rerank.weight_es, | |
| 412 | + "weight_ai": config.rerank.weight_ai, | |
| 405 | 413 | }, |
| 406 | 414 | "spu_config": { |
| 407 | 415 | "enabled": config.spu_config.enabled, | ... | ... |
docs/搜索API对接指南.md
| ... | ... | @@ -167,6 +167,7 @@ curl -X POST "http://120.76.41.98:6002/search/" \ |
| 167 | 167 | "min_score": 0.0, |
| 168 | 168 | "sku_filter_dimension": ["string"], |
| 169 | 169 | "debug": false, |
| 170 | + "ai_search": false, | |
| 170 | 171 | "user_id": "string", |
| 171 | 172 | "session_id": "string" |
| 172 | 173 | } |
| ... | ... | @@ -188,6 +189,7 @@ curl -X POST "http://120.76.41.98:6002/search/" \ |
| 188 | 189 | | `min_score` | float | N | null | 最小相关性分数阈值 | |
| 189 | 190 | | `sku_filter_dimension` | array[string] | N | null | 子SKU筛选维度列表(见[SKU筛选维度](#35-sku筛选维度)) | |
| 190 | 191 | | `debug` | boolean | N | false | 是否返回调试信息 | |
| 192 | +| `ai_search` | boolean | N | false | 是否开启 AI 搜索(调用本地重排服务对 ES 结果进行二次排序) | | |
| 191 | 193 | | `user_id` | string | N | null | 用户ID(用于个性化,预留) | |
| 192 | 194 | | `session_id` | string | N | null | 会话ID(用于分析,预留) | |
| 193 | 195 | |
| ... | ... | @@ -787,7 +789,7 @@ curl "http://localhost:6002/search/12345" |
| 787 | 789 | | `option3_name` | string | 选项3名称 | |
| 788 | 790 | | `specifications` | array[object] | 规格列表(与ES specifications字段对应) | |
| 789 | 791 | | `skus` | array | SKU 列表 | |
| 790 | -| `relevance_score` | float | 相关性分数 | | |
| 792 | +| `relevance_score` | float | 相关性分数(默认为 ES 原始分数;当开启 AI 搜索时为融合后的最终分数) | | |
| 791 | 793 | |
| 792 | 794 | ### 4.4 SkuResult字段说明 |
| 793 | 795 | ... | ... |
search/__init__.py
| ... | ... | @@ -2,14 +2,12 @@ |
| 2 | 2 | |
| 3 | 3 | from .boolean_parser import BooleanParser, QueryNode |
| 4 | 4 | from .es_query_builder import ESQueryBuilder |
| 5 | -from .rerank_engine import RerankEngine | |
| 6 | 5 | from .searcher import Searcher, SearchResult |
| 7 | 6 | |
| 8 | 7 | __all__ = [ |
| 9 | 8 | 'BooleanParser', |
| 10 | 9 | 'QueryNode', |
| 11 | 10 | 'ESQueryBuilder', |
| 12 | - 'RerankEngine', | |
| 13 | 11 | 'Searcher', |
| 14 | 12 | 'SearchResult', |
| 15 | 13 | ] | ... | ... |
| ... | ... | @@ -0,0 +1,244 @@ |
| 1 | +""" | |
| 2 | +重排客户端:调用外部 BGE 重排服务,并对 ES 分数与重排分数进行融合。 | |
| 3 | + | |
| 4 | +流程: | |
| 5 | +1. 从 ES hits 构造用于重排的文档文本列表 | |
| 6 | +2. POST 请求到重排服务 /rerank,获取每条文档的 relevance 分数 | |
| 7 | +3. 将 ES 分数(归一化)与重排分数线性融合,写回 hit["_score"] 并重排序 | |
| 8 | +""" | |
| 9 | + | |
| 10 | +from typing import Dict, Any, List, Optional, Tuple | |
| 11 | +import logging | |
| 12 | + | |
| 13 | +logger = logging.getLogger(__name__) | |
| 14 | + | |
| 15 | +# 默认融合权重:ES 归一化分数权重、重排分数权重(相加为 1) | |
| 16 | +DEFAULT_WEIGHT_ES = 0.4 | |
| 17 | +DEFAULT_WEIGHT_AI = 0.6 | |
| 18 | +# 重排服务默认超时(文档较多时需更大,建议 config 中 timeout_sec 调大) | |
| 19 | +DEFAULT_TIMEOUT_SEC = 15.0 | |
| 20 | + | |
| 21 | + | |
| 22 | +def build_docs_from_hits( | |
| 23 | + es_hits: List[Dict[str, Any]], | |
| 24 | + language: str = "zh", | |
| 25 | +) -> List[str]: | |
| 26 | + """ | |
| 27 | + 从 ES 命中结果构造重排服务所需的文档文本列表(与 hits 一一对应)。 | |
| 28 | + | |
| 29 | + 文本由 title、brief、description、vendor、category_path 等多语言字段拼接, | |
| 30 | + 按 language 优先选取对应语言;若无内容则用 spu_id 兜底。 | |
| 31 | + | |
| 32 | + Args: | |
| 33 | + es_hits: ES 返回的 hits 列表,每项含 _source | |
| 34 | + language: 语言代码,如 "zh"、"en" | |
| 35 | + | |
| 36 | + Returns: | |
| 37 | + 与 es_hits 等长的字符串列表,用于 POST /rerank 的 docs | |
| 38 | + """ | |
| 39 | + lang = (language or "zh").strip().lower() | |
| 40 | + if lang not in ("zh", "en"): | |
| 41 | + lang = "zh" | |
| 42 | + | |
| 43 | + def pick_lang_text(obj: Any) -> str: | |
| 44 | + if obj is None: | |
| 45 | + return "" | |
| 46 | + if isinstance(obj, dict): | |
| 47 | + return str(obj.get(lang) or obj.get("zh") or obj.get("en") or "").strip() | |
| 48 | + return str(obj).strip() | |
| 49 | + | |
| 50 | + docs: List[str] = [] | |
| 51 | + for hit in es_hits: | |
| 52 | + src = hit.get("_source") or {} | |
| 53 | + parts: List[str] = [] | |
| 54 | + for key in ("title", "brief", "description", "vendor", "category_path"): | |
| 55 | + parts.append(pick_lang_text(src.get(key))) | |
| 56 | + text = " ".join(p for p in parts if p).strip() | |
| 57 | + if not text: | |
| 58 | + text = str(src.get("spu_id", "")) | |
| 59 | + docs.append(text) | |
| 60 | + return docs | |
| 61 | + | |
| 62 | + | |
| 63 | +def call_rerank_service( | |
| 64 | + query: str, | |
| 65 | + docs: List[str], | |
| 66 | + service_url: str, | |
| 67 | + timeout_sec: float = DEFAULT_TIMEOUT_SEC, | |
| 68 | +) -> Tuple[Optional[List[float]], Optional[Dict[str, Any]]]: | |
| 69 | + """ | |
| 70 | + 调用重排服务 POST /rerank,返回分数列表与 meta。 | |
| 71 | + | |
| 72 | + Args: | |
| 73 | + query: 搜索查询字符串 | |
| 74 | + docs: 文档文本列表(与 ES hits 顺序一致) | |
| 75 | + service_url: 完整 URL,如 http://127.0.0.1:6007/rerank | |
| 76 | + timeout_sec: 请求超时秒数 | |
| 77 | + | |
| 78 | + Returns: | |
| 79 | + (scores, meta):成功时 scores 与 docs 等长,meta 为服务返回的 meta; | |
| 80 | + 失败时返回 (None, None) | |
| 81 | + """ | |
| 82 | + if not docs: | |
| 83 | + return [], {} | |
| 84 | + try: | |
| 85 | + import requests | |
| 86 | + payload = {"query": (query or "").strip(), "docs": docs} | |
| 87 | + response = requests.post(service_url, json=payload, timeout=timeout_sec) | |
| 88 | + if response.status_code != 200: | |
| 89 | + logger.warning( | |
| 90 | + "Rerank service HTTP %s: %s", | |
| 91 | + response.status_code, | |
| 92 | + (response.text or "")[:200], | |
| 93 | + ) | |
| 94 | + return None, None | |
| 95 | + data = response.json() | |
| 96 | + scores = data.get("scores") | |
| 97 | + if not isinstance(scores, list): | |
| 98 | + return None, None | |
| 99 | + return scores, data.get("meta") or {} | |
| 100 | + except (requests.exceptions.ReadTimeout, requests.exceptions.ConnectTimeout) as e: | |
| 101 | + logger.warning( | |
| 102 | + "Rerank request timed out after %.1fs (docs=%d); returning ES order. %s", | |
| 103 | + timeout_sec, len(docs), e, | |
| 104 | + ) | |
| 105 | + return None, None | |
| 106 | + except Exception as e: | |
| 107 | + logger.warning("Rerank request failed: %s", e, exc_info=True) | |
| 108 | + return None, None | |
| 109 | + | |
| 110 | + | |
| 111 | +def fuse_scores_and_resort( | |
| 112 | + es_hits: List[Dict[str, Any]], | |
| 113 | + rerank_scores: List[float], | |
| 114 | + weight_es: float = DEFAULT_WEIGHT_ES, | |
| 115 | + weight_ai: float = DEFAULT_WEIGHT_AI, | |
| 116 | +) -> List[Dict[str, Any]]: | |
| 117 | + """ | |
| 118 | + 将 ES 分数与重排分数线性融合,写回每条 hit 的 _score,并按融合分数降序重排。 | |
| 119 | + | |
| 120 | + 对每条 hit 会写入: | |
| 121 | + - _original_score: 原始 ES 分数 | |
| 122 | + - _ai_rerank_score: 重排服务返回的分数 | |
| 123 | + - _fused_score: 融合分数 | |
| 124 | + - _score: 置为融合分数(供后续 ResultFormatter 使用) | |
| 125 | + | |
| 126 | + Args: | |
| 127 | + es_hits: ES hits 列表(会被原地修改) | |
| 128 | + rerank_scores: 与 es_hits 等长的重排分数列表 | |
| 129 | + weight_es: ES 归一化分数权重 | |
| 130 | + weight_ai: 重排分数权重 | |
| 131 | + | |
| 132 | + Returns: | |
| 133 | + 每条文档的融合调试信息列表,用于 debug_info | |
| 134 | + """ | |
| 135 | + n = len(es_hits) | |
| 136 | + if n == 0 or len(rerank_scores) != n: | |
| 137 | + return [] | |
| 138 | + | |
| 139 | + # 收集 ES 原始分数 | |
| 140 | + es_scores: List[float] = [] | |
| 141 | + for hit in es_hits: | |
| 142 | + raw = hit.get("_score") | |
| 143 | + try: | |
| 144 | + es_scores.append(float(raw) if raw is not None else 0.0) | |
| 145 | + except (TypeError, ValueError): | |
| 146 | + es_scores.append(0.0) | |
| 147 | + | |
| 148 | + max_es = max(es_scores) if es_scores else 0.0 | |
| 149 | + fused_debug: List[Dict[str, Any]] = [] | |
| 150 | + | |
| 151 | + for idx, hit in enumerate(es_hits): | |
| 152 | + es_score = es_scores[idx] | |
| 153 | + ai_score_raw = rerank_scores[idx] | |
| 154 | + try: | |
| 155 | + ai_score = float(ai_score_raw) | |
| 156 | + except (TypeError, ValueError): | |
| 157 | + ai_score = 0.0 | |
| 158 | + | |
| 159 | + es_norm = (es_score / max_es) if max_es > 0 else 0.0 | |
| 160 | + fused = weight_es * es_norm + weight_ai * ai_score | |
| 161 | + | |
| 162 | + hit["_original_score"] = hit.get("_score") | |
| 163 | + hit["_ai_rerank_score"] = ai_score | |
| 164 | + hit["_fused_score"] = fused | |
| 165 | + hit["_score"] = fused | |
| 166 | + | |
| 167 | + fused_debug.append({ | |
| 168 | + "doc_id": hit.get("_id"), | |
| 169 | + "es_score": es_score, | |
| 170 | + "es_score_norm": es_norm, | |
| 171 | + "ai_rerank_score": ai_score, | |
| 172 | + "fused_score": fused, | |
| 173 | + }) | |
| 174 | + | |
| 175 | + # 按融合分数降序重排 | |
| 176 | + es_hits.sort( | |
| 177 | + key=lambda h: h.get("_fused_score", h.get("_score", 0.0)), | |
| 178 | + reverse=True, | |
| 179 | + ) | |
| 180 | + return fused_debug | |
| 181 | + | |
| 182 | + | |
| 183 | +def run_rerank( | |
| 184 | + query: str, | |
| 185 | + es_response: Dict[str, Any], | |
| 186 | + language: str = "zh", | |
| 187 | + service_url: Optional[str] = None, | |
| 188 | + timeout_sec: float = DEFAULT_TIMEOUT_SEC, | |
| 189 | + weight_es: float = DEFAULT_WEIGHT_ES, | |
| 190 | + weight_ai: float = DEFAULT_WEIGHT_AI, | |
| 191 | +) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]], List[Dict[str, Any]]]: | |
| 192 | + """ | |
| 193 | + 完整重排流程:从 es_response 取 hits -> 构造 docs -> 调服务 -> 融合分数并重排 -> 更新 max_score。 | |
| 194 | + | |
| 195 | + Args: | |
| 196 | + query: 搜索查询 | |
| 197 | + es_response: ES 原始响应(其中的 hits["hits"] 会被原地修改) | |
| 198 | + language: 文档文本使用的语言 | |
| 199 | + service_url: 重排服务 URL,为 None 时使用默认 127.0.0.1:6007 | |
| 200 | + timeout_sec: 请求超时 | |
| 201 | + weight_es: ES 分数权重 | |
| 202 | + weight_ai: 重排分数权重 | |
| 203 | + | |
| 204 | + Returns: | |
| 205 | + (es_response, rerank_meta, fused_debug): | |
| 206 | + - es_response: 已更新 hits 与 max_score 的响应(同一引用) | |
| 207 | + - rerank_meta: 重排服务返回的 meta,失败时为 None | |
| 208 | + - fused_debug: 每条文档的融合信息,供 debug 使用 | |
| 209 | + """ | |
| 210 | + try: | |
| 211 | + from reranker.config import CONFIG as RERANKER_CONFIG | |
| 212 | + except Exception: | |
| 213 | + RERANKER_CONFIG = None | |
| 214 | + | |
| 215 | + url = service_url | |
| 216 | + if not url and RERANKER_CONFIG is not None: | |
| 217 | + url = f"http://127.0.0.1:{RERANKER_CONFIG.PORT}/rerank" | |
| 218 | + if not url: | |
| 219 | + url = "http://127.0.0.1:6007/rerank" | |
| 220 | + | |
| 221 | + hits = es_response.get("hits", {}).get("hits") or [] | |
| 222 | + if not hits: | |
| 223 | + return es_response, None, [] | |
| 224 | + | |
| 225 | + docs = build_docs_from_hits(hits, language=language) | |
| 226 | + scores, meta = call_rerank_service(query, docs, url, timeout_sec=timeout_sec) | |
| 227 | + | |
| 228 | + if scores is None or len(scores) != len(hits): | |
| 229 | + return es_response, None, [] | |
| 230 | + | |
| 231 | + fused_debug = fuse_scores_and_resort( | |
| 232 | + hits, | |
| 233 | + scores, | |
| 234 | + weight_es=weight_es, | |
| 235 | + weight_ai=weight_ai, | |
| 236 | + ) | |
| 237 | + | |
| 238 | + # 更新 max_score 为融合后的最高分 | |
| 239 | + if hits: | |
| 240 | + top = hits[0].get("_fused_score", hits[0].get("_score", 0.0)) or 0.0 | |
| 241 | + if "hits" in es_response: | |
| 242 | + es_response["hits"]["max_score"] = top | |
| 243 | + | |
| 244 | + return es_response, meta, fused_debug | ... | ... |
search/rerank_engine.py deleted
| ... | ... | @@ -1,171 +0,0 @@ |
| 1 | -""" | |
| 2 | -Reranking engine for post-processing search result scoring. | |
| 3 | - | |
| 4 | -本地重排引擎,用于ES返回结果后的二次排序。 | |
| 5 | -当前状态:已禁用,优先使用ES的function_score。 | |
| 6 | - | |
| 7 | -Supports expression-based ranking with functions like: | |
| 8 | -- bm25(): Base BM25 text relevance score | |
| 9 | -- text_embedding_relevance(): KNN embedding similarity | |
| 10 | -- field_value(field): Use field value in scoring | |
| 11 | -- timeliness(date_field): Time decay function | |
| 12 | -""" | |
| 13 | - | |
| 14 | -import re | |
| 15 | -from typing import Dict, Any, List, Optional | |
| 16 | -import math | |
| 17 | - | |
| 18 | - | |
| 19 | -class RerankEngine: | |
| 20 | - """ | |
| 21 | - 本地重排引擎(当前禁用) | |
| 22 | - | |
| 23 | - 功能:对ES返回的结果进行二次打分和排序 | |
| 24 | - 用途:复杂的自定义排序逻辑、实时个性化等 | |
| 25 | - """ | |
| 26 | - | |
| 27 | - def __init__(self, ranking_expression: str, enabled: bool = False): | |
| 28 | - """ | |
| 29 | - Initialize rerank engine. | |
| 30 | - | |
| 31 | - Args: | |
| 32 | - ranking_expression: Ranking expression string | |
| 33 | - Example: "bm25() + 0.2*text_embedding_relevance() + general_score*2" | |
| 34 | - enabled: Whether local reranking is enabled (default: False) | |
| 35 | - """ | |
| 36 | - self.enabled = enabled | |
| 37 | - self.expression = ranking_expression | |
| 38 | - self.parsed_terms = [] | |
| 39 | - | |
| 40 | - if enabled: | |
| 41 | - self.parsed_terms = self._parse_expression(ranking_expression) | |
| 42 | - | |
| 43 | - def _parse_expression(self, expression: str) -> List[Dict[str, Any]]: | |
| 44 | - """ | |
| 45 | - Parse ranking expression into terms. | |
| 46 | - | |
| 47 | - Args: | |
| 48 | - expression: Ranking expression | |
| 49 | - | |
| 50 | - Returns: | |
| 51 | - List of term dictionaries | |
| 52 | - """ | |
| 53 | - terms = [] | |
| 54 | - | |
| 55 | - # Pattern to match: coefficient * function() or field_name | |
| 56 | - # Example: "0.2*text_embedding_relevance()" or "general_score*2" | |
| 57 | - pattern = r'([+-]?\s*\d*\.?\d*)\s*\*?\s*([a-zA-Z_]\w*(?:\([^)]*\))?)' | |
| 58 | - | |
| 59 | - for match in re.finditer(pattern, expression): | |
| 60 | - coef_str = match.group(1).strip() | |
| 61 | - func_str = match.group(2).strip() | |
| 62 | - | |
| 63 | - # Parse coefficient | |
| 64 | - if coef_str in ['', '+']: | |
| 65 | - coefficient = 1.0 | |
| 66 | - elif coef_str == '-': | |
| 67 | - coefficient = -1.0 | |
| 68 | - else: | |
| 69 | - try: | |
| 70 | - coefficient = float(coef_str) | |
| 71 | - except ValueError: | |
| 72 | - coefficient = 1.0 | |
| 73 | - | |
| 74 | - # Check if function or field | |
| 75 | - if '(' in func_str: | |
| 76 | - # Function call | |
| 77 | - func_name = func_str[:func_str.index('(')] | |
| 78 | - args_str = func_str[func_str.index('(') + 1:func_str.rindex(')')] | |
| 79 | - args = [arg.strip() for arg in args_str.split(',')] if args_str else [] | |
| 80 | - | |
| 81 | - terms.append({ | |
| 82 | - 'type': 'function', | |
| 83 | - 'name': func_name, | |
| 84 | - 'args': args, | |
| 85 | - 'coefficient': coefficient | |
| 86 | - }) | |
| 87 | - else: | |
| 88 | - # Field reference | |
| 89 | - terms.append({ | |
| 90 | - 'type': 'field', | |
| 91 | - 'name': func_str, | |
| 92 | - 'coefficient': coefficient | |
| 93 | - }) | |
| 94 | - | |
| 95 | - return terms | |
| 96 | - | |
| 97 | - def calculate_score( | |
| 98 | - self, | |
| 99 | - hit: Dict[str, Any], | |
| 100 | - base_score: float, | |
| 101 | - knn_score: Optional[float] = None | |
| 102 | - ) -> float: | |
| 103 | - """ | |
| 104 | - Calculate final score for a search result. | |
| 105 | - | |
| 106 | - Args: | |
| 107 | - hit: ES hit document | |
| 108 | - base_score: Base BM25 score | |
| 109 | - knn_score: KNN similarity score (if available) | |
| 110 | - | |
| 111 | - Returns: | |
| 112 | - Final calculated score | |
| 113 | - """ | |
| 114 | - if not self.enabled: | |
| 115 | - return base_score | |
| 116 | - | |
| 117 | - score = 0.0 | |
| 118 | - source = hit.get('_source', {}) | |
| 119 | - | |
| 120 | - for term in self.parsed_terms: | |
| 121 | - term_value = 0.0 | |
| 122 | - | |
| 123 | - if term['type'] == 'function': | |
| 124 | - func_name = term['name'] | |
| 125 | - | |
| 126 | - if func_name == 'bm25': | |
| 127 | - term_value = base_score | |
| 128 | - | |
| 129 | - elif func_name == 'text_embedding_relevance': | |
| 130 | - term_value = knn_score if knn_score is not None else 0.0 | |
| 131 | - | |
| 132 | - elif func_name == 'timeliness': | |
| 133 | - # Time decay function | |
| 134 | - if term['args']: | |
| 135 | - date_field = term['args'][0] | |
| 136 | - if date_field in source: | |
| 137 | - # Simple time decay (would need actual implementation) | |
| 138 | - term_value = 1.0 | |
| 139 | - else: | |
| 140 | - term_value = 1.0 | |
| 141 | - | |
| 142 | - elif func_name == 'field_value': | |
| 143 | - # Get field value | |
| 144 | - if term['args'] and term['args'][0] in source: | |
| 145 | - field_value = source[term['args'][0]] | |
| 146 | - try: | |
| 147 | - term_value = float(field_value) | |
| 148 | - except (ValueError, TypeError): | |
| 149 | - term_value = 0.0 | |
| 150 | - | |
| 151 | - elif term['type'] == 'field': | |
| 152 | - # Direct field reference | |
| 153 | - field_name = term['name'] | |
| 154 | - if field_name in source: | |
| 155 | - try: | |
| 156 | - term_value = float(source[field_name]) | |
| 157 | - except (ValueError, TypeError): | |
| 158 | - term_value = 0.0 | |
| 159 | - | |
| 160 | - score += term['coefficient'] * term_value | |
| 161 | - | |
| 162 | - return score | |
| 163 | - | |
| 164 | - def get_expression(self) -> str: | |
| 165 | - """Get ranking expression.""" | |
| 166 | - return self.expression | |
| 167 | - | |
| 168 | - def get_terms(self) -> List[Dict[str, Any]]: | |
| 169 | - """Get parsed expression terms.""" | |
| 170 | - return self.parsed_terms | |
| 171 | - |
search/searcher.py
| ... | ... | @@ -13,7 +13,6 @@ from query import QueryParser, ParsedQuery |
| 13 | 13 | from embeddings import CLIPImageEncoder |
| 14 | 14 | from .boolean_parser import BooleanParser, QueryNode |
| 15 | 15 | from .es_query_builder import ESQueryBuilder |
| 16 | -from .rerank_engine import RerankEngine | |
| 17 | 16 | from config import SearchConfig |
| 18 | 17 | from config.tenant_config_loader import get_tenant_config_loader |
| 19 | 18 | from config.utils import get_match_fields_for_index |
| ... | ... | @@ -99,7 +98,6 @@ class Searcher: |
| 99 | 98 | |
| 100 | 99 | # Initialize components |
| 101 | 100 | self.boolean_parser = BooleanParser() |
| 102 | - self.rerank_engine = RerankEngine(config.ranking.expression, enabled=False) | |
| 103 | 101 | |
| 104 | 102 | # Get match fields from config |
| 105 | 103 | self.match_fields = get_match_fields_for_index(config, "default") |
| ... | ... | @@ -137,6 +135,7 @@ class Searcher: |
| 137 | 135 | debug: bool = False, |
| 138 | 136 | language: str = "en", |
| 139 | 137 | sku_filter_dimension: Optional[List[str]] = None, |
| 138 | + ai_search: bool = False, | |
| 140 | 139 | ) -> SearchResult: |
| 141 | 140 | """ |
| 142 | 141 | Execute search query (外部友好格式). |
| ... | ... | @@ -168,15 +167,21 @@ class Searcher: |
| 168 | 167 | index_langs = tenant_cfg.get("index_languages") or [] |
| 169 | 168 | enable_translation = len(index_langs) > 0 |
| 170 | 169 | enable_embedding = self.config.query_config.enable_text_embedding |
| 171 | - enable_rerank = False # Temporarily disabled | |
| 170 | + # 重排仅由请求参数 ai_search 控制,唯一实现为调用外部 BGE 重排服务 | |
| 171 | + enable_rerank = bool(ai_search) | |
| 172 | + rerank_window = self.config.rerank.rerank_window or 1000 | |
| 173 | + # 若开启重排且请求范围在窗口内:从 ES 取前 rerank_window 条、重排后再按 from/size 分页;否则不重排,按原 from/size 查 ES | |
| 174 | + in_rerank_window = enable_rerank and (from_ + size) <= rerank_window | |
| 175 | + es_fetch_from = 0 if in_rerank_window else from_ | |
| 176 | + es_fetch_size = rerank_window if in_rerank_window else size | |
| 172 | 177 | |
| 173 | 178 | # Start timing |
| 174 | 179 | context.start_stage(RequestContextStage.TOTAL) |
| 175 | 180 | |
| 176 | 181 | context.logger.info( |
| 177 | 182 | f"开始搜索请求 | 查询: '{query}' | 参数: size={size}, from_={from_}, " |
| 178 | - f"enable_translation={enable_translation}, enable_embedding={enable_embedding}, " | |
| 179 | - f"enable_rerank={enable_rerank}, min_score={min_score}", | |
| 183 | + f"enable_rerank={enable_rerank}, in_rerank_window={in_rerank_window}, es_fetch=({es_fetch_from},{es_fetch_size}) | " | |
| 184 | + f"enable_translation={enable_translation}, enable_embedding={enable_embedding}, min_score={min_score}", | |
| 180 | 185 | extra={'reqid': context.reqid, 'uid': context.uid} |
| 181 | 186 | ) |
| 182 | 187 | |
| ... | ... | @@ -184,6 +189,9 @@ class Searcher: |
| 184 | 189 | context.metadata['search_params'] = { |
| 185 | 190 | 'size': size, |
| 186 | 191 | 'from_': from_, |
| 192 | + 'es_fetch_from': es_fetch_from, | |
| 193 | + 'es_fetch_size': es_fetch_size, | |
| 194 | + 'in_rerank_window': in_rerank_window, | |
| 187 | 195 | 'filters': filters, |
| 188 | 196 | 'range_filters': range_filters, |
| 189 | 197 | 'facets': facets, |
| ... | ... | @@ -287,8 +295,8 @@ class Searcher: |
| 287 | 295 | filters=filters, |
| 288 | 296 | range_filters=range_filters, |
| 289 | 297 | facet_configs=facets, |
| 290 | - size=size, | |
| 291 | - from_=from_, | |
| 298 | + size=es_fetch_size, | |
| 299 | + from_=es_fetch_from, | |
| 292 | 300 | enable_knn=enable_embedding and parsed_query.query_vector is not None, |
| 293 | 301 | min_score=min_score, |
| 294 | 302 | parsed_query=parsed_query |
| ... | ... | @@ -336,12 +344,12 @@ class Searcher: |
| 336 | 344 | # Step 4: Elasticsearch search |
| 337 | 345 | context.start_stage(RequestContextStage.ELASTICSEARCH_SEARCH) |
| 338 | 346 | try: |
| 339 | - # Use tenant-specific index name | |
| 347 | + # Use tenant-specific index name(开启重排且在窗口内时已用 es_fetch_size/es_fetch_from) | |
| 340 | 348 | es_response = self.es_client.search( |
| 341 | 349 | index_name=index_name, |
| 342 | 350 | body=body_for_es, |
| 343 | - size=size, | |
| 344 | - from_=from_ | |
| 351 | + size=es_fetch_size, | |
| 352 | + from_=es_fetch_from | |
| 345 | 353 | ) |
| 346 | 354 | |
| 347 | 355 | # Store ES response in context |
| ... | ... | @@ -365,6 +373,69 @@ class Searcher: |
| 365 | 373 | finally: |
| 366 | 374 | context.end_stage(RequestContextStage.ELASTICSEARCH_SEARCH) |
| 367 | 375 | |
| 376 | + # Optional Step 4.5: AI reranking(仅当请求范围在重排窗口内时执行) | |
| 377 | + if enable_rerank and in_rerank_window: | |
| 378 | + context.start_stage(RequestContextStage.RERANKING) | |
| 379 | + try: | |
| 380 | + from .rerank_client import run_rerank | |
| 381 | + | |
| 382 | + rerank_query = parsed_query.original_query if parsed_query else query | |
| 383 | + rc = self.config.rerank | |
| 384 | + es_response, rerank_meta, fused_debug = run_rerank( | |
| 385 | + query=rerank_query, | |
| 386 | + es_response=es_response, | |
| 387 | + language=language, | |
| 388 | + service_url=rc.service_url, | |
| 389 | + timeout_sec=rc.timeout_sec, | |
| 390 | + weight_es=rc.weight_es, | |
| 391 | + weight_ai=rc.weight_ai, | |
| 392 | + ) | |
| 393 | + | |
| 394 | + if rerank_meta is not None: | |
| 395 | + try: | |
| 396 | + from reranker.config import CONFIG as RERANKER_CONFIG | |
| 397 | + rerank_url = f"http://127.0.0.1:{RERANKER_CONFIG.PORT}/rerank" | |
| 398 | + except Exception: | |
| 399 | + rerank_url = "http://127.0.0.1:6007/rerank" | |
| 400 | + context.metadata.setdefault("rerank_info", {}) | |
| 401 | + context.metadata["rerank_info"].update({ | |
| 402 | + "service_url": rerank_url, | |
| 403 | + "docs": len(es_response.get("hits", {}).get("hits") or []), | |
| 404 | + "meta": rerank_meta, | |
| 405 | + }) | |
| 406 | + context.store_intermediate_result("rerank_scores", fused_debug) | |
| 407 | + context.logger.info( | |
| 408 | + f"重排完成 | docs={len(fused_debug)} | meta={rerank_meta}", | |
| 409 | + extra={'reqid': context.reqid, 'uid': context.uid} | |
| 410 | + ) | |
| 411 | + except Exception as e: | |
| 412 | + context.add_warning(f"Rerank failed: {e}") | |
| 413 | + context.logger.warning( | |
| 414 | + f"调用重排服务失败 | error: {e}", | |
| 415 | + extra={'reqid': context.reqid, 'uid': context.uid}, | |
| 416 | + exc_info=True, | |
| 417 | + ) | |
| 418 | + finally: | |
| 419 | + context.end_stage(RequestContextStage.RERANKING) | |
| 420 | + | |
| 421 | + # 当本次请求在重排窗口内时:已从 ES 取了 rerank_window 条并可能已重排,需按请求的 from/size 做分页切片 | |
| 422 | + if in_rerank_window: | |
| 423 | + hits = es_response.get("hits", {}).get("hits") or [] | |
| 424 | + sliced = hits[from_ : from_ + size] | |
| 425 | + es_response.setdefault("hits", {})["hits"] = sliced | |
| 426 | + if sliced: | |
| 427 | + slice_max = max((h.get("_score") for h in sliced), default=0.0) | |
| 428 | + try: | |
| 429 | + es_response["hits"]["max_score"] = float(slice_max) | |
| 430 | + except (TypeError, ValueError): | |
| 431 | + es_response["hits"]["max_score"] = 0.0 | |
| 432 | + else: | |
| 433 | + es_response["hits"]["max_score"] = 0.0 | |
| 434 | + context.logger.info( | |
| 435 | + f"重排分页切片 | from={from_}, size={size}, 返回={len(sliced)}条", | |
| 436 | + extra={'reqid': context.reqid, 'uid': context.uid} | |
| 437 | + ) | |
| 438 | + | |
| 368 | 439 | # Step 5: Result processing |
| 369 | 440 | context.start_stage(RequestContextStage.RESULT_PROCESSING) |
| 370 | 441 | try: |
| ... | ... | @@ -379,7 +450,7 @@ class Searcher: |
| 379 | 450 | total_value = total.get('value', 0) |
| 380 | 451 | else: |
| 381 | 452 | total_value = total |
| 382 | - | |
| 453 | + # max_score 会在启用 AI 搜索时被更新为融合分数的最大值 | |
| 383 | 454 | max_score = es_response.get('hits', {}).get('max_score') or 0.0 |
| 384 | 455 | |
| 385 | 456 | # Format results using ResultFormatter | ... | ... |
tests/conftest.py