@@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
442
442
return {tokens, std::exp (nll / count), logit_history, prob_history};
443
443
}
444
444
445
- static results_perplexity perplexity (llama_context * ctx, const gpt_params & params) {
445
+ static results_perplexity perplexity (llama_context * ctx, const gpt_params & params, const int32_t n_ctx ) {
446
446
if (params.ppl_stride > 0 ) {
447
447
return perplexity_v2 (ctx, params);
448
448
}
@@ -453,7 +453,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
453
453
// BOS tokens will be added for each chunk before eval
454
454
455
455
const bool add_bos = llama_should_add_bos_token (llama_get_model (ctx));
456
- const int n_ctx = llama_n_ctx (ctx);
457
456
458
457
std::ofstream logits_stream;
459
458
if (!params.logits_file .empty ()) {
@@ -499,13 +498,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
499
498
double nll2 = 0.0 ;
500
499
501
500
const int num_batches = (n_ctx + n_batch - 1 ) / n_batch;
501
+ const int n_seq = std::max (1 , n_batch / n_ctx);
502
+
503
+ GGML_ASSERT (n_batch < n_ctx || n_batch % n_ctx == 0 );
504
+ GGML_ASSERT (params.n_ctx == n_seq * n_ctx);
505
+
506
+ llama_batch batch = llama_batch_init (std::min (n_batch, n_ctx*n_seq), 0 , 1 );
502
507
503
508
std::vector<float > logits;
504
509
if (num_batches > 1 ) {
505
510
logits.reserve ((size_t )n_ctx * n_vocab);
506
511
}
507
512
508
- fprintf (stderr, " %s: calculating perplexity over %d chunks, batch_size=%d\n " , __func__, n_chunk, n_batch);
513
+ fprintf (stderr, " %s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d \n " , __func__, n_chunk, n_ctx, n_batch, n_seq );
509
514
510
515
std::vector<std::thread> workers (std::thread::hardware_concurrency () - 1 );
511
516
@@ -518,10 +523,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
518
523
log_probs.resize (n_ctx * nv);
519
524
}
520
525
521
- for (int i = 0 ; i < n_chunk; ++i) {
526
+ // We get the logits for all the tokens in the context window (params.n_ctx)
527
+ // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
528
+ // calculate the perplexity over the last half of the window (so the model always has
529
+ // some context to predict the token).
530
+ //
531
+ // We rely on the fact that attention in the forward pass only looks at previous
532
+ // tokens here, so the logits returned for each token are an accurate representation
533
+ // of what the model would have predicted at that point.
534
+ //
535
+ // Example, we have a context window of 512, we will compute perplexity for each of the
536
+ // last 256 tokens. Then, we split the input up into context window size chunks to
537
+ // process the entire prompt.
538
+ const int first = n_ctx/2 ;
539
+
540
+ for (int i = 0 ; i < n_chunk; i += n_seq) {
522
541
const int start = i * n_ctx;
523
542
const int end = start + n_ctx;
524
543
544
+ const int n_seq_batch = std::min (n_seq, n_chunk - i);
545
+
525
546
const auto t_start = std::chrono::high_resolution_clock::now ();
526
547
527
548
// clear the KV cache
@@ -531,22 +552,37 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
531
552
const int batch_start = start + j * n_batch;
532
553
const int batch_size = std::min (end - batch_start, n_batch);
533
554
534
- // save original token and restore it after eval
535
- const auto token_org = tokens[batch_start];
555
+ batch.n_tokens = 0 ;
556
+ for (int seq = 0 ; seq < n_seq_batch; seq++) {
557
+ int seq_start = batch_start + seq*n_ctx;
536
558
537
- // add BOS token for the first batch of each chunk
538
- if (add_bos && j == 0 ) {
539
- tokens[batch_start] = llama_token_bos (llama_get_model (ctx));
559
+ // save original token and restore it after eval
560
+ const auto token_org = tokens[seq_start];
561
+
562
+ // add BOS token for the first batch of each chunk
563
+ if (add_bos && j == 0 ) {
564
+ tokens[seq_start] = llama_token_bos (llama_get_model (ctx));
565
+ }
566
+
567
+ for (int k = 0 ; k < batch_size; ++k) {
568
+ const int idx = seq*n_ctx + k;
569
+ batch.token [idx] = tokens[seq_start + k];
570
+ batch.pos [idx] = j*n_batch + k;
571
+ batch.n_seq_id [idx] = 1 ;
572
+ batch.seq_id [idx][0 ] = seq;
573
+ batch.logits [idx] = batch.pos [idx] >= first ? 1 : 0 ;
574
+ }
575
+ batch.n_tokens += batch_size;
576
+
577
+ // restore the original token in case it was set to BOS
578
+ tokens[seq_start] = token_org;
540
579
}
541
580
542
- if (llama_decode (ctx, llama_batch_get_one (tokens. data () + batch_start, batch_size, j * n_batch, 0 ) )) {
581
+ if (llama_decode (ctx, batch )) {
543
582
fprintf (stderr, " %s : failed to eval\n " , __func__);
544
583
return {tokens, -1 , logit_history, prob_history};
545
584
}
546
585
547
- // restore the original token in case it was set to BOS
548
- tokens[batch_start] = token_org;
549
-
550
586
if (num_batches > 1 ) {
551
587
const auto * batch_logits = llama_get_logits (ctx);
552
588
logits.insert (logits.end (), batch_logits, batch_logits + batch_size * n_vocab);
@@ -558,45 +594,39 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
558
594
if (i == 0 ) {
559
595
const float t_total = std::chrono::duration<float >(t_end - t_start).count ();
560
596
fprintf (stderr, " %s: %.2f seconds per pass - ETA " , __func__, t_total);
561
- int total_seconds = (int )(t_total * n_chunk);
597
+ int total_seconds = (int )(t_total* n_chunk/n_seq );
562
598
if (total_seconds >= 60 *60 ) {
563
599
fprintf (stderr, " %d hours " , total_seconds / (60 *60 ));
564
600
total_seconds = total_seconds % (60 *60 );
565
601
}
566
602
fprintf (stderr, " %.2f minutes\n " , total_seconds / 60.0 );
567
603
}
568
604
569
- // We get the logits for all the tokens in the context window (params.n_ctx)
570
- // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
571
- // calculate the perplexity over the last half of the window (so the model always has
572
- // some context to predict the token).
573
- //
574
- // We rely on the fact that attention in the forward pass only looks at previous
575
- // tokens here, so the logits returned for each token are an accurate representation
576
- // of what the model would have predicted at that point.
577
- //
578
- // Example, we have a context window of 512, we will compute perplexity for each of the
579
- // last 256 tokens. Then, we split the input up into context window size chunks to
580
- // process the entire prompt.
581
- const int first = n_ctx/2 ;
582
- const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits (ctx);
583
- if (!params.logits_file .empty ()) {
584
- process_logits (logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
585
- workers, log_probs, nll, nll2);
586
- } else {
587
- process_logits (n_vocab, all_logits + first*n_vocab, tokens.data () + start + first, n_ctx - 1 - first,
588
- workers, nll, nll2, logit_history.data () + start + first, prob_history.data () + start + first);
589
- }
590
- count += n_ctx - first - 1 ;
591
-
592
- // perplexity is e^(average negative log-likelihood)
593
- if (params.ppl_output_type == 0 ) {
594
- printf (" [%d]%.4lf," , i + 1 , std::exp (nll / count));
595
- } else {
596
- double av = nll/count;
597
- double av2 = nll2/count - av*av;
598
- if (av2 > 0 ) av2 = sqrt (av2/(count-1 ));
599
- printf (" %8d %.4lf %4lf %4lf\n " , i*n_ctx, std::exp (nll / count), av, av2);
605
+ for (int seq = 0 ; seq < n_seq_batch; seq++) {
606
+ const float * all_logits = num_batches > 1 ? logits.data () : llama_get_logits_ith (ctx, seq*n_ctx);
607
+ llama_token * tokens_data = tokens.data () + start + seq*n_ctx + first;
608
+ if (!params.logits_file .empty ()) {
609
+ process_logits (logits_stream, n_vocab, all_logits + first*n_vocab,
610
+ tokens_data, n_ctx - 1 - first,
611
+ workers, log_probs, nll, nll2);
612
+ } else {
613
+ process_logits (n_vocab, all_logits + first*n_vocab,
614
+ tokens_data, n_ctx - 1 - first,
615
+ workers, nll, nll2,
616
+ logit_history.data () + start + seq*n_ctx + first,
617
+ prob_history.data () + start + seq*n_ctx + first);
618
+ }
619
+ count += n_ctx - first - 1 ;
620
+
621
+ // perplexity is e^(average negative log-likelihood)
622
+ if (params.ppl_output_type == 0 ) {
623
+ printf (" [%d]%.4lf," , i + seq + 1 , std::exp (nll / count));
624
+ } else {
625
+ double av = nll/count;
626
+ double av2 = nll2/count - av*av;
627
+ if (av2 > 0 ) av2 = sqrt (av2/(count-1 ));
628
+ printf (" %8d %.4lf %4lf %4lf\n " , i*n_ctx, std::exp (nll / count), av, av2);
629
+ }
600
630
}
601
631
fflush (stdout);
602
632
@@ -615,6 +645,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
615
645
printf (" Unexpected negative standard deviation of log(prob)\n " );
616
646
}
617
647
648
+ llama_batch_free (batch);
649
+
618
650
return {tokens, ppl, logit_history, prob_history};
619
651
}
620
652
@@ -1782,13 +1814,24 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
1782
1814
int main (int argc, char ** argv) {
1783
1815
gpt_params params;
1784
1816
1785
- params.n_batch = 512 ;
1786
1817
if (!gpt_params_parse (argc, argv, params)) {
1787
1818
return 1 ;
1788
1819
}
1789
1820
1790
1821
params.logits_all = true ;
1791
- params.n_batch = std::min (params.n_batch , params.n_ctx );
1822
+
1823
+ const int32_t n_ctx = params.n_ctx ;
1824
+
1825
+ const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence ;
1826
+ if (ppl) {
1827
+ int n_seq = std::max (1 , params.n_batch / n_ctx);
1828
+ int32_t n_kv = n_seq * n_ctx;
1829
+ params.n_parallel = n_seq;
1830
+ params.n_ctx = n_kv;
1831
+ params.n_batch = std::min (params.n_batch , n_kv);
1832
+ } else {
1833
+ params.n_batch = std::min (params.n_batch , params.n_ctx );
1834
+ }
1792
1835
1793
1836
if (params.ppl_stride > 0 ) {
1794
1837
fprintf (stderr, " Will perform strided perplexity calculation -> adjusting context size from %d to %d\n " ,
@@ -1847,7 +1890,7 @@ int main(int argc, char ** argv) {
1847
1890
} else if (params.kl_divergence ) {
1848
1891
kl_divergence (ctx, params);
1849
1892
} else {
1850
- results = perplexity (ctx, params);
1893
+ results = perplexity (ctx, params, n_ctx );
1851
1894
}
1852
1895
1853
1896
llama_print_timings (ctx);
0 commit comments