Commit 749d78c8351a81bdf33f7b2a1644448fbb0958d1
1 parent
4823f463
支持 reranker精简instruction
Showing
4 changed files
with
64 additions
and
6 deletions
Show diff stats
config/config.yaml
| @@ -402,6 +402,8 @@ services: | @@ -402,6 +402,8 @@ services: | ||
| 402 | enforce_eager: false | 402 | enforce_eager: false |
| 403 | infer_batch_size: 100 | 403 | infer_batch_size: 100 |
| 404 | sort_by_doc_length: true | 404 | sort_by_doc_length: true |
| 405 | + # 与 reranker/backends/qwen3_vllm.py 一致:standard=_format_instruction__standard(固定 yes/no system);compact=_format_instruction(instruction 作 system 且 user 内重复 Instruct) | ||
| 406 | + instruction_format: compact | ||
| 405 | # instruction: "Given a query, score the product for relevance" | 407 | # instruction: "Given a query, score the product for relevance" |
| 406 | # "rank products by given query" 比 “Given a query, score the product for relevance” 更好点 | 408 | # "rank products by given query" 比 “Given a query, score the product for relevance” 更好点 |
| 407 | # instruction: "rank products by given query, category match first" | 409 | # instruction: "rank products by given query, category match first" |
| @@ -433,6 +435,9 @@ services: | @@ -433,6 +435,9 @@ services: | ||
| 433 | enforce_eager: false | 435 | enforce_eager: false |
| 434 | infer_batch_size: 100 | 436 | infer_batch_size: 100 |
| 435 | sort_by_doc_length: true | 437 | sort_by_doc_length: true |
| 438 | + # 与 qwen3_vllm 同名项语义一致;默认 standard 与 vLLM 官方 Qwen3 reranker 前缀一致 | ||
| 439 | + # instruction_format: standard | ||
| 440 | + instruction_format: compact | ||
| 436 | instruction: "Rank products by query with category & style match prioritized" | 441 | instruction: "Rank products by query with category & style match prioritized" |
| 437 | qwen3_transformers: | 442 | qwen3_transformers: |
| 438 | model_name: "Qwen/Qwen3-Reranker-0.6B" | 443 | model_name: "Qwen/Qwen3-Reranker-0.6B" |
requirements_reranker_qwen3_transformers_packed.txt
| 1 | # Isolated dependencies for qwen3_transformers_packed reranker backend. | 1 | # Isolated dependencies for qwen3_transformers_packed reranker backend. |
| 2 | +# | ||
| 3 | +# Keep this stack aligned with the validated CUDA runtime on our hosts. | ||
| 4 | +# On this machine, torch 2.11.0 + cu130 fails CUDA init, while torch 2.10.0 + cu128 works. | ||
| 5 | +# We also cap transformers <5 to stay on the same family as the working vLLM score env. | ||
| 2 | 6 | ||
| 3 | -r requirements_reranker_qwen3_transformers.txt | 7 | -r requirements_reranker_qwen3_transformers.txt |
| 8 | +torch==2.10.0 | ||
| 9 | +transformers>=4.51.0,<5 |
reranker/backends/qwen3_vllm.py
| @@ -45,6 +45,19 @@ def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]: | @@ -45,6 +45,19 @@ def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]: | ||
| 45 | return unique_texts, position_to_unique | 45 | return unique_texts, position_to_unique |
| 46 | 46 | ||
| 47 | 47 | ||
| 48 | +def _format_instruction__standard(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: | ||
| 49 | + """Build chat messages for one (query, doc) pair.""" | ||
| 50 | + return [ | ||
| 51 | + { | ||
| 52 | + "role": "system", | ||
| 53 | + "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".", | ||
| 54 | + }, | ||
| 55 | + { | ||
| 56 | + "role": "user", | ||
| 57 | + "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}", | ||
| 58 | + }, | ||
| 59 | + ] | ||
| 60 | + | ||
| 48 | def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: | 61 | def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: |
| 49 | """Build chat messages for one (query, doc) pair.""" | 62 | """Build chat messages for one (query, doc) pair.""" |
| 50 | return [ | 63 | return [ |
| @@ -54,11 +67,10 @@ def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str | @@ -54,11 +67,10 @@ def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str | ||
| 54 | }, | 67 | }, |
| 55 | { | 68 | { |
| 56 | "role": "user", | 69 | "role": "user", |
| 57 | - "content": f"<Query>: {query}\n\n<Document>: {doc}", | 70 | + "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}", |
| 58 | }, | 71 | }, |
| 59 | ] | 72 | ] |
| 60 | 73 | ||
| 61 | - | ||
| 62 | class Qwen3VLLMRerankerBackend: | 74 | class Qwen3VLLMRerankerBackend: |
| 63 | """ | 75 | """ |
| 64 | Qwen3-Reranker-0.6B with vLLM inference. | 76 | Qwen3-Reranker-0.6B with vLLM inference. |
| @@ -78,6 +90,17 @@ class Qwen3VLLMRerankerBackend: | @@ -78,6 +90,17 @@ class Qwen3VLLMRerankerBackend: | ||
| 78 | self._config.get("instruction") | 90 | self._config.get("instruction") |
| 79 | or "Given a query, score the product for relevance" | 91 | or "Given a query, score the product for relevance" |
| 80 | ) | 92 | ) |
| 93 | + _fmt = str(self._config.get("instruction_format") or "compact").strip().lower() | ||
| 94 | + if _fmt not in {"standard", "compact"}: | ||
| 95 | + raise ValueError( | ||
| 96 | + f"instruction_format must be 'standard' or 'compact', got {_fmt!r}" | ||
| 97 | + ) | ||
| 98 | + self._instruction_format = _fmt | ||
| 99 | + self._format_messages = ( | ||
| 100 | + _format_instruction__standard | ||
| 101 | + if self._instruction_format == "standard" | ||
| 102 | + else _format_instruction | ||
| 103 | + ) | ||
| 81 | infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 64) | 104 | infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 64) |
| 82 | sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH") | 105 | sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH") |
| 83 | if sort_by_doc_length is None: | 106 | if sort_by_doc_length is None: |
| @@ -95,13 +118,15 @@ class Qwen3VLLMRerankerBackend: | @@ -95,13 +118,15 @@ class Qwen3VLLMRerankerBackend: | ||
| 95 | ) | 118 | ) |
| 96 | 119 | ||
| 97 | logger.info( | 120 | logger.info( |
| 98 | - "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)", | 121 | + "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, " |
| 122 | + "instruction_format=%s)", | ||
| 99 | model_name, | 123 | model_name, |
| 100 | max_model_len, | 124 | max_model_len, |
| 101 | tensor_parallel_size, | 125 | tensor_parallel_size, |
| 102 | gpu_memory_utilization, | 126 | gpu_memory_utilization, |
| 103 | dtype, | 127 | dtype, |
| 104 | enable_prefix_caching, | 128 | enable_prefix_caching, |
| 129 | + self._instruction_format, | ||
| 105 | ) | 130 | ) |
| 106 | 131 | ||
| 107 | self._llm = LLM( | 132 | self._llm = LLM( |
| @@ -145,7 +170,7 @@ class Qwen3VLLMRerankerBackend: | @@ -145,7 +170,7 @@ class Qwen3VLLMRerankerBackend: | ||
| 145 | ) -> List[TokensPrompt]: | 170 | ) -> List[TokensPrompt]: |
| 146 | """Build tokenized prompts for vLLM from (query, doc) pairs. Batch apply_chat_template.""" | 171 | """Build tokenized prompts for vLLM from (query, doc) pairs. Batch apply_chat_template.""" |
| 147 | messages_batch = [ | 172 | messages_batch = [ |
| 148 | - _format_instruction(self._instruction, q, d) for q, d in pairs | 173 | + self._format_messages(self._instruction, q, d) for q, d in pairs |
| 149 | ] | 174 | ] |
| 150 | tokenized = self._tokenizer.apply_chat_template( | 175 | tokenized = self._tokenizer.apply_chat_template( |
| 151 | messages_batch, | 176 | messages_batch, |
| @@ -242,6 +267,7 @@ class Qwen3VLLMRerankerBackend: | @@ -242,6 +267,7 @@ class Qwen3VLLMRerankerBackend: | ||
| 242 | "infer_batch_size": self._infer_batch_size, | 267 | "infer_batch_size": self._infer_batch_size, |
| 243 | "inference_batches": 0, | 268 | "inference_batches": 0, |
| 244 | "sort_by_doc_length": self._sort_by_doc_length, | 269 | "sort_by_doc_length": self._sort_by_doc_length, |
| 270 | + "instruction_format": self._instruction_format, | ||
| 245 | } | 271 | } |
| 246 | 272 | ||
| 247 | # Deduplicate globally by text, keep mapping to original indices. | 273 | # Deduplicate globally by text, keep mapping to original indices. |
| @@ -289,6 +315,7 @@ class Qwen3VLLMRerankerBackend: | @@ -289,6 +315,7 @@ class Qwen3VLLMRerankerBackend: | ||
| 289 | "normalize": normalize, | 315 | "normalize": normalize, |
| 290 | "infer_batch_size": self._infer_batch_size, | 316 | "infer_batch_size": self._infer_batch_size, |
| 291 | "inference_batches": inference_batches, | 317 | "inference_batches": inference_batches, |
| 292 | - "sort_by_doc_length": self._sort_by_doc_length | 318 | + "sort_by_doc_length": self._sort_by_doc_length, |
| 319 | + "instruction_format": self._instruction_format, | ||
| 293 | } | 320 | } |
| 294 | return output_scores, meta | 321 | return output_scores, meta |
reranker/backends/qwen3_vllm_score.py
| @@ -37,6 +37,8 @@ _DEFAULT_PREFIX = ( | @@ -37,6 +37,8 @@ _DEFAULT_PREFIX = ( | ||
| 37 | _DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" | 37 | _DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n" |
| 38 | _DEFAULT_QUERY_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n" | 38 | _DEFAULT_QUERY_TEMPLATE = "{prefix}<Instruct>: {instruction}\n<Query>: {query}\n" |
| 39 | _DEFAULT_DOCUMENT_TEMPLATE = "<Document>: {doc}{suffix}" | 39 | _DEFAULT_DOCUMENT_TEMPLATE = "<Document>: {doc}{suffix}" |
| 40 | +# compact:与 qwen3_vllm._format_instruction 一致(instruction 作 system,user 内重复 Instruct) | ||
| 41 | +_IM_USER_START = "<|im_end|>\n<|im_start|>user\n" | ||
| 40 | 42 | ||
| 41 | 43 | ||
| 42 | def _resolve_vllm_attention_config(config: Dict[str, Any]) -> Dict[str, Any] | None: | 44 | def _resolve_vllm_attention_config(config: Dict[str, Any]) -> Dict[str, Any] | None: |
| @@ -99,6 +101,12 @@ class Qwen3VLLMScoreRerankerBackend: | @@ -99,6 +101,12 @@ class Qwen3VLLMScoreRerankerBackend: | ||
| 99 | self._config.get("instruction") | 101 | self._config.get("instruction") |
| 100 | or "Given a query, score the product for relevance" | 102 | or "Given a query, score the product for relevance" |
| 101 | ) | 103 | ) |
| 104 | + _fmt = str(self._config.get("instruction_format") or "standard").strip().lower() | ||
| 105 | + if _fmt not in {"standard", "compact"}: | ||
| 106 | + raise ValueError( | ||
| 107 | + f"instruction_format must be 'standard' or 'compact', got {_fmt!r}" | ||
| 108 | + ) | ||
| 109 | + self._instruction_format = _fmt | ||
| 102 | self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX) | 110 | self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX) |
| 103 | self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX) | 111 | self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX) |
| 104 | self._query_template = str(self._config.get("query_template") or _DEFAULT_QUERY_TEMPLATE) | 112 | self._query_template = str(self._config.get("query_template") or _DEFAULT_QUERY_TEMPLATE) |
| @@ -142,7 +150,8 @@ class Qwen3VLLMScoreRerankerBackend: | @@ -142,7 +150,8 @@ class Qwen3VLLMScoreRerankerBackend: | ||
| 142 | 150 | ||
| 143 | logger.info( | 151 | logger.info( |
| 144 | "[Qwen3_VLLM_SCORE] Loading model %s (LLM.score API, runner=%s, convert=%s, " | 152 | "[Qwen3_VLLM_SCORE] Loading model %s (LLM.score API, runner=%s, convert=%s, " |
| 145 | - "hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)", | 153 | + "hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, " |
| 154 | + "instruction_format=%s)", | ||
| 146 | model_name, | 155 | model_name, |
| 147 | runner, | 156 | runner, |
| 148 | convert, | 157 | convert, |
| @@ -152,6 +161,7 @@ class Qwen3VLLMScoreRerankerBackend: | @@ -152,6 +161,7 @@ class Qwen3VLLMScoreRerankerBackend: | ||
| 152 | gpu_memory_utilization, | 161 | gpu_memory_utilization, |
| 153 | dtype, | 162 | dtype, |
| 154 | enable_prefix_caching, | 163 | enable_prefix_caching, |
| 164 | + self._instruction_format, | ||
| 155 | ) | 165 | ) |
| 156 | 166 | ||
| 157 | # vLLM 0.17+ uses runner/convert instead of LLM(..., task="score"). With the official | 167 | # vLLM 0.17+ uses runner/convert instead of LLM(..., task="score"). With the official |
| @@ -190,6 +200,14 @@ class Qwen3VLLMScoreRerankerBackend: | @@ -190,6 +200,14 @@ class Qwen3VLLMScoreRerankerBackend: | ||
| 190 | logger.info("[Qwen3_VLLM_SCORE] Model ready | model=%s", model_name) | 200 | logger.info("[Qwen3_VLLM_SCORE] Model ready | model=%s", model_name) |
| 191 | 201 | ||
| 192 | def _format_pair(self, query: str, doc: str) -> Tuple[str, str]: | 202 | def _format_pair(self, query: str, doc: str) -> Tuple[str, str]: |
| 203 | + if self._instruction_format == "compact": | ||
| 204 | + # Align with reranker.backends.qwen3_vllm._format_instruction query/doc split for LLM.score(). | ||
| 205 | + compact_prefix = f"<|im_start|>system\n{self._instruction}{_IM_USER_START}" | ||
| 206 | + q_text = ( | ||
| 207 | + f"{compact_prefix}<Instruct>: {self._instruction}\n\n<Query>: {query}\n" | ||
| 208 | + ) | ||
| 209 | + d_text = f"\n<Document>: {doc}{self._suffix}" | ||
| 210 | + return q_text, d_text | ||
| 193 | q_text = self._query_template.format( | 211 | q_text = self._query_template.format( |
| 194 | prefix=self._prefix, | 212 | prefix=self._prefix, |
| 195 | instruction=self._instruction, | 213 | instruction=self._instruction, |
| @@ -255,6 +273,7 @@ class Qwen3VLLMScoreRerankerBackend: | @@ -255,6 +273,7 @@ class Qwen3VLLMScoreRerankerBackend: | ||
| 255 | "infer_batch_size": self._infer_batch_size, | 273 | "infer_batch_size": self._infer_batch_size, |
| 256 | "inference_batches": 0, | 274 | "inference_batches": 0, |
| 257 | "sort_by_doc_length": self._sort_by_doc_length, | 275 | "sort_by_doc_length": self._sort_by_doc_length, |
| 276 | + "instruction_format": self._instruction_format, | ||
| 258 | } | 277 | } |
| 259 | 278 | ||
| 260 | indexed_texts = [text for _, text in indexed] | 279 | indexed_texts = [text for _, text in indexed] |
| @@ -299,5 +318,6 @@ class Qwen3VLLMScoreRerankerBackend: | @@ -299,5 +318,6 @@ class Qwen3VLLMScoreRerankerBackend: | ||
| 299 | "infer_batch_size": self._infer_batch_size, | 318 | "infer_batch_size": self._infer_batch_size, |
| 300 | "inference_batches": inference_batches, | 319 | "inference_batches": inference_batches, |
| 301 | "sort_by_doc_length": self._sort_by_doc_length, | 320 | "sort_by_doc_length": self._sort_by_doc_length, |
| 321 | + "instruction_format": self._instruction_format, | ||
| 302 | } | 322 | } |
| 303 | return output_scores, meta | 323 | return output_scores, meta |