""" ShopAgent - Streamlit UI Multi-modal fashion shopping assistant with conversational AI """ import logging import re import uuid from pathlib import Path from typing import Optional import streamlit as st from PIL import Image, ImageOps from app.agents.shopping_agent import ShoppingAgent # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Page config st.set_page_config( page_title="ShopAgent", page_icon="👗", layout="centered", initial_sidebar_state="collapsed", ) # Custom CSS - ChatGPT-like style st.markdown( """ """, unsafe_allow_html=True, ) # Initialize session state def initialize_session(): """Initialize session state variables""" if "session_id" not in st.session_state: st.session_state.session_id = str(uuid.uuid4()) if "shopping_agent" not in st.session_state: st.session_state.shopping_agent = ShoppingAgent( session_id=st.session_state.session_id ) if "messages" not in st.session_state: st.session_state.messages = [] if "uploaded_image" not in st.session_state: st.session_state.uploaded_image = None if "show_image_upload" not in st.session_state: st.session_state.show_image_upload = False # Debug panel toggle if "show_debug" not in st.session_state: st.session_state.show_debug = False def save_uploaded_image(uploaded_file) -> Optional[str]: """Save uploaded image to temp directory""" if uploaded_file is None: return None try: temp_dir = Path("temp_uploads") temp_dir.mkdir(exist_ok=True) image_path = temp_dir / f"{st.session_state.session_id}_{uploaded_file.name}" with open(image_path, "wb") as f: f.write(uploaded_file.getbuffer()) logger.info(f"Saved uploaded image to {image_path}") return str(image_path) except Exception as e: logger.error(f"Error saving uploaded image: {e}") st.error(f"Failed to save image: {str(e)}") return None def extract_products_from_response(response: str) -> list: """Extract product information from agent response Returns list of dicts with product info """ products = [] # Pattern to match product blocks in the response # Looking for ID, name, and other details lines = response.split("\n") current_product = {} for line in lines: line = line.strip() # Match product number (e.g., "1. Product Name" or "**1. Product Name**") if re.match(r"^\*?\*?\d+\.\s+", line): if current_product: products.append(current_product) current_product = {} # Extract product name name = re.sub(r"^\*?\*?\d+\.\s+", "", line) name = name.replace("**", "").strip() current_product["name"] = name # Match ID elif "ID:" in line or "id:" in line: id_match = re.search(r"(?:ID|id):\s*(\d+)", line) if id_match: current_product["id"] = id_match.group(1) # Match Category elif "Category:" in line: cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) if cat_match: current_product["category"] = cat_match.group(1).strip() # Match Color elif "Color:" in line: color_match = re.search(r"Color:\s*(\w+)", line) if color_match: current_product["color"] = color_match.group(1) # Match Gender elif "Gender:" in line: gender_match = re.search(r"Gender:\s*(\w+)", line) if gender_match: current_product["gender"] = gender_match.group(1) # Match Season elif "Season:" in line: season_match = re.search(r"Season:\s*(\w+)", line) if season_match: current_product["season"] = season_match.group(1) # Match Usage elif "Usage:" in line: usage_match = re.search(r"Usage:\s*(\w+)", line) if usage_match: current_product["usage"] = usage_match.group(1) # Match Similarity/Relevance score elif "Similarity:" in line or "Relevance:" in line: score_match = re.search(r"(?:Similarity|Relevance):\s*([\d.]+)%", line) if score_match: current_product["score"] = score_match.group(1) # Add last product if current_product: products.append(current_product) return products def display_product_card(product: dict): """Display a product card with image and name""" product_id = product.get("id", "") name = product.get("name", "Unknown Product") # Debug: log what we got logger.info(f"Displaying product: ID={product_id}, Name={name}") # Try to load image from data/images directory if product_id: image_path = Path(f"data/images/{product_id}.jpg") if image_path.exists(): try: img = Image.open(image_path) # Fixed size for all images target_size = (200, 200) try: # Try new Pillow API img_processed = ImageOps.fit( img, target_size, method=Image.Resampling.LANCZOS ) except AttributeError: # Fallback for older Pillow versions img_processed = ImageOps.fit( img, target_size, method=Image.LANCZOS ) # Display image with fixed width st.image(img_processed, use_container_width=False, width=200) st.markdown(f"**{name}**") st.caption(f"ID: {product_id}") return except Exception as e: logger.warning(f"Failed to load image {image_path}: {e}") else: logger.warning(f"Image not found: {image_path}") # Fallback: no image st.markdown(f"**📷 {name}**") if product_id: st.caption(f"ID: {product_id}") else: st.caption("ID not available") def display_message(message: dict): """Display a chat message""" role = message["role"] content = message["content"] image_path = message.get("image_path") tool_calls = message.get("tool_calls", []) debug_steps = message.get("debug_steps", []) if role == "user": st.markdown('
', unsafe_allow_html=True) if image_path and Path(image_path).exists(): try: img = Image.open(image_path) st.image(img, width=200) except Exception: logger.warning(f"Failed to load user uploaded image: {image_path}") st.markdown(content) st.markdown("
", unsafe_allow_html=True) else: # assistant # Display tool calls horizontally - only tool names if tool_calls: tool_names = [tc["name"] for tc in tool_calls] st.caption(" → ".join(tool_names)) st.markdown("") # Optional: detailed debug panel (reasoning + tool details) if debug_steps and st.session_state.get("show_debug"): with st.expander("思考 & 工具调用详细过程", expanded=False): for idx, step in enumerate(debug_steps, 1): node = step.get("node", "unknown") st.markdown(f"**Step {idx} – {node}**") if node == "agent": msgs = step.get("messages", []) if msgs: st.markdown("**Agent Messages**") for m in msgs: role = m.get("type", "assistant") content = m.get("content", "") st.markdown(f"- `{role}`: {content}") tcs = step.get("tool_calls", []) if tcs: st.markdown("**Planned Tool Calls**") for j, tc in enumerate(tcs, 1): st.markdown(f"- **{j}. {tc.get('name')}**") st.code(tc.get("args", {}), language="json") elif node == "tools": results = step.get("results", []) if results: st.markdown("**Tool Results**") for j, r in enumerate(results, 1): st.markdown(f"- **Result {j}:**") st.code(r.get("content", ""), language="text") st.markdown("---") # Extract and display products if any products = extract_products_from_response(content) # Debug logging logger.info(f"Extracted {len(products)} products from response") for p in products: logger.info(f"Product: {p}") if products: def parse_score(product: dict) -> float: score = product.get("score") if score is None: return 0.0 try: return float(score) except (TypeError, ValueError): return 0.0 # Sort by score and limit to 3 products = sorted(products, key=parse_score, reverse=True)[:3] logger.info(f"Displaying top {len(products)} products") # Display the text response first (without product details) text_lines = [] for line in content.split("\n"): # Skip product detail lines if not any( keyword in line for keyword in [ "ID:", "Category:", "Color:", "Gender:", "Season:", "Usage:", "Similarity:", "Relevance:", ] ): if not re.match(r"^\*?\*?\d+\.\s+", line): text_lines.append(line) intro_text = "\n".join(text_lines).strip() if intro_text: st.markdown(intro_text) # Display product cards in grid st.markdown("
", unsafe_allow_html=True) # Create exactly 3 columns with equal width cols = st.columns(3) for j, product in enumerate(products[:9]): # Ensure max 3 with cols[j]: display_product_card(product) else: # No products found, display full content st.markdown(content) st.markdown("", unsafe_allow_html=True) def display_welcome(): """Display welcome screen""" col1, col2, col3, col4 = st.columns(4) with col1: st.markdown( """
💗
懂你
能记住你的偏好,给你推荐适合的
""", unsafe_allow_html=True, ) with col2: st.markdown( """
🛍️
懂商品
深度理解店铺内所有商品,智能匹配你的需求
""", unsafe_allow_html=True, ) with col3: st.markdown( """
💭
贴心
任意聊
""", unsafe_allow_html=True, ) with col4: st.markdown( """
👗
懂时尚
穿搭顾问 + 轻松对比
""", unsafe_allow_html=True, ) st.markdown("

", unsafe_allow_html=True) def main(): """Main Streamlit app""" initialize_session() # Header st.markdown( """
👗 ShopAgent
AI Fashion Shopping Assistant
""", unsafe_allow_html=True, ) # Sidebar (collapsed by default, but accessible) with st.sidebar: st.markdown("### ⚙️ Settings") if st.button("🗑️ Clear Chat", use_container_width=True): if "shopping_agent" in st.session_state: st.session_state.shopping_agent.clear_history() st.session_state.messages = [] st.session_state.uploaded_image = None st.rerun() # Debug toggle st.markdown("---") st.checkbox( "显示调试过程 (debug)", key="show_debug", help="展开后可查看中间思考过程及工具调用详情", ) st.markdown("---") st.caption(f"Session: `{st.session_state.session_id[:8]}...`") # Chat messages container messages_container = st.container() with messages_container: if not st.session_state.messages: display_welcome() else: for message in st.session_state.messages: display_message(message) # Fixed input area at bottom (using container to simulate fixed position) st.markdown('
', unsafe_allow_html=True) input_container = st.container() with input_container: # Image upload area (shown when + is clicked) if st.session_state.show_image_upload: uploaded_file = st.file_uploader( "Choose an image", type=["jpg", "jpeg", "png"], key="file_uploader", ) if uploaded_file: st.session_state.uploaded_image = uploaded_file # Show preview col1, col2 = st.columns([1, 4]) with col1: img = Image.open(uploaded_file) st.image(img, width=100) with col2: if st.button("❌ Remove"): st.session_state.uploaded_image = None st.session_state.show_image_upload = False st.rerun() # Input row col1, col2 = st.columns([1, 12]) with col1: # Image upload toggle button if st.button("➕", help="Add image", use_container_width=True): st.session_state.show_image_upload = ( not st.session_state.show_image_upload ) st.rerun() with col2: # Text input user_query = st.chat_input( "Ask about fashion products...", key="chat_input", ) st.markdown("
", unsafe_allow_html=True) # Process user input if user_query: # Ensure shopping agent is initialized if "shopping_agent" not in st.session_state: st.error("Session not initialized. Please refresh the page.") st.stop() # Save uploaded image if present, or get from recent history image_path = None if st.session_state.uploaded_image: # User explicitly uploaded an image for this query image_path = save_uploaded_image(st.session_state.uploaded_image) else: # Check if query refers to a previous image query_lower = user_query.lower() if any( ref in query_lower for ref in [ "this", "that", "the image", "the shirt", "the product", "it", ] ): # Find the most recent message with an image for msg in reversed(st.session_state.messages): if msg.get("role") == "user" and msg.get("image_path"): image_path = msg["image_path"] logger.info(f"Using image from previous message: {image_path}") break # Add user message st.session_state.messages.append( { "role": "user", "content": user_query, "image_path": image_path, } ) # Display user message immediately with messages_container: display_message(st.session_state.messages[-1]) # Process with shopping agent try: shopping_agent = st.session_state.shopping_agent # Handle greetings query_lower = user_query.lower().strip() if query_lower in ["hi", "hello", "hey"]: response = """Hello! 👋 I'm your fashion shopping assistant. I can help you: - Search for products by description - Find items similar to images you upload - Analyze product styles What are you looking for today?""" tool_calls = [] else: # Process with agent result = shopping_agent.chat( query=user_query, image_path=image_path, ) response = result["response"] tool_calls = result.get("tool_calls", []) # Add assistant message st.session_state.messages.append( { "role": "assistant", "content": response, "tool_calls": tool_calls, "debug_steps": result.get("debug_steps", []), } ) # Clear uploaded image and hide upload area after sending st.session_state.uploaded_image = None st.session_state.show_image_upload = False # Auto-scroll to bottom with JavaScript st.markdown( """ """, unsafe_allow_html=True, ) except Exception as e: logger.error(f"Error processing query: {e}", exc_info=True) error_msg = f"I apologize, I encountered an error: {str(e)}" st.session_state.messages.append( { "role": "assistant", "content": error_msg, } ) # Rerun to update UI st.rerun() if __name__ == "__main__": main()