""" ShopAgent - Streamlit UI Multi-modal fashion shopping assistant with conversational AI """ import html import logging import re import uuid from collections import OrderedDict from pathlib import Path from typing import Any, Optional import streamlit as st import streamlit.components.v1 as st_components from PIL import Image, ImageOps from app.agents.shopping_agent import ShoppingAgent from app.search_registry import ProductItem, SearchResult, global_registry # Matches [SEARCH_RESULTS_REF:sr_xxxxxxxx] tokens embedded in AI responses. # Case-insensitive, optional spaces around the id. SEARCH_RESULTS_REF_PATTERN = re.compile(r"\[SEARCH_RESULTS_REF:\s*([a-zA-Z0-9_]+)\s*\]", re.IGNORECASE) # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # In-memory image cache (url or "local:path" -> PIL Image), max 100 entries _IMAGE_CACHE: OrderedDict = OrderedDict() _IMAGE_CACHE_MAX = 100 # Page config st.set_page_config( page_title="ShopAgent", page_icon="👗", layout="centered", initial_sidebar_state="collapsed", ) # Custom CSS - ChatGPT-like style st.markdown( """ """, unsafe_allow_html=True, ) # Initialize session state def initialize_session(): """Initialize session state variables""" if "session_id" not in st.session_state: st.session_state.session_id = str(uuid.uuid4()) if "shopping_agent" not in st.session_state: st.session_state.shopping_agent = ShoppingAgent( session_id=st.session_state.session_id ) if "messages" not in st.session_state: st.session_state.messages = [] if "uploaded_image" not in st.session_state: st.session_state.uploaded_image = None if "show_image_upload" not in st.session_state: st.session_state.show_image_upload = False # Debug panel toggle (default True so 显示调试过程 is checked by default) if "show_debug" not in st.session_state: st.session_state.show_debug = True # Selected products for ask/compare (key -> product info dict) if "selected_products" not in st.session_state: st.session_state.selected_products = {} # Right side panel: visible, mode in ("similar", "compare"), payload (e.g. ref_id, query, or list of selected items) if "side_panel" not in st.session_state: st.session_state.side_panel = { "visible": False, "mode": None, "payload": None, } # Products currently referenced in chat input (list of product summary dicts) if "referenced_products" not in st.session_state: st.session_state.referenced_products = [] def save_uploaded_image(uploaded_file) -> Optional[str]: """Save uploaded image to temp directory""" if uploaded_file is None: return None try: temp_dir = Path("temp_uploads") temp_dir.mkdir(exist_ok=True) image_path = temp_dir / f"{st.session_state.session_id}_{uploaded_file.name}" with open(image_path, "wb") as f: f.write(uploaded_file.getbuffer()) logger.info(f"Saved uploaded image to {image_path}") return str(image_path) except Exception as e: logger.error(f"Error saving uploaded image: {e}") st.error(f"Failed to save image: {str(e)}") return None def _product_key(ref_id: str, index: int, product: ProductItem) -> str: """Stable unique key for a product in the session (for selection and side panel).""" return f"{ref_id}_{index}_{product.spu_id or index}" def _product_to_info(product: ProductItem, ref_id: str) -> dict: """Serialize product to a small dict for selected_products and ask/compare.""" return { "ref_id": ref_id, "spu_id": product.spu_id, "sku_id": product.spu_id, "title": product.title or "未知商品", "price": product.price, "tags": product.tags or [], "specifications": product.specifications or [], } def _compact_field(value: Any) -> str: """Format a field into one readable line for chat reference payload.""" if value is None: return "-" if isinstance(value, list): if not value: return "-" parts = [] for item in value: if isinstance(item, dict): text = ", ".join(f"{k}:{v}" for k, v in item.items()) parts.append(text if text else str(item)) else: parts.append(str(item)) return " | ".join(p for p in parts if p) or "-" return str(value) def _build_reference_prefix(products: list[dict]) -> str: """Build backend prompt prefix for 'chat with referenced products'.""" lines = [f"引用 {len(products)} 款商品:"] for i, p in enumerate(products, 1): sku_id = _compact_field(p.get("sku_id") or p.get("spu_id")) title = _compact_field(p.get("title")) price = _compact_field(p.get("price")) tags = _compact_field(p.get("tags")) specifications = _compact_field(p.get("specifications")) lines.append( f"{i}. sku_id={sku_id}; title={title}; price={price}; " f"tags={tags}; specifications={specifications}" ) return "\n".join(lines) @st.fragment def render_referenced_products_in_input() -> None: """Render referenced products above chat input, each with remove button.""" refs = st.session_state.get("referenced_products", []) if not refs: return st.markdown("**已引用商品**") remove_idx = None for idx, item in enumerate(refs): with st.container(border=True): c1, c2 = st.columns([12, 1]) with c1: title = (item.get("title") or "未知商品")[:80] st.markdown(f"**{title}**") st.caption( f"sku_id={item.get('sku_id') or item.get('spu_id') or '-'}; " f"price={_compact_field(item.get('price'))}; " f"tags={_compact_field(item.get('tags'))}; " f"specifications={_compact_field(item.get('specifications'))}" ) with c2: if st.button("✕", key=f"remove_ref_{idx}", help="删除该引用"): remove_idx = idx if remove_idx is not None: refs.pop(remove_idx) st.session_state.referenced_products = refs st.rerun() def _load_product_image(product: ProductItem) -> Optional[Image.Image]: """Load product image with cache: image_url or local data/images. Cache key = url or 'local:path'.""" cache_key: Optional[str] = None if product.image_url: cache_key = product.image_url if cache_key in _IMAGE_CACHE: _IMAGE_CACHE.move_to_end(cache_key) return _IMAGE_CACHE[cache_key] try: import io import requests resp = requests.get(product.image_url, timeout=10) if resp.status_code == 200: img = Image.open(io.BytesIO(resp.content)) _IMAGE_CACHE[cache_key] = img _IMAGE_CACHE.move_to_end(cache_key) if len(_IMAGE_CACHE) > _IMAGE_CACHE_MAX: _IMAGE_CACHE.popitem(last=False) return img except Exception as e: logger.debug(f"Remote image fetch failed for {product.spu_id}: {e}") local = Path(f"data/images/{product.spu_id}.jpg") if local.exists(): cache_key = f"local:{local}" if cache_key in _IMAGE_CACHE: _IMAGE_CACHE.move_to_end(cache_key) return _IMAGE_CACHE[cache_key] try: img = Image.open(local) _IMAGE_CACHE[cache_key] = img _IMAGE_CACHE.move_to_end(cache_key) if len(_IMAGE_CACHE) > _IMAGE_CACHE_MAX: _IMAGE_CACHE.popitem(last=False) return img except Exception as e: logger.debug(f"Local image load failed {local}: {e}") return None def display_product_card_from_item( product: ProductItem, ref_id: str, index: int, widget_prefix: str = "", ) -> None: """Render a single product card with hover actions: Similar products + checkbox.""" pkey = _product_key(ref_id, index, product) key_suffix = f"{widget_prefix}_{pkey}" if widget_prefix else pkey info = _product_to_info(product, ref_id) selected = st.session_state.selected_products st.markdown('
', unsafe_allow_html=True) img = _load_product_image(product) if img: target = (220, 220) try: img = ImageOps.fit(img, target, method=Image.Resampling.LANCZOS) except AttributeError: img = ImageOps.fit(img, target, method=Image.LANCZOS) st.image(img, width="stretch") else: st.markdown( '
🛍️
', unsafe_allow_html=True, ) title = product.title or "未知商品" st.markdown(f"**{title[:40]}**" + ("…" if len(title) > 40 else "")) if product.price is not None: st.caption(f"¥{product.price:.2f}") label_style = "⭐" if product.match_label == "Relevant" else "✦" st.caption(f"{label_style} {product.match_label}") st.markdown('
', unsafe_allow_html=True) col_a, col_b = st.columns([1, 1]) with col_a: similar_clicked = st.button( "Similar products", key=f"similar_{key_suffix}", help="Search by this product title and show in side panel", ) with col_b: is_checked = st.checkbox( "Select", key=f"select_{key_suffix}", value=(pkey in selected), label_visibility="collapsed", ) st.markdown("
", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) if similar_clicked: search_query = (product.title or "").strip() or "商品" st.session_state.side_panel = { "visible": True, "mode": "similar", "payload": {"query": search_query, "loading": True}, } st.rerun() if is_checked: if pkey not in selected: selected[pkey] = info else: selected.pop(pkey, None) def render_search_result_block(result: SearchResult, widget_prefix: str = "") -> None: """ Render a full search result block in place of a [SEARCH_RESULTS_REF:ref_id] token. widget_prefix: unique per (message, ref block) so Streamlit widget keys stay unique. """ summary_line = f'  · {result.quality_summary}' if result.quality_summary else '' header_html = ( f'
' f'' f'🔍 {result.query}' f' · Relevant {result.perfect_count} 件' f' · Partially Relevant {result.partial_count} 件' f'{summary_line}' f'
' ) st.markdown(header_html, unsafe_allow_html=True) # Perfect matches first, fall back to partials if none perfect = [p for p in result.products if p.match_label == "Relevant"] partial = [p for p in result.products if p.match_label == "Partially Relevant"] to_show = (perfect + partial)[:6] if perfect else partial[:6] if not to_show: st.caption("(本次搜索未找到可展示的商品)") return cols = st.columns(min(len(to_show), 3)) for i, product in enumerate(to_show): with cols[i % 3]: display_product_card_from_item( product, result.ref_id, i, widget_prefix=widget_prefix ) def render_message_with_refs( content: str, session_id: str, fallback_refs: Optional[dict] = None, msg_index: int = 0, ) -> None: """ Render an assistant message that may contain [SEARCH_RESULTS_REF:ref_id] tokens. msg_index: message index in chat, used to keep widget keys unique across messages. """ fallback_refs = fallback_refs or {} parts = SEARCH_RESULTS_REF_PATTERN.split(content) for i, segment in enumerate(parts): if i % 2 == 0: text = segment.strip() if text: st.markdown(text) else: ref_id = segment.strip() result = global_registry.get(session_id, ref_id) or fallback_refs.get(ref_id) if result: widget_prefix = f"m{msg_index}_r{i}" render_search_result_block(result, widget_prefix=widget_prefix) else: st.caption(f"[搜索结果 {ref_id} 不可用]") def render_debug_steps_panel(debug_steps: list[dict], expanded: bool = True) -> None: """Render debug steps with thinking/tool details.""" with st.expander("思考 & 工具调用详细过程", expanded=expanded): for idx, step in enumerate(debug_steps, 1): node = step.get("node", "unknown") st.markdown(f"**Step {idx} – {node}**") if node == "agent": msgs = step.get("messages", []) if msgs: st.markdown("**Agent Messages**") for m in msgs: st.markdown(f"- `{m.get('type', 'assistant')}`: {m.get('content', '')}") if m.get("thinking"): st.markdown(" - `thinking`:") st.code(m.get("thinking", ""), language="text") tcs = step.get("tool_calls", []) if tcs: st.markdown("**Planned Tool Calls**") for j, tc in enumerate(tcs, 1): st.markdown(f"- **{j}. {tc.get('name')}**") st.code(tc.get("args", {}), language="json") elif node == "tools": results = step.get("results", []) if results: st.markdown("**Tool Results**") for j, r in enumerate(results, 1): st.markdown(f"- **Result {j}:**") st.code(r.get("content", ""), language="text") st.markdown("---") def display_message(message: dict, msg_index: int = 0): """Display a chat message. msg_index keeps widget keys unique across messages.""" role = message["role"] content = message["content"] image_path = message.get("image_path") tool_calls = message.get("tool_calls", []) debug_steps = message.get("debug_steps", []) if role == "user": st.markdown('
', unsafe_allow_html=True) if image_path and Path(image_path).exists(): try: img = Image.open(image_path) st.image(img, width=200) except Exception: logger.warning(f"Failed to load user uploaded image: {image_path}") st.markdown(content) st.markdown("
", unsafe_allow_html=True) else: # assistant # Tool call breadcrumb if tool_calls: tool_names = [tc["name"] for tc in tool_calls] st.caption(" → ".join(tool_names)) st.markdown("") # Debug panel if debug_steps and st.session_state.get("show_debug"): render_debug_steps_panel(debug_steps, expanded=True) # Render message: expand [SEARCH_RESULTS_REF:ref_id] tokens into product card blocks session_id = st.session_state.get("session_id", "") render_message_with_refs( content, session_id, fallback_refs=message.get("search_refs"), msg_index=msg_index ) st.markdown("", unsafe_allow_html=True) @st.fragment def render_bottom_actions_bar() -> None: """Show Ask and Compare when there are selected products. Disabled when none selected.""" selected = st.session_state.selected_products n = len(selected) if n == 0: return st.markdown( '
', unsafe_allow_html=True, ) col_sel, col_ask, col_cmp = st.columns([2, 1, 1]) with col_sel: st.caption(f"Selected: {n}") with col_ask: ask_clicked = st.button("Ask", key="bottom_ask", help="Continue conversation with selected products") with col_cmp: compare_clicked = st.button("Compare", key="bottom_compare", help="Compare selected products") st.markdown("
", unsafe_allow_html=True) if ask_clicked: st.session_state.referenced_products = list(selected.values()) st.rerun() if compare_clicked: st.session_state.side_panel = { "visible": True, "mode": "compare", "payload": list(selected.values()), } st.rerun() def render_side_drawer() -> None: """Render a fixed overlay side drawer that does not change background layout.""" panel = st.session_state.side_panel if not panel.get("visible") or not panel.get("mode"): return mode = panel["mode"] payload = panel.get("payload") or {} session_id = st.session_state.get("session_id", "") title = "Similar products" if mode == "similar" else "Compare" body_html = "" if mode == "similar": query = html.escape((payload.get("query") or "")) if payload.get("loading"): body_html = '

加载中…

' elif payload.get("products") is not None: to_show = payload["products"][:12] cards = [] for product in to_show: p_title = html.escape((product.title or "未知商品")[:80]) price = ( f"¥{product.price:.2f}" if product.price is not None else "价格待更新" ) image_html = ( f'{p_title}' if product.image_url else '
🛍️
' ) cards.append( '
' f"{image_html}" '
' f'
{p_title}
' f'
{price}
' "
" ) cards_html = "".join(cards) if cards else '

(未找到可展示的商品)

' body_html = ( f'
' f'基于「{query}」的搜索结果:
' '
' + cards_html + "
" ) else: # Legacy: ref_id from registry (e.g. from chat) ref_id = payload.get("ref_id") if ref_id: result = global_registry.get(session_id, ref_id) if result: to_show = (result.products or [])[:12] cards = [] for product in to_show: p_title = html.escape((product.title or "未知商品")[:80]) price = f"¥{product.price:.2f}" if product.price is not None else "价格待更新" image_html = ( f'{p_title}' if product.image_url else '
🛍️
' ) cards.append( '
' f"{image_html}" f'
{p_title}
' f'
{price}
' ) body_html = ( f'
基于「{query}」的搜索结果:
' '
' + "".join(cards) + "
" ) else: body_html = f'

[搜索结果 {html.escape(ref_id)} 不可用]

' else: body_html = '

搜索失败或暂无结果。

' else: items = payload if isinstance(payload, list) else [] if items: rows = [] for item in items: t = html.escape((item.get("title") or "未知商品")[:80]) p = item.get("price") ptext = f"¥{p:.2f}" if p is not None else "价格待更新" rows.append( '
' f'
{t}
' f'
{ptext}
' "
" ) items_html = "".join(rows) else: items_html = '

当前未选中商品。

' body_html = ( '
已选商品:
' f'
{items_html}
' '
对比功能暂未实现。
' ) st.markdown( f"""
{html.escape(title)}
{body_html}
""", unsafe_allow_html=True, ) st_components.html(""" """, height=0) def display_welcome(): """Display welcome screen""" col1, col2, col3, col4 = st.columns(4) with col1: st.markdown( """
💗
懂你
能记住你的偏好,给你推荐适合的
""", unsafe_allow_html=True, ) with col2: st.markdown( """
🛍️
懂商品
深度理解店铺内所有商品,智能匹配你的需求
""", unsafe_allow_html=True, ) with col3: st.markdown( """
💭
贴心
任意聊
""", unsafe_allow_html=True, ) with col4: st.markdown( """
👗
懂时尚
穿搭顾问 + 轻松对比
""", unsafe_allow_html=True, ) st.markdown("

", unsafe_allow_html=True) def main(): """Main Streamlit app""" initialize_session() # Sync drawer close state from JS (set by JS via history.replaceState, no reload) if st.query_params.get("close_side_panel"): st.session_state.side_panel = {"visible": False, "mode": None, "payload": None} st.query_params.clear() # "Similar" panel: if loading, run API-only search and rerun panel = st.session_state.side_panel if panel.get("visible") and panel.get("mode") == "similar": payload = panel.get("payload") or {} if payload.get("loading") and payload.get("query"): from app.tools.search_tools import search_products_api_only products = search_products_api_only(payload["query"], limit=12) st.session_state.side_panel["payload"] = { "query": payload["query"], "products": products, "loading": False, } st.rerun() # Drawer rendered early so fixed positioning works from top of DOM render_side_drawer() # Header st.markdown( """
👗 ShopAgent
AI Fashion Shopping Assistant
""", unsafe_allow_html=True, ) # Sidebar (collapsed by default, but accessible) @st.fragment def _sidebar_fragment(): st.markdown("### ⚙️ Settings") if st.button("🗑️ Clear Chat", width="stretch"): if "shopping_agent" in st.session_state: st.session_state.shopping_agent.clear_history() session_id = st.session_state.get("session_id", "") if session_id: global_registry.clear_session(session_id) st.session_state.messages = [] st.session_state.uploaded_image = None st.session_state.selected_products = {} st.session_state.referenced_products = [] st.session_state.side_panel = {"visible": False, "mode": None, "payload": None} st.rerun() st.markdown("---") st.checkbox( "显示调试过程 (debug)", key="show_debug", value=True, help="展开后可查看中间思考过程及工具调用详情", ) st.markdown("---") st.caption(f"Session: `{st.session_state.session_id[:8]}...`") with st.sidebar: _sidebar_fragment() MAX_MESSAGES = 50 messages_container = st.container() with messages_container: if not st.session_state.messages: display_welcome() else: messages = st.session_state.messages start_idx = max(0, len(messages) - MAX_MESSAGES) to_show = messages[start_idx:] if len(messages) > MAX_MESSAGES: st.caption(f"(仅显示最近 {MAX_MESSAGES} 条,共 {len(messages)} 条消息)") for i, message in enumerate(to_show): display_message(message, msg_index=start_idx + i) render_bottom_actions_bar() # Fixed input area at bottom (using container to simulate fixed position) st.markdown('
', unsafe_allow_html=True) input_container = st.container() with input_container: # Image upload area (shown when + is clicked) if st.session_state.show_image_upload: uploaded_file = st.file_uploader( "Choose an image", type=["jpg", "jpeg", "png"], key="file_uploader", ) if uploaded_file: st.session_state.uploaded_image = uploaded_file # Show preview col1, col2 = st.columns([1, 4]) with col1: img = Image.open(uploaded_file) st.image(img, width=100) with col2: if st.button("❌ Remove"): st.session_state.uploaded_image = None st.session_state.show_image_upload = False st.rerun() # Referenced products area (shown above chat input, each can be removed) render_referenced_products_in_input() # Input row col1, col2 = st.columns([1, 12]) with col1: # Image upload toggle button if st.button("➕", help="Add image", width="stretch"): st.session_state.show_image_upload = ( not st.session_state.show_image_upload ) st.rerun() with col2: # Text input user_query = st.chat_input( "Ask about fashion products...", key="chat_input", ) st.markdown("
", unsafe_allow_html=True) # Process user input if user_query: raw_user_query = user_query referenced_products = list(st.session_state.get("referenced_products", [])) agent_query = raw_user_query if referenced_products: agent_query = f"{_build_reference_prefix(referenced_products)}\n\n{raw_user_query}" # Ensure shopping agent is initialized if "shopping_agent" not in st.session_state: st.error("Session not initialized. Please refresh the page.") st.stop() # Save uploaded image if present, or get from recent history image_path = None if st.session_state.uploaded_image: # User explicitly uploaded an image for this query image_path = save_uploaded_image(st.session_state.uploaded_image) else: # Check if query refers to a previous image if any( ref in raw_user_query.lower() for ref in [ "this", "that", "the image", "the shirt", "the product", "it", ] ): # Find the most recent message with an image for msg in reversed(st.session_state.messages): if msg.get("role") == "user" and msg.get("image_path"): image_path = msg["image_path"] logger.info(f"Using image from previous message: {image_path}") break # Add user message st.session_state.messages.append( { "role": "user", "content": raw_user_query, "image_path": image_path, } ) # References are consumed once this message is sent st.session_state.referenced_products = [] # Display user message immediately with messages_container: display_message(st.session_state.messages[-1]) # Process with shopping agent try: shopping_agent = st.session_state.shopping_agent # Stream assistant updates to UI immediately with messages_container: live_container = st.container() with live_container: live_tool_caption = st.empty() live_debug_placeholder = st.empty() live_response_placeholder = st.empty() live_response = "" live_tool_calls: list[dict] = [] live_debug_steps: list[dict] = [] result = None def _render_live() -> None: if live_tool_calls: tool_names = [tc.get("name", "") for tc in live_tool_calls if tc.get("name")] live_tool_caption.caption(" → ".join(tool_names)) else: live_tool_caption.empty() if st.session_state.get("show_debug") and live_debug_steps: with live_debug_placeholder.container(): render_debug_steps_panel(live_debug_steps, expanded=True) else: live_debug_placeholder.empty() if live_response: live_response_placeholder.markdown(live_response) else: live_response_placeholder.markdown("…") for event in shopping_agent.chat_stream(query=agent_query, image_path=image_path): event_type = event.get("type") if event_type in {"debug_update", "response_delta", "response_replace"}: if "tool_calls" in event: live_tool_calls = event.get("tool_calls", live_tool_calls) if "debug_steps" in event: live_debug_steps = event.get("debug_steps", live_debug_steps) if event_type == "response_delta": live_response = event.get("response", live_response) elif event_type == "response_replace": live_response = event.get("response", live_response) _render_live() elif event_type == "done": result = event.get("result") if not result: result = { "response": live_response or "抱歉,处理您的请求时未返回结果。", "tool_calls": live_tool_calls, "debug_steps": live_debug_steps, "search_refs": {}, "error": True, } response = result["response"] tool_calls = result.get("tool_calls", []) debug_steps = result.get("debug_steps", []) # Add assistant message (store search_refs so refs resolve after rerun) st.session_state.messages.append( { "role": "assistant", "content": response, "tool_calls": tool_calls, "debug_steps": debug_steps, "search_refs": result.get("search_refs", {}), } ) # Clear uploaded image and hide upload area after sending st.session_state.uploaded_image = None st.session_state.show_image_upload = False # Auto-scroll to bottom with JavaScript st.markdown( """ """, unsafe_allow_html=True, ) except Exception as e: logger.error(f"Error processing query: {e}", exc_info=True) error_msg = f"I apologize, I encountered an error: {str(e)}" st.session_state.messages.append( { "role": "assistant", "content": error_msg, } ) # Rerun to update UI st.rerun() if __name__ == "__main__": main()