"""
ShopAgent - Streamlit UI
Multi-modal fashion shopping assistant with conversational AI
"""
import html
import logging
import re
import uuid
from pathlib import Path
from typing import Any, Optional
import streamlit as st
import streamlit.components.v1 as st_components
from PIL import Image, ImageOps
from app.agents.shopping_agent import ShoppingAgent
from app.search_registry import ProductItem, SearchResult, global_registry
# Matches [SEARCH_REF:sr_xxxxxxxx] tokens embedded in AI responses.
# Case-insensitive, optional spaces around the id.
SEARCH_REF_PATTERN = re.compile(r"\[SEARCH_REF:\s*([a-zA-Z0-9_]+)\s*\]", re.IGNORECASE)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Page config
st.set_page_config(
page_title="ShopAgent",
page_icon="👗",
layout="centered",
initial_sidebar_state="collapsed",
)
# Custom CSS - ChatGPT-like style
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# Initialize session state
def initialize_session():
"""Initialize session state variables"""
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if "shopping_agent" not in st.session_state:
st.session_state.shopping_agent = ShoppingAgent(
session_id=st.session_state.session_id
)
if "messages" not in st.session_state:
st.session_state.messages = []
if "uploaded_image" not in st.session_state:
st.session_state.uploaded_image = None
if "show_image_upload" not in st.session_state:
st.session_state.show_image_upload = False
# Debug panel toggle (default True so 显示调试过程 is checked by default)
if "show_debug" not in st.session_state:
st.session_state.show_debug = True
# Selected products for ask/compare (key -> product info dict)
if "selected_products" not in st.session_state:
st.session_state.selected_products = {}
# Right side panel: visible, mode in ("similar", "compare"), payload (e.g. ref_id, query, or list of selected items)
if "side_panel" not in st.session_state:
st.session_state.side_panel = {
"visible": False,
"mode": None,
"payload": None,
}
# Products currently referenced in chat input (list of product summary dicts)
if "referenced_products" not in st.session_state:
st.session_state.referenced_products = []
def save_uploaded_image(uploaded_file) -> Optional[str]:
"""Save uploaded image to temp directory"""
if uploaded_file is None:
return None
try:
temp_dir = Path("temp_uploads")
temp_dir.mkdir(exist_ok=True)
image_path = temp_dir / f"{st.session_state.session_id}_{uploaded_file.name}"
with open(image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
logger.info(f"Saved uploaded image to {image_path}")
return str(image_path)
except Exception as e:
logger.error(f"Error saving uploaded image: {e}")
st.error(f"Failed to save image: {str(e)}")
return None
def _product_key(ref_id: str, index: int, product: ProductItem) -> str:
"""Stable unique key for a product in the session (for selection and side panel)."""
return f"{ref_id}_{index}_{product.spu_id or index}"
def _product_to_info(product: ProductItem, ref_id: str) -> dict:
"""Serialize product to a small dict for selected_products and ask/compare."""
return {
"ref_id": ref_id,
"spu_id": product.spu_id,
"sku_id": product.spu_id,
"title": product.title or "未知商品",
"price": product.price,
"tags": product.tags or [],
"specifications": product.specifications or [],
}
def _compact_field(value: Any) -> str:
"""Format a field into one readable line for chat reference payload."""
if value is None:
return "-"
if isinstance(value, list):
if not value:
return "-"
parts = []
for item in value:
if isinstance(item, dict):
text = ", ".join(f"{k}:{v}" for k, v in item.items())
parts.append(text if text else str(item))
else:
parts.append(str(item))
return " | ".join(p for p in parts if p) or "-"
return str(value)
def _build_reference_prefix(products: list[dict]) -> str:
"""Build backend prompt prefix for 'chat with referenced products'."""
lines = [f"引用 {len(products)} 款商品:"]
for i, p in enumerate(products, 1):
sku_id = _compact_field(p.get("sku_id") or p.get("spu_id"))
title = _compact_field(p.get("title"))
price = _compact_field(p.get("price"))
tags = _compact_field(p.get("tags"))
specifications = _compact_field(p.get("specifications"))
lines.append(
f"{i}. sku_id={sku_id}; title={title}; price={price}; "
f"tags={tags}; specifications={specifications}"
)
return "\n".join(lines)
def render_referenced_products_in_input() -> None:
"""Render referenced products above chat input, each with remove button."""
refs = st.session_state.get("referenced_products", [])
if not refs:
return
st.markdown("**已引用商品**")
remove_idx = None
for idx, item in enumerate(refs):
with st.container(border=True):
c1, c2 = st.columns([12, 1])
with c1:
title = (item.get("title") or "未知商品")[:80]
st.markdown(f"**{title}**")
st.caption(
f"sku_id={item.get('sku_id') or item.get('spu_id') or '-'}; "
f"price={_compact_field(item.get('price'))}; "
f"tags={_compact_field(item.get('tags'))}; "
f"specifications={_compact_field(item.get('specifications'))}"
)
with c2:
if st.button("✕", key=f"remove_ref_{idx}", help="删除该引用"):
remove_idx = idx
if remove_idx is not None:
refs.pop(remove_idx)
st.session_state.referenced_products = refs
st.rerun()
def _load_product_image(product: ProductItem) -> Optional[Image.Image]:
"""Try to load a product image: image_url from API (normalized when stored) → local data/images → None."""
if product.image_url:
try:
import requests
resp = requests.get(product.image_url, timeout=10)
if resp.status_code == 200:
import io
return Image.open(io.BytesIO(resp.content))
except Exception as e:
logger.debug(f"Remote image fetch failed for {product.spu_id}: {e}")
local = Path(f"data/images/{product.spu_id}.jpg")
if local.exists():
try:
return Image.open(local)
except Exception as e:
logger.debug(f"Local image load failed {local}: {e}")
return None
def _run_similar_search(query: str) -> Optional[str]:
"""Run product search with query, register result, return new ref_id or None."""
if not query or not query.strip():
return None
from app.tools.search_tools import make_search_products_tool
session_id = st.session_state.get("session_id", "")
if not session_id:
return None
tool = make_search_products_tool(session_id, global_registry)
try:
out = tool.invoke({"query": query.strip()})
match = SEARCH_REF_PATTERN.search(out)
if match:
return match.group(1).strip()
except Exception as e:
logger.warning(f"Similar search failed: {e}")
return None
def display_product_card_from_item(
product: ProductItem,
ref_id: str,
index: int,
widget_prefix: str = "",
) -> None:
"""Render a single product card with hover actions: Similar products + checkbox."""
pkey = _product_key(ref_id, index, product)
key_suffix = f"{widget_prefix}_{pkey}" if widget_prefix else pkey
info = _product_to_info(product, ref_id)
selected = st.session_state.selected_products
st.markdown('
', unsafe_allow_html=True)
img = _load_product_image(product)
if img:
target = (220, 220)
try:
img = ImageOps.fit(img, target, method=Image.Resampling.LANCZOS)
except AttributeError:
img = ImageOps.fit(img, target, method=Image.LANCZOS)
st.image(img, width="stretch")
else:
st.markdown(
'
🛍️
',
unsafe_allow_html=True,
)
title = product.title or "未知商品"
st.markdown(f"**{title[:40]}**" + ("…" if len(title) > 40 else ""))
if product.price is not None:
st.caption(f"¥{product.price:.2f}")
label_style = "⭐" if product.match_label == "Relevant" else "✦"
st.caption(f"{label_style} {product.match_label}")
st.markdown('
', unsafe_allow_html=True)
col_a, col_b = st.columns([1, 1])
with col_a:
similar_clicked = st.button(
"Similar products",
key=f"similar_{key_suffix}",
help="Search by this product title and show in side panel",
)
with col_b:
is_checked = st.checkbox(
"Select",
key=f"select_{key_suffix}",
value=(pkey in selected),
label_visibility="collapsed",
)
st.markdown("
", unsafe_allow_html=True)
st.markdown("
", unsafe_allow_html=True)
if similar_clicked:
search_query = (product.title or "").strip() or "商品"
new_ref = _run_similar_search(search_query)
if new_ref:
st.session_state.side_panel = {
"visible": True,
"mode": "similar",
"payload": {"ref_id": new_ref, "query": search_query},
}
else:
st.session_state.side_panel = {
"visible": True,
"mode": "similar",
"payload": {"ref_id": None, "query": search_query, "error": True},
}
st.rerun()
if is_checked:
if pkey not in selected:
selected[pkey] = info
else:
selected.pop(pkey, None)
def render_search_result_block(result: SearchResult, widget_prefix: str = "") -> None:
"""
Render a full search result block in place of a [SEARCH_REF:ref_id] token.
widget_prefix: unique per (message, ref block) so Streamlit widget keys stay unique.
"""
summary_line = f' · {result.quality_summary}' if result.quality_summary else ''
header_html = (
f''
f''
f'🔍 {result.query}'
f' · Relevant {result.perfect_count} 件'
f' · Partially Relevant {result.partial_count} 件'
f'{summary_line}'
f'
'
)
st.markdown(header_html, unsafe_allow_html=True)
# Perfect matches first, fall back to partials if none
perfect = [p for p in result.products if p.match_label == "Relevant"]
partial = [p for p in result.products if p.match_label == "Partially Relevant"]
to_show = (perfect + partial)[:6] if perfect else partial[:6]
if not to_show:
st.caption("(本次搜索未找到可展示的商品)")
return
cols = st.columns(min(len(to_show), 3))
for i, product in enumerate(to_show):
with cols[i % 3]:
display_product_card_from_item(
product, result.ref_id, i, widget_prefix=widget_prefix
)
def render_message_with_refs(
content: str,
session_id: str,
fallback_refs: Optional[dict] = None,
msg_index: int = 0,
) -> None:
"""
Render an assistant message that may contain [SEARCH_REF:ref_id] tokens.
msg_index: message index in chat, used to keep widget keys unique across messages.
"""
fallback_refs = fallback_refs or {}
parts = SEARCH_REF_PATTERN.split(content)
for i, segment in enumerate(parts):
if i % 2 == 0:
text = segment.strip()
if text:
st.markdown(text)
else:
ref_id = segment.strip()
result = global_registry.get(session_id, ref_id) or fallback_refs.get(ref_id)
if result:
widget_prefix = f"m{msg_index}_r{i}"
render_search_result_block(result, widget_prefix=widget_prefix)
else:
st.caption(f"[搜索结果 {ref_id} 不可用]")
def display_message(message: dict, msg_index: int = 0):
"""Display a chat message. msg_index keeps widget keys unique across messages."""
role = message["role"]
content = message["content"]
image_path = message.get("image_path")
tool_calls = message.get("tool_calls", [])
debug_steps = message.get("debug_steps", [])
if role == "user":
st.markdown('', unsafe_allow_html=True)
if image_path and Path(image_path).exists():
try:
img = Image.open(image_path)
st.image(img, width=200)
except Exception:
logger.warning(f"Failed to load user uploaded image: {image_path}")
st.markdown(content)
st.markdown("
", unsafe_allow_html=True)
else: # assistant
# Tool call breadcrumb
if tool_calls:
tool_names = [tc["name"] for tc in tool_calls]
st.caption(" → ".join(tool_names))
st.markdown("")
# Debug panel
if debug_steps and st.session_state.get("show_debug"):
with st.expander("思考 & 工具调用详细过程", expanded=False):
for idx, step in enumerate(debug_steps, 1):
node = step.get("node", "unknown")
st.markdown(f"**Step {idx} – {node}**")
if node == "agent":
msgs = step.get("messages", [])
if msgs:
st.markdown("**Agent Messages**")
for m in msgs:
st.markdown(f"- `{m.get('type', 'assistant')}`: {m.get('content', '')}")
tcs = step.get("tool_calls", [])
if tcs:
st.markdown("**Planned Tool Calls**")
for j, tc in enumerate(tcs, 1):
st.markdown(f"- **{j}. {tc.get('name')}**")
st.code(tc.get("args", {}), language="json")
elif node == "tools":
results = step.get("results", [])
if results:
st.markdown("**Tool Results**")
for j, r in enumerate(results, 1):
st.markdown(f"- **Result {j}:**")
st.code(r.get("content", ""), language="text")
st.markdown("---")
# Render message: expand [SEARCH_REF:ref_id] tokens into product card blocks
session_id = st.session_state.get("session_id", "")
render_message_with_refs(
content, session_id, fallback_refs=message.get("search_refs"), msg_index=msg_index
)
st.markdown("", unsafe_allow_html=True)
def render_bottom_actions_bar() -> None:
"""Show Ask and Compare when there are selected products. Disabled when none selected."""
selected = st.session_state.selected_products
n = len(selected)
if n == 0:
return
st.markdown(
'',
unsafe_allow_html=True,
)
col_sel, col_ask, col_cmp = st.columns([2, 1, 1])
with col_sel:
st.caption(f"Selected: {n}")
with col_ask:
ask_clicked = st.button("Ask", key="bottom_ask", help="Continue conversation with selected products")
with col_cmp:
compare_clicked = st.button("Compare", key="bottom_compare", help="Compare selected products")
st.markdown("
", unsafe_allow_html=True)
if ask_clicked:
st.session_state.referenced_products = list(selected.values())
st.rerun()
if compare_clicked:
st.session_state.side_panel = {
"visible": True,
"mode": "compare",
"payload": list(selected.values()),
}
st.rerun()
def render_side_drawer() -> None:
"""Render a fixed overlay side drawer that does not change background layout."""
panel = st.session_state.side_panel
if not panel.get("visible") or not panel.get("mode"):
return
mode = panel["mode"]
payload = panel.get("payload") or {}
session_id = st.session_state.get("session_id", "")
title = "Similar products" if mode == "similar" else "Compare"
body_html = ""
if mode == "similar":
ref_id = payload.get("ref_id")
query = html.escape(payload.get("query", ""))
if payload.get("error") or not ref_id:
body_html = '搜索失败或暂无结果。
'
else:
result = global_registry.get(session_id, ref_id)
if not result:
body_html = f'[搜索结果 {html.escape(ref_id)} 不可用]
'
else:
perfect = [p for p in result.products if p.match_label == "Relevant"]
partial = [p for p in result.products if p.match_label == "Partially Relevant"]
to_show = (perfect + partial)[:12] if perfect else partial[:12]
cards = []
for product in to_show:
p_title = html.escape((product.title or "未知商品")[:80])
p_label = html.escape(product.match_label or "Partially Relevant")
price = (
f"¥{product.price:.2f}"
if product.price is not None
else "价格待更新"
)
image_html = (
f'
'
if product.image_url
else '🛍️
'
)
cards.append(
''
f"{image_html}"
'
'
f'
{p_title}
'
f'
{price}
'
f'
{p_label}
'
"
"
)
cards_html = "".join(cards) if cards else '(未找到可展示的商品)
'
body_html = (
f''
f'基于「{query}」的搜索结果:
'
''
f"{cards_html}"
"
"
)
else:
items = payload if isinstance(payload, list) else []
if items:
rows = []
for item in items:
t = html.escape((item.get("title") or "未知商品")[:80])
p = item.get("price")
ptext = f"¥{p:.2f}" if p is not None else "价格待更新"
rows.append(
'"
)
items_html = "".join(rows)
else:
items_html = '当前未选中商品。
'
body_html = (
'已选商品:
'
f'{items_html}
'
'对比功能暂未实现。
'
)
st.markdown(
f"""
""",
unsafe_allow_html=True,
)
st_components.html("""
""", height=0)
def display_welcome():
"""Display welcome screen"""
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown(
"""
""",
unsafe_allow_html=True,
)
with col2:
st.markdown(
"""
🛍️
懂商品
深度理解店铺内所有商品,智能匹配你的需求
""",
unsafe_allow_html=True,
)
with col3:
st.markdown(
"""
""",
unsafe_allow_html=True,
)
with col4:
st.markdown(
"""
""",
unsafe_allow_html=True,
)
st.markdown("
", unsafe_allow_html=True)
def main():
"""Main Streamlit app"""
initialize_session()
# Sync drawer close state from JS (set by JS via history.replaceState, no reload)
if st.query_params.get("close_side_panel"):
st.session_state.side_panel = {"visible": False, "mode": None, "payload": None}
st.query_params.clear()
# Drawer rendered early so fixed positioning works from top of DOM
render_side_drawer()
# Header
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# Sidebar (collapsed by default, but accessible)
with st.sidebar:
st.markdown("### ⚙️ Settings")
if st.button("🗑️ Clear Chat", width="stretch"):
if "shopping_agent" in st.session_state:
st.session_state.shopping_agent.clear_history()
# Clear search result registry for this session
session_id = st.session_state.get("session_id", "")
if session_id:
global_registry.clear_session(session_id)
st.session_state.messages = []
st.session_state.uploaded_image = None
st.session_state.selected_products = {}
st.session_state.referenced_products = []
st.session_state.side_panel = {"visible": False, "mode": None, "payload": None}
st.rerun()
# Debug toggle
st.markdown("---")
st.checkbox(
"显示调试过程 (debug)",
key="show_debug",
value=True,
help="展开后可查看中间思考过程及工具调用详情",
)
st.markdown("---")
st.caption(f"Session: `{st.session_state.session_id[:8]}...`")
messages_container = st.container()
with messages_container:
if not st.session_state.messages:
display_welcome()
else:
for msg_idx, message in enumerate(st.session_state.messages):
display_message(message, msg_index=msg_idx)
render_bottom_actions_bar()
# Fixed input area at bottom (using container to simulate fixed position)
st.markdown('', unsafe_allow_html=True)
input_container = st.container()
with input_container:
# Image upload area (shown when + is clicked)
if st.session_state.show_image_upload:
uploaded_file = st.file_uploader(
"Choose an image",
type=["jpg", "jpeg", "png"],
key="file_uploader",
)
if uploaded_file:
st.session_state.uploaded_image = uploaded_file
# Show preview
col1, col2 = st.columns([1, 4])
with col1:
img = Image.open(uploaded_file)
st.image(img, width=100)
with col2:
if st.button("❌ Remove"):
st.session_state.uploaded_image = None
st.session_state.show_image_upload = False
st.rerun()
# Referenced products area (shown above chat input, each can be removed)
render_referenced_products_in_input()
# Input row
col1, col2 = st.columns([1, 12])
with col1:
# Image upload toggle button
if st.button("➕", help="Add image", width="stretch"):
st.session_state.show_image_upload = (
not st.session_state.show_image_upload
)
st.rerun()
with col2:
# Text input
user_query = st.chat_input(
"Ask about fashion products...",
key="chat_input",
)
st.markdown("
", unsafe_allow_html=True)
# Process user input
if user_query:
raw_user_query = user_query
referenced_products = list(st.session_state.get("referenced_products", []))
agent_query = raw_user_query
if referenced_products:
agent_query = f"{_build_reference_prefix(referenced_products)}\n\n{raw_user_query}"
# Ensure shopping agent is initialized
if "shopping_agent" not in st.session_state:
st.error("Session not initialized. Please refresh the page.")
st.stop()
# Save uploaded image if present, or get from recent history
image_path = None
if st.session_state.uploaded_image:
# User explicitly uploaded an image for this query
image_path = save_uploaded_image(st.session_state.uploaded_image)
else:
# Check if query refers to a previous image
if any(
ref in raw_user_query.lower()
for ref in [
"this",
"that",
"the image",
"the shirt",
"the product",
"it",
]
):
# Find the most recent message with an image
for msg in reversed(st.session_state.messages):
if msg.get("role") == "user" and msg.get("image_path"):
image_path = msg["image_path"]
logger.info(f"Using image from previous message: {image_path}")
break
# Add user message
st.session_state.messages.append(
{
"role": "user",
"content": raw_user_query,
"image_path": image_path,
}
)
# References are consumed once this message is sent
st.session_state.referenced_products = []
# Display user message immediately
with messages_container:
display_message(st.session_state.messages[-1])
# Process with shopping agent
try:
shopping_agent = st.session_state.shopping_agent
# Process with agent
result = shopping_agent.chat(
query=agent_query,
image_path=image_path,
)
response = result["response"]
tool_calls = result.get("tool_calls", [])
debug_steps = result.get("debug_steps", [])
# Add assistant message (store search_refs so refs resolve after rerun)
st.session_state.messages.append(
{
"role": "assistant",
"content": response,
"tool_calls": tool_calls,
"debug_steps": debug_steps,
"search_refs": result.get("search_refs", {}),
}
)
# Clear uploaded image and hide upload area after sending
st.session_state.uploaded_image = None
st.session_state.show_image_upload = False
# Auto-scroll to bottom with JavaScript
st.markdown(
"""
""",
unsafe_allow_html=True,
)
except Exception as e:
logger.error(f"Error processing query: {e}", exc_info=True)
error_msg = f"I apologize, I encountered an error: {str(e)}"
st.session_state.messages.append(
{
"role": "assistant",
"content": error_msg,
}
)
# Rerun to update UI
st.rerun()
if __name__ == "__main__":
main()