""" OmniShopAgent - Streamlit UI Multi-modal fashion shopping assistant with conversational AI """ import logging import re import uuid from pathlib import Path from typing import Optional import streamlit as st from PIL import Image, ImageOps from app.agents.shopping_agent import ShoppingAgent # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) logger = logging.getLogger(__name__) # Page config st.set_page_config( page_title="OmniShopAgent", page_icon="👗", layout="centered", initial_sidebar_state="collapsed", ) # Custom CSS - ChatGPT-like style st.markdown( """ """, unsafe_allow_html=True, ) # Initialize session state def initialize_session(): """Initialize session state variables""" if "session_id" not in st.session_state: st.session_state.session_id = str(uuid.uuid4()) if "shopping_agent" not in st.session_state: st.session_state.shopping_agent = ShoppingAgent( session_id=st.session_state.session_id ) if "messages" not in st.session_state: st.session_state.messages = [] if "uploaded_image" not in st.session_state: st.session_state.uploaded_image = None if "show_image_upload" not in st.session_state: st.session_state.show_image_upload = False def save_uploaded_image(uploaded_file) -> Optional[str]: """Save uploaded image to temp directory""" if uploaded_file is None: return None try: temp_dir = Path("temp_uploads") temp_dir.mkdir(exist_ok=True) image_path = temp_dir / f"{st.session_state.session_id}_{uploaded_file.name}" with open(image_path, "wb") as f: f.write(uploaded_file.getbuffer()) logger.info(f"Saved uploaded image to {image_path}") return str(image_path) except Exception as e: logger.error(f"Error saving uploaded image: {e}") st.error(f"Failed to save image: {str(e)}") return None def extract_products_from_response(response: str) -> list: """Extract product information from agent response Returns list of dicts with product info """ products = [] # Pattern to match product blocks in the response # Looking for ID, name, and other details lines = response.split("\n") current_product = {} for line in lines: line = line.strip() # Match product number (e.g., "1. Product Name" or "**1. Product Name**") if re.match(r"^\*?\*?\d+\.\s+", line): if current_product: products.append(current_product) current_product = {} # Extract product name name = re.sub(r"^\*?\*?\d+\.\s+", "", line) name = name.replace("**", "").strip() current_product["name"] = name # Match ID elif "ID:" in line or "id:" in line: id_match = re.search(r"(?:ID|id):\s*(\d+)", line) if id_match: current_product["id"] = id_match.group(1) # Match Category elif "Category:" in line: cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line) if cat_match: current_product["category"] = cat_match.group(1).strip() # Match Color elif "Color:" in line: color_match = re.search(r"Color:\s*(\w+)", line) if color_match: current_product["color"] = color_match.group(1) # Match Gender elif "Gender:" in line: gender_match = re.search(r"Gender:\s*(\w+)", line) if gender_match: current_product["gender"] = gender_match.group(1) # Match Season elif "Season:" in line: season_match = re.search(r"Season:\s*(\w+)", line) if season_match: current_product["season"] = season_match.group(1) # Match Usage elif "Usage:" in line: usage_match = re.search(r"Usage:\s*(\w+)", line) if usage_match: current_product["usage"] = usage_match.group(1) # Match Similarity/Relevance score elif "Similarity:" in line or "Relevance:" in line: score_match = re.search(r"(?:Similarity|Relevance):\s*([\d.]+)%", line) if score_match: current_product["score"] = score_match.group(1) # Add last product if current_product: products.append(current_product) return products def display_product_card(product: dict): """Display a product card with image and name""" product_id = product.get("id", "") name = product.get("name", "Unknown Product") # Debug: log what we got logger.info(f"Displaying product: ID={product_id}, Name={name}") # Try to load image from data/images directory if product_id: image_path = Path(f"data/images/{product_id}.jpg") if image_path.exists(): try: img = Image.open(image_path) # Fixed size for all images target_size = (200, 200) try: # Try new Pillow API img_processed = ImageOps.fit( img, target_size, method=Image.Resampling.LANCZOS ) except AttributeError: # Fallback for older Pillow versions img_processed = ImageOps.fit( img, target_size, method=Image.LANCZOS ) # Display image with fixed width st.image(img_processed, use_container_width=False, width=200) st.markdown(f"**{name}**") st.caption(f"ID: {product_id}") return except Exception as e: logger.warning(f"Failed to load image {image_path}: {e}") else: logger.warning(f"Image not found: {image_path}") # Fallback: no image st.markdown(f"**📷 {name}**") if product_id: st.caption(f"ID: {product_id}") else: st.caption("ID not available") def display_message(message: dict): """Display a chat message""" role = message["role"] content = message["content"] image_path = message.get("image_path") tool_calls = message.get("tool_calls", []) if role == "user": st.markdown('
", unsafe_allow_html=True) else: # assistant # Display tool calls horizontally - only tool names if tool_calls: tool_names = [tc['name'] for tc in tool_calls] st.caption(" → ".join(tool_names)) st.markdown("") # Extract and display products if any products = extract_products_from_response(content) # Debug logging logger.info(f"Extracted {len(products)} products from response") for p in products: logger.info(f"Product: {p}") if products: def parse_score(product: dict) -> float: score = product.get("score") if score is None: return 0.0 try: return float(score) except (TypeError, ValueError): return 0.0 # Sort by score and limit to 3 products = sorted(products, key=parse_score, reverse=True)[:3] logger.info(f"Displaying top {len(products)} products") # Display the text response first (without product details) text_lines = [] for line in content.split("\n"): # Skip product detail lines if not any( keyword in line for keyword in [ "ID:", "Category:", "Color:", "Gender:", "Season:", "Usage:", "Similarity:", "Relevance:", ] ): if not re.match(r"^\*?\*?\d+\.\s+", line): text_lines.append(line) intro_text = "\n".join(text_lines).strip() if intro_text: st.markdown(intro_text) # Display product cards in grid st.markdown("