Blame view

reranker/backends/qwen3_transformers.py 7.67 KB
80955935   tangwang   Reranker 补充 qwen3...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
  """
  Qwen3-Reranker-0.6B backend using Transformers (direct usage). No vLLM required.
  
  Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
  Requires: transformers>=4.51.0, torch.
  """
  
  from __future__ import annotations
  
  import logging
  import time
  from typing import Any, Dict, List, Optional, Tuple
  
  logger = logging.getLogger("reranker.backends.qwen3_transformers")
  
  try:
      import torch
      from transformers import AutoModelForCausalLM, AutoTokenizer
  except ImportError as e:
      raise ImportError(
          "Qwen3-Transformers reranker backend requires transformers>=4.51.0 and torch. "
          "Install with: pip install transformers>=4.51.0 torch"
      ) from e
  
  
  def _format_instruction(instruction: str, query: str, doc: str) -> str:
      """Format (query, doc) pair per official Qwen3-Reranker spec."""
      return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
          instruction=instruction, query=query, doc=doc
      )
  
  
  class Qwen3TransformersRerankerBackend:
      """
      Qwen3-Reranker-0.6B with Transformers (AutoModelForCausalLM) inference.
      Config from services.rerank.backends.qwen3_transformers.
      No vLLM dependency; lighter than qwen3_vllm, suitable for CPU or small GPU.
      """
  
      def __init__(self, config: Dict[str, Any]) -> None:
          self._config = config or {}
          model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
          self._instruction = str(
              self._config.get("instruction")
              or "Given a web search query, retrieve relevant passages that answer the query"
          )
          max_length = int(self._config.get("max_length", 8192))
          batch_size = int(self._config.get("batch_size", 64))
          use_fp16 = bool(self._config.get("use_fp16", True))
          device = self._config.get("device")
          attn_impl = self._config.get("attn_implementation")  # e.g. "flash_attention_2"
  
          self._model_name = model_name
          self._batch_size = batch_size
  
          logger.info(
              "[Qwen3_Transformers] Loading model %s (max_length=%s, batch=%s, fp16=%s)",
              model_name,
              max_length,
              batch_size,
              use_fp16,
          )
  
          self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
          self._tokenizer.pad_token = self._tokenizer.eos_token
  
          # Prefix/suffix from official reference
          prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
          suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
          self._prefix_tokens = self._tokenizer.encode(prefix, add_special_tokens=False)
          self._suffix_tokens = self._tokenizer.encode(suffix, add_special_tokens=False)
          self._max_length = max_length
          self._effective_max_len = max_length - len(self._prefix_tokens) - len(self._suffix_tokens)
  
          self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes")
          self._token_false_id = self._tokenizer.convert_tokens_to_ids("no")
  
          kwargs = {}
          if use_fp16 and torch.cuda.is_available():
              kwargs["torch_dtype"] = torch.float16
          if attn_impl:
              kwargs["attn_implementation"] = attn_impl
  
          self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval()
          if device is not None:
              self._model = self._model.to(device)
          elif torch.cuda.is_available():
              self._model = self._model.cuda()
  
          logger.info(
              "[Qwen3_Transformers] Model ready | model=%s device=%s",
              model_name,
              next(self._model.parameters()).device,
          )
  
      def _process_inputs(self, pairs: List[str]) -> Dict[str, torch.Tensor]:
          """Tokenize pairs and add prefix/suffix tokens. Returns batched tensors on model device."""
          inputs = self._tokenizer(
              pairs,
              padding=False,
              truncation="longest_first",
              return_attention_mask=False,
              max_length=self._effective_max_len,
          )
          for i, ele in enumerate(inputs["input_ids"]):
              inputs["input_ids"][i] = self._prefix_tokens + ele + self._suffix_tokens
          inputs = self._tokenizer.pad(
              inputs,
              padding=True,
              return_tensors="pt",
          )
          for key in inputs:
              inputs[key] = inputs[key].to(self._model.device)
          return inputs
  
      @torch.no_grad()
      def _compute_scores(self, pairs: List[str]) -> List[float]:
          """Run forward pass and compute yes/no probability per pair."""
          if not pairs:
              return []
          inputs = self._process_inputs(pairs)
          outputs = self._model(**inputs)
          batch_scores = outputs.logits[:, -1, :]
          true_vector = batch_scores[:, self._token_true_id]
          false_vector = batch_scores[:, self._token_false_id]
          batch_scores = torch.stack([false_vector, true_vector], dim=1)
          batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
          scores = batch_scores[:, 1].exp().tolist()
          return scores
  
      def score_with_meta(
          self,
          query: str,
          docs: List[str],
          normalize: bool = True,
      ) -> Tuple[List[float], Dict[str, Any]]:
          start_ts = time.time()
          total_docs = len(docs) if docs else 0
          output_scores: List[float] = [0.0] * total_docs
  
          query = "" if query is None else str(query).strip()
          indexed: List[Tuple[int, str]] = []
          for i, doc in enumerate(docs or []):
              if doc is None:
                  continue
              text = str(doc).strip()
              if not text:
                  continue
              indexed.append((i, text))
  
          if not query or not indexed:
              elapsed_ms = (time.time() - start_ts) * 1000.0
              return output_scores, {
                  "input_docs": total_docs,
                  "usable_docs": len(indexed),
                  "unique_docs": 0,
                  "dedup_ratio": 0.0,
                  "elapsed_ms": round(elapsed_ms, 3),
                  "model": self._model_name,
                  "backend": "qwen3_transformers",
                  "normalize": normalize,
              }
  
          # Deduplicate by text, keep mapping to original indices
          unique_texts: List[str] = []
          position_to_unique: List[int] = []
          prev: Optional[str] = None
          for _idx, text in indexed:
              if text != prev:
                  unique_texts.append(text)
                  prev = text
              position_to_unique.append(len(unique_texts) - 1)
  
          pairs = [
              _format_instruction(self._instruction, query, t)
              for t in unique_texts
          ]
  
          # Batch inference
          unique_scores: List[float] = []
          for i in range(0, len(pairs), self._batch_size):
              batch = pairs[i : i + self._batch_size]
              batch_scores = self._compute_scores(batch)
              unique_scores.extend(batch_scores)
  
          for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
              output_scores[orig_idx] = float(unique_scores[unique_idx])
  
          elapsed_ms = (time.time() - start_ts) * 1000.0
          dedup_ratio = 0.0
          if indexed:
              dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
  
          meta = {
              "input_docs": total_docs,
              "usable_docs": len(indexed),
              "unique_docs": len(unique_texts),
              "dedup_ratio": round(dedup_ratio, 4),
              "elapsed_ms": round(elapsed_ms, 3),
              "model": self._model_name,
              "backend": "qwen3_transformers",
              "normalize": normalize,
          }
          return output_scores, meta