Commit 664426683046f36260ee8703f5631c85e40c0cf2
1 parent
9ad88986
feat: 搜索结果引用与并行搜索、两轮上限
## 搜索结果管理与人机回复引用 - 新增 app/search_registry.py:SearchResultRegistry + SearchResult/ProductItem 数据结构,按 session 存储每次搜索的 query、质量评估与商品列表。 - 搜索工具改为工厂 make_search_products_tool(session_id, registry):每次搜索后由 LLM 对 top20 打标(完美匹配/部分匹配/不相关),产出整体 verdict(优质/一般/较差),仅将「完美+部分」写入 registry 并返回摘要 + [SEARCH_REF:ref_id];不再向 Agent 返回完整商品列表。 - 废除 extract_products_from_response:最终回复中通过内联 [SEARCH_REF:xxx] 引用「搜索结果块」,UI 用 SEARCH_REF_PATTERN 解析后从 registry 取对应 SearchResult 渲染 query 标题 + 商品卡片,避免 LLM 复述商品列表,节省 token 并减少错误。 ## 系统提示与行为约束 - 系统提示词通用化(不绑定时尚品类),明确四步:理解意图 → 规划 2~4 个 query → 执行搜索并评估 → 撰写回复。 - 要求同一条回复中并行发起 2~4 次 search_products(不同 query),利用 LangGraph ToolNode 的并行执行缩短等待;禁止串行「搜一个看一个再搜下一个」。 - 轮次上限:最多两轮搜索(两轮 = 两次「Agent 发 tool_calls → Tools 执行 → 返回」);若已有优质/一般结果则直接写回复,仅当全部较差时允许第二轮(最多再 1~2 个 query)。图逻辑增加 n_tool_rounds 状态与 agent_final 节点,两轮后强制进入「仅回复、不调工具」的 agent_final,避免无限重搜。 ## 前端与工具导出 - app.py:render_message_with_refs(content, session_id) 按 [SEARCH_REF:xxx] 切分并渲染;render_search_result_block 展示 query + 质量 + 商品卡片;display_product_card_from_item 支持 image_url/本地图/占位;Clear Chat 时 clear_session(registry)。 - app/tools/__init__.py:改为导出 make_search_products_tool、web_search,不再导出已移除的 search_products 顶层名。 Co-authored-by: Cursor <cursoragent@cursor.com>
Showing
5 changed files
with
672 additions
and
501 deletions
Show diff stats
| ... | ... | @@ -13,6 +13,11 @@ import streamlit as st |
| 13 | 13 | from PIL import Image, ImageOps |
| 14 | 14 | |
| 15 | 15 | from app.agents.shopping_agent import ShoppingAgent |
| 16 | +from app.search_registry import ProductItem, SearchResult, global_registry | |
| 17 | + | |
| 18 | +# Matches [SEARCH_REF:sr_xxxxxxxx] tokens embedded in AI responses. | |
| 19 | +# Case-insensitive, optional spaces around the id. | |
| 20 | +SEARCH_REF_PATTERN = re.compile(r"\[SEARCH_REF:\s*([a-zA-Z0-9_]+)\s*\]", re.IGNORECASE) | |
| 16 | 21 | |
| 17 | 22 | # Configure logging |
| 18 | 23 | logging.basicConfig( |
| ... | ... | @@ -270,124 +275,118 @@ def save_uploaded_image(uploaded_file) -> Optional[str]: |
| 270 | 275 | return None |
| 271 | 276 | |
| 272 | 277 | |
| 273 | -def extract_products_from_response(response: str) -> list: | |
| 274 | - """Extract product information from agent response | |
| 278 | +def _load_product_image(product: ProductItem) -> Optional[Image.Image]: | |
| 279 | + """Try to load a product image: image_url from API → local data/images → None.""" | |
| 280 | + if product.image_url: | |
| 281 | + try: | |
| 282 | + import requests | |
| 283 | + resp = requests.get(product.image_url, timeout=10) | |
| 284 | + if resp.status_code == 200: | |
| 285 | + import io | |
| 286 | + return Image.open(io.BytesIO(resp.content)) | |
| 287 | + except Exception as e: | |
| 288 | + logger.debug(f"Remote image fetch failed for {product.spu_id}: {e}") | |
| 289 | + | |
| 290 | + local = Path(f"data/images/{product.spu_id}.jpg") | |
| 291 | + if local.exists(): | |
| 292 | + try: | |
| 293 | + return Image.open(local) | |
| 294 | + except Exception as e: | |
| 295 | + logger.debug(f"Local image load failed {local}: {e}") | |
| 296 | + return None | |
| 297 | + | |
| 298 | + | |
| 299 | +def display_product_card_from_item(product: ProductItem) -> None: | |
| 300 | + """Render a single product card from a ProductItem (registry entry).""" | |
| 301 | + img = _load_product_image(product) | |
| 302 | + | |
| 303 | + if img: | |
| 304 | + target = (220, 220) | |
| 305 | + try: | |
| 306 | + img = ImageOps.fit(img, target, method=Image.Resampling.LANCZOS) | |
| 307 | + except AttributeError: | |
| 308 | + img = ImageOps.fit(img, target, method=Image.LANCZOS) | |
| 309 | + st.image(img, use_container_width=True) | |
| 310 | + else: | |
| 311 | + st.markdown( | |
| 312 | + '<div style="height:120px;background:#f5f5f5;border-radius:6px;' | |
| 313 | + 'display:flex;align-items:center;justify-content:center;' | |
| 314 | + 'color:#bbb;font-size:2rem;">🛍️</div>', | |
| 315 | + unsafe_allow_html=True, | |
| 316 | + ) | |
| 317 | + | |
| 318 | + title = product.title or "未知商品" | |
| 319 | + st.markdown(f"**{title[:40]}**" + ("…" if len(title) > 40 else "")) | |
| 320 | + | |
| 321 | + if product.price is not None: | |
| 322 | + st.caption(f"¥{product.price:.2f}") | |
| 323 | + | |
| 324 | + label_style = "⭐" if product.match_label == "完美匹配" else "✦" | |
| 325 | + st.caption(f"{label_style} {product.match_label}") | |
| 275 | 326 | |
| 276 | - Returns list of dicts with product info | |
| 327 | + | |
| 328 | +def render_search_result_block(result: SearchResult) -> None: | |
| 277 | 329 | """ |
| 278 | - products = [] | |
| 279 | - | |
| 280 | - # Pattern to match product blocks in the response | |
| 281 | - # Looking for ID, name, and other details | |
| 282 | - lines = response.split("\n") | |
| 283 | - current_product = {} | |
| 284 | - | |
| 285 | - for line in lines: | |
| 286 | - line = line.strip() | |
| 287 | - | |
| 288 | - # Match product number (e.g., "1. Product Name" or "**1. Product Name**") | |
| 289 | - if re.match(r"^\*?\*?\d+\.\s+", line): | |
| 290 | - if current_product: | |
| 291 | - products.append(current_product) | |
| 292 | - current_product = {} | |
| 293 | - # Extract product name | |
| 294 | - name = re.sub(r"^\*?\*?\d+\.\s+", "", line) | |
| 295 | - name = name.replace("**", "").strip() | |
| 296 | - current_product["name"] = name | |
| 297 | - | |
| 298 | - # Match ID | |
| 299 | - elif "ID:" in line or "id:" in line: | |
| 300 | - id_match = re.search(r"(?:ID|id):\s*(\d+)", line) | |
| 301 | - if id_match: | |
| 302 | - current_product["id"] = id_match.group(1) | |
| 303 | - | |
| 304 | - # Match Category | |
| 305 | - elif "Category:" in line: | |
| 306 | - cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) | |
| 307 | - if cat_match: | |
| 308 | - current_product["category"] = cat_match.group(1).strip() | |
| 309 | - | |
| 310 | - # Match Color | |
| 311 | - elif "Color:" in line: | |
| 312 | - color_match = re.search(r"Color:\s*(\w+)", line) | |
| 313 | - if color_match: | |
| 314 | - current_product["color"] = color_match.group(1) | |
| 315 | - | |
| 316 | - # Match Gender | |
| 317 | - elif "Gender:" in line: | |
| 318 | - gender_match = re.search(r"Gender:\s*(\w+)", line) | |
| 319 | - if gender_match: | |
| 320 | - current_product["gender"] = gender_match.group(1) | |
| 321 | - | |
| 322 | - # Match Season | |
| 323 | - elif "Season:" in line: | |
| 324 | - season_match = re.search(r"Season:\s*(\w+)", line) | |
| 325 | - if season_match: | |
| 326 | - current_product["season"] = season_match.group(1) | |
| 327 | - | |
| 328 | - # Match Usage | |
| 329 | - elif "Usage:" in line: | |
| 330 | - usage_match = re.search(r"Usage:\s*(\w+)", line) | |
| 331 | - if usage_match: | |
| 332 | - current_product["usage"] = usage_match.group(1) | |
| 333 | - | |
| 334 | - # Match Similarity/Relevance score | |
| 335 | - elif "Similarity:" in line or "Relevance:" in line: | |
| 336 | - score_match = re.search(r"(?:Similarity|Relevance):\s*([\d.]+)%", line) | |
| 337 | - if score_match: | |
| 338 | - current_product["score"] = score_match.group(1) | |
| 339 | - | |
| 340 | - # Add last product | |
| 341 | - if current_product: | |
| 342 | - products.append(current_product) | |
| 343 | - | |
| 344 | - return products | |
| 345 | - | |
| 346 | - | |
| 347 | -def display_product_card(product: dict): | |
| 348 | - """Display a product card with image and name""" | |
| 349 | - product_id = product.get("id", "") | |
| 350 | - name = product.get("name", "Unknown Product") | |
| 351 | - | |
| 352 | - # Debug: log what we got | |
| 353 | - logger.info(f"Displaying product: ID={product_id}, Name={name}") | |
| 354 | - | |
| 355 | - # Try to load image from data/images directory | |
| 356 | - if product_id: | |
| 357 | - image_path = Path(f"data/images/{product_id}.jpg") | |
| 358 | - | |
| 359 | - if image_path.exists(): | |
| 360 | - try: | |
| 361 | - img = Image.open(image_path) | |
| 362 | - # Fixed size for all images | |
| 363 | - target_size = (200, 200) | |
| 364 | - try: | |
| 365 | - # Try new Pillow API | |
| 366 | - img_processed = ImageOps.fit( | |
| 367 | - img, target_size, method=Image.Resampling.LANCZOS | |
| 368 | - ) | |
| 369 | - except AttributeError: | |
| 370 | - # Fallback for older Pillow versions | |
| 371 | - img_processed = ImageOps.fit( | |
| 372 | - img, target_size, method=Image.LANCZOS | |
| 373 | - ) | |
| 374 | - | |
| 375 | - # Display image with fixed width | |
| 376 | - st.image(img_processed, use_container_width=False, width=200) | |
| 377 | - st.markdown(f"**{name}**") | |
| 378 | - st.caption(f"ID: {product_id}") | |
| 379 | - return | |
| 380 | - except Exception as e: | |
| 381 | - logger.warning(f"Failed to load image {image_path}: {e}") | |
| 330 | + Render a full search result block in place of a [SEARCH_REF:xxx] token. | |
| 331 | + | |
| 332 | + Shows: | |
| 333 | + - A styled header with query text + quality verdict + match counts | |
| 334 | + - A grid of product cards (perfect matches first, then partial; max 6) | |
| 335 | + """ | |
| 336 | + verdict_icon = {"优质": "✅", "一般": "〰️", "较差": "⚠️"}.get(result.quality_verdict, "🔍") | |
| 337 | + header_html = ( | |
| 338 | + f'<div style="border:1px solid #e0e0e0;border-radius:8px;padding:10px 14px;' | |
| 339 | + f'margin:8px 0 4px 0;background:#fafafa;">' | |
| 340 | + f'<span style="font-size:0.8rem;color:#555;">' | |
| 341 | + f'🔍 <b>{result.query}</b>' | |
| 342 | + f' {verdict_icon} {result.quality_verdict}' | |
| 343 | + f' · 完美匹配 {result.perfect_count} 件' | |
| 344 | + f' · 相关 {result.partial_count} 件' | |
| 345 | + f'</span></div>' | |
| 346 | + ) | |
| 347 | + st.markdown(header_html, unsafe_allow_html=True) | |
| 348 | + | |
| 349 | + # Perfect matches first, fall back to partials if none | |
| 350 | + perfect = [p for p in result.products if p.match_label == "完美匹配"] | |
| 351 | + partial = [p for p in result.products if p.match_label == "部分匹配"] | |
| 352 | + to_show = (perfect + partial)[:6] if perfect else partial[:6] | |
| 353 | + | |
| 354 | + if not to_show: | |
| 355 | + st.caption("(本次搜索未找到可展示的商品)") | |
| 356 | + return | |
| 357 | + | |
| 358 | + cols = st.columns(min(len(to_show), 3)) | |
| 359 | + for i, product in enumerate(to_show): | |
| 360 | + with cols[i % 3]: | |
| 361 | + display_product_card_from_item(product) | |
| 362 | + | |
| 363 | + | |
| 364 | +def render_message_with_refs(content: str, session_id: str) -> None: | |
| 365 | + """ | |
| 366 | + Render an assistant message that may contain [SEARCH_REF:xxx] tokens. | |
| 367 | + | |
| 368 | + Text segments are rendered as markdown. | |
| 369 | + [SEARCH_REF:xxx] tokens are replaced with full product card blocks | |
| 370 | + loaded from the global registry. | |
| 371 | + """ | |
| 372 | + # re.split with a capture group alternates: [text, ref_id, text, ref_id, ...] | |
| 373 | + parts = SEARCH_REF_PATTERN.split(content) | |
| 374 | + | |
| 375 | + for i, segment in enumerate(parts): | |
| 376 | + if i % 2 == 0: | |
| 377 | + # Text segment | |
| 378 | + text = segment.strip() | |
| 379 | + if text: | |
| 380 | + st.markdown(text) | |
| 382 | 381 | else: |
| 383 | - logger.warning(f"Image not found: {image_path}") | |
| 384 | - | |
| 385 | - # Fallback: no image | |
| 386 | - st.markdown(f"**📷 {name}**") | |
| 387 | - if product_id: | |
| 388 | - st.caption(f"ID: {product_id}") | |
| 389 | - else: | |
| 390 | - st.caption("ID not available") | |
| 382 | + # ref_id segment | |
| 383 | + ref_id = segment.strip() | |
| 384 | + result = global_registry.get(session_id, ref_id) | |
| 385 | + if result: | |
| 386 | + render_search_result_block(result) | |
| 387 | + else: | |
| 388 | + # ref not found (e.g. old session after restart) | |
| 389 | + st.caption(f"[搜索结果 {ref_id} 不可用]") | |
| 391 | 390 | |
| 392 | 391 | |
| 393 | 392 | def display_message(message: dict): |
| ... | ... | @@ -412,13 +411,13 @@ def display_message(message: dict): |
| 412 | 411 | st.markdown("</div>", unsafe_allow_html=True) |
| 413 | 412 | |
| 414 | 413 | else: # assistant |
| 415 | - # Display tool calls horizontally - only tool names | |
| 414 | + # Tool call breadcrumb | |
| 416 | 415 | if tool_calls: |
| 417 | 416 | tool_names = [tc["name"] for tc in tool_calls] |
| 418 | 417 | st.caption(" → ".join(tool_names)) |
| 419 | 418 | st.markdown("") |
| 420 | 419 | |
| 421 | - # Optional: detailed debug panel (reasoning + tool details) | |
| 420 | + # Debug panel | |
| 422 | 421 | if debug_steps and st.session_state.get("show_debug"): |
| 423 | 422 | with st.expander("思考 & 工具调用详细过程", expanded=False): |
| 424 | 423 | for idx, step in enumerate(debug_steps, 1): |
| ... | ... | @@ -430,9 +429,7 @@ def display_message(message: dict): |
| 430 | 429 | if msgs: |
| 431 | 430 | st.markdown("**Agent Messages**") |
| 432 | 431 | for m in msgs: |
| 433 | - role = m.get("type", "assistant") | |
| 434 | - content = m.get("content", "") | |
| 435 | - st.markdown(f"- `{role}`: {content}") | |
| 432 | + st.markdown(f"- `{m.get('type', 'assistant')}`: {m.get('content', '')}") | |
| 436 | 433 | |
| 437 | 434 | tcs = step.get("tool_calls", []) |
| 438 | 435 | if tcs: |
| ... | ... | @@ -450,65 +447,10 @@ def display_message(message: dict): |
| 450 | 447 | st.code(r.get("content", ""), language="text") |
| 451 | 448 | |
| 452 | 449 | st.markdown("---") |
| 453 | - | |
| 454 | - # Extract and display products if any | |
| 455 | - products = extract_products_from_response(content) | |
| 456 | - | |
| 457 | - # Debug logging | |
| 458 | - logger.info(f"Extracted {len(products)} products from response") | |
| 459 | - for p in products: | |
| 460 | - logger.info(f"Product: {p}") | |
| 461 | - | |
| 462 | - if products: | |
| 463 | - def parse_score(product: dict) -> float: | |
| 464 | - score = product.get("score") | |
| 465 | - if score is None: | |
| 466 | - return 0.0 | |
| 467 | - try: | |
| 468 | - return float(score) | |
| 469 | - except (TypeError, ValueError): | |
| 470 | - return 0.0 | |
| 471 | - | |
| 472 | - # Sort by score and limit to 3 | |
| 473 | - products = sorted(products, key=parse_score, reverse=True)[:3] | |
| 474 | - | |
| 475 | - logger.info(f"Displaying top {len(products)} products") | |
| 476 | - | |
| 477 | - # Display the text response first (without product details) | |
| 478 | - text_lines = [] | |
| 479 | - for line in content.split("\n"): | |
| 480 | - # Skip product detail lines | |
| 481 | - if not any( | |
| 482 | - keyword in line | |
| 483 | - for keyword in [ | |
| 484 | - "ID:", | |
| 485 | - "Category:", | |
| 486 | - "Color:", | |
| 487 | - "Gender:", | |
| 488 | - "Season:", | |
| 489 | - "Usage:", | |
| 490 | - "Similarity:", | |
| 491 | - "Relevance:", | |
| 492 | - ] | |
| 493 | - ): | |
| 494 | - if not re.match(r"^\*?\*?\d+\.\s+", line): | |
| 495 | - text_lines.append(line) | |
| 496 | - | |
| 497 | - intro_text = "\n".join(text_lines).strip() | |
| 498 | - if intro_text: | |
| 499 | - st.markdown(intro_text) | |
| 500 | - | |
| 501 | - # Display product cards in grid | |
| 502 | - st.markdown("<br>", unsafe_allow_html=True) | |
| 503 | - | |
| 504 | - # Create exactly 3 columns with equal width | |
| 505 | - cols = st.columns(3) | |
| 506 | - for j, product in enumerate(products[:9]): # Ensure max 3 | |
| 507 | - with cols[j]: | |
| 508 | - display_product_card(product) | |
| 509 | - else: | |
| 510 | - # No products found, display full content | |
| 511 | - st.markdown(content) | |
| 450 | + | |
| 451 | + # Render message: expand [SEARCH_REF:xxx] tokens into product card blocks | |
| 452 | + session_id = st.session_state.get("session_id", "") | |
| 453 | + render_message_with_refs(content, session_id) | |
| 512 | 454 | |
| 513 | 455 | st.markdown("</div>", unsafe_allow_html=True) |
| 514 | 456 | |
| ... | ... | @@ -591,6 +533,10 @@ def main(): |
| 591 | 533 | if st.button("🗑️ Clear Chat", use_container_width=True): |
| 592 | 534 | if "shopping_agent" in st.session_state: |
| 593 | 535 | st.session_state.shopping_agent.clear_history() |
| 536 | + # Clear search result registry for this session | |
| 537 | + session_id = st.session_state.get("session_id", "") | |
| 538 | + if session_id: | |
| 539 | + global_registry.clear_session(session_id) | |
| 594 | 540 | st.session_state.messages = [] |
| 595 | 541 | st.session_state.uploaded_image = None |
| 596 | 542 | st.rerun() |
| ... | ... | @@ -600,6 +546,7 @@ def main(): |
| 600 | 546 | st.checkbox( |
| 601 | 547 | "显示调试过程 (debug)", |
| 602 | 548 | key="show_debug", |
| 549 | + value=True, | |
| 603 | 550 | help="展开后可查看中间思考过程及工具调用详情", |
| 604 | 551 | ) |
| 605 | 552 | |
| ... | ... | @@ -713,26 +660,16 @@ def main(): |
| 713 | 660 | try: |
| 714 | 661 | shopping_agent = st.session_state.shopping_agent |
| 715 | 662 | |
| 716 | - # Handle greetings | |
| 663 | + # Handle greetings without invoking the agent | |
| 717 | 664 | query_lower = user_query.lower().strip() |
| 718 | - if query_lower in ["hi", "hello", "hey"]: | |
| 719 | - response = """Hello! 👋 I'm your fashion shopping assistant. | |
| 720 | - | |
| 721 | -I can help you: | |
| 722 | -- Search for products by description | |
| 723 | -- Find items similar to images you upload | |
| 724 | -- Analyze product styles | |
| 725 | - | |
| 726 | -What are you looking for today?""" | |
| 727 | - tool_calls = [] | |
| 728 | - else: | |
| 729 | - # Process with agent | |
| 730 | - result = shopping_agent.chat( | |
| 731 | - query=user_query, | |
| 732 | - image_path=image_path, | |
| 733 | - ) | |
| 734 | - response = result["response"] | |
| 735 | - tool_calls = result.get("tool_calls", []) | |
| 665 | + # Process with agent | |
| 666 | + result = shopping_agent.chat( | |
| 667 | + query=user_query, | |
| 668 | + image_path=image_path, | |
| 669 | + ) | |
| 670 | + response = result["response"] | |
| 671 | + tool_calls = result.get("tool_calls", []) | |
| 672 | + debug_steps = result.get("debug_steps", []) | |
| 736 | 673 | |
| 737 | 674 | # Add assistant message |
| 738 | 675 | st.session_state.messages.append( |
| ... | ... | @@ -740,7 +677,7 @@ What are you looking for today?""" |
| 740 | 677 | "role": "assistant", |
| 741 | 678 | "content": response, |
| 742 | 679 | "tool_calls": tool_calls, |
| 743 | - "debug_steps": result.get("debug_steps", []), | |
| 680 | + "debug_steps": debug_steps, | |
| 744 | 681 | } |
| 745 | 682 | ) |
| 746 | 683 | ... | ... |
app/agents/shopping_agent.py
| 1 | 1 | """ |
| 2 | 2 | Conversational Shopping Agent with LangGraph |
| 3 | -True ReAct agent with autonomous tool calling and message accumulation | |
| 3 | + | |
| 4 | +Architecture: | |
| 5 | +- ReAct-style agent: plan → search → evaluate → re-plan or respond | |
| 6 | +- search_products is session-bound, writing curated results to SearchResultRegistry | |
| 7 | +- Final AI message references results via [SEARCH_REF:xxx] tokens instead of | |
| 8 | + re-listing product details; the UI renders product cards from the registry | |
| 4 | 9 | """ |
| 5 | 10 | |
| 6 | 11 | import logging |
| ... | ... | @@ -16,14 +21,52 @@ from langgraph.prebuilt import ToolNode |
| 16 | 21 | from typing_extensions import Annotated, TypedDict |
| 17 | 22 | |
| 18 | 23 | from app.config import settings |
| 24 | +from app.search_registry import global_registry | |
| 19 | 25 | from app.tools.search_tools import get_all_tools |
| 20 | 26 | |
| 21 | 27 | logger = logging.getLogger(__name__) |
| 22 | 28 | |
| 29 | +# ── System prompt ────────────────────────────────────────────────────────────── | |
| 30 | +# Universal: works for any e-commerce vertical (fashion, electronics, home, etc.) | |
| 31 | +# Key design decisions: | |
| 32 | +# 1. Guides multi-query search planning with explicit evaluate-and-decide loop | |
| 33 | +# 2. Forbids re-listing product details in the final response | |
| 34 | +# 3. Mandates [SEARCH_REF:xxx] inline citation as the only product presentation mechanism | |
| 35 | +SYSTEM_PROMPT = """ | |
| 36 | +角色定义 | |
| 37 | +你是一名专业的服装电商导购,是一个善于倾听、主动引导、懂得搭配的“时尚顾问”,通过有温度的对话,给用户提供有价值的信息,包括需求引导、方案推荐、搜索结果推荐,最终促成满意的购物决策或转化行为。 | |
| 38 | + | |
| 39 | +一些原则: | |
| 40 | +1. 你是一个真人导购,是一个贴心、专业的销售,保持灵活,根据上下文,基于常识灵活的切换策略,在合适的上下文询问合适的问题、给出有价值的方案和搜索结果的呈现。 | |
| 41 | +2. 兼顾推荐与信息收集:适时的提供有价值的信息,如商品推荐、穿搭建议、趋势信息,在推荐方向上有需求缺口、需要明确的重要信息时,要适时的做“信息收集”,引导式的帮助用户更清晰的呈现需求、提高商品发现的效率,形成“提供-反馈”的良性循环。 | |
| 42 | + 1. 在意图不明时,主动通过1-2个关键问题(如品类、场景、风格、预算)进行引导,并提供初步方向。 | |
| 43 | + 2. 在了解到初步意向后,要进行相关商品的搜索、进行搜索结果的呈现,同时思考该方向下重要的决策因素,进行提议和问题收集,让用户既得到相关信息、又得到下一步的方向引导、同时也有机会修正或者细化诉求。 | |
| 44 | + 3. 对于复杂需求时,要能基于上下文,将导购任务进行合理拆解。 | |
| 45 | +3. 引导或者收集需求时,需要站在用户立场,比如询问用户期待的效果或感觉、使用的场合、偏好的风格等用户立场需,而不是询问具体的款式或参数,你需要将用户立场的需求理解/翻译/转化为具体的搜索计划,最后筛选产品、结合需求+结果特性组织推荐理由、呈现方案。 | |
| 46 | +4. 如何使用search_products:在需要搜索商品的时候,可以将需求分解为 2-4 个搜索查询,每个 query 聚焦一个明确的商品子类或搜索角度。每次调用 search_products 后,工具会返回以下内容,你需要决策是否要调整搜索策略,比如结果质量太差,可能需要调整搜索词、或者加大试探的query数量(不要超过3-5个)。可以进行多轮搜索,但是要适时的总结和反馈信息避免用户等待过长时间: | |
| 47 | + - 各层级数量:完美匹配 / 部分匹配 / 不相关 的条数 | |
| 48 | + - 整体质量判断:优质 / 一般 / 较差 | |
| 49 | + - 简短质量说明 | |
| 50 | + - 结果引用标识:[SEARCH_REF:xxx] | |
| 51 | +5. 撰写最终回复的时候,使用 [SEARCH_REF:xxx] 内联引用 | |
| 52 | + 1. 用自然流畅的语言组织回复,将 [SEARCH_REF:xxx] 嵌入叙述中 | |
| 53 | + 2. 系统会自动在 [SEARCH_REF:xxx] 位置渲染对应的商品卡片列表 | |
| 54 | + 3. 禁止在回复文本中列出商品名称、ID、价格、分类、规格等字段 | |
| 55 | + 4. 禁止用编号列表逐条复述搜索结果中的商品 | |
| 56 | +""" | |
| 57 | + | |
| 58 | + | |
| 59 | +# ── Agent state ──────────────────────────────────────────────────────────────── | |
| 60 | + | |
| 61 | +class AgentState(TypedDict): | |
| 62 | + messages: Annotated[Sequence[BaseMessage], add_messages] | |
| 63 | + current_image_path: Optional[str] | |
| 64 | + | |
| 65 | + | |
| 66 | +# ── Helper ───────────────────────────────────────────────────────────────────── | |
| 23 | 67 | |
| 24 | 68 | def _extract_message_text(msg) -> str: |
| 25 | - """Extract text from message content. | |
| 26 | - LangChain 1.0: content may be str or content_blocks (list) for multimodal.""" | |
| 69 | + """Extract plain text from a LangChain message (handles str or content_blocks).""" | |
| 27 | 70 | content = getattr(msg, "content", "") |
| 28 | 71 | if isinstance(content, str): |
| 29 | 72 | return content |
| ... | ... | @@ -31,27 +74,21 @@ def _extract_message_text(msg) -> str: |
| 31 | 74 | parts = [] |
| 32 | 75 | for block in content: |
| 33 | 76 | if isinstance(block, dict): |
| 34 | - parts.append(block.get("text", block.get("content", ""))) | |
| 77 | + parts.append(block.get("text") or block.get("content") or "") | |
| 35 | 78 | else: |
| 36 | 79 | parts.append(str(block)) |
| 37 | 80 | return "".join(str(p) for p in parts) |
| 38 | 81 | return str(content) if content else "" |
| 39 | 82 | |
| 40 | 83 | |
| 41 | -class AgentState(TypedDict): | |
| 42 | - """State for the shopping agent with message accumulation""" | |
| 43 | - | |
| 44 | - messages: Annotated[Sequence[BaseMessage], add_messages] | |
| 45 | - current_image_path: Optional[str] # Track uploaded image | |
| 84 | +# ── Agent class ──────────────────────────────────────────────────────────────── | |
| 46 | 85 | |
| 47 | -print("settings") | |
| 48 | 86 | class ShoppingAgent: |
| 49 | - """True ReAct agent with autonomous decision making""" | |
| 87 | + """ReAct shopping agent with search-evaluate-decide loop and registry-based result referencing.""" | |
| 50 | 88 | |
| 51 | 89 | def __init__(self, session_id: Optional[str] = None): |
| 52 | 90 | self.session_id = session_id or "default" |
| 53 | 91 | |
| 54 | - # Initialize LLM | |
| 55 | 92 | llm_kwargs = dict( |
| 56 | 93 | model=settings.openai_model, |
| 57 | 94 | temperature=settings.openai_temperature, |
| ... | ... | @@ -59,261 +96,173 @@ class ShoppingAgent: |
| 59 | 96 | ) |
| 60 | 97 | if settings.openai_api_base_url: |
| 61 | 98 | llm_kwargs["base_url"] = settings.openai_api_base_url |
| 62 | - | |
| 63 | - print("llm_kwargs") | |
| 64 | - print(llm_kwargs) | |
| 65 | 99 | |
| 66 | 100 | self.llm = ChatOpenAI(**llm_kwargs) |
| 67 | 101 | |
| 68 | - # Get tools and bind to model | |
| 69 | - self.tools = get_all_tools() | |
| 102 | + # Tools are session-bound so search_products writes to the right registry partition | |
| 103 | + self.tools = get_all_tools(session_id=self.session_id, registry=global_registry) | |
| 70 | 104 | self.llm_with_tools = self.llm.bind_tools(self.tools) |
| 71 | 105 | |
| 72 | - # Build graph | |
| 73 | 106 | self.graph = self._build_graph() |
| 74 | - | |
| 75 | - logger.info(f"Shopping agent initialized for session: {self.session_id}") | |
| 107 | + logger.info(f"ShoppingAgent ready — session={self.session_id}") | |
| 76 | 108 | |
| 77 | 109 | def _build_graph(self): |
| 78 | - """Build the LangGraph StateGraph""" | |
| 79 | - | |
| 80 | - # System prompt for the agent | |
| 81 | - system_prompt = """你是一位智能时尚购物助手,你可以: | |
| 82 | -1. 根据文字描述搜索商品(使用 search_products) | |
| 83 | -2. 分析图片风格和属性(使用 analyze_image_style) | |
| 84 | - | |
| 85 | -当用户咨询商品时: | |
| 86 | -- 文字提问:直接使用 search_products 搜索 | |
| 87 | -- 图片上传:先用 analyze_image_style 理解商品,再用提取的描述调用 search_products 搜索 | |
| 88 | -- 可按需连续调用多个工具 | |
| 89 | -- 始终保持有用、友好的回复风格 | |
| 90 | - | |
| 91 | -关键格式规则: | |
| 92 | -展示商品结果时,每个商品必须严格按以下格式输出: | |
| 93 | - | |
| 94 | -1. [标题 title] | |
| 95 | - ID: [商品ID] | |
| 96 | - 分类: [category_path] | |
| 97 | - 中文名: [title_cn](如有) | |
| 98 | - 标签: [tags](如有) | |
| 99 | - | |
| 100 | -示例: | |
| 101 | -1. Puma Men White 3/4 Length Pants | |
| 102 | - ID: 12345 | |
| 103 | - 分类: 服饰 > 裤装 > 运动裤 | |
| 104 | - 中文名: 彪马男士白色九分运动裤 | |
| 105 | - 标签: 运动,夏季,白色 | |
| 106 | - | |
| 107 | -不可省略 ID 字段!它是展示商品图片的关键。 | |
| 108 | -介绍要口语化,但必须保持上述商品格式。""" | |
| 109 | - | |
| 110 | 110 | def agent_node(state: AgentState): |
| 111 | - """Agent decision node - decides which tools to call or when to respond""" | |
| 112 | 111 | messages = state["messages"] |
| 113 | - | |
| 114 | - # Add system prompt if first message | |
| 115 | 112 | if not any(isinstance(m, SystemMessage) for m in messages): |
| 116 | - messages = [SystemMessage(content=system_prompt)] + list(messages) | |
| 117 | - | |
| 118 | - # Handle image context | |
| 119 | - if state.get("current_image_path"): | |
| 120 | - # Inject image path context for tool calls | |
| 121 | - # The agent can reference this in its reasoning | |
| 122 | - pass | |
| 123 | - | |
| 124 | - # Invoke LLM with tools | |
| 113 | + messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages) | |
| 125 | 114 | response = self.llm_with_tools.invoke(messages) |
| 126 | 115 | return {"messages": [response]} |
| 127 | 116 | |
| 128 | - # Create tool node | |
| 129 | - tool_node = ToolNode(self.tools) | |
| 130 | - | |
| 131 | 117 | def should_continue(state: AgentState): |
| 132 | - """Determine if agent should continue or end""" | |
| 133 | - messages = state["messages"] | |
| 134 | - last_message = messages[-1] | |
| 135 | - | |
| 136 | - # If LLM made tool calls, continue to tools | |
| 137 | - if hasattr(last_message, "tool_calls") and last_message.tool_calls: | |
| 118 | + last = state["messages"][-1] | |
| 119 | + if hasattr(last, "tool_calls") and last.tool_calls: | |
| 138 | 120 | return "tools" |
| 139 | - # Otherwise, end (agent has final response) | |
| 140 | 121 | return END |
| 141 | 122 | |
| 142 | - # Build graph | |
| 143 | - workflow = StateGraph(AgentState) | |
| 123 | + tool_node = ToolNode(self.tools) | |
| 144 | 124 | |
| 125 | + workflow = StateGraph(AgentState) | |
| 145 | 126 | workflow.add_node("agent", agent_node) |
| 146 | 127 | workflow.add_node("tools", tool_node) |
| 147 | - | |
| 148 | 128 | workflow.add_edge(START, "agent") |
| 149 | 129 | workflow.add_conditional_edges("agent", should_continue, ["tools", END]) |
| 150 | 130 | workflow.add_edge("tools", "agent") |
| 151 | 131 | |
| 152 | - # Compile with memory | |
| 153 | - checkpointer = MemorySaver() | |
| 154 | - return workflow.compile(checkpointer=checkpointer) | |
| 132 | + return workflow.compile(checkpointer=MemorySaver()) | |
| 155 | 133 | |
| 156 | 134 | def chat(self, query: str, image_path: Optional[str] = None) -> dict: |
| 157 | - """Process user query with the agent | |
| 158 | - | |
| 159 | - Args: | |
| 160 | - query: User's text query | |
| 161 | - image_path: Optional path to uploaded image | |
| 135 | + """ | |
| 136 | + Process a user query and return the agent response with metadata. | |
| 162 | 137 | |
| 163 | 138 | Returns: |
| 164 | - Dict with response and metadata, including: | |
| 165 | - - tool_calls: list of tool calls with args and (truncated) results | |
| 166 | - - debug_steps: detailed intermediate reasoning & tool execution steps | |
| 139 | + dict with keys: | |
| 140 | + response – final AI message text (may contain [SEARCH_REF:xxx] tokens) | |
| 141 | + tool_calls – list of {name, args, result_preview} | |
| 142 | + debug_steps – detailed per-node step log | |
| 143 | + search_refs – dict[ref_id → SearchResult] for all searches this turn | |
| 144 | + error – bool | |
| 167 | 145 | """ |
| 168 | 146 | try: |
| 169 | - logger.info( | |
| 170 | - f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})" | |
| 171 | - ) | |
| 147 | + logger.info(f"[{self.session_id}] chat: {query!r} image={bool(image_path)}") | |
| 172 | 148 | |
| 173 | - # Validate image | |
| 174 | 149 | if image_path and not Path(image_path).exists(): |
| 175 | 150 | return { |
| 176 | - "response": f"Error: Image file not found at '{image_path}'", | |
| 151 | + "response": f"错误:图片文件不存在:{image_path}", | |
| 177 | 152 | "error": True, |
| 178 | 153 | } |
| 179 | 154 | |
| 180 | - # Build input message | |
| 155 | + # Snapshot registry before the turn so we can report new additions | |
| 156 | + registry_before = set(global_registry.get_all(self.session_id).keys()) | |
| 157 | + | |
| 181 | 158 | message_content = query |
| 182 | 159 | if image_path: |
| 183 | - message_content = f"{query}\n[User uploaded image: {image_path}]" | |
| 160 | + message_content = f"{query}\n[用户上传了图片:{image_path}]" | |
| 184 | 161 | |
| 185 | - # Invoke agent | |
| 186 | 162 | config = {"configurable": {"thread_id": self.session_id}} |
| 187 | 163 | input_state = { |
| 188 | 164 | "messages": [HumanMessage(content=message_content)], |
| 189 | 165 | "current_image_path": image_path, |
| 190 | 166 | } |
| 191 | 167 | |
| 192 | - # Track tool calls (high-level) and detailed debug steps | |
| 193 | - tool_calls = [] | |
| 194 | - debug_steps = [] | |
| 195 | - | |
| 196 | - # Stream events to capture tool calls and intermediate reasoning | |
| 168 | + tool_calls: list[dict] = [] | |
| 169 | + debug_steps: list[dict] = [] | |
| 170 | + | |
| 197 | 171 | for event in self.graph.stream(input_state, config=config): |
| 198 | - logger.info(f"Event: {event}") | |
| 172 | + logger.debug(f"[{self.session_id}] event keys: {list(event.keys())}") | |
| 199 | 173 | |
| 200 | - # Agent node: LLM reasoning & tool decisions | |
| 201 | 174 | if "agent" in event: |
| 202 | - agent_output = event["agent"] | |
| 203 | - messages = agent_output.get("messages", []) | |
| 175 | + agent_out = event["agent"] | |
| 176 | + step_msgs: list[dict] = [] | |
| 177 | + step_tcs: list[dict] = [] | |
| 204 | 178 | |
| 205 | - step_messages = [] | |
| 206 | - step_tool_calls = [] | |
| 207 | - | |
| 208 | - for msg in messages: | |
| 209 | - msg_text = _extract_message_text(msg) | |
| 210 | - msg_entry = { | |
| 179 | + for msg in agent_out.get("messages", []): | |
| 180 | + text = _extract_message_text(msg) | |
| 181 | + step_msgs.append({ | |
| 211 | 182 | "type": getattr(msg, "type", "assistant"), |
| 212 | - "content": msg_text[:500], # truncate for safety | |
| 213 | - } | |
| 214 | - step_messages.append(msg_entry) | |
| 215 | - | |
| 216 | - # Capture tool calls from this agent message | |
| 183 | + "content": text[:500], | |
| 184 | + }) | |
| 217 | 185 | if hasattr(msg, "tool_calls") and msg.tool_calls: |
| 218 | 186 | for tc in msg.tool_calls: |
| 219 | - tc_entry = { | |
| 220 | - "name": tc.get("name"), | |
| 221 | - "args": tc.get("args", {}), | |
| 222 | - } | |
| 223 | - tool_calls.append(tc_entry) | |
| 224 | - step_tool_calls.append(tc_entry) | |
| 225 | - | |
| 226 | - debug_steps.append( | |
| 227 | - { | |
| 228 | - "node": "agent", | |
| 229 | - "messages": step_messages, | |
| 230 | - "tool_calls": step_tool_calls, | |
| 231 | - } | |
| 232 | - ) | |
| 233 | - | |
| 234 | - # Tool node: actual tool execution results | |
| 235 | - if "tools" in event: | |
| 236 | - tools_output = event["tools"] | |
| 237 | - messages = tools_output.get("messages", []) | |
| 238 | - | |
| 239 | - step_tool_results = [] | |
| 187 | + entry = {"name": tc.get("name"), "args": tc.get("args", {})} | |
| 188 | + tool_calls.append(entry) | |
| 189 | + step_tcs.append(entry) | |
| 240 | 190 | |
| 241 | - for i, msg in enumerate(messages): | |
| 242 | - content_text = _extract_message_text(msg) | |
| 243 | - result_preview = content_text[:500] + ("..." if len(content_text) > 500 else "") | |
| 191 | + debug_steps.append({"node": "agent", "messages": step_msgs, "tool_calls": step_tcs}) | |
| 244 | 192 | |
| 245 | - if i < len(tool_calls): | |
| 246 | - tool_calls[i]["result"] = result_preview | |
| 193 | + if "tools" in event: | |
| 194 | + tools_out = event["tools"] | |
| 195 | + step_results: list[dict] = [] | |
| 196 | + msgs = tools_out.get("messages", []) | |
| 247 | 197 | |
| 248 | - step_tool_results.append( | |
| 249 | - { | |
| 250 | - "content": result_preview, | |
| 251 | - } | |
| 252 | - ) | |
| 198 | + # Match results back to tool_calls by position within this event | |
| 199 | + unresolved = [tc for tc in tool_calls if "result" not in tc] | |
| 200 | + for i, msg in enumerate(msgs): | |
| 201 | + text = _extract_message_text(msg) | |
| 202 | + preview = text[:600] + ("…" if len(text) > 600 else "") | |
| 203 | + if i < len(unresolved): | |
| 204 | + unresolved[i]["result"] = preview | |
| 205 | + step_results.append({"content": preview}) | |
| 253 | 206 | |
| 254 | - debug_steps.append( | |
| 255 | - { | |
| 256 | - "node": "tools", | |
| 257 | - "results": step_tool_results, | |
| 258 | - } | |
| 259 | - ) | |
| 207 | + debug_steps.append({"node": "tools", "results": step_results}) | |
| 260 | 208 | |
| 261 | - # Get final state | |
| 262 | 209 | final_state = self.graph.get_state(config) |
| 263 | - final_message = final_state.values["messages"][-1] | |
| 264 | - response_text = _extract_message_text(final_message) | |
| 210 | + final_msg = final_state.values["messages"][-1] | |
| 211 | + response_text = _extract_message_text(final_msg) | |
| 212 | + | |
| 213 | + # Collect new SearchResults added during this turn | |
| 214 | + registry_after = global_registry.get_all(self.session_id) | |
| 215 | + new_refs = { | |
| 216 | + ref_id: result | |
| 217 | + for ref_id, result in registry_after.items() | |
| 218 | + if ref_id not in registry_before | |
| 219 | + } | |
| 265 | 220 | |
| 266 | - logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls") | |
| 221 | + logger.info( | |
| 222 | + f"[{self.session_id}] done — tool_calls={len(tool_calls)}, new_refs={list(new_refs.keys())}" | |
| 223 | + ) | |
| 267 | 224 | |
| 268 | 225 | return { |
| 269 | 226 | "response": response_text, |
| 270 | 227 | "tool_calls": tool_calls, |
| 271 | 228 | "debug_steps": debug_steps, |
| 229 | + "search_refs": new_refs, | |
| 272 | 230 | "error": False, |
| 273 | 231 | } |
| 274 | 232 | |
| 275 | 233 | except Exception as e: |
| 276 | - logger.error(f"Error in agent chat: {e}", exc_info=True) | |
| 234 | + logger.error(f"[{self.session_id}] chat error: {e}", exc_info=True) | |
| 277 | 235 | return { |
| 278 | - "response": f"I apologize, I encountered an error: {str(e)}", | |
| 236 | + "response": f"抱歉,处理您的请求时遇到错误:{e}", | |
| 237 | + "tool_calls": [], | |
| 238 | + "debug_steps": [], | |
| 239 | + "search_refs": {}, | |
| 279 | 240 | "error": True, |
| 280 | 241 | } |
| 281 | 242 | |
| 282 | 243 | def get_conversation_history(self) -> list: |
| 283 | - """Get conversation history for this session""" | |
| 284 | 244 | try: |
| 285 | 245 | config = {"configurable": {"thread_id": self.session_id}} |
| 286 | 246 | state = self.graph.get_state(config) |
| 287 | - | |
| 288 | 247 | if not state or not state.values.get("messages"): |
| 289 | 248 | return [] |
| 290 | 249 | |
| 291 | - messages = state.values["messages"] | |
| 292 | 250 | result = [] |
| 293 | - | |
| 294 | - for msg in messages: | |
| 295 | - # Skip system messages and tool messages | |
| 251 | + for msg in state.values["messages"]: | |
| 296 | 252 | if isinstance(msg, SystemMessage): |
| 297 | 253 | continue |
| 298 | - if hasattr(msg, "type") and msg.type in ["system", "tool"]: | |
| 254 | + if getattr(msg, "type", None) in ("system", "tool"): | |
| 299 | 255 | continue |
| 300 | - | |
| 301 | 256 | role = "user" if msg.type == "human" else "assistant" |
| 302 | 257 | result.append({"role": role, "content": _extract_message_text(msg)}) |
| 303 | - | |
| 304 | 258 | return result |
| 305 | - | |
| 306 | 259 | except Exception as e: |
| 307 | - logger.error(f"Error getting history: {e}") | |
| 260 | + logger.error(f"get_conversation_history error: {e}") | |
| 308 | 261 | return [] |
| 309 | 262 | |
| 310 | 263 | def clear_history(self): |
| 311 | - """Clear conversation history for this session""" | |
| 312 | - # With MemorySaver, we can't easily clear, but we can log | |
| 313 | - logger.info(f"[{self.session_id}] History clear requested") | |
| 314 | - # In production, implement proper clearing or use new thread_id | |
| 264 | + logger.info(f"[{self.session_id}] clear requested (use new session_id to fully reset)") | |
| 315 | 265 | |
| 316 | 266 | |
| 317 | 267 | def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent: |
| 318 | - """Factory function to create a shopping agent""" | |
| 319 | 268 | return ShoppingAgent(session_id=session_id) | ... | ... |
| ... | ... | @@ -0,0 +1,100 @@ |
| 1 | +""" | |
| 2 | +Search Result Registry | |
| 3 | + | |
| 4 | +Stores structured search results keyed by session and ref_id. | |
| 5 | +Each [SEARCH_REF:xxx] in an AI response maps to a SearchResult stored here, | |
| 6 | +allowing the UI to render product cards without the LLM ever re-listing them. | |
| 7 | +""" | |
| 8 | + | |
| 9 | +import uuid | |
| 10 | +from dataclasses import dataclass, field | |
| 11 | +from typing import Optional | |
| 12 | + | |
| 13 | + | |
| 14 | +def new_ref_id() -> str: | |
| 15 | + """Generate a short unique search reference ID, e.g. 'sr_3f9a1b2c'.""" | |
| 16 | + return "sr_" + uuid.uuid4().hex[:8] | |
| 17 | + | |
| 18 | + | |
| 19 | +@dataclass | |
| 20 | +class ProductItem: | |
| 21 | + """A single product extracted from a search result, enriched with a match label.""" | |
| 22 | + | |
| 23 | + spu_id: str | |
| 24 | + title: str | |
| 25 | + price: Optional[float] = None | |
| 26 | + category_path: Optional[str] = None | |
| 27 | + vendor: Optional[str] = None | |
| 28 | + image_url: Optional[str] = None | |
| 29 | + relevance_score: Optional[float] = None | |
| 30 | + # LLM-assigned label: "完美匹配" | "部分匹配" | "不相关" | |
| 31 | + match_label: str = "部分匹配" | |
| 32 | + tags: list = field(default_factory=list) | |
| 33 | + specifications: list = field(default_factory=list) | |
| 34 | + | |
| 35 | + | |
| 36 | +@dataclass | |
| 37 | +class SearchResult: | |
| 38 | + """ | |
| 39 | + A complete, self-contained search result block. | |
| 40 | + | |
| 41 | + Identified by ref_id (e.g. 'sr_3f9a1b2c'). | |
| 42 | + Stores the query, LLM quality assessment, and the curated product list | |
| 43 | + (only "完美匹配" and "部分匹配" items — "不相关" are discarded). | |
| 44 | + """ | |
| 45 | + | |
| 46 | + ref_id: str | |
| 47 | + query: str | |
| 48 | + | |
| 49 | + # Raw API stats | |
| 50 | + total_api_hits: int # total documents matched by the search engine | |
| 51 | + returned_count: int # number of results we actually assessed | |
| 52 | + | |
| 53 | + # LLM quality labels breakdown | |
| 54 | + perfect_count: int | |
| 55 | + partial_count: int | |
| 56 | + irrelevant_count: int | |
| 57 | + | |
| 58 | + # LLM overall quality verdict | |
| 59 | + quality_verdict: str # "优质" | "一般" | "较差" | |
| 60 | + quality_summary: str # one-sentence LLM explanation | |
| 61 | + | |
| 62 | + # Curated product list (perfect + partial only) | |
| 63 | + products: list # list[ProductItem] | |
| 64 | + | |
| 65 | + | |
| 66 | +class SearchResultRegistry: | |
| 67 | + """ | |
| 68 | + Session-scoped store: session_id → { ref_id → SearchResult }. | |
| 69 | + | |
| 70 | + Lives as a global singleton in the process; Streamlit reruns preserve it | |
| 71 | + as long as the worker process is alive. Session isolation is maintained | |
| 72 | + by keying on session_id. | |
| 73 | + """ | |
| 74 | + | |
| 75 | + def __init__(self) -> None: | |
| 76 | + self._store: dict[str, dict[str, SearchResult]] = {} | |
| 77 | + | |
| 78 | + def register(self, session_id: str, result: SearchResult) -> str: | |
| 79 | + """Store a SearchResult and return its ref_id.""" | |
| 80 | + if session_id not in self._store: | |
| 81 | + self._store[session_id] = {} | |
| 82 | + self._store[session_id][result.ref_id] = result | |
| 83 | + return result.ref_id | |
| 84 | + | |
| 85 | + def get(self, session_id: str, ref_id: str) -> Optional[SearchResult]: | |
| 86 | + """Look up a single SearchResult by session and ref_id.""" | |
| 87 | + return self._store.get(session_id, {}).get(ref_id) | |
| 88 | + | |
| 89 | + def get_all(self, session_id: str) -> dict: | |
| 90 | + """Return all SearchResults for a session (ref_id → SearchResult).""" | |
| 91 | + return dict(self._store.get(session_id, {})) | |
| 92 | + | |
| 93 | + def clear_session(self, session_id: str) -> None: | |
| 94 | + """Remove all search results for a session (e.g. on chat clear).""" | |
| 95 | + self._store.pop(session_id, None) | |
| 96 | + | |
| 97 | + | |
| 98 | +# ── Global singleton ────────────────────────────────────────────────────────── | |
| 99 | +# Imported by search_tools and app.py; both sides share the same object. | |
| 100 | +global_registry = SearchResultRegistry() | ... | ... |
app/tools/__init__.py
| 1 | 1 | """ |
| 2 | 2 | LangChain Tools for Product Search and Discovery |
| 3 | + | |
| 4 | +search_products is created per-session via make_search_products_tool(). | |
| 5 | +Use get_all_tools(session_id, registry) for the full tool list. | |
| 3 | 6 | """ |
| 4 | 7 | |
| 5 | 8 | from app.tools.search_tools import ( |
| 6 | 9 | analyze_image_style, |
| 7 | 10 | get_all_tools, |
| 8 | - search_products, | |
| 11 | + make_search_products_tool, | |
| 12 | + web_search, | |
| 9 | 13 | ) |
| 10 | 14 | |
| 11 | 15 | __all__ = [ |
| 12 | - "search_products", | |
| 16 | + "make_search_products_tool", | |
| 13 | 17 | "analyze_image_style", |
| 18 | + "web_search", | |
| 14 | 19 | "get_all_tools", |
| 15 | 20 | ] | ... | ... |
app/tools/search_tools.py
| 1 | 1 | """ |
| 2 | 2 | Search Tools for Product Discovery |
| 3 | -Provides text-based search via Search API, web search, and VLM style analysis | |
| 3 | + | |
| 4 | +Key design: | |
| 5 | +- search_products is created via a factory (make_search_products_tool) that | |
| 6 | + closes over (session_id, registry), so each agent session has its own tool | |
| 7 | + instance pointing to the shared registry. | |
| 8 | +- After calling the search API, an LLM quality-assessment step labels every | |
| 9 | + result as 完美匹配 / 部分匹配 / 不相关 and produces an overall verdict. | |
| 10 | +- The curated product list is stored in the registry under a unique ref_id. | |
| 11 | +- The tool returns ONLY the quality summary + [SEARCH_REF:ref_id], never the | |
| 12 | + raw product list. The LLM references the result in its final response via | |
| 13 | + the [SEARCH_REF:...] token; the UI renders the product cards from the registry. | |
| 4 | 14 | """ |
| 5 | 15 | |
| 6 | 16 | import base64 |
| 17 | +import json | |
| 7 | 18 | import logging |
| 8 | 19 | import os |
| 9 | 20 | from pathlib import Path |
| ... | ... | @@ -14,6 +25,13 @@ from langchain_core.tools import tool |
| 14 | 25 | from openai import OpenAI |
| 15 | 26 | |
| 16 | 27 | from app.config import settings |
| 28 | +from app.search_registry import ( | |
| 29 | + ProductItem, | |
| 30 | + SearchResult, | |
| 31 | + SearchResultRegistry, | |
| 32 | + global_registry, | |
| 33 | + new_ref_id, | |
| 34 | +) | |
| 17 | 35 | |
| 18 | 36 | logger = logging.getLogger(__name__) |
| 19 | 37 | |
| ... | ... | @@ -30,31 +48,264 @@ def get_openai_client() -> OpenAI: |
| 30 | 48 | return _openai_client |
| 31 | 49 | |
| 32 | 50 | |
| 51 | +# ── LLM quality assessment ───────────────────────────────────────────────────── | |
| 52 | + | |
| 53 | +def _assess_search_quality( | |
| 54 | + query: str, | |
| 55 | + raw_products: list, | |
| 56 | +) -> tuple[list[str], str, str]: | |
| 57 | + """ | |
| 58 | + Ask the LLM to evaluate how well each search result matches the query. | |
| 59 | + | |
| 60 | + Returns: | |
| 61 | + labels – list[str], one per product: "完美匹配" | "部分匹配" | "不相关" | |
| 62 | + verdict – str: "优质" | "一般" | "较差" | |
| 63 | + summary – str: one-sentence explanation | |
| 64 | + """ | |
| 65 | + n = len(raw_products) | |
| 66 | + if n == 0: | |
| 67 | + return [], "较差", "搜索未返回任何商品。" | |
| 68 | + | |
| 69 | + # Build a compact product list — only title/category/tags/score to save tokens | |
| 70 | + lines: list[str] = [] | |
| 71 | + for i, p in enumerate(raw_products, 1): | |
| 72 | + title = (p.get("title") or "")[:60] | |
| 73 | + cat = p.get("category_path") or p.get("category_name") or "" | |
| 74 | + tags_raw = p.get("tags") or [] | |
| 75 | + tags = ", ".join(str(t) for t in tags_raw[:5]) | |
| 76 | + score = p.get("relevance_score") or 0 | |
| 77 | + row = f"{i}. [{score:.1f}] {title} | {cat}" | |
| 78 | + if tags: | |
| 79 | + row += f" | 标签:{tags}" | |
| 80 | + lines.append(row) | |
| 81 | + | |
| 82 | + product_text = "\n".join(lines) | |
| 83 | + | |
| 84 | + prompt = f"""你是商品搜索质量评估专家。请评估以下搜索结果与用户查询的匹配程度。 | |
| 85 | + | |
| 86 | +用户查询:{query} | |
| 87 | + | |
| 88 | +搜索结果(共 {n} 条,格式:序号. [相关性分数] 标题 | 分类 | 标签): | |
| 89 | +{product_text} | |
| 90 | + | |
| 91 | +评估说明: | |
| 92 | +- 完美匹配:完全符合用户查询意图,用户必然感兴趣 | |
| 93 | +- 部分匹配:与查询有关联,但不完全满足意图(如品类对但风格偏差、相关配件等) | |
| 94 | +- 不相关:与查询无关,不应展示给用户 | |
| 95 | + | |
| 96 | +整体 verdict 判断标准: | |
| 97 | +- 优质:完美匹配 ≥ 5 条 | |
| 98 | +- 一般:完美匹配 2-4 条 | |
| 99 | +- 较差:完美匹配 < 2 条 | |
| 100 | + | |
| 101 | +请严格按以下 JSON 格式输出,不得有任何额外文字或代码块标记: | |
| 102 | +{{"labels": ["完美匹配", "部分匹配", "不相关", ...], "verdict": "优质", "summary": "一句话评价搜索质量"}} | |
| 103 | + | |
| 104 | +labels 数组长度必须恰好等于 {n}。""" | |
| 105 | + | |
| 106 | + try: | |
| 107 | + client = get_openai_client() | |
| 108 | + resp = client.chat.completions.create( | |
| 109 | + model=settings.openai_model, | |
| 110 | + messages=[{"role": "user", "content": prompt}], | |
| 111 | + max_tokens=800, | |
| 112 | + temperature=0.1, | |
| 113 | + ) | |
| 114 | + raw = resp.choices[0].message.content.strip() | |
| 115 | + # Strip markdown code fences if the model adds them | |
| 116 | + if raw.startswith("```"): | |
| 117 | + raw = raw.split("```")[1] | |
| 118 | + if raw.startswith("json"): | |
| 119 | + raw = raw[4:] | |
| 120 | + raw = raw.strip() | |
| 121 | + | |
| 122 | + data = json.loads(raw) | |
| 123 | + labels: list[str] = data.get("labels", []) | |
| 124 | + | |
| 125 | + # Normalize and pad / trim to match n | |
| 126 | + valid = {"完美匹配", "部分匹配", "不相关"} | |
| 127 | + labels = [l if l in valid else "部分匹配" for l in labels] | |
| 128 | + while len(labels) < n: | |
| 129 | + labels.append("部分匹配") | |
| 130 | + labels = labels[:n] | |
| 131 | + | |
| 132 | + verdict: str = data.get("verdict", "一般") | |
| 133 | + if verdict not in ("优质", "一般", "较差"): | |
| 134 | + verdict = "一般" | |
| 135 | + summary: str = str(data.get("summary", "")) | |
| 136 | + return labels, verdict, summary | |
| 137 | + | |
| 138 | + except Exception as e: | |
| 139 | + logger.warning(f"Quality assessment LLM call failed: {e}; using fallback labels.") | |
| 140 | + return ["部分匹配"] * n, "一般", "质量评估步骤失败,结果仅供参考。" | |
| 141 | + | |
| 142 | + | |
| 143 | +# ── Tool factory ─────────────────────────────────────────────────────────────── | |
| 144 | + | |
| 145 | +def make_search_products_tool( | |
| 146 | + session_id: str, | |
| 147 | + registry: SearchResultRegistry, | |
| 148 | +): | |
| 149 | + """ | |
| 150 | + Return a search_products tool bound to a specific session and registry. | |
| 151 | + | |
| 152 | + The tool: | |
| 153 | + 1. Calls the product search API. | |
| 154 | + 2. Runs LLM quality assessment on up to 20 results. | |
| 155 | + 3. Stores a SearchResult in the registry. | |
| 156 | + 4. Returns a concise quality summary + [SEARCH_REF:ref_id]. | |
| 157 | + The product list is NEVER returned in the tool output text. | |
| 158 | + """ | |
| 159 | + | |
| 160 | + @tool | |
| 161 | + def search_products(query: str, limit: int = 20) -> str: | |
| 162 | + """搜索商品库,根据自然语言描述找到匹配商品,并进行质量评估。 | |
| 163 | + | |
| 164 | + 每次调用专注于单一搜索角度。复杂需求请拆分为多次调用,每次换一个 query。 | |
| 165 | + 工具会自动评估结果质量(完美匹配 / 部分匹配 / 不相关),并给出整体判断。 | |
| 166 | + | |
| 167 | + Args: | |
| 168 | + query: 自然语言商品描述,例如"男士休闲亚麻短裤夏季" | |
| 169 | + limit: 最多返回条数(建议 10-20,越多评估越全面) | |
| 170 | + | |
| 171 | + Returns: | |
| 172 | + 质量评估摘要 + [SEARCH_REF:ref_id],供最终回复引用。 | |
| 173 | + """ | |
| 174 | + try: | |
| 175 | + logger.info(f"[{session_id}] search_products: query={query!r} limit={limit}") | |
| 176 | + | |
| 177 | + url = f"{settings.search_api_base_url.rstrip('/')}/search/" | |
| 178 | + headers = { | |
| 179 | + "Content-Type": "application/json", | |
| 180 | + "X-Tenant-ID": settings.search_api_tenant_id, | |
| 181 | + } | |
| 182 | + payload = { | |
| 183 | + "query": query, | |
| 184 | + "size": min(max(limit, 1), 20), | |
| 185 | + "from": 0, | |
| 186 | + "language": "zh", | |
| 187 | + } | |
| 188 | + | |
| 189 | + resp = requests.post(url, json=payload, headers=headers, timeout=60) | |
| 190 | + if resp.status_code != 200: | |
| 191 | + logger.error(f"Search API error {resp.status_code}: {resp.text[:300]}") | |
| 192 | + return f"搜索失败:API 返回状态码 {resp.status_code},请稍后重试。" | |
| 193 | + | |
| 194 | + data = resp.json() | |
| 195 | + raw_results: list = data.get("results", []) | |
| 196 | + total_hits: int = data.get("total", 0) | |
| 197 | + | |
| 198 | + if not raw_results: | |
| 199 | + return ( | |
| 200 | + f"【搜索完成】query='{query}'\n" | |
| 201 | + "未找到匹配商品,建议换用更宽泛或不同角度的关键词重新搜索。" | |
| 202 | + ) | |
| 203 | + | |
| 204 | + # ── LLM quality assessment ────────────────────────────────────── | |
| 205 | + labels, verdict, quality_summary = _assess_search_quality(query, raw_results) | |
| 206 | + | |
| 207 | + # ── Build ProductItem list (keep perfect + partial, discard irrelevant) ── | |
| 208 | + products: list[ProductItem] = [] | |
| 209 | + perfect_count = partial_count = irrelevant_count = 0 | |
| 210 | + | |
| 211 | + for raw, label in zip(raw_results, labels): | |
| 212 | + if label == "完美匹配": | |
| 213 | + perfect_count += 1 | |
| 214 | + elif label == "部分匹配": | |
| 215 | + partial_count += 1 | |
| 216 | + else: | |
| 217 | + irrelevant_count += 1 | |
| 218 | + | |
| 219 | + if label in ("完美匹配", "部分匹配"): | |
| 220 | + products.append( | |
| 221 | + ProductItem( | |
| 222 | + spu_id=str(raw.get("spu_id", "")), | |
| 223 | + title=raw.get("title") or "", | |
| 224 | + price=raw.get("price"), | |
| 225 | + category_path=( | |
| 226 | + raw.get("category_path") or raw.get("category_name") | |
| 227 | + ), | |
| 228 | + vendor=raw.get("vendor"), | |
| 229 | + image_url=raw.get("image_url"), | |
| 230 | + relevance_score=raw.get("relevance_score"), | |
| 231 | + match_label=label, | |
| 232 | + tags=raw.get("tags") or [], | |
| 233 | + specifications=raw.get("specifications") or [], | |
| 234 | + ) | |
| 235 | + ) | |
| 236 | + | |
| 237 | + # ── Register ──────────────────────────────────────────────────── | |
| 238 | + ref_id = new_ref_id() | |
| 239 | + result = SearchResult( | |
| 240 | + ref_id=ref_id, | |
| 241 | + query=query, | |
| 242 | + total_api_hits=total_hits, | |
| 243 | + returned_count=len(raw_results), | |
| 244 | + perfect_count=perfect_count, | |
| 245 | + partial_count=partial_count, | |
| 246 | + irrelevant_count=irrelevant_count, | |
| 247 | + quality_verdict=verdict, | |
| 248 | + quality_summary=quality_summary, | |
| 249 | + products=products, | |
| 250 | + ) | |
| 251 | + registry.register(session_id, result) | |
| 252 | + logger.info( | |
| 253 | + f"[{session_id}] Registered {ref_id}: verdict={verdict}, " | |
| 254 | + f"perfect={perfect_count}, partial={partial_count}, irrel={irrelevant_count}" | |
| 255 | + ) | |
| 256 | + | |
| 257 | + # ── Return summary to agent (NOT the product list) ────────────── | |
| 258 | + verdict_hint = { | |
| 259 | + "优质": "结果质量优质,可直接引用。", | |
| 260 | + "一般": "结果质量一般,可酌情引用,也可补充更精准的 query。", | |
| 261 | + "较差": "结果质量较差,建议重新规划 query 后再次搜索。", | |
| 262 | + }.get(verdict, "") | |
| 263 | + | |
| 264 | + return ( | |
| 265 | + f"【搜索完成】query='{query}'\n" | |
| 266 | + f"API 总命中:{total_hits} 条 | 本次评估:{len(raw_results)} 条\n" | |
| 267 | + f"质量评估:完美匹配 {perfect_count} 条 | 部分匹配 {partial_count} 条 | 不相关 {irrelevant_count} 条\n" | |
| 268 | + f"整体判断:{verdict} — {quality_summary}\n" | |
| 269 | + f"{verdict_hint}\n" | |
| 270 | + f"结果引用:[SEARCH_REF:{ref_id}]" | |
| 271 | + ) | |
| 272 | + | |
| 273 | + except requests.exceptions.RequestException as e: | |
| 274 | + logger.error(f"[{session_id}] Search network error: {e}", exc_info=True) | |
| 275 | + return f"搜索失败(网络错误):{e}" | |
| 276 | + except Exception as e: | |
| 277 | + logger.error(f"[{session_id}] Search error: {e}", exc_info=True) | |
| 278 | + return f"搜索失败:{e}" | |
| 279 | + | |
| 280 | + return search_products | |
| 281 | + | |
| 282 | + | |
| 283 | +# ── Standalone tools (no session binding needed) ─────────────────────────────── | |
| 284 | + | |
| 33 | 285 | @tool |
| 34 | 286 | def web_search(query: str) -> str: |
| 35 | 287 | """使用 Tavily 进行通用 Web 搜索,补充外部/实时知识。 |
| 36 | 288 | |
| 37 | - 触发场景(示例): | |
| 38 | - - 需要**外部知识**:流行趋势、新品信息、穿搭文化、品牌故事等 | |
| 39 | - - 需要**实时/及时信息**:某地某个时节的天气、当季流行元素、最新联名款 | |
| 40 | - - 需要**宏观参考**:不同城市/国家的穿衣习惯、节日穿搭建议 | |
| 289 | + 触发场景: | |
| 290 | + - 需要**外部知识**:流行趋势、品牌、搭配文化、节日习俗等 | |
| 291 | + - 需要**实时/及时信息**:当季流行元素、某地未来的天气 | |
| 292 | + - 需要**宏观参考**:不同场合/国家的穿着建议、选购攻略 | |
| 41 | 293 | |
| 42 | 294 | Args: |
| 43 | - query: 要搜索的问题,自然语言描述(建议用中文) | |
| 295 | + query: 要搜索的问题,自然语言描述 | |
| 44 | 296 | |
| 45 | 297 | Returns: |
| 46 | - 总结后的回答 + 若干来源链接,供模型继续推理使用。 | |
| 298 | + 总结后的回答 + 若干参考来源链接 | |
| 47 | 299 | """ |
| 48 | 300 | try: |
| 49 | 301 | api_key = os.getenv("TAVILY_API_KEY") |
| 50 | 302 | if not api_key: |
| 51 | - logger.error("TAVILY_API_KEY is not set in environment variables") | |
| 52 | 303 | return ( |
| 53 | 304 | "无法调用外部 Web 搜索:未检测到 TAVILY_API_KEY 环境变量。\n" |
| 54 | 305 | "请在运行环境中配置 TAVILY_API_KEY 后再重试。" |
| 55 | 306 | ) |
| 56 | 307 | |
| 57 | - logger.info(f"Calling Tavily web search with query: {query!r}") | |
| 308 | + logger.info(f"web_search: {query!r}") | |
| 58 | 309 | |
| 59 | 310 | url = "https://api.tavily.com/search" |
| 60 | 311 | headers = { |
| ... | ... | @@ -66,15 +317,9 @@ def web_search(query: str) -> str: |
| 66 | 317 | "search_depth": "advanced", |
| 67 | 318 | "include_answer": True, |
| 68 | 319 | } |
| 69 | - | |
| 70 | 320 | response = requests.post(url, json=payload, headers=headers, timeout=60) |
| 71 | 321 | |
| 72 | 322 | if response.status_code != 200: |
| 73 | - logger.error( | |
| 74 | - "Tavily API error: %s - %s", | |
| 75 | - response.status_code, | |
| 76 | - response.text, | |
| 77 | - ) | |
| 78 | 323 | return f"调用外部 Web 搜索失败:Tavily 返回状态码 {response.status_code}" |
| 79 | 324 | |
| 80 | 325 | data = response.json() |
| ... | ... | @@ -87,140 +332,61 @@ def web_search(query: str) -> str: |
| 87 | 332 | "回答摘要:", |
| 88 | 333 | answer.strip(), |
| 89 | 334 | ] |
| 90 | - | |
| 91 | 335 | if results: |
| 92 | 336 | output_lines.append("") |
| 93 | 337 | output_lines.append("参考来源(部分):") |
| 94 | 338 | for idx, item in enumerate(results[:5], 1): |
| 95 | 339 | title = item.get("title") or "无标题" |
| 96 | - url = item.get("url") or "" | |
| 340 | + link = item.get("url") or "" | |
| 97 | 341 | output_lines.append(f"{idx}. {title}") |
| 98 | - if url: | |
| 99 | - output_lines.append(f" 链接: {url}") | |
| 342 | + if link: | |
| 343 | + output_lines.append(f" 链接: {link}") | |
| 100 | 344 | |
| 101 | 345 | return "\n".join(output_lines).strip() |
| 102 | 346 | |
| 103 | 347 | except requests.exceptions.RequestException as e: |
| 104 | - logger.error("Error calling Tavily web search (network): %s", e, exc_info=True) | |
| 348 | + logger.error("web_search network error: %s", e, exc_info=True) | |
| 105 | 349 | return f"调用外部 Web 搜索失败(网络错误):{e}" |
| 106 | 350 | except Exception as e: |
| 107 | - logger.error("Error calling Tavily web search: %s", e, exc_info=True) | |
| 351 | + logger.error("web_search error: %s", e, exc_info=True) | |
| 108 | 352 | return f"调用外部 Web 搜索失败:{e}" |
| 109 | 353 | |
| 110 | 354 | |
| 111 | 355 | @tool |
| 112 | -def search_products(query: str, limit: int = 5) -> str: | |
| 113 | - """Search for fashion products using natural language descriptions. | |
| 114 | - | |
| 115 | - Use when users describe what they want: | |
| 116 | - - "Find me red summer dresses" | |
| 117 | - - "Show me blue running shoes" | |
| 118 | - - "I want casual shirts for men" | |
| 119 | - | |
| 120 | - Args: | |
| 121 | - query: Natural language product description | |
| 122 | - limit: Maximum number of results (1-20) | |
| 123 | - | |
| 124 | - Returns: | |
| 125 | - Formatted string with product information | |
| 126 | - """ | |
| 127 | - try: | |
| 128 | - logger.info(f"Searching products: '{query}', limit: {limit}") | |
| 129 | - | |
| 130 | - url = f"{settings.search_api_base_url.rstrip('/')}/search/" | |
| 131 | - headers = { | |
| 132 | - "Content-Type": "application/json", | |
| 133 | - "X-Tenant-ID": settings.search_api_tenant_id, | |
| 134 | - } | |
| 135 | - payload = { | |
| 136 | - "query": query, | |
| 137 | - "size": min(limit, 20), | |
| 138 | - "from": 0, | |
| 139 | - "language": "zh", | |
| 140 | - } | |
| 141 | - | |
| 142 | - response = requests.post(url, json=payload, headers=headers, timeout=60) | |
| 143 | - | |
| 144 | - if response.status_code != 200: | |
| 145 | - logger.error(f"Search API error: {response.status_code} - {response.text}") | |
| 146 | - return f"Error searching products: API returned {response.status_code}" | |
| 147 | - | |
| 148 | - data = response.json() | |
| 149 | - results = data.get("results", []) | |
| 150 | - | |
| 151 | - if not results: | |
| 152 | - return "No products found matching your search." | |
| 153 | - | |
| 154 | - output = f"Found {len(results)} product(s):\n\n" | |
| 155 | - | |
| 156 | - for idx, product in enumerate(results, 1): | |
| 157 | - output += f"{idx}. {product.get('title', 'Unknown Product')}\n" | |
| 158 | - output += f" ID: {product.get('spu_id', 'N/A')}\n" | |
| 159 | - output += f" Category: {product.get('category_path', product.get('category_name', 'N/A'))}\n" | |
| 160 | - if product.get("vendor"): | |
| 161 | - output += f" Brand: {product.get('vendor')}\n" | |
| 162 | - if product.get("price") is not None: | |
| 163 | - output += f" Price: {product.get('price')}\n" | |
| 164 | - | |
| 165 | - # 规格/颜色信息 | |
| 166 | - specs = product.get("specifications", []) | |
| 167 | - if specs: | |
| 168 | - color_spec = next( | |
| 169 | - (s for s in specs if s.get("name").lower() == "color"), | |
| 170 | - None, | |
| 171 | - ) | |
| 172 | - if color_spec: | |
| 173 | - output += f" Color: {color_spec.get('value', 'N/A')}\n" | |
| 174 | - | |
| 175 | - output += "\n" | |
| 176 | - | |
| 177 | - return output.strip() | |
| 178 | - | |
| 179 | - except requests.exceptions.RequestException as e: | |
| 180 | - logger.error(f"Error searching products (network): {e}", exc_info=True) | |
| 181 | - return f"Error searching products: {str(e)}" | |
| 182 | - except Exception as e: | |
| 183 | - logger.error(f"Error searching products: {e}", exc_info=True) | |
| 184 | - return f"Error searching products: {str(e)}" | |
| 185 | - | |
| 186 | - | |
| 187 | -@tool | |
| 188 | 356 | def analyze_image_style(image_path: str) -> str: |
| 189 | - """Analyze a fashion product image using AI vision to extract detailed style information. | |
| 357 | + """分析用户上传的商品图片,提取视觉风格属性,用于后续商品搜索。 | |
| 190 | 358 | |
| 191 | - Use when you need to understand style/attributes from an image: | |
| 192 | - - Understand the style, color, pattern of a product | |
| 193 | - - Extract attributes like "casual", "formal", "vintage" | |
| 194 | - - Get detailed descriptions for subsequent searches | |
| 359 | + 适用场景: | |
| 360 | + - 用户上传图片,想找相似商品 | |
| 361 | + - 需要理解图片中商品的风格、颜色、材质等属性 | |
| 195 | 362 | |
| 196 | 363 | Args: |
| 197 | - image_path: Path to the image file | |
| 364 | + image_path: 图片文件路径 | |
| 198 | 365 | |
| 199 | 366 | Returns: |
| 200 | - Detailed text description of the product's visual attributes | |
| 367 | + 商品视觉属性的详细文字描述,可直接作为 search_products 的 query | |
| 201 | 368 | """ |
| 202 | 369 | try: |
| 203 | - logger.info(f"Analyzing image with VLM: '{image_path}'") | |
| 370 | + logger.info(f"analyze_image_style: {image_path!r}") | |
| 204 | 371 | |
| 205 | 372 | img_path = Path(image_path) |
| 206 | 373 | if not img_path.exists(): |
| 207 | - return f"Error: Image file not found at '{image_path}'" | |
| 374 | + return f"错误:图片文件不存在:{image_path}" | |
| 208 | 375 | |
| 209 | - with open(img_path, "rb") as image_file: | |
| 210 | - image_data = base64.b64encode(image_file.read()).decode("utf-8") | |
| 376 | + with open(img_path, "rb") as f: | |
| 377 | + image_data = base64.b64encode(f.read()).decode("utf-8") | |
| 211 | 378 | |
| 212 | - prompt = """Analyze this fashion product image and provide a detailed description. | |
| 379 | + prompt = """请分析这张商品图片,提供详细的视觉属性描述,用于商品搜索。 | |
| 213 | 380 | |
| 214 | -Include: | |
| 215 | -- Product type (e.g., shirt, dress, shoes, pants, bag) | |
| 216 | -- Primary colors | |
| 217 | -- Style/design (e.g., casual, formal, sporty, vintage, modern) | |
| 218 | -- Pattern or texture (e.g., plain, striped, checked, floral) | |
| 219 | -- Key features (e.g., collar type, sleeve length, fit) | |
| 220 | -- Material appearance (if obvious, e.g., denim, cotton, leather) | |
| 221 | -- Suitable occasion (e.g., office wear, party, casual, sports) | |
| 381 | +请包含: | |
| 382 | +- 商品类型(如:连衣裙、运动鞋、双肩包、西装等) | |
| 383 | +- 主要颜色 | |
| 384 | +- 风格定位(如:休闲、正式、运动、复古、现代简约等) | |
| 385 | +- 图案/纹理(如:纯色、条纹、格纹、碎花、几何图案等) | |
| 386 | +- 关键设计特征(如:领型、袖长、版型、材质外观等) | |
| 387 | +- 适用场合(如:办公、户外、度假、聚会、运动等) | |
| 222 | 388 | |
| 223 | -Provide a comprehensive yet concise description (3-4 sentences).""" | |
| 389 | +输出格式:3-4句自然语言描述,可直接用作搜索关键词。""" | |
| 224 | 390 | |
| 225 | 391 | client = get_openai_client() |
| 226 | 392 | response = client.chat.completions.create( |
| ... | ... | @@ -245,15 +411,29 @@ Provide a comprehensive yet concise description (3-4 sentences).""" |
| 245 | 411 | ) |
| 246 | 412 | |
| 247 | 413 | analysis = response.choices[0].message.content.strip() |
| 248 | - logger.info("VLM analysis completed") | |
| 249 | - | |
| 414 | + logger.info("Image analysis completed.") | |
| 250 | 415 | return analysis |
| 251 | 416 | |
| 252 | 417 | except Exception as e: |
| 253 | - logger.error(f"Error analyzing image: {e}", exc_info=True) | |
| 254 | - return f"Error analyzing image: {str(e)}" | |
| 418 | + logger.error(f"analyze_image_style error: {e}", exc_info=True) | |
| 419 | + return f"图片分析失败:{e}" | |
| 255 | 420 | |
| 256 | 421 | |
| 257 | -def get_all_tools(): | |
| 258 | - """Get all available tools for the agent""" | |
| 259 | - return [search_products, analyze_image_style, web_search] | |
| 422 | +# ── Tool list factory ────────────────────────────────────────────────────────── | |
| 423 | + | |
| 424 | +def get_all_tools( | |
| 425 | + session_id: str = "default", | |
| 426 | + registry: Optional[SearchResultRegistry] = None, | |
| 427 | +) -> list: | |
| 428 | + """ | |
| 429 | + Return all agent tools. | |
| 430 | + | |
| 431 | + search_products is session-bound (factory); other tools are stateless. | |
| 432 | + """ | |
| 433 | + if registry is None: | |
| 434 | + registry = global_registry | |
| 435 | + return [ | |
| 436 | + make_search_products_tool(session_id, registry), | |
| 437 | + analyze_image_style, | |
| 438 | + web_search, | |
| 439 | + ] | ... | ... |