Commit b2dff38f61d84472225b8c3796192fc794ac4c09

Authored by tangwang
1 parent dfb45131

embedding-image接口(POST /embed/image)支持SVG格式:先转 PNG 再走 CN-CLIP

Showing 2 changed files with 86 additions and 13 deletions   Show diff stats
embeddings/server.py
@@ -15,9 +15,11 @@ import time @@ -15,9 +15,11 @@ import time
15 import uuid 15 import uuid
16 from collections import deque 16 from collections import deque
17 from dataclasses import dataclass 17 from dataclasses import dataclass
  18 +import tempfile
18 from typing import Any, Dict, List, Optional 19 from typing import Any, Dict, List, Optional
19 20
20 import numpy as np 21 import numpy as np
  22 +import requests
21 from fastapi import FastAPI, HTTPException, Request, Response 23 from fastapi import FastAPI, HTTPException, Request, Response
22 from fastapi.concurrency import run_in_threadpool 24 from fastapi.concurrency import run_in_threadpool
23 25
@@ -847,19 +849,36 @@ def _embed_image_lane_impl( @@ -847,19 +849,36 @@ def _embed_image_lane_impl(
847 ) 849 )
848 850
849 backend_t0 = time.perf_counter() 851 backend_t0 = time.perf_counter()
850 - with _image_encode_lock:  
851 - if lane == "image":  
852 - vectors = _image_model.encode_image_urls(  
853 - missing_items,  
854 - batch_size=CONFIG.IMAGE_BATCH_SIZE,  
855 - normalize_embeddings=effective_normalize,  
856 - )  
857 - else:  
858 - vectors = _image_model.encode_clip_texts(  
859 - missing_items,  
860 - batch_size=CONFIG.IMAGE_BATCH_SIZE,  
861 - normalize_embeddings=effective_normalize,  
862 - ) 852 + tmp_png_paths: List[str] = []
  853 + encode_inputs: List[str] = list(missing_items)
  854 + if lane == "image":
  855 + # Best-effort: rasterize SVGs into temporary PNGs so CN-CLIP can encode them.
  856 + for i, item in enumerate(missing_items):
  857 + if _looks_like_svg_image_ref(item):
  858 + png_path = _rasterize_svg_to_temp_png(item)
  859 + tmp_png_paths.append(png_path)
  860 + encode_inputs[i] = png_path
  861 +
  862 + try:
  863 + with _image_encode_lock:
  864 + if lane == "image":
  865 + vectors = _image_model.encode_image_urls(
  866 + encode_inputs,
  867 + batch_size=CONFIG.IMAGE_BATCH_SIZE,
  868 + normalize_embeddings=effective_normalize,
  869 + )
  870 + else:
  871 + vectors = _image_model.encode_clip_texts(
  872 + missing_items,
  873 + batch_size=CONFIG.IMAGE_BATCH_SIZE,
  874 + normalize_embeddings=effective_normalize,
  875 + )
  876 + finally:
  877 + for p in tmp_png_paths:
  878 + try:
  879 + os.remove(p)
  880 + except Exception:
  881 + pass
863 if vectors is None or len(vectors) != len(missing_items): 882 if vectors is None or len(vectors) != len(missing_items):
864 raise RuntimeError( 883 raise RuntimeError(
865 f"{lane} lane length mismatch: expected {len(missing_items)}, " 884 f"{lane} lane length mismatch: expected {len(missing_items)}, "
@@ -1284,6 +1303,57 @@ def _parse_string_inputs(raw: List[Any], *, kind: str, empty_detail: str) -> Lis @@ -1284,6 +1303,57 @@ def _parse_string_inputs(raw: List[Any], *, kind: str, empty_detail: str) -> Lis
1284 return out 1303 return out
1285 1304
1286 1305
  1306 +def _looks_like_svg_image_ref(value: str) -> bool:
  1307 + """
  1308 + CN-CLIP image embedding path expects raster images (jpg/png/webp/...) that PIL can decode.
  1309 + SVG is a vector format and currently not supported by the image embedding backend.
  1310 + """
  1311 + v = (value or "").strip().lower()
  1312 + if not v:
  1313 + return False
  1314 + # Drop query/fragment for URL suffix check.
  1315 + for sep in ("?", "#"):
  1316 + if sep in v:
  1317 + v = v.split(sep, 1)[0]
  1318 + return v.endswith(".svg") or v.startswith("data:image/svg+xml")
  1319 +
  1320 +
  1321 +def _rasterize_svg_to_temp_png(svg_ref: str, *, timeout_sec: int = 10) -> str:
  1322 + """
  1323 + Download/resolve an SVG ref (URL or local path) and rasterize it into a temporary PNG file.
  1324 +
  1325 + Returns the PNG file path (caller is responsible for cleanup).
  1326 + """
  1327 + if svg_ref.startswith(("http://", "https://")):
  1328 + resp = requests.get(svg_ref, timeout=timeout_sec)
  1329 + if resp.status_code != 200:
  1330 + raise ValueError(f"HTTP {resp.status_code} when downloading SVG")
  1331 + svg_bytes = resp.content
  1332 + else:
  1333 + with open(svg_ref, "rb") as f:
  1334 + svg_bytes = f.read()
  1335 +
  1336 + try:
  1337 + import cairosvg # type: ignore
  1338 + except Exception as exc: # pragma: no cover
  1339 + raise RuntimeError(
  1340 + "SVG rasterization requires optional dependency 'cairosvg'. "
  1341 + "Install it in the embedding-image service environment."
  1342 + ) from exc
  1343 +
  1344 + fd, out_path = tempfile.mkstemp(prefix="embed_svg_", suffix=".png")
  1345 + os.close(fd)
  1346 + try:
  1347 + cairosvg.svg2png(bytestring=svg_bytes, write_to=out_path)
  1348 + except Exception:
  1349 + try:
  1350 + os.remove(out_path)
  1351 + except Exception:
  1352 + pass
  1353 + raise
  1354 + return out_path
  1355 +
  1356 +
1287 async def _run_image_lane_embed( 1357 async def _run_image_lane_embed(
1288 *, 1358 *,
1289 route: str, 1359 route: str,
requirements_embedding_service.txt
@@ -12,6 +12,9 @@ numpy>=1.24.0 @@ -12,6 +12,9 @@ numpy>=1.24.0
12 pyyaml>=6.0 12 pyyaml>=6.0
13 redis>=5.0.0 13 redis>=5.0.0
14 14
  15 +# Optional: rasterize SVG to PNG for /embed/image
  16 +cairosvg>=2.7.0
  17 +
15 # Image backend via clip-as-service client 18 # Image backend via clip-as-service client
16 setuptools<82 19 setuptools<82
17 jina>=3.12.0 20 jina>=3.12.0