"""
请求粒度的上下文管理器

用于存储查询分析结果、各检索阶段中间结果、性能指标等。
支持线程安全的并发请求处理。
"""

import time
import threading
from enum import Enum
from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field
import uuid


class RequestContextStage(Enum):
    """搜索阶段枚举"""
    TOTAL = "total_search"
    QUERY_PARSING = "query_parsing"
    BOOLEAN_PARSING = "boolean_parsing"
    QUERY_BUILDING = "query_building"
    ELASTICSEARCH_SEARCH = "elasticsearch_search"
    RESULT_PROCESSING = "result_processing"
    RERANKING = "reranking"


@dataclass
class QueryAnalysisResult:
    """查询分析结果"""
    original_query: Optional[str] = None
    normalized_query: Optional[str] = None
    rewritten_query: Optional[str] = None
    detected_language: Optional[str] = None
    translations: Dict[str, str] = field(default_factory=dict)
    query_vector: Optional[List[float]] = None
    boolean_ast: Optional[str] = None
    is_simple_query: bool = True
    domain: str = "default"


@dataclass
class PerformanceMetrics:
    """性能指标"""
    stage_timings: Dict[str, float] = field(default_factory=dict)
    stage_start_times: Dict[str, float] = field(default_factory=dict)
    total_duration: float = 0.0
    extra_metrics: Dict[str, Any] = field(default_factory=dict)


