purge_caches.py 5.67 KB
#!/usr/bin/env python3
"""
Purge Redis caches used by this repo (exclude trans:deepl*).

Default behavior (db=0):
- Delete embedding cache keys: {embedding_cache_prefix}:*
  (includes :image: and :clip_text: namespaces)
- Delete legacy embedding keys: embed:*  (older deployments wrote raw logical keys)
- Delete anchors cache keys: {anchor_cache_prefix}:*
- Delete translation cache keys: trans:* EXCEPT those starting with "trans:deepl"

Usage:
  source activate.sh
  python scripts/redis/purge_caches.py --dry-run
  python scripts/redis/purge_caches.py
  python scripts/redis/purge_caches.py --db 1
"""

from __future__ import annotations

import argparse
import sys
from pathlib import Path
from typing import Iterable, List, Optional

import redis

PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

from config.env_config import REDIS_CONFIG  # type: ignore


def get_redis_client(db: int) -> redis.Redis:
    return redis.Redis(
        host=REDIS_CONFIG.get("host", "localhost"),
        port=REDIS_CONFIG.get("port", 6479),
        password=REDIS_CONFIG.get("password"),
        db=db,
        decode_responses=True,
        socket_timeout=10,
        socket_connect_timeout=10,
    )


def iter_scan_keys(client: redis.Redis, pattern: str, scan_count: int = 2000) -> Iterable[str]:
    cursor = 0
    while True:
        cursor, batch = client.scan(cursor=cursor, match=pattern, count=scan_count)
        for k in batch:
            yield k
        if cursor == 0:
            break


def delete_keys(
    *,
    client: redis.Redis,
    keys: Iterable[str],
    dry_run: bool,
    batch_size: int = 1000,
) -> int:
    deleted = 0
    buf: List[str] = []

    def flush() -> int:
        nonlocal buf
        if not buf:
            return 0
        if dry_run:
            n = len(buf)
            buf = []
            return n
        pipe = client.pipeline(transaction=False)
        for k in buf:
            pipe.delete(k)
        results = pipe.execute()
        n = int(sum(1 for r in results if isinstance(r, int) and r > 0))
        buf = []
        return n

    for k in keys:
        buf.append(k)
        if len(buf) >= batch_size:
            deleted += flush()
    deleted += flush()
    return deleted


def build_tasks(
    *,
    embedding_prefix: str,
    anchor_prefix: str,
    keep_translation_prefix: str,
    include_translation: bool,
) -> List[dict]:
    tasks = [
        {"name": "embedding", "pattern": f"{embedding_prefix}:*", "exclude_prefix": None},
        {"name": "embedding_legacy_embed", "pattern": "embed:*", "exclude_prefix": None},
        {"name": "anchors", "pattern": f"{anchor_prefix}:*", "exclude_prefix": None},
    ]
    if include_translation:
        tasks.append(
            {
                "name": "translation",
                "pattern": "trans:*",
                "exclude_prefix": keep_translation_prefix,
            }
        )
    return tasks


def main() -> None:
    parser = argparse.ArgumentParser(description="Purge Redis caches (skip trans:deepl*)")
    parser.add_argument("--db", type=int, default=0, help="Redis database number (default: 0)")
    parser.add_argument("--dry-run", action="store_true", help="Only count keys; do not delete")
    parser.add_argument(
        "--include-translation",
        action="store_true",
        default=True,
        help="Also purge translation cache (default: true)",
    )
    parser.add_argument(
        "--no-translation",
        dest="include_translation",
        action="store_false",
        help="Do not purge translation cache",
    )
    parser.add_argument(
        "--keep-translation-prefix",
        type=str,
        default="trans:deepl",
        help='Do not delete translation keys starting with this prefix (default: "trans:deepl")',
    )
    parser.add_argument(
        "--embedding-prefix",
        type=str,
        default=str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")),
        help='Embedding cache prefix (default from REDIS_CONFIG["embedding_cache_prefix"])',
    )
    parser.add_argument(
        "--anchor-prefix",
        type=str,
        default=str(REDIS_CONFIG.get("anchor_cache_prefix", "product_anchors")),
        help='Anchors cache prefix (default from REDIS_CONFIG["anchor_cache_prefix"])',
    )
    parser.add_argument("--batch-size", type=int, default=1000, help="DEL pipeline batch size")
    args = parser.parse_args()

    client = get_redis_client(db=args.db)
    client.ping()

    tasks = build_tasks(
        embedding_prefix=args.embedding_prefix,
        anchor_prefix=args.anchor_prefix,
        keep_translation_prefix=args.keep_translation_prefix,
        include_translation=args.include_translation,
    )

    total_matched = 0
    total_deleted = 0

    for t in tasks:
        pattern: str = t["pattern"]
        exclude_prefix: Optional[str] = t["exclude_prefix"]

        def filtered_keys() -> Iterable[str]:
            for k in iter_scan_keys(client, pattern=pattern):
                if exclude_prefix and k.startswith(exclude_prefix):
                    continue
                yield k

        n = delete_keys(
            client=client,
            keys=filtered_keys(),
            dry_run=args.dry_run,
            batch_size=args.batch_size,
        )
        total_matched += n
        if not args.dry_run:
            total_deleted += n

        action = "would delete" if args.dry_run else "deleted"
        print(f"[{t['name']}] pattern={pattern} exclude_prefix={exclude_prefix!r} -> {action} {n:,} keys")

    if args.dry_run:
        print(f"\nDry run complete. Total keys that would be deleted: {total_matched:,}")
    else:
        print(f"\nPurge complete. Total keys deleted: {total_deleted:,}")


if __name__ == "__main__":
    main()