diff --git a/docs/reference/商品数据源入ES配置规范.md b/docs/reference/商品数据源入ES配置规范.md deleted file mode 100644 index a503d64..0000000 --- a/docs/reference/商品数据源入ES配置规范.md +++ /dev/null @@ -1,221 +0,0 @@ -根据您提供的内容,我将其整理为规范的Markdown格式: - -# ES索引配置文档 - -## 1. 全局配置 - -### 1.1 文本字段相关性设定 -需要修改所有text字段相关性算法-BM25算法的默认参数: -```json -"similarity": { - "default": { - "type": "BM25", - "b": "0.0", - "k1": "0.0" - } -} -``` - -### 1.2 索引分片设定 -- `number_of_replicas`:0/1 -- `number_of_shards`:设置建议 分片数 <= ES集群的总CPU核心个数/ (副本数 + 1) - -### 1.3 索引刷新时间设定 -- `refresh_interval`:默认30S,根据客户需要进行调整 -```json -"refresh_interval": "30s" -``` - -## 2. 单个字段配置 - -| 分析方式 | 字段预处理和ES输入格式要求 | 对应ES mapping配置 | 备注 | -|---------|--------------------------|-------------------|------| -| 电商通用分析-中文 | - | ```json { "type": "text", "analyzer": "index_ansj", "search_analyzer": "query_ansj" } ``` | - | -| 文本-多语言向量化 | 调用"文本向量化"模块得到1024维向量 | ```json { "type": "dense_vector", "dims": 1024, "index": true, "similarity": "dot_product" } ``` | 1. 依赖"文本向量化"模块
2. 如果定期全量,需要对向量化结果做缓存 | -| 图片-向量化 | 调用"图片向量化"模块得到1024维向量 | ```json { "type": "nested", "properties": { "vector": { "type": "dense_vector", "dims": 1024, "similarity": "dot_product" }, "url": { "type": "text" } } } ``` | 1. 依赖"图片向量化"模块
2. 如果定期全量,需要对向量化结果做缓存 | -| 关键词 | ES输入格式:list或者单个值 | ```json {"type": "keyword"} ``` | - | -| 电商通用分析-英文 | - | ```json {"type": "text", "analyzer": "english"} ``` | - | -| 电商通用分析-阿拉伯文 | - | ```json {"type": "text", "analyzer": "arabic"} ``` | - | -| 电商通用分析-西班牙文 | - | ```json {"type": "text", "analyzer": "spanish"} ``` | - | -| 电商通用分析-俄文 | - | ```json {"type": "text", "analyzer": "russian"} ``` | - | -| 电商通用分析-日文 | - | ```json {"type": "text", "analyzer": "japanese"} ``` | - | -| 数值-整数 | - | ```json {"type": "long"} ``` | - | -| 数值-浮点型 | - | ```json {"type": "float"} ``` | - | -| 分值 | 输入是float,配置处理方式:log, pow, sigmoid等 | TODO:给代码, log | - | -| 子串 | - | 暂时不支持 | - | -| ngram匹配或前缀匹配或边缘前缀匹配 | - | 暂时不支持 | 以后根据需要再添加 | - -这样整理后,文档结构更加清晰,表格格式规范,便于阅读和理解。 - - -参考 opensearch: - -数据接口 -文本相关性字段 -向量相关性字段 -3. 模块提取 -文本向量化 -import sys -import torch -from sentence_transformers import SentenceTransformer -import time -import threading -from modelscope import snapshot_download -from transformers import AutoModel -import os -from openai import OpenAI -from config.logging_config import get_app_logger - -# Get logger for this module -logger = get_app_logger(__name__) - -class BgeEncoder: - _instance = None - _lock = threading.Lock() - - def __new__(cls, model_dir='Xorbits/bge-m3'): - with cls._lock: - if cls._instance is None: - cls._instance = super(BgeEncoder, cls).__new__(cls) - logger.info("[BgeEncoder] Creating a new instance with model directory: %s", model_dir) - cls._instance.model = SentenceTransformer(snapshot_download(model_dir)) - logger.info("[BgeEncoder] New instance has been created") - return cls._instance - - def encode(self, sentences, normalize_embeddings=True, device='cuda'): - # Move model to specified device - if device == 'gpu': - device = 'cuda' - self.model = self.model.to(device) - embeddings = self.model.encode(sentences, normalize_embeddings=normalize_embeddings, device=device, show_progress_bar=False) - return embeddings -图片向量化 -import sys -import os -import io -import requests -import torch -import numpy as np -from PIL import Image -import logging -import threading -from typing import List, Optional, Union -from config.logging_config import get_app_logger -import cn_clip.clip as clip -from cn_clip.clip import load_from_name - -# Get logger for this module -logger = get_app_logger(__name__) - -# DEFAULT_MODEL_NAME = "ViT-L-14-336" # ["ViT-B-16", "ViT-L-14", "ViT-L-14-336", "ViT-H-14", "RN50"] -DEFAULT_MODEL_NAME = "ViT-H-14" -MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" - -class CLIPImageEncoder: - """CLIP Image Encoder for generating image embeddings using cn_clip""" - - _instance = None - _lock = threading.Lock() - - def __new__(cls, model_name=DEFAULT_MODEL_NAME, device=None): - with cls._lock: - if cls._instance is None: - cls._instance = super(CLIPImageEncoder, cls).__new__(cls) - logger.info(f"[CLIPImageEncoder] Creating new instance with model: {model_name}") - cls._instance._initialize_model(model_name, device) - return cls._instance - - def _initialize_model(self, model_name, device): - """Initialize the CLIP model using cn_clip""" - try: - self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") - self.model, self.preprocess = load_from_name(model_name, device=self.device, download_root=MODEL_DOWNLOAD_DIR) - self.model.eval() - self.model_name = model_name - logger.info(f"[CLIPImageEncoder] Model {model_name} initialized successfully on device {self.device}") - - except Exception as e: - logger.error(f"[CLIPImageEncoder] Failed to initialize model: {str(e)}") - raise - - def validate_image(self, image_data: bytes) -> Image.Image: - """Validate image data and return PIL Image if valid""" - try: - image_stream = io.BytesIO(image_data) - image = Image.open(image_stream) - image.verify() - image_stream.seek(0) - image = Image.open(image_stream) - if image.mode != 'RGB': - image = image.convert('RGB') - return image - except Exception as e: - raise ValueError(f"Invalid image data: {str(e)}") - - def download_image(self, url: str, timeout: int = 10) -> bytes: - """Download image from URL""" - try: - if url.startswith(('http://', 'https://')): - response = requests.get(url, timeout=timeout) - if response.status_code != 200: - raise ValueError(f"HTTP {response.status_code}") - return response.content - else: - # Local file path - with open(url, 'rb') as f: - return f.read() - except Exception as e: - raise ValueError(f"Failed to download image from {url}: {str(e)}") - - def preprocess_image(self, image: Image.Image, max_size: int = 1024) -> Image.Image: - """Preprocess image for CLIP model""" - # Resize if too large - if max(image.size) > max_size: - ratio = max_size / max(image.size) - new_size = tuple(int(dim * ratio) for dim in image.size) - image = image.resize(new_size, Image.Resampling.LANCZOS) - return image - - def encode_text(self, text): - """Encode text to embedding vector using cn_clip""" - text_data = clip.tokenize([text] if type(text) == str else text).to(self.device) - with torch.no_grad(): - text_features = self.model.encode_text(text_data) - text_features /= text_features.norm(dim=-1, keepdim=True) - return text_features - - def encode_image(self, image: Image.Image) -> Optional[np.ndarray]: - """Encode image to embedding vector using cn_clip""" - if not isinstance(image, Image.Image): - raise ValueError("CLIPImageEncoder.encode_image Input must be a PIL.Image") - - try: - infer_data = self.preprocess(image).unsqueeze(0).to(self.device) - with torch.no_grad(): - image_features = self.model.encode_image(infer_data) - image_features /= image_features.norm(dim=-1, keepdim=True) - return image_features.cpu().numpy().astype('float32')[0] - except Exception as e: - logger.error(f"Failed to process image. Reason: {str(e)}") - return None - - def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: - """Complete pipeline: download, validate, preprocess and encode image from URL""" - try: - # Download image - image_data = self.download_image(url) - - # Validate image - image = self.validate_image(image_data) - - # Preprocess image - image = self.preprocess_image(image) - - # Encode image - embedding = self.encode_image(image) - - return embedding - - except Exception as e: - logger.error(f"Error processing image from URL {url}: {str(e)}") - return None \ No newline at end of file diff --git a/docs/reference/阿里opensearch电商行业.md b/docs/reference/阿里opensearch电商行业.md deleted file mode 100644 index 2e54e03..0000000 --- a/docs/reference/阿里opensearch电商行业.md +++ /dev/null @@ -1,47 +0,0 @@ -https://help.aliyun.com/zh/open-search/industry-algorithm-edition/e-commerce?spm=a2c4g.11186623.help-menu-29102.d_3_2_1.5a903cfbxOsaHt&scm=20140722.H_99739._.OR_help-T_cn~zh-V_1 - - -## 定义应用结构 -示例如下: -| 字段名称 | 主键 | 字段标签 | 类型 | -|----------------|------|------------|--------------| -| title | | 商品标题 | TEXT | -| text_embedding | | 文本向量 | EMBEDDING | -| image_embedding | | 图片向量 | EMBEDDING | -| category_name | | 类目名称 | TEXT | -| image_url | | | LITERAL_ARRAY| -| description | | 商品描述 | TEXT | -| brand_name | | 品牌名称 | TEXT | -| thumbnail_url | | | LITERAL_ARRAY| -| is_onsale | | | INT | -| url | | | LITERAL | -| brand_id | | | LITERAL | -| series_id | | | LITERAL | -| sold_num | | 商品销量 | INT | -| category_id | | | INT | -| onsale_time | | 上架时间 | INT | -| price | | | DOUBLE | -| series_name | | | TEXT | -| discount_price | | DOUBLE | -| pid | ● | INT | -| sale_price | | DOUBLE | -| act_price | | DOUBLE | - - -## 定义索引结构 - -| 索引名称 | 索引标签 | 包含字段 | 分析方式 | 使用示例 | -| --- | --- | --- | --- | --- | -| default | 默认索引 | category_name, description, brand_name, title, create_by, update_by | 行业 - 电商通用分析 | query=default:“云搜索” | -| category_name | 类目名称索引 | category_name | 行业 - 电商通用分析 | query=category_name:“云搜索” | -| category_id | | category_id | 关键字 | query=category_id:“云搜索” | -| series_name | | series_name | 中文 - 通用分析 | query=series_name:“云搜索” | -| brand_name | | brand_name | 中文 - 通用分析 | query=brand_name:“云搜索” | -| id | | id | 关键字 | query=id:“云搜索” | -| title | 标题索引 | title | 行业 - 电商通用分析 | query=title:“云搜索” | -| seller_id | | seller_id | 关键字 | query=seller_id:“云搜索” | -| brand_id | | brand_id | 关键字 | query=brand_id:“云搜索” | -| series_id | | series_id | 关键字 | query=series_id:“云搜索” | - -上面的只是阿里云的opensearch的例子,我们也要有同样的一套配置,这里支持的“字分析方式” 为ES预先支持的 多种分析器,我们要支持的分析方式参考 @商品数据源入ES配置规范.md - diff --git a/docs/temporary/sku_image_src问题诊断报告.md b/docs/temporary/sku_image_src问题诊断报告.md deleted file mode 100644 index ddd3d31..0000000 --- a/docs/temporary/sku_image_src问题诊断报告.md +++ /dev/null @@ -1,117 +0,0 @@ -# SKU image_src 字段为空问题诊断报告 - -## 问题描述 - -返回结果的每条结果中,多款式字段 `skus` 下面每个 SKU 的 `image_src` 为空。 - -## 问题分析 - -### 1. ES 数据检查 - -通过查询 ES 数据,发现: -- ES 中确实有 `skus` 数据(不是空数组) -- 但是 `skus` 数组中的每个 SKU 对象**都没有 `image_src` 字段** - -示例 ES 文档: -```json -{ - "spu_id": "68238", - "skus": [ - { - "sku_id": "3568395", - "price": 329.61, - "compare_at_price": 485.65, - "sku_code": "3468269", - "stock": 57, - "weight": 0.26, - "weight_unit": "kg", - "option1_value": "", - "option2_value": "", - "option3_value": "" - // 注意:这里没有 image_src 字段 - } - ] -} -``` - -### 2. 代码逻辑检查 - -在 `indexer/document_transformer.py` 的 `_transform_sku_row` 方法中(第558-560行),原有逻辑为: - -```python -# Image src -if pd.notna(sku_row.get('image_src')): - sku_data['image_src'] = str(sku_row['image_src']) -``` - -**问题根源**: -- 只有当 MySQL 中的 `image_src` 字段**非空**时,才会将其添加到 `sku_data` 字典中 -- 如果 MySQL 中的 `image_src` 是 `NULL` 或空字符串,这个字段就**不会出现在返回的字典中** -- 导致 ES 文档中缺少 `image_src` 字段 -- API 返回时,`sku_entry.get('image_src')` 返回 `None`,前端看到的就是空值 - -### 3. MySQL 数据情况 - -根据代码逻辑推断: -- MySQL 的 `shoplazza_product_sku` 表中,`image_src` 字段可能为 `NULL` 或空字符串 -- 这导致索引时该字段没有被写入 ES - -## 解决方案 - -### 修复方案 - -修改 `indexer/document_transformer.py` 中的 `_transform_sku_row` 方法,**始终包含 `image_src` 字段**,即使值为空也设置为 `None`: - -```python -# Image src - always include this field, even if empty -# This ensures the field is present in ES documents and API responses -image_src = sku_row.get('image_src') -if pd.notna(image_src) and str(image_src).strip(): - sku_data['image_src'] = str(image_src).strip() -else: - # Set to None (will be serialized as null in JSON) instead of omitting the field - sku_data['image_src'] = None -``` - -### 修复效果 - -修复后: -1. **即使 MySQL 中 `image_src` 为 NULL 或空字符串**,ES 文档中也会包含该字段(值为 `null`) -2. API 返回时,前端可以明确知道该字段存在但值为空 -3. 符合 API 模型定义:`image_src: Optional[str] = Field(None, ...)` - -## 问题分类 - -**问题类型**:**本项目填充的问题** - -- ✅ **不是 MySQL 原始数据的问题**:MySQL 中 `image_src` 字段可能确实为 NULL,但这是正常的业务数据 -- ✅ **不是 ES 数据的问题**:ES mapping 中 `image_src` 字段定义正确 -- ❌ **是本项目填充的问题**:代码逻辑导致当 MySQL 中 `image_src` 为空时,该字段没有被写入 ES 文档 - -## 后续操作 - -1. **重新索引数据**:修复代码后,需要重新索引数据才能生效 - ```bash - # 重新索引指定租户的数据 - ./scripts/ingest.sh true - ``` - -2. **验证修复**:重新索引后,查询 ES 验证 `image_src` 字段是否已包含: - ```bash - curl -u 'saas:4hOaLaf41y2VuI8y' -X GET 'http://localhost:9200/search_products/_search?pretty' \ - -H 'Content-Type: application/json' \ - -d '{ - "size": 1, - "query": {"nested": {"path": "skus", "query": {"exists": {"field": "skus"}}}}, - "_source": ["spu_id", "skus"] - }' - ``` - -3. **可选优化**:如果业务需要,可以考虑当 SKU 的 `image_src` 为空时,使用 SPU 的主图(`image_url`)作为默认值 - -## 相关文件 - -- `indexer/document_transformer.py` - 已修复 -- `api/models.py` - `SkuResult.image_src: Optional[str]` - 模型定义正确 -- `api/result_formatter.py` - `image_src=sku_entry.get('image_src')` - 读取逻辑正确 -- `mappings/search_products.json` - `skus.image_src` mapping 定义正确 diff --git a/embeddings/README.md b/embeddings/README.md index 5b59f2a..e39b8ac 100644 --- a/embeddings/README.md +++ b/embeddings/README.md @@ -8,8 +8,10 @@ - **HTTP 客户端**:`text_encoder.py` / `image_encoder.py`(供搜索/索引模块调用) - **本地模型实现**:`bge_model.py` / `clip_model.py` +- **clip-as-service 客户端**:`clip_as_service_encoder.py`(图片向量,推荐) - **向量化服务(FastAPI)**:`server.py` - **统一配置**:`config.py` +- **接口契约**:`protocols.ImageEncoderProtocol`(图片编码统一为 `encode_image_urls(urls, batch_size)`,本地 CN-CLIP 与 clip-as-service 均实现该接口) ### 服务接口 @@ -21,6 +23,24 @@ - 入参:`["url或本地路径1", ...]` - 出参:`[[...], null, ...]`(与输入按 index 对齐,失败为 `null`) +### 图片向量:clip-as-service(推荐) + +默认使用 `third-party/clip-as-service` 的 Jina CLIP 服务生成图片向量。 + +1. **安装 clip-client**(首次使用): + ```bash + pip install -e third-party/clip-as-service/client + ``` + +2. **启动 CN-CLIP 服务**(独立 gRPC 服务,默认端口 51000,详见 `docs/CNCLIP_SERVICE说明文档.md`): + ```bash + ./scripts/start_cnclip_service.sh + ``` + +3. **配置**(`embeddings/config.py` 或环境变量): + - `USE_CLIP_AS_SERVICE=true`(默认) + - `CLIP_AS_SERVICE_SERVER=grpc://127.0.0.1:51000` + ### 启动服务 使用仓库脚本启动(默认端口 6005): @@ -35,5 +55,6 @@ - `PORT`: 服务端口(默认 6005) - `TEXT_MODEL_DIR`, `TEXT_DEVICE`, `TEXT_BATCH_SIZE` -- `IMAGE_MODEL_NAME`, `IMAGE_DEVICE` +- `USE_CLIP_AS_SERVICE`, `CLIP_AS_SERVICE_SERVER`:图片向量(clip-as-service) +- `IMAGE_MODEL_NAME`, `IMAGE_DEVICE`:本地 CN-CLIP(当 `USE_CLIP_AS_SERVICE=false` 时) diff --git a/embeddings/clip_as_service_encoder.py b/embeddings/clip_as_service_encoder.py new file mode 100644 index 0000000..b7048c5 --- /dev/null +++ b/embeddings/clip_as_service_encoder.py @@ -0,0 +1,122 @@ +""" +Image encoder using third-party clip-as-service (Jina CLIP server). + +Requires clip-as-service server to be running. The client is loaded from +third-party/clip-as-service/client so no separate pip install is needed +if that path is on sys.path or the package is installed in development mode. +""" + +import logging +import os +import sys +from typing import List, Optional + +import numpy as np + +logger = logging.getLogger(__name__) + +# Ensure third-party clip client is importable +def _ensure_clip_client_path(): + repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + client_path = os.path.join(repo_root, "third-party", "clip-as-service", "client") + if os.path.isdir(client_path) and client_path not in sys.path: + sys.path.insert(0, client_path) + + +def _normalize_image_url(url: str) -> str: + """Normalize image URL for clip-as-service (e.g. //host/path -> https://host/path).""" + if not url or not isinstance(url, str): + return "" + url = url.strip() + if url.startswith("//"): + return "https:" + url + return url + + +class ClipAsServiceImageEncoder: + """ + Image embedding encoder using clip-as-service Client. + Encodes image URLs in batch; returns 1024-dim vectors (server model must match). + """ + + def __init__( + self, + server: str = "grpc://127.0.0.1:51000", + batch_size: int = 8, + show_progress: bool = False, + ): + """ + Args: + server: clip-as-service server URI (e.g. grpc://127.0.0.1:51000 or http://127.0.0.1:51000). + batch_size: batch size for encode requests. + show_progress: whether to show progress bar when encoding. + """ + _ensure_clip_client_path() + try: + from clip_client import Client + except ImportError as e: + raise ImportError( + "clip_client not found. Add third-party/clip-as-service/client to PYTHONPATH " + "or run: pip install -e third-party/clip-as-service/client" + ) from e + + self._server = server + self._batch_size = batch_size + self._show_progress = show_progress + self._client = Client(server) + + def encode_image_urls( + self, + urls: List[str], + batch_size: Optional[int] = None, + ) -> List[Optional[np.ndarray]]: + """ + Encode a list of image URLs to vectors. + + Args: + urls: list of image URLs (http/https or //host/path). + batch_size: override instance batch_size for this call. + + Returns: + List of vectors (1024-dim float32) or None for failed items, same length as urls. + """ + if not urls: + return [] + + normalized = [_normalize_image_url(u) for u in urls] + valid_indices = [i for i, u in enumerate(normalized) if u] + if not valid_indices: + return [None] * len(urls) + + valid_urls = [normalized[i] for i in valid_indices] + bs = batch_size if batch_size is not None else self._batch_size + out: List[Optional[np.ndarray]] = [None] * len(urls) + + try: + # Client.encode(iterable of str) returns np.ndarray [N, D] for string input + arr = self._client.encode( + valid_urls, + batch_size=bs, + show_progress=self._show_progress, + ) + if arr is not None and hasattr(arr, "shape") and len(arr) == len(valid_indices): + for j, idx in enumerate(valid_indices): + row = arr[j] + if row is not None and hasattr(row, "tolist"): + out[idx] = np.asarray(row, dtype=np.float32) + else: + out[idx] = np.array(row, dtype=np.float32) + else: + logger.warning( + "clip-as-service encode returned unexpected shape/length, " + "expected %d vectors", len(valid_indices) + ) + except Exception as e: + logger.warning("clip-as-service encode failed: %s", e, exc_info=True) + + return out + + def encode_image_from_url(self, url: str) -> Optional[np.ndarray]: + """Encode a single image URL. Returns 1024-dim vector or None.""" + results = self.encode_image_urls([url], batch_size=1) + return results[0] if results else None diff --git a/embeddings/clip_model.py b/embeddings/clip_model.py index 9bb1cf8..9beb210 100644 --- a/embeddings/clip_model.py +++ b/embeddings/clip_model.py @@ -17,7 +17,7 @@ import cn_clip.clip as clip DEFAULT_MODEL_NAME = "ViT-H-14" -MODEL_DOWNLOAD_DIR = "/data/tw/uat/EsSearcher" +MODEL_DOWNLOAD_DIR = "/data/" class ClipImageModel(object): @@ -91,6 +91,23 @@ class ClipImageModel(object): image = self.preprocess_image(image) return self.encode_image(image) + def encode_image_urls( + self, + urls: List[str], + batch_size: Optional[int] = None, + ) -> List[Optional[np.ndarray]]: + """ + Encode a list of image URLs to vectors. Same interface as ClipAsServiceImageEncoder. + + Args: + urls: list of image URLs or local paths. + batch_size: batch size for internal batching (default 8). + + Returns: + List of vectors (or None for failed items), same length as urls. + """ + return self.encode_batch(urls, batch_size=batch_size or 8) + def encode_batch(self, images: List[Union[str, Image.Image]], batch_size: int = 8) -> List[Optional[np.ndarray]]: results: List[Optional[np.ndarray]] = [] for i in range(0, len(images), batch_size): diff --git a/embeddings/cloud_text_encoder.py b/embeddings/cloud_text_encoder.py index 288589d..9c8360d 100644 --- a/embeddings/cloud_text_encoder.py +++ b/embeddings/cloud_text_encoder.py @@ -35,7 +35,7 @@ class CloudTextEncoder: if not api_key: raise ValueError("DASHSCOPE_API_KEY must be set in environment or passed as parameter") - # Use Beijing region by default + # 以下是北京地域base-url,如果使用新加坡地域的模型,需要将base_url替换为:https://dashscope-intl.aliyuncs.com/compatible-mode/v1 base_url = base_url or "https://dashscope.aliyuncs.com/compatible-mode/v1" cls._instance.client = OpenAI( diff --git a/embeddings/config.py b/embeddings/config.py index 7e0bf3a..767a15a 100644 --- a/embeddings/config.py +++ b/embeddings/config.py @@ -21,7 +21,12 @@ class EmbeddingConfig(object): TEXT_DEVICE = "cuda" # "cuda" or "cpu" (model may fall back to CPU if needed) TEXT_BATCH_SIZE = 32 - # Image embeddings (CN-CLIP) + # Image embeddings + # Option A: clip-as-service (Jina CLIP server, recommended) + USE_CLIP_AS_SERVICE = os.getenv("USE_CLIP_AS_SERVICE", "true").lower() in ("1", "true", "yes") + CLIP_AS_SERVICE_SERVER = os.getenv("CLIP_AS_SERVICE_SERVER", "grpc://127.0.0.1:51000") + + # Option B: local CN-CLIP (when USE_CLIP_AS_SERVICE=false) IMAGE_MODEL_NAME = "ViT-H-14" IMAGE_DEVICE = None # type: Optional[str] # "cuda" / "cpu" / None(auto) diff --git a/embeddings/protocols.py b/embeddings/protocols.py new file mode 100644 index 0000000..c9071c6 --- /dev/null +++ b/embeddings/protocols.py @@ -0,0 +1,27 @@ +""" +Protocols for embedding backends (structural typing, no inheritance required). + +Used by the embedding service so that any backend (ClipAsServiceImageEncoder, +ClipImageModel, etc.) can be used as long as it implements the same interface. +""" + +from typing import List, Optional, Protocol + +import numpy as np + + +class ImageEncoderProtocol(Protocol): + """Contract for image encoders used by the embedding service /embed/image endpoint.""" + + def encode_image_urls( + self, + urls: List[str], + batch_size: Optional[int] = None, + ) -> List[Optional[np.ndarray]]: + """ + Encode a list of image URLs to vectors. + + Returns: + List of vectors (or None for failed items), same length as urls. + """ + ... diff --git a/embeddings/server.py b/embeddings/server.py index d7bf9b8..779a1b3 100644 --- a/embeddings/server.py +++ b/embeddings/server.py @@ -16,6 +16,8 @@ from fastapi import FastAPI from embeddings.config import CONFIG from embeddings.bge_model import BgeTextModel from embeddings.clip_model import ClipImageModel +from embeddings.clip_as_service_encoder import ClipAsServiceImageEncoder +from embeddings.protocols import ImageEncoderProtocol logger = logging.getLogger(__name__) @@ -23,9 +25,9 @@ app = FastAPI(title="saas-search Embedding Service", version="1.0.0") # Models are loaded at startup, not lazily _text_model: Optional[BgeTextModel] = None -_image_model: Optional[ClipImageModel] = None +_image_model: Optional[ImageEncoderProtocol] = None open_text_model = True -open_image_model = False +open_image_model = True # Enable image embedding when using clip-as-service _text_encode_lock = threading.Lock() _image_encode_lock = threading.Lock() @@ -49,15 +51,23 @@ def load_models(): raise - # Load image model + # Load image model: clip-as-service (recommended) or local CN-CLIP if open_image_model: try: - logger.info(f"Loading image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})") - _image_model = ClipImageModel( - model_name=CONFIG.IMAGE_MODEL_NAME, - device=CONFIG.IMAGE_DEVICE, - ) - logger.info("Image model loaded successfully") + if CONFIG.USE_CLIP_AS_SERVICE: + logger.info(f"Loading image encoder via clip-as-service: {CONFIG.CLIP_AS_SERVICE_SERVER}") + _image_model = ClipAsServiceImageEncoder( + server=CONFIG.CLIP_AS_SERVICE_SERVER, + batch_size=CONFIG.IMAGE_BATCH_SIZE, + ) + logger.info("Image model (clip-as-service) loaded successfully") + else: + logger.info(f"Loading local image model: {CONFIG.IMAGE_MODEL_NAME} (device: {CONFIG.IMAGE_DEVICE})") + _image_model = ClipImageModel( + model_name=CONFIG.IMAGE_MODEL_NAME, + device=CONFIG.IMAGE_DEVICE, + ) + logger.info("Image model (local CN-CLIP) loaded successfully") except Exception as e: logger.error(f"Failed to load image model: {e}", exc_info=True) raise @@ -125,20 +135,31 @@ def embed_image(images: List[str]) -> List[Optional[List[float]]]: raise RuntimeError("Image model not loaded") out: List[Optional[List[float]]] = [None] * len(images) + # Normalize inputs + urls = [] + indices = [] + for i, url_or_path in enumerate(images): + if url_or_path is None: + continue + if not isinstance(url_or_path, str): + url_or_path = str(url_or_path) + url_or_path = url_or_path.strip() + if url_or_path: + urls.append(url_or_path) + indices.append(i) + + if not urls: + return out + with _image_encode_lock: - for i, url_or_path in enumerate(images): - try: - if url_or_path is None: - continue - if not isinstance(url_or_path, str): - url_or_path = str(url_or_path) - url_or_path = url_or_path.strip() - if not url_or_path: - continue - emb = _image_model.encode_image_from_url(url_or_path) - out[i] = _as_list(emb) - except Exception: - out[i] = None + try: + # Both ClipAsServiceImageEncoder and ClipImageModel implement encode_image_urls(urls, batch_size) + vectors = _image_model.encode_image_urls(urls, batch_size=CONFIG.IMAGE_BATCH_SIZE) + for j, idx in enumerate(indices): + out[idx] = _as_list(vectors[j] if j < len(vectors) else None) + except Exception: + for idx in indices: + out[idx] = None return out diff --git a/requirements.txt b/requirements.txt index 19b98e3..8f69f39 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,3 +40,9 @@ click>=8.1.0 pytest>=7.4.0 pytest-asyncio>=0.21.0 httpx>=0.24.0 + +# clip-as-service client (for image embeddings via clip-as-service) +# Install with: pip install -e third-party/clip-as-service/client +# Or: pip install jina docarray +jina>=3.12.0 +docarray[common]>=0.19.0,<0.30.0 -- libgit2 0.21.2