Commit d47889b91a609498be116c1147568a1e4c01f77a

Authored by tangwang
1 parent f86c5fee

ES 字段查询工具 scripts/es_debug_search.py

Showing 1 changed file with 569 additions and 0 deletions   Show diff stats
scripts/es_debug_search.py 0 → 100644
... ... @@ -0,0 +1,569 @@
  1 +#!/usr/bin/env python3
  2 +"""
  3 +Interactive Elasticsearch debug search (standalone; not part of the main API).
  4 +
  5 +Flow: query → mode 1–5 → 选择显示列 (默认全选 title.zh/en, qanchors.zh/en, tags) → 条数 → 表格。
  6 +
  7 +文本检索 (1–3) 使用 ES highlight,终端内红色 (ANSI) 标记匹配片段。
  8 +
  9 +mode 5(image_embedding):图片 URL/本地路径走 POST /embed/image(6008);纯文本走 clip-as-service
  10 +gRPC(与 `embedding.image_backends.clip_as_service` 一致),不下载本地 CN-CLIP。若仅配置
  11 +local_cnclip,请改用 clip_as_service 或只输入图片 URL。
  12 +
  13 +Usage:
  14 + source activate.sh
  15 + python scripts/es_debug_search.py [--tenant-id ID] [--index NAME]
  16 +"""
  17 +
  18 +from __future__ import annotations
  19 +
  20 +import argparse
  21 +import curses
  22 +import re
  23 +import shutil
  24 +import sys
  25 +from pathlib import Path
  26 +from typing import Any, Callable, Dict, List, Sequence, Tuple
  27 +
  28 +PROJECT_ROOT = Path(__file__).resolve().parents[1]
  29 +if str(PROJECT_ROOT) not in sys.path:
  30 + sys.path.insert(0, str(PROJECT_ROOT))
  31 +
  32 +OPTIONS: Sequence[tuple[str, str]] = (
  33 + ("title", "title.zh / title.en"),
  34 + ("qanchors", "qanchors.zh / qanchors.en"),
  35 + ("tags", "tags (keyword)"),
  36 + ("title_embedding", "KNN: title_embedding (text service)"),
  37 + ("image_embedding", "KNN: image_embedding (HTTP 图 / grpc 文本)"),
  38 +)
  39 +
  40 +# 列定义:(列 id, 表头短名)
  41 +COLUMN_DEFS: Tuple[Tuple[str, str], ...] = (
  42 + ("title.zh", "title.zh"),
  43 + ("title.en", "title.en"),
  44 + ("qanchors.zh", "qanchors.zh"),
  45 + ("qanchors.en", "qanchors.en"),
  46 + ("tags", "tags"),
  47 +)
  48 +
  49 +# 文本检索模式使用的 highlight 字段
  50 +HIGHLIGHT_FIELDS_BY_MODE: Dict[int, List[str]] = {
  51 + 1: ["title.zh", "title.en"],
  52 + 2: ["qanchors.zh", "qanchors.en"],
  53 + 3: ["tags"],
  54 +}
  55 +
  56 +ANSI_RE = re.compile(r"\x1b\[[0-9;]*m")
  57 +
  58 +
  59 +def _strip_ansi(s: str) -> str:
  60 + return ANSI_RE.sub("", s)
  61 +
  62 +
  63 +def _visible_len(s: str) -> int:
  64 + return len(_strip_ansi(s))
  65 +
  66 +
  67 +def _truncate(s: str, max_len: int) -> str:
  68 + if max_len <= 0:
  69 + return ""
  70 + if _visible_len(s) <= max_len:
  71 + return s
  72 + # 在纯文本长度上截断(忽略 ANSI 近似按字符截断)
  73 + plain = _strip_ansi(s)
  74 + if len(plain) <= max_len:
  75 + return s
  76 + return plain[: max_len - 1] + "…"
  77 +
  78 +
  79 +def _lang_field(source: Dict[str, Any], obj_key: str, lang: str) -> str:
  80 + obj = source.get(obj_key)
  81 + if isinstance(obj, dict):
  82 + return str(obj.get(lang) or "").strip()
  83 + if obj is None:
  84 + return ""
  85 + return str(obj).strip()
  86 +
  87 +
  88 +def _tags_str(source: Dict[str, Any]) -> str:
  89 + raw = source.get("tags")
  90 + if raw is None:
  91 + return ""
  92 + if isinstance(raw, list):
  93 + return ", ".join(str(x) for x in raw if x is not None)
  94 + return str(raw).strip()
  95 +
  96 +
  97 +def _cell_from_hit(hit: Dict[str, Any], field_id: str, source: Dict[str, Any]) -> str:
  98 + """优先使用 highlight,否则 _source。"""
  99 + hl = hit.get("highlight") or {}
  100 + if field_id in hl:
  101 + parts = hl[field_id]
  102 + if isinstance(parts, list):
  103 + if field_id == "tags":
  104 + return ", ".join(parts)
  105 + return parts[0] if parts else ""
  106 + return str(parts)
  107 + if field_id == "title.zh":
  108 + return _lang_field(source, "title", "zh")
  109 + if field_id == "title.en":
  110 + return _lang_field(source, "title", "en")
  111 + if field_id == "qanchors.zh":
  112 + return _lang_field(source, "qanchors", "zh")
  113 + if field_id == "qanchors.en":
  114 + return _lang_field(source, "qanchors", "en")
  115 + if field_id == "tags":
  116 + return _tags_str(source)
  117 + return ""
  118 +
  119 +
  120 +def _highlight_clause(field_names: Sequence[str]) -> Dict[str, Any]:
  121 + return {
  122 + "require_field_match": True,
  123 + "pre_tags": ["\x1b[31m"],
  124 + "post_tags": ["\x1b[0m"],
  125 + "fields": {
  126 + f: {"number_of_fragments": 0, "fragment_size": 8000} for f in field_names
  127 + },
  128 + }
  129 +
  130 +
  131 +def _source_includes() -> List[str]:
  132 + return ["title", "qanchors", "tags", "spu_id"]
  133 +
  134 +
  135 +def _select_mode_curses() -> int:
  136 + labels = [f"{key} — {desc}" for key, desc in OPTIONS]
  137 +
  138 + def run(stdscr: Any) -> int:
  139 + curses.curs_set(0)
  140 + stdscr.keypad(True)
  141 + current = 0
  142 + while True:
  143 + stdscr.erase()
  144 + stdscr.addstr(
  145 + 0,
  146 + 0,
  147 + "选择模式 (↑↓ 移动, Enter 确认; 默认第一项 title)",
  148 + curses.A_BOLD,
  149 + )
  150 + for i, line in enumerate(labels):
  151 + attr = curses.A_REVERSE if i == current else curses.A_NORMAL
  152 + prefix = ">" if i == current else " "
  153 + stdscr.addstr(2 + i, 0, f"{prefix} {i + 1}. {line}", attr)
  154 + stdscr.refresh()
  155 + ch = stdscr.getch()
  156 + if ch in (curses.KEY_UP, ord("k")):
  157 + current = (current - 1) % len(labels)
  158 + elif ch in (curses.KEY_DOWN, ord("j")):
  159 + current = (current + 1) % len(labels)
  160 + elif ch in (10, 13):
  161 + return current + 1
  162 + elif ch in (27,):
  163 + return 1
  164 +
  165 + return int(curses.wrapper(run))
  166 +
  167 +
  168 +def _select_mode_fallback() -> int:
  169 + print("选择模式 (直接回车 = 1 title):")
  170 + for i, (_k, desc) in enumerate(OPTIONS, 1):
  171 + print(f" {i}. {desc}")
  172 + raw = input("编号 [1]: ").strip() or "1"
  173 + try:
  174 + n = int(raw)
  175 + except ValueError:
  176 + n = 1
  177 + return max(1, min(n, len(OPTIONS)))
  178 +
  179 +
  180 +def _select_mode() -> int:
  181 + if not sys.stdin.isatty():
  182 + return _select_mode_fallback()
  183 + try:
  184 + return _select_mode_curses()
  185 + except Exception:
  186 + return _select_mode_fallback()
  187 +
  188 +
  189 +def _select_fields_curses() -> List[str]:
  190 + """返回选中的列 id 列表(顺序与 COLUMN_DEFS 一致)。"""
  191 + ids = [c[0] for c in COLUMN_DEFS]
  192 + labels = [c[1] for c in COLUMN_DEFS]
  193 + selected = [True] * len(ids)
  194 +
  195 + def run(stdscr: Any) -> List[str]:
  196 + curses.curs_set(0)
  197 + stdscr.keypad(True)
  198 + cur = 0
  199 + while True:
  200 + stdscr.erase()
  201 + stdscr.addstr(
  202 + 0,
  203 + 0,
  204 + "选择显示列 (空格切换, Enter 确认; 默认全选)",
  205 + curses.A_BOLD,
  206 + )
  207 + stdscr.addstr(1, 0, "a: 全选 / n: 全不选", curses.A_DIM)
  208 + row = 3
  209 + for i, lab in enumerate(labels):
  210 + mark = "[x]" if selected[i] else "[ ]"
  211 + attr = curses.A_REVERSE if i == cur else curses.A_NORMAL
  212 + stdscr.addstr(row + i, 0, f"{mark} {lab}", attr)
  213 + stdscr.refresh()
  214 + ch = stdscr.getch()
  215 + if ch in (curses.KEY_UP, ord("k")):
  216 + cur = (cur - 1) % len(ids)
  217 + elif ch in (curses.KEY_DOWN, ord("j")):
  218 + cur = (cur + 1) % len(ids)
  219 + elif ch in (32,): # space
  220 + selected[cur] = not selected[cur]
  221 + elif ch in (ord("a"), ord("A")):
  222 + for j in range(len(selected)):
  223 + selected[j] = True
  224 + elif ch in (ord("n"), ord("N")):
  225 + for j in range(len(selected)):
  226 + selected[j] = False
  227 + elif ch in (10, 13):
  228 + if not any(selected):
  229 + for j in range(len(selected)):
  230 + selected[j] = True
  231 + return [ids[i] for i in range(len(ids)) if selected[i]]
  232 + elif ch in (27,):
  233 + return list(ids)
  234 +
  235 + return curses.wrapper(run)
  236 +
  237 +
  238 +def _select_fields_fallback() -> List[str]:
  239 + print("显示列 (编号 1-5 逗号分隔; 回车=全选):")
  240 + for i, (cid, lab) in enumerate(COLUMN_DEFS, 1):
  241 + print(f" {i}. {lab}")
  242 + raw = input("列 [1,2,3,4,5]: ").strip()
  243 + if not raw:
  244 + return [c[0] for c in COLUMN_DEFS]
  245 + out: List[str] = []
  246 + for part in raw.replace(",", ",").split(","):
  247 + part = part.strip()
  248 + if not part:
  249 + continue
  250 + try:
  251 + n = int(part)
  252 + except ValueError:
  253 + continue
  254 + if 1 <= n <= len(COLUMN_DEFS):
  255 + cid = COLUMN_DEFS[n - 1][0]
  256 + if cid not in out:
  257 + out.append(cid)
  258 + return out if out else [c[0] for c in COLUMN_DEFS]
  259 +
  260 +
  261 +def _select_fields() -> List[str]:
  262 + if not sys.stdin.isatty():
  263 + return _select_fields_fallback()
  264 + try:
  265 + return _select_fields_curses()
  266 + except Exception:
  267 + return _select_fields_fallback()
  268 +
  269 +
  270 +def _ordered_columns(selected: List[str]) -> List[str]:
  271 + """按 COLUMN_DEFS 顺序输出选中的列。"""
  272 + id_set = set(selected)
  273 + return [c[0] for c in COLUMN_DEFS if c[0] in id_set]
  274 +
  275 +
  276 +def _run_es(
  277 + es: Any,
  278 + index_name: str,
  279 + body: Dict[str, Any],
  280 + size: int,
  281 +) -> List[Dict[str, Any]]:
  282 + resp = es.search(index=index_name, body=body, size=size)
  283 + if hasattr(resp, "body"):
  284 + payload = resp.body
  285 + else:
  286 + payload = dict(resp) if not isinstance(resp, dict) else resp
  287 + hits = (payload.get("hits") or {}).get("hits") or []
  288 + return hits
  289 +
  290 +
  291 +def _print_table(
  292 + hits: List[Dict[str, Any]],
  293 + columns: List[str],
  294 + *,
  295 + term_width: int,
  296 +) -> None:
  297 + """简单 Unicode 表格:#、doc_id、所选列。"""
  298 + if not columns:
  299 + columns = [c[0] for c in COLUMN_DEFS]
  300 +
  301 + headers = ["#", "doc_id"] + [next(h for cid, h in COLUMN_DEFS if cid == col) for col in columns]
  302 +
  303 + rows: List[List[str]] = []
  304 + for i, hit in enumerate(hits, 1):
  305 + sid = str(hit.get("_id", ""))
  306 + src = hit.get("_source") or {}
  307 + cells = [str(i), sid]
  308 + for col in columns:
  309 + cells.append(_cell_from_hit(hit, col, src))
  310 + rows.append(cells)
  311 +
  312 + # 列宽:总宽减去边框与分隔符
  313 + ncols = len(headers)
  314 + inner = max(term_width - 3 * (ncols - 1) - 4, 40)
  315 + base = max(6, inner // ncols)
  316 + col_widths = [
  317 + min(5, base) if j == 0 else (min(26, max(12, base)) if j == 1 else base)
  318 + for j in range(ncols)
  319 + ]
  320 + w_rem = max(0, inner - col_widths[0] - col_widths[1])
  321 + rest = ncols - 2
  322 + if rest > 0:
  323 + per = max(10, w_rem // rest)
  324 + for j in range(2, ncols):
  325 + col_widths[j] = per
  326 +
  327 + # 顶线
  328 + top = "┌" + "┬".join("─" * (w + 2) for w in col_widths) + "┐"
  329 + mid = "├" + "┼".join("─" * (w + 2) for w in col_widths) + "┤"
  330 + bot = "└" + "┴".join("─" * (w + 2) for w in col_widths) + "┘"
  331 +
  332 + def fmt_row(cells: List[str]) -> str:
  333 + out = []
  334 + for j, (cell, w) in enumerate(zip(cells, col_widths)):
  335 + t = _truncate(cell.replace("\n", " "), w)
  336 + pad = w - _visible_len(t)
  337 + if pad < 0:
  338 + pad = 0
  339 + out.append(" " + t + " " * pad + " ")
  340 + return "│" + "│".join(out) + "│"
  341 +
  342 + print(top)
  343 + print(fmt_row(headers))
  344 + print(mid)
  345 + for row in rows:
  346 + print(fmt_row(row))
  347 + print(bot)
  348 +
  349 +
  350 +def _build_body_title(query: str) -> Dict[str, Any]:
  351 + return {
  352 + "query": {
  353 + "multi_match": {
  354 + "query": query,
  355 + "fields": ["title.zh", "title.en"],
  356 + "type": "best_fields",
  357 + }
  358 + },
  359 + "_source": _source_includes(),
  360 + "highlight": _highlight_clause(HIGHLIGHT_FIELDS_BY_MODE[1]),
  361 + }
  362 +
  363 +
  364 +def _build_body_qanchors(query: str) -> Dict[str, Any]:
  365 + return {
  366 + "query": {
  367 + "multi_match": {
  368 + "query": query,
  369 + "fields": ["qanchors.zh", "qanchors.en"],
  370 + "type": "best_fields",
  371 + }
  372 + },
  373 + "_source": _source_includes(),
  374 + "highlight": _highlight_clause(HIGHLIGHT_FIELDS_BY_MODE[2]),
  375 + }
  376 +
  377 +
  378 +def _build_body_tags(query: str) -> Dict[str, Any]:
  379 + return {
  380 + "query": {
  381 + "bool": {
  382 + "should": [
  383 + {"term": {"tags": query}},
  384 + {
  385 + "wildcard": {
  386 + "tags": {"value": f"*{query}*", "case_insensitive": True}
  387 + }
  388 + },
  389 + ],
  390 + "minimum_should_match": 1,
  391 + }
  392 + },
  393 + "_source": _source_includes(),
  394 + "highlight": _highlight_clause(HIGHLIGHT_FIELDS_BY_MODE[3]),
  395 + }
  396 +
  397 +
  398 +def _looks_like_image_ref(url: str) -> bool:
  399 + """HTTP(S) URL、// URL、或存在的本地文件路径。"""
  400 + import os
  401 +
  402 + s = url.strip()
  403 + if not s:
  404 + return False
  405 + sl = s.lower()
  406 + if sl.startswith(("http://", "https://", "//")):
  407 + return True
  408 + if os.path.isfile(s):
  409 + return True
  410 + return False
  411 +
  412 +
  413 +def _encode_clip_query_vector(query: str) -> List[float]:
  414 + """
  415 + 与索引中 image_embedding 同空间:图走 HTTP /embed/image;文本走 clip-as-service gRPC encode。
  416 + """
  417 + import numpy as np
  418 +
  419 + q = (query or "").strip()
  420 + if not q:
  421 + raise ValueError("empty query")
  422 +
  423 + from config.services_config import get_embedding_image_backend_config
  424 +
  425 + backend, cfg = get_embedding_image_backend_config()
  426 +
  427 + if _looks_like_image_ref(q):
  428 + from embeddings.image_encoder import CLIPImageEncoder
  429 +
  430 + enc = CLIPImageEncoder()
  431 + vec = enc.encode_image_from_url(q, normalize_embeddings=True, priority=1)
  432 + return vec.astype(np.float32).flatten().tolist()
  433 +
  434 + if backend != "clip_as_service":
  435 + raise RuntimeError(
  436 + "mode 5 纯文本查询需要 CN-CLIP 文本向量(与 clip-as-service 同空间)。"
  437 + "当前 image_backend 为 local_cnclip,本脚本不加载本地模型。"
  438 + "请将 config 中 services.embedding.image_backend 设为 clip_as_service 并启动 grpc "
  439 + "(默认 51000),或输入图片 URL/路径(将调用 POST /embed/image 到 6008)。"
  440 + )
  441 +
  442 + from embeddings.clip_as_service_encoder import _ensure_clip_client_path
  443 +
  444 + _ensure_clip_client_path()
  445 + from clip_client import Client
  446 +
  447 + server = str(cfg.get("server") or "grpc://127.0.0.1:51000").strip()
  448 + normalize = bool(cfg.get("normalize_embeddings", True))
  449 + client = Client(server)
  450 + arr = client.encode([q], batch_size=1, show_progress=False)
  451 + vec = np.asarray(arr[0], dtype=np.float32).flatten()
  452 + if normalize:
  453 + n = float(np.linalg.norm(vec))
  454 + if np.isfinite(n) and n > 0.0:
  455 + vec = vec / n
  456 + return vec.tolist()
  457 +
  458 +
  459 +def search_title_knn(es: Any, index_name: str, query: str, size: int) -> List[Dict[str, Any]]:
  460 + from embeddings.text_encoder import TextEmbeddingEncoder
  461 +
  462 + enc = TextEmbeddingEncoder()
  463 + arr = enc.encode(query, normalize_embeddings=True)
  464 + vec = arr[0]
  465 + if vec is None:
  466 + raise RuntimeError("text embedding service returned no vector")
  467 + qv = vec.astype("float32").flatten().tolist()
  468 + num_cand = max(size * 10, 100)
  469 + body: Dict[str, Any] = {
  470 + "size": size,
  471 + "knn": {
  472 + "field": "title_embedding",
  473 + "query_vector": qv,
  474 + "k": size,
  475 + "num_candidates": num_cand,
  476 + },
  477 + "_source": _source_includes(),
  478 + }
  479 + return _run_es(es, index_name, body, size)
  480 +
  481 +
  482 +def search_image_knn(es: Any, index_name: str, query: str, size: int) -> List[Dict[str, Any]]:
  483 + qv = _encode_clip_query_vector(query)
  484 + num_cand = max(size * 10, 100)
  485 + field = "image_embedding.vector"
  486 + body: Dict[str, Any] = {
  487 + "size": size,
  488 + "knn": {
  489 + "field": field,
  490 + "query_vector": qv,
  491 + "k": size,
  492 + "num_candidates": num_cand,
  493 + },
  494 + "_source": _source_includes(),
  495 + }
  496 + return _run_es(es, index_name, body, size)
  497 +
  498 +
  499 +def main() -> None:
  500 + parser = argparse.ArgumentParser(description="Interactive ES debug search")
  501 + parser.add_argument(
  502 + "--tenant-id",
  503 + default=None,
  504 + help="Tenant id for index name search_products_tenant_{id} (default: env TENANT_ID or 170)",
  505 + )
  506 + parser.add_argument(
  507 + "--index",
  508 + default=None,
  509 + help="Override full index name (skips tenant-based naming)",
  510 + )
  511 + args = parser.parse_args()
  512 +
  513 + tenant = args.tenant_id or __import__("os").environ.get("TENANT_ID") or "170"
  514 +
  515 + from indexer.mapping_generator import get_tenant_index_name
  516 + from utils.es_client import get_es_client_from_env
  517 +
  518 + index_name = args.index or get_tenant_index_name(str(tenant))
  519 + es = get_es_client_from_env().client
  520 +
  521 + dispatch: Dict[int, Callable[..., List[Dict[str, Any]]]] = {
  522 + 1: lambda e, idx, q, s: _run_es(e, idx, _build_body_title(q), s),
  523 + 2: lambda e, idx, q, s: _run_es(e, idx, _build_body_qanchors(q), s),
  524 + 3: lambda e, idx, q, s: _run_es(e, idx, _build_body_tags(q), s),
  525 + 4: search_title_knn,
  526 + 5: search_image_knn,
  527 + }
  528 +
  529 + term_w = shutil.get_terminal_size((100, 24)).columns
  530 +
  531 + print(f"索引: {index_name} (Ctrl+D 退出)\n")
  532 + while True:
  533 + try:
  534 + query = input("query> ").strip()
  535 + except EOFError:
  536 + print()
  537 + break
  538 + if not query:
  539 + continue
  540 +
  541 + mode = _select_mode()
  542 + fn = dispatch.get(mode, dispatch[1])
  543 +
  544 + cols = _select_fields()
  545 + cols = _ordered_columns(cols)
  546 +
  547 + try:
  548 + raw_size = input("条数 [20]: ").strip() or "20"
  549 + size = max(1, int(raw_size))
  550 + except EOFError:
  551 + print()
  552 + break
  553 + except ValueError:
  554 + size = 20
  555 +
  556 + term_w = shutil.get_terminal_size((100, 24)).columns
  557 + print(f"--- mode={mode} ({OPTIONS[mode - 1][0]}) columns={','.join(cols)} size={size} ---")
  558 + try:
  559 + hits = fn(es, index_name, query, size)
  560 + if not hits:
  561 + print("(无命中)")
  562 + else:
  563 + _print_table(hits, cols, term_width=term_w)
  564 + except Exception as e:
  565 + print(f"错误: {e}", file=sys.stderr)
  566 +
  567 +
  568 +if __name__ == "__main__":
  569 + main()
... ...