image_encoder.py 5.12 KB
"""
Image embedding encoder using network service.

Generates embeddings via HTTP API service (default localhost:6005).
"""

import sys
import os
import requests
import numpy as np
from PIL import Image
import logging
import threading
from typing import List, Optional, Union, Dict, Any

logger = logging.getLogger(__name__)

from config.services_config import get_embedding_base_url


class CLIPImageEncoder:
    """
    Image Encoder for generating image embeddings using network service.

    Thread-safe singleton pattern.
    """

    _instance = None
    _lock = threading.Lock()

    def __new__(cls, service_url: Optional[str] = None):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super(CLIPImageEncoder, cls).__new__(cls)
                resolved_url = service_url or os.getenv("EMBEDDING_SERVICE_URL") or get_embedding_base_url()
                logger.info(f"Creating CLIPImageEncoder instance with service URL: {resolved_url}")
                cls._instance.service_url = resolved_url
                cls._instance.endpoint = f"{resolved_url}/embed/image"
        return cls._instance

    def _call_service(self, request_data: List[str]) -> List[Any]:
        """
        Call the embedding service API.

        Args:
            request_data: List of image URLs / local file paths

        Returns:
            List of embeddings (list[float]) or nulls (None), aligned to input order
        """
        try:
            response = requests.post(
                self.endpoint,
                json=request_data,
                timeout=60
            )
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"CLIPImageEncoder service request failed: {e}", exc_info=True)
            raise

    def encode_image(self, image: Image.Image) -> Optional[np.ndarray]:
        """
        Encode image to embedding vector using network service.

        Note: This method is kept for compatibility but the service only works with URLs.
        """
        logger.warning("encode_image with PIL Image not supported by service, returning None")
        return None

    def encode_image_from_url(self, url: str) -> Optional[np.ndarray]:
        """
        Generate image embedding via network service using URL.

        Args:
            url: Image URL to process

        Returns:
            Embedding vector or None if failed
        """
        try:
            response_data = self._call_service([url])
            if response_data and len(response_data) > 0 and response_data[0] is not None:
                return np.array(response_data[0], dtype=np.float32)
            logger.warning(f"No embedding for URL {url}")
            return None

        except Exception as e:
            logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True)
            return None

    def encode_batch(
        self,
        images: List[Union[str, Image.Image]],
        batch_size: int = 8
    ) -> List[Optional[np.ndarray]]:
        """
        Encode a batch of images efficiently via network service.

        Args:
            images: List of image URLs or PIL Images
            batch_size: Batch size for processing (used for service requests)

        Returns:
            List of embeddings (or None for failed images)
        """
        # Initialize results with None for all images
        results = [None] * len(images)

        # Filter out PIL Images since service only supports URLs
        url_images = []
        url_indices = []

        for i, img in enumerate(images):
            if isinstance(img, str):
                url_images.append(img)
                url_indices.append(i)
            elif isinstance(img, Image.Image):
                logger.warning(f"PIL Image at index {i} not supported by service, returning None")
                # results[i] is already None

        # Process URLs in batches
        for i in range(0, len(url_images), batch_size):
            batch_urls = url_images[i:i + batch_size]
            batch_indices = url_indices[i:i + batch_size]

            try:
                # Call service
                response_data = self._call_service(batch_urls)

                # Process response (aligned list)
                batch_results = []
                for j, url in enumerate(batch_urls):
                    if response_data and j < len(response_data) and response_data[j] is not None:
                        batch_results.append(np.array(response_data[j], dtype=np.float32))
                    else:
                        logger.warning(f"Failed to encode URL {url}: no embedding")
                        batch_results.append(None)

                # Insert results at the correct positions
                for j, result in enumerate(batch_results):
                    results[batch_indices[j]] = result

            except Exception as e:
                logger.error(f"Batch processing failed: {e}", exc_info=True)
                # Fill with None for this batch
                for j in range(len(batch_urls)):
                    results[batch_indices[j]] = None

        return results