@@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
442
442
return {tokens, std::exp (nll / count), logit_history, prob_history};
443
443
}
444
444
445
- static results_perplexity perplexity (llama_context * ctx, const gpt_params & params) {
445
+ static results_perplexity perplexity (llama_context * ctx, const gpt_params & params, const int32_t n_ctx ) {
446
446
if (params.ppl_stride > 0 ) {
447
447
return perplexity_v2 (ctx, params);
448
448
}
@@ -453,7 +453,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
453
453
// BOS tokens will be added for each chunk before eval
454
454
455
455
const bool add_bos = llama_should_add_bos_token (llama_get_model (ctx));
456
- const int n_ctx = llama_n_ctx (ctx);
457
456
458
457
std::ofstream logits_stream;
459
458
if (!params.logits_file .empty ()) {
@@ -500,12 +499,16 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
500
499
501
500
const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
502
501
502
+ GGML_ASSERT (n_batch < n_ctx || n_batch % n_ctx == 0 );
503
+ const int n_seq = std::max (1 , n_batch / n_ctx);
504
+ llama_batch batch = llama_batch_init (std::min (n_batch, n_ctx*n_seq), 0 , 1 );
505
+
503
506
std::vector<float > logits;
504
507
if (num_batches > 1 ) {
505
508
logits.reserve ((size_t )n_ctx * n_vocab);
506
509
}
507
510
508
- fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d\n " , __func__, n_chunk, n_batch);
511
+ fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d, n_seq=%d \n " , __func__, n_chunk, n_batch, n_seq );
509
512
510
513
std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
511
514
@@ -518,10 +521,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
518
521
log_probs.resize (n_ctx * nv);
519
522
}
520
523
521
- for (int i = 0 ; i < n_chunk; ++i) {
524
+ // We get the logits for all the tokens in the context window (params.n_ctx)
525
+ // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
526
+ // calculate the perplexity over the last half of the window (so the model always has
527
+ // some context to predict the token).
528
+ //
529
+ // We rely on the fact that attention in the forward pass only looks at previous
530
+ // tokens here, so the logits returned for each token are an accurate representation
531
+ // of what the model would have predicted at that point.
532
+ //
533
+ // Example, we have a context window of 512, we will compute perplexity for each of the
534
+ // last 256 tokens. Then, we split the input up into context window size chunks to
535
+ // process the entire prompt.
536
+ const int first = n_ctx/2 ;
537
+
538
+ for (int i = 0 ; i < n_chunk; i += n_seq) {
522
539
const int start = i * n_ctx;
523
540
const int end = start + n_ctx;
524
541
542
+ const int n_seq_batch = std::min (n_seq, n_chunk - i);
543
+
525
544
const auto t_start = std::chrono::high_resolution_clock::now ();
526
545
527
546
// clear the KV cache
@@ -531,22 +550,37 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
531
550
const int batch_start = start + j * n_batch;
532
551
const int batch_size = std::min (end - batch_start, n_batch);
533
552
534
- // save original token and restore it after eval
535
- const auto token_org = tokens[batch_start];
553
+ batch.n_tokens = 0 ;
554
+ for (int seq = 0 ; seq < n_seq_batch; seq++) {
555
+ int seq_start = batch_start + seq*n_ctx;
536
556
537
- // add BOS token for the first batch of each chunk
538
- if (add_bos && j == 0 ) {
539
- tokens[batch_start] = llama_token_bos (llama_get_model (ctx));
557
+ // save original token and restore it after eval
558
+ const auto token_org = tokens[seq_start];
559
+
560
+ // add BOS token for the first batch of each chunk
561
+ if (add_bos && j == 0 ) {
562
+ tokens[seq_start] = llama_token_bos (llama_get_model (ctx));
563
+ }
564
+
565
+ for (int k = 0 ; k < batch_size; ++k) {
566
+ const int idx = seq*n_ctx + k;
567
+ batch.token [idx] = tokens[seq_start + k];
568
+ batch.pos [idx] = j*n_batch + k;
569
+ batch.n_seq_id [idx] = 1 ;
570
+ batch.seq_id [idx][0 ] = seq;
571
+ batch.logits [idx] = batch.pos [idx] >= first ? 1 : 0 ;
572
+ }
573
+ batch.n_tokens += batch_size;
574
+
575
+ // restore the original token in case it was set to BOS
576
+ tokens[seq_start] = token_org;
540
577
}
541
578
542
- if (llama_decode (ctx, llama_batch_get_one (tokens. data () + batch_start, batch_size, j * n_batch, 0 ) )) {
579
+ if (llama_decode (ctx, batch )) {
543
580
fprintf (stderr, " %s : failed to eval\n " , __func__);
544
581
return {tokens, -1 , logit_history, prob_history};
545
582
}
546
583
547
- // restore the original token in case it was set to BOS
548
- tokens[batch_start] = token_org;
549
-
550
584
if (num_batches > 1 ) {
551
585
const auto * batch_logits = llama_get_logits (ctx);
552
586
logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
@@ -558,45 +592,39 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
558
592
if (i == 0 ) {
559
593
const float t_total = std::chrono::duration<float >(t_end - t_start).count ();
560
594
fprintf (stderr, " %s: %.2f seconds per pass - ETA " , __func__, t_total);
561
- int total_seconds = (int )(t_total * n_chunk);
595
+ int total_seconds = (int )(t_total* n_chunk/n_seq );
562
596
if (total_seconds >= 60 *60 ) {
563
597
fprintf (stderr, " %d hours " , total_seconds / (60 *60 ));
564
598
total_seconds = total_seconds % (60 *60 );
565
599
}
566
600
fprintf (stderr, " %.2f minutes\n " , total_seconds / 60.0 );
567
601
}
568
602
569
- // We get the logits for all the tokens in the context window (params.n_ctx)
570
- // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
571
- // calculate the perplexity over the last half of the window (so the model always has
572
- // some context to predict the token).
573
- //
574
- // We rely on the fact that attention in the forward pass only looks at previous
575
- // tokens here, so the logits returned for each token are an accurate representation
576
- // of what the model would have predicted at that point.
577
- //
578
- // Example, we have a context window of 512, we will compute perplexity for each of the
579
- // last 256 tokens. Then, we split the input up into context window size chunks to
580
- // process the entire prompt.
581
- const int first = n_ctx/2 ;
582
- const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits (ctx);
583
- if (!params.logits_file .empty ()) {
584
- process_logits (logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
585
- workers, log_probs, nll, nll2);
586
- } else {
587
- process_logits (n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
588
- workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
589
- }
590
- count += n_ctx - first - 1 ;
591
-
592
- // perplexity is e^(average negative log-likelihood)
593
- if (params.ppl_output_type == 0 ) {
594
- printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
595
- } else {
596
- double av = nll/count;
597
- double av2 = nll2/count - av*av;
598
- if (av2 > 0 ) av2 = sqrt (av2/(count-1 ));
599
- printf (" %8d %.4lf %4lf %4lf\n " , i*n_ctx, std::exp (nll / count), av, av2);
603
+ for (int seq = 0 ; seq < n_seq_batch; seq++) {
604
+ const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits_ith (ctx, seq*n_ctx);
605
+ llama_token * tokens_data = tokens.data () + start + seq*n_ctx + first;
606
+ if (!params.logits_file .empty ()) {
607
+ process_logits (logits_stream, n_vocab, all_logits + first*n_vocab,
608
+ tokens_data, n_ctx - 1 - first,
609
+ workers, log_probs, nll, nll2);
610
+ } else {
611
+ process_logits (n_vocab, all_logits + first*n_vocab,
612
+ tokens_data, n_ctx - 1 - first,
613
+ workers, nll, nll2,
614
+ logit_history.data () + start + seq*n_ctx + first,
615
+ prob_history.data () + start + seq*n_ctx + first);
616
+ }
617
+ count += n_ctx - first - 1 ;
618
+
619
+ // perplexity is e^(average negative log-likelihood)
620
+ if (params.ppl_output_type == 0 ) {
621
+ printf (" [%d]%.4lf," , i + seq + 1 , std::exp (nll / count));
622
+ } else {
623
+ double av = nll/count;
624
+ double av2 = nll2/count - av*av;
625
+ if (av2 > 0 ) av2 = sqrt (av2/(count-1 ));
626
+ printf (" %8d %.4lf %4lf %4lf\n " , i*n_ctx, std::exp (nll / count), av, av2);
627
+ }
600
628
}
601
629
fflush (stdout);
602
630
@@ -615,6 +643,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
615
643
printf (" Unexpected negative standard deviation of log(prob)\n " );
616
644
}
617
645
646
+ llama_batch_free (batch);
647
+
618
648
return {tokens, ppl, logit_history, prob_history};
619
649
}
620
650
@@ -1782,13 +1812,22 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1782
1812
int main (int argc, char ** argv) {
1783
1813
gpt_params params;
1784
1814
1785
- params.n_batch = 512 ;
1786
1815
if (!gpt_params_parse (argc, argv, params)) {
1787
1816
return 1 ;
1788
1817
}
1789
1818
1790
1819
params.logits_all = true ;
1791
- params.n_batch = std::min (params.n_batch , params.n_ctx );
1820
+
1821
+ const int32_t n_ctx = params.n_ctx ;
1822
+
1823
+ const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence ;
1824
+ if (ppl) {
1825
+ int32_t n_kv = std::max (1 , params.n_batch / n_ctx) * n_ctx;
1826
+ params.n_ctx = n_kv;
1827
+ params.n_batch = std::min (params.n_batch , n_kv);
1828
+ } else {
1829
+ params.n_batch = std::min (params.n_batch , params.n_ctx );
1830
+ }
1792
1831
1793
1832
if (params.ppl_stride > 0 ) {
1794
1833
fprintf (stderr, " Will perform strided perplexity calculation -> adjusting context size from %d to %d\n " ,
@@ -1847,7 +1886,7 @@ int main(int argc, char ** argv) {
1847
1886
} else if (params.kl_divergence ) {
1848
1887
kl_divergence (ctx, params);
1849
1888
} else {
1850
- results = perplexity (ctx, params);
1889
+ results = perplexity (ctx, params, n_ctx );
1851
1890
}
1852
1891
1853
1892
llama_print_timings (ctx);
0 commit comments