Commit 775168412f4c1d83386da3d642ff77bd74ca645a

Authored by tangwang
1 parent 3d588bef

tidy embeddings

docs/DEVELOPER_GUIDE.md
... ... @@ -315,7 +315,7 @@ services:
315 315  
316 316 **重排后端协议(服务内)**:所有在 reranker 服务内加载的后端须实现 `score_with_meta(query, docs, normalize=True) -> (scores: List[float], meta: dict)`。返回的 `scores[i]` 与 `docs[i]` 一一对应;meta 至少含 `input_docs`、`usable_docs`、`elapsed_ms` 等。对外 HTTP 契约固定:`POST /rerank` 请求体 `{ "query": str, "docs": [str] }`,响应体 `{ "scores": [float], "meta": object }`;`GET /health` 返回 `status`、`model`、`backend` 等。
317 317  
318   -**向量化后端协议(服务内)**:文本后端需支持 `encode_batch(texts, batch_size, device) -> List[ndarray]`,与 texts 一一对应;图片后端实现 `embeddings/protocols.ImageEncoderProtocol`:`encode_image_urls(urls, batch_size) -> List[Optional[ndarray]]`,与 urls 等长。
  318 +**向量化后端协议(服务内)**:文本后端需支持 `encode(sentences: Union[str, List[str]], batch_size, device) -> ndarray | List[ndarray]`,单条与批量输入统一通过一个接口处理;图片后端实现 `embeddings/protocols.ImageEncoderProtocol`:`encode_image_urls(urls, batch_size) -> List[Optional[ndarray]]`,与 urls 等长。
319 319  
320 320 **配置速查**:
321 321  
... ...
docs/MySQL到ES文档映射说明.md
... ... @@ -679,7 +679,7 @@ if enable_embedding and encoder and documents:
679 679 title_doc_indices.append(i)
680 680  
681 681 if title_texts:
682   - embeddings = encoder.encode_batch(title_texts, batch_size=32)
  682 + embeddings = encoder.encode(title_texts, batch_size=32)
683 683 for j, emb in enumerate(embeddings):
684 684 doc_idx = title_doc_indices[j]
685 685 if isinstance(emb, np.ndarray):
... ... @@ -731,7 +731,7 @@ if enable_embedding and encoder and documents:
731 731  
732 732 7. **批量生成 Embedding**(如果启用)
733 733 - 收集所有文档的标题文本
734   - - 批量调用 `encoder.encode_batch()` 生成 embedding
  734 + - 批量调用 `encoder.encode()`(传入 list[str])生成 embedding
735 735 - 填充到对应文档
736 736  
737 737 8. **批量写入 ES**
... ...
embeddings/README.md
... ... @@ -10,7 +10,7 @@
10 10 这个目录是一个完整的“向量化模块”,包含:
11 11  
12 12 - **HTTP 客户端**:`text_encoder.py` / `image_encoder.py`(供搜索/索引模块调用)
13   -- **本地模型实现**:`qwen3_model.py` / `clip_model.py`
  13 +- **本地模型实现**:`text_embedding_sentence_transformers.py` / `clip_model.py`
