""" Conversational Shopping Agent with LangGraph Architecture: - ReAct-style agent: plan → search → evaluate → re-plan or respond - search_products is session-bound, writing curated results to SearchResultRegistry - Final AI message references results via [SEARCH_REF:xxx] tokens instead of re-listing product details; the UI renders product cards from the registry """ import json import logging from pathlib import Path from typing import Any, Optional, Sequence from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode from typing_extensions import Annotated, TypedDict from app.config import settings from app.search_registry import global_registry from app.tools.search_tools import get_all_tools logger = logging.getLogger(__name__) # ── System prompt ────────────────────────────────────────────────────────────── # Universal: works for any e-commerce vertical (fashion, electronics, home, etc.) # Key design decisions: # 1. Guides multi-query search planning with explicit evaluate-and-decide loop # 2. Forbids re-listing product details in the final response # 3. Mandates [SEARCH_REF:xxx] inline citation as the only product presentation mechanism SYSTEM_PROMPT = """角色定义 你是一名专业的服装电商导购,是一个善于倾听、主动引导、懂得搭配的“时尚顾问”,通过有温度的对话,给用户提供有价值的信息,包括需求引导、方案推荐、搜索结果推荐,最终促成满意的购物决策或转化行为。 一些原则: 1. 你是一个真人导购,是一个贴心、专业的销售,保持灵活,根据上下文,基于常识灵活的切换策略,在合适的上下文询问合适的问题、给出有价值的方案和搜索结果的呈现。 2. 商品搜索结果推荐与信息收集: 1. 根据上下文、用户诉求,灵活的切换侧重点,何时需要进行搜索、何时要引导客户完善需求,你需要站在用户角度进行思考。比如已经有较为清晰的意图,则以搜索、方案推荐为主,有必要的时候,思考该方向下重要的决策因素,进行提议和问题收集,让用户既得到相关信息、又得到下一步的方向引导、同时也有机会修正或者细化诉求。如果存在重大的需求方向缺口,主动通过1-2个关键问题进行引导,并提供初步方向。 2. 适时的提供有价值的信息,如商品推荐、穿搭建议、趋势信息,在推荐方向上有需求缺口、需要明确的重要信息时,要适时的做“信息收集”,引导式的帮助用户更清晰的呈现需求、提高商品发现的效率,形成“提供-反馈”的良性循环。 3. 对于复杂需求时,要能基于上下文,将导购任务进行合理拆解。 3. 引导或者收集需求时,需要站在用户立场,比如询问用户期待的效果或感觉、使用的场合、偏好的风格等用户立场需,而不是询问具体的款式或参数,你需要将用户立场的需求理解/翻译/转化为具体的搜索计划,最后筛选产品、结合需求+结果特性组织推荐理由、呈现方案。 4. 如何使用search_products:在需要搜索商品的时候,可以将需求分解为 2-4 个搜索查询,每个 query 聚焦一个明确的商品子类或搜索角度。每次调用 search_products 后,工具会返回以下内容,你需要决策是否要调整搜索策略,比如结果质量太差,可能需要调整搜索词、或者加大试探的query数量(不要超过3-5个)。可以进行多轮搜索,但是要适时的总结和反馈信息避免用户等待过长时间: - 各层级数量:完美匹配 / 部分匹配 / 不相关 的条数 - 整体质量判断:优质 / 一般 / 较差 - 简短质量说明 - 结果引用标识:[SEARCH_REF:xxx] 5. 撰写最终回复的时候,使用 [SEARCH_REF:xxx] 内联引用 1. 用自然流畅的语言组织回复,将 [SEARCH_REF:xxx] 嵌入叙述中 2. 系统会自动在 [SEARCH_REF:xxx] 位置渲染对应的商品卡片列表 3. 禁止在回复文本中列出商品名称、ID、价格、分类、规格等字段 4. 禁止用编号列表逐条复述搜索结果中的商品 """ # ── Agent state ──────────────────────────────────────────────────────────────── class AgentState(TypedDict): messages: Annotated[Sequence[BaseMessage], add_messages] current_image_path: Optional[str] # ── Helper ───────────────────────────────────────────────────────────────────── # Max length for logging single content field (avoid huge logs) _LOG_CONTENT_MAX = 8000 _LOG_TOOL_RESULT_MAX = 4000 def _extract_message_text(msg) -> str: """Extract plain text from a LangChain message (handles str or content_blocks).""" content = getattr(msg, "content", "") if isinstance(content, str): return content if isinstance(content, list): parts = [] for block in content: if isinstance(block, dict): parts.append(block.get("text") or block.get("content") or "") else: parts.append(str(block)) return "".join(str(p) for p in parts) return str(content) if content else "" def _message_for_log(msg: BaseMessage) -> dict: """Serialize a message for structured logging (content truncated).""" text = _extract_message_text(msg) if len(text) > _LOG_CONTENT_MAX: text = text[:_LOG_CONTENT_MAX] + f"... [truncated, total {len(text)} chars]" out: dict[str, Any] = { "type": getattr(msg, "type", "unknown"), "content": text, } if hasattr(msg, "tool_calls") and msg.tool_calls: out["tool_calls"] = [ {"name": tc.get("name"), "args": tc.get("args", {})} for tc in msg.tool_calls ] return out # ── Agent class ──────────────────────────────────────────────────────────────── class ShoppingAgent: """ReAct shopping agent with search-evaluate-decide loop and registry-based result referencing.""" def __init__(self, session_id: Optional[str] = None): self.session_id = session_id or "default" llm_kwargs = dict( model=settings.openai_model, temperature=settings.openai_temperature, api_key=settings.openai_api_key, ) if settings.openai_api_base_url: llm_kwargs["base_url"] = settings.openai_api_base_url self.llm = ChatOpenAI(**llm_kwargs) # Tools are session-bound so search_products writes to the right registry partition self.tools = get_all_tools(session_id=self.session_id, registry=global_registry) self.llm_with_tools = self.llm.bind_tools(self.tools) self.graph = self._build_graph() logger.info(f"ShoppingAgent ready — session={self.session_id}") def _build_graph(self): def agent_node(state: AgentState): messages = state["messages"] if not any(isinstance(m, SystemMessage) for m in messages): messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages) request_log = [_message_for_log(m) for m in messages] req_json = json.dumps(request_log, ensure_ascii=False) if len(req_json) > _LOG_CONTENT_MAX: req_json = req_json[:_LOG_CONTENT_MAX] + f"... [truncated total {len(req_json)}]" logger.info("[%s] LLM_REQUEST messages=%s", self.session_id, req_json) response = self.llm_with_tools.invoke(messages) response_log = _message_for_log(response) logger.info( "[%s] LLM_RESPONSE %s", self.session_id, json.dumps(response_log, ensure_ascii=False), ) return {"messages": [response]} def should_continue(state: AgentState): last = state["messages"][-1] if hasattr(last, "tool_calls") and last.tool_calls: return "tools" return END tool_node = ToolNode(self.tools) workflow = StateGraph(AgentState) workflow.add_node("agent", agent_node) workflow.add_node("tools", tool_node) workflow.add_edge(START, "agent") workflow.add_conditional_edges("agent", should_continue, ["tools", END]) workflow.add_edge("tools", "agent") return workflow.compile(checkpointer=MemorySaver()) def chat(self, query: str, image_path: Optional[str] = None) -> dict: """ Process a user query and return the agent response with metadata. Returns: dict with keys: response – final AI message text (may contain [SEARCH_REF:xxx] tokens) tool_calls – list of {name, args, result_preview} debug_steps – detailed per-node step log search_refs – dict[ref_id → SearchResult] for all searches this turn error – bool """ try: logger.info(f"[{self.session_id}] chat: {query!r} image={bool(image_path)}") if image_path and not Path(image_path).exists(): return { "response": f"错误:图片文件不存在:{image_path}", "error": True, } # Snapshot registry before the turn so we can report new additions registry_before = set(global_registry.get_all(self.session_id).keys()) message_content = query if image_path: message_content = f"{query}\n[用户上传了图片:{image_path}]" config = {"configurable": {"thread_id": self.session_id}} input_state = { "messages": [HumanMessage(content=message_content)], "current_image_path": image_path, } tool_calls: list[dict] = [] debug_steps: list[dict] = [] for event in self.graph.stream(input_state, config=config): logger.debug(f"[{self.session_id}] event keys: {list(event.keys())}") if "agent" in event: agent_out = event["agent"] step_msgs: list[dict] = [] step_tcs: list[dict] = [] for msg in agent_out.get("messages", []): text = _extract_message_text(msg) step_msgs.append({ "type": getattr(msg, "type", "assistant"), "content": text[:500], }) if hasattr(msg, "tool_calls") and msg.tool_calls: for tc in msg.tool_calls: entry = {"name": tc.get("name"), "args": tc.get("args", {})} tool_calls.append(entry) step_tcs.append(entry) debug_steps.append({"node": "agent", "messages": step_msgs, "tool_calls": step_tcs}) if "tools" in event: tools_out = event["tools"] step_results: list[dict] = [] msgs = tools_out.get("messages", []) # Match results back to tool_calls by position within this event unresolved = [tc for tc in tool_calls if "result" not in tc] for i, msg in enumerate(msgs): text = _extract_message_text(msg) preview = text[:600] + ("…" if len(text) > 600 else "") if i < len(unresolved): unresolved[i]["result"] = preview tc_name = unresolved[i].get("name", "") tc_args = unresolved[i].get("args", {}) result_log = text if len(text) <= _LOG_TOOL_RESULT_MAX else text[:_LOG_TOOL_RESULT_MAX] + f"... [truncated total {len(text)}]" logger.info( "[%s] TOOL_CALL_RESULT name=%s args=%s result=%s", self.session_id, tc_name, json.dumps(tc_args, ensure_ascii=False), result_log, ) step_results.append({"content": preview}) debug_steps.append({"node": "tools", "results": step_results}) final_state = self.graph.get_state(config) final_msg = final_state.values["messages"][-1] response_text = _extract_message_text(final_msg) # Collect new SearchResults added during this turn registry_after = global_registry.get_all(self.session_id) new_refs = { ref_id: result for ref_id, result in registry_after.items() if ref_id not in registry_before } logger.info( f"[{self.session_id}] done — tool_calls={len(tool_calls)}, new_refs={list(new_refs.keys())}" ) return { "response": response_text, "tool_calls": tool_calls, "debug_steps": debug_steps, "search_refs": new_refs, "error": False, } except Exception as e: logger.error(f"[{self.session_id}] chat error: {e}", exc_info=True) return { "response": f"抱歉,处理您的请求时遇到错误:{e}", "tool_calls": [], "debug_steps": [], "search_refs": {}, "error": True, } def get_conversation_history(self) -> list: try: config = {"configurable": {"thread_id": self.session_id}} state = self.graph.get_state(config) if not state or not state.values.get("messages"): return [] result = [] for msg in state.values["messages"]: if isinstance(msg, SystemMessage): continue if getattr(msg, "type", None) in ("system", "tool"): continue role = "user" if msg.type == "human" else "assistant" result.append({"role": role, "content": _extract_message_text(msg)}) return result except Exception as e: logger.error(f"get_conversation_history error: {e}") return [] def clear_history(self): logger.info(f"[{self.session_id}] clear requested (use new session_id to fully reset)") def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent: return ShoppingAgent(session_id=session_id)