Blame view

app/tools/search_tools.py 8.73 KB
e7f2b240   tangwang   first commit
1
2
  """
  Search Tools for Product Discovery
46f8dd12   tangwang   1. add prod under...
3
  Provides text-based search via Search API, web search, and VLM style analysis
e7f2b240   tangwang   first commit
4
5
6
7
  """
  
  import base64
  import logging
46f8dd12   tangwang   1. add prod under...
8
  import os
e7f2b240   tangwang   first commit
9
10
11
  from pathlib import Path
  from typing import Optional
  
8810a6fa   tangwang   重构
12
  import requests
e7f2b240   tangwang   first commit
13
14
15
16
  from langchain_core.tools import tool
  from openai import OpenAI
  
  from app.config import settings
e7f2b240   tangwang   first commit
17
18
19
  
  logger = logging.getLogger(__name__)
  
e7f2b240   tangwang   first commit
20
21
22
  _openai_client: Optional[OpenAI] = None
  
  
e7f2b240   tangwang   first commit
23
24
25
  def get_openai_client() -> OpenAI:
      global _openai_client
      if _openai_client is None:
8810a6fa   tangwang   重构
26
27
28
29
          kwargs = {"api_key": settings.openai_api_key}
          if settings.openai_api_base_url:
              kwargs["base_url"] = settings.openai_api_base_url
          _openai_client = OpenAI(**kwargs)
e7f2b240   tangwang   first commit
30
31
32
33
      return _openai_client
  
  
  @tool
46f8dd12   tangwang   1. add prod under...
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  def web_search(query: str) -> str:
      """使用 Tavily 进行通用 Web 搜索,补充外部/实时知识。
  
      触发场景(示例):
      - 需要**外部知识**:流行趋势、新品信息、穿搭文化、品牌故事等
      - 需要**实时/及时信息**:某地某个时节的天气、当季流行元素、最新联名款
      - 需要**宏观参考**:不同城市/国家的穿衣习惯、节日穿搭建议
  
      Args:
          query: 要搜索的问题,自然语言描述(建议用中文)
  
      Returns:
          总结后的回答 + 若干来源链接,供模型继续推理使用。
      """
      try:
          api_key = os.getenv("TAVILY_API_KEY")
          if not api_key:
              logger.error("TAVILY_API_KEY is not set in environment variables")
              return (
                  "无法调用外部 Web 搜索:未检测到 TAVILY_API_KEY 环境变量。\n"
                  "请在运行环境中配置 TAVILY_API_KEY 后再重试。"
              )
  
          logger.info(f"Calling Tavily web search with query: {query!r}")
  
          url = "https://api.tavily.com/search"
          headers = {
              "Authorization": f"Bearer {api_key}",
              "Content-Type": "application/json",
          }
          payload = {
              "query": query,
              "search_depth": "advanced",
              "include_answer": True,
          }
  
          response = requests.post(url, json=payload, headers=headers, timeout=60)
  
          if response.status_code != 200:
              logger.error(
                  "Tavily API error: %s - %s",
                  response.status_code,
                  response.text,
              )
              return f"调用外部 Web 搜索失败:Tavily 返回状态码 {response.status_code}"
  
          data = response.json()
          answer = data.get("answer") or "(Tavily 未返回直接回答,仅返回了搜索结果。)"
          results = data.get("results") or []
  
          output_lines = [
              "【外部 Web 搜索结果(Tavily)】",
              "",
              "回答摘要:",
              answer.strip(),
          ]
  
          if results:
              output_lines.append("")
              output_lines.append("参考来源(部分):")
              for idx, item in enumerate(results[:5], 1):
                  title = item.get("title") or "无标题"
                  url = item.get("url") or ""
                  output_lines.append(f"{idx}. {title}")
                  if url:
                      output_lines.append(f"   链接: {url}")
  
          return "\n".join(output_lines).strip()
  
      except requests.exceptions.RequestException as e:
          logger.error("Error calling Tavily web search (network): %s", e, exc_info=True)
          return f"调用外部 Web 搜索失败(网络错误):{e}"
      except Exception as e:
          logger.error("Error calling Tavily web search: %s", e, exc_info=True)
          return f"调用外部 Web 搜索失败:{e}"
  
  
  @tool
e7f2b240   tangwang   first commit
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  def search_products(query: str, limit: int = 5) -> str:
      """Search for fashion products using natural language descriptions.
  
      Use when users describe what they want:
      - "Find me red summer dresses"
      - "Show me blue running shoes"
      - "I want casual shirts for men"
  
      Args:
          query: Natural language product description
          limit: Maximum number of results (1-20)
  
      Returns:
          Formatted string with product information
      """
      try:
          logger.info(f"Searching products: '{query}', limit: {limit}")
  
