search_tools.py 8.73 KB
"""
Search Tools for Product Discovery
Provides text-based search via Search API, web search, and VLM style analysis
"""

import base64
import logging
import os
from pathlib import Path
from typing import Optional

import requests
from langchain_core.tools import tool
from openai import OpenAI

from app.config import settings

logger = logging.getLogger(__name__)

_openai_client: Optional[OpenAI] = None


def get_openai_client() -> OpenAI:
    global _openai_client
    if _openai_client is None:
        kwargs = {"api_key": settings.openai_api_key}
        if settings.openai_api_base_url:
            kwargs["base_url"] = settings.openai_api_base_url
        _openai_client = OpenAI(**kwargs)
    return _openai_client


@tool
def web_search(query: str) -> str:
    """使用 Tavily 进行通用 Web 搜索,补充外部/实时知识。

    触发场景(示例):
    - 需要**外部知识**:流行趋势、新品信息、穿搭文化、品牌故事等
    - 需要**实时/及时信息**:某地某个时节的天气、当季流行元素、最新联名款
    - 需要**宏观参考**:不同城市/国家的穿衣习惯、节日穿搭建议

    Args:
        query: 要搜索的问题,自然语言描述(建议用中文)

    Returns:
        总结后的回答 + 若干来源链接,供模型继续推理使用。
    """
    try:
        api_key = os.getenv("TAVILY_API_KEY")
        if not api_key:
            logger.error("TAVILY_API_KEY is not set in environment variables")
            return (
                "无法调用外部 Web 搜索:未检测到 TAVILY_API_KEY 环境变量。\n"
                "请在运行环境中配置 TAVILY_API_KEY 后再重试。"
            )

        logger.info(f"Calling Tavily web search with query: {query!r}")

        url = "https://api.tavily.com/search"
        headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json",
        }
        payload = {
            "query": query,
            "search_depth": "advanced",
            "include_answer": True,
        }

        response = requests.post(url, json=payload, headers=headers, timeout=60)

        if response.status_code != 200:
            logger.error(
                "Tavily API error: %s - %s",
                response.status_code,
                response.text,
            )
            return f"调用外部 Web 搜索失败:Tavily 返回状态码 {response.status_code}"

        data = response.json()
        answer = data.get("answer") or "(Tavily 未返回直接回答,仅返回了搜索结果。)"
        results = data.get("results") or []

        output_lines = [
            "【外部 Web 搜索结果(Tavily)】",
            "",
            "回答摘要:",
            answer.strip(),
        ]

        if results:
            output_lines.append("")
            output_lines.append("参考来源(部分):")
            for idx, item in enumerate(results[:5], 1):
                title = item.get("title") or "无标题"
                url = item.get("url") or ""
                output_lines.append(f"{idx}. {title}")
                if url:
                    output_lines.append(f"   链接: {url}")

        return "\n".join(output_lines).strip()

    except requests.exceptions.RequestException as e:
        logger.error("Error calling Tavily web search (network): %s", e, exc_info=True)
        return f"调用外部 Web 搜索失败(网络错误):{e}"
    except Exception as e:
        logger.error("Error calling Tavily web search: %s", e, exc_info=True)
        return f"调用外部 Web 搜索失败:{e}"


@tool
def search_products(query: str, limit: int = 5) -> str:
    """Search for fashion products using natural language descriptions.

    Use when users describe what they want:
    - "Find me red summer dresses"
    - "Show me blue running shoes"
    - "I want casual shirts for men"

    Args:
        query: Natural language product description
        limit: Maximum number of results (1-20)

    Returns:
        Formatted string with product information
    """
    try:
        logger.info(f"Searching products: '{query}', limit: {limit}")

        url = f"{settings.search_api_base_url.rstrip('/')}/search/"
        headers = {
            "Content-Type": "application/json",
            "X-Tenant-ID": settings.search_api_tenant_id,
        }
        payload = {
            "query": query,
            "size": min(limit, 20),
            "from": 0,
            "language": "zh",
        }

        response = requests.post(url, json=payload, headers=headers, timeout=60)

        if response.status_code != 200:
            logger.error(f"Search API error: {response.status_code} - {response.text}")
            return f"Error searching products: API returned {response.status_code}"

        data = response.json()
        results = data.get("results", [])

        if not results:
            return "No products found matching your search."

        output = f"Found {len(results)} product(s):\n\n"

        for idx, product in enumerate(results, 1):
            output += f"{idx}. {product.get('title', 'Unknown Product')}\n"
            output += f"   ID: {product.get('spu_id', 'N/A')}\n"
            output += f"   Category: {product.get('category_path', product.get('category_name', 'N/A'))}\n"
            if product.get("vendor"):
                output += f"   Brand: {product.get('vendor')}\n"
            if product.get("price") is not None:
                output += f"   Price: {product.get('price')}\n"

            # 规格/颜色信息
            specs = product.get("specifications", [])
            if specs:
                color_spec = next(
                    (s for s in specs if s.get("name").lower() == "color"),
                    None,
                )
                if color_spec:
                    output += f"   Color: {color_spec.get('value', 'N/A')}\n"

            output += "\n"

        return output.strip()

    except requests.exceptions.RequestException as e:
        logger.error(f"Error searching products (network): {e}", exc_info=True)
        return f"Error searching products: {str(e)}"
    except Exception as e:
        logger.error(f"Error searching products: {e}", exc_info=True)
        return f"Error searching products: {str(e)}"


@tool
def analyze_image_style(image_path: str) -> str:
    """Analyze a fashion product image using AI vision to extract detailed style information.

    Use when you need to understand style/attributes from an image:
    - Understand the style, color, pattern of a product
    - Extract attributes like "casual", "formal", "vintage"
    - Get detailed descriptions for subsequent searches

    Args:
        image_path: Path to the image file

    Returns:
        Detailed text description of the product's visual attributes
    """
    try:
        logger.info(f"Analyzing image with VLM: '{image_path}'")

        img_path = Path(image_path)
        if not img_path.exists():
            return f"Error: Image file not found at '{image_path}'"

        with open(img_path, "rb") as image_file:
            image_data = base64.b64encode(image_file.read()).decode("utf-8")

        prompt = """Analyze this fashion product image and provide a detailed description.

Include:
- Product type (e.g., shirt, dress, shoes, pants, bag)
- Primary colors
- Style/design (e.g., casual, formal, sporty, vintage, modern)
- Pattern or texture (e.g., plain, striped, checked, floral)
- Key features (e.g., collar type, sleeve length, fit)
- Material appearance (if obvious, e.g., denim, cotton, leather)
- Suitable occasion (e.g., office wear, party, casual, sports)

Provide a comprehensive yet concise description (3-4 sentences)."""

        client = get_openai_client()
        response = client.chat.completions.create(
            model=settings.openai_vision_model,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:image/jpeg;base64,{image_data}",
                                "detail": "high",
                            },
                        },
                    ],
                }
            ],
            max_tokens=500,
            temperature=0.3,
        )

        analysis = response.choices[0].message.content.strip()
        logger.info("VLM analysis completed")

        return analysis

    except Exception as e:
        logger.error(f"Error analyzing image: {e}", exc_info=True)
        return f"Error analyzing image: {str(e)}"


def get_all_tools():
    """Get all available tools for the agent"""
    return [search_products, analyze_image_style, web_search]