cache.py 4.79 KB
"""
Cache utility for storing embedding results.
"""

import json
import hashlib
import pickle
from pathlib import Path
from typing import Any, Optional
import numpy as np


class EmbeddingCache:
    """
    Simple file-based cache for embeddings.

    Uses MD5 hash of input text/URL as cache key.
    """

    def __init__(self, cache_dir: str = ".cache/embeddings"):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)

    def _get_cache_key(self, input_str: str) -> str:
        """Generate cache key from input string."""
        return hashlib.md5(input_str.encode('utf-8')).hexdigest()

    def get(self, input_str: str) -> Optional[np.ndarray]:
        """
        Get cached embedding.

        Args:
            input_str: Input text or URL

        Returns:
            Cached embedding or None if not found
        """
        cache_key = self._get_cache_key(input_str)
        cache_file = self.cache_dir / f"{cache_key}.npy"

        if cache_file.exists():
            try:
                return np.load(cache_file)
            except Exception as e:
                print(f"Failed to load cache for {input_str}: {e}")
                return None
        return None

    def set(self, input_str: str, embedding: np.ndarray) -> bool:
        """
        Store embedding in cache.

        Args:
            input_str: Input text or URL
            embedding: Embedding vector

        Returns:
            True if successful
        """
        cache_key = self._get_cache_key(input_str)
        cache_file = self.cache_dir / f"{cache_key}.npy"

        try:
            np.save(cache_file, embedding)
            return True
        except Exception as e:
            print(f"Failed to cache embedding for {input_str}: {e}")
            return False

    def exists(self, input_str: str) -> bool:
        """Check if embedding is cached."""
        cache_key = self._get_cache_key(input_str)
        cache_file = self.cache_dir / f"{cache_key}.npy"
        return cache_file.exists()

    def clear(self) -> int:
        """
        Clear all cached embeddings.

        Returns:
            Number of files deleted
        """
        count = 0
        for cache_file in self.cache_dir.glob("*.npy"):
            cache_file.unlink()
            count += 1
        return count

    def size(self) -> int:
        """Get number of cached embeddings."""
        return len(list(self.cache_dir.glob("*.npy")))


class DictCache:
    """
    Simple dictionary-based cache for query rewrite rules, translations, etc.
    """

    def __init__(self, cache_file: str = ".cache/dict_cache.json"):
        self.cache_file = Path(cache_file)
        self.cache_file.parent.mkdir(parents=True, exist_ok=True)
        self.cache = self._load()

    def _load(self) -> dict:
        """Load cache from file."""
        if self.cache_file.exists():
            try:
                with open(self.cache_file, 'r', encoding='utf-8') as f:
                    return json.load(f)
            except Exception as e:
                print(f"Failed to load cache: {e}")
                return {}
        return {}

    def _save(self) -> bool:
        """Save cache to file."""
        try:
            with open(self.cache_file, 'w', encoding='utf-8') as f:
                json.dump(self.cache, f, ensure_ascii=False, indent=2)
            return True
        except Exception as e:
            print(f"Failed to save cache: {e}")
            return False

    def get(self, key: str, category: str = "default") -> Optional[Any]:
        """
        Get cached value.

        Args:
            key: Cache key
            category: Cache category (for organizing different types of data)

        Returns:
            Cached value or None
        """
        return self.cache.get(category, {}).get(key)

    def set(self, key: str, value: Any, category: str = "default") -> bool:
        """
        Store value in cache.

        Args:
            key: Cache key
            value: Value to cache
            category: Cache category

        Returns:
            True if successful
        """
        if category not in self.cache:
            self.cache[category] = {}
        self.cache[category][key] = value
        return self._save()

    def exists(self, key: str, category: str = "default") -> bool:
        """Check if key exists in cache."""
        return category in self.cache and key in self.cache[category]

    def clear(self, category: Optional[str] = None) -> bool:
        """
        Clear cache.

        Args:
            category: If specified, clear only this category. Otherwise clear all.

        Returns:
            True if successful
        """
        if category:
            if category in self.cache:
                del self.cache[category]
        else:
            self.cache = {}
        return self._save()