"""
ShopAgent - Streamlit UI
Multi-modal fashion shopping assistant with conversational AI
"""
import html
import logging
import re
import uuid
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional
import streamlit as st
import streamlit.components.v1 as st_components
from PIL import Image, ImageOps
from app.agents.shopping_agent import ShoppingAgent
from app.search_registry import ProductItem, SearchResult, global_registry
# Matches [SEARCH_RESULTS_REF:sr_xxxxxxxx] tokens embedded in AI responses.
# Case-insensitive, optional spaces around the id.
SEARCH_RESULTS_REF_PATTERN = re.compile(r"\[SEARCH_RESULTS_REF:\s*([a-zA-Z0-9_]+)\s*\]", re.IGNORECASE)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# In-memory image cache (url or "local:path" -> PIL Image), max 100 entries
_IMAGE_CACHE: OrderedDict = OrderedDict()
_IMAGE_CACHE_MAX = 100
# Page config
st.set_page_config(
page_title="ShopAgent",
page_icon="👗",
layout="centered",
initial_sidebar_state="collapsed",
)
# Custom CSS - ChatGPT-like style
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# Initialize session state
def initialize_session():
"""Initialize session state variables"""
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if "shopping_agent" not in st.session_state:
st.session_state.shopping_agent = ShoppingAgent(
session_id=st.session_state.session_id
)
if "messages" not in st.session_state:
st.session_state.messages = []
if "uploaded_image" not in st.session_state:
st.session_state.uploaded_image = None
if "show_image_upload" not in st.session_state:
st.session_state.show_image_upload = False
# Debug panel toggle (default True so 显示调试过程 is checked by default)
if "show_debug" not in st.session_state:
st.session_state.show_debug = True
# Selected products for ask/compare (key -> product info dict)
if "selected_products" not in st.session_state:
st.session_state.selected_products = {}
# Right side panel: visible, mode in ("similar", "compare"), payload (e.g. ref_id, query, or list of selected items)
if "side_panel" not in st.session_state:
st.session_state.side_panel = {
"visible": False,
"mode": None,
"payload": None,
}
# Products currently referenced in chat input (list of product summary dicts)
if "referenced_products" not in st.session_state:
st.session_state.referenced_products = []
def save_uploaded_image(uploaded_file) -> Optional[str]:
"""Save uploaded image to temp directory"""
if uploaded_file is None:
return None
try:
temp_dir = Path("temp_uploads")
temp_dir.mkdir(exist_ok=True)
image_path = temp_dir / f"{st.session_state.session_id}_{uploaded_file.name}"
with open(image_path, "wb") as f:
f.write(uploaded_file.getbuffer())
logger.info(f"Saved uploaded image to {image_path}")
return str(image_path)
except Exception as e:
logger.error(f"Error saving uploaded image: {e}")
st.error(f"Failed to save image: {str(e)}")
return None
def _product_key(ref_id: str, index: int, product: ProductItem) -> str:
"""Stable unique key for a product in the session (for selection and side panel)."""
return f"{ref_id}_{index}_{product.spu_id or index}"
def _product_to_info(product: ProductItem, ref_id: str) -> dict:
"""Serialize product to a small dict for selected_products and ask/compare."""
return {
"ref_id": ref_id,
"spu_id": product.spu_id,
"sku_id": product.spu_id,
"title": product.title or "未知商品",
"price": product.price,
"tags": product.tags or [],
"specifications": product.specifications or [],
}
def _compact_field(value: Any) -> str:
"""Format a field into one readable line for chat reference payload."""
if value is None:
return "-"
if isinstance(value, list):
if not value:
return "-"
parts = []
for item in value:
if isinstance(item, dict):
text = ", ".join(f"{k}:{v}" for k, v in item.items())
parts.append(text if text else str(item))
else:
parts.append(str(item))
return " | ".join(p for p in parts if p) or "-"
return str(value)
def _build_reference_prefix(products: list[dict]) -> str:
"""Build backend prompt prefix for 'chat with referenced products'."""
lines = [f"引用 {len(products)} 款商品:"]
for i, p in enumerate(products, 1):
sku_id = _compact_field(p.get("sku_id") or p.get("spu_id"))
title = _compact_field(p.get("title"))
price = _compact_field(p.get("price"))
tags = _compact_field(p.get("tags"))
specifications = _compact_field(p.get("specifications"))
lines.append(
f"{i}. sku_id={sku_id}; title={title}; price={price}; "
f"tags={tags}; specifications={specifications}"
)
return "\n".join(lines)
@st.fragment
def render_referenced_products_in_input() -> None:
"""Render referenced products above chat input, each with remove button."""
refs = st.session_state.get("referenced_products", [])
if not refs:
return
st.markdown("**已引用商品**")
remove_idx = None
for idx, item in enumerate(refs):
with st.container(border=True):
c1, c2 = st.columns([12, 1])
with c1:
title = (item.get("title") or "未知商品")[:80]
st.markdown(f"**{title}**")
st.caption(
f"sku_id={item.get('sku_id') or item.get('spu_id') or '-'}; "
f"price={_compact_field(item.get('price'))}; "
f"tags={_compact_field(item.get('tags'))}; "
f"specifications={_compact_field(item.get('specifications'))}"
)
with c2:
if st.button("✕", key=f"remove_ref_{idx}", help="删除该引用"):
remove_idx = idx
if remove_idx is not None:
refs.pop(remove_idx)
st.session_state.referenced_products = refs
st.rerun()
def _load_product_image(product: ProductItem) -> Optional[Image.Image]:
"""Load product image with cache: image_url or local data/images. Cache key = url or 'local:path'."""
cache_key: Optional[str] = None
if product.image_url:
cache_key = product.image_url
if cache_key in _IMAGE_CACHE:
_IMAGE_CACHE.move_to_end(cache_key)
return _IMAGE_CACHE[cache_key]
try:
import io
import requests
resp = requests.get(product.image_url, timeout=10)
if resp.status_code == 200:
img = Image.open(io.BytesIO(resp.content))
_IMAGE_CACHE[cache_key] = img
_IMAGE_CACHE.move_to_end(cache_key)
if len(_IMAGE_CACHE) > _IMAGE_CACHE_MAX:
_IMAGE_CACHE.popitem(last=False)
return img
except Exception as e:
logger.debug(f"Remote image fetch failed for {product.spu_id}: {e}")
local = Path(f"data/images/{product.spu_id}.jpg")
if local.exists():
cache_key = f"local:{local}"
if cache_key in _IMAGE_CACHE:
_IMAGE_CACHE.move_to_end(cache_key)
return _IMAGE_CACHE[cache_key]
try:
img = Image.open(local)
_IMAGE_CACHE[cache_key] = img
_IMAGE_CACHE.move_to_end(cache_key)
if len(_IMAGE_CACHE) > _IMAGE_CACHE_MAX:
_IMAGE_CACHE.popitem(last=False)
return img
except Exception as e:
logger.debug(f"Local image load failed {local}: {e}")
return None
def display_product_card_from_item(
product: ProductItem,
ref_id: str,
index: int,
widget_prefix: str = "",
) -> None:
"""Render a single product card with hover actions: Similar products + checkbox."""
pkey = _product_key(ref_id, index, product)
key_suffix = f"{widget_prefix}_{pkey}" if widget_prefix else pkey
info = _product_to_info(product, ref_id)
selected = st.session_state.selected_products
st.markdown('
', unsafe_allow_html=True)
img = _load_product_image(product)
if img:
target = (220, 220)
try:
img = ImageOps.fit(img, target, method=Image.Resampling.LANCZOS)
except AttributeError:
img = ImageOps.fit(img, target, method=Image.LANCZOS)
st.image(img, width="stretch")
else:
st.markdown(
'
🛍️
',
unsafe_allow_html=True,
)
title = product.title or "未知商品"
st.markdown(f"**{title[:40]}**" + ("…" if len(title) > 40 else ""))
if product.price is not None:
st.caption(f"¥{product.price:.2f}")
label_style = "⭐" if product.match_label == "Relevant" else "✦"
st.caption(f"{label_style} {product.match_label}")
st.markdown('
', unsafe_allow_html=True)
col_a, col_b = st.columns([1, 1])
with col_a:
similar_clicked = st.button(
"Similar products",
key=f"similar_{key_suffix}",
help="Search by this product title and show in side panel",
)
with col_b:
is_checked = st.checkbox(
"Select",
key=f"select_{key_suffix}",
value=(pkey in selected),
label_visibility="collapsed",
)
st.markdown("
", unsafe_allow_html=True)
st.markdown("
", unsafe_allow_html=True)
if similar_clicked:
search_query = (product.title or "").strip() or "商品"
st.session_state.side_panel = {
"visible": True,
"mode": "similar",
"payload": {"query": search_query, "loading": True},
}
st.rerun()
if is_checked:
if pkey not in selected:
selected[pkey] = info
else:
selected.pop(pkey, None)
def render_search_result_block(result: SearchResult, widget_prefix: str = "") -> None:
"""
Render a full search result block in place of a [SEARCH_RESULTS_REF:ref_id] token.
widget_prefix: unique per (message, ref block) so Streamlit widget keys stay unique.
"""
summary_line = f' · {result.quality_summary}' if result.quality_summary else ''
header_html = (
f''
f''
f'🔍 {result.query}'
f' · Relevant {result.perfect_count} 件'
f' · Partially Relevant {result.partial_count} 件'
f'{summary_line}'
f'
'
)
st.markdown(header_html, unsafe_allow_html=True)
# Perfect matches first, fall back to partials if none
perfect = [p for p in result.products if p.match_label == "Relevant"]
partial = [p for p in result.products if p.match_label == "Partially Relevant"]
to_show = (perfect + partial)[:6] if perfect else partial[:6]
if not to_show:
st.caption("(本次搜索未找到可展示的商品)")
return
cols = st.columns(min(len(to_show), 3))
for i, product in enumerate(to_show):
with cols[i % 3]:
display_product_card_from_item(
product, result.ref_id, i, widget_prefix=widget_prefix
)
def render_message_with_refs(
content: str,
session_id: str,
fallback_refs: Optional[dict] = None,
msg_index: int = 0,
) -> None:
"""
Render an assistant message that may contain [SEARCH_RESULTS_REF:ref_id] tokens.
msg_index: message index in chat, used to keep widget keys unique across messages.
"""
fallback_refs = fallback_refs or {}
parts = SEARCH_RESULTS_REF_PATTERN.split(content)
for i, segment in enumerate(parts):
if i % 2 == 0:
text = segment.strip()
if text:
st.markdown(text)
else:
ref_id = segment.strip()
result = global_registry.get(session_id, ref_id) or fallback_refs.get(ref_id)
if result:
widget_prefix = f"m{msg_index}_r{i}"
render_search_result_block(result, widget_prefix=widget_prefix)
else:
st.caption(f"[搜索结果 {ref_id} 不可用]")
def render_debug_steps_panel(debug_steps: list[dict], expanded: bool = True) -> None:
"""Render debug steps with thinking/tool details."""
with st.expander("思考 & 工具调用详细过程", expanded=expanded):
for idx, step in enumerate(debug_steps, 1):
node = step.get("node", "unknown")
st.markdown(f"**Step {idx} – {node}**")
if node == "agent":
msgs = step.get("messages", [])
if msgs:
st.markdown("**Agent Messages**")
for m in msgs:
st.markdown(f"- `{m.get('type', 'assistant')}`: {m.get('content', '')}")
if m.get("thinking"):
st.markdown(" - `thinking`:")
st.code(m.get("thinking", ""), language="text")
tcs = step.get("tool_calls", [])
if tcs:
st.markdown("**Planned Tool Calls**")
for j, tc in enumerate(tcs, 1):
st.markdown(f"- **{j}. {tc.get('name')}**")
st.code(tc.get("args", {}), language="json")
elif node == "tools":
results = step.get("results", [])
if results:
st.markdown("**Tool Results**")
for j, r in enumerate(results, 1):
st.markdown(f"- **Result {j}:**")
st.code(r.get("content", ""), language="text")
st.markdown("---")
def display_message(message: dict, msg_index: int = 0):
"""Display a chat message. msg_index keeps widget keys unique across messages."""
role = message["role"]
content = message["content"]
image_path = message.get("image_path")
tool_calls = message.get("tool_calls", [])
debug_steps = message.get("debug_steps", [])
if role == "user":
st.markdown('', unsafe_allow_html=True)
if image_path and Path(image_path).exists():
try:
img = Image.open(image_path)
st.image(img, width=200)
except Exception:
logger.warning(f"Failed to load user uploaded image: {image_path}")
st.markdown(content)
st.markdown("
", unsafe_allow_html=True)
else: # assistant
# Tool call breadcrumb
if tool_calls:
tool_names = [tc["name"] for tc in tool_calls]
st.caption(" → ".join(tool_names))
st.markdown("")
# Debug panel
if debug_steps and st.session_state.get("show_debug"):
render_debug_steps_panel(debug_steps, expanded=True)
# Render message: expand [SEARCH_RESULTS_REF:ref_id] tokens into product card blocks
session_id = st.session_state.get("session_id", "")
render_message_with_refs(
content, session_id, fallback_refs=message.get("search_refs"), msg_index=msg_index
)
st.markdown("", unsafe_allow_html=True)
@st.fragment
def render_bottom_actions_bar() -> None:
"""Show Ask and Compare when there are selected products. Disabled when none selected."""
selected = st.session_state.selected_products
n = len(selected)
if n == 0:
return
st.markdown(
'',
unsafe_allow_html=True,
)
col_sel, col_ask, col_cmp = st.columns([2, 1, 1])
with col_sel:
st.caption(f"Selected: {n}")
with col_ask:
ask_clicked = st.button("Ask", key="bottom_ask", help="Continue conversation with selected products")
with col_cmp:
compare_clicked = st.button("Compare", key="bottom_compare", help="Compare selected products")
st.markdown("
", unsafe_allow_html=True)
if ask_clicked:
st.session_state.referenced_products = list(selected.values())
st.rerun()
if compare_clicked:
st.session_state.side_panel = {
"visible": True,
"mode": "compare",
"payload": list(selected.values()),
}
st.rerun()
def render_side_drawer() -> None:
"""Render a fixed overlay side drawer that does not change background layout."""
panel = st.session_state.side_panel
if not panel.get("visible") or not panel.get("mode"):
return
mode = panel["mode"]
payload = panel.get("payload") or {}
session_id = st.session_state.get("session_id", "")
title = "Similar products" if mode == "similar" else "Compare"
body_html = ""
if mode == "similar":
query = html.escape((payload.get("query") or ""))
if payload.get("loading"):
body_html = '加载中…
'
elif payload.get("products") is not None:
to_show = payload["products"][:12]
cards = []
for product in to_show:
p_title = html.escape((product.title or "未知商品")[:80])
price = (
f"¥{product.price:.2f}"
if product.price is not None
else "价格待更新"
)
image_html = (
f'
'
if product.image_url
else '🛍️
'
)
cards.append(
''
f"{image_html}"
'
'
f'
{p_title}
'
f'
{price}
'
"
"
)
cards_html = "".join(cards) if cards else '(未找到可展示的商品)
'
body_html = (
f''
f'基于「{query}」的搜索结果:
'
'' + cards_html + "
"
)
else:
# Legacy: ref_id from registry (e.g. from chat)
ref_id = payload.get("ref_id")
if ref_id:
result = global_registry.get(session_id, ref_id)
if result:
to_show = (result.products or [])[:12]
cards = []
for product in to_show:
p_title = html.escape((product.title or "未知商品")[:80])
price = f"¥{product.price:.2f}" if product.price is not None else "价格待更新"
image_html = (
f'
'
if product.image_url
else '🛍️
'
)
cards.append(
''
)
body_html = (
f'基于「{query}」的搜索结果:
'
'' + "".join(cards) + "
"
)
else:
body_html = f'[搜索结果 {html.escape(ref_id)} 不可用]
'
else:
body_html = '搜索失败或暂无结果。
'
else:
items = payload if isinstance(payload, list) else []
if items:
rows = []
for item in items:
t = html.escape((item.get("title") or "未知商品")[:80])
p = item.get("price")
ptext = f"¥{p:.2f}" if p is not None else "价格待更新"
rows.append(
'"
)
items_html = "".join(rows)
else:
items_html = '当前未选中商品。
'
body_html = (
'已选商品:
'
f'{items_html}
'
'对比功能暂未实现。
'
)
st.markdown(
f"""
""",
unsafe_allow_html=True,
)
st_components.html("""
""", height=0)
def display_welcome():
"""Display welcome screen"""
col1, col2, col3, col4 = st.columns(4)
with col1:
st.markdown(
"""
""",
unsafe_allow_html=True,
)
with col2:
st.markdown(
"""
🛍️
懂商品
深度理解店铺内所有商品,智能匹配你的需求
""",
unsafe_allow_html=True,
)
with col3:
st.markdown(
"""
""",
unsafe_allow_html=True,
)
with col4:
st.markdown(
"""
""",
unsafe_allow_html=True,
)
st.markdown("
", unsafe_allow_html=True)
def main():
"""Main Streamlit app"""
initialize_session()
# Sync drawer close state from JS (set by JS via history.replaceState, no reload)
if st.query_params.get("close_side_panel"):
st.session_state.side_panel = {"visible": False, "mode": None, "payload": None}
st.query_params.clear()
# "Similar" panel: if loading, run API-only search and rerun
panel = st.session_state.side_panel
if panel.get("visible") and panel.get("mode") == "similar":
payload = panel.get("payload") or {}
if payload.get("loading") and payload.get("query"):
from app.tools.search_tools import search_products_api_only
products = search_products_api_only(payload["query"], limit=12)
st.session_state.side_panel["payload"] = {
"query": payload["query"],
"products": products,
"loading": False,
}
st.rerun()
# Drawer rendered early so fixed positioning works from top of DOM
render_side_drawer()
# Header
st.markdown(
"""
""",
unsafe_allow_html=True,
)
# Sidebar (collapsed by default, but accessible)
@st.fragment
def _sidebar_fragment():
st.markdown("### ⚙️ Settings")
if st.button("🗑️ Clear Chat", width="stretch"):
if "shopping_agent" in st.session_state:
st.session_state.shopping_agent.clear_history()
session_id = st.session_state.get("session_id", "")
if session_id:
global_registry.clear_session(session_id)
st.session_state.messages = []
st.session_state.uploaded_image = None
st.session_state.selected_products = {}
st.session_state.referenced_products = []
st.session_state.side_panel = {"visible": False, "mode": None, "payload": None}
st.rerun()
st.markdown("---")
st.checkbox(
"显示调试过程 (debug)",
key="show_debug",
value=True,
help="展开后可查看中间思考过程及工具调用详情",
)
st.markdown("---")
st.caption(f"Session: `{st.session_state.session_id[:8]}...`")
with st.sidebar:
_sidebar_fragment()
MAX_MESSAGES = 50
messages_container = st.container()
with messages_container:
if not st.session_state.messages:
display_welcome()
else:
messages = st.session_state.messages
start_idx = max(0, len(messages) - MAX_MESSAGES)
to_show = messages[start_idx:]
if len(messages) > MAX_MESSAGES:
st.caption(f"(仅显示最近 {MAX_MESSAGES} 条,共 {len(messages)} 条消息)")
for i, message in enumerate(to_show):
display_message(message, msg_index=start_idx + i)
render_bottom_actions_bar()
# Fixed input area at bottom (using container to simulate fixed position)
st.markdown('', unsafe_allow_html=True)
input_container = st.container()
with input_container:
# Image upload area (shown when + is clicked)
if st.session_state.show_image_upload:
uploaded_file = st.file_uploader(
"Choose an image",
type=["jpg", "jpeg", "png"],
key="file_uploader",
)
if uploaded_file:
st.session_state.uploaded_image = uploaded_file
# Show preview
col1, col2 = st.columns([1, 4])
with col1:
img = Image.open(uploaded_file)
st.image(img, width=100)
with col2:
if st.button("❌ Remove"):
st.session_state.uploaded_image = None
st.session_state.show_image_upload = False
st.rerun()
# Referenced products area (shown above chat input, each can be removed)
render_referenced_products_in_input()
# Input row
col1, col2 = st.columns([1, 12])
with col1:
# Image upload toggle button
if st.button("➕", help="Add image", width="stretch"):
st.session_state.show_image_upload = (
not st.session_state.show_image_upload
)
st.rerun()
with col2:
# Text input
user_query = st.chat_input(
"Ask about fashion products...",
key="chat_input",
)
st.markdown("
", unsafe_allow_html=True)
# Process user input
if user_query:
raw_user_query = user_query
referenced_products = list(st.session_state.get("referenced_products", []))
agent_query = raw_user_query
if referenced_products:
agent_query = f"{_build_reference_prefix(referenced_products)}\n\n{raw_user_query}"
# Ensure shopping agent is initialized
if "shopping_agent" not in st.session_state:
st.error("Session not initialized. Please refresh the page.")
st.stop()
# Save uploaded image if present, or get from recent history
image_path = None
if st.session_state.uploaded_image:
# User explicitly uploaded an image for this query
image_path = save_uploaded_image(st.session_state.uploaded_image)
else:
# Check if query refers to a previous image
if any(
ref in raw_user_query.lower()
for ref in [
"this",
"that",
"the image",
"the shirt",
"the product",
"it",
]
):
# Find the most recent message with an image
for msg in reversed(st.session_state.messages):
if msg.get("role") == "user" and msg.get("image_path"):
image_path = msg["image_path"]
logger.info(f"Using image from previous message: {image_path}")
break
# Add user message
st.session_state.messages.append(
{
"role": "user",
"content": raw_user_query,
"image_path": image_path,
}
)
# References are consumed once this message is sent
st.session_state.referenced_products = []
# Display user message immediately
with messages_container:
display_message(st.session_state.messages[-1])
# Process with shopping agent
try:
shopping_agent = st.session_state.shopping_agent
# Stream assistant updates to UI immediately
with messages_container:
live_container = st.container()
with live_container:
live_tool_caption = st.empty()
live_debug_placeholder = st.empty()
live_response_placeholder = st.empty()
live_response = ""
live_tool_calls: list[dict] = []
live_debug_steps: list[dict] = []
result = None
def _render_live() -> None:
if live_tool_calls:
tool_names = [tc.get("name", "") for tc in live_tool_calls if tc.get("name")]
live_tool_caption.caption(" → ".join(tool_names))
else:
live_tool_caption.empty()
if st.session_state.get("show_debug") and live_debug_steps:
with live_debug_placeholder.container():
render_debug_steps_panel(live_debug_steps, expanded=True)
else:
live_debug_placeholder.empty()
if live_response:
live_response_placeholder.markdown(live_response)
else:
live_response_placeholder.markdown("…")
for event in shopping_agent.chat_stream(query=agent_query, image_path=image_path):
event_type = event.get("type")
if event_type in {"debug_update", "response_delta", "response_replace"}:
if "tool_calls" in event:
live_tool_calls = event.get("tool_calls", live_tool_calls)
if "debug_steps" in event:
live_debug_steps = event.get("debug_steps", live_debug_steps)
if event_type == "response_delta":
live_response = event.get("response", live_response)
elif event_type == "response_replace":
live_response = event.get("response", live_response)
_render_live()
elif event_type == "done":
result = event.get("result")
if not result:
result = {
"response": live_response or "抱歉,处理您的请求时未返回结果。",
"tool_calls": live_tool_calls,
"debug_steps": live_debug_steps,
"search_refs": {},
"error": True,
}
response = result["response"]
tool_calls = result.get("tool_calls", [])
debug_steps = result.get("debug_steps", [])
# Add assistant message (store search_refs so refs resolve after rerun)
st.session_state.messages.append(
{
"role": "assistant",
"content": response,
"tool_calls": tool_calls,
"debug_steps": debug_steps,
"search_refs": result.get("search_refs", {}),
}
)
# Clear uploaded image and hide upload area after sending
st.session_state.uploaded_image = None
st.session_state.show_image_upload = False
# Auto-scroll to bottom with JavaScript
st.markdown(
"""
""",
unsafe_allow_html=True,
)
except Exception as e:
logger.error(f"Error processing query: {e}", exc_info=True)
error_msg = f"I apologize, I encountered an error: {str(e)}"
st.session_state.messages.append(
{
"role": "assistant",
"content": error_msg,
}
)
# Rerun to update UI
st.rerun()
if __name__ == "__main__":
main()