Commit b2e50710eba04858a7f45112120b57fc77574cca
1 parent
5c2b70a2
BgeEncoder.encode(...) 返回:np.ndarray(dtype=object),每个元素 要么是 np.ndarray,要么是 None。
cache/service 任一环节返回坏 embedding(含 NaN/Inf/空/非 ndarray)都会 视为 None,并且坏 cache 会被自动删除。
Showing
5 changed files
with
99 additions
and
39 deletions
Show diff stats
docs/索引字段说明v2.md
| @@ -433,7 +433,7 @@ filters AND (text_recall OR embedding_recall) | @@ -433,7 +433,7 @@ filters AND (text_recall OR embedding_recall) | ||
| 433 | 433 | ||
| 434 | ### 文本召回字段 | 434 | ### 文本召回字段 |
| 435 | 435 | ||
| 436 | -默认同时搜索以下字段(中英文都包含): | 436 | +根据查询词的语言选择对应的索引字段: |
| 437 | - `title_zh^3.0`, `title_en^3.0` | 437 | - `title_zh^3.0`, `title_en^3.0` |
| 438 | - `brief_zh^1.5`, `brief_en^1.5` | 438 | - `brief_zh^1.5`, `brief_en^1.5` |
| 439 | - `description_zh^1.0`, `description_en^1.0` | 439 | - `description_zh^1.0`, `description_en^1.0` |
embeddings/text_encoder.py
| @@ -101,24 +101,17 @@ class BgeEncoder: | @@ -101,24 +101,17 @@ class BgeEncoder: | ||
| 101 | batch_size: Batch size for processing (used for service requests) | 101 | batch_size: Batch size for processing (used for service requests) |
| 102 | 102 | ||
| 103 | Returns: | 103 | Returns: |
| 104 | - numpy array of shape (n, 1024) containing embeddings | 104 | + numpy array of dtype=object, where each element is either: |
| 105 | + - np.ndarray (valid embedding vector) or | ||
| 106 | + - None (no embedding available for that text) | ||
| 105 | """ | 107 | """ |
| 106 | # Convert single string to list | 108 | # Convert single string to list |
| 107 | if isinstance(sentences, str): | 109 | if isinstance(sentences, str): |
| 108 | sentences = [sentences] | 110 | sentences = [sentences] |
| 109 | 111 | ||
| 110 | # Check cache first | 112 | # Check cache first |
| 111 | - cached_embeddings = [] | ||
| 112 | - uncached_indices = [] | ||
| 113 | - uncached_texts = [] | ||
| 114 | - | ||
| 115 | - for i, text in enumerate(sentences): | ||
| 116 | - cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding | ||
| 117 | - if cached is not None: | ||
| 118 | - cached_embeddings.append((i, cached)) | ||
| 119 | - else: | ||
| 120 | - uncached_indices.append(i) | ||
| 121 | - uncached_texts.append(text) | 113 | + uncached_indices: List[int] = [] |
| 114 | + uncached_texts: List[str] = [] | ||
| 122 | 115 | ||
| 123 | # Prepare request data for uncached texts | 116 | # Prepare request data for uncached texts |
| 124 | request_data = [] | 117 | request_data = [] |
| @@ -136,11 +129,16 @@ class BgeEncoder: | @@ -136,11 +129,16 @@ class BgeEncoder: | ||
| 136 | request_data.append(request_item) | 129 | request_data.append(request_item) |
| 137 | 130 | ||
| 138 | # Process response | 131 | # Process response |
| 139 | - embeddings = [None] * len(sentences) | ||
| 140 | - | ||
| 141 | - # Fill in cached embeddings | ||
| 142 | - for idx, cached_emb in cached_embeddings: | ||
| 143 | - embeddings[idx] = cached_emb | 132 | + # Each element can be np.ndarray or None (表示该文本没有可用的向量) |
| 133 | + embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) | ||
| 134 | + | ||
| 135 | + for i, text in enumerate(sentences): | ||
| 136 | + cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding | ||
| 137 | + if cached is not None: | ||
| 138 | + embeddings[i] = cached | ||
| 139 | + else: | ||
| 140 | + uncached_indices.append(i) | ||
| 141 | + uncached_texts.append(text) | ||
| 144 | 142 | ||
| 145 | # If there are uncached texts, call service | 143 | # If there are uncached texts, call service |
| 146 | if uncached_texts: | 144 | if uncached_texts: |
| @@ -168,25 +166,35 @@ class BgeEncoder: | @@ -168,25 +166,35 @@ class BgeEncoder: | ||
| 168 | 166 | ||
| 169 | if embedding is not None: | 167 | if embedding is not None: |
| 170 | embedding_array = np.array(embedding, dtype=np.float32) | 168 | embedding_array = np.array(embedding, dtype=np.float32) |
| 171 | - embeddings[original_idx] = embedding_array | ||
| 172 | - # Cache the embedding | ||
| 173 | - self._set_cached_embedding(text, 'en', embedding_array) | 169 | + # Validate embedding from service - if invalid, treat as no result |
| 170 | + if self._is_valid_embedding(embedding_array): | ||
| 171 | + embeddings[original_idx] = embedding_array | ||
| 172 | + # Cache the embedding | ||
| 173 | + self._set_cached_embedding(text, 'en', embedding_array) | ||
| 174 | + else: | ||
| 175 | + logger.warning( | ||
| 176 | + f"Invalid embedding returned from service for text {original_idx} " | ||
| 177 | + f"(contains NaN/Inf or invalid shape), treating as no result. " | ||
| 178 | + f"Text preview: {text[:50]}..." | ||
| 179 | + ) | ||
| 180 | + # 不生成兜底向量,保持为 None | ||
| 181 | + embeddings[original_idx] = None | ||
| 174 | else: | 182 | else: |
| 175 | logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...") | 183 | logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...") |
| 176 | - embeddings[original_idx] = np.zeros(1024, dtype=np.float32) | 184 | + # 不生成兜底向量,保持为 None |
| 185 | + embeddings[original_idx] = None | ||
| 177 | else: | 186 | else: |
| 178 | logger.warning(f"No response found for text {original_idx}") | 187 | logger.warning(f"No response found for text {original_idx}") |
| 179 | - embeddings[original_idx] = np.zeros(1024, dtype=np.float32) | 188 | + # 不生成兜底向量,保持为 None |
| 189 | + embeddings[original_idx] = None | ||
| 180 | 190 | ||
| 181 | except Exception as e: | 191 | except Exception as e: |
| 182 | logger.error(f"Failed to encode texts: {e}", exc_info=True) | 192 | logger.error(f"Failed to encode texts: {e}", exc_info=True) |
| 183 | - # Fill missing embeddings with zeros | ||
| 184 | - for idx in uncached_indices: | ||
| 185 | - if embeddings[idx] is None: | ||
| 186 | - embeddings[idx] = np.zeros(1024, dtype=np.float32) | 193 | + # 出错时不要生成兜底全零向量,保持为 None |
| 194 | + pass | ||
| 187 | 195 | ||
| 188 | - # Convert to numpy array | ||
| 189 | - return np.array(embeddings, dtype=np.float32) | 196 | + # 返回 numpy 数组(dtype=object),元素为 np.ndarray 或 None |
| 197 | + return np.array(embeddings, dtype=object) | ||
| 190 | 198 | ||
| 191 | def encode_batch( | 199 | def encode_batch( |
| 192 | self, | 200 | self, |
| @@ -211,6 +219,27 @@ class BgeEncoder: | @@ -211,6 +219,27 @@ class BgeEncoder: | ||
| 211 | """Generate a cache key for the query""" | 219 | """Generate a cache key for the query""" |
| 212 | return f"embedding:{language}:{query}" | 220 | return f"embedding:{language}:{query}" |
| 213 | 221 | ||
| 222 | + def _is_valid_embedding(self, embedding: np.ndarray) -> bool: | ||
| 223 | + """ | ||
| 224 | + Check if embedding is valid (not None, correct shape, no NaN/Inf). | ||
| 225 | + | ||
| 226 | + Args: | ||
| 227 | + embedding: Embedding array to validate | ||
| 228 | + | ||
| 229 | + Returns: | ||
| 230 | + True if valid, False otherwise | ||
| 231 | + """ | ||
| 232 | + if embedding is None: | ||
| 233 | + return False | ||
| 234 | + if not isinstance(embedding, np.ndarray): | ||
| 235 | + return False | ||
| 236 | + if embedding.size == 0: | ||
| 237 | + return False | ||
| 238 | + # Check for NaN or Inf values | ||
| 239 | + if not np.isfinite(embedding).all(): | ||
| 240 | + return False | ||
| 241 | + return True | ||
| 242 | + | ||
| 214 | def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]: | 243 | def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]: |
| 215 | """Get embedding from cache if exists (with sliding expiration)""" | 244 | """Get embedding from cache if exists (with sliding expiration)""" |
| 216 | if not self.redis_client: | 245 | if not self.redis_client: |
| @@ -220,10 +249,24 @@ class BgeEncoder: | @@ -220,10 +249,24 @@ class BgeEncoder: | ||
| 220 | cache_key = self._get_cache_key(query, language) | 249 | cache_key = self._get_cache_key(query, language) |
| 221 | cached_data = self.redis_client.get(cache_key) | 250 | cached_data = self.redis_client.get(cache_key) |
| 222 | if cached_data: | 251 | if cached_data: |
| 223 | - logger.debug(f"Cache hit for embedding: {query}") | ||
| 224 | - # Update expiration time on access (sliding expiration) | ||
| 225 | - self.redis_client.expire(cache_key, self.expire_time) | ||
| 226 | - return pickle.loads(cached_data) | 252 | + embedding = pickle.loads(cached_data) |
| 253 | + # Validate cached embedding - if invalid, ignore cache and return None | ||
| 254 | + if self._is_valid_embedding(embedding): | ||
| 255 | + logger.debug(f"Cache hit for embedding: {query}") | ||
| 256 | + # Update expiration time on access (sliding expiration) | ||
| 257 | + self.redis_client.expire(cache_key, self.expire_time) | ||
| 258 | + return embedding | ||
| 259 | + else: | ||
| 260 | + logger.warning( | ||
| 261 | + f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), " | ||
| 262 | + f"ignoring cache for query: {query[:50]}..." | ||
| 263 | + ) | ||
| 264 | + # Delete invalid cache entry | ||
| 265 | + try: | ||
| 266 | + self.redis_client.delete(cache_key) | ||
| 267 | + except Exception as e: | ||
| 268 | + logger.debug(f"Failed to delete invalid cache entry: {e}") | ||
| 269 | + return None | ||
| 227 | return None | 270 | return None |
| 228 | except Exception as e: | 271 | except Exception as e: |
| 229 | logger.error(f"Error retrieving embedding from cache: {e}") | 272 | logger.error(f"Error retrieving embedding from cache: {e}") |
indexer/document_transformer.py
| @@ -5,6 +5,7 @@ SPU文档转换器 - 公共转换逻辑。 | @@ -5,6 +5,7 @@ SPU文档转换器 - 公共转换逻辑。 | ||
| 5 | """ | 5 | """ |
| 6 | 6 | ||
| 7 | import pandas as pd | 7 | import pandas as pd |
| 8 | +import numpy as np | ||
| 8 | import logging | 9 | import logging |
| 9 | from typing import Dict, Any, Optional, List | 10 | from typing import Dict, Any, Optional, List |
| 10 | from config import ConfigLoader | 11 | from config import ConfigLoader |
| @@ -594,6 +595,9 @@ class SPUDocumentTransformer: | @@ -594,6 +595,9 @@ class SPUDocumentTransformer: | ||
| 594 | if embeddings is not None and len(embeddings) > 0: | 595 | if embeddings is not None and len(embeddings) > 0: |
| 595 | # 取第一个embedding(因为只传了一个文本) | 596 | # 取第一个embedding(因为只传了一个文本) |
| 596 | embedding = embeddings[0] | 597 | embedding = embeddings[0] |
| 598 | + if not isinstance(embedding, np.ndarray): | ||
| 599 | + logger.warning(f"Embedding is None/invalid for title: {title_text[:50]}...") | ||
| 600 | + return | ||
| 597 | # 转换为列表格式(ES需要) | 601 | # 转换为列表格式(ES需要) |
| 598 | doc['title_embedding'] = embedding.tolist() | 602 | doc['title_embedding'] = embedding.tolist() |
| 599 | logger.debug(f"Generated title_embedding for SPU: {doc.get('spu_id')}, title: {title_text[:50]}...") | 603 | logger.debug(f"Generated title_embedding for SPU: {doc.get('spu_id')}, title: {title_text[:50]}...") |
indexer/incremental_service.py
| @@ -5,6 +5,7 @@ import logging | @@ -5,6 +5,7 @@ import logging | ||
| 5 | import time | 5 | import time |
| 6 | import threading | 6 | import threading |
| 7 | from typing import Dict, Any, Optional, List, Tuple | 7 | from typing import Dict, Any, Optional, List, Tuple |
| 8 | +import numpy as np | ||
| 8 | from sqlalchemy import text, bindparam | 9 | from sqlalchemy import text, bindparam |
| 9 | from indexer.indexing_utils import load_category_mapping, create_document_transformer | 10 | from indexer.indexing_utils import load_category_mapping, create_document_transformer |
| 10 | from indexer.bulk_indexer import BulkIndexer | 11 | from indexer.bulk_indexer import BulkIndexer |
| @@ -134,7 +135,9 @@ class IncrementalIndexerService: | @@ -134,7 +135,9 @@ class IncrementalIndexerService: | ||
| 134 | try: | 135 | try: |
| 135 | embeddings = encoder.encode(title_text) | 136 | embeddings = encoder.encode(title_text) |
| 136 | if embeddings is not None and len(embeddings) > 0: | 137 | if embeddings is not None and len(embeddings) > 0: |
| 137 | - doc["title_embedding"] = embeddings[0].tolist() | 138 | + emb0 = embeddings[0] |
| 139 | + if isinstance(emb0, np.ndarray): | ||
| 140 | + doc["title_embedding"] = emb0.tolist() | ||
| 138 | except Exception as e: | 141 | except Exception as e: |
| 139 | logger.warning(f"Failed to generate embedding for spu_id={spu_id}: {e}") | 142 | logger.warning(f"Failed to generate embedding for spu_id={spu_id}: {e}") |
| 140 | 143 | ||
| @@ -564,7 +567,8 @@ class IncrementalIndexerService: | @@ -564,7 +567,8 @@ class IncrementalIndexerService: | ||
| 564 | embeddings = encoder.encode_batch(title_texts, batch_size=32) | 567 | embeddings = encoder.encode_batch(title_texts, batch_size=32) |
| 565 | for j, emb in enumerate(embeddings): | 568 | for j, emb in enumerate(embeddings): |
| 566 | doc_idx = title_doc_indices[j] | 569 | doc_idx = title_doc_indices[j] |
| 567 | - documents[doc_idx][1]["title_embedding"] = emb.tolist() | 570 | + if isinstance(emb, np.ndarray): |
| 571 | + documents[doc_idx][1]["title_embedding"] = emb.tolist() | ||
| 568 | except Exception as e: | 572 | except Exception as e: |
| 569 | logger.warning(f"[IncrementalIndexing] Batch embedding failed for tenant_id={tenant_id}: {e}", exc_info=True) | 573 | logger.warning(f"[IncrementalIndexing] Batch embedding failed for tenant_id={tenant_id}: {e}", exc_info=True) |
| 570 | 574 |
query/query_parser.py
| @@ -331,8 +331,14 @@ class QueryParser: | @@ -331,8 +331,14 @@ class QueryParser: | ||
| 331 | log_debug("开始生成查询向量(异步)") | 331 | log_debug("开始生成查询向量(异步)") |
| 332 | # Submit encoding task to thread pool for async execution | 332 | # Submit encoding task to thread pool for async execution |
| 333 | encoding_executor = ThreadPoolExecutor(max_workers=1) | 333 | encoding_executor = ThreadPoolExecutor(max_workers=1) |
| 334 | + def _encode_query_vector() -> Optional[np.ndarray]: | ||
| 335 | + arr = self.text_encoder.encode([query_text]) | ||
| 336 | + if arr is None or len(arr) == 0: | ||
| 337 | + return None | ||
| 338 | + vec = arr[0] | ||
| 339 | + return vec if isinstance(vec, np.ndarray) else None | ||
| 334 | embedding_future = encoding_executor.submit( | 340 | embedding_future = encoding_executor.submit( |
| 335 | - lambda: self.text_encoder.encode([query_text])[0] | 341 | + _encode_query_vector |
| 336 | ) | 342 | ) |
| 337 | except Exception as e: | 343 | except Exception as e: |
| 338 | error_msg = f"查询向量生成任务提交失败 | 错误: {str(e)}" | 344 | error_msg = f"查询向量生成任务提交失败 | 错误: {str(e)}" |
| @@ -370,9 +376,12 @@ class QueryParser: | @@ -370,9 +376,12 @@ class QueryParser: | ||
| 370 | context.store_intermediate_result(f'translation_{lang}', result) | 376 | context.store_intermediate_result(f'translation_{lang}', result) |
| 371 | elif task_type == 'embedding': | 377 | elif task_type == 'embedding': |
| 372 | query_vector = result | 378 | query_vector = result |
| 373 | - log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}") | ||
| 374 | - if context: | ||
| 375 | - context.store_intermediate_result('query_vector_shape', query_vector.shape) | 379 | + if query_vector is not None: |
| 380 | + log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}") | ||
| 381 | + if context: | ||
| 382 | + context.store_intermediate_result('query_vector_shape', query_vector.shape) | ||
| 383 | + else: | ||
| 384 | + log_info("查询向量生成完成但结果为空(None),将按无向量处理") | ||
| 376 | except Exception as e: | 385 | except Exception as e: |
| 377 | if task_type == 'translation': | 386 | if task_type == 'translation': |
| 378 | error_msg = f"翻译失败 | 语言: {lang} | 错误: {str(e)}" | 387 | error_msg = f"翻译失败 | 语言: {lang} | 错误: {str(e)}" |