"""Utilities for reranker batching and deduplication.""" from __future__ import annotations from typing import Iterable, List, Sequence, Tuple def deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]: """ Deduplicate texts globally while preserving first-seen order. Returns: unique_texts: deduplicated texts in first-seen order position_to_unique: mapping from each original position to unique index """ unique_texts: List[str] = [] position_to_unique: List[int] = [] seen: dict[str, int] = {} for text in texts: idx = seen.get(text) if idx is None: idx = len(unique_texts) seen[text] = idx unique_texts.append(text) position_to_unique.append(idx) return unique_texts, position_to_unique def sort_indices_by_length(lengths: Sequence[int]) -> List[int]: """Return stable ascending indices by lengths.""" return sorted(range(len(lengths)), key=lambda i: lengths[i]) def iter_batches(indices: Sequence[int], batch_size: int) -> Iterable[List[int]]: """Yield consecutive batches from indices.""" if batch_size <= 0: raise ValueError(f"batch_size must be > 0, got {batch_size}") for i in range(0, len(indices), batch_size): yield list(indices[i : i + batch_size])