batching_utils.py 1.31 KB
"""Utilities for reranker batching and deduplication."""

from __future__ import annotations

from typing import Iterable, List, Sequence, Tuple


def deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]:
    """
    Deduplicate texts globally while preserving first-seen order.

    Returns:
        unique_texts: deduplicated texts in first-seen order
        position_to_unique: mapping from each original position to unique index
    """
    unique_texts: List[str] = []
    position_to_unique: List[int] = []
    seen: dict[str, int] = {}

    for text in texts:
        idx = seen.get(text)
        if idx is None:
            idx = len(unique_texts)
            seen[text] = idx
            unique_texts.append(text)
        position_to_unique.append(idx)

    return unique_texts, position_to_unique


def sort_indices_by_length(lengths: Sequence[int]) -> List[int]:
    """Return stable ascending indices by lengths."""
    return sorted(range(len(lengths)), key=lambda i: lengths[i])


def iter_batches(indices: Sequence[int], batch_size: int) -> Iterable[List[int]]:
    """Yield consecutive batches from indices."""
    if batch_size <= 0:
        raise ValueError(f"batch_size must be > 0, got {batch_size}")
    for i in range(0, len(indices), batch_size):
        yield list(indices[i : i + batch_size])