Blame view

app/agents/shopping_agent.py 15.2 KB
e7f2b240   tangwang   first commit
1
2
  """
  Conversational Shopping Agent with LangGraph
66442668   tangwang   feat: 搜索结果引用与并行搜索...
3
4
5
6
7
8
  
  Architecture:
  - ReAct-style agent: plan  search  evaluate  re-plan or respond
  - search_products is session-bound, writing curated results to SearchResultRegistry
  - Final AI message references results via [SEARCH_REF:xxx] tokens instead of
    re-listing product details; the UI renders product cards from the registry
e7f2b240   tangwang   first commit
9
10
  """
  
825828c4   tangwang   fix: search image...
11
  import json
e7f2b240   tangwang   first commit
12
13
  import logging
  from pathlib import Path
825828c4   tangwang   fix: search image...
14
  from typing import Any, Optional, Sequence
e7f2b240   tangwang   first commit
15
16
17
18
19
20
21
22
23
24
  
  from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
  from langchain_openai import ChatOpenAI
  from langgraph.checkpoint.memory import MemorySaver
  from langgraph.graph import END, START, StateGraph
  from langgraph.graph.message import add_messages
  from langgraph.prebuilt import ToolNode
  from typing_extensions import Annotated, TypedDict
  
  from app.config import settings
66442668   tangwang   feat: 搜索结果引用与并行搜索...
25
  from app.search_registry import global_registry
e7f2b240   tangwang   first commit
26
27
28
29
  from app.tools.search_tools import get_all_tools
  
  logger = logging.getLogger(__name__)
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
30
31
32
33
34
35
  # ── System prompt ──────────────────────────────────────────────────────────────
  # Universal: works for any e-commerce vertical (fashion, electronics, home, etc.)
  # Key design decisions:
  #   1. Guides multi-query search planning with explicit evaluate-and-decide loop
  #   2. Forbids re-listing product details in the final response
  #   3. Mandates [SEARCH_REF:xxx] inline citation as the only product presentation mechanism
825828c4   tangwang   fix: search image...
36
  SYSTEM_PROMPT = """角色定义
66442668   tangwang   feat: 搜索结果引用与并行搜索...
37
38
39
40
  你是一名专业的服装电商导购,是一个善于倾听、主动引导、懂得搭配的“时尚顾问”,通过有温度的对话,给用户提供有价值的信息,包括需求引导、方案推荐、搜索结果推荐,最终促成满意的购物决策或转化行为。
  
  一些原则:
  1. 你是一个真人导购,是一个贴心、专业的销售,保持灵活,根据上下文,基于常识灵活的切换策略,在合适的上下文询问合适的问题、给出有价值的方案和搜索结果的呈现。
825828c4   tangwang   fix: search image...
41
42
43
  2. 商品搜索结果推荐与信息收集:
    1. 根据上下文、用户诉求,灵活的切换侧重点,何时需要进行搜索、何时要引导客户完善需求,你需要站在用户角度进行思考。比如已经有较为清晰的意图,则以搜索、方案推荐为主,有必要的时候,思考该方向下重要的决策因素,进行提议和问题收集,让用户既得到相关信息、又得到下一步的方向引导、同时也有机会修正或者细化诉求。如果存在重大的需求方向缺口,主动通过1-2个关键问题进行引导,并提供初步方向。
    2. 适时的提供有价值的信息,如商品推荐、穿搭建议、趋势信息,在推荐方向上有需求缺口、需要明确的重要信息时,要适时的做“信息收集”,引导式的帮助用户更清晰的呈现需求、提高商品发现的效率,形成“提供-反馈”的良性循环。
66442668   tangwang   feat: 搜索结果引用与并行搜索...
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    3. 对于复杂需求时,要能基于上下文,将导购任务进行合理拆解。
  3. 引导或者收集需求时,需要站在用户立场,比如询问用户期待的效果或感觉、使用的场合、偏好的风格等用户立场需,而不是询问具体的款式或参数,你需要将用户立场的需求理解/翻译/转化为具体的搜索计划,最后筛选产品、结合需求+结果特性组织推荐理由、呈现方案。
  4. 如何使用search_products:在需要搜索商品的时候,可以将需求分解为 2-4 个搜索查询,每个 query 聚焦一个明确的商品子类或搜索角度。每次调用 search_products 后,工具会返回以下内容,你需要决策是否要调整搜索策略,比如结果质量太差,可能需要调整搜索词、或者加大试探的query数量(不要超过3-5个)。可以进行多轮搜索,但是要适时的总结和反馈信息避免用户等待过长时间:
    - 各层级数量:完美匹配 / 部分匹配 / 不相关 的条数
    - 整体质量判断:优质 / 一般 / 较差
    - 简短质量说明
    - 结果引用标识:[SEARCH_REF:xxx]
  5. 撰写最终回复的时候,使用 [SEARCH_REF:xxx] 内联引用
    1. 用自然流畅的语言组织回复,将 [SEARCH_REF:xxx] 嵌入叙述中
    2. 系统会自动在 [SEARCH_REF:xxx] 位置渲染对应的商品卡片列表
    3. 禁止在回复文本中列出商品名称、ID、价格、分类、规格等字段
    4. 禁止用编号列表逐条复述搜索结果中的商品
  """
  
  
  # ── Agent state ────────────────────────────────────────────────────────────────
  
  class AgentState(TypedDict):
      messages: Annotated[Sequence[BaseMessage], add_messages]
      current_image_path: Optional[str]
  
  
  # ── Helper ─────────────────────────────────────────────────────────────────────
