english_query_bucketing_demo.py
17.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
#!/usr/bin/env python3
"""
Offline experiment: English query bucketing (intersection / boost / drop).
Scheme A: spaCy noun_chunks + head + lemma + rule buckets
Scheme B: spaCy NP candidates + KeyBERT rerank → intersection vs boost
Scheme C: YAKE + spaCy noun/POS filter
Run (after deps): python scripts/experiments/english_query_bucketing_demo.py
Optional: pip install -r scripts/experiments/requirements_query_bucketing_experiments.txt
"""
from __future__ import annotations
import argparse
import json
import re
import sys
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple
# --- shared -----------------------------------------------------------------
_POSSESSIVE_RE = re.compile(r"(['’]s)\b", re.IGNORECASE)
def normalize_query(s: str) -> str:
s = (s or "").strip()
s = _POSSESSIVE_RE.sub("", s)
return s
@dataclass
class BucketResult:
intersection_terms: List[str] = field(default_factory=list)
boost_terms: List[str] = field(default_factory=list)
drop_terms: List[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"intersection_terms": self.intersection_terms,
"boost_terms": self.boost_terms,
"drop_terms": self.drop_terms,
}
def _dedupe_preserve(seq: Sequence[str]) -> List[str]:
seen: Set[str] = set()
out: List[str] = []
for x in seq:
k = x.strip().lower()
if not k or k in seen:
continue
seen.add(k)
out.append(x.strip())
return out
# --- Scheme A: spaCy + rules -------------------------------------------------
WEAK_BOOST_ADJS = frozenset(
{
"best",
"good",
"great",
"new",
"free",
"cheap",
"top",
"fine",
"real",
}
)
FUNCTIONAL_DEP = frozenset(
{
"det",
"aux",
"auxpass",
"prep",
"mark",
"expl",
"cc",
"punct",
"case",
}
)
# Second pobj under list-like INTJ roots often encodes audience/size (boost, not must-match).
_DEMOGRAPHIC_NOUNS = frozenset(
{
"women",
"woman",
"men",
"man",
"kids",
"kid",
"boys",
"boy",
"girls",
"girl",
"baby",
"babies",
"toddler",
"adult",
"adults",
}
)
def _lemma_lower(t) -> str:
return ((t.lemma_ or t.text) or "").lower().strip()
def _surface_lower(t) -> str:
"""Lowercased surface form (keeps plural 'headphones' vs lemma 'headphone')."""
return (t.text or "").lower().strip()
_PRICE_PREP_LEMMAS = frozenset({"under", "over", "below", "above", "within", "between", "near"})
def bucket_scheme_a_spacy(query: str, nlp) -> BucketResult:
"""
Dependency-first bucketing: noun_chunks alone mis-parse verbal queries like
"noise cancelling headphones" (ROOT verb). Prefer dobj / ROOT product nouns,
purpose PP (for …), and brand INTJ/PROPN.
"""
import spacy # noqa: F401
# Do not strip possessives ('s) before spaCy: it changes the parse tree
# (e.g. "women's running shoes size 8" vs "women running shoes size 8").
text = (query or "").strip()
doc = nlp(text)
intersection: Set[str] = set()
boost: Set[str] = set()
drop: Set[str] = set()
stops = nlp.Defaults.stop_words | WEAK_BOOST_ADJS
def mark_drop(t) -> None:
if not t.is_space and not t.is_punct:
drop.add(t.text.lower())
# --- Drops: function words / question words ---
for token in doc:
if token.is_space or token.is_punct:
continue
lem = _lemma_lower(token)
if token.pos_ in ("DET", "PRON", "AUX", "ADP", "PART", "SCONJ", "CCONJ"):
mark_drop(token)
continue
if token.dep_ in FUNCTIONAL_DEP:
mark_drop(token)
continue
if token.pos_ == "ADV" and lem in {"where", "how", "when", "why", "what", "which"}:
mark_drop(token)
continue
if token.text.lower() in ("'s", "’s"):
mark_drop(token)
continue
if lem in stops and token.pos_ != "PROPN":
mark_drop(token)
pobj_heads_to_demote: Set[int] = set()
# Purpose / context: "for airplane travel" → boost phrase; demote bare head from intersection
for token in doc:
if token.dep_ == "prep" and token.text.lower() == "for":
for c in token.children:
if c.dep_ == "pobj" and c.pos_ in ("NOUN", "PROPN"):
span = doc[c.left_edge.i : c.right_edge.i + 1]
phrase = span.text.strip().lower()
if phrase:
boost.add(phrase)
pobj_heads_to_demote.add(c.i)
# Price / range: "under 500 dollars" → boost only
for token in doc:
if token.dep_ != "prep" or _lemma_lower(token) not in _PRICE_PREP_LEMMAS:
continue
for c in token.children:
if c.dep_ == "pobj" and c.pos_ in ("NOUN", "PROPN"):
span = doc[c.left_edge.i : c.right_edge.i + 1]
phrase = span.text.strip().lower()
if phrase:
boost.add(phrase)
pobj_heads_to_demote.add(c.i)
# Direct object product nouns (handles "noise cancelling … headphones")
for token in doc:
if token.dep_ == "dobj" and token.pos_ in ("NOUN", "PROPN"):
if token.i in pobj_heads_to_demote:
continue
intersection.add(_surface_lower(token))
# Copular questions / definitions: "what is the best smartphone …"
for token in doc:
if token.dep_ != "nsubj" or token.pos_ not in ("NOUN", "PROPN"):
continue
h = token.head
if h.pos_ == "AUX" and h.dep_ == "ROOT":
intersection.add(_surface_lower(token))
# Verbal ROOT: modifiers left of dobj → boost phrase (e.g. "noise cancelling")
roots = [t for t in doc if t.dep_ == "ROOT"]
if roots and roots[0].pos_ == "VERB":
root_v = roots[0]
for t in doc:
if t.dep_ != "dobj" or t.pos_ not in ("NOUN", "PROPN"):
continue
if t.i in pobj_heads_to_demote:
continue
parts: List[str] = []
for x in doc[: t.i]:
if x.is_punct or x.is_space:
continue
if x.pos_ in ("DET", "ADP", "PRON"):
continue
xl = _lemma_lower(x)
if xl in stops:
continue
parts.append(x.text.lower())
if len(parts) >= 1:
boost.add(" ".join(parts))
# Brand / query lead: INTJ/PROPN ROOT (e.g. Nike …)
for token in doc:
if token.dep_ == "ROOT" and token.pos_ in ("INTJ", "PROPN"):
intersection.add(_surface_lower(token))
if token.pos_ == "PROPN":
intersection.add(_surface_lower(token))
_DIMENSION_ROOTS = frozenset({"size", "width", "length", "height", "weight"})
# "women's running shoes size 8" → shoes ∩, "size 8" boost (not size alone)
for token in doc:
if token.dep_ != "ROOT" or token.pos_ != "NOUN":
continue
if _lemma_lower(token) not in _DIMENSION_ROOTS:
continue
for c in token.children:
if c.dep_ == "nsubj" and c.pos_ in ("NOUN", "PROPN"):
intersection.add(_surface_lower(c))
for ch in c.children:
if ch.dep_ == "compound" and ch.pos_ in ("NOUN", "VERB", "ADJ"):
boost.add(_surface_lower(ch))
# Only the dimension head + numbers (not full subtree: left_edge/right_edge is huge)
dim_parts = [token.text.lower()]
for ch in token.children:
if ch.dep_ == "nummod":
dim_parts.append(ch.text.lower())
boost.add(" ".join(dim_parts))
# ROOT noun product (e.g. "plastic toy car")
for token in doc:
if token.dep_ == "ROOT" and token.pos_ in ("NOUN", "PROPN"):
if _lemma_lower(token) in _DIMENSION_ROOTS and any(
c.dep_ == "nsubj" and c.pos_ in ("NOUN", "PROPN") for c in token.children
):
continue
intersection.add(_surface_lower(token))
for c in token.children:
if c.dep_ == "compound" and c.pos_ == "NOUN":
boost.add(c.text.lower())
if token.i - token.left_edge.i >= 1:
comps = [x.text.lower() for x in doc[token.left_edge.i : token.i] if x.dep_ == "compound"]
if len(comps) >= 2:
boost.add(" ".join(comps))
# List-like INTJ head with multiple pobj: first pobj = product head, rest often demographic
for token in doc:
if token.dep_ != "ROOT" or token.pos_ not in ("INTJ", "VERB", "NOUN"):
continue
pobjs = sorted(
[c for c in token.children if c.dep_ == "pobj" and c.pos_ in ("NOUN", "PROPN")],
key=lambda x: x.i,
)
if len(pobjs) >= 2 and token.pos_ == "INTJ":
intersection.add(_surface_lower(pobjs[0]))
for extra in pobjs[1:]:
if _lemma_lower(extra) in _DEMOGRAPHIC_NOUNS:
boost.add(_surface_lower(extra))
else:
intersection.add(_surface_lower(extra))
elif len(pobjs) == 1 and token.pos_ == "INTJ":
intersection.add(_surface_lower(pobjs[0]))
# amod under pobj (running → shoes)
for token in doc:
if token.dep_ == "amod" and token.head.pos_ in ("NOUN", "PROPN"):
if token.pos_ == "VERB":
boost.add(_surface_lower(token))
elif token.pos_ == "ADJ":
boost.add(_lemma_lower(token))
# Genitive possessor (women's shoes → women boost)
for token in doc:
if token.dep_ == "poss" and token.head.pos_ in ("NOUN", "PROPN"):
boost.add(_surface_lower(token))
# noun_chunks fallback when no dobj/ROOT intersection yet
if not intersection:
for chunk in doc.noun_chunks:
head = chunk.root
if head.pos_ not in ("NOUN", "PROPN"):
continue
# Price / range: "under 500 dollars" → boost, not a product head
if head.dep_ == "pobj" and head.head.dep_ == "prep":
prep = head.head
if _lemma_lower(prep) in _PRICE_PREP_LEMMAS:
boost.add(chunk.text.strip().lower())
continue
hl = _surface_lower(head)
if hl:
intersection.add(hl)
for t in chunk:
if t == head or t.pos_ != "PROPN":
continue
intersection.add(_surface_lower(t))
for t in chunk:
if t == head:
continue
if t.pos_ == "ADJ" or (t.pos_ == "NOUN" and t.dep_ == "compound"):
boost.add(_lemma_lower(t))
# Remove demoted pobj heads from intersection (purpose / price clause)
for i in pobj_heads_to_demote:
t = doc[i]
intersection.discard(_lemma_lower(t))
intersection.discard(_surface_lower(t))
boost -= intersection
boost = {b for b in boost if b.lower() not in stops and b.strip()}
return BucketResult(
intersection_terms=_dedupe_preserve(sorted(intersection)),
boost_terms=_dedupe_preserve(sorted(boost)),
drop_terms=_dedupe_preserve(sorted(drop)),
)
# --- Scheme B: spaCy candidates + KeyBERT -----------------------------------
def _spacy_np_candidates(doc) -> List[str]:
phrases: List[str] = []
for chunk in doc.noun_chunks:
t = chunk.text.strip()
if len(t) < 2:
continue
root = chunk.root
if root.pos_ not in ("NOUN", "PROPN"):
continue
phrases.append(t)
return phrases
def bucket_scheme_b_keybert(query: str, nlp, kw_model) -> BucketResult:
text = (query or "").strip()
doc = nlp(text)
candidates = _spacy_np_candidates(doc)
if not candidates:
candidates = [text]
# KeyBERT API: candidate_keywords=... (sentence-transformers backend)
try:
keywords = kw_model.extract_keywords(
text,
candidates=candidates,
top_n=min(8, max(4, len(candidates) + 2)),
)
except TypeError:
keywords = kw_model.extract_keywords(
text,
candidate_keywords=candidates,
top_n=min(8, max(4, len(candidates) + 2)),
)
ranked = [k[0].lower().strip() for k in (keywords or []) if k and k[0].strip()]
intersection: List[str] = []
boost: List[str] = []
if ranked:
intersection.append(ranked[0])
if len(ranked) > 1:
boost.extend(ranked[1:])
# Add remaining spaCy heads not in lists
heads: List[str] = []
for ch in doc.noun_chunks:
h = ch.root
if h.pos_ in ("NOUN", "PROPN"):
heads.append(_surface_lower(h))
for h in heads:
if h and h not in intersection and h not in boost:
boost.append(h)
if not intersection and heads:
intersection.append(heads[0])
boost = [x for x in boost if x != heads[0]]
drop_tokens: Set[str] = set()
stops = nlp.Defaults.stop_words | WEAK_BOOST_ADJS
for token in doc:
if token.is_punct:
continue
lem = (token.lemma_ or token.text).lower()
if token.pos_ in ("DET", "ADP", "PART", "PRON", "AUX") or lem in stops:
drop_tokens.add(token.text.lower())
return BucketResult(
intersection_terms=_dedupe_preserve(intersection),
boost_terms=_dedupe_preserve(boost),
drop_terms=sorted(drop_tokens),
)
# --- Scheme C: YAKE + noun filter --------------------------------------------
def bucket_scheme_c_yake(query: str, nlp, yake_extractor) -> BucketResult:
text = (query or "").strip()
doc = nlp(text)
kws = yake_extractor.extract_keywords(text) # List[Tuple[str, float]] newest yake API may differ
scored: List[Tuple[str, float]] = []
if kws and isinstance(kws[0], (list, tuple)) and len(kws[0]) >= 2:
scored = [(str(a).strip(), float(b)) for a, b in kws]
else:
# older yake returns list of tuples (kw, score)
scored = [(str(x[0]).strip(), float(x[1])) for x in kws]
boost: List[str] = []
intersection: List[str] = []
for phrase, _score in sorted(scored, key=lambda x: x[1]): # lower score = more important in YAKE
phrase = phrase.lower().strip()
if not phrase or len(phrase) < 2:
continue
sub = nlp(phrase)
keep = False
head_noun = False
for t in sub:
if t.is_punct or t.is_space:
continue
if t.pos_ in ("NOUN", "PROPN"):
keep = True
if t.dep_ == "ROOT" or t == sub[-1]:
head_noun = True
if not keep:
continue
# top 1–2 important → intersection (very small)
if len(intersection) < 2 and head_noun and len(phrase.split()) <= 2:
intersection.append(phrase)
else:
boost.append(phrase)
drop: Set[str] = set()
stops = nlp.Defaults.stop_words | WEAK_BOOST_ADJS
for token in doc:
if token.is_punct:
continue
lem = (token.lemma_ or token.text).lower()
if token.pos_ in ("DET", "ADP", "PART", "PRON", "AUX") or lem in stops:
drop.add(token.text.lower())
return BucketResult(
intersection_terms=_dedupe_preserve(intersection),
boost_terms=_dedupe_preserve(boost),
drop_terms=sorted(drop),
)
# --- CLI ---------------------------------------------------------------------
DEFAULT_QUERIES = [
"best noise cancelling headphones for airplane travel",
"nike running shoes women",
"plastic toy car",
"what is the best smartphone under 500 dollars",
"women's running shoes size 8",
]
def _load_spacy():
import spacy
try:
return spacy.load("en_core_web_sm")
except OSError:
print(
"Missing model: run: python -m spacy download en_core_web_sm",
file=sys.stderr,
)
raise
def _load_keybert():
from keybert import KeyBERT
# small & fast for demo; swap for larger if needed
return KeyBERT(model="paraphrase-MiniLM-L6-v2")
def _load_yake():
import yake
return yake.KeywordExtractor(
lan="en",
n=3,
dedupLim=0.9,
top=20,
features=None,
)
def main() -> None:
parser = argparse.ArgumentParser(description="English query bucketing experiments")
parser.add_argument(
"--queries",
nargs="*",
default=DEFAULT_QUERIES,
help="Queries to run (default: built-in examples)",
)
parser.add_argument(
"--scheme",
choices=("a", "b", "c", "all"),
default="all",
)
args = parser.parse_args()
nlp = _load_spacy()
kb = None
yk = None
if args.scheme in ("b", "all"):
kb = _load_keybert()
if args.scheme in ("c", "all"):
yk = _load_yake()
for q in args.queries:
print("=" * 72)
print("QUERY:", q)
print("-" * 72)
if args.scheme in ("a", "all"):
ra = bucket_scheme_a_spacy(q, nlp)
print("A spaCy+rules:", json.dumps(ra.to_dict(), ensure_ascii=False))
if args.scheme in ("b", "all") and kb is not None:
rb = bucket_scheme_b_keybert(q, nlp, kb)
print("B spaCy+KeyBERT:", json.dumps(rb.to_dict(), ensure_ascii=False))
if args.scheme in ("c", "all") and yk is not None:
rc = bucket_scheme_c_yake(q, nlp, yk)
print("C YAKE+noun filter:", json.dumps(rc.to_dict(), ensure_ascii=False))
print()
if __name__ == "__main__":
main()