english_keyword_extractor.py 8.53 KB
"""
Lightweight English core-term extraction for lexical keyword constraints.
"""

from __future__ import annotations

import logging
from typing import List, Optional, Sequence, Set

from .tokenization import normalize_query_text, simple_tokenize_query

logger = logging.getLogger(__name__)

_WEAK_BOOST_ADJS = frozenset(
    {
        "best",
        "good",
        "great",
        "new",
        "free",
        "cheap",
        "top",
        "fine",
        "real",
    }
)

_FUNCTIONAL_DEP = frozenset(
    {
        "det",
        "aux",
        "auxpass",
        "prep",
        "mark",
        "expl",
        "cc",
        "punct",
        "case",
    }
)

_DEMOGRAPHIC_NOUNS = frozenset(
    {
        "women",
        "woman",
        "men",
        "man",
        "kids",
        "kid",
        "boys",
        "boy",
        "girls",
        "girl",
        "baby",
        "babies",
        "toddler",
        "adult",
        "adults",
    }
)

_PRICE_PREP_LEMMAS = frozenset({"under", "over", "below", "above", "within", "between", "near"})
_DIMENSION_ROOTS = frozenset({"size", "width", "length", "height", "weight"})


def _dedupe_preserve(seq: Sequence[str]) -> List[str]:
    seen: Set[str] = set()
    out: List[str] = []
    for item in seq:
        normalized = normalize_query_text(item)
        if not normalized or normalized in seen:
            continue
        seen.add(normalized)
        out.append(normalized)
    return out


def _lemma_lower(token) -> str:
    return ((token.lemma_ or token.text) or "").lower().strip()


def _surface_lower(token) -> str:
    return (token.text or "").lower().strip()


def _project_terms_to_query_tokens(query: str, terms: Sequence[str]) -> List[str]:
    simple_tokens = _dedupe_preserve(simple_tokenize_query(query))
    projected: List[str] = []
    for term in terms:
        normalized = normalize_query_text(term)
        if len(normalized) < 2 or normalized in _DEMOGRAPHIC_NOUNS:
            continue
        exact = next((token for token in simple_tokens if token == normalized), None)
        if exact is not None:
            projected.append(exact)
            continue
        partial = next(
            (
                token
                for token in simple_tokens
                if len(normalized) >= 3 and normalized in token and token not in _DEMOGRAPHIC_NOUNS
            ),
            None,
        )
        if partial is not None:
            projected.append(partial)
            continue
        projected.append(normalized)
    return _dedupe_preserve(projected)