8810a6fa   tangwang   重构
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
          url = f"{settings.search_api_base_url.rstrip('/')}/search/"
          headers = {
              "Content-Type": "application/json",
              "X-Tenant-ID": settings.search_api_tenant_id,
          }
          payload = {
              "query": query,
              "size": min(limit, 20),
              "from": 0,
              "language": "zh",
          }
  
          response = requests.post(url, json=payload, headers=headers, timeout=60)
  
          if response.status_code != 200:
              logger.error(f"Search API error: {response.status_code} - {response.text}")
              return f"Error searching products: API returned {response.status_code}"
  
          data = response.json()
          results = data.get("results", [])
e7f2b240   tangwang   first commit
150
151
152
153
154
155
156
  
          if not results:
              return "No products found matching your search."
  
          output = f"Found {len(results)} product(s):\n\n"
  
          for idx, product in enumerate(results, 1):
8810a6fa   tangwang   重构
157
158
159
160
161
162
163
164
165
166
167
168
              output += f"{idx}. {product.get('title', 'Unknown Product')}\n"
              output += f"   ID: {product.get('spu_id', 'N/A')}\n"
              output += f"   Category: {product.get('category_path', product.get('category_name', 'N/A'))}\n"
              if product.get("vendor"):
                  output += f"   Brand: {product.get('vendor')}\n"
              if product.get("price") is not None:
                  output += f"   Price: {product.get('price')}\n"
  
              # 规格/颜色信息
              specs = product.get("specifications", [])
              if specs:
                  color_spec = next(
9ad88986   tangwang   up`
169
                      (s for s in specs if s.get("name").lower() == "color"),
8810a6fa   tangwang   重构
170
171
172
173
174
                      None,
                  )
                  if color_spec:
                      output += f"   Color: {color_spec.get('value', 'N/A')}\n"
  
e7f2b240   tangwang   first commit
175
176
177
178
              output += "\n"
  
          return output.strip()
  
8810a6fa   tangwang   重构
179
180
181
      except requests.exceptions.RequestException as e:
          logger.error(f"Error searching products (network): {e}", exc_info=True)
          return f"Error searching products: {str(e)}"
e7f2b240   tangwang   first commit
182
183
184
185
186
187
      except Exception as e:
          logger.error(f"Error searching products: {e}", exc_info=True)
          return f"Error searching products: {str(e)}"
  
  
  @tool
e7f2b240   tangwang   first commit
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
  def analyze_image_style(image_path: str) -> str:
      """Analyze a fashion product image using AI vision to extract detailed style information.
  
      Use when you need to understand style/attributes from an image:
      - Understand the style, color, pattern of a product
      - Extract attributes like "casual", "formal", "vintage"
      - Get detailed descriptions for subsequent searches
  
      Args:
          image_path: Path to the image file
  
      Returns:
          Detailed text description of the product's visual attributes
      """
      try:
          logger.info(f"Analyzing image with VLM: '{image_path}'")
  
          img_path = Path(image_path)
          if not img_path.exists():
              return f"Error: Image file not found at '{image_path}'"
  
          with open(img_path, "rb") as image_file:
              image_data = base64.b64encode(image_file.read()).decode("utf-8")
  
          prompt = """Analyze this fashion product image and provide a detailed description.
  
  Include:
  - Product type (e.g., shirt, dress, shoes, pants, bag)
  - Primary colors
  - Style/design (e.g., casual, formal, sporty, vintage, modern)
  - Pattern or texture (e.g., plain, striped, checked, floral)
  - Key features (e.g., collar type, sleeve length, fit)
  - Material appearance (if obvious, e.g., denim, cotton, leather)
  - Suitable occasion (e.g., office wear, party, casual, sports)
  
  Provide a comprehensive yet concise description (3-4 sentences)."""
  
          client = get_openai_client()
          response = client.chat.completions.create(
46f8dd12   tangwang   1. add prod under...
227
              model=settings.openai_vision_model,
e7f2b240   tangwang   first commit
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
              messages=[
                  {
                      "role": "user",
                      "content": [
                          {"type": "text", "text": prompt},
                          {
                              "type": "image_url",
                              "image_url": {
                                  "url": f"data:image/jpeg;base64,{image_data}",
                                  "detail": "high",
                              },
                          },
                      ],
                  }
              ],
              max_tokens=500,
              temperature=0.3,
          )
  
          analysis = response.choices[0].message.content.strip()
          logger.info("VLM analysis completed")
  
          return analysis
  
      except Exception as e:
          logger.error(f"Error analyzing image: {e}", exc_info=True)
          return f"Error analyzing image: {str(e)}"
  
  
  def get_all_tools():
      """Get all available tools for the agent"""
46f8dd12   tangwang   1. add prod under...
259
      return [search_products, analyze_image_style, web_search]