""" Configuration loader and validator for customer-specific search configurations. This module handles loading, parsing, and validating YAML configuration files that define how each customer's data should be indexed and searched. """ import yaml import os from typing import Dict, Any, List, Optional from dataclasses import dataclass, field from pathlib import Path from .field_types import ( FieldConfig, FieldType, AnalyzerType, FIELD_TYPE_MAP, ANALYZER_MAP ) @dataclass class IndexConfig: """Configuration for an index domain (e.g., default, title, brand).""" name: str label: str fields: List[str] # List of field names to include analyzer: AnalyzerType boost: float = 1.0 example: Optional[str] = None @dataclass class RankingConfig: """Configuration for ranking expressions.""" expression: str # e.g., "bm25() + 0.2*text_embedding_relevance()" description: str @dataclass class QueryConfig: """Configuration for query processing.""" supported_languages: List[str] = field(default_factory=lambda: ["zh", "en"]) default_language: str = "zh" enable_translation: bool = True enable_text_embedding: bool = True enable_query_rewrite: bool = True rewrite_dictionary: Dict[str, str] = field(default_factory=dict) # Translation API settings translation_api_key: Optional[str] = None translation_service: str = "deepl" # deepl, google, etc. @dataclass class SPUConfig: """Configuration for SPU aggregation.""" enabled: bool = False spu_field: Optional[str] = None # Field containing SPU ID inner_hits_size: int = 3 @dataclass class CustomerConfig: """Complete configuration for a customer.""" customer_id: str customer_name: str # Database settings mysql_config: Dict[str, Any] main_table: str = "shoplazza_product_sku" extension_table: Optional[str] = None # Field definitions fields: List[FieldConfig] # Index structure (query domains) indexes: List[IndexConfig] # Query processing query_config: QueryConfig # Ranking configuration ranking: RankingConfig # SPU configuration spu_config: SPUConfig # ES index settings es_index_name: str es_settings: Dict[str, Any] = field(default_factory=dict) class ConfigurationError(Exception): """Raised when configuration validation fails.""" pass class ConfigLoader: """Loads and validates customer configurations from YAML files.""" def __init__(self, config_dir: str = "config/schema"): self.config_dir = Path(config_dir) def load_customer_config(self, customer_id: str) -> CustomerConfig: """ Load customer configuration from YAML file. Args: customer_id: Customer identifier (used to find config file) Returns: CustomerConfig object Raises: ConfigurationError: If config file not found or invalid """ config_file = self.config_dir / f"{customer_id}_config.yaml" if not config_file.exists(): raise ConfigurationError(f"Configuration file not found: {config_file}") try: with open(config_file, 'r', encoding='utf-8') as f: config_data = yaml.safe_load(f) except yaml.YAMLError as e: raise ConfigurationError(f"Invalid YAML in {config_file}: {e}") return self._parse_config(config_data, customer_id) def _parse_config(self, config_data: Dict[str, Any], customer_id: str) -> CustomerConfig: """Parse configuration dictionary into CustomerConfig object.""" # Parse fields fields = [] for field_data in config_data.get("fields", []): fields.append(self._parse_field_config(field_data)) # Parse indexes indexes = [] for index_data in config_data.get("indexes", []): indexes.append(self._parse_index_config(index_data)) # Parse query config query_config_data = config_data.get("query_config", {}) query_config = QueryConfig( supported_languages=query_config_data.get("supported_languages", ["zh", "en"]), default_language=query_config_data.get("default_language", "zh"), enable_translation=query_config_data.get("enable_translation", True), enable_text_embedding=query_config_data.get("enable_text_embedding", True), enable_query_rewrite=query_config_data.get("enable_query_rewrite", True), rewrite_dictionary=query_config_data.get("rewrite_dictionary", {}), translation_api_key=query_config_data.get("translation_api_key"), translation_service=query_config_data.get("translation_service", "deepl") ) # Parse ranking config ranking_data = config_data.get("ranking", {}) ranking = RankingConfig( expression=ranking_data.get("expression", "bm25() + 0.2*text_embedding_relevance()"), description=ranking_data.get("description", "Default BM25 + text embedding ranking") ) # Parse SPU config spu_data = config_data.get("spu_config", {}) spu_config = SPUConfig( enabled=spu_data.get("enabled", False), spu_field=spu_data.get("spu_field"), inner_hits_size=spu_data.get("inner_hits_size", 3) ) return CustomerConfig( customer_id=customer_id, customer_name=config_data.get("customer_name", customer_id), mysql_config=config_data.get("mysql_config", {}), main_table=config_data.get("main_table", "shoplazza_product_sku"), extension_table=config_data.get("extension_table"), fields=fields, indexes=indexes, query_config=query_config, ranking=ranking, spu_config=spu_config, es_index_name=config_data.get("es_index_name", f"search_{customer_id}"), es_settings=config_data.get("es_settings", {}) ) def _parse_field_config(self, field_data: Dict[str, Any]) -> FieldConfig: """Parse field configuration from dictionary.""" name = field_data["name"] field_type_str = field_data["type"] # Map field type string to enum if field_type_str not in FIELD_TYPE_MAP: raise ConfigurationError(f"Unknown field type: {field_type_str}") field_type = FIELD_TYPE_MAP[field_type_str] # Map analyzer string to enum (if provided) analyzer = None analyzer_str = field_data.get("analyzer") if analyzer_str and analyzer_str in ANALYZER_MAP: analyzer = ANALYZER_MAP[analyzer_str] search_analyzer = None search_analyzer_str = field_data.get("search_analyzer") if search_analyzer_str and search_analyzer_str in ANALYZER_MAP: search_analyzer = ANALYZER_MAP[search_analyzer_str] return FieldConfig( name=name, field_type=field_type, source_table=field_data.get("source_table"), source_column=field_data.get("source_column", name), analyzer=analyzer, search_analyzer=search_analyzer, required=field_data.get("required", False), multi_language=field_data.get("multi_language", False), languages=field_data.get("languages"), boost=field_data.get("boost", 1.0), store=field_data.get("store", False), index=field_data.get("index", True), embedding_dims=field_data.get("embedding_dims", 1024), embedding_similarity=field_data.get("embedding_similarity", "dot_product"), nested=field_data.get("nested", False), nested_properties=field_data.get("nested_properties") ) def _parse_index_config(self, index_data: Dict[str, Any]) -> IndexConfig: """Parse index configuration from dictionary.""" analyzer_str = index_data.get("analyzer", "chinese_ecommerce") if analyzer_str not in ANALYZER_MAP: raise ConfigurationError(f"Unknown analyzer: {analyzer_str}") return IndexConfig( name=index_data["name"], label=index_data.get("label", index_data["name"]), fields=index_data["fields"], analyzer=ANALYZER_MAP[analyzer_str], boost=index_data.get("boost", 1.0), example=index_data.get("example") ) def validate_config(self, config: CustomerConfig) -> List[str]: """ Validate customer configuration. Args: config: Customer configuration to validate Returns: List of validation error messages (empty if valid) """ errors = [] # Validate field references in indexes field_names = {field.name for field in config.fields} for index in config.indexes: for field_name in index.fields: if field_name not in field_names: errors.append(f"Index '{index.name}' references unknown field '{field_name}'") # Validate SPU config if config.spu_config.enabled: if not config.spu_config.spu_field: errors.append("SPU aggregation enabled but no spu_field specified") elif config.spu_config.spu_field not in field_names: errors.append(f"SPU field '{config.spu_config.spu_field}' not found in fields") # Validate embedding fields have proper configuration for field in config.fields: if field.field_type in [FieldType.TEXT_EMBEDDING, FieldType.IMAGE_EMBEDDING]: if field.embedding_dims <= 0: errors.append(f"Field '{field.name}': embedding_dims must be positive") if field.embedding_similarity not in ["dot_product", "cosine", "l2_norm"]: errors.append(f"Field '{field.name}': invalid embedding_similarity") # Validate MySQL config if "host" not in config.mysql_config: errors.append("MySQL configuration missing 'host'") if "username" not in config.mysql_config: errors.append("MySQL configuration missing 'username'") if "password" not in config.mysql_config: errors.append("MySQL configuration missing 'password'") if "database" not in config.mysql_config: errors.append("MySQL configuration missing 'database'") return errors def save_config(self, config: CustomerConfig, output_path: Optional[str] = None) -> None: """ Save customer configuration to YAML file. Args: config: Configuration to save output_path: Optional output path (defaults to config dir) """ if output_path is None: output_path = self.config_dir / f"{config.customer_id}_config.yaml" # Convert config back to dictionary format config_dict = { "customer_name": config.customer_name, "mysql_config": config.mysql_config, "main_table": config.main_table, "extension_table": config.extension_table, "es_index_name": config.es_index_name, "es_settings": config.es_settings, "fields": [self._field_to_dict(field) for field in config.fields], "indexes": [self._index_to_dict(index) for index in config.indexes], "query_config": { "supported_languages": config.query_config.supported_languages, "default_language": config.query_config.default_language, "enable_translation": config.query_config.enable_translation, "enable_text_embedding": config.query_config.enable_text_embedding, "enable_query_rewrite": config.query_config.enable_query_rewrite, "rewrite_dictionary": config.query_config.rewrite_dictionary, "translation_api_key": config.query_config.translation_api_key, "translation_service": config.query_config.translation_service, }, "ranking": { "expression": config.ranking.expression, "description": config.ranking.description }, "spu_config": { "enabled": config.spu_config.enabled, "spu_field": config.spu_config.spu_field, "inner_hits_size": config.spu_config.inner_hits_size } } with open(output_path, 'w', encoding='utf-8') as f: yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True) def _field_to_dict(self, field: FieldConfig) -> Dict[str, Any]: """Convert FieldConfig to dictionary.""" result = { "name": field.name, "type": field.field_type.value, "source_table": field.source_table, "source_column": field.source_column, "required": field.required, "boost": field.boost, "store": field.store, "index": field.index, } if field.analyzer: result["analyzer"] = field.analyzer.value if field.search_analyzer: result["search_analyzer"] = field.search_analyzer.value if field.multi_language: result["multi_language"] = field.multi_language result["languages"] = field.languages if field.embedding_dims != 1024: result["embedding_dims"] = field.embedding_dims if field.embedding_similarity != "dot_product": result["embedding_similarity"] = field.embedding_similarity if field.nested: result["nested"] = field.nested result["nested_properties"] = field.nested_properties return result def _index_to_dict(self, index: IndexConfig) -> Dict[str, Any]: """Convert IndexConfig to dictionary.""" return { "name": index.name, "label": index.label, "fields": index.fields, "analyzer": index.analyzer.value, "boost": index.boost, "example": index.example }