test_cloud_embedding.py 5.51 KB
"""
Test script for cloud text embedding using Aliyun DashScope API.

Reads queries from queries.txt and tests embedding generation,
logging send time, receive time, and duration for each request.
"""

import os
import sys
import time
from datetime import datetime
from pathlib import Path

# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent.parent))

from embeddings.cloud_text_encoder import CloudTextEncoder


def format_timestamp(ts: float) -> str:
    """Format timestamp to readable string."""
    return datetime.fromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]


def read_queries(file_path: str, limit: int = 100) -> list:
    """
    Read queries from text file.
    
    Args:
        file_path: Path to queries file
        limit: Maximum number of queries to read
        
    Returns:
        List of query strings
    """
    queries = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if i >= limit:
                break
            query = line.strip()
            if query:  # Skip empty lines
                queries.append(query)
    return queries


def test_cloud_embedding(queries_file: str, num_queries: int = 100):
    """
    Test cloud embedding with queries from file.
    
    Args:
        queries_file: Path to queries file
        num_queries: Number of queries to test
    """
    print("=" * 80)
    print("Cloud Text Embedding Test - Aliyun DashScope API")
    print("=" * 80)
    print()
    
    # Check if API key is set
    api_key = os.getenv("DASHSCOPE_API_KEY")
    if not api_key:
        print("ERROR: DASHSCOPE_API_KEY environment variable is not set!")
        print("Please set it using: export DASHSCOPE_API_KEY='your-api-key'")
        return
    
    print(f"API Key: {api_key[:10]}...{api_key[-4:]}")
    print()
    
    # Read queries
    print(f"Reading queries from: {queries_file}")
    try:
        queries = read_queries(queries_file, limit=num_queries)
        print(f"Successfully read {len(queries)} queries")
        print()
    except Exception as e:
        print(f"ERROR: Failed to read queries file: {e}")
        return
    
    # Initialize encoder
    print("Initializing CloudTextEncoder...")
    try:
        encoder = CloudTextEncoder()
        print("CloudTextEncoder initialized successfully")
        print()
    except Exception as e:
        print(f"ERROR: Failed to initialize encoder: {e}")
        return
    
    # Test embeddings
    print("=" * 80)
    print(f"Testing {len(queries)} queries (one by one)")
    print("=" * 80)
    print()
    
    total_start = time.time()
    success_count = 0
    failure_count = 0
    total_duration = 0.0
    
    for i, query in enumerate(queries, 1):
        try:
            # Record send time
            send_time = time.time()
            send_time_str = format_timestamp(send_time)
            
            # Generate embedding
            embedding = encoder.encode(query)
            
            # Record receive time
            receive_time = time.time()
            receive_time_str = format_timestamp(receive_time)
            
            # Calculate duration
            duration = receive_time - send_time
            total_duration += duration
            
            # Verify embedding
            if embedding.shape[0] > 0:
                success_count += 1
                status = "✓ SUCCESS"
            else:
                failure_count += 1
                status = "✗ FAILED"
            
            # Print result
            query_display = query[:50] + "..." if len(query) > 50 else query
            print(f"[{i:3d}/{len(queries)}] {status}")
            print(f"  Query: {query_display}")
            print(f"  Send Time:    {send_time_str}")
            print(f"  Receive Time: {receive_time_str}")
            print(f"  Duration:     {duration:.3f}s")
            print(f"  Embedding Shape: {embedding.shape}")
            print()
            
        except Exception as e:
            failure_count += 1
            receive_time = time.time()
            duration = receive_time - send_time
            
            print(f"[{i:3d}/{len(queries)}] ✗ ERROR")
            print(f"  Query: {query[:50]}...")
            print(f"  Send Time:    {send_time_str}")
            print(f"  Receive Time: {format_timestamp(receive_time)}")
            print(f"  Duration:     {duration:.3f}s")
            print(f"  Error: {str(e)}")
            print()
    
    # Print summary
    total_elapsed = time.time() - total_start
    avg_duration = total_duration / len(queries) if queries else 0
    
    print("=" * 80)
    print("Test Summary")
    print("=" * 80)
    print(f"Total Queries:     {len(queries)}")
    print(f"Successful:        {success_count}")
    print(f"Failed:            {failure_count}")
    print(f"Success Rate:      {success_count / len(queries) * 100:.1f}%")
    print(f"Total Time:        {total_elapsed:.3f}s")
    print(f"Total API Time:    {total_duration:.3f}s")
    print(f"Average Duration:  {avg_duration:.3f}s per query")
    print(f"Throughput:        {len(queries) / total_elapsed:.2f} queries/second")
    print("=" * 80)


def main():
    """Main entry point."""
    # Default queries file path
    queries_file = Path(__file__).parent.parent / "data_crawling" / "queries.txt"
    
    # Check if file exists
    if not queries_file.exists():
        print(f"ERROR: Queries file not found: {queries_file}")
        return
    
    # Run test with 100 queries
    test_cloud_embedding(str(queries_file), num_queries=100)


if __name__ == "__main__":
    main()