Commit b2dff38f61d84472225b8c3796192fc794ac4c09

Authored by tangwang
1 parent dfb45131

embedding-image接口(POST /embed/image)支持SVG格式:先转 PNG 再走 CN-CLIP

Showing 2 changed files with 86 additions and 13 deletions   Show diff stats
embeddings/server.py
... ... @@ -15,9 +15,11 @@ import time
15 15 import uuid
16 16 from collections import deque
17 17 from dataclasses import dataclass
  18 +import tempfile
18 19 from typing import Any, Dict, List, Optional
19 20  
20 21 import numpy as np
  22 +import requests
21 23 from fastapi import FastAPI, HTTPException, Request, Response
22 24 from fastapi.concurrency import run_in_threadpool
23 25  
... ... @@ -847,19 +849,36 @@ def _embed_image_lane_impl(
847 849 )
848 850  
849 851 backend_t0 = time.perf_counter()
850   - with _image_encode_lock:
851   - if lane == "image":
852   - vectors = _image_model.encode_image_urls(
853   - missing_items,
854   - batch_size=CONFIG.IMAGE_BATCH_SIZE,
855   - normalize_embeddings=effective_normalize,
856   - )
857   - else:
858   - vectors = _image_model.encode_clip_texts(
859   - missing_items,
860   - batch_size=CONFIG.IMAGE_BATCH_SIZE,
861   - normalize_embeddings=effective_normalize,
862   - )
  852 + tmp_png_paths: List[str] = []
  853 + encode_inputs: List[str] = list(missing_items)
  854 + if lane == "image":
  855 + # Best-effort: rasterize SVGs into temporary PNGs so CN-CLIP can encode them.
  856 + for i, item in enumerate(missing_items):
  857 + if _looks_like_svg_image_ref(item):
  858 + png_path = _rasterize_svg_to_temp_png(item)
  859 + tmp_png_paths.append(png_path)
  860 + encode_inputs[i] = png_path
  861 +
  862 + try:
  863 + with _image_encode_lock:
  864 + if lane == "image":
  865 + vectors = _image_model.encode_image_urls(
  866 + encode_inputs,
  867 + batch_size=CONFIG.IMAGE_BATCH_SIZE,
  868 + normalize_embeddings=effective_normalize,
  869 + )
  870 + else:
  871 + vectors = _image_model.encode_clip_texts(
  872 + missing_items,
  873 + batch_size=CONFIG.IMAGE_BATCH_SIZE,
  874 + normalize_embeddings=effective_normalize,
  875 + )
  876 + finally:
  877 + for p in tmp_png_paths:
  878 + try:
  879 + os.remove(p)
  880 + except Exception:
  881 + pass
863 882 if vectors is None or len(vectors) != len(missing_items):
864 883 raise RuntimeError(
865 884 f"{lane} lane length mismatch: expected {len(missing_items)}, "
... ... @@ -1284,6 +1303,57 @@ def _parse_string_inputs(raw: List[Any], *, kind: str, empty_detail: str) -> Lis
1284 1303 return out
1285 1304  
1286 1305  
  1306 +def _looks_like_svg_image_ref(value: str) -> bool:
  1307 + """
  1308 + CN-CLIP image embedding path expects raster images (jpg/png/webp/...) that PIL can decode.
  1309 + SVG is a vector format and currently not supported by the image embedding backend.
  1310 + """
  1311 + v = (value or "").strip().lower()
  1312 + if not v:
  1313 + return False
  1314 + # Drop query/fragment for URL suffix check.
  1315 + for sep in ("?", "#"):
  1316 + if sep in v:
  1317 + v = v.split(sep, 1)[0]
  1318 + return v.endswith(".svg") or v.startswith("data:image/svg+xml")
  1319 +
  1320 +
  1321 +def _rasterize_svg_to_temp_png(svg_ref: str, *, timeout_sec: int = 10) -> str:
  1322 + """
  1323 + Download/resolve an SVG ref (URL or local path) and rasterize it into a temporary PNG file.
  1324 +
  1325 + Returns the PNG file path (caller is responsible for cleanup).
  1326 + """
  1327 + if svg_ref.startswith(("http://", "https://")):
  1328 + resp = requests.get(svg_ref, timeout=timeout_sec)
  1329 + if resp.status_code != 200:
  1330 + raise ValueError(f"HTTP {resp.status_code} when downloading SVG")
  1331 + svg_bytes = resp.content
  1332 + else:
  1333 + with open(svg_ref, "rb") as f:
  1334 + svg_bytes = f.read()
  1335 +
  1336 + try:
  1337 + import cairosvg # type: ignore
  1338 + except Exception as exc: # pragma: no cover
  1339 + raise RuntimeError(
  1340 + "SVG rasterization requires optional dependency 'cairosvg'. "
  1341 + "Install it in the embedding-image service environment."
  1342 + ) from exc
  1343 +
  1344 + fd, out_path = tempfile.mkstemp(prefix="embed_svg_", suffix=".png")
  1345 + os.close(fd)
  1346 + try:
  1347 + cairosvg.svg2png(bytestring=svg_bytes, write_to=out_path)
  1348 + except Exception:
  1349 + try:
  1350 + os.remove(out_path)
  1351 + except Exception:
  1352 + pass
  1353 + raise
  1354 + return out_path
  1355 +
  1356 +
1287 1357 async def _run_image_lane_embed(
1288 1358 *,
1289 1359 route: str,
... ...
requirements_embedding_service.txt
... ... @@ -12,6 +12,9 @@ numpy>=1.24.0
12 12 pyyaml>=6.0
13 13 redis>=5.0.0
14 14  
  15 +# Optional: rasterize SVG to PNG for /embed/image
  16 +cairosvg>=2.7.0
  17 +
15 18 # Image backend via clip-as-service client
16 19 setuptools<82
17 20 jina>=3.12.0
... ...