Commit 809559351b2cb130a0d8c124f604040035c167ad

Authored by tangwang
1 parent bc089b43

Reranker 补充 qwen3_transformers

.env deleted
@@ -1,41 +0,0 @@ @@ -1,41 +0,0 @@
1 -# Elasticsearch Configuration  
2 -ES_HOST=http://localhost:9200  
3 -ES_USERNAME=saas  
4 -ES_PASSWORD=4hOaLaf41y2VuI8y  
5 -  
6 -# Redis Configuration (Optional) - AI 生产 10.200.16.14:6479  
7 -REDIS_HOST=10.200.16.14  
8 -REDIS_PORT=6479  
9 -REDIS_PASSWORD=dxEkegEZ@C5SXWKv  
10 -  
11 -# DeepL Translation API  
12 -DEEPL_AUTH_KEY=c9293ab4-ad25-479b-919f-ab4e63b429ed  
13 -  
14 -# API Service Configuration  
15 -API_HOST=0.0.0.0  
16 -API_PORT=6002  
17 -  
18 -# MySQL Database Configuration (Shoplazza) - AI 生产 10.200.16.14:3316  
19 -DB_HOST=10.200.16.14  
20 -DB_PORT=3316  
21 -DB_DATABASE=saas  
22 -DB_USERNAME=root  
23 -DB_PASSWORD=qY8tgodLoA&KT#yQ  
24 -  
25 -# Model Directories  
26 -TEXT_MODEL_DIR=/data/tw/models/bge-m3 # 已经改为web请求了,不使用本地模型  
27 -IMAGE_MODEL_DIR=/data/tw/models/cn-clip # 已经改为web请求了,不使用本地模型  
28 -  
29 -# Cache Directory  
30 -CACHE_DIR=.cache  
31 -  
32 -# Frontend API Base URL  
33 -API_BASE_URL=http://43.166.252.75:6002  
34 -  
35 -  
36 -# 国内  
37 -DASHSCOPE_API_KEY=sk-c3b8d4db061840aa8effb748df2a997b  
38 -# 美国  
39 -DASHSCOPE_API_KEY=sk-482cc3ff37a8467dab134a7a46830556  
40 -  
41 -OPENAI_API_KEY=sk-HvmTMKtuznibZ75l7L2uF2jiaYocCthqd8Cbdkl09KTE7Ft0  
reranker/README.md
@@ -4,10 +4,10 @@ @@ -4,10 +4,10 @@
4 4
5 --- 5 ---
6 6
7 -Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM)。调用方通过 HTTP 访问,不关心具体后端。 7 +Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwen3-vLLM、Qwen3-Transformers)。调用方通过 HTTP 访问,不关心具体后端。
8 8
9 **特性** 9 **特性**
10 -- 多后端:`qwen3_vllm`(默认,Qwen3-Reranker-0.6B + vLLM)、`bge`(兼容保留) 10 +- 多后端:`qwen3_vllm`(默认,Qwen3-Reranker-0.6B + vLLM)、`qwen3_transformers`(纯 Transformers,无需 vLLM)、`bge`(兼容保留)
11 - 统一配置:`config/config.yaml` → `services.rerank.backend` / `services.rerank.backends.<name>` 11 - 统一配置:`config/config.yaml` → `services.rerank.backend` / `services.rerank.backends.<name>`
12 - 文档去重、分数与输入顺序一致、FP16/GPU 支持(视后端) 12 - 文档去重、分数与输入顺序一致、FP16/GPU 支持(视后端)
13 13
@@ -17,18 +17,20 @@ Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwe @@ -17,18 +17,20 @@ Reranker 服务提供统一的 `/rerank` API,支持可插拔后端(BGE、Qwe
17 - `backends/__init__.py`:`get_rerank_backend(name, config)` 17 - `backends/__init__.py`:`get_rerank_backend(name, config)`
18 - `backends/bge.py`:BGE 后端 18 - `backends/bge.py`:BGE 后端
19 - `backends/qwen3_vllm.py`:Qwen3-Reranker-0.6B + vLLM 后端 19 - `backends/qwen3_vllm.py`:Qwen3-Reranker-0.6B + vLLM 后端
  20 + - `backends/qwen3_transformers.py`:Qwen3-Reranker-0.6B 纯 Transformers 后端(官方 Usage 方式)
20 - `reranker/bge_reranker.py`:BGE 核心推理(被 bge 后端封装) 21 - `reranker/bge_reranker.py`:BGE 核心推理(被 bge 后端封装)
21 - `reranker/config.py`:服务端口、MAX_DOCS、NORMALIZE 等(后端参数在 config.yaml) 22 - `reranker/config.py`:服务端口、MAX_DOCS、NORMALIZE 等(后端参数在 config.yaml)
22 23
23 ## 依赖 24 ## 依赖
24 - 通用:`torch`、`modelscope`、`fastapi`、`uvicorn`(见项目 `requirements.txt` / `requirements_ml.txt`) 25 - 通用:`torch`、`modelscope`、`fastapi`、`uvicorn`(见项目 `requirements.txt` / `requirements_ml.txt`)
25 -- **Qwen3-vLLM 后端**:`vllm>=0.8.5`、`transformers`(可选,仅当使用 `backend: qwen3_vllm` 时安装) 26 +- **Qwen3-vLLM 后端**:`vllm>=0.8.5`、`transformers>=4.51.0`(仅当使用 `backend: qwen3_vllm` 时需 vLLM)
  27 +- **Qwen3-Transformers 后端**:`transformers>=4.51.0`、`torch`(无需 vLLM,适合 CPU 或小显存)
26 ```bash 28 ```bash
27 ./scripts/setup_reranker_venv.sh 29 ./scripts/setup_reranker_venv.sh
28 ``` 30 ```
29 31
30 ## 配置 32 ## 配置
31 -- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `bge`),或环境变量 `RERANK_BACKEND`。 33 +- **后端选择**:`config/config.yaml` 中 `services.rerank.backend`(`qwen3_vllm` | `qwen3_transformers` | `bge`),或环境变量 `RERANK_BACKEND`。
32 - **后端参数**:`services.rerank.backends.bge` / `services.rerank.backends.qwen3_vllm`,例如: 34 - **后端参数**:`services.rerank.backends.bge` / `services.rerank.backends.qwen3_vllm`,例如:
33 35
34 ```yaml 36 ```yaml
@@ -47,6 +49,12 @@ services: @@ -47,6 +49,12 @@ services:
47 qwen3_vllm: 49 qwen3_vllm:
48 model_name: "Qwen/Qwen3-Reranker-0.6B" 50 model_name: "Qwen/Qwen3-Reranker-0.6B"
49 max_model_len: 8192 51 max_model_len: 8192
  52 + qwen3_transformers:
  53 + model_name: "Qwen/Qwen3-Reranker-0.6B"
  54 + instruction: "Given a web search query, retrieve relevant passages that answer the query"
  55 + max_length: 8192
  56 + batch_size: 64
  57 + use_fp16: true