e7f2b240   tangwang   first commit
67
  
825828c4   tangwang   fix: search image...
68
69
70
71
72
  # Max length for logging single content field (avoid huge logs)
  _LOG_CONTENT_MAX = 8000
  _LOG_TOOL_RESULT_MAX = 4000
  
  
e7f2b240   tangwang   first commit
73
  def _extract_message_text(msg) -> str:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
74
      """Extract plain text from a LangChain message (handles str or content_blocks)."""
e7f2b240   tangwang   first commit
75
76
77
78
79
80
81
      content = getattr(msg, "content", "")
      if isinstance(content, str):
          return content
      if isinstance(content, list):
          parts = []
          for block in content:
              if isinstance(block, dict):
66442668   tangwang   feat: 搜索结果引用与并行搜索...
82
                  parts.append(block.get("text") or block.get("content") or "")
e7f2b240   tangwang   first commit
83
84
85
86
87
88
              else:
                  parts.append(str(block))
          return "".join(str(p) for p in parts)
      return str(content) if content else ""
  
  
825828c4   tangwang   fix: search image...
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  def _message_for_log(msg: BaseMessage) -> dict:
      """Serialize a message for structured logging (content truncated)."""
      text = _extract_message_text(msg)
      if len(text) > _LOG_CONTENT_MAX:
          text = text[:_LOG_CONTENT_MAX] + f"... [truncated, total {len(text)} chars]"
      out: dict[str, Any] = {
          "type": getattr(msg, "type", "unknown"),
          "content": text,
      }
      if hasattr(msg, "tool_calls") and msg.tool_calls:
          out["tool_calls"] = [
              {"name": tc.get("name"), "args": tc.get("args", {})}
              for tc in msg.tool_calls
          ]
      return out
  
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
106
  # ── Agent class ────────────────────────────────────────────────────────────────
e7f2b240   tangwang   first commit
107
  
e7f2b240   tangwang   first commit
108
  class ShoppingAgent:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
109
      """ReAct shopping agent with search-evaluate-decide loop and registry-based result referencing."""
e7f2b240   tangwang   first commit
110
111
112
113
  
      def __init__(self, session_id: Optional[str] = None):
          self.session_id = session_id or "default"
  
