Compare View
Commits (2)
Showing
14 changed files
Show diff stats
api/app.py
| ... | ... | @@ -40,14 +40,13 @@ limiter = Limiter(key_func=get_remote_address) |
| 40 | 40 | # Add parent directory to path |
| 41 | 41 | sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| 42 | 42 | |
| 43 | -from config import ConfigLoader, SearchConfig | |
| 44 | 43 | from config.env_config import ES_CONFIG |
| 45 | 44 | from utils import ESClient |
| 46 | 45 | from search import Searcher |
| 46 | +from search.query_config import DEFAULT_INDEX_NAME | |
| 47 | 47 | from query import QueryParser |
| 48 | 48 | |
| 49 | 49 | # Global instances |
| 50 | -_config: Optional[SearchConfig] = None | |
| 51 | 50 | _es_client: Optional[ESClient] = None |
| 52 | 51 | _searcher: Optional[Searcher] = None |
| 53 | 52 | _query_parser: Optional[QueryParser] = None |
| ... | ... | @@ -60,20 +59,11 @@ def init_service(es_host: str = "http://localhost:9200"): |
| 60 | 59 | Args: |
| 61 | 60 | es_host: Elasticsearch host URL |
| 62 | 61 | """ |
| 63 | - global _config, _es_client, _searcher, _query_parser | |
| 62 | + global _es_client, _searcher, _query_parser | |
| 64 | 63 | |
| 65 | 64 | start_time = time.time() |
| 66 | 65 | logger.info("Initializing search service (multi-tenant)") |
| 67 | 66 | |
| 68 | - # Load and validate configuration | |
| 69 | - logger.info("Loading configuration...") | |
| 70 | - config_loader = ConfigLoader("config/config.yaml") | |
| 71 | - _config = config_loader.load_config() | |
| 72 | - errors = config_loader.validate_config(_config) | |
| 73 | - if errors: | |
| 74 | - raise ValueError(f"Configuration validation failed: {errors}") | |
| 75 | - logger.info(f"Configuration loaded: {_config.es_index_name}") | |
| 76 | - | |
| 77 | 67 | # Get ES credentials |
| 78 | 68 | es_username = os.getenv('ES_USERNAME') or ES_CONFIG.get('username') |
| 79 | 69 | es_password = os.getenv('ES_PASSWORD') or ES_CONFIG.get('password') |
| ... | ... | @@ -91,20 +81,15 @@ def init_service(es_host: str = "http://localhost:9200"): |
| 91 | 81 | |
| 92 | 82 | # Initialize components |
| 93 | 83 | logger.info("Initializing query parser...") |
| 94 | - _query_parser = QueryParser(_config) | |
| 84 | + _query_parser = QueryParser() | |
| 95 | 85 | |
| 96 | 86 | logger.info("Initializing searcher...") |
| 97 | - _searcher = Searcher(_config, _es_client, _query_parser) | |
| 87 | + _searcher = Searcher(_es_client, _query_parser, index_name=DEFAULT_INDEX_NAME) | |
| 98 | 88 | |
| 99 | 89 | elapsed = time.time() - start_time |
| 100 | - logger.info(f"Search service ready! (took {elapsed:.2f}s)") | |
| 90 | + logger.info(f"Search service ready! (took {elapsed:.2f}s) | Index: {DEFAULT_INDEX_NAME}") | |
| 101 | 91 | |
| 102 | 92 | |
| 103 | -def get_config() -> SearchConfig: | |
| 104 | - """Get search engine configuration.""" | |
| 105 | - if _config is None: | |
| 106 | - raise RuntimeError("Service not initialized") | |
| 107 | - return _config | |
| 108 | 93 | |
| 109 | 94 | |
| 110 | 95 | def get_es_client() -> ESClient: |
| ... | ... | @@ -243,8 +228,8 @@ async def health_check(request: Request): |
| 243 | 228 | """Health check endpoint.""" |
| 244 | 229 | try: |
| 245 | 230 | # Check if services are initialized |
| 246 | - get_config() | |
| 247 | 231 | get_es_client() |
| 232 | + get_searcher() | |
| 248 | 233 | |
| 249 | 234 | return { |
| 250 | 235 | "status": "healthy", | ... | ... |
api/models.py
| ... | ... | @@ -69,6 +69,10 @@ class SearchRequest(BaseModel): |
| 69 | 69 | query: str = Field(..., description="搜索查询字符串,支持布尔表达式(AND, OR, RANK, ANDNOT)") |
| 70 | 70 | size: int = Field(10, ge=1, le=100, description="返回结果数量") |
| 71 | 71 | from_: int = Field(0, ge=0, alias="from", description="分页偏移量") |
| 72 | + language: Literal["zh", "en"] = Field( | |
| 73 | + "zh", | |
| 74 | + description="响应语言:'zh'(中文)或 'en'(英文),用于选择 title/description/vendor 等多语言字段" | |
| 75 | + ) | |
| 72 | 76 | |
| 73 | 77 | # 过滤器 - 精确匹配和多值匹配 |
| 74 | 78 | filters: Optional[Dict[str, Union[str, int, bool, List[Union[str, int]]]]] = Field( |
| ... | ... | @@ -175,28 +179,53 @@ class FacetResult(BaseModel): |
| 175 | 179 | class SkuResult(BaseModel): |
| 176 | 180 | """SKU 结果""" |
| 177 | 181 | sku_id: str = Field(..., description="SKU ID") |
| 178 | - title: Optional[str] = Field(None, description="SKU标题") | |
| 182 | + # 与 ES nested skus 结构对齐 | |
| 179 | 183 | price: Optional[float] = Field(None, description="价格") |
| 180 | 184 | compare_at_price: Optional[float] = Field(None, description="原价") |
| 181 | - sku: Optional[str] = Field(None, description="SKU编码") | |
| 185 | + sku_code: Optional[str] = Field(None, description="SKU编码") | |
| 182 | 186 | stock: int = Field(0, description="库存数量") |
| 183 | - options: Optional[Dict[str, Any]] = Field(None, description="选项(颜色、尺寸等)") | |
| 187 | + weight: Optional[float] = Field(None, description="重量") | |
| 188 | + weight_unit: Optional[str] = Field(None, description="重量单位") | |
| 189 | + option1_value: Optional[str] = Field(None, description="选项1取值(如颜色)") | |
| 190 | + option2_value: Optional[str] = Field(None, description="选项2取值(如尺码)") | |
| 191 | + option3_value: Optional[str] = Field(None, description="选项3取值") | |
| 192 | + image_src: Optional[str] = Field(None, description="SKU图片地址") | |
| 184 | 193 | |
| 185 | 194 | |
| 186 | 195 | class SpuResult(BaseModel): |
| 187 | 196 | """SPU 搜索结果""" |
| 188 | 197 | spu_id: str = Field(..., description="SPU ID") |
| 189 | 198 | title: Optional[str] = Field(None, description="商品标题") |
| 199 | + brief: Optional[str] = Field(None, description="商品短描述") | |
| 190 | 200 | handle: Optional[str] = Field(None, description="商品handle") |
| 191 | 201 | description: Optional[str] = Field(None, description="商品描述") |
| 192 | 202 | vendor: Optional[str] = Field(None, description="供应商/品牌") |
| 193 | - category: Optional[str] = Field(None, description="类目") | |
| 194 | - tags: Optional[str] = Field(None, description="标签") | |
| 203 | + category: Optional[str] = Field(None, description="类目(兼容字段,等同于category_name)") | |
| 204 | + category_path: Optional[str] = Field(None, description="类目路径(多级,用于面包屑)") | |
| 205 | + category_name: Optional[str] = Field(None, description="类目名称(展示用)") | |
| 206 | + category_id: Optional[str] = Field(None, description="类目ID") | |
| 207 | + category_level: Optional[int] = Field(None, description="类目层级") | |
| 208 | + category1_name: Optional[str] = Field(None, description="一级类目名称") | |
| 209 | + category2_name: Optional[str] = Field(None, description="二级类目名称") | |
| 210 | + category3_name: Optional[str] = Field(None, description="三级类目名称") | |
| 211 | + tags: Optional[List[str]] = Field(None, description="标签列表") | |
| 195 | 212 | price: Optional[float] = Field(None, description="价格(min_price)") |
| 196 | 213 | compare_at_price: Optional[float] = Field(None, description="原价") |
| 197 | 214 | currency: str = Field("USD", description="货币单位") |
| 198 | 215 | image_url: Optional[str] = Field(None, description="主图URL") |
| 199 | 216 | in_stock: bool = Field(True, description="是否有库存") |
| 217 | + # SKU 扁平化信息 | |
| 218 | + sku_prices: Optional[List[float]] = Field(None, description="所有SKU价格列表") | |
| 219 | + sku_weights: Optional[List[int]] = Field(None, description="所有SKU重量列表") | |
| 220 | + sku_weight_units: Optional[List[str]] = Field(None, description="所有SKU重量单位列表") | |
| 221 | + total_inventory: Optional[int] = Field(None, description="总库存") | |
| 222 | + option1_name: Optional[str] = Field(None, description="选项1名称(如颜色)") | |
| 223 | + option2_name: Optional[str] = Field(None, description="选项2名称(如尺码)") | |
| 224 | + option3_name: Optional[str] = Field(None, description="选项3名称") | |
| 225 | + specifications: Optional[List[Dict[str, Any]]] = Field( | |
| 226 | + None, | |
| 227 | + description="规格列表(与 ES specifications 字段对应)" | |
| 228 | + ) | |
| 200 | 229 | skus: List[SkuResult] = Field(default_factory=list, description="SKU列表") |
| 201 | 230 | relevance_score: float = Field(..., ge=0.0, description="相关性分数(ES原始分数)") |
| 202 | 231 | ... | ... |
api/result_formatter.py
| ... | ... | @@ -12,7 +12,8 @@ class ResultFormatter: |
| 12 | 12 | @staticmethod |
| 13 | 13 | def format_search_results( |
| 14 | 14 | es_hits: List[Dict[str, Any]], |
| 15 | - max_score: float = 1.0 | |
| 15 | + max_score: float = 1.0, | |
| 16 | + language: str = "zh" | |
| 16 | 17 | ) -> List[SpuResult]: |
| 17 | 18 | """ |
| 18 | 19 | Convert ES hits to SpuResult list. |
| ... | ... | @@ -25,6 +26,18 @@ class ResultFormatter: |
| 25 | 26 | List of SpuResult objects |
| 26 | 27 | """ |
| 27 | 28 | results = [] |
| 29 | + lang = (language or "zh").lower() | |
| 30 | + if lang not in ("zh", "en"): | |
| 31 | + lang = "en" | |
| 32 | + | |
| 33 | + def pick_lang_field(src: Dict[str, Any], base: str) -> Optional[str]: | |
| 34 | + """从 *_zh / *_en 字段中按语言选择一个值,若目标语言缺失则回退到另一种。""" | |
| 35 | + zh_val = src.get(f"{base}_zh") | |
| 36 | + en_val = src.get(f"{base}_en") | |
| 37 | + if lang == "zh": | |
| 38 | + return zh_val or en_val | |
| 39 | + else: | |
| 40 | + return en_val or zh_val | |
| 28 | 41 | |
| 29 | 42 | for hit in es_hits: |
| 30 | 43 | source = hit.get('_source', {}) |
| ... | ... | @@ -40,6 +53,14 @@ class ResultFormatter: |
| 40 | 53 | except (ValueError, TypeError): |
| 41 | 54 | relevance_score = 0.0 |
| 42 | 55 | |
| 56 | + # Multi-language fields | |
| 57 | + title = pick_lang_field(source, "title") | |
| 58 | + brief = pick_lang_field(source, "brief") | |
| 59 | + description = pick_lang_field(source, "description") | |
| 60 | + vendor = pick_lang_field(source, "vendor") | |
| 61 | + category_path = pick_lang_field(source, "category_path") | |
| 62 | + category_name = pick_lang_field(source, "category_name") | |
| 63 | + | |
| 43 | 64 | # Extract SKUs |
| 44 | 65 | skus = [] |
| 45 | 66 | skus_data = source.get('skus', []) |
| ... | ... | @@ -62,17 +83,33 @@ class ResultFormatter: |
| 62 | 83 | # Build SpuResult |
| 63 | 84 | spu = SpuResult( |
| 64 | 85 | spu_id=str(source.get('spu_id', '')), |
| 65 | - title=source.get('title'), | |
| 86 | + title=title, | |
| 87 | + brief=brief, | |
| 66 | 88 | handle=source.get('handle'), |
| 67 | - description=source.get('description'), | |
| 68 | - vendor=source.get('vendor'), | |
| 69 | - category=source.get('category'), | |
| 89 | + description=description, | |
| 90 | + vendor=vendor, | |
| 91 | + category=category_name, | |
| 92 | + category_path=category_path, | |
| 93 | + category_name=category_name, | |
| 94 | + category_id=source.get('category_id'), | |
| 95 | + category_level=source.get('category_level'), | |
| 96 | + category1_name=source.get('category1_name'), | |
| 97 | + category2_name=source.get('category2_name'), | |
| 98 | + category3_name=source.get('category3_name'), | |
| 70 | 99 | tags=source.get('tags'), |
| 71 | 100 | price=source.get('min_price'), |
| 72 | 101 | compare_at_price=source.get('compare_at_price'), |
| 73 | 102 | currency="USD", # Default currency |
| 74 | 103 | image_url=source.get('image_url'), |
| 75 | 104 | in_stock=in_stock, |
| 105 | + sku_prices=source.get('sku_prices'), | |
| 106 | + sku_weights=source.get('sku_weights'), | |
| 107 | + sku_weight_units=source.get('sku_weight_units'), | |
| 108 | + total_inventory=source.get('total_inventory'), | |
| 109 | + option1_name=source.get('option1_name'), | |
| 110 | + option2_name=source.get('option2_name'), | |
| 111 | + option3_name=source.get('option3_name'), | |
| 112 | + specifications=source.get('specifications'), | |
| 76 | 113 | skus=skus, |
| 77 | 114 | relevance_score=relevance_score |
| 78 | 115 | ) |
| ... | ... | @@ -89,6 +126,11 @@ class ResultFormatter: |
| 89 | 126 | """ |
| 90 | 127 | Format ES aggregations to FacetResult list. |
| 91 | 128 | |
| 129 | + 支持: | |
| 130 | + 1. 普通terms聚合 | |
| 131 | + 2. range聚合 | |
| 132 | + 3. specifications嵌套聚合(按name分组,然后按value聚合) | |
| 133 | + | |
| 92 | 134 | Args: |
| 93 | 135 | es_aggregations: ES aggregations response |
| 94 | 136 | facet_configs: Facet configurations (optional) |
| ... | ... | @@ -100,6 +142,38 @@ class ResultFormatter: |
| 100 | 142 | |
| 101 | 143 | for field_name, agg_data in es_aggregations.items(): |
| 102 | 144 | display_field = field_name[:-6] if field_name.endswith("_facet") else field_name |
| 145 | + | |
| 146 | + # 处理specifications嵌套分面 | |
| 147 | + if field_name == "specifications_facet" and 'by_name' in agg_data: | |
| 148 | + # specifications嵌套聚合:按name分组,每个name下有value_counts | |
| 149 | + by_name_agg = agg_data['by_name'] | |
| 150 | + if 'buckets' in by_name_agg: | |
| 151 | + for name_bucket in by_name_agg['buckets']: | |
| 152 | + name = name_bucket['key'] | |
| 153 | + value_counts = name_bucket.get('value_counts', {}) | |
| 154 | + | |
| 155 | + values = [] | |
| 156 | + if 'buckets' in value_counts: | |
| 157 | + for value_bucket in value_counts['buckets']: | |
| 158 | + value = FacetValue( | |
| 159 | + value=value_bucket['key'], | |
| 160 | + label=str(value_bucket['key']), | |
| 161 | + count=value_bucket['doc_count'], | |
| 162 | + selected=False | |
| 163 | + ) | |
| 164 | + values.append(value) | |
| 165 | + | |
| 166 | + # 为每个name创建一个分面结果 | |
| 167 | + facet = FacetResult( | |
| 168 | + field=f"specifications.{name}", | |
| 169 | + label=str(name), # 使用name作为label,如"颜色"、"尺寸" | |
| 170 | + type="terms", | |
| 171 | + values=values, | |
| 172 | + total_count=name_bucket['doc_count'] | |
| 173 | + ) | |
| 174 | + facets.append(facet) | |
| 175 | + continue | |
| 176 | + | |
| 103 | 177 | # Handle terms aggregation |
| 104 | 178 | if 'buckets' in agg_data: |
| 105 | 179 | values = [] | ... | ... |
api/routes/search.py
| ... | ... | @@ -94,7 +94,8 @@ async def search(request: SearchRequest, http_request: Request): |
| 94 | 94 | context=context, |
| 95 | 95 | sort_by=request.sort_by, |
| 96 | 96 | sort_order=request.sort_order, |
| 97 | - debug=request.debug | |
| 97 | + debug=request.debug, | |
| 98 | + language=request.language, | |
| 98 | 99 | ) |
| 99 | 100 | |
| 100 | 101 | # Include performance summary in response | ... | ... |
docs/ES常用表达式.md
| ... | ... | @@ -7,3 +7,14 @@ GET /search_products/_search |
| 7 | 7 | } |
| 8 | 8 | } |
| 9 | 9 | |
| 10 | + | |
| 11 | +curl -u 'essa:4hOaLaf41y2VuI8y' -X GET 'http://localhost:9200/search_products/_search?pretty' -H 'Content-Type: application/json' -d '{ | |
| 12 | + "size": 5, | |
| 13 | + "query": { | |
| 14 | + "bool": { | |
| 15 | + "filter": [ | |
| 16 | + { "term": { "tenant_id": "162" } } | |
| 17 | + ] | |
| 18 | + } | |
| 19 | + } | |
| 20 | + }' | |
| 10 | 21 | \ No newline at end of file | ... | ... |
indexer/mapping_generator.py
| ... | ... | @@ -51,23 +51,26 @@ def create_index_if_not_exists(es_client, index_name: str, mapping: Dict[str, An |
| 51 | 51 | Create Elasticsearch index if it doesn't exist. |
| 52 | 52 | |
| 53 | 53 | Args: |
| 54 | - es_client: Elasticsearch client instance | |
| 54 | + es_client: ESClient instance | |
| 55 | 55 | index_name: Name of the index to create |
| 56 | 56 | mapping: Index mapping configuration. If None, loads from default file. |
| 57 | 57 | |
| 58 | 58 | Returns: |
| 59 | 59 | True if index was created, False if it already exists |
| 60 | 60 | """ |
| 61 | - if es_client.indices.exists(index=index_name): | |
| 61 | + if es_client.index_exists(index_name): | |
| 62 | 62 | logger.info(f"Index '{index_name}' already exists") |
| 63 | 63 | return False |
| 64 | 64 | |
| 65 | 65 | if mapping is None: |
| 66 | 66 | mapping = load_mapping() |
| 67 | 67 | |
| 68 | - es_client.indices.create(index=index_name, body=mapping) | |
| 69 | - logger.info(f"Index '{index_name}' created successfully") | |
| 70 | - return True | |
| 68 | + if es_client.create_index(index_name, mapping): | |
| 69 | + logger.info(f"Index '{index_name}' created successfully") | |
| 70 | + return True | |
| 71 | + else: | |
| 72 | + logger.error(f"Failed to create index '{index_name}'") | |
| 73 | + return False | |
| 71 | 74 | |
| 72 | 75 | |
| 73 | 76 | def delete_index_if_exists(es_client, index_name: str) -> bool: |
| ... | ... | @@ -75,19 +78,22 @@ def delete_index_if_exists(es_client, index_name: str) -> bool: |
| 75 | 78 | Delete Elasticsearch index if it exists. |
| 76 | 79 | |
| 77 | 80 | Args: |
| 78 | - es_client: Elasticsearch client instance | |
| 81 | + es_client: ESClient instance | |
| 79 | 82 | index_name: Name of the index to delete |
| 80 | 83 | |
| 81 | 84 | Returns: |
| 82 | 85 | True if index was deleted, False if it didn't exist |
| 83 | 86 | """ |
| 84 | - if not es_client.indices.exists(index=index_name): | |
| 87 | + if not es_client.index_exists(index_name): | |
| 85 | 88 | logger.warning(f"Index '{index_name}' does not exist") |
| 86 | 89 | return False |
| 87 | 90 | |
| 88 | - es_client.indices.delete(index=index_name) | |
| 89 | - logger.info(f"Index '{index_name}' deleted successfully") | |
| 90 | - return True | |
| 91 | + if es_client.delete_index(index_name): | |
| 92 | + logger.info(f"Index '{index_name}' deleted successfully") | |
| 93 | + return True | |
| 94 | + else: | |
| 95 | + logger.error(f"Failed to delete index '{index_name}'") | |
| 96 | + return False | |
| 91 | 97 | |
| 92 | 98 | |
| 93 | 99 | def update_mapping(es_client, index_name: str, new_fields: Dict[str, Any]) -> bool: |
| ... | ... | @@ -95,18 +101,21 @@ def update_mapping(es_client, index_name: str, new_fields: Dict[str, Any]) -> bo |
| 95 | 101 | Update mapping for existing index (only adding new fields). |
| 96 | 102 | |
| 97 | 103 | Args: |
| 98 | - es_client: Elasticsearch client instance | |
| 104 | + es_client: ESClient instance | |
| 99 | 105 | index_name: Name of the index |
| 100 | 106 | new_fields: New field mappings to add |
| 101 | 107 | |
| 102 | 108 | Returns: |
| 103 | 109 | True if successful |
| 104 | 110 | """ |
| 105 | - if not es_client.indices.exists(index=index_name): | |
| 111 | + if not es_client.index_exists(index_name): | |
| 106 | 112 | logger.error(f"Index '{index_name}' does not exist") |
| 107 | 113 | return False |
| 108 | 114 | |
| 109 | 115 | mapping = {"properties": new_fields} |
| 110 | - es_client.indices.put_mapping(index=index_name, body=mapping) | |
| 111 | - logger.info(f"Mapping updated for index '{index_name}'") | |
| 112 | - return True | |
| 116 | + if es_client.update_mapping(index_name, mapping): | |
| 117 | + logger.info(f"Mapping updated for index '{index_name}'") | |
| 118 | + return True | |
| 119 | + else: | |
| 120 | + logger.error(f"Failed to update mapping for index '{index_name}'") | |
| 121 | + return False | ... | ... |
indexer/spu_transformer.py
| ... | ... | @@ -124,7 +124,7 @@ class SPUTransformer: |
| 124 | 124 | query = text(""" |
| 125 | 125 | SELECT |
| 126 | 126 | id, spu_id, shop_id, shoplazza_id, shoplazza_product_id, |
| 127 | - position, name, values, tenant_id, | |
| 127 | + position, name, `values`, tenant_id, | |
| 128 | 128 | creator, create_time, updater, update_time, deleted |
| 129 | 129 | FROM shoplazza_product_option |
| 130 | 130 | WHERE tenant_id = :tenant_id AND deleted = 0 | ... | ... |
query/query_parser.py
| ... | ... | @@ -8,8 +8,14 @@ from typing import Dict, List, Optional, Any |
| 8 | 8 | import numpy as np |
| 9 | 9 | import logging |
| 10 | 10 | |
| 11 | -from config import SearchConfig, QueryConfig | |
| 12 | 11 | from embeddings import BgeEncoder |
| 12 | +from search.query_config import ( | |
| 13 | + ENABLE_TEXT_EMBEDDING, | |
| 14 | + ENABLE_TRANSLATION, | |
| 15 | + REWRITE_DICTIONARY, | |
| 16 | + TRANSLATION_API_KEY, | |
| 17 | + TRANSLATION_SERVICE | |
| 18 | +) | |
| 13 | 19 | from .language_detector import LanguageDetector |
| 14 | 20 | from .translator import Translator |
| 15 | 21 | from .query_rewriter import QueryRewriter, QueryNormalizer |
| ... | ... | @@ -64,7 +70,6 @@ class QueryParser: |
| 64 | 70 | |
| 65 | 71 | def __init__( |
| 66 | 72 | self, |
| 67 | - config: SearchConfig, | |
| 68 | 73 | text_encoder: Optional[BgeEncoder] = None, |
| 69 | 74 | translator: Optional[Translator] = None |
| 70 | 75 | ): |
| ... | ... | @@ -72,24 +77,21 @@ class QueryParser: |
| 72 | 77 | Initialize query parser. |
| 73 | 78 | |
| 74 | 79 | Args: |
| 75 | - config: Search configuration | |
| 76 | 80 | text_encoder: Text embedding encoder (lazy loaded if not provided) |
| 77 | 81 | translator: Translator instance (lazy loaded if not provided) |
| 78 | 82 | """ |
| 79 | - self.config = config | |
| 80 | - self.query_config = config.query_config | |
| 81 | 83 | self._text_encoder = text_encoder |
| 82 | 84 | self._translator = translator |
| 83 | 85 | |
| 84 | 86 | # Initialize components |
| 85 | 87 | self.normalizer = QueryNormalizer() |
| 86 | 88 | self.language_detector = LanguageDetector() |
| 87 | - self.rewriter = QueryRewriter(self.query_config.rewrite_dictionary) | |
| 89 | + self.rewriter = QueryRewriter(REWRITE_DICTIONARY) | |
| 88 | 90 | |
| 89 | 91 | @property |
| 90 | 92 | def text_encoder(self) -> BgeEncoder: |
| 91 | 93 | """Lazy load text encoder.""" |
| 92 | - if self._text_encoder is None and self.query_config.enable_text_embedding: | |
| 94 | + if self._text_encoder is None and ENABLE_TEXT_EMBEDDING: | |
| 93 | 95 | logger.info("Initializing text encoder (lazy load)...") |
| 94 | 96 | self._text_encoder = BgeEncoder() |
| 95 | 97 | return self._text_encoder |
| ... | ... | @@ -97,13 +99,13 @@ class QueryParser: |
| 97 | 99 | @property |
| 98 | 100 | def translator(self) -> Translator: |
| 99 | 101 | """Lazy load translator.""" |
| 100 | - if self._translator is None and self.query_config.enable_translation: | |
| 102 | + if self._translator is None and ENABLE_TRANSLATION: | |
| 101 | 103 | logger.info("Initializing translator (lazy load)...") |
| 102 | 104 | self._translator = Translator( |
| 103 | - api_key=self.query_config.translation_api_key, | |
| 105 | + api_key=TRANSLATION_API_KEY, | |
| 104 | 106 | use_cache=True, |
| 105 | - glossary_id=getattr(self.query_config, 'translation_glossary_id', None), | |
| 106 | - translation_context=getattr(self.query_config, 'translation_context', 'e-commerce product search') | |
| 107 | + glossary_id=None, # Can be added to query_config if needed | |
| 108 | + translation_context='e-commerce product search' | |
| 107 | 109 | ) |
| 108 | 110 | return self._translator |
| 109 | 111 | |
| ... | ... | @@ -154,7 +156,7 @@ class QueryParser: |
| 154 | 156 | |
| 155 | 157 | # Stage 2: Query rewriting |
| 156 | 158 | rewritten = None |
| 157 | - if self.query_config.enable_query_rewrite: | |
| 159 | + if REWRITE_DICTIONARY: # Enable rewrite if dictionary exists | |
| 158 | 160 | rewritten = self.rewriter.rewrite(query_text) |
| 159 | 161 | if rewritten != query_text: |
| 160 | 162 | log_info(f"查询重写 | '{query_text}' -> '{rewritten}'") |
| ... | ... | @@ -171,26 +173,11 @@ class QueryParser: |
| 171 | 173 | |
| 172 | 174 | # Stage 4: Translation |
| 173 | 175 | translations = {} |
| 174 | - if self.query_config.enable_translation: | |
| 176 | + if ENABLE_TRANSLATION: | |
| 175 | 177 | try: |
| 176 | 178 | # Determine target languages for translation |
| 177 | - # If domain has language_field_mapping, only translate to languages in the mapping | |
| 178 | - # Otherwise, use all supported languages | |
| 179 | - target_langs_for_translation = self.query_config.supported_languages | |
| 180 | - | |
| 181 | - # Check if domain has language_field_mapping | |
| 182 | - domain_config = next( | |
| 183 | - (idx for idx in self.config.indexes if idx.name == domain), | |
| 184 | - None | |
| 185 | - ) | |
| 186 | - if domain_config and domain_config.language_field_mapping: | |
| 187 | - # Only translate to languages that exist in the mapping | |
| 188 | - available_languages = set(domain_config.language_field_mapping.keys()) | |
| 189 | - target_langs_for_translation = [ | |
| 190 | - lang for lang in self.query_config.supported_languages | |
| 191 | - if lang in available_languages | |
| 192 | - ] | |
| 193 | - log_debug(f"域 '{domain}' 有语言字段映射,将翻译到: {target_langs_for_translation}") | |
| 179 | + # Simplified: always translate to Chinese and English | |
| 180 | + target_langs_for_translation = ['zh', 'en'] | |
| 194 | 181 | |
| 195 | 182 | target_langs = self.translator.get_translation_needs( |
| 196 | 183 | detected_lang, |
| ... | ... | @@ -200,7 +187,7 @@ class QueryParser: |
| 200 | 187 | if target_langs: |
| 201 | 188 | log_info(f"开始翻译 | 源语言: {detected_lang} | 目标语言: {target_langs}") |
| 202 | 189 | # Use e-commerce context for better disambiguation |
| 203 | - translation_context = getattr(self.query_config, 'translation_context', 'e-commerce product search') | |
| 190 | + translation_context = 'e-commerce product search' | |
| 204 | 191 | translations = self.translator.translate_multi( |
| 205 | 192 | query_text, |
| 206 | 193 | target_langs, |
| ... | ... | @@ -223,7 +210,7 @@ class QueryParser: |
| 223 | 210 | # Stage 5: Text embedding |
| 224 | 211 | query_vector = None |
| 225 | 212 | if (generate_vector and |
| 226 | - self.query_config.enable_text_embedding and | |
| 213 | + ENABLE_TEXT_EMBEDDING and | |
| 227 | 214 | domain == "default"): # Only generate vector for default domain |
| 228 | 215 | try: |
| 229 | 216 | log_debug("开始生成查询向量") | ... | ... |
| ... | ... | @@ -0,0 +1,184 @@ |
| 1 | +#!/usr/bin/env python3 | |
| 2 | +""" | |
| 3 | +重建索引并导入数据的脚本。 | |
| 4 | + | |
| 5 | +清除旧索引,使用新的mapping重建索引,然后导入数据。 | |
| 6 | +""" | |
| 7 | + | |
| 8 | +import sys | |
| 9 | +import os | |
| 10 | +import argparse | |
| 11 | +from pathlib import Path | |
| 12 | + | |
| 13 | +# Add parent directory to path | |
| 14 | +sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| 15 | + | |
| 16 | +from utils.db_connector import create_db_connection | |
| 17 | +from utils.es_client import ESClient | |
| 18 | +from indexer.mapping_generator import load_mapping, delete_index_if_exists, DEFAULT_INDEX_NAME | |
| 19 | +from indexer.spu_transformer import SPUTransformer | |
| 20 | +from indexer.bulk_indexer import BulkIndexer | |
| 21 | + | |
| 22 | + | |
| 23 | +def main(): | |
| 24 | + parser = argparse.ArgumentParser(description='重建ES索引并导入数据') | |
| 25 | + | |
| 26 | + # Database connection | |
| 27 | + parser.add_argument('--db-host', help='MySQL host (或使用环境变量 DB_HOST)') | |
| 28 | + parser.add_argument('--db-port', type=int, help='MySQL port (或使用环境变量 DB_PORT, 默认: 3306)') | |
| 29 | + parser.add_argument('--db-database', help='MySQL database (或使用环境变量 DB_DATABASE)') | |
| 30 | + parser.add_argument('--db-username', help='MySQL username (或使用环境变量 DB_USERNAME)') | |
| 31 | + parser.add_argument('--db-password', help='MySQL password (或使用环境变量 DB_PASSWORD)') | |
| 32 | + | |
| 33 | + # Tenant and ES | |
| 34 | + parser.add_argument('--tenant-id', required=True, help='Tenant ID (必需)') | |
| 35 | + parser.add_argument('--es-host', help='Elasticsearch host (或使用环境变量 ES_HOST, 默认: http://localhost:9200)') | |
| 36 | + | |
| 37 | + # Options | |
| 38 | + parser.add_argument('--batch-size', type=int, default=500, help='批量导入大小 (默认: 500)') | |
| 39 | + parser.add_argument('--skip-delete', action='store_true', help='跳过删除旧索引步骤') | |
| 40 | + | |
| 41 | + args = parser.parse_args() | |
| 42 | + | |
| 43 | + print("=" * 60) | |
| 44 | + print("重建ES索引并导入数据") | |
| 45 | + print("=" * 60) | |
| 46 | + | |
| 47 | + # 加载mapping | |
| 48 | + print("\n[1/4] 加载mapping配置...") | |
| 49 | + try: | |
| 50 | + mapping = load_mapping() | |
| 51 | + print(f"✓ 成功加载mapping配置") | |
| 52 | + except Exception as e: | |
| 53 | + print(f"✗ 加载mapping失败: {e}") | |
| 54 | + return 1 | |
| 55 | + | |
| 56 | + index_name = DEFAULT_INDEX_NAME | |
| 57 | + print(f"索引名称: {index_name}") | |
| 58 | + | |
| 59 | + # 连接Elasticsearch | |
| 60 | + print("\n[2/4] 连接Elasticsearch...") | |
| 61 | + es_host = args.es_host or os.environ.get('ES_HOST', 'http://localhost:9200') | |
| 62 | + es_username = os.environ.get('ES_USERNAME') | |
| 63 | + es_password = os.environ.get('ES_PASSWORD') | |
| 64 | + | |
| 65 | + print(f"ES地址: {es_host}") | |
| 66 | + if es_username: | |
| 67 | + print(f"ES用户名: {es_username}") | |
| 68 | + | |
| 69 | + try: | |
| 70 | + if es_username and es_password: | |
| 71 | + es_client = ESClient(hosts=[es_host], username=es_username, password=es_password) | |
| 72 | + else: | |
| 73 | + es_client = ESClient(hosts=[es_host]) | |
| 74 | + | |
| 75 | + if not es_client.ping(): | |
| 76 | + print(f"✗ 无法连接到Elasticsearch: {es_host}") | |
| 77 | + return 1 | |
| 78 | + print("✓ Elasticsearch连接成功") | |
| 79 | + except Exception as e: | |
| 80 | + print(f"✗ 连接Elasticsearch失败: {e}") | |
| 81 | + return 1 | |
| 82 | + | |
| 83 | + # 删除旧索引 | |
| 84 | + if not args.skip_delete: | |
| 85 | + print("\n[3/4] 删除旧索引...") | |
| 86 | + if es_client.index_exists(index_name): | |
| 87 | + print(f"发现已存在的索引: {index_name}") | |
| 88 | + if delete_index_if_exists(es_client, index_name): | |
| 89 | + print(f"✓ 成功删除索引: {index_name}") | |
| 90 | + else: | |
| 91 | + print(f"✗ 删除索引失败: {index_name}") | |
| 92 | + return 1 | |
| 93 | + else: | |
| 94 | + print(f"索引不存在,跳过删除: {index_name}") | |
| 95 | + else: | |
| 96 | + print("\n[3/4] 跳过删除旧索引步骤") | |
| 97 | + | |
| 98 | + # 创建新索引 | |
| 99 | + print("\n[4/4] 创建新索引...") | |
| 100 | + try: | |
| 101 | + if es_client.index_exists(index_name): | |
| 102 | + print(f"✓ 索引已存在: {index_name},跳过创建") | |
| 103 | + else: | |
| 104 | + print(f"创建索引: {index_name}") | |
| 105 | + if es_client.create_index(index_name, mapping): | |
| 106 | + print(f"✓ 成功创建索引: {index_name}") | |
| 107 | + else: | |
| 108 | + print(f"✗ 创建索引失败: {index_name}") | |
| 109 | + return 1 | |
| 110 | + except Exception as e: | |
| 111 | + print(f"✗ 创建索引失败: {e}") | |
| 112 | + import traceback | |
| 113 | + traceback.print_exc() | |
| 114 | + return 1 | |
| 115 | + | |
| 116 | + # 连接MySQL | |
| 117 | + print("\n[5/5] 连接MySQL...") | |
| 118 | + db_host = args.db_host or os.environ.get('DB_HOST') | |
| 119 | + db_port = args.db_port or int(os.environ.get('DB_PORT', 3306)) | |
| 120 | + db_database = args.db_database or os.environ.get('DB_DATABASE') | |
| 121 | + db_username = args.db_username or os.environ.get('DB_USERNAME') | |
| 122 | + db_password = args.db_password or os.environ.get('DB_PASSWORD') | |
| 123 | + | |
| 124 | + if not all([db_host, db_database, db_username, db_password]): | |
| 125 | + print("✗ MySQL连接参数不完整") | |
| 126 | + print("请提供 --db-host, --db-database, --db-username, --db-password") | |
| 127 | + print("或设置环境变量: DB_HOST, DB_DATABASE, DB_USERNAME, DB_PASSWORD") | |
| 128 | + return 1 | |
| 129 | + | |
| 130 | + print(f"MySQL: {db_host}:{db_port}/{db_database}") | |
| 131 | + try: | |
| 132 | + db_engine = create_db_connection( | |
| 133 | + host=db_host, | |
| 134 | + port=db_port, | |
| 135 | + database=db_database, | |
| 136 | + username=db_username, | |
| 137 | + password=db_password | |
| 138 | + ) | |
| 139 | + print("✓ MySQL连接成功") | |
| 140 | + except Exception as e: | |
| 141 | + print(f"✗ 连接MySQL失败: {e}") | |
| 142 | + return 1 | |
| 143 | + | |
| 144 | + # 导入数据 | |
| 145 | + print("\n[6/6] 导入数据...") | |
| 146 | + print(f"Tenant ID: {args.tenant_id}") | |
| 147 | + print(f"批量大小: {args.batch_size}") | |
| 148 | + | |
| 149 | + try: | |
| 150 | + transformer = SPUTransformer(db_engine, args.tenant_id) | |
| 151 | + print("正在转换数据...") | |
| 152 | + documents = transformer.transform_batch() | |
| 153 | + print(f"✓ 转换完成: {len(documents)} 个文档") | |
| 154 | + | |
| 155 | + if not documents: | |
| 156 | + print("⚠ 没有数据需要导入") | |
| 157 | + return 0 | |
| 158 | + | |
| 159 | + print(f"正在导入数据到ES (批量大小: {args.batch_size})...") | |
| 160 | + indexer = BulkIndexer(es_client, index_name, batch_size=args.batch_size) | |
| 161 | + results = indexer.index_documents(documents, id_field="spu_id", show_progress=True) | |
| 162 | + | |
| 163 | + print(f"\n{'='*60}") | |
| 164 | + print("导入完成!") | |
| 165 | + print(f"{'='*60}") | |
| 166 | + print(f"成功: {results['success']}") | |
| 167 | + print(f"失败: {results['failed']}") | |
| 168 | + print(f"耗时: {results.get('elapsed_time', 0):.2f}秒") | |
| 169 | + | |
| 170 | + if results['failed'] > 0: | |
| 171 | + print(f"\n⚠ 警告: {results['failed']} 个文档导入失败") | |
| 172 | + return 1 | |
| 173 | + | |
| 174 | + return 0 | |
| 175 | + except Exception as e: | |
| 176 | + print(f"✗ 导入数据失败: {e}") | |
| 177 | + import traceback | |
| 178 | + traceback.print_exc() | |
| 179 | + return 1 | |
| 180 | + | |
| 181 | + | |
| 182 | +if __name__ == '__main__': | |
| 183 | + sys.exit(main()) | |
| 184 | + | ... | ... |
search/es_query_builder.py
| ... | ... | @@ -330,10 +330,15 @@ class ESQueryBuilder: |
| 330 | 330 | """ |
| 331 | 331 | 构建分面聚合。 |
| 332 | 332 | |
| 333 | + 支持: | |
| 334 | + 1. 分类分面:category1_name, category2_name, category3_name, category_name | |
| 335 | + 2. specifications分面:嵌套聚合,按name聚合,然后按value聚合 | |
| 336 | + | |
| 333 | 337 | Args: |
| 334 | 338 | facet_configs: 分面配置列表(标准格式): |
| 335 | 339 | - str: 字段名,使用默认 terms 配置 |
| 336 | 340 | - FacetConfig: 详细的分面配置对象 |
| 341 | + - 特殊值 "specifications": 构建specifications嵌套分面 | |
| 337 | 342 | |
| 338 | 343 | Returns: |
| 339 | 344 | ES aggregations 字典 |
| ... | ... | @@ -344,6 +349,34 @@ class ESQueryBuilder: |
| 344 | 349 | aggs = {} |
| 345 | 350 | |
| 346 | 351 | for config in facet_configs: |
| 352 | + # 特殊处理:specifications嵌套分面 | |
| 353 | + if isinstance(config, str) and config == "specifications": | |
| 354 | + # 构建specifications嵌套分面(按name聚合,然后按value聚合) | |
| 355 | + aggs["specifications_facet"] = { | |
| 356 | + "nested": { | |
| 357 | + "path": "specifications" | |
| 358 | + }, | |
| 359 | + "aggs": { | |
| 360 | + "by_name": { | |
| 361 | + "terms": { | |
| 362 | + "field": "specifications.name", | |
| 363 | + "size": 20, | |
| 364 | + "order": {"_count": "desc"} | |
| 365 | + }, | |
| 366 | + "aggs": { | |
| 367 | + "value_counts": { | |
| 368 | + "terms": { | |
| 369 | + "field": "specifications.value", | |
| 370 | + "size": 10, | |
| 371 | + "order": {"_count": "desc"} | |
| 372 | + } | |
| 373 | + } | |
| 374 | + } | |
| 375 | + } | |
| 376 | + } | |
| 377 | + } | |
| 378 | + continue | |
| 379 | + | |
| 347 | 380 | # 简单模式:只有字段名(字符串) |
| 348 | 381 | if isinstance(config, str): |
| 349 | 382 | field = config | ... | ... |
search/multilang_query_builder.py
| ... | ... | @@ -11,9 +11,9 @@ import numpy as np |
| 11 | 11 | import logging |
| 12 | 12 | import re |
| 13 | 13 | |
| 14 | -from config import SearchConfig, IndexConfig | |
| 15 | 14 | from query import ParsedQuery |
| 16 | 15 | from .es_query_builder import ESQueryBuilder |
| 16 | +from .query_config import DEFAULT_MATCH_FIELDS, DOMAIN_FIELDS, FUNCTION_SCORE_CONFIG | |
| 17 | 17 | |
| 18 | 18 | logger = logging.getLogger(__name__) |
| 19 | 19 | |
| ... | ... | @@ -30,8 +30,8 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 30 | 30 | |
| 31 | 31 | def __init__( |
| 32 | 32 | self, |
| 33 | - config: SearchConfig, | |
| 34 | 33 | index_name: str, |
| 34 | + match_fields: Optional[List[str]] = None, | |
| 35 | 35 | text_embedding_field: Optional[str] = None, |
| 36 | 36 | image_embedding_field: Optional[str] = None, |
| 37 | 37 | source_fields: Optional[List[str]] = None |
| ... | ... | @@ -40,53 +40,32 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 40 | 40 | Initialize multi-language query builder. |
| 41 | 41 | |
| 42 | 42 | Args: |
| 43 | - config: Search configuration | |
| 44 | 43 | index_name: ES index name |
| 44 | + match_fields: Fields to search for text matching (default: from query_config) | |
| 45 | 45 | text_embedding_field: Field name for text embeddings |
| 46 | 46 | image_embedding_field: Field name for image embeddings |
| 47 | 47 | source_fields: Fields to return in search results (_source includes) |
| 48 | 48 | """ |
| 49 | - self.config = config | |
| 50 | - self.function_score_config = config.function_score | |
| 49 | + self.function_score_config = FUNCTION_SCORE_CONFIG | |
| 51 | 50 | |
| 52 | - # For default domain, use all fields as fallback | |
| 53 | - default_fields = self._get_domain_fields("default") | |
| 51 | + # Use provided match_fields or default | |
| 52 | + if match_fields is None: | |
| 53 | + match_fields = DEFAULT_MATCH_FIELDS | |
| 54 | 54 | |
| 55 | 55 | super().__init__( |
| 56 | 56 | index_name=index_name, |
| 57 | - match_fields=default_fields, | |
| 57 | + match_fields=match_fields, | |
| 58 | 58 | text_embedding_field=text_embedding_field, |
| 59 | 59 | image_embedding_field=image_embedding_field, |
| 60 | 60 | source_fields=source_fields |
| 61 | 61 | ) |
| 62 | 62 | |
| 63 | - # Build domain configurations | |
| 64 | - self.domain_configs = self._build_domain_configs() | |
| 65 | - | |
| 66 | - def _build_domain_configs(self) -> Dict[str, IndexConfig]: | |
| 67 | - """Build mapping of domain name to IndexConfig.""" | |
| 68 | - return {index.name: index for index in self.config.indexes} | |
| 63 | + # Build domain configurations from query_config | |
| 64 | + self.domain_configs = DOMAIN_FIELDS | |
| 69 | 65 | |
| 70 | 66 | def _get_domain_fields(self, domain_name: str) -> List[str]: |
| 71 | 67 | """Get fields for a specific domain with boost notation.""" |
| 72 | - for index in self.config.indexes: | |
| 73 | - if index.name == domain_name: | |
| 74 | - result = [] | |
| 75 | - for field_name in index.fields: | |
| 76 | - field = self._get_field_by_name(field_name) | |
| 77 | - if field and field.boost != 1.0: | |
| 78 | - result.append(f"{field_name}^{field.boost}") | |
| 79 | - else: | |
| 80 | - result.append(field_name) | |
| 81 | - return result | |
| 82 | - return [] | |
| 83 | - | |
| 84 | - def _get_field_by_name(self, field_name: str): | |
| 85 | - """Get field configuration by name.""" | |
| 86 | - for field in self.config.fields: | |
| 87 | - if field.name == field_name: | |
| 88 | - return field | |
| 89 | - return None | |
| 68 | + return self.domain_configs.get(domain_name, DEFAULT_MATCH_FIELDS) | |
| 90 | 69 | |
| 91 | 70 | def build_multilang_query( |
| 92 | 71 | self, |
| ... | ... | @@ -103,7 +82,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 103 | 82 | min_score: Optional[float] = None |
| 104 | 83 | ) -> Dict[str, Any]: |
| 105 | 84 | """ |
| 106 | - Build ES query with multi-language support (重构版). | |
| 85 | + Build ES query with multi-language support (简化版). | |
| 107 | 86 | |
| 108 | 87 | Args: |
| 109 | 88 | parsed_query: Parsed query with language info and translations |
| ... | ... | @@ -120,19 +99,18 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 120 | 99 | Returns: |
| 121 | 100 | ES query DSL dictionary |
| 122 | 101 | """ |
| 123 | - domain = parsed_query.domain | |
| 124 | - domain_config = self.domain_configs.get(domain) | |
| 125 | - | |
| 126 | - if not domain_config: | |
| 127 | - # Fallback to default domain | |
| 128 | - domain = "default" | |
| 129 | - domain_config = self.domain_configs.get("default") | |
| 130 | - | |
| 131 | - if not domain_config: | |
| 132 | - # Use original behavior | |
| 102 | + # 1. 根据域选择匹配字段(默认域使用 DEFAULT_MATCH_FIELDS) | |
| 103 | + domain = parsed_query.domain or "default" | |
| 104 | + domain_fields = self.domain_configs.get(domain) or DEFAULT_MATCH_FIELDS | |
| 105 | + | |
| 106 | + # 2. 临时切换 match_fields,复用基类 build_query 逻辑 | |
| 107 | + original_match_fields = self.match_fields | |
| 108 | + self.match_fields = domain_fields | |
| 109 | + try: | |
| 133 | 110 | return super().build_query( |
| 134 | - query_text=parsed_query.rewritten_query, | |
| 111 | + query_text=parsed_query.rewritten_query or parsed_query.normalized_query, | |
| 135 | 112 | query_vector=query_vector, |
| 113 | + query_node=query_node, | |
| 136 | 114 | filters=filters, |
| 137 | 115 | range_filters=range_filters, |
| 138 | 116 | size=size, |
| ... | ... | @@ -142,95 +120,9 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 142 | 120 | knn_num_candidates=knn_num_candidates, |
| 143 | 121 | min_score=min_score |
| 144 | 122 | ) |
| 145 | - | |
| 146 | - logger.debug(f"Building query for domain: {domain}, language: {parsed_query.detected_language}") | |
| 147 | - | |
| 148 | - # Build query clause with multi-language support | |
| 149 | - if query_node and isinstance(query_node, tuple) and len(query_node) > 0: | |
| 150 | - # Handle boolean query from tuple (AST, score) | |
| 151 | - ast_node = query_node[0] | |
| 152 | - query_clause = self._build_boolean_query_from_tuple(ast_node) | |
| 153 | - logger.debug(f"Using boolean query") | |
| 154 | - elif query_node and hasattr(query_node, 'operator') and query_node.operator != 'TERM': | |
| 155 | - # Handle boolean query using base class method | |
| 156 | - query_clause = self._build_boolean_query(query_node) | |
| 157 | - logger.debug(f"Using boolean query") | |
| 158 | - else: | |
| 159 | - # Handle text query with multi-language support | |
| 160 | - query_clause = self._build_multilang_text_query(parsed_query, domain_config) | |
| 161 | - | |
| 162 | - # 构建内层bool: 文本和KNN二选一 | |
| 163 | - inner_bool_should = [query_clause] | |
| 164 | - | |
| 165 | - # 如果启用KNN,添加到should | |
| 166 | - if enable_knn and query_vector is not None and self.text_embedding_field: | |
| 167 | - knn_query = { | |
| 168 | - "knn": { | |
| 169 | - "field": self.text_embedding_field, | |
| 170 | - "query_vector": query_vector.tolist(), | |
| 171 | - "k": knn_k, | |
| 172 | - "num_candidates": knn_num_candidates | |
| 173 | - } | |
| 174 | - } | |
| 175 | - inner_bool_should.append(knn_query) | |
| 176 | - logger.info(f"KNN query added: field={self.text_embedding_field}, k={knn_k}") | |
| 177 | - else: | |
| 178 | - # Debug why KNN is not added | |
| 179 | - reasons = [] | |
| 180 | - if not enable_knn: | |
| 181 | - reasons.append("enable_knn=False") | |
| 182 | - if query_vector is None: | |
| 183 | - reasons.append("query_vector is None") | |
| 184 | - if not self.text_embedding_field: | |
| 185 | - reasons.append(f"text_embedding_field is not set (current: {self.text_embedding_field})") | |
| 186 | - logger.debug(f"KNN query NOT added. Reasons: {', '.join(reasons) if reasons else 'unknown'}") | |
| 187 | - | |
| 188 | - # 构建内层bool结构 | |
| 189 | - inner_bool = { | |
| 190 | - "bool": { | |
| 191 | - "should": inner_bool_should, | |
| 192 | - "minimum_should_match": 1 | |
| 193 | - } | |
| 194 | - } | |
| 195 | - | |
| 196 | - # 构建外层bool: 包含filter | |
| 197 | - filter_clauses = self._build_filters(filters, range_filters) if (filters or range_filters) else [] | |
| 198 | - | |
| 199 | - outer_bool = { | |
| 200 | - "bool": { | |
| 201 | - "must": [inner_bool] | |
| 202 | - } | |
| 203 | - } | |
| 204 | - | |
| 205 | - if filter_clauses: | |
| 206 | - outer_bool["bool"]["filter"] = filter_clauses | |
| 207 | - | |
| 208 | - # 包裹function_score(从配置读取score_mode和boost_mode) | |
| 209 | - function_score_query = { | |
| 210 | - "function_score": { | |
| 211 | - "query": outer_bool, | |
| 212 | - "functions": self._build_score_functions(), | |
| 213 | - "score_mode": self.function_score_config.score_mode if self.function_score_config else "sum", | |
| 214 | - "boost_mode": self.function_score_config.boost_mode if self.function_score_config else "multiply" | |
| 215 | - } | |
| 216 | - } | |
| 217 | - | |
| 218 | - es_query = { | |
| 219 | - "size": size, | |
| 220 | - "from": from_, | |
| 221 | - "query": function_score_query | |
| 222 | - } | |
| 223 | - | |
| 224 | - # Add _source filtering if source_fields are configured | |
| 225 | - if self.source_fields: | |
| 226 | - es_query["_source"] = { | |
| 227 | - "includes": self.source_fields | |
| 228 | - } | |
| 229 | - | |
| 230 | - if min_score is not None: | |
| 231 | - es_query["min_score"] = min_score | |
| 232 | - | |
| 233 | - return es_query | |
| 123 | + finally: | |
| 124 | + # 恢复原始配置,避免影响后续查询 | |
| 125 | + self.match_fields = original_match_fields | |
| 234 | 126 | |
| 235 | 127 | def _build_score_functions(self) -> List[Dict[str, Any]]: |
| 236 | 128 | """ |
| ... | ... | @@ -291,7 +183,7 @@ class MultiLanguageQueryBuilder(ESQueryBuilder): |
| 291 | 183 | def _build_multilang_text_query( |
| 292 | 184 | self, |
| 293 | 185 | parsed_query: ParsedQuery, |
| 294 | - domain_config: IndexConfig | |
| 186 | + domain_config: Dict[str, Any] | |
| 295 | 187 | ) -> Dict[str, Any]: |
| 296 | 188 | """ |
| 297 | 189 | Build text query with multi-language field routing. | ... | ... |
search/query_config.py
| ... | ... | @@ -37,17 +37,32 @@ DOMAIN_FIELDS: Dict[str, List[str]] = { |
| 37 | 37 | } |
| 38 | 38 | |
| 39 | 39 | # Source fields to return in search results |
| 40 | +# 注意:为了在后端做多语言选择,_zh / _en 字段仍然需要从 ES 取出, | |
| 41 | +# 但不会原样透出给前端,而是统一映射到 title / description / vendor 等字段。 | |
| 40 | 42 | SOURCE_FIELDS = [ |
| 43 | + # 基本标识 | |
| 41 | 44 | "tenant_id", |
| 42 | 45 | "spu_id", |
| 46 | + "create_time", | |
| 47 | + "update_time", | |
| 48 | + | |
| 49 | + # 多语言文本字段(仅用于后端选择,不直接返回给前端) | |
| 43 | 50 | "title_zh", |
| 51 | + "title_en", | |
| 44 | 52 | "brief_zh", |
| 53 | + "brief_en", | |
| 45 | 54 | "description_zh", |
| 55 | + "description_en", | |
| 46 | 56 | "vendor_zh", |
| 47 | - "tags", | |
| 48 | - "image_url", | |
| 57 | + "vendor_en", | |
| 49 | 58 | "category_path_zh", |
| 59 | + "category_path_en", | |
| 50 | 60 | "category_name_zh", |
| 61 | + "category_name_en", | |
| 62 | + | |
| 63 | + # 语言无关字段(直接返回给前端) | |
| 64 | + "tags", | |
| 65 | + "image_url", | |
| 51 | 66 | "category_id", |
| 52 | 67 | "category_name", |
| 53 | 68 | "category_level", |
| ... | ... | @@ -60,11 +75,12 @@ SOURCE_FIELDS = [ |
| 60 | 75 | "min_price", |
| 61 | 76 | "max_price", |
| 62 | 77 | "compare_at_price", |
| 78 | + "sku_prices", | |
| 79 | + "sku_weights", | |
| 80 | + "sku_weight_units", | |
| 63 | 81 | "total_inventory", |
| 64 | - "create_time", | |
| 65 | - "update_time", | |
| 66 | 82 | "skus", |
| 67 | - "specifications" | |
| 83 | + "specifications", | |
| 68 | 84 | ] |
| 69 | 85 | |
| 70 | 86 | # Query processing settings |
| ... | ... | @@ -112,3 +128,13 @@ def load_rewrite_dictionary() -> Dict[str, str]: |
| 112 | 128 | |
| 113 | 129 | REWRITE_DICTIONARY = load_rewrite_dictionary() |
| 114 | 130 | |
| 131 | +# Default facets for faceted search | |
| 132 | +# 分类分面:使用category1_name, category2_name, category3_name | |
| 133 | +# specifications分面:使用嵌套聚合,按name分组,然后按value聚合 | |
| 134 | +DEFAULT_FACETS = [ | |
| 135 | + "category1_name", # 一级分类 | |
| 136 | + "category2_name", # 二级分类 | |
| 137 | + "category3_name", # 三级分类 | |
| 138 | + "specifications" # 规格分面(特殊处理:嵌套聚合) | |
| 139 | +] | |
| 140 | + | ... | ... |
search/searcher.py
| ... | ... | @@ -8,15 +8,23 @@ from typing import Dict, Any, List, Optional, Union |
| 8 | 8 | import time |
| 9 | 9 | import logging |
| 10 | 10 | |
| 11 | -from config import SearchConfig | |
| 12 | 11 | from utils.es_client import ESClient |
| 13 | 12 | from query import QueryParser, ParsedQuery |
| 14 | -from indexer import MappingGenerator | |
| 15 | 13 | from embeddings import CLIPImageEncoder |
| 16 | 14 | from .boolean_parser import BooleanParser, QueryNode |
| 17 | 15 | from .es_query_builder import ESQueryBuilder |
| 18 | 16 | from .multilang_query_builder import MultiLanguageQueryBuilder |
| 19 | 17 | from .rerank_engine import RerankEngine |
| 18 | +from .query_config import ( | |
| 19 | + DEFAULT_INDEX_NAME, | |
| 20 | + DEFAULT_MATCH_FIELDS, | |
| 21 | + TEXT_EMBEDDING_FIELD, | |
| 22 | + IMAGE_EMBEDDING_FIELD, | |
| 23 | + SOURCE_FIELDS, | |
| 24 | + ENABLE_TRANSLATION, | |
| 25 | + ENABLE_TEXT_EMBEDDING, | |
| 26 | + RANKING_EXPRESSION | |
| 27 | +) | |
| 20 | 28 | from context.request_context import RequestContext, RequestContextStage, create_request_context |
| 21 | 29 | from api.models import FacetResult, FacetValue |
| 22 | 30 | from api.result_formatter import ResultFormatter |
| ... | ... | @@ -79,39 +87,38 @@ class Searcher: |
| 79 | 87 | |
| 80 | 88 | def __init__( |
| 81 | 89 | self, |
| 82 | - config: SearchConfig, | |
| 83 | 90 | es_client: ESClient, |
| 84 | - query_parser: Optional[QueryParser] = None | |
| 91 | + query_parser: Optional[QueryParser] = None, | |
| 92 | + index_name: str = DEFAULT_INDEX_NAME | |
| 85 | 93 | ): |
| 86 | 94 | """ |
| 87 | 95 | Initialize searcher. |
| 88 | 96 | |
| 89 | 97 | Args: |
| 90 | - config: Search configuration | |
| 91 | 98 | es_client: Elasticsearch client |
| 92 | 99 | query_parser: Query parser (created if not provided) |
| 100 | + index_name: ES index name (default: search_products) | |
| 93 | 101 | """ |
| 94 | - self.config = config | |
| 95 | 102 | self.es_client = es_client |
| 96 | - self.query_parser = query_parser or QueryParser(config) | |
| 103 | + self.index_name = index_name | |
| 104 | + self.query_parser = query_parser or QueryParser() | |
| 97 | 105 | |
| 98 | 106 | # Initialize components |
| 99 | 107 | self.boolean_parser = BooleanParser() |
| 100 | - self.rerank_engine = RerankEngine(config.ranking.expression, enabled=False) | |
| 108 | + self.rerank_engine = RerankEngine(RANKING_EXPRESSION, enabled=False) | |
| 101 | 109 | |
| 102 | - # Get mapping info | |
| 103 | - mapping_gen = MappingGenerator(config) | |
| 104 | - self.match_fields = mapping_gen.get_match_fields_for_domain("default") | |
| 105 | - self.text_embedding_field = mapping_gen.get_text_embedding_field() | |
| 106 | - self.image_embedding_field = mapping_gen.get_image_embedding_field() | |
| 110 | + # Use constants from query_config | |
| 111 | + self.match_fields = DEFAULT_MATCH_FIELDS | |
| 112 | + self.text_embedding_field = TEXT_EMBEDDING_FIELD | |
| 113 | + self.image_embedding_field = IMAGE_EMBEDDING_FIELD | |
| 107 | 114 | |
| 108 | 115 | # Query builder - use multi-language version |
| 109 | 116 | self.query_builder = MultiLanguageQueryBuilder( |
| 110 | - config=config, | |
| 111 | - index_name=config.es_index_name, | |
| 117 | + index_name=index_name, | |
| 118 | + match_fields=self.match_fields, | |
| 112 | 119 | text_embedding_field=self.text_embedding_field, |
| 113 | 120 | image_embedding_field=self.image_embedding_field, |
| 114 | - source_fields=config.query_config.source_fields | |
| 121 | + source_fields=SOURCE_FIELDS | |
| 115 | 122 | ) |
| 116 | 123 | |
| 117 | 124 | def search( |
| ... | ... | @@ -127,7 +134,8 @@ class Searcher: |
| 127 | 134 | context: Optional[RequestContext] = None, |
| 128 | 135 | sort_by: Optional[str] = None, |
| 129 | 136 | sort_order: Optional[str] = "desc", |
| 130 | - debug: bool = False | |
| 137 | + debug: bool = False, | |
| 138 | + language: str = "zh", | |
| 131 | 139 | ) -> SearchResult: |
| 132 | 140 | """ |
| 133 | 141 | Execute search query (外部友好格式). |
| ... | ... | @@ -154,8 +162,8 @@ class Searcher: |
| 154 | 162 | context = create_request_context() |
| 155 | 163 | |
| 156 | 164 | # Always use config defaults (these are backend configuration, not user parameters) |
| 157 | - enable_translation = self.config.query_config.enable_translation | |
| 158 | - enable_embedding = self.config.query_config.enable_text_embedding | |
| 165 | + enable_translation = ENABLE_TRANSLATION | |
| 166 | + enable_embedding = ENABLE_TEXT_EMBEDDING | |
| 159 | 167 | enable_rerank = False # Temporarily disabled |
| 160 | 168 | |
| 161 | 169 | # Start timing |
| ... | ... | @@ -278,14 +286,6 @@ class Searcher: |
| 278 | 286 | min_score=min_score |
| 279 | 287 | ) |
| 280 | 288 | |
| 281 | - # Add SPU collapse if configured | |
| 282 | - if self.config.spu_config.enabled: | |
| 283 | - es_query = self.query_builder.add_spu_collapse( | |
| 284 | - es_query, | |
| 285 | - self.config.spu_config.spu_field, | |
| 286 | - self.config.spu_config.inner_hits_size | |
| 287 | - ) | |
| 288 | - | |
| 289 | 289 | # Add facets for faceted search |
| 290 | 290 | if facets: |
| 291 | 291 | facet_aggs = self.query_builder.build_facets(facets) |
| ... | ... | @@ -329,7 +329,7 @@ class Searcher: |
| 329 | 329 | context.start_stage(RequestContextStage.ELASTICSEARCH_SEARCH) |
| 330 | 330 | try: |
| 331 | 331 | es_response = self.es_client.search( |
| 332 | - index_name=self.config.es_index_name, | |
| 332 | + index_name=self.index_name, | |
| 333 | 333 | body=body_for_es, |
| 334 | 334 | size=size, |
| 335 | 335 | from_=from_ |
| ... | ... | @@ -374,7 +374,11 @@ class Searcher: |
| 374 | 374 | max_score = es_response.get('hits', {}).get('max_score') or 0.0 |
| 375 | 375 | |
| 376 | 376 | # Format results using ResultFormatter |
| 377 | - formatted_results = ResultFormatter.format_search_results(es_hits, max_score) | |
| 377 | + formatted_results = ResultFormatter.format_search_results( | |
| 378 | + es_hits, | |
| 379 | + max_score, | |
| 380 | + language=language | |
| 381 | + ) | |
| 378 | 382 | |
| 379 | 383 | # Format facets |
| 380 | 384 | standardized_facets = None |
| ... | ... | @@ -503,9 +507,9 @@ class Searcher: |
| 503 | 507 | } |
| 504 | 508 | |
| 505 | 509 | # Add _source filtering if source_fields are configured |
| 506 | - if self.config.query_config.source_fields: | |
| 510 | + if SOURCE_FIELDS: | |
| 507 | 511 | es_query["_source"] = { |
| 508 | - "includes": self.config.query_config.source_fields | |
| 512 | + "includes": SOURCE_FIELDS | |
| 509 | 513 | } |
| 510 | 514 | |
| 511 | 515 | if filters or range_filters: |
| ... | ... | @@ -519,7 +523,7 @@ class Searcher: |
| 519 | 523 | |
| 520 | 524 | # Execute search |
| 521 | 525 | es_response = self.es_client.search( |
| 522 | - index_name=self.config.es_index_name, | |
| 526 | + index_name=self.index_name, | |
| 523 | 527 | body=es_query, |
| 524 | 528 | size=size |
| 525 | 529 | ) |
| ... | ... | @@ -573,7 +577,7 @@ class Searcher: |
| 573 | 577 | """ |
| 574 | 578 | try: |
| 575 | 579 | response = self.es_client.client.get( |
| 576 | - index=self.config.es_index_name, | |
| 580 | + index=self.index_name, | |
| 577 | 581 | id=doc_id |
| 578 | 582 | ) |
| 579 | 583 | return response.get('_source') |
| ... | ... | @@ -657,10 +661,11 @@ class Searcher: |
| 657 | 661 | |
| 658 | 662 | def _get_field_label(self, field: str) -> str: |
| 659 | 663 | """获取字段的显示标签""" |
| 660 | - # 从配置中获取字段标签 | |
| 661 | - for field_config in self.config.fields: | |
| 662 | - if field_config.name == field: | |
| 663 | - # 尝试获取 label 属性 | |
| 664 | - return getattr(field_config, 'label', field) | |
| 665 | - # 如果没有配置,返回字段名 | |
| 666 | - return field | |
| 664 | + # 字段标签映射(简化版,不再从配置读取) | |
| 665 | + field_labels = { | |
| 666 | + "category1_name": "一级分类", | |
| 667 | + "category2_name": "二级分类", | |
| 668 | + "category3_name": "三级分类", | |
| 669 | + "specifications": "规格" | |
| 670 | + } | |
| 671 | + return field_labels.get(field, field) | ... | ... |
utils/es_client.py
| ... | ... | @@ -172,6 +172,18 @@ class ESClient: |
| 172 | 172 | Returns: |
| 173 | 173 | Search results |
| 174 | 174 | """ |
| 175 | + # Safety guard: collapse is no longer needed (index is already SPU-level). | |
| 176 | + # If any caller accidentally adds a collapse clause (e.g. on product_id), | |
| 177 | + # strip it here to avoid 400 errors like: | |
| 178 | + # "no mapping found for `product_id` in order to collapse on" | |
| 179 | + if isinstance(body, dict) and "collapse" in body: | |
| 180 | + logger.warning( | |
| 181 | + "Removing unsupported 'collapse' clause from ES query body: %s", | |
| 182 | + body.get("collapse") | |
| 183 | + ) | |
| 184 | + body = dict(body) # shallow copy to avoid mutating caller | |
| 185 | + body.pop("collapse", None) | |
| 186 | + | |
| 175 | 187 | try: |
| 176 | 188 | return self.client.search( |
| 177 | 189 | index=index_name, | ... | ... |