keyword_extractor.py 5.16 KB
"""
HanLP-based noun keyword string for lexical constraints (token POS starts with N, length >= 2).

``ParsedQuery.keywords_queries`` uses the same key layout as text variants:
``KEYWORDS_QUERY_BASE_KEY`` for the rewritten source query, and ISO-like language
codes for each ``ParsedQuery.translations`` entry (non-empty extractions only).
"""

from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional

from .english_keyword_extractor import EnglishKeywordExtractor
from .tokenization import QueryTextAnalysisCache

logger = logging.getLogger(__name__)

import hanlp  # type: ignore

# Aligns with ``rewritten_query`` / ES ``base_query`` (not a language code).
KEYWORDS_QUERY_BASE_KEY = "base"

# | 场景         | 推荐模型                                         |
# | :--------- | :------------------------------------------- |
# | 纯中文 + 最高精度 | CTB9_TOK_ELECTRA_BASE_CRF 或 MSR_TOK_ELECTRA_BASE_CRF                |
# | 纯中文 + 速度优先 | FINE_ELECTRA_SMALL_ZH (细粒度)或 COARSE_ELECTRA_SMALL_ZH (粗粒度) |
# | **中英文混合**  | `UD_TOK_MMINILMV2L6` 或 `UD_TOK_MMINILMV2L12` ( Transformer 编码器的层数不同)|


class KeywordExtractor:
    """基于 HanLP 的名词关键词提取器(与分词位置对齐,非连续名词间插入空格)。"""

    def __init__(
        self,
        tokenizer: Optional[Any] = None,
        *,
        ignore_keywords: Optional[List[str]] = None,
        english_extractor: Optional[EnglishKeywordExtractor] = None,
    ):
        if tokenizer is not None:
            self.tok = tokenizer
        else:
            self.tok = hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH)
            self.tok.config.output_spans = True
        self.pos_tag = hanlp.load(hanlp.pretrained.pos.CTB9_POS_ELECTRA_SMALL)
        self.ignore_keywords = frozenset(ignore_keywords or ["玩具"])
        self.english_extractor = english_extractor or EnglishKeywordExtractor()

    def extract_keywords(
        self,
        query: str,
        *,
        language_hint: Optional[str] = None,
        tokenizer_result: Optional[Any] = None,
    ) -> str:
        """
        从查询中提取关键词(名词,长度 ≥ 2),以空格分隔非连续片段。
        """
        query = (query or "").strip()
        if not query:
            return ""
        normalized_language = str(language_hint or "").strip().lower()
        if normalized_language == "en":
            return self.english_extractor.extract_keywords(query)
        if normalized_language and normalized_language != "zh":
            return ""
        tok_result_with_position = (
            tokenizer_result if tokenizer_result is not None else self.tok(query)
        )
        tok_result = [x[0] for x in tok_result_with_position]
        if not tok_result:
            return ""
        pos_tags = self.pos_tag(tok_result)
        pos_tag_result = list(zip(tok_result, pos_tags))
        keywords: List[str] = []
        last_end_pos = 0
        for (word, postag), (_, start_pos, end_pos) in zip(pos_tag_result, tok_result_with_position):
            if len(word) >= 2 and str(postag).startswith("N"):
                if word in self.ignore_keywords:
                    continue
                if start_pos != last_end_pos and keywords:
                    keywords.append(" ")
                keywords.append(word)
                last_end_pos = end_pos
        return "".join(keywords).strip()


def collect_keywords_queries(
    extractor: KeywordExtractor,
    rewritten_query: str,
    translations: Dict[str, str],
    *,
    source_language: Optional[str] = None,
    text_analysis_cache: Optional[QueryTextAnalysisCache] = None,
    base_keywords_query: Optional[str] = None,
) -> Dict[str, str]:
    """
    Build the keyword map for all lexical variants (base + translations).

    Omits entries when extraction yields an empty string.
    """
    out: Dict[str, str] = {}
    base_kw = base_keywords_query
    if base_kw is None:
        base_kw = extractor.extract_keywords(
            rewritten_query,
            language_hint=source_language or (
                text_analysis_cache.get_language_hint(rewritten_query)
                if text_analysis_cache is not None
                else None
            ),
            tokenizer_result=(
                text_analysis_cache.get_tokenizer_result(rewritten_query)
                if text_analysis_cache is not None
                else None
            ),
        )
    if base_kw:
        out[KEYWORDS_QUERY_BASE_KEY] = base_kw
    for lang, text in translations.items():
        lang_key = str(lang or "").strip().lower()
        if not lang_key or not (text or "").strip():
            continue
        kw = extractor.extract_keywords(
            text,
            language_hint=lang_key or (
                text_analysis_cache.get_language_hint(text)
                if text_analysis_cache is not None
                else None
            ),
            tokenizer_result=(
                text_analysis_cache.get_tokenizer_result(text)
                if text_analysis_cache is not None
                else None
            ),
        )
        if kw:
            out[lang_key] = kw
    return out