Commit b2e50710eba04858a7f45112120b57fc77574cca

Authored by tangwang
1 parent 5c2b70a2

BgeEncoder.encode(...) 返回:np.ndarray(dtype=object),每个元素 要么是 np.ndarray,要么是 None。

cache/service 任一环节返回坏 embedding(含 NaN/Inf/空/非 ndarray)都会 视为 None,并且坏 cache 会被自动删除。
docs/索引字段说明v2.md
@@ -433,7 +433,7 @@ filters AND (text_recall OR embedding_recall) @@ -433,7 +433,7 @@ filters AND (text_recall OR embedding_recall)
433 433
434 ### 文本召回字段 434 ### 文本召回字段
435 435
436 -默认同时搜索以下字段(中英文都包含) 436 +根据查询词的语言选择对应的索引字段
437 - `title_zh^3.0`, `title_en^3.0` 437 - `title_zh^3.0`, `title_en^3.0`
438 - `brief_zh^1.5`, `brief_en^1.5` 438 - `brief_zh^1.5`, `brief_en^1.5`
439 - `description_zh^1.0`, `description_en^1.0` 439 - `description_zh^1.0`, `description_en^1.0`
embeddings/text_encoder.py
@@ -101,24 +101,17 @@ class BgeEncoder: @@ -101,24 +101,17 @@ class BgeEncoder:
101 batch_size: Batch size for processing (used for service requests) 101 batch_size: Batch size for processing (used for service requests)
102 102
103 Returns: 103 Returns:
104 - numpy array of shape (n, 1024) containing embeddings 104 + numpy array of dtype=object, where each element is either:
  105 + - np.ndarray (valid embedding vector) or
  106 + - None (no embedding available for that text)
105 """ 107 """
106 # Convert single string to list 108 # Convert single string to list
107 if isinstance(sentences, str): 109 if isinstance(sentences, str):
108 sentences = [sentences] 110 sentences = [sentences]
109 111
110 # Check cache first 112 # Check cache first
111 - cached_embeddings = []  
112 - uncached_indices = []  
113 - uncached_texts = []  
114 -  
115 - for i, text in enumerate(sentences):  
116 - cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding  
117 - if cached is not None:  
118 - cached_embeddings.append((i, cached))  
119 - else:  
120 - uncached_indices.append(i)  
121 - uncached_texts.append(text) 113 + uncached_indices: List[int] = []
  114 + uncached_texts: List[str] = []
122 115
123 # Prepare request data for uncached texts 116 # Prepare request data for uncached texts
124 request_data = [] 117 request_data = []
@@ -136,11 +129,16 @@ class BgeEncoder: @@ -136,11 +129,16 @@ class BgeEncoder:
136 request_data.append(request_item) 129 request_data.append(request_item)
137 130
138 # Process response 131 # Process response
139 - embeddings = [None] * len(sentences)  
140 -  
141 - # Fill in cached embeddings  
142 - for idx, cached_emb in cached_embeddings:  
143 - embeddings[idx] = cached_emb 132 + # Each element can be np.ndarray or None (表示该文本没有可用的向量)
  133 + embeddings: List[Optional[np.ndarray]] = [None] * len(sentences)
  134 +
  135 + for i, text in enumerate(sentences):
  136 + cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding
  137 + if cached is not None:
  138 + embeddings[i] = cached
  139 + else:
  140 + uncached_indices.append(i)
  141 + uncached_texts.append(text)
144 142
145 # If there are uncached texts, call service 143 # If there are uncached texts, call service
146 if uncached_texts: 144 if uncached_texts:
@@ -168,25 +166,35 @@ class BgeEncoder: @@ -168,25 +166,35 @@ class BgeEncoder:
168 166
169 if embedding is not None: 167 if embedding is not None:
170 embedding_array = np.array(embedding, dtype=np.float32) 168 embedding_array = np.array(embedding, dtype=np.float32)
171 - embeddings[original_idx] = embedding_array  
172 - # Cache the embedding  
173 - self._set_cached_embedding(text, 'en', embedding_array) 169 + # Validate embedding from service - if invalid, treat as no result
  170 + if self._is_valid_embedding(embedding_array):
  171 + embeddings[original_idx] = embedding_array
  172 + # Cache the embedding
  173 + self._set_cached_embedding(text, 'en', embedding_array)
  174 + else:
  175 + logger.warning(
  176 + f"Invalid embedding returned from service for text {original_idx} "
  177 + f"(contains NaN/Inf or invalid shape), treating as no result. "
  178 + f"Text preview: {text[:50]}..."
  179 + )
  180 + # 不生成兜底向量,保持为 None
  181 + embeddings[original_idx] = None
174 else: 182 else:
175 logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...") 183 logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...")
176 - embeddings[original_idx] = np.zeros(1024, dtype=np.float32) 184 + # 不生成兜底向量,保持为 None
  185 + embeddings[original_idx] = None
177 else: 186 else:
178 logger.warning(f"No response found for text {original_idx}") 187 logger.warning(f"No response found for text {original_idx}")
179 - embeddings[original_idx] = np.zeros(1024, dtype=np.float32) 188 + # 不生成兜底向量,保持为 None
  189 + embeddings[original_idx] = None
