#!/usr/bin/env python3 """ Interactive Elasticsearch debug search (standalone; not part of the main API). Flow: query → mode 1–5 → 选择显示列 (默认全选 title.zh/en, qanchors.zh/en, tags) → 条数 → 表格。 文本检索 (1–3) 使用 ES highlight,终端内红色 (ANSI) 标记匹配片段。 mode 5(image_embedding):图片 URL/本地路径走 POST /embed/image(6008);纯文本走 clip-as-service gRPC(与 `embedding.image_backends.clip_as_service` 一致),不下载本地 CN-CLIP。若仅配置 local_cnclip,请改用 clip_as_service 或只输入图片 URL。 Usage: source activate.sh python scripts/es_debug_search.py [--tenant-id ID] [--index NAME] """ from __future__ import annotations import argparse import curses import re import shutil import sys from pathlib import Path from typing import Any, Callable, Dict, List, Sequence, Tuple PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) OPTIONS: Sequence[tuple[str, str]] = ( ("title", "title.zh / title.en"), ("qanchors", "qanchors.zh / qanchors.en"), ("tags", "tags (keyword)"), ("title_embedding", "KNN: title_embedding (text service)"), ("image_embedding", "KNN: image_embedding (HTTP 图 / grpc 文本)"), ) # 列定义:(列 id, 表头短名) COLUMN_DEFS: Tuple[Tuple[str, str], ...] = ( ("title.zh", "title.zh"), ("title.en", "title.en"), ("qanchors.zh", "qanchors.zh"), ("qanchors.en", "qanchors.en"), ("tags", "tags"), ) # 文本检索模式使用的 highlight 字段 HIGHLIGHT_FIELDS_BY_MODE: Dict[int, List[str]] = { 1: ["title.zh", "title.en"], 2: ["qanchors.zh", "qanchors.en"], 3: ["tags"], } ANSI_RE = re.compile(r"\x1b\[[0-9;]*m") def _strip_ansi(s: str) -> str: return ANSI_RE.sub("", s) def _visible_len(s: str) -> int: return len(_strip_ansi(s)) def _truncate(s: str, max_len: int) -> str: if max_len <= 0: return "" if _visible_len(s) <= max_len: return s # 在纯文本长度上截断(忽略 ANSI 近似按字符截断) plain = _strip_ansi(s) if len(plain) <= max_len: return s return plain[: max_len - 1] + "…" def _lang_field(source: Dict[str, Any], obj_key: str, lang: str) -> str: obj = source.get(obj_key) if isinstance(obj, dict): return str(obj.get(lang) or "").strip() if obj is None: return "" return str(obj).strip() def _tags_str(source: Dict[str, Any]) -> str: raw = source.get("tags") if raw is None: return "" if isinstance(raw, list): return ", ".join(str(x) for x in raw if x is not None) return str(raw).strip() def _cell_from_hit(hit: Dict[str, Any], field_id: str, source: Dict[str, Any]) -> str: """优先使用 highlight,否则 _source。""" hl = hit.get("highlight") or {} if field_id in hl: parts = hl[field_id] if isinstance(parts, list): if field_id == "tags": return ", ".join(parts) return parts[0] if parts else "" return str(parts) if field_id == "title.zh": return _lang_field(source, "title", "zh") if field_id == "title.en": return _lang_field(source, "title", "en") if field_id == "qanchors.zh": return _lang_field(source, "qanchors", "zh") if field_id == "qanchors.en": return _lang_field(source, "qanchors", "en") if field_id == "tags": return _tags_str(source) return "" def _highlight_clause(field_names: Sequence[str]) -> Dict[str, Any]: return { "require_field_match": True, "pre_tags": ["\x1b[31m"], "post_tags": ["\x1b[0m"], "fields": { f: {"number_of_fragments": 0, "fragment_size": 8000} for f in field_names }, } def _source_includes() -> List[str]: return ["title", "qanchors", "tags", "spu_id"] def _select_mode_curses() -> int: labels = [f"{key} — {desc}" for key, desc in OPTIONS] def run(stdscr: Any) -> int: curses.curs_set(0) stdscr.keypad(True) current = 0 while True: stdscr.erase() stdscr.addstr( 0, 0, "选择模式 (↑↓ 移动, Enter 确认; 默认第一项 title)", curses.A_BOLD, ) for i, line in enumerate(labels): attr = curses.A_REVERSE if i == current else curses.A_NORMAL prefix = ">" if i == current else " " stdscr.addstr(2 + i, 0, f"{prefix} {i + 1}. {line}", attr) stdscr.refresh() ch = stdscr.getch() if ch in (curses.KEY_UP, ord("k")): current = (current - 1) % len(labels) elif ch in (curses.KEY_DOWN, ord("j")): current = (current + 1) % len(labels) elif ch in (10, 13): return current + 1 elif ch in (27,): return 1 return int(curses.wrapper(run)) def _select_mode_fallback() -> int: print("选择模式 (直接回车 = 1 title):") for i, (_k, desc) in enumerate(OPTIONS, 1): print(f" {i}. {desc}") raw = input("编号 [1]: ").strip() or "1" try: n = int(raw) except ValueError: n = 1 return max(1, min(n, len(OPTIONS))) def _select_mode() -> int: if not sys.stdin.isatty(): return _select_mode_fallback() try: return _select_mode_curses() except Exception: return _select_mode_fallback() def _select_fields_curses() -> List[str]: """返回选中的列 id 列表(顺序与 COLUMN_DEFS 一致)。""" ids = [c[0] for c in COLUMN_DEFS] labels = [c[1] for c in COLUMN_DEFS] selected = [True] * len(ids) def run(stdscr: Any) -> List[str]: curses.curs_set(0) stdscr.keypad(True) cur = 0 while True: stdscr.erase() stdscr.addstr( 0, 0, "选择显示列 (空格切换, Enter 确认; 默认全选)", curses.A_BOLD, ) stdscr.addstr(1, 0, "a: 全选 / n: 全不选", curses.A_DIM) row = 3 for i, lab in enumerate(labels): mark = "[x]" if selected[i] else "[ ]" attr = curses.A_REVERSE if i == cur else curses.A_NORMAL stdscr.addstr(row + i, 0, f"{mark} {lab}", attr) stdscr.refresh() ch = stdscr.getch() if ch in (curses.KEY_UP, ord("k")): cur = (cur - 1) % len(ids) elif ch in (curses.KEY_DOWN, ord("j")): cur = (cur + 1) % len(ids) elif ch in (32,): # space selected[cur] = not selected[cur] elif ch in (ord("a"), ord("A")): for j in range(len(selected)): selected[j] = True elif ch in (ord("n"), ord("N")): for j in range(len(selected)): selected[j] = False elif ch in (10, 13): if not any(selected): for j in range(len(selected)): selected[j] = True return [ids[i] for i in range(len(ids)) if selected[i]] elif ch in (27,): return list(ids) return curses.wrapper(run) def _select_fields_fallback() -> List[str]: print("显示列 (编号 1-5 逗号分隔; 回车=全选):") for i, (cid, lab) in enumerate(COLUMN_DEFS, 1): print(f" {i}. {lab}") raw = input("列 [1,2,3,4,5]: ").strip() if not raw: return [c[0] for c in COLUMN_DEFS] out: List[str] = [] for part in raw.replace(",", ",").split(","): part = part.strip() if not part: continue try: n = int(part) except ValueError: continue if 1 <= n <= len(COLUMN_DEFS): cid = COLUMN_DEFS[n - 1][0] if cid not in out: out.append(cid) return out if out else [c[0] for c in COLUMN_DEFS] def _select_fields() -> List[str]: if not sys.stdin.isatty(): return _select_fields_fallback() try: return _select_fields_curses() except Exception: return _select_fields_fallback() def _ordered_columns(selected: List[str]) -> List[str]: """按 COLUMN_DEFS 顺序输出选中的列。""" id_set = set(selected) return [c[0] for c in COLUMN_DEFS if c[0] in id_set] def _run_es( es: Any, index_name: str, body: Dict[str, Any], size: int, ) -> List[Dict[str, Any]]: resp = es.search(index=index_name, body=body, size=size) if hasattr(resp, "body"): payload = resp.body else: payload = dict(resp) if not isinstance(resp, dict) else resp hits = (payload.get("hits") or {}).get("hits") or [] return hits def _print_table( hits: List[Dict[str, Any]], columns: List[str], *, term_width: int, ) -> None: """简单 Unicode 表格:#、doc_id、所选列。""" if not columns: columns = [c[0] for c in COLUMN_DEFS] headers = ["#", "doc_id"] + [next(h for cid, h in COLUMN_DEFS if cid == col) for col in columns] rows: List[List[str]] = [] for i, hit in enumerate(hits, 1): sid = str(hit.get("_id", "")) src = hit.get("_source") or {} cells = [str(i), sid] for col in columns: cells.append(_cell_from_hit(hit, col, src)) rows.append(cells) # 列宽:总宽减去边框与分隔符 ncols = len(headers) inner = max(term_width - 3 * (ncols - 1) - 4, 40) base = max(6, inner // ncols) col_widths = [ min(5, base) if j == 0 else (min(26, max(12, base)) if j == 1 else base) for j in range(ncols) ] w_rem = max(0, inner - col_widths[0] - col_widths[1]) rest = ncols - 2 if rest > 0: per = max(10, w_rem // rest) for j in range(2, ncols): col_widths[j] = per # 顶线 top = "┌" + "┬".join("─" * (w + 2) for w in col_widths) + "┐" mid = "├" + "┼".join("─" * (w + 2) for w in col_widths) + "┤" bot = "└" + "┴".join("─" * (w + 2) for w in col_widths) + "┘" def fmt_row(cells: List[str]) -> str: out = [] for j, (cell, w) in enumerate(zip(cells, col_widths)): t = _truncate(cell.replace("\n", " "), w) pad = w - _visible_len(t) if pad < 0: pad = 0 out.append(" " + t + " " * pad + " ") return "│" + "│".join(out) + "│" print(top) print(fmt_row(headers)) print(mid) for row in rows: print(fmt_row(row)) print(bot) def _build_body_title(query: str) -> Dict[str, Any]: return { "query": { "multi_match": { "query": query, "fields": ["title.zh", "title.en"], "type": "best_fields", } }, "_source": _source_includes(), "highlight": _highlight_clause(HIGHLIGHT_FIELDS_BY_MODE[1]), } def _build_body_qanchors(query: str) -> Dict[str, Any]: return { "query": { "multi_match": { "query": query, "fields": ["qanchors.zh", "qanchors.en"], "type": "best_fields", } }, "_source": _source_includes(), "highlight": _highlight_clause(HIGHLIGHT_FIELDS_BY_MODE[2]), } def _build_body_tags(query: str) -> Dict[str, Any]: return { "query": { "bool": { "should": [ {"term": {"tags": query}}, { "wildcard": { "tags": {"value": f"*{query}*", "case_insensitive": True} } }, ], "minimum_should_match": 1, } }, "_source": _source_includes(), "highlight": _highlight_clause(HIGHLIGHT_FIELDS_BY_MODE[3]), } def _looks_like_image_ref(url: str) -> bool: """HTTP(S) URL、// URL、或存在的本地文件路径。""" import os s = url.strip() if not s: return False sl = s.lower() if sl.startswith(("http://", "https://", "//")): return True if os.path.isfile(s): return True return False def _encode_clip_query_vector(query: str) -> List[float]: """ 与索引中 image_embedding 同空间:图走 ``POST /embed/image``;文本走 ``POST /embed/clip_text``(6008)。 """ import numpy as np q = (query or "").strip() if not q: raise ValueError("empty query") from embeddings.image_encoder import CLIPImageEncoder enc = CLIPImageEncoder() if _looks_like_image_ref(q): vec = enc.encode_image_from_url(q, normalize_embeddings=True, priority=1) else: vec = enc.encode_clip_text(q, normalize_embeddings=True, priority=1) return vec.astype(np.float32).flatten().tolist() def search_title_knn(es: Any, index_name: str, query: str, size: int) -> List[Dict[str, Any]]: from embeddings.text_encoder import TextEmbeddingEncoder enc = TextEmbeddingEncoder() arr = enc.encode(query, normalize_embeddings=True) vec = arr[0] if vec is None: raise RuntimeError("text embedding service returned no vector") qv = vec.astype("float32").flatten().tolist() num_cand = max(size * 10, 100) body: Dict[str, Any] = { "knn": { "field": "title_embedding", "query_vector": qv, "k": size, "num_candidates": num_cand, }, "_source": _source_includes(), } return _run_es(es, index_name, body, size) def search_image_knn(es: Any, index_name: str, query: str, size: int) -> List[Dict[str, Any]]: qv = _encode_clip_query_vector(query) num_cand = max(size * 10, 100) field = "image_embedding.vector" body: Dict[str, Any] = { "knn": { "field": field, "query_vector": qv, "k": size, "num_candidates": num_cand, }, "_source": _source_includes(), } return _run_es(es, index_name, body, size) def main() -> None: parser = argparse.ArgumentParser(description="Interactive ES debug search") parser.add_argument( "--tenant-id", default=None, help="Tenant id for index name search_products_tenant_{id} (default: env TENANT_ID or 170)", ) parser.add_argument( "--index", default=None, help="Override full index name (skips tenant-based naming)", ) args = parser.parse_args() tenant = args.tenant_id or __import__("os").environ.get("TENANT_ID") or "170" from indexer.mapping_generator import get_tenant_index_name from utils.es_client import get_es_client_from_env index_name = args.index or get_tenant_index_name(str(tenant)) es = get_es_client_from_env().client dispatch: Dict[int, Callable[..., List[Dict[str, Any]]]] = { 1: lambda e, idx, q, s: _run_es(e, idx, _build_body_title(q), s), 2: lambda e, idx, q, s: _run_es(e, idx, _build_body_qanchors(q), s), 3: lambda e, idx, q, s: _run_es(e, idx, _build_body_tags(q), s), 4: search_title_knn, 5: search_image_knn, } term_w = shutil.get_terminal_size((100, 24)).columns print(f"索引: {index_name} (Ctrl+D 退出)\n") while True: try: query = input("query> ").strip() except EOFError: print() break if not query: continue mode = _select_mode() fn = dispatch.get(mode, dispatch[1]) cols = _select_fields() cols = _ordered_columns(cols) try: raw_size = input("条数 [20]: ").strip() or "20" size = max(1, int(raw_size)) except EOFError: print() break except ValueError: size = 20 term_w = shutil.get_terminal_size((100, 24)).columns print(f"--- mode={mode} ({OPTIONS[mode - 1][0]}) columns={','.join(cols)} size={size} ---") try: hits = fn(es, index_name, query, size) if not hits: print("(无命中)") else: _print_table(hits, cols, term_width=term_w) except Exception as e: print(f"错误: {e}", file=sys.stderr) if __name__ == "__main__": main()