From 45b397964fb80661b13dbc49c8fd03990123ea41 Mon Sep 17 00:00:00 2001 From: tangwang Date: Tue, 31 Mar 2026 10:55:53 +0800 Subject: [PATCH] qp性能优化 --- config/config.yaml | 4 ++-- query/english_keyword_extractor.py | 256 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ query/keyword_extractor.py | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++++++------ query/product_title_exclusion.py | 25 +++++++++++++++++++++---- query/query_parser.py | 135 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------- query/style_intent.py | 117 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------- query/tokenization.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------- suggestion/service.py | 2 +- tests/test_query_parser_mixed_language.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ tests/test_style_intent.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_tokenization.py | 13 +++++++++++++ 11 files changed, 760 insertions(+), 76 deletions(-) create mode 100644 query/english_keyword_extractor.py create mode 100644 tests/test_tokenization.py diff --git a/config/config.yaml b/config/config.yaml index 6093e4b..a2656f6 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -116,8 +116,8 @@ query_config: # 查询解析阶段:翻译与 query 向量并发执行,共用同一等待预算(毫秒)。 # 检测语言已在租户 index_languages 内:较短;不在索引语言内:较长(翻译对召回更关键)。 - translation_embedding_wait_budget_ms_source_in_index: 500 # 80 - translation_embedding_wait_budget_ms_source_not_in_index: 700 #200 + translation_embedding_wait_budget_ms_source_in_index: 200 # 80 + translation_embedding_wait_budget_ms_source_not_in_index: 300 #200 style_intent: enabled: true diff --git a/query/english_keyword_extractor.py b/query/english_keyword_extractor.py new file mode 100644 index 0000000..41a2f74 --- /dev/null +++ b/query/english_keyword_extractor.py @@ -0,0 +1,256 @@ +""" +Lightweight English core-term extraction for lexical keyword constraints. +""" + +from __future__ import annotations + +import logging +from typing import List, Optional, Sequence, Set + +from .tokenization import normalize_query_text, simple_tokenize_query + +logger = logging.getLogger(__name__) + +_WEAK_BOOST_ADJS = frozenset( + { + "best", + "good", + "great", + "new", + "free", + "cheap", + "top", + "fine", + "real", + } +) + +_FUNCTIONAL_DEP = frozenset( + { + "det", + "aux", + "auxpass", + "prep", + "mark", + "expl", + "cc", + "punct", + "case", + } +) + +_DEMOGRAPHIC_NOUNS = frozenset( + { + "women", + "woman", + "men", + "man", + "kids", + "kid", + "boys", + "boy", + "girls", + "girl", + "baby", + "babies", + "toddler", + "adult", + "adults", + } +) + +_PRICE_PREP_LEMMAS = frozenset({"under", "over", "below", "above", "within", "between", "near"}) +_DIMENSION_ROOTS = frozenset({"size", "width", "length", "height", "weight"}) + + +def _dedupe_preserve(seq: Sequence[str]) -> List[str]: + seen: Set[str] = set() + out: List[str] = [] + for item in seq: + normalized = normalize_query_text(item) + if not normalized or normalized in seen: + continue + seen.add(normalized) + out.append(normalized) + return out + + +def _lemma_lower(token) -> str: + return ((token.lemma_ or token.text) or "").lower().strip() + + +def _surface_lower(token) -> str: + return (token.text or "").lower().strip() + + +def _project_terms_to_query_tokens(query: str, terms: Sequence[str]) -> List[str]: + simple_tokens = _dedupe_preserve(simple_tokenize_query(query)) + projected: List[str] = [] + for term in terms: + normalized = normalize_query_text(term) + if len(normalized) < 2 or normalized in _DEMOGRAPHIC_NOUNS: + continue + exact = next((token for token in simple_tokens if token == normalized), None) + if exact is not None: + projected.append(exact) + continue + partial = next( + ( + token + for token in simple_tokens + if len(normalized) >= 3 and normalized in token and token not in _DEMOGRAPHIC_NOUNS + ), + None, + ) + if partial is not None: + projected.append(partial) + continue + projected.append(normalized) + return _dedupe_preserve(projected) + + +class EnglishKeywordExtractor: + """Extracts a small set of English core product terms with spaCy.""" + + def __init__(self, nlp: Optional[object] = None) -> None: + self._nlp = nlp if nlp is not None else self._load_nlp() + + @staticmethod + def _load_nlp() -> Optional[object]: + try: + import spacy + + return spacy.load("en_core_web_sm", disable=["ner", "textcat"]) + except Exception as exc: + logger.warning("English keyword extractor disabled; failed to load spaCy model: %s", exc) + return None + + def extract_keywords(self, query: str) -> str: + text = str(query or "").strip() + if not text: + return "" + if self._nlp is None: + return self._fallback_keywords(text) + try: + return self._extract_keywords_with_spacy(text) + except Exception as exc: + logger.warning("spaCy English keyword extraction failed; using fallback: %s", exc) + return self._fallback_keywords(text) + + def _extract_keywords_with_spacy(self, query: str) -> str: + doc = self._nlp(query) + intersection: Set[str] = set() + stops = self._nlp.Defaults.stop_words | _WEAK_BOOST_ADJS + pobj_heads_to_demote: Set[int] = set() + + for token in doc: + if token.dep_ == "prep" and token.text.lower() == "for": + for child in token.children: + if child.dep_ == "pobj" and child.pos_ in ("NOUN", "PROPN"): + pobj_heads_to_demote.add(child.i) + + for token in doc: + if token.dep_ != "prep" or _lemma_lower(token) not in _PRICE_PREP_LEMMAS: + continue + for child in token.children: + if child.dep_ == "pobj" and child.pos_ in ("NOUN", "PROPN"): + pobj_heads_to_demote.add(child.i) + + for token in doc: + if token.dep_ == "dobj" and token.pos_ in ("NOUN", "PROPN") and token.i not in pobj_heads_to_demote: + intersection.add(_surface_lower(token)) + + for token in doc: + if token.dep_ == "nsubj" and token.pos_ in ("NOUN", "PROPN"): + head = token.head + if head.pos_ == "AUX" and head.dep_ == "ROOT": + intersection.add(_surface_lower(token)) + + for token in doc: + if token.dep_ == "ROOT" and token.pos_ in ("INTJ", "PROPN"): + intersection.add(_surface_lower(token)) + if token.pos_ == "PROPN": + if token.dep_ == "compound" and _lemma_lower(token.head) in _DEMOGRAPHIC_NOUNS: + continue + intersection.add(_surface_lower(token)) + + for token in doc: + if token.dep_ == "ROOT" and token.pos_ in ("NOUN", "PROPN"): + if _lemma_lower(token) in _DIMENSION_ROOTS: + for child in token.children: + if child.dep_ == "nsubj" and child.pos_ in ("NOUN", "PROPN"): + intersection.add(_surface_lower(child)) + continue + if _lemma_lower(token) in _DEMOGRAPHIC_NOUNS: + for child in token.children: + if child.dep_ == "compound" and child.pos_ == "NOUN": + intersection.add(_surface_lower(child)) + continue + if token.i in pobj_heads_to_demote: + continue + intersection.add(_surface_lower(token)) + + for token in doc: + if token.dep_ != "ROOT" or token.pos_ not in ("INTJ", "VERB", "NOUN"): + continue + pobjs = sorted( + [child for child in token.children if child.dep_ == "pobj" and child.pos_ in ("NOUN", "PROPN")], + key=lambda item: item.i, + ) + if len(pobjs) >= 2 and token.pos_ == "INTJ": + intersection.add(_surface_lower(pobjs[0])) + for extra in pobjs[1:]: + if _lemma_lower(extra) not in _DEMOGRAPHIC_NOUNS: + intersection.add(_surface_lower(extra)) + elif len(pobjs) == 1 and token.pos_ == "INTJ": + intersection.add(_surface_lower(pobjs[0])) + + if not intersection: + for chunk in doc.noun_chunks: + head = chunk.root + if head.pos_ not in ("NOUN", "PROPN"): + continue + if head.dep_ == "pobj" and head.head.dep_ == "prep": + prep = head.head + if _lemma_lower(prep) in _PRICE_PREP_LEMMAS or prep.text.lower() == "for": + continue + head_text = _surface_lower(head) + if head_text: + intersection.add(head_text) + for token in chunk: + if token == head or token.pos_ != "PROPN": + continue + intersection.add(_surface_lower(token)) + + core_terms = _dedupe_preserve( + token.text.lower() + for token in doc + if _surface_lower(token) in intersection + and _surface_lower(token) not in stops + and _surface_lower(token) not in _DEMOGRAPHIC_NOUNS + and token.dep_ not in _FUNCTIONAL_DEP + and len(_surface_lower(token)) >= 2 + ) + projected_terms = _project_terms_to_query_tokens(query, core_terms) + if projected_terms: + return " ".join(projected_terms[:3]) + return self._fallback_keywords(query) + + def _fallback_keywords(self, query: str) -> str: + tokens = [ + normalize_query_text(token) + for token in simple_tokenize_query(query) + if normalize_query_text(token) + ] + if not tokens: + return "" + + filtered = [token for token in tokens if token not in _DEMOGRAPHIC_NOUNS] + if not filtered: + filtered = tokens + + # Keep the right-most likely product head plus one close modifier. + head = filtered[-1] + if len(filtered) >= 2: + return " ".join(filtered[-2:]) + return head diff --git a/query/keyword_extractor.py b/query/keyword_extractor.py index b435aec..999f15c 100644 --- a/query/keyword_extractor.py +++ b/query/keyword_extractor.py @@ -11,6 +11,9 @@ from __future__ import annotations import logging from typing import Any, Dict, List, Optional +from .english_keyword_extractor import EnglishKeywordExtractor +from .tokenization import QueryTextAnalysisCache + logger = logging.getLogger(__name__) import hanlp # type: ignore @@ -21,7 +24,7 @@ KEYWORDS_QUERY_BASE_KEY = "base" # | 场景 | 推荐模型 | # | :--------- | :------------------------------------------- | # | 纯中文 + 最高精度 | CTB9_TOK_ELECTRA_BASE_CRF 或 MSR_TOK_ELECTRA_BASE_CRF | -# | 纯中文 + 速度优先 | FINE_ELECTRA_SMALL_ZH(细粒度)或 COARSE_ELECTRA_SMALL_ZH(粗粒度) | +# | 纯中文 + 速度优先 | FINE_ELECTRA_SMALL_ZH (细粒度)或 COARSE_ELECTRA_SMALL_ZH (粗粒度) | # | **中英文混合** | `UD_TOK_MMINILMV2L6` 或 `UD_TOK_MMINILMV2L12` ( Transformer 编码器的层数不同)| @@ -33,23 +36,38 @@ class KeywordExtractor: tokenizer: Optional[Any] = None, *, ignore_keywords: Optional[List[str]] = None, + english_extractor: Optional[EnglishKeywordExtractor] = None, ): if tokenizer is not None: self.tok = tokenizer else: - self.tok = hanlp.load(hanlp.pretrained.tok.UD_TOK_MMINILMV2L6) + self.tok = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH) self.tok.config.output_spans = True self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL) self.ignore_keywords = frozenset(ignore_keywords or ["玩具"]) + self.english_extractor = english_extractor or EnglishKeywordExtractor() - def extract_keywords(self, query: str) -> str: + def extract_keywords( + self, + query: str, + *, + language_hint: Optional[str] = None, + tokenizer_result: Optional[Any] = None, + ) -> str: """ 从查询中提取关键词(名词,长度 ≥ 2),以空格分隔非连续片段。 """ query = (query or "").strip() if not query: return "" - tok_result_with_position = self.tok(query) + normalized_language = str(language_hint or "").strip().lower() + if normalized_language == "en": + return self.english_extractor.extract_keywords(query) + if normalized_language and normalized_language != "zh": + return "" + tok_result_with_position = ( + tokenizer_result if tokenizer_result is not None else self.tok(query) + ) tok_result = [x[0] for x in tok_result_with_position] if not tok_result: return "" @@ -72,6 +90,10 @@ def collect_keywords_queries( extractor: KeywordExtractor, rewritten_query: str, translations: Dict[str, str], + *, + source_language: Optional[str] = None, + text_analysis_cache: Optional[QueryTextAnalysisCache] = None, + base_keywords_query: Optional[str] = None, ) -> Dict[str, str]: """ Build the keyword map for all lexical variants (base + translations). @@ -79,14 +101,40 @@ def collect_keywords_queries( Omits entries when extraction yields an empty string. """ out: Dict[str, str] = {} - base_kw = extractor.extract_keywords(rewritten_query) + base_kw = base_keywords_query + if base_kw is None: + base_kw = extractor.extract_keywords( + rewritten_query, + language_hint=source_language or ( + text_analysis_cache.get_language_hint(rewritten_query) + if text_analysis_cache is not None + else None + ), + tokenizer_result=( + text_analysis_cache.get_tokenizer_result(rewritten_query) + if text_analysis_cache is not None + else None + ), + ) if base_kw: out[KEYWORDS_QUERY_BASE_KEY] = base_kw for lang, text in translations.items(): lang_key = str(lang or "").strip().lower() if not lang_key or not (text or "").strip(): continue - kw = extractor.extract_keywords(text) + kw = extractor.extract_keywords( + text, + language_hint=lang_key or ( + text_analysis_cache.get_language_hint(text) + if text_analysis_cache is not None + else None + ), + tokenizer_result=( + text_analysis_cache.get_tokenizer_result(text) + if text_analysis_cache is not None + else None + ), + ) if kw: out[lang_key] = kw return out diff --git a/query/product_title_exclusion.py b/query/product_title_exclusion.py index 66dfacd..0f147ce 100644 --- a/query/product_title_exclusion.py +++ b/query/product_title_exclusion.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple -from .tokenization import TokenizedText, normalize_query_text, tokenize_text +from .tokenization import QueryTextAnalysisCache, TokenizedText, normalize_query_text, tokenize_text def _dedupe_terms(terms: Iterable[str]) -> List[str]: @@ -158,9 +158,27 @@ class ProductTitleExclusionDetector: self.registry = registry self.tokenizer = tokenizer + def _tokenize_text( + self, + text: str, + *, + analysis_cache: Optional[QueryTextAnalysisCache] = None, + ) -> TokenizedText: + if analysis_cache is not None: + return analysis_cache.get_tokenized_text( + text, + max_ngram=self.registry.max_term_ngram, + ) + return tokenize_text( + text, + tokenizer=self.tokenizer, + max_ngram=self.registry.max_term_ngram, + ) + def _build_query_variants(self, parsed_query: Any) -> Tuple[TokenizedText, ...]: seen = set() variants: List[TokenizedText] = [] + analysis_cache = getattr(parsed_query, "_text_analysis_cache", None) texts = [ getattr(parsed_query, "original_query", None), getattr(parsed_query, "query_normalized", None), @@ -180,10 +198,9 @@ class ProductTitleExclusionDetector: continue seen.add(normalized) variants.append( - tokenize_text( + self._tokenize_text( text, - tokenizer=self.tokenizer, - max_ngram=self.registry.max_term_ngram, + analysis_cache=analysis_cache, ) ) diff --git a/query/query_parser.py b/query/query_parser.py index 89ee0d6..28279c7 100644 --- a/query/query_parser.py +++ b/query/query_parser.py @@ -27,7 +27,7 @@ from .product_title_exclusion import ( ) from .query_rewriter import QueryRewriter, QueryNormalizer from .style_intent import StyleIntentDetector, StyleIntentProfile, StyleIntentRegistry -from .tokenization import extract_token_strings, simple_tokenize_query +from .tokenization import QueryTextAnalysisCache, contains_han_text, extract_token_strings from .keyword_extractor import KeywordExtractor, collect_keywords_queries logger = logging.getLogger(__name__) @@ -119,6 +119,7 @@ class ParsedQuery: keywords_queries: Dict[str, str] = field(default_factory=dict) style_intent_profile: Optional[StyleIntentProfile] = None product_title_exclusion_profile: Optional[ProductTitleExclusionProfile] = None + _text_analysis_cache: Optional[QueryTextAnalysisCache] = field(default=None, repr=False) def text_for_rerank(self) -> str: """See :func:`rerank_query_text`.""" @@ -238,7 +239,7 @@ class QueryParser: if hanlp is None: raise RuntimeError("HanLP is required for QueryParser tokenization") logger.info("Initializing HanLP tokenizer...") - tokenizer = hanlp.load(hanlp.pretrained.tok.CTB9_TOK_ELECTRA_BASE_CRF) + tokenizer = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH) tokenizer.config.output_spans = True logger.info("HanLP tokenizer initialized") return tokenizer @@ -288,6 +289,33 @@ class QueryParser: def _get_query_tokens(self, query: str) -> List[str]: return self._extract_tokens(self._tokenizer(query)) + @staticmethod + def _is_ascii_latin_query(text: str) -> bool: + candidate = str(text or "").strip() + if not candidate or contains_han_text(candidate): + return False + try: + candidate.encode("ascii") + except UnicodeEncodeError: + return False + return any(ch.isalpha() for ch in candidate) + + def _detect_query_language( + self, + query_text: str, + *, + target_languages: Optional[List[str]] = None, + ) -> str: + normalized_targets = self._normalize_language_codes(target_languages) + supported_languages = self._normalize_language_codes( + getattr(self.config.query_config, "supported_languages", None) + ) + active_languages = normalized_targets or supported_languages + if active_languages and set(active_languages).issubset({"en", "zh"}): + if self._is_ascii_latin_query(query_text): + return "en" + return self.language_detector.detect(query_text) + def parse( self, query: str, @@ -332,12 +360,15 @@ class QueryParser: active_logger.debug(msg) # Stage 1: Normalize + normalize_t0 = time.perf_counter() normalized = self.normalizer.normalize(query) + normalize_ms = (time.perf_counter() - normalize_t0) * 1000.0 log_debug(f"Normalization completed | '{query}' -> '{normalized}'") if context: context.store_intermediate_result('query_normalized', normalized) # Stage 2: Query rewriting + rewrite_t0 = time.perf_counter() query_text = normalized rewritten = normalized if self.config.query_config.rewrite_dictionary: # Enable rewrite if dictionary exists @@ -348,21 +379,26 @@ class QueryParser: if context: context.store_intermediate_result('rewritten_query', rewritten) context.add_warning(f"Query was rewritten: {query_text}") + rewrite_ms = (time.perf_counter() - rewrite_t0) * 1000.0 + + normalized_targets = self._normalize_language_codes(target_languages) # Stage 3: Language detection - detected_lang = self.language_detector.detect(query_text) + language_detect_t0 = time.perf_counter() + detected_lang = self._detect_query_language( + query_text, + target_languages=normalized_targets, + ) # Use default language if detection failed (None or "unknown") if not detected_lang or detected_lang == "unknown": detected_lang = self.config.query_config.default_language + language_detect_ms = (time.perf_counter() - language_detect_t0) * 1000.0 log_info(f"Language detection | Detected language: {detected_lang}") if context: context.store_intermediate_result('detected_language', detected_lang) - # Stage 4: Query analysis (tokenization) - query_tokens = self._get_query_tokens(query_text) - - log_debug(f"Query analysis | Query tokens: {query_tokens}") - if context: - context.store_intermediate_result('query_tokens', query_tokens) + text_analysis_cache = QueryTextAnalysisCache(tokenizer=self._tokenizer) + for text_variant in (query, normalized, query_text): + text_analysis_cache.set_language_hint(text_variant, detected_lang) # Stage 5: Translation + embedding. Parser only coordinates async enrichment work; the # caller decides translation targets and later search-field planning. @@ -371,7 +407,6 @@ class QueryParser: future_submit_at: Dict[Any, float] = {} async_executor: Optional[ThreadPoolExecutor] = None detected_norm = str(detected_lang or "").strip().lower() - normalized_targets = self._normalize_language_codes(target_languages) translation_targets = [lang for lang in normalized_targets if lang != detected_norm] source_language_in_index = bool(normalized_targets) and detected_norm in normalized_targets @@ -398,7 +433,9 @@ class QueryParser: thread_name_prefix="query-enrichment", ) + async_submit_ms = 0.0 try: + async_submit_t0 = time.perf_counter() if async_executor is not None: for lang in translation_targets: model_name = self._pick_query_translation_model( @@ -466,6 +503,7 @@ class QueryParser: future = async_executor.submit(_encode_image_query_vector) future_to_task[future] = ("image_embedding", None) future_submit_at[future] = time.perf_counter() + async_submit_ms = (time.perf_counter() - async_submit_t0) * 1000.0 except Exception as e: error_msg = f"Async query enrichment submission failed | Error: {str(e)}" log_info(error_msg) @@ -477,6 +515,33 @@ class QueryParser: future_to_task.clear() future_submit_at.clear() + # Stage 4: Query analysis (tokenization) now overlaps with async enrichment work. + query_analysis_t0 = time.perf_counter() + query_tokenizer_t0 = time.perf_counter() + query_tokenizer_result = text_analysis_cache.get_tokenizer_result(query_text) + query_tokenizer_ms = (time.perf_counter() - query_tokenizer_t0) * 1000.0 + query_token_extract_t0 = time.perf_counter() + query_tokens = self._extract_tokens(query_tokenizer_result) + query_token_extract_ms = (time.perf_counter() - query_token_extract_t0) * 1000.0 + query_analysis_ms = (time.perf_counter() - query_analysis_t0) * 1000.0 + + log_debug(f"Query analysis | Query tokens: {query_tokens}") + if context: + context.store_intermediate_result('query_tokens', query_tokens) + + keywords_base_query = "" + keywords_base_ms = 0.0 + try: + keywords_base_t0 = time.perf_counter() + keywords_base_query = self._keyword_extractor.extract_keywords( + query_text, + language_hint=detected_lang, + tokenizer_result=text_analysis_cache.get_tokenizer_result(query_text), + ) + keywords_base_ms = (time.perf_counter() - keywords_base_t0) * 1000.0 + except Exception as e: + log_info(f"Base keyword extraction failed | Error: {e}") + # Wait for translation + embedding concurrently; shared budget depends on whether # the detected language belongs to caller-provided target_languages. qc = self.config.query_config @@ -501,7 +566,10 @@ class QueryParser: f"source_in_target_languages={source_in_target_languages}" ) + async_wait_t0 = time.perf_counter() done, not_done = wait(list(future_to_task.keys()), timeout=budget_sec) + async_wait_ms = (time.perf_counter() - async_wait_t0) * 1000.0 + async_collect_t0 = time.perf_counter() for future in done: task_type, lang = future_to_task[future] t0 = future_submit_at.pop(future, None) @@ -511,6 +579,7 @@ class QueryParser: if task_type == "translation": if result: translations[lang] = result + text_analysis_cache.set_language_hint(result, lang) if context: context.store_intermediate_result(f"translation_{lang}", result) elif task_type == "embedding": @@ -561,20 +630,31 @@ class QueryParser: log_info(timeout_msg) if context: context.add_warning(timeout_msg) + async_collect_ms = (time.perf_counter() - async_collect_t0) * 1000.0 if async_executor: async_executor.shutdown(wait=False) if translations and context: context.store_intermediate_result("translations", translations) + else: + async_wait_ms = 0.0 + async_collect_ms = 0.0 + tail_sync_t0 = time.perf_counter() keywords_queries: Dict[str, str] = {} + keyword_tail_ms = 0.0 try: + keywords_t0 = time.perf_counter() keywords_queries = collect_keywords_queries( self._keyword_extractor, query_text, translations, + source_language=detected_lang, + text_analysis_cache=text_analysis_cache, + base_keywords_query=keywords_base_query, ) + keyword_tail_ms = (time.perf_counter() - keywords_t0) * 1000.0 except Exception as e: log_info(f"Keyword extraction failed | Error: {e}") @@ -589,9 +669,43 @@ class QueryParser: image_query_vector=image_query_vector, query_tokens=query_tokens, keywords_queries=keywords_queries, + _text_analysis_cache=text_analysis_cache, ) + style_intent_t0 = time.perf_counter() style_intent_profile = self.style_intent_detector.detect(base_result) + style_intent_ms = (time.perf_counter() - style_intent_t0) * 1000.0 + product_title_exclusion_t0 = time.perf_counter() product_title_exclusion_profile = self.product_title_exclusion_detector.detect(base_result) + product_title_exclusion_ms = ( + (time.perf_counter() - product_title_exclusion_t0) * 1000.0 + ) + tail_sync_ms = (time.perf_counter() - tail_sync_t0) * 1000.0 + before_wait_ms = ( + normalize_ms + + rewrite_ms + + language_detect_ms + + async_submit_ms + + query_analysis_ms + + keywords_base_ms + ) + log_info( + "Query parse stage timings | " + f"normalize_ms={normalize_ms:.1f} | " + f"rewrite_ms={rewrite_ms:.1f} | " + f"language_detect_ms={language_detect_ms:.1f} | " + f"query_tokenizer_ms={query_tokenizer_ms:.1f} | " + f"query_token_extract_ms={query_token_extract_ms:.1f} | " + f"query_analysis_ms={query_analysis_ms:.1f} | " + f"async_submit_ms={async_submit_ms:.1f} | " + f"before_wait_ms={before_wait_ms:.1f} | " + f"async_wait_ms={async_wait_ms:.1f} | " + f"async_collect_ms={async_collect_ms:.1f} | " + f"base_keywords_ms={keywords_base_ms:.1f} | " + f"keyword_tail_ms={keyword_tail_ms:.1f} | " + f"style_intent_ms={style_intent_ms:.1f} | " + f"product_title_exclusion_ms={product_title_exclusion_ms:.1f} | " + f"tail_sync_ms={tail_sync_ms:.1f}" + ) if context: context.store_intermediate_result( "style_intent_profile", @@ -614,6 +728,7 @@ class QueryParser: keywords_queries=keywords_queries, style_intent_profile=style_intent_profile, product_title_exclusion_profile=product_title_exclusion_profile, + _text_analysis_cache=text_analysis_cache, ) parse_total_ms = (time.perf_counter() - parse_t0) * 1000.0 diff --git a/query/style_intent.py b/query/style_intent.py index 96e4234..c9a1f50 100644 --- a/query/style_intent.py +++ b/query/style_intent.py @@ -7,7 +7,7 @@ from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple -from .tokenization import TokenizedText, normalize_query_text, tokenize_text +from .tokenization import QueryTextAnalysisCache, TokenizedText, normalize_query_text, tokenize_text @dataclass(frozen=True) @@ -233,32 +233,63 @@ class StyleIntentDetector: self.registry = registry self.tokenizer = tokenizer - def _build_query_variants(self, parsed_query: Any) -> Tuple[TokenizedText, ...]: - seen = set() - variants: List[TokenizedText] = [] - texts = [ - self._get_language_query_text(parsed_query, "zh"), - self._get_language_query_text(parsed_query, "en"), - ] + def _max_term_ngram(self) -> int: + return max( + (definition.max_term_ngram for definition in self.registry.definitions.values()), + default=3, + ) + + def _tokenize_text( + self, + text: str, + *, + analysis_cache: Optional[QueryTextAnalysisCache] = None, + ) -> TokenizedText: + max_term_ngram = self._max_term_ngram() + if analysis_cache is not None: + return analysis_cache.get_tokenized_text(text, max_ngram=max_term_ngram) + return tokenize_text( + text, + tokenizer=self.tokenizer, + max_ngram=max_term_ngram, + ) - for raw_text in texts: - text = str(raw_text or "").strip() + def _build_language_variants( + self, + parsed_query: Any, + *, + analysis_cache: Optional[QueryTextAnalysisCache] = None, + ) -> Dict[str, TokenizedText]: + variants: Dict[str, TokenizedText] = {} + for language in ("zh", "en"): + text = self._get_language_query_text(parsed_query, language).strip() if not text: continue - normalized = normalize_query_text(text) + variants[language] = self._tokenize_text( + text, + analysis_cache=analysis_cache, + ) + return variants + + def _build_query_variants( + self, + parsed_query: Any, + *, + language_variants: Optional[Dict[str, TokenizedText]] = None, + analysis_cache: Optional[QueryTextAnalysisCache] = None, + ) -> Tuple[TokenizedText, ...]: + seen = set() + variants: List[TokenizedText] = [] + + for variant in (language_variants or self._build_language_variants( + parsed_query, + analysis_cache=analysis_cache, + )).values(): + normalized = variant.normalized_text if not normalized or normalized in seen: continue seen.add(normalized) - variants.append( - tokenize_text( - text, - tokenizer=self.tokenizer, - max_ngram=max( - (definition.max_term_ngram for definition in self.registry.definitions.values()), - default=3, - ), - ) - ) + variants.append(variant) return tuple(variants) @@ -271,26 +302,50 @@ class StyleIntentDetector: return str(translated) return str(getattr(parsed_query, "original_query", "") or "") - def _tokenize_language_query(self, parsed_query: Any, language: str) -> Optional[TokenizedText]: + def _tokenize_language_query( + self, + parsed_query: Any, + language: str, + *, + language_variants: Optional[Dict[str, TokenizedText]] = None, + analysis_cache: Optional[QueryTextAnalysisCache] = None, + ) -> Optional[TokenizedText]: + if language_variants is not None: + return language_variants.get(language) text = self._get_language_query_text(parsed_query, language).strip() if not text: return None - return tokenize_text( + return self._tokenize_text( text, - tokenizer=self.tokenizer, - max_ngram=max( - (definition.max_term_ngram for definition in self.registry.definitions.values()), - default=3, - ), + analysis_cache=analysis_cache, ) def detect(self, parsed_query: Any) -> StyleIntentProfile: if not self.registry.enabled or not self.registry.definitions: return StyleIntentProfile() - query_variants = self._build_query_variants(parsed_query) - zh_variant = self._tokenize_language_query(parsed_query, "zh") - en_variant = self._tokenize_language_query(parsed_query, "en") + analysis_cache = getattr(parsed_query, "_text_analysis_cache", None) + language_variants = self._build_language_variants( + parsed_query, + analysis_cache=analysis_cache, + ) + query_variants = self._build_query_variants( + parsed_query, + language_variants=language_variants, + analysis_cache=analysis_cache, + ) + zh_variant = self._tokenize_language_query( + parsed_query, + "zh", + language_variants=language_variants, + analysis_cache=analysis_cache, + ) + en_variant = self._tokenize_language_query( + parsed_query, + "en", + language_variants=language_variants, + analysis_cache=analysis_cache, + ) detected: List[DetectedStyleIntent] = [] seen_pairs = set() diff --git a/query/tokenization.py b/query/tokenization.py index 61beaf2..e33a31d 100644 --- a/query/tokenization.py +++ b/query/tokenization.py @@ -6,10 +6,11 @@ from __future__ import annotations from dataclasses import dataclass import re -from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple -_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]+|[A-Za-z0-9_]+(?:-[A-Za-z0-9_]+)*") +_HAN_PATTERN = re.compile(r"[\u4e00-\u9fff]") +_TOKEN_PATTERN = re.compile(r"[\u4e00-\u9fff]+|[^\W_]+(?:[-'][^\W_]+)*", re.UNICODE) def normalize_query_text(text: Optional[str]) -> str: @@ -30,6 +31,10 @@ def simple_tokenize_query(text: str) -> List[str]: return _TOKEN_PATTERN.findall(text) +def contains_han_text(text: Optional[str]) -> bool: + return bool(text and _HAN_PATTERN.search(str(text))) + + def extract_token_strings(tokenizer_result: Any) -> List[str]: """Normalize tokenizer output into a flat token string list.""" if not tokenizer_result: @@ -84,6 +89,13 @@ def _build_phrase_candidates(tokens: Sequence[str], max_ngram: int) -> List[str] return phrases +def _build_coarse_tokens(text: str, fine_tokens: Sequence[str]) -> List[str]: + coarse_tokens = _dedupe_preserve_order(simple_tokenize_query(text)) + if contains_han_text(text) and fine_tokens: + return list(_dedupe_preserve_order(fine_tokens)) + return coarse_tokens + + @dataclass(frozen=True) class TokenizedText: text: str @@ -93,30 +105,88 @@ class TokenizedText: candidates: Tuple[str, ...] +class QueryTextAnalysisCache: + """Per-parse cache for tokenizer output and derived token bundles.""" + + def __init__(self, *, tokenizer: Optional[Callable[[str], Any]] = None) -> None: + self.tokenizer = tokenizer + self._tokenizer_results: Dict[str, Any] = {} + self._tokenized_texts: Dict[Tuple[str, int], TokenizedText] = {} + self._language_hints: Dict[str, str] = {} + + @staticmethod + def _normalize_input(text: Optional[str]) -> str: + return str(text or "").strip() + + def set_language_hint(self, text: Optional[str], language: Optional[str]) -> None: + normalized_input = self._normalize_input(text) + normalized_language = normalize_query_text(language) + if normalized_input and normalized_language: + self._language_hints[normalized_input] = normalized_language + + def get_language_hint(self, text: Optional[str]) -> Optional[str]: + normalized_input = self._normalize_input(text) + if not normalized_input: + return None + return self._language_hints.get(normalized_input) + + def _should_use_model_tokenizer(self, text: str) -> bool: + if self.tokenizer is None: + return False + language_hint = self.get_language_hint(text) + has_han = contains_han_text(text) + if language_hint == "zh": + return has_han + return has_han + + def get_tokenizer_result(self, text: Optional[str]) -> Any: + normalized_input = self._normalize_input(text) + if not normalized_input: + return [] + if not self._should_use_model_tokenizer(normalized_input): + return simple_tokenize_query(normalized_input) + if normalized_input not in self._tokenizer_results: + self._tokenizer_results[normalized_input] = self.tokenizer(normalized_input) + return self._tokenizer_results[normalized_input] + + def get_tokenized_text(self, text: Optional[str], *, max_ngram: int = 3) -> TokenizedText: + normalized_input = self._normalize_input(text) + cache_key = (normalized_input, max(1, int(max_ngram))) + cached = self._tokenized_texts.get(cache_key) + if cached is not None: + return cached + + normalized_text = normalize_query_text(normalized_input) + fine_raw = extract_token_strings(self.get_tokenizer_result(normalized_input)) + fine_tokens = _dedupe_preserve_order(fine_raw) + coarse_tokens = _build_coarse_tokens(normalized_input, fine_tokens) + + bundle = TokenizedText( + text=normalized_input, + normalized_text=normalized_text, + fine_tokens=tuple(fine_tokens), + coarse_tokens=tuple(coarse_tokens), + candidates=tuple( + _dedupe_preserve_order( + list(fine_tokens) + + list(coarse_tokens) + + _build_phrase_candidates(fine_tokens, max_ngram=max_ngram) + + _build_phrase_candidates(coarse_tokens, max_ngram=max_ngram) + + ([normalized_text] if normalized_text else []) + ) + ), + ) + self._tokenized_texts[cache_key] = bundle + return bundle + + def tokenize_text( text: str, *, tokenizer: Optional[Callable[[str], Any]] = None, max_ngram: int = 3, ) -> TokenizedText: - normalized_text = normalize_query_text(text) - coarse_tokens = _dedupe_preserve_order(simple_tokenize_query(text)) - - fine_raw = extract_token_strings(tokenizer(text)) if tokenizer is not None and text else [] - fine_tokens = _dedupe_preserve_order(fine_raw) - - candidates = _dedupe_preserve_order( - list(fine_tokens) - + list(coarse_tokens) - + _build_phrase_candidates(fine_tokens, max_ngram=max_ngram) - + _build_phrase_candidates(coarse_tokens, max_ngram=max_ngram) - + ([normalized_text] if normalized_text else []) - ) - - return TokenizedText( - text=text, - normalized_text=normalized_text, - fine_tokens=tuple(fine_tokens), - coarse_tokens=tuple(coarse_tokens), - candidates=tuple(candidates), + return QueryTextAnalysisCache(tokenizer=tokenizer).get_tokenized_text( + text, + max_ngram=max_ngram, ) diff --git a/suggestion/service.py b/suggestion/service.py index 15671d3..0506a23 100644 --- a/suggestion/service.py +++ b/suggestion/service.py @@ -7,7 +7,7 @@ import time from typing import Any, Dict, List, Optional from config.tenant_config_loader import get_tenant_config_loader -from query.query_parser import simple_tokenize_query +from query.tokenization import simple_tokenize_query from suggestion.builder import get_suggestion_alias_name from utils.es_client import ESClient diff --git a/tests/test_query_parser_mixed_language.py b/tests/test_query_parser_mixed_language.py index dcae93a..ec8e385 100644 --- a/tests/test_query_parser_mixed_language.py +++ b/tests/test_query_parser_mixed_language.py @@ -77,3 +77,79 @@ def test_parse_waits_for_translation_when_source_in_index_languages(monkeypatch) assert result.detected_language == "en" assert result.translations.get("zh") == "off shoulder top-zh" assert not hasattr(result, "source_in_index_languages") + + +def test_parse_reuses_tokenization_across_tail_stages(monkeypatch): + tokenize_calls = [] + + def counting_tokenizer(text): + tokenize_calls.append(str(text)) + return str(text).split() + + config = SearchConfig( + es_index_name="test_products", + field_boosts={"title.en": 3.0, "title.zh": 3.0}, + indexes=[IndexConfig(name="default", label="default", fields=["title.en", "title.zh"])], + query_config=QueryConfig( + enable_text_embedding=False, + enable_query_rewrite=False, + supported_languages=["en", "zh"], + default_language="en", + style_intent_terms={ + "color": [ + {"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]} + ], + }, + style_intent_dimension_aliases={"color": ["color", "颜色"]}, + product_title_exclusion_rules=[ + { + "zh_trigger_terms": ["修身"], + "en_trigger_terms": ["fitted"], + "zh_title_exclusions": ["宽松"], + "en_title_exclusions": ["loose"], + } + ], + ), + function_score=FunctionScoreConfig(), + rerank=RerankConfig(), + spu_config=SPUConfig(enabled=False), + ) + parser = QueryParser( + config, + translator=_DummyTranslator(), + tokenizer=counting_tokenizer, + ) + monkeypatch.setattr(parser.language_detector, "detect", lambda text: "en") + + result = parser.parse( + "black fitted dress", + tenant_id="0", + generate_vector=False, + target_languages=["en", "zh"], + ) + + assert result.translations == {"zh": "black fitted dress-zh"} + assert result.style_intent_profile is not None + assert result.style_intent_profile.is_active is True + assert result.product_title_exclusion_profile is not None + assert result.product_title_exclusion_profile.is_active is True + assert tokenize_calls == [] + + +def test_parse_fast_path_detects_ascii_query_as_english_without_lingua(monkeypatch): + parser = QueryParser(_build_config(), translator=_DummyTranslator(), tokenizer=_tokenizer) + monkeypatch.setattr( + parser.language_detector, + "detect", + lambda text: (_ for _ in ()).throw(AssertionError("Lingua path should not be used")), + ) + + result = parser.parse( + "street t-shirt women", + tenant_id="0", + generate_vector=False, + target_languages=["en", "zh"], + ) + + assert result.detected_language == "en" + assert result.query_tokens == ["street", "t-shirt", "women"] diff --git a/tests/test_style_intent.py b/tests/test_style_intent.py index 6fe19db..4c40fc7 100644 --- a/tests/test_style_intent.py +++ b/tests/test_style_intent.py @@ -58,3 +58,37 @@ def test_style_intent_detector_uses_original_query_when_language_translation_mis assert profile.get_canonical_values("color") == {"black"} assert profile.intents[0].attribute_terms == ("black",) + + +def test_style_intent_detector_tokenizes_each_language_once(): + query_config = QueryConfig( + style_intent_terms={ + "color": [{"en_terms": ["black"], "zh_terms": ["黑色"], "attribute_terms": ["black"]}], + "size": [{"en_terms": ["xl"], "zh_terms": ["加大码"], "attribute_terms": ["xl"]}], + }, + style_intent_dimension_aliases={ + "color": ["color", "颜色"], + "size": ["size", "尺码"], + }, + ) + tokenize_calls = [] + + def counting_tokenizer(text): + tokenize_calls.append(text) + return str(text).split() + + detector = StyleIntentDetector( + StyleIntentRegistry.from_query_config(query_config), + tokenizer=counting_tokenizer, + ) + parsed_query = SimpleNamespace( + original_query="黑色 连衣裙", + query_normalized="黑色 连衣裙", + rewritten_query="黑色 连衣裙", + translations={"en": "black dress xl"}, + ) + + profile = detector.detect(parsed_query) + + assert profile.is_active is True + assert tokenize_calls == ["黑色 连衣裙"] diff --git a/tests/test_tokenization.py b/tests/test_tokenization.py new file mode 100644 index 0000000..cbf2901 --- /dev/null +++ b/tests/test_tokenization.py @@ -0,0 +1,13 @@ +from query.tokenization import QueryTextAnalysisCache + + +def test_han_coarse_tokens_follow_model_tokens_instead_of_whole_sentence(): + cache = QueryTextAnalysisCache( + tokenizer=lambda text: [("路上", 0, 2), ("穿着", 2, 4), ("女性", 4, 6), ("黑色", 10, 12)] + ) + cache.set_language_hint("路上穿着女性的衣服是黑色的", "zh") + + tokenized = cache.get_tokenized_text("路上穿着女性的衣服是黑色的") + + assert tokenized.fine_tokens == ("路上", "穿着", "女性", "黑色") + assert tokenized.coarse_tokens == ("路上", "穿着", "女性", "黑色") -- libgit2 0.21.2