e7f2b240
tangwang
first commit
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
|
# OmniShopAgent 项目技术实现报告
## 一、项目概述
OmniShopAgent 是一个基于 **LangGraph** 和 **ReAct 模式** 的自主多模态时尚购物智能体。系统能够自主决定调用哪些工具、维护对话状态、判断何时回复,实现智能化的商品发现与推荐。
### 核心特性
- **自主工具选择与执行**:Agent 根据用户意图自主选择并调用工具
- **多模态搜索**:支持文本搜索 + 图像搜索
- **对话上下文感知**:多轮对话中保持上下文记忆
- **实时视觉分析**:基于 VLM 的图片风格分析
---
## 二、技术栈
| 组件 | 技术选型 |
|------|----------|
| 运行环境 | Python 3.12 |
| Agent 框架 | LangGraph 1.x |
| LLM 框架 | LangChain 1.x(支持任意 LLM,默认 gpt-4o-mini) |
| 文本向量 | text-embedding-3-small |
| 图像向量 | CLIP ViT-B/32 |
| 向量数据库 | Milvus |
| 前端 | Streamlit |
| 数据集 | Kaggle Fashion Products |
---
## 三、系统架构
### 3.1 整体架构图
```
┌─────────────────────────────────────────────────────────────────┐
│ Streamlit 前端 (app.py) │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ ShoppingAgent (shopping_agent.py) │
│ ┌───────────────────────────────────────────────────────────┐ │
│ │ LangGraph StateGraph + ReAct Pattern │ │
│ │ START → Agent → [Has tool_calls?] → Tools → Agent → END │ │
│ └───────────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────┘
│ │ │
▼ ▼ ▼
┌──────────────┐ ┌──────────────────┐ ┌─────────────────────┐
│ search_ │ │ search_by_image │ │ analyze_image_style │
│ products │ │ │ │ (OpenAI Vision) │
└──────┬───────┘ └────────┬─────────┘ └──────────┬───────────┘
│ │ │
▼ ▼ ▼
┌─────────────────────────────────────────────────────────────────┐
│ EmbeddingService (embedding_service.py) │
│ OpenAI API (文本) │ CLIP Server (图像) │
└─────────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────────┐
│ MilvusService (milvus_service.py) │
│ text_embeddings 集合 │ image_embeddings 集合 │
└─────────────────────────────────────────────────────────────────┘
```
### 3.2 Agent 流程图(LangGraph)
```mermaid
graph LR
START --> Agent
Agent -->|Has tool_calls| Tools
Agent -->|No tool_calls| END
Tools --> Agent
```
---
## 四、关键代码实现
### 4.1 Agent 核心实现(shopping_agent.py)
#### 4.1.1 状态定义
```python
from typing_extensions import Annotated, TypedDict
from langgraph.graph.message import add_messages
class AgentState(TypedDict):
"""State for the shopping agent with message accumulation"""
messages: Annotated[Sequence[BaseMessage], add_messages]
current_image_path: Optional[str] # Track uploaded image
```
- `messages` 使用 `add_messages` 实现消息累加,支持多轮对话
- `current_image_path` 存储当前上传的图片路径供工具使用
#### 4.1.2 LangGraph 图构建
```python
def _build_graph(self):
"""Build the LangGraph StateGraph"""
def agent_node(state: AgentState):
"""Agent decision node - decides which tools to call or when to respond"""
messages = state["messages"]
if not any(isinstance(m, SystemMessage) for m in messages):
messages = [SystemMessage(content=system_prompt)] + list(messages)
response = self.llm_with_tools.invoke(messages)
return {"messages": [response]}
tool_node = ToolNode(self.tools)
def should_continue(state: AgentState):
"""Determine if agent should continue or end"""
last_message = state["messages"][-1]
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
return "tools"
return END
workflow = StateGraph(AgentState)
workflow.add_node("agent", agent_node)
workflow.add_node("tools", tool_node)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue, ["tools", END])
workflow.add_edge("tools", "agent")
checkpointer = MemorySaver()
return workflow.compile(checkpointer=checkpointer)
```
关键点:
- **agent_node**:将消息传入 LLM,由 LLM 决定是否调用工具
- **should_continue**:若有 `tool_calls` 则进入工具节点,否则结束
- **MemorySaver**:按 `thread_id` 持久化对话状态
#### 4.1.3 System Prompt 设计
```python
system_prompt = """You are an intelligent fashion shopping assistant. You can:
1. Search for products by text description (use search_products)
2. Find visually similar products from images (use search_by_image)
3. Analyze image style and attributes (use analyze_image_style)
When a user asks about products:
- For text queries: use search_products directly
- For image uploads: decide if you need to analyze_image_style first, then search
- You can call multiple tools in sequence if needed
- Always provide helpful, friendly responses
CRITICAL FORMATTING RULES:
When presenting product results, you MUST use this EXACT format for EACH product:
1. [Product Name]
ID: [Product ID Number]
Category: [Category]
Color: [Color]
Gender: [Gender]
(Include Season, Usage, Relevance if available)
..."""
```
通过 system prompt 约束工具使用和输出格式,保证前端可正确解析产品信息。
#### 4.1.4 对话入口与流式处理
```python
def chat(self, query: str, image_path: Optional[str] = None) -> dict:
# Build input message
message_content = query
if image_path:
message_content = f"{query}\n[User uploaded image: {image_path}]"
config = {"configurable": {"thread_id": self.session_id}}
input_state = {
"messages": [HumanMessage(content=message_content)],
"current_image_path": image_path,
}
tool_calls = []
for event in self.graph.stream(input_state, config=config):
if "agent" in event:
for msg in event["agent"].get("messages", []):
if hasattr(msg, "tool_calls") and msg.tool_calls:
for tc in msg.tool_calls:
tool_calls.append({"name": tc["name"], "args": tc.get("args", {})})
if "tools" in event:
# 记录工具执行结果
...
final_state = self.graph.get_state(config)
response_text = final_state.values["messages"][-1].content
return {"response": response_text, "tool_calls": tool_calls, "error": False}
```
---
### 4.2 搜索工具实现(search_tools.py)
#### 4.2.1 文本语义搜索
```python
@tool
def search_products(query: str, limit: int = 5) -> str:
"""Search for fashion products using natural language descriptions."""
try:
embedding_service = get_embedding_service()
milvus_service = get_milvus_service()
query_embedding = embedding_service.get_text_embedding(query)
results = milvus_service.search_similar_text(
query_embedding=query_embedding,
limit=min(limit, 20),
filters=None,
output_fields=[
"id", "productDisplayName", "gender", "masterCategory",
"subCategory", "articleType", "baseColour", "season", "usage",
],
)
if not results:
return "No products found matching your search."
output = f"Found {len(results)} product(s):\n\n"
for idx, product in enumerate(results, 1):
output += f"{idx}. {product.get('productDisplayName', 'Unknown Product')}\n"
output += f" ID: {product.get('id', 'N/A')}\n"
output += f" Category: {product.get('masterCategory')} > {product.get('subCategory')} > {product.get('articleType')}\n"
output += f" Color: {product.get('baseColour')}\n"
output += f" Gender: {product.get('gender')}\n"
if "distance" in product:
similarity = 1 - product["distance"]
output += f" Relevance: {similarity:.2%}\n"
output += "\n"
return output.strip()
except Exception as e:
return f"Error searching products: {str(e)}"
```
#### 4.2.2 图像相似度搜索
```python
@tool
def search_by_image(image_path: str, limit: int = 5) -> str:
"""Find similar fashion products using an image."""
if not Path(image_path).exists():
return f"Error: Image file not found at '{image_path}'"
embedding_service = get_embedding_service()
milvus_service = get_milvus_service()
if not embedding_service.clip_client:
embedding_service.connect_clip()
image_embedding = embedding_service.get_image_embedding(image_path)
results = milvus_service.search_similar_images(
query_embedding=image_embedding,
limit=min(limit + 1, 21),
output_fields=[...],
)
# 过滤掉查询图像本身(如上传的是商品库中的图)
query_id = Path(image_path).stem
filtered_results = [r for r in results if Path(r.get("image_path", "")).stem != query_id]
filtered_results = filtered_results[:limit]
```
#### 4.2.3 视觉分析(VLM)
```python
@tool
def analyze_image_style(image_path: str) -> str:
"""Analyze a fashion product image using AI vision to extract detailed style information."""
with open(img_path, "rb") as image_file:
image_data = base64.b64encode(image_file.read()).decode("utf-8")
prompt = """Analyze this fashion product image and provide a detailed description.
Include:
- Product type (e.g., shirt, dress, shoes, pants, bag)
- Primary colors
- Style/design (e.g., casual, formal, sporty, vintage, modern)
- Pattern or texture (e.g., plain, striped, checked, floral)
- Key features (e.g., collar type, sleeve length, fit)
- Material appearance (if obvious, e.g., denim, cotton, leather)
- Suitable occasion (e.g., office wear, party, casual, sports)
Provide a comprehensive yet concise description (3-4 sentences)."""
client = get_openai_client()
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}", "detail": "high"}},
],
}],
max_tokens=500,
temperature=0.3,
)
return response.choices[0].message.content.strip()
```
---
### 4.3 向量服务实现
#### 4.3.1 EmbeddingService(embedding_service.py)
```python
class EmbeddingService:
def get_text_embedding(self, text: str) -> List[float]:
"""OpenAI text-embedding-3-small"""
response = self.openai_client.embeddings.create(
input=text, model=self.text_embedding_model
)
return response.data[0].embedding
def get_image_embedding(self, image_path: Union[str, Path]) -> List[float]:
"""CLIP 图像向量"""
if not self.clip_client:
raise RuntimeError("CLIP client not connected. Call connect_clip() first.")
result = self.clip_client.encode([str(image_path)])
if isinstance(result, np.ndarray):
embedding = result[0].tolist() if len(result.shape) > 1 else result.tolist()
else:
embedding = result[0].embedding.tolist()
return embedding
def get_text_embeddings_batch(self, texts: List[str], batch_size: int = 100) -> List[List[float]]:
"""批量文本嵌入,用于索引"""
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
response = self.openai_client.embeddings.create(input=batch, ...)
embeddings = [item.embedding for item in response.data]
all_embeddings.extend(embeddings)
return all_embeddings
```
#### 4.3.2 MilvusService(milvus_service.py)
**文本集合 Schema:**
```python
schema = MilvusClient.create_schema(auto_id=False, enable_dynamic_field=True)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="text", datatype=DataType.VARCHAR, max_length=2000)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.text_dim) # 1536
schema.add_field(field_name="productDisplayName", datatype=DataType.VARCHAR, max_length=500)
schema.add_field(field_name="gender", datatype=DataType.VARCHAR, max_length=50)
schema.add_field(field_name="masterCategory", datatype=DataType.VARCHAR, max_length=100)
# ... 更多元数据字段
```
**图像集合 Schema:**
```python
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="image_path", datatype=DataType.VARCHAR, max_length=500)
schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=self.image_dim) # 512
# ... 产品元数据
```
**相似度搜索:**
```python
def search_similar_text(self, query_embedding, limit=10, output_fields=None):
results = self.client.search(
collection_name=self.text_collection_name,
data=[query_embedding],
limit=limit,
output_fields=output_fields,
)
formatted_results = []
for hit in results[0]:
result = {"id": hit.get("id"), "distance": hit.get("distance")}
entity = hit.get("entity", {})
for field in output_fields:
if field in entity:
result[field] = entity.get(field)
formatted_results.append(result)
return formatted_results
```
---
### 4.4 数据索引脚本(index_data.py)
#### 4.4.1 产品数据加载
```python
def _load_products_from_csv(self) -> Dict[int, Dict[str, Any]]:
products = {}
# 加载 images.csv 映射
with open(self.images_csv, "r") as f:
images_dict = {int(row["filename"].split(".")[0]): row["link"] for row in csv.DictReader(f)}
# 加载 styles.csv
with open(self.styles_csv, "r") as f:
for row in csv.DictReader(f):
product_id = int(row["id"])
products[product_id] = {
"id": product_id,
"gender": row.get("gender", ""),
"masterCategory": row.get("masterCategory", ""),
"subCategory": row.get("subCategory", ""),
"articleType": row.get("articleType", ""),
"baseColour": row.get("baseColour", ""),
"season": row.get("season", ""),
"usage": row.get("usage", ""),
"productDisplayName": row.get("productDisplayName", ""),
"imagePath": f"{product_id}.jpg",
}
return products
```
#### 4.4.2 文本索引
```python
def _create_product_text(self, product: Dict[str, Any]) -> str:
"""构造产品文本用于 embedding"""
parts = [
product.get("productDisplayName", ""),
f"Gender: {product.get('gender', '')}",
f"Category: {product.get('masterCategory', '')} > {product.get('subCategory', '')}",
f"Type: {product.get('articleType', '')}",
f"Color: {product.get('baseColour', '')}",
f"Season: {product.get('season', '')}",
f"Usage: {product.get('usage', '')}",
]
return " | ".join([p for p in parts if p and p != "Gender: " and p != "Color: "])
```
#### 4.4.3 批量索引流程
```python
# 文本索引
texts = [self._create_product_text(p) for p in products]
embeddings = self.embedding_service.get_text_embeddings_batch(texts, batch_size=50)
milvus_data = [{
"id": product_id,
"text": text[:2000],
"embedding": embedding,
"productDisplayName": product["productDisplayName"][:500],
"gender": product["gender"][:50],
# ... 其他元数据
} for product_id, text, embedding in zip(...)]
self.milvus_service.insert_text_embeddings(milvus_data)
# 图像索引
image_paths = [self.image_dir / p["imagePath"] for p in products]
embeddings = self.embedding_service.get_image_embeddings_batch(image_paths, batch_size=32)
# 类似插入 image_embeddings 集合
```
---
### 4.5 Streamlit 前端(app.py)
#### 4.5.1 会话与 Agent 初始化
```python
def initialize_session():
if "session_id" not in st.session_state:
st.session_state.session_id = str(uuid.uuid4())
if "shopping_agent" not in st.session_state:
st.session_state.shopping_agent = ShoppingAgent(session_id=st.session_state.session_id)
if "messages" not in st.session_state:
st.session_state.messages = []
if "uploaded_image" not in st.session_state:
st.session_state.uploaded_image = None
```
#### 4.5.2 产品信息解析
```python
def extract_products_from_response(response: str) -> list:
"""从 Agent 回复中解析产品信息"""
products = []
for line in response.split("\n"):
if re.match(r"^\*?\*?\d+\.\s+", line):
if current_product:
products.append(current_product)
current_product = {"name": re.sub(r"^\*?\*?\d+\.\s+", "", line).replace("**", "").strip()}
elif "ID:" in line:
id_match = re.search(r"(?:ID|id):\s*(\d+)", line)
if id_match:
current_product["id"] = id_match.group(1)
elif "Category:" in line:
cat_match = re.search(r"Category:\s*(.+?)(?:\n|$)", line)
if cat_match:
current_product["category"] = cat_match.group(1).strip()
# ... Color, Gender, Season, Usage, Similarity/Relevance
return products
```
#### 4.5.3 多轮对话中的图片引用
```python
# 用户输入 "make them formal" 时,若上一条消息有图片,则引用该图片
if any(ref in query_lower for ref in ["this", "that", "the image", "it"]):
for msg in reversed(st.session_state.messages):
if msg.get("role") == "user" and msg.get("image_path"):
image_path = msg["image_path"]
break
```
---
### 4.6 配置管理(config.py)
```python
class Settings(BaseSettings):
openai_api_key: str
openai_model: str = "gpt-4o-mini"
openai_embedding_model: str = "text-embedding-3-small"
clip_server_url: str = "grpc://localhost:51000"
milvus_uri: str = "http://localhost:19530"
text_collection_name: str = "text_embeddings"
image_collection_name: str = "image_embeddings"
text_dim: int = 1536
image_dim: int = 512
@property
def milvus_uri_absolute(self) -> str:
"""支持 Milvus Standalone 和 Milvus Lite"""
if self.milvus_uri.startswith(("http://", "https://")):
return self.milvus_uri
if self.milvus_uri.startswith("./"):
return os.path.join(base_dir, self.milvus_uri[2:])
return self.milvus_uri
class Config:
env_file = ".env"
```
---
## 五、部署与运行
### 5.1 依赖服务
```yaml
# docker-compose.yml 提供
- etcd: 元数据存储
- minio: 对象存储
- milvus-standalone: 向量数据库
- attu: Milvus 管理界面
```
### 5.2 启动流程
```bash
# 1. 环境
pip install -r requirements.txt
cp .env.example .env # 配置 OPENAI_API_KEY
# 2. 下载数据
python scripts/download_dataset.py # Kaggle Fashion Product Images Dataset
# 3. 启动 CLIP 服务(需单独运行)
python -m clip_server
# 4. 启动 Milvus
docker-compose up
# 5. 索引数据
python scripts/index_data.py
# 6. 启动应用
streamlit run app.py
```
---
## 六、典型交互流程
| 场景 | 用户输入 | Agent 行为 | 工具调用 |
|------|----------|------------|----------|
| 文本搜索 | "winter coats for women" | 直接文本搜索 | `search_products("winter coats women")` |
| 图像搜索 | [上传图片] "find similar" | 图像相似度搜索 | `search_by_image(path)` |
| 风格分析+搜索 | [上传复古夹克] "what style? find matching pants" | 先分析风格再搜索 | `analyze_image_style(path)` → `search_products("vintage pants casual")` |
| 多轮上下文 | [第1轮] "show me red dresses"<br>[第2轮] "make them formal" | 结合上下文 | `search_products("red formal dresses")` |
---
## 七、设计要点总结
1. **ReAct 模式**:Agent 自主决定何时调用工具、调用哪些工具、是否继续调用。
2. **LangGraph 状态图**:`START → Agent → [条件] → Tools → Agent → END`,支持多轮工具调用。
3. **多模态**:文本 + 图像 + VLM 分析,覆盖文本搜索、以图搜图、风格理解。
4. **双向量集合**:Milvus 中 text_embeddings / image_embeddings 分别存储,支持不同模态的检索。
5. **会话持久化**:`MemorySaver` + `thread_id` 实现多轮对话记忆。
6. **格式约束**:System prompt 严格限制产品输出格式,便于前端解析和展示。
---
## 八、附录:项目结构
```
OmniShopAgent/
├── app/
│ ├── agents/
│ │ └── shopping_agent.py
│ ├── config.py
│ ├── services/
│ │ ├── embedding_service.py
│ │ └── milvus_service.py
│ └── tools/
│ └── search_tools.py
├── scripts/
│ ├── download_dataset.py
│ └── index_data.py
├── app.py
├── docker-compose.yml
└── requirements.txt
```
|