Blame view

reranker/backends/qwen3_transformers_packed.py 15.5 KB
4823f463   tangwang   qwen3_vllm_score ...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
  """
  Qwen3-Reranker backend using packed inference with Transformers.
  
  This backend implements the sequence stitching optimization described in
  Qwen3-Reranker packed inference examples:
  1. Share the query/instruction prefix across many documents.
  2. Reset document ``position_ids`` relative to the shared prefix.
  3. Use a custom causal attention mask so each document can attend to the
     prefix and itself, but never to other documents.
  
  Compared with the standard per-pair batching path, this reduces repeated
  prefix computation and removes inter-sample padding waste. For online search
  requests like ``1 query + 400 docs``, the backend further packs documents into
  multiple chunks under a configurable total token budget.
  """
  
  from __future__ import annotations
  
  import logging
  import threading
  import time
  from typing import Any, Dict, List, Sequence, Tuple
  
  import torch
  from transformers import AutoModelForCausalLM, AutoTokenizer
  
  logger = logging.getLogger("reranker.backends.qwen3_transformers_packed")
  
  _DEFAULT_PREFIX = (
      "<|im_start|>system\n"
      "Judge whether the Document meets the requirements based on the Query and the Instruct "
      'provided. Note that the answer can only be "yes" or "no".'
      "<|im_end|>\n<|im_start|>user\n"
  )
  _DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
  _DEFAULT_PAIR_PREFIX_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n<Document>: "
  
  
  def _deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]:
      unique_texts: List[str] = []
      position_to_unique: List[int] = []
      seen: Dict[str, int] = {}
  
      for text in texts:
          idx = seen.get(text)
          if idx is None:
              idx = len(unique_texts)
              seen[text] = idx
              unique_texts.append(text)
          position_to_unique.append(idx)
  
      return unique_texts, position_to_unique
  
  
  class Qwen3TransformersPackedRerankerBackend:
      """
      Qwen3-Reranker packed inference backend using Transformers.
  
      Config from ``services.rerank.backends.qwen3_transformers_packed``.
      """
  
      def __init__(self, config: Dict[str, Any]) -> None:
          self._config = config or {}
          model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
          self._instruction = str(
              self._config.get("instruction")
              or "Rank products by query with category & style match prioritized"
          )
          self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX)
          self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX)
          self._pair_prefix_template = str(
              self._config.get("pair_prefix_template") or _DEFAULT_PAIR_PREFIX_TEMPLATE
          )
  
          max_model_len = int(self._config.get("max_model_len", 4096))
          max_doc_len = int(self._config.get("max_doc_len", 160))
          max_docs_per_pack = int(self._config.get("max_docs_per_pack", 0))
          use_fp16 = bool(self._config.get("use_fp16", True))
          device = self._config.get("device")
          attn_impl = str(self._config.get("attn_implementation") or "eager").strip()
          sort_by_doc_length = self._config.get("sort_by_doc_length", True)
  
          self._model_name = model_name
          self._max_model_len = max_model_len
          self._max_doc_len = max_doc_len
          self._max_docs_per_pack = max_docs_per_pack
          self._sort_by_doc_length = str(sort_by_doc_length).strip().lower() in {
              "1",
              "true",
              "yes",
              "y",
              "on",
          }
          self._attn_impl = attn_impl
  
          logger.info(
              "[Qwen3_Transformers_Packed] Loading model %s (max_model_len=%s, max_doc_len=%s, "
              "max_docs_per_pack=%s, fp16=%s, attn_impl=%s)",
              model_name,
              max_model_len,
              max_doc_len,
              max_docs_per_pack,
              use_fp16,
              attn_impl,
          )
  
          self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
          self._tokenizer.pad_token = self._tokenizer.eos_token
  
          self._prefix_tokens = self._tokenizer.encode(self._prefix, add_special_tokens=False)
          self._suffix_tokens = self._tokenizer.encode(self._suffix, add_special_tokens=False)
          self._suffix_len = len(self._suffix_tokens)
  
          if not torch.cuda.is_available():
              raise RuntimeError(
                  "qwen3_transformers_packed backend requires CUDA GPU, "
                  "but torch.cuda.is_available() is False"
              )
  
          kwargs: Dict[str, Any] = {}
          if use_fp16:
              kwargs["torch_dtype"] = torch.float16
          if attn_impl:
              kwargs["attn_implementation"] = attn_impl
  
          self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval()
          target_device = str(device).strip() if device is not None else "cuda"
          if not target_device.startswith("cuda"):
              raise ValueError(
                  "qwen3_transformers_packed backend is GPU-only. "
                  f"Unsupported device setting: {target_device!r}"
              )
          self._model = self._model.to(target_device)
          self._device = next(self._model.parameters()).device
          if self._device.type != "cuda":
              raise RuntimeError(
                  "qwen3_transformers_packed backend failed to place model on CUDA. "
                  f"Current device: {self._device}"
              )
  
          self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes")
          self._token_false_id = self._tokenizer.convert_tokens_to_ids("no")
          if self._token_true_id is None or self._token_false_id is None:
              raise RuntimeError("Failed to resolve Qwen3 reranker classifier token ids for yes/no")
  
          prefix_budget = len(self._prefix_tokens) + self._suffix_len + 1
          if self._max_model_len <= prefix_budget:
              raise ValueError(
                  "max_model_len is too small for packed reranking. "
                  f"Need > {prefix_budget}, got {self._max_model_len}."
              )
          if self._max_doc_len <= 0:
              raise ValueError(f"max_doc_len must be > 0, got {self._max_doc_len}")
          if self._max_docs_per_pack < 0:
              raise ValueError(
                  f"max_docs_per_pack must be >= 0, got {self._max_docs_per_pack}"
              )
  
          self._infer_lock = threading.Lock()
  
          logger.info(
              "[Qwen3_Transformers_Packed] Model ready | model=%s device=%s",
              model_name,
              self._device,
          )
  
      def _build_pair_prefix_tokens(self, query: str) -> List[int]:
          pair_prefix = self._pair_prefix_template.format(
              prefix=self._prefix,
              instruction=self._instruction,
              query=query,
          )
          return self._tokenizer.encode(pair_prefix, add_special_tokens=False)
  
      def _tokenize_documents(self, docs: Sequence[str], query_prefix_len: int) -> List[List[int]]:
          max_doc_tokens = min(
              self._max_doc_len,
              max(1, self._max_model_len - query_prefix_len - self._suffix_len),
          )
          tokenized = self._tokenizer(
              list(docs),
              padding=False,
              truncation=True,
              max_length=max_doc_tokens,
              add_special_tokens=False,
              return_attention_mask=False,
          )
          return [list(ids) for ids in tokenized["input_ids"]]
  
      def _build_pack_plan(
          self,
          query_prefix_len: int,
          doc_tokens: Sequence[Sequence[int]],
      ) -> List[List[int]]:
          order = list(range(len(doc_tokens)))
          if self._sort_by_doc_length and len(order) > 1:
              order.sort(key=lambda idx: len(doc_tokens[idx]))
  
          packs: List[List[int]] = []
          current_pack: List[int] = []
          current_len = query_prefix_len
          for idx in order:
              packed_doc_len = len(doc_tokens[idx]) + self._suffix_len
              if packed_doc_len <= 0:
                  continue
  
              over_docs_cap = self._max_docs_per_pack > 0 and len(current_pack) >= self._max_docs_per_pack
              over_token_cap = current_pack and (current_len + packed_doc_len > self._max_model_len)
              if over_docs_cap or over_token_cap:
                  packs.append(current_pack)
                  current_pack = []
                  current_len = query_prefix_len
  
              if query_prefix_len + packed_doc_len > self._max_model_len:
                  raise ValueError(
                      "Packed doc still exceeds max_model_len after truncation. "
                      f"query_prefix_len={query_prefix_len}, doc_len={packed_doc_len}, "
                      f"max_model_len={self._max_model_len}"
                  )
  
              current_pack.append(idx)
              current_len += packed_doc_len
  
          if current_pack:
              packs.append(current_pack)
          return packs
  
      def _build_pack_inputs(
          self,
          query_prefix_tokens: Sequence[int],
          doc_tokens: Sequence[Sequence[int]],
          doc_indices: Sequence[int],
      ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
          prefix_len = len(query_prefix_tokens)
          input_ids_list = list(query_prefix_tokens)
          position_ids_list = list(range(prefix_len))
          spans: List[Tuple[int, int]] = []
          current_len = prefix_len
  
          for idx in doc_indices:
              doc_with_suffix = list(doc_tokens[idx]) + self._suffix_tokens
              start = current_len
              end = start + len(doc_with_suffix)
              spans.append((start, end))
              input_ids_list.extend(doc_with_suffix)
              position_ids_list.extend(range(prefix_len, prefix_len + len(doc_with_suffix)))
              current_len = end
  
          total_len = len(input_ids_list)
          device = self._device
          neg_inf = torch.finfo(torch.float32).min
  
          allowed = torch.zeros((total_len, total_len), dtype=torch.bool, device=device)
          prefix_causal = torch.tril(
              torch.ones((prefix_len, prefix_len), dtype=torch.bool, device=device)
          )
          allowed[:prefix_len, :prefix_len] = prefix_causal
          for start, end in spans:
              allowed[start:end, :prefix_len] = True
              doc_len = end - start
              allowed[start:end, start:end] = torch.tril(
                  torch.ones((doc_len, doc_len), dtype=torch.bool, device=device)
              )
  
          attention_mask = torch.full(
              (total_len, total_len),
              neg_inf,
              dtype=torch.float32,
              device=device,
          )
          attention_mask.masked_fill_(allowed, 0.0)
  
          inputs = {
              "input_ids": torch.tensor([input_ids_list], dtype=torch.long, device=device),
              "position_ids": torch.tensor([position_ids_list], dtype=torch.long, device=device),
              "attention_mask": attention_mask.view(1, 1, total_len, total_len),
          }
          logits_ids = torch.tensor(
              [end - 1 for _, end in spans],
              dtype=torch.long,
              device=device,
          )
          return inputs, logits_ids
  
      @torch.no_grad()
      def _score_pack(
          self,
          query_prefix_tokens: Sequence[int],
          doc_tokens: Sequence[Sequence[int]],
          doc_indices: Sequence[int],
      ) -> Tuple[List[float], int]:
          inputs, logits_ids = self._build_pack_inputs(
              query_prefix_tokens=query_prefix_tokens,
              doc_tokens=doc_tokens,
              doc_indices=doc_indices,
          )
          outputs = self._model(**inputs)
          scores = outputs.logits[0, logits_ids, :]
          true_vector = scores[:, self._token_true_id]
          false_vector = scores[:, self._token_false_id]
          pair_scores = torch.stack([false_vector, true_vector], dim=1)
          pair_scores = torch.nn.functional.log_softmax(pair_scores, dim=1)
          return pair_scores[:, 1].exp().tolist(), int(inputs["input_ids"].shape[1])
  
      def score_with_meta(
          self,
          query: str,
          docs: List[str],
          normalize: bool = True,
      ) -> Tuple[List[float], Dict[str, Any]]:
          start_ts = time.time()
          total_docs = len(docs) if docs else 0
          output_scores: List[float] = [0.0] * total_docs
  
          query = "" if query is None else str(query).strip()
          indexed: List[Tuple[int, str]] = []
          for i, doc in enumerate(docs or []):
              if doc is None:
                  continue
              text = str(doc).strip()
              if not text:
                  continue
              indexed.append((i, text))
  
          if not query or not indexed:
              elapsed_ms = (time.time() - start_ts) * 1000.0
              return output_scores, {
                  "input_docs": total_docs,
                  "usable_docs": len(indexed),
                  "unique_docs": 0,
                  "dedup_ratio": 0.0,
                  "elapsed_ms": round(elapsed_ms, 3),
                  "model": self._model_name,
                  "backend": "qwen3_transformers_packed",
                  "normalize": normalize,
                  "packed_batches": 0,
                  "max_model_len": self._max_model_len,
                  "max_doc_len": self._max_doc_len,
                  "sort_by_doc_length": self._sort_by_doc_length,
              }
  
          indexed_texts = [text for _, text in indexed]
          unique_texts, position_to_unique = _deduplicate_with_positions(indexed_texts)
  
          query_prefix_tokens = self._build_pair_prefix_tokens(query)
          doc_tokens = self._tokenize_documents(unique_texts, query_prefix_len=len(query_prefix_tokens))
          pack_plan = self._build_pack_plan(
              query_prefix_len=len(query_prefix_tokens),
              doc_tokens=doc_tokens,
          )
  
          unique_scores: List[float] = [0.0] * len(unique_texts)
          pack_lengths: List[int] = []
          with self._infer_lock:
              for pack_doc_indices in pack_plan:
                  batch_scores, pack_seq_len = self._score_pack(
                      query_prefix_tokens=query_prefix_tokens,
                      doc_tokens=doc_tokens,
                      doc_indices=pack_doc_indices,
                  )
                  if len(batch_scores) != len(pack_doc_indices):
                      raise RuntimeError(
                          "Packed reranker score size mismatch: "
                          f"expected {len(pack_doc_indices)}, got {len(batch_scores)}"
                      )
                  for idx, score in zip(pack_doc_indices, batch_scores):
                      unique_scores[idx] = float(score)
                  pack_lengths.append(pack_seq_len)
  
          for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
              output_scores[orig_idx] = float(unique_scores[unique_idx])
  
          elapsed_ms = (time.time() - start_ts) * 1000.0
          dedup_ratio = 0.0
          if indexed:
              dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
  
          meta = {
              "input_docs": total_docs,
              "usable_docs": len(indexed),
              "unique_docs": len(unique_texts),
              "dedup_ratio": round(dedup_ratio, 4),
              "elapsed_ms": round(elapsed_ms, 3),
              "model": self._model_name,
              "backend": "qwen3_transformers_packed",
              "normalize": normalize,
              "packed_batches": len(pack_plan),
              "packed_max_seq_len": max(pack_lengths) if pack_lengths else 0,
              "packed_avg_seq_len": round(sum(pack_lengths) / len(pack_lengths), 3)
              if pack_lengths
              else 0.0,
              "max_model_len": self._max_model_len,
              "max_doc_len": self._max_doc_len,
              "max_docs_per_pack": self._max_docs_per_pack,
              "sort_by_doc_length": self._sort_by_doc_length,
              "attn_implementation": self._attn_impl,
          }
          return output_scores, meta