From 41f0b2e9fb4fe2301a797139f1d9530783139f72 Mon Sep 17 00:00:00 2001 From: tangwang Date: Thu, 19 Mar 2026 23:32:53 +0800 Subject: [PATCH] product_enrich支持并发 --- api/routes/admin.py | 44 -------------------------------------------- config/config.yaml | 4 ++++ config/loader.py | 25 +++++++++++-------------- config/schema.py | 8 ++++++++ indexer/product_enrich.py | 122 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------------- query/query_parser.py | 14 -------------- tests/test_process_products_batching.py | 3 ++- 7 files changed, 121 insertions(+), 99 deletions(-) diff --git a/api/routes/admin.py b/api/routes/admin.py index 8a84989..22279a9 100644 --- a/api/routes/admin.py +++ b/api/routes/admin.py @@ -3,7 +3,6 @@ Admin API routes for configuration and management. """ from fastapi import APIRouter, HTTPException, Request -from typing import Dict from ..models import HealthResponse, ErrorResponse from indexer.mapping_generator import get_tenant_index_name @@ -74,49 +73,6 @@ async def get_configuration_meta(): raise HTTPException(status_code=500, detail=str(e)) -@router.post("/rewrite-rules") -async def update_rewrite_rules(rules: Dict[str, str]): - """ - Update query rewrite rules. - - Args: - rules: Dictionary of pattern -> replacement mappings - """ - try: - from ..app import get_query_parser - - query_parser = get_query_parser() - query_parser.update_rewrite_rules(rules) - - return { - "status": "success", - "message": f"Updated {len(rules)} rewrite rules" - } - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/rewrite-rules") -async def get_rewrite_rules(): - """ - Get current query rewrite rules. - """ - try: - from ..app import get_query_parser - - query_parser = get_query_parser() - rules = query_parser.get_rewrite_rules() - - return { - "rules": rules, - "count": len(rules) - } - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @router.get("/stats") async def get_index_stats(http_request: Request): """ diff --git a/config/config.yaml b/config/config.yaml index 0934739..87b7b0d 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -9,6 +9,10 @@ es_index_name: "search_products" assets: query_rewrite_dictionary_path: "config/dictionaries/query_rewrite.dict" +# Product content understanding (LLM enrich-content) configuration +product_enrich: + max_workers: 40 + # ES Index Settings (基础设置) es_settings: number_of_shards: 1 diff --git a/config/loader.py b/config/loader.py index 07f5c17..91553d0 100644 --- a/config/loader.py +++ b/config/loader.py @@ -34,6 +34,7 @@ from config.schema import ( IndexConfig, InfrastructureConfig, QueryConfig, + ProductEnrichConfig, RedisSettings, RerankConfig, RerankServiceConfig, @@ -188,6 +189,11 @@ class AppConfigLoader: runtime_config = self._build_runtime_config() infrastructure_config = self._build_infrastructure_config(runtime_config.environment) + product_enrich_raw = raw.get("product_enrich") if isinstance(raw.get("product_enrich"), dict) else {} + product_enrich_config = ProductEnrichConfig( + max_workers=int(product_enrich_raw.get("max_workers", 40)), + ) + metadata = ConfigMetadata( loaded_files=tuple(loaded_files), config_hash="", @@ -197,6 +203,7 @@ class AppConfigLoader: app_config = AppConfig( runtime=runtime_config, infrastructure=infrastructure_config, + product_enrich=product_enrich_config, search=search_config, services=services_config, tenants=tenants_config, @@ -208,6 +215,7 @@ class AppConfigLoader: return AppConfig( runtime=app_config.runtime, infrastructure=app_config.infrastructure, + product_enrich=app_config.product_enrich, search=app_config.search, services=app_config.services, tenants=app_config.tenants, @@ -547,20 +555,9 @@ class AppConfigLoader: return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16] def _detect_deprecated_keys(self, raw: Dict[str, Any]) -> Iterable[str]: - tenant_raw = raw.get("tenant_config") if isinstance(raw.get("tenant_config"), dict) else {} - for key in ("default",): - item = tenant_raw.get(key) - if isinstance(item, dict): - for deprecated in ("translate_to_en", "translate_to_zh"): - if deprecated in item: - yield f"tenant_config.{key}.{deprecated}" - tenants = tenant_raw.get("tenants") if isinstance(tenant_raw.get("tenants"), dict) else {} - for tenant_id, cfg in tenants.items(): - if not isinstance(cfg, dict): - continue - for deprecated in ("translate_to_en", "translate_to_zh"): - if deprecated in cfg: - yield f"tenant_config.tenants.{tenant_id}.{deprecated}" + # Translation-era legacy flags have been removed; keep the hook for future + # deprecations, but currently no deprecated keys are detected. + return () @lru_cache(maxsize=1) diff --git a/config/schema.py b/config/schema.py index 99fa38b..6f081e1 100644 --- a/config/schema.py +++ b/config/schema.py @@ -240,6 +240,13 @@ class InfrastructureConfig: @dataclass(frozen=True) +class ProductEnrichConfig: + """Configuration for LLM-based product content understanding (enrich-content).""" + + max_workers: int = 40 + + +@dataclass(frozen=True) class RuntimeConfig: environment: str = "prod" index_namespace: str = "" @@ -275,6 +282,7 @@ class AppConfig: runtime: RuntimeConfig infrastructure: InfrastructureConfig + product_enrich: ProductEnrichConfig search: SearchConfig services: ServicesConfig tenants: TenantCatalogConfig diff --git a/indexer/product_enrich.py b/indexer/product_enrich.py index 9f54849..9572269 100644 --- a/indexer/product_enrich.py +++ b/indexer/product_enrich.py @@ -12,8 +12,11 @@ import logging import re import time import hashlib +import uuid +import threading from collections import OrderedDict from datetime import datetime +from concurrent.futures import ThreadPoolExecutor from typing import List, Dict, Tuple, Any, Optional import redis @@ -31,6 +34,9 @@ from indexer.product_enrich_prompts import ( # 配置 BATCH_SIZE = 20 +# enrich-content LLM 批次并发 worker 上限(线程池;仅对 uncached batch 并发) +_APP_CONFIG = get_app_config() +CONTENT_UNDERSTANDING_MAX_WORKERS = int(_APP_CONFIG.product_enrich.max_workers) # 华北2(北京):https://dashscope.aliyuncs.com/compatible-mode/v1 # 新加坡:https://dashscope-intl.aliyuncs.com/compatible-mode/v1 # 美国(弗吉尼亚):https://dashscope-us.aliyuncs.com/compatible-mode/v1 @@ -56,6 +62,24 @@ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") log_file = LOG_DIR / f"product_enrich_{timestamp}.log" verbose_log_file = LOG_DIR / "product_enrich_verbose.log" _logged_shared_context_keys: "OrderedDict[str, None]" = OrderedDict() +_logged_shared_context_lock = threading.Lock() + +_content_understanding_executor: Optional[ThreadPoolExecutor] = None +_content_understanding_executor_lock = threading.Lock() + + +def _get_content_understanding_executor() -> ThreadPoolExecutor: + """ + 使用模块级单例线程池,避免同一进程内多次请求叠加创建线程池导致并发失控。 + """ + global _content_understanding_executor + with _content_understanding_executor_lock: + if _content_understanding_executor is None: + _content_understanding_executor = ThreadPoolExecutor( + max_workers=CONTENT_UNDERSTANDING_MAX_WORKERS, + thread_name_prefix="product-enrich-llm", + ) + return _content_understanding_executor # 主日志 logger:执行流程、批次信息等 logger = logging.getLogger("product_enrich") @@ -91,7 +115,7 @@ logger.info("Verbose LLM logs are written to: %s", verbose_log_file) # Redis 缓存(用于 anchors / 语义属性) -_REDIS_CONFIG = get_app_config().infrastructure.redis +_REDIS_CONFIG = _APP_CONFIG.infrastructure.redis ANCHOR_CACHE_PREFIX = _REDIS_CONFIG.anchor_cache_prefix ANCHOR_CACHE_EXPIRE_DAYS = int(_REDIS_CONFIG.anchor_cache_expire_days) _anchor_redis: Optional[redis.Redis] = None @@ -243,19 +267,21 @@ def _hash_text(text: str) -> str: def _mark_shared_context_logged_once(shared_context_key: str) -> bool: - if shared_context_key in _logged_shared_context_keys: - _logged_shared_context_keys.move_to_end(shared_context_key) - return False + with _logged_shared_context_lock: + if shared_context_key in _logged_shared_context_keys: + _logged_shared_context_keys.move_to_end(shared_context_key) + return False - _logged_shared_context_keys[shared_context_key] = None - if len(_logged_shared_context_keys) > LOGGED_SHARED_CONTEXT_CACHE_SIZE: - _logged_shared_context_keys.popitem(last=False) - return True + _logged_shared_context_keys[shared_context_key] = None + if len(_logged_shared_context_keys) > LOGGED_SHARED_CONTEXT_CACHE_SIZE: + _logged_shared_context_keys.popitem(last=False) + return True def reset_logged_shared_context_keys() -> None: """测试辅助:清理已记录的共享 prompt key。""" - _logged_shared_context_keys.clear() + with _logged_shared_context_lock: + _logged_shared_context_keys.clear() def create_prompt( @@ -626,7 +652,9 @@ def process_batch( "final_results": results_with_ids, } - batch_log_file = LOG_DIR / f"batch_{batch_num:04d}_{timestamp}.json" + # 并发写 batch json 日志时,保证文件名唯一避免覆盖 + batch_call_id = uuid.uuid4().hex[:12] + batch_log_file = LOG_DIR / f"batch_{batch_num:04d}_{timestamp}_{batch_call_id}.json" with open(batch_log_file, "w", encoding="utf-8") as f: json.dump(batch_log, f, ensure_ascii=False, indent=2) @@ -708,28 +736,70 @@ def analyze_products( bs = max(1, min(req_bs, BATCH_SIZE)) total_batches = (len(uncached_items) + bs - 1) // bs + batch_jobs: List[Tuple[int, List[Tuple[int, Dict[str, str]]], List[Dict[str, str]]]] = [] for i in range(0, len(uncached_items), bs): batch_num = i // bs + 1 batch_slice = uncached_items[i : i + bs] batch = [item for _, item in batch_slice] + batch_jobs.append((batch_num, batch_slice, batch)) + + # 只有一个批次时走串行,减少线程池创建开销与日志/日志文件的不可控交织 + if total_batches <= 1 or CONTENT_UNDERSTANDING_MAX_WORKERS <= 1: + for batch_num, batch_slice, batch in batch_jobs: + logger.info( + f"[analyze_products] Processing batch {batch_num}/{total_batches}, " + f"size={len(batch)}, target_lang={target_lang}" + ) + batch_results = process_batch(batch, batch_num=batch_num, target_lang=target_lang) + + for (original_idx, product), item in zip(batch_slice, batch_results): + results_by_index[original_idx] = item + title_input = str(item.get("title_input") or "").strip() + if not title_input: + continue + if item.get("error"): + # 不缓存错误结果,避免放大临时故障 + continue + try: + _set_cached_anchor_result(product, target_lang, item) + except Exception: + # 已在内部记录 warning + pass + else: + max_workers = min(CONTENT_UNDERSTANDING_MAX_WORKERS, len(batch_jobs)) logger.info( - f"[analyze_products] Processing batch {batch_num}/{total_batches}, " - f"size={len(batch)}, target_lang={target_lang}" + "[analyze_products] Using ThreadPoolExecutor for uncached batches: " + "max_workers=%s, total_batches=%s, bs=%s, target_lang=%s", + max_workers, + total_batches, + bs, + target_lang, ) - batch_results = process_batch(batch, batch_num=batch_num, target_lang=target_lang) - for (original_idx, product), item in zip(batch_slice, batch_results): - results_by_index[original_idx] = item - title_input = str(item.get("title_input") or "").strip() - if not title_input: - continue - if item.get("error"): - # 不缓存错误结果,避免放大临时故障 - continue - try: - _set_cached_anchor_result(product, target_lang, item) - except Exception: - # 已在内部记录 warning - pass + # 只把“LLM 调用 + markdown 解析”放到线程里;Redis get/set 保持在主线程,避免并发写入带来额外风险。 + # 注意:线程池是模块级单例,因此这里的 max_workers 主要用于日志语义(实际并发受单例池上限约束)。 + executor = _get_content_understanding_executor() + future_by_batch_num: Dict[int, Any] = {} + for batch_num, _batch_slice, batch in batch_jobs: + future_by_batch_num[batch_num] = executor.submit( + process_batch, batch, batch_num=batch_num, target_lang=target_lang + ) + + # 按 batch_num 回填,确保输出稳定(results_by_index 是按原始 input index 映射的) + for batch_num, batch_slice, _batch in batch_jobs: + batch_results = future_by_batch_num[batch_num].result() + for (original_idx, product), item in zip(batch_slice, batch_results): + results_by_index[original_idx] = item + title_input = str(item.get("title_input") or "").strip() + if not title_input: + continue + if item.get("error"): + # 不缓存错误结果,避免放大临时故障 + continue + try: + _set_cached_anchor_result(product, target_lang, item) + except Exception: + # 已在内部记录 warning + pass return [item for item in results_by_index if item is not None] diff --git a/query/query_parser.py b/query/query_parser.py index 741d93c..a311d00 100644 --- a/query/query_parser.py +++ b/query/query_parser.py @@ -618,17 +618,3 @@ class QueryParser: queries.append(translation) return queries - - def update_rewrite_rules(self, rules: Dict[str, str]) -> None: - """ - Update query rewrite rules. - - Args: - rules: Dictionary of pattern -> replacement mappings - """ - for pattern, replacement in rules.items(): - self.rewriter.add_rule(pattern, replacement) - - def get_rewrite_rules(self) -> Dict[str, str]: - """Get current rewrite rules.""" - return self.rewriter.get_rules() diff --git a/tests/test_process_products_batching.py b/tests/test_process_products_batching.py index d8d7ef4..319ce88 100644 --- a/tests/test_process_products_batching.py +++ b/tests/test_process_products_batching.py @@ -45,7 +45,8 @@ def test_analyze_products_caps_batch_size_to_20(monkeypatch): ) assert len(out) == 45 - assert seen_batch_sizes == [20, 20, 5] + # 并发执行时 batch 调用顺序可能变化,因此校验“批大小集合”而不是严格顺序 + assert sorted(seen_batch_sizes) == [5, 20, 20] def test_analyze_products_uses_min_batch_size_1(monkeypatch): -- libgit2 0.21.2