Commit 775168412f4c1d83386da3d642ff77bd74ca645a

Authored by tangwang
1 parent 3d588bef

tidy embeddings

docs/DEVELOPER_GUIDE.md
@@ -315,7 +315,7 @@ services: @@ -315,7 +315,7 @@ services:
315 315
316 **重排后端协议(服务内)**:所有在 reranker 服务内加载的后端须实现 `score_with_meta(query, docs, normalize=True) -> (scores: List[float], meta: dict)`。返回的 `scores[i]` 与 `docs[i]` 一一对应;meta 至少含 `input_docs`、`usable_docs`、`elapsed_ms` 等。对外 HTTP 契约固定:`POST /rerank` 请求体 `{ "query": str, "docs": [str] }`,响应体 `{ "scores": [float], "meta": object }`;`GET /health` 返回 `status`、`model`、`backend` 等。 316 **重排后端协议(服务内)**:所有在 reranker 服务内加载的后端须实现 `score_with_meta(query, docs, normalize=True) -> (scores: List[float], meta: dict)`。返回的 `scores[i]` 与 `docs[i]` 一一对应;meta 至少含 `input_docs`、`usable_docs`、`elapsed_ms` 等。对外 HTTP 契约固定:`POST /rerank` 请求体 `{ "query": str, "docs": [str] }`,响应体 `{ "scores": [float], "meta": object }`;`GET /health` 返回 `status`、`model`、`backend` 等。
317 317
318 -**向量化后端协议(服务内)**:文本后端需支持 `encode_batch(texts, batch_size, device) -> List[ndarray]`,与 texts 一一对应;图片后端实现 `embeddings/protocols.ImageEncoderProtocol`:`encode_image_urls(urls, batch_size) -> List[Optional[ndarray]]`,与 urls 等长。 318 +**向量化后端协议(服务内)**:文本后端需支持 `encode(sentences: Union[str, List[str]], batch_size, device) -> ndarray | List[ndarray]`,单条与批量输入统一通过一个接口处理;图片后端实现 `embeddings/protocols.ImageEncoderProtocol`:`encode_image_urls(urls, batch_size) -> List[Optional[ndarray]]`,与 urls 等长。
319 319
320 **配置速查**: 320 **配置速查**:
321 321
docs/MySQL到ES文档映射说明.md
@@ -679,7 +679,7 @@ if enable_embedding and encoder and documents: @@ -679,7 +679,7 @@ if enable_embedding and encoder and documents:
679 title_doc_indices.append(i) 679 title_doc_indices.append(i)
680 680
681 if title_texts: 681 if title_texts:
682 - embeddings = encoder.encode_batch(title_texts, batch_size=32) 682 + embeddings = encoder.encode(title_texts, batch_size=32)
683 for j, emb in enumerate(embeddings): 683 for j, emb in enumerate(embeddings):
684 doc_idx = title_doc_indices[j] 684 doc_idx = title_doc_indices[j]
685 if isinstance(emb, np.ndarray): 685 if isinstance(emb, np.ndarray):
@@ -731,7 +731,7 @@ if enable_embedding and encoder and documents: @@ -731,7 +731,7 @@ if enable_embedding and encoder and documents:
731 731
732 7. **批量生成 Embedding**(如果启用) 732 7. **批量生成 Embedding**(如果启用)
733 - 收集所有文档的标题文本 733 - 收集所有文档的标题文本
734 - - 批量调用 `encoder.encode_batch()` 生成 embedding 734 + - 批量调用 `encoder.encode()`(传入 list[str])生成 embedding
735 - 填充到对应文档 735 - 填充到对应文档
736 736
737 8. **批量写入 ES** 737 8. **批量写入 ES**
embeddings/README.md
@@ -10,7 +10,7 @@ @@ -10,7 +10,7 @@
10 这个目录是一个完整的“向量化模块”,包含: 10 这个目录是一个完整的“向量化模块”,包含:
11 11
12 - **HTTP 客户端**:`text_encoder.py` / `image_encoder.py`(供搜索/索引模块调用) 12 - **HTTP 客户端**:`text_encoder.py` / `image_encoder.py`(供搜索/索引模块调用)
13 -- **本地模型实现**:`qwen3_model.py` / `clip_model.py` 13 +- **本地模型实现**:`text_embedding_sentence_transformers.py` / `clip_model.py`
14 - **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐) 14 - **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐)
15 - **向量化服务(FastAPI)**:`server.py` 15 - **向量化服务(FastAPI)**:`server.py`
16 - **统一配置**:`config.py` 16 - **统一配置**:`config.py`
embeddings/server.py
@@ -83,7 +83,7 @@ def _preview_inputs(items: List[str], max_items: int, max_chars: int) -> List[Di @@ -83,7 +83,7 @@ def _preview_inputs(items: List[str], max_items: int, max_chars: int) -> List[Di
83 83
84 def _encode_local_st(texts: List[str], normalize_embeddings: bool) -> Any: 84 def _encode_local_st(texts: List[str], normalize_embeddings: bool) -> Any:
85 with _text_encode_lock: 85 with _text_encode_lock:
86 - return _text_model.encode_batch( 86 + return _text_model.encode(
87 texts, 87 texts,
88 batch_size=int(CONFIG.TEXT_BATCH_SIZE), 88 batch_size=int(CONFIG.TEXT_BATCH_SIZE),
89 device=CONFIG.TEXT_DEVICE, 89 device=CONFIG.TEXT_DEVICE,
@@ -198,7 +198,7 @@ def load_models(): @@ -198,7 +198,7 @@ def load_models():
198 backend_name, backend_cfg = get_embedding_backend_config() 198 backend_name, backend_cfg = get_embedding_backend_config()
199 _text_backend_name = backend_name 199 _text_backend_name = backend_name
200 if backend_name == "tei": 200 if backend_name == "tei":
201 - from embeddings.tei_model import TEITextModel 201 + from embeddings.text_embedding_tei import TEITextModel
202 202
203 base_url = ( 203 base_url = (
204 os.getenv("TEI_BASE_URL") 204 os.getenv("TEI_BASE_URL")
@@ -216,7 +216,7 @@ def load_models(): @@ -216,7 +216,7 @@ def load_models():
216 timeout_sec=timeout_sec, 216 timeout_sec=timeout_sec,
217 ) 217 )
218 elif backend_name == "local_st": 218 elif backend_name == "local_st":
219 - from embeddings.qwen3_model import Qwen3TextModel 219 + from embeddings.text_embedding_sentence_transformers import Qwen3TextModel
220 220
221 model_id = ( 221 model_id = (
222 os.getenv("TEXT_MODEL_ID") 222 os.getenv("TEXT_MODEL_ID")
@@ -342,7 +342,7 @@ def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optio @@ -342,7 +342,7 @@ def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optio
342 return out 342 return out
343 embs = _encode_local_st(normalized, normalize_embeddings=False) 343 embs = _encode_local_st(normalized, normalize_embeddings=False)
344 else: 344 else:
345 - embs = _text_model.encode_batch( 345 + embs = _text_model.encode(
346 normalized, 346 normalized,
347 batch_size=int(CONFIG.TEXT_BATCH_SIZE), 347 batch_size=int(CONFIG.TEXT_BATCH_SIZE),
348 device=CONFIG.TEXT_DEVICE, 348 device=CONFIG.TEXT_DEVICE,
embeddings/qwen3_model.py renamed to embeddings/text_embedding_sentence_transformers.py
@@ -47,6 +47,7 @@ class Qwen3TextModel(object): @@ -47,6 +47,7 @@ class Qwen3TextModel(object):
47 device: str = "cuda", 47 device: str = "cuda",
48 batch_size: int = 32, 48 batch_size: int = 32,
49 ) -> np.ndarray: 49 ) -> np.ndarray:
  50 +
50 # SentenceTransformer + CUDA inference is not thread-safe in our usage; 51 # SentenceTransformer + CUDA inference is not thread-safe in our usage;
51 # keep one in-flight encode call while avoiding repeated .to(device) hops. 52 # keep one in-flight encode call while avoiding repeated .to(device) hops.
52 with self._encode_lock: 53 with self._encode_lock:
@@ -60,16 +61,3 @@ class Qwen3TextModel(object): @@ -60,16 +61,3 @@ class Qwen3TextModel(object):
60 ) 61 )
61 return embeddings 62 return embeddings
62 63
63 - def encode_batch(  
64 - self,  
65 - texts: List[str],  
66 - batch_size: int = 32,  
67 - device: str = "cuda",  
68 - normalize_embeddings: bool = True,  
69 - ) -> np.ndarray:  
70 - return self.encode(  
71 - texts,  
72 - batch_size=batch_size,  
73 - device=device,  
74 - normalize_embeddings=normalize_embeddings,  
75 - )  
embeddings/tei_model.py renamed to embeddings/text_embedding_tei.py
@@ -54,24 +54,17 @@ class TEITextModel: @@ -54,24 +54,17 @@ class TEITextModel:
54 device: str = "cuda", 54 device: str = "cuda",
55 batch_size: int = 32, 55 batch_size: int = 32,
56 ) -> np.ndarray: 56 ) -> np.ndarray:
57 - if isinstance(sentences, str):  
58 - sentences = [sentences]  
59 - return self.encode_batch(  
60 - texts=sentences,  
61 - batch_size=batch_size,  
62 - device=device,  
63 - normalize_embeddings=normalize_embeddings,  
64 - ) 57 + """
  58 + Encode a single sentence or a list of sentences.
65 59
66 - def encode_batch(  
67 - self,  
68 - texts: List[str],  
69 - batch_size: int = 32,  
70 - device: str = "cuda",  
71 - normalize_embeddings: bool = True,  
72 - ) -> np.ndarray:  
73 - del batch_size # TEI performs its own batching.  
74 - del device # Not used by HTTP backend. 60 + TEI HTTP 后端天然是批量接口,这里统一通过 encode 处理单条和批量输入,
  61 + 不再额外暴露 encode_batch。
  62 + """
  63 +
  64 + if isinstance(sentences, str):
  65 + texts: List[str] = [sentences]
  66 + else:
  67 + texts = sentences
75 68
76 if texts is None or len(texts) == 0: 69 if texts is None or len(texts) == 0:
77 return np.array([], dtype=object) 70 return np.array([], dtype=object)
embeddings/text_encoder.py
@@ -135,33 +135,8 @@ class TextEmbeddingEncoder: @@ -135,33 +135,8 @@ class TextEmbeddingEncoder:
135 else: 135 else:
136 raise ValueError(f"No embedding found for text index {original_idx}: {text[:50]}...") 136 raise ValueError(f"No embedding found for text index {original_idx}: {text[:50]}...")
137 137
138 - # 返回 numpy 数组(dtype=object),元素为 np.ndarray 或 None 138 + # 返回 numpy 数组(dtype=object),元素均为有效 np.ndarray 向量
139 return np.array(embeddings, dtype=object) 139 return np.array(embeddings, dtype=object)
140 -  
141 - def encode_batch(  
142 - self,  
143 - texts: List[str],  
144 - batch_size: int = 32,  
145 - device: str = 'cpu',  
146 - normalize_embeddings: bool = True,  
147 - ) -> np.ndarray:  
148 - """  
149 - Encode a batch of texts efficiently via network service.  
150 -  
151 - Args:  
152 - texts: List of texts to encode  
153 - batch_size: Batch size for processing  
154 - device: Device parameter ignored for service compatibility  
155 -  
156 - Returns:  
157 - numpy array of embeddings  
158 - """  
159 - return self.encode(  
160 - texts,  
161 - batch_size=batch_size,  
162 - device=device,  
163 - normalize_embeddings=normalize_embeddings,  
164 - )  
165 140
166 def _is_valid_embedding(self, embedding: np.ndarray) -> bool: 141 def _is_valid_embedding(self, embedding: np.ndarray) -> bool:
167 """ 142 """
indexer/incremental_service.py
@@ -641,7 +641,7 @@ class IncrementalIndexerService: @@ -641,7 +641,7 @@ class IncrementalIndexerService:
641 title_doc_indices.append(i) 641 title_doc_indices.append(i)
642 642
643 if title_texts: 643 if title_texts:
644 - embeddings = encoder.encode_batch(title_texts, batch_size=32) 644 + embeddings = encoder.encode(title_texts, batch_size=32)
645 if embeddings is None or len(embeddings) != len(title_texts): 645 if embeddings is None or len(embeddings) != len(title_texts):
646 raise RuntimeError( 646 raise RuntimeError(
647 f"[IncrementalIndexing] Batch embedding length mismatch for tenant_id={tenant_id}: " 647 f"[IncrementalIndexing] Batch embedding length mismatch for tenant_id={tenant_id}: "
indexer/product_enrich.py
@@ -96,6 +96,9 @@ except Exception as e: @@ -96,6 +96,9 @@ except Exception as e:
96 logger.warning(f"Failed to initialize Redis for anchors cache: {e}") 96 logger.warning(f"Failed to initialize Redis for anchors cache: {e}")
97 _anchor_redis = None 97 _anchor_redis = None
98 98
  99 +# 中文版本提示词(请勿删除):
  100 +# "你是一名电商平台的商品标注员,你的工作是对输入的每个商品进行理解、分析和标注,"
  101 +# "并按要求格式返回 Markdown 表格。所有输出内容必须为中文。"
99 102
100 SYSTEM_MESSAGES = ( 103 SYSTEM_MESSAGES = (
101 "You are a product annotator for an e-commerce platform. " 104 "You are a product annotator for an e-commerce platform. "
@@ -163,6 +166,31 @@ def create_prompt(products: List[Dict[str, str]], target_lang: str = "zh") -> st @@ -163,6 +166,31 @@ def create_prompt(products: List[Dict[str, str]], target_lang: str = "zh") -> st
163 """ 166 """
164 lang_name = SOURCE_LANG_CODE_MAP.get(target_lang, target_lang) 167 lang_name = SOURCE_LANG_CODE_MAP.get(target_lang, target_lang)
165 168
  169 +# 中文版本提示词(请勿删除)
  170 +# prompt = """请对输入的每条商品标题,分析并提取以下信息:
  171 +
  172 +# 1. 商品标题:将输入商品名称翻译为自然、完整的中文商品标题
  173 +# 2. 品类路径:从大类到细分品类,用">"分隔(例如:服装>女装>裤子>工装裤)
  174 +# 3. 细分标签:商品的风格、特点、功能等(例如:碎花,收腰,法式)
  175 +# 4. 适用人群:性别/年龄段等(例如:年轻女性)
  176 +# 5. 使用场景
  177 +# 6. 适用季节
  178 +# 7. 关键属性
  179 +# 8. 材质说明
  180 +# 9. 功能特点
  181 +# 10. 商品卖点:分析和提取一句话核心卖点,用于推荐理由
  182 +# 11. 锚文本:生成一组能够代表该商品、并可能被用户用于搜索的词语或短语。这些词语应覆盖用户需求的各个维度,如品类、细分标签、功能特性、需求场景等等。
  183 +
  184 +# 输入商品列表:
  185 +
  186 +# """
  187 +# prompt_tail = """
  188 +# 请严格按照以下markdown表格格式返回,每列内部的多值内容都用逗号分隔,不要添加任何其他说明:
  189 +
  190 +# | 序号 | 商品标题 | 品类路径 | 细分标签 | 适用人群 | 使用场景 | 适用季节 | 关键属性 | 材质说明 | 功能特点 | 商品卖点 | 锚文本 |
  191 +# |----|----|----|----|----|----|----|----|----|----|----|----|
  192 +# """
  193 +
166 prompt = """Please analyze each input product title and extract the following information: 194 prompt = """Please analyze each input product title and extract the following information:
167 195
168 1. Product title: a natural English product name derived from the input title 196 1. Product title: a natural English product name derived from the input title