class RequestContext:
    """
    请求粒度的上下文管理器

    功能：
    1. 存储查询分析结果和各阶段中间结果
    2. 自动跟踪各阶段耗时
    3. 提供线程安全的上下文访问
    4. 支持上下文管理器模式
    """

    def __init__(self, reqid: str = None, uid: str = None):
        # 生成唯一请求ID；如果外部未提供，则自动生成
        # 如果无法获取到 uid，则使用 "-1" 作为占位，用于日志关联
        self.reqid = reqid or str(uuid.uuid4())[:8]
        self.uid = uid or "-1"

        # 查询分析结果
        self.query_analysis = QueryAnalysisResult()

        # 各检索阶段中间结果
        self.intermediate_results = {
            'parsed_query': None,
            'query_node': None,
            'es_query': {},
            'es_response': {},
            'processed_hits': [],
            'raw_hits': []
        }

        # 性能指标
        self.performance_metrics = PerformanceMetrics()

        # 元数据
        self.metadata = {
            'search_params': {},  # size, from_, filters等
            'feature_flags': {},  # enable_translation, enable_embedding等
            'config_info': {},   # 索引配置、字段映射等
            'error_info': None,
            'warnings': []
        }

        # 日志记录器引用（延迟初始化）
        self._logger = None

    @property
    def logger(self):
        """获取日志记录器"""
        if self._logger is None:
            from utils.logger import get_logger
            self._logger = get_logger("request_context")
        return self._logger

    def start_stage(self, stage: RequestContextStage) -> float:
        """
        开始一个阶段的计时

        Args:
            stage: 阶段枚举

        Returns:
            开始时间戳
        """
        start_time = time.time()
        self.performance_metrics.stage_start_times[stage.value] = start_time
        self.logger.debug(
            f"Start stage | {stage.value}",
            extra={'reqid': self.reqid, 'uid': self.uid}
        )
        return start_time

    def end_stage(self, stage: RequestContextStage) -> float:
        """
        结束一个阶段的计时

        Args:
            stage: 阶段枚举

        Returns:
            阶段耗时（毫秒）
        """
        if stage.value not in self.performance_metrics.stage_start_times:
            self.logger.warning(
                f"Stage not started | {stage.value}",
                extra={'reqid': self.reqid, 'uid': self.uid}
            )
            return 0.0

        start_time = self.performance_metrics.stage_start_times[stage.value]
        duration_ms = (time.time() - start_time) * 1000
        self.performance_metrics.stage_timings[stage.value] = duration_ms

        self.logger.debug(
            f"End stage | {stage.value} | duration: {duration_ms:.2f}ms",
            extra={'reqid': self.reqid, 'uid': self.uid}
        )
        return duration_ms

    def get_stage_duration(self, stage: RequestContextStage) -> float:
        """
        获取指定阶段的耗时

        Args:
            stage: 阶段枚举

        Returns:
            阶段耗时（毫秒），如果未计时则返回0
        """
        return self.performance_metrics.stage_timings.get(stage.value, 0.0)

    def store_query_analysis(self, **kwargs) -> None:
        """
        存储查询分析结果

        Args:
            **kwargs: 查询分析相关的字段
        """
        for key, value in kwargs.items():
            if hasattr(self.query_analysis, key):
                setattr(self.query_analysis, key, value)
            else:
                self.logger.warning(
                    f"Unknown query analysis field | {key}",
                    extra={'reqid': self.reqid, 'uid': self.uid}
                )

    def store_intermediate_result(self, key: str, value: Any) -> None:
        """
        存储中间结果

        Args:
            key: 结果键名
            value: 结果值
        """
        self.intermediate_results[key] = value
        self.logger.debug(
            f"Store intermediate result | {key}",
            extra={'reqid': self.reqid, 'uid': self.uid}
        )

    def get_intermediate_result(self, key: str, default: Any = None) -> Any:
        """
        获取中间结果

        Args:
            key: 结果键名
            default: 默认值

        Returns:
            中间结果值
        """
        return self.intermediate_results.get(key, default)

    def add_warning(self, warning: str) -> None:
        """
        添加警告信息

        Args:
            warning: 警告信息
        """
        self.metadata['warnings'].append(warning)
        self.logger.warning(warning, extra={'reqid': self.reqid, 'uid': self.uid})

    def set_error(self, error: Exception) -> None:
        """
        设置错误信息

        Args:
            error: 异常对象
        """
        self.metadata['error_info'] = {
            'type': type(error).__name__,
            'message': str(error),
            'details': {}
        }
        self.logger.error(
            f"Set error info | {type(error).__name__}: {str(error)}",
            extra={'reqid': self.reqid, 'uid': self.uid}
        )

    def has_error(self) -> bool:
        """检查是否有错误"""
        return self.metadata['error_info'] is not None

    def calculate_stage_percentages(self) -> Dict[str, float]:
        """
        计算各阶段耗时占总耗时的百分比

        Returns:
            各阶段耗时占比字典
        """
        total = self.performance_metrics.total_duration
        if total <= 0:
            return {}

        percentages = {}
        for stage, duration in self.performance_metrics.stage_timings.items():
            percentages[stage] = round((duration / total) * 100, 2)

        return percentages

    def get_summary(self) -> Dict[str, Any]:
        """
        获取完整的上下文摘要

        Returns:
            包含所有关键信息的字典
        """
        return {
            'request_info': {
                'reqid': self.reqid,
                'uid': self.uid,
                'has_error': self.has_error(),
                'warnings_count': len(self.metadata['warnings'])
            },
            'query_analysis': {
                'original_query': self.query_analysis.original_query,
                'normalized_query': self.query_analysis.normalized_query,
                'rewritten_query': self.query_analysis.rewritten_query,
                'detected_language': self.query_analysis.detected_language,
                'domain': self.query_analysis.domain,
                'has_vector': self.query_analysis.query_vector is not None,
                'is_simple_query': self.query_analysis.is_simple_query
            },
            'performance': {
                'total_duration_ms': round(self.performance_metrics.total_duration, 2),
                'stage_timings_ms': {
                    k: round(v, 2) for k, v in self.performance_metrics.stage_timings.items()
                },
                'stage_percentages': self.calculate_stage_percentages()
            },
            'results': {
                'total_hits': len(self.intermediate_results.get('processed_hits', [])),
                'has_es_response': bool(self.intermediate_results.get('es_response')),
                'es_query_size': len(str(self.intermediate_results.get('es_query', {})))
            },
            'metadata': {
                'feature_flags': self.metadata['feature_flags'],
                'search_params': self.metadata['search_params'],
                'config_info': self.metadata['config_info']
            }
        }

    def log_performance_summary(self) -> None:
        """记录完整的性能摘要日志"""
        summary = self.get_summary()

        # 构建详细的日志消息
        msg_parts = [
            f"Search request performance summary | reqid: {self.reqid}",
            f"Total duration: {summary['performance']['total_duration_ms']:.2f}ms"
        ]

        # 添加各阶段耗时
        if summary['performance']['stage_timings_ms']:
            msg_parts.append("Stage durations:")
            for stage, duration in summary['performance']['stage_timings_ms'].items():
                percentage = summary['performance']['stage_percentages'].get(stage, 0)
                msg_parts.append(f"  - {stage}: {duration:.2f}ms ({percentage}%)")

        # 添加查询信息
        if summary['query_analysis']['original_query']:
            msg_parts.append(
                "Query: "
                f"'{summary['query_analysis']['original_query']}' "
                f"-> '{summary['query_analysis']['rewritten_query']}' "
                f"({summary['query_analysis']['detected_language']})"
            )

        # 添加结果统计
        msg_parts.append(
            f"Results: {summary['results']['total_hits']} hits "
            f"ES query size: {summary['results']['es_query_size']} chars"
        )

        # 添加错误信息（如果有）
        if summary['request_info']['has_error']:
            error_info = self.metadata['error_info']
            msg_parts.append(f"Error: {error_info['type']}: {error_info['message']}")

        # 添加警告信息（如果有）
        if summary['request_info']['warnings_count'] > 0:
            msg_parts.append(f"Warnings: {summary['request_info']['warnings_count']}")

        log_message = " | ".join(msg_parts)

        if self.has_error():
            self.logger.error(log_message, extra={'extra_data': summary, 'reqid': self.reqid, 'uid': self.uid})
        elif summary['request_info']['warnings_count'] > 0:
            self.logger.warning(log_message, extra={'extra_data': summary, 'reqid': self.reqid, 'uid': self.uid})
        else:
            self.logger.info(log_message, extra={'extra_data': summary, 'reqid': self.reqid, 'uid': self.uid})

    def __enter__(self):
        """上下文管理器入口"""
        self.start_stage(RequestContextStage.TOTAL)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """上下文管理器出口"""
        # 结束总计时
        self.end_stage(RequestContextStage.TOTAL)
        self.performance_metrics.total_duration = self.get_stage_duration(RequestContextStage.TOTAL)

        # 记录性能摘要
        self.log_performance_summary()

        # 如果有异常，记录错误信息
        if exc_type and exc_val:
            self.set_error(exc_val)


# 便利函数
def create_request_context(reqid: str = None, uid: str = None) -> RequestContext:
    """创建新的请求上下文"""
    return RequestContext(reqid, uid)


def get_current_request_context() -> Optional[RequestContext]:
    """获取当前线程的请求上下文（如果已设置）"""
    return getattr(threading.current_thread(), 'request_context', None)


def set_current_request_context(context: RequestContext) -> None:
    """设置当前线程的请求上下文"""
    threading.current_thread().request_context = context


def clear_current_request_context() -> None:
    """清除当前线程的请求上下文"""
    if hasattr(threading.current_thread(), 'request_context'):
        delattr(threading.current_thread(), 'request_context')