8810a6fa   tangwang   重构
114
          llm_kwargs = dict(
e7f2b240   tangwang   first commit
115
116
117
118
              model=settings.openai_model,
              temperature=settings.openai_temperature,
              api_key=settings.openai_api_key,
          )
8810a6fa   tangwang   重构
119
120
          if settings.openai_api_base_url:
              llm_kwargs["base_url"] = settings.openai_api_base_url
bad17b15   tangwang   调通baseline
121
  
8810a6fa   tangwang   重构
122
          self.llm = ChatOpenAI(**llm_kwargs)
e7f2b240   tangwang   first commit
123
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
124
125
          # Tools are session-bound so search_products writes to the right registry partition
          self.tools = get_all_tools(session_id=self.session_id, registry=global_registry)
e7f2b240   tangwang   first commit
126
127
          self.llm_with_tools = self.llm.bind_tools(self.tools)
  
e7f2b240   tangwang   first commit
128
          self.graph = self._build_graph()
66442668   tangwang   feat: 搜索结果引用与并行搜索...
129
          logger.info(f"ShoppingAgent ready — session={self.session_id}")
e7f2b240   tangwang   first commit
130
131
  
      def _build_graph(self):
e7f2b240   tangwang   first commit
132
          def agent_node(state: AgentState):
e7f2b240   tangwang   first commit
133
              messages = state["messages"]
e7f2b240   tangwang   first commit
134
              if not any(isinstance(m, SystemMessage) for m in messages):
66442668   tangwang   feat: 搜索结果引用与并行搜索...
135
                  messages = [SystemMessage(content=SYSTEM_PROMPT)] + list(messages)
825828c4   tangwang   fix: search image...
136
137
138
139
140
              request_log = [_message_for_log(m) for m in messages]
              req_json = json.dumps(request_log, ensure_ascii=False)
              if len(req_json) > _LOG_CONTENT_MAX:
                  req_json = req_json[:_LOG_CONTENT_MAX] + f"... [truncated total {len(req_json)}]"
              logger.info("[%s] LLM_REQUEST messages=%s", self.session_id, req_json)
e7f2b240   tangwang   first commit
141
              response = self.llm_with_tools.invoke(messages)
825828c4   tangwang   fix: search image...
142
143
144
145
146
147
              response_log = _message_for_log(response)
              logger.info(
                  "[%s] LLM_RESPONSE %s",
                  self.session_id,
                  json.dumps(response_log, ensure_ascii=False),
              )
e7f2b240   tangwang   first commit
148
149
              return {"messages": [response]}
  
e7f2b240   tangwang   first commit
150
          def should_continue(state: AgentState):
66442668   tangwang   feat: 搜索结果引用与并行搜索...
151
152
              last = state["messages"][-1]
              if hasattr(last, "tool_calls") and last.tool_calls:
e7f2b240   tangwang   first commit
153
                  return "tools"
e7f2b240   tangwang   first commit
154
155
              return END
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
156
          tool_node = ToolNode(self.tools)
e7f2b240   tangwang   first commit
157
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
158
          workflow = StateGraph(AgentState)
e7f2b240   tangwang   first commit
159
160
          workflow.add_node("agent", agent_node)
          workflow.add_node("tools", tool_node)
e7f2b240   tangwang   first commit
161
162
163
164
          workflow.add_edge(START, "agent")
          workflow.add_conditional_edges("agent", should_continue, ["tools", END])
          workflow.add_edge("tools", "agent")
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
165
          return workflow.compile(checkpointer=MemorySaver())
e7f2b240   tangwang   first commit
166
167
  
      def chat(self, query: str, image_path: Optional[str] = None) -> dict:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
168
169
          """
          Process a user query and return the agent response with metadata.
e7f2b240   tangwang   first commit
170
171
  
          Returns:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
172
173
174
175
176
177
              dict with keys:
                response       final AI message text (may contain [SEARCH_REF:xxx] tokens)
                tool_calls     list of {name, args, result_preview}
                debug_steps    detailed per-node step log
                search_refs    dict[ref_id  SearchResult] for all searches this turn
                error          bool
e7f2b240   tangwang   first commit
178
179
          """
          try:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
