Commit 3d588beff5d016b419549d68fa0d61eef8b5e67e
1 parent
8b74784e
embeddings
Showing
9 changed files
with
26 additions
and
63 deletions
Show diff stats
config/env_config.py
| ... | ... | @@ -44,6 +44,8 @@ REDIS_CONFIG = { |
| 44 | 44 | 'cache_expire_days': int(os.getenv('REDIS_CACHE_EXPIRE_DAYS', 360*2)), # 6 months |
| 45 | 45 | 'translation_cache_expire_days': int(os.getenv('REDIS_TRANSLATION_CACHE_EXPIRE_DAYS', 360*2)), |
| 46 | 46 | 'translation_cache_prefix': os.getenv('REDIS_TRANSLATION_CACHE_PREFIX', 'trans'), |
| 47 | + # Embedding 缓存 key 前缀,例如 "embedding" | |
| 48 | + 'embedding_cache_prefix': os.getenv('REDIS_EMBEDDING_CACHE_PREFIX', 'embedding'), | |
| 47 | 49 | } |
| 48 | 50 | |
| 49 | 51 | # DeepL API Key | ... | ... |
docs/缓存与Redis使用说明.md
| ... | ... | @@ -20,7 +20,7 @@ |
| 20 | 20 | |
| 21 | 21 | | 模块 / 场景 | Key 模板 | Value 内容示例 | 过期策略 | 备注 | |
| 22 | 22 | |------------|----------|----------------|----------|------| |
| 23 | -| 文本向量缓存(embedding) | `embedding:{language}:{norm_flag}:{query}` | `pickle.dumps(np.ndarray)`,如 1024 维 BGE 向量 | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py` | | |
| 23 | +| 文本向量缓存(embedding) | `{EMBEDDING_CACHE_PREFIX}:{query}` | `pickle.dumps(np.ndarray)`,如 1024 维 BGE 向量 | TTL=`REDIS_CONFIG["cache_expire_days"]` 天;访问时滑动过期 | 见 `embeddings/text_encoder.py`,前缀由 `REDIS_CONFIG["embedding_cache_prefix"]` 控制 | | |
| 24 | 24 | | 翻译结果缓存(Qwen-MT 翻译) | `{cache_prefix}:{model}:{src}:{tgt}:{sha256(payload)}` | 机翻后的单条字符串 | TTL=`services.translation.cache.ttl_seconds` 秒;可配置滑动过期 | 见 `query/qwen_mt_translate.py` + `config/config.yaml` | |
| 25 | 25 | | 商品内容理解缓存(anchors / 语义属性 / tags) | `{ANCHOR_CACHE_PREFIX}:{tenant_or_global}:{target_lang}:{md5(title)}` | `json.dumps(dict)`,包含 id/title/category/tags/anchor_text 等 | TTL=`ANCHOR_CACHE_EXPIRE_DAYS` 天 | 见 `indexer/product_enrich.py` | |
| 26 | 26 | |
| ... | ... | @@ -35,16 +35,16 @@ |
| 35 | 35 | |
| 36 | 36 | ### 2.1 Key 设计 |
| 37 | 37 | |
| 38 | -- 函数:`_get_cache_key(query: str, language: str, normalize_embeddings: bool) -> str` | |
| 38 | +- 函数:`_get_cache_key(query: str, normalize_embeddings: bool) -> str` | |
| 39 | 39 | - 模板: |
| 40 | 40 | |
| 41 | 41 | ```text |
| 42 | -embedding:{language}:{norm_flag}:{query} | |
| 42 | +{EMBEDDING_CACHE_PREFIX}:{query} | |
| 43 | 43 | ``` |
| 44 | 44 | |
| 45 | 45 | - 字段说明: |
| 46 | - - `language`:当前实现中统一传入 `"generic"`; | |
| 47 | - - `norm_flag`:`"norm1"` 表示归一化向量,`"norm0"` 表示未归一化; | |
| 46 | + - `EMBEDDING_CACHE_PREFIX`:来自 `REDIS_CONFIG["embedding_cache_prefix"]`,默认值为 `"embedding"`,可通过环境变量 `REDIS_EMBEDDING_CACHE_PREFIX` 覆盖; | |
| 47 | + - 当前实现**不再区分 language 与 normalize flag**,即无论是否归一化,key 结构都相同; | |
| 48 | 48 | - `query`:原始文本(未做哈希),注意长度特别长的 query 会直接出现在 key 中。 |
| 49 | 49 | |
| 50 | 50 | ### 2.2 Value 与类型 | ... | ... |
embeddings/clip_as_service_encoder.py
| ... | ... | @@ -54,13 +54,8 @@ class ClipAsServiceImageEncoder: |
| 54 | 54 | show_progress: whether to show progress bar when encoding. |
| 55 | 55 | """ |
| 56 | 56 | _ensure_clip_client_path() |
| 57 | - try: | |
| 58 | - from clip_client import Client | |
| 59 | - except ImportError as e: | |
| 60 | - raise ImportError( | |
| 61 | - "clip_client not found. Add third-party/clip-as-service/client to PYTHONPATH " | |
| 62 | - "or run: pip install -e third-party/clip-as-service/client" | |
| 63 | - ) from e | |
| 57 | + | |
| 58 | + from clip_client import Client | |
| 64 | 59 | |
| 65 | 60 | self._server = server |
| 66 | 61 | self._batch_size = batch_size | ... | ... |
embeddings/text_encoder.py
| ... | ... | @@ -15,11 +15,7 @@ logger = logging.getLogger(__name__) |
| 15 | 15 | from config.services_config import get_embedding_base_url |
| 16 | 16 | |
| 17 | 17 | # Try to import REDIS_CONFIG, but allow import to fail |
| 18 | -try: | |
| 19 | - from config.env_config import REDIS_CONFIG | |
| 20 | -except ImportError: | |
| 21 | - REDIS_CONFIG = {} | |
| 22 | - | |
| 18 | +from config.env_config import REDIS_CONFIG | |
| 23 | 19 | |
| 24 | 20 | class TextEmbeddingEncoder: |
| 25 | 21 | """ |
| ... | ... | @@ -31,6 +27,7 @@ class TextEmbeddingEncoder: |
| 31 | 27 | self.service_url = str(resolved_url).rstrip("/") |
| 32 | 28 | self.endpoint = f"{self.service_url}/embed/text" |
| 33 | 29 | self.expire_time = timedelta(days=REDIS_CONFIG.get("cache_expire_days", 180)) |
| 30 | + self.cache_prefix = str(REDIS_CONFIG.get("embedding_cache_prefix", "embedding")).strip() or "embedding" | |
| 34 | 31 | logger.info("Creating TextEmbeddingEncoder instance with service URL: %s", self.service_url) |
| 35 | 32 | |
| 36 | 33 | try: |
| ... | ... | @@ -104,7 +101,7 @@ class TextEmbeddingEncoder: |
| 104 | 101 | embeddings: List[Optional[np.ndarray]] = [None] * len(sentences) |
| 105 | 102 | |
| 106 | 103 | for i, text in enumerate(sentences): |
| 107 | - cached = self._get_cached_embedding(text, "generic", normalize_embeddings) | |
| 104 | + cached = self._get_cached_embedding(text) | |
| 108 | 105 | if cached is not None: |
| 109 | 106 | embeddings[i] = cached |
| 110 | 107 | else: |
| ... | ... | @@ -130,7 +127,7 @@ class TextEmbeddingEncoder: |
| 130 | 127 | embedding_array = np.array(embedding, dtype=np.float32) |
| 131 | 128 | if self._is_valid_embedding(embedding_array): |
| 132 | 129 | embeddings[original_idx] = embedding_array |
| 133 | - self._set_cached_embedding(text, "generic", embedding_array, normalize_embeddings) | |
| 130 | + self._set_cached_embedding(text, embedding_array, normalize_embeddings) | |
| 134 | 131 | else: |
| 135 | 132 | raise ValueError( |
| 136 | 133 | f"Invalid embedding returned from service for text index {original_idx}" |
| ... | ... | @@ -165,12 +162,7 @@ class TextEmbeddingEncoder: |
| 165 | 162 | device=device, |
| 166 | 163 | normalize_embeddings=normalize_embeddings, |
| 167 | 164 | ) |
| 168 | - | |
| 169 | - def _get_cache_key(self, query: str, language: str, normalize_embeddings: bool = True) -> str: | |
| 170 | - """Generate a cache key for the query""" | |
| 171 | - norm_flag = "norm1" if normalize_embeddings else "norm0" | |
| 172 | - return f"embedding:{language}:{norm_flag}:{query}" | |
| 173 | - | |
| 165 | + | |
| 174 | 166 | def _is_valid_embedding(self, embedding: np.ndarray) -> bool: |
| 175 | 167 | """ |
| 176 | 168 | Check if embedding is valid (not None, correct shape, no NaN/Inf). |
| ... | ... | @@ -194,16 +186,14 @@ class TextEmbeddingEncoder: |
| 194 | 186 | |
| 195 | 187 | def _get_cached_embedding( |
| 196 | 188 | self, |
| 197 | - query: str, | |
| 198 | - language: str, | |
| 199 | - normalize_embeddings: bool = True, | |
| 189 | + query: str | |
| 200 | 190 | ) -> Optional[np.ndarray]: |
| 201 | 191 | """Get embedding from cache if exists (with sliding expiration)""" |
| 202 | 192 | if not self.redis_client: |
| 203 | 193 | return None |
| 204 | 194 | |
| 205 | 195 | try: |
| 206 | - cache_key = self._get_cache_key(query, language, normalize_embeddings) | |
| 196 | + cache_key = f"{self.cache_prefix}:{query}" | |
| 207 | 197 | cached_data = self.redis_client.get(cache_key) |
| 208 | 198 | if cached_data: |
| 209 | 199 | embedding = pickle.loads(cached_data) |
| ... | ... | @@ -232,7 +222,6 @@ class TextEmbeddingEncoder: |
| 232 | 222 | def _set_cached_embedding( |
| 233 | 223 | self, |
| 234 | 224 | query: str, |
| 235 | - language: str, | |
| 236 | 225 | embedding: np.ndarray, |
| 237 | 226 | normalize_embeddings: bool = True, |
| 238 | 227 | ) -> bool: |
| ... | ... | @@ -241,7 +230,7 @@ class TextEmbeddingEncoder: |
| 241 | 230 | return False |
| 242 | 231 | |
| 243 | 232 | try: |
| 244 | - cache_key = self._get_cache_key(query, language, normalize_embeddings) | |
| 233 | + cache_key = f"{self.cache_prefix}:{query}" | |
| 245 | 234 | serialized_data = pickle.dumps(embedding) |
| 246 | 235 | self.redis_client.setex( |
| 247 | 236 | cache_key, | ... | ... |
indexer/document_transformer.py
| ... | ... | @@ -18,13 +18,7 @@ from indexer.product_enrich import analyze_products |
| 18 | 18 | |
| 19 | 19 | logger = logging.getLogger(__name__) |
| 20 | 20 | |
| 21 | -# Try to import translator (optional dependency) | |
| 22 | -try: | |
| 23 | - from query.qwen_mt_translate import Translator | |
| 24 | - TRANSLATOR_AVAILABLE = True | |
| 25 | -except ImportError: | |
| 26 | - TRANSLATOR_AVAILABLE = False | |
| 27 | - Translator = None | |
| 21 | +from query.qwen_mt_translate import Translator | |
| 28 | 22 | |
| 29 | 23 | |
| 30 | 24 | class SPUDocumentTransformer: | ... | ... |
reranker/backends/qwen3_transformers.py
| ... | ... | @@ -13,15 +13,8 @@ from typing import Any, Dict, List, Optional, Tuple |
| 13 | 13 | |
| 14 | 14 | logger = logging.getLogger("reranker.backends.qwen3_transformers") |
| 15 | 15 | |
| 16 | -try: | |
| 17 | - import torch | |
| 18 | - from transformers import AutoModelForCausalLM, AutoTokenizer | |
| 19 | -except ImportError as e: | |
| 20 | - raise ImportError( | |
| 21 | - "Qwen3-Transformers reranker backend requires transformers>=4.51.0 and torch. " | |
| 22 | - "Install with: pip install transformers>=4.51.0 torch" | |
| 23 | - ) from e | |
| 24 | - | |
| 16 | +import torch | |
| 17 | +from transformers import AutoModelForCausalLM, AutoTokenizer | |
| 25 | 18 | |
| 26 | 19 | def _format_instruction(instruction: str, query: str, doc: str) -> str: |
| 27 | 20 | """Format (query, doc) pair per official Qwen3-Reranker spec.""" | ... | ... |
reranker/backends/qwen3_vllm.py
| ... | ... | @@ -16,16 +16,10 @@ from typing import Any, Dict, List, Tuple |
| 16 | 16 | |
| 17 | 17 | logger = logging.getLogger("reranker.backends.qwen3_vllm") |
| 18 | 18 | |
| 19 | -try: | |
| 20 | - import torch | |
| 21 | - from transformers import AutoTokenizer | |
| 22 | - from vllm import LLM, SamplingParams | |
| 23 | - from vllm.inputs.data import TokensPrompt | |
| 24 | -except ImportError as e: | |
| 25 | - raise ImportError( | |
| 26 | - "Qwen3-vLLM reranker backend requires vllm>=0.8.5 and transformers. " | |
| 27 | - "Install with: pip install vllm transformers" | |
| 28 | - ) from e | |
| 19 | +import torch | |
| 20 | +from transformers import AutoTokenizer | |
| 21 | +from vllm import LLM, SamplingParams | |
| 22 | +from vllm.inputs.data import TokensPrompt | |
| 29 | 23 | |
| 30 | 24 | |
| 31 | 25 | def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]: | ... | ... |
tests/test_embedding_pipeline.py
| ... | ... | @@ -128,7 +128,7 @@ def test_text_embedding_encoder_raises_on_missing_vector(monkeypatch): |
| 128 | 128 | def test_text_embedding_encoder_cache_hit(monkeypatch): |
| 129 | 129 | fake_redis = _FakeRedis() |
| 130 | 130 | cached = np.array([0.9, 0.8], dtype=np.float32) |
| 131 | - fake_redis.store["embedding:generic:cached-text"] = pickle.dumps(cached) | |
| 131 | + fake_redis.store["embedding:cached-text"] = pickle.dumps(cached) | |
| 132 | 132 | monkeypatch.setattr("embeddings.text_encoder.redis.Redis", lambda **kwargs: fake_redis) |
| 133 | 133 | |
| 134 | 134 | calls = {"count": 0} | ... | ... |
utils/es_client.py
| ... | ... | @@ -8,11 +8,7 @@ from typing import Dict, Any, List, Optional |
| 8 | 8 | import os |
| 9 | 9 | import logging |
| 10 | 10 | |
| 11 | -# Try to import ES_CONFIG, but allow import to fail | |
| 12 | -try: | |
| 13 | - from config.env_config import ES_CONFIG | |
| 14 | -except ImportError: | |
| 15 | - ES_CONFIG = None | |
| 11 | +from config.env_config import ES_CONFIG | |
| 16 | 12 | |
| 17 | 13 | logger = logging.getLogger(__name__) |
| 18 | 14 | ... | ... |