config_loader.py 13.9 KB
"""
Configuration loader and validator for customer-specific search configurations.

This module handles loading, parsing, and validating YAML configuration files
that define how each customer's data should be indexed and searched.
"""

import yaml
import os
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from pathlib import Path

from .field_types import (
    FieldConfig, FieldType, AnalyzerType,
    FIELD_TYPE_MAP, ANALYZER_MAP
)


@dataclass
class IndexConfig:
    """Configuration for an index domain (e.g., default, title, brand)."""
    name: str
    label: str
    fields: List[str]  # List of field names to include
    analyzer: AnalyzerType
    boost: float = 1.0
    example: Optional[str] = None


@dataclass
class RankingConfig:
    """Configuration for ranking expressions."""
    expression: str  # e.g., "bm25() + 0.2*text_embedding_relevance()"
    description: str


@dataclass
class QueryConfig:
    """Configuration for query processing."""
    supported_languages: List[str] = field(default_factory=lambda: ["zh", "en"])
    default_language: str = "zh"
    enable_translation: bool = True
    enable_text_embedding: bool = True
    enable_query_rewrite: bool = True
    rewrite_dictionary: Dict[str, str] = field(default_factory=dict)

    # Translation API settings
    translation_api_key: Optional[str] = None
    translation_service: str = "deepl"  # deepl, google, etc.


@dataclass
class SPUConfig:
    """Configuration for SPU aggregation."""
    enabled: bool = False
    spu_field: Optional[str] = None  # Field containing SPU ID
    inner_hits_size: int = 3


@dataclass
class CustomerConfig:
    """Complete configuration for a customer."""
    customer_id: str
    customer_name: str

    # Database settings
    mysql_config: Dict[str, Any]
    main_table: str = "shoplazza_product_sku"
    extension_table: Optional[str] = None

    # Field definitions
    fields: List[FieldConfig]

    # Index structure (query domains)
    indexes: List[IndexConfig]

    # Query processing
    query_config: QueryConfig

    # Ranking configuration
    ranking: RankingConfig

    # SPU configuration
    spu_config: SPUConfig

    # ES index settings
    es_index_name: str
    es_settings: Dict[str, Any] = field(default_factory=dict)


class ConfigurationError(Exception):
    """Raised when configuration validation fails."""
    pass


