Blame view

reranker/backends/qwen3_vllm.py 8.79 KB
701ae503   tangwang   docs
1
2
3
4
5
6
7
8
9
10
11
  """
  Qwen3-Reranker-0.6B backend using vLLM.
  
  Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
  Requires: vllm>=0.8.5, transformers; GPU recommended.
  """
  
  from __future__ import annotations
  
  import logging
  import math
efd435cf   tangwang   tei性能调优:
12
  import threading
701ae503   tangwang   docs
13
14
15
16
17
18
19
20
21
  import time
  from typing import Any, Dict, List, Optional, Tuple
  
  logger = logging.getLogger("reranker.backends.qwen3_vllm")
  
  try:
      import torch
      from transformers import AutoTokenizer
      from vllm import LLM, SamplingParams
bc089b43   tangwang   refactor(reranker...
22
      from vllm.inputs.data import TokensPrompt
701ae503   tangwang   docs
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
  except ImportError as e:
      raise ImportError(
          "Qwen3-vLLM reranker backend requires vllm>=0.8.5 and transformers. "
          "Install with: pip install vllm transformers"
      ) from e
  
  
  def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]:
      """Build chat messages for one (query, doc) pair."""
      return [
          {
              "role": "system",
              "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".",
          },
          {
              "role": "user",
              "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}",
          },
      ]
  
  
  class Qwen3VLLMRerankerBackend:
      """
      Qwen3-Reranker-0.6B with vLLM inference.
      Config from services.rerank.backends.qwen3_vllm.
      """
  
      def __init__(self, config: Dict[str, Any]) -> None:
          self._config = config or {}
          model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
07cf5a93   tangwang   START_EMBEDDING=...
53
          max_model_len = int(self._config.get("max_model_len", 2048))
701ae503   tangwang   docs
54
          tensor_parallel_size = int(self._config.get("tensor_parallel_size", 1))
07cf5a93   tangwang   START_EMBEDDING=...
55
56
57
58
          gpu_memory_utilization = float(self._config.get("gpu_memory_utilization", 0.4))
          enable_prefix_caching = bool(self._config.get("enable_prefix_caching", False))
          enforce_eager = bool(self._config.get("enforce_eager", True))
          dtype = str(self._config.get("dtype", "float16")).strip().lower()
701ae503   tangwang   docs
59
60
61
62
          self._instruction = str(
              self._config.get("instruction")
              or "Given a web search query, retrieve relevant passages that answer the query"
          )
07cf5a93   tangwang   START_EMBEDDING=...
63
64
65
66
          if not torch.cuda.is_available():
              raise RuntimeError("qwen3_vllm backend requires CUDA GPU, but torch.cuda.is_available() is False")
          if dtype not in {"float16", "half", "auto"}:
              raise ValueError(f"Unsupported dtype for qwen3_vllm: {dtype!r}. Use float16/half/auto.")
701ae503   tangwang   docs
67
68
  
          logger.info(
07cf5a93   tangwang   START_EMBEDDING=...
69
              "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)",
701ae503   tangwang   docs
70
71
72
              model_name,
              max_model_len,
              tensor_parallel_size,
07cf5a93   tangwang   START_EMBEDDING=...
73
74
              gpu_memory_utilization,
              dtype,
701ae503   tangwang   docs
75
76
77
78
79
80
81
82
83
              enable_prefix_caching,
          )
  
          self._llm = LLM(
              model=model_name,
              tensor_parallel_size=tensor_parallel_size,
              max_model_len=max_model_len,
              gpu_memory_utilization=gpu_memory_utilization,
              enable_prefix_caching=enable_prefix_caching,
07cf5a93   tangwang   START_EMBEDDING=...
84
85
              enforce_eager=enforce_eager,
              dtype=dtype,
701ae503   tangwang   docs
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
          )
          self._tokenizer = AutoTokenizer.from_pretrained(model_name)
          self._tokenizer.padding_side = "left"
          self._tokenizer.pad_token = self._tokenizer.eos_token
  
          # Suffix for generation prompt (assistant answer)
          self._suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
          self._suffix_tokens = self._tokenizer.encode(
              self._suffix, add_special_tokens=False
          )
          self._max_prompt_len = max_model_len - len(self._suffix_tokens)
  
          self._true_token = self._tokenizer("yes", add_special_tokens=False).input_ids[0]
          self._false_token = self._tokenizer("no", add_special_tokens=False).input_ids[0]
          self._sampling_params = SamplingParams(
              temperature=0,
              max_tokens=1,
              logprobs=20,
              allowed_token_ids=[self._true_token, self._false_token],
          )
efd435cf   tangwang   tei性能调优:
106
107
108
          # vLLM generate path is unstable under concurrent calls in this process model.
          # Serialize infer calls to avoid engine-core protocol corruption.
          self._infer_lock = threading.Lock()