180
              logger.info(f"[{self.session_id}] chat: {query!r} image={bool(image_path)}")
e7f2b240   tangwang   first commit
181
  
e7f2b240   tangwang   first commit
182
183
              if image_path and not Path(image_path).exists():
                  return {
66442668   tangwang   feat: 搜索结果引用与并行搜索...
184
                      "response": f"错误:图片文件不存在:{image_path}",
e7f2b240   tangwang   first commit
185
186
187
                      "error": True,
                  }
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
188
189
190
              # Snapshot registry before the turn so we can report new additions
              registry_before = set(global_registry.get_all(self.session_id).keys())
  
e7f2b240   tangwang   first commit
191
192
              message_content = query
              if image_path:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
193
                  message_content = f"{query}\n[用户上传了图片:{image_path}]"
e7f2b240   tangwang   first commit
194
  
e7f2b240   tangwang   first commit
195
196
197
198
199
200
              config = {"configurable": {"thread_id": self.session_id}}
              input_state = {
                  "messages": [HumanMessage(content=message_content)],
                  "current_image_path": image_path,
              }
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
201
202
203
              tool_calls: list[dict] = []
              debug_steps: list[dict] = []
  
e7f2b240   tangwang   first commit
204
              for event in self.graph.stream(input_state, config=config):
66442668   tangwang   feat: 搜索结果引用与并行搜索...
205
                  logger.debug(f"[{self.session_id}] event keys: {list(event.keys())}")
01b46131   tangwang   流程跑通
206
  
e7f2b240   tangwang   first commit
207
                  if "agent" in event:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
208
209
210
                      agent_out = event["agent"]
                      step_msgs: list[dict] = []
                      step_tcs: list[dict] = []
01b46131   tangwang   流程跑通
211
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
212
213
214
                      for msg in agent_out.get("messages", []):
                          text = _extract_message_text(msg)
                          step_msgs.append({
01b46131   tangwang   流程跑通
215
                              "type": getattr(msg, "type", "assistant"),
66442668   tangwang   feat: 搜索结果引用与并行搜索...
216
217
                              "content": text[:500],
                          })
01b46131   tangwang   流程跑通
218
219
                          if hasattr(msg, "tool_calls") and msg.tool_calls:
                              for tc in msg.tool_calls:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
220
221
222
                                  entry = {"name": tc.get("name"), "args": tc.get("args", {})}
                                  tool_calls.append(entry)
                                  step_tcs.append(entry)
01b46131   tangwang   流程跑通
223
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
224
                      debug_steps.append({"node": "agent", "messages": step_msgs, "tool_calls": step_tcs})
01b46131   tangwang   流程跑通
225
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
226
227
228
229
                  if "tools" in event:
                      tools_out = event["tools"]
                      step_results: list[dict] = []
                      msgs = tools_out.get("messages", [])
01b46131   tangwang   流程跑通
230
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
231
232
233
234
235
236
237
                      # Match results back to tool_calls by position within this event
                      unresolved = [tc for tc in tool_calls if "result" not in tc]
                      for i, msg in enumerate(msgs):
                          text = _extract_message_text(msg)
                          preview = text[:600] + ("…" if len(text) > 600 else "")
                          if i < len(unresolved):
                              unresolved[i]["result"] = preview
825828c4   tangwang   fix: search image...
238
239
240
241
242
243
244
245
246
247
                              tc_name = unresolved[i].get("name", "")
                              tc_args = unresolved[i].get("args", {})
                              result_log = text if len(text) <= _LOG_TOOL_RESULT_MAX else text[:_LOG_TOOL_RESULT_MAX] + f"... [truncated total {len(text)}]"
                              logger.info(
                                  "[%s] TOOL_CALL_RESULT name=%s args=%s result=%s",
                                  self.session_id,
                                  tc_name,
                                  json.dumps(tc_args, ensure_ascii=False),
                                  result_log,
                              )
