From b2e50710eba04858a7f45112120b57fc77574cca Mon Sep 17 00:00:00 2001 From: tangwang Date: Fri, 19 Dec 2025 18:05:59 +0800 Subject: [PATCH] BgeEncoder.encode(...) 返回:np.ndarray(dtype=object),每个元素 要么是 np.ndarray,要么是 None。 cache/service 任一环节返回坏 embedding(含 NaN/Inf/空/非 ndarray)都会 视为 None,并且坏 cache 会被自动删除。 --- docs/索引字段说明v2.md | 2 +- embeddings/text_encoder.py | 107 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------- indexer/document_transformer.py | 4 ++++ indexer/incremental_service.py | 8 ++++++-- query/query_parser.py | 17 +++++++++++++---- 5 files changed, 99 insertions(+), 39 deletions(-) diff --git a/docs/索引字段说明v2.md b/docs/索引字段说明v2.md index 5705e45..40900ec 100644 --- a/docs/索引字段说明v2.md +++ b/docs/索引字段说明v2.md @@ -433,7 +433,7 @@ filters AND (text_recall OR embedding_recall) ### 文本召回字段 -默认同时搜索以下字段(中英文都包含): +根据查询词的语言选择对应的索引字段: - `title_zh^3.0`, `title_en^3.0` - `brief_zh^1.5`, `brief_en^1.5` - `description_zh^1.0`, `description_en^1.0` diff --git a/embeddings/text_encoder.py b/embeddings/text_encoder.py index a27c287..8d6d3ac 100644 --- a/embeddings/text_encoder.py +++ b/embeddings/text_encoder.py @@ -101,24 +101,17 @@ class BgeEncoder: batch_size: Batch size for processing (used for service requests) Returns: - numpy array of shape (n, 1024) containing embeddings + numpy array of dtype=object, where each element is either: + - np.ndarray (valid embedding vector) or + - None (no embedding available for that text) """ # Convert single string to list if isinstance(sentences, str): sentences = [sentences] # Check cache first - cached_embeddings = [] - uncached_indices = [] - uncached_texts = [] - - for i, text in enumerate(sentences): - cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding - if cached is not None: - cached_embeddings.append((i, cached)) - else: - uncached_indices.append(i) - uncached_texts.append(text) + uncached_indices: List[int] = [] + uncached_texts: List[str] = [] # Prepare request data for uncached texts request_data = [] @@ -136,11 +129,16 @@ class BgeEncoder: request_data.append(request_item) # Process response - embeddings = [None] * len(sentences) - - # Fill in cached embeddings - for idx, cached_emb in cached_embeddings: - embeddings[idx] = cached_emb + # Each element can be np.ndarray or None (表示该文本没有可用的向量) + embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) + + for i, text in enumerate(sentences): + cached = self._get_cached_embedding(text, 'en') # Use 'en' as default language for title embedding + if cached is not None: + embeddings[i] = cached + else: + uncached_indices.append(i) + uncached_texts.append(text) # If there are uncached texts, call service if uncached_texts: @@ -168,25 +166,35 @@ class BgeEncoder: if embedding is not None: embedding_array = np.array(embedding, dtype=np.float32) - embeddings[original_idx] = embedding_array - # Cache the embedding - self._set_cached_embedding(text, 'en', embedding_array) + # Validate embedding from service - if invalid, treat as no result + if self._is_valid_embedding(embedding_array): + embeddings[original_idx] = embedding_array + # Cache the embedding + self._set_cached_embedding(text, 'en', embedding_array) + else: + logger.warning( + f"Invalid embedding returned from service for text {original_idx} " + f"(contains NaN/Inf or invalid shape), treating as no result. " + f"Text preview: {text[:50]}..." + ) + # 不生成兜底向量,保持为 None + embeddings[original_idx] = None else: logger.warning(f"No embedding found for text {original_idx}: {text[:50]}...") - embeddings[original_idx] = np.zeros(1024, dtype=np.float32) + # 不生成兜底向量,保持为 None + embeddings[original_idx] = None else: logger.warning(f"No response found for text {original_idx}") - embeddings[original_idx] = np.zeros(1024, dtype=np.float32) + # 不生成兜底向量,保持为 None + embeddings[original_idx] = None except Exception as e: logger.error(f"Failed to encode texts: {e}", exc_info=True) - # Fill missing embeddings with zeros - for idx in uncached_indices: - if embeddings[idx] is None: - embeddings[idx] = np.zeros(1024, dtype=np.float32) + # 出错时不要生成兜底全零向量,保持为 None + pass - # Convert to numpy array - return np.array(embeddings, dtype=np.float32) + # 返回 numpy 数组(dtype=object),元素为 np.ndarray 或 None + return np.array(embeddings, dtype=object) def encode_batch( self, @@ -211,6 +219,27 @@ class BgeEncoder: """Generate a cache key for the query""" return f"embedding:{language}:{query}" + def _is_valid_embedding(self, embedding: np.ndarray) -> bool: + """ + Check if embedding is valid (not None, correct shape, no NaN/Inf). + + Args: + embedding: Embedding array to validate + + Returns: + True if valid, False otherwise + """ + if embedding is None: + return False + if not isinstance(embedding, np.ndarray): + return False + if embedding.size == 0: + return False + # Check for NaN or Inf values + if not np.isfinite(embedding).all(): + return False + return True + def _get_cached_embedding(self, query: str, language: str) -> Optional[np.ndarray]: """Get embedding from cache if exists (with sliding expiration)""" if not self.redis_client: @@ -220,10 +249,24 @@ class BgeEncoder: cache_key = self._get_cache_key(query, language) cached_data = self.redis_client.get(cache_key) if cached_data: - logger.debug(f"Cache hit for embedding: {query}") - # Update expiration time on access (sliding expiration) - self.redis_client.expire(cache_key, self.expire_time) - return pickle.loads(cached_data) + embedding = pickle.loads(cached_data) + # Validate cached embedding - if invalid, ignore cache and return None + if self._is_valid_embedding(embedding): + logger.debug(f"Cache hit for embedding: {query}") + # Update expiration time on access (sliding expiration) + self.redis_client.expire(cache_key, self.expire_time) + return embedding + else: + logger.warning( + f"Invalid embedding found in cache (contains NaN/Inf or invalid shape), " + f"ignoring cache for query: {query[:50]}..." + ) + # Delete invalid cache entry + try: + self.redis_client.delete(cache_key) + except Exception as e: + logger.debug(f"Failed to delete invalid cache entry: {e}") + return None return None except Exception as e: logger.error(f"Error retrieving embedding from cache: {e}") diff --git a/indexer/document_transformer.py b/indexer/document_transformer.py index 1d5a34f..1c9b6c5 100644 --- a/indexer/document_transformer.py +++ b/indexer/document_transformer.py @@ -5,6 +5,7 @@ SPU文档转换器 - 公共转换逻辑。 """ import pandas as pd +import numpy as np import logging from typing import Dict, Any, Optional, List from config import ConfigLoader @@ -594,6 +595,9 @@ class SPUDocumentTransformer: if embeddings is not None and len(embeddings) > 0: # 取第一个embedding(因为只传了一个文本) embedding = embeddings[0] + if not isinstance(embedding, np.ndarray): + logger.warning(f"Embedding is None/invalid for title: {title_text[:50]}...") + return # 转换为列表格式(ES需要) doc['title_embedding'] = embedding.tolist() logger.debug(f"Generated title_embedding for SPU: {doc.get('spu_id')}, title: {title_text[:50]}...") diff --git a/indexer/incremental_service.py b/indexer/incremental_service.py index e2a0989..1246f75 100644 --- a/indexer/incremental_service.py +++ b/indexer/incremental_service.py @@ -5,6 +5,7 @@ import logging import time import threading from typing import Dict, Any, Optional, List, Tuple +import numpy as np from sqlalchemy import text, bindparam from indexer.indexing_utils import load_category_mapping, create_document_transformer from indexer.bulk_indexer import BulkIndexer @@ -134,7 +135,9 @@ class IncrementalIndexerService: try: embeddings = encoder.encode(title_text) if embeddings is not None and len(embeddings) > 0: - doc["title_embedding"] = embeddings[0].tolist() + emb0 = embeddings[0] + if isinstance(emb0, np.ndarray): + doc["title_embedding"] = emb0.tolist() except Exception as e: logger.warning(f"Failed to generate embedding for spu_id={spu_id}: {e}") @@ -564,7 +567,8 @@ class IncrementalIndexerService: embeddings = encoder.encode_batch(title_texts, batch_size=32) for j, emb in enumerate(embeddings): doc_idx = title_doc_indices[j] - documents[doc_idx][1]["title_embedding"] = emb.tolist() + if isinstance(emb, np.ndarray): + documents[doc_idx][1]["title_embedding"] = emb.tolist() except Exception as e: logger.warning(f"[IncrementalIndexing] Batch embedding failed for tenant_id={tenant_id}: {e}", exc_info=True) diff --git a/query/query_parser.py b/query/query_parser.py index d0c6112..c71c4db 100644 --- a/query/query_parser.py +++ b/query/query_parser.py @@ -331,8 +331,14 @@ class QueryParser: log_debug("开始生成查询向量(异步)") # Submit encoding task to thread pool for async execution encoding_executor = ThreadPoolExecutor(max_workers=1) + def _encode_query_vector() -> Optional[np.ndarray]: + arr = self.text_encoder.encode([query_text]) + if arr is None or len(arr) == 0: + return None + vec = arr[0] + return vec if isinstance(vec, np.ndarray) else None embedding_future = encoding_executor.submit( - lambda: self.text_encoder.encode([query_text])[0] + _encode_query_vector ) except Exception as e: error_msg = f"查询向量生成任务提交失败 | 错误: {str(e)}" @@ -370,9 +376,12 @@ class QueryParser: context.store_intermediate_result(f'translation_{lang}', result) elif task_type == 'embedding': query_vector = result - log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}") - if context: - context.store_intermediate_result('query_vector_shape', query_vector.shape) + if query_vector is not None: + log_debug(f"查询向量生成完成 | 形状: {query_vector.shape}") + if context: + context.store_intermediate_result('query_vector_shape', query_vector.shape) + else: + log_info("查询向量生成完成但结果为空(None),将按无向量处理") except Exception as e: if task_type == 'translation': error_msg = f"翻译失败 | 语言: {lang} | 错误: {str(e)}" -- libgit2 0.21.2