shopping_agent.py 14.8 KB
"""
Conversational Shopping Agent with LangGraph

Architecture:
- ReAct-style agent: plan → search → evaluate → re-plan or respond
- search_products is session-bound, writing curated results to SearchResultRegistry
- Final AI message references results via [SEARCH_REF:xxx] tokens instead of
  re-listing product details; the UI renders product cards from the registry
"""

import json
import logging
from pathlib import Path
from typing import Any, Optional, Sequence

from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from typing_extensions import Annotated, TypedDict

from app.config import settings
from app.search_registry import global_registry
from app.tools.search_tools import get_all_tools

logger = logging.getLogger(__name__)

# ── System prompt ──────────────────────────────────────────────────────────────
# Universal: works for any e-commerce vertical (fashion, electronics, home, etc.)
# Key design decisions:
#   1. Guides multi-query search planning with explicit evaluate-and-decide loop
#   2. Forbids re-listing product details in the final response
#   3. Mandates [SEARCH_REF:xxx] inline citation as the only product presentation mechanism
SYSTEM_PROMPT = """  角色定义
  你是我们店铺的一名专业的电商导购,是一个善于倾听、主动引导、懂得搭配的“时尚顾问”,通过有温度的对话,给用户提供有价值的信息,包括需求引导、方案推荐、搜索结果推荐,最终促成满意的购物决策或转化行为。
  作为我们店铺的一名专业的销售,除了本店铺的商品的推荐,你可以给用户提供有帮助的信息,但是不要虚构商品、提供本商店搜索结果以外的商品。
  
  一些原则:
  1. 价值提供与信息收集的原则:
    1. 优先价值提供:适时的提供有价值的信息,如商品推荐、穿搭建议、趋势信息,在推荐方向上有需求缺口、需要明确的重要信息时,要适时的做“信息收集”,引导式的澄清需求、提高商品发现的效率,形成“提供-反馈”的良性循环。
    2. 缺口大(比如品类或者使用人群都不能确定)→ 给出方案推荐 + 1-2个关键问题让用户选择;缺口小→直接检索+方案呈现,根据情况,可以考虑该方向下重要的决策因素,进行提议和问题收集,让用户既得到相关信息、又得到下一步的方向引导、同时也有机会修正或者细化诉求。
    3. 选项驱动式澄清:推荐几个清晰的方向,呈现方案或商品搜索结果,再做澄清
    4. 单轮对话最好只提一个问题,最多两个,禁止多问题堆叠。
    5. 站在用户立场思考:比如询问用户期待的效果或感觉、使用的场合、想解决的问题,而不是询问具体的款式、参数,你需要将用户表达的需求翻译为具体可检索的商品特征(版型、材质、设计元素、风格标签等),并据此筛选商品、组织推荐逻辑。
  2. 如何使用make_search_products_tool:
    1. 可以生成多个query进行搜索:在需要搜索商品的时候,可以将需求分解为 2-4 个搜索查询,每个 query 聚焦一个明确的商品子类或搜索角度。
    2. 可以根据搜索结果调整搜索策略:每次调用 search_products 后,工具会返回搜索结果的相关性的判断、以及搜索结果的topN的title,你需要决策是否要调整搜索策略,比如结果质量太差,可能需要调整搜索词、或者加大试探的query数量(不要超过3-5个)。
    3. 使用 [SEARCH_REF:xxx] 内联引用搜索结果:搜索工具会返回一个结果引用标识[SEARCH_REF:xxx],撰写最终答复的时候可以直接引用将 [SEARCH_REF:xxx] ,系统会自动在该位置渲染对应的商品卡片列表,无需复述搜索结果。
"""


# ── Agent state ────────────────────────────────────────────────────────────────

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    current_image_path: Optional[str]


# ── Helper ─────────────────────────────────────────────────────────────────────

# Max length for logging single content field (avoid huge logs)
_LOG_CONTENT_MAX = 8000
_LOG_TOOL_RESULT_MAX = 4000


