Commit 41f0b2e9fb4fe2301a797139f1d9530783139f72

Authored by tangwang
1 parent 86d8358b

product_enrich支持并发

api/routes/admin.py
... ... @@ -3,7 +3,6 @@ Admin API routes for configuration and management.
3 3 """
4 4  
5 5 from fastapi import APIRouter, HTTPException, Request
6   -from typing import Dict
7 6  
8 7 from ..models import HealthResponse, ErrorResponse
9 8 from indexer.mapping_generator import get_tenant_index_name
... ... @@ -74,49 +73,6 @@ async def get_configuration_meta():
74 73 raise HTTPException(status_code=500, detail=str(e))
75 74  
76 75  
77   -@router.post("/rewrite-rules")
78   -async def update_rewrite_rules(rules: Dict[str, str]):
79   - """
80   - Update query rewrite rules.
81   -
82   - Args:
83   - rules: Dictionary of pattern -> replacement mappings
84   - """
85   - try:
86   - from ..app import get_query_parser
87   -
88   - query_parser = get_query_parser()
89   - query_parser.update_rewrite_rules(rules)
90   -
91   - return {
92   - "status": "success",
93   - "message": f"Updated {len(rules)} rewrite rules"
94   - }
95   -
96   - except Exception as e:
97   - raise HTTPException(status_code=500, detail=str(e))
98   -
99   -
100   -@router.get("/rewrite-rules")
101   -async def get_rewrite_rules():
102   - """
103   - Get current query rewrite rules.
104   - """
105   - try:
106   - from ..app import get_query_parser
107   -
108   - query_parser = get_query_parser()
109   - rules = query_parser.get_rewrite_rules()
110   -
111   - return {
112   - "rules": rules,
113   - "count": len(rules)
114   - }
115   -
116   - except Exception as e:
117   - raise HTTPException(status_code=500, detail=str(e))
118   -
119   -
120 76 @router.get("/stats")
121 77 async def get_index_stats(http_request: Request):
122 78 """
... ...
config/config.yaml
... ... @@ -9,6 +9,10 @@ es_index_name: "search_products"
9 9 assets:
10 10 query_rewrite_dictionary_path: "config/dictionaries/query_rewrite.dict"
11 11  
  12 +# Product content understanding (LLM enrich-content) configuration
  13 +product_enrich:
  14 + max_workers: 40
  15 +
