diff --git a/config/config.yaml b/config/config.yaml index 776e7b5..7499494 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -402,6 +402,8 @@ services: enforce_eager: false infer_batch_size: 100 sort_by_doc_length: true + # 与 reranker/backends/qwen3_vllm.py 一致:standard=_format_instruction__standard(固定 yes/no system);compact=_format_instruction(instruction 作 system 且 user 内重复 Instruct) + instruction_format: compact # instruction: "Given a query, score the product for relevance" # "rank products by given query" 比 “Given a query, score the product for relevance” 更好点 # instruction: "rank products by given query, category match first" @@ -433,6 +435,9 @@ services: enforce_eager: false infer_batch_size: 100 sort_by_doc_length: true + # 与 qwen3_vllm 同名项语义一致;默认 standard 与 vLLM 官方 Qwen3 reranker 前缀一致 + # instruction_format: standard + instruction_format: compact instruction: "Rank products by query with category & style match prioritized" qwen3_transformers: model_name: "Qwen/Qwen3-Reranker-0.6B" diff --git a/requirements_reranker_qwen3_transformers_packed.txt b/requirements_reranker_qwen3_transformers_packed.txt index 025981f..a4878ae 100644 --- a/requirements_reranker_qwen3_transformers_packed.txt +++ b/requirements_reranker_qwen3_transformers_packed.txt @@ -1,3 +1,9 @@ # Isolated dependencies for qwen3_transformers_packed reranker backend. +# +# Keep this stack aligned with the validated CUDA runtime on our hosts. +# On this machine, torch 2.11.0 + cu130 fails CUDA init, while torch 2.10.0 + cu128 works. +# We also cap transformers <5 to stay on the same family as the working vLLM score env. -r requirements_reranker_qwen3_transformers.txt +torch==2.10.0 +transformers>=4.51.0,<5 diff --git a/reranker/backends/qwen3_vllm.py b/reranker/backends/qwen3_vllm.py index 4903700..4462666 100644 --- a/reranker/backends/qwen3_vllm.py +++ b/reranker/backends/qwen3_vllm.py @@ -45,6 +45,19 @@ def deduplicate_with_positions(texts: List[str]) -> Tuple[List[str], List[int]]: return unique_texts, position_to_unique +def _format_instruction__standard(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: + """Build chat messages for one (query, doc) pair.""" + return [ + { + "role": "system", + "content": "Judge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".", + }, + { + "role": "user", + "content": f": {instruction}\n\n: {query}\n\n: {doc}", + }, + ] + def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str, str]]: """Build chat messages for one (query, doc) pair.""" return [ @@ -54,11 +67,10 @@ def _format_instruction(instruction: str, query: str, doc: str) -> List[Dict[str }, { "role": "user", - "content": f": {query}\n\n: {doc}", + "content": f": {instruction}\n\n: {query}\n\n: {doc}", }, ] - class Qwen3VLLMRerankerBackend: """ Qwen3-Reranker-0.6B with vLLM inference. @@ -78,6 +90,17 @@ class Qwen3VLLMRerankerBackend: self._config.get("instruction") or "Given a query, score the product for relevance" ) + _fmt = str(self._config.get("instruction_format") or "compact").strip().lower() + if _fmt not in {"standard", "compact"}: + raise ValueError( + f"instruction_format must be 'standard' or 'compact', got {_fmt!r}" + ) + self._instruction_format = _fmt + self._format_messages = ( + _format_instruction__standard + if self._instruction_format == "standard" + else _format_instruction + ) infer_batch_size = os.getenv("RERANK_VLLM_INFER_BATCH_SIZE") or self._config.get("infer_batch_size", 64) sort_by_doc_length = os.getenv("RERANK_VLLM_SORT_BY_DOC_LENGTH") if sort_by_doc_length is None: @@ -95,13 +118,15 @@ class Qwen3VLLMRerankerBackend: ) logger.info( - "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)", + "[Qwen3_VLLM] Loading model %s (max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, " + "instruction_format=%s)", model_name, max_model_len, tensor_parallel_size, gpu_memory_utilization, dtype, enable_prefix_caching, + self._instruction_format, ) self._llm = LLM( @@ -145,7 +170,7 @@ class Qwen3VLLMRerankerBackend: ) -> List[TokensPrompt]: """Build tokenized prompts for vLLM from (query, doc) pairs. Batch apply_chat_template.""" messages_batch = [ - _format_instruction(self._instruction, q, d) for q, d in pairs + self._format_messages(self._instruction, q, d) for q, d in pairs ] tokenized = self._tokenizer.apply_chat_template( messages_batch, @@ -242,6 +267,7 @@ class Qwen3VLLMRerankerBackend: "infer_batch_size": self._infer_batch_size, "inference_batches": 0, "sort_by_doc_length": self._sort_by_doc_length, + "instruction_format": self._instruction_format, } # Deduplicate globally by text, keep mapping to original indices. @@ -289,6 +315,7 @@ class Qwen3VLLMRerankerBackend: "normalize": normalize, "infer_batch_size": self._infer_batch_size, "inference_batches": inference_batches, - "sort_by_doc_length": self._sort_by_doc_length + "sort_by_doc_length": self._sort_by_doc_length, + "instruction_format": self._instruction_format, } return output_scores, meta diff --git a/reranker/backends/qwen3_vllm_score.py b/reranker/backends/qwen3_vllm_score.py index aab8e40..49d378b 100644 --- a/reranker/backends/qwen3_vllm_score.py +++ b/reranker/backends/qwen3_vllm_score.py @@ -37,6 +37,8 @@ _DEFAULT_PREFIX = ( _DEFAULT_SUFFIX = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" _DEFAULT_QUERY_TEMPLATE = "{prefix}: {instruction}\n: {query}\n" _DEFAULT_DOCUMENT_TEMPLATE = ": {doc}{suffix}" +# compact:与 qwen3_vllm._format_instruction 一致(instruction 作 system,user 内重复 Instruct) +_IM_USER_START = "<|im_end|>\n<|im_start|>user\n" def _resolve_vllm_attention_config(config: Dict[str, Any]) -> Dict[str, Any] | None: @@ -99,6 +101,12 @@ class Qwen3VLLMScoreRerankerBackend: self._config.get("instruction") or "Given a query, score the product for relevance" ) + _fmt = str(self._config.get("instruction_format") or "standard").strip().lower() + if _fmt not in {"standard", "compact"}: + raise ValueError( + f"instruction_format must be 'standard' or 'compact', got {_fmt!r}" + ) + self._instruction_format = _fmt self._prefix = str(self._config.get("prompt_prefix") or _DEFAULT_PREFIX) self._suffix = str(self._config.get("prompt_suffix") or _DEFAULT_SUFFIX) self._query_template = str(self._config.get("query_template") or _DEFAULT_QUERY_TEMPLATE) @@ -142,7 +150,8 @@ class Qwen3VLLMScoreRerankerBackend: logger.info( "[Qwen3_VLLM_SCORE] Loading model %s (LLM.score API, runner=%s, convert=%s, " - "hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s)", + "hf_overrides=%s, max_model_len=%s, tp=%s, gpu_mem=%.2f, dtype=%s, prefix_caching=%s, " + "instruction_format=%s)", model_name, runner, convert, @@ -152,6 +161,7 @@ class Qwen3VLLMScoreRerankerBackend: gpu_memory_utilization, dtype, enable_prefix_caching, + self._instruction_format, ) # vLLM 0.17+ uses runner/convert instead of LLM(..., task="score"). With the official @@ -190,6 +200,14 @@ class Qwen3VLLMScoreRerankerBackend: logger.info("[Qwen3_VLLM_SCORE] Model ready | model=%s", model_name) def _format_pair(self, query: str, doc: str) -> Tuple[str, str]: + if self._instruction_format == "compact": + # Align with reranker.backends.qwen3_vllm._format_instruction query/doc split for LLM.score(). + compact_prefix = f"<|im_start|>system\n{self._instruction}{_IM_USER_START}" + q_text = ( + f"{compact_prefix}: {self._instruction}\n\n: {query}\n" + ) + d_text = f"\n: {doc}{self._suffix}" + return q_text, d_text q_text = self._query_template.format( prefix=self._prefix, instruction=self._instruction, @@ -255,6 +273,7 @@ class Qwen3VLLMScoreRerankerBackend: "infer_batch_size": self._infer_batch_size, "inference_batches": 0, "sort_by_doc_length": self._sort_by_doc_length, + "instruction_format": self._instruction_format, } indexed_texts = [text for _, text in indexed] @@ -299,5 +318,6 @@ class Qwen3VLLMScoreRerankerBackend: "infer_batch_size": self._infer_batch_size, "inference_batches": inference_batches, "sort_by_doc_length": self._sort_by_doc_length, + "instruction_format": self._instruction_format, } return output_scores, meta -- libgit2 0.21.2