@@ -315,13 +315,15 @@ def sample(
315
315
)
316
316
317
317
vocab_size = tokenizer .vocab_size
318
+ num_special_tokens = tokenizer .num_special_tokens_to_add ()
319
+ real_input_len = input_len - num_special_tokens
318
320
319
321
prefix_token_ids = (np .random .randint (
320
322
0 , vocab_size , size = prefix_len ).tolist () if prefix_len > 0 else [])
321
323
322
324
# New sampling logic: [X * (1 - b), X * (1 + b)]
323
- input_low = int (input_len * (1 - range_ratio ))
324
- input_high = int (input_len * (1 + range_ratio ))
325
+ input_low = int (real_input_len * (1 - range_ratio ))
326
+ input_high = int (real_input_len * (1 + range_ratio ))
325
327
output_low = int (output_len * (1 - range_ratio ))
326
328
output_high = int (output_len * (1 + range_ratio ))
327
329
@@ -344,6 +346,17 @@ def sample(
344
346
vocab_size ).tolist ()
345
347
token_sequence = prefix_token_ids + inner_seq
346
348
prompt = tokenizer .decode (token_sequence )
349
+ # After decoding the prompt we have to encode and decode it again.
350
+ # This is done because in some cases N consecutive tokens
351
+ # give a string tokenized into != N number of tokens.
352
+ # For example for GPT2Tokenizer:
353
+ # [6880, 6881] -> ['Ġcalls', 'here'] ->
354
+ # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
355
+ # To avoid uncontrolled change of the prompt length,
356
+ # the encoded sequence is truncated before being decode again.
357
+ re_encoded_sequence = tokenizer .encode (
358
+ prompt , add_special_tokens = False )[:input_lens [i ]]
359
+ prompt = tokenizer .decode (re_encoded_sequence )
347
360
total_input_len = prefix_len + int (input_lens [i ])
348
361
requests .append (
349
362
SampleRequest (
0 commit comments