compare_indices.py 13.6 KB
#!/usr/bin/env python3
"""
Compare two Elasticsearch indices:
- mapping structure (field paths + types)
- field coverage stats (exists; nested-safe)
- random sample documents (same _id) and diff _source field paths

Usage:
  python scripts/inspect/compare_indices.py INDEX_A INDEX_B --sample-size 25
  python scripts/inspect/compare_indices.py INDEX_A INDEX_B --fields title.zh,vendor.zh,keywords.zh,tags.zh --fields-nested image_embedding.url,enriched_attributes.name
"""

from __future__ import annotations

import argparse
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

sys.path.insert(0, str(Path(__file__).resolve().parents[2]))

from utils.es_client import ESClient, get_es_client_from_env


def _walk_mapping_properties(props: Dict[str, Any], prefix: str = "") -> Dict[str, str]:
    """Flatten mapping properties into {field_path: type} including multi-fields."""
    out: Dict[str, str] = {}
    for name, node in (props or {}).items():
        path = f"{prefix}.{name}" if prefix else name
        if not isinstance(node, dict):
            out[path] = "unknown"
            continue
        out[path] = node.get("type") or "object"
        if isinstance(node.get("properties"), dict):
            out.update(_walk_mapping_properties(node["properties"], path))
        if isinstance(node.get("fields"), dict):
            for sub, subnode in node["fields"].items():
                if isinstance(subnode, dict):
                    out[f"{path}.{sub}"] = subnode.get("type") or "object"
                else:
                    out[f"{path}.{sub}"] = "unknown"
    return out


def _get_top_level_field_type(mapping: Dict[str, Any], top_field: str) -> Optional[str]:
    props = mapping.get("mappings", {}).get("properties", {}) or {}
    node = props.get(top_field)
    if not isinstance(node, dict):
        return None
    return node.get("type") or "object"


def _field_paths_from_source(obj: Any, prefix: str = "", list_depth: int = 3) -> Set[str]:
    """Return dotted field paths found in _source. For lists, uses '[]' marker."""
    out: Set[str] = set()
    if isinstance(obj, dict):
        for k, v in obj.items():
            p = f"{prefix}.{k}" if prefix else k
            out.add(p)
            out |= _field_paths_from_source(v, p, list_depth=list_depth)
    elif isinstance(obj, list):
        # Do not explode: just traverse first N elements
        for v in obj[:list_depth]:
            p = f"{prefix}[]" if prefix else "[]"
            out |= _field_paths_from_source(v, p, list_depth=list_depth)
    return out


def _chunks(seq: List[str], size: int) -> Iterable[List[str]]:
    for i in range(0, len(seq), size):
        yield seq[i : i + size]


@dataclass(frozen=True)
class CoverageField:
    field: str
    # If set, use nested query with this path (e.g. "image_embedding").
    nested_path: Optional[str] = None


def _infer_coverage_fields(
    mapping: Dict[str, Any],
    raw_fields: List[str],
    raw_nested_fields: List[str],
) -> List[CoverageField]:
    """
    Build coverage fields list. For fields in raw_nested_fields, always treat as nested
    and infer nested path as first segment.
    For raw_fields, auto-detect nested by checking mapping top-level field type.
    """
    out: List[CoverageField] = []

    nested_set = {f.strip() for f in raw_nested_fields if f.strip()}
    for f in nested_set:
        path = f.split(".", 1)[0]
        out.append(CoverageField(field=f, nested_path=path))

    for f in [x.strip() for x in raw_fields if x.strip()]:
        if f in nested_set:
            continue
        top = f.split(".", 1)[0]
        top_type = _get_top_level_field_type(mapping, top)
        if top_type == "nested":
            out.append(CoverageField(field=f, nested_path=top))
        else:
            out.append(CoverageField(field=f, nested_path=None))

    # stable order (nested first then normal, but preserve user order otherwise)
    seen: Set[Tuple[str, Optional[str]]] = set()
    dedup: List[CoverageField] = []
    for cf in out:
        key = (cf.field, cf.nested_path)
        if key in seen:
            continue
        seen.add(key)
        dedup.append(cf)
    return dedup


def _count_exists(es, index: str, cf: CoverageField) -> int:
    """
    Count docs where field exists.
    - If nested_path is set, uses nested query (safe for nested fields).
    - If nested query fails because path isn't actually nested in that index,
      fall back to a non-nested exists query to avoid crashing the whole report.
    """
    if cf.nested_path:
        nested_body = {
            "query": {
                "nested": {
                    "path": cf.nested_path,
                    "query": {"exists": {"field": cf.field}},
                }
            }
        }
        try:
            return int(es.count(index, body=nested_body))
        except Exception as e:
            # Most common: "[nested] failed to find nested object under path [...]"
            print(f"[warn] nested exists failed for {index} field={cf.field} path={cf.nested_path}: {type(e).__name__}")
            # fall through to exists
    body = {"query": {"exists": {"field": cf.field}}}
    return int(es.count(index, body=body))


