Commit b2dff38f61d84472225b8c3796192fc794ac4c09
1 parent
dfb45131
embedding-image接口(POST /embed/image)支持SVG格式:先转 PNG 再走 CN-CLIP
Showing
2 changed files
with
86 additions
and
13 deletions
Show diff stats
embeddings/server.py
| ... | ... | @@ -15,9 +15,11 @@ import time |
| 15 | 15 | import uuid |
| 16 | 16 | from collections import deque |
| 17 | 17 | from dataclasses import dataclass |
| 18 | +import tempfile | |
| 18 | 19 | from typing import Any, Dict, List, Optional |
| 19 | 20 | |
| 20 | 21 | import numpy as np |
| 22 | +import requests | |
| 21 | 23 | from fastapi import FastAPI, HTTPException, Request, Response |
| 22 | 24 | from fastapi.concurrency import run_in_threadpool |
| 23 | 25 | |
| ... | ... | @@ -847,19 +849,36 @@ def _embed_image_lane_impl( |
| 847 | 849 | ) |
| 848 | 850 | |
| 849 | 851 | backend_t0 = time.perf_counter() |
| 850 | - with _image_encode_lock: | |
| 851 | - if lane == "image": | |
| 852 | - vectors = _image_model.encode_image_urls( | |
| 853 | - missing_items, | |
| 854 | - batch_size=CONFIG.IMAGE_BATCH_SIZE, | |
| 855 | - normalize_embeddings=effective_normalize, | |
| 856 | - ) | |
| 857 | - else: | |
| 858 | - vectors = _image_model.encode_clip_texts( | |
| 859 | - missing_items, | |
| 860 | - batch_size=CONFIG.IMAGE_BATCH_SIZE, | |
| 861 | - normalize_embeddings=effective_normalize, | |
| 862 | - ) | |
| 852 | + tmp_png_paths: List[str] = [] | |
| 853 | + encode_inputs: List[str] = list(missing_items) | |
| 854 | + if lane == "image": | |
| 855 | + # Best-effort: rasterize SVGs into temporary PNGs so CN-CLIP can encode them. | |
| 856 | + for i, item in enumerate(missing_items): | |
| 857 | + if _looks_like_svg_image_ref(item): | |
| 858 | + png_path = _rasterize_svg_to_temp_png(item) | |
| 859 | + tmp_png_paths.append(png_path) | |
| 860 | + encode_inputs[i] = png_path | |
| 861 | + | |
| 862 | + try: | |
| 863 | + with _image_encode_lock: | |
| 864 | + if lane == "image": | |
| 865 | + vectors = _image_model.encode_image_urls( | |
| 866 | + encode_inputs, | |
| 867 | + batch_size=CONFIG.IMAGE_BATCH_SIZE, | |
| 868 | + normalize_embeddings=effective_normalize, | |
| 869 | + ) | |
| 870 | + else: | |
| 871 | + vectors = _image_model.encode_clip_texts( | |
| 872 | + missing_items, | |
| 873 | + batch_size=CONFIG.IMAGE_BATCH_SIZE, | |
| 874 | + normalize_embeddings=effective_normalize, | |
| 875 | + ) | |
| 876 | + finally: | |
| 877 | + for p in tmp_png_paths: | |
| 878 | + try: | |
| 879 | + os.remove(p) | |
| 880 | + except Exception: | |
| 881 | + pass | |
| 863 | 882 | if vectors is None or len(vectors) != len(missing_items): |
| 864 | 883 | raise RuntimeError( |
| 865 | 884 | f"{lane} lane length mismatch: expected {len(missing_items)}, " |
| ... | ... | @@ -1284,6 +1303,57 @@ def _parse_string_inputs(raw: List[Any], *, kind: str, empty_detail: str) -> Lis |
| 1284 | 1303 | return out |
| 1285 | 1304 | |
| 1286 | 1305 | |
| 1306 | +def _looks_like_svg_image_ref(value: str) -> bool: | |
| 1307 | + """ | |
| 1308 | + CN-CLIP image embedding path expects raster images (jpg/png/webp/...) that PIL can decode. | |
| 1309 | + SVG is a vector format and currently not supported by the image embedding backend. | |
| 1310 | + """ | |
| 1311 | + v = (value or "").strip().lower() | |
| 1312 | + if not v: | |
| 1313 | + return False | |
| 1314 | + # Drop query/fragment for URL suffix check. | |
| 1315 | + for sep in ("?", "#"): | |
| 1316 | + if sep in v: | |
| 1317 | + v = v.split(sep, 1)[0] | |
| 1318 | + return v.endswith(".svg") or v.startswith("data:image/svg+xml") | |
| 1319 | + | |
| 1320 | + | |
| 1321 | +def _rasterize_svg_to_temp_png(svg_ref: str, *, timeout_sec: int = 10) -> str: | |
| 1322 | + """ | |
| 1323 | + Download/resolve an SVG ref (URL or local path) and rasterize it into a temporary PNG file. | |
| 1324 | + | |
| 1325 | + Returns the PNG file path (caller is responsible for cleanup). | |
| 1326 | + """ | |
| 1327 | + if svg_ref.startswith(("http://", "https://")): | |
| 1328 | + resp = requests.get(svg_ref, timeout=timeout_sec) | |
| 1329 | + if resp.status_code != 200: | |
| 1330 | + raise ValueError(f"HTTP {resp.status_code} when downloading SVG") | |
| 1331 | + svg_bytes = resp.content | |
| 1332 | + else: | |
| 1333 | + with open(svg_ref, "rb") as f: | |
| 1334 | + svg_bytes = f.read() | |
| 1335 | + | |
| 1336 | + try: | |
| 1337 | + import cairosvg # type: ignore | |
| 1338 | + except Exception as exc: # pragma: no cover | |
| 1339 | + raise RuntimeError( | |
| 1340 | + "SVG rasterization requires optional dependency 'cairosvg'. " | |
| 1341 | + "Install it in the embedding-image service environment." | |
| 1342 | + ) from exc | |
| 1343 | + | |
| 1344 | + fd, out_path = tempfile.mkstemp(prefix="embed_svg_", suffix=".png") | |
| 1345 | + os.close(fd) | |
| 1346 | + try: | |
| 1347 | + cairosvg.svg2png(bytestring=svg_bytes, write_to=out_path) | |
| 1348 | + except Exception: | |
| 1349 | + try: | |
| 1350 | + os.remove(out_path) | |
| 1351 | + except Exception: | |
| 1352 | + pass | |
| 1353 | + raise | |
| 1354 | + return out_path | |
| 1355 | + | |
| 1356 | + | |
| 1287 | 1357 | async def _run_image_lane_embed( |
| 1288 | 1358 | *, |
| 1289 | 1359 | route: str, | ... | ... |
requirements_embedding_service.txt