image_encoder.py 6.04 KB
"""
Image embedding encoder using network service.

Generates embeddings via HTTP API service running on localhost:5001.
"""

import sys
import os
import requests
import numpy as np
from PIL import Image
import logging
import threading
from typing import List, Optional, Union, Dict, Any

logger = logging.getLogger(__name__)


class CLIPImageEncoder:
    """
    Image Encoder for generating image embeddings using network service.

    Thread-safe singleton pattern.
    """

    _instance = None
    _lock = threading.Lock()

    def __new__(cls, service_url='http://localhost:5001'):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super(CLIPImageEncoder, cls).__new__(cls)
                logger.info(f"Creating CLIPImageEncoder instance with service URL: {service_url}")
                cls._instance.service_url = service_url
                cls._instance.endpoint = f"{service_url}/embedding/generate_image_embeddings"
        return cls._instance

    def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Call the embedding service API.

        Args:
            request_data: List of dictionaries with id and pic_url fields

        Returns:
            List of dictionaries with id, pic_url, embedding and error fields
        """
        try:
            response = requests.post(
                self.endpoint,
                json=request_data,
                timeout=60
            )
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"CLIPImageEncoder service request failed: {e}", exc_info=True)
            raise

    def encode_image(self, image: Image.Image) -> Optional[np.ndarray]:
        """
        Encode image to embedding vector using network service.

        Note: This method is kept for compatibility but the service only works with URLs.
        """
        logger.warning("encode_image with PIL Image not supported by service, returning None")
        return None

    def encode_image_from_url(self, url: str) -> Optional[np.ndarray]:
        """
        Generate image embedding via network service using URL.

        Args:
            url: Image URL to process

        Returns:
            Embedding vector or None if failed
        """
        try:
            # Prepare request data
            request_data = [{
                "id": "image_0",
                "pic_url": url
            }]

            # Call service
            response_data = self._call_service(request_data)

            # Process response
            if response_data and len(response_data) > 0:
                response_item = response_data[0]
                if response_item.get("embedding"):
                    return np.array(response_item["embedding"], dtype=np.float32)
                else:
                    logger.warning(f"No embedding for URL {url}, error: {response_item.get('error', 'Unknown error')}")
                    return None
            else:
                logger.warning(f"No response for URL {url}")
                return None

        except Exception as e:
            logger.error(f"Failed to process image from URL {url}: {str(e)}", exc_info=True)
            return None

    def encode_batch(
        self,
        images: List[Union[str, Image.Image]],
        batch_size: int = 8
    ) -> List[Optional[np.ndarray]]:
        """
        Encode a batch of images efficiently via network service.

        Args:
            images: List of image URLs or PIL Images
            batch_size: Batch size for processing (used for service requests)

        Returns:
            List of embeddings (or None for failed images)
        """
        # Initialize results with None for all images
        results = [None] * len(images)

        # Filter out PIL Images since service only supports URLs
        url_images = []
        url_indices = []

        for i, img in enumerate(images):
            if isinstance(img, str):
                url_images.append(img)
                url_indices.append(i)
            elif isinstance(img, Image.Image):
                logger.warning(f"PIL Image at index {i} not supported by service, returning None")
                # results[i] is already None

        # Process URLs in batches
        for i in range(0, len(url_images), batch_size):
            batch_urls = url_images[i:i + batch_size]
            batch_indices = url_indices[i:i + batch_size]

            # Prepare request data
            request_data = []
            for j, url in enumerate(batch_urls):
                request_data.append({
                    "id": f"image_{j}",
                    "pic_url": url
                })

            try:
                # Call service
                response_data = self._call_service(request_data)

                # Process response
                batch_results = []
                for j, url in enumerate(batch_urls):
                    response_item = None
                    for item in response_data:
                        if str(item.get("id")) == f"image_{j}":
                            response_item = item
                            break

                    if response_item and response_item.get("embedding"):
                        batch_results.append(np.array(response_item["embedding"], dtype=np.float32))
                    else:
                        error_msg = response_item.get("error", "Unknown error") if response_item else "No response"
                        logger.warning(f"Failed to encode URL {url}: {error_msg}")
                        batch_results.append(None)

                # Insert results at the correct positions
                for j, result in enumerate(batch_results):
                    results[batch_indices[j]] = result

            except Exception as e:
                logger.error(f"Batch processing failed: {e}", exc_info=True)
                # Fill with None for this batch
                for j in range(len(batch_urls)):
                    results[batch_indices[j]] = None

        return results