b401ef94
tangwang
third-party/xinfe...
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
|
#!/usr/bin/env python3
"""
电商搜索实战示例
演示如何使用 Qwen3-Embedding 和 Qwen3-Reranker 构建两阶段搜索系统
"""
import time
from typing import List, Tuple
from xinference_client import RESTfulClient as Client
class EcommerceSearchEngine:
"""电商搜索引擎"""
def __init__(self, host="http://localhost:9997"):
"""
初始化搜索引擎
Args:
host: Xinference 服务地址
"""
print("🔗 连接到 Xinference 服务...")
self.client = Client(host)
self.embedding_model = None
self.reranker_model = None
print("✅ 连接成功!\n")
def load_models(self, embedding_uid=None, reranker_uid=None):
"""
加载模型
Args:
embedding_uid: Embedding 模型 UID
reranker_uid: Reranker 模型 UID
"""
# 列出所有模型
models = self.client.list_models()
model_dict = {m.get("model_type"): m.get("model_uid") for m in models}
# 使用提供的 UID 或自动查找
self.embedding_uid = embedding_uid or model_dict.get("embedding")
self.reranker_uid = reranker_uid or model_dict.get("rerank")
if not self.embedding_uid:
raise ValueError("❌ 未找到 Embedding 模型,请先运行: python deploy_models.py")
if not self.reranker_uid:
raise ValueError("❌ 未找到 Reranker 模型,请先运行: python deploy_models.py")
print(f"📦 加载 Embedding 模型: {self.embedding_uid}")
self.embedding_model = self.client.get_model(self.embedding_uid)
print("✅ Embedding 模型加载完成\n")
print(f"📦 加载 Reranker 模型: {self.reranker_uid}")
self.reranker_model = self.client.get_model(self.reranker_uid)
print("✅ Reranker 模型加载完成\n")
def dense_retrieval(self, query: str, candidates: List[str], top_k: int = 200) -> List[Tuple[str, float]]:
"""
密集检索阶段(第一阶:粗筛)
在实际生产环境中,这里会使用 Faiss 或向量数据库进行 ANN 搜索
从百万级商品中快速召回 Top-K 候选
Args:
query: 用户查询
candidates: 候选商品列表
top_k: 返回的数量
Returns:
[(商品, 相似度分数), ...]
"""
start_time = time.time()
# 生成 query 向量
query_embedding = self.embedding_model.create_embedding(query)["data"][0]["embedding"]
# 为所有候选商品生成向量
# 注意:生产环境中这些向量应该预计算并存储在向量数据库中
candidate_embeddings = []
for product in candidates:
emb = self.embedding_model.create_embedding(product)["data"][0]["embedding"]
candidate_embeddings.append((product, emb))
# 计算余弦相似度(简化版,生产环境使用 Faiss)
import numpy as np
query_vec = np.array(query_embedding)
query_vec = query_vec / np.linalg.norm(query_vec) # 归一化
similarities = []
for product, emb in candidate_embeddings:
emb_vec = np.array(emb)
emb_vec = emb_vec / np.linalg.norm(emb_vec)
similarity = float(np.dot(query_vec, emb_vec))
similarities.append((product, similarity))
# 按 similarity 排序,取 Top-K
similarities.sort(key=lambda x: x[1], reverse=True)
top_results = similarities[:top_k]
elapsed = time.time() - start_time
print(f"⏱️ 密集检索耗时: {elapsed:.2f}秒")
return top_results
def cross_encoder_rerank(self, query: str, candidates: List[str]) -> List[Tuple[str, float]]:
"""
精排阶段(第二阶:细排)
使用 Cross-Encoder 对密集检索的结果进行精确打分
Args:
query: 用户查询
candidates: 候选商品列表
Returns:
[(商品, 相关性分数), ...]
"""
start_time = time.time()
# 构建 query-document 对
pairs = [(query, product) for product in candidates]
# 批量打分
rerank_results = self.reranker_model.rerank(pairs)
# 组合结果
results = list(zip(candidates, rerank_results))
# 按相关性分数排序
results.sort(key=lambda x: x[1]["relevance_score"], reverse=True)
elapsed = time.time() - start_time
print(f"⏱️ 精排耗时: {elapsed:.2f}秒")
return results
def search(self, query: str, product_catalog: List[str], top_k: int = 10) -> List[Tuple[str, float]]:
"""
完整的两阶段搜索流程
Args:
query: 用户查询
product_catalog: 商品目录(假设有数万到数百万商品)
top_k: 最终返回的结果数
Returns:
[(商品, 相关性分数), ...]
"""
print(f"\n{'='*70}")
print(f"🔍 搜索查询: {query}")
print(f"{'='*70}\n")
# 阶段1:密集检索召回 Top-200
print("📊 阶段1: 密集检索(召回 Top-200)")
print("-" * 70)
recall_top_k = min(200, len(product_catalog))
retrieved = self.dense_retrieval(query, product_catalog, top_k=recall_top_k)
retrieved_products = [p for p, s in retrieved]
print(f"✅ 召回 {len(retrieved)} 个候选商品\n")
# 阶段2:Cross-Encoder 精排
print("🎯 阶段2: 精排(Cross-Encoder 打分)")
print("-" * 70)
reranked = self.cross_encoder_rerank(query, retrieved_products)
# 取最终 Top-K
final_results = reranked[:top_k]
return final_results
def demo_ecommerce_search():
"""电商搜索演示"""
print("\n" + "="*70)
print(" 🛒 电商搜索实战演示 - Qwen3 双塔架构")
print("="*70 + "\n")
# 初始化搜索引擎
engine = EcommerceSearchEngine(host="http://localhost:9997")
# 加载模型
print("⏳ 加载模型...")
engine.load_models()
# 模拟商品数据库(实际应用中可能有数百万商品)
product_catalog = [
"红米Note12 5000mAh大电量 6.67英寸大屏 老人模式",
"iPhone 15 Pro Max 专业摄影旗舰 A17芯片",
"华为畅享60 6000mAh超长续航 护眼大屏 鸿蒙系统",
"OPPO A1 5000mAh电池 简易模式适合长辈",
"小米手环8 智能运动监测 血氧心率",
"vivo Y78 5000mAh大电池 120Hz高刷屏",
"三星Galaxy A54 5000mAh 防水防尘",
"荣耀Play7T 6000mAh巨量电池 双卡双待",
"真我11 Pro 2亿像素 100W快充",
"诺基亚C31 5050mAh电池 耐用三防",
"联想拯救者Y70 8GB+256GB 骁龙8+",
"摩托罗拉edge S30 骁龙888+ 144Hz",
"一加Ace 2V 天玑9000 80W快充",
"iQOO Neo8 独显芯片双芯 120W闪充",
"Redmi K60 2K高光屏 骁龙8+",
"华为Mate 60 Pro 卫星通信 鸿蒙4.0",
"iPhone 14 128GB A15芯片",
"OPPO Reno10 Pro 人像镜头 100W",
"vivo X90s 天玑9200+ 蔡司影像",
"小米13 Pro 徕卡光学镜头",
]
# 测试查询
test_queries = [
"适合老人用的智能手机大屏幕长续航",
"拍照效果好的手机推荐",
"性价比高的游戏手机",
]
# 执行搜索
for query in test_queries:
results = engine.search(query, product_catalog, top_k=5)
# 显示结果
print(f"\n🎯 搜索结果 (Top 5):")
print("-" * 70)
for i, (product, score) in enumerate(results, 1):
print(f"{i}. [{score['relevance_score']:.4f}] {product}")
print()
# 性能统计
print("\n" + "="*70)
print("📊 生产环境部署建议")
print("="*70)
print("""
1. 离线批量处理:
- 每天凌晨使用 Qwen3-Embedding 为全量商品生成向量
- 存储到 Milvus/Pinecone 等向量数据库
- 预计耗时: 2亿商品约 4-6 小时
2. 在线实时搜索:
- 用户 query 实时生成 embedding
- 向量数据库 ANN 检索召回 Top-1000 (耗时 < 50ms)
- Qwen3-Reranker 精排 Top-1000 → Top-50 (耗时 < 200ms)
- 总体延迟: < 300ms
3. 缓存优化:
- Top 10000 热搜 query 的 embedding 和结果缓存到 Redis
- QPS 提升 10-20 倍
4. 混合检索:
- 结合 BM25 关键词召回(头部 Query 准确率更高)
- 向量召回 + 关键词召回 → 合并去重 → 精排
5. A/B 测试建议:
- 对照组: 纯 BM25 或传统 embedding 模型
- 实验组: Qwen3-Embedding + Qwen3-Reranker
- 核心指标: CTR, CVR, GMV, 用户停留时间
""")
def demo_simple_usage():
"""简单的使用示例"""
print("\n" + "="*70)
print(" 📝 快速使用示例")
print("="*70 + "\n")
# 连接到服务
client = Client("http://localhost:9997")
# 列出可用模型
models = client.list_models()
print("可用模型:")
for model in models:
print(f" - {model.get('model_type')}: {model.get('model_uid')}")
# 假设模型已部署
if models:
embedding_model = next((m for m in models if m.get("model_type") == "embedding"), None)
reranker_model = next((m for m in models if m.get("model_type") == "rerank"), None)
if embedding_model:
print(f"\n使用 Embedding 模型: {embedding_model['model_uid']}")
model = client.get_model(embedding_model['model_uid'])
result = model.create_embedding("测试文本")
print(f"向量维度: {len(result['data'][0]['embedding'])}")
if reranker_model:
print(f"\n使用 Reranker 模型: {reranker_model['model_uid']}")
model = client.get_model(reranker_model['model_uid'])
query = "适合老人用的智能手机"
docs = ["华为畅享60 6000mAh", "小米手环8"]
result = model.rerank([(query, d) for d in docs])
for doc, score in zip(docs, result):
print(f" [{score['relevance_score']:.4f}] {doc}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="电商搜索实战演示")
parser.add_argument("--host", default="http://localhost:9997", help="Xinference 服务地址")
parser.add_argument("--simple", action="store_true", help="运行简单示例")
parser.add_argument("--embedding", help="指定 Embedding 模型 UID")
parser.add_argument("--reranker", help="指定 Reranker 模型 UID")
args = parser.parse_args()
try:
if args.simple:
demo_simple_usage()
else:
demo_ecommerce_search()
except Exception as e:
print(f"\n❌ 错误: {e}")
print("\n💡 请确保:")
print(" 1. Xinference 服务正在运行: ./start.sh")
print(" 2. 模型已部署: python deploy_models.py")
import sys
sys.exit(1)
|