shopping_agent.py
9.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""
Conversational Shopping Agent with LangGraph
True ReAct agent with autonomous tool calling and message accumulation
"""
import logging
from pathlib import Path
from typing import Optional, Sequence
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from typing_extensions import Annotated, TypedDict
from app.config import settings
from app.tools.search_tools import get_all_tools
logger = logging.getLogger(__name__)
def _extract_message_text(msg) -> str:
"""Extract text from message content.
LangChain 1.0: content may be str or content_blocks (list) for multimodal."""
content = getattr(msg, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for block in content:
if isinstance(block, dict):
parts.append(block.get("text", block.get("content", "")))
else:
parts.append(str(block))
return "".join(str(p) for p in parts)
return str(content) if content else ""
class AgentState(TypedDict):
"""State for the shopping agent with message accumulation"""
messages: Annotated[Sequence[BaseMessage], add_messages]
current_image_path: Optional[str] # Track uploaded image
class ShoppingAgent:
"""True ReAct agent with autonomous decision making"""
def __init__(self, session_id: Optional[str] = None):
self.session_id = session_id or "default"
# Initialize LLM
self.llm = ChatOpenAI(
model=settings.openai_model,
temperature=settings.openai_temperature,
api_key=settings.openai_api_key,
)
# Get tools and bind to model
self.tools = get_all_tools()
self.llm_with_tools = self.llm.bind_tools(self.tools)
# Build graph
self.graph = self._build_graph()
logger.info(f"Shopping agent initialized for session: {self.session_id}")
def _build_graph(self):
"""Build the LangGraph StateGraph"""
# System prompt for the agent
system_prompt = """You are an intelligent fashion shopping assistant. You can:
1. Search for products by text description (use search_products)
2. Find visually similar products from images (use search_by_image)
3. Analyze image style and attributes (use analyze_image_style)
When a user asks about products:
- For text queries: use search_products directly
- For image uploads: decide if you need to analyze_image_style first, then search
- You can call multiple tools in sequence if needed
- Always provide helpful, friendly responses
CRITICAL FORMATTING RULES:
When presenting product results, you MUST use this EXACT format for EACH product:
1. [Product Name]
ID: [Product ID Number]
Category: [Category]
Color: [Color]
Gender: [Gender]
(Include Season, Usage, Relevance if available)
Example:
1. Puma Men White 3/4 Length Pants
ID: 12345
Category: Apparel > Bottomwear > Track Pants
Color: White
Gender: Men
Season: Summer
Usage: Sports
Relevance: 95.2%
DO NOT skip the ID field! It is essential for displaying product images.
Be conversational in your introduction, but preserve the exact product format."""
def agent_node(state: AgentState):
"""Agent decision node - decides which tools to call or when to respond"""
messages = state["messages"]
# Add system prompt if first message
if not any(isinstance(m, SystemMessage) for m in messages):
messages = [SystemMessage(content=system_prompt)] + list(messages)
# Handle image context
if state.get("current_image_path"):
# Inject image path context for tool calls
# The agent can reference this in its reasoning
pass
# Invoke LLM with tools
response = self.llm_with_tools.invoke(messages)
return {"messages": [response]}
# Create tool node
tool_node = ToolNode(self.tools)
def should_continue(state: AgentState):
"""Determine if agent should continue or end"""
messages = state["messages"]
last_message = messages[-1]
# If LLM made tool calls, continue to tools
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
# Otherwise, end (agent has final response)
return END
# Build graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue, ["tools", END])
workflow.add_edge("tools", "agent")
# Compile with memory
checkpointer = MemorySaver()
return workflow.compile(checkpointer=checkpointer)
def chat(self, query: str, image_path: Optional[str] = None) -> dict:
"""Process user query with the agent
Args:
query: User's text query
image_path: Optional path to uploaded image
Returns:
Dict with response and metadata
"""
try:
logger.info(
f"[{self.session_id}] Processing: '{query}' (image={'Yes' if image_path else 'No'})"
)
# Validate image
if image_path and not Path(image_path).exists():
return {
"response": f"Error: Image file not found at '{image_path}'",
"error": True,
}
# Build input message
message_content = query
if image_path:
message_content = f"{query}\n[User uploaded image: {image_path}]"
# Invoke agent
config = {"configurable": {"thread_id": self.session_id}}
input_state = {
"messages": [HumanMessage(content=message_content)],
"current_image_path": image_path,
}
# Track tool calls
tool_calls = []
# Stream events to capture tool calls
for event in self.graph.stream(input_state, config=config):
logger.info(f"Event: {event}")
# Check for agent node (tool calls)
if "agent" in event:
agent_output = event["agent"]
if "messages" in agent_output:
for msg in agent_output["messages"]:
if hasattr(msg, "tool_calls") and msg.tool_calls:
for tc in msg.tool_calls:
tool_calls.append({
"name": tc["name"],
"args": tc.get("args", {}),
})
# Check for tool node (tool results)
if "tools" in event:
tools_output = event["tools"]
if "messages" in tools_output:
for i, msg in enumerate(tools_output["messages"]):
if i < len(tool_calls):
tool_calls[i]["result"] = str(msg.content)[:200] + "..."
# Get final state
final_state = self.graph.get_state(config)
final_message = final_state.values["messages"][-1]
response_text = _extract_message_text(final_message)
logger.info(f"[{self.session_id}] Response generated with {len(tool_calls)} tool calls")
return {
"response": response_text,
"tool_calls": tool_calls,
"error": False,
}
except Exception as e:
logger.error(f"Error in agent chat: {e}", exc_info=True)
return {
"response": f"I apologize, I encountered an error: {str(e)}",
"error": True,
}
def get_conversation_history(self) -> list:
"""Get conversation history for this session"""
try:
config = {"configurable": {"thread_id": self.session_id}}
state = self.graph.get_state(config)
if not state or not state.values.get("messages"):
return []
messages = state.values["messages"]
result = []
for msg in messages:
# Skip system messages and tool messages
if isinstance(msg, SystemMessage):
continue
if hasattr(msg, "type") and msg.type in ["system", "tool"]:
continue
role = "user" if msg.type == "human" else "assistant"
result.append({"role": role, "content": _extract_message_text(msg)})
return result
except Exception as e:
logger.error(f"Error getting history: {e}")
return []
def clear_history(self):
"""Clear conversation history for this session"""
# With MemorySaver, we can't easily clear, but we can log
logger.info(f"[{self.session_id}] History clear requested")
# In production, implement proper clearing or use new thread_id
def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent:
"""Factory function to create a shopping agent"""
return ShoppingAgent(session_id=session_id)