class EnglishKeywordExtractor:
    """Extracts a small set of English core product terms with spaCy."""

    def __init__(self, nlp: Optional[object] = None) -> None:
        self._nlp = nlp if nlp is not None else self._load_nlp()

    @staticmethod
    def _load_nlp() -> Optional[object]:
        try:
            import spacy

            return spacy.load("en_core_web_sm", disable=["ner", "textcat"])
        except Exception as exc:
            logger.warning("English keyword extractor disabled; failed to load spaCy model: %s", exc)
            return None

    def extract_keywords(self, query: str) -> str:
        text = str(query or "").strip()
        if not text:
            return ""
        if self._nlp is None:
            return self._fallback_keywords(text)
        try:
            return self._extract_keywords_with_spacy(text)
        except Exception as exc:
            logger.warning("spaCy English keyword extraction failed; using fallback: %s", exc)
            return self._fallback_keywords(text)

    def _extract_keywords_with_spacy(self, query: str) -> str:
        doc = self._nlp(query)
        intersection: Set[str] = set()
        stops = self._nlp.Defaults.stop_words | _WEAK_BOOST_ADJS
        pobj_heads_to_demote: Set[int] = set()

        for token in doc:
            if token.dep_ == "prep" and token.text.lower() == "for":
                for child in token.children:
                    if child.dep_ == "pobj" and child.pos_ in ("NOUN", "PROPN"):
                        pobj_heads_to_demote.add(child.i)

        for token in doc:
            if token.dep_ != "prep" or _lemma_lower(token) not in _PRICE_PREP_LEMMAS:
                continue
            for child in token.children:
                if child.dep_ == "pobj" and child.pos_ in ("NOUN", "PROPN"):
                    pobj_heads_to_demote.add(child.i)

        for token in doc:
            if token.dep_ == "dobj" and token.pos_ in ("NOUN", "PROPN") and token.i not in pobj_heads_to_demote:
                intersection.add(_surface_lower(token))

        for token in doc:
            if token.dep_ == "nsubj" and token.pos_ in ("NOUN", "PROPN"):
                head = token.head
                if head.pos_ == "AUX" and head.dep_ == "ROOT":
                    intersection.add(_surface_lower(token))

        for token in doc:
            if token.dep_ == "ROOT" and token.pos_ in ("INTJ", "PROPN"):
                intersection.add(_surface_lower(token))
            if token.pos_ == "PROPN":
                if token.dep_ == "compound" and _lemma_lower(token.head) in _DEMOGRAPHIC_NOUNS:
                    continue
                intersection.add(_surface_lower(token))

        for token in doc:
            if token.dep_ == "ROOT" and token.pos_ in ("NOUN", "PROPN"):
                if _lemma_lower(token) in _DIMENSION_ROOTS:
                    for child in token.children:
                        if child.dep_ == "nsubj" and child.pos_ in ("NOUN", "PROPN"):
                            intersection.add(_surface_lower(child))
                    continue
                if _lemma_lower(token) in _DEMOGRAPHIC_NOUNS:
                    for child in token.children:
                        if child.dep_ == "compound" and child.pos_ == "NOUN":
                            intersection.add(_surface_lower(child))
                    continue
                if token.i in pobj_heads_to_demote:
                    continue
                intersection.add(_surface_lower(token))

        for token in doc:
            if token.dep_ != "ROOT" or token.pos_ not in ("INTJ", "VERB", "NOUN"):
                continue
            pobjs = sorted(
                [child for child in token.children if child.dep_ == "pobj" and child.pos_ in ("NOUN", "PROPN")],
                key=lambda item: item.i,
            )
            if len(pobjs) >= 2 and token.pos_ == "INTJ":
                intersection.add(_surface_lower(pobjs[0]))
                for extra in pobjs[1:]:
                    if _lemma_lower(extra) not in _DEMOGRAPHIC_NOUNS:
                        intersection.add(_surface_lower(extra))
            elif len(pobjs) == 1 and token.pos_ == "INTJ":
                intersection.add(_surface_lower(pobjs[0]))

        if not intersection:
            for chunk in doc.noun_chunks:
                head = chunk.root
                if head.pos_ not in ("NOUN", "PROPN"):
                    continue
                if head.dep_ == "pobj" and head.head.dep_ == "prep":
                    prep = head.head
                    if _lemma_lower(prep) in _PRICE_PREP_LEMMAS or prep.text.lower() == "for":
                        continue
                head_text = _surface_lower(head)
                if head_text:
                    intersection.add(head_text)
                for token in chunk:
                    if token == head or token.pos_ != "PROPN":
                        continue
                    intersection.add(_surface_lower(token))

        core_terms = _dedupe_preserve(
            token.text.lower()
            for token in doc
            if _surface_lower(token) in intersection
            and _surface_lower(token) not in stops
            and _surface_lower(token) not in _DEMOGRAPHIC_NOUNS
            and token.dep_ not in _FUNCTIONAL_DEP
            and len(_surface_lower(token)) >= 2
        )
        projected_terms = _project_terms_to_query_tokens(query, core_terms)
        if projected_terms:
            return " ".join(projected_terms[:3])
        return self._fallback_keywords(query)

    def _fallback_keywords(self, query: str) -> str:
        tokens = [
            normalize_query_text(token)
            for token in simple_tokenize_query(query)
            if normalize_query_text(token)
        ]
        if not tokens:
            return ""

        filtered = [token for token in tokens if token not in _DEMOGRAPHIC_NOUNS]
        if not filtered:
            filtered = tokens

        # Keep the right-most likely product head plus one close modifier.
        head = filtered[-1]
        if len(filtered) >= 2:
            return " ".join(filtered[-2:])
        return head