test_reranker_batching_utils.py 923 Bytes
import pytest

from reranker.backends.batching_utils import (
    deduplicate_with_positions,
    iter_batches,
    sort_indices_by_length,
)


def test_deduplicate_with_positions_global_not_adjacent():
    texts = ["a", "b", "a", "c", "b", "a"]
    unique, mapping = deduplicate_with_positions(texts)
    assert unique == ["a", "b", "c"]
    assert mapping == [0, 1, 0, 2, 1, 0]


def test_sort_indices_by_length_stable():
    lengths = [5, 2, 2, 9, 4]
    order = sort_indices_by_length(lengths)
    # Stable sort: index 1 remains ahead of index 2 when lengths are equal.
    assert order == [1, 2, 4, 0, 3]


def test_iter_batches():
    indices = list(range(10))
    batches = list(iter_batches(indices, 4))
    assert batches == [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9]]


def test_iter_batches_invalid_batch_size():
    with pytest.raises(ValueError, match="batch_size must be > 0"):
        list(iter_batches([0, 1], 0))