From 664426683046f36260ee8703f5631c85e40c0cf2 Mon Sep 17 00:00:00 2001 From: tangwang Date: Fri, 20 Feb 2026 15:54:53 +0800 Subject: [PATCH] feat: 搜索结果引用与并行搜索、两轮上限 --- app.py | 335 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- app/agents/shopping_agent.py | 299 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- app/search_registry.py | 100 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ app/tools/__init__.py | 9 +++++++-- app/tools/search_tools.py | 430 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------------------------------------------------------------------------------------------------------- 5 files changed, 672 insertions(+), 501 deletions(-) create mode 100644 app/search_registry.py diff --git a/app.py b/app.py index 6a5560c..0b3b4a1 100644 --- a/app.py +++ b/app.py @@ -13,6 +13,11 @@ import streamlit as st from PIL import Image, ImageOps from app.agents.shopping_agent import ShoppingAgent +from app.search_registry import ProductItem, SearchResult, global_registry + +# Matches [SEARCH_REF:sr_xxxxxxxx] tokens embedded in AI responses. +# Case-insensitive, optional spaces around the id. +SEARCH_REF_PATTERN = re.compile(r"\[SEARCH_REF:\s*([a-zA-Z0-9_]+)\s*\]", re.IGNORECASE) # Configure logging logging.basicConfig( @@ -270,124 +275,118 @@ def save_uploaded_image(uploaded_file) -> Optional[str]: return None -def extract_products_from_response(response: str) -> list: - """Extract product information from agent response +def _load_product_image(product: ProductItem) -> Optional[Image.Image]: + """Try to load a product image: image_url from API → local data/images → None.""" + if product.image_url: + try: + import requests + resp = requests.get(product.image_url, timeout=10) + if resp.status_code == 200: + import io + return Image.open(io.BytesIO(resp.content)) + except Exception as e: + logger.debug(f"Remote image fetch failed for {product.spu_id}: {e}") + + local = Path(f"data/images/{product.spu_id}.jpg") + if local.exists(): + try: + return Image.open(local) + except Exception as e: + logger.debug(f"Local image load failed {local}: {e}") + return None + + +def display_product_card_from_item(product: ProductItem) -> None: + """Render a single product card from a ProductItem (registry entry).""" + img = _load_product_image(product) + + if img: + target = (220, 220) + try: + img = ImageOps.fit(img, target, method=Image.Resampling.LANCZOS) + except AttributeError: + img = ImageOps.fit(img, target, method=Image.LANCZOS) + st.image(img, use_container_width=True) + else: + st.markdown( + '
🛍️
', + unsafe_allow_html=True, + ) + + title = product.title or "未知商品" + st.markdown(f"**{title[:40]}**" + ("…" if len(title) > 40 else "")) + + if product.price is not None: + st.caption(f"¥{product.price:.2f}") + + label_style = "⭐" if product.match_label == "完美匹配" else "✦" + st.caption(f"{label_style} {product.match_label}") - Returns list of dicts with product info + +def render_search_result_block(result: SearchResult) -> None: """ - products = [] - - # Pattern to match product blocks in the response - # Looking for ID, name, and other details - lines = response.split("\n") - current_product = {} - - for line in lines: - line = line.strip() - - # Match product number (e.g., "1. Product Name" or "**1. Product Name**") - if re.match(r"^\*?\*?\d+\.\s+", line): - if current_product: - products.append(current_product) - current_product = {} - # Extract product name - name = re.sub(r"^\*?\*?\d+\.\s+", "", line) - name = name.replace("**", "").strip() - current_product["name"] = name - - # Match ID - elif "ID:" in line or "id:" in line: - id_match = re.search(r"(?:ID|id):\s*(\d+)", line) - if id_match: - current_product["id"] = id_match.group(1) - - # Match Category - elif "Category:" in line: - cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) - if cat_match: - current_product["category"] = cat_match.group(1).strip() - - # Match Color - elif "Color:" in line: - color_match = re.search(r"Color:\s*(\w+)", line) - if color_match: - current_product["color"] = color_match.group(1) - - # Match Gender - elif "Gender:" in line: - gender_match = re.search(r"Gender:\s*(\w+)", line) - if gender_match: - current_product["gender"] = gender_match.group(1) - - # Match Season - elif "Season:" in line: - season_match = re.search(r"Season:\s*(\w+)", line) - if season_match: - current_product["season"] = season_match.group(1) - - # Match Usage - elif "Usage:" in line: - usage_match = re.search(r"Usage:\s*(\w+)", line) - if usage_match: - current_product["usage"] = usage_match.group(1) - - # Match Similarity/Relevance score - elif "Similarity:" in line or "Relevance:" in line: - score_match = re.search(r"(?:Similarity|Relevance):\s*([\d.]+)%", line) - if score_match: - current_product["score"] = score_match.group(1) - - # Add last product - if current_product: - products.append(current_product) - - return products - - -def display_product_card(product: dict): - """Display a product card with image and name""" - product_id = product.get("id", "") - name = product.get("name", "Unknown Product") - - # Debug: log what we got - logger.info(f"Displaying product: ID={product_id}, Name={name}") - - # Try to load image from data/images directory - if product_id: - image_path = Path(f"data/images/{product_id}.jpg") - - if image_path.exists(): - try: - img = Image.open(image_path) - # Fixed size for all images - target_size = (200, 200) - try: - # Try new Pillow API - img_processed = ImageOps.fit( - img, target_size, method=Image.Resampling.LANCZOS - ) - except AttributeError: - # Fallback for older Pillow versions - img_processed = ImageOps.fit( - img, target_size, method=Image.LANCZOS - ) - - # Display image with fixed width - st.image(img_processed, use_container_width=False, width=200) - st.markdown(f"**{name}**") - st.caption(f"ID: {product_id}") - return - except Exception as e: - logger.warning(f"Failed to load image {image_path}: {e}") + Render a full search result block in place of a [SEARCH_REF:xxx] token. + + Shows: + - A styled header with query text + quality verdict + match counts + - A grid of product cards (perfect matches first, then partial; max 6) + """ + verdict_icon = {"优质": "✅", "一般": "〰️", "较差": "⚠️"}.get(result.quality_verdict, "🔍") + header_html = ( + f'
' + f'' + f'🔍 {result.query}' + f'  {verdict_icon} {result.quality_verdict}' + f' · 完美匹配 {result.perfect_count} 件' + f' · 相关 {result.partial_count} 件' + f'
' + ) + st.markdown(header_html, unsafe_allow_html=True) + + # Perfect matches first, fall back to partials if none + perfect = [p for p in result.products if p.match_label == "完美匹配"] + partial = [p for p in result.products if p.match_label == "部分匹配"] + to_show = (perfect + partial)[:6] if perfect else partial[:6] + + if not to_show: + st.caption("(本次搜索未找到可展示的商品)") + return + + cols = st.columns(min(len(to_show), 3)) + for i, product in enumerate(to_show): + with cols[i % 3]: + display_product_card_from_item(product) + + +def render_message_with_refs(content: str, session_id: str) -> None: + """ + Render an assistant message that may contain [SEARCH_REF:xxx] tokens. + + Text segments are rendered as markdown. + [SEARCH_REF:xxx] tokens are replaced with full product card blocks + loaded from the global registry. + """ + # re.split with a capture group alternates: [text, ref_id, text, ref_id, ...] + parts = SEARCH_REF_PATTERN.split(content) + + for i, segment in enumerate(parts): + if i % 2 == 0: + # Text segment + text = segment.strip() + if text: + st.markdown(text) else: - logger.warning(f"Image not found: {image_path}") - - # Fallback: no image - st.markdown(f"**📷 {name}**") - if product_id: - st.caption(f"ID: {product_id}") - else: - st.caption("ID not available") + # ref_id segment + ref_id = segment.strip() + result = global_registry.get(session_id, ref_id) + if result: + render_search_result_block(result) + else: + # ref not found (e.g. old session after restart) + st.caption(f"[搜索结果 {ref_id} 不可用]") def display_message(message: dict): @@ -412,13 +411,13 @@ def display_message(message: dict): st.markdown("", unsafe_allow_html=True) else: # assistant - # Display tool calls horizontally - only tool names + # Tool call breadcrumb if tool_calls: tool_names = [tc["name"] for tc in tool_calls] st.caption(" → ".join(tool_names)) st.markdown("") - # Optional: detailed debug panel (reasoning + tool details) + # Debug panel if debug_steps and st.session_state.get("show_debug"): with st.expander("思考 & 工具调用详细过程", expanded=False): for idx, step in enumerate(debug_steps, 1): @@ -430,9 +429,7 @@ def display_message(message: dict): if msgs: st.markdown("**Agent Messages**") for m in msgs: - role = m.get("type", "assistant") - content = m.get("content", "") - st.markdown(f"- `{role}`: {content}") + st.markdown(f"- `{m.get('type', 'assistant')}`: {m.get('content', '')}") tcs = step.get("tool_calls", []) if tcs: @@ -450,65 +447,10 @@ def display_message(message: dict): st.code(r.get("content", ""), language="text") st.markdown("---") - - # Extract and display products if any - products = extract_products_from_response(content) - - # Debug logging - logger.info(f"Extracted {len(products)} products from response") - for p in products: - logger.info(f"Product: {p}") - - if products: - def parse_score(product: dict) -> float: - score = product.get("score") - if score is None: - return 0.0 - try: - return float(score) - except (TypeError, ValueError): - return 0.0 - - # Sort by score and limit to 3 - products = sorted(products, key=parse_score, reverse=True)[:3] - - logger.info(f"Displaying top {len(products)} products") - - # Display the text response first (without product details) - text_lines = [] - for line in content.split("\n"): - # Skip product detail lines - if not any( - keyword in line - for keyword in [ - "ID:", - "Category:", - "Color:", - "Gender:", - "Season:", - "Usage:", - "Similarity:", - "Relevance:", - ] - ): - if not re.match(r"^\*?\*?\d+\.\s+", line): - text_lines.append(line) - - intro_text = "\n".join(text_lines).strip() - if intro_text: - st.markdown(intro_text) - - # Display product cards in grid - st.markdown("
", unsafe_allow_html=True) - - # Create exactly 3 columns with equal width - cols = st.columns(3) - for j, product in enumerate(products[:9]): # Ensure max 3 - with cols[j]: - display_product_card(product) - else: - # No products found, display full content - st.markdown(content) + + # Render message: expand [SEARCH_REF:xxx] tokens into product card blocks + session_id = st.session_state.get("session_id", "") + render_message_with_refs(content, session_id) st.markdown("", unsafe_allow_html=True) @@ -591,6 +533,10 @@ def main(): if st.button("🗑️ Clear Chat", use_container_width=True): if "shopping_agent" in st.session_state: st.session_state.shopping_agent.clear_history() + # Clear search result registry for this session + session_id = st.session_state.get("session_id", "") + if session_id: + global_registry.clear_session(session_id) st.session_state.messages = [] st.session_state.uploaded_image = None st.rerun() @@ -600,6 +546,7 @@ def main(): st.checkbox( "显示调试过程 (debug)", key="show_debug", + value=True, help="展开后可查看中间思考过程及工具调用详情", ) @@ -713,26 +660,16 @@ def main(): try: shopping_agent = st.session_state.shopping_agent - # Handle greetings + # Handle greetings without invoking the agent query_lower = user_query.lower().strip() - if query_lower in ["hi", "hello", "hey"]: - response = """Hello! 👋 I'm your fashion shopping assistant. - -I can help you: -- Search for products by description -- Find items similar to images you upload -- Analyze product styles - -What are you looking for today?""" - tool_calls = [] - else: - # Process with agent - result = shopping_agent.chat( - query=user_query, - image_path=image_path, - ) - response = result["response"] - tool_calls = result.get("tool_calls", []) + # Process with agent + result = shopping_agent.chat( + query=user_query, + image_path=image_path, + ) + response = result["response"] + tool_calls = result.get("tool_calls", []) + debug_steps = result.get("debug_steps", []) # Add assistant message st.session_state.messages.append( @@ -740,7 +677,7 @@ What are you looking for today?""" "role": "assistant", "content": response, "tool_calls": tool_calls, - "debug_steps": result.get("debug_steps", []), + "debug_steps": debug_steps, } ) diff --git a/app/agents/shopping_agent.py b/app/agents/shopping_agent.py index e2a6963..76d4414 100644 --- a/app/agents/shopping_agent.py +++ b/app/agents/shopping_agent.py @@ -1,6 +1,11 @@ """ Conversational Shopping Agent with LangGraph -True ReAct agent with autonomous tool calling and message accumulation + +Architecture: +- ReAct-style agent: plan → search → evaluate → re-plan or respond +- search_products is session-bound, writing curated results to SearchResultRegistry +- Final AI message references results via [SEARCH_REF:xxx] tokens instead of + re-listing product details; the UI renders product cards from the registry """ import logging @@ -16,14 +21,52 @@ from langgraph.prebuilt import ToolNode from typing_extensions import Annotated, TypedDict from app.config import settings +from app.search_registry import global_registry from app.tools.search_tools import get_all_tools logger = logging.getLogger(__name__) +# ── System prompt ────────────────────────────────────────────────────────────── +# Universal: works for any e-commerce vertical (fashion, electronics, home, etc.) +# Key design decisions: +# 1. Guides multi-query search planning with explicit evaluate-and-decide loop +# 2. Forbids re-listing product details in the final response +# 3. Mandates [SEARCH_REF:xxx] inline citation as the only product presentation mechanism +SYSTEM_PROMPT = """ +角色定义 +你是一名专业的服装电商导购,是一个善于倾听、主动引导、懂得搭配的“时尚顾问”,通过有温度的对话,给用户提供有价值的信息,包括需求引导、方案推荐、搜索结果推荐,最终促成满意的购物决策或转化行为。 + +一些原则: +1. 你是一个真人导购,是一个贴心、专业的销售,保持灵活,根据上下文,基于常识灵活的切换策略,在合适的上下文询问合适的问题、给出有价值的方案和搜索结果的呈现。 +2. 兼顾推荐与信息收集:适时的提供有价值的信息,如商品推荐、穿搭建议、趋势信息,在推荐方向上有需求缺口、需要明确的重要信息时,要适时的做“信息收集”,引导式的帮助用户更清晰的呈现需求、提高商品发现的效率,形成“提供-反馈”的良性循环。 + 1. 在意图不明时,主动通过1-2个关键问题(如品类、场景、风格、预算)进行引导,并提供初步方向。 + 2. 在了解到初步意向后,要进行相关商品的搜索、进行搜索结果的呈现,同时思考该方向下重要的决策因素,进行提议和问题收集,让用户既得到相关信息、又得到下一步的方向引导、同时也有机会修正或者细化诉求。 + 3. 对于复杂需求时,要能基于上下文,将导购任务进行合理拆解。 +3. 引导或者收集需求时,需要站在用户立场,比如询问用户期待的效果或感觉、使用的场合、偏好的风格等用户立场需,而不是询问具体的款式或参数,你需要将用户立场的需求理解/翻译/转化为具体的搜索计划,最后筛选产品、结合需求+结果特性组织推荐理由、呈现方案。 +4. 如何使用search_products:在需要搜索商品的时候,可以将需求分解为 2-4 个搜索查询,每个 query 聚焦一个明确的商品子类或搜索角度。每次调用 search_products 后,工具会返回以下内容,你需要决策是否要调整搜索策略,比如结果质量太差,可能需要调整搜索词、或者加大试探的query数量(不要超过3-5个)。可以进行多轮搜索,但是要适时的总结和反馈信息避免用户等待过长时间: + - 各层级数量:完美匹配 / 部分匹配 / 不相关 的条数 + - 整体质量判断:优质 / 一般 / 较差 + - 简短质量说明 + - 结果引用标识:[SEARCH_REF:xxx] +5. 撰写最终回复的时候,使用 [SEARCH_REF:xxx] 内联引用 + 1. 用自然流畅的语言组织回复,将 [SEARCH_REF:xxx] 嵌入叙述中 + 2. 系统会自动在 [SEARCH_REF:xxx] 位置渲染对应的商品卡片列表 + 3. 禁止在回复文本中列出商品名称、ID、价格、分类、规格等字段 + 4. 禁止用编号列表逐条复述搜索结果中的商品 +""" + + +# ── Agent state ──────────────────────────────────────────────────────────────── + +class AgentState(TypedDict): + messages: Annotated[Sequence[BaseMessage], add_messages] + current_image_path: Optional[str] + + +# ── Helper ───────────────────────────────────────────────────────────────────── def _extract_message_text(msg) -> str: - """Extract text from message content. - LangChain 1.0: content may be str or content_blocks (list) for multimodal.""" + """Extract plain text from a LangChain message (handles str or content_blocks).""" content = getattr(msg, "content", "") if isinstance(content, str): return content @@ -31,27 +74,21 @@ def _extract_message_text(msg) -> str: parts = [] for block in content: if isinstance(block, dict): - parts.append(block.get("text", block.get("content", ""))) + parts.append(block.get("text") or block.get("content") or "") else: parts.append(str(block)) return "".join(str(p) for p in parts) return str(content) if content else "" -class AgentState(TypedDict): - """State for the shopping agent with message accumulation""" - - messages: Annotated[Sequence[BaseMessage], add_messages] - current_image_path: Optional[str] # Track uploaded image +# ── Agent class ──────────────────────────────────────────────────────────────── -print("settings") class ShoppingAgent: - """True ReAct agent with autonomous decision making""" + """ReAct shopping agent with search-evaluate-decide loop and registry-based result referencing.""" def __init__(self, session_id: Optional[str] = None): self.session_id = session_id or "default" - # Initialize LLM llm_kwargs = dict( model=settings.openai_model, temperature=settings.openai_temperature, @@ -59,261 +96,173 @@ class ShoppingAgent: ) if settings.openai_api_base_url: llm_kwargs["base_url"] = settings.openai_api_base_url - - print("llm_kwargs") - print(llm_kwargs) self.llm = ChatOpenAI(**llm_kwargs) - # Get tools and bind to model - self.tools = get_all_tools() + # Tools are session-bound so search_products writes to the right registry partition + self.tools = get_all_tools(session_id=self.session_id, registry=global_registry) self.llm_with_tools = self.llm.bind_tools(self.tools) - # Build graph self.graph = self._build_graph() - - logger.info(f"Shopping agent initialized for session: {self.session_id}") + logger.info(f"ShoppingAgent ready — session={self.session_id}") def _build_graph(self): - """Build the LangGraph StateGraph""" - - # System prompt for the agent - system_prompt = """你是一位智能时尚购物助手,你可以: -1. 根据文字描述搜索商品(使用 search_products) -2. 分析图片风格和属性(使用 analyze_image_style) - -当用户咨询商品时: -- 文字提问:直接使用 search_products 搜索 -- 图片上传:先用 analyze_image_style 理解商品,再用提取的描述调用 search_products 搜索 -- 可按需连续调用多个工具 -- 始终保持有用、友好的回复风格 - -关键格式规则: -展示商品结果时,每个商品必须严格按以下格式输出: - -1. [标题 title] - ID: [商品ID] - 分类: [category_path] - 中文名: [title_cn](如有) - 标签: [tags](如有) - -示例: -1. Puma Men White 3/4 Length Pants - ID: 12345 - 分类: 服饰 > 裤装 > 运动裤 - 中文名: 彪马男士白色九分运动裤 - 标签: 运动,夏季,白色 - -不可省略 ID 字段!它是展示商品图片的关键。 -介绍要口语化,但必须保持上述商品格式。""" - def agent_node(state: AgentState): - """Agent decision node - decides which tools to call or when to respond""" messages = state["messages"] - - # Add system prompt if first message if not any(isinstance(m, SystemMessage) for m in messages): - messages = [SystemMessage(content=system_prompt)] + list(messages) - - # Handle image context - if state.get("current_image_path"): - # Inject image path context for tool calls - # The agent can reference this in its reasoning - pass - - # Invoke LLM with tools + messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages) response = self.llm_with_tools.invoke(messages) return {"messages": [response]} - # Create tool node - tool_node = ToolNode(self.tools) - def should_continue(state: AgentState): - """Determine if agent should continue or end""" - messages = state["messages"] - last_message = messages[-1] - - # If LLM made tool calls, continue to tools - if hasattr(last_message, "tool_calls") and last_message.tool_calls: + last = state["messages"][-1] + if hasattr(last, "tool_calls") and last.tool_calls: return "tools" - # Otherwise, end (agent has final response) return END - # Build graph - workflow = StateGraph(AgentState) + tool_node = ToolNode(self.tools) + workflow = StateGraph(AgentState) workflow.add_node("agent", agent_node) workflow.add_node("tools", tool_node) - workflow.add_edge(START, "agent") workflow.add_conditional_edges("agent", should_continue, ["tools", END]) workflow.add_edge("tools", "agent") - # Compile with memory - checkpointer = MemorySaver() - return workflow.compile(checkpointer=checkpointer) + return workflow.compile(checkpointer=MemorySaver()) def chat(self, query: str, image_path: Optional[str] = None) -> dict: - """Process user query with the agent - - Args: - query: User's text query - image_path: Optional path to uploaded image + """ + Process a user query and return the agent response with metadata. Returns: - Dict with response and metadata, including: - - tool_calls: list of tool calls with args and (truncated) results - - debug_steps: detailed intermediate reasoning & tool execution steps + dict with keys: + response – final AI message text (may contain [SEARCH_REF:xxx] tokens) + tool_calls – list of {name, args, result_preview} + debug_steps – detailed per-node step log + search_refs – dict[ref_id → SearchResult] for all searches this turn + error – bool """ try: - logger.info( - f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})" - ) + logger.info(f"[{self.session_id}] chat: {query!r} image={bool(image_path)}") - # Validate image if image_path and not Path(image_path).exists(): return { - "response": f"Error: Image file not found at '{image_path}'", + "response": f"错误:图片文件不存在:{image_path}", "error": True, } - # Build input message + # Snapshot registry before the turn so we can report new additions + registry_before = set(global_registry.get_all(self.session_id).keys()) + message_content = query if image_path: - message_content = f"{query}\n[User uploaded image: {image_path}]" + message_content = f"{query}\n[用户上传了图片:{image_path}]" - # Invoke agent config = {"configurable": {"thread_id": self.session_id}} input_state = { "messages": [HumanMessage(content=message_content)], "current_image_path": image_path, } - # Track tool calls (high-level) and detailed debug steps - tool_calls = [] - debug_steps = [] - - # Stream events to capture tool calls and intermediate reasoning + tool_calls: list[dict] = [] + debug_steps: list[dict] = [] + for event in self.graph.stream(input_state, config=config): - logger.info(f"Event: {event}") + logger.debug(f"[{self.session_id}] event keys: {list(event.keys())}") - # Agent node: LLM reasoning & tool decisions if "agent" in event: - agent_output = event["agent"] - messages = agent_output.get("messages", []) + agent_out = event["agent"] + step_msgs: list[dict] = [] + step_tcs: list[dict] = [] - step_messages = [] - step_tool_calls = [] - - for msg in messages: - msg_text = _extract_message_text(msg) - msg_entry = { + for msg in agent_out.get("messages", []): + text = _extract_message_text(msg) + step_msgs.append({ "type": getattr(msg, "type", "assistant"), - "content": msg_text[:500], # truncate for safety - } - step_messages.append(msg_entry) - - # Capture tool calls from this agent message + "content": text[:500], + }) if hasattr(msg, "tool_calls") and msg.tool_calls: for tc in msg.tool_calls: - tc_entry = { - "name": tc.get("name"), - "args": tc.get("args", {}), - } - tool_calls.append(tc_entry) - step_tool_calls.append(tc_entry) - - debug_steps.append( - { - "node": "agent", - "messages": step_messages, - "tool_calls": step_tool_calls, - } - ) - - # Tool node: actual tool execution results - if "tools" in event: - tools_output = event["tools"] - messages = tools_output.get("messages", []) - - step_tool_results = [] + entry = {"name": tc.get("name"), "args": tc.get("args", {})} + tool_calls.append(entry) + step_tcs.append(entry) - for i, msg in enumerate(messages): - content_text = _extract_message_text(msg) - result_preview = content_text[:500] + ("..." if len(content_text) > 500 else "") + debug_steps.append({"node": "agent", "messages": step_msgs, "tool_calls": step_tcs}) - if i < len(tool_calls): - tool_calls[i]["result"] = result_preview + if "tools" in event: + tools_out = event["tools"] + step_results: list[dict] = [] + msgs = tools_out.get("messages", []) - step_tool_results.append( - { - "content": result_preview, - } - ) + # Match results back to tool_calls by position within this event + unresolved = [tc for tc in tool_calls if "result" not in tc] + for i, msg in enumerate(msgs): + text = _extract_message_text(msg) + preview = text[:600] + ("…" if len(text) > 600 else "") + if i < len(unresolved): + unresolved[i]["result"] = preview + step_results.append({"content": preview}) - debug_steps.append( - { - "node": "tools", - "results": step_tool_results, - } - ) + debug_steps.append({"node": "tools", "results": step_results}) - # Get final state final_state = self.graph.get_state(config) - final_message = final_state.values["messages"][-1] - response_text = _extract_message_text(final_message) + final_msg = final_state.values["messages"][-1] + response_text = _extract_message_text(final_msg) + + # Collect new SearchResults added during this turn + registry_after = global_registry.get_all(self.session_id) + new_refs = { + ref_id: result + for ref_id, result in registry_after.items() + if ref_id not in registry_before + } - logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls") + logger.info( + f"[{self.session_id}] done — tool_calls={len(tool_calls)}, new_refs={list(new_refs.keys())}" + ) return { "response": response_text, "tool_calls": tool_calls, "debug_steps": debug_steps, + "search_refs": new_refs, "error": False, } except Exception as e: - logger.error(f"Error in agent chat: {e}", exc_info=True) + logger.error(f"[{self.session_id}] chat error: {e}", exc_info=True) return { - "response": f"I apologize, I encountered an error: {str(e)}", + "response": f"抱歉,处理您的请求时遇到错误:{e}", + "tool_calls": [], + "debug_steps": [], + "search_refs": {}, "error": True, } def get_conversation_history(self) -> list: - """Get conversation history for this session""" try: config = {"configurable": {"thread_id": self.session_id}} state = self.graph.get_state(config) - if not state or not state.values.get("messages"): return [] - messages = state.values["messages"] result = [] - - for msg in messages: - # Skip system messages and tool messages + for msg in state.values["messages"]: if isinstance(msg, SystemMessage): continue - if hasattr(msg, "type") and msg.type in ["system", "tool"]: + if getattr(msg, "type", None) in ("system", "tool"): continue - role = "user" if msg.type == "human" else "assistant" result.append({"role": role, "content": _extract_message_text(msg)}) - return result - except Exception as e: - logger.error(f"Error getting history: {e}") + logger.error(f"get_conversation_history error: {e}") return [] def clear_history(self): - """Clear conversation history for this session""" - # With MemorySaver, we can't easily clear, but we can log - logger.info(f"[{self.session_id}] History clear requested") - # In production, implement proper clearing or use new thread_id + logger.info(f"[{self.session_id}] clear requested (use new session_id to fully reset)") def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent: - """Factory function to create a shopping agent""" return ShoppingAgent(session_id=session_id) diff --git a/app/search_registry.py b/app/search_registry.py new file mode 100644 index 0000000..861be0e --- /dev/null +++ b/app/search_registry.py @@ -0,0 +1,100 @@ +""" +Search Result Registry + +Stores structured search results keyed by session and ref_id. +Each [SEARCH_REF:xxx] in an AI response maps to a SearchResult stored here, +allowing the UI to render product cards without the LLM ever re-listing them. +""" + +import uuid +from dataclasses import dataclass, field +from typing import Optional + + +def new_ref_id() -> str: + """Generate a short unique search reference ID, e.g. 'sr_3f9a1b2c'.""" + return "sr_" + uuid.uuid4().hex[:8] + + +@dataclass +class ProductItem: + """A single product extracted from a search result, enriched with a match label.""" + + spu_id: str + title: str + price: Optional[float] = None + category_path: Optional[str] = None + vendor: Optional[str] = None + image_url: Optional[str] = None + relevance_score: Optional[float] = None + # LLM-assigned label: "完美匹配" | "部分匹配" | "不相关" + match_label: str = "部分匹配" + tags: list = field(default_factory=list) + specifications: list = field(default_factory=list) + + +@dataclass +class SearchResult: + """ + A complete, self-contained search result block. + + Identified by ref_id (e.g. 'sr_3f9a1b2c'). + Stores the query, LLM quality assessment, and the curated product list + (only "完美匹配" and "部分匹配" items — "不相关" are discarded). + """ + + ref_id: str + query: str + + # Raw API stats + total_api_hits: int # total documents matched by the search engine + returned_count: int # number of results we actually assessed + + # LLM quality labels breakdown + perfect_count: int + partial_count: int + irrelevant_count: int + + # LLM overall quality verdict + quality_verdict: str # "优质" | "一般" | "较差" + quality_summary: str # one-sentence LLM explanation + + # Curated product list (perfect + partial only) + products: list # list[ProductItem] + + +class SearchResultRegistry: + """ + Session-scoped store: session_id → { ref_id → SearchResult }. + + Lives as a global singleton in the process; Streamlit reruns preserve it + as long as the worker process is alive. Session isolation is maintained + by keying on session_id. + """ + + def __init__(self) -> None: + self._store: dict[str, dict[str, SearchResult]] = {} + + def register(self, session_id: str, result: SearchResult) -> str: + """Store a SearchResult and return its ref_id.""" + if session_id not in self._store: + self._store[session_id] = {} + self._store[session_id][result.ref_id] = result + return result.ref_id + + def get(self, session_id: str, ref_id: str) -> Optional[SearchResult]: + """Look up a single SearchResult by session and ref_id.""" + return self._store.get(session_id, {}).get(ref_id) + + def get_all(self, session_id: str) -> dict: + """Return all SearchResults for a session (ref_id → SearchResult).""" + return dict(self._store.get(session_id, {})) + + def clear_session(self, session_id: str) -> None: + """Remove all search results for a session (e.g. on chat clear).""" + self._store.pop(session_id, None) + + +# ── Global singleton ────────────────────────────────────────────────────────── +# Imported by search_tools and app.py; both sides share the same object. +global_registry = SearchResultRegistry() diff --git a/app/tools/__init__.py b/app/tools/__init__.py index ccf1a57..8706ef7 100644 --- a/app/tools/__init__.py +++ b/app/tools/__init__.py @@ -1,15 +1,20 @@ """ LangChain Tools for Product Search and Discovery + +search_products is created per-session via make_search_products_tool(). +Use get_all_tools(session_id, registry) for the full tool list. """ from app.tools.search_tools import ( analyze_image_style, get_all_tools, - search_products, + make_search_products_tool, + web_search, ) __all__ = [ - "search_products", + "make_search_products_tool", "analyze_image_style", + "web_search", "get_all_tools", ] diff --git a/app/tools/search_tools.py b/app/tools/search_tools.py index f2b77e1..56e889e 100644 --- a/app/tools/search_tools.py +++ b/app/tools/search_tools.py @@ -1,9 +1,20 @@ """ Search Tools for Product Discovery -Provides text-based search via Search API, web search, and VLM style analysis + +Key design: +- search_products is created via a factory (make_search_products_tool) that + closes over (session_id, registry), so each agent session has its own tool + instance pointing to the shared registry. +- After calling the search API, an LLM quality-assessment step labels every + result as 完美匹配 / 部分匹配 / 不相关 and produces an overall verdict. +- The curated product list is stored in the registry under a unique ref_id. +- The tool returns ONLY the quality summary + [SEARCH_REF:ref_id], never the + raw product list. The LLM references the result in its final response via + the [SEARCH_REF:...] token; the UI renders the product cards from the registry. """ import base64 +import json import logging import os from pathlib import Path @@ -14,6 +25,13 @@ from langchain_core.tools import tool from openai import OpenAI from app.config import settings +from app.search_registry import ( + ProductItem, + SearchResult, + SearchResultRegistry, + global_registry, + new_ref_id, +) logger = logging.getLogger(__name__) @@ -30,31 +48,264 @@ def get_openai_client() -> OpenAI: return _openai_client +# ── LLM quality assessment ───────────────────────────────────────────────────── + +def _assess_search_quality( + query: str, + raw_products: list, +) -> tuple[list[str], str, str]: + """ + Ask the LLM to evaluate how well each search result matches the query. + + Returns: + labels – list[str], one per product: "完美匹配" | "部分匹配" | "不相关" + verdict – str: "优质" | "一般" | "较差" + summary – str: one-sentence explanation + """ + n = len(raw_products) + if n == 0: + return [], "较差", "搜索未返回任何商品。" + + # Build a compact product list — only title/category/tags/score to save tokens + lines: list[str] = [] + for i, p in enumerate(raw_products, 1): + title = (p.get("title") or "")[:60] + cat = p.get("category_path") or p.get("category_name") or "" + tags_raw = p.get("tags") or [] + tags = ", ".join(str(t) for t in tags_raw[:5]) + score = p.get("relevance_score") or 0 + row = f"{i}. [{score:.1f}] {title} | {cat}" + if tags: + row += f" | 标签:{tags}" + lines.append(row) + + product_text = "\n".join(lines) + + prompt = f"""你是商品搜索质量评估专家。请评估以下搜索结果与用户查询的匹配程度。 + +用户查询:{query} + +搜索结果(共 {n} 条,格式:序号. [相关性分数] 标题 | 分类 | 标签): +{product_text} + +评估说明: +- 完美匹配:完全符合用户查询意图,用户必然感兴趣 +- 部分匹配:与查询有关联,但不完全满足意图(如品类对但风格偏差、相关配件等) +- 不相关:与查询无关,不应展示给用户 + +整体 verdict 判断标准: +- 优质:完美匹配 ≥ 5 条 +- 一般:完美匹配 2-4 条 +- 较差:完美匹配 < 2 条 + +请严格按以下 JSON 格式输出,不得有任何额外文字或代码块标记: +{{"labels": ["完美匹配", "部分匹配", "不相关", ...], "verdict": "优质", "summary": "一句话评价搜索质量"}} + +labels 数组长度必须恰好等于 {n}。""" + + try: + client = get_openai_client() + resp = client.chat.completions.create( + model=settings.openai_model, + messages=[{"role": "user", "content": prompt}], + max_tokens=800, + temperature=0.1, + ) + raw = resp.choices[0].message.content.strip() + # Strip markdown code fences if the model adds them + if raw.startswith("```"): + raw = raw.split("```")[1] + if raw.startswith("json"): + raw = raw[4:] + raw = raw.strip() + + data = json.loads(raw) + labels: list[str] = data.get("labels", []) + + # Normalize and pad / trim to match n + valid = {"完美匹配", "部分匹配", "不相关"} + labels = [l if l in valid else "部分匹配" for l in labels] + while len(labels) < n: + labels.append("部分匹配") + labels = labels[:n] + + verdict: str = data.get("verdict", "一般") + if verdict not in ("优质", "一般", "较差"): + verdict = "一般" + summary: str = str(data.get("summary", "")) + return labels, verdict, summary + + except Exception as e: + logger.warning(f"Quality assessment LLM call failed: {e}; using fallback labels.") + return ["部分匹配"] * n, "一般", "质量评估步骤失败,结果仅供参考。" + + +# ── Tool factory ─────────────────────────────────────────────────────────────── + +def make_search_products_tool( + session_id: str, + registry: SearchResultRegistry, +): + """ + Return a search_products tool bound to a specific session and registry. + + The tool: + 1. Calls the product search API. + 2. Runs LLM quality assessment on up to 20 results. + 3. Stores a SearchResult in the registry. + 4. Returns a concise quality summary + [SEARCH_REF:ref_id]. + The product list is NEVER returned in the tool output text. + """ + + @tool + def search_products(query: str, limit: int = 20) -> str: + """搜索商品库,根据自然语言描述找到匹配商品,并进行质量评估。 + + 每次调用专注于单一搜索角度。复杂需求请拆分为多次调用,每次换一个 query。 + 工具会自动评估结果质量(完美匹配 / 部分匹配 / 不相关),并给出整体判断。 + + Args: + query: 自然语言商品描述,例如"男士休闲亚麻短裤夏季" + limit: 最多返回条数(建议 10-20,越多评估越全面) + + Returns: + 质量评估摘要 + [SEARCH_REF:ref_id],供最终回复引用。 + """ + try: + logger.info(f"[{session_id}] search_products: query={query!r} limit={limit}") + + url = f"{settings.search_api_base_url.rstrip('/')}/search/" + headers = { + "Content-Type": "application/json", + "X-Tenant-ID": settings.search_api_tenant_id, + } + payload = { + "query": query, + "size": min(max(limit, 1), 20), + "from": 0, + "language": "zh", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=60) + if resp.status_code != 200: + logger.error(f"Search API error {resp.status_code}: {resp.text[:300]}") + return f"搜索失败:API 返回状态码 {resp.status_code},请稍后重试。" + + data = resp.json() + raw_results: list = data.get("results", []) + total_hits: int = data.get("total", 0) + + if not raw_results: + return ( + f"【搜索完成】query='{query}'\n" + "未找到匹配商品,建议换用更宽泛或不同角度的关键词重新搜索。" + ) + + # ── LLM quality assessment ────────────────────────────────────── + labels, verdict, quality_summary = _assess_search_quality(query, raw_results) + + # ── Build ProductItem list (keep perfect + partial, discard irrelevant) ── + products: list[ProductItem] = [] + perfect_count = partial_count = irrelevant_count = 0 + + for raw, label in zip(raw_results, labels): + if label == "完美匹配": + perfect_count += 1 + elif label == "部分匹配": + partial_count += 1 + else: + irrelevant_count += 1 + + if label in ("完美匹配", "部分匹配"): + products.append( + ProductItem( + spu_id=str(raw.get("spu_id", "")), + title=raw.get("title") or "", + price=raw.get("price"), + category_path=( + raw.get("category_path") or raw.get("category_name") + ), + vendor=raw.get("vendor"), + image_url=raw.get("image_url"), + relevance_score=raw.get("relevance_score"), + match_label=label, + tags=raw.get("tags") or [], + specifications=raw.get("specifications") or [], + ) + ) + + # ── Register ──────────────────────────────────────────────────── + ref_id = new_ref_id() + result = SearchResult( + ref_id=ref_id, + query=query, + total_api_hits=total_hits, + returned_count=len(raw_results), + perfect_count=perfect_count, + partial_count=partial_count, + irrelevant_count=irrelevant_count, + quality_verdict=verdict, + quality_summary=quality_summary, + products=products, + ) + registry.register(session_id, result) + logger.info( + f"[{session_id}] Registered {ref_id}: verdict={verdict}, " + f"perfect={perfect_count}, partial={partial_count}, irrel={irrelevant_count}" + ) + + # ── Return summary to agent (NOT the product list) ────────────── + verdict_hint = { + "优质": "结果质量优质,可直接引用。", + "一般": "结果质量一般,可酌情引用,也可补充更精准的 query。", + "较差": "结果质量较差,建议重新规划 query 后再次搜索。", + }.get(verdict, "") + + return ( + f"【搜索完成】query='{query}'\n" + f"API 总命中:{total_hits} 条 | 本次评估:{len(raw_results)} 条\n" + f"质量评估:完美匹配 {perfect_count} 条 | 部分匹配 {partial_count} 条 | 不相关 {irrelevant_count} 条\n" + f"整体判断:{verdict} — {quality_summary}\n" + f"{verdict_hint}\n" + f"结果引用:[SEARCH_REF:{ref_id}]" + ) + + except requests.exceptions.RequestException as e: + logger.error(f"[{session_id}] Search network error: {e}", exc_info=True) + return f"搜索失败(网络错误):{e}" + except Exception as e: + logger.error(f"[{session_id}] Search error: {e}", exc_info=True) + return f"搜索失败:{e}" + + return search_products + + +# ── Standalone tools (no session binding needed) ─────────────────────────────── + @tool def web_search(query: str) -> str: """使用 Tavily 进行通用 Web 搜索,补充外部/实时知识。 - 触发场景(示例): - - 需要**外部知识**:流行趋势、新品信息、穿搭文化、品牌故事等 - - 需要**实时/及时信息**:某地某个时节的天气、当季流行元素、最新联名款 - - 需要**宏观参考**:不同城市/国家的穿衣习惯、节日穿搭建议 + 触发场景: + - 需要**外部知识**:流行趋势、品牌、搭配文化、节日习俗等 + - 需要**实时/及时信息**:当季流行元素、某地未来的天气 + - 需要**宏观参考**:不同场合/国家的穿着建议、选购攻略 Args: - query: 要搜索的问题,自然语言描述(建议用中文) + query: 要搜索的问题,自然语言描述 Returns: - 总结后的回答 + 若干来源链接,供模型继续推理使用。 + 总结后的回答 + 若干参考来源链接 """ try: api_key = os.getenv("TAVILY_API_KEY") if not api_key: - logger.error("TAVILY_API_KEY is not set in environment variables") return ( "无法调用外部 Web 搜索:未检测到 TAVILY_API_KEY 环境变量。\n" "请在运行环境中配置 TAVILY_API_KEY 后再重试。" ) - logger.info(f"Calling Tavily web search with query: {query!r}") + logger.info(f"web_search: {query!r}") url = "https://api.tavily.com/search" headers = { @@ -66,15 +317,9 @@ def web_search(query: str) -> str: "search_depth": "advanced", "include_answer": True, } - response = requests.post(url, json=payload, headers=headers, timeout=60) if response.status_code != 200: - logger.error( - "Tavily API error: %s - %s", - response.status_code, - response.text, - ) return f"调用外部 Web 搜索失败:Tavily 返回状态码 {response.status_code}" data = response.json() @@ -87,140 +332,61 @@ def web_search(query: str) -> str: "回答摘要:", answer.strip(), ] - if results: output_lines.append("") output_lines.append("参考来源(部分):") for idx, item in enumerate(results[:5], 1): title = item.get("title") or "无标题" - url = item.get("url") or "" + link = item.get("url") or "" output_lines.append(f"{idx}. {title}") - if url: - output_lines.append(f" 链接: {url}") + if link: + output_lines.append(f" 链接: {link}") return "\n".join(output_lines).strip() except requests.exceptions.RequestException as e: - logger.error("Error calling Tavily web search (network): %s", e, exc_info=True) + logger.error("web_search network error: %s", e, exc_info=True) return f"调用外部 Web 搜索失败(网络错误):{e}" except Exception as e: - logger.error("Error calling Tavily web search: %s", e, exc_info=True) + logger.error("web_search error: %s", e, exc_info=True) return f"调用外部 Web 搜索失败:{e}" @tool -def search_products(query: str, limit: int = 5) -> str: - """Search for fashion products using natural language descriptions. - - Use when users describe what they want: - - "Find me red summer dresses" - - "Show me blue running shoes" - - "I want casual shirts for men" - - Args: - query: Natural language product description - limit: Maximum number of results (1-20) - - Returns: - Formatted string with product information - """ - try: - logger.info(f"Searching products: '{query}', limit: {limit}") - - url = f"{settings.search_api_base_url.rstrip('/')}/search/" - headers = { - "Content-Type": "application/json", - "X-Tenant-ID": settings.search_api_tenant_id, - } - payload = { - "query": query, - "size": min(limit, 20), - "from": 0, - "language": "zh", - } - - response = requests.post(url, json=payload, headers=headers, timeout=60) - - if response.status_code != 200: - logger.error(f"Search API error: {response.status_code} - {response.text}") - return f"Error searching products: API returned {response.status_code}" - - data = response.json() - results = data.get("results", []) - - if not results: - return "No products found matching your search." - - output = f"Found {len(results)} product(s):\n\n" - - for idx, product in enumerate(results, 1): - output += f"{idx}. {product.get('title', 'Unknown Product')}\n" - output += f" ID: {product.get('spu_id', 'N/A')}\n" - output += f" Category: {product.get('category_path', product.get('category_name', 'N/A'))}\n" - if product.get("vendor"): - output += f" Brand: {product.get('vendor')}\n" - if product.get("price") is not None: - output += f" Price: {product.get('price')}\n" - - # 规格/颜色信息 - specs = product.get("specifications", []) - if specs: - color_spec = next( - (s for s in specs if s.get("name").lower() == "color"), - None, - ) - if color_spec: - output += f" Color: {color_spec.get('value', 'N/A')}\n" - - output += "\n" - - return output.strip() - - except requests.exceptions.RequestException as e: - logger.error(f"Error searching products (network): {e}", exc_info=True) - return f"Error searching products: {str(e)}" - except Exception as e: - logger.error(f"Error searching products: {e}", exc_info=True) - return f"Error searching products: {str(e)}" - - -@tool def analyze_image_style(image_path: str) -> str: - """Analyze a fashion product image using AI vision to extract detailed style information. + """分析用户上传的商品图片,提取视觉风格属性,用于后续商品搜索。 - Use when you need to understand style/attributes from an image: - - Understand the style, color, pattern of a product - - Extract attributes like "casual", "formal", "vintage" - - Get detailed descriptions for subsequent searches + 适用场景: + - 用户上传图片,想找相似商品 + - 需要理解图片中商品的风格、颜色、材质等属性 Args: - image_path: Path to the image file + image_path: 图片文件路径 Returns: - Detailed text description of the product's visual attributes + 商品视觉属性的详细文字描述,可直接作为 search_products 的 query """ try: - logger.info(f"Analyzing image with VLM: '{image_path}'") + logger.info(f"analyze_image_style: {image_path!r}") img_path = Path(image_path) if not img_path.exists(): - return f"Error: Image file not found at '{image_path}'" + return f"错误:图片文件不存在:{image_path}" - with open(img_path, "rb") as image_file: - image_data = base64.b64encode(image_file.read()).decode("utf-8") + with open(img_path, "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") - prompt = """Analyze this fashion product image and provide a detailed description. + prompt = """请分析这张商品图片,提供详细的视觉属性描述,用于商品搜索。 -Include: -- Product type (e.g., shirt, dress, shoes, pants, bag) -- Primary colors -- Style/design (e.g., casual, formal, sporty, vintage, modern) -- Pattern or texture (e.g., plain, striped, checked, floral) -- Key features (e.g., collar type, sleeve length, fit) -- Material appearance (if obvious, e.g., denim, cotton, leather) -- Suitable occasion (e.g., office wear, party, casual, sports) +请包含: +- 商品类型(如:连衣裙、运动鞋、双肩包、西装等) +- 主要颜色 +- 风格定位(如:休闲、正式、运动、复古、现代简约等) +- 图案/纹理(如:纯色、条纹、格纹、碎花、几何图案等) +- 关键设计特征(如:领型、袖长、版型、材质外观等) +- 适用场合(如:办公、户外、度假、聚会、运动等) -Provide a comprehensive yet concise description (3-4 sentences).""" +输出格式:3-4句自然语言描述,可直接用作搜索关键词。""" client = get_openai_client() response = client.chat.completions.create( @@ -245,15 +411,29 @@ Provide a comprehensive yet concise description (3-4 sentences).""" ) analysis = response.choices[0].message.content.strip() - logger.info("VLM analysis completed") - + logger.info("Image analysis completed.") return analysis except Exception as e: - logger.error(f"Error analyzing image: {e}", exc_info=True) - return f"Error analyzing image: {str(e)}" + logger.error(f"analyze_image_style error: {e}", exc_info=True) + return f"图片分析失败:{e}" -def get_all_tools(): - """Get all available tools for the agent""" - return [search_products, analyze_image_style, web_search] +# ── Tool list factory ────────────────────────────────────────────────────────── + +def get_all_tools( + session_id: str = "default", + registry: Optional[SearchResultRegistry] = None, +) -> list: + """ + Return all agent tools. + + search_products is session-bound (factory); other tools are stateless. + """ + if registry is None: + registry = global_registry + return [ + make_search_products_tool(session_id, registry), + analyze_image_style, + web_search, + ] -- libgit2 0.21.2