Blame view

app/agents/shopping_agent.py 11.3 KB
e7f2b240   tangwang   first commit
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
  """
  Conversational Shopping Agent with LangGraph
  True ReAct agent with autonomous tool calling and message accumulation
  """
  
  import logging
  from pathlib import Path
  from typing import Optional, Sequence
  
  from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
  from langchain_openai import ChatOpenAI
  from langgraph.checkpoint.memory import MemorySaver
  from langgraph.graph import END, START, StateGraph
  from langgraph.graph.message import add_messages
  from langgraph.prebuilt import ToolNode
  from typing_extensions import Annotated, TypedDict
  
  from app.config import settings
  from app.tools.search_tools import get_all_tools
  
  logger = logging.getLogger(__name__)
  
  
  def _extract_message_text(msg) -> str:
      """Extract text from message content.
      LangChain 1.0: content may be str or content_blocks (list) for multimodal."""
      content = getattr(msg, "content", "")
      if isinstance(content, str):
          return content
      if isinstance(content, list):
          parts = []
          for block in content:
              if isinstance(block, dict):
                  parts.append(block.get("text", block.get("content", "")))
              else:
                  parts.append(str(block))
          return "".join(str(p) for p in parts)
      return str(content) if content else ""
  
  
  class AgentState(TypedDict):
      """State for the shopping agent with message accumulation"""
  
      messages: Annotated[Sequence[BaseMessage], add_messages]
      current_image_path: Optional[str]  # Track uploaded image
  
bad17b15   tangwang   调通baseline
47
  print("settings")
e7f2b240   tangwang   first commit
48
49
50
51
52
53
54
  class ShoppingAgent:
      """True ReAct agent with autonomous decision making"""
  
      def __init__(self, session_id: Optional[str] = None):
          self.session_id = session_id or "default"
  
          # Initialize LLM
8810a6fa   tangwang   重构
55
          llm_kwargs = dict(
e7f2b240   tangwang   first commit
56
57
58
59
              model=settings.openai_model,
              temperature=settings.openai_temperature,
              api_key=settings.openai_api_key,
          )
8810a6fa   tangwang   重构
60
61
          if settings.openai_api_base_url:
              llm_kwargs["base_url"] = settings.openai_api_base_url
bad17b15   tangwang   调通baseline
62
63
64
65
          
          print("llm_kwargs")
          print(llm_kwargs)
  
8810a6fa   tangwang   重构
66
          self.llm = ChatOpenAI(**llm_kwargs)
e7f2b240   tangwang   first commit
67
68
69
70
71
72
73
74
75
76
77
78
79
80
  
          # Get tools and bind to model
          self.tools = get_all_tools()
          self.llm_with_tools = self.llm.bind_tools(self.tools)
  
          # Build graph
          self.graph = self._build_graph()
  
          logger.info(f"Shopping agent initialized for session: {self.session_id}")
  
      def _build_graph(self):
          """Build the LangGraph StateGraph"""
  
          # System prompt for the agent
01b46131   tangwang   流程跑通
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
          system_prompt = """你是一位智能时尚购物助手,你可以:
  1. 根据文字描述搜索商品(使用 search_products
  2. 分析图片风格和属性(使用 analyze_image_style
  
  当用户咨询商品时:
  - 文字提问:直接使用 search_products 搜索
  - 图片上传:先用 analyze_image_style 理解商品,再用提取的描述调用 search_products 搜索
  - 可按需连续调用多个工具
  - 始终保持有用、友好的回复风格
  
  关键格式规则:
  展示商品结果时,每个商品必须严格按以下格式输出:
  
  1. [标题 title]
     ID: [商品ID]
     分类: [category_path]
     中文名: [title_cn](如有)
     标签: [tags](如有)
  
  示例:
e7f2b240   tangwang   first commit
101
102
  1. Puma Men White 3/4 Length Pants
     ID: 12345
01b46131   tangwang   流程跑通
103
104
105
     分类: 服饰 > 裤装 > 运动裤
     中文名: 彪马男士白色九分运动裤
     标签: 运动,夏季,白色
e7f2b240   tangwang   first commit
106
  
01b46131   tangwang   流程跑通
107
108
  不可省略 ID 字段!它是展示商品图片的关键。
  介绍要口语化,但必须保持上述商品格式。"""
e7f2b240   tangwang   first commit
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
  
          def agent_node(state: AgentState):
              """Agent decision node - decides which tools to call or when to respond"""
              messages = state["messages"]
  
              # Add system prompt if first message
              if not any(isinstance(m, SystemMessage) for m in messages):
                  messages = [SystemMessage(content=system_prompt)] + list(messages)
  
              # Handle image context
              if state.get("current_image_path"):
                  # Inject image path context for tool calls
                  # The agent can reference this in its reasoning
                  pass
  
              # Invoke LLM with tools
              response = self.llm_with_tools.invoke(messages)
              return {"messages": [response]}
  
          # Create tool node
          tool_node = ToolNode(self.tools)
  
          def should_continue(state: AgentState):
              """Determine if agent should continue or end"""
              messages = state["messages"]
              last_message = messages[-1]
  
              # If LLM made tool calls, continue to tools
              if hasattr(last_message, "tool_calls") and last_message.tool_calls:
                  return "tools"
              # Otherwise, end (agent has final response)
              return END
  
          # Build graph
          workflow = StateGraph(AgentState)
  
          workflow.add_node("agent", agent_node)
          workflow.add_node("tools", tool_node)
  
          workflow.add_edge(START, "agent")
          workflow.add_conditional_edges("agent", should_continue, ["tools", END])
          workflow.add_edge("tools", "agent")
  
          # Compile with memory
          checkpointer = MemorySaver()
          return workflow.compile(checkpointer=checkpointer)
  
      def chat(self, query: str, image_path: Optional[str] = None) -> dict:
          """Process user query with the agent
  
          Args:
              query: User's text query
              image_path: Optional path to uploaded image
  
          Returns:
01b46131   tangwang   流程跑通
164
165
166
              Dict with response and metadata, including:
                  - tool_calls: list of tool calls with args and (truncated) results
                  - debug_steps: detailed intermediate reasoning & tool execution steps
e7f2b240   tangwang   first commit
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
          """
          try:
              logger.info(
                  f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})"
              )
  
              # Validate image
              if image_path and not Path(image_path).exists():
                  return {
                      "response": f"Error: Image file not found at '{image_path}'",
                      "error": True,
                  }
  
              # Build input message
              message_content = query
              if image_path:
                  message_content = f"{query}\n[User uploaded image: {image_path}]"
  
              # Invoke agent
              config = {"configurable": {"thread_id": self.session_id}}
              input_state = {
                  "messages": [HumanMessage(content=message_content)],
                  "current_image_path": image_path,
              }
  
