Skip to content

Commit babfe9e

Browse files
committed
perplexity : support using multiple sequences to allow larger batch sizes
ggml-ci
1 parent c2101a2 commit babfe9e

File tree

2 files changed

+102
-52
lines changed

2 files changed

+102
-52
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 87 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
442442
return {tokens, std::exp(nll / count), logit_history, prob_history};
443443
}
444444

445-
static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
445+
static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
446446
if (params.ppl_stride > 0) {
447447
return perplexity_v2(ctx, params);
448448
}
@@ -453,7 +453,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
453453
// BOS tokens will be added for each chunk before eval
454454

455455
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
456-
const int n_ctx = llama_n_ctx(ctx);
457456

458457
std::ofstream logits_stream;
459458
if (!params.logits_file.empty()) {
@@ -500,12 +499,16 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
500499

501500
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
502501

502+
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
503+
const int n_seq = std::max(1, n_batch / n_ctx);
504+
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
505+
503506
std::vector<float> logits;
504507
if (num_batches > 1) {
505508
logits.reserve((size_t)n_ctx * n_vocab);
506509
}
507510

508-
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
511+
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_batch, n_seq);
509512

510513
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
511514

@@ -518,10 +521,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
518521
log_probs.resize(n_ctx * nv);
519522
}
520523

