Commit 809559351b2cb130a0d8c124f604040035c167ad
1 parent
bc089b43
Reranker 补充 qwen3_transformers
Showing
4 changed files
with
223 additions
and
46 deletions
Show diff stats
.env deleted
| ... | ... | @@ -1,41 +0,0 @@ |
| 1 | -# Elasticsearch Configuration | |
| 2 | -ES_HOST=http://localhost:9200 | |
| 3 | -ES_USERNAME=saas | |
| 4 | -ES_PASSWORD=4hOaLaf41y2VuI8y | |
| 5 | - | |
| 6 | -# Redis Configuration (Optional) - AI 生产 10.200.16.14:6479 | |
| 7 | -REDIS_HOST=10.200.16.14 | |
| 8 | -REDIS_PORT=6479 | |
| 9 | -REDIS_PASSWORD=dxEkegEZ@C5SXWKv | |
| 10 | - | |
| 11 | -# DeepL Translation API | |
| 12 | -DEEPL_AUTH_KEY=c9293ab4-ad25-479b-919f-ab4e63b429ed | |
| 13 | - | |
| 14 | -# API Service Configuration | |
| 15 | -API_HOST=0.0.0.0 | |
| 16 | -API_PORT=6002 | |
| 17 | - | |
| 18 | -# MySQL Database Configuration (Shoplazza) - AI 生产 10.200.16.14:3316 | |
| 19 | -DB_HOST=10.200.16.14 | |
| 20 | -DB_PORT=3316 | |
| 21 | -DB_DATABASE=saas | |
| 22 | -DB_USERNAME=root | |
| 23 | -DB_PASSWORD=qY8tgodLoA&KT#yQ | |
| 24 | - | |
| 25 | -# Model Directories | |
| 26 | -TEXT_MODEL_DIR=/data/tw/models/bge-m3 # 已经改为web请求了,不使用本地模型 | |
| 27 | -IMAGE_MODEL_DIR=/data/tw/models/cn-clip # 已经改为web请求了,不使用本地模型 | |
| 28 | - | |
| 29 | -# Cache Directory | |
| 30 | -CACHE_DIR=.cache | |
| 31 | - | |
| 32 | -# Frontend API Base URL | |
| 33 | -API_BASE_URL=http://43.166.252.75:6002 | |
| 34 | - | |
| 35 | - | |
| 36 | -# 国内 | |
| 37 | -DASHSCOPE_API_KEY=sk-c3b8d4db061840aa8effb748df2a997b | |
| 38 | -# 美国 | |
| 39 | -DASHSCOPE_API_KEY=sk-482cc3ff37a8467dab134a7a46830556 | |
| 40 | - | |
| 41 | -OPENAI_API_KEY=sk-HvmTMKtuznibZ75l7L2uF2jiaYocCthqd8Cbdkl09KTE7Ft0 |
reranker/README.md
| ... | ... | @@ -4,10 +4,10 @@ |
| 4 | 4 | |
| 5 | 5 | --- |
| 6 | 6 | |
| 7 | -Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM)。调用方通过 HTTP 访问,不关心具体后端。 | |
| 7 | +Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers)。调用方通过 HTTP 访问,不关心具体后端。 | |
| 8 | 8 | |
| 9 | 9 | **特性** |
| 10 | -- 多后端:`qwen3_vllm`(默认,Qwen3-Reranker-0.6B + vLLM)、`bge`(兼容保留) | |
| 10 | +- 多后端:`qwen3_vllm`(默认,Qwen3-Reranker-0.6B + vLLM)、`qwen3_transformers`(纯 Transformers,无需 vLLM)、`bge`(兼容保留) | |
| 11 | 11 | - 统一配置:`config/config.yaml` → `services.rerank.backend` / `services.rerank.backends.<name>` |
| 12 | 12 | - 文档去重、分数与输入顺序一致、FP16/GPU 支持(视后端) |
| 13 | 13 | |
| ... | ... | @@ -17,18 +17,20 @@ Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwe |
| 17 | 17 | - `backends/__init__.py`:`get_rerank_backend(name, config)` |
| 18 | 18 | - `backends/bge.py`:BGE 后端 |
| 19 | 19 | - `backends/qwen3_vllm.py`:Qwen3-Reranker-0.6B + vLLM 后端 |
| 20 | + - `backends/qwen3_transformers.py`:Qwen3-Reranker-0.6B 纯 Transformers 后端(官方 Usage 方式) | |
| 20 | 21 | - `reranker/bge_reranker.py`:BGE 核心推理(被 bge 后端封装) |
| 21 | 22 | - `reranker/config.py`:服务端口、MAX_DOCS、NORMALIZE 等(后端参数在 config.yaml) |
| 22 | 23 | |
| 23 | 24 | ## 依赖 |
| 24 | 25 | - 通用:`torch`、`modelscope`、`fastapi`、`uvicorn`(见项目 `requirements.txt` / `requirements_ml.txt`) |
| 25 | -- **Qwen3-vLLM 后端**:`vllm>=0.8.5`、`transformers`(可选,仅当使用 `backend: qwen3_vllm` 时安装) | |
| 26 | +- **Qwen3-vLLM 后端**:`vllm>=0.8.5`、`transformers>=4.51.0`(仅当使用 `backend: qwen3_vllm` 时需 vLLM) | |
| 27 | +- **Qwen3-Transformers 后端**:`transformers>=4.51.0`、`torch`(无需 vLLM,适合 CPU 或小显存) | |
| 26 | 28 | ```bash |
| 27 | 29 | ./scripts/setup_reranker_venv.sh |
| 28 | 30 | ``` |
| 29 | 31 | |
| 30 | 32 | ## 配置 |
| 31 | -- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `bge`),或环境变量 `RERANK_BACKEND`。 | |
| 33 | +- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_transformers` | `bge`),或环境变量 `RERANK_BACKEND`。 | |
| 32 | 34 | - **后端参数**:`services.rerank.backends.bge` / `services.rerank.backends.qwen3_vllm`,例如: |
| 33 | 35 | |
| 34 | 36 | ```yaml |
| ... | ... | @@ -47,6 +49,12 @@ services: |
| 47 | 49 | qwen3_vllm: |
| 48 | 50 | model_name: "Qwen/Qwen3-Reranker-0.6B" |
| 49 | 51 | max_model_len: 8192 |
| 52 | + qwen3_transformers: | |
| 53 | + model_name: "Qwen/Qwen3-Reranker-0.6B" | |
| 54 | + instruction: "Given a web search query, retrieve relevant passages that answer the query" | |
| 55 | + max_length: 8192 | |
| 56 | + batch_size: 64 | |
| 57 | + use_fp16: true | |
| 50 | 58 | tensor_parallel_size: 1 |
| 51 | 59 | gpu_memory_utilization: 0.8 |
| 52 | 60 | enable_prefix_caching: true |
| ... | ... | @@ -111,3 +119,4 @@ uvicorn reranker.server:app --host 0.0.0.0 --port 6007 --log-level info |
| 111 | 119 | - 无请求级缓存;输入按字符串去重后推理,再按原始顺序回填分数。 |
| 112 | 120 | - 空或 null 的 doc 跳过并计为 0。 |
| 113 | 121 | - **Qwen3-vLLM**:参考 [Qwen3-Reranker-0.6B](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B),需 GPU 与较多显存;与 BGE 相比适合长文本、高吞吐场景(vLLM 前缀缓存)。 |
| 122 | +- **Qwen3-Transformers**:官方 Transformers Usage 方式,无需 vLLM;适合 CPU 或小显存,可选 `attn_implementation: "flash_attention_2"` 加速。 | ... | ... |
reranker/backends/__init__.py
| ... | ... | @@ -43,7 +43,12 @@ def get_rerank_backend(name: str, config: Dict[str, Any]) -> RerankBackendProtoc |
| 43 | 43 | if name == "qwen3_vllm": |
| 44 | 44 | from reranker.backends.qwen3_vllm import Qwen3VLLMRerankerBackend |
| 45 | 45 | return Qwen3VLLMRerankerBackend(config) |
| 46 | - raise ValueError(f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm") | |
| 46 | + if name == "qwen3_transformers": | |
| 47 | + from reranker.backends.qwen3_transformers import Qwen3TransformersRerankerBackend | |
| 48 | + return Qwen3TransformersRerankerBackend(config) | |
| 49 | + raise ValueError( | |
| 50 | + f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_transformers" | |
| 51 | + ) | |
| 47 | 52 | |
| 48 | 53 | |
| 49 | 54 | __all__ = ["RerankBackendProtocol", "get_rerank_backend"] | ... | ... |
| ... | ... | @@ -0,0 +1,204 @@ |
| 1 | +""" | |
| 2 | +Qwen3-Reranker-0.6B backend using Transformers (direct usage). No vLLM required. | |
| 3 | + | |
| 4 | +Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B | |
| 5 | +Requires: transformers>=4.51.0, torch. | |
| 6 | +""" | |
| 7 | + | |
| 8 | +from __future__ import annotations | |
| 9 | + | |
| 10 | +import logging | |
| 11 | +import time | |
| 12 | +from typing import Any, Dict, List, Optional, Tuple | |
| 13 | + | |
| 14 | +logger = logging.getLogger("reranker.backends.qwen3_transformers") | |
| 15 | + | |
| 16 | +try: | |
| 17 | + import torch | |
| 18 | + from transformers import AutoModelForCausalLM, AutoTokenizer | |
| 19 | +except ImportError as e: | |
| 20 | + raise ImportError( | |
| 21 | + "Qwen3-Transformers reranker backend requires transformers>=4.51.0 and torch. " | |
| 22 | + "Install with: pip install transformers>=4.51.0 torch" | |
| 23 | + ) from e | |
| 24 | + | |
| 25 | + | |
| 26 | +def _format_instruction(instruction: str, query: str, doc: str) -> str: | |
| 27 | + """Format (query, doc) pair per official Qwen3-Reranker spec.""" | |
| 28 | + return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format( | |
| 29 | + instruction=instruction, query=query, doc=doc | |
| 30 | + ) | |
| 31 | + | |
| 32 | + | |
| 33 | +class Qwen3TransformersRerankerBackend: | |
| 34 | + """ | |
| 35 | + Qwen3-Reranker-0.6B with Transformers (AutoModelForCausalLM) inference. | |
| 36 | + Config from services.rerank.backends.qwen3_transformers. | |
| 37 | + No vLLM dependency; lighter than qwen3_vllm, suitable for CPU or small GPU. | |
| 38 | + """ | |
| 39 | + | |
| 40 | + def __init__(self, config: Dict[str, Any]) -> None: | |
| 41 | + self._config = config or {} | |
| 42 | + model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B") | |
| 43 | + self._instruction = str( | |
| 44 | + self._config.get("instruction") | |
| 45 | + or "Given a web search query, retrieve relevant passages that answer the query" | |
| 46 | + ) | |
| 47 | + max_length = int(self._config.get("max_length", 8192)) | |
| 48 | + batch_size = int(self._config.get("batch_size", 64)) | |
| 49 | + use_fp16 = bool(self._config.get("use_fp16", True)) | |
| 50 | + device = self._config.get("device") | |
| 51 | + attn_impl = self._config.get("attn_implementation") # e.g. "flash_attention_2" | |
| 52 | + | |
| 53 | + self._model_name = model_name | |
| 54 | + self._batch_size = batch_size | |
| 55 | + | |
| 56 | + logger.info( | |
| 57 | + "[Qwen3_Transformers] Loading model %s (max_length=%s, batch=%s, fp16=%s)", | |
| 58 | + model_name, | |
| 59 | + max_length, | |
| 60 | + batch_size, | |
| 61 | + use_fp16, | |
| 62 | + ) | |
| 63 | + | |
| 64 | + self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") | |
| 65 | + self._tokenizer.pad_token = self._tokenizer.eos_token | |
| 66 | + | |
| 67 | + # Prefix/suffix from official reference | |
| 68 | + prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" | |
| 69 | + suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | |
| 70 | + self._prefix_tokens = self._tokenizer.encode(prefix, add_special_tokens=False) | |
| 71 | + self._suffix_tokens = self._tokenizer.encode(suffix, add_special_tokens=False) | |
| 72 | + self._max_length = max_length | |
| 73 | + self._effective_max_len = max_length - len(self._prefix_tokens) - len(self._suffix_tokens) | |
| 74 | + | |
| 75 | + self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes") | |
| 76 | + self._token_false_id = self._tokenizer.convert_tokens_to_ids("no") | |
| 77 | + | |
| 78 | + kwargs = {} | |
| 79 | + if use_fp16 and torch.cuda.is_available(): | |
| 80 | + kwargs["torch_dtype"] = torch.float16 | |
| 81 | + if attn_impl: | |
| 82 | + kwargs["attn_implementation"] = attn_impl | |
| 83 | + | |
| 84 | + self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval() | |
| 85 | + if device is not None: | |
| 86 | + self._model = self._model.to(device) | |
| 87 | + elif torch.cuda.is_available(): | |
| 88 | + self._model = self._model.cuda() | |
| 89 | + | |
| 90 | + logger.info( | |
| 91 | + "[Qwen3_Transformers] Model ready | model=%s device=%s", | |
| 92 | + model_name, | |
| 93 | + next(self._model.parameters()).device, | |
| 94 | + ) | |
| 95 | + | |
| 96 | + def _process_inputs(self, pairs: List[str]) -> Dict[str, torch.Tensor]: | |
| 97 | + """Tokenize pairs and add prefix/suffix tokens. Returns batched tensors on model device.""" | |
| 98 | + inputs = self._tokenizer( | |
| 99 | + pairs, | |
| 100 | + padding=False, | |
| 101 | + truncation="longest_first", | |
| 102 | + return_attention_mask=False, | |
| 103 | + max_length=self._effective_max_len, | |
| 104 | + ) | |
| 105 | + for i, ele in enumerate(inputs["input_ids"]): | |
| 106 | + inputs["input_ids"][i] = self._prefix_tokens + ele + self._suffix_tokens | |
| 107 | + inputs = self._tokenizer.pad( | |
| 108 | + inputs, | |
| 109 | + padding=True, | |
| 110 | + return_tensors="pt", | |
| 111 | + ) | |
| 112 | + for key in inputs: | |
| 113 | + inputs[key] = inputs[key].to(self._model.device) | |
| 114 | + return inputs | |
| 115 | + | |
| 116 | + @torch.no_grad() | |
| 117 | + def _compute_scores(self, pairs: List[str]) -> List[float]: | |
| 118 | + """Run forward pass and compute yes/no probability per pair.""" | |
| 119 | + if not pairs: | |
| 120 | + return [] | |
| 121 | + inputs = self._process_inputs(pairs) | |
| 122 | + outputs = self._model(**inputs) | |
| 123 | + batch_scores = outputs.logits[:, -1, :] | |
| 124 | + true_vector = batch_scores[:, self._token_true_id] | |
| 125 | + false_vector = batch_scores[:, self._token_false_id] | |
| 126 | + batch_scores = torch.stack([false_vector, true_vector], dim=1) | |
| 127 | + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) | |
| 128 | + scores = batch_scores[:, 1].exp().tolist() | |
| 129 | + return scores | |
| 130 | + | |
| 131 | + def score_with_meta( | |
| 132 | + self, | |
| 133 | + query: str, | |
| 134 | + docs: List[str], | |
| 135 | + normalize: bool = True, | |
| 136 | + ) -> Tuple[List[float], Dict[str, Any]]: | |
| 137 | + start_ts = time.time() | |
| 138 | + total_docs = len(docs) if docs else 0 | |
| 139 | + output_scores: List[float] = [0.0] * total_docs | |
| 140 | + | |
| 141 | + query = "" if query is None else str(query).strip() | |
| 142 | + indexed: List[Tuple[int, str]] = [] | |
| 143 | + for i, doc in enumerate(docs or []): | |
| 144 | + if doc is None: | |
| 145 | + continue | |
| 146 | + text = str(doc).strip() | |
| 147 | + if not text: | |
| 148 | + continue | |
| 149 | + indexed.append((i, text)) | |
| 150 | + | |
| 151 | + if not query or not indexed: | |
| 152 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 153 | + return output_scores, { | |
| 154 | + "input_docs": total_docs, | |
| 155 | + "usable_docs": len(indexed), | |
| 156 | + "unique_docs": 0, | |
| 157 | + "dedup_ratio": 0.0, | |
| 158 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 159 | + "model": self._model_name, | |
| 160 | + "backend": "qwen3_transformers", | |
| 161 | + "normalize": normalize, | |
| 162 | + } | |
| 163 | + | |
| 164 | + # Deduplicate by text, keep mapping to original indices | |
| 165 | + unique_texts: List[str] = [] | |
| 166 | + position_to_unique: List[int] = [] | |
| 167 | + prev: Optional[str] = None | |
| 168 | + for _idx, text in indexed: | |
| 169 | + if text != prev: | |
| 170 | + unique_texts.append(text) | |
| 171 | + prev = text | |
| 172 | + position_to_unique.append(len(unique_texts) - 1) | |
| 173 | + | |
| 174 | + pairs = [ | |
| 175 | + _format_instruction(self._instruction, query, t) | |
| 176 | + for t in unique_texts | |
| 177 | + ] | |
| 178 | + | |
| 179 | + # Batch inference | |
| 180 | + unique_scores: List[float] = [] | |
| 181 | + for i in range(0, len(pairs), self._batch_size): | |
| 182 | + batch = pairs[i : i + self._batch_size] | |
| 183 | + batch_scores = self._compute_scores(batch) | |
| 184 | + unique_scores.extend(batch_scores) | |
| 185 | + | |
| 186 | + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique): | |
| 187 | + output_scores[orig_idx] = float(unique_scores[unique_idx]) | |
| 188 | + | |
| 189 | + elapsed_ms = (time.time() - start_ts) * 1000.0 | |
| 190 | + dedup_ratio = 0.0 | |
| 191 | + if indexed: | |
| 192 | + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed))) | |
| 193 | + | |
| 194 | + meta = { | |
| 195 | + "input_docs": total_docs, | |
| 196 | + "usable_docs": len(indexed), | |
| 197 | + "unique_docs": len(unique_texts), | |
| 198 | + "dedup_ratio": round(dedup_ratio, 4), | |
| 199 | + "elapsed_ms": round(elapsed_ms, 3), | |
| 200 | + "model": self._model_name, | |
| 201 | + "backend": "qwen3_transformers", | |
| 202 | + "normalize": normalize, | |
| 203 | + } | |
| 204 | + return output_scores, meta | ... | ... |