14 14 - **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐)
15 15 - **向量化服务(FastAPI)**:`server.py`
16 16 - **统一配置**:`config.py`
... ...
embeddings/server.py
... ... @@ -83,7 +83,7 @@ def _preview_inputs(items: List[str], max_items: int, max_chars: int) -> List[Di
83 83  
84 84 def _encode_local_st(texts: List[str], normalize_embeddings: bool) -> Any:
85 85 with _text_encode_lock:
86   - return _text_model.encode_batch(
  86 + return _text_model.encode(
87 87 texts,
88 88 batch_size=int(CONFIG.TEXT_BATCH_SIZE),
89 89 device=CONFIG.TEXT_DEVICE,
... ... @@ -198,7 +198,7 @@ def load_models():
198 198 backend_name, backend_cfg = get_embedding_backend_config()
199 199 _text_backend_name = backend_name
200 200 if backend_name == "tei":
201   - from embeddings.tei_model import TEITextModel
  201 + from embeddings.text_embedding_tei import TEITextModel
202 202  
203 203 base_url = (
204 204 os.getenv("TEI_BASE_URL")
... ... @@ -216,7 +216,7 @@ def load_models():
216 216 timeout_sec=timeout_sec,
217 217 )
218 218 elif backend_name == "local_st":
219   - from embeddings.qwen3_model import Qwen3TextModel
  219 + from embeddings.text_embedding_sentence_transformers import Qwen3TextModel
220 220  
221 221 model_id = (
222 222 os.getenv("TEXT_MODEL_ID")
... ... @@ -342,7 +342,7 @@ def embed_text(texts: List[str], normalize: Optional[bool] = None) -> List[Optio
342 342 return out
343 343 embs = _encode_local_st(normalized, normalize_embeddings=False)
344 344 else:
345   - embs = _text_model.encode_batch(
  345 + embs = _text_model.encode(
346 346 normalized,
347 347 batch_size=int(CONFIG.TEXT_BATCH_SIZE),
348 348 device=CONFIG.TEXT_DEVICE,
... ...
embeddings/qwen3_model.py renamed to embeddings/text_embedding_sentence_transformers.py
... ... @@ -47,6 +47,7 @@ class Qwen3TextModel(object):
47 47 device: str = "cuda",
48 48 batch_size: int = 32,
49 49 ) -> np.ndarray:
  50 +
50 51 # SentenceTransformer + CUDA inference is not thread-safe in our usage;
51 52 # keep one in-flight encode call while avoiding repeated .to(device) hops.
52 53 with self._encode_lock:
... ... @@ -60,16 +61,3 @@ class Qwen3TextModel(object):
60 61 )
61 62 return embeddings
62 63  
63   - def encode_batch(
64   - self,
65   - texts: List[str],
66   - batch_size: int = 32,
67   - device: str = "cuda",
68   - normalize_embeddings: bool = True,
69   - ) -> np.ndarray:
70   - return self.encode(
71   - texts,
72   - batch_size=batch_size,
73   - device=device,
74   - normalize_embeddings=normalize_embeddings,
75   - )
... ...
embeddings/tei_model.py renamed to embeddings/text_embedding_tei.py
... ... @@ -54,24 +54,17 @@ class TEITextModel:
54 54 device: str = "cuda",
55 55 batch_size: int = 32,
56 56 ) -> np.ndarray:
57   - if isinstance(sentences, str):
58   - sentences = [sentences]
59   - return self.encode_batch(
60   - texts=sentences,
61   - batch_size=batch_size,
62   - device=device,
63   - normalize_embeddings=normalize_embeddings,
64   - )
  57 + """
  58 + Encode a single sentence or a list of sentences.
65 59  
66   - def encode_batch(
67   - self,
68   - texts: List[str],
69   - batch_size: int = 32,
70   - device: str = "cuda",
71   - normalize_embeddings: bool = True,
72   - ) -> np.ndarray:
73   - del batch_size # TEI performs its own batching.
74   - del device # Not used by HTTP backend.
  60 + TEI HTTP 后端天然是批量接口,这里统一通过 encode 处理单条和批量输入,
  61 + 不再额外暴露 encode_batch。
  62 + """
  63 +
  64 + if isinstance(sentences, str):
  65 + texts: List[str] = [sentences]
  66 + else:
  67 + texts = sentences
75 68  
76 69 if texts is None or len(texts) == 0:
77 70 return np.array([], dtype=object)
... ...
embeddings/text_encoder.py
... ... @@ -135,33 +135,8 @@ class TextEmbeddingEncoder:
135 135 else:
136 136 raise ValueError(f"No embedding found for text index {original_idx}: {text[:50]}...")
137 137  
138   - # 返回 numpy 数组(dtype=object),元素为 np.ndarray 或 None
  138 + # 返回 numpy 数组(dtype=object),元素均为有效 np.ndarray 向量
139 139 return np.array(embeddings, dtype=object)
140   -
141   - def encode_batch(
142   - self,
143   - texts: List[str],
144   - batch_size: int = 32,
145   - device: str = 'cpu',
146   - normalize_embeddings: bool = True,
147   - ) -> np.ndarray:
148   - """
149   - Encode a batch of texts efficiently via network service.
150   -
151   - Args:
152   - texts: List of texts to encode
153   - batch_size: Batch size for processing
154   - device: Device parameter ignored for service compatibility
155   -
156   - Returns:
157   - numpy array of embeddings
158   - """
159   - return self.encode(
160   - texts,
161   - batch_size=batch_size,
162   - device=device,
163   - normalize_embeddings=normalize_embeddings,
164   - )
165 140  
166 141 def _is_valid_embedding(self, embedding: np.ndarray) -> bool:
167 142 """
... ...
indexer/incremental_service.py
... ... @@ -641,7 +641,7 @@ class IncrementalIndexerService:
641 641 title_doc_indices.append(i)
642 642  
643 643 if title_texts:
644   - embeddings = encoder.encode_batch(title_texts, batch_size=32)
  644 + embeddings = encoder.encode(title_texts, batch_size=32)
645 645 if embeddings is None or len(embeddings) != len(title_texts):
646 646 raise RuntimeError(
647 647 f"[IncrementalIndexing] Batch embedding length mismatch for tenant_id={tenant_id}: "
... ...
indexer/product_enrich.py
... ... @@ -96,6 +96,9 @@ except Exception as e:
96 96 logger.warning(f"Failed to initialize Redis for anchors cache: {e}")
97 97 _anchor_redis = None
98 98  
  99 +# 中文版本提示词(请勿删除):
  100 +# "你是一名电商平台的商品标注员,你的工作是对输入的每个商品进行理解、分析和标注,"
  101 +# "并按要求格式返回 Markdown 表格。所有输出内容必须为中文。"
99 102  
100 103 SYSTEM_MESSAGES = (
101 104 "You are a product annotator for an e-commerce platform. "
... ... @@ -163,6 +166,31 @@ def create_prompt(products: List[Dict[str, str]], target_lang: str = "zh") -> st
163 166 """
164 167 lang_name = SOURCE_LANG_CODE_MAP.get(target_lang, target_lang)
165 168  
  169 +# 中文版本提示词(请勿删除)
  170 +# prompt = """请对输入的每条商品标题,分析并提取以下信息:
  171 +
  172 +# 1. 商品标题:将输入商品名称翻译为自然、完整的中文商品标题
  173 +# 2. 品类路径:从大类到细分品类,用">"分隔(例如:服装>女装>裤子>工装裤)
  174 +# 3. 细分标签:商品的风格、特点、功能等(例如:碎花,收腰,法式)
  175 +# 4. 适用人群:性别/年龄段等(例如:年轻女性)
  176 +# 5. 使用场景
  177 +# 6. 适用季节
  178 +# 7. 关键属性
  179 +# 8. 材质说明
  180 +# 9. 功能特点
  181 +# 10. 商品卖点:分析和提取一句话核心卖点,用于推荐理由
  182 +# 11. 锚文本:生成一组能够代表该商品、并可能被用户用于搜索的词语或短语。这些词语应覆盖用户需求的各个维度,如品类、细分标签、功能特性、需求场景等等。
  183 +
  184 +# 输入商品列表:
  185 +
  186 +# """
  187 +# prompt_tail = """
  188 +# 请严格按照以下markdown表格格式返回,每列内部的多值内容都用逗号分隔,不要添加任何其他说明:
  189 +
  190 +# | 序号 | 商品标题 | 品类路径 | 细分标签 | 适用人群 | 使用场景 | 适用季节 | 关键属性 | 材质说明 | 功能特点 | 商品卖点 | 锚文本 |
  191 +# |----|----|----|----|----|----|----|----|----|----|----|----|
  192 +# """
  193 +
166 194 prompt = """Please analyze each input product title and extract the following information:
167 195  
168 196 1. Product title: a natural English product name derived from the input title
... ...