9f5994b4
tangwang
reranker
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
import pytest
from reranker.backends.batching_utils import (
deduplicate_with_positions,
iter_batches,
sort_indices_by_length,
)
def test_deduplicate_with_positions_global_not_adjacent():
texts = ["a", "b", "a", "c", "b", "a"]
unique, mapping = deduplicate_with_positions(texts)
assert unique == ["a", "b", "c"]
assert mapping == [0, 1, 0, 2, 1, 0]
def test_sort_indices_by_length_stable():
lengths = [5, 2, 2, 9, 4]
order = sort_indices_by_length(lengths)
# Stable sort: index 1 remains ahead of index 2 when lengths are equal.
assert order == [1, 2, 4, 0, 3]
def test_iter_batches():
indices = list(range(10))
batches = list(iter_batches(indices, 4))
assert batches == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]]
def test_iter_batches_invalid_batch_size():
with pytest.raises(ValueError, match="batch_size must be > 0"):
list(iter_batches([0, 1], 0))
|