521-
for (int i = 0; i < n_chunk; ++i) {
524+
// We get the logits for all the tokens in the context window (params.n_ctx)
525+
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
526+
// calculate the perplexity over the last half of the window (so the model always has
527+
// some context to predict the token).
528+
//
529+
// We rely on the fact that attention in the forward pass only looks at previous
530+
// tokens here, so the logits returned for each token are an accurate representation
531+
// of what the model would have predicted at that point.
532+
//
533+
// Example, we have a context window of 512, we will compute perplexity for each of the
534+
// last 256 tokens. Then, we split the input up into context window size chunks to
535+
// process the entire prompt.
536+
const int first = n_ctx/2;
537+
538+
for (int i = 0; i < n_chunk; i += n_seq) {
522539
const int start = i * n_ctx;
523540
const int end = start + n_ctx;
524541

542+
const int n_seq_batch = std::min(n_seq, n_chunk - i);
543+
525544
const auto t_start = std::chrono::high_resolution_clock::now();
526545

527546
// clear the KV cache
@@ -531,22 +550,37 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
531550
const int batch_start = start + j * n_batch;
532551
const int batch_size = std::min(end - batch_start, n_batch);
533552

534-
// save original token and restore it after eval
535-
const auto token_org = tokens[batch_start];
553+
batch.n_tokens = 0;
554+
for (int seq = 0; seq < n_seq_batch; seq++) {
555+
int seq_start = batch_start + seq*n_ctx;
536556

537-
// add BOS token for the first batch of each chunk
538-
if (add_bos && j == 0) {
539-
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
557+
// save original token and restore it after eval
558+
const auto token_org = tokens[seq_start];
559+
560+
// add BOS token for the first batch of each chunk
561+
if (add_bos && j == 0) {
562+
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
563+
}
564+
565+
for (int k = 0; k < batch_size; ++k) {
566+
const int idx = seq*n_ctx + k;
567+
batch.token[idx] = tokens[seq_start + k];
568+
batch.pos[idx] = j*n_batch + k;
569+
batch.n_seq_id[idx] = 1;
570+
batch.seq_id[idx][0] = seq;
571+
batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
572+
}
573+
batch.n_tokens += batch_size;
574+
575+
// restore the original token in case it was set to BOS
576+
tokens[seq_start] = token_org;
540577
}
541578

542-
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
579+
if (llama_decode(ctx, batch)) {
543580
fprintf(stderr, "%s : failed to eval\n", __func__);
544581
return {tokens, -1, logit_history, prob_history};
545582
}
546583

547-
// restore the original token in case it was set to BOS
548-
tokens[batch_start] = token_org;
549-
550584
if (num_batches > 1) {
551585
const auto * batch_logits = llama_get_logits(ctx);
552586
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
@@ -558,45 +592,39 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
558592
if (i == 0) {
559593
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
560594
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
561-
int total_seconds = (int)(t_total * n_chunk);
595+
int total_seconds = (int)(t_total*n_chunk/n_seq);
562596
if (total_seconds >= 60*60) {
563597
fprintf(stderr, "%d hours ", total_seconds / (60*60));
564598
total_seconds = total_seconds % (60*60);
565599
}
566600
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
567601
}
568602

569-
// We get the logits for all the tokens in the context window (params.n_ctx)
570-
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
571-
// calculate the perplexity over the last half of the window (so the model always has
572-
// some context to predict the token).
573-
//
574-
// We rely on the fact that attention in the forward pass only looks at previous
575-
// tokens here, so the logits returned for each token are an accurate representation
576-
// of what the model would have predicted at that point.
577-
//
578-
// Example, we have a context window of 512, we will compute perplexity for each of the
579-
// last 256 tokens. Then, we split the input up into context window size chunks to
580-
// process the entire prompt.
581-
const int first = n_ctx/2;
582-
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
583-
if (!params.logits_file.empty()) {
584-
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
585-
workers, log_probs, nll, nll2);
586-
} else {
587-
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
588-
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
589-
}
590-
count += n_ctx - first - 1;
591-
592-
// perplexity is e^(average negative log-likelihood)
593-
if (params.ppl_output_type == 0) {
594-
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
595-
} else {
596-
double av = nll/count;
597-
double av2 = nll2/count - av*av;
598-
if (av2 > 0) av2 = sqrt(av2/(count-1));
599-
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
603+
for (int seq = 0; seq < n_seq_batch; seq++) {
604+
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
605+
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
606+
if (!params.logits_file.empty()) {
607+
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
608+
tokens_data, n_ctx - 1 - first,
609+
workers, log_probs, nll, nll2);
610+
} else {
611+
process_logits(n_vocab, all_logits + first*n_vocab,
612+
tokens_data, n_ctx - 1 - first,
613+
workers, nll, nll2,
614+
logit_history.data() + start + seq*n_ctx + first,
615+
prob_history.data() + start + seq*n_ctx + first);
616+
}
617+
count += n_ctx - first - 1;
618+
619+
// perplexity is e^(average negative log-likelihood)
620+
if (params.ppl_output_type == 0) {
621+
printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
622+
} else {
623+
double av = nll/count;
624+
double av2 = nll2/count - av*av;
625+
if (av2 > 0) av2 = sqrt(av2/(count-1));
626+
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
627+
}
600628
}
601629
fflush(stdout);
602630

@@ -615,6 +643,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
615643
printf("Unexpected negative standard deviation of log(prob)\n");
616644
}
617645

646+
llama_batch_free(batch);
647+
618648
return {tokens, ppl, logit_history, prob_history};
619649
}
620650

@@ -1782,13 +1812,22 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
17821812
int main(int argc, char ** argv) {
17831813
gpt_params params;
17841814

1785-
params.n_batch = 512;
17861815
if (!gpt_params_parse(argc, argv, params)) {
17871816
return 1;
17881817
}
17891818

17901819
params.logits_all = true;
1791-
params.n_batch = std::min(params.n_batch, params.n_ctx);
1820+
1821+
const int32_t n_ctx = params.n_ctx;
1822+
1823+
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
1824+
if (ppl) {
1825+
int32_t n_kv = std::max(1, params.n_batch / n_ctx) * n_ctx;
1826+
params.n_ctx = n_kv;
1827+
params.n_batch = std::min(params.n_batch, n_kv);
1828+
} else {
1829+
params.n_batch = std::min(params.n_batch, params.n_ctx);
1830+
}
17921831

17931832
if (params.ppl_stride > 0) {
17941833
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
@@ -1847,7 +1886,7 @@ int main(int argc, char ** argv) {
18471886
} else if (params.kl_divergence) {
18481887
kl_divergence(ctx, params);
18491888
} else {
1850-
results = perplexity(ctx, params);
1889+
results = perplexity(ctx, params, n_ctx);
18511890
}
18521891

18531892
llama_print_timings(ctx);

llama.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8925,17 +8925,28 @@ static int llama_decode_internal(
89258925

89268926
if (batch.logits) {
89278927
logits_out.resize(n_vocab * n_tokens);
8928+
int32_t i_first = -1;
89288929
for (uint32_t i = 0; i < n_tokens; i++) {
8929-
if (batch.logits[i] == 0) {
8930-
continue;
8930+
if (batch.logits[i] == 0 || i == n_tokens - 1) {
8931+
if (i_first != -1) {
8932+
int i_last = batch.logits[i] == 0 ? i : i + 1;
8933+
// extract logits for the range [i_first, i_last)
8934+
// group the requests to minimize the number of calls to the backend
8935+
ggml_backend_tensor_get_async(backend_res, res,
8936+
logits_out.data() + (n_vocab*i_first),
8937+
(n_vocab*i_first)*sizeof(float),
8938+
(i_last - i_first)*n_vocab*sizeof(float));
8939+
i_first = -1;
8940+
}
8941+
} else if (i_first == -1) {
8942+
i_first = (int32_t) i;
89318943
}
8932-
ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
89338944
#ifndef NDEBUG
89348945
logits_valid[i] = true;
89358946
#endif
89368947
}
89378948
} else if (lctx.logits_all) {
8938-
logits_out.resize(n_vocab * n_tokens);
8949+
logits_out.resize(n_vocab*n_tokens);
89398950
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
89408951
#ifndef NDEBUG
89418952
std::fill(logits_valid.begin(), logits_valid.end(), true);

0 commit comments

Comments
 (0)