12 16 # ES Index Settings (基础设置)
13 17 es_settings:
14 18 number_of_shards: 1
... ...
config/loader.py
... ... @@ -34,6 +34,7 @@ from config.schema import (
34 34 IndexConfig,
35 35 InfrastructureConfig,
36 36 QueryConfig,
  37 + ProductEnrichConfig,
37 38 RedisSettings,
38 39 RerankConfig,
39 40 RerankServiceConfig,
... ... @@ -188,6 +189,11 @@ class AppConfigLoader:
188 189 runtime_config = self._build_runtime_config()
189 190 infrastructure_config = self._build_infrastructure_config(runtime_config.environment)
190 191  
  192 + product_enrich_raw = raw.get("product_enrich") if isinstance(raw.get("product_enrich"), dict) else {}
  193 + product_enrich_config = ProductEnrichConfig(
  194 + max_workers=int(product_enrich_raw.get("max_workers", 40)),
  195 + )
  196 +
191 197 metadata = ConfigMetadata(
192 198 loaded_files=tuple(loaded_files),
193 199 config_hash="",
... ... @@ -197,6 +203,7 @@ class AppConfigLoader:
197 203 app_config = AppConfig(
198 204 runtime=runtime_config,
199 205 infrastructure=infrastructure_config,
  206 + product_enrich=product_enrich_config,
200 207 search=search_config,
201 208 services=services_config,
202 209 tenants=tenants_config,
... ... @@ -208,6 +215,7 @@ class AppConfigLoader:
208 215 return AppConfig(
209 216 runtime=app_config.runtime,
210 217 infrastructure=app_config.infrastructure,
  218 + product_enrich=app_config.product_enrich,
211 219 search=app_config.search,
212 220 services=app_config.services,
213 221 tenants=app_config.tenants,
... ... @@ -547,20 +555,9 @@ class AppConfigLoader:
547 555 return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16]
548 556  
549 557 def _detect_deprecated_keys(self, raw: Dict[str, Any]) -> Iterable[str]:
550   - tenant_raw = raw.get("tenant_config") if isinstance(raw.get("tenant_config"), dict) else {}
551   - for key in ("default",):
552   - item = tenant_raw.get(key)
553   - if isinstance(item, dict):
554   - for deprecated in ("translate_to_en", "translate_to_zh"):
555   - if deprecated in item:
556   - yield f"tenant_config.{key}.{deprecated}"
557   - tenants = tenant_raw.get("tenants") if isinstance(tenant_raw.get("tenants"), dict) else {}
558   - for tenant_id, cfg in tenants.items():
559   - if not isinstance(cfg, dict):
560   - continue
561   - for deprecated in ("translate_to_en", "translate_to_zh"):
562   - if deprecated in cfg:
563   - yield f"tenant_config.tenants.{tenant_id}.{deprecated}"
  558 + # Translation-era legacy flags have been removed; keep the hook for future
  559 + # deprecations, but currently no deprecated keys are detected.
  560 + return ()
564 561  
565 562  
566 563 @lru_cache(maxsize=1)
... ...
config/schema.py
... ... @@ -240,6 +240,13 @@ class InfrastructureConfig:
240 240  
241 241  
242 242 @dataclass(frozen=True)
  243 +class ProductEnrichConfig:
  244 + """Configuration for LLM-based product content understanding (enrich-content)."""
  245 +
  246 + max_workers: int = 40
  247 +
  248 +
  249 +@dataclass(frozen=True)
243 250 class RuntimeConfig:
244 251 environment: str = "prod"
245 252 index_namespace: str = ""
... ... @@ -275,6 +282,7 @@ class AppConfig:
275 282  
276 283 runtime: RuntimeConfig
277 284 infrastructure: InfrastructureConfig
  285 + product_enrich: ProductEnrichConfig
278 286 search: SearchConfig
279 287 services: ServicesConfig
280 288 tenants: TenantCatalogConfig
... ...
indexer/product_enrich.py
... ... @@ -12,8 +12,11 @@ import logging
12 12 import re
13 13 import time
14 14 import hashlib
  15 +import uuid
  16 +import threading
15 17 from collections import OrderedDict
16 18 from datetime import datetime
  19 +from concurrent.futures import ThreadPoolExecutor
17 20 from typing import List, Dict, Tuple, Any, Optional
18 21  
19 22 import redis
... ... @@ -31,6 +34,9 @@ from indexer.product_enrich_prompts import (
31 34  
32 35 # 配置
33 36 BATCH_SIZE = 20
  37 +# enrich-content LLM 批次并发 worker 上限(线程池;仅对 uncached batch 并发)
  38 +_APP_CONFIG = get_app_config()
  39 +CONTENT_UNDERSTANDING_MAX_WORKERS = int(_APP_CONFIG.product_enrich.max_workers)
34 40 # 华北2(北京):https://dashscope.aliyuncs.com/compatible-mode/v1
35 41 # 新加坡:https://dashscope-intl.aliyuncs.com/compatible-mode/v1
36 42 # 美国(弗吉尼亚):https://dashscope-us.aliyuncs.com/compatible-mode/v1
... ... @@ -56,6 +62,24 @@ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
56 62 log_file = LOG_DIR / f"product_enrich_{timestamp}.log"
57 63 verbose_log_file = LOG_DIR / "product_enrich_verbose.log"
58 64 _logged_shared_context_keys: "OrderedDict[str, None]" = OrderedDict()
  65 +_logged_shared_context_lock = threading.Lock()
  66 +
  67 +_content_understanding_executor: Optional[ThreadPoolExecutor] = None
  68 +_content_understanding_executor_lock = threading.Lock()
  69 +
  70 +
  71 +def _get_content_understanding_executor() -> ThreadPoolExecutor:
  72 + """
  73 + 使用模块级单例线程池,避免同一进程内多次请求叠加创建线程池导致并发失控。
  74 + """
  75 + global _content_understanding_executor
  76 + with _content_understanding_executor_lock:
  77 + if _content_understanding_executor is None:
  78 + _content_understanding_executor = ThreadPoolExecutor(
  79 + max_workers=CONTENT_UNDERSTANDING_MAX_WORKERS,
  80 + thread_name_prefix="product-enrich-llm",
  81 + )
  82 + return _content_understanding_executor
59 83  
60 84 # 主日志 logger:执行流程、批次信息等
61 85 logger = logging.getLogger("product_enrich")
... ... @@ -91,7 +115,7 @@ logger.info("Verbose LLM logs are written to: %s", verbose_log_file)
91 115  
92 116  
93 117 # Redis 缓存(用于 anchors / 语义属性)
94   -_REDIS_CONFIG = get_app_config().infrastructure.redis
  118 +_REDIS_CONFIG = _APP_CONFIG.infrastructure.redis
95 119 ANCHOR_CACHE_PREFIX = _REDIS_CONFIG.anchor_cache_prefix
96 120 ANCHOR_CACHE_EXPIRE_DAYS = int(_REDIS_CONFIG.anchor_cache_expire_days)
97 121 _anchor_redis: Optional[redis.Redis] = None
... ... @@ -243,19 +267,21 @@ def _hash_text(text: str) -> str:
243 267  
244 268  
245 269 def _mark_shared_context_logged_once(shared_context_key: str) -> bool:
246   - if shared_context_key in _logged_shared_context_keys:
247   - _logged_shared_context_keys.move_to_end(shared_context_key)
248   - return False
  270 + with _logged_shared_context_lock:
  271 + if shared_context_key in _logged_shared_context_keys:
  272 + _logged_shared_context_keys.move_to_end(shared_context_key)
  273 + return False
249 274  
250   - _logged_shared_context_keys[shared_context_key] = None
251   - if len(_logged_shared_context_keys) > LOGGED_SHARED_CONTEXT_CACHE_SIZE:
252   - _logged_shared_context_keys.popitem(last=False)
253   - return True
  275 + _logged_shared_context_keys[shared_context_key] = None
  276 + if len(_logged_shared_context_keys) > LOGGED_SHARED_CONTEXT_CACHE_SIZE:
  277 + _logged_shared_context_keys.popitem(last=False)
  278 + return True
254 279  
255 280  
256 281 def reset_logged_shared_context_keys() -> None:
257 282 """测试辅助:清理已记录的共享 prompt key。"""
258   - _logged_shared_context_keys.clear()
  283 + with _logged_shared_context_lock:
  284 + _logged_shared_context_keys.clear()
259 285  
260 286  
261 287 def create_prompt(
... ... @@ -626,7 +652,9 @@ def process_batch(
626 652 "final_results": results_with_ids,
627 653 }
628 654  
629   - batch_log_file = LOG_DIR / f"batch_{batch_num:04d}_{timestamp}.json"
  655 + # 并发写 batch json 日志时,保证文件名唯一避免覆盖
  656 + batch_call_id = uuid.uuid4().hex[:12]
  657 + batch_log_file = LOG_DIR / f"batch_{batch_num:04d}_{timestamp}_{batch_call_id}.json"
630 658 with open(batch_log_file, "w", encoding="utf-8") as f:
631 659 json.dump(batch_log, f, ensure_ascii=False, indent=2)
632 660  
... ... @@ -708,28 +736,70 @@ def analyze_products(
708 736 bs = max(1, min(req_bs, BATCH_SIZE))
709 737 total_batches = (len(uncached_items) + bs - 1) // bs
710 738  
  739 + batch_jobs: List[Tuple[int, List[Tuple[int, Dict[str, str]]], List[Dict[str, str]]]] = []
711 740 for i in range(0, len(uncached_items), bs):
712 741 batch_num = i // bs + 1
713 742 batch_slice = uncached_items[i : i + bs]
714 743 batch = [item for _, item in batch_slice]
  744 + batch_jobs.append((batch_num, batch_slice, batch))
  745 +
  746 + # 只有一个批次时走串行,减少线程池创建开销与日志/日志文件的不可控交织
  747 + if total_batches <= 1 or CONTENT_UNDERSTANDING_MAX_WORKERS <= 1:
  748 + for batch_num, batch_slice, batch in batch_jobs:
  749 + logger.info(
  750 + f"[analyze_products] Processing batch {batch_num}/{total_batches}, "
  751 + f"size={len(batch)}, target_lang={target_lang}"
  752 + )
  753 + batch_results = process_batch(batch, batch_num=batch_num, target_lang=target_lang)
  754 +
  755 + for (original_idx, product), item in zip(batch_slice, batch_results):
  756 + results_by_index[original_idx] = item
  757 + title_input = str(item.get("title_input") or "").strip()
  758 + if not title_input:
  759 + continue
  760 + if item.get("error"):
  761 + # 不缓存错误结果,避免放大临时故障
  762 + continue
  763 + try:
  764 + _set_cached_anchor_result(product, target_lang, item)
  765 + except Exception:
  766 + # 已在内部记录 warning
  767 + pass
  768 + else:
  769 + max_workers = min(CONTENT_UNDERSTANDING_MAX_WORKERS, len(batch_jobs))
715 770 logger.info(
716   - f"[analyze_products] Processing batch {batch_num}/{total_batches}, "
717   - f"size={len(batch)}, target_lang={target_lang}"
  771 + "[analyze_products] Using ThreadPoolExecutor for uncached batches: "
  772 + "max_workers=%s, total_batches=%s, bs=%s, target_lang=%s",
  773 + max_workers,
  774 + total_batches,
  775 + bs,
  776 + target_lang,
718 777 )
719   - batch_results = process_batch(batch, batch_num=batch_num, target_lang=target_lang)
720 778  
721   - for (original_idx, product), item in zip(batch_slice, batch_results):
722   - results_by_index[original_idx] = item
723   - title_input = str(item.get("title_input") or "").strip()
724   - if not title_input:
725   - continue
726   - if item.get("error"):
727   - # 不缓存错误结果,避免放大临时故障
728   - continue
729   - try:
730   - _set_cached_anchor_result(product, target_lang, item)
731   - except Exception:
732   - # 已在内部记录 warning
733   - pass
  779 + # 只把“LLM 调用 + markdown 解析”放到线程里;Redis get/set 保持在主线程,避免并发写入带来额外风险。
  780 + # 注意:线程池是模块级单例,因此这里的 max_workers 主要用于日志语义(实际并发受单例池上限约束)。
  781 + executor = _get_content_understanding_executor()
  782 + future_by_batch_num: Dict[int, Any] = {}
  783 + for batch_num, _batch_slice, batch in batch_jobs:
  784 + future_by_batch_num[batch_num] = executor.submit(
  785 + process_batch, batch, batch_num=batch_num, target_lang=target_lang
  786 + )
  787 +
  788 + # 按 batch_num 回填,确保输出稳定(results_by_index 是按原始 input index 映射的)
  789 + for batch_num, batch_slice, _batch in batch_jobs:
  790 + batch_results = future_by_batch_num[batch_num].result()
  791 + for (original_idx, product), item in zip(batch_slice, batch_results):
  792 + results_by_index[original_idx] = item
  793 + title_input = str(item.get("title_input") or "").strip()
  794 + if not title_input:
  795 + continue
  796 + if item.get("error"):
  797 + # 不缓存错误结果,避免放大临时故障
  798 + continue
  799 + try:
  800 + _set_cached_anchor_result(product, target_lang, item)
  801 + except Exception:
  802 + # 已在内部记录 warning
  803 + pass
734 804  
735 805 return [item for item in results_by_index if item is not None]
... ...
query/query_parser.py
... ... @@ -618,17 +618,3 @@ class QueryParser:
618 618 queries.append(translation)
619 619  
620 620 return queries
621   -
622   - def update_rewrite_rules(self, rules: Dict[str, str]) -> None:
623   - """
624   - Update query rewrite rules.
625   -
626   - Args:
627   - rules: Dictionary of pattern -> replacement mappings
628   - """
629   - for pattern, replacement in rules.items():
630   - self.rewriter.add_rule(pattern, replacement)
631   -
632   - def get_rewrite_rules(self) -> Dict[str, str]:
633   - """Get current rewrite rules."""
634   - return self.rewriter.get_rules()
... ...
tests/test_process_products_batching.py
... ... @@ -45,7 +45,8 @@ def test_analyze_products_caps_batch_size_to_20(monkeypatch):
45 45 )
46 46  
47 47 assert len(out) == 45
48   - assert seen_batch_sizes == [20, 20, 5]
  48 + # 并发执行时 batch 调用顺序可能变化,因此校验“批大小集合”而不是严格顺序
  49 + assert sorted(seen_batch_sizes) == [5, 20, 20]
49 50  
50 51  
51 52 def test_analyze_products_uses_min_batch_size_1(monkeypatch):
... ...