image_encoder.py 13 KB
"""Image embedding client for the local embedding HTTP service."""

import logging
from typing import Any, List, Optional, Union

import numpy as np
import requests
from PIL import Image

logger = logging.getLogger(__name__)

from config.loader import get_app_config
from config.services_config import get_embedding_image_backend_config, get_embedding_image_base_url
from embeddings.cache_keys import build_clip_text_cache_key, build_image_cache_key
from embeddings.config import CONFIG
from embeddings.redis_embedding_cache import RedisEmbeddingCache
from request_log_context import build_downstream_request_headers, build_request_log_extra


class CLIPImageEncoder:
    """
    Image Encoder for generating image embeddings using network service.

    This client is stateless and safe to instantiate per caller.
    """

    def __init__(self, service_url: Optional[str] = None):
        resolved_url = service_url or get_embedding_image_base_url()
        redis_config = get_app_config().infrastructure.redis
        self.service_url = str(resolved_url).rstrip("/")
        self.endpoint = f"{self.service_url}/embed/image"
        self.clip_text_endpoint = f"{self.service_url}/embed/clip_text"
        # Reuse embedding cache prefix, but separate namespace for images to avoid collisions.
        self.cache_prefix = str(redis_config.embedding_cache_prefix).strip() or "embedding"
        self._mm_model_name = CONFIG.MULTIMODAL_MODEL_NAME
        logger.info("Creating CLIPImageEncoder instance with service URL: %s", self.service_url)
        self.cache = RedisEmbeddingCache(
            key_prefix=self.cache_prefix,
            namespace="image",
        )
        self._clip_text_cache = RedisEmbeddingCache(
            key_prefix=self.cache_prefix,
            namespace="clip_text",
        )

    def _call_service(
        self,
        request_data: List[str],
        normalize_embeddings: bool = True,
        priority: int = 0,
        request_id: Optional[str] = None,
        user_id: Optional[str] = None,
    ) -> List[Any]:
        """
        Call the embedding service API.

        Args:
            request_data: List of image URLs / local file paths

        Returns:
            List of embeddings (list[float]) or nulls (None), aligned to input order
        """
        response = None
        try:
            response = requests.post(
                self.endpoint,
                params={
                    "normalize": "true" if normalize_embeddings else "false",
                    "priority": max(0, int(priority)),
                },
                json=request_data,
                headers=build_downstream_request_headers(request_id=request_id, user_id=user_id),
                timeout=60
            )
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            body_preview = ""
            if response is not None:
                try:
                    body_preview = (response.text or "")[:300]
                except Exception:
                    body_preview = ""
            logger.error(
                "CLIPImageEncoder service request failed | status=%s body=%s error=%s",
                getattr(response, "status_code", "n/a"),
                body_preview,
                e,
                exc_info=True,
                extra=build_request_log_extra(request_id=request_id, user_id=user_id),
            )
            raise

    def _clip_text_via_grpc(
        self,
        request_data: List[str],
        normalize_embeddings: bool,
    ) -> List[Any]:
        """旧版 6008 无 ``/embed/clip_text`` 时走 gRPC(需 ``image_backend: clip_as_service``)。"""
        backend, cfg = get_embedding_image_backend_config()
        if backend != "clip_as_service":
            raise RuntimeError(
                "POST /embed/clip_text 返回 404:请重启图片向量服务(6008)以加载新路由;"
                "或配置 services.embedding.image_backend=clip_as_service 并启动 grpc cnclip。"
            )
        from embeddings.clip_as_service_encoder import ClipAsServiceImageEncoder
        from embeddings.config import CONFIG

        enc = ClipAsServiceImageEncoder(
            server=str(cfg.get("server") or CONFIG.CLIP_AS_SERVICE_SERVER),
            batch_size=int(cfg.get("batch_size") or CONFIG.IMAGE_BATCH_SIZE),
        )
        arrs = enc.encode_clip_texts(
            request_data,
            batch_size=len(request_data),
            normalize_embeddings=normalize_embeddings,
        )
        return [v.tolist() for v in arrs]

    def _call_clip_text_service(
        self,
        request_data: List[str],
        normalize_embeddings: bool = True,
        priority: int = 1,
        request_id: Optional[str] = None,
        user_id: Optional[str] = None,
    ) -> List[Any]:
        response = None
        try:
            response = requests.post(
                self.clip_text_endpoint,
                params={
                    "normalize": "true" if normalize_embeddings else "false",
                    "priority": max(0, int(priority)),
                },
                json=request_data,
                headers=build_downstream_request_headers(request_id=request_id, user_id=user_id),
                timeout=60,
            )
            if response.status_code == 404:
                logger.warning(
                    "POST %s returned 404; using clip-as-service gRPC fallback (restart 6008 after deploy to use HTTP)",
                    self.clip_text_endpoint,
                )
                return self._clip_text_via_grpc(request_data, normalize_embeddings)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            body_preview = ""
            if response is not None:
                try:
                    body_preview = (response.text or "")[:300]
                except Exception:
                    body_preview = ""
            logger.error(
                "CLIPImageEncoder clip_text request failed | status=%s body=%s error=%s",
                getattr(response, "status_code", "n/a"),
                body_preview,
                e,
                exc_info=True,
                extra=build_request_log_extra(request_id=request_id, user_id=user_id),
            )
            raise

    def encode_clip_text(
        self,
        text: str,
        normalize_embeddings: bool = True,
        priority: int = 1,
        request_id: Optional[str] = None,
        user_id: Optional[str] = None,
    ) -> np.ndarray:
        """
        CN-CLIP 文本塔(与 ``/embed/image`` 同向量空间),对应服务端 ``POST /embed/clip_text``。
        """
        cache_key = build_clip_text_cache_key(
            text, normalize=normalize_embeddings, model_name=self._mm_model_name
        )
        cached = self._clip_text_cache.get(cache_key)
        if cached is not None:
            return cached

        response_data = self._call_clip_text_service(
            [text.strip()],
            normalize_embeddings=normalize_embeddings,
            priority=priority,
            request_id=request_id,
            user_id=user_id,
        )
        if not response_data or len(response_data) != 1 or response_data[0] is None:
            raise RuntimeError(f"No CLIP text embedding returned for: {text[:80]!r}")
        vec = np.array(response_data[0], dtype=np.float32)
        if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
            raise RuntimeError("Invalid CLIP text embedding returned")
        self._clip_text_cache.set(cache_key, vec)
        return vec

    def encode_image(self, image: Image.Image) -> np.ndarray:
        """
        Encode image to embedding vector using network service.

        Note: This method is kept for compatibility but the service only works with URLs.
        """
        raise NotImplementedError("encode_image with PIL Image is not supported by embedding service")

    def encode_image_from_url(
        self,
        url: str,
        normalize_embeddings: bool = True,
        priority: int = 0,
        request_id: Optional[str] = None,
        user_id: Optional[str] = None,
    ) -> np.ndarray:
        """
        Generate image embedding via network service using URL.

        Args:
            url: Image URL to process

        Returns:
            Embedding vector
        """
        cache_key = build_image_cache_key(
            url, normalize=normalize_embeddings, model_name=self._mm_model_name
        )
        cached = self.cache.get(cache_key)
        if cached is not None:
            return cached

        response_data = self._call_service(
            [url],
            normalize_embeddings=normalize_embeddings,
            priority=priority,
            request_id=request_id,
            user_id=user_id,
        )
        if not response_data or len(response_data) != 1 or response_data[0] is None:
            raise RuntimeError(f"No image embedding returned for URL: {url}")
        vec = np.array(response_data[0], dtype=np.float32)
        if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
            raise RuntimeError(f"Invalid image embedding returned for URL: {url}")
        self.cache.set(cache_key, vec)
        return vec

    def encode_batch(
        self,
        images: List[Union[str, Image.Image]],
        batch_size: int = 8,
        normalize_embeddings: bool = True,
        priority: int = 0,
        request_id: Optional[str] = None,
        user_id: Optional[str] = None,
    ) -> List[np.ndarray]:
        """
        Encode a batch of images efficiently via network service.

        Args:
            images: List of image URLs or PIL Images
            batch_size: Batch size for processing (used for service requests)

        Returns:
            List of embeddings
        """
        for i, img in enumerate(images):
            if isinstance(img, Image.Image):
                raise NotImplementedError(f"PIL Image at index {i} is not supported by service")
            if not isinstance(img, str) or not img.strip():
                raise ValueError(f"Invalid image URL/path at index {i}: {img!r}")

        results: List[np.ndarray] = []
        pending_urls: List[str] = []
        pending_positions: List[int] = []

        normalized_urls = [str(u).strip() for u in images]  # type: ignore[list-item]
        for pos, url in enumerate(normalized_urls):
            cache_key = build_image_cache_key(
                url, normalize=normalize_embeddings, model_name=self._mm_model_name
            )
            cached = self.cache.get(cache_key)
            if cached is not None:
                results.append(cached)
                continue
            results.append(np.array([], dtype=np.float32))  # placeholder
            pending_positions.append(pos)
            pending_urls.append(url)

        for i in range(0, len(pending_urls), batch_size):
            batch_urls = pending_urls[i : i + batch_size]
            response_data = self._call_service(
                batch_urls,
                normalize_embeddings=normalize_embeddings,
                priority=priority,
                request_id=request_id,
                user_id=user_id,
            )
            if not response_data or len(response_data) != len(batch_urls):
                raise RuntimeError(
                    f"Image embedding response length mismatch: expected {len(batch_urls)}, "
                    f"got {0 if response_data is None else len(response_data)}"
                )
            for j, url in enumerate(batch_urls):
                embedding = response_data[j]
                if embedding is None:
                    raise RuntimeError(f"No image embedding returned for URL: {url}")
                vec = np.array(embedding, dtype=np.float32)
                if vec.ndim != 1 or vec.size == 0 or not np.isfinite(vec).all():
                    raise RuntimeError(f"Invalid image embedding returned for URL: {url}")
                self.cache.set(
                    build_image_cache_key(
                        url, normalize=normalize_embeddings, model_name=self._mm_model_name
                    ),
                    vec,
                )
                pos = pending_positions[i + j]
                results[pos] = vec

        return results

    def encode_image_urls(
        self,
        urls: List[str],
        batch_size: Optional[int] = None,
        normalize_embeddings: bool = True,
        priority: int = 0,
        request_id: Optional[str] = None,
        user_id: Optional[str] = None,
    ) -> List[np.ndarray]:
        """
        与 ClipImageModel / ClipAsServiceImageEncoder 一致的接口,供索引器 document_transformer 调用。

        Args:
            urls: 图片 URL 列表
            batch_size: 批大小(默认 8)

        Returns:
            与 urls 等长的向量列表
        """
        return self.encode_batch(
            urls,
            batch_size=batch_size or 8,
            normalize_embeddings=normalize_embeddings,
            priority=priority,
            request_id=request_id,
            user_id=user_id,
        )