Blame view

app/agents/shopping_agent.py 9.75 KB
e7f2b240   tangwang   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  """
  Conversational Shopping Agent with LangGraph
  True ReAct agent with autonomous tool calling and message accumulation
  """
  
  import logging
  from pathlib import Path
  from typing import Optional, Sequence
  
  from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
  from langchain_openai import ChatOpenAI
  from langgraph.checkpoint.memory import MemorySaver
  from langgraph.graph import END, START, StateGraph
  from langgraph.graph.message import add_messages
  from langgraph.prebuilt import ToolNode
  from typing_extensions import Annotated, TypedDict
  
  from app.config import settings
  from app.tools.search_tools import get_all_tools
  
  logger = logging.getLogger(__name__)
  
  
  def _extract_message_text(msg) -> str:
      """Extract text from message content.
      LangChain 1.0: content may be str or content_blocks (list) for multimodal."""
      content = getattr(msg, "content", "")
      if isinstance(content, str):
          return content
      if isinstance(content, list):
          parts = []
          for block in content:
              if isinstance(block, dict):
                  parts.append(block.get("text", block.get("content", "")))
              else:
                  parts.append(str(block))
          return "".join(str(p) for p in parts)
      return str(content) if content else ""
  
  
  class AgentState(TypedDict):
      """State for the shopping agent with message accumulation"""
  
      messages: Annotated[Sequence[BaseMessage], add_messages]
      current_image_path: Optional[str]  # Track uploaded image
  
bad17b15   tangwang   调通baseline
47
  print("settings")
e7f2b240   tangwang   first commit
48
49
50
51
52
53
54
  class ShoppingAgent:
      """True ReAct agent with autonomous decision making"""
  
      def __init__(self, session_id: Optional[str] = None):
          self.session_id = session_id or "default"
  
          # Initialize LLM
8810a6fa   tangwang   重构
55
          llm_kwargs = dict(
e7f2b240   tangwang   first commit
56
57
58
59
              model=settings.openai_model,
              temperature=settings.openai_temperature,
              api_key=settings.openai_api_key,
          )
8810a6fa   tangwang   重构
60
61
          if settings.openai_api_base_url:
              llm_kwargs["base_url"] = settings.openai_api_base_url
bad17b15   tangwang   调通baseline
62
63
64
65
          
          print("llm_kwargs")
          print(llm_kwargs)
  
8810a6fa   tangwang   重构
66
          self.llm = ChatOpenAI(**llm_kwargs)
