Commit b2e50710eba04858a7f45112120b57fc77574cca

Authored by tangwang
1 parent 5c2b70a2

BgeEncoder.encode(...) 返回:np.ndarray(dtype=object),每个元素 要么是 np.ndarray,要么是 None。

cache/service 任一环节返回坏 embedding(含 NaN/Inf/空/非 ndarray)都会 视为 None,并且坏 cache 会被自动删除。
docs/索引字段说明v2.md
... ... @@ -433,7 +433,7 @@ filters AND (text_recall OR embedding_recall)
433 433  
434 434 ### 文本召回字段
435 435  
436   -默认同时搜索以下字段(中英文都包含)
  436 +根据查询词的语言选择对应的索引字段
437 437 - `title_zh^3.0`, `title_en^3.0`
438 438 - `brief_zh^1.5`, `brief_en^1.5`
439 439 - `description_zh^1.0`, `description_en^1.0`
... ...
embeddings/text_encoder.py
... ... @@ -101,24 +101,17 @@ class BgeEncoder:
101 101 batch_size: Batch size for processing (used for service requests)
102 102  
103 103 Returns:
104   - numpy array of shape (n, 1024) containing embeddings
  104 + numpy array of dtype=object, where each element is either:
  105 + - np.ndarray (valid embedding vector) or
  106 + - None (no embedding available for that text)
