cloud_text_encoder.py 4.6 KB
"""
Text embedding encoder using Aliyun DashScope API.

Generates embeddings via Aliyun's text-embedding-v4 model.
"""

import os
import logging
import threading
import time
import numpy as np
from typing import List, Union
from openai import OpenAI

logger = logging.getLogger(__name__)


class CloudTextEncoder:
    """
    Singleton text encoder using Aliyun DashScope API.
    
    Thread-safe singleton pattern ensures only one instance exists.
    Uses text-embedding-v4 model for generating embeddings.
    """
    _instance = None
    _lock = threading.Lock()

    def __new__(cls, api_key: str = None, base_url: str = None):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super(CloudTextEncoder, cls).__new__(cls)
                
                # Get API key from parameter or environment variable
                api_key = api_key or os.getenv("DASHSCOPE_API_KEY")
                if not api_key:
                    raise ValueError("DASHSCOPE_API_KEY must be set in environment or passed as parameter")
                
                # Use Beijing region by default
                base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1"
                
                cls._instance.client = OpenAI(
                    api_key=api_key,
                    base_url=base_url
                )
                cls._instance.model = "text-embedding-v4"
                logger.info(f"Created CloudTextEncoder instance with base_url: {base_url}")
                
        return cls._instance

    def encode(
        self,
        sentences: Union[str, List[str]],
        normalize_embeddings: bool = True,
        device: str = 'cpu',
        batch_size: int = 32
    ) -> np.ndarray:
        """
        Encode text into embeddings via Aliyun DashScope API.

        Args:
            sentences: Single string or list of strings to encode
            normalize_embeddings: Whether to normalize embeddings (handled by API)
            device: Device parameter (ignored, for compatibility)
            batch_size: Batch size for processing (currently processes all at once)

        Returns:
            numpy array of shape (n, dimension) containing embeddings
        """
        # Convert single string to list
        if isinstance(sentences, str):
            sentences = [sentences]
        
        if not sentences:
            return np.array([])

        try:
            # Call DashScope API
            start_time = time.time()
            completion = self.client.embeddings.create(
                model=self.model,
                input=sentences
            )
            elapsed_time = time.time() - start_time
            
            logger.info(f"Generated embeddings for {len(sentences)} texts in {elapsed_time:.3f}s")
            
            # Extract embeddings from response
            embeddings = []
            for item in completion.data:
                embeddings.append(item.embedding)
            
            return np.array(embeddings, dtype=np.float32)
            
        except Exception as e:
            logger.error(f"Failed to encode texts via DashScope API: {e}", exc_info=True)
            # Return zero embeddings as fallback (dimension based on text-embedding-v4)
            # text-embedding-v4 typically returns 1024-dimensional vectors
            return np.zeros((len(sentences), 1024), dtype=np.float32)

    def encode_batch(
        self,
        texts: List[str],
        batch_size: int = 32,
        device: str = 'cpu'
    ) -> np.ndarray:
        """
        Encode a batch of texts via Aliyun DashScope API.

        Args:
            texts: List of texts to encode
            batch_size: Batch size for processing
            device: Device parameter (ignored, for compatibility)

        Returns:
            numpy array of embeddings
        """
        if not texts:
            return np.array([])
        
        # Process in batches to avoid API limits
        all_embeddings = []
        
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            embeddings = self.encode(batch, device=device)
            all_embeddings.append(embeddings)
            
            # Small delay to avoid rate limiting
            if i + batch_size < len(texts):
                time.sleep(0.1)
        
        return np.vstack(all_embeddings) if all_embeddings else np.array([])

    def get_embedding_dimension(self) -> int:
        """
        Get the dimension of embeddings produced by this encoder.
        
        Returns:
            Embedding dimension (1024 for text-embedding-v4)
        """
        return 1024