66442668   tangwang   feat: 搜索结果引用与并行搜索...
248
                          step_results.append({"content": preview})
01b46131   tangwang   流程跑通
249
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
250
                      debug_steps.append({"node": "tools", "results": step_results})
e7f2b240   tangwang   first commit
251
  
e7f2b240   tangwang   first commit
252
              final_state = self.graph.get_state(config)
66442668   tangwang   feat: 搜索结果引用与并行搜索...
253
254
255
256
257
258
259
260
261
262
              final_msg = final_state.values["messages"][-1]
              response_text = _extract_message_text(final_msg)
  
              # Collect new SearchResults added during this turn
              registry_after = global_registry.get_all(self.session_id)
              new_refs = {
                  ref_id: result
                  for ref_id, result in registry_after.items()
                  if ref_id not in registry_before
              }
e7f2b240   tangwang   first commit
263
  
66442668   tangwang   feat: 搜索结果引用与并行搜索...
264
265
266
              logger.info(
                  f"[{self.session_id}] done — tool_calls={len(tool_calls)}, new_refs={list(new_refs.keys())}"
              )
e7f2b240   tangwang   first commit
267
268
269
270
  
              return {
                  "response": response_text,
                  "tool_calls": tool_calls,
01b46131   tangwang   流程跑通
271
                  "debug_steps": debug_steps,
66442668   tangwang   feat: 搜索结果引用与并行搜索...
272
                  "search_refs": new_refs,
e7f2b240   tangwang   first commit
273
274
275
276
                  "error": False,
              }
  
          except Exception as e:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
277
              logger.error(f"[{self.session_id}] chat error: {e}", exc_info=True)
e7f2b240   tangwang   first commit
278
              return {
66442668   tangwang   feat: 搜索结果引用与并行搜索...
279
280
281
282
                  "response": f"抱歉,处理您的请求时遇到错误:{e}",
                  "tool_calls": [],
                  "debug_steps": [],
                  "search_refs": {},
e7f2b240   tangwang   first commit
283
284
285
286
                  "error": True,
              }
  
      def get_conversation_history(self) -> list:
e7f2b240   tangwang   first commit
287
288
289
          try:
              config = {"configurable": {"thread_id": self.session_id}}
              state = self.graph.get_state(config)
e7f2b240   tangwang   first commit
290
291
292
              if not state or not state.values.get("messages"):
                  return []
  
e7f2b240   tangwang   first commit
293
              result = []
66442668   tangwang   feat: 搜索结果引用与并行搜索...
294
              for msg in state.values["messages"]:
e7f2b240   tangwang   first commit
295
296
                  if isinstance(msg, SystemMessage):
                      continue
66442668   tangwang   feat: 搜索结果引用与并行搜索...
297
                  if getattr(msg, "type", None) in ("system", "tool"):
e7f2b240   tangwang   first commit
298
                      continue
e7f2b240   tangwang   first commit
299
300
                  role = "user" if msg.type == "human" else "assistant"
                  result.append({"role": role, "content": _extract_message_text(msg)})
e7f2b240   tangwang   first commit
301
              return result
e7f2b240   tangwang   first commit
302
          except Exception as e:
66442668   tangwang   feat: 搜索结果引用与并行搜索...
303
              logger.error(f"get_conversation_history error: {e}")
e7f2b240   tangwang   first commit
304
305
306
              return []
  
      def clear_history(self):
66442668   tangwang   feat: 搜索结果引用与并行搜索...
307
          logger.info(f"[{self.session_id}] clear requested (use new session_id to fully reset)")
e7f2b240   tangwang   first commit
308
309
310
  
  
  def create_shopping_agent(session_id: Optional[str] = None) -> ShoppingAgent:
e7f2b240   tangwang   first commit
311
      return ShoppingAgent(session_id=session_id)