701ae503   tangwang   docs
109
110
111
112
113
114
115
116
  
          self._model_name = model_name
          logger.info("[Qwen3_VLLM] Model ready | model=%s", model_name)
  
      def _process_inputs(
          self,
          pairs: List[Tuple[str, str]],
      ) -> List[TokensPrompt]:
bc089b43   tangwang   refactor(reranker...
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
          """Build tokenized prompts for vLLM from (query, doc) pairs. Batch apply_chat_template."""
          messages_batch = [
              _format_instruction(self._instruction, q, d) for q, d in pairs
          ]
          tokenized = self._tokenizer.apply_chat_template(
              messages_batch,
              tokenize=True,
              add_generation_prompt=False,
              enable_thinking=False,
          )
          # Single conv returns flat list; batch returns list of lists
          if tokenized and not isinstance(tokenized[0], list):
              tokenized = [tokenized]
          prompts = [
              TokensPrompt(
                  prompt_token_ids=ids[: self._max_prompt_len] + self._suffix_tokens
701ae503   tangwang   docs
133
              )
bc089b43   tangwang   refactor(reranker...
134
135
              for ids in tokenized
          ]
701ae503   tangwang   docs
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
          return prompts
  
      def _compute_scores(
          self,
          prompts: List[TokensPrompt],
      ) -> List[float]:
          """Run vLLM generate and compute yes/no probability per prompt."""
          if not prompts:
              return []
          outputs = self._llm.generate(prompts, self._sampling_params, use_tqdm=False)
          scores = []
          for i in range(len(outputs)):
              out = outputs[i]
              if not out.outputs:
                  scores.append(0.0)
                  continue
bc089b43   tangwang   refactor(reranker...
152
153
              final_logits = out.outputs[0].logprobs
              if not final_logits:
701ae503   tangwang   docs
154
155
                  scores.append(0.0)
                  continue
bc089b43   tangwang   refactor(reranker...
156
157
158
159
160
161
162
163
164
165
166
167
168
              last = final_logits[-1]
              # Match official: missing token -> logprob = -10
              if self._true_token not in last:
                  true_logit = -10
              else:
                  true_logit = last[self._true_token].logprob
              if self._false_token not in last:
                  false_logit = -10
              else:
                  false_logit = last[self._false_token].logprob
              true_score = math.exp(true_logit)
              false_score = math.exp(false_logit)
              score = true_score / (true_score + false_score)
701ae503   tangwang   docs
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
              scores.append(float(score))
          return scores
  
      def score_with_meta(
          self,
          query: str,
          docs: List[str],
          normalize: bool = True,
      ) -> Tuple[List[float], Dict[str, Any]]:
          start_ts = time.time()
          total_docs = len(docs) if docs else 0
          output_scores: List[float] = [0.0] * total_docs
  
          query = "" if query is None else str(query).strip()
          indexed: List[Tuple[int, str]] = []
          for i, doc in enumerate(docs or []):
              if doc is None:
                  continue
              text = str(doc).strip()
              if not text:
                  continue
              indexed.append((i, text))
  
          if not query or not indexed:
              elapsed_ms = (time.time() - start_ts) * 1000.0
              return output_scores, {
                  "input_docs": total_docs,
                  "usable_docs": len(indexed),
                  "unique_docs": 0,
                  "dedup_ratio": 0.0,
                  "elapsed_ms": round(elapsed_ms, 3),
                  "model": self._model_name,
                  "backend": "qwen3_vllm",
                  "normalize": normalize,
              }
  
          # Deduplicate by text, keep mapping to original indices
          unique_texts: List[str] = []
          position_to_unique: List[int] = []
          prev: Optional[str] = None
          for _idx, text in indexed:
              if text != prev:
                  unique_texts.append(text)
                  prev = text
              position_to_unique.append(len(unique_texts) - 1)
  
          pairs = [(query, t) for t in unique_texts]
efd435cf   tangwang   tei性能调优:
216
217
218
          with self._infer_lock:
              prompts = self._process_inputs(pairs)
              unique_scores = self._compute_scores(prompts)
701ae503   tangwang   docs
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
  
          for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
              # Score is already P(yes) in [0,1] from yes/(yes+no)
              output_scores[orig_idx] = float(unique_scores[unique_idx])
  
          elapsed_ms = (time.time() - start_ts) * 1000.0
          dedup_ratio = 0.0
          if indexed:
              dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
  
          meta = {
              "input_docs": total_docs,
              "usable_docs": len(indexed),
              "unique_docs": len(unique_texts),
              "dedup_ratio": round(dedup_ratio, 4),
              "elapsed_ms": round(elapsed_ms, 3),
              "model": self._model_name,
              "backend": "qwen3_vllm",
              "normalize": normalize,
          }
          return output_scores, meta