qwen3_vllm_score.py
13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
"""
Qwen3-Reranker via vLLM ``LLM.score()`` (pooling / cross-encoder score API).
Matches vLLM ``examples/offline_inference/qwen3_reranker.py``: paired
``llm.score(query_texts, doc_texts)`` with the recommended prefix/suffix templates.
Requires vLLM >= 0.17 (uses ``runner``/``convert`` auto, not legacy ``task="score"``).
Dedicated venv: ``.venv-reranker-score`` + ``requirements_reranker_qwen3_vllm_score.txt``
(see ``./scripts/setup_reranker_venv.sh qwen3_vllm_score``). Default ``model_name`` can match
``qwen3_vllm``; only the Python env differs for pinned high-performance vLLM.
Reference: https://docs.vllm.ai/ — Qwen3 reranker example
"""
from __future__ import annotations
import logging
import os
import threading
import time
from typing import Any, Dict, List, Tuple
logger = logging.getLogger("reranker.backends.qwen3_vllm_score")
import torch
from vllm import LLM
from reranker.backends.qwen3_vllm import deduplicate_with_positions
# Official vLLM Qwen3 reranker prompt layout (im_start blocks + assistant suffix).
_DEFAULT_PREFIX = (
"<|im_start|>system\n"
"Judge whether the Document meets the requirements based on the Query and the Instruct "
'provided. Note that the answer can only be "yes" or "no".'
"<|im_end|>\n<|im_start|>user\n"
)
_DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
_DEFAULT_QUERY_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"
_DEFAULT_DOCUMENT_TEMPLATE = "<Document>: {doc}{suffix}"
# compact:与 qwen3_vllm._format_instruction 一致(instruction 作 system,user 内重复 Instruct)
_IM_USER_START = "<|im_end|>\n<|im_start|>user\n"
def _resolve_vllm_attention_config(config: Dict[str, Any]) -> Dict[str, Any] | None:
"""
vLLM 0.18 defaults to Flash-Attention paths that require compute capability >= 8 (Ampere+).
Turing / Volta (e.g. T4 sm_75) must use a non-FA backend such as TRITON_ATTN.
"""
env = (os.getenv("RERANK_VLLM_ATTENTION_BACKEND") or "").strip()
raw = config.get("vllm_attention_backend")
if env:
choice = env
elif raw is not None and str(raw).strip() and str(raw).strip().lower() != "auto":
choice = str(raw).strip()
else:
choice = ""
if choice:
backend = choice.strip().upper()
if backend == "AUTO":
choice = ""
else:
logger.info("[Qwen3_VLLM_SCORE] attention_config.backend=%s (from config/env)", backend)
return {"backend": backend}
major, minor = torch.cuda.get_device_capability()
if major < 8:
logger.info(
"[Qwen3_VLLM_SCORE] GPU compute capability %d.%d < 8.0; using attention backend "
"TRITON_ATTN (Flash-Attention 2 requires sm >= 80). "
"Override with services.rerank.backends.qwen3_vllm_score.vllm_attention_backend "
"or RERANK_VLLM_ATTENTION_BACKEND.",
major,
minor,
)
return {"backend": "TRITON_ATTN"}
return None
class Qwen3VLLMScoreRerankerBackend:
"""
Qwen3 reranker using vLLM ``LLM.score()`` (pooling runner) for cross-encoder scores.
Config from ``services.rerank.backends.qwen3_vllm_score``.
"""
def __init__(self, config: Dict[str, Any]) -> None:
self._config = config or {}
model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
max_model_len = int(self._config.get("max_model_len", 2048))
tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1))
gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4))
enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False))
enforce_eager = bool(self._config.get("enforce_eager", True))
dtype = str(self._config.get("dtype", "float16")).strip().lower()
use_hf_overrides = self._config.get("use_original_qwen3_hf_overrides")
if use_hf_overrides is None:
use_hf_overrides = True
use_hf_overrides = bool(use_hf_overrides)
self._instruction = str(
self._config.get("instruction")
or "Given a query, score the product for relevance"
)
_fmt = str(self._config.get("instruction_format") or "standard").strip().lower()
if _fmt not in {"standard", "compact"}:
raise ValueError(
f"instruction_format must be 'standard' or 'compact', got {_fmt!r}"
)
self._instruction_format = _fmt
self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX)
self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX)
self._query_template = str(self._config.get("query_template") or _DEFAULT_QUERY_TEMPLATE)
self._document_template = str(
self._config.get("document_template") or _DEFAULT_DOCUMENT_TEMPLATE
)
infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get(
"infer_batch_size", 64
)
sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH")
if sort_by_doc_length is None:
sort_by_doc_length = self._config.get("sort_by_doc_length", True)
self._infer_batch_size = int(infer_batch_size)
self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
"1",
"true",
"yes",
"y",
"on",
}
if not torch.cuda.is_available():
raise RuntimeError(
"qwen3_vllm_score backend requires CUDA GPU, but torch.cuda.is_available() is False"
)
if dtype not in {"float16", "half", "auto"}:
raise ValueError(
f"Unsupported dtype for qwen3_vllm_score: {dtype!r}. Use float16/half/auto."
)
if self._infer_batch_size <= 0:
raise ValueError(f"infer_batch_size must be > 0, got {self._infer_batch_size}")
runner = str(self._config.get("vllm_runner") or "auto").strip().lower()
convert = str(self._config.get("vllm_convert") or "auto").strip().lower()
if runner not in {"auto", "generate", "pooling", "draft"}:
raise ValueError(f"Invalid vllm_runner: {runner!r}")
if convert not in {"auto", "none", "embed", "classify"}:
raise ValueError(f"Invalid vllm_convert: {convert!r}")
logger.info(
"[Qwen3_VLLM_SCORE] Loading model %s (LLM.score API, runner=%s, convert=%s, "
"hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, "
"instruction_format=%s)",
model_name,
runner,
convert,
use_hf_overrides,
max_model_len,
tensor_parallel_size,
gpu_memory_utilization,
dtype,
enable_prefix_caching,
self._instruction_format,
)
# vLLM 0.17+ uses runner/convert instead of LLM(..., task="score"). With the official
# Qwen3 reranker hf_overrides, architecture becomes *ForSequenceClassification -> pooling+classify.
llm_kwargs: Dict[str, Any] = {
"model": model_name,
"runner": runner,
"convert": convert,
"tensor_parallel_size": tensor_parallel_size,
"max_model_len": max_model_len,
"gpu_memory_utilization": gpu_memory_utilization,
"enable_prefix_caching": enable_prefix_caching,
"enforce_eager": enforce_eager,
"dtype": dtype,
}
hf_overrides: Dict[str, Any] = dict(self._config.get("hf_overrides") or {})
if use_hf_overrides:
hf_overrides = {
**hf_overrides,
"architectures": ["Qwen3ForSequenceClassification"],
"classifier_from_token": ["no", "yes"],
"is_original_qwen3_reranker": True,
}
if hf_overrides:
llm_kwargs["hf_overrides"] = hf_overrides
attn_cfg = _resolve_vllm_attention_config(self._config)
if attn_cfg is not None:
llm_kwargs["attention_config"] = attn_cfg
self._llm = LLM(**llm_kwargs)
# vLLM score path: single-process safety (mirrors generate backend until verified).
self._infer_lock = threading.Lock()
self._model_name = model_name
logger.info("[Qwen3_VLLM_SCORE] Model ready | model=%s", model_name)
def _format_pair(self, query: str, doc: str) -> Tuple[str, str]:
if self._instruction_format == "compact":
# Align with reranker.backends.qwen3_vllm._format_instruction query/doc split for LLM.score().
compact_prefix = f"<|im_start|>system\n{self._instruction}{_IM_USER_START}"
q_text = (
f"{compact_prefix}<Instruct>: {self._instruction}\n\n<Query>: {query}\n"
)
d_text = f"\n<Document>: {doc}{self._suffix}"
return q_text, d_text
q_text = self._query_template.format(
prefix=self._prefix,
instruction=self._instruction,
query=query,
)
d_text = self._document_template.format(doc=doc, suffix=self._suffix)
return q_text, d_text
def _score_batch(self, pairs: List[Tuple[str, str]]) -> List[float]:
if not pairs:
return []
queries: List[str] = []
documents: List[str] = []
for q, d in pairs:
qt, dt = self._format_pair(q, d)
queries.append(qt)
documents.append(dt)
with self._infer_lock:
outputs = self._llm.score(queries, documents, use_tqdm=False)
scores: List[float] = []
for out in outputs:
so = out.outputs
scores.append(float(so.score))
return scores
@staticmethod
def _estimate_doc_lengths(docs: List[str]) -> List[int]:
if not docs:
return []
return [len(text) for text in docs]
def score_with_meta(
self,
query: str,
docs: List[str],
normalize: bool = True,
) -> Tuple[List[float], Dict[str, Any]]:
start_ts = time.time()
total_docs = len(docs) if docs else 0
output_scores: List[float] = [0.0] * total_docs
query = "" if query is None else str(query).strip()
indexed: List[Tuple[int, str]] = []
for i, doc in enumerate(docs or []):
if doc is None:
continue
text = str(doc).strip()
if not text:
continue
indexed.append((i, text))
if not query or not indexed:
elapsed_ms = (time.time() - start_ts) * 1000.0
return output_scores, {
"input_docs": total_docs,
"usable_docs": len(indexed),
"unique_docs": 0,
"dedup_ratio": 0.0,
"elapsed_ms": round(elapsed_ms, 3),
"model": self._model_name,
"backend": "qwen3_vllm_score",
"normalize": normalize,
"infer_batch_size": self._infer_batch_size,
"inference_batches": 0,
"sort_by_doc_length": self._sort_by_doc_length,
"instruction_format": self._instruction_format,
}
indexed_texts = [text for _, text in indexed]
unique_texts, position_to_unique = deduplicate_with_positions(indexed_texts)
lengths = self._estimate_doc_lengths(unique_texts)
order = list(range(len(unique_texts)))
if self._sort_by_doc_length and len(unique_texts) > 1:
order = sorted(order, key=lambda i: lengths[i])
unique_scores: List[float] = [0.0] * len(unique_texts)
inference_batches = 0
for start in range(0, len(order), self._infer_batch_size):
batch_indices = order[start : start + self._infer_batch_size]
inference_batches += 1
pairs = [(query, unique_texts[i]) for i in batch_indices]
batch_scores = self._score_batch(pairs)
if len(batch_scores) != len(batch_indices):
raise RuntimeError(
f"Reranker score size mismatch: expected {len(batch_indices)}, got {len(batch_scores)}"
)
for idx, score in zip(batch_indices, batch_scores):
unique_scores[idx] = float(score)
for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
output_scores[orig_idx] = float(unique_scores[unique_idx])
elapsed_ms = (time.time() - start_ts) * 1000.0
dedup_ratio = 0.0
if indexed:
dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
meta = {
"input_docs": total_docs,
"usable_docs": len(indexed),
"unique_docs": len(unique_texts),
"dedup_ratio": round(dedup_ratio, 4),
"elapsed_ms": round(elapsed_ms, 3),
"model": self._model_name,
"backend": "qwen3_vllm_score",
"normalize": normalize,
"infer_batch_size": self._infer_batch_size,
"inference_batches": inference_batches,
"sort_by_doc_length": self._sort_by_doc_length,
"instruction_format": self._instruction_format,
}
return output_scores, meta