180 190
181 except Exception as e: 191 except Exception as e:
182 logger.error(f"Failed to encode texts: {e}", exc_info=True) 192 logger.error(f"Failed to encode texts: {e}", exc_info=True)
183 - # Fill missing embeddings with zeros  
184 - for idx in uncached_indices:  
185 - if embeddings[idx] is None:  
186 - embeddings[idx] = np.zeros(1024, dtype=np.float32) 193 + # 出错时不要生成兜底全零向量,保持为 None
  194 + pass
187 195
188 - # Convert to numpy array  
189 - return np.array(embeddings, dtype=np.float32) 196 + # 返回 numpy 数组(dtype=object),元素为 np.ndarray 或 None
  197 + return np.array(embeddings, dtype=object)
190 198
191 def encode_batch( 199 def encode_batch(
192 self, 200 self,
@@ -211,6 +219,27 @@ class BgeEncoder: @@ -211,6 +219,27 @@ class BgeEncoder:
211 """Generate a cache key for the query""" 219 """Generate a cache key for the query"""
212 return f"embedding:{language}:{query}" 220 return f"embedding:{language}:{query}"
213 221
  222 + def _is_valid_embedding(self, embedding: np.ndarray) -> bool:
  223 + """
  224 + Check if embedding is valid (not None, correct shape, no NaN/Inf).
  225 +
  226 + Args:
  227 + embedding: Embedding array to validate
  228 +
  229 + Returns:
  230 + True if valid, False otherwise
  231 + """
  232 + if embedding is None:
  233 + return False
  234 + if not isinstance(embedding, np.ndarray):
  235 + return False
  236 + if embedding.size == 0:
  237 + return False
  238 + # Check for NaN or Inf values
  239 + if not np.isfinite(embedding).all():
  240 + return False
  241 + return True
  242 +
214 def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]: 243 def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]:
215 """Get embedding from cache if exists (with sliding expiration)""" 244 """Get embedding from cache if exists (with sliding expiration)"""
216 if not self.redis_client: 245 if not self.redis_client:
@@ -220,10 +249,24 @@ class BgeEncoder: @@ -220,10 +249,24 @@ class BgeEncoder:
220 cache_key = self._get_cache_key(query, language) 249 cache_key = self._get_cache_key(query, language)
221 cached_data = self.redis_client.get(cache_key) 250 cached_data = self.redis_client.get(cache_key)
222 if cached_data: 251 if cached_data:
223 - logger.debug(f"Cache hit for embedding: {query}")  
224 - # Update expiration time on access (sliding expiration)  
225 - self.redis_client.expire(cache_key, self.expire_time)  
226 - return pickle.loads(cached_data) 252 + embedding = pickle.loads(cached_data)
  253 + # Validate cached embedding - if invalid, ignore cache and return None
  254 + if self._is_valid_embedding(embedding):
  255 + logger.debug(f"Cache hit for embedding: {query}")
  256 + # Update expiration time on access (sliding expiration)
  257 + self.redis_client.expire(cache_key, self.expire_time)
  258 + return embedding
  259 + else:
  260 + logger.warning(
  261 + f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), "
  262 + f"ignoring cache for query: {query[:50]}..."
  263 + )
  264 + # Delete invalid cache entry
  265 + try:
  266 + self.redis_client.delete(cache_key)
  267 + except Exception as e:
  268 + logger.debug(f"Failed to delete invalid cache entry: {e}")
  269 + return None
227 return None 270 return None
228 except Exception as e: 271 except Exception as e:
229 logger.error(f"Error retrieving embedding from cache: {e}") 272 logger.error(f"Error retrieving embedding from cache: {e}")
indexer/document_transformer.py
@@ -5,6 +5,7 @@ SPU文档转换器 - 公共转换逻辑。 @@ -5,6 +5,7 @@ SPU文档转换器 - 公共转换逻辑。
5 """ 5 """
6 6
7 import pandas as pd 7 import pandas as pd
  8 +import numpy as np
8 import logging 9 import logging
9 from typing import Dict, Any, Optional, List 10 from typing import Dict, Any, Optional, List
10 from config import ConfigLoader 11 from config import ConfigLoader
@@ -594,6 +595,9 @@ class SPUDocumentTransformer: @@ -594,6 +595,9 @@ class SPUDocumentTransformer:
594 if embeddings is not None and len(embeddings) > 0: 595 if embeddings is not None and len(embeddings) > 0:
595 # 取第一个embedding(因为只传了一个文本) 596 # 取第一个embedding(因为只传了一个文本)
596 embedding = embeddings[0] 597 embedding = embeddings[0]
  598 + if not isinstance(embedding, np.ndarray):
  599 + logger.warning(f"Embedding is None/invalid for title: {title_text[:50]}...")
  600 + return
