shopping_agent.py
11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
"""
Conversational Shopping Agent with LangGraph
True ReAct agent with autonomous tool calling and message accumulation
"""
import logging
from pathlib import Path
from typing import Optional, Sequence
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from typing_extensions import Annotated, TypedDict
from app.config import settings
from app.tools.search_tools import get_all_tools
logger = logging.getLogger(__name__)
def _extract_message_text(msg) -> str:
"""Extract text from message content.
LangChain 1.0: content may be str or content_blocks (list) for multimodal."""
content = getattr(msg, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, dict):
parts.append(block.get("text", block.get("content", "")))
else:
parts.append(str(block))
return "".join(str(p) for p in parts)
return str(content) if content else ""
class AgentState(TypedDict):
"""State for the shopping agent with message accumulation"""
messages: Annotated[Sequence[BaseMessage], add_messages]
current_image_path: Optional[str] # Track uploaded image
print("settings")
class ShoppingAgent:
"""True ReAct agent with autonomous decision making"""
def __init__(self, session_id: Optional[str] = None):
self.session_id = session_id or "default"
# Initialize LLM
llm_kwargs = dict(
model=settings.openai_model,
temperature=settings.openai_temperature,
api_key=settings.openai_api_key,
)
if settings.openai_api_base_url:
llm_kwargs["base_url"] = settings.openai_api_base_url
print("llm_kwargs")
print(llm_kwargs)
self.llm = ChatOpenAI(**llm_kwargs)
# Get tools and bind to model
self.tools = get_all_tools()
self.llm_with_tools = self.llm.bind_tools(self.tools)
# Build graph
self.graph = self._build_graph()
logger.info(f"Shopping agent initialized for session: {self.session_id}")
def _build_graph(self):
"""Build the LangGraph StateGraph"""
# System prompt for the agent
system_prompt = """你是一位智能时尚购物助手,你可以:
1. 根据文字描述搜索商品(使用 search_products)
2. 分析图片风格和属性(使用 analyze_image_style)
当用户咨询商品时:
- 文字提问:直接使用 search_products 搜索
- 图片上传:先用 analyze_image_style 理解商品,再用提取的描述调用 search_products 搜索
- 可按需连续调用多个工具
- 始终保持有用、友好的回复风格
关键格式规则:
展示商品结果时,每个商品必须严格按以下格式输出:
1. [标题 title]
ID: [商品ID]
分类: [category_path]
中文名: [title_cn](如有)
标签: [tags](如有)
示例:
1. Puma Men White 3/4 Length Pants
ID: 12345
分类: 服饰 > 裤装 > 运动裤
中文名: 彪马男士白色九分运动裤
标签: 运动,夏季,白色
不可省略 ID 字段!它是展示商品图片的关键。
介绍要口语化,但必须保持上述商品格式。"""
def agent_node(state: AgentState):
"""Agent decision node - decides which tools to call or when to respond"""
messages = state["messages"]
# Add system prompt if first message
if not any(isinstance(m, SystemMessage) for m in messages):
messages = [SystemMessage(content=system_prompt)] + list(messages)
# Handle image context
if state.get("current_image_path"):
# Inject image path context for tool calls
# The agent can reference this in its reasoning
pass
# Invoke LLM with tools
response = self.llm_with_tools.invoke(messages)
return {"messages": [response]}
# Create tool node
tool_node = ToolNode(self.tools)
def should_continue(state: AgentState):
"""Determine if agent should continue or end"""
messages = state["messages"]
last_message = messages[-1]
# If LLM made tool calls, continue to tools
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
# Otherwise, end (agent has final response)
return END
# Build graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue, ["tools", END])
workflow.add_edge("tools", "agent")
# Compile with memory
checkpointer = MemorySaver()
return workflow.compile(checkpointer=checkpointer)
def chat(self, query: str, image_path: Optional[str] = None) -> dict:
"""Process user query with the agent
Args:
query: User's text query
image_path: Optional path to uploaded image
Returns:
Dict with response and metadata, including:
- tool_calls: list of tool calls with args and (truncated) results
- debug_steps: detailed intermediate reasoning & tool execution steps
"""
try:
logger.info(
f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})"
)
# Validate image
if image_path and not Path(image_path).exists():
return {
"response": f"Error: Image file not found at '{image_path}'",
"error": True,
}
# Build input message
message_content = query
if image_path:
message_content = f"{query}\n[User uploaded image: {image_path}]"
# Invoke agent
config = {"configurable": {"thread_id": self.session_id}}
input_state = {
"messages": [HumanMessage(content=message_content)],
"current_image_path": image_path,
}
# Track tool calls (high-level) and detailed debug steps
tool_calls = []
debug_steps = []
# Stream events to capture tool calls and intermediate reasoning
for event in self.graph.stream(input_state, config=config):
logger.info(f"Event: {event}")
# Agent node: LLM reasoning & tool decisions
if "agent" in event:
agent_output = event["agent"]
messages = agent_output.get("messages", [])
step_messages = []
step_tool_calls = []
for msg in messages:
msg_text = _extract_message_text(msg)
msg_entry = {
"type": getattr(msg, "type", "assistant"),
"content": msg_text[:500], # truncate for safety
}
step_messages.append(msg_entry)
# Capture tool calls from this agent message
if hasattr(msg, "tool_calls") and msg.tool_calls:
for tc in msg.tool_calls:
tc_entry = {
"name": tc.get("name"),
"args": tc.get("args", {}),
}
tool_calls.append(tc_entry)
step_tool_calls.append(tc_entry)
debug_steps.append(
{
"node": "agent",
"messages": step_messages,
"tool_calls": step_tool_calls,
}
)
# Tool node: actual tool execution results
if "tools" in event:
tools_output = event["tools"]
messages = tools_output.get("messages", [])
step_tool_results = []
for i, msg in enumerate(messages):
content_text = _extract_message_text(msg)
result_preview = content_text[:500] + ("..." if len(content_text) > 500 else "")
if i < len(tool_calls):
tool_calls[i]["result"] = result_preview
step_tool_results.append(
{
"content": result_preview,
}
)
debug_steps.append(
{
"node": "tools",
"results": step_tool_results,
}
)
# Get final state
final_state = self.graph.get_state(config)
final_message = final_state.values["messages"][-1]
response_text = _extract_message_text(final_message)
logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls")
return {
"response": response_text,
"tool_calls": tool_calls,
"debug_steps": debug_steps,
"error": False,
}
except Exception as e:
logger.error(f"Error in agent chat: {e}", exc_info=True)
return {
"response": f"I apologize, I encountered an error: {str(e)}",
"error": True,
}
def get_conversation_history(self) -> list:
"""Get conversation history for this session"""
try:
config = {"configurable": {"thread_id": self.session_id}}
state = self.graph.get_state(config)
if not state or not state.values.get("messages"):
return []
messages = state.values["messages"]
result = []
for msg in messages:
# Skip system messages and tool messages
if isinstance(msg, SystemMessage):
continue
if hasattr(msg, "type") and msg.type in ["system", "tool"]:
continue
role = "user" if msg.type == "human" else "assistant"
result.append({"role": role, "content": _extract_message_text(msg)})
return result
except Exception as e:
logger.error(f"Error getting history: {e}")
return []
def clear_history(self):
"""Clear conversation history for this session"""
# With MemorySaver, we can't easily clear, but we can log
logger.info(f"[{self.session_id}] History clear requested")
# In production, implement proper clearing or use new thread_id
def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent:
"""Factory function to create a shopping agent"""
return ShoppingAgent(session_id=session_id)