Commit 9de5ef4988cef11a85d55e9eed44075d74a71a49
1 parent
5c21a485
qwen3_vllm_score : task="score" +(原版 + hf_overrides)或 HuggingFace 上已转好的 seq-cls 模型。generate()
Showing
7 changed files
with
294 additions
and
11 deletions
Show diff stats
config/config.yaml
| ... | ... | @@ -231,7 +231,7 @@ rerank: |
| 231 | 231 | text_bias: 0.1 |
| 232 | 232 | text_exponent: 0.35 |
| 233 | 233 | knn_bias: 0.6 |
| 234 | - knn_exponent: 0.2 | |
| 234 | + knn_exponent: 0.0 | |
| 235 | 235 | |
| 236 | 236 | # 可扩展服务/provider 注册表(单一配置源) |
| 237 | 237 | services: |
| ... | ... | @@ -381,7 +381,7 @@ services: |
| 381 | 381 | max_docs: 1000 |
| 382 | 382 | normalize: true |
| 383 | 383 | # 服务内后端(reranker 进程启动时读取) |
| 384 | - backend: "qwen3_vllm" # bge | qwen3_vllm | qwen3_transformers | qwen3_gguf | qwen3_gguf_06b | dashscope_rerank | |
| 384 | + backend: "qwen3_vllm_score" # bge | qwen3_vllm | qwen3_vllm_score | qwen3_transformers | qwen3_gguf | qwen3_gguf_06b | dashscope_rerank | |
| 385 | 385 | backends: |
| 386 | 386 | bge: |
| 387 | 387 | model_name: "BAAI/bge-reranker-v2-m3" |
| ... | ... | @@ -411,6 +411,26 @@ services: |
| 411 | 411 | # instruction: "Relevance ranking: category & style match first" |
| 412 | 412 | # instruction: "Score product relevance by query with category & style match prioritized" |
| 413 | 413 | instruction: "Rank products by query with category & style match prioritized" |
| 414 | + # vLLM LLM.score()(跨编码打分);与 qwen3_vllm 共用 .venv-reranker 与同模型权重(vLLM 0.17+ 用 runner/convert=auto,旧版曾用 task=score) | |
| 415 | + qwen3_vllm_score: | |
| 416 | + model_name: "Qwen/Qwen3-Reranker-0.6B" | |
| 417 | + # 官方 Hub 原版需 true;若改用已转换的 seq-cls 权重(如 tomaarsen/...-seq-cls)则设为 false | |
| 418 | + use_original_qwen3_hf_overrides: true | |
| 419 | + # 可选:与 vLLM 对齐;一般保持 auto | |
| 420 | + # vllm_runner: "auto" | |
| 421 | + # vllm_convert: "auto" | |
| 422 | + # 可选:在 use_original_qwen3_hf_overrides 为 true 时与内置 overrides 合并 | |
| 423 | + # hf_overrides: {} | |
| 424 | + engine: "vllm" | |
| 425 | + max_model_len: 160 | |
| 426 | + tensor_parallel_size: 1 | |
| 427 | + gpu_memory_utilization: 0.20 | |
| 428 | + dtype: "float16" | |
| 429 | + enable_prefix_caching: true | |
| 430 | + enforce_eager: false | |
| 431 | + infer_batch_size: 100 | |
| 432 | + sort_by_doc_length: true | |
| 433 | + instruction: "Rank products by query with category & style match prioritized" | |
| 414 | 434 | qwen3_transformers: |
| 415 | 435 | model_name: "Qwen/Qwen3-Reranker-0.6B" |
| 416 | 436 | instruction: "rank products by given query" | ... | ... |
reranker/backends/__init__.py
| ... | ... | @@ -43,6 +43,9 @@ def get_rerank_backend(name: str, config: Dict[str, Any]) -> RerankBackendProtoc |
| 43 | 43 | if name == "qwen3_vllm": |
| 44 | 44 | from reranker.backends.qwen3_vllm import Qwen3VLLMRerankerBackend |
| 45 | 45 | return Qwen3VLLMRerankerBackend(config) |
| 46 | + if name == "qwen3_vllm_score": | |
| 47 | + from reranker.backends.qwen3_vllm_score import Qwen3VLLMScoreRerankerBackend | |
| 48 | + return Qwen3VLLMScoreRerankerBackend(config) | |
| 46 | 49 | if name == "qwen3_transformers": |
| 47 | 50 | from reranker.backends.qwen3_transformers import Qwen3TransformersRerankerBackend |
| 48 | 51 | return Qwen3TransformersRerankerBackend(config) |
| ... | ... | @@ -60,7 +63,7 @@ def get_rerank_backend(name: str, config: Dict[str, Any]) -> RerankBackendProtoc |
| 60 | 63 | from reranker.backends.dashscope_rerank import DashScopeRerankBackend |
| 61 | 64 | return DashScopeRerankBackend(config) |
| 62 | 65 | raise ValueError( |
| 63 | - f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_transformers, qwen3_gguf, qwen3_gguf_06b, dashscope_rerank" | |
| 66 | + f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_vllm_score, qwen3_transformers, qwen3_gguf, qwen3_gguf_06b, dashscope_rerank" | |
| 64 | 67 | ) |
| 65 | 68 | |
| 66 | 69 | ... | ... |
reranker/backends/qwen3_vllm.py
| ... | ... | @@ -50,11 +50,11 @@ def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str |
| 50 | 50 | return [ |
| 51 | 51 | { |
| 52 | 52 | "role": "system", |
| 53 | - "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".", | |
| 53 | + "content": instruction, | |
| 54 | 54 | }, |
| 55 | 55 | { |
| 56 | 56 | "role": "user", |
| 57 | - "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}", | |
| 57 | + "content": f"<Query>: {query}\n\n<Document>: {doc}", | |
| 58 | 58 | }, |
| 59 | 59 | ] |
| 60 | 60 | ... | ... |
| ... | ... | @@ -0,0 +1,260 @@ |
| 1 | +""" | |
| 2 | +Qwen3-Reranker via vLLM ``task="score"`` (official pooling/score API). | |
| 3 | + | |
| 4 | +Matches vLLM ``examples/offline_inference/qwen3_reranker.py``: paired ``llm.score(query_texts, doc_texts)`` | |
| 5 | +with the recommended prefix/suffix templates. Same venv and default model as ``qwen3_vllm``. | |
| 6 | + | |
| 7 | +Reference: https://docs.vllm.ai/ (Qwen3 reranker example) | |
| 8 | +https://docs.vllm.com.cn/en/latest/examples/offline_inference/qwen3_reranker.html | |
| 9 | +""" | |
| 10 | + | |
| 11 | +from __future__ import annotations | |
| 12 | + | |
| 13 | +import logging | |
| 14 | +import os | |
| 15 | +import threading | |
| 16 | +import time | |
| 17 | +from typing import Any, Dict, List, Tuple | |
| 18 | + | |
| 19 | +logger = logging.getLogger("reranker.backends.qwen3_vllm_score") | |
| 20 | + | |
| 21 | +import torch | |
| 22 | +from vllm import LLM | |
| 23 | + | |
| 24 | +from reranker.backends.qwen3_vllm import deduplicate_with_positions | |
| 25 | + | |
| 26 | +# Official vLLM Qwen3 reranker prompt layout (im_start blocks + assistant suffix). | |
| 27 | +_DEFAULT_PREFIX = ( | |
| 28 | + "<|im_start|>system\n" | |
| 29 | + "Judge whether the Document meets the requirements based on the Query and the Instruct " | |
| 30 | + 'provided. Note that the answer can only be "yes" or "no".' | |
| 31 | + "<|im_end|>\n<|im_start|>user\n" | |
| 32 | +) | |
| 33 | +_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | |
| 34 | +_DEFAULT_QUERY_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n" | |
| 35 | +_DEFAULT_DOCUMENT_TEMPLATE = "<Document>: {doc}{suffix}" | |
| 36 | + | |
| 37 | + | |
| 38 | +class Qwen3VLLMScoreRerankerBackend: | |
| 39 | + """ | |
| 40 | + Qwen3 reranker using vLLM ``LLM(..., task="score")`` and ``llm.score(queries, documents)``. | |
| 41 | + | |
| 42 | + Config from ``services.rerank.backends.qwen3_vllm_score``. | |
| 43 | + """ | |
| 44 | + | |
| 45 | + def __init__(self, config: Dict[str, Any]) -> None: | |
| 46 | + self._config = config or {} | |
| 47 | + model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B") | |
| 48 | + max_model_len = int(self._config.get("max_model_len", 2048)) | |
| 49 | + tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1)) | |
| 50 | + gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4)) | |
| 51 | + enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False)) | |
| 52 | + enforce_eager = bool(self._config.get("enforce_eager", True)) | |
| 53 | + dtype = str(self._config.get("dtype", "float16")).strip().lower() | |
| 54 | + use_hf_overrides = self._config.get("use_original_qwen3_hf_overrides") | |
| 55 | + if use_hf_overrides is None: | |
| 56 | + use_hf_overrides = True | |
| 57 | + use_hf_overrides = bool(use_hf_overrides) | |
| 58 | + | |
| 59 | + self._instruction = str( | |
| 60 | + self._config.get("instruction") | |
| 61 | + or "Given a query, score the product for relevance" | |
| 62 | + ) | |
| 63 | + self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX) | |
| 64 | + self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX) | |
| 65 | + self._query_template = str(self._config.get("query_template") or _DEFAULT_QUERY_TEMPLATE) | |
| 66 | + self._document_template = str( | |
| 67 | + self._config.get("document_template") or _DEFAULT_DOCUMENT_TEMPLATE | |
| 68 | + ) | |
| 69 | + | |
| 70 | + infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get( | |
| 71 | + "infer_batch_size", 64 | |
| 72 | + ) | |
| 73 | + sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH") | |
| 74 | + if sort_by_doc_length is None: | |
| 75 | + sort_by_doc_length = self._config.get("sort_by_doc_length", True) | |
| 76 | + | |
| 77 | + self._infer_batch_size = int(infer_batch_size) | |
| 78 | + self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in { | |
| 79 | + "1", | |
| 80 | + "true", | |
| 81 | + "yes", | |
| 82 | + "y", | |
| 83 | + "on", | |
| 84 | + } | |
| 85 | + | |
| 86 | + if not torch.cuda.is_available(): | |
| 87 | + raise RuntimeError( | |
| 88 | + "qwen3_vllm_score backend requires CUDA GPU, but torch.cuda.is_available() is False" | |
| 89 | + ) | |
| 90 | + if dtype not in {"float16", "half", "auto"}: | |
| 91 | + raise ValueError( | |
| 92 | + f"Unsupported dtype for qwen3_vllm_score: {dtype!r}. Use float16/half/auto." | |
| 93 | + ) | |
| 94 | + if self._infer_batch_size <= 0: | |
| 95 | + raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}") | |
| 96 | + | |
| 97 | + runner = str(self._config.get("vllm_runner") or "auto").strip().lower() | |
| 98 | + convert = str(self._config.get("vllm_convert") or "auto").strip().lower() | |
| 99 | + if runner not in {"auto", "generate", "pooling", "draft"}: | |
| 100 | + raise ValueError(f"Invalid vllm_runner: {runner!r}") | |
| 101 | + if convert not in {"auto", "none", "embed", "classify"}: | |
| 102 | + raise ValueError(f"Invalid vllm_convert: {convert!r}") | |
| 103 | + | |
| 104 | + logger.info( | |
| 105 | + "[Qwen3_VLLM_SCORE] Loading model %s (LLM.score API, runner=%s, convert=%s, " | |
| 106 | + "hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)", | |
| 107 | + model_name, | |
| 108 | + runner, | |
| 109 | + convert, | |
| 110 | + use_hf_overrides, | |
| 111 | + max_model_len, | |
| 112 | + tensor_parallel_size, | |
| 113 | + gpu_memory_utilization, | |
| 114 | + dtype, | |
| 115 | + enable_prefix_caching, | |
| 116 | + ) | |
| 117 | + | |
| 118 | + # vLLM 0.17+ uses runner/convert instead of LLM(..., task="score"). With the official | |
| 119 | + # Qwen3 reranker hf_overrides, architecture becomes *ForSequenceClassification -> pooling+classify. | |
| 120 | + llm_kwargs: Dict[str, Any] = { | |
| 121 | + "model": model_name, | |
| 122 | + "runner": runner, | |
| 123 | + "convert": convert, | |
| 124 | + "tensor_parallel_size": tensor_parallel_size, | |
| 125 | + "max_model_len": max_model_len, | |
| 126 | + "gpu_memory_utilization": gpu_memory_utilization, | |
| 127 | + "enable_prefix_caching": enable_prefix_caching, | |
| 128 | + "enforce_eager": enforce_eager, | |
| 129 | + "dtype": dtype, | |
| 130 | + } | |
| 131 | + hf_overrides: Dict[str, Any] = dict(self._config.get("hf_overrides") or {}) | |
| 132 | + if use_hf_overrides: | |
| 133 | + hf_overrides = { | |
| 134 | + **hf_overrides, | |
| 135 | + "architectures": ["Qwen3ForSequenceClassification"], | |
| 136 | + "classifier_from_token": ["no", "yes"], | |
| 137 | + "is_original_qwen3_reranker": True, | |
| 138 | + } | |
| 139 | + if hf_overrides: | |
| 140 | + llm_kwargs["hf_overrides"] = hf_overrides | |
| 141 | + | |
| 142 | + self._llm = LLM(**llm_kwargs) | |
| 143 | + # vLLM score path: single-process safety (mirrors generate backend until verified). | |
| 144 | + self._infer_lock = threading.Lock() | |
| 145 | + | |
| 146 | + self._model_name = model_name | |
| 147 | + logger.info("[Qwen3_VLLM_SCORE] Model ready | model=%s", model_name) | |
| 148 | + | |
| 149 | + def _format_pair(self, query: str, doc: str) -> Tuple[str, str]: | |
| 150 | + q_text = self._query_template.format( | |
| 151 | + prefix=self._prefix, | |
| 152 | + instruction=self._instruction, | |
| 153 | + query=query, | |
| 154 | + ) | |
| 155 | + d_text = self._document_template.format(doc=doc, suffix=self._suffix) | |
| 156 | + return q_text, d_text | |
| 157 | + | |
| 158 | + def _score_batch(self, pairs: List[Tuple[str, str]]) -> List[float]: | |
| 159 | + if not pairs: | |
| 160 | + return [] | |
| 161 | + queries: List[str] = [] | |
| 162 | + documents: List[str] = [] | |
| 163 | + for q, d in pairs: | |
| 164 | + qt, dt = self._format_pair(q, d) | |
| 165 | + queries.append(qt) | |
| 166 | + documents.append(dt) | |
| 167 | + with self._infer_lock: | |
| 168 | + outputs = self._llm.score(queries, documents, use_tqdm=False) | |
| 169 | + scores: List[float] = [] | |
| 170 | + for out in outputs: | |
| 171 | + so = out.outputs | |
| 172 | + scores.append(float(so.score)) | |
| 173 | + return scores | |
| 174 | + | |
| 175 | + @staticmethod | |
| 176 | + def _estimate_doc_lengths(docs: List[str]) -> List[int]: | |
| 177 | + if not docs: | |
| 178 | + return [] | |
| 179 | + return [len(text) for text in docs] | |
| 180 | + | |
| 181 | + def score_with_meta( | |
| 182 | + self, | |
| 183 | + query: str, | |
| 184 | + docs: List[str], | |
| 185 | + normalize: bool = True, | |
| 186 | + ) -> Tuple[List[float], Dict[str, Any]]: | |
| 187 | + start_ts = time.time() | |
| 188 | + total_docs = len(docs) if docs else 0 | |
| 189 | + output_scores: List[float] = [0.0] * total_docs | |
| 190 | + | |
| 191 | + query = "" if query is None else str(query).strip() | |
| 192 | + indexed: List[Tuple[int, str]] = [] | |
| 193 | + for i, doc in enumerate(docs or []): | |
| 194 | + if doc is None: | |
| 195 | + continue | |
| 196 | + text = str(doc).strip() | |
| 197 | + if not text: | |
| 198 | + continue | |
| 199 | + indexed.append((i, text)) | |
| 200 | + | |
| 201 | + if not query or not indexed: | |
| 202 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 203 | + return output_scores, { | |
| 204 | + "input_docs": total_docs, | |
| 205 | + "usable_docs": len(indexed), | |
| 206 | + "unique_docs": 0, | |
| 207 | + "dedup_ratio": 0.0, | |
| 208 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 209 | + "model": self._model_name, | |
| 210 | + "backend": "qwen3_vllm_score", | |
| 211 | + "normalize": normalize, | |
| 212 | + "infer_batch_size": self._infer_batch_size, | |
| 213 | + "inference_batches": 0, | |
| 214 | + "sort_by_doc_length": self._sort_by_doc_length, | |
| 215 | + } | |
| 216 | + | |
| 217 | + indexed_texts = [text for _, text in indexed] | |
| 218 | + unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts) | |
| 219 | + | |
| 220 | + lengths = self._estimate_doc_lengths(unique_texts) | |
| 221 | + order = list(range(len(unique_texts))) | |
| 222 | + if self._sort_by_doc_length and len(unique_texts) > 1: | |
| 223 | + order = sorted(order, key=lambda i: lengths[i]) | |
| 224 | + | |
| 225 | + unique_scores: List[float] = [0.0] * len(unique_texts) | |
| 226 | + inference_batches = 0 | |
| 227 | + for start in range(0, len(order), self._infer_batch_size): | |
| 228 | + batch_indices = order[start : start + self._infer_batch_size] | |
| 229 | + inference_batches += 1 | |
| 230 | + pairs = [(query, unique_texts[i]) for i in batch_indices] | |
| 231 | + batch_scores = self._score_batch(pairs) | |
| 232 | + if len(batch_scores) != len(batch_indices): | |
| 233 | + raise RuntimeError( | |
| 234 | + f"Reranker score size mismatch: expected {len(batch_indices)}, got {len(batch_scores)}" | |
| 235 | + ) | |
| 236 | + for idx, score in zip(batch_indices, batch_scores): | |
| 237 | + unique_scores[idx] = float(score) | |
| 238 | + | |
| 239 | + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): | |
| 240 | + output_scores[orig_idx] = float(unique_scores[unique_idx]) | |
| 241 | + | |
| 242 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 243 | + dedup_ratio = 0.0 | |
| 244 | + if indexed: | |
| 245 | + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) | |
| 246 | + | |
| 247 | + meta = { | |
| 248 | + "input_docs": total_docs, | |
| 249 | + "usable_docs": len(indexed), | |
| 250 | + "unique_docs": len(unique_texts), | |
| 251 | + "dedup_ratio": round(dedup_ratio, 4), | |
| 252 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 253 | + "model": self._model_name, | |
| 254 | + "backend": "qwen3_vllm_score", | |
| 255 | + "normalize": normalize, | |
| 256 | + "infer_batch_size": self._infer_batch_size, | |
| 257 | + "inference_batches": inference_batches, | |
| 258 | + "sort_by_doc_length": self._sort_by_doc_length, | |
| 259 | + } | |
| 260 | + return output_scores, meta | ... | ... |
reranker/server.py
| ... | ... | @@ -7,7 +7,7 @@ Request: { "query": "...", "docs": ["doc1", "doc2", ...], "normalize": optional |
| 7 | 7 | Response: { "scores": [float], "meta": {...} } |
| 8 | 8 | |
| 9 | 9 | Backend selected via config: services.rerank.backend |
| 10 | -(bge | qwen3_vllm | qwen3_transformers | qwen3_gguf | qwen3_gguf_06b | dashscope_rerank), env RERANK_BACKEND. | |
| 10 | +(bge | qwen3_vllm | qwen3_vllm_score | qwen3_transformers | qwen3_gguf | qwen3_gguf_06b | dashscope_rerank), env RERANK_BACKEND. | |
| 11 | 11 | """ |
| 12 | 12 | |
| 13 | 13 | import logging | ... | ... |
scripts/lib/reranker_backend_env.sh
| ... | ... | @@ -38,7 +38,7 @@ reranker_backend_venv_dir() { |
| 38 | 38 | local backend="$2" |
| 39 | 39 | |
| 40 | 40 | case "${backend}" in |
| 41 | - qwen3_vllm) printf '%s/.venv-reranker\n' "${project_root}" ;; | |
| 41 | + qwen3_vllm|qwen3_vllm_score) printf '%s/.venv-reranker\n' "${project_root}" ;; | |
| 42 | 42 | qwen3_gguf) printf '%s/.venv-reranker-gguf\n' "${project_root}" ;; |
| 43 | 43 | qwen3_gguf_06b) printf '%s/.venv-reranker-gguf-06b\n' "${project_root}" ;; |
| 44 | 44 | qwen3_transformers) printf '%s/.venv-reranker-transformers\n' "${project_root}" ;; |
| ... | ... | @@ -53,7 +53,7 @@ reranker_backend_requirements_file() { |
| 53 | 53 | local backend="$2" |
| 54 | 54 | |
| 55 | 55 | case "${backend}" in |
| 56 | - qwen3_vllm) printf '%s/requirements_reranker_qwen3_vllm.txt\n' "${project_root}" ;; | |
| 56 | + qwen3_vllm|qwen3_vllm_score) printf '%s/requirements_reranker_qwen3_vllm.txt\n' "${project_root}" ;; | |
| 57 | 57 | qwen3_gguf) printf '%s/requirements_reranker_qwen3_gguf.txt\n' "${project_root}" ;; |
| 58 | 58 | qwen3_gguf_06b) printf '%s/requirements_reranker_qwen3_gguf_06b.txt\n' "${project_root}" ;; |
| 59 | 59 | qwen3_transformers) printf '%s/requirements_reranker_qwen3_transformers.txt\n' "${project_root}" ;; | ... | ... |
scripts/start_reranker.sh
| ... | ... | @@ -47,9 +47,9 @@ if [[ "${RERANK_BACKEND}" == qwen3_gguf* ]]; then |
| 47 | 47 | export HF_HUB_DISABLE_XET="${HF_HUB_DISABLE_XET:-1}" |
| 48 | 48 | fi |
| 49 | 49 | |
| 50 | -if [[ "${RERANK_BACKEND}" == "qwen3_vllm" ]]; then | |
| 50 | +if [[ "${RERANK_BACKEND}" == "qwen3_vllm" || "${RERANK_BACKEND}" == "qwen3_vllm_score" ]]; then | |
| 51 | 51 | if ! command -v nvidia-smi >/dev/null 2>&1 || ! nvidia-smi >/dev/null 2>&1; then |
| 52 | - echo "ERROR: qwen3_vllm backend requires NVIDIA GPU, but nvidia-smi is unavailable." >&2 | |
| 52 | + echo "ERROR: ${RERANK_BACKEND} backend requires NVIDIA GPU, but nvidia-smi is unavailable." >&2 | |
| 53 | 53 | exit 1 |
| 54 | 54 | fi |
| 55 | 55 | if ! "${PYTHON_BIN}" - <<'PY' |
| ... | ... | @@ -62,7 +62,7 @@ except Exception: |
| 62 | 62 | raise SystemExit(1) |
| 63 | 63 | PY |
| 64 | 64 | then |
| 65 | - echo "ERROR: qwen3_vllm backend requires vllm + CUDA runtime in ${RERANKER_VENV}." >&2 | |
| 65 | + echo "ERROR: ${RERANK_BACKEND} backend requires vllm + CUDA runtime in ${RERANKER_VENV}." >&2 | |
| 66 | 66 | echo "Please run: ./scripts/setup_reranker_venv.sh ${RERANK_BACKEND} and verify CUDA is available." >&2 |
| 67 | 67 | exit 1 |
| 68 | 68 | fi | ... | ... |