text_encoder.py 4.94 KB
"""
Text embedding encoder using network service.

Generates embeddings via HTTP API service running on localhost:5001.
"""

import sys
import requests
import time
import threading
import numpy as np
import logging
from typing import List, Union, Dict, Any

logger = logging.getLogger(__name__)


class BgeEncoder:
    """
    Singleton text encoder using network service.

    Thread-safe singleton pattern ensures only one instance exists.
    """
    _instance = None
    _lock = threading.Lock()

    def __new__(cls, service_url='http://localhost:5001'):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super(BgeEncoder, cls).__new__(cls)
                logger.info(f"Creating BgeEncoder instance with service URL: {service_url}")
                cls._instance.service_url = service_url
                cls._instance.endpoint = f"{service_url}/embedding/generate_embeddings"
        return cls._instance

    def _call_service(self, request_data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Call the embedding service API.

        Args:
            request_data: List of dictionaries with id and text fields

        Returns:
            List of dictionaries with id and embedding fields
        """
        try:
            response = requests.post(
                self.endpoint,
                json=request_data,
                timeout=60
            )
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"BgeEncoder service request failed: {e}", exc_info=True)
            raise

    def encode(
        self,
        sentences: Union[str, List[str]],
        normalize_embeddings: bool = True,
        device: str = 'cpu',
        batch_size: int = 32
    ) -> np.ndarray:
        """
        Encode text into embeddings via network service.

        Args:
            sentences: Single string or list of strings to encode
            normalize_embeddings: Whether to normalize embeddings (ignored for service)
            device: Device parameter ignored for service compatibility
            batch_size: Batch size for processing (used for service requests)

        Returns:
            numpy array of shape (n, 1024) containing embeddings
        """
        # Convert single string to list
        if isinstance(sentences, str):
            sentences = [sentences]

        # Prepare request data
        request_data = []
        for i, text in enumerate(sentences):
            request_item = {
                "id": str(i),
                "name_zh": text
            }

            # Add English and Russian fields as empty for now
            # Could be enhanced with language detection in the future
            request_item["name_en"] = None
            request_item["name_ru"] = None

            request_data.append(request_item)

        try:
            # Call service
            response_data = self._call_service(request_data)

            # Process response
            embeddings = []
            for i, text in enumerate(sentences):
                # Find corresponding response by ID
                response_item = None
                for item in response_data:
                    if str(item.get("id")) == str(i):
                        response_item = item
                        break

                if response_item:
                    # Try Chinese embedding first, then English, then Russian
                    embedding = None
                    for lang in ["embedding_zh", "embedding_en", "embedding_ru"]:
                        if lang in response_item and response_item[lang] is not None:
                            embedding = response_item[lang]
                            break

                    if embedding is not None:
                        embeddings.append(embedding)
                    else:
                        logger.warning(f"No embedding found for text {i}: {text[:50]}...")
                        embeddings.append([0.0] * 1024)
                else:
                    logger.warning(f"No response found for text {i}")
                    embeddings.append([0.0] * 1024)

            return np.array(embeddings, dtype=np.float32)

        except Exception as e:
            logger.error(f"Failed to encode texts: {e}", exc_info=True)
            # Return zero embeddings as fallback
            return np.zeros((len(sentences), 1024), dtype=np.float32)

    def encode_batch(
        self,
        texts: List[str],
        batch_size: int = 32,
        device: str = 'cpu'
    ) -> np.ndarray:
        """
        Encode a batch of texts efficiently via network service.

        Args:
            texts: List of texts to encode
            batch_size: Batch size for processing
            device: Device parameter ignored for service compatibility

        Returns:
            numpy array of embeddings
        """
        return self.encode(texts, batch_size=batch_size, device=device)