Compare View
Commits (4)
Showing
34 changed files
Show diff stats
api/app.py
| ... | ... | @@ -27,6 +27,8 @@ from slowapi.errors import RateLimitExceeded |
| 27 | 27 | # Configure backend logging |
| 28 | 28 | import pathlib |
| 29 | 29 | |
| 30 | +from request_log_context import LOG_LINE_FORMAT, RequestLogContextFilter | |
| 31 | + | |
| 30 | 32 | |
| 31 | 33 | def configure_backend_logging() -> None: |
| 32 | 34 | log_dir = pathlib.Path("logs") |
| ... | ... | @@ -34,9 +36,8 @@ def configure_backend_logging() -> None: |
| 34 | 36 | log_level = os.getenv("LOG_LEVEL", "INFO").upper() |
| 35 | 37 | numeric_level = getattr(logging, log_level, logging.INFO) |
| 36 | 38 | |
| 37 | - default_formatter = logging.Formatter( | |
| 38 | - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| 39 | - ) | |
| 39 | + default_formatter = logging.Formatter(LOG_LINE_FORMAT) | |
| 40 | + request_filter = RequestLogContextFilter() | |
| 40 | 41 | |
| 41 | 42 | root_logger = logging.getLogger() |
| 42 | 43 | root_logger.setLevel(numeric_level) |
| ... | ... | @@ -45,6 +46,7 @@ def configure_backend_logging() -> None: |
| 45 | 46 | console_handler = logging.StreamHandler() |
| 46 | 47 | console_handler.setLevel(numeric_level) |
| 47 | 48 | console_handler.setFormatter(default_formatter) |
| 49 | + console_handler.addFilter(request_filter) | |
| 48 | 50 | root_logger.addHandler(console_handler) |
| 49 | 51 | |
| 50 | 52 | backend_handler = TimedRotatingFileHandler( |
| ... | ... | @@ -56,6 +58,7 @@ def configure_backend_logging() -> None: |
| 56 | 58 | ) |
| 57 | 59 | backend_handler.setLevel(numeric_level) |
| 58 | 60 | backend_handler.setFormatter(default_formatter) |
| 61 | + backend_handler.addFilter(request_filter) | |
| 59 | 62 | root_logger.addHandler(backend_handler) |
| 60 | 63 | |
| 61 | 64 | verbose_logger = logging.getLogger("backend.verbose") |
| ... | ... | @@ -71,11 +74,16 @@ def configure_backend_logging() -> None: |
| 71 | 74 | encoding="utf-8", |
| 72 | 75 | ) |
| 73 | 76 | verbose_handler.setLevel(numeric_level) |
| 74 | - verbose_handler.setFormatter( | |
| 75 | - logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| 76 | - ) | |
| 77 | + verbose_handler.setFormatter(logging.Formatter(LOG_LINE_FORMAT)) | |
| 78 | + verbose_handler.addFilter(request_filter) | |
| 77 | 79 | verbose_logger.addHandler(verbose_handler) |
| 78 | 80 | |
| 81 | + for logger_name in ("uvicorn", "uvicorn.error", "uvicorn.access"): | |
| 82 | + uvicorn_logger = logging.getLogger(logger_name) | |
| 83 | + uvicorn_logger.handlers.clear() | |
| 84 | + uvicorn_logger.setLevel(numeric_level) | |
| 85 | + uvicorn_logger.propagate = True | |
| 86 | + | |
| 79 | 87 | |
| 80 | 88 | configure_backend_logging() |
| 81 | 89 | logger = logging.getLogger(__name__) |
| ... | ... | @@ -101,6 +109,16 @@ _suggestion_service: Optional[SuggestionService] = None |
| 101 | 109 | _app_config = None |
| 102 | 110 | |
| 103 | 111 | |
| 112 | +def _request_log_extra_from_http(request: Request) -> dict: | |
| 113 | + reqid = getattr(getattr(request, "state", None), "reqid", None) or request.headers.get("X-Request-ID") | |
| 114 | + uid = ( | |
| 115 | + getattr(getattr(request, "state", None), "uid", None) | |
| 116 | + or request.headers.get("X-User-ID") | |
| 117 | + or request.headers.get("User-ID") | |
| 118 | + ) | |
| 119 | + return {"reqid": reqid or "-1", "uid": uid or "-1"} | |
| 120 | + | |
| 121 | + | |
| 104 | 122 | def init_service(es_host: str = "http://localhost:9200"): |
| 105 | 123 | """ |
| 106 | 124 | Initialize search service with unified configuration. |
| ... | ... | @@ -261,7 +279,11 @@ async def shutdown_event(): |
| 261 | 279 | async def global_exception_handler(request: Request, exc: Exception): |
| 262 | 280 | """Global exception handler with detailed logging.""" |
| 263 | 281 | client_ip = request.client.host if request.client else "unknown" |
| 264 | - logger.error(f"Unhandled exception from {client_ip}: {exc}", exc_info=True) | |
| 282 | + logger.error( | |
| 283 | + f"Unhandled exception from {client_ip}: {exc}", | |
| 284 | + exc_info=True, | |
| 285 | + extra=_request_log_extra_from_http(request), | |
| 286 | + ) | |
| 265 | 287 | |
| 266 | 288 | return JSONResponse( |
| 267 | 289 | status_code=500, |
| ... | ... | @@ -276,7 +298,10 @@ async def global_exception_handler(request: Request, exc: Exception): |
| 276 | 298 | @app.exception_handler(HTTPException) |
| 277 | 299 | async def http_exception_handler(request: Request, exc: HTTPException): |
| 278 | 300 | """HTTP exception handler.""" |
| 279 | - logger.warning(f"HTTP exception from {request.client.host if request.client else 'unknown'}: {exc.status_code} - {exc.detail}") | |
| 301 | + logger.warning( | |
| 302 | + f"HTTP exception from {request.client.host if request.client else 'unknown'}: {exc.status_code} - {exc.detail}", | |
| 303 | + extra=_request_log_extra_from_http(request), | |
| 304 | + ) | |
| 280 | 305 | |
| 281 | 306 | return JSONResponse( |
| 282 | 307 | status_code=exc.status_code, | ... | ... |
api/routes/search.py
| ... | ... | @@ -59,6 +59,8 @@ async def search(request: SearchRequest, http_request: Request): |
| 59 | 59 | Requires tenant_id in header (X-Tenant-ID) or query parameter (tenant_id). |
| 60 | 60 | """ |
| 61 | 61 | reqid, uid = extract_request_info(http_request) |
| 62 | + http_request.state.reqid = reqid | |
| 63 | + http_request.state.uid = uid | |
| 62 | 64 | |
| 63 | 65 | # Extract tenant_id (required) |
| 64 | 66 | tenant_id = http_request.headers.get('X-Tenant-ID') |
| ... | ... | @@ -213,6 +215,8 @@ async def search_by_image(request: ImageSearchRequest, http_request: Request): |
| 213 | 215 | Requires tenant_id in header (X-Tenant-ID) or query parameter (tenant_id). |
| 214 | 216 | """ |
| 215 | 217 | reqid, uid = extract_request_info(http_request) |
| 218 | + http_request.state.reqid = reqid | |
| 219 | + http_request.state.uid = uid | |
| 216 | 220 | |
| 217 | 221 | # Extract tenant_id (required) |
| 218 | 222 | tenant_id = http_request.headers.get('X-Tenant-ID') | ... | ... |
config/config.yaml
| ... | ... | @@ -17,9 +17,9 @@ runtime: |
| 17 | 17 | embedding_port: 6005 |
| 18 | 18 | embedding_text_port: 6005 |
| 19 | 19 | embedding_image_port: 6008 |
| 20 | - translator_host: "127.0.0.1" | |
| 20 | + translator_host: "0.0.0.0" | |
| 21 | 21 | translator_port: 6006 |
| 22 | - reranker_host: "127.0.0.1" | |
| 22 | + reranker_host: "0.0.0.0" | |
| 23 | 23 | reranker_port: 6007 |
| 24 | 24 | |
| 25 | 25 | # 基础设施连接(敏感项优先读环境变量:ES_*、REDIS_*、DB_*、DASHSCOPE_API_KEY、DEEPL_AUTH_KEY) |
| ... | ... | @@ -116,6 +116,14 @@ query_config: |
| 116 | 116 | translation_embedding_wait_budget_ms_source_in_index: 500 # 80 |
| 117 | 117 | translation_embedding_wait_budget_ms_source_not_in_index: 500 #200 |
| 118 | 118 | |
| 119 | + style_intent: | |
| 120 | + enabled: true | |
| 121 | + color_dictionary_path: "config/dictionaries/style_intent_color.csv" | |
| 122 | + size_dictionary_path: "config/dictionaries/style_intent_size.csv" | |
| 123 | + dimension_aliases: | |
| 124 | + color: ["color", "colors", "colour", "colours", "颜色", "色", "色系"] | |
| 125 | + size: ["size", "sizes", "sizing", "尺码", "尺寸", "码数", "号码", "码"] | |
| 126 | + | |
| 119 | 127 | # 动态多语言检索字段配置 |
| 120 | 128 | # multilingual_fields 会被拼成 title.{lang}/brief.{lang}/... 形式; |
| 121 | 129 | # shared_fields 为无语言后缀字段。 |
| ... | ... | @@ -186,6 +194,10 @@ query_config: |
| 186 | 194 | - total_inventory |
| 187 | 195 | - option1_name |
| 188 | 196 | - option1_values |
| 197 | + - option2_name | |
| 198 | + - option2_values | |
| 199 | + - option3_name | |
| 200 | + - option3_values | |
| 189 | 201 | - specifications |
| 190 | 202 | - skus |
| 191 | 203 | ... | ... |
| ... | ... | @@ -0,0 +1,15 @@ |
| 1 | +black,black,blk,黑,黑色 | |
| 2 | +white,white,wht,白,白色 | |
| 3 | +red,red,reddish,红,红色 | |
| 4 | +blue,blue,blu,蓝,蓝色 | |
| 5 | +green,green,grn,绿,绿色 | |
| 6 | +yellow,yellow,ylw,黄,黄色 | |
| 7 | +pink,pink,粉,粉色 | |
| 8 | +purple,purple,violet,紫,紫色 | |
| 9 | +gray,gray,grey,灰,灰色 | |
| 10 | +brown,brown,棕,棕色,咖啡色 | |
| 11 | +beige,beige,khaki,米色,卡其色 | |
| 12 | +navy,navy,navy blue,藏青,藏蓝,深蓝 | |
| 13 | +silver,silver,银,银色 | |
| 14 | +gold,gold,金,金色 | |
| 15 | +orange,orange,橙,橙色 | ... | ... |
config/loader.py
| ... | ... | @@ -95,6 +95,29 @@ def _read_rewrite_dictionary(path: Path) -> Dict[str, str]: |
| 95 | 95 | return rewrite_dict |
| 96 | 96 | |
| 97 | 97 | |
| 98 | +def _read_synonym_csv_dictionary(path: Path) -> List[List[str]]: | |
| 99 | + rows: List[List[str]] = [] | |
| 100 | + if not path.exists(): | |
| 101 | + return rows | |
| 102 | + | |
| 103 | + with open(path, "r", encoding="utf-8") as handle: | |
| 104 | + for raw_line in handle: | |
| 105 | + line = raw_line.strip() | |
| 106 | + if not line or line.startswith("#"): | |
| 107 | + continue | |
| 108 | + parts = [segment.strip() for segment in line.split(",")] | |
| 109 | + normalized = [segment for segment in parts if segment] | |
| 110 | + if normalized: | |
| 111 | + rows.append(normalized) | |
| 112 | + return rows | |
| 113 | + | |
| 114 | + | |
| 115 | +_DEFAULT_STYLE_INTENT_DIMENSION_ALIASES: Dict[str, List[str]] = { | |
| 116 | + "color": ["color", "colors", "colour", "colours", "颜色", "色", "色系"], | |
| 117 | + "size": ["size", "sizes", "sizing", "尺码", "尺寸", "码数", "号码", "码"], | |
| 118 | +} | |
| 119 | + | |
| 120 | + | |
| 98 | 121 | class AppConfigLoader: |
| 99 | 122 | """Load the unified application configuration.""" |
| 100 | 123 | |
| ... | ... | @@ -253,6 +276,45 @@ class AppConfigLoader: |
| 253 | 276 | if isinstance(query_cfg.get("text_query_strategy"), dict) |
| 254 | 277 | else {} |
| 255 | 278 | ) |
| 279 | + style_intent_cfg = ( | |
| 280 | + query_cfg.get("style_intent") | |
| 281 | + if isinstance(query_cfg.get("style_intent"), dict) | |
| 282 | + else {} | |
| 283 | + ) | |
| 284 | + | |
| 285 | + def _resolve_project_path(value: Any, default_path: Path) -> Path: | |
| 286 | + if value in (None, ""): | |
| 287 | + return default_path | |
| 288 | + candidate = Path(str(value)) | |
| 289 | + if candidate.is_absolute(): | |
| 290 | + return candidate | |
| 291 | + return self.project_root / candidate | |
| 292 | + | |
| 293 | + style_color_path = _resolve_project_path( | |
| 294 | + style_intent_cfg.get("color_dictionary_path"), | |
| 295 | + self.config_dir / "dictionaries" / "style_intent_color.csv", | |
| 296 | + ) | |
| 297 | + style_size_path = _resolve_project_path( | |
| 298 | + style_intent_cfg.get("size_dictionary_path"), | |
| 299 | + self.config_dir / "dictionaries" / "style_intent_size.csv", | |
| 300 | + ) | |
| 301 | + configured_dimension_aliases = ( | |
| 302 | + style_intent_cfg.get("dimension_aliases") | |
| 303 | + if isinstance(style_intent_cfg.get("dimension_aliases"), dict) | |
| 304 | + else {} | |
| 305 | + ) | |
| 306 | + style_dimension_aliases: Dict[str, List[str]] = {} | |
| 307 | + for intent_type, default_aliases in _DEFAULT_STYLE_INTENT_DIMENSION_ALIASES.items(): | |
| 308 | + aliases = configured_dimension_aliases.get(intent_type) | |
| 309 | + if isinstance(aliases, list) and aliases: | |
| 310 | + style_dimension_aliases[intent_type] = [str(alias) for alias in aliases if str(alias).strip()] | |
| 311 | + else: | |
| 312 | + style_dimension_aliases[intent_type] = list(default_aliases) | |
| 313 | + | |
| 314 | + style_intent_terms = { | |
| 315 | + "color": _read_synonym_csv_dictionary(style_color_path), | |
| 316 | + "size": _read_synonym_csv_dictionary(style_size_path), | |
| 317 | + } | |
| 256 | 318 | query_config = QueryConfig( |
| 257 | 319 | supported_languages=list(query_cfg.get("supported_languages") or ["zh", "en"]), |
| 258 | 320 | default_language=str(query_cfg.get("default_language") or "en"), |
| ... | ... | @@ -324,6 +386,9 @@ class AppConfigLoader: |
| 324 | 386 | translation_embedding_wait_budget_ms_source_not_in_index=int( |
| 325 | 387 | query_cfg.get("translation_embedding_wait_budget_ms_source_not_in_index", 200) |
| 326 | 388 | ), |
| 389 | + style_intent_enabled=bool(style_intent_cfg.get("enabled", True)), | |
| 390 | + style_intent_terms=style_intent_terms, | |
| 391 | + style_intent_dimension_aliases=style_dimension_aliases, | |
| 327 | 392 | ) |
| 328 | 393 | |
| 329 | 394 | function_score_cfg = raw.get("function_score") if isinstance(raw.get("function_score"), dict) else {} | ... | ... |
config/schema.py
| ... | ... | @@ -64,6 +64,9 @@ class QueryConfig: |
| 64 | 64 | # 检测语言不在 index_languages 内:翻译对召回更关键,预算较长。 |
| 65 | 65 | translation_embedding_wait_budget_ms_source_in_index: int = 80 |
| 66 | 66 | translation_embedding_wait_budget_ms_source_not_in_index: int = 200 |
| 67 | + style_intent_enabled: bool = True | |
| 68 | + style_intent_terms: Dict[str, List[List[str]]] = field(default_factory=dict) | |
| 69 | + style_intent_dimension_aliases: Dict[str, List[str]] = field(default_factory=dict) | |
| 67 | 70 | |
| 68 | 71 | |
| 69 | 72 | @dataclass(frozen=True) | ... | ... |
context/request_context.py
| ... | ... | @@ -12,6 +12,8 @@ from typing import Dict, Any, Optional, List |
| 12 | 12 | from dataclasses import dataclass, field |
| 13 | 13 | import uuid |
| 14 | 14 | |
| 15 | +from request_log_context import bind_request_log_context, reset_request_log_context | |
| 16 | + | |
| 15 | 17 | |
| 16 | 18 | class RequestContextStage(Enum): |
| 17 | 19 | """搜索阶段枚举""" |
| ... | ... | @@ -375,9 +377,15 @@ def get_current_request_context() -> Optional[RequestContext]: |
| 375 | 377 | def set_current_request_context(context: RequestContext) -> None: |
| 376 | 378 | """设置当前线程的请求上下文""" |
| 377 | 379 | threading.current_thread().request_context = context |
| 380 | + _, _, tokens = bind_request_log_context(context.reqid, context.uid) | |
| 381 | + threading.current_thread().request_log_tokens = tokens | |
| 378 | 382 | |
| 379 | 383 | |
| 380 | 384 | def clear_current_request_context() -> None: |
| 381 | 385 | """清除当前线程的请求上下文""" |
| 386 | + tokens = getattr(threading.current_thread(), 'request_log_tokens', None) | |
| 387 | + if tokens is not None: | |
| 388 | + reset_request_log_context(tokens) | |
| 389 | + delattr(threading.current_thread(), 'request_log_tokens') | |
| 382 | 390 | if hasattr(threading.current_thread(), 'request_context'): |
| 383 | 391 | delattr(threading.current_thread(), 'request_context') |
| 384 | 392 | \ No newline at end of file | ... | ... |
| ... | ... | @@ -0,0 +1,40 @@ |
| 1 | + | |
| 2 | +一、 增加款式意图识别模块 | |
| 3 | +意图类型: 颜色,尺码(目前只需要支持这两种) | |
| 4 | + | |
| 5 | + | |
| 6 | +二、 意图判断 | |
| 7 | +- 意图召回层: | |
| 8 | +每种意图,有一个召回词集合 | |
| 9 | +对query(包括原始query、各种翻译query 都做匹配) | |
| 10 | +- 以颜色意图为例: | |
| 11 | +有一个词表,每一行 都逗号分割,互为同义词,行内第一个为标准化词 | |
| 12 | +query匹配了其中任何一个词,都认为,具有颜色意图 | |
| 13 | +匹配规则: 用细粒度、粗粒度分词,看是否有在词表中的。原始query分词、和每种翻译的分词,都要用。 | |
| 14 | + | |
| 15 | + | |
| 16 | +三、 意图使用: | |
| 17 | + 当前 SKU 置顶逻辑在「分页 + 详情回填」之后 | |
| 18 | +流程是:run_rerank → 按 from/size 切片 → page fill → _apply_sku_sorting_for_page_hits → ResultFormatter | |
| 19 | + 要改为: | |
| 20 | + 1. 有款式意图的时候,才做sku筛选 | |
| 21 | + 2. sku筛选的时机,改为在reranker之前,对所有内容(rerank输入的所有spus)做sku筛选 | |
| 22 | + 3. 从仅 option1 扩展到多个维度,识别的意图,包含意图的维度名(color)和维度名的泛化词list(color、颜色、colour、colors...),遍历spu的option1_name,option2_name,option3_name字段,看哪个能匹配上意图的维度名list,哪个匹配上了,则在这个维度筛选。 | |
| 23 | + 1. 比如匹配到option2_name,那么取每一个sku的option2_values。如果没匹配到任何一个,那么把三个属性值都用空格拼接起来。这个值要记录下来。有两个作用: | |
| 24 | + 1. 用来跟query匹配,看哪个更query相关性更高,以此进行最优sku筛选,把选出来的sku置顶,并替换spu的image_url | |
| 25 | + 2. 用来做rerank doc的title补充,从而参与rerank | |
| 26 | + 4. Rerank doc (有款式意图的时候)要带上属性后缀,拼接到title后面。在调用 run_rerank 前,对每条 hit 生成「用于重排的 doc 文本」(标题 + 可选后缀) | |
| 27 | + | |
| 28 | +- sku筛选的规则也要优化: | |
| 29 | +现在的逻辑是,先做包含的判断,找到第一个 option_value被query包含的,则直接认为匹配。没有匹配的再用embedding相似度。 | |
| 30 | +改为: | |
| 31 | + 1. 第一轮:遍历完,如果有且仅有一个被query包含,那么认为匹配。 | |
| 32 | + 2. 第二轮:如果有多个符合(被query包含),跳到3。如果没有,对每个词都走泛化词表进行匹配。 | |
| 33 | + 3. 第三轮:如果有多个,那么对这多个,走embedding相关性取最高的。如果一个也没有,则对所有的走embedding相关性取最高的 | |
| 34 | + 这个sku筛选也需要提取为一个独立的模块。 | |
| 35 | + | |
| 36 | +细节备注: | |
| 37 | +在重排窗口内,第一次 ES 查询会把 _source 裁成「重排模板需要的字段」,默认只有 title 等,不包含 skus / option*_name。因此,有意图的时候,需要给这一次的_source加上 skus / option*_name | |
| 38 | + | |
| 39 | +5. TODO: 搜索接口里,results[].skus 不是全量子 SKU:由 sku_filter_dimension 控制在应用层按维度分组折叠,每个「维度取值组合」只保留一条 SKU(组内第一条)。请求未传该字段时,Pydantic 默认是 ["option1"],等价于只按 option1_value 去重;服务端不会读取店铺主题的「主展示维」,需调用方与装修配置对齐并传入正确维度。因此当用户有款式等更细粒度意图、而款式落在 option2/option3(或对应 option*_name)时,若仍用默认只按 option1(常见为颜色)折叠,同一颜色下多种款式只会出现一条代表 SKU,无法从返回的 skus 里拿到该颜色下的全部款式行。(若业务需要全量子款,需传包含对应维度的 sku_filter_dimension,或传 null/[] 跳过折叠——以当前 ResultFormatter 实现为准。) | |
| 40 | + | ... | ... |
| ... | ... | @@ -0,0 +1,53 @@ |
| 1 | + | |
| 2 | + | |
| 3 | +增加款式意图识别模块。意图类型: 颜色,尺码(目前只需要支持这两种) | |
| 4 | + | |
| 5 | +一、 意图判断 | |
| 6 | +- 意图召回层: | |
| 7 | +每种意图,有一个召回词集合 | |
| 8 | +对query(包括原始query、各种翻译query 都做匹配) | |
| 9 | +- 以颜色意图为例: | |
| 10 | +有一个词表,每一行 都逗号分割,互为同义词,行内第一个为标准化词 | |
| 11 | +query匹配了其中任何一个词,都认为,具有颜色意图 | |
| 12 | +匹配规则: 用细粒度、粗粒度分词,看是否有在词表中的。原始query分词、和每种翻译的分词,都要用。 | |
| 13 | + | |
| 14 | +二、 意图使用: | |
| 15 | + 当前 SKU 置顶逻辑在「分页 + 详情回填」之后 | |
| 16 | +流程是:run_rerank → 按 from/size 切片 → page fill → _apply_sku_sorting_for_page_hits → ResultFormatter | |
| 17 | + 要改为: | |
| 18 | + 1. 有款式意图的时候,才做sku筛选 | |
| 19 | + 2. sku筛选的时机,改为在reranker之前,对所有内容(rerank输入的所有spus)做sku筛选 | |
| 20 | + 3. 从仅 option1 扩展到多个维度,识别的意图,包含意图的维度名(color)和维度名的泛化词list(color、颜色、colour、colors...),遍历spu的option1_name,option2_name,option3_name字段,看哪个能匹配上意图的维度名list,哪个匹配上了,则在这个维度筛选。 | |
| 21 | + 1. 比如匹配到option2_name,那么取每一个sku的option2_values。如果没匹配到任何一个,那么把三个属性值都用空格拼接起来。这个值要记录下来。有两个作用: | |
| 22 | + 1. 用来跟query匹配,看哪个更query相关性更高,以此进行最优sku筛选,把选出来的sku置顶,并替换spu的image_url | |
| 23 | + 2. 用来做rerank doc的title补充,从而参与rerank | |
| 24 | + 4. Rerank doc (有款式意图的时候)要带上属性后缀,拼接到title后面。在调用 run_rerank 前,对每条 hit 生成「用于重排的 doc 文本」(标题 + 可选后缀) | |
| 25 | + | |
| 26 | +- sku筛选的规则也要优化: | |
| 27 | +现在的逻辑是,先做包含的判断,找到第一个 option_value被query包含的,则直接认为匹配。没有匹配的再用embedding相似度。 | |
| 28 | +改为: | |
| 29 | + 1. 第一轮:遍历完,如果有且仅有一个被query包含,那么认为匹配。 | |
| 30 | + 2. 第二轮:如果有多个符合(被query包含),跳到3。如果没有,对每个词都走泛化词表进行匹配。 | |
| 31 | + 3. 第三轮:如果有多个,那么对这多个,走embedding相关性取最高的。如果一个也没有,则对所有的走embedding相关性取最高的 | |
| 32 | + 这个sku筛选也需要提取为一个独立的模块。 | |
| 33 | + | |
| 34 | +细节备注: | |
| 35 | +intent 考虑由 QueryParser 编排、具体实现拆成独立模块,主义好,现有的分词等基础设施的复用,缺失的英文分词可以补充。 | |
| 36 | +在重排窗口内,第一次 ES 查询会把 _source 裁成「重排模板需要的字段」,默认只有 title 等,不包含 skus / option*_name。因此,有意图的时候,需要给这一次的_source加上 skus / option*_name | |
| 37 | + | |
| 38 | +先仔细理解需求,查看代码,深度思考应该如何设计,和当前的系统较好的融合,给出统一的设计,可以根据需要适当改造当前的实现,降低整个系统的复杂度,提高模块化程度,而不是打补丁。修改后的最终状态应该是要足够简单、清晰、无冗余和分叉,模块间低耦合。多步思考确认最佳施工方案之后才进行代码修改。 | |
| 39 | + | |
| 40 | +5. TODO: 搜索接口里,results[].skus 不是全量子 SKU:由 sku_filter_dimension 控制在应用层按维度分组折叠,每个「维度取值组合」只保留一条 SKU(组内第一条)。请求未传该字段时,Pydantic 默认是 ["option1"],等价于只按 option1_value 去重;服务端不会读取店铺主题的「主展示维」,需调用方与装修配置对齐并传入正确维度。因此当用户有款式等更细粒度意图、而款式落在 option2/option3(或对应 option*_name)时,若仍用默认只按 option1(常见为颜色)折叠,同一颜色下多种款式只会出现一条代表 SKU,无法从返回的 skus 里拿到该颜色下的全部款式行。(若业务需要全量子款,需传包含对应维度的 sku_filter_dimension,或传 null/[] 跳过折叠——以当前 ResultFormatter 实现为准。) | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | +GPT5.4 主要思路: | |
| 47 | +我准备按这条线改,尽量把逻辑收束成两个新模块而不是继续堆在 Searcher 里: | |
| 48 | + | |
| 49 | +QueryParser 负责产出统一的“款式意图分析结果”,包含:命中的意图类型、标准化值、命中的 query 变体、分词结果、维度候选词。 | |
| 50 | +独立的 SKU 选择模块负责:维度解析、候选值抽取、三阶段匹配、选中 SKU 后的主图替换,以及生成 rerank 用的属性后缀。 | |
| 51 | +Searcher 只编排:解析 query -> 首次 ES 召回 -> 若有款式意图则对 rerank 窗口内全部 hits 做 SKU 预筛选并补充 rerank doc 文本 -> run_rerank -> 分页 -> page fill -> ResultFormatter。这样旧的页内 SKU 置顶逻辑可以下沉或删除,避免前后两个阶段重复做同一件事。我再补看一下你最近写的意图草稿文档,确认有没有现成词表示例或边界条件,避免我这边把词表格式和你预期做偏。 | |
| 52 | + | |
| 53 | + | ... | ... |
| ... | ... | @@ -0,0 +1,38 @@ |
| 1 | + | |
| 2 | +一、 增加款式意图识别模块 | |
| 3 | +意图类型: 颜色,尺码(目前只需要支持这两种) | |
| 4 | + | |
| 5 | + | |
| 6 | +二、 意图判断 | |
| 7 | +- 意图召回层: | |
| 8 | +每种意图,有一个召回词集合 | |
| 9 | +对query(包括原始query、各种翻译query 都做匹配) | |
| 10 | +- 以颜色意图为例: | |
| 11 | +有一个词表,每一行 都逗号分割,互为同义词,行内第一个为标准化词 | |
| 12 | +query匹配了其中任何一个词,都认为,具有颜色意图 | |
| 13 | +匹配规则: 用细粒度、粗粒度分词,看是否有在词表中的。原始query分词、和每种翻译的分词,都要用。 | |
| 14 | + | |
| 15 | + | |
| 16 | +三、 意图使用: | |
| 17 | + 当前 SKU 置顶逻辑在「分页 + 详情回填」之后 | |
| 18 | +流程是:run_rerank → 按 from/size 切片 → page fill → _apply_sku_sorting_for_page_hits → ResultFormatter | |
| 19 | + 要改为: | |
| 20 | + 1. 有款式意图的时候,才做sku筛选 | |
| 21 | + 2. sku筛选的时机,改为在reranker之前,对所有内容(rerank输入的所有spus)做sku筛选 | |
| 22 | + 3. 从仅 option1 扩展到多个维度,识别的意图,包含意图的维度名(color)和维度名的泛化词list(color、颜色、colour、colors...),遍历spu的option1_name,option2_name,option3_name字段,看哪个能匹配上意图的维度名list,哪个匹配上了,则在这个维度筛选。 | |
| 23 | + 1. 比如匹配到option2_name,那么取每一个sku的option2_values。如果没匹配到任何一个,那么把三个属性值都用空格拼接起来。这个值要记录下来。有两个作用: | |
| 24 | + 1. 用来跟query匹配,看哪个更query相关性更高,以此进行最优sku筛选,把选出来的sku置顶,并替换spu的image_url | |
| 25 | + 2. 用来做rerank doc的title补充,从而参与rerank | |
| 26 | + 4. Rerank doc (有款式意图的时候)要带上属性后缀,拼接到title后面。在调用 run_rerank 前,对每条 hit 生成「用于重排的 doc 文本」(标题 + 可选后缀) | |
| 27 | + | |
| 28 | +- sku筛选的规则也要优化: | |
| 29 | +现在的逻辑是,先做包含的判断,找到第一个 option_value被query包含的,则直接认为匹配。没有匹配的再用embedding相似度。 | |
| 30 | +改为: | |
| 31 | + 1. 第一轮:遍历完,如果有且仅有一个被query包含,那么认为匹配。 | |
| 32 | + 2. 第二轮:如果有多个符合(被query包含),跳到3。如果没有,对每个词都走泛化词表进行匹配。 | |
| 33 | + 3. 第三轮:如果有多个,那么对这多个,走embedding相关性取最高的。如果一个也没有,则对所有的走embedding相关性取最高的 | |
| 34 | + 这个sku筛选也需要提取为一个独立的模块。 | |
| 35 | + | |
| 36 | +细节备注: | |
| 37 | +intent 考虑由 QueryParser 编排、具体实现拆成独立模块,主义好,现有的分词等基础设施的复用,缺失的英文分词可以补充。 | |
| 38 | +在重排窗口内,第一次 ES 查询会把 _source 裁成「重排模板需要的字段」,默认只有 title 等,不包含 skus / option*_name。因此,有意图的时候,需要给这一次的_source加上 skus / option*_name | ... | ... |
| ... | ... | @@ -0,0 +1,118 @@ |
| 1 | + | |
| 2 | +## prompt | |
| 3 | +参考 @docs/常用查询 - ES.md 帮我查询 tenant_Id=163的@search_products.json (1314-1332) 这些字段的统计分布,也就是这里每个字段的所有的值及其对应的文档数 | |
| 4 | + | |
| 5 | + | |
| 6 | +已按 `docs/常用查询 - ES.md` 的写法,在索引 **`search_products_tenant_163`** 上做了 `terms` 聚合(租户已按索引拆分,用 `match_all` 即可,等价于只查 tenant 163)。集群在本机 `localhost:9200` 上执行成功。 | |
| 7 | + | |
| 8 | +## 结果概要 | |
| 9 | + | |
| 10 | +- **总文档数**:8977 | |
| 11 | +- 六个字段的 **`sum_other_doc_count` 均为 0**,在 `size: 20000` 下没有截断,即:**当前数据里每个字段的取值都已全部列出**。 | |
| 12 | +- **`__missing__`**:来自聚合里的 `"missing": "__missing__"`,表示该字段缺失或为空的文档数(便于和「有值」一起看)。 | |
| 13 | +- **`option1_values` / `option2_values`** 多为多值数组:同一篇文档可落入多个桶,因此 **各桶 `doc_count` 之和会大于 8977**(这是 ES `terms` 的正常语义,表示「含该取值的文档数」,不是互斥分区)。 | |
| 14 | +- **`option1_name` / `option2_name` / `option3_name`** 及 **`option3_values`** 的桶计数之和与文档关系更接近单值/稀疏(见下方校验)。 | |
| 15 | + | |
| 16 | +### `option1_name`(13 个取值) | |
| 17 | + | |
| 18 | +| 值 | 文档数 | | |
| 19 | +|---:|---:| | |
| 20 | +| __missing__ | 4538 | | |
| 21 | +| Size | 2916 | | |
| 22 | +| Color | 1174 | | |
| 23 | +| 颜色 | 244 | | |
| 24 | +| COLOR | 56 | | |
| 25 | +| color | 16 | | |
| 26 | +| Colour | 15 | | |
| 27 | +| Variant | 9 | | |
| 28 | +| Style | 3 | | |
| 29 | +| Colors | 2 | | |
| 30 | +| Scent | 2 | | |
| 31 | +| Cup Size | 1 | | |
| 32 | +| Pattern Name | 1 | | |
| 33 | + | |
| 34 | +### `option2_name`(13 个取值) | |
| 35 | + | |
| 36 | +| 值 | 文档数 | | |
| 37 | +|---:|---:| | |
| 38 | +| __missing__ | 4666 | | |
| 39 | +| Color | 2879 | | |
| 40 | +| Size | 1134 | | |
| 41 | +| 尺码 | 244 | | |
| 42 | +| SIZE | 17 | | |
| 43 | +| size | 16 | | |
| 44 | +| Style | 12 | | |
| 45 | +| Cup Size | 4 | | |
| 46 | +| Item Package Quantity | 1 | | |
| 47 | +| Number of Items | 1 | | |
| 48 | +| Ring Size | 1 | | |
| 49 | +| Scent | 1 | | |
| 50 | +| Team Name | 1 | | |
| 51 | + | |
| 52 | +### `option3_name`(11 个取值) | |
| 53 | + | |
| 54 | +| 值 | 文档数 | | |
| 55 | +|---:|---:| | |
| 56 | +| __missing__ | 8831 | | |
| 57 | +| Fit Type | 52 | | |
| 58 | +| Special Size | 31 | | |
| 59 | +| Number of Items | 17 | | |
| 60 | +| Special Size Type | 16 | | |
| 61 | +| Team Name | 15 | | |
| 62 | +| Material Type | 8 | | |
| 63 | +| Style | 4 | | |
| 64 | +| Configuration | 1 | | |
| 65 | +| Hand Orientation | 1 | | |
| 66 | +| Pattern Name | 1 | | |
| 67 | + | |
| 68 | +### `option1_values` / `option2_values` | |
| 69 | + | |
| 70 | +- **桶数量**:`option1_values` **3463**,`option2_values` **3605**(完整列表过长,不宜贴在对话里)。 | |
| 71 | +- **文档中出现最多的取值(前 10)** | |
| 72 | + - **option1_values**:`__missing__` 4538,其后如 Medium 749、Black 653、Large 635、8/9/White/10/7/Small… | |
| 73 | + - **option2_values**:`__missing__` 4666,其后如 L 1161、M 1158、S 1064、XL 1036、Black 876… | |
| 74 | + | |
| 75 | +### `option3_values`(94 个桶,已全部枚举) | |
| 76 | + | |
| 77 | +除 `__missing__` 8831 外,另有 93 个具体取值;上文 Markdown 表里已列全(从 Standard 到 `standard` 等)。 | |
| 78 | + | |
| 79 | +--- | |
| 80 | + | |
| 81 | +## 可复用的查询(与文档一致) | |
| 82 | + | |
| 83 | +```bash | |
| 84 | +curl -sS -u 'saas:4hOaLaf41y2VuI8y' -X GET 'http://localhost:9200/search_products_tenant_163/_search?pretty' \ | |
| 85 | + -H 'Content-Type: application/json' -d '{ | |
| 86 | + "query": { "match_all": {} }, | |
| 87 | + "size": 0, | |
| 88 | + "track_total_hits": true, | |
| 89 | + "aggs": { | |
| 90 | + "option1_name_dist": { | |
| 91 | + "terms": { "field": "option1_name", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } | |
| 92 | + }, | |
| 93 | + "option2_name_dist": { | |
| 94 | + "terms": { "field": "option2_name", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } | |
| 95 | + }, | |
| 96 | + "option3_name_dist": { | |
| 97 | + "terms": { "field": "option3_name", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } | |
| 98 | + }, | |
| 99 | + "option1_values_dist": { | |
| 100 | + "terms": { "field": "option1_values", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } | |
| 101 | + }, | |
| 102 | + "option2_values_dist": { | |
| 103 | + "terms": { "field": "option2_values", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } | |
| 104 | + }, | |
| 105 | + "option3_values_dist": { | |
| 106 | + "terms": { "field": "option3_values", "size": 20000, "order": { "_count": "desc" }, "missing": "__missing__" } | |
| 107 | + } | |
| 108 | + } | |
| 109 | +}' | |
| 110 | +``` | |
| 111 | + | |
| 112 | +若某租户取值超过 2 万,需把 `size` 再加大或用 **composite** 分页扫完全部桶。 | |
| 113 | + | |
| 114 | +## 完整 JSON(含 3463 + 3605 条 value 分布) | |
| 115 | + | |
| 116 | +已整理为:`/tmp/tenant163_option_fields_distribution.json`(约 620KB),结构为按字段分组的 `values: [{ "value", "doc_count" }, ...]`,便于你用脚本或表格工具打开。 | |
| 117 | + | |
| 118 | +如需去掉 `__missing__` 桶,从请求里删掉各聚合中的 `"missing": "__missing__"` 即可。 | |
| 0 | 119 | \ No newline at end of file | ... | ... |
docs/相关性检索优化说明.md
| ... | ... | @@ -17,9 +17,9 @@ |
| 17 | 17 | 查询链路(文本相关): |
| 18 | 18 | |
| 19 | 19 | 1. `QueryParser.parse()` |
| 20 | - 负责产出解析事实:`query_normalized`、`rewritten_query`、`detected_language`、`translations`、`query_vector`、`query_tokens`、`contains_chinese`、`contains_english`。 | |
| 20 | + 负责产出解析事实:`query_normalized`、`rewritten_query`、`detected_language`、`translations`、`query_vector`、`query_tokens`。 | |
| 21 | 21 | 2. `Searcher.search()` |
| 22 | - 负责读取租户 `index_languages`,并将其一方面传给 `QueryParser` 作为 `target_languages`,另一方面传给 `ESQueryBuilder` 作为字段展开约束。 | |
| 22 | + 负责读取租户 `index_languages`,并将其传给 `QueryParser` 作为 `target_languages`(控制翻译目标语种);`ESQueryBuilder` 仅根据 `detected_language` 与各条译文构建子句字段,不再接收 `index_languages`。 | |
| 23 | 23 | 2. `ESQueryBuilder._build_advanced_text_query()` |
| 24 | 24 | 基于 `rewritten_query + detected_language + translations + index_languages` 构建 `base_query` 与 `base_query_trans_*`;并按语言动态拼接 `title/brief/description/vendor/category_*` 的 `.{lang}` 字段,叠加 shared 字段(`tags`、`option*_values`)。 |
| 25 | 25 | 3. `build_query()` |
| ... | ... | @@ -76,9 +76,6 @@ |
| 76 | 76 | |
| 77 | 77 | 最终按 `bool.should` 组合,`minimum_should_match: 1`。 |
| 78 | 78 | |
| 79 | -> **附 — 混写辅助召回** | |
| 80 | -> 当中英(或多脚本)混写时,为略抬召回:`QueryParser` 用 `contains_chinese`(文中有汉字)、`contains_english`(分词中有长度 ≥3 的纯英文 token)打标;`ESQueryBuilder` 在某一语言的 `multi_match` 上,按规则把**另一语种**的同类字段并入同一 `fields`(受 `index_languages` 限制),并入列的 boost 为配置值再乘 **`mixed_script_merged_field_boost_scale`(默认 0.6,`ESQueryBuilder` 构造参数)**。字段在内部以 `(path, boost)` 列表合并后再格式化为 ES 字符串。 | |
| 81 | - | |
| 82 | 79 | ## 5. 关键配置项(文本策略) |
| 83 | 80 | |
| 84 | 81 | `query_config` 下与解析等待相关的项: |
| ... | ... | @@ -147,11 +144,9 @@ |
| 147 | 144 | - `translations` |
| 148 | 145 | - `query_vector` |
| 149 | 146 | - `query_tokens` |
| 150 | - - `contains_chinese` / `contains_english` | |
| 151 | 147 | - `Searcher` 负责“租户语境”: |
| 152 | 148 | - `index_languages` |
| 153 | 149 | - 将其传给 parser 作为 `target_languages` |
| 154 | - - 将其传给 builder 作为字段展开约束 | |
| 155 | 150 | - `ESQueryBuilder` 负责“表达式展开”: |
| 156 | 151 | - 动态字段组装 |
| 157 | 152 | - 子句权重分配 | ... | ... |
embeddings/README.md
| ... | ... | @@ -5,6 +5,8 @@ |
| 5 | 5 | - `../docs/TEI_SERVICE说明文档.md` |
| 6 | 6 | - `../docs/CNCLIP_SERVICE说明文档.md` |
| 7 | 7 | |
| 8 | +**请求日志串联(reqid / uid)**:统一实现在仓库根目录的 `request_log_context.py`(勿放到 `utils/` 下,以免 `.venv-embedding` 因 `utils/__init__.py` 拉取数据库依赖)。Uvicorn 日志配置见 `config/uvicorn_embedding_logging.json`。 | |
| 9 | + | |
| 8 | 10 | --- |
| 9 | 11 | |
| 10 | 12 | 这个目录是一个完整的“向量化模块”,包含: | ... | ... |
embeddings/config.py
| ... | ... | @@ -2,6 +2,7 @@ |
| 2 | 2 | |
| 3 | 3 | from __future__ import annotations |
| 4 | 4 | |
| 5 | +import os | |
| 5 | 6 | from typing import Optional |
| 6 | 7 | |
| 7 | 8 | from config.loader import get_app_config |
| ... | ... | @@ -25,6 +26,11 @@ class EmbeddingConfig(object): |
| 25 | 26 | self.TEXT_NORMALIZE_EMBEDDINGS = bool(text_backend.get("normalize_embeddings", True)) |
| 26 | 27 | self.TEI_BASE_URL = str(text_backend.get("base_url") or "http://127.0.0.1:8080") |
| 27 | 28 | self.TEI_TIMEOUT_SEC = int(text_backend.get("timeout_sec", 60)) |
| 29 | + self.TEI_MAX_CLIENT_BATCH_SIZE = int( | |
| 30 | + os.getenv("TEI_MAX_CLIENT_BATCH_SIZE") | |
| 31 | + or text_backend.get("max_client_batch_size") | |
| 32 | + or 24 | |
| 33 | + ) | |
| 28 | 34 | |
| 29 | 35 | self.USE_CLIP_AS_SERVICE = services.image_backend == "clip_as_service" |
| 30 | 36 | self.CLIP_AS_SERVICE_SERVER = str(image_backend.get("server") or "grpc://127.0.0.1:51000") | ... | ... |
embeddings/image_encoder.py
| ... | ... | @@ -13,6 +13,7 @@ from config.loader import get_app_config |
| 13 | 13 | from config.services_config import get_embedding_image_base_url |
| 14 | 14 | from embeddings.cache_keys import build_image_cache_key |
| 15 | 15 | from embeddings.redis_embedding_cache import RedisEmbeddingCache |
| 16 | +from request_log_context import build_downstream_request_headers, build_request_log_extra | |
| 16 | 17 | |
| 17 | 18 | |
| 18 | 19 | class CLIPImageEncoder: |
| ... | ... | @@ -40,6 +41,8 @@ class CLIPImageEncoder: |
| 40 | 41 | request_data: List[str], |
| 41 | 42 | normalize_embeddings: bool = True, |
| 42 | 43 | priority: int = 0, |
| 44 | + request_id: Optional[str] = None, | |
| 45 | + user_id: Optional[str] = None, | |
| 43 | 46 | ) -> List[Any]: |
| 44 | 47 | """ |
| 45 | 48 | Call the embedding service API. |
| ... | ... | @@ -50,6 +53,7 @@ class CLIPImageEncoder: |
| 50 | 53 | Returns: |
| 51 | 54 | List of embeddings (list[float]) or nulls (None), aligned to input order |
| 52 | 55 | """ |
| 56 | + response = None | |
| 53 | 57 | try: |
| 54 | 58 | response = requests.post( |
| 55 | 59 | self.endpoint, |
| ... | ... | @@ -58,12 +62,26 @@ class CLIPImageEncoder: |
| 58 | 62 | "priority": max(0, int(priority)), |
| 59 | 63 | }, |
| 60 | 64 | json=request_data, |
| 65 | + headers=build_downstream_request_headers(request_id=request_id, user_id=user_id), | |
| 61 | 66 | timeout=60 |
| 62 | 67 | ) |
| 63 | 68 | response.raise_for_status() |
| 64 | 69 | return response.json() |
| 65 | 70 | except requests.exceptions.RequestException as e: |
| 66 | - logger.error(f"CLIPImageEncoder service request failed: {e}", exc_info=True) | |
| 71 | + body_preview = "" | |
| 72 | + if response is not None: | |
| 73 | + try: | |
| 74 | + body_preview = (response.text or "")[:300] | |
| 75 | + except Exception: | |
| 76 | + body_preview = "" | |
| 77 | + logger.error( | |
| 78 | + "CLIPImageEncoder service request failed | status=%s body=%s error=%s", | |
| 79 | + getattr(response, "status_code", "n/a"), | |
| 80 | + body_preview, | |
| 81 | + e, | |
| 82 | + exc_info=True, | |
| 83 | + extra=build_request_log_extra(request_id=request_id, user_id=user_id), | |
| 84 | + ) | |
| 67 | 85 | raise |
| 68 | 86 | |
| 69 | 87 | def encode_image(self, image: Image.Image) -> np.ndarray: |
| ... | ... | @@ -79,6 +97,8 @@ class CLIPImageEncoder: |
| 79 | 97 | url: str, |
| 80 | 98 | normalize_embeddings: bool = True, |
| 81 | 99 | priority: int = 0, |
| 100 | + request_id: Optional[str] = None, | |
| 101 | + user_id: Optional[str] = None, | |
| 82 | 102 | ) -> np.ndarray: |
| 83 | 103 | """ |
| 84 | 104 | Generate image embedding via network service using URL. |
| ... | ... | @@ -98,6 +118,8 @@ class CLIPImageEncoder: |
| 98 | 118 | [url], |
| 99 | 119 | normalize_embeddings=normalize_embeddings, |
| 100 | 120 | priority=priority, |
| 121 | + request_id=request_id, | |
| 122 | + user_id=user_id, | |
| 101 | 123 | ) |
| 102 | 124 | if not response_data or len(response_data) != 1 or response_data[0] is None: |
| 103 | 125 | raise RuntimeError(f"No image embedding returned for URL: {url}") |
| ... | ... | @@ -113,6 +135,8 @@ class CLIPImageEncoder: |
| 113 | 135 | batch_size: int = 8, |
| 114 | 136 | normalize_embeddings: bool = True, |
| 115 | 137 | priority: int = 0, |
| 138 | + request_id: Optional[str] = None, | |
| 139 | + user_id: Optional[str] = None, | |
| 116 | 140 | ) -> List[np.ndarray]: |
| 117 | 141 | """ |
| 118 | 142 | Encode a batch of images efficiently via network service. |
| ... | ... | @@ -151,6 +175,8 @@ class CLIPImageEncoder: |
| 151 | 175 | batch_urls, |
| 152 | 176 | normalize_embeddings=normalize_embeddings, |
| 153 | 177 | priority=priority, |
| 178 | + request_id=request_id, | |
| 179 | + user_id=user_id, | |
| 154 | 180 | ) |
| 155 | 181 | if not response_data or len(response_data) != len(batch_urls): |
| 156 | 182 | raise RuntimeError( |
| ... | ... | @@ -176,6 +202,8 @@ class CLIPImageEncoder: |
| 176 | 202 | batch_size: Optional[int] = None, |
| 177 | 203 | normalize_embeddings: bool = True, |
| 178 | 204 | priority: int = 0, |
| 205 | + request_id: Optional[str] = None, | |
| 206 | + user_id: Optional[str] = None, | |
| 179 | 207 | ) -> List[np.ndarray]: |
| 180 | 208 | """ |
| 181 | 209 | 与 ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。 |
| ... | ... | @@ -192,4 +220,6 @@ class CLIPImageEncoder: |
| 192 | 220 | batch_size=batch_size or 8, |
| 193 | 221 | normalize_embeddings=normalize_embeddings, |
| 194 | 222 | priority=priority, |
| 223 | + request_id=request_id, | |
| 224 | + user_id=user_id, | |
| 195 | 225 | ) | ... | ... |
embeddings/server.py
| ... | ... | @@ -26,17 +26,17 @@ from embeddings.cache_keys import build_image_cache_key, build_text_cache_key |
| 26 | 26 | from embeddings.config import CONFIG |
| 27 | 27 | from embeddings.protocols import ImageEncoderProtocol |
| 28 | 28 | from embeddings.redis_embedding_cache import RedisEmbeddingCache |
| 29 | +from request_log_context import ( | |
| 30 | + LOG_LINE_FORMAT, | |
| 31 | + RequestLogContextFilter, | |
| 32 | + bind_request_log_context, | |
| 33 | + build_request_log_extra, | |
| 34 | + reset_request_log_context, | |
| 35 | +) | |
| 29 | 36 | |
| 30 | 37 | app = FastAPI(title="saas-search Embedding Service", version="1.0.0") |
| 31 | 38 | |
| 32 | 39 | |
| 33 | -class _DefaultRequestIdFilter(logging.Filter): | |
| 34 | - def filter(self, record: logging.LogRecord) -> bool: | |
| 35 | - if not hasattr(record, "reqid"): | |
| 36 | - record.reqid = "-1" | |
| 37 | - return True | |
| 38 | - | |
| 39 | - | |
| 40 | 40 | def configure_embedding_logging() -> None: |
| 41 | 41 | root_logger = logging.getLogger() |
| 42 | 42 | if getattr(root_logger, "_embedding_logging_configured", False): |
| ... | ... | @@ -47,17 +47,15 @@ def configure_embedding_logging() -> None: |
| 47 | 47 | |
| 48 | 48 | log_level = os.getenv("LOG_LEVEL", "INFO").upper() |
| 49 | 49 | numeric_level = getattr(logging, log_level, logging.INFO) |
| 50 | - formatter = logging.Formatter( | |
| 51 | - "%(asctime)s | reqid:%(reqid)s | %(name)s | %(levelname)s | %(message)s" | |
| 52 | - ) | |
| 53 | - request_filter = _DefaultRequestIdFilter() | |
| 50 | + formatter = logging.Formatter(LOG_LINE_FORMAT) | |
| 51 | + context_filter = RequestLogContextFilter() | |
| 54 | 52 | |
| 55 | 53 | root_logger.setLevel(numeric_level) |
| 56 | 54 | root_logger.handlers.clear() |
| 57 | 55 | stream_handler = logging.StreamHandler() |
| 58 | 56 | stream_handler.setLevel(numeric_level) |
| 59 | 57 | stream_handler.setFormatter(formatter) |
| 60 | - stream_handler.addFilter(request_filter) | |
| 58 | + stream_handler.addFilter(context_filter) | |
| 61 | 59 | root_logger.addHandler(stream_handler) |
| 62 | 60 | |
| 63 | 61 | verbose_logger = logging.getLogger("embedding.verbose") |
| ... | ... | @@ -231,6 +229,7 @@ class _TextDispatchTask: |
| 231 | 229 | normalized: List[str] |
| 232 | 230 | effective_normalize: bool |
| 233 | 231 | request_id: str |
| 232 | + user_id: str | |
| 234 | 233 | priority: int |
| 235 | 234 | created_at: float |
| 236 | 235 | done: threading.Event |
| ... | ... | @@ -321,12 +320,13 @@ def _text_dispatch_worker_loop(worker_idx: int) -> None: |
| 321 | 320 | _priority_label(task.priority), |
| 322 | 321 | len(task.normalized), |
| 323 | 322 | queue_wait_ms, |
| 324 | - extra=_request_log_extra(task.request_id), | |
| 323 | + extra=build_request_log_extra(task.request_id, task.user_id), | |
| 325 | 324 | ) |
| 326 | 325 | task.result = _embed_text_impl( |
| 327 | 326 | task.normalized, |
| 328 | 327 | task.effective_normalize, |
| 329 | 328 | task.request_id, |
| 329 | + task.user_id, | |
| 330 | 330 | task.priority, |
| 331 | 331 | ) |
| 332 | 332 | except Exception as exc: |
| ... | ... | @@ -339,6 +339,7 @@ def _submit_text_dispatch_and_wait( |
| 339 | 339 | normalized: List[str], |
| 340 | 340 | effective_normalize: bool, |
| 341 | 341 | request_id: str, |
| 342 | + user_id: str, | |
| 342 | 343 | priority: int, |
| 343 | 344 | ) -> _EmbedResult: |
| 344 | 345 | if not any(worker.is_alive() for worker in _text_dispatch_workers): |
| ... | ... | @@ -347,6 +348,7 @@ def _submit_text_dispatch_and_wait( |
| 347 | 348 | normalized=normalized, |
| 348 | 349 | effective_normalize=effective_normalize, |
| 349 | 350 | request_id=request_id, |
| 351 | + user_id=user_id, | |
| 350 | 352 | priority=_effective_priority(priority), |
| 351 | 353 | created_at=time.perf_counter(), |
| 352 | 354 | done=threading.Event(), |
| ... | ... | @@ -380,6 +382,7 @@ class _SingleTextTask: |
| 380 | 382 | priority: int |
| 381 | 383 | created_at: float |
| 382 | 384 | request_id: str |
| 385 | + user_id: str | |
| 383 | 386 | done: threading.Event |
| 384 | 387 | result: Optional[List[float]] = None |
| 385 | 388 | error: Optional[Exception] = None |
| ... | ... | @@ -435,10 +438,6 @@ def _preview_vector(vec: Optional[List[float]], max_dims: int = _VECTOR_PREVIEW_ |
| 435 | 438 | return [round(float(v), 6) for v in vec[:max_dims]] |
| 436 | 439 | |
| 437 | 440 | |
| 438 | -def _request_log_extra(request_id: str) -> Dict[str, str]: | |
| 439 | - return {"reqid": request_id} | |
| 440 | - | |
| 441 | - | |
| 442 | 441 | def _resolve_request_id(http_request: Request) -> str: |
| 443 | 442 | header_value = http_request.headers.get("X-Request-ID") |
| 444 | 443 | if header_value and header_value.strip(): |
| ... | ... | @@ -446,6 +445,13 @@ def _resolve_request_id(http_request: Request) -> str: |
| 446 | 445 | return str(uuid.uuid4())[:8] |
| 447 | 446 | |
| 448 | 447 | |
| 448 | +def _resolve_user_id(http_request: Request) -> str: | |
| 449 | + header_value = http_request.headers.get("X-User-ID") or http_request.headers.get("User-ID") | |
| 450 | + if header_value and header_value.strip(): | |
| 451 | + return header_value.strip()[:64] | |
| 452 | + return "-1" | |
| 453 | + | |
| 454 | + | |
| 449 | 455 | def _request_client(http_request: Request) -> str: |
| 450 | 456 | client = getattr(http_request, "client", None) |
| 451 | 457 | host = getattr(client, "host", None) |
| ... | ... | @@ -522,18 +528,21 @@ def _text_batch_worker_loop() -> None: |
| 522 | 528 | try: |
| 523 | 529 | queue_wait_ms = [(time.perf_counter() - task.created_at) * 1000.0 for task in batch] |
| 524 | 530 | reqids = [task.request_id for task in batch] |
| 531 | + uids = [task.user_id for task in batch] | |
| 525 | 532 | logger.info( |
| 526 | - "text microbatch dispatch | size=%d priority=%s queue_wait_ms_min=%.2f queue_wait_ms_max=%.2f reqids=%s preview=%s", | |
| 533 | + "text microbatch dispatch | size=%d priority=%s queue_wait_ms_min=%.2f queue_wait_ms_max=%.2f reqids=%s uids=%s preview=%s", | |
| 527 | 534 | len(batch), |
| 528 | 535 | _priority_label(max(task.priority for task in batch)), |
| 529 | 536 | min(queue_wait_ms) if queue_wait_ms else 0.0, |
| 530 | 537 | max(queue_wait_ms) if queue_wait_ms else 0.0, |
| 531 | 538 | reqids, |
| 539 | + uids, | |
| 532 | 540 | _preview_inputs( |
| 533 | 541 | [task.text for task in batch], |
| 534 | 542 | _LOG_PREVIEW_COUNT, |
| 535 | 543 | _LOG_TEXT_PREVIEW_CHARS, |
| 536 | 544 | ), |
| 545 | + extra=build_request_log_extra(), | |
| 537 | 546 | ) |
| 538 | 547 | batch_t0 = time.perf_counter() |
| 539 | 548 | embs = _encode_local_st([task.text for task in batch], normalize_embeddings=False) |
| ... | ... | @@ -548,19 +557,23 @@ def _text_batch_worker_loop() -> None: |
| 548 | 557 | raise RuntimeError("Text model returned empty embedding in micro-batch") |
| 549 | 558 | task.result = vec |
| 550 | 559 | logger.info( |
| 551 | - "text microbatch done | size=%d reqids=%s dim=%d backend_elapsed_ms=%.2f", | |
| 560 | + "text microbatch done | size=%d reqids=%s uids=%s dim=%d backend_elapsed_ms=%.2f", | |
| 552 | 561 | len(batch), |
| 553 | 562 | reqids, |
| 563 | + uids, | |
| 554 | 564 | len(batch[0].result) if batch and batch[0].result is not None else 0, |
| 555 | 565 | (time.perf_counter() - batch_t0) * 1000.0, |
| 566 | + extra=build_request_log_extra(), | |
| 556 | 567 | ) |
| 557 | 568 | except Exception as exc: |
| 558 | 569 | logger.error( |
| 559 | - "text microbatch failed | size=%d reqids=%s error=%s", | |
| 570 | + "text microbatch failed | size=%d reqids=%s uids=%s error=%s", | |
| 560 | 571 | len(batch), |
| 561 | 572 | [task.request_id for task in batch], |
| 573 | + [task.user_id for task in batch], | |
| 562 | 574 | exc, |
| 563 | 575 | exc_info=True, |
| 576 | + extra=build_request_log_extra(), | |
| 564 | 577 | ) |
| 565 | 578 | for task in batch: |
| 566 | 579 | task.error = exc |
| ... | ... | @@ -573,6 +586,7 @@ def _encode_single_text_with_microbatch( |
| 573 | 586 | text: str, |
| 574 | 587 | normalize: bool, |
| 575 | 588 | request_id: str, |
| 589 | + user_id: str, | |
| 576 | 590 | priority: int, |
| 577 | 591 | ) -> List[float]: |
| 578 | 592 | task = _SingleTextTask( |
| ... | ... | @@ -581,6 +595,7 @@ def _encode_single_text_with_microbatch( |
| 581 | 595 | priority=_effective_priority(priority), |
| 582 | 596 | created_at=time.perf_counter(), |
| 583 | 597 | request_id=request_id, |
| 598 | + user_id=user_id, | |
| 584 | 599 | done=threading.Event(), |
| 585 | 600 | ) |
| 586 | 601 | with _text_single_queue_cv: |
| ... | ... | @@ -632,6 +647,9 @@ def load_models(): |
| 632 | 647 | _text_model = TEITextModel( |
| 633 | 648 | base_url=str(base_url), |
| 634 | 649 | timeout_sec=timeout_sec, |
| 650 | + max_client_batch_size=int( | |
| 651 | + backend_cfg.get("max_client_batch_size") or CONFIG.TEI_MAX_CLIENT_BATCH_SIZE | |
| 652 | + ), | |
| 635 | 653 | ) |
| 636 | 654 | elif backend_name == "local_st": |
| 637 | 655 | from embeddings.text_embedding_sentence_transformers import Qwen3TextModel |
| ... | ... | @@ -823,6 +841,7 @@ def _embed_text_impl( |
| 823 | 841 | normalized: List[str], |
| 824 | 842 | effective_normalize: bool, |
| 825 | 843 | request_id: str, |
| 844 | + user_id: str, | |
| 826 | 845 | priority: int = 0, |
| 827 | 846 | ) -> _EmbedResult: |
| 828 | 847 | if _text_model is None: |
| ... | ... | @@ -854,7 +873,7 @@ def _embed_text_impl( |
| 854 | 873 | effective_normalize, |
| 855 | 874 | len(out[0]) if out and out[0] is not None else 0, |
| 856 | 875 | cache_hits, |
| 857 | - extra=_request_log_extra(request_id), | |
| 876 | + extra=build_request_log_extra(request_id, user_id), | |
| 858 | 877 | ) |
| 859 | 878 | return _EmbedResult( |
| 860 | 879 | vectors=out, |
| ... | ... | @@ -873,6 +892,7 @@ def _embed_text_impl( |
| 873 | 892 | missing_texts[0], |
| 874 | 893 | normalize=effective_normalize, |
| 875 | 894 | request_id=request_id, |
| 895 | + user_id=user_id, | |
| 876 | 896 | priority=priority, |
| 877 | 897 | ) |
| 878 | 898 | ] |
| ... | ... | @@ -905,7 +925,7 @@ def _embed_text_impl( |
| 905 | 925 | "Text embedding backend failure: %s", |
| 906 | 926 | e, |
| 907 | 927 | exc_info=True, |
| 908 | - extra=_request_log_extra(request_id), | |
| 928 | + extra=build_request_log_extra(request_id, user_id), | |
| 909 | 929 | ) |
| 910 | 930 | raise RuntimeError(f"Text embedding backend failure: {e}") from e |
| 911 | 931 | |
| ... | ... | @@ -931,7 +951,7 @@ def _embed_text_impl( |
| 931 | 951 | cache_hits, |
| 932 | 952 | len(missing_texts), |
| 933 | 953 | backend_elapsed_ms, |
| 934 | - extra=_request_log_extra(request_id), | |
| 954 | + extra=build_request_log_extra(request_id, user_id), | |
| 935 | 955 | ) |
| 936 | 956 | return _EmbedResult( |
| 937 | 957 | vectors=out, |
| ... | ... | @@ -954,75 +974,79 @@ async def embed_text( |
| 954 | 974 | raise HTTPException(status_code=503, detail="Text embedding model not loaded in this service") |
| 955 | 975 | |
| 956 | 976 | request_id = _resolve_request_id(http_request) |
| 977 | + user_id = _resolve_user_id(http_request) | |
| 978 | + _, _, log_tokens = bind_request_log_context(request_id, user_id) | |
| 957 | 979 | response.headers["X-Request-ID"] = request_id |
| 958 | - | |
| 959 | - if priority < 0: | |
| 960 | - raise HTTPException(status_code=400, detail="priority must be >= 0") | |
| 961 | - effective_priority = _effective_priority(priority) | |
| 962 | - effective_normalize = bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | |
| 963 | - normalized: List[str] = [] | |
| 964 | - for i, t in enumerate(texts): | |
| 965 | - if not isinstance(t, str): | |
| 966 | - raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: must be string") | |
| 967 | - s = t.strip() | |
| 968 | - if not s: | |
| 969 | - raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: empty string") | |
| 970 | - normalized.append(s) | |
| 971 | - | |
| 972 | - cache_check_started = time.perf_counter() | |
| 973 | - cache_only = _try_full_text_cache_hit(normalized, effective_normalize) | |
| 974 | - if cache_only is not None: | |
| 975 | - latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 | |
| 976 | - _text_stats.record_completed( | |
| 977 | - success=True, | |
| 978 | - latency_ms=latency_ms, | |
| 979 | - backend_latency_ms=0.0, | |
| 980 | - cache_hits=cache_only.cache_hits, | |
| 981 | - cache_misses=0, | |
| 982 | - ) | |
| 983 | - logger.info( | |
| 984 | - "embed_text response | backend=%s mode=cache-only priority=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 first_vector=%s latency_ms=%.2f", | |
| 985 | - _text_backend_name, | |
| 986 | - _priority_label(effective_priority), | |
| 987 | - len(normalized), | |
| 988 | - effective_normalize, | |
| 989 | - len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, | |
| 990 | - cache_only.cache_hits, | |
| 991 | - _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), | |
| 992 | - latency_ms, | |
| 993 | - extra=_request_log_extra(request_id), | |
| 994 | - ) | |
| 995 | - return cache_only.vectors | |
| 996 | - | |
| 997 | - accepted, active = _text_request_limiter.try_acquire(bypass_limit=effective_priority > 0) | |
| 998 | - if not accepted: | |
| 999 | - _text_stats.record_rejected() | |
| 1000 | - logger.warning( | |
| 1001 | - "embed_text rejected | client=%s backend=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", | |
| 1002 | - _request_client(http_request), | |
| 1003 | - _text_backend_name, | |
| 1004 | - _priority_label(effective_priority), | |
| 1005 | - len(normalized), | |
| 1006 | - effective_normalize, | |
| 1007 | - active, | |
| 1008 | - _TEXT_MAX_INFLIGHT, | |
| 1009 | - _preview_inputs(normalized, _LOG_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), | |
| 1010 | - extra=_request_log_extra(request_id), | |
| 1011 | - ) | |
| 1012 | - raise HTTPException( | |
| 1013 | - status_code=_OVERLOAD_STATUS_CODE, | |
| 1014 | - detail=( | |
| 1015 | - "Text embedding service busy for priority=0 requests: " | |
| 1016 | - f"active={active}, limit={_TEXT_MAX_INFLIGHT}" | |
| 1017 | - ), | |
| 1018 | - ) | |
| 1019 | - | |
| 980 | + response.headers["X-User-ID"] = user_id | |
| 1020 | 981 | request_started = time.perf_counter() |
| 1021 | 982 | success = False |
| 1022 | 983 | backend_elapsed_ms = 0.0 |
| 1023 | 984 | cache_hits = 0 |
| 1024 | 985 | cache_misses = 0 |
| 986 | + limiter_acquired = False | |
| 987 | + | |
| 1025 | 988 | try: |
| 989 | + if priority < 0: | |
| 990 | + raise HTTPException(status_code=400, detail="priority must be >= 0") | |
| 991 | + effective_priority = _effective_priority(priority) | |
| 992 | + effective_normalize = bool(CONFIG.TEXT_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | |
| 993 | + normalized: List[str] = [] | |
| 994 | + for i, t in enumerate(texts): | |
| 995 | + if not isinstance(t, str): | |
| 996 | + raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: must be string") | |
| 997 | + s = t.strip() | |
| 998 | + if not s: | |
| 999 | + raise HTTPException(status_code=400, detail=f"Invalid text at index {i}: empty string") | |
| 1000 | + normalized.append(s) | |
| 1001 | + | |
| 1002 | + cache_check_started = time.perf_counter() | |
| 1003 | + cache_only = _try_full_text_cache_hit(normalized, effective_normalize) | |
| 1004 | + if cache_only is not None: | |
| 1005 | + latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 | |
| 1006 | + _text_stats.record_completed( | |
| 1007 | + success=True, | |
| 1008 | + latency_ms=latency_ms, | |
| 1009 | + backend_latency_ms=0.0, | |
| 1010 | + cache_hits=cache_only.cache_hits, | |
| 1011 | + cache_misses=0, | |
| 1012 | + ) | |
| 1013 | + logger.info( | |
| 1014 | + "embed_text response | backend=%s mode=cache-only priority=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 first_vector=%s latency_ms=%.2f", | |
| 1015 | + _text_backend_name, | |
| 1016 | + _priority_label(effective_priority), | |
| 1017 | + len(normalized), | |
| 1018 | + effective_normalize, | |
| 1019 | + len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, | |
| 1020 | + cache_only.cache_hits, | |
| 1021 | + _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), | |
| 1022 | + latency_ms, | |
| 1023 | + extra=build_request_log_extra(request_id, user_id), | |
| 1024 | + ) | |
| 1025 | + return cache_only.vectors | |
| 1026 | + | |
| 1027 | + accepted, active = _text_request_limiter.try_acquire(bypass_limit=effective_priority > 0) | |
| 1028 | + if not accepted: | |
| 1029 | + _text_stats.record_rejected() | |
| 1030 | + logger.warning( | |
| 1031 | + "embed_text rejected | client=%s backend=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", | |
| 1032 | + _request_client(http_request), | |
| 1033 | + _text_backend_name, | |
| 1034 | + _priority_label(effective_priority), | |
| 1035 | + len(normalized), | |
| 1036 | + effective_normalize, | |
| 1037 | + active, | |
| 1038 | + _TEXT_MAX_INFLIGHT, | |
| 1039 | + _preview_inputs(normalized, _LOG_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), | |
| 1040 | + extra=build_request_log_extra(request_id, user_id), | |
| 1041 | + ) | |
| 1042 | + raise HTTPException( | |
| 1043 | + status_code=_OVERLOAD_STATUS_CODE, | |
| 1044 | + detail=( | |
| 1045 | + "Text embedding service busy for priority=0 requests: " | |
| 1046 | + f"active={active}, limit={_TEXT_MAX_INFLIGHT}" | |
| 1047 | + ), | |
| 1048 | + ) | |
| 1049 | + limiter_acquired = True | |
| 1026 | 1050 | logger.info( |
| 1027 | 1051 | "embed_text request | client=%s backend=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", |
| 1028 | 1052 | _request_client(http_request), |
| ... | ... | @@ -1033,7 +1057,7 @@ async def embed_text( |
| 1033 | 1057 | active, |
| 1034 | 1058 | _TEXT_MAX_INFLIGHT, |
| 1035 | 1059 | _preview_inputs(normalized, _LOG_PREVIEW_COUNT, _LOG_TEXT_PREVIEW_CHARS), |
| 1036 | - extra=_request_log_extra(request_id), | |
| 1060 | + extra=build_request_log_extra(request_id, user_id), | |
| 1037 | 1061 | ) |
| 1038 | 1062 | verbose_logger.info( |
| 1039 | 1063 | "embed_text detail | payload=%s normalize=%s backend=%s priority=%s", |
| ... | ... | @@ -1041,13 +1065,14 @@ async def embed_text( |
| 1041 | 1065 | effective_normalize, |
| 1042 | 1066 | _text_backend_name, |
| 1043 | 1067 | _priority_label(effective_priority), |
| 1044 | - extra=_request_log_extra(request_id), | |
| 1068 | + extra=build_request_log_extra(request_id, user_id), | |
| 1045 | 1069 | ) |
| 1046 | 1070 | result = await run_in_threadpool( |
| 1047 | 1071 | _submit_text_dispatch_and_wait, |
| 1048 | 1072 | normalized, |
| 1049 | 1073 | effective_normalize, |
| 1050 | 1074 | request_id, |
| 1075 | + user_id, | |
| 1051 | 1076 | effective_priority, |
| 1052 | 1077 | ) |
| 1053 | 1078 | success = True |
| ... | ... | @@ -1074,7 +1099,7 @@ async def embed_text( |
| 1074 | 1099 | cache_misses, |
| 1075 | 1100 | _preview_vector(result.vectors[0] if result.vectors else None), |
| 1076 | 1101 | latency_ms, |
| 1077 | - extra=_request_log_extra(request_id), | |
| 1102 | + extra=build_request_log_extra(request_id, user_id), | |
| 1078 | 1103 | ) |
| 1079 | 1104 | verbose_logger.info( |
| 1080 | 1105 | "embed_text result detail | count=%d priority=%s first_vector=%s latency_ms=%.2f", |
| ... | ... | @@ -1084,7 +1109,7 @@ async def embed_text( |
| 1084 | 1109 | if result.vectors and result.vectors[0] is not None |
| 1085 | 1110 | else [], |
| 1086 | 1111 | latency_ms, |
| 1087 | - extra=_request_log_extra(request_id), | |
| 1112 | + extra=build_request_log_extra(request_id, user_id), | |
| 1088 | 1113 | ) |
| 1089 | 1114 | return result.vectors |
| 1090 | 1115 | except HTTPException: |
| ... | ... | @@ -1107,24 +1132,27 @@ async def embed_text( |
| 1107 | 1132 | latency_ms, |
| 1108 | 1133 | e, |
| 1109 | 1134 | exc_info=True, |
| 1110 | - extra=_request_log_extra(request_id), | |
| 1135 | + extra=build_request_log_extra(request_id, user_id), | |
| 1111 | 1136 | ) |
| 1112 | 1137 | raise HTTPException(status_code=502, detail=str(e)) from e |
| 1113 | 1138 | finally: |
| 1114 | - remaining = _text_request_limiter.release(success=success) | |
| 1115 | - logger.info( | |
| 1116 | - "embed_text finalize | success=%s priority=%s active_after=%d", | |
| 1117 | - success, | |
| 1118 | - _priority_label(effective_priority), | |
| 1119 | - remaining, | |
| 1120 | - extra=_request_log_extra(request_id), | |
| 1121 | - ) | |
| 1139 | + if limiter_acquired: | |
| 1140 | + remaining = _text_request_limiter.release(success=success) | |
| 1141 | + logger.info( | |
| 1142 | + "embed_text finalize | success=%s priority=%s active_after=%d", | |
| 1143 | + success, | |
| 1144 | + _priority_label(effective_priority), | |
| 1145 | + remaining, | |
| 1146 | + extra=build_request_log_extra(request_id, user_id), | |
| 1147 | + ) | |
| 1148 | + reset_request_log_context(log_tokens) | |
| 1122 | 1149 | |
| 1123 | 1150 | |
| 1124 | 1151 | def _embed_image_impl( |
| 1125 | 1152 | urls: List[str], |
| 1126 | 1153 | effective_normalize: bool, |
| 1127 | 1154 | request_id: str, |
| 1155 | + user_id: str, | |
| 1128 | 1156 | ) -> _EmbedResult: |
| 1129 | 1157 | if _image_model is None: |
| 1130 | 1158 | raise RuntimeError("Image model not loaded") |
| ... | ... | @@ -1154,7 +1182,7 @@ def _embed_image_impl( |
| 1154 | 1182 | effective_normalize, |
| 1155 | 1183 | len(out[0]) if out and out[0] is not None else 0, |
| 1156 | 1184 | cache_hits, |
| 1157 | - extra=_request_log_extra(request_id), | |
| 1185 | + extra=build_request_log_extra(request_id, user_id), | |
| 1158 | 1186 | ) |
| 1159 | 1187 | return _EmbedResult( |
| 1160 | 1188 | vectors=out, |
| ... | ... | @@ -1194,7 +1222,7 @@ def _embed_image_impl( |
| 1194 | 1222 | cache_hits, |
| 1195 | 1223 | len(missing_urls), |
| 1196 | 1224 | backend_elapsed_ms, |
| 1197 | - extra=_request_log_extra(request_id), | |
| 1225 | + extra=build_request_log_extra(request_id, user_id), | |
| 1198 | 1226 | ) |
| 1199 | 1227 | return _EmbedResult( |
| 1200 | 1228 | vectors=out, |
| ... | ... | @@ -1217,74 +1245,78 @@ async def embed_image( |
| 1217 | 1245 | raise HTTPException(status_code=503, detail="Image embedding model not loaded in this service") |
| 1218 | 1246 | |
| 1219 | 1247 | request_id = _resolve_request_id(http_request) |
| 1248 | + user_id = _resolve_user_id(http_request) | |
| 1249 | + _, _, log_tokens = bind_request_log_context(request_id, user_id) | |
| 1220 | 1250 | response.headers["X-Request-ID"] = request_id |
| 1221 | - | |
| 1222 | - if priority < 0: | |
| 1223 | - raise HTTPException(status_code=400, detail="priority must be >= 0") | |
| 1224 | - effective_priority = _effective_priority(priority) | |
| 1225 | - | |
| 1226 | - effective_normalize = bool(CONFIG.IMAGE_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | |
| 1227 | - urls: List[str] = [] | |
| 1228 | - for i, url_or_path in enumerate(images): | |
| 1229 | - if not isinstance(url_or_path, str): | |
| 1230 | - raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: must be string URL/path") | |
| 1231 | - s = url_or_path.strip() | |
| 1232 | - if not s: | |
| 1233 | - raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: empty URL/path") | |
| 1234 | - urls.append(s) | |
| 1235 | - | |
| 1236 | - cache_check_started = time.perf_counter() | |
| 1237 | - cache_only = _try_full_image_cache_hit(urls, effective_normalize) | |
| 1238 | - if cache_only is not None: | |
| 1239 | - latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 | |
| 1240 | - _image_stats.record_completed( | |
| 1241 | - success=True, | |
| 1242 | - latency_ms=latency_ms, | |
| 1243 | - backend_latency_ms=0.0, | |
| 1244 | - cache_hits=cache_only.cache_hits, | |
| 1245 | - cache_misses=0, | |
| 1246 | - ) | |
| 1247 | - logger.info( | |
| 1248 | - "embed_image response | mode=cache-only priority=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 first_vector=%s latency_ms=%.2f", | |
| 1249 | - _priority_label(effective_priority), | |
| 1250 | - len(urls), | |
| 1251 | - effective_normalize, | |
| 1252 | - len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, | |
| 1253 | - cache_only.cache_hits, | |
| 1254 | - _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), | |
| 1255 | - latency_ms, | |
| 1256 | - extra=_request_log_extra(request_id), | |
| 1257 | - ) | |
| 1258 | - return cache_only.vectors | |
| 1259 | - | |
| 1260 | - accepted, active = _image_request_limiter.try_acquire(bypass_limit=effective_priority > 0) | |
| 1261 | - if not accepted: | |
| 1262 | - _image_stats.record_rejected() | |
| 1263 | - logger.warning( | |
| 1264 | - "embed_image rejected | client=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", | |
| 1265 | - _request_client(http_request), | |
| 1266 | - _priority_label(effective_priority), | |
| 1267 | - len(urls), | |
| 1268 | - effective_normalize, | |
| 1269 | - active, | |
| 1270 | - _IMAGE_MAX_INFLIGHT, | |
| 1271 | - _preview_inputs(urls, _LOG_PREVIEW_COUNT, _LOG_IMAGE_PREVIEW_CHARS), | |
| 1272 | - extra=_request_log_extra(request_id), | |
| 1273 | - ) | |
| 1274 | - raise HTTPException( | |
| 1275 | - status_code=_OVERLOAD_STATUS_CODE, | |
| 1276 | - detail=( | |
| 1277 | - "Image embedding service busy for priority=0 requests: " | |
| 1278 | - f"active={active}, limit={_IMAGE_MAX_INFLIGHT}" | |
| 1279 | - ), | |
| 1280 | - ) | |
| 1281 | - | |
| 1251 | + response.headers["X-User-ID"] = user_id | |
| 1282 | 1252 | request_started = time.perf_counter() |
| 1283 | 1253 | success = False |
| 1284 | 1254 | backend_elapsed_ms = 0.0 |
| 1285 | 1255 | cache_hits = 0 |
| 1286 | 1256 | cache_misses = 0 |
| 1257 | + limiter_acquired = False | |
| 1258 | + | |
| 1287 | 1259 | try: |
| 1260 | + if priority < 0: | |
| 1261 | + raise HTTPException(status_code=400, detail="priority must be >= 0") | |
| 1262 | + effective_priority = _effective_priority(priority) | |
| 1263 | + | |
| 1264 | + effective_normalize = bool(CONFIG.IMAGE_NORMALIZE_EMBEDDINGS) if normalize is None else bool(normalize) | |
| 1265 | + urls: List[str] = [] | |
| 1266 | + for i, url_or_path in enumerate(images): | |
| 1267 | + if not isinstance(url_or_path, str): | |
| 1268 | + raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: must be string URL/path") | |
| 1269 | + s = url_or_path.strip() | |
| 1270 | + if not s: | |
| 1271 | + raise HTTPException(status_code=400, detail=f"Invalid image at index {i}: empty URL/path") | |
| 1272 | + urls.append(s) | |
| 1273 | + | |
| 1274 | + cache_check_started = time.perf_counter() | |
| 1275 | + cache_only = _try_full_image_cache_hit(urls, effective_normalize) | |
| 1276 | + if cache_only is not None: | |
| 1277 | + latency_ms = (time.perf_counter() - cache_check_started) * 1000.0 | |
| 1278 | + _image_stats.record_completed( | |
| 1279 | + success=True, | |
| 1280 | + latency_ms=latency_ms, | |
| 1281 | + backend_latency_ms=0.0, | |
| 1282 | + cache_hits=cache_only.cache_hits, | |
| 1283 | + cache_misses=0, | |
| 1284 | + ) | |
| 1285 | + logger.info( | |
| 1286 | + "embed_image response | mode=cache-only priority=%s inputs=%d normalize=%s dim=%d cache_hits=%d cache_misses=0 first_vector=%s latency_ms=%.2f", | |
| 1287 | + _priority_label(effective_priority), | |
| 1288 | + len(urls), | |
| 1289 | + effective_normalize, | |
| 1290 | + len(cache_only.vectors[0]) if cache_only.vectors and cache_only.vectors[0] is not None else 0, | |
| 1291 | + cache_only.cache_hits, | |
| 1292 | + _preview_vector(cache_only.vectors[0] if cache_only.vectors else None), | |
| 1293 | + latency_ms, | |
| 1294 | + extra=build_request_log_extra(request_id, user_id), | |
| 1295 | + ) | |
| 1296 | + return cache_only.vectors | |
| 1297 | + | |
| 1298 | + accepted, active = _image_request_limiter.try_acquire(bypass_limit=effective_priority > 0) | |
| 1299 | + if not accepted: | |
| 1300 | + _image_stats.record_rejected() | |
| 1301 | + logger.warning( | |
| 1302 | + "embed_image rejected | client=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", | |
| 1303 | + _request_client(http_request), | |
| 1304 | + _priority_label(effective_priority), | |
| 1305 | + len(urls), | |
| 1306 | + effective_normalize, | |
| 1307 | + active, | |
| 1308 | + _IMAGE_MAX_INFLIGHT, | |
| 1309 | + _preview_inputs(urls, _LOG_PREVIEW_COUNT, _LOG_IMAGE_PREVIEW_CHARS), | |
| 1310 | + extra=build_request_log_extra(request_id, user_id), | |
| 1311 | + ) | |
| 1312 | + raise HTTPException( | |
| 1313 | + status_code=_OVERLOAD_STATUS_CODE, | |
| 1314 | + detail=( | |
| 1315 | + "Image embedding service busy for priority=0 requests: " | |
| 1316 | + f"active={active}, limit={_IMAGE_MAX_INFLIGHT}" | |
| 1317 | + ), | |
| 1318 | + ) | |
| 1319 | + limiter_acquired = True | |
| 1288 | 1320 | logger.info( |
| 1289 | 1321 | "embed_image request | client=%s priority=%s inputs=%d normalize=%s active=%d limit=%d preview=%s", |
| 1290 | 1322 | _request_client(http_request), |
| ... | ... | @@ -1294,16 +1326,16 @@ async def embed_image( |
| 1294 | 1326 | active, |
| 1295 | 1327 | _IMAGE_MAX_INFLIGHT, |
| 1296 | 1328 | _preview_inputs(urls, _LOG_PREVIEW_COUNT, _LOG_IMAGE_PREVIEW_CHARS), |
| 1297 | - extra=_request_log_extra(request_id), | |
| 1329 | + extra=build_request_log_extra(request_id, user_id), | |
| 1298 | 1330 | ) |
| 1299 | 1331 | verbose_logger.info( |
| 1300 | 1332 | "embed_image detail | payload=%s normalize=%s priority=%s", |
| 1301 | 1333 | urls, |
| 1302 | 1334 | effective_normalize, |
| 1303 | 1335 | _priority_label(effective_priority), |
| 1304 | - extra=_request_log_extra(request_id), | |
| 1336 | + extra=build_request_log_extra(request_id, user_id), | |
| 1305 | 1337 | ) |
| 1306 | - result = await run_in_threadpool(_embed_image_impl, urls, effective_normalize, request_id) | |
| 1338 | + result = await run_in_threadpool(_embed_image_impl, urls, effective_normalize, request_id, user_id) | |
| 1307 | 1339 | success = True |
| 1308 | 1340 | backend_elapsed_ms = result.backend_elapsed_ms |
| 1309 | 1341 | cache_hits = result.cache_hits |
| ... | ... | @@ -1327,7 +1359,7 @@ async def embed_image( |
| 1327 | 1359 | cache_misses, |
| 1328 | 1360 | _preview_vector(result.vectors[0] if result.vectors else None), |
| 1329 | 1361 | latency_ms, |
| 1330 | - extra=_request_log_extra(request_id), | |
| 1362 | + extra=build_request_log_extra(request_id, user_id), | |
| 1331 | 1363 | ) |
| 1332 | 1364 | verbose_logger.info( |
| 1333 | 1365 | "embed_image result detail | count=%d first_vector=%s latency_ms=%.2f", |
| ... | ... | @@ -1336,7 +1368,7 @@ async def embed_image( |
| 1336 | 1368 | if result.vectors and result.vectors[0] is not None |
| 1337 | 1369 | else [], |
| 1338 | 1370 | latency_ms, |
| 1339 | - extra=_request_log_extra(request_id), | |
| 1371 | + extra=build_request_log_extra(request_id, user_id), | |
| 1340 | 1372 | ) |
| 1341 | 1373 | return result.vectors |
| 1342 | 1374 | except HTTPException: |
| ... | ... | @@ -1358,15 +1390,17 @@ async def embed_image( |
| 1358 | 1390 | latency_ms, |
| 1359 | 1391 | e, |
| 1360 | 1392 | exc_info=True, |
| 1361 | - extra=_request_log_extra(request_id), | |
| 1393 | + extra=build_request_log_extra(request_id, user_id), | |
| 1362 | 1394 | ) |
| 1363 | 1395 | raise HTTPException(status_code=502, detail=f"Image embedding backend failure: {e}") from e |
| 1364 | 1396 | finally: |
| 1365 | - remaining = _image_request_limiter.release(success=success) | |
| 1366 | - logger.info( | |
| 1367 | - "embed_image finalize | success=%s priority=%s active_after=%d", | |
| 1368 | - success, | |
| 1369 | - _priority_label(effective_priority), | |
| 1370 | - remaining, | |
| 1371 | - extra=_request_log_extra(request_id), | |
| 1372 | - ) | |
| 1397 | + if limiter_acquired: | |
| 1398 | + remaining = _image_request_limiter.release(success=success) | |
| 1399 | + logger.info( | |
| 1400 | + "embed_image finalize | success=%s priority=%s active_after=%d", | |
| 1401 | + success, | |
| 1402 | + _priority_label(effective_priority), | |
| 1403 | + remaining, | |
| 1404 | + extra=build_request_log_extra(request_id, user_id), | |
| 1405 | + ) | |
| 1406 | + reset_request_log_context(log_tokens) | ... | ... |
embeddings/text_embedding_tei.py
| ... | ... | @@ -2,11 +2,14 @@ |
| 2 | 2 | |
| 3 | 3 | from __future__ import annotations |
| 4 | 4 | |
| 5 | +import logging | |
| 5 | 6 | from typing import Any, List, Union |
| 6 | 7 | |
| 7 | 8 | import numpy as np |
| 8 | 9 | import requests |
| 9 | 10 | |
| 11 | +logger = logging.getLogger(__name__) | |
| 12 | + | |
| 10 | 13 | |
| 11 | 14 | class TEITextModel: |
| 12 | 15 | """ |
| ... | ... | @@ -18,12 +21,13 @@ class TEITextModel: |
| 18 | 21 | response: [[...], [...], ...] |
| 19 | 22 | """ |
| 20 | 23 | |
| 21 | - def __init__(self, base_url: str, timeout_sec: int = 60): | |
| 24 | + def __init__(self, base_url: str, timeout_sec: int = 60, max_client_batch_size: int = 24): | |
| 22 | 25 | if not base_url or not str(base_url).strip(): |
| 23 | 26 | raise ValueError("TEI base_url must not be empty") |
| 24 | 27 | self.base_url = str(base_url).rstrip("/") |
| 25 | 28 | self.endpoint = f"{self.base_url}/embed" |
| 26 | 29 | self.timeout_sec = int(timeout_sec) |
| 30 | + self.max_client_batch_size = max(1, int(max_client_batch_size)) | |
| 27 | 31 | self._health_check() |
| 28 | 32 | |
| 29 | 33 | def _health_check(self) -> None: |
| ... | ... | @@ -72,16 +76,28 @@ class TEITextModel: |
| 72 | 76 | if not isinstance(t, str) or not t.strip(): |
| 73 | 77 | raise ValueError(f"Invalid input text at index {i}: {t!r}") |
| 74 | 78 | |
| 75 | - response = requests.post( | |
| 76 | - self.endpoint, | |
| 77 | - json={"inputs": texts}, | |
| 78 | - timeout=self.timeout_sec, | |
| 79 | - ) | |
| 80 | - response.raise_for_status() | |
| 81 | - payload = response.json() | |
| 82 | - vectors = self._parse_payload(payload, expected_len=len(texts)) | |
| 83 | - if normalize_embeddings: | |
| 84 | - vectors = [self._normalize(vec) for vec in vectors] | |
| 79 | + if len(texts) > self.max_client_batch_size: | |
| 80 | + logger.info( | |
| 81 | + "TEI batch split | total_inputs=%d chunk_size=%d chunks=%d", | |
| 82 | + len(texts), | |
| 83 | + self.max_client_batch_size, | |
| 84 | + (len(texts) + self.max_client_batch_size - 1) // self.max_client_batch_size, | |
| 85 | + ) | |
| 86 | + | |
| 87 | + vectors: List[np.ndarray] = [] | |
| 88 | + for start in range(0, len(texts), self.max_client_batch_size): | |
| 89 | + batch = texts[start : start + self.max_client_batch_size] | |
| 90 | + response = requests.post( | |
| 91 | + self.endpoint, | |
| 92 | + json={"inputs": batch}, | |
| 93 | + timeout=self.timeout_sec, | |
| 94 | + ) | |
| 95 | + response.raise_for_status() | |
| 96 | + payload = response.json() | |
| 97 | + parsed = self._parse_payload(payload, expected_len=len(batch)) | |
| 98 | + if normalize_embeddings: | |
| 99 | + parsed = [self._normalize(vec) for vec in parsed] | |
| 100 | + vectors.extend(parsed) | |
| 85 | 101 | return np.array(vectors, dtype=object) |
| 86 | 102 | |
| 87 | 103 | def _parse_payload(self, payload: Any, expected_len: int) -> List[np.ndarray]: | ... | ... |
embeddings/text_encoder.py
| ... | ... | @@ -13,6 +13,7 @@ from config.loader import get_app_config |
| 13 | 13 | from config.services_config import get_embedding_text_base_url |
| 14 | 14 | from embeddings.cache_keys import build_text_cache_key |
| 15 | 15 | from embeddings.redis_embedding_cache import RedisEmbeddingCache |
| 16 | +from request_log_context import build_downstream_request_headers, build_request_log_extra | |
| 16 | 17 | |
| 17 | 18 | |
| 18 | 19 | class TextEmbeddingEncoder: |
| ... | ... | @@ -40,6 +41,8 @@ class TextEmbeddingEncoder: |
| 40 | 41 | request_data: List[str], |
| 41 | 42 | normalize_embeddings: bool = True, |
| 42 | 43 | priority: int = 0, |
| 44 | + request_id: Optional[str] = None, | |
| 45 | + user_id: Optional[str] = None, | |
| 43 | 46 | ) -> List[Any]: |
| 44 | 47 | """ |
| 45 | 48 | Call the embedding service API. |
| ... | ... | @@ -50,6 +53,7 @@ class TextEmbeddingEncoder: |
| 50 | 53 | Returns: |
| 51 | 54 | List of embeddings (list[float]) or nulls (None), aligned to input order |
| 52 | 55 | """ |
| 56 | + response = None | |
| 53 | 57 | try: |
| 54 | 58 | response = requests.post( |
| 55 | 59 | self.endpoint, |
| ... | ... | @@ -58,12 +62,26 @@ class TextEmbeddingEncoder: |
| 58 | 62 | "priority": max(0, int(priority)), |
| 59 | 63 | }, |
| 60 | 64 | json=request_data, |
| 65 | + headers=build_downstream_request_headers(request_id=request_id, user_id=user_id), | |
| 61 | 66 | timeout=60 |
| 62 | 67 | ) |
| 63 | 68 | response.raise_for_status() |
| 64 | 69 | return response.json() |
| 65 | 70 | except requests.exceptions.RequestException as e: |
| 66 | - logger.error(f"TextEmbeddingEncoder service request failed: {e}", exc_info=True) | |
| 71 | + body_preview = "" | |
| 72 | + if response is not None: | |
| 73 | + try: | |
| 74 | + body_preview = (response.text or "")[:300] | |
| 75 | + except Exception: | |
| 76 | + body_preview = "" | |
| 77 | + logger.error( | |
| 78 | + "TextEmbeddingEncoder service request failed | status=%s body=%s error=%s", | |
| 79 | + getattr(response, "status_code", "n/a"), | |
| 80 | + body_preview, | |
| 81 | + e, | |
| 82 | + exc_info=True, | |
| 83 | + extra=build_request_log_extra(request_id=request_id, user_id=user_id), | |
| 84 | + ) | |
| 67 | 85 | raise |
| 68 | 86 | |
| 69 | 87 | def encode( |
| ... | ... | @@ -72,7 +90,9 @@ class TextEmbeddingEncoder: |
| 72 | 90 | normalize_embeddings: bool = True, |
| 73 | 91 | priority: int = 0, |
| 74 | 92 | device: str = 'cpu', |
| 75 | - batch_size: int = 32 | |
| 93 | + batch_size: int = 32, | |
| 94 | + request_id: Optional[str] = None, | |
| 95 | + user_id: Optional[str] = None, | |
| 76 | 96 | ) -> np.ndarray: |
| 77 | 97 | """ |
| 78 | 98 | Encode text into embeddings via network service with Redis caching. |
| ... | ... | @@ -113,6 +133,8 @@ class TextEmbeddingEncoder: |
| 113 | 133 | request_data, |
| 114 | 134 | normalize_embeddings=normalize_embeddings, |
| 115 | 135 | priority=priority, |
| 136 | + request_id=request_id, | |
| 137 | + user_id=user_id, | |
| 116 | 138 | ) |
| 117 | 139 | |
| 118 | 140 | # Process response | ... | ... |
query/query_parser.py
| ... | ... | @@ -12,7 +12,6 @@ from dataclasses import dataclass, field |
| 12 | 12 | from typing import Any, Callable, Dict, List, Optional, Tuple |
| 13 | 13 | import numpy as np |
| 14 | 14 | import logging |
| 15 | -import re | |
| 16 | 15 | from concurrent.futures import ThreadPoolExecutor, wait |
| 17 | 16 | |
| 18 | 17 | from embeddings.text_encoder import TextEmbeddingEncoder |
| ... | ... | @@ -20,25 +19,14 @@ from config import SearchConfig |
| 20 | 19 | from translation import create_translation_client |
| 21 | 20 | from .language_detector import LanguageDetector |
| 22 | 21 | from .query_rewriter import QueryRewriter, QueryNormalizer |
| 22 | +from .style_intent import StyleIntentDetector, StyleIntentProfile, StyleIntentRegistry | |
| 23 | +from .tokenization import extract_token_strings, simple_tokenize_query | |
| 23 | 24 | |
| 24 | 25 | logger = logging.getLogger(__name__) |
| 25 | 26 | |
| 26 | 27 | import hanlp # type: ignore |
| 27 | 28 | |
| 28 | 29 | |
| 29 | -def simple_tokenize_query(text: str) -> List[str]: | |
| 30 | - """ | |
| 31 | - Lightweight tokenizer for suggestion-side heuristics only. | |
| 32 | - | |
| 33 | - - Consecutive CJK characters form one token | |
| 34 | - - Latin / digit runs (with internal hyphens) form tokens | |
| 35 | - """ | |
| 36 | - if not text: | |
| 37 | - return [] | |
| 38 | - pattern = re.compile(r"[\u4e00-\u9fff]+|[A-Za-z0-9_]+(?:-[A-Za-z0-9_]+)*") | |
| 39 | - return pattern.findall(text) | |
| 40 | - | |
| 41 | - | |
| 42 | 30 | @dataclass(slots=True) |
| 43 | 31 | class ParsedQuery: |
| 44 | 32 | """Container for query parser facts.""" |
| ... | ... | @@ -50,8 +38,7 @@ class ParsedQuery: |
| 50 | 38 | translations: Dict[str, str] = field(default_factory=dict) |
| 51 | 39 | query_vector: Optional[np.ndarray] = None |
| 52 | 40 | query_tokens: List[str] = field(default_factory=list) |
| 53 | - contains_chinese: bool = False | |
| 54 | - contains_english: bool = False | |
| 41 | + style_intent_profile: Optional[StyleIntentProfile] = None | |
| 55 | 42 | |
| 56 | 43 | def to_dict(self) -> Dict[str, Any]: |
| 57 | 44 | """Convert to dictionary representation.""" |
| ... | ... | @@ -62,8 +49,9 @@ class ParsedQuery: |
| 62 | 49 | "detected_language": self.detected_language, |
| 63 | 50 | "translations": self.translations, |
| 64 | 51 | "query_tokens": self.query_tokens, |
| 65 | - "contains_chinese": self.contains_chinese, | |
| 66 | - "contains_english": self.contains_english, | |
| 52 | + "style_intent_profile": ( | |
| 53 | + self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None | |
| 54 | + ), | |
| 67 | 55 | } |
| 68 | 56 | |
| 69 | 57 | |
| ... | ... | @@ -101,6 +89,11 @@ class QueryParser: |
| 101 | 89 | self.language_detector = LanguageDetector() |
| 102 | 90 | self.rewriter = QueryRewriter(config.query_config.rewrite_dictionary) |
| 103 | 91 | self._tokenizer = tokenizer or self._build_tokenizer() |
| 92 | + self.style_intent_registry = StyleIntentRegistry.from_query_config(config.query_config) | |
| 93 | + self.style_intent_detector = StyleIntentDetector( | |
| 94 | + self.style_intent_registry, | |
| 95 | + tokenizer=self._tokenizer, | |
| 96 | + ) | |
| 104 | 97 | |
| 105 | 98 | # Eager initialization (startup-time failure visibility, no lazy init in request path) |
| 106 | 99 | if self.config.query_config.enable_text_embedding and self._text_encoder is None: |
| ... | ... | @@ -176,47 +169,11 @@ class QueryParser: |
| 176 | 169 | @staticmethod |
| 177 | 170 | def _extract_tokens(tokenizer_result: Any) -> List[str]: |
| 178 | 171 | """Normalize tokenizer output into a flat token string list.""" |
| 179 | - if not tokenizer_result: | |
| 180 | - return [] | |
| 181 | - if isinstance(tokenizer_result, str): | |
| 182 | - token = tokenizer_result.strip() | |
| 183 | - return [token] if token else [] | |
| 184 | - | |
| 185 | - tokens: List[str] = [] | |
| 186 | - for item in tokenizer_result: | |
| 187 | - token: Optional[str] = None | |
| 188 | - if isinstance(item, str): | |
| 189 | - token = item | |
| 190 | - elif isinstance(item, (list, tuple)) and item: | |
| 191 | - token = str(item[0]) | |
| 192 | - elif item is not None: | |
| 193 | - token = str(item) | |
| 194 | - | |
| 195 | - if token is None: | |
| 196 | - continue | |
| 197 | - token = token.strip() | |
| 198 | - if token: | |
| 199 | - tokens.append(token) | |
| 200 | - return tokens | |
| 172 | + return extract_token_strings(tokenizer_result) | |
| 201 | 173 | |
| 202 | 174 | def _get_query_tokens(self, query: str) -> List[str]: |
| 203 | 175 | return self._extract_tokens(self._tokenizer(query)) |
| 204 | 176 | |
| 205 | - @staticmethod | |
| 206 | - def _contains_cjk(text: str) -> bool: | |
| 207 | - """Whether query contains any CJK ideograph.""" | |
| 208 | - return bool(re.search(r"[\u4e00-\u9fff]", text or "")) | |
| 209 | - | |
| 210 | - @staticmethod | |
| 211 | - def _is_pure_english_word_token(token: str) -> bool: | |
| 212 | - """ | |
| 213 | - A tokenizer token counts as English iff it is letters only (optional internal hyphens) | |
| 214 | - and length >= 3. | |
| 215 | - """ | |
| 216 | - if not token or len(token) < 3: | |
| 217 | - return False | |
| 218 | - return bool(re.fullmatch(r"[A-Za-z]+(?:-[A-Za-z]+)*", token)) | |
| 219 | - | |
| 220 | 177 | def parse( |
| 221 | 178 | self, |
| 222 | 179 | query: str, |
| ... | ... | @@ -285,19 +242,12 @@ class QueryParser: |
| 285 | 242 | log_info(f"Language detection | Detected language: {detected_lang}") |
| 286 | 243 | if context: |
| 287 | 244 | context.store_intermediate_result('detected_language', detected_lang) |
| 288 | - # Stage 4: Query analysis (tokenization + script flags) | |
| 245 | + # Stage 4: Query analysis (tokenization) | |
| 289 | 246 | query_tokens = self._get_query_tokens(query_text) |
| 290 | - contains_chinese = self._contains_cjk(query_text) | |
| 291 | - contains_english = any(self._is_pure_english_word_token(t) for t in query_tokens) | |
| 292 | 247 | |
| 293 | - log_debug( | |
| 294 | - f"Query analysis | Query tokens: {query_tokens} | " | |
| 295 | - f"contains_chinese={contains_chinese} | contains_english={contains_english}" | |
| 296 | - ) | |
| 248 | + log_debug(f"Query analysis | Query tokens: {query_tokens}") | |
| 297 | 249 | if context: |
| 298 | 250 | context.store_intermediate_result('query_tokens', query_tokens) |
| 299 | - context.store_intermediate_result('contains_chinese', contains_chinese) | |
| 300 | - context.store_intermediate_result('contains_english', contains_english) | |
| 301 | 251 | |
| 302 | 252 | # Stage 5: Translation + embedding. Parser only coordinates async enrichment work; the |
| 303 | 253 | # caller decides translation targets and later search-field planning. |
| ... | ... | @@ -351,7 +301,12 @@ class QueryParser: |
| 351 | 301 | log_debug("Submitting query vector generation") |
| 352 | 302 | |
| 353 | 303 | def _encode_query_vector() -> Optional[np.ndarray]: |
| 354 | - arr = self.text_encoder.encode([query_text], priority=1) | |
| 304 | + arr = self.text_encoder.encode( | |
| 305 | + [query_text], | |
| 306 | + priority=1, | |
| 307 | + request_id=(context.reqid if context else None), | |
| 308 | + user_id=(context.uid if context else None), | |
| 309 | + ) | |
| 355 | 310 | if arr is None or len(arr) == 0: |
| 356 | 311 | return None |
| 357 | 312 | vec = arr[0] |
| ... | ... | @@ -451,6 +406,22 @@ class QueryParser: |
| 451 | 406 | context.store_intermediate_result("translations", translations) |
| 452 | 407 | |
| 453 | 408 | # Build result |
| 409 | + base_result = ParsedQuery( | |
| 410 | + original_query=query, | |
| 411 | + query_normalized=normalized, | |
| 412 | + rewritten_query=query_text, | |
| 413 | + detected_language=detected_lang, | |
| 414 | + translations=translations, | |
| 415 | + query_vector=query_vector, | |
| 416 | + query_tokens=query_tokens, | |
| 417 | + ) | |
| 418 | + style_intent_profile = self.style_intent_detector.detect(base_result) | |
| 419 | + if context: | |
| 420 | + context.store_intermediate_result( | |
| 421 | + "style_intent_profile", | |
| 422 | + style_intent_profile.to_dict(), | |
| 423 | + ) | |
| 424 | + | |
| 454 | 425 | result = ParsedQuery( |
| 455 | 426 | original_query=query, |
| 456 | 427 | query_normalized=normalized, |
| ... | ... | @@ -459,8 +430,7 @@ class QueryParser: |
| 459 | 430 | translations=translations, |
| 460 | 431 | query_vector=query_vector, |
| 461 | 432 | query_tokens=query_tokens, |
| 462 | - contains_chinese=contains_chinese, | |
| 463 | - contains_english=contains_english, | |
| 433 | + style_intent_profile=style_intent_profile, | |
| 464 | 434 | ) |
| 465 | 435 | |
| 466 | 436 | if context and hasattr(context, 'logger'): | ... | ... |
| ... | ... | @@ -0,0 +1,261 @@ |
| 1 | +""" | |
| 2 | +Style intent detection for query understanding. | |
| 3 | +""" | |
| 4 | + | |
| 5 | +from __future__ import annotations | |
| 6 | + | |
| 7 | +from dataclasses import dataclass, field | |
| 8 | +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple | |
| 9 | + | |
| 10 | +from .tokenization import TokenizedText, normalize_query_text, tokenize_text | |
| 11 | + | |
| 12 | + | |
| 13 | +@dataclass(frozen=True) | |
| 14 | +class StyleIntentDefinition: | |
| 15 | + intent_type: str | |
| 16 | + term_groups: Tuple[Tuple[str, ...], ...] | |
| 17 | + dimension_aliases: Tuple[str, ...] | |
| 18 | + synonym_to_canonical: Dict[str, str] | |
| 19 | + max_term_ngram: int = 3 | |
| 20 | + | |
| 21 | + @classmethod | |
| 22 | + def from_rows( | |
| 23 | + cls, | |
| 24 | + intent_type: str, | |
| 25 | + rows: Sequence[Sequence[str]], | |
| 26 | + dimension_aliases: Sequence[str], | |
| 27 | + ) -> "StyleIntentDefinition": | |
| 28 | + term_groups: List[Tuple[str, ...]] = [] | |
| 29 | + synonym_to_canonical: Dict[str, str] = {} | |
| 30 | + max_ngram = 1 | |
| 31 | + | |
| 32 | + for row in rows: | |
| 33 | + normalized_terms: List[str] = [] | |
| 34 | + for raw_term in row: | |
| 35 | + term = normalize_query_text(raw_term) | |
| 36 | + if not term or term in normalized_terms: | |
| 37 | + continue | |
| 38 | + normalized_terms.append(term) | |
| 39 | + if not normalized_terms: | |
| 40 | + continue | |
| 41 | + | |
| 42 | + canonical = normalized_terms[0] | |
| 43 | + term_groups.append(tuple(normalized_terms)) | |
| 44 | + for term in normalized_terms: | |
| 45 | + synonym_to_canonical[term] = canonical | |
| 46 | + max_ngram = max(max_ngram, len(term.split())) | |
| 47 | + | |
| 48 | + aliases = tuple( | |
| 49 | + dict.fromkeys( | |
| 50 | + term | |
| 51 | + for term in ( | |
| 52 | + normalize_query_text(alias) | |
| 53 | + for alias in dimension_aliases | |
| 54 | + ) | |
| 55 | + if term | |
| 56 | + ) | |
| 57 | + ) | |
| 58 | + | |
| 59 | + return cls( | |
| 60 | + intent_type=intent_type, | |
| 61 | + term_groups=tuple(term_groups), | |
| 62 | + dimension_aliases=aliases, | |
| 63 | + synonym_to_canonical=synonym_to_canonical, | |
| 64 | + max_term_ngram=max_ngram, | |
| 65 | + ) | |
| 66 | + | |
| 67 | + def match_candidates(self, candidates: Iterable[str]) -> Set[str]: | |
| 68 | + matched: Set[str] = set() | |
| 69 | + for candidate in candidates: | |
| 70 | + canonical = self.synonym_to_canonical.get(normalize_query_text(candidate)) | |
| 71 | + if canonical: | |
| 72 | + matched.add(canonical) | |
| 73 | + return matched | |
| 74 | + | |
| 75 | + def match_text( | |
| 76 | + self, | |
| 77 | + text: str, | |
| 78 | + *, | |
| 79 | + tokenizer: Optional[Callable[[str], Any]] = None, | |
| 80 | + ) -> Set[str]: | |
| 81 | + bundle = tokenize_text(text, tokenizer=tokenizer, max_ngram=self.max_term_ngram) | |
| 82 | + return self.match_candidates(bundle.candidates) | |
| 83 | + | |
| 84 | + | |
| 85 | +@dataclass(frozen=True) | |
| 86 | +class DetectedStyleIntent: | |
| 87 | + intent_type: str | |
| 88 | + canonical_value: str | |
| 89 | + matched_term: str | |
| 90 | + matched_query_text: str | |
| 91 | + dimension_aliases: Tuple[str, ...] | |
| 92 | + | |
| 93 | + def to_dict(self) -> Dict[str, Any]: | |
| 94 | + return { | |
| 95 | + "intent_type": self.intent_type, | |
| 96 | + "canonical_value": self.canonical_value, | |
| 97 | + "matched_term": self.matched_term, | |
| 98 | + "matched_query_text": self.matched_query_text, | |
| 99 | + "dimension_aliases": list(self.dimension_aliases), | |
| 100 | + } | |
| 101 | + | |
| 102 | + | |
| 103 | +@dataclass(frozen=True) | |
| 104 | +class StyleIntentProfile: | |
| 105 | + query_variants: Tuple[TokenizedText, ...] = field(default_factory=tuple) | |
| 106 | + intents: Tuple[DetectedStyleIntent, ...] = field(default_factory=tuple) | |
| 107 | + | |
| 108 | + @property | |
| 109 | + def is_active(self) -> bool: | |
| 110 | + return bool(self.intents) | |
| 111 | + | |
| 112 | + def get_intents(self, intent_type: Optional[str] = None) -> List[DetectedStyleIntent]: | |
| 113 | + if intent_type is None: | |
| 114 | + return list(self.intents) | |
| 115 | + normalized = normalize_query_text(intent_type) | |
| 116 | + return [intent for intent in self.intents if intent.intent_type == normalized] | |
| 117 | + | |
| 118 | + def get_canonical_values(self, intent_type: str) -> Set[str]: | |
| 119 | + return {intent.canonical_value for intent in self.get_intents(intent_type)} | |
| 120 | + | |
| 121 | + def to_dict(self) -> Dict[str, Any]: | |
| 122 | + return { | |
| 123 | + "active": self.is_active, | |
| 124 | + "intents": [intent.to_dict() for intent in self.intents], | |
| 125 | + "query_variants": [ | |
| 126 | + { | |
| 127 | + "text": variant.text, | |
| 128 | + "normalized_text": variant.normalized_text, | |
| 129 | + "fine_tokens": list(variant.fine_tokens), | |
| 130 | + "coarse_tokens": list(variant.coarse_tokens), | |
| 131 | + "candidates": list(variant.candidates), | |
| 132 | + } | |
| 133 | + for variant in self.query_variants | |
| 134 | + ], | |
| 135 | + } | |
| 136 | + | |
| 137 | + | |
| 138 | +class StyleIntentRegistry: | |
| 139 | + """Holds style intent vocabularies and matching helpers.""" | |
| 140 | + | |
| 141 | + def __init__( | |
| 142 | + self, | |
| 143 | + definitions: Dict[str, StyleIntentDefinition], | |
| 144 | + *, | |
| 145 | + enabled: bool = True, | |
| 146 | + ) -> None: | |
| 147 | + self.definitions = definitions | |
| 148 | + self.enabled = bool(enabled) | |
| 149 | + | |
| 150 | + @classmethod | |
| 151 | + def from_query_config(cls, query_config: Any) -> "StyleIntentRegistry": | |
| 152 | + style_terms = getattr(query_config, "style_intent_terms", {}) or {} | |
| 153 | + dimension_aliases = getattr(query_config, "style_intent_dimension_aliases", {}) or {} | |
| 154 | + definitions: Dict[str, StyleIntentDefinition] = {} | |
| 155 | + | |
| 156 | + for intent_type, rows in style_terms.items(): | |
| 157 | + definition = StyleIntentDefinition.from_rows( | |
| 158 | + intent_type=normalize_query_text(intent_type), | |
| 159 | + rows=rows or [], | |
| 160 | + dimension_aliases=dimension_aliases.get(intent_type, []), | |
| 161 | + ) | |
| 162 | + if definition.synonym_to_canonical: | |
| 163 | + definitions[definition.intent_type] = definition | |
| 164 | + | |
| 165 | + return cls( | |
| 166 | + definitions, | |
| 167 | + enabled=bool(getattr(query_config, "style_intent_enabled", True)), | |
| 168 | + ) | |
| 169 | + | |
| 170 | + def get_definition(self, intent_type: str) -> Optional[StyleIntentDefinition]: | |
| 171 | + return self.definitions.get(normalize_query_text(intent_type)) | |
| 172 | + | |
| 173 | + def get_dimension_aliases(self, intent_type: str) -> Tuple[str, ...]: | |
| 174 | + definition = self.get_definition(intent_type) | |
| 175 | + return definition.dimension_aliases if definition else tuple() | |
| 176 | + | |
| 177 | + | |
| 178 | +class StyleIntentDetector: | |
| 179 | + """Detects style intents from parsed query variants.""" | |
| 180 | + | |
| 181 | + def __init__( | |
| 182 | + self, | |
| 183 | + registry: StyleIntentRegistry, | |
| 184 | + *, | |
| 185 | + tokenizer: Optional[Callable[[str], Any]] = None, | |
| 186 | + ) -> None: | |
| 187 | + self.registry = registry | |
| 188 | + self.tokenizer = tokenizer | |
| 189 | + | |
| 190 | + def _build_query_variants(self, parsed_query: Any) -> Tuple[TokenizedText, ...]: | |
| 191 | + seen = set() | |
| 192 | + variants: List[TokenizedText] = [] | |
| 193 | + texts = [ | |
| 194 | + getattr(parsed_query, "original_query", None), | |
| 195 | + getattr(parsed_query, "query_normalized", None), | |
| 196 | + getattr(parsed_query, "rewritten_query", None), | |
| 197 | + ] | |
| 198 | + | |
| 199 | + translations = getattr(parsed_query, "translations", {}) or {} | |
| 200 | + if isinstance(translations, dict): | |
| 201 | + texts.extend(translations.values()) | |
| 202 | + | |
| 203 | + for raw_text in texts: | |
| 204 | + text = str(raw_text or "").strip() | |
| 205 | + if not text: | |
| 206 | + continue | |
| 207 | + normalized = normalize_query_text(text) | |
| 208 | + if not normalized or normalized in seen: | |
| 209 | + continue | |
| 210 | + seen.add(normalized) | |
| 211 | + variants.append( | |
| 212 | + tokenize_text( | |
| 213 | + text, | |
| 214 | + tokenizer=self.tokenizer, | |
| 215 | + max_ngram=max( | |
| 216 | + (definition.max_term_ngram for definition in self.registry.definitions.values()), | |
| 217 | + default=3, | |
| 218 | + ), | |
| 219 | + ) | |
| 220 | + ) | |
| 221 | + | |
| 222 | + return tuple(variants) | |
| 223 | + | |
| 224 | + def detect(self, parsed_query: Any) -> StyleIntentProfile: | |
| 225 | + if not self.registry.enabled or not self.registry.definitions: | |
| 226 | + return StyleIntentProfile() | |
| 227 | + | |
| 228 | + query_variants = self._build_query_variants(parsed_query) | |
| 229 | + detected: List[DetectedStyleIntent] = [] | |
| 230 | + seen_pairs = set() | |
| 231 | + | |
| 232 | + for variant in query_variants: | |
| 233 | + for intent_type, definition in self.registry.definitions.items(): | |
| 234 | + matched_canonicals = definition.match_candidates(variant.candidates) | |
| 235 | + if not matched_canonicals: | |
| 236 | + continue | |
| 237 | + | |
| 238 | + for candidate in variant.candidates: | |
| 239 | + normalized_candidate = normalize_query_text(candidate) | |
| 240 | + canonical = definition.synonym_to_canonical.get(normalized_candidate) | |
| 241 | + if not canonical or canonical not in matched_canonicals: | |
| 242 | + continue | |
| 243 | + pair = (intent_type, canonical) | |
| 244 | + if pair in seen_pairs: | |
| 245 | + continue | |
| 246 | + seen_pairs.add(pair) | |
| 247 | + detected.append( | |
| 248 | + DetectedStyleIntent( | |
| 249 | + intent_type=intent_type, | |
| 250 | + canonical_value=canonical, | |
| 251 | + matched_term=normalized_candidate, | |
| 252 | + matched_query_text=variant.text, | |
| 253 | + dimension_aliases=definition.dimension_aliases, | |
| 254 | + ) | |
| 255 | + ) | |
| 256 | + break | |
| 257 | + | |
| 258 | + return StyleIntentProfile( | |
| 259 | + query_variants=query_variants, | |
| 260 | + intents=tuple(detected), | |
| 261 | + ) | ... | ... |
| ... | ... | @@ -0,0 +1,122 @@ |
| 1 | +""" | |
| 2 | +Shared tokenization helpers for query understanding. | |
| 3 | +""" | |
| 4 | + | |
| 5 | +from __future__ import annotations | |
| 6 | + | |
| 7 | +from dataclasses import dataclass | |
| 8 | +import re | |
| 9 | +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple | |
| 10 | + | |
| 11 | + | |
| 12 | +_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]+|[A-Za-z0-9_]+(?:-[A-Za-z0-9_]+)*") | |
| 13 | + | |
| 14 | + | |
| 15 | +def normalize_query_text(text: Optional[str]) -> str: | |
| 16 | + if text is None: | |
| 17 | + return "" | |
| 18 | + return " ".join(str(text).strip().casefold().split()) | |
| 19 | + | |
| 20 | + | |
| 21 | +def simple_tokenize_query(text: str) -> List[str]: | |
| 22 | + """ | |
| 23 | + Lightweight tokenizer for coarse query matching. | |
| 24 | + | |
| 25 | + - Consecutive CJK characters form one token | |
| 26 | + - Latin / digit runs (with internal hyphens) form tokens | |
| 27 | + """ | |
| 28 | + if not text: | |
| 29 | + return [] | |
| 30 | + return _TOKEN_PATTERN.findall(text) | |
| 31 | + | |
| 32 | + | |
| 33 | +def extract_token_strings(tokenizer_result: Any) -> List[str]: | |
| 34 | + """Normalize tokenizer output into a flat token string list.""" | |
| 35 | + if not tokenizer_result: | |
| 36 | + return [] | |
| 37 | + if isinstance(tokenizer_result, str): | |
| 38 | + token = tokenizer_result.strip() | |
| 39 | + return [token] if token else [] | |
| 40 | + | |
| 41 | + tokens: List[str] = [] | |
| 42 | + for item in tokenizer_result: | |
| 43 | + token: Optional[str] = None | |
| 44 | + if isinstance(item, str): | |
| 45 | + token = item | |
| 46 | + elif isinstance(item, (list, tuple)) and item: | |
| 47 | + token = str(item[0]) | |
| 48 | + elif item is not None: | |
| 49 | + token = str(item) | |
| 50 | + | |
| 51 | + if token is None: | |
| 52 | + continue | |
| 53 | + token = token.strip() | |
| 54 | + if token: | |
| 55 | + tokens.append(token) | |
| 56 | + return tokens | |
| 57 | + | |
| 58 | + | |
| 59 | +def _dedupe_preserve_order(values: Iterable[str]) -> List[str]: | |
| 60 | + result: List[str] = [] | |
| 61 | + seen = set() | |
| 62 | + for value in values: | |
| 63 | + normalized = normalize_query_text(value) | |
| 64 | + if not normalized or normalized in seen: | |
| 65 | + continue | |
| 66 | + seen.add(normalized) | |
| 67 | + result.append(normalized) | |
| 68 | + return result | |
| 69 | + | |
| 70 | + | |
| 71 | +def _build_phrase_candidates(tokens: Sequence[str], max_ngram: int) -> List[str]: | |
| 72 | + if not tokens: | |
| 73 | + return [] | |
| 74 | + | |
| 75 | + phrases: List[str] = [] | |
| 76 | + upper = max(1, int(max_ngram)) | |
| 77 | + for size in range(1, upper + 1): | |
| 78 | + if size > len(tokens): | |
| 79 | + break | |
| 80 | + for start in range(0, len(tokens) - size + 1): | |
| 81 | + phrase = " ".join(tokens[start:start + size]).strip() | |
| 82 | + if phrase: | |
| 83 | + phrases.append(phrase) | |
| 84 | + return phrases | |
| 85 | + | |
| 86 | + | |
| 87 | +@dataclass(frozen=True) | |
| 88 | +class TokenizedText: | |
| 89 | + text: str | |
| 90 | + normalized_text: str | |
| 91 | + fine_tokens: Tuple[str, ...] | |
| 92 | + coarse_tokens: Tuple[str, ...] | |
| 93 | + candidates: Tuple[str, ...] | |
| 94 | + | |
| 95 | + | |
| 96 | +def tokenize_text( | |
| 97 | + text: str, | |
| 98 | + *, | |
| 99 | + tokenizer: Optional[Callable[[str], Any]] = None, | |
| 100 | + max_ngram: int = 3, | |
| 101 | +) -> TokenizedText: | |
| 102 | + normalized_text = normalize_query_text(text) | |
| 103 | + coarse_tokens = _dedupe_preserve_order(simple_tokenize_query(text)) | |
| 104 | + | |
| 105 | + fine_raw = extract_token_strings(tokenizer(text)) if tokenizer is not None and text else [] | |
| 106 | + fine_tokens = _dedupe_preserve_order(fine_raw) | |
| 107 | + | |
| 108 | + candidates = _dedupe_preserve_order( | |
| 109 | + list(fine_tokens) | |
| 110 | + + list(coarse_tokens) | |
| 111 | + + _build_phrase_candidates(fine_tokens, max_ngram=max_ngram) | |
| 112 | + + _build_phrase_candidates(coarse_tokens, max_ngram=max_ngram) | |
| 113 | + + ([normalized_text] if normalized_text else []) | |
| 114 | + ) | |
| 115 | + | |
| 116 | + return TokenizedText( | |
| 117 | + text=text, | |
| 118 | + normalized_text=normalized_text, | |
| 119 | + fine_tokens=tuple(fine_tokens), | |
| 120 | + coarse_tokens=tuple(coarse_tokens), | |
| 121 | + candidates=tuple(candidates), | |
| 122 | + ) | ... | ... |
| ... | ... | @@ -0,0 +1,107 @@ |
| 1 | +""" | |
| 2 | +Request-scoped reqid/uid for logging and downstream HTTP headers. | |
| 3 | + | |
| 4 | +Kept as a **top-level module** (not under ``utils/``) because ``utils/__init__.py`` | |
| 5 | +pulls optional deps (e.g. sqlalchemy) that are not installed in ``.venv-embedding``. | |
| 6 | +Uvicorn ``--log-config`` and the embedding service must be able to import this module | |
| 7 | +without importing the full ``utils`` package. | |
| 8 | +""" | |
| 9 | + | |
| 10 | +from __future__ import annotations | |
| 11 | + | |
| 12 | +import logging | |
| 13 | +from contextvars import ContextVar, Token | |
| 14 | +from typing import Dict, Optional, Tuple | |
| 15 | + | |
| 16 | +_DEFAULT_REQUEST_ID = "-1" | |
| 17 | +_DEFAULT_USER_ID = "-1" | |
| 18 | + | |
| 19 | +_request_id_var: ContextVar[str] = ContextVar("request_log_reqid", default=_DEFAULT_REQUEST_ID) | |
| 20 | +_user_id_var: ContextVar[str] = ContextVar("request_log_uid", default=_DEFAULT_USER_ID) | |
| 21 | + | |
| 22 | +LOG_LINE_FORMAT = ( | |
| 23 | + "%(asctime)s | reqid:%(reqid)s | uid:%(uid)s | %(levelname)-8s | %(name)s | %(message)s" | |
| 24 | +) | |
| 25 | + | |
| 26 | + | |
| 27 | +def _normalize_value(value: Optional[str], *, fallback: str) -> str: | |
| 28 | + text = str(value or "").strip() | |
| 29 | + return text[:64] if text else fallback | |
| 30 | + | |
| 31 | + | |
| 32 | +def bind_request_log_context( | |
| 33 | + request_id: Optional[str] = None, | |
| 34 | + user_id: Optional[str] = None, | |
| 35 | +) -> Tuple[str, str, Tuple[Token[str], Token[str]]]: | |
| 36 | + """Bind reqid/uid to contextvars for the current execution context.""" | |
| 37 | + normalized_reqid = _normalize_value(request_id, fallback=_DEFAULT_REQUEST_ID) | |
| 38 | + normalized_uid = _normalize_value(user_id, fallback=_DEFAULT_USER_ID) | |
| 39 | + req_token = _request_id_var.set(normalized_reqid) | |
| 40 | + uid_token = _user_id_var.set(normalized_uid) | |
| 41 | + return normalized_reqid, normalized_uid, (req_token, uid_token) | |
| 42 | + | |
| 43 | + | |
| 44 | +def reset_request_log_context(tokens: Tuple[Token[str], Token[str]]) -> None: | |
| 45 | + """Reset reqid/uid contextvars back to their previous values.""" | |
| 46 | + req_token, uid_token = tokens | |
| 47 | + _request_id_var.reset(req_token) | |
| 48 | + _user_id_var.reset(uid_token) | |
| 49 | + | |
| 50 | + | |
| 51 | +def current_request_log_context() -> Tuple[str, str]: | |
| 52 | + """Return the currently bound reqid/uid pair.""" | |
| 53 | + return _request_id_var.get(), _user_id_var.get() | |
| 54 | + | |
| 55 | + | |
| 56 | +def build_request_log_extra( | |
| 57 | + request_id: Optional[str] = None, | |
| 58 | + user_id: Optional[str] = None, | |
| 59 | +) -> Dict[str, str]: | |
| 60 | + """Build logging extras, defaulting to the current bound context.""" | |
| 61 | + current_reqid, current_uid = current_request_log_context() | |
| 62 | + return { | |
| 63 | + "reqid": _normalize_value(request_id, fallback=current_reqid), | |
| 64 | + "uid": _normalize_value(user_id, fallback=current_uid), | |
| 65 | + } | |
| 66 | + | |
| 67 | + | |
| 68 | +def build_downstream_request_headers( | |
| 69 | + request_id: Optional[str] = None, | |
| 70 | + user_id: Optional[str] = None, | |
| 71 | +) -> Dict[str, str]: | |
| 72 | + """Build headers for downstream service calls when request context exists.""" | |
| 73 | + extra = build_request_log_extra(request_id=request_id, user_id=user_id) | |
| 74 | + if extra["reqid"] == _DEFAULT_REQUEST_ID and extra["uid"] == _DEFAULT_USER_ID: | |
| 75 | + return {} | |
| 76 | + headers = {"X-Request-ID": extra["reqid"]} | |
| 77 | + if extra["uid"]: | |
| 78 | + headers["X-User-ID"] = extra["uid"] | |
| 79 | + return headers | |
| 80 | + | |
| 81 | + | |
| 82 | +class RequestLogContextFilter(logging.Filter): | |
| 83 | + """Inject reqid/uid defaults into all log records.""" | |
| 84 | + | |
| 85 | + def filter(self, record: logging.LogRecord) -> bool: | |
| 86 | + reqid = getattr(record, "reqid", None) | |
| 87 | + uid = getattr(record, "uid", None) | |
| 88 | + | |
| 89 | + if reqid is None or uid is None: | |
| 90 | + bound_reqid, bound_uid = current_request_log_context() | |
| 91 | + reqid = reqid if reqid is not None else bound_reqid | |
| 92 | + uid = uid if uid is not None else bound_uid | |
| 93 | + | |
| 94 | + if reqid == _DEFAULT_REQUEST_ID and uid == _DEFAULT_USER_ID: | |
| 95 | + try: | |
| 96 | + from context.request_context import get_current_request_context | |
| 97 | + | |
| 98 | + context = get_current_request_context() | |
| 99 | + except Exception: | |
| 100 | + context = None | |
| 101 | + if context is not None: | |
| 102 | + reqid = getattr(context, "reqid", None) or reqid | |
| 103 | + uid = getattr(context, "uid", None) or uid | |
| 104 | + | |
| 105 | + record.reqid = _normalize_value(reqid, fallback=_DEFAULT_REQUEST_ID) | |
| 106 | + record.uid = _normalize_value(uid, fallback=_DEFAULT_USER_ID) | |
| 107 | + return True | ... | ... |
search/es_query_builder.py
| ... | ... | @@ -8,14 +8,11 @@ Simplified architecture: |
| 8 | 8 | - function_score wrapper for boosting fields |
| 9 | 9 | """ |
| 10 | 10 | |
| 11 | -from typing import Dict, Any, List, Optional, Union, Tuple | |
| 11 | +from typing import Dict, Any, List, Optional, Tuple | |
| 12 | 12 | |
| 13 | 13 | import numpy as np |
| 14 | 14 | from config import FunctionScoreConfig |
| 15 | 15 | |
| 16 | -# (Elasticsearch field path, boost before formatting as "path^boost") | |
| 17 | -MatchFieldSpec = Tuple[str, float] | |
| 18 | - | |
| 19 | 16 | |
| 20 | 17 | class ESQueryBuilder: |
| 21 | 18 | """Builds Elasticsearch DSL queries.""" |
| ... | ... | @@ -39,7 +36,6 @@ class ESQueryBuilder: |
| 39 | 36 | tie_breaker_base_query: float = 0.9, |
| 40 | 37 | best_fields_boosts: Optional[Dict[str, float]] = None, |
| 41 | 38 | best_fields_clause_boost: float = 2.0, |
| 42 | - mixed_script_merged_field_boost_scale: float = 0.6, | |
| 43 | 39 | phrase_field_boosts: Optional[Dict[str, float]] = None, |
| 44 | 40 | phrase_match_base_fields: Optional[Tuple[str, ...]] = None, |
| 45 | 41 | phrase_match_slop: int = 0, |
| ... | ... | @@ -60,7 +56,6 @@ class ESQueryBuilder: |
| 60 | 56 | function_score_config: Function score configuration |
| 61 | 57 | default_language: Default language to use when detection fails or returns "unknown" |
| 62 | 58 | knn_boost: Boost value for KNN (embedding recall) |
| 63 | - mixed_script_merged_field_boost_scale: Multiply per-field ^boost for cross-script merged fields | |
| 64 | 59 | """ |
| 65 | 60 | self.match_fields = match_fields |
| 66 | 61 | self.field_boosts = field_boosts or {} |
| ... | ... | @@ -77,7 +72,6 @@ class ESQueryBuilder: |
| 77 | 72 | self.translation_minimum_should_match = translation_minimum_should_match |
| 78 | 73 | self.translation_boost = float(translation_boost) |
| 79 | 74 | self.tie_breaker_base_query = float(tie_breaker_base_query) |
| 80 | - self.mixed_script_merged_field_boost_scale = float(mixed_script_merged_field_boost_scale) | |
| 81 | 75 | default_best_fields = { |
| 82 | 76 | base: self._get_field_boost(base) |
| 83 | 77 | for base in self.core_multilingual_fields |
| ... | ... | @@ -180,7 +174,6 @@ class ESQueryBuilder: |
| 180 | 174 | knn_num_candidates: int = 200, |
| 181 | 175 | min_score: Optional[float] = None, |
| 182 | 176 | parsed_query: Optional[Any] = None, |
| 183 | - index_languages: Optional[List[str]] = None, | |
| 184 | 177 | ) -> Dict[str, Any]: |
| 185 | 178 | """ |
| 186 | 179 | Build complete ES query with post_filter support for multi-select faceting. |
| ... | ... | @@ -223,11 +216,7 @@ class ESQueryBuilder: |
| 223 | 216 | # Text recall (always include if query_text exists) |
| 224 | 217 | if query_text: |
| 225 | 218 | # Unified text query strategy |
| 226 | - text_query = self._build_advanced_text_query( | |
| 227 | - query_text, | |
| 228 | - parsed_query, | |
| 229 | - index_languages=index_languages, | |
| 230 | - ) | |
| 219 | + text_query = self._build_advanced_text_query(query_text, parsed_query) | |
| 231 | 220 | recall_clauses.append(text_query) |
| 232 | 221 | |
| 233 | 222 | # Embedding recall (KNN - separate from query, handled below) |
| ... | ... | @@ -434,90 +423,36 @@ class ESQueryBuilder: |
| 434 | 423 | return float(self.field_boosts[base_field]) |
| 435 | 424 | return 1.0 |
| 436 | 425 | |
| 437 | - def _build_match_field_specs( | |
| 426 | + def _match_field_strings( | |
| 438 | 427 | self, |
| 439 | 428 | language: str, |
| 440 | 429 | *, |
| 441 | 430 | multilingual_fields: Optional[List[str]] = None, |
| 442 | 431 | shared_fields: Optional[List[str]] = None, |
| 443 | 432 | boost_overrides: Optional[Dict[str, float]] = None, |
| 444 | - ) -> List[MatchFieldSpec]: | |
| 445 | - """ | |
| 446 | - Per-language match targets as (field_path, boost). Single source of truth before | |
| 447 | - formatting as Elasticsearch ``fields`` strings. | |
| 448 | - """ | |
| 433 | + ) -> List[str]: | |
| 434 | + """Build ``multi_match`` / ``combined_fields`` field entries for one language code.""" | |
| 449 | 435 | lang = (language or "").strip().lower() |
| 450 | - specs: List[MatchFieldSpec] = [] | |
| 451 | - text_fields = multilingual_fields if multilingual_fields is not None else self.multilingual_fields | |
| 436 | + text_bases = multilingual_fields if multilingual_fields is not None else self.multilingual_fields | |
| 452 | 437 | term_fields = shared_fields if shared_fields is not None else self.shared_fields |
| 453 | 438 | overrides = boost_overrides or {} |
| 454 | - | |
| 455 | - for base in text_fields: | |
| 456 | - field = f"{base}.{lang}" | |
| 439 | + out: List[str] = [] | |
| 440 | + for base in text_bases: | |
| 441 | + path = f"{base}.{lang}" | |
| 457 | 442 | boost = float(overrides.get(base, self._get_field_boost(base, lang))) |
| 458 | - specs.append((field, boost)) | |
| 459 | - | |
| 443 | + out.append(self._format_field_with_boost(path, boost)) | |
| 460 | 444 | for shared in term_fields: |
| 461 | 445 | boost = float(overrides.get(shared, self._get_field_boost(shared, None))) |
| 462 | - specs.append((shared, boost)) | |
| 463 | - return specs | |
| 464 | - | |
| 465 | - def _format_match_field_specs(self, specs: List[MatchFieldSpec]) -> List[str]: | |
| 466 | - """Format (field_path, boost) pairs for Elasticsearch multi_match ``fields``.""" | |
| 467 | - return [self._format_field_with_boost(path, boost) for path, boost in specs] | |
| 468 | - | |
| 469 | - def _merge_supplemental_lang_field_specs( | |
| 470 | - self, | |
| 471 | - specs: List[MatchFieldSpec], | |
| 472 | - supplemental_lang: str, | |
| 473 | - ) -> List[MatchFieldSpec]: | |
| 474 | - """Append supplemental-language columns; boosts multiplied by mixed_script scale.""" | |
| 475 | - scale = float(self.mixed_script_merged_field_boost_scale) | |
| 476 | - extra_all = self._build_match_field_specs(supplemental_lang) | |
| 477 | - seen = {path for path, _ in specs} | |
| 478 | - out = list(specs) | |
| 479 | - for path, boost in extra_all: | |
| 480 | - if path not in seen: | |
| 481 | - out.append((path, boost * scale)) | |
| 482 | - seen.add(path) | |
| 483 | - return out | |
| 484 | - | |
| 485 | - def _expand_match_field_specs_for_mixed_script( | |
| 486 | - self, | |
| 487 | - lang: str, | |
| 488 | - specs: List[MatchFieldSpec], | |
| 489 | - contains_chinese: bool, | |
| 490 | - contains_english: bool, | |
| 491 | - index_languages: List[str], | |
| 492 | - is_source: bool = False | |
| 493 | - ) -> List[MatchFieldSpec]: | |
| 494 | - """ | |
| 495 | - When the query mixes scripts, widen each clause to indexed fields for the other script | |
| 496 | - (e.g. zh clause also searches title.en when the query contains an English word token). | |
| 497 | - """ | |
| 498 | - norm = {str(x or "").strip().lower() for x in (index_languages or []) if str(x or "").strip()} | |
| 499 | - allow = norm or {"zh", "en"} | |
| 500 | - | |
| 501 | - def can_use(lcode: str) -> bool: | |
| 502 | - return lcode in allow if norm else True | |
| 503 | - | |
| 504 | - out = list(specs) | |
| 505 | - lnorm = (lang or "").strip().lower() | |
| 506 | - if is_source: | |
| 507 | - if contains_english and lnorm != "en" and can_use("en"): | |
| 508 | - out = self._merge_supplemental_lang_field_specs(out, "en") | |
| 509 | - if contains_chinese and lnorm != "zh" and can_use("zh"): | |
| 510 | - out = self._merge_supplemental_lang_field_specs(out, "zh") | |
| 446 | + out.append(self._format_field_with_boost(shared, boost)) | |
| 511 | 447 | return out |
| 512 | 448 | |
| 513 | 449 | def _build_best_fields_clause(self, language: str, query_text: str) -> Optional[Dict[str, Any]]: |
| 514 | - specs = self._build_match_field_specs( | |
| 450 | + fields = self._match_field_strings( | |
| 515 | 451 | language, |
| 516 | 452 | multilingual_fields=list(self.best_fields_boosts), |
| 517 | 453 | shared_fields=[], |
| 518 | 454 | boost_overrides=self.best_fields_boosts, |
| 519 | 455 | ) |
| 520 | - fields = self._format_match_field_specs(specs) | |
| 521 | 456 | if not fields: |
| 522 | 457 | return None |
| 523 | 458 | return { |
| ... | ... | @@ -530,13 +465,12 @@ class ESQueryBuilder: |
| 530 | 465 | } |
| 531 | 466 | |
| 532 | 467 | def _build_phrase_clause(self, language: str, query_text: str) -> Optional[Dict[str, Any]]: |
| 533 | - specs = self._build_match_field_specs( | |
| 468 | + fields = self._match_field_strings( | |
| 534 | 469 | language, |
| 535 | 470 | multilingual_fields=list(self.phrase_field_boosts), |
| 536 | 471 | shared_fields=[], |
| 537 | 472 | boost_overrides=self.phrase_field_boosts, |
| 538 | 473 | ) |
| 539 | - fields = self._format_match_field_specs(specs) | |
| 540 | 474 | if not fields: |
| 541 | 475 | return None |
| 542 | 476 | clause: Dict[str, Any] = { |
| ... | ... | @@ -560,20 +494,8 @@ class ESQueryBuilder: |
| 560 | 494 | clause_name: str, |
| 561 | 495 | *, |
| 562 | 496 | is_source: bool, |
| 563 | - contains_chinese: bool, | |
| 564 | - contains_english: bool, | |
| 565 | - index_languages: List[str], | |
| 566 | 497 | ) -> Optional[Dict[str, Any]]: |
| 567 | - all_specs = self._build_match_field_specs(lang) | |
| 568 | - expanded_specs = self._expand_match_field_specs_for_mixed_script( | |
| 569 | - lang, | |
| 570 | - all_specs, | |
| 571 | - contains_chinese, | |
| 572 | - contains_english, | |
| 573 | - index_languages, | |
| 574 | - is_source, | |
| 575 | - ) | |
| 576 | - combined_fields = self._format_match_field_specs(expanded_specs) | |
| 498 | + combined_fields = self._match_field_strings(lang) | |
| 577 | 499 | if not combined_fields: |
| 578 | 500 | return None |
| 579 | 501 | minimum_should_match = ( |
| ... | ... | @@ -607,29 +529,10 @@ class ESQueryBuilder: |
| 607 | 529 | clause["bool"]["boost"] = float(self.translation_boost) |
| 608 | 530 | return clause |
| 609 | 531 | |
| 610 | - def _get_embedding_field(self, language: str) -> str: | |
| 611 | - """Get embedding field name for a language.""" | |
| 612 | - # Currently using unified embedding field | |
| 613 | - return self.text_embedding_field or "title_embedding" | |
| 614 | - | |
| 615 | - @staticmethod | |
| 616 | - def _normalize_language_list(languages: Optional[List[str]]) -> List[str]: | |
| 617 | - normalized: List[str] = [] | |
| 618 | - seen = set() | |
| 619 | - for language in languages or []: | |
| 620 | - token = str(language or "").strip().lower() | |
| 621 | - if not token or token in seen: | |
| 622 | - continue | |
| 623 | - seen.add(token) | |
| 624 | - normalized.append(token) | |
| 625 | - return normalized | |
| 626 | - | |
| 627 | 532 | def _build_advanced_text_query( |
| 628 | 533 | self, |
| 629 | 534 | query_text: str, |
| 630 | 535 | parsed_query: Optional[Any] = None, |
| 631 | - *, | |
| 632 | - index_languages: Optional[List[str]] = None, | |
| 633 | 536 | ) -> Dict[str, Any]: |
| 634 | 537 | """ |
| 635 | 538 | Build advanced text query using base and translated lexical clauses. |
| ... | ... | @@ -649,39 +552,26 @@ class ESQueryBuilder: |
| 649 | 552 | should_clauses = [] |
| 650 | 553 | source_lang = self.default_language |
| 651 | 554 | translations: Dict[str, str] = {} |
| 652 | - contains_chinese = False | |
| 653 | - contains_english = False | |
| 654 | - normalized_index_languages = self._normalize_language_list(index_languages) | |
| 655 | 555 | |
| 656 | 556 | if parsed_query: |
| 657 | 557 | detected_lang = getattr(parsed_query, "detected_language", None) |
| 658 | 558 | source_lang = detected_lang if detected_lang and detected_lang != "unknown" else self.default_language |
| 659 | 559 | translations = getattr(parsed_query, "translations", None) or {} |
| 660 | - contains_chinese = bool(getattr(parsed_query, "contains_chinese", False)) | |
| 661 | - contains_english = bool(getattr(parsed_query, "contains_english", False)) | |
| 662 | 560 | |
| 663 | 561 | source_lang = str(source_lang or self.default_language).strip().lower() or self.default_language |
| 664 | 562 | base_query_text = ( |
| 665 | 563 | getattr(parsed_query, "rewritten_query", None) if parsed_query else None |
| 666 | 564 | ) or query_text |
| 667 | 565 | |
| 668 | - def append_clause(lang: str, lang_query: str, clause_name: str, is_source: bool) -> None: | |
| 669 | - nonlocal should_clauses | |
| 670 | - clause = self._build_lexical_language_clause( | |
| 671 | - lang, | |
| 672 | - lang_query, | |
| 673 | - clause_name, | |
| 674 | - is_source=is_source, | |
| 675 | - contains_chinese=contains_chinese, | |
| 676 | - contains_english=contains_english, | |
| 677 | - index_languages=normalized_index_languages, | |
| 678 | - ) | |
| 679 | - if not clause: | |
| 680 | - return | |
| 681 | - should_clauses.append(clause) | |
| 682 | - | |
| 683 | 566 | if base_query_text: |
| 684 | - append_clause(source_lang, base_query_text, "base_query", True) | |
| 567 | + base_clause = self._build_lexical_language_clause( | |
| 568 | + source_lang, | |
| 569 | + base_query_text, | |
| 570 | + "base_query", | |
| 571 | + is_source=True, | |
| 572 | + ) | |
| 573 | + if base_clause: | |
| 574 | + should_clauses.append(base_clause) | |
| 685 | 575 | |
| 686 | 576 | for lang, translated_text in translations.items(): |
| 687 | 577 | normalized_lang = str(lang or "").strip().lower() |
| ... | ... | @@ -690,7 +580,14 @@ class ESQueryBuilder: |
| 690 | 580 | continue |
| 691 | 581 | if normalized_lang == source_lang and normalized_text == base_query_text: |
| 692 | 582 | continue |
| 693 | - append_clause(normalized_lang, normalized_text, f"base_query_trans_{normalized_lang}", False) | |
| 583 | + trans_clause = self._build_lexical_language_clause( | |
| 584 | + normalized_lang, | |
| 585 | + normalized_text, | |
| 586 | + f"base_query_trans_{normalized_lang}", | |
| 587 | + is_source=False, | |
| 588 | + ) | |
| 589 | + if trans_clause: | |
| 590 | + should_clauses.append(trans_clause) | |
| 694 | 591 | |
| 695 | 592 | # Fallback to a simple query when language fields cannot be resolved. |
| 696 | 593 | if not should_clauses: | ... | ... |
search/rerank_client.py
| ... | ... | @@ -62,11 +62,19 @@ def build_docs_from_hits( |
| 62 | 62 | need_category_path = "{category_path}" in doc_template |
| 63 | 63 | for hit in es_hits: |
| 64 | 64 | src = hit.get("_source") or {} |
| 65 | + title_suffix = str(hit.get("_style_rerank_suffix") or "").strip() | |
| 65 | 66 | if only_title: |
| 66 | - docs.append(pick_lang_text(src.get("title"))) | |
| 67 | + title = pick_lang_text(src.get("title")) | |
| 68 | + if title_suffix: | |
| 69 | + title = f"{title} {title_suffix}".strip() | |
| 70 | + docs.append(title) | |
| 67 | 71 | else: |
| 68 | 72 | values = _SafeDict( |
| 69 | - title=pick_lang_text(src.get("title")), | |
| 73 | + title=( | |
| 74 | + f"{pick_lang_text(src.get('title'))} {title_suffix}".strip() | |
| 75 | + if title_suffix | |
| 76 | + else pick_lang_text(src.get("title")) | |
| 77 | + ), | |
| 70 | 78 | brief=pick_lang_text(src.get("brief")) if need_brief else "", |
| 71 | 79 | vendor=pick_lang_text(src.get("vendor")) if need_vendor else "", |
| 72 | 80 | description=pick_lang_text(src.get("description")) if need_description else "", | ... | ... |
search/searcher.py
| ... | ... | @@ -10,12 +10,13 @@ import time, json |
| 10 | 10 | import logging |
| 11 | 11 | import hashlib |
| 12 | 12 | from string import Formatter |
| 13 | -import numpy as np | |
| 14 | 13 | |
| 15 | 14 | from utils.es_client import ESClient |
| 16 | 15 | from query import QueryParser, ParsedQuery |
| 16 | +from query.style_intent import StyleIntentRegistry | |
| 17 | 17 | from embeddings.image_encoder import CLIPImageEncoder |
| 18 | 18 | from .es_query_builder import ESQueryBuilder |
| 19 | +from .sku_intent_selector import SkuSelectionDecision, StyleSkuSelector | |
| 19 | 20 | from config import SearchConfig |
| 20 | 21 | from config.tenant_config_loader import get_tenant_config_loader |
| 21 | 22 | from context.request_context import RequestContext, RequestContextStage |
| ... | ... | @@ -115,6 +116,12 @@ class Searcher: |
| 115 | 116 | else: |
| 116 | 117 | self.image_encoder = image_encoder |
| 117 | 118 | self.source_fields = config.query_config.source_fields |
| 119 | + self.style_intent_registry = StyleIntentRegistry.from_query_config(self.config.query_config) | |
| 120 | + self.style_sku_selector = StyleSkuSelector( | |
| 121 | + self.style_intent_registry, | |
| 122 | + text_encoder_getter=lambda: getattr(self.query_parser, "text_encoder", None), | |
| 123 | + tokenizer_getter=lambda: getattr(self.query_parser, "_tokenizer", None), | |
| 124 | + ) | |
| 118 | 125 | |
| 119 | 126 | # Query builder - simplified single-layer architecture |
| 120 | 127 | self.query_builder = ESQueryBuilder( |
| ... | ... | @@ -155,7 +162,11 @@ class Searcher: |
| 155 | 162 | return |
| 156 | 163 | es_query["_source"] = {"includes": self.source_fields} |
| 157 | 164 | |
| 158 | - def _resolve_rerank_source_filter(self, doc_template: str) -> Dict[str, Any]: | |
| 165 | + def _resolve_rerank_source_filter( | |
| 166 | + self, | |
| 167 | + doc_template: str, | |
| 168 | + parsed_query: Optional[ParsedQuery] = None, | |
| 169 | + ) -> Dict[str, Any]: | |
| 159 | 170 | """ |
| 160 | 171 | Build a lightweight _source filter for rerank prefetch. |
| 161 | 172 | |
| ... | ... | @@ -182,6 +193,16 @@ class Searcher: |
| 182 | 193 | if not includes: |
| 183 | 194 | includes.add("title") |
| 184 | 195 | |
| 196 | + if self._has_style_intent(parsed_query): | |
| 197 | + includes.update( | |
| 198 | + { | |
| 199 | + "skus", | |
| 200 | + "option1_name", | |
| 201 | + "option2_name", | |
| 202 | + "option3_name", | |
| 203 | + } | |
| 204 | + ) | |
| 205 | + | |
| 185 | 206 | return {"includes": sorted(includes)} |
| 186 | 207 | |
| 187 | 208 | def _fetch_hits_by_ids( |
| ... | ... | @@ -225,256 +246,23 @@ class Searcher: |
| 225 | 246 | return hits_by_id, int(resp.get("took", 0) or 0) |
| 226 | 247 | |
| 227 | 248 | @staticmethod |
| 228 | - def _normalize_sku_match_text(value: Optional[str]) -> str: | |
| 229 | - """Normalize free text for lightweight SKU option matching.""" | |
| 230 | - if value is None: | |
| 231 | - return "" | |
| 232 | - return " ".join(str(value).strip().casefold().split()) | |
| 233 | - | |
| 234 | - @staticmethod | |
| 235 | - def _sku_option1_embedding_key( | |
| 236 | - sku: Dict[str, Any], | |
| 237 | - spu_option1_name: Optional[Any] = None, | |
| 238 | - ) -> Optional[str]: | |
| 239 | - """ | |
| 240 | - Text sent to the embedding service for option1 must be "name:value" | |
| 241 | - (option name from SKU row or SPU-level option1_name). | |
| 242 | - """ | |
| 243 | - value_raw = sku.get("option1_value") | |
| 244 | - if value_raw is None: | |
| 245 | - return None | |
| 246 | - value = str(value_raw).strip() | |
| 247 | - if not value: | |
| 248 | - return None | |
| 249 | - name = sku.get("option1_name") | |
| 250 | - if name is None or not str(name).strip(): | |
| 251 | - name = spu_option1_name | |
| 252 | - name_str = str(name).strip() if name is not None and str(name).strip() else "" | |
| 253 | - if name_str: | |
| 254 | - value = f"{name_str}:{value}" | |
| 255 | - return value.casefold() | |
| 256 | - | |
| 257 | - def _build_sku_query_texts(self, parsed_query: ParsedQuery) -> List[str]: | |
| 258 | - """Collect original and translated query texts for SKU option matching.""" | |
| 259 | - candidates: List[str] = [] | |
| 260 | - for text in ( | |
| 261 | - getattr(parsed_query, "original_query", None), | |
| 262 | - getattr(parsed_query, "query_normalized", None), | |
| 263 | - getattr(parsed_query, "rewritten_query", None), | |
| 264 | - ): | |
| 265 | - normalized = self._normalize_sku_match_text(text) | |
| 266 | - if normalized: | |
| 267 | - candidates.append(normalized) | |
| 268 | - | |
| 269 | - translations = getattr(parsed_query, "translations", {}) or {} | |
| 270 | - if isinstance(translations, dict): | |
| 271 | - for text in translations.values(): | |
| 272 | - normalized = self._normalize_sku_match_text(text) | |
| 273 | - if normalized: | |
| 274 | - candidates.append(normalized) | |
| 275 | - | |
| 276 | - deduped: List[str] = [] | |
| 277 | - seen = set() | |
| 278 | - for text in candidates: | |
| 279 | - if text in seen: | |
| 280 | - continue | |
| 281 | - seen.add(text) | |
| 282 | - deduped.append(text) | |
| 283 | - return deduped | |
| 284 | - | |
| 285 | - def _find_query_matching_sku_index( | |
| 286 | - self, | |
| 287 | - skus: List[Dict[str, Any]], | |
| 288 | - query_texts: List[str], | |
| 289 | - spu_option1_name: Optional[Any] = None, | |
| 290 | - ) -> Optional[int]: | |
| 291 | - """Return the first SKU whose option1_value (or name:value) appears in query texts.""" | |
| 292 | - if not skus or not query_texts: | |
| 293 | - return None | |
| 294 | - | |
| 295 | - for index, sku in enumerate(skus): | |
| 296 | - option1_value = self._normalize_sku_match_text(sku.get("option1_value")) | |
| 297 | - if not option1_value: | |
| 298 | - continue | |
| 299 | - if any(option1_value in query_text for query_text in query_texts): | |
| 300 | - return index | |
| 301 | - embed_key = self._sku_option1_embedding_key(sku, spu_option1_name) | |
| 302 | - if embed_key and embed_key != option1_value: | |
| 303 | - composite_norm = self._normalize_sku_match_text(embed_key.replace(":", " ")) | |
| 304 | - if any(composite_norm in query_text for query_text in query_texts): | |
| 305 | - return index | |
| 306 | - if any(embed_key.casefold() in query_text for query_text in query_texts): | |
| 307 | - return index | |
| 308 | - return None | |
| 309 | - | |
| 310 | - def _encode_query_vector_for_sku_matching( | |
| 311 | - self, | |
| 312 | - parsed_query: ParsedQuery, | |
| 313 | - context: Optional[RequestContext] = None, | |
| 314 | - ) -> Optional[np.ndarray]: | |
| 315 | - """Best-effort fallback query embedding for final-page SKU matching.""" | |
| 316 | - query_text = ( | |
| 317 | - getattr(parsed_query, "rewritten_query", None) | |
| 318 | - or getattr(parsed_query, "query_normalized", None) | |
| 319 | - or getattr(parsed_query, "original_query", None) | |
| 320 | - ) | |
| 321 | - if not query_text: | |
| 322 | - return None | |
| 323 | - | |
| 324 | - text_encoder = getattr(self.query_parser, "text_encoder", None) | |
| 325 | - if text_encoder is None: | |
| 326 | - return None | |
| 327 | - | |
| 328 | - try: | |
| 329 | - vectors = text_encoder.encode([query_text], priority=1) | |
| 330 | - except Exception as exc: | |
| 331 | - logger.warning("Failed to encode query vector for SKU matching: %s", exc, exc_info=True) | |
| 332 | - if context is not None: | |
| 333 | - context.add_warning(f"SKU query embedding failed: {exc}") | |
| 334 | - return None | |
| 335 | - | |
| 336 | - if vectors is None or len(vectors) == 0: | |
| 337 | - return None | |
| 338 | - | |
| 339 | - vector = vectors[0] | |
| 340 | - if vector is None: | |
| 341 | - return None | |
| 342 | - return np.asarray(vector, dtype=np.float32) | |
| 343 | - | |
| 344 | - def _select_sku_by_embedding( | |
| 345 | - self, | |
| 346 | - skus: List[Dict[str, Any]], | |
| 347 | - option1_vectors: Dict[str, np.ndarray], | |
| 348 | - query_vector: np.ndarray, | |
| 349 | - spu_option1_name: Optional[Any] = None, | |
| 350 | - ) -> Tuple[Optional[int], Optional[float]]: | |
| 351 | - """Select the SKU whose option1 embedding key (name:value) is most similar to the query.""" | |
| 352 | - best_index: Optional[int] = None | |
| 353 | - best_score: Optional[float] = None | |
| 354 | - | |
| 355 | - for index, sku in enumerate(skus): | |
| 356 | - embed_key = self._sku_option1_embedding_key(sku, spu_option1_name) | |
| 357 | - if not embed_key: | |
| 358 | - continue | |
| 359 | - option_vector = option1_vectors.get(embed_key) | |
| 360 | - if option_vector is None: | |
| 361 | - continue | |
| 362 | - score = float(np.inner(query_vector, option_vector)) | |
| 363 | - if best_score is None or score > best_score: | |
| 364 | - best_index = index | |
| 365 | - best_score = score | |
| 366 | - | |
| 367 | - return best_index, best_score | |
| 368 | - | |
| 369 | - @staticmethod | |
| 370 | - def _promote_matching_sku(source: Dict[str, Any], match_index: int) -> Optional[Dict[str, Any]]: | |
| 371 | - """Move the matched SKU to the front and swap the SPU image.""" | |
| 372 | - skus = source.get("skus") | |
| 373 | - if not isinstance(skus, list) or match_index < 0 or match_index >= len(skus): | |
| 374 | - return None | |
| 375 | - | |
| 376 | - matched_sku = skus.pop(match_index) | |
| 377 | - skus.insert(0, matched_sku) | |
| 249 | + def _has_style_intent(parsed_query: Optional[ParsedQuery]) -> bool: | |
| 250 | + profile = getattr(parsed_query, "style_intent_profile", None) | |
| 251 | + return bool(getattr(profile, "is_active", False)) | |
| 378 | 252 | |
| 379 | - image_src = matched_sku.get("image_src") or matched_sku.get("imageSrc") | |
| 380 | - if image_src: | |
| 381 | - source["image_url"] = image_src | |
| 382 | - return matched_sku | |
| 383 | - | |
| 384 | - def _apply_sku_sorting_for_page_hits( | |
| 253 | + def _apply_style_intent_to_hits( | |
| 385 | 254 | self, |
| 386 | 255 | es_hits: List[Dict[str, Any]], |
| 387 | 256 | parsed_query: ParsedQuery, |
| 388 | 257 | context: Optional[RequestContext] = None, |
| 389 | - ) -> None: | |
| 390 | - """Sort each page hit's SKUs so the best-matching SKU is first.""" | |
| 391 | - if not es_hits: | |
| 392 | - return | |
| 393 | - | |
| 394 | - query_texts = self._build_sku_query_texts(parsed_query) | |
| 395 | - unmatched_hits: List[Dict[str, Any]] = [] | |
| 396 | - option1_values_to_encode: List[str] = [] | |
| 397 | - seen_option1_values = set() | |
| 398 | - text_matched = 0 | |
| 399 | - embedding_matched = 0 | |
| 400 | - | |
| 401 | - for hit in es_hits: | |
| 402 | - source = hit.get("_source") | |
| 403 | - if not isinstance(source, dict): | |
| 404 | - continue | |
| 405 | - skus = source.get("skus") | |
| 406 | - if not isinstance(skus, list) or not skus: | |
| 407 | - continue | |
| 408 | - | |
| 409 | - spu_option1_name = source.get("option1_name") | |
| 410 | - match_index = self._find_query_matching_sku_index( | |
| 411 | - skus, query_texts, spu_option1_name=spu_option1_name | |
| 412 | - ) | |
| 413 | - if match_index is not None: | |
| 414 | - self._promote_matching_sku(source, match_index) | |
| 415 | - text_matched += 1 | |
| 416 | - continue | |
| 417 | - | |
| 418 | - unmatched_hits.append(hit) | |
| 419 | - for sku in skus: | |
| 420 | - embed_key = self._sku_option1_embedding_key(sku, spu_option1_name) | |
| 421 | - if not embed_key or embed_key in seen_option1_values: | |
| 422 | - continue | |
| 423 | - seen_option1_values.add(embed_key) | |
| 424 | - option1_values_to_encode.append(embed_key) | |
| 425 | - | |
| 426 | - if not unmatched_hits or not option1_values_to_encode: | |
| 427 | - return | |
| 428 | - | |
| 429 | - query_vector = getattr(parsed_query, "query_vector", None) | |
| 430 | - if query_vector is None: | |
| 431 | - query_vector = self._encode_query_vector_for_sku_matching(parsed_query, context=context) | |
| 432 | - if query_vector is None: | |
| 433 | - return | |
| 434 | - | |
| 435 | - text_encoder = getattr(self.query_parser, "text_encoder", None) | |
| 436 | - if text_encoder is None: | |
| 437 | - return | |
| 438 | - | |
| 439 | - try: | |
| 440 | - encoded_option_vectors = text_encoder.encode(option1_values_to_encode, priority=1) | |
| 441 | - except Exception as exc: | |
| 442 | - logger.warning("Failed to encode SKU option1 values for final-page sorting: %s", exc, exc_info=True) | |
| 443 | - if context is not None: | |
| 444 | - context.add_warning(f"SKU option embedding failed: {exc}") | |
| 445 | - return | |
| 446 | - | |
| 447 | - option1_vectors: Dict[str, np.ndarray] = {} | |
| 448 | - for option1_value, vector in zip(option1_values_to_encode, encoded_option_vectors): | |
| 449 | - if vector is None: | |
| 450 | - continue | |
| 451 | - option1_vectors[option1_value] = np.asarray(vector, dtype=np.float32) | |
| 452 | - | |
| 453 | - query_vector_array = np.asarray(query_vector, dtype=np.float32) | |
| 454 | - for hit in unmatched_hits: | |
| 455 | - source = hit.get("_source") | |
| 456 | - if not isinstance(source, dict): | |
| 457 | - continue | |
| 458 | - skus = source.get("skus") | |
| 459 | - if not isinstance(skus, list) or not skus: | |
| 460 | - continue | |
| 461 | - match_index, _ = self._select_sku_by_embedding( | |
| 462 | - skus, | |
| 463 | - option1_vectors, | |
| 464 | - query_vector_array, | |
| 465 | - spu_option1_name=source.get("option1_name"), | |
| 466 | - ) | |
| 467 | - if match_index is None: | |
| 468 | - continue | |
| 469 | - self._promote_matching_sku(source, match_index) | |
| 470 | - embedding_matched += 1 | |
| 471 | - | |
| 472 | - if text_matched or embedding_matched: | |
| 473 | - logger.info( | |
| 474 | - "Final-page SKU sorting completed | text_matched=%s | embedding_matched=%s", | |
| 475 | - text_matched, | |
| 476 | - embedding_matched, | |
| 258 | + ) -> Dict[str, SkuSelectionDecision]: | |
| 259 | + decisions = self.style_sku_selector.prepare_hits(es_hits, parsed_query) | |
| 260 | + if decisions and context is not None: | |
| 261 | + context.store_intermediate_result( | |
| 262 | + "style_intent_sku_decisions", | |
| 263 | + {doc_id: decision.to_dict() for doc_id, decision in decisions.items()}, | |
| 477 | 264 | ) |
| 265 | + return decisions | |
| 478 | 266 | |
| 479 | 267 | def search( |
| 480 | 268 | self, |
| ... | ... | @@ -583,7 +371,8 @@ class Searcher: |
| 583 | 371 | context.metadata['feature_flags'] = { |
| 584 | 372 | 'translation_enabled': enable_translation, |
| 585 | 373 | 'embedding_enabled': enable_embedding, |
| 586 | - 'rerank_enabled': do_rerank | |
| 374 | + 'rerank_enabled': do_rerank, | |
| 375 | + 'style_intent_enabled': bool(self.style_intent_registry.enabled), | |
| 587 | 376 | } |
| 588 | 377 | |
| 589 | 378 | # Step 1: Parse query |
| ... | ... | @@ -607,6 +396,7 @@ class Searcher: |
| 607 | 396 | domain="default", |
| 608 | 397 | is_simple_query=True |
| 609 | 398 | ) |
| 399 | + context.metadata["feature_flags"]["style_intent_active"] = self._has_style_intent(parsed_query) | |
| 610 | 400 | |
| 611 | 401 | context.logger.info( |
| 612 | 402 | f"查询解析完成 | 原查询: '{parsed_query.original_query}' | " |
| ... | ... | @@ -645,7 +435,6 @@ class Searcher: |
| 645 | 435 | enable_knn=enable_embedding and parsed_query.query_vector is not None, |
| 646 | 436 | min_score=min_score, |
| 647 | 437 | parsed_query=parsed_query, |
| 648 | - index_languages=index_langs, | |
| 649 | 438 | ) |
| 650 | 439 | |
| 651 | 440 | # Add facets for faceted search |
| ... | ... | @@ -668,7 +457,10 @@ class Searcher: |
| 668 | 457 | es_query_for_fetch = es_query |
| 669 | 458 | rerank_prefetch_source = None |
| 670 | 459 | if in_rerank_window: |
| 671 | - rerank_prefetch_source = self._resolve_rerank_source_filter(effective_doc_template) | |
| 460 | + rerank_prefetch_source = self._resolve_rerank_source_filter( | |
| 461 | + effective_doc_template, | |
| 462 | + parsed_query=parsed_query, | |
| 463 | + ) | |
| 672 | 464 | es_query_for_fetch = dict(es_query) |
| 673 | 465 | es_query_for_fetch["_source"] = rerank_prefetch_source |
| 674 | 466 | |
| ... | ... | @@ -752,6 +544,20 @@ class Searcher: |
| 752 | 544 | finally: |
| 753 | 545 | context.end_stage(RequestContextStage.ELASTICSEARCH_SEARCH_PRIMARY) |
| 754 | 546 | |
| 547 | + style_intent_decisions: Dict[str, SkuSelectionDecision] = {} | |
| 548 | + if self._has_style_intent(parsed_query) and in_rerank_window: | |
| 549 | + style_intent_decisions = self._apply_style_intent_to_hits( | |
| 550 | + es_response.get("hits", {}).get("hits") or [], | |
| 551 | + parsed_query, | |
| 552 | + context=context, | |
| 553 | + ) | |
| 554 | + if style_intent_decisions: | |
| 555 | + context.logger.info( | |
| 556 | + "款式意图 SKU 预筛选完成 | hits=%s", | |
| 557 | + len(style_intent_decisions), | |
| 558 | + extra={'reqid': context.reqid, 'uid': context.uid} | |
| 559 | + ) | |
| 560 | + | |
| 755 | 561 | # Optional Step 4.5: AI reranking(仅当请求范围在重排窗口内时执行) |
| 756 | 562 | if do_rerank and in_rerank_window: |
| 757 | 563 | context.start_stage(RequestContextStage.RERANKING) |
| ... | ... | @@ -842,6 +648,11 @@ class Searcher: |
| 842 | 648 | if "_source" in detail_hit: |
| 843 | 649 | hit["_source"] = detail_hit.get("_source") or {} |
| 844 | 650 | filled += 1 |
| 651 | + if style_intent_decisions: | |
| 652 | + self.style_sku_selector.apply_precomputed_decisions( | |
| 653 | + sliced, | |
| 654 | + style_intent_decisions, | |
| 655 | + ) | |
| 845 | 656 | if fill_took: |
| 846 | 657 | es_response["took"] = int((es_response.get("took", 0) or 0) + fill_took) |
| 847 | 658 | context.logger.info( |
| ... | ... | @@ -884,7 +695,18 @@ class Searcher: |
| 884 | 695 | continue |
| 885 | 696 | rerank_debug_by_doc[str(doc_id)] = item |
| 886 | 697 | |
| 887 | - self._apply_sku_sorting_for_page_hits(es_hits, parsed_query, context=context) | |
| 698 | + if self._has_style_intent(parsed_query): | |
| 699 | + if in_rerank_window and style_intent_decisions: | |
| 700 | + self.style_sku_selector.apply_precomputed_decisions( | |
| 701 | + es_hits, | |
| 702 | + style_intent_decisions, | |
| 703 | + ) | |
| 704 | + elif not in_rerank_window: | |
| 705 | + style_intent_decisions = self._apply_style_intent_to_hits( | |
| 706 | + es_hits, | |
| 707 | + parsed_query, | |
| 708 | + context=context, | |
| 709 | + ) | |
| 888 | 710 | |
| 889 | 711 | # Format results using ResultFormatter |
| 890 | 712 | formatted_results = ResultFormatter.format_search_results( |
| ... | ... | @@ -903,6 +725,11 @@ class Searcher: |
| 903 | 725 | rerank_debug = None |
| 904 | 726 | if doc_id is not None: |
| 905 | 727 | rerank_debug = rerank_debug_by_doc.get(str(doc_id)) |
| 728 | + style_intent_debug = None | |
| 729 | + if doc_id is not None and style_intent_decisions: | |
| 730 | + decision = style_intent_decisions.get(str(doc_id)) | |
| 731 | + if decision is not None: | |
| 732 | + style_intent_debug = decision.to_dict() | |
| 906 | 733 | |
| 907 | 734 | raw_score = hit.get("_score") |
| 908 | 735 | try: |
| ... | ... | @@ -941,6 +768,9 @@ class Searcher: |
| 941 | 768 | debug_entry["fused_score"] = rerank_debug.get("fused_score") |
| 942 | 769 | debug_entry["matched_queries"] = rerank_debug.get("matched_queries") |
| 943 | 770 | |
| 771 | + if style_intent_debug: | |
| 772 | + debug_entry["style_intent_sku"] = style_intent_debug | |
| 773 | + | |
| 944 | 774 | per_result_debug.append(debug_entry) |
| 945 | 775 | |
| 946 | 776 | # Format facets |
| ... | ... | @@ -988,7 +818,8 @@ class Searcher: |
| 988 | 818 | "translations": context.query_analysis.translations, |
| 989 | 819 | "has_vector": context.query_analysis.query_vector is not None, |
| 990 | 820 | "is_simple_query": context.query_analysis.is_simple_query, |
| 991 | - "domain": context.query_analysis.domain | |
| 821 | + "domain": context.query_analysis.domain, | |
| 822 | + "style_intent_profile": context.get_intermediate_result("style_intent_profile"), | |
| 992 | 823 | }, |
| 993 | 824 | "es_query": context.get_intermediate_result('es_query', {}), |
| 994 | 825 | "es_response": { | ... | ... |
| ... | ... | @@ -0,0 +1,405 @@ |
| 1 | +""" | |
| 2 | +SKU selection for style-intent-aware search results. | |
| 3 | +""" | |
| 4 | + | |
| 5 | +from __future__ import annotations | |
| 6 | + | |
| 7 | +from dataclasses import dataclass, field | |
| 8 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple | |
| 9 | + | |
| 10 | +import numpy as np | |
| 11 | + | |
| 12 | +from query.style_intent import StyleIntentProfile, StyleIntentRegistry | |
| 13 | +from query.tokenization import normalize_query_text | |
| 14 | + | |
| 15 | + | |
| 16 | +@dataclass(frozen=True) | |
| 17 | +class SkuSelectionDecision: | |
| 18 | + selected_sku_id: Optional[str] | |
| 19 | + rerank_suffix: str | |
| 20 | + selected_text: str | |
| 21 | + matched_stage: str | |
| 22 | + similarity_score: Optional[float] = None | |
| 23 | + resolved_dimensions: Dict[str, Optional[str]] = field(default_factory=dict) | |
| 24 | + | |
| 25 | + def to_dict(self) -> Dict[str, Any]: | |
| 26 | + return { | |
| 27 | + "selected_sku_id": self.selected_sku_id, | |
| 28 | + "rerank_suffix": self.rerank_suffix, | |
| 29 | + "selected_text": self.selected_text, | |
| 30 | + "matched_stage": self.matched_stage, | |
| 31 | + "similarity_score": self.similarity_score, | |
| 32 | + "resolved_dimensions": dict(self.resolved_dimensions), | |
| 33 | + } | |
| 34 | + | |
| 35 | + | |
| 36 | +@dataclass | |
| 37 | +class _SkuCandidate: | |
| 38 | + index: int | |
| 39 | + sku_id: str | |
| 40 | + sku: Dict[str, Any] | |
| 41 | + selection_text: str | |
| 42 | + intent_texts: Dict[str, str] | |
| 43 | + | |
| 44 | + | |
| 45 | +class StyleSkuSelector: | |
| 46 | + """Selects the best SKU for an SPU based on detected style intent.""" | |
| 47 | + | |
| 48 | + def __init__( | |
| 49 | + self, | |
| 50 | + registry: StyleIntentRegistry, | |
| 51 | + *, | |
| 52 | + text_encoder_getter: Optional[Callable[[], Any]] = None, | |
| 53 | + tokenizer_getter: Optional[Callable[[], Any]] = None, | |
| 54 | + ) -> None: | |
| 55 | + self.registry = registry | |
| 56 | + self._text_encoder_getter = text_encoder_getter | |
| 57 | + self._tokenizer_getter = tokenizer_getter | |
| 58 | + | |
| 59 | + def prepare_hits( | |
| 60 | + self, | |
| 61 | + es_hits: List[Dict[str, Any]], | |
| 62 | + parsed_query: Any, | |
| 63 | + ) -> Dict[str, SkuSelectionDecision]: | |
| 64 | + decisions: Dict[str, SkuSelectionDecision] = {} | |
| 65 | + style_profile = getattr(parsed_query, "style_intent_profile", None) | |
| 66 | + if not isinstance(style_profile, StyleIntentProfile) or not style_profile.is_active: | |
| 67 | + return decisions | |
| 68 | + | |
| 69 | + query_texts = self._build_query_texts(parsed_query, style_profile) | |
| 70 | + query_vector = self._get_query_vector(parsed_query) | |
| 71 | + tokenizer = self._get_tokenizer() | |
| 72 | + | |
| 73 | + for hit in es_hits: | |
| 74 | + source = hit.get("_source") | |
| 75 | + if not isinstance(source, dict): | |
| 76 | + continue | |
| 77 | + | |
| 78 | + decision = self._select_for_source( | |
| 79 | + source, | |
| 80 | + style_profile=style_profile, | |
| 81 | + query_texts=query_texts, | |
| 82 | + query_vector=query_vector, | |
| 83 | + tokenizer=tokenizer, | |
| 84 | + ) | |
| 85 | + if decision is None: | |
| 86 | + continue | |
| 87 | + | |
| 88 | + self._apply_decision_to_source(source, decision) | |
| 89 | + if decision.rerank_suffix: | |
| 90 | + hit["_style_rerank_suffix"] = decision.rerank_suffix | |
| 91 | + | |
| 92 | + doc_id = hit.get("_id") | |
| 93 | + if doc_id is not None: | |
| 94 | + decisions[str(doc_id)] = decision | |
| 95 | + | |
| 96 | + return decisions | |
| 97 | + | |
| 98 | + def apply_precomputed_decisions( | |
| 99 | + self, | |
| 100 | + es_hits: List[Dict[str, Any]], | |
| 101 | + decisions: Dict[str, SkuSelectionDecision], | |
| 102 | + ) -> None: | |
| 103 | + if not es_hits or not decisions: | |
| 104 | + return | |
| 105 | + | |
| 106 | + for hit in es_hits: | |
| 107 | + doc_id = hit.get("_id") | |
| 108 | + if doc_id is None: | |
| 109 | + continue | |
| 110 | + decision = decisions.get(str(doc_id)) | |
| 111 | + if decision is None: | |
| 112 | + continue | |
| 113 | + source = hit.get("_source") | |
| 114 | + if not isinstance(source, dict): | |
| 115 | + continue | |
| 116 | + self._apply_decision_to_source(source, decision) | |
| 117 | + if decision.rerank_suffix: | |
| 118 | + hit["_style_rerank_suffix"] = decision.rerank_suffix | |
| 119 | + | |
| 120 | + def _build_query_texts( | |
| 121 | + self, | |
| 122 | + parsed_query: Any, | |
| 123 | + style_profile: StyleIntentProfile, | |
| 124 | + ) -> List[str]: | |
| 125 | + texts = [variant.normalized_text for variant in style_profile.query_variants if variant.normalized_text] | |
| 126 | + if texts: | |
| 127 | + return list(dict.fromkeys(texts)) | |
| 128 | + | |
| 129 | + fallbacks: List[str] = [] | |
| 130 | + for value in ( | |
| 131 | + getattr(parsed_query, "original_query", None), | |
| 132 | + getattr(parsed_query, "query_normalized", None), | |
| 133 | + getattr(parsed_query, "rewritten_query", None), | |
| 134 | + ): | |
| 135 | + normalized = normalize_query_text(value) | |
| 136 | + if normalized: | |
| 137 | + fallbacks.append(normalized) | |
| 138 | + translations = getattr(parsed_query, "translations", {}) or {} | |
| 139 | + if isinstance(translations, dict): | |
| 140 | + for value in translations.values(): | |
| 141 | + normalized = normalize_query_text(value) | |
| 142 | + if normalized: | |
| 143 | + fallbacks.append(normalized) | |
| 144 | + return list(dict.fromkeys(fallbacks)) | |
| 145 | + | |
| 146 | + def _get_query_vector(self, parsed_query: Any) -> Optional[np.ndarray]: | |
| 147 | + query_vector = getattr(parsed_query, "query_vector", None) | |
| 148 | + if query_vector is not None: | |
| 149 | + return np.asarray(query_vector, dtype=np.float32) | |
| 150 | + | |
| 151 | + text_encoder = self._get_text_encoder() | |
| 152 | + if text_encoder is None: | |
| 153 | + return None | |
| 154 | + | |
| 155 | + query_text = ( | |
| 156 | + getattr(parsed_query, "rewritten_query", None) | |
| 157 | + or getattr(parsed_query, "query_normalized", None) | |
| 158 | + or getattr(parsed_query, "original_query", None) | |
| 159 | + ) | |
| 160 | + if not query_text: | |
| 161 | + return None | |
| 162 | + | |
| 163 | + vectors = text_encoder.encode([query_text], priority=1) | |
| 164 | + if vectors is None or len(vectors) == 0 or vectors[0] is None: | |
| 165 | + return None | |
| 166 | + return np.asarray(vectors[0], dtype=np.float32) | |
| 167 | + | |
| 168 | + def _get_text_encoder(self) -> Any: | |
| 169 | + if self._text_encoder_getter is None: | |
| 170 | + return None | |
| 171 | + return self._text_encoder_getter() | |
| 172 | + | |
| 173 | + def _get_tokenizer(self) -> Any: | |
| 174 | + if self._tokenizer_getter is None: | |
| 175 | + return None | |
| 176 | + return self._tokenizer_getter() | |
| 177 | + | |
| 178 | + @staticmethod | |
| 179 | + def _fallback_sku_text(sku: Dict[str, Any]) -> str: | |
| 180 | + parts = [] | |
| 181 | + for field_name in ("option1_value", "option2_value", "option3_value"): | |
| 182 | + value = str(sku.get(field_name) or "").strip() | |
| 183 | + if value: | |
| 184 | + parts.append(value) | |
| 185 | + return " ".join(parts) | |
| 186 | + | |
| 187 | + def _resolve_dimensions( | |
| 188 | + self, | |
| 189 | + source: Dict[str, Any], | |
| 190 | + style_profile: StyleIntentProfile, | |
| 191 | + ) -> Dict[str, Optional[str]]: | |
| 192 | + option_names = { | |
| 193 | + "option1_value": normalize_query_text(source.get("option1_name")), | |
| 194 | + "option2_value": normalize_query_text(source.get("option2_name")), | |
| 195 | + "option3_value": normalize_query_text(source.get("option3_name")), | |
| 196 | + } | |
| 197 | + resolved: Dict[str, Optional[str]] = {} | |
| 198 | + for intent in style_profile.intents: | |
| 199 | + if intent.intent_type in resolved: | |
| 200 | + continue | |
| 201 | + aliases = set(intent.dimension_aliases or self.registry.get_dimension_aliases(intent.intent_type)) | |
| 202 | + matched_field = None | |
| 203 | + for field_name, option_name in option_names.items(): | |
| 204 | + if option_name and option_name in aliases: | |
| 205 | + matched_field = field_name | |
| 206 | + break | |
| 207 | + resolved[intent.intent_type] = matched_field | |
| 208 | + return resolved | |
| 209 | + | |
| 210 | + def _build_candidates( | |
| 211 | + self, | |
| 212 | + skus: List[Dict[str, Any]], | |
| 213 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 214 | + ) -> List[_SkuCandidate]: | |
| 215 | + candidates: List[_SkuCandidate] = [] | |
| 216 | + for index, sku in enumerate(skus): | |
| 217 | + fallback_text = self._fallback_sku_text(sku) | |
| 218 | + intent_texts: Dict[str, str] = {} | |
| 219 | + for intent_type, field_name in resolved_dimensions.items(): | |
| 220 | + if field_name: | |
| 221 | + value = str(sku.get(field_name) or "").strip() | |
| 222 | + intent_texts[intent_type] = value or fallback_text | |
| 223 | + else: | |
| 224 | + intent_texts[intent_type] = fallback_text | |
| 225 | + | |
| 226 | + selection_parts: List[str] = [] | |
| 227 | + seen = set() | |
| 228 | + for value in intent_texts.values(): | |
| 229 | + normalized = normalize_query_text(value) | |
| 230 | + if not normalized or normalized in seen: | |
| 231 | + continue | |
| 232 | + seen.add(normalized) | |
| 233 | + selection_parts.append(str(value).strip()) | |
| 234 | + | |
| 235 | + selection_text = " ".join(selection_parts).strip() or fallback_text | |
| 236 | + candidates.append( | |
| 237 | + _SkuCandidate( | |
| 238 | + index=index, | |
| 239 | + sku_id=str(sku.get("sku_id") or ""), | |
| 240 | + sku=sku, | |
| 241 | + selection_text=selection_text, | |
| 242 | + intent_texts=intent_texts, | |
| 243 | + ) | |
| 244 | + ) | |
| 245 | + return candidates | |
| 246 | + | |
| 247 | + @staticmethod | |
| 248 | + def _is_direct_match( | |
| 249 | + candidate: _SkuCandidate, | |
| 250 | + query_texts: Sequence[str], | |
| 251 | + ) -> bool: | |
| 252 | + if not candidate.intent_texts or not query_texts: | |
| 253 | + return False | |
| 254 | + for value in candidate.intent_texts.values(): | |
| 255 | + normalized_value = normalize_query_text(value) | |
| 256 | + if not normalized_value: | |
| 257 | + return False | |
| 258 | + if not any(normalized_value in query_text for query_text in query_texts): | |
| 259 | + return False | |
| 260 | + return True | |
| 261 | + | |
| 262 | + def _is_generalized_match( | |
| 263 | + self, | |
| 264 | + candidate: _SkuCandidate, | |
| 265 | + style_profile: StyleIntentProfile, | |
| 266 | + tokenizer: Any, | |
| 267 | + ) -> bool: | |
| 268 | + if not candidate.intent_texts: | |
| 269 | + return False | |
| 270 | + | |
| 271 | + for intent_type, value in candidate.intent_texts.items(): | |
| 272 | + definition = self.registry.get_definition(intent_type) | |
| 273 | + if definition is None: | |
| 274 | + return False | |
| 275 | + matched_canonicals = definition.match_text(value, tokenizer=tokenizer) | |
| 276 | + if not matched_canonicals.intersection(style_profile.get_canonical_values(intent_type)): | |
| 277 | + return False | |
| 278 | + return True | |
| 279 | + | |
| 280 | + def _select_by_embedding( | |
| 281 | + self, | |
| 282 | + candidates: Sequence[_SkuCandidate], | |
| 283 | + query_vector: Optional[np.ndarray], | |
| 284 | + ) -> Tuple[Optional[_SkuCandidate], Optional[float]]: | |
| 285 | + if not candidates: | |
| 286 | + return None, None | |
| 287 | + text_encoder = self._get_text_encoder() | |
| 288 | + if query_vector is None or text_encoder is None: | |
| 289 | + return candidates[0], None | |
| 290 | + | |
| 291 | + unique_texts = list( | |
| 292 | + dict.fromkeys( | |
| 293 | + normalize_query_text(candidate.selection_text) | |
| 294 | + for candidate in candidates | |
| 295 | + if normalize_query_text(candidate.selection_text) | |
| 296 | + ) | |
| 297 | + ) | |
| 298 | + if not unique_texts: | |
| 299 | + return candidates[0], None | |
| 300 | + | |
| 301 | + vectors = text_encoder.encode(unique_texts, priority=1) | |
| 302 | + vector_map: Dict[str, np.ndarray] = {} | |
| 303 | + for key, vector in zip(unique_texts, vectors): | |
| 304 | + if vector is None: | |
| 305 | + continue | |
| 306 | + vector_map[key] = np.asarray(vector, dtype=np.float32) | |
| 307 | + | |
| 308 | + best_candidate: Optional[_SkuCandidate] = None | |
| 309 | + best_score: Optional[float] = None | |
| 310 | + query_vector_array = np.asarray(query_vector, dtype=np.float32) | |
| 311 | + for candidate in candidates: | |
| 312 | + normalized_text = normalize_query_text(candidate.selection_text) | |
| 313 | + candidate_vector = vector_map.get(normalized_text) | |
| 314 | + if candidate_vector is None: | |
| 315 | + continue | |
| 316 | + score = float(np.inner(query_vector_array, candidate_vector)) | |
| 317 | + if best_score is None or score > best_score: | |
| 318 | + best_candidate = candidate | |
| 319 | + best_score = score | |
| 320 | + | |
| 321 | + return best_candidate or candidates[0], best_score | |
| 322 | + | |
| 323 | + def _select_for_source( | |
| 324 | + self, | |
| 325 | + source: Dict[str, Any], | |
| 326 | + *, | |
| 327 | + style_profile: StyleIntentProfile, | |
| 328 | + query_texts: Sequence[str], | |
| 329 | + query_vector: Optional[np.ndarray], | |
| 330 | + tokenizer: Any, | |
| 331 | + ) -> Optional[SkuSelectionDecision]: | |
| 332 | + skus = source.get("skus") | |
| 333 | + if not isinstance(skus, list) or not skus: | |
| 334 | + return None | |
| 335 | + | |
| 336 | + resolved_dimensions = self._resolve_dimensions(source, style_profile) | |
| 337 | + candidates = self._build_candidates(skus, resolved_dimensions) | |
| 338 | + if not candidates: | |
| 339 | + return None | |
| 340 | + | |
| 341 | + direct_matches = [candidate for candidate in candidates if self._is_direct_match(candidate, query_texts)] | |
| 342 | + if len(direct_matches) == 1: | |
| 343 | + chosen = direct_matches[0] | |
| 344 | + return self._build_decision(chosen, resolved_dimensions, matched_stage="direct") | |
| 345 | + | |
| 346 | + generalized_matches: List[_SkuCandidate] = [] | |
| 347 | + if not direct_matches: | |
| 348 | + generalized_matches = [ | |
| 349 | + candidate | |
| 350 | + for candidate in candidates | |
| 351 | + if self._is_generalized_match(candidate, style_profile, tokenizer) | |
| 352 | + ] | |
| 353 | + if len(generalized_matches) == 1: | |
| 354 | + chosen = generalized_matches[0] | |
| 355 | + return self._build_decision(chosen, resolved_dimensions, matched_stage="generalized") | |
| 356 | + | |
| 357 | + embedding_pool = direct_matches or generalized_matches or candidates | |
| 358 | + chosen, similarity_score = self._select_by_embedding(embedding_pool, query_vector) | |
| 359 | + if chosen is None: | |
| 360 | + return None | |
| 361 | + stage = "embedding_from_matches" if direct_matches or generalized_matches else "embedding_from_all" | |
| 362 | + return self._build_decision( | |
| 363 | + chosen, | |
| 364 | + resolved_dimensions, | |
| 365 | + matched_stage=stage, | |
| 366 | + similarity_score=similarity_score, | |
| 367 | + ) | |
| 368 | + | |
| 369 | + @staticmethod | |
| 370 | + def _build_decision( | |
| 371 | + candidate: _SkuCandidate, | |
| 372 | + resolved_dimensions: Dict[str, Optional[str]], | |
| 373 | + *, | |
| 374 | + matched_stage: str, | |
| 375 | + similarity_score: Optional[float] = None, | |
| 376 | + ) -> SkuSelectionDecision: | |
| 377 | + return SkuSelectionDecision( | |
| 378 | + selected_sku_id=candidate.sku_id or None, | |
| 379 | + rerank_suffix=str(candidate.selection_text or "").strip(), | |
| 380 | + selected_text=str(candidate.selection_text or "").strip(), | |
| 381 | + matched_stage=matched_stage, | |
| 382 | + similarity_score=similarity_score, | |
| 383 | + resolved_dimensions=dict(resolved_dimensions), | |
| 384 | + ) | |
| 385 | + | |
| 386 | + @staticmethod | |
| 387 | + def _apply_decision_to_source(source: Dict[str, Any], decision: SkuSelectionDecision) -> None: | |
| 388 | + skus = source.get("skus") | |
| 389 | + if not isinstance(skus, list) or not skus or not decision.selected_sku_id: | |
| 390 | + return | |
| 391 | + | |
| 392 | + selected_index = None | |
| 393 | + for index, sku in enumerate(skus): | |
| 394 | + if str(sku.get("sku_id") or "") == decision.selected_sku_id: | |
| 395 | + selected_index = index | |
| 396 | + break | |
| 397 | + if selected_index is None: | |
| 398 | + return | |
| 399 | + | |
| 400 | + selected_sku = skus.pop(selected_index) | |
| 401 | + skus.insert(0, selected_sku) | |
| 402 | + | |
| 403 | + image_src = selected_sku.get("image_src") or selected_sku.get("imageSrc") | |
| 404 | + if image_src: | |
| 405 | + source["image_url"] = image_src | ... | ... |
tests/test_embedding_pipeline.py
| ... | ... | @@ -13,9 +13,11 @@ from config import ( |
| 13 | 13 | ) |
| 14 | 14 | from embeddings.text_encoder import TextEmbeddingEncoder |
| 15 | 15 | from embeddings.image_encoder import CLIPImageEncoder |
| 16 | +from embeddings.text_embedding_tei import TEITextModel | |
| 16 | 17 | from embeddings.bf16 import encode_embedding_for_redis |
| 17 | 18 | from embeddings.cache_keys import build_image_cache_key, build_text_cache_key |
| 18 | 19 | from query import QueryParser |
| 20 | +from context.request_context import create_request_context, set_current_request_context, clear_current_request_context | |
| 19 | 21 | |
| 20 | 22 | |
| 21 | 23 | class _FakeRedis: |
| ... | ... | @@ -168,6 +170,30 @@ def test_text_embedding_encoder_cache_hit(monkeypatch): |
| 168 | 170 | assert np.allclose(out[1], np.array([0.3, 0.4], dtype=np.float32)) |
| 169 | 171 | |
| 170 | 172 | |
| 173 | +def test_text_embedding_encoder_forwards_request_headers(monkeypatch): | |
| 174 | + fake_cache = _FakeEmbeddingCache() | |
| 175 | + monkeypatch.setattr("embeddings.text_encoder.RedisEmbeddingCache", lambda **kwargs: fake_cache) | |
| 176 | + | |
| 177 | + captured = {} | |
| 178 | + | |
| 179 | + def _fake_post(url, json, timeout, **kwargs): | |
| 180 | + captured["headers"] = dict(kwargs.get("headers") or {}) | |
| 181 | + return _FakeResponse([[0.1, 0.2]]) | |
| 182 | + | |
| 183 | + monkeypatch.setattr("embeddings.text_encoder.requests.post", _fake_post) | |
| 184 | + | |
| 185 | + context = create_request_context(reqid="req-ctx-1", uid="user-ctx-1") | |
| 186 | + set_current_request_context(context) | |
| 187 | + try: | |
| 188 | + encoder = TextEmbeddingEncoder(service_url="http://127.0.0.1:6005") | |
| 189 | + encoder.encode(["hello"]) | |
| 190 | + finally: | |
| 191 | + clear_current_request_context() | |
| 192 | + | |
| 193 | + assert captured["headers"]["X-Request-ID"] == "req-ctx-1" | |
| 194 | + assert captured["headers"]["X-User-ID"] == "user-ctx-1" | |
| 195 | + | |
| 196 | + | |
| 171 | 197 | def test_image_embedding_encoder_cache_hit(monkeypatch): |
| 172 | 198 | fake_cache = _FakeEmbeddingCache() |
| 173 | 199 | cached = np.array([0.5, 0.6], dtype=np.float32) |
| ... | ... | @@ -234,3 +260,37 @@ def test_query_parser_skips_query_vector_when_disabled(): |
| 234 | 260 | |
| 235 | 261 | parsed = parser.parse("red dress", tenant_id="162", generate_vector=False) |
| 236 | 262 | assert parsed.query_vector is None |
| 263 | + | |
| 264 | + | |
| 265 | +def test_tei_text_model_splits_batches_over_client_limit(monkeypatch): | |
| 266 | + monkeypatch.setattr(TEITextModel, "_health_check", lambda self: None) | |
| 267 | + calls = [] | |
| 268 | + | |
| 269 | + class _Response: | |
| 270 | + def __init__(self, payload): | |
| 271 | + self._payload = payload | |
| 272 | + | |
| 273 | + def raise_for_status(self): | |
| 274 | + return None | |
| 275 | + | |
| 276 | + def json(self): | |
| 277 | + return self._payload | |
| 278 | + | |
| 279 | + def _fake_post(url, json, timeout): | |
| 280 | + inputs = list(json["inputs"]) | |
| 281 | + calls.append(inputs) | |
| 282 | + return _Response([[float(idx)] for idx, _ in enumerate(inputs, start=1)]) | |
| 283 | + | |
| 284 | + monkeypatch.setattr("embeddings.text_embedding_tei.requests.post", _fake_post) | |
| 285 | + | |
| 286 | + model = TEITextModel( | |
| 287 | + base_url="http://127.0.0.1:8080", | |
| 288 | + timeout_sec=20, | |
| 289 | + max_client_batch_size=24, | |
| 290 | + ) | |
| 291 | + vectors = model.encode([f"text-{idx}" for idx in range(25)], normalize_embeddings=False) | |
| 292 | + | |
| 293 | + assert len(calls) == 2 | |
| 294 | + assert len(calls[0]) == 24 | |
| 295 | + assert len(calls[1]) == 1 | |
| 296 | + assert len(vectors) == 25 | ... | ... |
tests/test_es_query_builder.py
| ... | ... | @@ -9,6 +9,9 @@ from search.es_query_builder import ESQueryBuilder |
| 9 | 9 | def _builder() -> ESQueryBuilder: |
| 10 | 10 | return ESQueryBuilder( |
| 11 | 11 | match_fields=["title.en^3.0", "brief.en^1.0"], |
| 12 | + multilingual_fields=["title", "brief"], | |
| 13 | + core_multilingual_fields=["title", "brief"], | |
| 14 | + shared_fields=[], | |
| 12 | 15 | text_embedding_field="title_embedding", |
| 13 | 16 | default_language="en", |
| 14 | 17 | ) |
| ... | ... | @@ -25,10 +28,6 @@ def _lexical_clause(query_root: Dict[str, Any]) -> Dict[str, Any]: |
| 25 | 28 | raise AssertionError("no lexical bool clause in query_root") |
| 26 | 29 | |
| 27 | 30 | |
| 28 | -def _lexical_combined_fields(query_root: Dict[str, Any]) -> list: | |
| 29 | - return _lexical_clause(query_root)["must"][0]["combined_fields"]["fields"] | |
| 30 | - | |
| 31 | - | |
| 32 | 31 | def test_knn_prefilter_includes_range_filters(): |
| 33 | 32 | qb = _builder() |
| 34 | 33 | q = qb.build_query( |
| ... | ... | @@ -93,7 +92,6 @@ def test_text_query_contains_only_base_and_translation_named_queries(): |
| 93 | 92 | query_text="dress", |
| 94 | 93 | parsed_query=parsed_query, |
| 95 | 94 | enable_knn=False, |
| 96 | - index_languages=["en", "zh", "fr"], | |
| 97 | 95 | ) |
| 98 | 96 | should = q["query"]["bool"]["should"] |
| 99 | 97 | names = [clause["bool"]["_name"] for clause in should] |
| ... | ... | @@ -115,120 +113,8 @@ def test_text_query_skips_duplicate_translation_same_as_base(): |
| 115 | 113 | query_text="dress", |
| 116 | 114 | parsed_query=parsed_query, |
| 117 | 115 | enable_knn=False, |
| 118 | - index_languages=["en", "zh"], | |
| 119 | 116 | ) |
| 120 | 117 | |
| 121 | 118 | root = q["query"] |
| 122 | 119 | assert root["bool"]["_name"] == "base_query" |
| 123 | 120 | assert [clause["multi_match"]["type"] for clause in root["bool"]["should"]] == ["best_fields", "phrase"] |
| 124 | - | |
| 125 | - | |
| 126 | -def test_mixed_script_merges_en_fields_into_zh_clause(): | |
| 127 | - qb = ESQueryBuilder( | |
| 128 | - match_fields=["title.en^3.0"], | |
| 129 | - multilingual_fields=["title", "brief"], | |
| 130 | - shared_fields=[], | |
| 131 | - text_embedding_field="title_embedding", | |
| 132 | - default_language="en", | |
| 133 | - ) | |
| 134 | - parsed_query = SimpleNamespace( | |
| 135 | - rewritten_query="法式 dress", | |
| 136 | - detected_language="zh", | |
| 137 | - translations={}, | |
| 138 | - contains_chinese=True, | |
| 139 | - contains_english=True, | |
| 140 | - ) | |
| 141 | - q = qb.build_query( | |
| 142 | - query_text="法式 dress", | |
| 143 | - parsed_query=parsed_query, | |
| 144 | - enable_knn=False, | |
| 145 | - index_languages=["zh", "en"], | |
| 146 | - ) | |
| 147 | - fields = _lexical_combined_fields(q["query"]) | |
| 148 | - bases = {f.split("^", 1)[0] for f in fields} | |
| 149 | - assert "title.zh" in bases and "title.en" in bases | |
| 150 | - assert "brief.zh" in bases and "brief.en" in bases | |
| 151 | - # Merged supplemental language fields use boost * 0.6 by default. | |
| 152 | - assert "title.en^0.6" in fields | |
| 153 | - assert "brief.en^0.6" in fields | |
| 154 | - | |
| 155 | - | |
| 156 | -def test_mixed_script_merges_zh_fields_into_en_clause(): | |
| 157 | - qb = ESQueryBuilder( | |
| 158 | - match_fields=["title.en^3.0"], | |
| 159 | - multilingual_fields=["title"], | |
| 160 | - shared_fields=[], | |
| 161 | - text_embedding_field="title_embedding", | |
| 162 | - default_language="en", | |
| 163 | - ) | |
| 164 | - parsed_query = SimpleNamespace( | |
| 165 | - rewritten_query="red 连衣裙", | |
| 166 | - detected_language="en", | |
| 167 | - translations={}, | |
| 168 | - contains_chinese=True, | |
| 169 | - contains_english=True, | |
| 170 | - ) | |
| 171 | - q = qb.build_query( | |
| 172 | - query_text="red 连衣裙", | |
| 173 | - parsed_query=parsed_query, | |
| 174 | - enable_knn=False, | |
| 175 | - index_languages=["zh", "en"], | |
| 176 | - ) | |
| 177 | - fields = _lexical_combined_fields(q["query"]) | |
| 178 | - bases = {f.split("^", 1)[0] for f in fields} | |
| 179 | - assert "title.en" in bases and "title.zh" in bases | |
| 180 | - assert "title.zh^0.6" in fields | |
| 181 | - | |
| 182 | - | |
| 183 | -def test_mixed_script_merged_fields_scale_configured_boosts(): | |
| 184 | - qb = ESQueryBuilder( | |
| 185 | - match_fields=["title.en^3.0"], | |
| 186 | - multilingual_fields=["title"], | |
| 187 | - shared_fields=[], | |
| 188 | - field_boosts={"title.zh": 5.0, "title.en": 10.0}, | |
| 189 | - text_embedding_field="title_embedding", | |
| 190 | - default_language="en", | |
| 191 | - ) | |
| 192 | - parsed_query = SimpleNamespace( | |
| 193 | - rewritten_query="法式 dress", | |
| 194 | - detected_language="zh", | |
| 195 | - translations={}, | |
| 196 | - contains_chinese=True, | |
| 197 | - contains_english=True, | |
| 198 | - ) | |
| 199 | - q = qb.build_query( | |
| 200 | - query_text="法式 dress", | |
| 201 | - parsed_query=parsed_query, | |
| 202 | - enable_knn=False, | |
| 203 | - index_languages=["zh", "en"], | |
| 204 | - ) | |
| 205 | - fields = _lexical_combined_fields(q["query"]) | |
| 206 | - assert "title.zh^5.0" in fields | |
| 207 | - assert "title.en^6.0" in fields # 10.0 * 0.6 | |
| 208 | - | |
| 209 | - | |
| 210 | -def test_mixed_script_does_not_merge_en_when_not_in_index_languages(): | |
| 211 | - qb = ESQueryBuilder( | |
| 212 | - match_fields=["title.zh^3.0"], | |
| 213 | - multilingual_fields=["title"], | |
| 214 | - shared_fields=[], | |
| 215 | - text_embedding_field="title_embedding", | |
| 216 | - default_language="zh", | |
| 217 | - ) | |
| 218 | - parsed_query = SimpleNamespace( | |
| 219 | - rewritten_query="法式 dress", | |
| 220 | - detected_language="zh", | |
| 221 | - translations={}, | |
| 222 | - contains_chinese=True, | |
| 223 | - contains_english=True, | |
| 224 | - ) | |
| 225 | - q = qb.build_query( | |
| 226 | - query_text="法式 dress", | |
| 227 | - parsed_query=parsed_query, | |
| 228 | - enable_knn=False, | |
| 229 | - index_languages=["zh"], | |
| 230 | - ) | |
| 231 | - fields = _lexical_combined_fields(q["query"]) | |
| 232 | - bases = {f.split("^", 1)[0] for f in fields} | |
| 233 | - assert "title.zh" in bases | |
| 234 | - assert "title.en" not in bases | ... | ... |
tests/test_es_query_builder_text_recall_languages.py
| 1 | 1 | """ |
| 2 | 2 | ES text recall: base_query (rewritten @ detected_language) + base_query_trans_*. |
| 3 | 3 | |
| 4 | -Covers combinations of query language vs tenant index_languages, translations, | |
| 5 | -and mixed Chinese/English queries. Asserts named lexical clause boundaries, | |
| 6 | -combined_fields payloads, and per-language target fields (title.{lang}). | |
| 4 | +Covers translation routing, mixed-script queries (per-clause language fields only), | |
| 5 | +and clause naming. Asserts named lexical clause boundaries, combined_fields payloads, | |
| 6 | +and per-language target fields (title.{lang}). | |
| 7 | 7 | """ |
| 8 | 8 | |
| 9 | 9 | from types import SimpleNamespace |
| ... | ... | @@ -14,11 +14,7 @@ import numpy as np |
| 14 | 14 | from search.es_query_builder import ESQueryBuilder |
| 15 | 15 | |
| 16 | 16 | |
| 17 | -def _builder_multilingual_title_only( | |
| 18 | - *, | |
| 19 | - default_language: str = "en", | |
| 20 | - mixed_script_scale: float = 0.6, | |
| 21 | -) -> ESQueryBuilder: | |
| 17 | +def _builder_multilingual_title_only(*, default_language: str = "en") -> ESQueryBuilder: | |
| 22 | 18 | """Minimal builder: only title.{lang} for easy field assertions.""" |
| 23 | 19 | return ESQueryBuilder( |
| 24 | 20 | match_fields=["title.en^1.0"], |
| ... | ... | @@ -26,7 +22,6 @@ def _builder_multilingual_title_only( |
| 26 | 22 | shared_fields=[], |
| 27 | 23 | text_embedding_field="title_embedding", |
| 28 | 24 | default_language=default_language, |
| 29 | - mixed_script_merged_field_boost_scale=mixed_script_scale, | |
| 30 | 25 | function_score_config=None, |
| 31 | 26 | ) |
| 32 | 27 | |
| ... | ... | @@ -101,22 +96,16 @@ def _build( |
| 101 | 96 | rewritten: str, |
| 102 | 97 | detected_language: str, |
| 103 | 98 | translations: Dict[str, str], |
| 104 | - index_languages: List[str], | |
| 105 | - contains_chinese: bool = False, | |
| 106 | - contains_english: bool = False, | |
| 107 | 99 | ) -> Dict[str, Any]: |
| 108 | 100 | parsed = SimpleNamespace( |
| 109 | 101 | rewritten_query=rewritten, |
| 110 | 102 | detected_language=detected_language, |
| 111 | 103 | translations=dict(translations), |
| 112 | - contains_chinese=contains_chinese, | |
| 113 | - contains_english=contains_english, | |
| 114 | 104 | ) |
| 115 | 105 | return qb.build_query( |
| 116 | 106 | query_text=query_text, |
| 117 | 107 | parsed_query=parsed, |
| 118 | 108 | enable_knn=False, |
| 119 | - index_languages=index_languages, | |
| 120 | 109 | ) |
| 121 | 110 | |
| 122 | 111 | |
| ... | ... | @@ -131,7 +120,6 @@ def test_zh_query_index_zh_en_includes_base_zh_and_trans_en(): |
| 131 | 120 | rewritten="连衣裙", |
| 132 | 121 | detected_language="zh", |
| 133 | 122 | translations={"en": "dress"}, |
| 134 | - index_languages=["zh", "en"], | |
| 135 | 123 | ) |
| 136 | 124 | idx = _clauses_index(q) |
| 137 | 125 | assert set(idx) == {"base_query", "base_query_trans_en"} |
| ... | ... | @@ -149,7 +137,6 @@ def test_en_query_index_zh_en_includes_base_en_and_trans_zh(): |
| 149 | 137 | rewritten="dress", |
| 150 | 138 | detected_language="en", |
| 151 | 139 | translations={"zh": "连衣裙"}, |
| 152 | - index_languages=["en", "zh"], | |
| 153 | 140 | ) |
| 154 | 141 | idx = _clauses_index(q) |
| 155 | 142 | assert set(idx) == {"base_query", "base_query_trans_zh"} |
| ... | ... | @@ -167,7 +154,6 @@ def test_de_query_index_de_en_fr_includes_base_and_two_translations(): |
| 167 | 154 | rewritten="kleid", |
| 168 | 155 | detected_language="de", |
| 169 | 156 | translations={"en": "dress", "fr": "robe"}, |
| 170 | - index_languages=["de", "en", "fr"], | |
| 171 | 157 | ) |
| 172 | 158 | idx = _clauses_index(q) |
| 173 | 159 | assert set(idx) == {"base_query", "base_query_trans_en", "base_query_trans_fr"} |
| ... | ... | @@ -188,7 +174,6 @@ def test_de_query_index_only_en_zh_base_on_de_translations_on_target_fields(): |
| 188 | 174 | rewritten="schuh", |
| 189 | 175 | detected_language="de", |
| 190 | 176 | translations={"en": "shoe", "zh": "鞋"}, |
| 191 | - index_languages=["en", "zh"], | |
| 192 | 177 | ) |
| 193 | 178 | idx = _clauses_index(q) |
| 194 | 179 | assert set(idx) == {"base_query", "base_query_trans_en", "base_query_trans_zh"} |
| ... | ... | @@ -201,10 +186,10 @@ def test_de_query_index_only_en_zh_base_on_de_translations_on_target_fields(): |
| 201 | 186 | assert idx["base_query_trans_zh"]["boost"] == qb.translation_boost |
| 202 | 187 | |
| 203 | 188 | |
| 204 | -# --- 中英混写:原文在 base_query;翻译子句独立;混写时 base 子句扩列 --- | |
| 189 | +# --- 中英混写:base 打在检测语种字段;翻译子句打在译文语种字段 --- | |
| 205 | 190 | |
| 206 | 191 | |
| 207 | -def test_mixed_zh_primary_with_en_translation_merges_en_into_zh_base_clause(): | |
| 192 | +def test_mixed_zh_detected_base_clause_zh_fields_only_with_en_translation(): | |
| 208 | 193 | qb = _builder_multilingual_title_only(default_language="en") |
| 209 | 194 | q = _build( |
| 210 | 195 | qb, |
| ... | ... | @@ -212,19 +197,16 @@ def test_mixed_zh_primary_with_en_translation_merges_en_into_zh_base_clause(): |
| 212 | 197 | rewritten="红色 dress", |
| 213 | 198 | detected_language="zh", |
| 214 | 199 | translations={"en": "red dress"}, |
| 215 | - index_languages=["zh", "en"], | |
| 216 | - contains_chinese=True, | |
| 217 | - contains_english=True, | |
| 218 | 200 | ) |
| 219 | 201 | idx = _clauses_index(q) |
| 220 | 202 | assert set(idx) == {"base_query", "base_query_trans_en"} |
| 221 | 203 | assert _combined_fields_clause(idx["base_query"])["query"] == "红色 dress" |
| 222 | - assert _has_title_lang(idx["base_query"], "zh") and _has_title_lang(idx["base_query"], "en") | |
| 204 | + assert _has_title_lang(idx["base_query"], "zh") and not _has_title_lang(idx["base_query"], "en") | |
| 223 | 205 | assert _combined_fields_clause(idx["base_query_trans_en"])["query"] == "red dress" |
| 224 | 206 | assert _has_title_lang(idx["base_query_trans_en"], "en") |
| 225 | 207 | |
| 226 | 208 | |
| 227 | -def test_mixed_en_primary_with_zh_translation_merges_zh_into_en_base_clause(): | |
| 209 | +def test_mixed_en_detected_base_clause_en_fields_only_with_zh_translation(): | |
| 228 | 210 | qb = _builder_multilingual_title_only(default_language="en") |
| 229 | 211 | q = _build( |
| 230 | 212 | qb, |
| ... | ... | @@ -232,18 +214,15 @@ def test_mixed_en_primary_with_zh_translation_merges_zh_into_en_base_clause(): |
| 232 | 214 | rewritten="nike 运动鞋", |
| 233 | 215 | detected_language="en", |
| 234 | 216 | translations={"zh": "耐克运动鞋"}, |
| 235 | - index_languages=["zh", "en"], | |
| 236 | - contains_chinese=True, | |
| 237 | - contains_english=True, | |
| 238 | 217 | ) |
| 239 | 218 | idx = _clauses_index(q) |
| 240 | 219 | assert set(idx) == {"base_query", "base_query_trans_zh"} |
| 241 | 220 | assert _combined_fields_clause(idx["base_query"])["query"] == "nike 运动鞋" |
| 242 | - assert _has_title_lang(idx["base_query"], "en") and _has_title_lang(idx["base_query"], "zh") | |
| 221 | + assert _has_title_lang(idx["base_query"], "en") and not _has_title_lang(idx["base_query"], "zh") | |
| 243 | 222 | assert _combined_fields_clause(idx["base_query_trans_zh"])["query"] == "耐克运动鞋" |
| 244 | 223 | |
| 245 | 224 | |
| 246 | -def test_mixed_zh_query_index_zh_only_no_en_merge_in_base(): | |
| 225 | +def test_zh_query_no_translations_only_zh_fields(): | |
| 247 | 226 | qb = _builder_multilingual_title_only(default_language="en") |
| 248 | 227 | q = _build( |
| 249 | 228 | qb, |
| ... | ... | @@ -251,9 +230,6 @@ def test_mixed_zh_query_index_zh_only_no_en_merge_in_base(): |
| 251 | 230 | rewritten="法式 dress", |
| 252 | 231 | detected_language="zh", |
| 253 | 232 | translations={}, |
| 254 | - index_languages=["zh"], | |
| 255 | - contains_chinese=True, | |
| 256 | - contains_english=True, | |
| 257 | 233 | ) |
| 258 | 234 | idx = _clauses_index(q) |
| 259 | 235 | assert set(idx) == {"base_query"} |
| ... | ... | @@ -272,7 +248,6 @@ def test_skips_translation_when_same_lang_and_same_text_as_base(): |
| 272 | 248 | rewritten="NIKE", |
| 273 | 249 | detected_language="en", |
| 274 | 250 | translations={"en": "NIKE", "zh": "耐克"}, |
| 275 | - index_languages=["en", "zh"], | |
| 276 | 251 | ) |
| 277 | 252 | idx = _clauses_index(q) |
| 278 | 253 | assert set(idx) == {"base_query", "base_query_trans_zh"} |
| ... | ... | @@ -286,7 +261,6 @@ def test_keeps_translation_when_same_text_but_different_lang_than_base(): |
| 286 | 261 | rewritten="NIKE", |
| 287 | 262 | detected_language="en", |
| 288 | 263 | translations={"zh": "NIKE"}, |
| 289 | - index_languages=["en", "zh"], | |
| 290 | 264 | ) |
| 291 | 265 | idx = _clauses_index(q) |
| 292 | 266 | assert set(idx) == {"base_query", "base_query_trans_zh"} |
| ... | ... | @@ -304,7 +278,6 @@ def test_translation_language_key_is_normalized_case_insensitive(): |
| 304 | 278 | rewritten="dress", |
| 305 | 279 | detected_language="en", |
| 306 | 280 | translations={"ZH": "连衣裙"}, |
| 307 | - index_languages=["en", "zh"], | |
| 308 | 281 | ) |
| 309 | 282 | idx = _clauses_index(q) |
| 310 | 283 | assert "base_query_trans_zh" in idx |
| ... | ... | @@ -319,17 +292,16 @@ def test_empty_translation_value_is_skipped(): |
| 319 | 292 | rewritten="dress", |
| 320 | 293 | detected_language="en", |
| 321 | 294 | translations={"zh": " ", "fr": "robe"}, |
| 322 | - index_languages=["en", "zh", "fr"], | |
| 323 | 295 | ) |
| 324 | 296 | idx = _clauses_index(q) |
| 325 | 297 | assert "base_query_trans_zh" not in idx |
| 326 | 298 | assert "base_query_trans_fr" in idx |
| 327 | 299 | |
| 328 | 300 | |
| 329 | -# --- index_languages 为空:视为「未约束」source_in_index 为 True --- | |
| 301 | +# --- base 子句无 bool.boost;翻译子句带 translation_boost;phrase should 继承 phrase_match_boost --- | |
| 330 | 302 | |
| 331 | 303 | |
| 332 | -def test_empty_index_languages_treats_source_as_in_index_boosts(): | |
| 304 | +def test_de_base_and_en_translation_phrase_boosts(): | |
| 333 | 305 | qb = _builder_multilingual_title_only(default_language="en") |
| 334 | 306 | q = _build( |
| 335 | 307 | qb, |
| ... | ... | @@ -337,7 +309,6 @@ def test_empty_index_languages_treats_source_as_in_index_boosts(): |
| 337 | 309 | rewritten="x", |
| 338 | 310 | detected_language="de", |
| 339 | 311 | translations={"en": "y"}, |
| 340 | - index_languages=[], | |
| 341 | 312 | ) |
| 342 | 313 | idx = _clauses_index(q) |
| 343 | 314 | assert "boost" not in idx["base_query"] |
| ... | ... | @@ -359,7 +330,6 @@ def test_no_translations_only_base_query(): |
| 359 | 330 | rewritten="hello", |
| 360 | 331 | detected_language="en", |
| 361 | 332 | translations={}, |
| 362 | - index_languages=["en", "zh"], | |
| 363 | 333 | ) |
| 364 | 334 | idx = _clauses_index(q) |
| 365 | 335 | assert set(idx) == {"base_query"} |
| ... | ... | @@ -374,15 +344,12 @@ def test_text_clauses_present_alongside_knn(): |
| 374 | 344 | rewritten_query="dress", |
| 375 | 345 | detected_language="en", |
| 376 | 346 | translations={"zh": "连衣裙"}, |
| 377 | - contains_chinese=False, | |
| 378 | - contains_english=True, | |
| 379 | 347 | ) |
| 380 | 348 | q = qb.build_query( |
| 381 | 349 | query_text="dress", |
| 382 | 350 | query_vector=np.array([0.1, 0.2, 0.3], dtype=np.float32), |
| 383 | 351 | parsed_query=parsed, |
| 384 | 352 | enable_knn=True, |
| 385 | - index_languages=["en", "zh"], | |
| 386 | 353 | ) |
| 387 | 354 | assert "knn" in q |
| 388 | 355 | idx = _clauses_index(q) |
| ... | ... | @@ -396,14 +363,11 @@ def test_detected_language_unknown_falls_back_to_default_language(): |
| 396 | 363 | rewritten_query="shirt", |
| 397 | 364 | detected_language="unknown", |
| 398 | 365 | translations={"zh": "衬衫"}, |
| 399 | - contains_chinese=False, | |
| 400 | - contains_english=True, | |
| 401 | 366 | ) |
| 402 | 367 | q = qb.build_query( |
| 403 | 368 | query_text="shirt", |
| 404 | 369 | parsed_query=parsed, |
| 405 | 370 | enable_knn=False, |
| 406 | - index_languages=["en", "zh"], | |
| 407 | 371 | ) |
| 408 | 372 | idx = _clauses_index(q) |
| 409 | 373 | assert set(idx) == {"base_query", "base_query_trans_zh"} |
| ... | ... | @@ -419,7 +383,6 @@ def test_ru_query_index_ru_en_includes_base_ru_and_trans_en(): |
| 419 | 383 | rewritten="платье", |
| 420 | 384 | detected_language="ru", |
| 421 | 385 | translations={"en": "dress"}, |
| 422 | - index_languages=["ru", "en"], | |
| 423 | 386 | ) |
| 424 | 387 | idx = _clauses_index(q) |
| 425 | 388 | assert set(idx) == {"base_query", "base_query_trans_en"} |
| ... | ... | @@ -428,11 +391,8 @@ def test_ru_query_index_ru_en_includes_base_ru_and_trans_en(): |
| 428 | 391 | assert _combined_fields_clause(idx["base_query_trans_en"])["query"] == "dress" |
| 429 | 392 | |
| 430 | 393 | |
| 431 | -def test_translation_for_lang_not_listed_in_index_languages_still_generates_clause(): | |
| 432 | - """ | |
| 433 | - 当前实现:凡是 translations 里非空的条目都会生成子句; | |
| 434 | - index_languages 只约束混写扩列,不用于过滤翻译子句。 | |
| 435 | - """ | |
| 394 | +def test_translation_generates_clause_for_any_target_lang_key(): | |
| 395 | + """translations 里非空的每个语种键都会生成对应 base_query_trans_* 子句。""" | |
| 436 | 396 | qb = _builder_multilingual_title_only(default_language="en") |
| 437 | 397 | q = _build( |
| 438 | 398 | qb, |
| ... | ... | @@ -440,7 +400,6 @@ def test_translation_for_lang_not_listed_in_index_languages_still_generates_clau |
| 440 | 400 | rewritten="dress", |
| 441 | 401 | detected_language="en", |
| 442 | 402 | translations={"zh": "连衣裙", "de": "Kleid"}, |
| 443 | - index_languages=["en", "zh"], | |
| 444 | 403 | ) |
| 445 | 404 | idx = _clauses_index(q) |
| 446 | 405 | assert "base_query_trans_de" in idx |
| ... | ... | @@ -457,9 +416,6 @@ def test_mixed_detected_zh_rewrite_differs_from_query_text_uses_rewritten_in_bas |
| 457 | 416 | rewritten="红色连衣裙", |
| 458 | 417 | detected_language="zh", |
| 459 | 418 | translations={"en": "red dress"}, |
| 460 | - index_languages=["zh", "en"], | |
| 461 | - contains_chinese=True, | |
| 462 | - contains_english=False, | |
| 463 | 419 | ) |
| 464 | 420 | idx = _clauses_index(q) |
| 465 | 421 | assert _combined_fields_clause(idx["base_query"])["query"] == "红色连衣裙" | ... | ... |
tests/test_query_parser_mixed_language.py
| ... | ... | @@ -11,14 +11,6 @@ def _tokenizer(text): |
| 11 | 11 | return str(text).split() |
| 12 | 12 | |
| 13 | 13 | |
| 14 | -def test_pure_english_word_token_length_and_script(): | |
| 15 | - assert QueryParser._is_pure_english_word_token("ab") is False | |
| 16 | - assert QueryParser._is_pure_english_word_token("abc") is True | |
| 17 | - assert QueryParser._is_pure_english_word_token("wi-fi") is True | |
| 18 | - assert QueryParser._is_pure_english_word_token("连衣裙") is False | |
| 19 | - assert QueryParser._is_pure_english_word_token("ab12") is False | |
| 20 | - | |
| 21 | - | |
| 22 | 14 | def _build_config() -> SearchConfig: |
| 23 | 15 | return SearchConfig( |
| 24 | 16 | es_index_name="test_products", |
| ... | ... | @@ -36,7 +28,7 @@ def _build_config() -> SearchConfig: |
| 36 | 28 | ) |
| 37 | 29 | |
| 38 | 30 | |
| 39 | -def test_parse_adds_en_fields_for_mixed_chinese_query_with_meaningful_english(monkeypatch): | |
| 31 | +def test_parse_mixed_zh_query_translates_to_en(monkeypatch): | |
| 40 | 32 | parser = QueryParser(_build_config(), translator=_DummyTranslator(), tokenizer=_tokenizer) |
| 41 | 33 | monkeypatch.setattr(parser.language_detector, "detect", lambda text: "zh") |
| 42 | 34 | |
| ... | ... | @@ -48,15 +40,13 @@ def test_parse_adds_en_fields_for_mixed_chinese_query_with_meaningful_english(mo |
| 48 | 40 | ) |
| 49 | 41 | |
| 50 | 42 | assert result.detected_language == "zh" |
| 51 | - assert result.contains_chinese is True | |
| 52 | - assert result.contains_english is True | |
| 53 | 43 | assert result.translations == {"en": "法式 dress 连衣裙-en"} |
| 54 | 44 | assert result.query_tokens == ["法式", "dress", "连衣裙"] |
| 55 | 45 | assert not hasattr(result, "query_text_by_lang") |
| 56 | 46 | assert not hasattr(result, "search_langs") |
| 57 | 47 | |
| 58 | 48 | |
| 59 | -def test_parse_adds_zh_fields_for_english_query_when_cjk_present(monkeypatch): | |
| 49 | +def test_parse_mixed_en_query_translates_to_zh(monkeypatch): | |
| 60 | 50 | parser = QueryParser(_build_config(), translator=_DummyTranslator(), tokenizer=_tokenizer) |
| 61 | 51 | monkeypatch.setattr(parser.language_detector, "detect", lambda text: "en") |
| 62 | 52 | |
| ... | ... | @@ -68,8 +58,6 @@ def test_parse_adds_zh_fields_for_english_query_when_cjk_present(monkeypatch): |
| 68 | 58 | ) |
| 69 | 59 | |
| 70 | 60 | assert result.detected_language == "en" |
| 71 | - assert result.contains_chinese is True | |
| 72 | - assert result.contains_english is True | |
| 73 | 61 | assert result.translations == {"zh": "red 连衣裙-zh"} |
| 74 | 62 | assert result.query_tokens == ["red", "连衣裙"] |
| 75 | 63 | |
| ... | ... | @@ -87,7 +75,5 @@ def test_parse_waits_for_translation_when_source_in_index_languages(monkeypatch) |
| 87 | 75 | ) |
| 88 | 76 | |
| 89 | 77 | assert result.detected_language == "en" |
| 90 | - assert result.contains_chinese is False | |
| 91 | - assert result.contains_english is True | |
| 92 | 78 | assert result.translations.get("zh") == "off shoulder top-zh" |
| 93 | 79 | assert not hasattr(result, "source_in_index_languages") | ... | ... |
tests/test_search_rerank_window.py
| ... | ... | @@ -18,6 +18,7 @@ from config import ( |
| 18 | 18 | SearchConfig, |
| 19 | 19 | ) |
| 20 | 20 | from context import create_request_context |
| 21 | +from query.style_intent import DetectedStyleIntent, StyleIntentProfile | |
| 21 | 22 | from search.searcher import Searcher |
| 22 | 23 | |
| 23 | 24 | |
| ... | ... | @@ -30,6 +31,7 @@ class _FakeParsedQuery: |
| 30 | 31 | translations: Dict[str, str] = None |
| 31 | 32 | query_vector: Any = None |
| 32 | 33 | domain: str = "default" |
| 34 | + style_intent_profile: Any = None | |
| 33 | 35 | |
| 34 | 36 | def to_dict(self) -> Dict[str, Any]: |
| 35 | 37 | return { |
| ... | ... | @@ -39,9 +41,27 @@ class _FakeParsedQuery: |
| 39 | 41 | "detected_language": self.detected_language, |
| 40 | 42 | "translations": self.translations or {}, |
| 41 | 43 | "domain": self.domain, |
| 44 | + "style_intent_profile": ( | |
| 45 | + self.style_intent_profile.to_dict() if self.style_intent_profile is not None else None | |
| 46 | + ), | |
| 42 | 47 | } |
| 43 | 48 | |
| 44 | 49 | |
| 50 | +def _build_style_intent_profile(intent_type: str, canonical_value: str, *dimension_aliases: str) -> StyleIntentProfile: | |
| 51 | + aliases = dimension_aliases or (intent_type,) | |
| 52 | + return StyleIntentProfile( | |
| 53 | + intents=( | |
| 54 | + DetectedStyleIntent( | |
| 55 | + intent_type=intent_type, | |
| 56 | + canonical_value=canonical_value, | |
| 57 | + matched_term=canonical_value, | |
| 58 | + matched_query_text=canonical_value, | |
| 59 | + dimension_aliases=tuple(aliases), | |
| 60 | + ), | |
| 61 | + ) | |
| 62 | + ) | |
| 63 | + | |
| 64 | + | |
| 45 | 65 | class _FakeQueryParser: |
| 46 | 66 | def parse( |
| 47 | 67 | self, |
| ... | ... | @@ -340,6 +360,57 @@ def test_searcher_rerank_prefetch_source_follows_doc_template(monkeypatch): |
| 340 | 360 | assert es_client.calls[0]["body"]["_source"] == {"includes": ["brief", "title", "vendor"]} |
| 341 | 361 | |
| 342 | 362 | |
| 363 | +def test_searcher_rerank_prefetch_source_includes_sku_fields_when_style_intent_active(monkeypatch): | |
| 364 | + es_client = _FakeESClient() | |
| 365 | + searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) | |
| 366 | + context = create_request_context(reqid="t1c", uid="u1c") | |
| 367 | + | |
| 368 | + monkeypatch.setattr( | |
| 369 | + "search.searcher.get_tenant_config_loader", | |
| 370 | + lambda: SimpleNamespace(get_tenant_config=lambda tenant_id: {"index_languages": ["en"]}), | |
| 371 | + ) | |
| 372 | + monkeypatch.setattr( | |
| 373 | + "search.rerank_client.run_rerank", | |
| 374 | + lambda **kwargs: (kwargs["es_response"], None, []), | |
| 375 | + ) | |
| 376 | + | |
| 377 | + class _IntentQueryParser: | |
| 378 | + text_encoder = None | |
| 379 | + | |
| 380 | + def parse( | |
| 381 | + self, | |
| 382 | + query: str, | |
| 383 | + tenant_id: str, | |
| 384 | + generate_vector: bool, | |
| 385 | + context: Any, | |
| 386 | + target_languages: Any = None, | |
| 387 | + ): | |
| 388 | + return _FakeParsedQuery( | |
| 389 | + original_query=query, | |
| 390 | + query_normalized=query, | |
| 391 | + rewritten_query=query, | |
| 392 | + translations={}, | |
| 393 | + style_intent_profile=_build_style_intent_profile( | |
| 394 | + "color", "black", "color", "colors", "颜色" | |
| 395 | + ), | |
| 396 | + ) | |
| 397 | + | |
| 398 | + searcher.query_parser = _IntentQueryParser() | |
| 399 | + | |
| 400 | + searcher.search( | |
| 401 | + query="black dress", | |
| 402 | + tenant_id="162", | |
| 403 | + from_=0, | |
| 404 | + size=5, | |
| 405 | + context=context, | |
| 406 | + enable_rerank=None, | |
| 407 | + ) | |
| 408 | + | |
| 409 | + assert es_client.calls[0]["body"]["_source"] == { | |
| 410 | + "includes": ["option1_name", "option2_name", "option3_name", "skus", "title"] | |
| 411 | + } | |
| 412 | + | |
| 413 | + | |
| 343 | 414 | def test_searcher_skips_rerank_when_request_explicitly_false(monkeypatch): |
| 344 | 415 | es_client = _FakeESClient() |
| 345 | 416 | searcher = _build_searcher(_build_search_config(rerank_enabled=True), es_client) |
| ... | ... | @@ -434,6 +505,9 @@ def test_searcher_promotes_sku_when_option1_matches_translated_query(monkeypatch |
| 434 | 505 | query_normalized=query, |
| 435 | 506 | rewritten_query=query, |
| 436 | 507 | translations={"en": "black dress"}, |
| 508 | + style_intent_profile=_build_style_intent_profile( | |
| 509 | + "color", "black", "color", "colors", "颜色" | |
| 510 | + ), | |
| 437 | 511 | ) |
| 438 | 512 | |
| 439 | 513 | searcher.query_parser = _TranslatedQueryParser() |
| ... | ... | @@ -481,8 +555,8 @@ def test_searcher_promotes_sku_by_embedding_when_query_has_no_direct_option_matc |
| 481 | 555 | encoder = _FakeTextEncoder( |
| 482 | 556 | { |
| 483 | 557 | "linen summer dress": [0.8, 0.2], |
| 484 | - "color:red": [1.0, 0.0], | |
| 485 | - "color:blue": [0.0, 1.0], | |
| 558 | + "red": [1.0, 0.0], | |
| 559 | + "blue": [0.0, 1.0], | |
| 486 | 560 | } |
| 487 | 561 | ) |
| 488 | 562 | |
| ... | ... | @@ -503,6 +577,9 @@ def test_searcher_promotes_sku_by_embedding_when_query_has_no_direct_option_matc |
| 503 | 577 | rewritten_query=query, |
| 504 | 578 | translations={}, |
| 505 | 579 | query_vector=np.array([0.0, 1.0], dtype=np.float32), |
| 580 | + style_intent_profile=_build_style_intent_profile( | |
| 581 | + "color", "blue", "color", "colors", "颜色" | |
| 582 | + ), | |
| 506 | 583 | ) |
| 507 | 584 | |
| 508 | 585 | searcher.query_parser = _EmbeddingQueryParser() | ... | ... |
| ... | ... | @@ -0,0 +1,35 @@ |
| 1 | +from types import SimpleNamespace | |
| 2 | + | |
| 3 | +from config import QueryConfig | |
| 4 | +from query.style_intent import StyleIntentDetector, StyleIntentRegistry | |
| 5 | + | |
| 6 | + | |
| 7 | +def test_style_intent_detector_matches_original_and_translated_queries(): | |
| 8 | + query_config = QueryConfig( | |
| 9 | + style_intent_terms={ | |
| 10 | + "color": [["black", "黑色", "black"]], | |
| 11 | + "size": [["xl", "x-large", "加大码"]], | |
| 12 | + }, | |
| 13 | + style_intent_dimension_aliases={ | |
| 14 | + "color": ["color", "颜色"], | |
| 15 | + "size": ["size", "尺码"], | |
| 16 | + }, | |
| 17 | + ) | |
| 18 | + detector = StyleIntentDetector( | |
| 19 | + StyleIntentRegistry.from_query_config(query_config), | |
| 20 | + tokenizer=lambda text: text.split(), | |
| 21 | + ) | |
| 22 | + | |
| 23 | + parsed_query = SimpleNamespace( | |
| 24 | + original_query="黑色 连衣裙", | |
| 25 | + query_normalized="黑色 连衣裙", | |
| 26 | + rewritten_query="黑色 连衣裙", | |
| 27 | + translations={"en": "black dress xl"}, | |
| 28 | + ) | |
| 29 | + | |
| 30 | + profile = detector.detect(parsed_query) | |
| 31 | + | |
| 32 | + assert profile.is_active is True | |
| 33 | + assert profile.get_canonical_values("color") == {"black"} | |
| 34 | + assert profile.get_canonical_values("size") == {"xl"} | |
| 35 | + assert len(profile.query_variants) == 2 | ... | ... |
utils/logger.py
| ... | ... | @@ -14,6 +14,8 @@ from datetime import datetime |
| 14 | 14 | from typing import Any, Dict, Optional |
| 15 | 15 | from pathlib import Path |
| 16 | 16 | |
| 17 | +from request_log_context import LOG_LINE_FORMAT, RequestLogContextFilter | |
| 18 | + | |
| 17 | 19 | |
| 18 | 20 | class StructuredFormatter(logging.Formatter): |
| 19 | 21 | """Structured JSON formatter with request context support""" |
| ... | ... | @@ -89,25 +91,6 @@ def _log_with_context(logger: logging.Logger, level: int, msg: str, **kwargs): |
| 89 | 91 | logging.setLogRecordFactory(old_factory) |
| 90 | 92 | |
| 91 | 93 | |
| 92 | -class RequestContextFilter(logging.Filter): | |
| 93 | - """Filter that automatically injects request context from thread-local storage""" | |
| 94 | - | |
| 95 | - def filter(self, record: logging.LogRecord) -> bool: | |
| 96 | - """Inject request context from thread-local storage""" | |
| 97 | - try: | |
| 98 | - # Import here to avoid circular imports | |
| 99 | - from context.request_context import get_current_request_context | |
| 100 | - context = get_current_request_context() | |
| 101 | - if context: | |
| 102 | - # Ensure every request-scoped log record carries reqid/uid. | |
| 103 | - # If they are missing in the context, fall back to "-1". | |
| 104 | - record.reqid = getattr(context, "reqid", None) or "-1" | |
| 105 | - record.uid = getattr(context, "uid", None) or "-1" | |
| 106 | - except (ImportError, AttributeError): | |
| 107 | - pass | |
| 108 | - return True | |
| 109 | - | |
| 110 | - | |
| 111 | 94 | class ContextAwareConsoleFormatter(logging.Formatter): |
| 112 | 95 | """ |
| 113 | 96 | Console formatter that injects reqid/uid into the log line. |
| ... | ... | @@ -156,9 +139,7 @@ def setup_logging( |
| 156 | 139 | |
| 157 | 140 | # Create formatters |
| 158 | 141 | structured_formatter = StructuredFormatter() |
| 159 | - console_formatter = ContextAwareConsoleFormatter( | |
| 160 | - '%(asctime)s | reqid:%(reqid)s | uid:%(uid)s | %(levelname)-8s | %(name)-15s | %(message)s' | |
| 161 | - ) | |
| 142 | + console_formatter = ContextAwareConsoleFormatter(LOG_LINE_FORMAT) | |
| 162 | 143 | |
| 163 | 144 | # Add console handler |
| 164 | 145 | if enable_console: | ... | ... |