Blame view

translation/text_splitter.py 7.6 KB
294c3d0a   tangwang   实现第一版“按模型预算智能分句”的...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
  """Utilities for token-budget-aware translation text splitting."""
  
  from __future__ import annotations
  
  from typing import Callable, List, Optional
  
  TokenLengthFn = Callable[[str], int]
  
  _CJK_LANGS = {"zh", "ja", "ko"}
  _STRONG_BOUNDARIES = {"\n", "。", "!", "?", "!", "?", ";", ";", "…"}
  _WEAK_BOUNDARIES = {",", ",", "、", ":", ":", "(", ")", "(", ")", "[", "]", "【", "】", "/", "|"}
  _CLOSING_CHARS = {'"', "'", "”", "’", ")", "]", "}", ")", "】", "》", "」", "』"}
  _NO_SPACE_BEFORE = tuple('.,!?;:)]}%>"\'')
  _NO_SPACE_AFTER = tuple("([{$#@/<")
  
  
  def is_cjk_language(lang: Optional[str]) -> bool:
      return str(lang or "").strip().lower() in _CJK_LANGS
  
  
  def compute_safe_input_token_limit(
      *,
      max_input_length: int,
      max_new_tokens: int,
      decoding_length_mode: str = "fixed",
      decoding_length_extra: int = 0,
      reserve_input_tokens: int = 8,
      reserve_output_tokens: int = 8,
  ) -> int:
      """Derive a conservative source-token budget for translation splitting.
  
      We keep a small reserve for tokenizer special tokens on the input side. If
      the decode side is much tighter than the encode side, we also cap the
      source budget based on decode settings so we split before the model is
      likely to truncate.
      """
  
      input_limit = max(8, int(max_input_length) - max(0, int(reserve_input_tokens)))
      decode_mode = str(decoding_length_mode or "fixed").strip().lower()
      if int(max_new_tokens) <= 0:
          return input_limit
      if decode_mode == "source":
          output_limit = max(8, int(max_new_tokens) - max(0, int(decoding_length_extra)))
          return max(8, min(input_limit, output_limit))
      if int(max_new_tokens) >= int(max_input_length):
          return input_limit
      output_limit = max(8, int(max_new_tokens) - max(0, int(reserve_output_tokens)))
      return max(8, min(input_limit, output_limit))
  
  
  def split_text_for_translation(
      text: str,
      *,
      max_tokens: int,
      token_length_fn: TokenLengthFn,
  ) -> List[str]:
      """Split long text into a few translation-friendly segments.
  
      The splitter prefers sentence boundaries, then clause boundaries, then
      whitespace, and only falls back to character-based splitting when needed.
      """
  
      if not text:
          return [text]
      if token_length_fn(text) <= max_tokens:
          return [text]
      segments = _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=0)
      return [segment for segment in segments if segment]
  
  
  def join_translated_segments(
      segments: List[Optional[str]],
      *,
      target_lang: Optional[str],
      original_text: str,
  ) -> Optional[str]:
      parts = [segment.strip() for segment in segments if segment and segment.strip()]
      if not parts:
          return None
      separator = "" if is_cjk_language(target_lang) else " "
      if "\n" in original_text and separator:
          separator = "\n"
  
      merged = parts[0]
      for part in parts[1:]:
          if not separator:
              merged += part
              continue
          if merged.endswith(_NO_SPACE_AFTER) or part.startswith(_NO_SPACE_BEFORE):
              merged += part
              continue
          merged += separator + part
      return merged.strip() or None
  
  
  def _split_recursive(
      text: str,
      *,
      max_tokens: int,
      token_length_fn: TokenLengthFn,
      level: int,
  ) -> List[str]:
      if token_length_fn(text) <= max_tokens:
          return [text]
      if level >= 3:
          return _hard_split(text, max_tokens=max_tokens, token_length_fn=token_length_fn)
  
      pieces = _split_by_level(text, level)
      if len(pieces) <= 1:
          return _split_recursive(text, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)
  
      merged: List[str] = []
      buffer = ""
      for piece in pieces:
          candidate = buffer + piece if buffer else piece
          if token_length_fn(candidate) <= max_tokens:
              buffer = candidate
              continue
          if buffer:
              merged.extend(
                  _split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1)
              )
              buffer = piece
              continue
          merged.extend(_split_recursive(piece, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1))
      if buffer:
          merged.extend(_split_recursive(buffer, max_tokens=max_tokens, token_length_fn=token_length_fn, level=level + 1))
      return merged
  
  
  def _split_by_level(text: str, level: int) -> List[str]:
      parts: List[str] = []
      start = 0
      index = 0
      while index < len(text):
          boundary_end = _match_boundary(text, index, level)
          if boundary_end is None:
              index += 1
              continue
          if boundary_end > start:
              parts.append(text[start:boundary_end])
              start = boundary_end
          index = boundary_end
      if start < len(text):
          parts.append(text[start:])
      return [part for part in parts if part]
  
  
  def _match_boundary(text: str, index: int, level: int) -> Optional[int]:
      char = text[index]
      if level == 0:
          if char in _STRONG_BOUNDARIES:
              return _consume_boundary_tail(text, index + 1)
          if char == "." and _is_sentence_period(text, index):
              return _consume_boundary_tail(text, index + 1)
          return None
      if level == 1:
          if char in _WEAK_BOUNDARIES:
              return _consume_boundary_tail(text, index + 1)
          return None
      if level == 2 and char.isspace():
          end = index + 1
          while end < len(text) and text[end].isspace():
              end += 1
          return end
      return None
  
  
  def _consume_boundary_tail(text: str, index: int) -> int:
      end = index
      while end < len(text) and text[end] in _CLOSING_CHARS:
          end += 1
      while end < len(text) and text[end].isspace():
          end += 1
      return end
  
  
  def _is_sentence_period(text: str, index: int) -> bool:
      prev_char = text[index - 1] if index > 0 else ""
      next_char = text[index + 1] if index + 1 < len(text) else ""
      if prev_char.isdigit() and next_char.isdigit():
          return False
      if not next_char:
          return True
      return next_char.isspace() or next_char in _CLOSING_CHARS
  
  
  def _hard_split(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> List[str]:
      segments: List[str] = []
      remaining = text
      while remaining:
          if token_length_fn(remaining) <= max_tokens:
              segments.append(remaining)
              break
          cut = _largest_prefix_within_limit(remaining, max_tokens=max_tokens, token_length_fn=token_length_fn)
          refined_cut = _refine_cut(remaining, cut, max_tokens=max_tokens, token_length_fn=token_length_fn)
          if refined_cut <= 0:
              refined_cut = max(1, cut)
          segments.append(remaining[:refined_cut])
          remaining = remaining[refined_cut:]
      return segments
  
  
  def _largest_prefix_within_limit(text: str, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int:
      low = 1
      high = len(text)
      best = 1
      while low <= high:
          mid = (low + high) // 2
          if token_length_fn(text[:mid]) <= max_tokens:
              best = mid
              low = mid + 1
              continue
          high = mid - 1
      return best
  
  
  def _refine_cut(text: str, cut: int, *, max_tokens: int, token_length_fn: TokenLengthFn) -> int:
      best = cut
      lower_bound = max(1, cut - 32)
      for candidate in range(cut, lower_bound - 1, -1):
          if text[candidate - 1].isspace() or text[candidate - 1] in _STRONG_BOUNDARIES or text[candidate - 1] in _WEAK_BOUNDARIES:
              if candidate >= max(1, cut // 2) and token_length_fn(text[:candidate]) <= max_tokens:
                  return candidate
              best = max(best, candidate)
      return best