def _extract_message_text(msg) -> str:
    """Extract plain text from a LangChain message (handles str or content_blocks)."""
    content = getattr(msg, "content", "")
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts = []
        for block in content:
            if isinstance(block, dict):
                parts.append(block.get("text") or block.get("content") or "")
            else:
                parts.append(str(block))
        return "".join(str(p) for p in parts)
    return str(content) if content else ""


def _message_for_log(msg: BaseMessage) -> dict:
    """Serialize a message for structured logging (content truncated)."""
    text = _extract_message_text(msg)
    if len(text) > _LOG_CONTENT_MAX:
        text = text[:_LOG_CONTENT_MAX] + f"... [truncated, total {len(text)} chars]"
    out: dict[str, Any] = {
        "type": getattr(msg, "type", "unknown"),
        "content": text,
    }
    if hasattr(msg, "tool_calls") and msg.tool_calls:
        out["tool_calls"] = [
            {"name": tc.get("name"), "args": tc.get("args", {})}
            for tc in msg.tool_calls
        ]
    return out


# ── Agent class ────────────────────────────────────────────────────────────────

class ShoppingAgent:
    """ReAct shopping agent with search-evaluate-decide loop and registry-based result referencing."""

    def __init__(self, session_id: Optional[str] = None):
        self.session_id = session_id or "default"

        llm_kwargs = dict(
            model=settings.openai_model,
            temperature=settings.openai_temperature,
            api_key=settings.openai_api_key,
        )
        if settings.openai_api_base_url:
            llm_kwargs["base_url"] = settings.openai_api_base_url

        self.llm = ChatOpenAI(**llm_kwargs)

        # Tools are session-bound so search_products writes to the right registry partition
        self.tools = get_all_tools(session_id=self.session_id, registry=global_registry)
        self.llm_with_tools = self.llm.bind_tools(self.tools)

        self.graph = self._build_graph()
        logger.info(f"ShoppingAgent ready — session={self.session_id}")

    def _build_graph(self):
        def agent_node(state: AgentState):
            messages = state["messages"]
            if not any(isinstance(m, SystemMessage) for m in messages):
                messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages)
            request_log = [_message_for_log(m) for m in messages]
            req_json = json.dumps(request_log, ensure_ascii=False)
            if len(req_json) > _LOG_CONTENT_MAX:
                req_json = req_json[:_LOG_CONTENT_MAX] + f"... [truncated total {len(req_json)}]"
            logger.info("[%s] LLM_REQUEST messages=%s", self.session_id, req_json)
            response = self.llm_with_tools.invoke(messages)
            response_log = _message_for_log(response)
            logger.info(
                "[%s] LLM_RESPONSE %s",
                self.session_id,
                json.dumps(response_log, ensure_ascii=False),
            )
            return {"messages": [response]}

        def should_continue(state: AgentState):
            last = state["messages"][-1]
            if hasattr(last, "tool_calls") and last.tool_calls:
                return "tools"
            return END

        tool_node = ToolNode(self.tools)

        workflow = StateGraph(AgentState)
        workflow.add_node("agent", agent_node)
        workflow.add_node("tools", tool_node)
        workflow.add_edge(START, "agent")
        workflow.add_conditional_edges("agent", should_continue, ["tools", END])
        workflow.add_edge("tools", "agent")

        return workflow.compile(checkpointer=MemorySaver())

    def chat(self, query: str, image_path: Optional[str] = None) -> dict:
        """
        Process a user query and return the agent response with metadata.

        Returns:
            dict with keys:
              response      – final AI message text (may contain [SEARCH_REF:xxx] tokens)
              tool_calls    – list of {name, args, result_preview}
              debug_steps   – detailed per-node step log
              search_refs   – dict[ref_id → SearchResult] for all searches this turn
              error         – bool
        """
        try:
            logger.info(f"[{self.session_id}] chat: {query!r} image={bool(image_path)}")

            if image_path and not Path(image_path).exists():
                return {
                    "response": f"错误:图片文件不存在:{image_path}",
                    "error": True,
                }

            # Snapshot registry before the turn so we can report new additions
            registry_before = set(global_registry.get_all(self.session_id).keys())

            message_content = query
            if image_path:
                message_content = f"{query}\n[用户上传了图片:{image_path}]"

            config = {"configurable": {"thread_id": self.session_id}}
            input_state = {
                "messages": [HumanMessage(content=message_content)],
                "current_image_path": image_path,
            }

            tool_calls: list[dict] = []
            debug_steps: list[dict] = []

            for event in self.graph.stream(input_state, config=config):
                logger.debug(f"[{self.session_id}] event keys: {list(event.keys())}")

                if "agent" in event:
                    agent_out = event["agent"]
                    step_msgs: list[dict] = []
                    step_tcs: list[dict] = []

                    for msg in agent_out.get("messages", []):
                        text = _extract_message_text(msg)
                        step_msgs.append({
                            "type": getattr(msg, "type", "assistant"),
                            "content": text[:500],
                        })
                        if hasattr(msg, "tool_calls") and msg.tool_calls:
                            for tc in msg.tool_calls:
                                entry = {"name": tc.get("name"), "args": tc.get("args", {})}
                                tool_calls.append(entry)
                                step_tcs.append(entry)

                    debug_steps.append({"node": "agent", "messages": step_msgs, "tool_calls": step_tcs})

                if "tools" in event:
                    tools_out = event["tools"]
                    step_results: list[dict] = []
                    msgs = tools_out.get("messages", [])

                    # Match results back to tool_calls by position within this event
                    unresolved = [tc for tc in tool_calls if "result" not in tc]
                    for i, msg in enumerate(msgs):
                        text = _extract_message_text(msg)
                        preview = text[:600] + ("…" if len(text) > 600 else "")
                        if i < len(unresolved):
                            unresolved[i]["result"] = preview
                            tc_name = unresolved[i].get("name", "")
                            tc_args = unresolved[i].get("args", {})
                            result_log = text if len(text) <= _LOG_TOOL_RESULT_MAX else text[:_LOG_TOOL_RESULT_MAX] + f"... [truncated total {len(text)}]"
                            logger.info(
                                "[%s] TOOL_CALL_RESULT name=%s args=%s result=%s",
                                self.session_id,
                                tc_name,
                                json.dumps(tc_args, ensure_ascii=False),
                                result_log,
                            )
                        step_results.append({"content": preview})

                    debug_steps.append({"node": "tools", "results": step_results})

            final_state = self.graph.get_state(config)
            final_msg = final_state.values["messages"][-1]
            response_text = _extract_message_text(final_msg)

            # Collect new SearchResults added during this turn
            registry_after = global_registry.get_all(self.session_id)
            new_refs = {
                ref_id: result
                for ref_id, result in registry_after.items()
                if ref_id not in registry_before
            }

            logger.info(
                f"[{self.session_id}] done — tool_calls={len(tool_calls)}, new_refs={list(new_refs.keys())}"
            )

            return {
                "response": response_text,
                "tool_calls": tool_calls,
                "debug_steps": debug_steps,
                "search_refs": new_refs,
                "error": False,
            }

        except Exception as e:
            logger.error(f"[{self.session_id}] chat error: {e}", exc_info=True)
            return {
                "response": f"抱歉,处理您的请求时遇到错误:{e}",
                "tool_calls": [],
                "debug_steps": [],
                "search_refs": {},
                "error": True,
            }

    def get_conversation_history(self) -> list:
        try:
            config = {"configurable": {"thread_id": self.session_id}}
            state = self.graph.get_state(config)
            if not state or not state.values.get("messages"):
                return []

            result = []
            for msg in state.values["messages"]:
                if isinstance(msg, SystemMessage):
                    continue
                if getattr(msg, "type", None) in ("system", "tool"):
                    continue
                role = "user" if msg.type == "human" else "assistant"
                result.append({"role": role, "content": _extract_message_text(msg)})
            return result
        except Exception as e:
            logger.error(f"get_conversation_history error: {e}")
            return []

    def clear_history(self):
        logger.info(f"[{self.session_id}] clear requested (use new session_id to fully reset)")


def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent:
    return ShoppingAgent(session_id=session_id)