def _print_json(obj: Any) -> None:
    print(json.dumps(obj, ensure_ascii=False, indent=2, sort_keys=False))


def compare_mapping(index_a: str, index_b: str, mapping_a: Dict[str, Any], mapping_b: Dict[str, Any]) -> None:
    flat_a = _walk_mapping_properties(mapping_a.get("mappings", {}).get("properties", {}) or {})
    flat_b = _walk_mapping_properties(mapping_b.get("mappings", {}).get("properties", {}) or {})

    only_a = sorted(set(flat_a) - set(flat_b))
    only_b = sorted(set(flat_b) - set(flat_a))
    type_diff = sorted([k for k in set(flat_a) & set(flat_b) if flat_a[k] != flat_b[k]])

    print("\n" + "=" * 90)
    print("Mapping diff (flattened field paths + types)")
    print("=" * 90)
    print(f"index_a: {index_a}")
    print(f"index_b: {index_b}")
    print(f"only_in_a: {len(only_a)}")
    print(f"only_in_b: {len(only_b)}")
    print(f"type_diff: {len(type_diff)}")

    if only_a[:50]:
        print("\nFields only in index_a (first 50):")
        for f in only_a[:50]:
            print(f"  - {f} ({flat_a.get(f)})")
        if len(only_a) > 50:
            print(f"  ... and {len(only_a) - 50} more")

    if only_b[:50]:
        print("\nFields only in index_b (first 50):")
        for f in only_b[:50]:
            print(f"  - {f} ({flat_b.get(f)})")
        if len(only_b) > 50:
            print(f"  ... and {len(only_b) - 50} more")

    if type_diff[:50]:
        print("\nFields with different types (first 50):")
        for f in type_diff[:50]:
            print(f"  - {f}: a={flat_a.get(f)} b={flat_b.get(f)}")
        if len(type_diff) > 50:
            print(f"  ... and {len(type_diff) - 50} more")


def compare_coverage(
    es,
    index_a: str,
    index_b: str,
    mapping_a: Dict[str, Any],
    mapping_b: Dict[str, Any],
    fields: List[str],
    nested_fields: List[str],
) -> None:
    cov_fields_a = _infer_coverage_fields(mapping_a, fields, nested_fields)
    cov_fields_b = _infer_coverage_fields(mapping_b, fields, nested_fields)

    # keep shared list, but warn if inference differs (it shouldn't)
    if [c.field for c in cov_fields_a] != [c.field for c in cov_fields_b]:
        print("\n[warn] coverage field list differs between indices; using index_a inference as baseline")
    cov_fields = cov_fields_a

    print("\n" + "=" * 90)
    print("Field coverage stats (count of docs where field exists)")
    print("=" * 90)
    print(f"index_a: {index_a}")
    print(f"index_b: {index_b}")

    for cf in cov_fields:
        mode = f"nested(path={cf.nested_path})" if cf.nested_path else "exists"
        a = _count_exists(es, index_a, cf)
        b = _count_exists(es, index_b, cf)
        print(f"\n- {cf.field}  [{mode}]")
        print(f"  {index_a}: {a}")
        print(f"  {index_b}: {b}")