597 # 转换为列表格式(ES需要) 601 # 转换为列表格式(ES需要)
598 doc['title_embedding'] = embedding.tolist() 602 doc['title_embedding'] = embedding.tolist()
599 logger.debug(f"Generated title_embedding for SPU: {doc.get('spu_id')}, title: {title_text[:50]}...") 603 logger.debug(f"Generated title_embedding for SPU: {doc.get('spu_id')}, title: {title_text[:50]}...")
indexer/incremental_service.py
@@ -5,6 +5,7 @@ import logging @@ -5,6 +5,7 @@ import logging
5 import time 5 import time
6 import threading 6 import threading
7 from typing import Dict, Any, Optional, List, Tuple 7 from typing import Dict, Any, Optional, List, Tuple
  8 +import numpy as np
8 from sqlalchemy import text, bindparam 9 from sqlalchemy import text, bindparam
9 from indexer.indexing_utils import load_category_mapping, create_document_transformer 10 from indexer.indexing_utils import load_category_mapping, create_document_transformer
10 from indexer.bulk_indexer import BulkIndexer 11 from indexer.bulk_indexer import BulkIndexer
@@ -134,7 +135,9 @@ class IncrementalIndexerService: @@ -134,7 +135,9 @@ class IncrementalIndexerService:
134 try: 135 try:
135 embeddings = encoder.encode(title_text) 136 embeddings = encoder.encode(title_text)
136 if embeddings is not None and len(embeddings) > 0: 137 if embeddings is not None and len(embeddings) > 0:
137 - doc["title_embedding"] = embeddings[0].tolist() 138 + emb0 = embeddings[0]
  139 + if isinstance(emb0, np.ndarray):
  140 + doc["title_embedding"] = emb0.tolist()
138 except Exception as e: 141 except Exception as e:
139 logger.warning(f"Failed to generate embedding for spu_id={spu_id}: {e}") 142 logger.warning(f"Failed to generate embedding for spu_id={spu_id}: {e}")
140 143
@@ -564,7 +567,8 @@ class IncrementalIndexerService: @@ -564,7 +567,8 @@ class IncrementalIndexerService:
564 embeddings = encoder.encode_batch(title_texts, batch_size=32) 567 embeddings = encoder.encode_batch(title_texts, batch_size=32)
565 for j, emb in enumerate(embeddings): 568 for j, emb in enumerate(embeddings):
566 doc_idx = title_doc_indices[j] 569 doc_idx = title_doc_indices[j]
567 - documents[doc_idx][1]["title_embedding"] = emb.tolist() 570 + if isinstance(emb, np.ndarray):
  571 + documents[doc_idx][1]["title_embedding"] = emb.tolist()
568 except Exception as e: 572 except Exception as e:
569 logger.warning(f"[IncrementalIndexing] Batch embedding failed for tenant_id={tenant_id}: {e}", exc_info=True) 573 logger.warning(f"[IncrementalIndexing] Batch embedding failed for tenant_id={tenant_id}: {e}", exc_info=True)
570 574
query/query_parser.py
@@ -331,8 +331,14 @@ class QueryParser: @@ -331,8 +331,14 @@ class QueryParser:
331 log_debug("开始生成查询向量(异步)") 331 log_debug("开始生成查询向量(异步)")
332 # Submit encoding task to thread pool for async execution 332 # Submit encoding task to thread pool for async execution
333 encoding_executor = ThreadPoolExecutor(max_workers=1) 333 encoding_executor = ThreadPoolExecutor(max_workers=1)
  334 + def _encode_query_vector() -> Optional[np.ndarray]:
  335 + arr = self.text_encoder.encode([query_text])
  336 + if arr is None or len(arr) == 0:
  337 + return None
  338 + vec = arr[0]
  339 + return vec if isinstance(vec, np.ndarray) else None
334 embedding_future = encoding_executor.submit( 340 embedding_future = encoding_executor.submit(
335 - lambda: self.text_encoder.encode([query_text])[0] 341 + _encode_query_vector
336 ) 342 )
337 except Exception as e: 343 except Exception as e:
338 error_msg = f"查询向量生成任务提交失败 | 错误: {str(e)}" 344 error_msg = f"查询向量生成任务提交失败 | 错误: {str(e)}"
@@ -370,9 +376,12 @@ class QueryParser: @@ -370,9 +376,12 @@ class QueryParser:
370 context.store_intermediate_result(f'translation_{lang}', result) 376 context.store_intermediate_result(f'translation_{lang}', result)
371 elif task_type == 'embedding': 377 elif task_type == 'embedding':
372 query_vector = result 378 query_vector = result
373 - log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}")  
374 - if context:  
375 - context.store_intermediate_result('query_vector_shape', query_vector.shape) 379 + if query_vector is not None:
  380 + log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}")
  381 + if context:
  382 + context.store_intermediate_result('query_vector_shape', query_vector.shape)
  383 + else:
  384 + log_info("查询向量生成完成但结果为空(None),将按无向量处理")
376 except Exception as e: 385 except Exception as e:
377 if task_type == 'translation': 386 if task_type == 'translation':
378 error_msg = f"翻译失败 | 语言: {lang} | 错误: {str(e)}" 387 error_msg = f"翻译失败 | 语言: {lang} | 错误: {str(e)}"