search_tools.py 5.94 KB
"""
Search Tools for Product Discovery
Provides text-based search via Search API and VLM style analysis
"""

import base64
import logging
from pathlib import Path
from typing import Optional

import requests
from langchain_core.tools import tool
from openai import OpenAI

from app.config import settings

logger = logging.getLogger(__name__)

_openai_client: Optional[OpenAI] = None


def get_openai_client() -> OpenAI:
    global _openai_client
    if _openai_client is None:
        kwargs = {"api_key": settings.openai_api_key}
        if settings.openai_api_base_url:
            kwargs["base_url"] = settings.openai_api_base_url
        _openai_client = OpenAI(**kwargs)
    return _openai_client


@tool
def search_products(query: str, limit: int = 5) -> str:
    """Search for fashion products using natural language descriptions.

    Use when users describe what they want:
    - "Find me red summer dresses"
    - "Show me blue running shoes"
    - "I want casual shirts for men"

    Args:
        query: Natural language product description
        limit: Maximum number of results (1-20)

    Returns:
        Formatted string with product information
    """
    try:
        logger.info(f"Searching products: '{query}', limit: {limit}")

        url = f"{settings.search_api_base_url.rstrip('/')}/search/"
        headers = {
            "Content-Type": "application/json",
            "X-Tenant-ID": settings.search_api_tenant_id,
        }
        payload = {
            "query": query,
            "size": min(limit, 20),
            "from": 0,
            "language": "zh",
        }

        response = requests.post(url, json=payload, headers=headers, timeout=60)

        if response.status_code != 200:
            logger.error(f"Search API error: {response.status_code} - {response.text}")
            return f"Error searching products: API returned {response.status_code}"

        data = response.json()
        results = data.get("results", [])

        if not results:
            return "No products found matching your search."

        output = f"Found {len(results)} product(s):\n\n"

        for idx, product in enumerate(results, 1):
            output += f"{idx}. {product.get('title', 'Unknown Product')}\n"
            output += f"   ID: {product.get('spu_id', 'N/A')}\n"
            output += f"   Category: {product.get('category_path', product.get('category_name', 'N/A'))}\n"
            if product.get("vendor"):
                output += f"   Brand: {product.get('vendor')}\n"
            if product.get("price") is not None:
                output += f"   Price: {product.get('price')}\n"

            # 规格/颜色信息
            specs = product.get("specifications", [])
            if specs:
                color_spec = next(
                    (s for s in specs if s.get("name") == "color"),
                    None,
                )
                if color_spec:
                    output += f"   Color: {color_spec.get('value', 'N/A')}\n"

            if product.get("relevance_score") is not None:
                output += f"   Relevance: {product['relevance_score']:.2f}\n"

            output += "\n"

        return output.strip()

    except requests.exceptions.RequestException as e:
        logger.error(f"Error searching products (network): {e}", exc_info=True)
        return f"Error searching products: {str(e)}"
    except Exception as e:
        logger.error(f"Error searching products: {e}", exc_info=True)
        return f"Error searching products: {str(e)}"


@tool
def analyze_image_style(image_path: str) -> str:
    """Analyze a fashion product image using AI vision to extract detailed style information.

    Use when you need to understand style/attributes from an image:
    - Understand the style, color, pattern of a product
    - Extract attributes like "casual", "formal", "vintage"
    - Get detailed descriptions for subsequent searches

    Args:
        image_path: Path to the image file

    Returns:
        Detailed text description of the product's visual attributes
    """
    try:
        logger.info(f"Analyzing image with VLM: '{image_path}'")

        img_path = Path(image_path)
        if not img_path.exists():
            return f"Error: Image file not found at '{image_path}'"

        with open(img_path, "rb") as image_file:
            image_data = base64.b64encode(image_file.read()).decode("utf-8")

        prompt = """Analyze this fashion product image and provide a detailed description.

Include:
- Product type (e.g., shirt, dress, shoes, pants, bag)
- Primary colors
- Style/design (e.g., casual, formal, sporty, vintage, modern)
- Pattern or texture (e.g., plain, striped, checked, floral)
- Key features (e.g., collar type, sleeve length, fit)
- Material appearance (if obvious, e.g., denim, cotton, leather)
- Suitable occasion (e.g., office wear, party, casual, sports)

Provide a comprehensive yet concise description (3-4 sentences)."""

        client = get_openai_client()
        response = client.chat.completions.create(
            model="gpt-4o-mini",
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{image_data}",
                                "detail": "high",
                            },
                        },
                    ],
                }
            ],
            max_tokens=500,
            temperature=0.3,
        )

        analysis = response.choices[0].message.content.strip()
        logger.info("VLM analysis completed")

        return analysis

    except Exception as e:
        logger.error(f"Error analyzing image: {e}", exc_info=True)
        return f"Error analyzing image: {str(e)}"


def get_all_tools():
    """Get all available tools for the agent"""
    return [search_products, analyze_image_style]