e7f2b240
tangwang
first commit
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
|
"""
Conversational Shopping Agent with LangGraph
True ReAct agent with autonomous tool calling and message accumulation
"""
import logging
from pathlib import Path
from typing import Optional, Sequence
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from typing_extensions import Annotated, TypedDict
from app.config import settings
from app.tools.search_tools import get_all_tools
logger = logging.getLogger(__name__)
def _extract_message_text(msg) -> str:
"""Extract text from message content.
LangChain 1.0: content may be str or content_blocks (list) for multimodal."""
content = getattr(msg, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, dict):
parts.append(block.get("text", block.get("content", "")))
else:
parts.append(str(block))
return "".join(str(p) for p in parts)
return str(content) if content else ""
class AgentState(TypedDict):
"""State for the shopping agent with message accumulation"""
messages: Annotated[Sequence[BaseMessage], add_messages]
current_image_path: Optional[str] # Track uploaded image
|
bad17b15
tangwang
调通baseline
|
47
|
print("settings")
|
e7f2b240
tangwang
first commit
|
48
49
50
51
52
53
54
|
class ShoppingAgent:
"""True ReAct agent with autonomous decision making"""
def __init__(self, session_id: Optional[str] = None):
self.session_id = session_id or "default"
# Initialize LLM
|
8810a6fa
tangwang
重构
|
55
|
llm_kwargs = dict(
|
e7f2b240
tangwang
first commit
|
56
57
58
59
|
model=settings.openai_model,
temperature=settings.openai_temperature,
api_key=settings.openai_api_key,
)
|
8810a6fa
tangwang
重构
|
60
61
|
if settings.openai_api_base_url:
llm_kwargs["base_url"] = settings.openai_api_base_url
|
bad17b15
tangwang
调通baseline
|
62
63
64
65
|
print("llm_kwargs")
print(llm_kwargs)
|
8810a6fa
tangwang
重构
|
66
|
self.llm = ChatOpenAI(**llm_kwargs)
|
e7f2b240
tangwang
first commit
|
67
68
69
70
71
72
73
74
75
76
77
78
79
80
|
# Get tools and bind to model
self.tools = get_all_tools()
self.llm_with_tools = self.llm.bind_tools(self.tools)
# Build graph
self.graph = self._build_graph()
logger.info(f"Shopping agent initialized for session: {self.session_id}")
def _build_graph(self):
"""Build the LangGraph StateGraph"""
# System prompt for the agent
|
01b46131
tangwang
流程跑通
|
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
system_prompt = """你是一位智能时尚购物助手,你可以:
1. 根据文字描述搜索商品(使用 search_products)
2. 分析图片风格和属性(使用 analyze_image_style)
当用户咨询商品时:
- 文字提问:直接使用 search_products 搜索
- 图片上传:先用 analyze_image_style 理解商品,再用提取的描述调用 search_products 搜索
- 可按需连续调用多个工具
- 始终保持有用、友好的回复风格
关键格式规则:
展示商品结果时,每个商品必须严格按以下格式输出:
1. [标题 title]
ID: [商品ID]
分类: [category_path]
中文名: [title_cn](如有)
标签: [tags](如有)
示例:
|
e7f2b240
tangwang
first commit
|
101
102
|
1. Puma Men White 3/4 Length Pants
ID: 12345
|
01b46131
tangwang
流程跑通
|
103
104
105
|
分类: 服饰 > 裤装 > 运动裤
中文名: 彪马男士白色九分运动裤
标签: 运动,夏季,白色
|
e7f2b240
tangwang
first commit
|
106
|
|
01b46131
tangwang
流程跑通
|
107
108
|
不可省略 ID 字段!它是展示商品图片的关键。
介绍要口语化,但必须保持上述商品格式。"""
|
e7f2b240
tangwang
first commit
|
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
|
def agent_node(state: AgentState):
"""Agent decision node - decides which tools to call or when to respond"""
messages = state["messages"]
# Add system prompt if first message
if not any(isinstance(m, SystemMessage) for m in messages):
messages = [SystemMessage(content=system_prompt)] + list(messages)
# Handle image context
if state.get("current_image_path"):
# Inject image path context for tool calls
# The agent can reference this in its reasoning
pass
# Invoke LLM with tools
response = self.llm_with_tools.invoke(messages)
return {"messages": [response]}
# Create tool node
tool_node = ToolNode(self.tools)
def should_continue(state: AgentState):
"""Determine if agent should continue or end"""
messages = state["messages"]
last_message = messages[-1]
# If LLM made tool calls, continue to tools
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
# Otherwise, end (agent has final response)
return END
# Build graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue, ["tools", END])
workflow.add_edge("tools", "agent")
# Compile with memory
checkpointer = MemorySaver()
return workflow.compile(checkpointer=checkpointer)
def chat(self, query: str, image_path: Optional[str] = None) -> dict:
"""Process user query with the agent
Args:
query: User's text query
image_path: Optional path to uploaded image
Returns:
|
01b46131
tangwang
流程跑通
|
164
165
166
|
Dict with response and metadata, including:
- tool_calls: list of tool calls with args and (truncated) results
- debug_steps: detailed intermediate reasoning & tool execution steps
|
e7f2b240
tangwang
first commit
|
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
|
"""
try:
logger.info(
f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})"
)
# Validate image
if image_path and not Path(image_path).exists():
return {
"response": f"Error: Image file not found at '{image_path}'",
"error": True,
}
# Build input message
message_content = query
if image_path:
message_content = f"{query}\n[User uploaded image: {image_path}]"
# Invoke agent
config = {"configurable": {"thread_id": self.session_id}}
input_state = {
"messages": [HumanMessage(content=message_content)],
"current_image_path": image_path,
}
|
01b46131
tangwang
流程跑通
|
192
|
# Track tool calls (high-level) and detailed debug steps
|
e7f2b240
tangwang
first commit
|
193
|
tool_calls = []
|
01b46131
tangwang
流程跑通
|
194
|
debug_steps = []
|
e7f2b240
tangwang
first commit
|
195
|
|
01b46131
tangwang
流程跑通
|
196
|
# Stream events to capture tool calls and intermediate reasoning
|
e7f2b240
tangwang
first commit
|
197
198
|
for event in self.graph.stream(input_state, config=config):
logger.info(f"Event: {event}")
|
01b46131
tangwang
流程跑通
|
199
200
|
# Agent node: LLM reasoning & tool decisions
|
e7f2b240
tangwang
first commit
|
201
202
|
if "agent" in event:
agent_output = event["agent"]
|
01b46131
tangwang
流程跑通
|
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
|
messages = agent_output.get("messages", [])
step_messages = []
step_tool_calls = []
for msg in messages:
msg_text = _extract_message_text(msg)
msg_entry = {
"type": getattr(msg, "type", "assistant"),
"content": msg_text[:500], # truncate for safety
}
step_messages.append(msg_entry)
# Capture tool calls from this agent message
if hasattr(msg, "tool_calls") and msg.tool_calls:
for tc in msg.tool_calls:
tc_entry = {
"name": tc.get("name"),
"args": tc.get("args", {}),
}
tool_calls.append(tc_entry)
step_tool_calls.append(tc_entry)
debug_steps.append(
{
"node": "agent",
"messages": step_messages,
"tool_calls": step_tool_calls,
}
)
# Tool node: actual tool execution results
|
e7f2b240
tangwang
first commit
|
235
236
|
if "tools" in event:
tools_output = event["tools"]
|
01b46131
tangwang
流程跑通
|
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
|
messages = tools_output.get("messages", [])
step_tool_results = []
for i, msg in enumerate(messages):
content_text = _extract_message_text(msg)
result_preview = content_text[:500] + ("..." if len(content_text) > 500 else "")
if i < len(tool_calls):
tool_calls[i]["result"] = result_preview
step_tool_results.append(
{
"content": result_preview,
}
)
debug_steps.append(
{
"node": "tools",
"results": step_tool_results,
}
)
|
e7f2b240
tangwang
first commit
|
260
261
262
263
264
265
266
267
268
269
270
|
# Get final state
final_state = self.graph.get_state(config)
final_message = final_state.values["messages"][-1]
response_text = _extract_message_text(final_message)
logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls")
return {
"response": response_text,
"tool_calls": tool_calls,
|
01b46131
tangwang
流程跑通
|
271
|
"debug_steps": debug_steps,
|
e7f2b240
tangwang
first commit
|
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
|
"error": False,
}
except Exception as e:
logger.error(f"Error in agent chat: {e}", exc_info=True)
return {
"response": f"I apologize, I encountered an error: {str(e)}",
"error": True,
}
def get_conversation_history(self) -> list:
"""Get conversation history for this session"""
try:
config = {"configurable": {"thread_id": self.session_id}}
state = self.graph.get_state(config)
if not state or not state.values.get("messages"):
return []
messages = state.values["messages"]
result = []
for msg in messages:
# Skip system messages and tool messages
if isinstance(msg, SystemMessage):
continue
if hasattr(msg, "type") and msg.type in ["system", "tool"]:
continue
role = "user" if msg.type == "human" else "assistant"
result.append({"role": role, "content": _extract_message_text(msg)})
return result
except Exception as e:
logger.error(f"Error getting history: {e}")
return []
def clear_history(self):
"""Clear conversation history for this session"""
# With MemorySaver, we can't easily clear, but we can log
logger.info(f"[{self.session_id}] History clear requested")
# In production, implement proper clearing or use new thread_id
def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent:
"""Factory function to create a shopping agent"""
return ShoppingAgent(session_id=session_id)
|