105 107 """
106 108 # Convert single string to list
107 109 if isinstance(sentences, str):
108 110 sentences = [sentences]
109 111  
110 112 # Check cache first
111   - cached_embeddings = []
112   - uncached_indices = []
113   - uncached_texts = []
114   -
115   - for i, text in enumerate(sentences):
116   - cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding
117   - if cached is not None:
118   - cached_embeddings.append((i, cached))
119   - else:
120   - uncached_indices.append(i)
121   - uncached_texts.append(text)
  113 + uncached_indices: List[int] = []
  114 + uncached_texts: List[str] = []
122 115  
123 116 # Prepare request data for uncached texts
124 117 request_data = []
... ... @@ -136,11 +129,16 @@ class BgeEncoder:
136 129 request_data.append(request_item)
137 130  
138 131 # Process response
139   - embeddings = [None] * len(sentences)
140   -
141   - # Fill in cached embeddings
142   - for idx, cached_emb in cached_embeddings:
143   - embeddings[idx] = cached_emb
  132 + # Each element can be np.ndarray or None (表示该文本没有可用的向量)
  133 + embeddings: List[Optional[np.ndarray]] = [None] * len(sentences)
  134 +
  135 + for i, text in enumerate(sentences):
  136 + cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding
  137 + if cached is not None:
  138 + embeddings[i] = cached
  139 + else:
  140 + uncached_indices.append(i)
  141 + uncached_texts.append(text)
144 142  
145 143 # If there are uncached texts, call service
146 144 if uncached_texts:
... ... @@ -168,25 +166,35 @@ class BgeEncoder:
168 166  
169 167 if embedding is not None:
170 168 embedding_array = np.array(embedding, dtype=np.float32)
171   - embeddings[original_idx] = embedding_array
172   - # Cache the embedding
173   - self._set_cached_embedding(text, 'en', embedding_array)
  169 + # Validate embedding from service - if invalid, treat as no result
  170 + if self._is_valid_embedding(embedding_array):
  171 + embeddings[original_idx] = embedding_array
  172 + # Cache the embedding
  173 + self._set_cached_embedding(text, 'en', embedding_array)
  174 + else:
  175 + logger.warning(
  176 + f"Invalid embedding returned from service for text {original_idx} "
  177 + f"(contains NaN/Inf or invalid shape), treating as no result. "
  178 + f"Text preview: {text[:50]}..."
  179 + )
  180 + # 不生成兜底向量,保持为 None
  181 + embeddings[original_idx] = None
174 182 else:
175 183 logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...")
176   - embeddings[original_idx] = np.zeros(1024, dtype=np.float32)
  184 + # 不生成兜底向量,保持为 None
  185 + embeddings[original_idx] = None
177 186 else:
178 187 logger.warning(f"No response found for text {original_idx}")
179   - embeddings[original_idx] = np.zeros(1024, dtype=np.float32)
  188 + # 不生成兜底向量,保持为 None
  189 + embeddings[original_idx] = None
180 190  
181 191 except Exception as e:
182 192 logger.error(f"Failed to encode texts: {e}", exc_info=True)
183   - # Fill missing embeddings with zeros
184   - for idx in uncached_indices:
185   - if embeddings[idx] is None:
186   - embeddings[idx] = np.zeros(1024, dtype=np.float32)
  193 + # 出错时不要生成兜底全零向量,保持为 None
  194 + pass
187 195  
188   - # Convert to numpy array
189   - return np.array(embeddings, dtype=np.float32)
  196 + # 返回 numpy 数组(dtype=object),元素为 np.ndarray 或 None
  197 + return np.array(embeddings, dtype=object)
190 198  
191 199 def encode_batch(
192 200 self,
... ... @@ -211,6 +219,27 @@ class BgeEncoder:
211 219 """Generate a cache key for the query"""
212 220 return f"embedding:{language}:{query}"
213 221  
  222 + def _is_valid_embedding(self, embedding: np.ndarray) -> bool:
  223 + """
  224 + Check if embedding is valid (not None, correct shape, no NaN/Inf).
  225 +
  226 + Args:
  227 + embedding: Embedding array to validate
  228 +
  229 + Returns:
  230 + True if valid, False otherwise
  231 + """
  232 + if embedding is None:
  233 + return False
  234 + if not isinstance(embedding, np.ndarray):
  235 + return False
  236 + if embedding.size == 0:
  237 + return False
  238 + # Check for NaN or Inf values
  239 + if not np.isfinite(embedding).all():
  240 + return False
  241 + return True
  242 +
214 243 def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]:
215 244 """Get embedding from cache if exists (with sliding expiration)"""
216 245 if not self.redis_client:
... ... @@ -220,10 +249,24 @@ class BgeEncoder:
220 249 cache_key = self._get_cache_key(query, language)
221 250 cached_data = self.redis_client.get(cache_key)
222 251 if cached_data:
223   - logger.debug(f"Cache hit for embedding: {query}")
224   - # Update expiration time on access (sliding expiration)
225   - self.redis_client.expire(cache_key, self.expire_time)
226   - return pickle.loads(cached_data)
  252 + embedding = pickle.loads(cached_data)
  253 + # Validate cached embedding - if invalid, ignore cache and return None
  254 + if self._is_valid_embedding(embedding):
  255 + logger.debug(f"Cache hit for embedding: {query}")
  256 + # Update expiration time on access (sliding expiration)
  257 + self.redis_client.expire(cache_key, self.expire_time)
  258 + return embedding
  259 + else:
  260 + logger.warning(
  261 + f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), "
  262 + f"ignoring cache for query: {query[:50]}..."
  263 + )
  264 + # Delete invalid cache entry
  265 + try:
  266 + self.redis_client.delete(cache_key)
  267 + except Exception as e:
  268 + logger.debug(f"Failed to delete invalid cache entry: {e}")
  269 + return None
227 270 return None
228 271 except Exception as e:
229 272 logger.error(f"Error retrieving embedding from cache: {e}")
... ...
indexer/document_transformer.py
... ... @@ -5,6 +5,7 @@ SPU文档转换器 - 公共转换逻辑。
5 5 """
6 6  
7 7 import pandas as pd
  8 +import numpy as np