01b46131   tangwang   流程跑通
192
              # Track tool calls (high-level) and detailed debug steps
e7f2b240   tangwang   first commit
193
              tool_calls = []
01b46131   tangwang   流程跑通
194
              debug_steps = []
e7f2b240   tangwang   first commit
195
              
01b46131   tangwang   流程跑通
196
              # Stream events to capture tool calls and intermediate reasoning
e7f2b240   tangwang   first commit
197
198
              for event in self.graph.stream(input_state, config=config):
                  logger.info(f"Event: {event}")
01b46131   tangwang   流程跑通
199
200
  
                  # Agent node: LLM reasoning & tool decisions
e7f2b240   tangwang   first commit
201
202
                  if "agent" in event:
                      agent_output = event["agent"]
01b46131   tangwang   流程跑通
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
                      messages = agent_output.get("messages", [])
  
                      step_messages = []
                      step_tool_calls = []
  
                      for msg in messages:
                          msg_text = _extract_message_text(msg)
                          msg_entry = {
                              "type": getattr(msg, "type", "assistant"),
                              "content": msg_text[:500],  # truncate for safety
                          }
                          step_messages.append(msg_entry)
  
                          # Capture tool calls from this agent message
                          if hasattr(msg, "tool_calls") and msg.tool_calls:
                              for tc in msg.tool_calls:
                                  tc_entry = {
                                      "name": tc.get("name"),
                                      "args": tc.get("args", {}),
                                  }
                                  tool_calls.append(tc_entry)
                                  step_tool_calls.append(tc_entry)
  
                      debug_steps.append(
                          {
                              "node": "agent",
                              "messages": step_messages,
                              "tool_calls": step_tool_calls,
                          }
                      )
  
                  # Tool node: actual tool execution results
e7f2b240   tangwang   first commit
235
236
                  if "tools" in event:
                      tools_output = event["tools"]
01b46131   tangwang   流程跑通
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
                      messages = tools_output.get("messages", [])
  
                      step_tool_results = []
  
                      for i, msg in enumerate(messages):
                          content_text = _extract_message_text(msg)
                          result_preview = content_text[:500] + ("..." if len(content_text) > 500 else "")
  
                          if i < len(tool_calls):
                              tool_calls[i]["result"] = result_preview
  
                          step_tool_results.append(
                              {
                                  "content": result_preview,
                              }
                          )
  
                      debug_steps.append(
                          {
                              "node": "tools",
                              "results": step_tool_results,
                          }
                      )
e7f2b240   tangwang   first commit
260
261
262
263
264
265
266
267
268
269
270
  
              # Get final state
              final_state = self.graph.get_state(config)
              final_message = final_state.values["messages"][-1]
              response_text = _extract_message_text(final_message)
  
              logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls")
  
              return {
                  "response": response_text,
                  "tool_calls": tool_calls,
01b46131   tangwang   流程跑通
271
                  "debug_steps": debug_steps,
e7f2b240   tangwang   first commit
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
                  "error": False,
              }
  
          except Exception as e:
              logger.error(f"Error in agent chat: {e}", exc_info=True)
              return {
                  "response": f"I apologize, I encountered an error: {str(e)}",
                  "error": True,
              }
  
      def get_conversation_history(self) -> list:
          """Get conversation history for this session"""
          try:
              config = {"configurable": {"thread_id": self.session_id}}
              state = self.graph.get_state(config)
  
              if not state or not state.values.get("messages"):
                  return []
  
              messages = state.values["messages"]
              result = []
  
              for msg in messages:
                  # Skip system messages and tool messages
                  if isinstance(msg, SystemMessage):
                      continue
                  if hasattr(msg, "type") and msg.type in ["system", "tool"]:
                      continue
  
                  role = "user" if msg.type == "human" else "assistant"
                  result.append({"role": role, "content": _extract_message_text(msg)})
  
              return result
  
          except Exception as e:
              logger.error(f"Error getting history: {e}")
              return []
  
      def clear_history(self):
          """Clear conversation history for this session"""
          # With MemorySaver, we can't easily clear, but we can log
          logger.info(f"[{self.session_id}] History clear requested")
          # In production, implement proper clearing or use new thread_id
  
  
  def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent:
      """Factory function to create a shopping agent"""
      return ShoppingAgent(session_id=session_id)