Blame view

tests/test_reranker_batching_utils.py 923 Bytes
9f5994b4   tangwang   reranker
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
  import pytest
  
  from reranker.backends.batching_utils import (
      deduplicate_with_positions,
      iter_batches,
      sort_indices_by_length,
  )
  
  
  def test_deduplicate_with_positions_global_not_adjacent():
      texts = ["a", "b", "a", "c", "b", "a"]
      unique, mapping = deduplicate_with_positions(texts)
      assert unique == ["a", "b", "c"]
      assert mapping == [0, 1, 0, 2, 1, 0]
  
  
  def test_sort_indices_by_length_stable():
      lengths = [5, 2, 2, 9, 4]
      order = sort_indices_by_length(lengths)
      # Stable sort: index 1 remains ahead of index 2 when lengths are equal.
      assert order == [1, 2, 4, 0, 3]
  
  
  def test_iter_batches():
      indices = list(range(10))
      batches = list(iter_batches(indices, 4))
      assert batches == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]]
  
  
  def test_iter_batches_invalid_batch_size():
      with pytest.raises(ValueError, match="batch_size must be > 0"):
          list(iter_batches([0, 1], 0))