Commit b2e50710eba04858a7f45112120b57fc77574cca
1 parent
5c2b70a2
BgeEncoder.encode(...) 返回:np.ndarray(dtype=object),每个元素 要么是 np.ndarray,要么是 None。
cache/service 任一环节返回坏 embedding(含 NaN/Inf/空/非 ndarray)都会 视为 None,并且坏 cache 会被自动删除。
Showing
5 changed files
with
99 additions
and
39 deletions
Show diff stats
docs/索引字段说明v2.md
| ... | ... | @@ -433,7 +433,7 @@ filters AND (text_recall OR embedding_recall) |
| 433 | 433 | |
| 434 | 434 | ### 文本召回字段 |
| 435 | 435 | |
| 436 | -默认同时搜索以下字段(中英文都包含): | |
| 436 | +根据查询词的语言选择对应的索引字段: | |
| 437 | 437 | - `title_zh^3.0`, `title_en^3.0` |
| 438 | 438 | - `brief_zh^1.5`, `brief_en^1.5` |
| 439 | 439 | - `description_zh^1.0`, `description_en^1.0` | ... | ... |
embeddings/text_encoder.py
| ... | ... | @@ -101,24 +101,17 @@ class BgeEncoder: |
| 101 | 101 | batch_size: Batch size for processing (used for service requests) |
| 102 | 102 | |
| 103 | 103 | Returns: |
| 104 | - numpy array of shape (n, 1024) containing embeddings | |
| 104 | + numpy array of dtype=object, where each element is either: | |
| 105 | + - np.ndarray (valid embedding vector) or | |
| 106 | + - None (no embedding available for that text) | |
| 105 | 107 | """ |
| 106 | 108 | # Convert single string to list |
| 107 | 109 | if isinstance(sentences, str): |
| 108 | 110 | sentences = [sentences] |
| 109 | 111 | |
| 110 | 112 | # Check cache first |
| 111 | - cached_embeddings = [] | |
| 112 | - uncached_indices = [] | |
| 113 | - uncached_texts = [] | |
| 114 | - | |
| 115 | - for i, text in enumerate(sentences): | |
| 116 | - cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding | |
| 117 | - if cached is not None: | |
| 118 | - cached_embeddings.append((i, cached)) | |
| 119 | - else: | |
| 120 | - uncached_indices.append(i) | |
| 121 | - uncached_texts.append(text) | |
| 113 | + uncached_indices: List[int] = [] | |
| 114 | + uncached_texts: List[str] = [] | |
| 122 | 115 | |
| 123 | 116 | # Prepare request data for uncached texts |
| 124 | 117 | request_data = [] |
| ... | ... | @@ -136,11 +129,16 @@ class BgeEncoder: |
| 136 | 129 | request_data.append(request_item) |
| 137 | 130 | |
| 138 | 131 | # Process response |
| 139 | - embeddings = [None] * len(sentences) | |
| 140 | - | |
| 141 | - # Fill in cached embeddings | |
| 142 | - for idx, cached_emb in cached_embeddings: | |
| 143 | - embeddings[idx] = cached_emb | |
| 132 | + # Each element can be np.ndarray or None (表示该文本没有可用的向量) | |
| 133 | + embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) | |
| 134 | + | |
| 135 | + for i, text in enumerate(sentences): | |
| 136 | + cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding | |
| 137 | + if cached is not None: | |
| 138 | + embeddings[i] = cached | |
| 139 | + else: | |
| 140 | + uncached_indices.append(i) | |
| 141 | + uncached_texts.append(text) | |
| 144 | 142 | |
| 145 | 143 | # If there are uncached texts, call service |
| 146 | 144 | if uncached_texts: |
| ... | ... | @@ -168,25 +166,35 @@ class BgeEncoder: |
| 168 | 166 | |
| 169 | 167 | if embedding is not None: |
| 170 | 168 | embedding_array = np.array(embedding, dtype=np.float32) |
| 171 | - embeddings[original_idx] = embedding_array | |
| 172 | - # Cache the embedding | |
| 173 | - self._set_cached_embedding(text, 'en', embedding_array) | |
| 169 | + # Validate embedding from service - if invalid, treat as no result | |
| 170 | + if self._is_valid_embedding(embedding_array): | |
| 171 | + embeddings[original_idx] = embedding_array | |
| 172 | + # Cache the embedding | |
| 173 | + self._set_cached_embedding(text, 'en', embedding_array) | |
| 174 | + else: | |
| 175 | + logger.warning( | |
| 176 | + f"Invalid embedding returned from service for text {original_idx} " | |
| 177 | + f"(contains NaN/Inf or invalid shape), treating as no result. " | |
| 178 | + f"Text preview: {text[:50]}..." | |
| 179 | + ) | |
| 180 | + # 不生成兜底向量,保持为 None | |
| 181 | + embeddings[original_idx] = None | |
| 174 | 182 | else: |
| 175 | 183 | logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...") |
| 176 | - embeddings[original_idx] = np.zeros(1024, dtype=np.float32) | |
| 184 | + # 不生成兜底向量,保持为 None | |
| 185 | + embeddings[original_idx] = None | |
| 177 | 186 | else: |
| 178 | 187 | logger.warning(f"No response found for text {original_idx}") |
| 179 | - embeddings[original_idx] = np.zeros(1024, dtype=np.float32) | |
| 188 | + # 不生成兜底向量,保持为 None | |
| 189 | + embeddings[original_idx] = None | |
| 180 | 190 | |
| 181 | 191 | except Exception as e: |
| 182 | 192 | logger.error(f"Failed to encode texts: {e}", exc_info=True) |
| 183 | - # Fill missing embeddings with zeros | |
| 184 | - for idx in uncached_indices: | |
| 185 | - if embeddings[idx] is None: | |
| 186 | - embeddings[idx] = np.zeros(1024, dtype=np.float32) | |
| 193 | + # 出错时不要生成兜底全零向量,保持为 None | |
| 194 | + pass | |
| 187 | 195 | |
| 188 | - # Convert to numpy array | |
| 189 | - return np.array(embeddings, dtype=np.float32) | |
| 196 | + # 返回 numpy 数组(dtype=object),元素为 np.ndarray 或 None | |
| 197 | + return np.array(embeddings, dtype=object) | |
| 190 | 198 | |
| 191 | 199 | def encode_batch( |
| 192 | 200 | self, |
| ... | ... | @@ -211,6 +219,27 @@ class BgeEncoder: |
| 211 | 219 | """Generate a cache key for the query""" |
| 212 | 220 | return f"embedding:{language}:{query}" |
| 213 | 221 | |
| 222 | + def _is_valid_embedding(self, embedding: np.ndarray) -> bool: | |
| 223 | + """ | |
| 224 | + Check if embedding is valid (not None, correct shape, no NaN/Inf). | |
| 225 | + | |
| 226 | + Args: | |
| 227 | + embedding: Embedding array to validate | |
| 228 | + | |
| 229 | + Returns: | |
| 230 | + True if valid, False otherwise | |
| 231 | + """ | |
| 232 | + if embedding is None: | |
| 233 | + return False | |
| 234 | + if not isinstance(embedding, np.ndarray): | |
| 235 | + return False | |
| 236 | + if embedding.size == 0: | |
| 237 | + return False | |
| 238 | + # Check for NaN or Inf values | |
| 239 | + if not np.isfinite(embedding).all(): | |
| 240 | + return False | |
| 241 | + return True | |
| 242 | + | |
| 214 | 243 | def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]: |
| 215 | 244 | """Get embedding from cache if exists (with sliding expiration)""" |
| 216 | 245 | if not self.redis_client: |
| ... | ... | @@ -220,10 +249,24 @@ class BgeEncoder: |
| 220 | 249 | cache_key = self._get_cache_key(query, language) |
| 221 | 250 | cached_data = self.redis_client.get(cache_key) |
| 222 | 251 | if cached_data: |
| 223 | - logger.debug(f"Cache hit for embedding: {query}") | |
| 224 | - # Update expiration time on access (sliding expiration) | |
| 225 | - self.redis_client.expire(cache_key, self.expire_time) | |
| 226 | - return pickle.loads(cached_data) | |
| 252 | + embedding = pickle.loads(cached_data) | |
| 253 | + # Validate cached embedding - if invalid, ignore cache and return None | |
| 254 | + if self._is_valid_embedding(embedding): | |
| 255 | + logger.debug(f"Cache hit for embedding: {query}") | |
| 256 | + # Update expiration time on access (sliding expiration) | |
| 257 | + self.redis_client.expire(cache_key, self.expire_time) | |
| 258 | + return embedding | |
| 259 | + else: | |
| 260 | + logger.warning( | |
| 261 | + f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), " | |
| 262 | + f"ignoring cache for query: {query[:50]}..." | |
| 263 | + ) | |
| 264 | + # Delete invalid cache entry | |
| 265 | + try: | |
| 266 | + self.redis_client.delete(cache_key) | |
| 267 | + except Exception as e: | |
| 268 | + logger.debug(f"Failed to delete invalid cache entry: {e}") | |
| 269 | + return None | |
| 227 | 270 | return None |
| 228 | 271 | except Exception as e: |
| 229 | 272 | logger.error(f"Error retrieving embedding from cache: {e}") | ... | ... |
indexer/document_transformer.py
| ... | ... | @@ -5,6 +5,7 @@ SPU文档转换器 - 公共转换逻辑。 |
| 5 | 5 | """ |
| 6 | 6 | |
| 7 | 7 | import pandas as pd |
| 8 | +import numpy as np | |
| 8 | 9 | import logging |
| 9 | 10 | from typing import Dict, Any, Optional, List |
| 10 | 11 | from config import ConfigLoader |
| ... | ... | @@ -594,6 +595,9 @@ class SPUDocumentTransformer: |
| 594 | 595 | if embeddings is not None and len(embeddings) > 0: |
| 595 | 596 | # 取第一个embedding(因为只传了一个文本) |
| 596 | 597 | embedding = embeddings[0] |
| 598 | + if not isinstance(embedding, np.ndarray): | |
| 599 | + logger.warning(f"Embedding is None/invalid for title: {title_text[:50]}...") | |
| 600 | + return | |
| 597 | 601 | # 转换为列表格式(ES需要) |
| 598 | 602 | doc['title_embedding'] = embedding.tolist() |
| 599 | 603 | logger.debug(f"Generated title_embedding for SPU: {doc.get('spu_id')}, title: {title_text[:50]}...") | ... | ... |
indexer/incremental_service.py
| ... | ... | @@ -5,6 +5,7 @@ import logging |
| 5 | 5 | import time |
| 6 | 6 | import threading |
| 7 | 7 | from typing import Dict, Any, Optional, List, Tuple |
| 8 | +import numpy as np | |
| 8 | 9 | from sqlalchemy import text, bindparam |
| 9 | 10 | from indexer.indexing_utils import load_category_mapping, create_document_transformer |
| 10 | 11 | from indexer.bulk_indexer import BulkIndexer |
| ... | ... | @@ -134,7 +135,9 @@ class IncrementalIndexerService: |
| 134 | 135 | try: |
| 135 | 136 | embeddings = encoder.encode(title_text) |
| 136 | 137 | if embeddings is not None and len(embeddings) > 0: |
| 137 | - doc["title_embedding"] = embeddings[0].tolist() | |
| 138 | + emb0 = embeddings[0] | |
| 139 | + if isinstance(emb0, np.ndarray): | |
| 140 | + doc["title_embedding"] = emb0.tolist() | |
| 138 | 141 | except Exception as e: |
| 139 | 142 | logger.warning(f"Failed to generate embedding for spu_id={spu_id}: {e}") |
| 140 | 143 | |
| ... | ... | @@ -564,7 +567,8 @@ class IncrementalIndexerService: |
| 564 | 567 | embeddings = encoder.encode_batch(title_texts, batch_size=32) |
| 565 | 568 | for j, emb in enumerate(embeddings): |
| 566 | 569 | doc_idx = title_doc_indices[j] |
| 567 | - documents[doc_idx][1]["title_embedding"] = emb.tolist() | |
| 570 | + if isinstance(emb, np.ndarray): | |
| 571 | + documents[doc_idx][1]["title_embedding"] = emb.tolist() | |
| 568 | 572 | except Exception as e: |
| 569 | 573 | logger.warning(f"[IncrementalIndexing] Batch embedding failed for tenant_id={tenant_id}: {e}", exc_info=True) |
| 570 | 574 | ... | ... |
query/query_parser.py
| ... | ... | @@ -331,8 +331,14 @@ class QueryParser: |
| 331 | 331 | log_debug("开始生成查询向量(异步)") |
| 332 | 332 | # Submit encoding task to thread pool for async execution |
| 333 | 333 | encoding_executor = ThreadPoolExecutor(max_workers=1) |
| 334 | + def _encode_query_vector() -> Optional[np.ndarray]: | |
| 335 | + arr = self.text_encoder.encode([query_text]) | |
| 336 | + if arr is None or len(arr) == 0: | |
| 337 | + return None | |
| 338 | + vec = arr[0] | |
| 339 | + return vec if isinstance(vec, np.ndarray) else None | |
| 334 | 340 | embedding_future = encoding_executor.submit( |
| 335 | - lambda: self.text_encoder.encode([query_text])[0] | |
| 341 | + _encode_query_vector | |
| 336 | 342 | ) |
| 337 | 343 | except Exception as e: |
| 338 | 344 | error_msg = f"查询向量生成任务提交失败 | 错误: {str(e)}" |
| ... | ... | @@ -370,9 +376,12 @@ class QueryParser: |
| 370 | 376 | context.store_intermediate_result(f'translation_{lang}', result) |
| 371 | 377 | elif task_type == 'embedding': |
| 372 | 378 | query_vector = result |
| 373 | - log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}") | |
| 374 | - if context: | |
| 375 | - context.store_intermediate_result('query_vector_shape', query_vector.shape) | |
| 379 | + if query_vector is not None: | |
| 380 | + log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}") | |
| 381 | + if context: | |
| 382 | + context.store_intermediate_result('query_vector_shape', query_vector.shape) | |
| 383 | + else: | |
| 384 | + log_info("查询向量生成完成但结果为空(None),将按无向量处理") | |
| 376 | 385 | except Exception as e: |
| 377 | 386 | if task_type == 'translation': |
| 378 | 387 | error_msg = f"翻译失败 | 语言: {lang} | 错误: {str(e)}" | ... | ... |