shopping_agent.py 11.3 KB
"""
Conversational Shopping Agent with LangGraph
True ReAct agent with autonomous tool calling and message accumulation
"""

import logging
from pathlib import Path
from typing import Optional, Sequence

from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from typing_extensions import Annotated, TypedDict

from app.config import settings
from app.tools.search_tools import get_all_tools

logger = logging.getLogger(__name__)


def _extract_message_text(msg) -> str:
    """Extract text from message content.
    LangChain 1.0: content may be str or content_blocks (list) for multimodal."""
    content = getattr(msg, "content", "")
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        parts = []
        for block in content:
            if isinstance(block, dict):
                parts.append(block.get("text", block.get("content", "")))
            else:
                parts.append(str(block))
        return "".join(str(p) for p in parts)
    return str(content) if content else ""


class AgentState(TypedDict):
    """State for the shopping agent with message accumulation"""

    messages: Annotated[Sequence[BaseMessage], add_messages]
    current_image_path: Optional[str]  # Track uploaded image

print("settings")
class ShoppingAgent:
    """True ReAct agent with autonomous decision making"""

    def __init__(self, session_id: Optional[str] = None):
        self.session_id = session_id or "default"

        # Initialize LLM
        llm_kwargs = dict(
            model=settings.openai_model,
            temperature=settings.openai_temperature,
            api_key=settings.openai_api_key,
        )
        if settings.openai_api_base_url:
            llm_kwargs["base_url"] = settings.openai_api_base_url
        
        print("llm_kwargs")
        print(llm_kwargs)

        self.llm = ChatOpenAI(**llm_kwargs)

        # Get tools and bind to model
        self.tools = get_all_tools()
        self.llm_with_tools = self.llm.bind_tools(self.tools)

        # Build graph
        self.graph = self._build_graph()

        logger.info(f"Shopping agent initialized for session: {self.session_id}")

    def _build_graph(self):
        """Build the LangGraph StateGraph"""

        # System prompt for the agent
        system_prompt = """你是一位智能时尚购物助手,你可以:
1. 根据文字描述搜索商品(使用 search_products)
2. 分析图片风格和属性(使用 analyze_image_style)

当用户咨询商品时:
- 文字提问:直接使用 search_products 搜索
- 图片上传:先用 analyze_image_style 理解商品,再用提取的描述调用 search_products 搜索
- 可按需连续调用多个工具
- 始终保持有用、友好的回复风格

关键格式规则:
展示商品结果时,每个商品必须严格按以下格式输出:

1. [标题 title]
   ID: [商品ID]
   分类: [category_path]
   中文名: [title_cn](如有)
   标签: [tags](如有)

示例:
1. Puma Men White 3/4 Length Pants
   ID: 12345
   分类: 服饰 > 裤装 > 运动裤
   中文名: 彪马男士白色九分运动裤
   标签: 运动,夏季,白色

不可省略 ID 字段!它是展示商品图片的关键。
介绍要口语化,但必须保持上述商品格式。"""

        def agent_node(state: AgentState):
            """Agent decision node - decides which tools to call or when to respond"""
            messages = state["messages"]

            # Add system prompt if first message
            if not any(isinstance(m, SystemMessage) for m in messages):
                messages = [SystemMessage(content=system_prompt)] + list(messages)

            # Handle image context
            if state.get("current_image_path"):
                # Inject image path context for tool calls
                # The agent can reference this in its reasoning
                pass

            # Invoke LLM with tools
            response = self.llm_with_tools.invoke(messages)
            return {"messages": [response]}

        # Create tool node
        tool_node = ToolNode(self.tools)

        def should_continue(state: AgentState):
            """Determine if agent should continue or end"""
            messages = state["messages"]
            last_message = messages[-1]

            # If LLM made tool calls, continue to tools
            if hasattr(last_message, "tool_calls") and last_message.tool_calls:
                return "tools"
            # Otherwise, end (agent has final response)
            return END

        # Build graph
        workflow = StateGraph(AgentState)

        workflow.add_node("agent", agent_node)
        workflow.add_node("tools", tool_node)

        workflow.add_edge(START, "agent")
        workflow.add_conditional_edges("agent", should_continue, ["tools", END])
        workflow.add_edge("tools", "agent")

        # Compile with memory
        checkpointer = MemorySaver()
        return workflow.compile(checkpointer=checkpointer)

    def chat(self, query: str, image_path: Optional[str] = None) -> dict:
        """Process user query with the agent

        Args:
            query: User's text query
            image_path: Optional path to uploaded image

        Returns:
            Dict with response and metadata, including:
                - tool_calls: list of tool calls with args and (truncated) results
                - debug_steps: detailed intermediate reasoning & tool execution steps
        """
        try:
            logger.info(
                f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})"
            )

            # Validate image
            if image_path and not Path(image_path).exists():
                return {
                    "response": f"Error: Image file not found at '{image_path}'",
                    "error": True,
                }

            # Build input message
            message_content = query
            if image_path:
                message_content = f"{query}\n[User uploaded image: {image_path}]"

            # Invoke agent
            config = {"configurable": {"thread_id": self.session_id}}
            input_state = {
                "messages": [HumanMessage(content=message_content)],
                "current_image_path": image_path,
            }

            # Track tool calls (high-level) and detailed debug steps
            tool_calls = []
            debug_steps = []
            
            # Stream events to capture tool calls and intermediate reasoning
            for event in self.graph.stream(input_state, config=config):
                logger.info(f"Event: {event}")

                # Agent node: LLM reasoning & tool decisions
                if "agent" in event:
                    agent_output = event["agent"]
                    messages = agent_output.get("messages", [])

                    step_messages = []
                    step_tool_calls = []

                    for msg in messages:
                        msg_text = _extract_message_text(msg)
                        msg_entry = {
                            "type": getattr(msg, "type", "assistant"),
                            "content": msg_text[:500],  # truncate for safety
                        }
                        step_messages.append(msg_entry)

                        # Capture tool calls from this agent message
                        if hasattr(msg, "tool_calls") and msg.tool_calls:
                            for tc in msg.tool_calls:
                                tc_entry = {
                                    "name": tc.get("name"),
                                    "args": tc.get("args", {}),
                                }
                                tool_calls.append(tc_entry)
                                step_tool_calls.append(tc_entry)

                    debug_steps.append(
                        {
                            "node": "agent",
                            "messages": step_messages,
                            "tool_calls": step_tool_calls,
                        }
                    )

                # Tool node: actual tool execution results
                if "tools" in event:
                    tools_output = event["tools"]
                    messages = tools_output.get("messages", [])

                    step_tool_results = []

                    for i, msg in enumerate(messages):
                        content_text = _extract_message_text(msg)
                        result_preview = content_text[:500] + ("..." if len(content_text) > 500 else "")

                        if i < len(tool_calls):
                            tool_calls[i]["result"] = result_preview

                        step_tool_results.append(
                            {
                                "content": result_preview,
                            }
                        )

                    debug_steps.append(
                        {
                            "node": "tools",
                            "results": step_tool_results,
                        }
                    )

            # Get final state
            final_state = self.graph.get_state(config)
            final_message = final_state.values["messages"][-1]
            response_text = _extract_message_text(final_message)

            logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls")

            return {
                "response": response_text,
                "tool_calls": tool_calls,
                "debug_steps": debug_steps,
                "error": False,
            }

        except Exception as e:
            logger.error(f"Error in agent chat: {e}", exc_info=True)
            return {
                "response": f"I apologize, I encountered an error: {str(e)}",
                "error": True,
            }

    def get_conversation_history(self) -> list:
        """Get conversation history for this session"""
        try:
            config = {"configurable": {"thread_id": self.session_id}}
            state = self.graph.get_state(config)

            if not state or not state.values.get("messages"):
                return []

            messages = state.values["messages"]
            result = []

            for msg in messages:
                # Skip system messages and tool messages
                if isinstance(msg, SystemMessage):
                    continue
                if hasattr(msg, "type") and msg.type in ["system", "tool"]:
                    continue

                role = "user" if msg.type == "human" else "assistant"
                result.append({"role": role, "content": _extract_message_text(msg)})

            return result

        except Exception as e:
            logger.error(f"Error getting history: {e}")
            return []

    def clear_history(self):
        """Clear conversation history for this session"""
        # With MemorySaver, we can't easily clear, but we can log
        logger.info(f"[{self.session_id}] History clear requested")
        # In production, implement proper clearing or use new thread_id


def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent:
    """Factory function to create a shopping agent"""
    return ShoppingAgent(session_id=session_id)