Blame view

reranker/backends/batching_utils.py 1.31 KB
9f5994b4   tangwang   reranker
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
  """Utilities for reranker batching and deduplication."""
  
  from __future__ import annotations
  
  from typing import Iterable, List, Sequence, Tuple
  
  
  def deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]:
      """
      Deduplicate texts globally while preserving first-seen order.
  
      Returns:
          unique_texts: deduplicated texts in first-seen order
          position_to_unique: mapping from each original position to unique index
      """
      unique_texts: List[str] = []
      position_to_unique: List[int] = []
      seen: dict[str, int] = {}
  
      for text in texts:
          idx = seen.get(text)
          if idx is None:
              idx = len(unique_texts)
              seen[text] = idx
              unique_texts.append(text)
          position_to_unique.append(idx)
  
      return unique_texts, position_to_unique
  
  
  def sort_indices_by_length(lengths: Sequence[int]) -> List[int]:
      """Return stable ascending indices by lengths."""
      return sorted(range(len(lengths)), key=lambda i: lengths[i])
  
  
  def iter_batches(indices: Sequence[int], batch_size: int) -> Iterable[List[int]]:
      """Yield consecutive batches from indices."""
      if batch_size <= 0:
          raise ValueError(f"batch_size must be > 0, got {batch_size}")
      for i in range(0, len(indices), batch_size):
          yield list(indices[i : i + batch_size])