class ConfigLoader:
    """Loads and validates customer configurations from YAML files."""

    def __init__(self, config_dir: str = "config/schema"):
        self.config_dir = Path(config_dir)

    def load_customer_config(self, customer_id: str) -> CustomerConfig:
        """
        Load customer configuration from YAML file.

        Args:
            customer_id: Customer identifier (used to find config file)

        Returns:
            CustomerConfig object

        Raises:
            ConfigurationError: If config file not found or invalid
        """
        config_file = self.config_dir / f"{customer_id}_config.yaml"

        if not config_file.exists():
            raise ConfigurationError(f"Configuration file not found: {config_file}")

        try:
            with open(config_file, 'r', encoding='utf-8') as f:
                config_data = yaml.safe_load(f)
        except yaml.YAMLError as e:
            raise ConfigurationError(f"Invalid YAML in {config_file}: {e}")

        return self._parse_config(config_data, customer_id)

    def _parse_config(self, config_data: Dict[str, Any], customer_id: str) -> CustomerConfig:
        """Parse configuration dictionary into CustomerConfig object."""

        # Parse fields
        fields = []
        for field_data in config_data.get("fields", []):
            fields.append(self._parse_field_config(field_data))

        # Parse indexes
        indexes = []
        for index_data in config_data.get("indexes", []):
            indexes.append(self._parse_index_config(index_data))

        # Parse query config
        query_config_data = config_data.get("query_config", {})
        query_config = QueryConfig(
            supported_languages=query_config_data.get("supported_languages", ["zh", "en"]),
            default_language=query_config_data.get("default_language", "zh"),
            enable_translation=query_config_data.get("enable_translation", True),
            enable_text_embedding=query_config_data.get("enable_text_embedding", True),
            enable_query_rewrite=query_config_data.get("enable_query_rewrite", True),
            rewrite_dictionary=query_config_data.get("rewrite_dictionary", {}),
            translation_api_key=query_config_data.get("translation_api_key"),
            translation_service=query_config_data.get("translation_service", "deepl")
        )

        # Parse ranking config
        ranking_data = config_data.get("ranking", {})
        ranking = RankingConfig(
            expression=ranking_data.get("expression", "bm25() + 0.2*text_embedding_relevance()"),
            description=ranking_data.get("description", "Default BM25 + text embedding ranking")
        )

        # Parse SPU config
        spu_data = config_data.get("spu_config", {})
        spu_config = SPUConfig(
            enabled=spu_data.get("enabled", False),
            spu_field=spu_data.get("spu_field"),
            inner_hits_size=spu_data.get("inner_hits_size", 3)
        )

        return CustomerConfig(
            customer_id=customer_id,
            customer_name=config_data.get("customer_name", customer_id),
            mysql_config=config_data.get("mysql_config", {}),
            main_table=config_data.get("main_table", "shoplazza_product_sku"),
            extension_table=config_data.get("extension_table"),
            fields=fields,
            indexes=indexes,
            query_config=query_config,
            ranking=ranking,
            spu_config=spu_config,
            es_index_name=config_data.get("es_index_name", f"search_{customer_id}"),
            es_settings=config_data.get("es_settings", {})
        )

    def _parse_field_config(self, field_data: Dict[str, Any]) -> FieldConfig:
        """Parse field configuration from dictionary."""
        name = field_data["name"]
        field_type_str = field_data["type"]

        # Map field type string to enum
        if field_type_str not in FIELD_TYPE_MAP:
            raise ConfigurationError(f"Unknown field type: {field_type_str}")
        field_type = FIELD_TYPE_MAP[field_type_str]

        # Map analyzer string to enum (if provided)
        analyzer = None
        analyzer_str = field_data.get("analyzer")
        if analyzer_str and analyzer_str in ANALYZER_MAP:
            analyzer = ANALYZER_MAP[analyzer_str]

        search_analyzer = None
        search_analyzer_str = field_data.get("search_analyzer")
        if search_analyzer_str and search_analyzer_str in ANALYZER_MAP:
            search_analyzer = ANALYZER_MAP[search_analyzer_str]

        return FieldConfig(
            name=name,
            field_type=field_type,
            source_table=field_data.get("source_table"),
            source_column=field_data.get("source_column", name),
            analyzer=analyzer,
            search_analyzer=search_analyzer,
            required=field_data.get("required", False),
            multi_language=field_data.get("multi_language", False),
            languages=field_data.get("languages"),
            boost=field_data.get("boost", 1.0),
            store=field_data.get("store", False),
            index=field_data.get("index", True),
            embedding_dims=field_data.get("embedding_dims", 1024),
            embedding_similarity=field_data.get("embedding_similarity", "dot_product"),
            nested=field_data.get("nested", False),
            nested_properties=field_data.get("nested_properties")
        )

    def _parse_index_config(self, index_data: Dict[str, Any]) -> IndexConfig:
        """Parse index configuration from dictionary."""
        analyzer_str = index_data.get("analyzer", "chinese_ecommerce")
        if analyzer_str not in ANALYZER_MAP:
            raise ConfigurationError(f"Unknown analyzer: {analyzer_str}")

        return IndexConfig(
            name=index_data["name"],
            label=index_data.get("label", index_data["name"]),
            fields=index_data["fields"],
            analyzer=ANALYZER_MAP[analyzer_str],
            boost=index_data.get("boost", 1.0),
            example=index_data.get("example")
        )

    def validate_config(self, config: CustomerConfig) -> List[str]:
        """
        Validate customer configuration.

        Args:
            config: Customer configuration to validate

        Returns:
            List of validation error messages (empty if valid)
        """
        errors = []

        # Validate field references in indexes
        field_names = {field.name for field in config.fields}
        for index in config.indexes:
            for field_name in index.fields:
                if field_name not in field_names:
                    errors.append(f"Index '{index.name}' references unknown field '{field_name}'")

        # Validate SPU config
        if config.spu_config.enabled:
            if not config.spu_config.spu_field:
                errors.append("SPU aggregation enabled but no spu_field specified")
            elif config.spu_config.spu_field not in field_names:
                errors.append(f"SPU field '{config.spu_config.spu_field}' not found in fields")

        # Validate embedding fields have proper configuration
        for field in config.fields:
            if field.field_type in [FieldType.TEXT_EMBEDDING, FieldType.IMAGE_EMBEDDING]:
                if field.embedding_dims <= 0:
                    errors.append(f"Field '{field.name}': embedding_dims must be positive")
                if field.embedding_similarity not in ["dot_product", "cosine", "l2_norm"]:
                    errors.append(f"Field '{field.name}': invalid embedding_similarity")

        # Validate MySQL config
        if "host" not in config.mysql_config:
            errors.append("MySQL configuration missing 'host'")
        if "username" not in config.mysql_config:
            errors.append("MySQL configuration missing 'username'")
        if "password" not in config.mysql_config:
            errors.append("MySQL configuration missing 'password'")
        if "database" not in config.mysql_config:
            errors.append("MySQL configuration missing 'database'")

        return errors

    def save_config(self, config: CustomerConfig, output_path: Optional[str] = None) -> None:
        """
        Save customer configuration to YAML file.

        Args:
            config: Configuration to save
            output_path: Optional output path (defaults to config dir)
        """
        if output_path is None:
            output_path = self.config_dir / f"{config.customer_id}_config.yaml"

        # Convert config back to dictionary format
        config_dict = {
            "customer_name": config.customer_name,
            "mysql_config": config.mysql_config,
            "main_table": config.main_table,
            "extension_table": config.extension_table,
            "es_index_name": config.es_index_name,
            "es_settings": config.es_settings,
            "fields": [self._field_to_dict(field) for field in config.fields],
            "indexes": [self._index_to_dict(index) for index in config.indexes],
            "query_config": {
                "supported_languages": config.query_config.supported_languages,
                "default_language": config.query_config.default_language,
                "enable_translation": config.query_config.enable_translation,
                "enable_text_embedding": config.query_config.enable_text_embedding,
                "enable_query_rewrite": config.query_config.enable_query_rewrite,
                "rewrite_dictionary": config.query_config.rewrite_dictionary,
                "translation_api_key": config.query_config.translation_api_key,
                "translation_service": config.query_config.translation_service,
            },
            "ranking": {
                "expression": config.ranking.expression,
                "description": config.ranking.description
            },
            "spu_config": {
                "enabled": config.spu_config.enabled,
                "spu_field": config.spu_config.spu_field,
                "inner_hits_size": config.spu_config.inner_hits_size
            }
        }

        with open(output_path, 'w', encoding='utf-8') as f:
            yaml.dump(config_dict, f, default_flow_style=False, allow_unicode=True)

    def _field_to_dict(self, field: FieldConfig) -> Dict[str, Any]:
        """Convert FieldConfig to dictionary."""
        result = {
            "name": field.name,
            "type": field.field_type.value,
            "source_table": field.source_table,
            "source_column": field.source_column,
            "required": field.required,
            "boost": field.boost,
            "store": field.store,
            "index": field.index,
        }

        if field.analyzer:
            result["analyzer"] = field.analyzer.value
        if field.search_analyzer:
            result["search_analyzer"] = field.search_analyzer.value
        if field.multi_language:
            result["multi_language"] = field.multi_language
            result["languages"] = field.languages
        if field.embedding_dims != 1024:
            result["embedding_dims"] = field.embedding_dims
        if field.embedding_similarity != "dot_product":
            result["embedding_similarity"] = field.embedding_similarity
        if field.nested:
            result["nested"] = field.nested
            result["nested_properties"] = field.nested_properties

        return result

    def _index_to_dict(self, index: IndexConfig) -> Dict[str, Any]:
        """Convert IndexConfig to dictionary."""
        return {
            "name": index.name,
            "label": index.label,
            "fields": index.fields,
            "analyzer": index.analyzer.value,
            "boost": index.boost,
            "example": index.example
        }