""" 请求粒度的上下文管理器 用于存储查询分析结果、各检索阶段中间结果、性能指标等。 支持线程安全的并发请求处理。 """ import time import threading from enum import Enum from typing import Dict, Any, Optional, List from dataclasses import dataclass, field import uuid class RequestContextStage(Enum): """搜索阶段枚举""" TOTAL = "total_search" QUERY_PARSING = "query_parsing" BOOLEAN_PARSING = "boolean_parsing" QUERY_BUILDING = "query_building" ELASTICSEARCH_SEARCH = "elasticsearch_search" RESULT_PROCESSING = "result_processing" RERANKING = "reranking" @dataclass class QueryAnalysisResult: """查询分析结果""" original_query: Optional[str] = None normalized_query: Optional[str] = None rewritten_query: Optional[str] = None detected_language: Optional[str] = None translations: Dict[str, str] = field(default_factory=dict) query_vector: Optional[List[float]] = None boolean_ast: Optional[str] = None is_simple_query: bool = True domain: str = "default" @dataclass class PerformanceMetrics: """性能指标""" stage_timings: Dict[str, float] = field(default_factory=dict) stage_start_times: Dict[str, float] = field(default_factory=dict) total_duration: float = 0.0 extra_metrics: Dict[str, Any] = field(default_factory=dict) class RequestContext: """ 请求粒度的上下文管理器 功能: 1. 存储查询分析结果和各阶段中间结果 2. 自动跟踪各阶段耗时 3. 提供线程安全的上下文访问 4. 支持上下文管理器模式 """ def __init__(self, reqid: str = None, uid: str = None): # 生成唯一请求ID self.reqid = reqid or str(uuid.uuid4())[:8] self.uid = uid or 'anonymous' # 查询分析结果 self.query_analysis = QueryAnalysisResult() # 各检索阶段中间结果 self.intermediate_results = { 'parsed_query': None, 'query_node': None, 'es_query': {}, 'es_response': {}, 'processed_hits': [], 'raw_hits': [] } # 性能指标 self.performance_metrics = PerformanceMetrics() # 元数据 self.metadata = { 'search_params': {}, # size, from_, filters等 'feature_flags': {}, # enable_translation, enable_embedding等 'config_info': {}, # 索引配置、字段映射等 'error_info': None, 'warnings': [] } # 日志记录器引用(延迟初始化) self._logger = None @property def logger(self): """获取日志记录器""" if self._logger is None: from utils.logger import get_logger self._logger = get_logger("request_context") return self._logger def start_stage(self, stage: RequestContextStage) -> float: """ 开始一个阶段的计时 Args: stage: 阶段枚举 Returns: 开始时间戳 """ start_time = time.time() self.performance_metrics.stage_start_times[stage.value] = start_time self.logger.debug(f"开始阶段 | {stage.value}", extra={'reqid': self.reqid, 'uid': self.uid}) return start_time def end_stage(self, stage: RequestContextStage) -> float: """ 结束一个阶段的计时 Args: stage: 阶段枚举 Returns: 阶段耗时(毫秒) """ if stage.value not in self.performance_metrics.stage_start_times: self.logger.warning(f"阶段未开始计时 | {stage.value}", extra={'reqid': self.reqid, 'uid': self.uid}) return 0.0 start_time = self.performance_metrics.stage_start_times[stage.value] duration_ms = (time.time() - start_time) * 1000 self.performance_metrics.stage_timings[stage.value] = duration_ms self.logger.debug( f"结束阶段 | {stage.value} | 耗时: {duration_ms:.2f}ms", extra={'reqid': self.reqid, 'uid': self.uid} ) return duration_ms def get_stage_duration(self, stage: RequestContextStage) -> float: """ 获取指定阶段的耗时 Args: stage: 阶段枚举 Returns: 阶段耗时(毫秒),如果未计时则返回0 """ return self.performance_metrics.stage_timings.get(stage.value, 0.0) def store_query_analysis(self, **kwargs) -> None: """ 存储查询分析结果 Args: **kwargs: 查询分析相关的字段 """ for key, value in kwargs.items(): if hasattr(self.query_analysis, key): setattr(self.query_analysis, key, value) else: self.logger.warning( f"未知的查询分析字段 | {key}", extra={'reqid': self.reqid, 'uid': self.uid} ) def store_intermediate_result(self, key: str, value: Any) -> None: """ 存储中间结果 Args: key: 结果键名 value: 结果值 """ self.intermediate_results[key] = value self.logger.debug(f"存储中间结果 | {key}", extra={'reqid': self.reqid, 'uid': self.uid}) def get_intermediate_result(self, key: str, default: Any = None) -> Any: """ 获取中间结果 Args: key: 结果键名 default: 默认值 Returns: 中间结果值 """ return self.intermediate_results.get(key, default) def add_warning(self, warning: str) -> None: """ 添加警告信息 Args: warning: 警告信息 """ self.metadata['warnings'].append(warning) self.logger.warning(warning, extra={'reqid': self.reqid, 'uid': self.uid}) def set_error(self, error: Exception) -> None: """ 设置错误信息 Args: error: 异常对象 """ self.metadata['error_info'] = { 'type': type(error).__name__, 'message': str(error), 'details': {} } self.logger.error( f"设置错误信息 | {type(error).__name__}: {str(error)}", extra={'reqid': self.reqid, 'uid': self.uid} ) def has_error(self) -> bool: """检查是否有错误""" return self.metadata['error_info'] is not None def calculate_stage_percentages(self) -> Dict[str, float]: """ 计算各阶段耗时占总耗时的百分比 Returns: 各阶段耗时占比字典 """ total = self.performance_metrics.total_duration if total <= 0: return {} percentages = {} for stage, duration in self.performance_metrics.stage_timings.items(): percentages[stage] = round((duration / total) * 100, 2) return percentages def get_summary(self) -> Dict[str, Any]: """ 获取完整的上下文摘要 Returns: 包含所有关键信息的字典 """ return { 'request_info': { 'reqid': self.reqid, 'uid': self.uid, 'has_error': self.has_error(), 'warnings_count': len(self.metadata['warnings']) }, 'query_analysis': { 'original_query': self.query_analysis.original_query, 'normalized_query': self.query_analysis.normalized_query, 'rewritten_query': self.query_analysis.rewritten_query, 'detected_language': self.query_analysis.detected_language, 'domain': self.query_analysis.domain, 'has_vector': self.query_analysis.query_vector is not None, 'is_simple_query': self.query_analysis.is_simple_query }, 'performance': { 'total_duration_ms': round(self.performance_metrics.total_duration, 2), 'stage_timings_ms': { k: round(v, 2) for k, v in self.performance_metrics.stage_timings.items() }, 'stage_percentages': self.calculate_stage_percentages() }, 'results': { 'total_hits': len(self.intermediate_results.get('processed_hits', [])), 'has_es_response': bool(self.intermediate_results.get('es_response')), 'es_query_size': len(str(self.intermediate_results.get('es_query', {}))) }, 'metadata': { 'feature_flags': self.metadata['feature_flags'], 'search_params': self.metadata['search_params'], 'config_info': self.metadata['config_info'] } } def log_performance_summary(self) -> None: """记录完整的性能摘要日志""" summary = self.get_summary() # 构建详细的日志消息 msg_parts = [ f"搜索请求性能摘要 | reqid: {self.reqid}", f"总耗时: {summary['performance']['total_duration_ms']:.2f}ms" ] # 添加各阶段耗时 if summary['performance']['stage_timings_ms']: msg_parts.append("阶段耗时:") for stage, duration in summary['performance']['stage_timings_ms'].items(): percentage = summary['performance']['stage_percentages'].get(stage, 0) msg_parts.append(f" - {stage}: {duration:.2f}ms ({percentage}%)") # 添加查询信息 if summary['query_analysis']['original_query']: msg_parts.append( f"查询: '{summary['query_analysis']['original_query']}' " f"-> '{summary['query_analysis']['rewritten_query']}' " f"({summary['query_analysis']['detected_language']})" ) # 添加结果统计 msg_parts.append( f"结果: {summary['results']['total_hits']} hits " f"ES查询: {summary['results']['es_query_size']} chars" ) # 添加错误信息(如果有) if summary['request_info']['has_error']: error_info = self.metadata['error_info'] msg_parts.append(f"错误: {error_info['type']}: {error_info['message']}") # 添加警告信息(如果有) if summary['request_info']['warnings_count'] > 0: msg_parts.append(f"警告: {summary['request_info']['warnings_count']} 个") log_message = " | ".join(msg_parts) if self.has_error(): self.logger.error(log_message, extra={'extra_data': summary, 'reqid': self.reqid, 'uid': self.uid}) elif summary['request_info']['warnings_count'] > 0: self.logger.warning(log_message, extra={'extra_data': summary, 'reqid': self.reqid, 'uid': self.uid}) else: self.logger.info(log_message, extra={'extra_data': summary, 'reqid': self.reqid, 'uid': self.uid}) def __enter__(self): """上下文管理器入口""" self.start_stage(RequestContextStage.TOTAL) return self def __exit__(self, exc_type, exc_val, exc_tb): """上下文管理器出口""" # 结束总计时 self.end_stage(RequestContextStage.TOTAL) self.performance_metrics.total_duration = self.get_stage_duration(RequestContextStage.TOTAL) # 记录性能摘要 self.log_performance_summary() # 如果有异常,记录错误信息 if exc_type and exc_val: self.set_error(exc_val) # 便利函数 def create_request_context(reqid: str = None, uid: str = None) -> RequestContext: """创建新的请求上下文""" return RequestContext(reqid, uid) def get_current_request_context() -> Optional[RequestContext]: """获取当前线程的请求上下文(如果已设置)""" return getattr(threading.current_thread(), 'request_context', None) def set_current_request_context(context: RequestContext) -> None: """设置当前线程的请求上下文""" threading.current_thread().request_context = context def clear_current_request_context() -> None: """清除当前线程的请求上下文""" if hasattr(threading.current_thread(), 'request_context'): delattr(threading.current_thread(), 'request_context')