8 9 import logging
9 10 from typing import Dict, Any, Optional, List
10 11 from config import ConfigLoader
... ... @@ -594,6 +595,9 @@ class SPUDocumentTransformer:
594 595 if embeddings is not None and len(embeddings) > 0:
595 596 # 取第一个embedding(因为只传了一个文本)
596 597 embedding = embeddings[0]
  598 + if not isinstance(embedding, np.ndarray):
  599 + logger.warning(f"Embedding is None/invalid for title: {title_text[:50]}...")
  600 + return
597 601 # 转换为列表格式(ES需要)
598 602 doc['title_embedding'] = embedding.tolist()
599 603 logger.debug(f"Generated title_embedding for SPU: {doc.get('spu_id')}, title: {title_text[:50]}...")
... ...
indexer/incremental_service.py
... ... @@ -5,6 +5,7 @@ import logging
5 5 import time
6 6 import threading
7 7 from typing import Dict, Any, Optional, List, Tuple
  8 +import numpy as np
8 9 from sqlalchemy import text, bindparam
9 10 from indexer.indexing_utils import load_category_mapping, create_document_transformer
10 11 from indexer.bulk_indexer import BulkIndexer
... ... @@ -134,7 +135,9 @@ class IncrementalIndexerService:
134 135 try:
135 136 embeddings = encoder.encode(title_text)
136 137 if embeddings is not None and len(embeddings) > 0:
137   - doc["title_embedding"] = embeddings[0].tolist()
  138 + emb0 = embeddings[0]
  139 + if isinstance(emb0, np.ndarray):
  140 + doc["title_embedding"] = emb0.tolist()
138 141 except Exception as e:
139 142 logger.warning(f"Failed to generate embedding for spu_id={spu_id}: {e}")
140 143  
... ... @@ -564,7 +567,8 @@ class IncrementalIndexerService:
564 567 embeddings = encoder.encode_batch(title_texts, batch_size=32)
565 568 for j, emb in enumerate(embeddings):
566 569 doc_idx = title_doc_indices[j]
567   - documents[doc_idx][1]["title_embedding"] = emb.tolist()
  570 + if isinstance(emb, np.ndarray):
  571 + documents[doc_idx][1]["title_embedding"] = emb.tolist()
568 572 except Exception as e:
569 573 logger.warning(f"[IncrementalIndexing] Batch embedding failed for tenant_id={tenant_id}: {e}", exc_info=True)
570 574  
... ...
query/query_parser.py
... ... @@ -331,8 +331,14 @@ class QueryParser:
331 331 log_debug("开始生成查询向量(异步)")
332 332 # Submit encoding task to thread pool for async execution
333 333 encoding_executor = ThreadPoolExecutor(max_workers=1)
  334 + def _encode_query_vector() -> Optional[np.ndarray]:
  335 + arr = self.text_encoder.encode([query_text])
  336 + if arr is None or len(arr) == 0:
  337 + return None
  338 + vec = arr[0]
  339 + return vec if isinstance(vec, np.ndarray) else None
334 340 embedding_future = encoding_executor.submit(
335   - lambda: self.text_encoder.encode([query_text])[0]
  341 + _encode_query_vector
336 342 )
337 343 except Exception as e:
338 344 error_msg = f"查询向量生成任务提交失败 | 错误: {str(e)}"
... ... @@ -370,9 +376,12 @@ class QueryParser:
370 376 context.store_intermediate_result(f'translation_{lang}', result)
371 377 elif task_type == 'embedding':
372 378 query_vector = result
373   - log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}")
374   - if context:
375   - context.store_intermediate_result('query_vector_shape', query_vector.shape)
  379 + if query_vector is not None:
  380 + log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}")
  381 + if context:
  382 + context.store_intermediate_result('query_vector_shape', query_vector.shape)
  383 + else:
  384 + log_info("查询向量生成完成但结果为空(None),将按无向量处理")
376 385 except Exception as e:
377 386 if task_type == 'translation':
378 387 error_msg = f"翻译失败 | 语言: {lang} | 错误: {str(e)}"
... ...