9f5994b4
tangwang
reranker
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
"""Utilities for reranker batching and deduplication."""
from __future__ import annotations
from typing import Iterable, List, Sequence, Tuple
def deduplicate_with_positions(texts: Sequence[str]) -> Tuple[List[str], List[int]]:
"""
Deduplicate texts globally while preserving first-seen order.
Returns:
unique_texts: deduplicated texts in first-seen order
position_to_unique: mapping from each original position to unique index
"""
unique_texts: List[str] = []
position_to_unique: List[int] = []
seen: dict[str, int] = {}
for text in texts:
idx = seen.get(text)
if idx is None:
idx = len(unique_texts)
seen[text] = idx
unique_texts.append(text)
position_to_unique.append(idx)
return unique_texts, position_to_unique
def sort_indices_by_length(lengths: Sequence[int]) -> List[int]:
"""Return stable ascending indices by lengths."""
return sorted(range(len(lengths)), key=lambda i: lengths[i])
def iter_batches(indices: Sequence[int], batch_size: int) -> Iterable[List[int]]:
"""Yield consecutive batches from indices."""
if batch_size <= 0:
raise ValueError(f"batch_size must be > 0, got {batch_size}")
for i in range(0, len(indices), batch_size):
yield list(indices[i : i + batch_size])
|