temp_embed_tenant_image_urls.py 8.05 KB
#!/usr/bin/env python3
"""
临时脚本:从 ES 遍历指定租户的 image_url,批量调用图片 embedding 服务。
5 进程并发,每请求最多 8 条 URL。日志打印到标准输出。

用法:
  source activate.sh   # 会加载 .env,提供 ES_HOST / ES_USERNAME / ES_PASSWORD
  python scripts/temp_embed_tenant_image_urls.py

未 source 时脚本也会尝试加载项目根目录 .env。
"""

from __future__ import annotations

import json
import multiprocessing as mp
import os
import sys
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlencode

import requests
from elasticsearch import Elasticsearch
from elasticsearch.helpers import scan

# 未 source activate.sh 时仍可从项目根 .env 加载(与 ES_HOST / ES_USERNAME / ES_PASSWORD 一致)
try:
    from dotenv import load_dotenv

    _ROOT = Path(__file__).resolve().parents[1]
    load_dotenv(_ROOT / ".env")
except ImportError:
    pass

# ---------------------------------------------------------------------------
# 配置(可按需修改;默认与 .env 中 ES_* 一致,见 config/loader.py)
# ---------------------------------------------------------------------------

# Elasticsearch(默认读环境变量:ES_HOST、ES_USERNAME、ES_PASSWORD)
ES_HOST: str = os.getenv("ES_HOST", "http://localhost:9200")
ES_USERNAME: Optional[str] = os.getenv("ES_USERNAME") or None
ES_PASSWORD: Optional[str] = os.getenv("ES_PASSWORD") or None
ES_INDEX: str = "search_products_tenant_163"

# 租户(keyword 字段,字符串)
TENANT_ID: str = "163"

# 图片 embedding 服务(与文档 7.1.2 一致)
EMBED_BASE_URL: str = "http://localhost:6008"
EMBED_PATH: str = "/embed/image"
EMBED_QUERY: Dict[str, Any] = {
    "normalize": "true",
    "priority": "1",  # 与对接文档 curl 一致;批量离线可改为 "0"
}

# 并发与批量
WORKER_PROCESSES: int = 5
URLS_PER_REQUEST: int = 8

# HTTP
REQUEST_TIMEOUT_SEC: float = 120.0

# ES scan(elasticsearch-py 8+/ES 9:`scan(..., query=...)` 会展开为 `client.search(**kwargs)`,
# 必须传与 Search API 一致的参数名,例如顶层 `query` = DSL 的 query 子句,不要用裸 `match_all`。)
SCROLL_CHUNK_SIZE: int = 500

# ---------------------------------------------------------------------------


@dataclass
class BatchResult:
    batch_index: int
    url_count: int
    ok: bool
    status_code: Optional[int]
    elapsed_sec: float
    error: Optional[str] = None


def _build_embed_url() -> str:
    q = urlencode(EMBED_QUERY)
    return f"{EMBED_BASE_URL.rstrip('/')}{EMBED_PATH}?{q}"


def _process_batch(payload: Tuple[int, List[str]]) -> BatchResult:
    batch_index, urls = payload
    if not urls:
        return BatchResult(batch_index, 0, True, None, 0.0, None)

    url = _build_embed_url()
    t0 = time.perf_counter()
    try:
        resp = requests.post(
            url,
            headers={"Content-Type": "application/json"},
            data=json.dumps(urls),
            timeout=REQUEST_TIMEOUT_SEC,
        )
        elapsed = time.perf_counter() - t0
        ok = resp.status_code == 200
        err: Optional[str] = None
        if ok:
            try:
                body = resp.json()
                if not isinstance(body, list) or len(body) != len(urls):
                    ok = False
                    err = f"response length mismatch or not list: got {type(body).__name__}"
            except Exception as e:
                ok = False
                err = f"json decode: {e}"
        else:
            err = resp.text[:500] if resp.text else f"HTTP {resp.status_code}"

        worker = mp.current_process().name
        status = resp.status_code if resp else None
        ms = elapsed * 1000.0
        if ok:
            print(
                f"[embed] worker={worker} batch={batch_index} urls={len(urls)} "
                f"http={status} elapsed_ms={ms:.2f} ok",
                flush=True,
            )
        else:
            print(
                f"[embed] worker={worker} batch={batch_index} urls={len(urls)} "
                f"http={status} elapsed_ms={ms:.2f} FAIL err={err}",
                flush=True,
            )
        return BatchResult(batch_index, len(urls), ok, status, elapsed, err)
    except Exception as e:
        elapsed = time.perf_counter() - t0
        worker = mp.current_process().name
        print(
            f"[embed] worker={worker} batch={batch_index} urls={len(urls)} "
            f"http=None elapsed_ms={elapsed * 1000.0:.2f} FAIL err={e}",
            flush=True,
        )
        return BatchResult(batch_index, len(urls), False, None, elapsed, str(e))


