Commit 41f0b2e9fb4fe2301a797139f1d9530783139f72
1 parent
86d8358b
product_enrich支持并发
Showing
7 changed files
with
121 additions
and
99 deletions
Show diff stats
api/routes/admin.py
| @@ -3,7 +3,6 @@ Admin API routes for configuration and management. | @@ -3,7 +3,6 @@ Admin API routes for configuration and management. | ||
| 3 | """ | 3 | """ |
| 4 | 4 | ||
| 5 | from fastapi import APIRouter, HTTPException, Request | 5 | from fastapi import APIRouter, HTTPException, Request |
| 6 | -from typing import Dict | ||
| 7 | 6 | ||
| 8 | from ..models import HealthResponse, ErrorResponse | 7 | from ..models import HealthResponse, ErrorResponse |
| 9 | from indexer.mapping_generator import get_tenant_index_name | 8 | from indexer.mapping_generator import get_tenant_index_name |
| @@ -74,49 +73,6 @@ async def get_configuration_meta(): | @@ -74,49 +73,6 @@ async def get_configuration_meta(): | ||
| 74 | raise HTTPException(status_code=500, detail=str(e)) | 73 | raise HTTPException(status_code=500, detail=str(e)) |
| 75 | 74 | ||
| 76 | 75 | ||
| 77 | -@router.post("/rewrite-rules") | ||
| 78 | -async def update_rewrite_rules(rules: Dict[str, str]): | ||
| 79 | - """ | ||
| 80 | - Update query rewrite rules. | ||
| 81 | - | ||
| 82 | - Args: | ||
| 83 | - rules: Dictionary of pattern -> replacement mappings | ||
| 84 | - """ | ||
| 85 | - try: | ||
| 86 | - from ..app import get_query_parser | ||
| 87 | - | ||
| 88 | - query_parser = get_query_parser() | ||
| 89 | - query_parser.update_rewrite_rules(rules) | ||
| 90 | - | ||
| 91 | - return { | ||
| 92 | - "status": "success", | ||
| 93 | - "message": f"Updated {len(rules)} rewrite rules" | ||
| 94 | - } | ||
| 95 | - | ||
| 96 | - except Exception as e: | ||
| 97 | - raise HTTPException(status_code=500, detail=str(e)) | ||
| 98 | - | ||
| 99 | - | ||
| 100 | -@router.get("/rewrite-rules") | ||
| 101 | -async def get_rewrite_rules(): | ||
| 102 | - """ | ||
| 103 | - Get current query rewrite rules. | ||
| 104 | - """ | ||
| 105 | - try: | ||
| 106 | - from ..app import get_query_parser | ||
| 107 | - | ||
| 108 | - query_parser = get_query_parser() | ||
| 109 | - rules = query_parser.get_rewrite_rules() | ||
| 110 | - | ||
| 111 | - return { | ||
| 112 | - "rules": rules, | ||
| 113 | - "count": len(rules) | ||
| 114 | - } | ||
| 115 | - | ||
| 116 | - except Exception as e: | ||
| 117 | - raise HTTPException(status_code=500, detail=str(e)) | ||
| 118 | - | ||
| 119 | - | ||
| 120 | @router.get("/stats") | 76 | @router.get("/stats") |
| 121 | async def get_index_stats(http_request: Request): | 77 | async def get_index_stats(http_request: Request): |
| 122 | """ | 78 | """ |
config/config.yaml
| @@ -9,6 +9,10 @@ es_index_name: "search_products" | @@ -9,6 +9,10 @@ es_index_name: "search_products" | ||
| 9 | assets: | 9 | assets: |
| 10 | query_rewrite_dictionary_path: "config/dictionaries/query_rewrite.dict" | 10 | query_rewrite_dictionary_path: "config/dictionaries/query_rewrite.dict" |
| 11 | 11 | ||
| 12 | +# Product content understanding (LLM enrich-content) configuration | ||
| 13 | +product_enrich: | ||
| 14 | + max_workers: 40 | ||
| 15 | + | ||
| 12 | # ES Index Settings (基础设置) | 16 | # ES Index Settings (基础设置) |
| 13 | es_settings: | 17 | es_settings: |
| 14 | number_of_shards: 1 | 18 | number_of_shards: 1 |
config/loader.py
| @@ -34,6 +34,7 @@ from config.schema import ( | @@ -34,6 +34,7 @@ from config.schema import ( | ||
| 34 | IndexConfig, | 34 | IndexConfig, |
| 35 | InfrastructureConfig, | 35 | InfrastructureConfig, |
| 36 | QueryConfig, | 36 | QueryConfig, |
| 37 | + ProductEnrichConfig, | ||
| 37 | RedisSettings, | 38 | RedisSettings, |
| 38 | RerankConfig, | 39 | RerankConfig, |
| 39 | RerankServiceConfig, | 40 | RerankServiceConfig, |
| @@ -188,6 +189,11 @@ class AppConfigLoader: | @@ -188,6 +189,11 @@ class AppConfigLoader: | ||
| 188 | runtime_config = self._build_runtime_config() | 189 | runtime_config = self._build_runtime_config() |
| 189 | infrastructure_config = self._build_infrastructure_config(runtime_config.environment) | 190 | infrastructure_config = self._build_infrastructure_config(runtime_config.environment) |
| 190 | 191 | ||
| 192 | + product_enrich_raw = raw.get("product_enrich") if isinstance(raw.get("product_enrich"), dict) else {} | ||
| 193 | + product_enrich_config = ProductEnrichConfig( | ||
| 194 | + max_workers=int(product_enrich_raw.get("max_workers", 40)), | ||
| 195 | + ) | ||
| 196 | + | ||
| 191 | metadata = ConfigMetadata( | 197 | metadata = ConfigMetadata( |
| 192 | loaded_files=tuple(loaded_files), | 198 | loaded_files=tuple(loaded_files), |
| 193 | config_hash="", | 199 | config_hash="", |
| @@ -197,6 +203,7 @@ class AppConfigLoader: | @@ -197,6 +203,7 @@ class AppConfigLoader: | ||
| 197 | app_config = AppConfig( | 203 | app_config = AppConfig( |
| 198 | runtime=runtime_config, | 204 | runtime=runtime_config, |
| 199 | infrastructure=infrastructure_config, | 205 | infrastructure=infrastructure_config, |
| 206 | + product_enrich=product_enrich_config, | ||
| 200 | search=search_config, | 207 | search=search_config, |
| 201 | services=services_config, | 208 | services=services_config, |
| 202 | tenants=tenants_config, | 209 | tenants=tenants_config, |
| @@ -208,6 +215,7 @@ class AppConfigLoader: | @@ -208,6 +215,7 @@ class AppConfigLoader: | ||
| 208 | return AppConfig( | 215 | return AppConfig( |
| 209 | runtime=app_config.runtime, | 216 | runtime=app_config.runtime, |
| 210 | infrastructure=app_config.infrastructure, | 217 | infrastructure=app_config.infrastructure, |
| 218 | + product_enrich=app_config.product_enrich, | ||
| 211 | search=app_config.search, | 219 | search=app_config.search, |
| 212 | services=app_config.services, | 220 | services=app_config.services, |
| 213 | tenants=app_config.tenants, | 221 | tenants=app_config.tenants, |
| @@ -547,20 +555,9 @@ class AppConfigLoader: | @@ -547,20 +555,9 @@ class AppConfigLoader: | ||
| 547 | return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16] | 555 | return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16] |
| 548 | 556 | ||
| 549 | def _detect_deprecated_keys(self, raw: Dict[str, Any]) -> Iterable[str]: | 557 | def _detect_deprecated_keys(self, raw: Dict[str, Any]) -> Iterable[str]: |
| 550 | - tenant_raw = raw.get("tenant_config") if isinstance(raw.get("tenant_config"), dict) else {} | ||
| 551 | - for key in ("default",): | ||
| 552 | - item = tenant_raw.get(key) | ||
| 553 | - if isinstance(item, dict): | ||
| 554 | - for deprecated in ("translate_to_en", "translate_to_zh"): | ||
| 555 | - if deprecated in item: | ||
| 556 | - yield f"tenant_config.{key}.{deprecated}" | ||
| 557 | - tenants = tenant_raw.get("tenants") if isinstance(tenant_raw.get("tenants"), dict) else {} | ||
| 558 | - for tenant_id, cfg in tenants.items(): | ||
| 559 | - if not isinstance(cfg, dict): | ||
| 560 | - continue | ||
| 561 | - for deprecated in ("translate_to_en", "translate_to_zh"): | ||
| 562 | - if deprecated in cfg: | ||
| 563 | - yield f"tenant_config.tenants.{tenant_id}.{deprecated}" | 558 | + # Translation-era legacy flags have been removed; keep the hook for future |
| 559 | + # deprecations, but currently no deprecated keys are detected. | ||
| 560 | + return () | ||
| 564 | 561 | ||
| 565 | 562 | ||
| 566 | @lru_cache(maxsize=1) | 563 | @lru_cache(maxsize=1) |
config/schema.py
| @@ -240,6 +240,13 @@ class InfrastructureConfig: | @@ -240,6 +240,13 @@ class InfrastructureConfig: | ||
| 240 | 240 | ||
| 241 | 241 | ||
| 242 | @dataclass(frozen=True) | 242 | @dataclass(frozen=True) |
| 243 | +class ProductEnrichConfig: | ||
| 244 | + """Configuration for LLM-based product content understanding (enrich-content).""" | ||
| 245 | + | ||
| 246 | + max_workers: int = 40 | ||
| 247 | + | ||
| 248 | + | ||
| 249 | +@dataclass(frozen=True) | ||
| 243 | class RuntimeConfig: | 250 | class RuntimeConfig: |
| 244 | environment: str = "prod" | 251 | environment: str = "prod" |
| 245 | index_namespace: str = "" | 252 | index_namespace: str = "" |
| @@ -275,6 +282,7 @@ class AppConfig: | @@ -275,6 +282,7 @@ class AppConfig: | ||
| 275 | 282 | ||
| 276 | runtime: RuntimeConfig | 283 | runtime: RuntimeConfig |
| 277 | infrastructure: InfrastructureConfig | 284 | infrastructure: InfrastructureConfig |
| 285 | + product_enrich: ProductEnrichConfig | ||
| 278 | search: SearchConfig | 286 | search: SearchConfig |
| 279 | services: ServicesConfig | 287 | services: ServicesConfig |
| 280 | tenants: TenantCatalogConfig | 288 | tenants: TenantCatalogConfig |
indexer/product_enrich.py
| @@ -12,8 +12,11 @@ import logging | @@ -12,8 +12,11 @@ import logging | ||
| 12 | import re | 12 | import re |
| 13 | import time | 13 | import time |
| 14 | import hashlib | 14 | import hashlib |
| 15 | +import uuid | ||
| 16 | +import threading | ||
| 15 | from collections import OrderedDict | 17 | from collections import OrderedDict |
| 16 | from datetime import datetime | 18 | from datetime import datetime |
| 19 | +from concurrent.futures import ThreadPoolExecutor | ||
| 17 | from typing import List, Dict, Tuple, Any, Optional | 20 | from typing import List, Dict, Tuple, Any, Optional |
| 18 | 21 | ||
| 19 | import redis | 22 | import redis |
| @@ -31,6 +34,9 @@ from indexer.product_enrich_prompts import ( | @@ -31,6 +34,9 @@ from indexer.product_enrich_prompts import ( | ||
| 31 | 34 | ||
| 32 | # 配置 | 35 | # 配置 |
| 33 | BATCH_SIZE = 20 | 36 | BATCH_SIZE = 20 |
| 37 | +# enrich-content LLM 批次并发 worker 上限(线程池;仅对 uncached batch 并发) | ||
| 38 | +_APP_CONFIG = get_app_config() | ||
| 39 | +CONTENT_UNDERSTANDING_MAX_WORKERS = int(_APP_CONFIG.product_enrich.max_workers) | ||
| 34 | # 华北2(北京):https://dashscope.aliyuncs.com/compatible-mode/v1 | 40 | # 华北2(北京):https://dashscope.aliyuncs.com/compatible-mode/v1 |
| 35 | # 新加坡:https://dashscope-intl.aliyuncs.com/compatible-mode/v1 | 41 | # 新加坡:https://dashscope-intl.aliyuncs.com/compatible-mode/v1 |
| 36 | # 美国(弗吉尼亚):https://dashscope-us.aliyuncs.com/compatible-mode/v1 | 42 | # 美国(弗吉尼亚):https://dashscope-us.aliyuncs.com/compatible-mode/v1 |
| @@ -56,6 +62,24 @@ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | @@ -56,6 +62,24 @@ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | ||
| 56 | log_file = LOG_DIR / f"product_enrich_{timestamp}.log" | 62 | log_file = LOG_DIR / f"product_enrich_{timestamp}.log" |
| 57 | verbose_log_file = LOG_DIR / "product_enrich_verbose.log" | 63 | verbose_log_file = LOG_DIR / "product_enrich_verbose.log" |
| 58 | _logged_shared_context_keys: "OrderedDict[str, None]" = OrderedDict() | 64 | _logged_shared_context_keys: "OrderedDict[str, None]" = OrderedDict() |
| 65 | +_logged_shared_context_lock = threading.Lock() | ||
| 66 | + | ||
| 67 | +_content_understanding_executor: Optional[ThreadPoolExecutor] = None | ||
| 68 | +_content_understanding_executor_lock = threading.Lock() | ||
| 69 | + | ||
| 70 | + | ||
| 71 | +def _get_content_understanding_executor() -> ThreadPoolExecutor: | ||
| 72 | + """ | ||
| 73 | + 使用模块级单例线程池,避免同一进程内多次请求叠加创建线程池导致并发失控。 | ||
| 74 | + """ | ||
| 75 | + global _content_understanding_executor | ||
| 76 | + with _content_understanding_executor_lock: | ||
| 77 | + if _content_understanding_executor is None: | ||
| 78 | + _content_understanding_executor = ThreadPoolExecutor( | ||
| 79 | + max_workers=CONTENT_UNDERSTANDING_MAX_WORKERS, | ||
| 80 | + thread_name_prefix="product-enrich-llm", | ||
| 81 | + ) | ||
| 82 | + return _content_understanding_executor | ||
| 59 | 83 | ||
| 60 | # 主日志 logger:执行流程、批次信息等 | 84 | # 主日志 logger:执行流程、批次信息等 |
| 61 | logger = logging.getLogger("product_enrich") | 85 | logger = logging.getLogger("product_enrich") |
| @@ -91,7 +115,7 @@ logger.info("Verbose LLM logs are written to: %s", verbose_log_file) | @@ -91,7 +115,7 @@ logger.info("Verbose LLM logs are written to: %s", verbose_log_file) | ||
| 91 | 115 | ||
| 92 | 116 | ||
| 93 | # Redis 缓存(用于 anchors / 语义属性) | 117 | # Redis 缓存(用于 anchors / 语义属性) |
| 94 | -_REDIS_CONFIG = get_app_config().infrastructure.redis | 118 | +_REDIS_CONFIG = _APP_CONFIG.infrastructure.redis |
| 95 | ANCHOR_CACHE_PREFIX = _REDIS_CONFIG.anchor_cache_prefix | 119 | ANCHOR_CACHE_PREFIX = _REDIS_CONFIG.anchor_cache_prefix |
| 96 | ANCHOR_CACHE_EXPIRE_DAYS = int(_REDIS_CONFIG.anchor_cache_expire_days) | 120 | ANCHOR_CACHE_EXPIRE_DAYS = int(_REDIS_CONFIG.anchor_cache_expire_days) |
| 97 | _anchor_redis: Optional[redis.Redis] = None | 121 | _anchor_redis: Optional[redis.Redis] = None |
| @@ -243,19 +267,21 @@ def _hash_text(text: str) -> str: | @@ -243,19 +267,21 @@ def _hash_text(text: str) -> str: | ||
| 243 | 267 | ||
| 244 | 268 | ||
| 245 | def _mark_shared_context_logged_once(shared_context_key: str) -> bool: | 269 | def _mark_shared_context_logged_once(shared_context_key: str) -> bool: |
| 246 | - if shared_context_key in _logged_shared_context_keys: | ||
| 247 | - _logged_shared_context_keys.move_to_end(shared_context_key) | ||
| 248 | - return False | 270 | + with _logged_shared_context_lock: |
| 271 | + if shared_context_key in _logged_shared_context_keys: | ||
| 272 | + _logged_shared_context_keys.move_to_end(shared_context_key) | ||
| 273 | + return False | ||
| 249 | 274 | ||
| 250 | - _logged_shared_context_keys[shared_context_key] = None | ||
| 251 | - if len(_logged_shared_context_keys) > LOGGED_SHARED_CONTEXT_CACHE_SIZE: | ||
| 252 | - _logged_shared_context_keys.popitem(last=False) | ||
| 253 | - return True | 275 | + _logged_shared_context_keys[shared_context_key] = None |
| 276 | + if len(_logged_shared_context_keys) > LOGGED_SHARED_CONTEXT_CACHE_SIZE: | ||
| 277 | + _logged_shared_context_keys.popitem(last=False) | ||
| 278 | + return True | ||
| 254 | 279 | ||
| 255 | 280 | ||
| 256 | def reset_logged_shared_context_keys() -> None: | 281 | def reset_logged_shared_context_keys() -> None: |
| 257 | """测试辅助:清理已记录的共享 prompt key。""" | 282 | """测试辅助:清理已记录的共享 prompt key。""" |
| 258 | - _logged_shared_context_keys.clear() | 283 | + with _logged_shared_context_lock: |
| 284 | + _logged_shared_context_keys.clear() | ||
| 259 | 285 | ||
| 260 | 286 | ||
| 261 | def create_prompt( | 287 | def create_prompt( |
| @@ -626,7 +652,9 @@ def process_batch( | @@ -626,7 +652,9 @@ def process_batch( | ||
| 626 | "final_results": results_with_ids, | 652 | "final_results": results_with_ids, |
| 627 | } | 653 | } |
| 628 | 654 | ||
| 629 | - batch_log_file = LOG_DIR / f"batch_{batch_num:04d}_{timestamp}.json" | 655 | + # 并发写 batch json 日志时,保证文件名唯一避免覆盖 |
| 656 | + batch_call_id = uuid.uuid4().hex[:12] | ||
| 657 | + batch_log_file = LOG_DIR / f"batch_{batch_num:04d}_{timestamp}_{batch_call_id}.json" | ||
| 630 | with open(batch_log_file, "w", encoding="utf-8") as f: | 658 | with open(batch_log_file, "w", encoding="utf-8") as f: |
| 631 | json.dump(batch_log, f, ensure_ascii=False, indent=2) | 659 | json.dump(batch_log, f, ensure_ascii=False, indent=2) |
| 632 | 660 | ||
| @@ -708,28 +736,70 @@ def analyze_products( | @@ -708,28 +736,70 @@ def analyze_products( | ||
| 708 | bs = max(1, min(req_bs, BATCH_SIZE)) | 736 | bs = max(1, min(req_bs, BATCH_SIZE)) |
| 709 | total_batches = (len(uncached_items) + bs - 1) // bs | 737 | total_batches = (len(uncached_items) + bs - 1) // bs |
| 710 | 738 | ||
| 739 | + batch_jobs: List[Tuple[int, List[Tuple[int, Dict[str, str]]], List[Dict[str, str]]]] = [] | ||
| 711 | for i in range(0, len(uncached_items), bs): | 740 | for i in range(0, len(uncached_items), bs): |
| 712 | batch_num = i // bs + 1 | 741 | batch_num = i // bs + 1 |
| 713 | batch_slice = uncached_items[i : i + bs] | 742 | batch_slice = uncached_items[i : i + bs] |
| 714 | batch = [item for _, item in batch_slice] | 743 | batch = [item for _, item in batch_slice] |
| 744 | + batch_jobs.append((batch_num, batch_slice, batch)) | ||
| 745 | + | ||
| 746 | + # 只有一个批次时走串行,减少线程池创建开销与日志/日志文件的不可控交织 | ||
| 747 | + if total_batches <= 1 or CONTENT_UNDERSTANDING_MAX_WORKERS <= 1: | ||
| 748 | + for batch_num, batch_slice, batch in batch_jobs: | ||
| 749 | + logger.info( | ||
| 750 | + f"[analyze_products] Processing batch {batch_num}/{total_batches}, " | ||
| 751 | + f"size={len(batch)}, target_lang={target_lang}" | ||
| 752 | + ) | ||
| 753 | + batch_results = process_batch(batch, batch_num=batch_num, target_lang=target_lang) | ||
| 754 | + | ||
| 755 | + for (original_idx, product), item in zip(batch_slice, batch_results): | ||
| 756 | + results_by_index[original_idx] = item | ||
| 757 | + title_input = str(item.get("title_input") or "").strip() | ||
| 758 | + if not title_input: | ||
| 759 | + continue | ||
| 760 | + if item.get("error"): | ||
| 761 | + # 不缓存错误结果,避免放大临时故障 | ||
| 762 | + continue | ||
| 763 | + try: | ||
| 764 | + _set_cached_anchor_result(product, target_lang, item) | ||
| 765 | + except Exception: | ||
| 766 | + # 已在内部记录 warning | ||
| 767 | + pass | ||
| 768 | + else: | ||
| 769 | + max_workers = min(CONTENT_UNDERSTANDING_MAX_WORKERS, len(batch_jobs)) | ||
| 715 | logger.info( | 770 | logger.info( |
| 716 | - f"[analyze_products] Processing batch {batch_num}/{total_batches}, " | ||
| 717 | - f"size={len(batch)}, target_lang={target_lang}" | 771 | + "[analyze_products] Using ThreadPoolExecutor for uncached batches: " |
| 772 | + "max_workers=%s, total_batches=%s, bs=%s, target_lang=%s", | ||
| 773 | + max_workers, | ||
| 774 | + total_batches, | ||
| 775 | + bs, | ||
| 776 | + target_lang, | ||
| 718 | ) | 777 | ) |
| 719 | - batch_results = process_batch(batch, batch_num=batch_num, target_lang=target_lang) | ||
| 720 | 778 | ||
| 721 | - for (original_idx, product), item in zip(batch_slice, batch_results): | ||
| 722 | - results_by_index[original_idx] = item | ||
| 723 | - title_input = str(item.get("title_input") or "").strip() | ||
| 724 | - if not title_input: | ||
| 725 | - continue | ||
| 726 | - if item.get("error"): | ||
| 727 | - # 不缓存错误结果,避免放大临时故障 | ||
| 728 | - continue | ||
| 729 | - try: | ||
| 730 | - _set_cached_anchor_result(product, target_lang, item) | ||
| 731 | - except Exception: | ||
| 732 | - # 已在内部记录 warning | ||
| 733 | - pass | 779 | + # 只把“LLM 调用 + markdown 解析”放到线程里;Redis get/set 保持在主线程,避免并发写入带来额外风险。 |
| 780 | + # 注意:线程池是模块级单例,因此这里的 max_workers 主要用于日志语义(实际并发受单例池上限约束)。 | ||
| 781 | + executor = _get_content_understanding_executor() | ||
| 782 | + future_by_batch_num: Dict[int, Any] = {} | ||
| 783 | + for batch_num, _batch_slice, batch in batch_jobs: | ||
| 784 | + future_by_batch_num[batch_num] = executor.submit( | ||
| 785 | + process_batch, batch, batch_num=batch_num, target_lang=target_lang | ||
| 786 | + ) | ||
| 787 | + | ||
| 788 | + # 按 batch_num 回填,确保输出稳定(results_by_index 是按原始 input index 映射的) | ||
| 789 | + for batch_num, batch_slice, _batch in batch_jobs: | ||
| 790 | + batch_results = future_by_batch_num[batch_num].result() | ||
| 791 | + for (original_idx, product), item in zip(batch_slice, batch_results): | ||
| 792 | + results_by_index[original_idx] = item | ||
| 793 | + title_input = str(item.get("title_input") or "").strip() | ||
| 794 | + if not title_input: | ||
| 795 | + continue | ||
| 796 | + if item.get("error"): | ||
| 797 | + # 不缓存错误结果,避免放大临时故障 | ||
| 798 | + continue | ||
| 799 | + try: | ||
| 800 | + _set_cached_anchor_result(product, target_lang, item) | ||
| 801 | + except Exception: | ||
| 802 | + # 已在内部记录 warning | ||
| 803 | + pass | ||
| 734 | 804 | ||
| 735 | return [item for item in results_by_index if item is not None] | 805 | return [item for item in results_by_index if item is not None] |
query/query_parser.py
| @@ -618,17 +618,3 @@ class QueryParser: | @@ -618,17 +618,3 @@ class QueryParser: | ||
| 618 | queries.append(translation) | 618 | queries.append(translation) |
| 619 | 619 | ||
| 620 | return queries | 620 | return queries |
| 621 | - | ||
| 622 | - def update_rewrite_rules(self, rules: Dict[str, str]) -> None: | ||
| 623 | - """ | ||
| 624 | - Update query rewrite rules. | ||
| 625 | - | ||
| 626 | - Args: | ||
| 627 | - rules: Dictionary of pattern -> replacement mappings | ||
| 628 | - """ | ||
| 629 | - for pattern, replacement in rules.items(): | ||
| 630 | - self.rewriter.add_rule(pattern, replacement) | ||
| 631 | - | ||
| 632 | - def get_rewrite_rules(self) -> Dict[str, str]: | ||
| 633 | - """Get current rewrite rules.""" | ||
| 634 | - return self.rewriter.get_rules() |
tests/test_process_products_batching.py
| @@ -45,7 +45,8 @@ def test_analyze_products_caps_batch_size_to_20(monkeypatch): | @@ -45,7 +45,8 @@ def test_analyze_products_caps_batch_size_to_20(monkeypatch): | ||
| 45 | ) | 45 | ) |
| 46 | 46 | ||
| 47 | assert len(out) == 45 | 47 | assert len(out) == 45 |
| 48 | - assert seen_batch_sizes == [20, 20, 5] | 48 | + # 并发执行时 batch 调用顺序可能变化,因此校验“批大小集合”而不是严格顺序 |
| 49 | + assert sorted(seen_batch_sizes) == [5, 20, 20] | ||
| 49 | 50 | ||
| 50 | 51 | ||
| 51 | def test_analyze_products_uses_min_batch_size_1(monkeypatch): | 52 | def test_analyze_products_uses_min_batch_size_1(monkeypatch): |