e7f2b240   tangwang   first commit
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
  
          # Get tools and bind to model
          self.tools = get_all_tools()
          self.llm_with_tools = self.llm.bind_tools(self.tools)
  
          # Build graph
          self.graph = self._build_graph()
  
          logger.info(f"Shopping agent initialized for session: {self.session_id}")
  
      def _build_graph(self):
          """Build the LangGraph StateGraph"""
  
          # System prompt for the agent
          system_prompt = """You are an intelligent fashion shopping assistant. You can:
  1. Search for products by text description (use search_products)
8810a6fa   tangwang   重构
83
  2. Analyze image style and attributes (use analyze_image_style)
e7f2b240   tangwang   first commit
84
85
86
  
  When a user asks about products:
  - For text queries: use search_products directly
8810a6fa   tangwang   重构
87
  - For image uploads: use analyze_image_style first to understand the product, then use search_products with the extracted description
e7f2b240   tangwang   first commit
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
  - You can call multiple tools in sequence if needed
  - Always provide helpful, friendly responses
  
  CRITICAL FORMATTING RULES:
  When presenting product results, you MUST use this EXACT format for EACH product:
  
  1. [Product Name]
     ID: [Product ID Number]
     Category: [Category]
     Color: [Color]
     Gender: [Gender]
     (Include Season, Usage, Relevance if available)
  
  Example:
  1. Puma Men White 3/4 Length Pants
     ID: 12345
     Category: Apparel > Bottomwear > Track Pants
     Color: White
     Gender: Men
     Season: Summer
     Usage: Sports
     Relevance: 95.2%
  
  DO NOT skip the ID field! It is essential for displaying product images.
  Be conversational in your introduction, but preserve the exact product format."""
  
          def agent_node(state: AgentState):
              """Agent decision node - decides which tools to call or when to respond"""
              messages = state["messages"]
  
              # Add system prompt if first message
              if not any(isinstance(m, SystemMessage) for m in messages):
                  messages = [SystemMessage(content=system_prompt)] + list(messages)
  
              # Handle image context
              if state.get("current_image_path"):
                  # Inject image path context for tool calls
                  # The agent can reference this in its reasoning
                  pass
  
              # Invoke LLM with tools
              response = self.llm_with_tools.invoke(messages)
              return {"messages": [response]}
  
          # Create tool node
          tool_node = ToolNode(self.tools)
  
          def should_continue(state: AgentState):
              """Determine if agent should continue or end"""
              messages = state["messages"]
              last_message = messages[-1]
  
              # If LLM made tool calls, continue to tools
              if hasattr(last_message, "tool_calls") and last_message.tool_calls:
                  return "tools"
              # Otherwise, end (agent has final response)
              return END
  
          # Build graph
          workflow = StateGraph(AgentState)
  
          workflow.add_node("agent", agent_node)
          workflow.add_node("tools", tool_node)
  
          workflow.add_edge(START, "agent")
          workflow.add_conditional_edges("agent", should_continue, ["tools", END])
          workflow.add_edge("tools", "agent")
  
          # Compile with memory
          checkpointer = MemorySaver()
          return workflow.compile(checkpointer=checkpointer)
  
      def chat(self, query: str, image_path: Optional[str] = None) -> dict:
          """Process user query with the agent
  
          Args:
              query: User's text query
              image_path: Optional path to uploaded image
  
          Returns:
              Dict with response and metadata
          """
          try:
              logger.info(
                  f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})"
              )
  
              # Validate image
              if image_path and not Path(image_path).exists():
                  return {
                      "response": f"Error: Image file not found at '{image_path}'",
                      "error": True,
                  }
  
              # Build input message
              message_content = query
              if image_path:
                  message_content = f"{query}\n[User uploaded image: {image_path}]"
  
              # Invoke agent
              config = {"configurable": {"thread_id": self.session_id}}
              input_state = {
                  "messages": [HumanMessage(content=message_content)],
                  "current_image_path": image_path,
              }
  
              # Track tool calls
              tool_calls = []
              
              # Stream events to capture tool calls
              for event in self.graph.stream(input_state, config=config):
                  logger.info(f"Event: {event}")
                  
                  # Check for agent node (tool calls)
                  if "agent" in event:
                      agent_output = event["agent"]
                      if "messages" in agent_output:
                          for msg in agent_output["messages"]:
                              if hasattr(msg, "tool_calls") and msg.tool_calls:
                                  for tc in msg.tool_calls:
                                      tool_calls.append({
                                          "name": tc["name"],
                                          "args": tc.get("args", {}),
                                      })
                  
                  # Check for tool node (tool results)
                  if "tools" in event:
                      tools_output = event["tools"]
                      if "messages" in tools_output:
                          for i, msg in enumerate(tools_output["messages"]):
                              if i < len(tool_calls):
                                  tool_calls[i]["result"] = str(msg.content)[:200] + "..."
  
              # Get final state
              final_state = self.graph.get_state(config)
              final_message = final_state.values["messages"][-1]
              response_text = _extract_message_text(final_message)
  
              logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls")
  
              return {
                  "response": response_text,
                  "tool_calls": tool_calls,
                  "error": False,
              }
  
          except Exception as e:
              logger.error(f"Error in agent chat: {e}", exc_info=True)
              return {
                  "response": f"I apologize, I encountered an error: {str(e)}",
                  "error": True,
              }
  
      def get_conversation_history(self) -> list:
          """Get conversation history for this session"""
          try:
              config = {"configurable": {"thread_id": self.session_id}}
              state = self.graph.get_state(config)
  
              if not state or not state.values.get("messages"):
                  return []
  
              messages = state.values["messages"]
              result = []
  
              for msg in messages:
                  # Skip system messages and tool messages
                  if isinstance(msg, SystemMessage):
                      continue
                  if hasattr(msg, "type") and msg.type in ["system", "tool"]:
                      continue
  
                  role = "user" if msg.type == "human" else "assistant"
                  result.append({"role": role, "content": _extract_message_text(msg)})
  
              return result
  
          except Exception as e:
              logger.error(f"Error getting history: {e}")
              return []
  
      def clear_history(self):
          """Clear conversation history for this session"""
          # With MemorySaver, we can't easily clear, but we can log
          logger.info(f"[{self.session_id}] History clear requested")
          # In production, implement proper clearing or use new thread_id
  
  
  def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent:
      """Factory function to create a shopping agent"""
      return ShoppingAgent(session_id=session_id)