def _iter_image_urls(es: Elasticsearch) -> List[str]:
    # 对应 search body: { "query": { "term": { "tenant_id": "..." } } }
    search_kw: Dict[str, Any] = {
        "query": {"term": {"tenant_id": TENANT_ID}},
        "source_includes": ["image_url"],
    }
    urls: List[str] = []
    for hit in scan(
        es,
        query=search_kw,
        index=ES_INDEX,
        size=SCROLL_CHUNK_SIZE,
    ):
        src = hit.get("_source") or {}
        u = src.get("image_url")
        if u is None:
            continue
        s = str(u).strip()
        if not s:
            continue
        urls.append(s)
    return urls


def main() -> int:
    t_wall0 = time.perf_counter()

    auth = None
    if ES_USERNAME and ES_PASSWORD:
        auth = (ES_USERNAME, ES_PASSWORD)

    es = Elasticsearch([ES_HOST], basic_auth=auth)
    if not es.ping():
        print("ERROR: Elasticsearch ping failed", file=sys.stderr)
        return 1

    print(
        f"[main] ES={ES_HOST} basic_auth={'yes' if auth else 'no'} "
        f"index={ES_INDEX} tenant_id={TENANT_ID} "
        f"workers={WORKER_PROCESSES} urls_per_req={URLS_PER_REQUEST}",
        flush=True,
    )
    print(f"[main] embed_url={_build_embed_url()}", flush=True)

    t_fetch0 = time.perf_counter()
    all_urls = _iter_image_urls(es)
    fetch_elapsed = time.perf_counter() - t_fetch0
    print(
        f"[main] collected image_url count={len(all_urls)} es_scan_elapsed_sec={fetch_elapsed:.3f}",
        flush=True,
    )

    batches: List[List[str]] = []
    for i in range(0, len(all_urls), URLS_PER_REQUEST):
        batches.append(all_urls[i : i + URLS_PER_REQUEST])

    if not batches:
        print("[main] no URLs to process; done.", flush=True)
        return 0

    tasks = [(idx, batch) for idx, batch in enumerate(batches)]
    print(f"[main] batches={len(tasks)} (parallel processes={WORKER_PROCESSES})", flush=True)

    t_run0 = time.perf_counter()
    total_urls = 0
    success_urls = 0
    failed_urls = 0
    ok_batches = 0
    fail_batches = 0
    sum_req_sec = 0.0

    with mp.Pool(processes=WORKER_PROCESSES) as pool:
        for res in pool.imap_unordered(_process_batch, tasks, chunksize=1):
            total_urls += res.url_count
            sum_req_sec += res.elapsed_sec
            if res.ok:
                ok_batches += 1
                success_urls += res.url_count
            else:
                fail_batches += 1
                failed_urls += res.url_count

    wall_total = time.perf_counter() - t_wall0
    run_elapsed = time.perf_counter() - t_run0

    print("---------- summary ----------", flush=True)
    print(f"tenant_id:              {TENANT_ID}", flush=True)
    print(f"total documents w/ url: {len(all_urls)}", flush=True)
    print(f"total batches:          {len(batches)}", flush=True)
    print(f"batches succeeded:      {ok_batches}", flush=True)
    print(f"batches failed:         {fail_batches}", flush=True)
    print(f"urls (success path):    {success_urls}", flush=True)
    print(f"urls (failed path):     {failed_urls}", flush=True)
    print(f"ES scan elapsed (s):    {fetch_elapsed:.3f}", flush=True)
    print(f"embed phase wall (s):   {run_elapsed:.3f}", flush=True)
    print(f"sum request time (s):   {sum_req_sec:.3f}  (sequential sum, for reference)", flush=True)
    print(f"total wall time (s):    {wall_total:.3f}", flush=True)
    print("-----------------------------", flush=True)
    return 0 if fail_batches == 0 else 2


if __name__ == "__main__":
    raise SystemExit(main())