""" Search Tools for Product Discovery Key design: - search_products is created via a factory (make_search_products_tool) that closes over (session_id, registry), so each agent session has its own tool instance pointing to the shared registry. - After calling the search API, an LLM quality-assessment step labels every result as 完美匹配 / 部分匹配 / 不相关 and produces an overall verdict. - The curated product list is stored in the registry under a unique ref_id. - The tool returns ONLY the quality summary + [SEARCH_REF:ref_id], never the raw product list. The LLM references the result in its final response via the [SEARCH_REF:...] token; the UI renders the product cards from the registry. """ import base64 import json import logging import os from pathlib import Path from typing import Optional import requests from langchain_core.tools import tool from openai import OpenAI from app.config import settings from app.search_registry import ( ProductItem, SearchResult, SearchResultRegistry, global_registry, new_ref_id, ) logger = logging.getLogger(__name__) _openai_client: Optional[OpenAI] = None def get_openai_client() -> OpenAI: global _openai_client if _openai_client is None: kwargs = {"api_key": settings.openai_api_key} if settings.openai_api_base_url: kwargs["base_url"] = settings.openai_api_base_url _openai_client = OpenAI(**kwargs) return _openai_client # ── LLM quality assessment ───────────────────────────────────────────────────── def _assess_search_quality( query: str, raw_products: list, ) -> tuple[list[str], str, str]: """ Ask the LLM to evaluate how well each search result matches the query. Returns: labels – list[str], one per product: "完美匹配" | "部分匹配" | "不相关" verdict – str: "优质" | "一般" | "较差" summary – str: one-sentence explanation """ n = len(raw_products) if n == 0: return [], "较差", "搜索未返回任何商品。" # Build a compact product list — only title/category/tags/score to save tokens lines: list[str] = [] for i, p in enumerate(raw_products, 1): title = (p.get("title") or "")[:60] cat = p.get("category_path") or p.get("category_name") or "" tags_raw = p.get("tags") or [] tags = ", ".join(str(t) for t in tags_raw[:5]) score = p.get("relevance_score") or 0 row = f"{i}. [{score:.1f}] {title} | {cat}" if tags: row += f" | 标签:{tags}" lines.append(row) product_text = "\n".join(lines) prompt = f"""你是商品搜索质量评估专家。请评估以下搜索结果与用户查询的匹配程度。 用户查询:{query} 搜索结果(共 {n} 条,格式:序号. [相关性分数] 标题 | 分类 | 标签): {product_text} 评估说明: - 完美匹配:完全符合用户查询意图,用户必然感兴趣 - 部分匹配:与查询有关联,但不完全满足意图(如品类对但风格偏差、相关配件等) - 不相关:与查询无关,不应展示给用户 整体 verdict 判断标准: - 优质:完美匹配 ≥ 5 条 - 一般:完美匹配 2-4 条 - 较差:完美匹配 < 2 条 请严格按以下 JSON 格式输出,不得有任何额外文字或代码块标记: {{"labels": ["完美匹配", "部分匹配", "不相关", ...], "verdict": "优质", "summary": "一句话评价搜索质量"}} labels 数组长度必须恰好等于 {n}。""" try: client = get_openai_client() resp = client.chat.completions.create( model=settings.openai_model, messages=[{"role": "user", "content": prompt}], max_tokens=800, temperature=0.1, ) raw = resp.choices[0].message.content.strip() # Strip markdown code fences if the model adds them if raw.startswith("```"): raw = raw.split("```")[1] if raw.startswith("json"): raw = raw[4:] raw = raw.strip() data = json.loads(raw) labels: list[str] = data.get("labels", []) # Normalize and pad / trim to match n valid = {"完美匹配", "部分匹配", "不相关"} labels = [l if l in valid else "部分匹配" for l in labels] while len(labels) < n: labels.append("部分匹配") labels = labels[:n] verdict: str = data.get("verdict", "一般") if verdict not in ("优质", "一般", "较差"): verdict = "一般" summary: str = str(data.get("summary", "")) return labels, verdict, summary except Exception as e: logger.warning(f"Quality assessment LLM call failed: {e}; using fallback labels.") return ["部分匹配"] * n, "一般", "质量评估步骤失败,结果仅供参考。" # ── Tool factory ─────────────────────────────────────────────────────────────── def make_search_products_tool( session_id: str, registry: SearchResultRegistry, ): """ Return a search_products tool bound to a specific session and registry. The tool: 1. Calls the product search API. 2. Runs LLM quality assessment on up to 20 results. 3. Stores a SearchResult in the registry. 4. Returns a concise quality summary + [SEARCH_REF:ref_id]. The product list is NEVER returned in the tool output text. """ @tool def search_products(query: str, limit: int = 20) -> str: """搜索商品库,根据自然语言描述找到匹配商品,并进行质量评估。 每次调用专注于单一搜索角度。复杂需求请拆分为多次调用,每次换一个 query。 工具会自动评估结果质量(完美匹配 / 部分匹配 / 不相关),并给出整体判断。 Args: query: 自然语言商品描述,例如"男士休闲亚麻短裤夏季" limit: 最多返回条数(建议 10-20,越多评估越全面) Returns: 质量评估摘要 + [SEARCH_REF:ref_id],供最终回复引用。 """ try: logger.info(f"[{session_id}] search_products: query={query!r} limit={limit}") url = f"{settings.search_api_base_url.rstrip('/')}/search/" headers = { "Content-Type": "application/json", "X-Tenant-ID": settings.search_api_tenant_id, } payload = { "query": query, "size": min(max(limit, 1), 20), "from": 0, "language": "zh", } resp = requests.post(url, json=payload, headers=headers, timeout=60) if resp.status_code != 200: logger.error(f"Search API error {resp.status_code}: {resp.text[:300]}") return f"搜索失败:API 返回状态码 {resp.status_code},请稍后重试。" data = resp.json() raw_results: list = data.get("results", []) total_hits: int = data.get("total", 0) if not raw_results: return ( f"【搜索完成】query='{query}'\n" "未找到匹配商品,建议换用更宽泛或不同角度的关键词重新搜索。" ) # ── LLM quality assessment ────────────────────────────────────── labels, verdict, quality_summary = _assess_search_quality(query, raw_results) # ── Build ProductItem list (keep perfect + partial, discard irrelevant) ── products: list[ProductItem] = [] perfect_count = partial_count = irrelevant_count = 0 for raw, label in zip(raw_results, labels): if label == "完美匹配": perfect_count += 1 elif label == "部分匹配": partial_count += 1 else: irrelevant_count += 1 if label in ("完美匹配", "部分匹配"): products.append( ProductItem( spu_id=str(raw.get("spu_id", "")), title=raw.get("title") or "", price=raw.get("price"), category_path=( raw.get("category_path") or raw.get("category_name") ), vendor=raw.get("vendor"), image_url=raw.get("image_url"), relevance_score=raw.get("relevance_score"), match_label=label, tags=raw.get("tags") or [], specifications=raw.get("specifications") or [], ) ) # ── Register ──────────────────────────────────────────────────── ref_id = new_ref_id() result = SearchResult( ref_id=ref_id, query=query, total_api_hits=total_hits, returned_count=len(raw_results), perfect_count=perfect_count, partial_count=partial_count, irrelevant_count=irrelevant_count, quality_verdict=verdict, quality_summary=quality_summary, products=products, ) registry.register(session_id, result) logger.info( f"[{session_id}] Registered {ref_id}: verdict={verdict}, " f"perfect={perfect_count}, partial={partial_count}, irrel={irrelevant_count}" ) # ── Return summary to agent (NOT the product list) ────────────── verdict_hint = { "优质": "结果质量优质,可直接引用。", "一般": "结果质量一般,可酌情引用,也可补充更精准的 query。", "较差": "结果质量较差,建议重新规划 query 后再次搜索。", }.get(verdict, "") return ( f"【搜索完成】query='{query}'\n" f"API 总命中:{total_hits} 条 | 本次评估:{len(raw_results)} 条\n" f"质量评估:完美匹配 {perfect_count} 条 | 部分匹配 {partial_count} 条 | 不相关 {irrelevant_count} 条\n" f"整体判断:{verdict} — {quality_summary}\n" f"{verdict_hint}\n" f"结果引用:[SEARCH_REF:{ref_id}]" ) except requests.exceptions.RequestException as e: logger.error(f"[{session_id}] Search network error: {e}", exc_info=True) return f"搜索失败(网络错误):{e}" except Exception as e: logger.error(f"[{session_id}] Search error: {e}", exc_info=True) return f"搜索失败:{e}" return search_products # ── Standalone tools (no session binding needed) ─────────────────────────────── @tool def web_search(query: str) -> str: """使用 Tavily 进行通用 Web 搜索,补充外部/实时知识。 触发场景: - 需要**外部知识**:流行趋势、品牌、搭配文化、节日习俗等 - 需要**实时/及时信息**:当季流行元素、某地未来的天气 - 需要**宏观参考**:不同场合/国家的穿着建议、选购攻略 Args: query: 要搜索的问题,自然语言描述 Returns: 总结后的回答 + 若干参考来源链接 """ try: api_key = os.getenv("TAVILY_API_KEY") if not api_key: return ( "无法调用外部 Web 搜索:未检测到 TAVILY_API_KEY 环境变量。\n" "请在运行环境中配置 TAVILY_API_KEY 后再重试。" ) logger.info(f"web_search: {query!r}") url = "https://api.tavily.com/search" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } payload = { "query": query, "search_depth": "advanced", "include_answer": True, } response = requests.post(url, json=payload, headers=headers, timeout=60) if response.status_code != 200: return f"调用外部 Web 搜索失败:Tavily 返回状态码 {response.status_code}" data = response.json() answer = data.get("answer") or "(Tavily 未返回直接回答,仅返回了搜索结果。)" results = data.get("results") or [] output_lines = [ "【外部 Web 搜索结果(Tavily)】", "", "回答摘要:", answer.strip(), ] if results: output_lines.append("") output_lines.append("参考来源(部分):") for idx, item in enumerate(results[:5], 1): title = item.get("title") or "无标题" link = item.get("url") or "" output_lines.append(f"{idx}. {title}") if link: output_lines.append(f" 链接: {link}") return "\n".join(output_lines).strip() except requests.exceptions.RequestException as e: logger.error("web_search network error: %s", e, exc_info=True) return f"调用外部 Web 搜索失败(网络错误):{e}" except Exception as e: logger.error("web_search error: %s", e, exc_info=True) return f"调用外部 Web 搜索失败:{e}" @tool def analyze_image_style(image_path: str) -> str: """分析用户上传的商品图片,提取视觉风格属性,用于后续商品搜索。 适用场景: - 用户上传图片,想找相似商品 - 需要理解图片中商品的风格、颜色、材质等属性 Args: image_path: 图片文件路径 Returns: 商品视觉属性的详细文字描述,可直接作为 search_products 的 query """ try: logger.info(f"analyze_image_style: {image_path!r}") img_path = Path(image_path) if not img_path.exists(): return f"错误:图片文件不存在:{image_path}" with open(img_path, "rb") as f: image_data = base64.b64encode(f.read()).decode("utf-8") prompt = """请分析这张商品图片,提供详细的视觉属性描述,用于商品搜索。 请包含: - 商品类型(如:连衣裙、运动鞋、双肩包、西装等) - 主要颜色 - 风格定位(如:休闲、正式、运动、复古、现代简约等) - 图案/纹理(如:纯色、条纹、格纹、碎花、几何图案等) - 关键设计特征(如:领型、袖长、版型、材质外观等) - 适用场合(如:办公、户外、度假、聚会、运动等) 输出格式:3-4句自然语言描述,可直接用作搜索关键词。""" client = get_openai_client() response = client.chat.completions.create( model=settings.openai_vision_model, messages=[ { "role": "user", "content": [ {"type": "text", "text": prompt}, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{image_data}", "detail": "high", }, }, ], } ], max_tokens=500, temperature=0.3, ) analysis = response.choices[0].message.content.strip() logger.info("Image analysis completed.") return analysis except Exception as e: logger.error(f"analyze_image_style error: {e}", exc_info=True) return f"图片分析失败:{e}" # ── Tool list factory ────────────────────────────────────────────────────────── def get_all_tools( session_id: str = "default", registry: Optional[SearchResultRegistry] = None, ) -> list: """ Return all agent tools. search_products is session-bound (factory); other tools are stateless. """ if registry is None: registry = global_registry return [ make_search_products_tool(session_id, registry), analyze_image_style, web_search, ]