50 tensor_parallel_size: 1 58 tensor_parallel_size: 1
51 gpu_memory_utilization: 0.8 59 gpu_memory_utilization: 0.8
52 enable_prefix_caching: true 60 enable_prefix_caching: true
@@ -111,3 +119,4 @@ uvicorn reranker.server:app --host 0.0.0.0 --port 6007 --log-level info @@ -111,3 +119,4 @@ uvicorn reranker.server:app --host 0.0.0.0 --port 6007 --log-level info
111 - 无请求级缓存;输入按字符串去重后推理,再按原始顺序回填分数。 119 - 无请求级缓存;输入按字符串去重后推理,再按原始顺序回填分数。
112 - 空或 null 的 doc 跳过并计为 0。 120 - 空或 null 的 doc 跳过并计为 0。
113 - **Qwen3-vLLM**:参考 [Qwen3-Reranker-0.6B](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B),需 GPU 与较多显存;与 BGE 相比适合长文本、高吞吐场景(vLLM 前缀缓存)。 121 - **Qwen3-vLLM**:参考 [Qwen3-Reranker-0.6B](https://huggingface.co/Qwen/Qwen3-Reranker-0.6B),需 GPU 与较多显存;与 BGE 相比适合长文本、高吞吐场景(vLLM 前缀缓存)。
  122 +- **Qwen3-Transformers**:官方 Transformers Usage 方式,无需 vLLM;适合 CPU 或小显存,可选 `attn_implementation: "flash_attention_2"` 加速。
reranker/backends/__init__.py
@@ -43,7 +43,12 @@ def get_rerank_backend(name: str, config: Dict[str, Any]) -&gt; RerankBackendProtoc @@ -43,7 +43,12 @@ def get_rerank_backend(name: str, config: Dict[str, Any]) -&gt; RerankBackendProtoc
43 if name == "qwen3_vllm": 43 if name == "qwen3_vllm":
44 from reranker.backends.qwen3_vllm import Qwen3VLLMRerankerBackend 44 from reranker.backends.qwen3_vllm import Qwen3VLLMRerankerBackend
45 return Qwen3VLLMRerankerBackend(config) 45 return Qwen3VLLMRerankerBackend(config)
46 - raise ValueError(f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm") 46 + if name == "qwen3_transformers":
  47 + from reranker.backends.qwen3_transformers import Qwen3TransformersRerankerBackend
  48 + return Qwen3TransformersRerankerBackend(config)
  49 + raise ValueError(
  50 + f"Unknown rerank backend: {name!r}. Supported: bge, qwen3_vllm, qwen3_transformers"
  51 + )
47 52
48 53
49 __all__ = ["RerankBackendProtocol", "get_rerank_backend"] 54 __all__ = ["RerankBackendProtocol", "get_rerank_backend"]
reranker/backends/qwen3_transformers.py 0 → 100644
@@ -0,0 +1,204 @@ @@ -0,0 +1,204 @@
  1 +"""
  2 +Qwen3-Reranker-0.6B backend using Transformers (direct usage). No vLLM required.
  3 +
  4 +Reference: https://huggingface.co/Qwen/Qwen3-Reranker-0.6B
  5 +Requires: transformers>=4.51.0, torch.
  6 +"""
  7 +
  8 +from __future__ import annotations
  9 +
  10 +import logging
  11 +import time
  12 +from typing import Any, Dict, List, Optional, Tuple
  13 +
  14 +logger = logging.getLogger("reranker.backends.qwen3_transformers")
  15 +
  16 +try:
  17 + import torch
  18 + from transformers import AutoModelForCausalLM, AutoTokenizer
  19 +except ImportError as e:
  20 + raise ImportError(
  21 + "Qwen3-Transformers reranker backend requires transformers>=4.51.0 and torch. "
  22 + "Install with: pip install transformers>=4.51.0 torch"
  23 + ) from e
  24 +
  25 +
  26 +def _format_instruction(instruction: str, query: str, doc: str) -> str:
  27 + """Format (query, doc) pair per official Qwen3-Reranker spec."""
  28 + return "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
  29 + instruction=instruction, query=query, doc=doc
  30 + )
  31 +
  32 +
  33 +class Qwen3TransformersRerankerBackend:
  34 + """
  35 + Qwen3-Reranker-0.6B with Transformers (AutoModelForCausalLM) inference.
  36 + Config from services.rerank.backends.qwen3_transformers.
  37 + No vLLM dependency; lighter than qwen3_vllm, suitable for CPU or small GPU.
  38 + """
  39 +
  40 + def __init__(self, config: Dict[str, Any]) -> None:
  41 + self._config = config or {}
  42 + model_name = str(self._config.get("model_name") or "Qwen/Qwen3-Reranker-0.6B")
  43 + self._instruction = str(
  44 + self._config.get("instruction")
  45 + or "Given a web search query, retrieve relevant passages that answer the query"
  46 + )
  47 + max_length = int(self._config.get("max_length", 8192))
  48 + batch_size = int(self._config.get("batch_size", 64))
  49 + use_fp16 = bool(self._config.get("use_fp16", True))
  50 + device = self._config.get("device")
  51 + attn_impl = self._config.get("attn_implementation") # e.g. "flash_attention_2"
  52 +
  53 + self._model_name = model_name
  54 + self._batch_size = batch_size
  55 +
  56 + logger.info(
  57 + "[Qwen3_Transformers] Loading model %s (max_length=%s, batch=%s, fp16=%s)",
  58 + model_name,
  59 + max_length,
  60 + batch_size,
  61 + use_fp16,
  62 + )
  63 +
  64 + self._tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
  65 + self._tokenizer.pad_token = self._tokenizer.eos_token
  66 +
  67 + # Prefix/suffix from official reference
  68 + prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
  69 + suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
  70 + self._prefix_tokens = self._tokenizer.encode(prefix, add_special_tokens=False)
  71 + self._suffix_tokens = self._tokenizer.encode(suffix, add_special_tokens=False)
  72 + self._max_length = max_length
  73 + self._effective_max_len = max_length - len(self._prefix_tokens) - len(self._suffix_tokens)
  74 +
  75 + self._token_true_id = self._tokenizer.convert_tokens_to_ids("yes")
  76 + self._token_false_id = self._tokenizer.convert_tokens_to_ids("no")
  77 +
  78 + kwargs = {}
  79 + if use_fp16 and torch.cuda.is_available():
  80 + kwargs["torch_dtype"] = torch.float16
  81 + if attn_impl:
  82 + kwargs["attn_implementation"] = attn_impl
  83 +
  84 + self._model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs).eval()
  85 + if device is not None:
  86 + self._model = self._model.to(device)
  87 + elif torch.cuda.is_available():
  88 + self._model = self._model.cuda()
  89 +
  90 + logger.info(
  91 + "[Qwen3_Transformers] Model ready | model=%s device=%s",
  92 + model_name,
  93 + next(self._model.parameters()).device,
  94 + )
  95 +
  96 + def _process_inputs(self, pairs: List[str]) -> Dict[str, torch.Tensor]:
  97 + """Tokenize pairs and add prefix/suffix tokens. Returns batched tensors on model device."""
  98 + inputs = self._tokenizer(
  99 + pairs,
  100 + padding=False,
  101 + truncation="longest_first",
  102 + return_attention_mask=False,
  103 + max_length=self._effective_max_len,
  104 + )
  105 + for i, ele in enumerate(inputs["input_ids"]):
  106 + inputs["input_ids"][i] = self._prefix_tokens + ele + self._suffix_tokens
  107 + inputs = self._tokenizer.pad(
  108 + inputs,
  109 + padding=True,
  110 + return_tensors="pt",
  111 + )
  112 + for key in inputs:
  113 + inputs[key] = inputs[key].to(self._model.device)
  114 + return inputs
  115 +
  116 + @torch.no_grad()
  117 + def _compute_scores(self, pairs: List[str]) -> List[float]:
  118 + """Run forward pass and compute yes/no probability per pair."""
  119 + if not pairs:
  120 + return []
  121 + inputs = self._process_inputs(pairs)
  122 + outputs = self._model(**inputs)
  123 + batch_scores = outputs.logits[:, -1, :]
  124 + true_vector = batch_scores[:, self._token_true_id]
  125 + false_vector = batch_scores[:, self._token_false_id]
  126 + batch_scores = torch.stack([false_vector, true_vector], dim=1)
  127 + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
  128 + scores = batch_scores[:, 1].exp().tolist()
  129 + return scores
  130 +
  131 + def score_with_meta(
  132 + self,
  133 + query: str,
  134 + docs: List[str],
  135 + normalize: bool = True,
  136 + ) -> Tuple[List[float], Dict[str, Any]]:
  137 + start_ts = time.time()
  138 + total_docs = len(docs) if docs else 0
  139 + output_scores: List[float] = [0.0] * total_docs
  140 +
  141 + query = "" if query is None else str(query).strip()
  142 + indexed: List[Tuple[int, str]] = []
  143 + for i, doc in enumerate(docs or []):
  144 + if doc is None:
  145 + continue
  146 + text = str(doc).strip()
  147 + if not text:
  148 + continue
  149 + indexed.append((i, text))
  150 +
  151 + if not query or not indexed:
  152 + elapsed_ms = (time.time() - start_ts) * 1000.0
  153 + return output_scores, {
  154 + "input_docs": total_docs,
  155 + "usable_docs": len(indexed),
  156 + "unique_docs": 0,
  157 + "dedup_ratio": 0.0,
  158 + "elapsed_ms": round(elapsed_ms, 3),
  159 + "model": self._model_name,
  160 + "backend": "qwen3_transformers",
  161 + "normalize": normalize,
  162 + }
  163 +
  164 + # Deduplicate by text, keep mapping to original indices
  165 + unique_texts: List[str] = []
  166 + position_to_unique: List[int] = []
  167 + prev: Optional[str] = None
  168 + for _idx, text in indexed:
  169 + if text != prev:
  170 + unique_texts.append(text)
  171 + prev = text
  172 + position_to_unique.append(len(unique_texts) - 1)
  173 +
  174 + pairs = [
  175 + _format_instruction(self._instruction, query, t)
  176 + for t in unique_texts
  177 + ]
  178 +
  179 + # Batch inference
  180 + unique_scores: List[float] = []
  181 + for i in range(0, len(pairs), self._batch_size):
  182 + batch = pairs[i : i + self._batch_size]
  183 + batch_scores = self._compute_scores(batch)
  184 + unique_scores.extend(batch_scores)
  185 +
  186 + for (orig_idx, _), unique_idx in zip(indexed, position_to_unique):
  187 + output_scores[orig_idx] = float(unique_scores[unique_idx])
  188 +
  189 + elapsed_ms = (time.time() - start_ts) * 1000.0
  190 + dedup_ratio = 0.0
  191 + if indexed:
  192 + dedup_ratio = 1.0 - (len(unique_texts) / float(len(indexed)))
  193 +
  194 + meta = {
  195 + "input_docs": total_docs,
  196 + "usable_docs": len(indexed),
  197 + "unique_docs": len(unique_texts),
  198 + "dedup_ratio": round(dedup_ratio, 4),
  199 + "elapsed_ms": round(elapsed_ms, 3),
  200 + "model": self._model_name,
  201 + "backend": "qwen3_transformers",
  202 + "normalize": normalize,
  203 + }
  204 + return output_scores, meta