-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtokenization_utils.py
More file actions
444 lines (351 loc) · 16.1 KB
/
tokenization_utils.py
File metadata and controls
444 lines (351 loc) · 16.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
import random
from transformers import AutoTokenizer
from typing import Union, Literal
from functools import lru_cache
import torch
from tqdm import tqdm
import sys
# set maximum recursion depth
sys.setrecursionlimit(10**8)
# doing things this way because subclassing AutoTokenizer is not
# a great way to spend one's evening
# so instead we load this tokenizer and change it in-place
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B")
tokenizer.pad_token_id = 128005
# just making sure...
tokenizer.bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>")
tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
start_of_text = [tokenizer.convert_tokens_to_ids("<|begin_of_text|>")]
start_header = [tokenizer.convert_tokens_to_ids("<|start_header_id|>")]
whois = {
name: [tokenizer.convert_tokens_to_ids(name)]
for name in ["system", "user", "assistant"]
}
end_header = [
tokenizer.convert_tokens_to_ids("<|end_header_id|>"),
tokenizer.convert_tokens_to_ids("ĊĊ"),
]
end_message = [tokenizer.convert_tokens_to_ids("<|eot_id|>")]
def format_message(
message: list[int] | str,
who: Union[str, Literal["system", "user", "assistant"]],
ended: bool = True,
):
if isinstance(message, str):
message = tokenize(message)
if isinstance(message, tuple):
message = list(message)
end = end_message if ended else []
return start_header + whois[who] + end_header + message + end
def tokenize(text: str):
return tokenizer(text, add_special_tokens=False)["input_ids"]
def pad_sequences(sequences: list, pad_to_len: int | None = None, pad_with=tokenizer.pad_token_id):
""" "We always pad on the left."""
if pad_to_len is None:
pad_to_len = max(len(seq) for seq in sequences)
return [[pad_with] * (pad_to_len - len(seq)) + seq for seq in sequences]
original_tokenizer_dict = {v: k for (k, v) in tokenizer.get_vocab().items()}
def pretokenize(text: str, return_list=False):
pretokenization = [
x[0] for x in tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
]
if return_list:
return pretokenization
else:
return "".join(pretokenization)
def byte_level_tokenize(text: str):
text = pretokenize(text)
return [tokenize(x)[0] for x in text]
hashable_original_dict = frozenset(original_tokenizer_dict.items())
def min_len_tokenization(message):
tokenized = tokenize(message)
return len(tokenized)
def max_len_tokenization(message):
tokenized = byte_level_tokenize(message)
return len(tokenized)
@lru_cache(maxsize=None)
def get_all_tokenizations(
text: str, token_set: frozenset | None = None, pretokenize_before=True
):
# There's an optimization to be made here (faster result, but more memory):
# At each step filter a copy of original_tokenizer_dict with only the
# tokens that are present in the leftover text, and pass that down.
# This makes the iteration over its items much faster.
# I don't really need it, but just noting it here
if token_set is None:
token_set = hashable_original_dict
if pretokenize_before:
text = pretokenize(text)
# print(len(tok_set))
# contained_in = frozenset([(x, y) for (x, y) in tok_set if y in text])
if len(text) == 0:
return ((),)
if len(text) == 1:
return ((tokenizer.convert_tokens_to_ids(text),),)
tokenizations = []
for tok_id, tok in token_set:
if text.startswith(tok):
tokenizations += (
(tok_id,) + t
for t in get_all_tokenizations(
text[len(tok) :],
token_set=token_set,
pretokenize_before=False,
)
)
return tokenizations
def get_prob_of_continuation(
model,
device,
prefix_tokens: list[int],
continuation_tokens: list[tuple[int]],
):
import torch
from tqdm import tqdm
model.eval()
probs = {}
for continuation in tqdm(continuation_tokens):
full_input = prefix_tokens + list(continuation)
input_ids = torch.tensor(full_input).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits.to(torch.float64) # shape (1, seq_len, vocab)
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
# Get log probs of continuation tokens only
start = len(prefix_tokens)
end = start + len(continuation)
target_tokens = torch.tensor(continuation).to(device)
token_log_probs = log_probs[0, start - 1 : end - 1].gather(
1, target_tokens.unsqueeze(1)
).squeeze(1)
joint_log_prob = token_log_probs.sum()
joint_prob = torch.exp(joint_log_prob).item()
probs[continuation] = joint_prob # if you're scaling
return probs
def most_likey_completion(
model,
device,
prefix_tokens: list[int],
top_k: int = 5,
tokens_id_to_include: list[int] | None = None,
):
# Convert prefix tokens into a tensor and move it to the proper device.
input_ids = torch.tensor(prefix_tokens).unsqueeze(0).to(device)
with torch.no_grad():
# Compute the logits for the last token position.
logits = model(input_ids).logits[:, -1, :].to(torch.float64)
# Sort logits in descending order to get a complete ranking.
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
# Remove the batch dimension.
sorted_indices = sorted_indices.squeeze(0)
sorted_probs = sorted_probs.squeeze(0)
# Build the list for the top_k tokens.
top_k_tokens = [
(rank, int(token_id), float(prob))
for rank, (token_id, prob) in enumerate(
zip(sorted_indices[:top_k], sorted_probs[:top_k])
)
]
extra_tokens = []
if tokens_id_to_include is not None:
# Create a set of token IDs already present in top_k.
top_k_set = {token_id for _, token_id, _ in top_k_tokens}
for token_id in tokens_id_to_include:
if token_id not in top_k_set:
# Locate the ranking of this token in the sorted_indices.
# (This returns a tensor of indices; we expect a single match.)
rank_tensor = (sorted_indices == token_id).nonzero(as_tuple=True)[0]
if rank_tensor.numel() > 0:
rank = rank_tensor.item()
prob = float(sorted_probs[rank])
extra_tokens.append((rank, token_id, prob))
# Return top_k tokens first, then the extra tokens from tokens_id_to_include.
return top_k_tokens + extra_tokens
def decode_list(tokens: list[int]):
# decode the list of tokens, but return it as just list of tokens,
# rather than a string as is usually done
return tuple(tokenizer.decode(token) for token in tokens)
all_tokens = set(
[
tokenizer.convert_tokens_to_string([k])
for (k, v) in tokenizer.get_vocab().items()
]
)
@lru_cache(maxsize=None) # Using built-in LRU cache
def F_memoized(text: str, token_set: frozenset):
if len(text) in {0, 1}:
return 1
toks = []
for tok in token_set:
if text.startswith(tok):
toks.append(F_memoized(text[len(tok) :], token_set))
return sum(toks)
def count_tokenizations(text):
restricted_set = {x for x in all_tokens if x in text}
return F_memoized(text, frozenset(restricted_set))
def evaluate_theta(message, thetas):
contained_in_message = set()
pretokenized = pretokenize(message)
for (k, v) in original_tokenizer_dict.items():
if v in pretokenized:
contained_in_message.add((k, v))
contained_in_message = frozenset(contained_in_message)
found_tokenizations = sample_from_all_tokenizations(message, thetas, toks_to_consider=contained_in_message)
found_tokenizations_lengths = [len(x) for x in found_tokenizations]
return list(zip(thetas, found_tokenizations_lengths, found_tokenizations))
def find_tokenization_for_each_length(message, lengths: list[int]):
starting_points = [-220, 0, 220]
theta_values = dict()
found_toks = []
initial_thetas = evaluate_theta(message, starting_points)
for theta, length, _ in initial_thetas:
theta_values[theta] = length
assert min(lengths) >= theta_values[min(starting_points)], \
f"min_len: {min(lengths)}, theta_values[min(starting_points)]: {theta_values[min(starting_points)]}"
assert max(lengths) <= theta_values[max(starting_points)], \
f"max_len: {max(lengths)}, theta_values[max(starting_points)]: {theta_values[max(starting_points)]}"
while lengths:
thetas_to_try = []
batch_size = min(25, len(lengths))
for length in random.sample(lengths, batch_size):
below = {theta: val for theta, val in theta_values.items() if val < length}
above = {theta: val for theta, val in theta_values.items() if val > length}
below_max = max(below.items(), key=lambda x: x[1], default=None)
above_min = min(above.items(), key=lambda x: x[1], default=None)
if below_max is None:
lower_bound = -220
else:
lower_bound = below_max[0]
if above_min is None:
upper_bound = 220
else:
upper_bound = above_min[0]
# take 4 points between lower_bound and upper_bound, randomly sampled
for _ in range(8):
new_theta_to_try = random.random()*(upper_bound - lower_bound) + lower_bound
thetas_to_try.append(new_theta_to_try)
found_toks_theta = evaluate_theta(message, thetas_to_try)
for _, length, tokenization in found_toks_theta:
if length in lengths:
found_toks.append((length, tokenization))
lengths.remove(length)
for theta, length, _ in found_toks_theta:
theta_values[theta] = length
return found_toks
def find_all_integer_toks_between(message, min_len, max_len):
integers_to_find = list(range(min_len, max_len + 1))
starting_points = [-200, 0, 200]
theta_values = dict()
initial_thetas = evaluate_theta(message, starting_points)
for theta, length, _ in initial_thetas:
theta_values[theta] = length
assert theta_values[min(starting_points)] == min_len, f"min_len: {min_len}, theta_values[min(starting_points)]: {theta_values[min(starting_points)]}"
assert theta_values[max(starting_points)] == max_len, f"max_len: {max_len}, theta_values[max(starting_points)]: {theta_values[max(starting_points)]}"
found_toks = {
min_len: initial_thetas[0][2],
max_len: initial_thetas[2][2]
}
integers_to_find.remove(min_len)
integers_to_find.remove(max_len)
while len(integers_to_find) > 0:
thetas_to_try = []
for integer in integers_to_find:
below = {theta: val for theta, val in theta_values.items() if val < integer}
above = {theta: val for theta, val in theta_values.items() if val > integer}
below_max = max(below.items(), key=lambda x: x[1], default=None)
above_min = min(above.items(), key=lambda x: x[1], default=None)
lower_bound = below_max[0]
upper_bound = above_min[0]
dist_below = integer - below_max[1]
dist_above = above_min[1] - integer
total_dist = dist_below + dist_above
# because we want to weight inversely proportional:
new_theta_to_try = (lower_bound * dist_above
+ upper_bound * dist_below)/total_dist
thetas_to_try.append(new_theta_to_try)
if len(integers_to_find) < 50:
thetas_to_try.append((lower_bound + upper_bound)/2)
more_points = torch.rand(2)*(upper_bound - lower_bound) + lower_bound
thetas_to_try.extend(more_points.tolist())
if len(integers_to_find) < 25:
thetas_to_try.append((lower_bound + upper_bound)/2)
more_points = torch.rand(4)*(upper_bound - lower_bound) + lower_bound
thetas_to_try.extend(more_points.tolist())
if len(integers_to_find) < 10:
more_points = torch.rand(10)*(upper_bound - lower_bound) + lower_bound
thetas_to_try.extend(more_points.tolist())
if len(integers_to_find) < 10:
lower_bound = below_max[min(2, len(below_max) -1 )]
upper_bound = above_min[min(2, len(below_max) -1 )]
more_points = torch.rand(10)*(upper_bound - lower_bound) + lower_bound
thetas_to_try.extend(more_points.tolist())
found_toks_theta = evaluate_theta(message, thetas_to_try)
for _, length, tokenization in found_toks_theta:
if length in integers_to_find:
found_toks[length] = tokenization
integers_to_find.remove(length)
# also update theta_values
for theta, length, _ in found_toks_theta:
theta_values[theta] = length
return found_toks
def sample_from_all_tokenizations(text, p_list, toks_to_consider=None):
def batch_distribution(lengths: list[int], p_list: list[float]) -> list[int]:
"""
For each p in p_list, sample an index from the distribution defined by lengths and p.
Args:
lengths: 1D list or tensor of object lengths, shape (n,).
p_list: 1D list or tensor of scalar values controlling bias toward long or short lengths,
shape (k,). A positive p biases toward longer lengths, while a negative p biases
toward shorter lengths.
Returns:
A list of k sampled indices, one for each p in p_list.
"""
# Convert lengths to a float64 Tensor
lengths_t = torch.tensor(lengths, dtype=torch.float64)
# Normalize lengths if there is more than one length value
if len(lengths_t) > 1:
lengths_t = lengths_t / torch.max(lengths_t)
# Convert p_list to a float64 Tensor of shape (k,)
p_t = torch.tensor(p_list, dtype=torch.float64)
# print("lengths_t.shape:", lengths_t.shape)
# print("p_t.shape:", p_t.shape)
# Compute weights for each p. Broadcasting makes this a (k x n) matrix,
# where n is the number of lengths and k is number of p values.
# Each row corresponds to torch.exp(p_t[i] * lengths_t).
weights = torch.exp(p_t.unsqueeze(1) * lengths_t.unsqueeze(0))
# Convert weights to probabilities by normalizing each row
# Shape: (k x n)
row_sums = torch.sum(weights, dim=1, keepdim=True) + 1e-20 # Avoid division by zero
probabilities = weights / row_sums
# For each row in probabilities (for each p), sample 1 index
# sampled_indices shape: (k, 1)
sampled_indices = torch.multinomial(probabilities, num_samples=1)
# Convert to a flat list of indices
return sampled_indices.squeeze(1).tolist()
toks = tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)
default_tokenizations = [tuple(tokenize(text[i:j])) for _, (i, j) in toks]
toks = [get_all_tokenizations(x[0], toks_to_consider, pretokenize_before=False) for x in toks]
lengths = [
[
0 if candidate == default_tokenizations[i] else len(candidate)
for candidate in candidates
]
for i, candidates in enumerate(toks)
]
chosen_toks_s = [batch_distribution(length, p_list) for length in lengths]
# print(chosen_toks_s)
# print(len(chosen_toks_s))
# print(len(chosen_toks_s[0]))
# final_dict = dict()
# for ix, p in enumerate(p_list):
# final_dict[p] = sum([toks[i][chosen_toks_s[ix][i]] for i in range(len(toks))], start=())
tokens_for_each = [
[tik[i] for tik in chosen_toks_s]
for i in range(len(p_list))
]
final_list = [
sum([toks[i][toki[i]] for i in range(len(toks))], start=())
for toki in tokens_for_each
]
return final_list