def compare_random_samples(
    es,
    index_a: str,
    index_b: str,
    sample_size: int,
    random_seed: Optional[int],
) -> None:
    print("\n" + "=" * 90)
    print("Random sample diff (same _id; diff _source field paths)")
    print("=" * 90)
    print(f"sample_size: {sample_size}")

    random_score: Dict[str, Any] = {}
    if random_seed is not None:
        random_score["seed"] = random_seed

    sample_body = {
        "size": sample_size,
        "_source": False,
        "query": {"function_score": {"query": {"match_all": {}}, "random_score": random_score}},
    }

    # Use the underlying client directly to avoid passing duplicate `size`
    # parameters through the wrapper.
    resp = es.client.search(index=index_a, body=sample_body)
    hits = (((resp or {}).get("hits") or {}).get("hits") or [])
    ids = [h.get("_id") for h in hits if h.get("_id") is not None]

    if not ids:
        print("No hits returned; cannot sample.")
        return

    # mget in chunks
    def mget(index: str, ids_: List[str]) -> Dict[str, Dict[str, Any]]:
        out: Dict[str, Dict[str, Any]] = {}
        for batch in _chunks(ids_, 500):
            docs = es.client.mget(index=index, body={"ids": batch}).get("docs") or []
            for d in docs:
                if d.get("found") and d.get("_id") and isinstance(d.get("_source"), dict):
                    out[d["_id"]] = d["_source"]
        return out

    a_docs = mget(index_a, ids)
    b_docs = mget(index_b, ids)

    missing_in_b = [i for i in ids if i in a_docs and i not in b_docs]
    missing_in_a = [i for i in ids if i in b_docs and i not in a_docs]

    only_in_a: Set[str] = set()
    only_in_b: Set[str] = set()

    matched = 0
    for _id in ids:
        if _id in a_docs and _id in b_docs:
            matched += 1
            pa = _field_paths_from_source(a_docs[_id])
            pb = _field_paths_from_source(b_docs[_id])
            only_in_a |= (pa - pb)
            only_in_b |= (pb - pa)

    summary = {
        "sample_size": len(ids),
        "matched": matched,
        "missing_in_index_b_count": len(missing_in_b),
        "missing_in_index_a_count": len(missing_in_a),
        "missing_in_index_b_example": missing_in_b[:5],
        "missing_in_index_a_example": missing_in_a[:5],
        "fields_only_in_index_a_count": len(only_in_a),
        "fields_only_in_index_b_count": len(only_in_b),
        "fields_only_in_index_a_first80": sorted(list(only_in_a))[:80],
        "fields_only_in_index_b_first80": sorted(list(only_in_b))[:80],
    }
    _print_json(summary)


def main() -> int:
    parser = argparse.ArgumentParser(description="Compare two ES indices (mapping + data coverage + random sample).")
    parser.add_argument("index_a", help="Index A name")
    parser.add_argument("index_b", help="Index B name")
    parser.add_argument("--sample-size", type=int, default=25, help="Random sample size (default: 25)")
    parser.add_argument("--seed", type=int, default=None, help="Random seed for random_score (optional)")
    parser.add_argument(
        "--es-url",
        default=None,
        help="Elasticsearch URL. If omitted, uses env ES (preferred) or config/config.yaml.",
    )
    parser.add_argument(
        "--es-auth",
        default=None,
        help="Basic auth in 'user:pass' form. If omitted, uses env ES_AUTH or config credentials.",
    )
    parser.add_argument(
        "--fields",
        default="title.zh,vendor.zh,keywords.zh,tags.zh,keywords.en,tags.en,enriched_taxonomy_attributes,image_embedding.url,enriched_attributes.name",
        help="Comma-separated fields to compute coverage for (default: a sensible set)",
    )
    parser.add_argument(
        "--fields-nested",
        default="image_embedding.url,enriched_attributes.name",
        help="Comma-separated fields that must be treated as nested exists (default: image_embedding.url,enriched_attributes.name)",
    )
    args = parser.parse_args()

    # Prefer doc-style env vars (ES/ES_AUTH) to match ops workflow in docs/常用查询 - ES.md.
    # Fallback to config/config.yaml for repo-local tooling.
    env = __import__("os").environ
    es_url = args.es_url or (env.get("ES") or env.get("ES_HOST") or None)
    es_auth = args.es_auth or env.get("ES_AUTH")
    # Doc convention: if ES is unset, default to localhost:9200.
    if not es_url and es_auth:
        es_url = "http://127.0.0.1:9200"

    if es_url:
        username = password = None
        if es_auth and ":" in es_auth:
            username, password = es_auth.split(":", 1)
        es = ESClient(hosts=[es_url], username=username, password=password)
    else:
        es = get_es_client_from_env()

    if not es.ping():
        print("✗ Cannot connect to Elasticsearch")
        return 2

    if not es.index_exists(args.index_a):
        print(f"✗ index not found: {args.index_a}")
        return 2
    if not es.index_exists(args.index_b):
        print(f"✗ index not found: {args.index_b}")
        return 2

    mapping_all_a = es.get_mapping(args.index_a) or {}
    mapping_all_b = es.get_mapping(args.index_b) or {}
    if args.index_a not in mapping_all_a or args.index_b not in mapping_all_b:
        print("✗ Failed to fetch mappings for both indices")
        return 2

    mapping_a = mapping_all_a[args.index_a]
    mapping_b = mapping_all_b[args.index_b]

    compare_mapping(args.index_a, args.index_b, mapping_a, mapping_b)

    fields = [x for x in (args.fields or "").split(",") if x.strip()]
    nested_fields = [x for x in (args.fields_nested or "").split(",") if x.strip()]
    compare_coverage(es, args.index_a, args.index_b, mapping_a, mapping_b, fields, nested_fields)

    compare_random_samples(es, args.index_a, args.index_b, args.sample_size, args.seed)

    return 0


if __name__ == "__main__":
    raise SystemExit(main())