From 4ed4fe75ed85b54b8c44a6e216d98eb20da35f3b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Feb 2025 00:48:12 +0100 Subject: [PATCH 01/52] first proposal for private llama_batch --- include/llama.h | 78 +++++++++++++++++++---------- src/llama-batch.cpp | 118 +++++++++++++++++++++++++++++++++++--------- src/llama-batch.h | 24 +++++++++ 3 files changed, 171 insertions(+), 49 deletions(-) diff --git a/include/llama.h b/include/llama.h index 1f5f3a09b311e..2c7569f8ab7a1 100644 --- a/include/llama.h +++ b/include/llama.h @@ -231,29 +231,7 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); - // Input data for llama_decode - // A llama_batch object can contain input about one or many sequences - // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens - // - // - token : the token ids of the input (used when embd is NULL) - // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) - // - pos : the positions of the respective token in the sequence - // (if set to NULL, the token position will be tracked automatically by llama_decode) - // - seq_id : the sequence to which the respective token belongs - // (if set to NULL, the sequence ID will be assumed to be 0) - // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output - // (if set to NULL, only the logits for last token will be returned) - // - typedef struct llama_batch { - int32_t n_tokens; - - llama_token * token; - float * embd; - llama_pos * pos; - int32_t * n_seq_id; - llama_seq_id ** seq_id; - int8_t * logits; // TODO: rename this to "output" - } llama_batch; + struct llama_batch; enum llama_model_kv_override_type { LLAMA_KV_OVERRIDE_TYPE_INT, @@ -829,7 +807,7 @@ extern "C" { // // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // - LLAMA_API struct llama_batch llama_batch_get_one( + LLAMA_API struct llama_batch * llama_batch_get_one( llama_token * tokens, int32_t n_tokens); @@ -840,13 +818,59 @@ extern "C" { // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token // The rest of the llama_batch members are allocated with size n_tokens // All members are left uninitialized - LLAMA_API struct llama_batch llama_batch_init( + // LLAMA_API struct llama_batch llama_batch_init( + // int32_t n_tokens, + // int32_t embd, + // int32_t n_seq_max); + + // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens + // Each token can be assigned up to n_seq_max sequence ids + // The batch has to be freed with llama_batch_free() + LLAMA_API struct llama_batch * llama_batch_init( int32_t n_tokens, - int32_t embd, int32_t n_seq_max); + // Same with llama_batch_init, but initializes the batch with the provided raw embeddings + LLAMA_API struct llama_batch * llama_batch_init_from_embd( + float * embd, + size_t n_embd, + int32_t pos0, + int32_t seq_id); + + // Add text tokens to the batch + // First token in the list starts at position pos0 + // Return values: + // 0 : success + // -1 : not enough space in the batch + // -2 : embd is already set, cannot add text tokens + LLAMA_API int32_t llama_batch_add_text( + struct llama_batch * batch, + llama_token * tokens, + size_t n_tokens, + int32_t pos0, + int32_t seq_id); + + // Same as llama_batch_add_text, but accepts multiple sequences + LLAMA_API int32_t llama_batch_add_text( + struct llama_batch * batch, + llama_token * tokens, + size_t n_tokens, + int32_t pos0, + int32_t * seq_ids, + size_t n_seq_ids); + + // Set logits for the token in the ith sequence + // If pos == -1, logits will be set for the all tokens + LLAMA_API int32_t llama_batch_set_logits( + struct llama_batch * batch, + int32_t pos, + int32_t seq_id); + + // Remove everything from the batch + LLAMA_API void llama_batch_clear(struct llama_batch * batch); + // Frees a batch of tokens allocated with llama_batch_init() - LLAMA_API void llama_batch_free(struct llama_batch batch); + LLAMA_API void llama_batch_free(struct llama_batch * batch); // Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Stores the encoder output internally for later use by the decoder cross-attention layers. diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57fd82b..027ac24139f5e 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -309,10 +309,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0 // interface implementation // -struct llama_batch llama_batch_get_one( +struct llama_batch * llama_batch_get_one( llama_token * tokens, int32_t n_tokens) { - return { + return new llama_batch{ /*n_tokens =*/ n_tokens, /*tokens =*/ tokens, /*embd =*/ nullptr, @@ -323,8 +323,8 @@ struct llama_batch llama_batch_get_one( }; } -struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { - llama_batch batch = { +static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch * batch = new llama_batch{ /*n_tokens =*/ 0, /*tokens =*/ nullptr, /*embd =*/ nullptr, @@ -335,34 +335,108 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ }; if (embd) { - batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); } else { - batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); } - batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); - batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); - batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); for (int i = 0; i < n_tokens_alloc; ++i) { - batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } - batch.seq_id[n_tokens_alloc] = nullptr; + batch->seq_id[n_tokens_alloc] = nullptr; - batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); return batch; } -void llama_batch_free(struct llama_batch batch) { - if (batch.token) free(batch.token); - if (batch.embd) free(batch.embd); - if (batch.pos) free(batch.pos); - if (batch.n_seq_id) free(batch.n_seq_id); - if (batch.seq_id) { - for (int i = 0; batch.seq_id[i] != nullptr; ++i) { - free(batch.seq_id[i]); +struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) { + return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max); +} + +struct llama_batch * llama_batch_init_from_embd( + float * embd, + size_t n_embd, + int32_t pos0, + int32_t seq_id) { + struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1); + memcpy(batch->embd, embd, n_embd * sizeof(float)); + for (int32_t i = 0; i < n_embd; i++) { + batch->pos [i] = pos0 + i; + batch->n_seq_id[i] = 1; + batch->seq_id [i][0] = seq_id; + } +} + +int32_t llama_batch_add_text( + struct llama_batch * batch, + llama_token * tokens, + size_t n_tokens, + int32_t pos0, + int32_t * seq_ids, + size_t n_seq_ids) { + if (batch->n_tokens + n_tokens > batch->n_tokens) { + return -1; + } + if (batch->embd) { + return -2; + } + for (int32_t i = 0; i < n_tokens; i++) { + batch->token [batch->n_tokens + i] = tokens[i]; + batch->pos [batch->n_tokens + i] = pos0 + i; + batch->n_seq_id[batch->n_tokens + i] = n_seq_ids; + for (int32_t j = 0; j < n_seq_ids; j++) { + batch->seq_id[batch->n_tokens + i][j] = seq_ids[j]; + } + } +} + +int32_t llama_batch_add_text( + struct llama_batch * batch, + llama_token * tokens, + size_t n_tokens, + int32_t pos0, + int32_t seq_id) { + std::array seq_ids = { seq_id }; + return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size()); +} + +int32_t llama_batch_set_logits( + struct llama_batch * batch, + int32_t pos, + int32_t seq_id) { + for (int32_t i = 0; i < batch->n_tokens; i++) { + // find the token having seq_id + for (int32_t j = 0; j < batch->n_seq_id[i]; j++) { + if (batch->seq_id[i][j] == seq_id) { + // found the sequence + if (pos == -1 || pos == batch->pos[i]) { + batch->logits[i] = true; + break; + } + } + } + } +} + +void llama_batch_clear(struct llama_batch * batch) { + batch->n_tokens = 0; +} + +void llama_batch_free(struct llama_batch * batch) { + if (batch->token) free(batch->token); + if (batch->embd) free(batch->embd); + if (batch->pos) free(batch->pos); + if (batch->n_seq_id) free(batch->n_seq_id); + if (batch->seq_id) { + for (int i = 0; batch->seq_id[i] != nullptr; ++i) { + free(batch->seq_id[i]); } - free(batch.seq_id); + free(batch->seq_id); } - if (batch.logits) free(batch.logits); + if (batch->logits) free(batch->logits); + delete batch; } diff --git a/src/llama-batch.h b/src/llama-batch.h index 773c3808b770f..de702da76d5b3 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -5,6 +5,30 @@ #include #include +// Input data for llama_decode +// A llama_batch object can contain input about one or many sequences +// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens +// +// - token : the token ids of the input (used when embd is NULL) +// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) +// - pos : the positions of the respective token in the sequence +// (if set to NULL, the token position will be tracked automatically by llama_decode) +// - seq_id : the sequence to which the respective token belongs +// (if set to NULL, the sequence ID will be assumed to be 0) +// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output +// (if set to NULL, only the logits for last token will be returned) +// +struct llama_batch { + int32_t n_tokens; + + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; // TODO: rename this to "output" +}; + // very similar to llama_batch, // but has more metadata about sequences struct llama_ubatch { From f2e59a8eb91bcfc0e89fe599e0ddf9dbdb3fb1b2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Feb 2025 18:16:49 +0100 Subject: [PATCH 02/52] rework, targeting llama-server --- .gitignore | 1 + common/common.cpp | 21 +++---- common/common.h | 4 +- common/speculative.cpp | 6 +- examples/server/server.cpp | 98 +++++++++++++---------------- include/llama-cpp.h | 5 ++ include/llama.h | 56 +++++++++++------ src/llama-batch.cpp | 124 +++++++++++++++++++++++++------------ src/llama-batch.h | 2 + src/llama.cpp | 8 +-- 10 files changed, 190 insertions(+), 135 deletions(-) diff --git a/.gitignore b/.gitignore index 694f36e042fb5..56b5ac2c18cfe 100644 --- a/.gitignore +++ b/.gitignore @@ -98,6 +98,7 @@ examples/server/*.css.hpp examples/server/*.html.hpp examples/server/*.js.hpp examples/server/*.mjs.hpp +examples/server/*.gz.hpp !build_64.sh !examples/*.bat !examples/*/*.kts diff --git a/common/common.cpp b/common/common.cpp index 8661e164ada6b..c79f1e73688d8 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -580,6 +580,7 @@ std::string string_from(const struct llama_context * ctx, const std::vector common_get_hf_file(const std::string &, cons // Batch utils // -void common_batch_clear(struct llama_batch & batch) { - batch.n_tokens = 0; +void common_batch_clear(struct llama_batch * batch) { + llama_batch_clear(batch); } void common_batch_add( - struct llama_batch & batch, + struct llama_batch * batch, llama_token id, llama_pos pos, const std::vector & seq_ids, bool logits) { - GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); - - batch.token [batch.n_tokens] = id; - batch.pos [batch.n_tokens] = pos; - batch.n_seq_id[batch.n_tokens] = seq_ids.size(); - for (size_t i = 0; i < seq_ids.size(); ++i) { - batch.seq_id[batch.n_tokens][i] = seq_ids[i]; + int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits); + if (res == -1) { + LOG_ERR("%s: llama_batch size exceeded\n", __func__); } - batch.logits [batch.n_tokens] = logits; - - batch.n_tokens++; } // diff --git a/common/common.h b/common/common.h index 98b9a4464787a..8ce6f2f127a6b 100644 --- a/common/common.h +++ b/common/common.h @@ -554,10 +554,10 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector & seq_ids, diff --git a/common/speculative.cpp b/common/speculative.cpp index 318e96ea35468..0836845ecc2a7 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -13,7 +13,7 @@ struct common_speculative { struct llama_context * ctx; struct common_sampler * smpl; - llama_batch batch; + llama_batch * batch; llama_tokens prompt; }; @@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init( auto * result = new common_speculative { /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1), /* .prompt = */ {}, }; @@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft( } // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { + if (llama_batch_get_n_tokens(batch) > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode(ctx, batch); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 71151183b81da..41f8dc505795d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1215,7 +1215,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - llama_batch batch_spec = {}; + llama_batch_ptr batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1787,7 +1787,7 @@ struct server_context { llama_context_params cparams_dft; - llama_batch batch = {}; + llama_batch_ptr batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1820,11 +1820,7 @@ struct server_context { common_speculative_free(slot.spec); slot.spec = nullptr; - - llama_batch_free(slot.batch_spec); } - - llama_batch_free(batch); } bool load_model(const common_params & params) { @@ -1944,7 +1940,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1)); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1969,7 +1965,7 @@ struct server_context { slot.reset(); - slots.push_back(slot); + slots.push_back(std::move(slot)); } default_generation_settings_for_props = slots[0].to_json(); @@ -1980,7 +1976,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1)); } metrics.init(); @@ -2098,9 +2094,7 @@ struct server_context { } if (slot.ctx_dft) { - llama_batch_free(slot.batch_spec); - - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1)); } slot.state = SLOT_STATE_STARTED; @@ -2408,7 +2402,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, const llama_batch & batch) { + void send_embedding(const server_slot & slot, llama_batch_ptr & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2419,18 +2413,19 @@ struct server_context { std::vector embd_res(n_embd, 0.0f); - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { + llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); + if (!tok.logits || tok.seq_id[0] != slot.id) { continue; } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); res->embedding.push_back(std::vector(n_embd, 0.0f)); continue; @@ -2451,24 +2446,25 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, const llama_batch & batch) { + void send_rerank(const server_slot & slot, llama_batch_ptr & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; res->n_tokens = slot.n_prompt_tokens; - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { + for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { + llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); + if (!tok.logits || tok.seq_id[0] != slot.id) { continue; } - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); + const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); res->score = -1e6; continue; @@ -2859,7 +2855,7 @@ struct server_context { } // start populating the batch for this iteration - common_batch_clear(batch); + common_batch_clear(batch.get()); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2881,9 +2877,9 @@ struct server_context { continue; } - slot.i_batch = batch.n_tokens; + slot.i_batch = llama_batch_get_n_tokens(batch.get()); - common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); + common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true); slot.n_past += 1; @@ -2900,7 +2896,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { + if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3066,7 +3062,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3086,11 +3082,11 @@ struct server_context { slot.cache_tokens.resize(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); + common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3100,13 +3096,13 @@ struct server_context { slot.n_past++; } - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0); common_sampler_reset(slot.smpl); @@ -3116,27 +3112,27 @@ struct server_context { } // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + llama_batch_set_logits_last(batch.get()); slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = llama_batch_get_n_tokens(batch.get()) - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get())); } } - if (batch.n_tokens >= n_batch) { + if (llama_batch_get_n_tokens(batch.get()) >= n_batch) { break; } } } - if (batch.n_tokens == 0) { + if (llama_batch_get_n_tokens(batch.get()) == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get())); if (slot_batched) { // make sure we're in the right embedding mode @@ -3146,20 +3142,12 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i); + + llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens)); - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode(ctx, batch_view.get()); metrics.on_decoded(slots); if (ret != 0) { @@ -3294,16 +3282,16 @@ struct server_context { } // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + common_batch_clear(slot.batch_spec.get()); + common_batch_add (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true); for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true); } - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get())); - llama_decode(ctx, slot.batch_spec); + llama_decode(ctx, slot.batch_spec.get()); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 8f6368177de09..80c726e301009 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -24,7 +24,12 @@ struct llama_adapter_lora_deleter { void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } }; +struct llama_batch_deleter { + void operator()(llama_batch * batch) { llama_batch_free(batch); } +}; + typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; typedef std::unique_ptr llama_adapter_lora_ptr; +typedef std::unique_ptr llama_batch_ptr; diff --git a/include/llama.h b/include/llama.h index 2c7569f8ab7a1..79dc8604db1be 100644 --- a/include/llama.h +++ b/include/llama.h @@ -233,6 +233,14 @@ extern "C" { struct llama_batch; + struct llama_batch_token_info { + llama_token token; + llama_pos pos; + int32_t n_seq_id; + llama_seq_id * seq_id; + int8_t logits; + }; + enum llama_model_kv_override_type { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, @@ -837,34 +845,44 @@ extern "C" { int32_t pos0, int32_t seq_id); + // Get the number of tokens in the batch + LLAMA_API int32_t llama_batch_get_n_tokens(const struct llama_batch * batch); + + LLAMA_API struct llama_batch_token_info llama_batch_get_token_info( + struct llama_batch * batch, + int32_t i); + // Add text tokens to the batch - // First token in the list starts at position pos0 // Return values: // 0 : success // -1 : not enough space in the batch // -2 : embd is already set, cannot add text tokens - LLAMA_API int32_t llama_batch_add_text( + LLAMA_API int32_t llama_batch_add_text_token( struct llama_batch * batch, - llama_token * tokens, - size_t n_tokens, - int32_t pos0, - int32_t seq_id); - - // Same as llama_batch_add_text, but accepts multiple sequences - LLAMA_API int32_t llama_batch_add_text( - struct llama_batch * batch, - llama_token * tokens, - size_t n_tokens, - int32_t pos0, - int32_t * seq_ids, - size_t n_seq_ids); + llama_token token, + llama_pos pos, + const llama_seq_id * seq_ids, + size_t n_seq_ids, + float logits); // Set logits for the token in the ith sequence // If pos == -1, logits will be set for the all tokens + // Returns -1 if the token is not in the batch LLAMA_API int32_t llama_batch_set_logits( struct llama_batch * batch, - int32_t pos, - int32_t seq_id); + llama_pos pos, + llama_seq_id seq_id); + + // Set logits for the last added token + // Returns -1 if there is no tokens in the batch + LLAMA_API int32_t llama_batch_set_logits_last(struct llama_batch * batch); + + // Get a "view" from a number of tokens offset + // Return returned batch must be freed with llama_batch_free() + LLAMA_API struct llama_batch * llama_batch_get_view( + struct llama_batch * batch, + int32_t offset, + int32_t n_tokens); // Remove everything from the batch LLAMA_API void llama_batch_clear(struct llama_batch * batch); @@ -878,7 +896,7 @@ extern "C" { // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, - struct llama_batch batch); + struct llama_batch * batch); // Positive return values does not mean a fatal error, but rather a warning. // 0 - success @@ -886,7 +904,7 @@ extern "C" { // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_decode( struct llama_context * ctx, - struct llama_batch batch); + struct llama_batch * batch); // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 027ac24139f5e..c9b6a97f73f50 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -314,6 +314,8 @@ struct llama_batch * llama_batch_get_one( int32_t n_tokens) { return new llama_batch{ /*n_tokens =*/ n_tokens, + /*max_tokens =*/ n_tokens, + /*is_view =*/ false, /*tokens =*/ tokens, /*embd =*/ nullptr, /*pos =*/ nullptr, @@ -326,6 +328,8 @@ struct llama_batch * llama_batch_get_one( static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch * batch = new llama_batch{ /*n_tokens =*/ 0, + /*max_tokens =*/ n_tokens_alloc, + /*is_view =*/ false, /*tokens =*/ nullptr, /*embd =*/ nullptr, /*pos =*/ nullptr, @@ -364,50 +368,46 @@ struct llama_batch * llama_batch_init_from_embd( int32_t seq_id) { struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1); memcpy(batch->embd, embd, n_embd * sizeof(float)); - for (int32_t i = 0; i < n_embd; i++) { + for (size_t i = 0; i < n_embd; i++) { batch->pos [i] = pos0 + i; batch->n_seq_id[i] = 1; batch->seq_id [i][0] = seq_id; } + return batch; +} + +int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) { + return batch->n_tokens; } -int32_t llama_batch_add_text( +int32_t llama_batch_add_text_token( struct llama_batch * batch, - llama_token * tokens, - size_t n_tokens, - int32_t pos0, - int32_t * seq_ids, - size_t n_seq_ids) { - if (batch->n_tokens + n_tokens > batch->n_tokens) { - return -1; + llama_token token, + llama_pos pos, + const llama_seq_id * seq_ids, + size_t n_seq_ids, + float logits) { + if (batch->n_tokens + 1 > batch->max_tokens) { + return -1; // llama_batch size exceeded } if (batch->embd) { - return -2; + return -2; // embd is already set, cannot add text tokens } - for (int32_t i = 0; i < n_tokens; i++) { - batch->token [batch->n_tokens + i] = tokens[i]; - batch->pos [batch->n_tokens + i] = pos0 + i; - batch->n_seq_id[batch->n_tokens + i] = n_seq_ids; - for (int32_t j = 0; j < n_seq_ids; j++) { - batch->seq_id[batch->n_tokens + i][j] = seq_ids[j]; - } + batch->token [batch->n_tokens] = token; + batch->pos [batch->n_tokens] = pos; + batch->n_seq_id[batch->n_tokens] = n_seq_ids; + for (size_t j = 0; j < n_seq_ids; j++) { + batch->seq_id[batch->n_tokens][j] = seq_ids[j]; } -} - -int32_t llama_batch_add_text( - struct llama_batch * batch, - llama_token * tokens, - size_t n_tokens, - int32_t pos0, - int32_t seq_id) { - std::array seq_ids = { seq_id }; - return llama_batch_add_text(batch, tokens, n_tokens, pos0, seq_ids.data(), seq_ids.size()); + batch->logits [batch->n_tokens] = logits; + batch->n_tokens++; + return 0; } int32_t llama_batch_set_logits( struct llama_batch * batch, - int32_t pos, - int32_t seq_id) { + llama_pos pos, + llama_seq_id seq_id) { for (int32_t i = 0; i < batch->n_tokens; i++) { // find the token having seq_id for (int32_t j = 0; j < batch->n_seq_id[i]; j++) { @@ -415,28 +415,74 @@ int32_t llama_batch_set_logits( // found the sequence if (pos == -1 || pos == batch->pos[i]) { batch->logits[i] = true; - break; + return 0; } } } } + return -1; // not found +} + +int32_t llama_batch_set_logits_last(struct llama_batch * batch) { + if (batch->n_tokens == 0) { + return -1; + } + batch->logits[batch->n_tokens - 1] = true; + return 0; } void llama_batch_clear(struct llama_batch * batch) { batch->n_tokens = 0; } +struct llama_batch * llama_batch_get_view( + struct llama_batch * batch, + int32_t offset, + int32_t n_tokens) { + if (batch->embd) { + return nullptr; // not yet supported + } + llama_batch * batch_view = new llama_batch{ + /*n_tokens =*/ n_tokens, + /*max_tokens =*/ n_tokens, + /*is_view =*/ true, + /*tokens =*/ batch->token + offset, + /*embd =*/ nullptr, + /*pos =*/ batch->pos + offset, + /*n_seq_id =*/ batch->n_seq_id + offset, + /*seq_id =*/ batch->seq_id + offset, + /*logits =*/ batch->logits + offset, + }; + return batch_view; +} + +struct llama_batch_token_info llama_batch_get_token_info( + struct llama_batch * batch, + int32_t i) { + GGML_ASSERT(i >= 0 && i < batch->n_tokens); + return llama_batch_token_info{ + /*token =*/ batch->token [i], + /*pos =*/ batch->pos [i], + /*n_seq_id =*/ batch->n_seq_id[i], + /*seq_id =*/ batch->seq_id [i], + /*logits =*/ batch->logits [i], + }; +} + void llama_batch_free(struct llama_batch * batch) { - if (batch->token) free(batch->token); - if (batch->embd) free(batch->embd); - if (batch->pos) free(batch->pos); - if (batch->n_seq_id) free(batch->n_seq_id); - if (batch->seq_id) { - for (int i = 0; batch->seq_id[i] != nullptr; ++i) { - free(batch->seq_id[i]); + // do not free the members if it's a view + if (!batch->is_view) { + if (batch->token) free(batch->token); + if (batch->embd) free(batch->embd); + if (batch->pos) free(batch->pos); + if (batch->n_seq_id) free(batch->n_seq_id); + if (batch->seq_id) { + for (int i = 0; batch->seq_id[i] != nullptr; ++i) { + free(batch->seq_id[i]); + } + free(batch->seq_id); } - free(batch->seq_id); + if (batch->logits) free(batch->logits); } - if (batch->logits) free(batch->logits); delete batch; } diff --git a/src/llama-batch.h b/src/llama-batch.h index de702da76d5b3..70bc6d4052c75 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -20,6 +20,8 @@ // struct llama_batch { int32_t n_tokens; + int32_t max_tokens; + bool is_view; llama_token * token; float * embd; diff --git a/src/llama.cpp b/src/llama.cpp index 607f278615969..978ce0dd76c09 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9978,8 +9978,8 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { int32_t llama_encode( struct llama_context * ctx, - struct llama_batch batch) { - const int ret = llama_encode_impl(*ctx, batch); + struct llama_batch * batch) { + const int ret = llama_encode_impl(*ctx, *batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -9989,8 +9989,8 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, - struct llama_batch batch) { - const int ret = llama_decode_impl(*ctx, batch); + struct llama_batch * batch) { + const int ret = llama_decode_impl(*ctx, *batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } From 17d3658b5f0a595cb6e0c56fa04dc00f8a6ab58d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 16 Feb 2025 00:02:53 +0100 Subject: [PATCH 03/52] move to llama_batch_ext --- common/common.cpp | 21 +++++--- common/common.h | 6 ++- common/speculative.cpp | 6 +-- include/llama-cpp.h | 6 +-- include/llama.h | 111 ++++++++++++++++++++++++++++----------- src/llama-batch.cpp | 116 +++++++++++++++++++++++++---------------- src/llama-batch.h | 16 +++--- src/llama.cpp | 59 +++++++++++++-------- 8 files changed, 223 insertions(+), 118 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index c79f1e73688d8..b54e546f96f37 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1610,20 +1610,29 @@ std::pair common_get_hf_file(const std::string &, cons // Batch utils // -void common_batch_clear(struct llama_batch * batch) { - llama_batch_clear(batch); +// DEPRECATED +void common_batch_clear(struct llama_batch & batch) { + batch.n_tokens = 0; } +// DEPRECATED void common_batch_add( - struct llama_batch * batch, + struct llama_batch & batch, llama_token id, llama_pos pos, const std::vector & seq_ids, bool logits) { - int32_t res = llama_batch_add_text_token(batch, id, pos, seq_ids.data(), seq_ids.size(), logits); - if (res == -1) { - LOG_ERR("%s: llama_batch size exceeded\n", __func__); + GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); + + batch.token [batch.n_tokens] = id; + batch.pos [batch.n_tokens] = pos; + batch.n_seq_id[batch.n_tokens] = seq_ids.size(); + for (size_t i = 0; i < seq_ids.size(); ++i) { + batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } + batch.logits [batch.n_tokens] = logits; + + batch.n_tokens++; } // diff --git a/common/common.h b/common/common.h index 8ce6f2f127a6b..524559de42f0e 100644 --- a/common/common.h +++ b/common/common.h @@ -554,10 +554,12 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector & seq_ids, diff --git a/common/speculative.cpp b/common/speculative.cpp index 0836845ecc2a7..318e96ea35468 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -13,7 +13,7 @@ struct common_speculative { struct llama_context * ctx; struct common_sampler * smpl; - llama_batch * batch; + llama_batch batch; llama_tokens prompt; }; @@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init( auto * result = new common_speculative { /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 1), + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), /* .prompt = */ {}, }; @@ -215,7 +215,7 @@ llama_tokens common_speculative_gen_draft( } // we should rarely end-up here during normal decoding - if (llama_batch_get_n_tokens(batch) > 0) { + if (batch.n_tokens > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode(ctx, batch); diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 80c726e301009..880a6a5fae8f5 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -24,12 +24,12 @@ struct llama_adapter_lora_deleter { void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } }; -struct llama_batch_deleter { - void operator()(llama_batch * batch) { llama_batch_free(batch); } +struct llama_batch_ext_deleter { + void operator()(llama_batch_ext * batch) { llama_batch_ext_free(batch); } }; typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; typedef std::unique_ptr llama_adapter_lora_ptr; -typedef std::unique_ptr llama_batch_ptr; +typedef std::unique_ptr llama_batch_ext_ptr; diff --git a/include/llama.h b/include/llama.h index 79dc8604db1be..32b4cdbe1dd05 100644 --- a/include/llama.h +++ b/include/llama.h @@ -231,9 +231,38 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); - struct llama_batch; - - struct llama_batch_token_info { + // Input data for llama_decode + // + // WARN: This struct is DEPRECATED and will be removed in the future, use llama_batch_ext instead + // + // A llama_batch object can contain input about one or many sequences + // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens + // + // - token : the token ids of the input (used when embd is NULL) + // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + // - pos : the positions of the respective token in the sequence + // (if set to NULL, the token position will be tracked automatically by llama_decode) + // - seq_id : the sequence to which the respective token belongs + // (if set to NULL, the sequence ID will be assumed to be 0) + // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output + // (if set to NULL, only the logits for last token will be returned) + // + typedef struct llama_batch { + int32_t n_tokens; + + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; // TODO: rename this to "output" + } llama_batch; + + // Input data for llama_decode / llama_encode + // It can contain text tokens and embeddings for one or many sequences + struct llama_batch_ext; + + struct llama_batch_ext_token_info { llama_token token; llama_pos pos; int32_t n_seq_id; @@ -815,9 +844,9 @@ extern "C" { // // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // - LLAMA_API struct llama_batch * llama_batch_get_one( + DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens); + int32_t n_tokens), "use llama_batch_ext API instead"); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids @@ -826,30 +855,47 @@ extern "C" { // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token // The rest of the llama_batch members are allocated with size n_tokens // All members are left uninitialized - // LLAMA_API struct llama_batch llama_batch_init( - // int32_t n_tokens, - // int32_t embd, - // int32_t n_seq_max); + DEPRECATED(LLAMA_API struct llama_batch llama_batch_init( + int32_t n_tokens, + int32_t embd, + int32_t n_seq_max), "use llama_batch_ext API instead"); + + // Frees a batch of tokens allocated with llama_batch_init() + DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch), + "use llama_batch_ext API instead"); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids - // The batch has to be freed with llama_batch_free() - LLAMA_API struct llama_batch * llama_batch_init( + // The batch has to be freed with llama_batch_ext_free() + LLAMA_API struct llama_batch_ext * llama_batch_ext_init( int32_t n_tokens, int32_t n_seq_max); + // Same with llama_batch_init, but initializes the batch with the provided text tokens + // First token will be at position pos0 + // The sequence ID will be fixed to seq_id + // The batch has to be freed with llama_batch_ext_free() + LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text( + llama_token * tokens, + int32_t n_tokens, + int32_t pos0, + int32_t seq_id); + // Same with llama_batch_init, but initializes the batch with the provided raw embeddings - LLAMA_API struct llama_batch * llama_batch_init_from_embd( + // First token will be at position pos0 + // The sequence ID will be fixed to seq_id + // The batch has to be freed with llama_batch_ext_free() + LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, size_t n_embd, int32_t pos0, int32_t seq_id); // Get the number of tokens in the batch - LLAMA_API int32_t llama_batch_get_n_tokens(const struct llama_batch * batch); + LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch); - LLAMA_API struct llama_batch_token_info llama_batch_get_token_info( - struct llama_batch * batch, + LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info( + struct llama_batch_ext * batch, int32_t i); // Add text tokens to the batch @@ -857,8 +903,8 @@ extern "C" { // 0 : success // -1 : not enough space in the batch // -2 : embd is already set, cannot add text tokens - LLAMA_API int32_t llama_batch_add_text_token( - struct llama_batch * batch, + LLAMA_API int32_t llama_batch_ext_add_text_token( + struct llama_batch_ext * batch, llama_token token, llama_pos pos, const llama_seq_id * seq_ids, @@ -868,43 +914,50 @@ extern "C" { // Set logits for the token in the ith sequence // If pos == -1, logits will be set for the all tokens // Returns -1 if the token is not in the batch - LLAMA_API int32_t llama_batch_set_logits( - struct llama_batch * batch, + LLAMA_API int32_t llama_batch_ext_set_logits( + struct llama_batch_ext * batch, llama_pos pos, llama_seq_id seq_id); // Set logits for the last added token // Returns -1 if there is no tokens in the batch - LLAMA_API int32_t llama_batch_set_logits_last(struct llama_batch * batch); + LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch); // Get a "view" from a number of tokens offset // Return returned batch must be freed with llama_batch_free() - LLAMA_API struct llama_batch * llama_batch_get_view( - struct llama_batch * batch, + LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view( + struct llama_batch_ext * batch, int32_t offset, int32_t n_tokens); // Remove everything from the batch - LLAMA_API void llama_batch_clear(struct llama_batch * batch); + LLAMA_API void llama_batch_ext_clear(struct llama_batch_ext * batch); - // Frees a batch of tokens allocated with llama_batch_init() - LLAMA_API void llama_batch_free(struct llama_batch * batch); + // Frees a batch of tokens allocated with llama_batch_ext_init() + // If this is a view, the original batch is not freed + LLAMA_API void llama_batch_ext_free(struct llama_batch_ext * batch); // Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Stores the encoder output internally for later use by the decoder cross-attention layers. // 0 - success // < 0 - error. the KV cache state is restored to the state before this call - LLAMA_API int32_t llama_encode( + DEPRECATED(LLAMA_API int32_t llama_encode( + struct llama_context * ctx, + struct llama_batch batch), "use llama_batch_ext API instead"); + LLAMA_API int32_t llama_text_encode( struct llama_context * ctx, - struct llama_batch * batch); + struct llama_batch_ext * batch); // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // < 0 - error. the KV cache state is restored to the state before this call - LLAMA_API int32_t llama_decode( + DEPRECATED(LLAMA_API int32_t llama_decode( + struct llama_context * ctx, + struct llama_batch batch), "use llama_batch_ext API instead"); + LLAMA_API int32_t llama_text_decode( struct llama_context * ctx, - struct llama_batch * batch); + struct llama_batch_ext * batch); // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index c9b6a97f73f50..36a3d00be1412 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { return ubatch; } -void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { +void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split, bool logits_all) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; @@ -273,49 +273,61 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ); } -llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) { - batch = in_batch; - GGML_ASSERT(batch.n_tokens > 0); - if (!batch.pos) { - pos.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { +llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) { + batch = new llama_batch_ext{ + /*n_tokens =*/ in_batch.n_tokens, + /*max_tokens =*/ in_batch.n_tokens, + /*is_view =*/ false, + /*tokens =*/ in_batch.token, + /*embd =*/ in_batch.embd, + /*pos =*/ in_batch.pos, + /*n_seq_id =*/ in_batch.n_seq_id, + /*seq_id =*/ in_batch.seq_id, + /*logits =*/ in_batch.logits, + }; + GGML_ASSERT(batch->n_tokens > 0); + if (!in_batch.pos) { + pos.resize(batch->n_tokens); + for (int32_t i = 0; i < batch->n_tokens; i++) { pos[i] = i + p0; } - batch.pos = pos.data(); + batch->pos = pos.data(); } - if (!batch.n_seq_id) { - n_seq_id.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch->n_seq_id) { + n_seq_id.resize(batch->n_tokens); + for (int32_t i = 0; i < batch->n_tokens; i++) { n_seq_id[i] = seq_id_0.size(); } - batch.n_seq_id = n_seq_id.data(); + batch->n_seq_id = n_seq_id.data(); } - if (!batch.seq_id) { - seq_id.resize(batch.n_tokens + 1); - seq_id[batch.n_tokens] = NULL; - for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch->seq_id) { + seq_id.resize(batch->n_tokens + 1); + seq_id[batch->n_tokens] = NULL; + for (int32_t i = 0; i < batch->n_tokens; i++) { seq_id[i] = seq_id_0.data(); } - batch.seq_id = seq_id.data(); + batch->seq_id = seq_id.data(); } - if (!batch.logits) { - logits.resize(batch.n_tokens); + if (!batch->logits) { + logits.resize(batch->n_tokens); logits[logits.size() - 1] = true; - batch.logits = logits.data(); + batch->logits = logits.data(); } } +llama_batch_allocr::~llama_batch_allocr() { + delete batch; +} + // // interface implementation // -struct llama_batch * llama_batch_get_one( - llama_token * tokens, - int32_t n_tokens) { - return new llama_batch{ +struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens) { + return llama_batch{ /*n_tokens =*/ n_tokens, - /*max_tokens =*/ n_tokens, - /*is_view =*/ false, /*tokens =*/ tokens, /*embd =*/ nullptr, /*pos =*/ nullptr, @@ -325,8 +337,20 @@ struct llama_batch * llama_batch_get_one( }; } -static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { - llama_batch * batch = new llama_batch{ +struct llama_batch_ext * llama_batch_ext_init_from_text( + llama_token * tokens, + int32_t n_tokens, + int32_t pos0, + int32_t seq_id) { + llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); + for (int32_t i = 0; i < n_tokens; i++) { + llama_batch_ext_add_text_token(batch, tokens[i], pos0 + i, &seq_id, 1, false); + } + return batch; +} + +static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch_ext * batch = new llama_batch_ext{ /*n_tokens =*/ 0, /*max_tokens =*/ n_tokens_alloc, /*is_view =*/ false, @@ -357,16 +381,16 @@ static struct llama_batch * llama_batch_init_impl(int32_t n_tokens_alloc, int32_ return batch; } -struct llama_batch * llama_batch_init(int32_t n_tokens_alloc, int32_t n_seq_max) { - return llama_batch_init_impl(n_tokens_alloc, 0, n_seq_max); +struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) { + return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max); } -struct llama_batch * llama_batch_init_from_embd( +struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, size_t n_embd, int32_t pos0, int32_t seq_id) { - struct llama_batch * batch = llama_batch_init_impl(0, n_embd, 1); + struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1); memcpy(batch->embd, embd, n_embd * sizeof(float)); for (size_t i = 0; i < n_embd; i++) { batch->pos [i] = pos0 + i; @@ -376,12 +400,12 @@ struct llama_batch * llama_batch_init_from_embd( return batch; } -int32_t llama_batch_get_n_tokens(const struct llama_batch * batch) { +int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { return batch->n_tokens; } -int32_t llama_batch_add_text_token( - struct llama_batch * batch, +int32_t llama_batch_ext_add_text_token( + struct llama_batch_ext * batch, llama_token token, llama_pos pos, const llama_seq_id * seq_ids, @@ -404,8 +428,8 @@ int32_t llama_batch_add_text_token( return 0; } -int32_t llama_batch_set_logits( - struct llama_batch * batch, +int32_t llama_batch_ext_set_logits( + struct llama_batch_ext * batch, llama_pos pos, llama_seq_id seq_id) { for (int32_t i = 0; i < batch->n_tokens; i++) { @@ -423,7 +447,7 @@ int32_t llama_batch_set_logits( return -1; // not found } -int32_t llama_batch_set_logits_last(struct llama_batch * batch) { +int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) { if (batch->n_tokens == 0) { return -1; } @@ -431,18 +455,18 @@ int32_t llama_batch_set_logits_last(struct llama_batch * batch) { return 0; } -void llama_batch_clear(struct llama_batch * batch) { +void llama_batch_ext_clear(struct llama_batch_ext * batch) { batch->n_tokens = 0; } -struct llama_batch * llama_batch_get_view( - struct llama_batch * batch, +struct llama_batch_ext * llama_batch_ext_get_view( + struct llama_batch_ext * batch, int32_t offset, int32_t n_tokens) { if (batch->embd) { return nullptr; // not yet supported } - llama_batch * batch_view = new llama_batch{ + llama_batch_ext * batch_view = new llama_batch_ext{ /*n_tokens =*/ n_tokens, /*max_tokens =*/ n_tokens, /*is_view =*/ true, @@ -456,11 +480,11 @@ struct llama_batch * llama_batch_get_view( return batch_view; } -struct llama_batch_token_info llama_batch_get_token_info( - struct llama_batch * batch, +struct llama_batch_ext_token_info llama_batch_ext_get_token_info( + struct llama_batch_ext * batch, int32_t i) { GGML_ASSERT(i >= 0 && i < batch->n_tokens); - return llama_batch_token_info{ + return llama_batch_ext_token_info{ /*token =*/ batch->token [i], /*pos =*/ batch->pos [i], /*n_seq_id =*/ batch->n_seq_id[i], @@ -469,7 +493,7 @@ struct llama_batch_token_info llama_batch_get_token_info( }; } -void llama_batch_free(struct llama_batch * batch) { +void llama_batch_ext_free(struct llama_batch_ext * batch) { // do not free the members if it's a view if (!batch->is_view) { if (batch->token) free(batch->token); diff --git a/src/llama-batch.h b/src/llama-batch.h index 70bc6d4052c75..bbd2205b3d3d7 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -5,8 +5,8 @@ #include #include -// Input data for llama_decode -// A llama_batch object can contain input about one or many sequences +// Input data for llama_decode / llama_encode +// A llama_batch_ext object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens // // - token : the token ids of the input (used when embd is NULL) @@ -18,7 +18,7 @@ // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output // (if set to NULL, only the logits for last token will be returned) // -struct llama_batch { +struct llama_batch_ext { int32_t n_tokens; int32_t max_tokens; bool is_view; @@ -73,7 +73,7 @@ struct llama_sbatch { std::vector out_ids; std::vector seq; - const llama_batch * batch = nullptr; + const llama_batch_ext * batch = nullptr; // buffers for the ubatch std::vector ubatch_token; @@ -96,12 +96,12 @@ struct llama_sbatch { // sequence-wise split llama_ubatch split_seq(size_t n_ubatch); - void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); + void from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); }; // temporary allocate memory for the input batch if needed struct llama_batch_allocr { - struct llama_batch batch; + struct llama_batch_ext * batch; std::array seq_id_0 = { 0 }; // default sequence id std::vector pos; @@ -110,5 +110,7 @@ struct llama_batch_allocr { std::vector logits; // optionally fulfill the batch returned by llama_batch_get_one - llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); + llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0); + + ~llama_batch_allocr(); }; diff --git a/src/llama.cpp b/src/llama.cpp index 978ce0dd76c09..a3dc7824aebcb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -8445,7 +8445,7 @@ static enum ggml_status llama_graph_compute( static int llama_prepare_sbatch( llama_context & lctx, - const llama_batch & batch, + const llama_batch_ext & batch, uint32_t & n_outputs) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -8585,7 +8585,7 @@ static int llama_prepare_ubatch( // static int llama_decode_impl( llama_context & lctx, - llama_batch inp_batch) { + llama_batch_ext & inp_batch) { lctx.is_encoding = false; @@ -8594,10 +8594,6 @@ static int llama_decode_impl( return -1; } - // temporarily allocate memory for the input batch if needed - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); - const llama_batch & batch = batch_allocr.batch; - const auto & model = lctx.model; const auto & vocab = model.vocab; const auto & hparams = model.hparams; @@ -8616,7 +8612,7 @@ static int llama_decode_impl( uint32_t n_outputs_prev = 0; { - const int ret = llama_prepare_sbatch(lctx, batch, n_outputs); + const int ret = llama_prepare_sbatch(lctx, inp_batch, n_outputs); if (ret != 0) { return ret; } @@ -8625,7 +8621,7 @@ static int llama_decode_impl( while (lctx.sbatch.n_tokens > 0) { llama_ubatch ubatch; { - const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens); + const int ret = llama_prepare_ubatch(lctx, kv_slot_restorer, ubatch, n_outputs, inp_batch.n_tokens); if (ret != 0) { return ret; } @@ -8832,7 +8828,7 @@ static int llama_decode_impl( // static int llama_encode_impl( llama_context & lctx, - llama_batch inp_batch) { + llama_batch_ext & inp_batch) { lctx.is_encoding = true; @@ -8841,22 +8837,18 @@ static int llama_encode_impl( return -1; } - // temporary allocate memory for the input batch if needed - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : lctx.kv_self.max_pos() + 1); - - const llama_batch & batch = batch_allocr.batch; - const uint32_t n_tokens = batch.n_tokens; + const uint32_t n_tokens = inp_batch.n_tokens; const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + GGML_ASSERT((!inp_batch.token && inp_batch.embd) || (inp_batch.token && !inp_batch.embd)); // NOLINT - if (batch.token) { + if (inp_batch.token) { for (uint32_t i = 0; i < n_tokens; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= model.vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + if (inp_batch.token[i] < 0 || (uint32_t) inp_batch.token[i] >= model.vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, inp_batch.token[i]); return -1; } } @@ -8873,7 +8865,7 @@ static int llama_encode_impl( const int64_t n_embd = hparams.n_embd; - lctx.sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + lctx.sbatch.from_batch(inp_batch, n_embd, /* simple_split */ true, /* logits_all */ true); const llama_ubatch ubatch = lctx.sbatch.split_simple(n_tokens); @@ -9976,9 +9968,32 @@ bool llama_kv_cache_can_shift(struct llama_context * ctx) { /// + +// DEPRECATED int32_t llama_encode( struct llama_context * ctx, - struct llama_batch * batch) { + struct llama_batch batch) { + // temporarily allocate memory for the input batch if needed + // also convert llama_batch to llama_batch_ext + llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1); + llama_batch_ext * batch_ext = batch_allocr.batch; + return llama_text_encode(ctx, batch_ext); +} + +// DEPRECATED +int32_t llama_decode( + struct llama_context * ctx, + struct llama_batch batch) { + // temporarily allocate memory for the input batch if needed + // also convert llama_batch to llama_batch_ext + llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1); + llama_batch_ext * batch_ext = batch_allocr.batch; + return llama_text_decode(ctx, batch_ext); +} + +int32_t llama_text_encode( + struct llama_context * ctx, + struct llama_batch_ext * batch) { const int ret = llama_encode_impl(*ctx, *batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); @@ -9987,9 +10002,9 @@ int32_t llama_encode( return ret; } -int32_t llama_decode( +int32_t llama_text_decode( struct llama_context * ctx, - struct llama_batch * batch) { + struct llama_batch_ext * batch) { const int ret = llama_decode_impl(*ctx, *batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); From 85ef80cbe95a45ad0d4c01a7ba83d58513a9e47b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 16 Feb 2025 00:06:48 +0100 Subject: [PATCH 04/52] server : use llama_batch_ext --- examples/server/server.cpp | 73 ++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 41f8dc505795d..caf412341d0ae 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1215,7 +1215,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - llama_batch_ptr batch_spec; + llama_batch_ext_ptr batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1787,7 +1787,7 @@ struct server_context { llama_context_params cparams_dft; - llama_batch_ptr batch; + llama_batch_ext_ptr batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1940,7 +1940,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec.reset(llama_batch_init(params_base.speculative.n_max + 1, 1)); + slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1)); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1976,7 +1976,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch.reset(llama_batch_init(std::max(n_batch, params_base.n_parallel), 1)); + batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1)); } metrics.init(); @@ -2094,7 +2094,7 @@ struct server_context { } if (slot.ctx_dft) { - slot.batch_spec.reset(llama_batch_init(slot.params.speculative.n_max + 1, 1)); + slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1)); } slot.state = SLOT_STATE_STARTED; @@ -2402,7 +2402,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, llama_batch_ptr & batch) { + void send_embedding(const server_slot & slot, llama_batch_ext_ptr & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2413,8 +2413,8 @@ struct server_context { std::vector embd_res(n_embd, 0.0f); - for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { - llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); + for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { + llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i); if (!tok.logits || tok.seq_id[0] != slot.id) { continue; } @@ -2446,14 +2446,14 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, llama_batch_ptr & batch) { + void send_rerank(const server_slot & slot, llama_batch_ext_ptr & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; res->n_tokens = slot.n_prompt_tokens; - for (int i = 0; i < llama_batch_get_n_tokens(batch.get()); ++i) { - llama_batch_token_info tok = llama_batch_get_token_info(batch.get(), i); + for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { + llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i); if (!tok.logits || tok.seq_id[0] != slot.id) { continue; } @@ -2855,7 +2855,7 @@ struct server_context { } // start populating the batch for this iteration - common_batch_clear(batch.get()); + llama_batch_ext_clear(batch.get()); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2877,9 +2877,10 @@ struct server_context { continue; } - slot.i_batch = llama_batch_get_n_tokens(batch.get()); + slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); - common_batch_add(batch.get(), slot.sampled, slot.n_past, { slot.id }, true); + std::array seq_id = { slot.id }; + llama_batch_ext_add_text_token(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true); slot.n_past += 1; @@ -2896,7 +2897,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || llama_batch_get_n_tokens(batch.get()) == 0) { + if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3062,7 +3063,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (llama_batch_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { + if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3082,11 +3083,12 @@ struct server_context { slot.cache_tokens.resize(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && llama_batch_get_n_tokens(batch.get()) < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch.get(), prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); + std::array seq_id = { slot.id }; + llama_batch_ext_add_text_token(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), true); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3096,13 +3098,13 @@ struct server_context { slot.n_past++; } - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(llama_batch_get_n_tokens(batch.get()) > 0); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0); common_sampler_reset(slot.smpl); @@ -3112,27 +3114,27 @@ struct server_context { } // extract the logits only for the last token - llama_batch_set_logits_last(batch.get()); + llama_batch_ext_set_logits_last(batch.get()); slot.n_decoded = 0; - slot.i_batch = llama_batch_get_n_tokens(batch.get()) - 1; + slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_get_n_tokens(batch.get())); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get())); } } - if (llama_batch_get_n_tokens(batch.get()) >= n_batch) { + if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) { break; } } } - if (llama_batch_get_n_tokens(batch.get()) == 0) { + if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_get_n_tokens(batch.get())); + SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get())); if (slot_batched) { // make sure we're in the right embedding mode @@ -3142,12 +3144,12 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < llama_batch_get_n_tokens(batch.get()); i += n_batch) { - const int32_t n_tokens = std::min(n_batch, llama_batch_get_n_tokens(batch.get()) - i); + for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i); - llama_batch_ptr batch_view(llama_batch_get_view(batch.get(), i, n_tokens)); + llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens)); - const int ret = llama_decode(ctx, batch_view.get()); + const int ret = llama_text_decode(ctx, batch_view.get()); metrics.on_decoded(slots); if (ret != 0) { @@ -3282,16 +3284,17 @@ struct server_context { } // construct the speculation batch - common_batch_clear(slot.batch_spec.get()); - common_batch_add (slot.batch_spec.get(), id, slot.n_past, { slot.id }, true); + llama_batch_ext_clear(slot.batch_spec.get()); + std::array seq_id = { slot.id }; + llama_batch_ext_add_text_token(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true); for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, { slot.id }, true); + llama_batch_ext_add_text_token(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true); } - SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_get_n_tokens(slot.batch_spec.get())); + SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); - llama_decode(ctx, slot.batch_spec.get()); + llama_text_decode(ctx, slot.batch_spec.get()); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); From aed4a8e980d76f246aaa83ba4f79e4f000c47f53 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 16 Feb 2025 11:36:50 +0100 Subject: [PATCH 05/52] fix server --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index caf412341d0ae..029bd97778beb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3088,7 +3088,7 @@ struct server_context { const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; std::array seq_id = { slot.id }; - llama_batch_ext_add_text_token(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), true); + llama_batch_ext_add_text_token(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); From 4bf7ca3943dfa6f34f3ab63deb58cfdec59d2fa6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 24 Feb 2025 17:01:20 +0100 Subject: [PATCH 06/52] llama_decode_ext --- examples/server/server.cpp | 4 ++-- include/llama.h | 4 ++-- src/llama.cpp | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 029bd97778beb..89d79f73e8461 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3149,7 +3149,7 @@ struct server_context { llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens)); - const int ret = llama_text_decode(ctx, batch_view.get()); + const int ret = llama_decode_ext(ctx, batch_view.get()); metrics.on_decoded(slots); if (ret != 0) { @@ -3294,7 +3294,7 @@ struct server_context { SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); - llama_text_decode(ctx, slot.batch_spec.get()); + llama_decode_ext(ctx, slot.batch_spec.get()); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); diff --git a/include/llama.h b/include/llama.h index 32b4cdbe1dd05..c0a3533de904d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -944,7 +944,7 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch), "use llama_batch_ext API instead"); - LLAMA_API int32_t llama_text_encode( + LLAMA_API int32_t llama_encode_ext( struct llama_context * ctx, struct llama_batch_ext * batch); @@ -955,7 +955,7 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch), "use llama_batch_ext API instead"); - LLAMA_API int32_t llama_text_decode( + LLAMA_API int32_t llama_decode_ext( struct llama_context * ctx, struct llama_batch_ext * batch); diff --git a/src/llama.cpp b/src/llama.cpp index a3dc7824aebcb..fb0e88c5b63fe 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9977,7 +9977,7 @@ int32_t llama_encode( // also convert llama_batch to llama_batch_ext llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1); llama_batch_ext * batch_ext = batch_allocr.batch; - return llama_text_encode(ctx, batch_ext); + return llama_encode_ext(ctx, batch_ext); } // DEPRECATED @@ -9988,10 +9988,10 @@ int32_t llama_decode( // also convert llama_batch to llama_batch_ext llama_batch_allocr batch_allocr(batch, batch.pos ? -1 : ctx->kv_self.max_pos() + 1); llama_batch_ext * batch_ext = batch_allocr.batch; - return llama_text_decode(ctx, batch_ext); + return llama_decode_ext(ctx, batch_ext); } -int32_t llama_text_encode( +int32_t llama_encode_ext( struct llama_context * ctx, struct llama_batch_ext * batch) { const int ret = llama_encode_impl(*ctx, *batch); @@ -10002,7 +10002,7 @@ int32_t llama_text_encode( return ret; } -int32_t llama_text_decode( +int32_t llama_decode_ext( struct llama_context * ctx, struct llama_batch_ext * batch) { const int ret = llama_decode_impl(*ctx, *batch); From f0ffd811305436c3eeb219f79c3fe7c7ea531c25 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 1 Mar 2025 12:12:52 +0100 Subject: [PATCH 07/52] adapt common --- common/common.cpp | 6 ++++-- common/speculative.cpp | 28 ++++++++++++++-------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index d45061f59ab2b..84d23163e50b2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1047,7 +1047,8 @@ struct common_init_result common_init_from_params(common_params & params) { } if (llama_model_has_encoder(model)) { - llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size())); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0)); + llama_encode_ext(lctx, batch.get()); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == LLAMA_TOKEN_NULL) { decoder_start_token_id = bos; @@ -1056,7 +1057,8 @@ struct common_init_result common_init_from_params(common_params & params) { tmp.push_back(decoder_start_token_id); } if (llama_model_has_decoder(model)) { - llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch))); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_encode_ext(lctx, batch.get()); } llama_kv_cache_clear(lctx); llama_synchronize(lctx); diff --git a/common/speculative.cpp b/common/speculative.cpp index b1fff27a55f91..585850aaea3a4 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -13,7 +13,7 @@ struct common_speculative { struct llama_context * ctx; struct common_sampler * smpl; - llama_batch batch; + llama_batch_ext_ptr batch; llama_tokens prompt; }; @@ -22,7 +22,7 @@ struct common_speculative * common_speculative_init( auto * result = new common_speculative { /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)), /* .prompt = */ {}, }; @@ -68,8 +68,6 @@ void common_speculative_free(struct common_speculative * spec) { common_sampler_free(spec->smpl); - llama_batch_free(spec->batch); - delete spec; } @@ -150,6 +148,8 @@ llama_tokens common_speculative_gen_draft( const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + const llama_seq_id seq_id = 0; + // reuse as much as possible from the old draft context // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt for (int i = 0; i < (int) prompt.size(); ++i) { @@ -205,40 +205,40 @@ llama_tokens common_speculative_gen_draft( } // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); + llama_batch_ext_clear(batch.get()); for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); - common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + llama_batch_ext_add_text_token(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false); prompt.push_back(prompt_tgt[i]); } // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { + if (llama_batch_ext_get_n_tokens(batch.get()) > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - llama_decode(ctx, batch); + llama_decode_ext(ctx, batch.get()); } const llama_pos n_past = prompt.size(); LOG_DBG("%s: n_past = %d\n", __func__, n_past); - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); + llama_batch_ext_clear(batch.get()); + llama_batch_ext_add_text_token(batch.get(), id_last, n_past, &seq_id, 1, true); prompt.push_back(id_last); //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); - llama_decode(ctx, batch); + llama_decode_ext(ctx, batch.get()); common_sampler_reset(smpl); // sample n_draft tokens from the draft model for (int i = 0; i < params.n_draft; ++i) { - common_batch_clear(batch); + llama_batch_ext_clear(batch.get()); common_sampler_sample(smpl, ctx, 0, true); @@ -265,10 +265,10 @@ llama_tokens common_speculative_gen_draft( break; } - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + llama_batch_ext_add_text_token(batch.get(), id, n_past + i + 1, &seq_id, 1, true); // evaluate the drafted tokens on the draft model - llama_decode(ctx, batch); + llama_decode_ext(ctx, batch.get()); prompt.push_back(id); } From 40989f4116a4e47f3af384902daac01537b62293 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 1 Mar 2025 14:00:05 +0100 Subject: [PATCH 08/52] correct llama_decode_ext --- common/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 84d23163e50b2..072a133b1ec42 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1058,7 +1058,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (llama_model_has_decoder(model)) { llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); - llama_encode_ext(lctx, batch.get()); + llama_decode_ext(lctx, batch.get()); } llama_kv_cache_clear(lctx); llama_synchronize(lctx); From 1170135dfbec0c6cfdfd7374b335f90480dea597 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 1 Mar 2025 14:00:14 +0100 Subject: [PATCH 09/52] llama_batch_ext_add_text --- common/speculative.cpp | 6 +++--- examples/server/server.cpp | 8 ++++---- include/llama.h | 2 +- src/llama-batch.cpp | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 585850aaea3a4..62ec5bfd87d65 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -209,7 +209,7 @@ llama_tokens common_speculative_gen_draft( for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); - llama_batch_ext_add_text_token(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false); + llama_batch_ext_add_text(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false); prompt.push_back(prompt_tgt[i]); } @@ -226,7 +226,7 @@ llama_tokens common_speculative_gen_draft( LOG_DBG("%s: n_past = %d\n", __func__, n_past); llama_batch_ext_clear(batch.get()); - llama_batch_ext_add_text_token(batch.get(), id_last, n_past, &seq_id, 1, true); + llama_batch_ext_add_text(batch.get(), id_last, n_past, &seq_id, 1, true); prompt.push_back(id_last); @@ -265,7 +265,7 @@ llama_tokens common_speculative_gen_draft( break; } - llama_batch_ext_add_text_token(batch.get(), id, n_past + i + 1, &seq_id, 1, true); + llama_batch_ext_add_text(batch.get(), id, n_past + i + 1, &seq_id, 1, true); // evaluate the drafted tokens on the draft model llama_decode_ext(ctx, batch.get()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2b06914ea336e..b745dd044db11 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2849,7 +2849,7 @@ struct server_context { slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); std::array seq_id = { slot.id }; - llama_batch_ext_add_text_token(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true); + llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true); slot.n_past += 1; @@ -3057,7 +3057,7 @@ struct server_context { const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; std::array seq_id = { slot.id }; - llama_batch_ext_add_text_token(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), need_embd); + llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3255,10 +3255,10 @@ struct server_context { // construct the speculation batch llama_batch_ext_clear(slot.batch_spec.get()); std::array seq_id = { slot.id }; - llama_batch_ext_add_text_token(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true); + llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true); for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_ext_add_text_token(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true); + llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true); } SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); diff --git a/include/llama.h b/include/llama.h index 86aa40d8cb027..dab1aea2b9b3b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -905,7 +905,7 @@ extern "C" { // 0 : success // -1 : not enough space in the batch // -2 : embd is already set, cannot add text tokens - LLAMA_API int32_t llama_batch_ext_add_text_token( + LLAMA_API int32_t llama_batch_ext_add_text( struct llama_batch_ext * batch, llama_token token, llama_pos pos, diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 36a3d00be1412..b63d4ec7ffc09 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -344,7 +344,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text( int32_t seq_id) { llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); for (int32_t i = 0; i < n_tokens; i++) { - llama_batch_ext_add_text_token(batch, tokens[i], pos0 + i, &seq_id, 1, false); + llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); } return batch; } @@ -404,7 +404,7 @@ int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { return batch->n_tokens; } -int32_t llama_batch_ext_add_text_token( +int32_t llama_batch_ext_add_text( struct llama_batch_ext * batch, llama_token token, llama_pos pos, From 1d6ba97789ed71a16b2e0b1d6f89dfda2499ac44 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 1 Mar 2025 16:21:16 +0100 Subject: [PATCH 10/52] remove token_info API --- examples/server/server.cpp | 126 ++++++++++++++++++++++++------------- include/llama.h | 12 ---- src/llama-batch.cpp | 13 ---- 3 files changed, 82 insertions(+), 69 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b745dd044db11..057184764104d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1205,6 +1205,47 @@ struct server_task_result_apply_lora : server_task_result { } }; +struct server_batch { + llama_batch_ext_ptr batch; + struct batch_token { + llama_token token; + llama_seq_id seq_id; + bool logits; + }; + std::vector tokens; + server_batch() = default; + server_batch(int32_t n_tokens, int32_t n_seq_max) { + batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); + tokens.reserve(n_tokens); + } + void clear() { + llama_batch_ext_clear(batch.get()); + tokens.clear(); + } + void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { + llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); + tokens.push_back({token, seq_id, logits}); + } + void set_logits_last() { + if (!tokens.empty()) { + llama_batch_ext_set_logits_last(batch.get()); + tokens.back().logits = true; + } + } + int32_t get_n_tokens() const { + return (int32_t)tokens.size(); + } + server_batch get_view(int32_t offset, int32_t n_tokens) { + server_batch view; + view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); + view.tokens.reserve(n_tokens); + for (int32_t i = 0; i < n_tokens; i++) { + view.tokens.push_back(tokens[offset + i]); + } + return view; + } +}; + struct server_slot { int id; int id_task = -1; @@ -1212,7 +1253,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - llama_batch_ext_ptr batch_spec; + server_batch batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1784,7 +1825,7 @@ struct server_context { llama_context_params cparams_dft; - llama_batch_ext_ptr batch; + server_batch batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1909,7 +1950,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1)); + slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1945,7 +1986,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1)); + batch = server_batch(std::max(n_batch, params_base.n_parallel), 1); } metrics.init(); @@ -2063,7 +2104,7 @@ struct server_context { } if (slot.ctx_dft) { - slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1)); + slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1); } slot.state = SLOT_STATE_STARTED; @@ -2371,7 +2412,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, llama_batch_ext_ptr & batch) { + void send_embedding(const server_slot & slot, server_batch & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2382,19 +2423,19 @@ struct server_context { std::vector embd_res(n_embd, 0.0f); - for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { - llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i); - if (!tok.logits || tok.seq_id[0] != slot.id) { + for (int i = 0; i < batch.get_n_tokens(); ++i) { + auto tok = batch.tokens[i]; + if (!tok.logits || tok.seq_id != slot.id) { continue; } - const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); + const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); res->embedding.push_back(std::vector(n_embd, 0.0f)); continue; @@ -2415,25 +2456,25 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, llama_batch_ext_ptr & batch) { + void send_rerank(const server_slot & slot, server_batch & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; res->n_tokens = slot.n_prompt_tokens; - for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { - llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i); - if (!tok.logits || tok.seq_id[0] != slot.id) { + for (int i = 0; i < batch.get_n_tokens(); ++i) { + auto tok = batch.tokens[i]; + if (!tok.logits || tok.seq_id != slot.id) { continue; } - const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]); + const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); if (embd == NULL) { embd = llama_get_embeddings_ith(ctx, i); } if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]); + SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); res->score = -1e6; continue; @@ -2824,7 +2865,7 @@ struct server_context { } // start populating the batch for this iteration - llama_batch_ext_clear(batch.get()); + batch.clear(); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2846,10 +2887,9 @@ struct server_context { continue; } - slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); + slot.i_batch = batch.get_n_tokens(); - std::array seq_id = { slot.id }; - llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true); + batch.add_text(slot.sampled, slot.n_past, slot.id, true); slot.n_past += 1; @@ -2866,7 +2906,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (params_base.cont_batching || batch.get_n_tokens() == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3032,7 +3072,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { + if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3052,12 +3092,11 @@ struct server_context { slot.cache_tokens.resize(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - std::array seq_id = { slot.id }; - llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), need_embd); + batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3067,13 +3106,13 @@ struct server_context { slot.n_past++; } - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0); + GGML_ASSERT(batch.get_n_tokens() > 0); common_sampler_reset(slot.smpl); @@ -3083,27 +3122,27 @@ struct server_context { } // extract the logits only for the last token - llama_batch_ext_set_logits_last(batch.get()); + batch.set_logits_last(); slot.n_decoded = 0; - slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1; + slot.i_batch = batch.get_n_tokens() - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get())); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens()); } } - if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) { + if (batch.get_n_tokens() >= n_batch) { break; } } } - if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (batch.get_n_tokens() == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get())); + SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens()); if (slot_batched) { // make sure we're in the right embedding mode @@ -3113,12 +3152,12 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) { - const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i); + for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); - llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens)); + server_batch batch_view = batch.get_view(i, n_tokens); - const int ret = llama_decode_ext(ctx, batch_view.get()); + const int ret = llama_decode_ext(ctx, batch_view.batch.get()); metrics.on_decoded(slots); if (ret != 0) { @@ -3253,17 +3292,16 @@ struct server_context { } // construct the speculation batch - llama_batch_ext_clear(slot.batch_spec.get()); - std::array seq_id = { slot.id }; - llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true); + slot.batch_spec.clear(); + slot.batch_spec.add_text(id, slot.n_past, slot.id, true); for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true); + slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); } - SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); - llama_decode_ext(ctx, slot.batch_spec.get()); + llama_decode_ext(ctx, slot.batch_spec.batch.get()); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); diff --git a/include/llama.h b/include/llama.h index dab1aea2b9b3b..d370e05234886 100644 --- a/include/llama.h +++ b/include/llama.h @@ -263,14 +263,6 @@ extern "C" { // It can contain text tokens and embeddings for one or many sequences struct llama_batch_ext; - struct llama_batch_ext_token_info { - llama_token token; - llama_pos pos; - int32_t n_seq_id; - llama_seq_id * seq_id; - int8_t logits; - }; - enum llama_model_kv_override_type { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, @@ -896,10 +888,6 @@ extern "C" { // Get the number of tokens in the batch LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch); - LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info( - struct llama_batch_ext * batch, - int32_t i); - // Add text tokens to the batch // Return values: // 0 : success diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index b63d4ec7ffc09..d8117c3f08bdf 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -480,19 +480,6 @@ struct llama_batch_ext * llama_batch_ext_get_view( return batch_view; } -struct llama_batch_ext_token_info llama_batch_ext_get_token_info( - struct llama_batch_ext * batch, - int32_t i) { - GGML_ASSERT(i >= 0 && i < batch->n_tokens); - return llama_batch_ext_token_info{ - /*token =*/ batch->token [i], - /*pos =*/ batch->pos [i], - /*n_seq_id =*/ batch->n_seq_id[i], - /*seq_id =*/ batch->seq_id [i], - /*logits =*/ batch->logits [i], - }; -} - void llama_batch_ext_free(struct llama_batch_ext * batch) { // do not free the members if it's a view if (!batch->is_view) { From 46596caf6df14ce1aa7a6d06fb7f5acbc3938014 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 1 Mar 2025 20:42:18 +0100 Subject: [PATCH 11/52] apply various in places --- common/common.h | 46 ++++++++++++++ examples/batched-bench/batched-bench.cpp | 40 ++++++------ examples/batched/batched.cpp | 32 +++++----- .../cvector-generator/cvector-generator.cpp | 3 +- examples/embedding/embedding.cpp | 29 +++++---- examples/eval-callback/eval-callback.cpp | 3 +- examples/gritlm/gritlm.cpp | 24 ++++---- examples/imatrix/imatrix.cpp | 13 ++-- examples/infill/infill.cpp | 3 +- examples/llama-bench/llama-bench.cpp | 6 +- examples/lookup/lookup.cpp | 19 +++--- examples/server/server.cpp | 61 +++---------------- 12 files changed, 144 insertions(+), 135 deletions(-) diff --git a/common/common.h b/common/common.h index c7e71bb290b8a..86cad86559bc6 100644 --- a/common/common.h +++ b/common/common.h @@ -565,6 +565,52 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions +// this is meant to be temporary +struct common_batch { + llama_batch_ext_ptr batch; + struct batch_token { + llama_token token; + llama_seq_id seq_id; + bool logits; + }; + std::vector tokens; + common_batch() = default; + common_batch(int32_t n_tokens, int32_t n_seq_max) { + batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); + tokens.reserve(n_tokens); + } + void clear() { + llama_batch_ext_clear(batch.get()); + tokens.clear(); + } + void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { + llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); + tokens.push_back({token, seq_id, logits}); + } + void set_logits_last() { + if (!tokens.empty()) { + llama_batch_ext_set_logits_last(batch.get()); + tokens.back().logits = true; + } + } + int32_t get_n_tokens() const { + return (int32_t)tokens.size(); + } + llama_batch_ext * get() { + return batch.get(); + } + common_batch get_view(int32_t offset, int32_t n_tokens) { + common_batch view; + view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); + view.tokens.reserve(n_tokens); + for (int32_t i = 0; i < n_tokens; i++) { + view.tokens.push_back(tokens[offset + i]); + } + return view; + } +}; + // // Token utils // diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 0659ab6f119a7..829bf7f949ba2 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -59,24 +59,17 @@ int main(int argc, char ** argv) { const int32_t n_kv_max = llama_n_ctx(ctx); - llama_batch batch = llama_batch_init(n_kv_max, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1); // decode in batches of ctx_params.n_batch tokens - auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); + auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) { + const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch); + for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i)); + + llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens)); + + const int ret = llama_decode_ext(ctx, batch_view.get()); if (ret != 0) { LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); return false; @@ -91,7 +84,8 @@ int main(int argc, char ** argv) { // warm up { for (int i = 0; i < 16; ++i) { - common_batch_add(batch, 0, i, { 0 }, false); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { @@ -121,14 +115,14 @@ int main(int argc, char ** argv) { continue; } - common_batch_clear(batch); + llama_batch_ext_clear(batch); for (int i = 0; i < pp; ++i) { for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { - common_batch_add(batch, 0, i, { j }, false); + llama_batch_ext_add_text(batch, 0, i, &j, 1, false); } } - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_logits_last(batch); const auto t_pp_start = ggml_time_us(); @@ -150,10 +144,10 @@ int main(int argc, char ** argv) { const auto t_tg_start = ggml_time_us(); for (int i = 0; i < tg; ++i) { - common_batch_clear(batch); + llama_batch_ext_clear(batch); for (int j = 0; j < pl; ++j) { - common_batch_add(batch, 0, pp + i, { j }, true); + llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, false); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { @@ -191,7 +185,7 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_free(ctx); llama_model_free(model); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 21b95ef5e4e83..858053a889e3a 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -102,7 +102,7 @@ int main(int argc, char ** argv) { // create a llama_batch // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel); + llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel); std::vector seq_ids(n_parallel, 0); for (int32_t i = 0; i < n_parallel; ++i) { @@ -111,12 +111,12 @@ int main(int argc, char ** argv) { // evaluate the initial prompt for (size_t i = 0; i < tokens_list.size(); ++i) { - common_batch_add(batch, tokens_list[i], i, seq_ids, false); + llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false); } - GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size()); if (llama_model_has_encoder(model)) { - if (llama_encode(ctx, batch)) { + if (llama_encode_ext(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -126,14 +126,14 @@ int main(int argc, char ** argv) { decoder_start_token_id = llama_vocab_bos(vocab); } - common_batch_clear(batch); - common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); + llama_batch_ext_clear(batch); + llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false); } // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_logits_last(batch); - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -155,16 +155,16 @@ int main(int argc, char ** argv) { // remember the batch index of the last token for each parallel sequence // we need this to determine which logits to sample from - std::vector i_batch(n_parallel, batch.n_tokens - 1); + std::vector i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); - int n_cur = batch.n_tokens; + int n_cur = llama_batch_ext_get_n_tokens(batch); int n_decode = 0; const auto t_main_start = ggml_time_us(); while (n_cur <= n_predict) { // prepare the next batch - common_batch_clear(batch); + llama_batch_ext_clear(batch); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -193,23 +193,23 @@ int main(int argc, char ** argv) { streams[i] += common_token_to_piece(ctx, new_token_id); - i_batch[i] = batch.n_tokens; + i_batch[i] = llama_batch_ext_get_n_tokens(batch); // push this new token for next evaluation - common_batch_add(batch, new_token_id, n_cur, { i }, true); + llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, false); n_decode += 1; } // all streams are finished - if (batch.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch) == 0) { break; } n_cur += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } @@ -234,7 +234,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_sampler_free(smpl); llama_free(ctx); diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 413b71d34c52b..689e3e53900f2 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_cache_clear(ctx); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 38d22c90f82bb..c71200958d4ca 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -25,14 +25,14 @@ static std::vector split_lines(const std::string & s, const std::st return lines; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { +static void batch_add_seq(common_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - common_batch_add(batch, tokens[i], i, { seq_id }, true); + batch.add_text(tokens[i], i, seq_id, true); } } -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { +static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const struct llama_model * model = llama_get_model(ctx); @@ -40,21 +40,21 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu llama_kv_cache_clear(ctx); // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { // encoder-only model - if (llama_encode(ctx, batch) < 0) { + if (llama_encode_ext(ctx, batch.get()) < 0) { LOG_ERR("%s : failed to encode\n", __func__); } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { // decoder-only model - if (llama_decode(ctx, batch) < 0) { + if (llama_decode_ext(ctx, batch.get()) < 0) { LOG_ERR("%s : failed to decode\n", __func__); } } - for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { + if (!batch.tokens[i].logits) { continue; } @@ -68,8 +68,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu GGML_ASSERT(embd != NULL && "failed to get token embeddings"); } else { // try to get sequence embeddings - supported only when pooling_type is not NONE - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - embd_pos = batch.seq_id[i][0]; + embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); + embd_pos = batch.tokens[i].seq_id; GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); } @@ -170,7 +170,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + struct common_batch batch = common_batch(n_batch, 1); // count number of embeddings int n_embd_count = 0; @@ -197,12 +197,12 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (batch.n_tokens + n_toks > n_batch) { + if (batch.get_n_tokens() + n_toks > n_batch) { float * out = emb + e * n_embd; batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); - e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s; s = 0; - common_batch_clear(batch); + batch.clear(); } // add to batch @@ -318,7 +318,6 @@ int main(int argc, char ** argv) { llama_perf_context_print(ctx); // clean up - llama_batch_free(batch); llama_backend_free(); return 0; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index fb188f5a9e132..7e600440d839d 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 72eb46257429e..aa87c3a27855c 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -13,10 +13,10 @@ static std::vector> encode(llama_context * ctx, const std::ve const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); - llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1); for (uint64_t i = 0; i < sentences.size(); i++) { - common_batch_clear(batch); + llama_batch_ext_clear(batch); const std::string input_string = instruction + sentences[i]; @@ -41,7 +41,8 @@ static std::vector> encode(llama_context * ctx, const std::ve // add input to batch (this increments n_tokens) for (int32_t j = 0; j < n_toks; j++) { - common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst); } // clear previous kv_cache values (irrelevant for embeddings) @@ -50,7 +51,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_set_causal_attn(ctx, false); // run model - llama_decode(ctx, batch); + llama_decode_ext(ctx, batch); // get embedding dimensions uint64_t n_embd = llama_model_n_embd(model); @@ -89,7 +90,7 @@ static std::vector> encode(llama_context * ctx, const std::ve #endif } - llama_batch_free(batch); + llama_batch_ext_free(batch); return result; } @@ -106,25 +107,26 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); - llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); + llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1); std::vector inputs = common_tokenize(vocab, prompt, false, true); int32_t i_current_token = 0; while (true) { - common_batch_clear(bat); + llama_batch_ext_clear(bat); { const int32_t n_inputs = inputs.size(); for (int32_t i = 0; i < n_inputs; i++) { - common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1); } } inputs.clear(); - llama_decode(ctx, bat); + llama_decode_ext(ctx, bat); - llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); + llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1); if (token == eos_token) { break; @@ -145,7 +147,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std std::printf("\n"); } - llama_batch_free(bat); + llama_batch_ext_free(bat); return result; } diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 4edc0bfacf125..86f7ccbc3bbb0 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -500,7 +500,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_cache_clear(ctx); - llama_batch batch = llama_batch_init(n_batch, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -514,14 +514,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_vocab_bos(vocab); } - common_batch_clear(batch); + llama_batch_ext_clear(batch); for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true); } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch); return false; } @@ -534,7 +535,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { } } - llama_batch_free(batch); + llama_batch_ext_free(batch); const auto t_end = std::chrono::high_resolution_clock::now(); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 489a208b66b34..738fd6e11ce9f 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -353,7 +353,8 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index f518d02d38689..f270cce69c0c5 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0)); + llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1461,7 +1462,8 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_decode(ctx, llama_batch_get_one(&token, 1)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0)); + llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index dbd0444ec8742..fee09adcd9103 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -91,8 +91,10 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); + llama_decode_ext(ctx, batch0.get()); + llama_decode_ext(ctx, batch1.get()); const auto t_enc_end = ggml_time_us(); @@ -108,7 +110,7 @@ int main(int argc, char ** argv){ std::vector draft; - llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); + llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1); // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); @@ -194,8 +196,9 @@ int main(int argc, char ** argv){ // clean the cache of draft tokens that weren't accepted llama_kv_cache_seq_rm(ctx, 0, n_past, -1); - common_batch_clear(batch_tgt); - common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); + const llama_seq_id seq_id = 0; + llama_batch_ext_clear(batch_tgt); + llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true); // Draft already contains a single token sampled from the model: GGML_ASSERT(draft.size() == 1); @@ -205,13 +208,13 @@ int main(int argc, char ** argv){ common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); for (size_t i = 1; i < draft.size(); ++i) { - common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); } t_draft_us += ggml_time_us() - t_start_draft_us; n_drafted += draft.size() - 1; - llama_decode(ctx, batch_tgt); + llama_decode_ext(ctx, batch_tgt); ++n_past; draft.erase(draft.begin()); @@ -243,7 +246,7 @@ int main(int argc, char ** argv){ common_sampler_free(smpl); - llama_batch_free(batch_tgt); + llama_batch_ext_free(batch_tgt); llama_backend_free(); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 057184764104d..22d2c6e92b0f0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1205,47 +1205,6 @@ struct server_task_result_apply_lora : server_task_result { } }; -struct server_batch { - llama_batch_ext_ptr batch; - struct batch_token { - llama_token token; - llama_seq_id seq_id; - bool logits; - }; - std::vector tokens; - server_batch() = default; - server_batch(int32_t n_tokens, int32_t n_seq_max) { - batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); - tokens.reserve(n_tokens); - } - void clear() { - llama_batch_ext_clear(batch.get()); - tokens.clear(); - } - void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { - llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); - tokens.push_back({token, seq_id, logits}); - } - void set_logits_last() { - if (!tokens.empty()) { - llama_batch_ext_set_logits_last(batch.get()); - tokens.back().logits = true; - } - } - int32_t get_n_tokens() const { - return (int32_t)tokens.size(); - } - server_batch get_view(int32_t offset, int32_t n_tokens) { - server_batch view; - view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); - view.tokens.reserve(n_tokens); - for (int32_t i = 0; i < n_tokens; i++) { - view.tokens.push_back(tokens[offset + i]); - } - return view; - } -}; - struct server_slot { int id; int id_task = -1; @@ -1253,7 +1212,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - server_batch batch_spec; + common_batch batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1825,7 +1784,7 @@ struct server_context { llama_context_params cparams_dft; - server_batch batch; + common_batch batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1950,7 +1909,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1); + slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1986,7 +1945,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = server_batch(std::max(n_batch, params_base.n_parallel), 1); + batch = common_batch(std::max(n_batch, params_base.n_parallel), 1); } metrics.init(); @@ -2104,7 +2063,7 @@ struct server_context { } if (slot.ctx_dft) { - slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1); + slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1); } slot.state = SLOT_STATE_STARTED; @@ -2412,7 +2371,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, server_batch & batch) { + void send_embedding(const server_slot & slot, common_batch & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2456,7 +2415,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, server_batch & batch) { + void send_rerank(const server_slot & slot, common_batch & batch) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -3155,9 +3114,9 @@ struct server_context { for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); - server_batch batch_view = batch.get_view(i, n_tokens); + common_batch batch_view = batch.get_view(i, n_tokens); - const int ret = llama_decode_ext(ctx, batch_view.batch.get()); + const int ret = llama_decode_ext(ctx, batch_view.get()); metrics.on_decoded(slots); if (ret != 0) { @@ -3301,7 +3260,7 @@ struct server_context { SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); - llama_decode_ext(ctx, slot.batch_spec.batch.get()); + llama_decode_ext(ctx, slot.batch_spec.get()); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); From 86973cb14a51bbd5268871c23fe7ab1ddfa75830 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 17:32:36 +0100 Subject: [PATCH 12/52] fix merge errors --- include/llama.h | 2 ++ src/llama-context.cpp | 63 +++++++++++++++++++++++++------------------ src/llama-context.h | 4 +++ 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/include/llama.h b/include/llama.h index ac04813938fcd..564ffe1aa961c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -994,6 +994,7 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch), "use llama_batch_ext API instead"); + LLAMA_API int32_t llama_encode_ext( struct llama_context * ctx, struct llama_batch_ext * batch); @@ -1005,6 +1006,7 @@ extern "C" { DEPRECATED(LLAMA_API int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch), "use llama_batch_ext API instead"); + LLAMA_API int32_t llama_decode_ext( struct llama_context * ctx, struct llama_batch_ext * batch); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 0a43a3af8e003..d89e1ac2cc265 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -4,6 +4,7 @@ #include "llama-io.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-batch.h" #include "llama-kv-cache.h" #include @@ -980,16 +981,26 @@ bool llama_context::apply_adapter_cvec( } int llama_context::encode(llama_batch & inp_batch) { + // temporary allocate memory and convert llama_batch to llama_batch_ext + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + return encode(*batch_allocr.batch); +} + +int llama_context::decode(llama_batch & inp_batch) { + // temporary allocate memory and convert llama_batch to llama_batch_ext + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + return decode(*batch_allocr.batch); +} + +int llama_context::encode(llama_batch_ext & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); - - const llama_batch & batch = batch_allocr.batch; + llama_batch_ext & batch = inp_batch; const int32_t n_tokens = batch.n_tokens; const auto & hparams = model.hparams; @@ -1132,17 +1143,13 @@ int llama_context::encode(llama_batch & inp_batch) { return 0; } -int llama_context::decode(llama_batch & inp_batch) { +int llama_context::decode(llama_batch_ext & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); - - const llama_batch & batch = batch_allocr.batch; + llama_batch_ext & batch = inp_batch; const auto & vocab = model.vocab; const auto & hparams = model.hparams; @@ -2714,26 +2721,30 @@ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, lla /// +// deprecated int32_t llama_encode( - llama_context * ctx, - llama_batch batch) { - const int ret = ctx->encode(batch); - if (ret != 0) { - LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); - } - - return ret; + struct llama_context * ctx, + struct llama_batch inp_batch) { + return ctx->encode(inp_batch); } +// deprecated int32_t llama_decode( - llama_context * ctx, - llama_batch batch) { - const int ret = ctx->decode(batch); - if (ret != 0) { - LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); - } + struct llama_context * ctx, + struct llama_batch inp_batch) { + return ctx->decode(inp_batch); +} + +int32_t llama_encode_ext( + struct llama_context * ctx, + struct llama_batch_ext * inp_batch) { + return ctx->encode(*inp_batch); +} - return ret; +int32_t llama_decode_ext( + struct llama_context * ctx, + struct llama_batch_ext * inp_batch) { + return ctx->decode(*inp_batch); } // diff --git a/src/llama-context.h b/src/llama-context.h index 71d702e8baeeb..29bb230f1060b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -81,9 +81,13 @@ struct llama_context { int32_t il_start, int32_t il_end); + // deprecated int encode(llama_batch & inp_batch); int decode(llama_batch & inp_batch); + int encode(llama_batch_ext & inp_batch); + int decode(llama_batch_ext & inp_batch); + // // state save/load // From 4aabf4e8f4b88e96c6c98a504b2c8cbe0d815e46 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 17:47:07 +0100 Subject: [PATCH 13/52] return output ID from llama_batch_ext_add/set --- common/common.h | 2 +- examples/batched-bench/batched-bench.cpp | 2 +- examples/batched/batched.cpp | 2 +- include/llama.h | 26 ++++++++++++++---------- src/llama-batch.cpp | 24 ++++++++++++---------- 5 files changed, 31 insertions(+), 25 deletions(-) diff --git a/common/common.h b/common/common.h index c7dbcc202325d..afede57bbe24d 100644 --- a/common/common.h +++ b/common/common.h @@ -606,7 +606,7 @@ struct common_batch { } void set_logits_last() { if (!tokens.empty()) { - llama_batch_ext_set_logits_last(batch.get()); + llama_batch_ext_set_output_last(batch.get()); tokens.back().logits = true; } } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 1eb0ede77fcb7..8f7c2c94b8964 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -122,7 +122,7 @@ int main(int argc, char ** argv) { llama_batch_ext_add_text(batch, 0, i, &j, 1, false); } } - llama_batch_ext_set_logits_last(batch); + llama_batch_ext_set_output_last(batch); const auto t_pp_start = ggml_time_us(); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 858053a889e3a..1ed189859d4d0 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -131,7 +131,7 @@ int main(int argc, char ** argv) { } // llama_decode will output logits only for the last token of the prompt - llama_batch_ext_set_logits_last(batch); + llama_batch_ext_set_output_last(batch); if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); diff --git a/include/llama.h b/include/llama.h index 564ffe1aa961c..ee74d9a8c16c2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -900,7 +900,7 @@ extern "C" { // DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens), "use llama_batch_ext API instead"); + int32_t n_tokens), "use llama_batch_ext_init_from_text instead"); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids @@ -912,7 +912,7 @@ extern "C" { DEPRECATED(LLAMA_API struct llama_batch llama_batch_init( int32_t n_tokens, int32_t embd, - int32_t n_seq_max), "use llama_batch_ext API instead"); + int32_t n_seq_max), "use llama_batch_ext_init instead"); // Frees a batch of tokens allocated with llama_batch_init() DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch), @@ -950,28 +950,32 @@ extern "C" { // Add text tokens to the batch // Return values: - // 0 : success // -1 : not enough space in the batch // -2 : embd is already set, cannot add text tokens + // otherwise, returns the output ID LLAMA_API int32_t llama_batch_ext_add_text( struct llama_batch_ext * batch, llama_token token, llama_pos pos, const llama_seq_id * seq_ids, size_t n_seq_ids, - float logits); + bool output); - // Set logits for the token in the ith sequence - // If pos == -1, logits will be set for the all tokens - // Returns -1 if the token is not in the batch - LLAMA_API int32_t llama_batch_ext_set_logits( + // Set output (logits/embeddings) for the token in the ith sequence + // If pos == -1, output will be set for the all tokens + // Return values: + // -1 : the token is not in the batch + // otherwise, returns the output ID + LLAMA_API int32_t llama_batch_ext_set_output( struct llama_batch_ext * batch, llama_pos pos, llama_seq_id seq_id); - // Set logits for the last added token - // Returns -1 if there is no tokens in the batch - LLAMA_API int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch); + // Set output (logits/embeddings) for the last added token + // Return values: + // -1 : the batch is empty + // otherwise, returns the output ID + LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch); // Get a "view" from a number of tokens offset // Return returned batch must be freed with llama_batch_free() diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index d8117c3f08bdf..bae8b37b3fc1f 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -410,25 +410,26 @@ int32_t llama_batch_ext_add_text( llama_pos pos, const llama_seq_id * seq_ids, size_t n_seq_ids, - float logits) { + bool output) { if (batch->n_tokens + 1 > batch->max_tokens) { return -1; // llama_batch size exceeded } if (batch->embd) { return -2; // embd is already set, cannot add text tokens } - batch->token [batch->n_tokens] = token; - batch->pos [batch->n_tokens] = pos; - batch->n_seq_id[batch->n_tokens] = n_seq_ids; + const int32_t output_id = batch->n_tokens; + batch->token [output_id] = token; + batch->pos [output_id] = pos; + batch->n_seq_id[output_id] = n_seq_ids; for (size_t j = 0; j < n_seq_ids; j++) { batch->seq_id[batch->n_tokens][j] = seq_ids[j]; } - batch->logits [batch->n_tokens] = logits; + batch->logits [output_id] = output; batch->n_tokens++; - return 0; + return output_id; } -int32_t llama_batch_ext_set_logits( +int32_t llama_batch_ext_set_output( struct llama_batch_ext * batch, llama_pos pos, llama_seq_id seq_id) { @@ -439,7 +440,7 @@ int32_t llama_batch_ext_set_logits( // found the sequence if (pos == -1 || pos == batch->pos[i]) { batch->logits[i] = true; - return 0; + return i; } } } @@ -447,12 +448,13 @@ int32_t llama_batch_ext_set_logits( return -1; // not found } -int32_t llama_batch_ext_set_logits_last(struct llama_batch_ext * batch) { +int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) { if (batch->n_tokens == 0) { return -1; } - batch->logits[batch->n_tokens - 1] = true; - return 0; + const int32_t output_id = batch->n_tokens - 1; + batch->logits[output_id] = true; + return output_id; } void llama_batch_ext_clear(struct llama_batch_ext * batch) { From 47086fa82d24b8d39ba4e4ecdc09927c721055ad Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 22:36:27 +0100 Subject: [PATCH 14/52] apply to the rest --- common/common.cpp | 37 ------- common/common.h | 18 +++- examples/llava/gemma3-cli.cpp | 58 +++------- examples/llava/llava.cpp | 38 +------ examples/llava/qwen2vl-cli.cpp | 1 + examples/lookahead/lookahead.cpp | 21 ++-- examples/parallel/parallel.cpp | 50 ++++----- examples/passkey/passkey.cpp | 32 +++--- examples/perplexity/perplexity.cpp | 100 +++++++----------- examples/retrieval/retrieval.cpp | 61 +++++++---- examples/run/run.cpp | 10 +- examples/save-load-state/save-load-state.cpp | 39 ++++--- examples/simple-chat/simple-chat.cpp | 13 ++- examples/simple/simple.cpp | 14 ++- .../speculative-simple/speculative-simple.cpp | 12 ++- examples/speculative/speculative.cpp | 19 ++-- examples/tts/tts.cpp | 36 ++++--- include/llama.h | 8 +- 18 files changed, 243 insertions(+), 324 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8eb65053c1ea2..ec4bf699ab808 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -582,43 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector & values); std::string string_from(const struct llama_context * ctx, const std::vector & tokens); -std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch); // // Filesystem utils @@ -587,10 +586,10 @@ struct common_batch { llama_batch_ext_ptr batch; struct batch_token { llama_token token; - llama_seq_id seq_id; bool logits; }; std::vector tokens; + int n_outputs = 0; common_batch() = default; common_batch(int32_t n_tokens, int32_t n_seq_max) { batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); @@ -602,7 +601,17 @@ struct common_batch { } void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); - tokens.push_back({token, seq_id, logits}); + tokens.push_back({token, logits}); + if (logits) { + n_outputs++; + } + } + void add_text(llama_token token, llama_pos pos, std::vector seq_ids, bool logits) { + llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); + tokens.push_back({token, logits}); + if (logits) { + n_outputs++; + } } void set_logits_last() { if (!tokens.empty()) { @@ -622,6 +631,9 @@ struct common_batch { view.tokens.reserve(n_tokens); for (int32_t i = 0; i < n_tokens; i++) { view.tokens.push_back(tokens[offset + i]); + if (tokens[offset + i].logits) { + view.n_outputs++; + } } return view; } diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index c36bb2eda0c70..9aa71065249e1 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -5,6 +5,7 @@ #include "clip.h" #include "stb_image.h" #include "llama.h" +#include "llama-cpp.h" #include "ggml.h" #include "console.h" @@ -63,7 +64,7 @@ struct gemma3_context { llama_model * model; llama_context * lctx; const llama_vocab * vocab; - llama_batch batch; + llama_batch_ext_ptr batch; int n_threads = 1; llama_pos n_past = 0; @@ -73,7 +74,7 @@ struct gemma3_context { lctx = llama_init.context.get(); vocab = llama_model_get_vocab(model); n_threads = params.cpuparams.n_threads; - batch = llama_batch_init(params.n_batch, 0, 1); + batch.reset(llama_batch_ext_init(params.n_batch, 1)); init_clip_model(params); } @@ -87,50 +88,18 @@ struct gemma3_context { } }; -struct decode_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); - common_batch_clear(ctx.batch); + llama_batch_ext_clear(ctx.batch.get()); for (llama_token & t : tokens) { - common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false); } if (logits_last) { - ctx.batch.logits[ctx.batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(ctx.batch.get()); } // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); - if (llama_decode(ctx.lctx, ctx.batch)) { + if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { LOG_ERR("Failed to decode text\n"); return 1; } @@ -179,8 +148,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) { int64_t t1 = ggml_time_ms(); eval_text(ctx, ""); llama_set_causal_attn(ctx.lctx, false); - decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0); - if (llama_decode(ctx.lctx, batch_img.batch)) { + llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0)); + if (llama_decode_ext(ctx.lctx, batch_img.get())) { LOG_ERR("failed to decode image\n"); return 1; } @@ -210,9 +179,10 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_ fflush(stdout); // eval the token - common_batch_clear(ctx.batch); - common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); - if (llama_decode(ctx.lctx, ctx.batch)) { + llama_batch_ext_clear(ctx.batch.get()); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true); + if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { LOG_ERR("failed to decode token\n"); return 1; } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 518aad3f1f70b..53ce30215508b 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -2,6 +2,7 @@ #include "llava.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co return true; } -struct llava_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); @@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); - if (llama_decode(ctx_llama, llava_batch.batch)) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0)); + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 132a7da543c2a..d65e88f9d12d5 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -66,6 +66,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); + // TODO: move this to llama_batch_ext API llama_batch batch = { int32_t(n_eval), // n_tokens nullptr, // token diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 7df20aee17046..1c2c3ec46c903 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -115,7 +115,7 @@ int main(int argc, char ** argv) { // seq_id == 0 : the current input token // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations // seq_id [W + 1, W + G] : verification n-grams - llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); + llama_batch_ext * batch = llama_batch_ext_init(params.n_ctx, W + G + 1); // target model sampling context struct common_sampler * smpl = common_sampler_init(model, params.sampling); @@ -204,10 +204,10 @@ int main(int argc, char ** argv) { // V V V V V V // id { - common_batch_clear(batch); + llama_batch_ext_clear(batch); // current token - first token of the first level - common_batch_add(batch, id, n_past, seq_id_all, true); + llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true); // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation { @@ -230,9 +230,10 @@ int main(int argc, char ** argv) { const llama_token t = ngrams_observed.tokens[idx + j]; ngrams_cur[g].tokens [j + 1] = t; - ngrams_cur[g].i_batch[j + 1] = batch.n_tokens; + ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch); - common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); + llama_seq_id seq_id = W + 1 + g; + llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true); } } } @@ -244,18 +245,20 @@ int main(int argc, char ** argv) { seq_id_look[j] = i + j + 1; } - common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); + llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i, + seq_id_look.data(), seq_id_look.size(), false); } // fill the rest of the levels for (int j = 1; j < N - 1; j++) { for (int i = 0; i < W; i++) { - common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); + llama_seq_id seq_id = i + 1; + llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2); } } } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__); return 1; } @@ -475,7 +478,7 @@ int main(int argc, char ** argv) { llama_kv_cache_view_free(&kvc_view); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_backend_free(); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 588632f0432b2..1d5f59f7d2124 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -174,7 +174,7 @@ int main(int argc, char ** argv) { // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time - llama_batch batch = llama_batch_init(n_ctx, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1); int32_t n_total_prompt = 0; int32_t n_total_gen = 0; @@ -192,10 +192,11 @@ int main(int argc, char ** argv) { LOG_INF("%s: Evaluating the system prompt ...\n", __func__); for (int32_t i = 0; i < n_tokens_system; ++i) { - common_batch_add(batch, tokens_system[i], i, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -216,7 +217,7 @@ int main(int argc, char ** argv) { common_kv_cache_dump_view_seqs(kvc_view, 40); } - common_batch_clear(batch); + llama_batch_ext_clear(batch); // decode any currently ongoing sequences for (auto & client : clients) { @@ -224,14 +225,15 @@ int main(int argc, char ** argv) { continue; } - client.i_batch = batch.n_tokens; + client.i_batch = llama_batch_ext_get_n_tokens(batch); - common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); + llama_seq_id seq_id = client.id + 1; + llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true); client.n_decoded += 1; } - if (batch.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch) == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { llama_kv_self_seq_rm(ctx, i, -1, -1); @@ -243,7 +245,7 @@ int main(int argc, char ** argv) { } // insert new sequences for decoding - if (cont_batching || batch.n_tokens == 0) { + if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) { for (auto & client : clients) { if (client.seq_id == -1 && g_seq_id < n_seq) { client.seq_id = g_seq_id; @@ -262,17 +264,18 @@ int main(int argc, char ** argv) { tokens_prompt = common_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); + llama_seq_id seq_id = client.id + 1; + llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false); } // extract the logits only for the last token - if (batch.n_tokens > 0) { - batch.logits[batch.n_tokens - 1] = true; + if (llama_batch_ext_get_n_tokens(batch) > 0) { + llama_batch_ext_set_output_last(batch); } client.n_prompt = tokens_prompt.size(); client.n_decoded = 0; - client.i_batch = batch.n_tokens - 1; + client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1; LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); @@ -286,14 +289,15 @@ int main(int argc, char ** argv) { } } - if (batch.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch) == 0) { break; } // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch); + for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) { // experiment: process in powers of 2 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { // n_batch /= 2; @@ -301,19 +305,11 @@ int main(int argc, char ** argv) { // continue; //} - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i)); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); + llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens); + const int ret = llama_decode_ext(ctx, batch_view); + llama_batch_ext_free(batch_view); if (ret != 0) { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size @@ -417,7 +413,7 @@ int main(int argc, char ** argv) { // TODO: print sampling/grammar timings for all clients llama_perf_context_print(ctx); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_backend_free(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index ea3a6c1fca3ee..88e6ccdde6424 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -122,7 +123,7 @@ int main(int argc, char ** argv) { LOG_INF("prompt tokens: %d\n", n_tokens_all); //LOG_INF("prompt: %s\n", params.prompt.c_str()); - llama_batch batch = llama_batch_init(params.n_batch, 0, 1); + llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1)); int n_past = 0; @@ -140,17 +141,18 @@ int main(int argc, char ** argv) { n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; } - common_batch_clear(batch); + llama_batch_ext_clear(batch.get()); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch.get()) != 0) { LOG_INF("%s: llama_decode() failed\n", __func__); return 1; } @@ -174,17 +176,18 @@ int main(int argc, char ** argv) { n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; - common_batch_clear(batch); + llama_batch_ext_clear(batch.get()); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -223,7 +226,7 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1); // is it an end of generation? if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { @@ -237,16 +240,17 @@ int main(int argc, char ** argv) { n_decode += 1; // prepare the next batch - common_batch_clear(batch); + llama_batch_ext_clear(batch.get()); // push this new token for next evaluation - common_batch_add(batch, new_token_id, n_past++, { 0 }, true); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true); } n_cur += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } @@ -266,8 +270,6 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl); - llama_batch_free(batch); - llama_free(ctx); llama_model_free(model); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 8c413f7d66e6d..d24fddbf450e1 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // clear the KV cache llama_kv_self_clear(ctx); - llama_batch batch = llama_batch_init(n_batch, 0, 1); + common_batch batch(n_batch, 1); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - common_batch_clear(batch); + batch.clear(); for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); } //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { //LOG_ERR("%s : failed to eval\n", __func__); - llama_batch_free(batch); return {tokens, -1, logit_history, prob_history}; } @@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params } } - llama_batch_free(batch); - const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { @@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); GGML_ASSERT(params.n_ctx == n_seq * n_ctx); - llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); + common_batch batch(std::min(n_batch, n_ctx*n_seq), 1); std::vector logits; if (num_batches > 1) { @@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & int n_outputs = 0; - batch.n_tokens = 0; + batch.clear(); for (int seq = 0; seq < n_seq_batch; seq++) { int seq_start = batch_start + seq*n_ctx; @@ -569,21 +566,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & for (int k = 0; k < batch_size; ++k) { const int idx = seq*n_ctx + k; - batch.token [idx] = tokens[seq_start + k]; - batch.pos [idx] = j*n_batch + k; - batch.n_seq_id[idx] = 1; - batch.seq_id [idx][0] = seq; - batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; + const llama_pos pos = j*n_batch + k; + bool output = pos >= first; + batch.add_text(tokens[seq_start + k], pos, seq, output); - n_outputs += batch.logits[idx] != 0; + n_outputs += output ? 1 : 0; } - batch.n_tokens += batch_size; // restore the original token in case it was set to BOS tokens[seq_start] = token_org; } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { LOG_INF("%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -653,36 +647,23 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); } - llama_batch_free(batch); - return {tokens, ppl, logit_history, prob_history}; } -static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int n_batch, int n_vocab) { +static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector & batch_logits, int n_batch, int n_vocab) { int prev_outputs = 0; - for (int i = 0; i < (int) batch.n_tokens; i += n_batch) { - const int n_tokens = std::min(n_batch, batch.n_tokens - i); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) { + const int n_tokens = std::min(n_batch, batch.get_n_tokens() - i); + + common_batch batch_view = batch.get_view(i, n_tokens); - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode_ext(ctx, batch_view.get()); if (ret != 0) { LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); return false; } - int n_outputs = 0; - for (int i = 0; i < n_tokens; ++i) { - n_outputs += batch_view.logits[i] != 0; - } + int n_outputs = batch_view.n_outputs; memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); @@ -863,7 +844,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, 4); + common_batch batch(n_ctx, 4); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -879,7 +860,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - common_batch_clear(batch); + batch.clear(); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -895,9 +876,9 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { } for (size_t i = 0; i < hs_cur.common_prefix; ++i) { - common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); + batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + llama_batch_ext_set_output_last(batch.get()); n_logits += 1; for (int s = 0; s < 4; ++s) { @@ -905,7 +886,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { // TODO: don't evaluate the last token of each sequence for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); + batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); n_logits += needs_logits; } } @@ -992,8 +973,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { i0 = i1 - 1; } - llama_batch_free(batch); - LOG("\n"); } @@ -1147,7 +1126,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) const int max_tasks_per_batch = 128; const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, 2); + common_batch batch(n_ctx, 2); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -1166,7 +1145,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) size_t i1 = i0; size_t i_logits = 0; - common_batch_clear(batch); + batch.clear(); while (n_cur + (int) data[i1].required_tokens <= n_ctx) { int n_logits = 0; @@ -1176,15 +1155,15 @@ static void winogrande_score(llama_context * ctx, const common_params & params) } for (size_t i = 0; i < data[i1].common_prefix; ++i) { - common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); + batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); } - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); n_logits += 1; for (int s = 0; s < 2; ++s) { // TODO: end before the last token, no need to predict past the end of the sequences for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { - common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); + batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true); n_logits += 1; } } @@ -1501,7 +1480,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); + common_batch batch(n_ctx, max_seq); std::vector tok_logits(n_vocab); std::vector batch_logits(size_t(n_ctx)*n_vocab); @@ -1521,7 +1500,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - common_batch_clear(batch); + batch.clear(); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -1544,9 +1523,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par for (size_t i = 0; i < cur_task.common_prefix; ++i) { //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); - common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); + batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { @@ -1554,7 +1533,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par // TODO: don't evaluate the last token of each sequence for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); + batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); n_logits += needs_logits; } } @@ -1653,8 +1632,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par i0 = i1 - 1; } - llama_batch_free(batch); - if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return; float p = 1.f*n_correct/n_done; @@ -1767,7 +1744,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_self_clear(ctx); - llama_batch batch = llama_batch_init(n_batch, 0, 1); + common_batch batch(n_batch, 1); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -1781,14 +1758,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_vocab_bos(vocab); } - common_batch_clear(batch); + batch.clear(); for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true); } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); - llama_batch_free(batch); return; } @@ -1801,8 +1777,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } } - llama_batch_free(batch); - const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 0efe20d4b3f5d..d43270e856554 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -74,40 +74,56 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { +static void batch_add_seq(common_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - common_batch_add(batch, tokens[i], i, { seq_id }, true); + batch.add_text(tokens[i], i, seq_id, true); } } -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { +static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + const struct llama_model * model = llama_get_model(ctx); + // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_decode(ctx, batch) < 0) { - LOG_ERR("%s : failed to decode\n", __func__); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); + if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { + // encoder-only model + if (llama_encode_ext(ctx, batch.get()) < 0) { + LOG_ERR("%s : failed to encode\n", __func__); + } + } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + // decoder-only model + if (llama_decode_ext(ctx, batch.get()) < 0) { + LOG_ERR("%s : failed to decode\n", __func__); + } } - for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { + for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { + if (!batch.tokens[i].logits) { continue; } - // try to get sequence embeddings - supported only when pooling_type is not NONE - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { + const float * embd = nullptr; + int embd_pos = 0; + + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + // try to get token embeddings embd = llama_get_embeddings_ith(ctx, i); - if (embd == NULL) { - LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i); - continue; - } + embd_pos = i; + GGML_ASSERT(embd != NULL && "failed to get token embeddings"); + } else { + // try to get sequence embeddings - supported only when pooling_type is not NONE + embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); + embd_pos = batch.tokens[i].seq_id; + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); } - float * out = output + batch.seq_id[i][0] * n_embd; - common_embd_normalize(embd, out, n_embd, 2); + float * out = output + embd_pos * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); } } @@ -214,7 +230,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_chunks = chunks.size(); - struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + struct common_batch batch = common_batch(n_batch, 1); // allocate output const int n_embd = llama_model_n_embd(model); @@ -231,10 +247,10 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (batch.n_tokens + n_toks > n_batch) { + if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) { float * out = emb + p * n_embd; batch_decode(ctx, batch, out, s, n_embd); - common_batch_clear(batch); + batch.clear(); p += s; s = 0; } @@ -255,7 +271,7 @@ int main(int argc, char ** argv) { chunks[i].tokens.clear(); } - struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); + struct common_batch query_batch = common_batch(n_batch, 1); // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; @@ -269,7 +285,7 @@ int main(int argc, char ** argv) { std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); - common_batch_clear(query_batch); + query_batch.clear(); // compute cosine similarities { @@ -299,6 +315,5 @@ int main(int argc, char ** argv) { llama_perf_context_print(ctx); // clean up - llama_batch_free(query_batch); llama_backend_free(); } diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 437f2533e5777..02cafa9da0a92 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -905,10 +905,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt } // Check if we have enough space in the context to evaluate this batch -static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { +static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) { const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx_used = llama_kv_self_used_cells(ctx.get()); - if (n_ctx_used + batch.n_tokens > n_ctx) { + if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) { printf(LOG_COL_DEFAULT "\n"); printe("context size exceeded\n"); return 1; @@ -946,11 +946,11 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); - if (llama_decode(llama_data.context.get(), batch)) { + if (llama_decode_ext(llama_data.context.get(), batch.get())) { printe("failed to decode\n"); return 1; } @@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str print_word_and_concatenate_to_response(piece, response); // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0)); } printf(LOG_COL_DEFAULT); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 760ebbbf08788..d1cf599b1665b 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,15 +48,11 @@ int main(int argc, char ** argv) { auto tokens = common_tokenize(ctx, params.prompt, true); // prepare the batch - llama_batch batch = llama_batch_init(tokens.size(), 0, 1); - for (size_t i = 0; i < tokens.size(); i++) { - common_batch_add(batch, tokens[i], i, {0}, false); - } - batch.logits[batch.n_tokens - 1] = true; // generate next token + llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0); // evaluate prompt - llama_decode(ctx, batch); - n_past += batch.n_tokens; + llama_decode_ext(ctx, batch); + n_past += llama_batch_ext_get_n_tokens(batch); // save state (rng, logits, embedding and kv_cache) to file { @@ -83,12 +79,13 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result0 += next_token_str; - common_batch_clear(batch); - common_batch_add(batch, next_token, n_past, {0}, true); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch); return 1; } n_past += 1; @@ -135,12 +132,13 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result1 += next_token_str; - common_batch_clear(batch); - common_batch_add(batch, next_token, n_past, {0}, true); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 1; + llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); - if (llama_decode(ctx2, batch)) { + if (llama_decode_ext(ctx2, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch); return 1; } n_past += 1; @@ -216,12 +214,13 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result2 += next_token_str; - common_batch_clear(batch); - common_batch_add(batch, next_token, n_past, {1}, true); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 1; + llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); - if (llama_decode(ctx3, batch)) { + if (llama_decode_ext(ctx3, batch)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch); return 1; } n_past += 1; @@ -233,7 +232,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl2); llama_sampler_free(smpl3); - llama_batch_free(batch); + llama_batch_ext_free(batch); if (result0 != result2) { fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 84f4159737260..cee00ea82421b 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -108,19 +108,20 @@ int main(int argc, char ** argv) { } // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); + llama_batch_ext_set_output_last(batch); llama_token new_token_id; while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); int n_ctx_used = llama_kv_self_used_cells(ctx); - if (n_ctx_used + batch.n_tokens > n_ctx) { + if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); exit(0); } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { GGML_ABORT("failed to decode\n"); } @@ -144,9 +145,13 @@ int main(int argc, char ** argv) { response += piece; // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); } + llama_batch_ext_free(batch); + return response; }; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 10e79a0a69eeb..7b3ba8d815e45 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -143,7 +143,8 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); + llama_batch_ext_set_output_last(batch); // main loop @@ -151,14 +152,14 @@ int main(int argc, char ** argv) { int n_decode = 0; llama_token new_token_id; - for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) { + for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) { // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } - n_pos += batch.n_tokens; + n_pos += llama_batch_ext_get_n_tokens(batch); // sample the next token { @@ -180,7 +181,9 @@ int main(int argc, char ** argv) { fflush(stdout); // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); n_decode += 1; } @@ -198,6 +201,7 @@ int main(int argc, char ** argv) { llama_perf_context_print(ctx); fprintf(stderr, "\n"); + llama_batch_ext_free(batch); llama_sampler_free(smpl); llama_free(ctx); llama_model_free(model); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a5d2bc9d09de7..e61e863ce02eb 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -132,7 +132,7 @@ int main(int argc, char ** argv) { struct common_speculative * spec = common_speculative_init(ctx_dft); - llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1); const auto t_enc_end = ggml_time_us(); @@ -151,8 +151,9 @@ int main(int argc, char ** argv) { //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); // always have a token to evaluate from before - id_last - common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); + llama_batch_ext_clear(batch_tgt); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { @@ -162,12 +163,12 @@ int main(int argc, char ** argv) { } for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); } //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); - llama_decode(ctx_tgt, batch_tgt); + llama_decode_ext(ctx_tgt, batch_tgt); } // sample from the full target batch and return the accepted tokens based on the target sampler @@ -253,6 +254,7 @@ int main(int argc, char ** argv) { common_sampler_free(smpl); common_speculative_free(spec); + llama_batch_ext_free(batch_tgt); llama_backend_free(); LOG("\n\n"); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index bfddc67e034fb..1f55db7b65f53 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -45,7 +45,7 @@ int main(int argc, char ** argv) { } common_init(); - +#ifdef 0 if (params.speculative.model.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; @@ -199,8 +199,8 @@ int main(int argc, char ** argv) { drafts[s].smpl = common_sampler_init(model_dft, params.sampling); } - llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft); + llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1); + llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft); const auto t_dec_start = ggml_time_us(); @@ -441,12 +441,13 @@ int main(int argc, char ** argv) { drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); - common_batch_clear(batch_dft); - common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); + llama_batch_ext_clear(batch_dft); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true); llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - llama_decode(ctx_dft, batch_dft); + llama_decode_ext(ctx_dft, batch_dft); ++n_past_dft; } @@ -471,8 +472,9 @@ int main(int argc, char ** argv) { drafts[0].drafting = true; drafts[0].i_batch_dft = 0; - common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); + llama_batch_ext_clear(batch_tgt); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true); // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { @@ -640,5 +642,6 @@ int main(int argc, char ** argv) { LOG("\n\n"); +#endif return 0; } diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index c658f3182f4c2..32f8c43a8d314 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -817,7 +817,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // create a llama_batch // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel); + llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel); std::vector seq_ids(n_parallel, 0); for (int32_t i = 0; i < n_parallel; ++i) { @@ -826,14 +826,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // evaluate the initial prompt for (size_t i = 0; i < prompt_inp.size(); ++i) { - common_batch_add(batch, prompt_inp[i], i, seq_ids, false); + llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false); } - GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size()); // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); - if (llama_decode(ctx_ttc, batch) != 0) { + if (llama_decode_ext(ctx_ttc, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -852,16 +852,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // remember the batch index of the last token for each parallel sequence // we need this to determine which logits to sample from - std::vector i_batch(n_parallel, batch.n_tokens - 1); + std::vector i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); - int n_past = batch.n_tokens; + int n_past = llama_batch_ext_get_n_tokens(batch); int n_decode = 0; bool next_token_uses_guide_token = true; while (n_decode <= n_predict) { // prepare the next batch - common_batch_clear(batch); + llama_batch_ext_clear(batch); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -917,14 +917,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 //LOG_CNT("%d", i); } - i_batch[i] = batch.n_tokens; + i_batch[i] = llama_batch_ext_get_n_tokens(batch); // push this new token for next evaluation - common_batch_add(batch, new_token_id, n_past, { i }, true); + llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, false); } // all streams are finished - if (batch.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch) == 0) { break; } @@ -932,13 +932,13 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 n_past += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx_ttc, batch)) { + if (llama_decode_ext(ctx_ttc, batch)) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } } - llama_batch_free(batch); + llama_batch_ext_free(batch); LOG("\n"); LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); @@ -1007,14 +1007,15 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 const int n_codes = codes.size(); - llama_batch batch = llama_batch_init(n_codes, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1); for (size_t i = 0; i < codes.size(); ++i) { - common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits? + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits? } - GGML_ASSERT(batch.n_tokens == n_codes); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes); - if (llama_decode(ctx_cts, batch) != 0) { + if (llama_decode_ext(ctx_cts, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -1076,6 +1077,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 LOG_INF("%s: audio written to file '%s'\n", __func__, fname.c_str()); + llama_batch_ext_free(batch); llama_backend_free(); return 0; diff --git a/include/llama.h b/include/llama.h index ee74d9a8c16c2..4521b3a415638 100644 --- a/include/llama.h +++ b/include/llama.h @@ -995,9 +995,9 @@ extern "C" { // Stores the encoder output internally for later use by the decoder cross-attention layers. // 0 - success // < 0 - error. the KV cache state is restored to the state before this call - DEPRECATED(LLAMA_API int32_t llama_encode( + LLAMA_API int32_t llama_encode( struct llama_context * ctx, - struct llama_batch batch), "use llama_batch_ext API instead"); + struct llama_batch batch); LLAMA_API int32_t llama_encode_ext( struct llama_context * ctx, @@ -1007,9 +1007,9 @@ extern "C" { // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // < 0 - error. the KV cache state is restored to the state before this call - DEPRECATED(LLAMA_API int32_t llama_decode( + LLAMA_API int32_t llama_decode( struct llama_context * ctx, - struct llama_batch batch), "use llama_batch_ext API instead"); + struct llama_batch batch); LLAMA_API int32_t llama_decode_ext( struct llama_context * ctx, From 9fb2d81eab08b8242be9be2c1db8f43d6e6e5b2a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 22:38:04 +0100 Subject: [PATCH 15/52] fix common_batch missing seq_id --- common/common.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/common.h b/common/common.h index 94f3352063604..4bcc4a2c44cf6 100644 --- a/common/common.h +++ b/common/common.h @@ -586,6 +586,7 @@ struct common_batch { llama_batch_ext_ptr batch; struct batch_token { llama_token token; + llama_seq_id seq_id; // only support single seq for now bool logits; }; std::vector tokens; @@ -601,14 +602,14 @@ struct common_batch { } void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); - tokens.push_back({token, logits}); + tokens.push_back({token, seq_id, logits}); if (logits) { n_outputs++; } } void add_text(llama_token token, llama_pos pos, std::vector seq_ids, bool logits) { llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); - tokens.push_back({token, logits}); + tokens.push_back({token, seq_ids[0], logits}); if (logits) { n_outputs++; } From 65f0184517aa886443d3cf648031db4df3507946 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 22:56:35 +0100 Subject: [PATCH 16/52] compile ok --- common/common.h | 2 +- examples/llava/llava-cli.cpp | 3 ++- examples/llava/minicpmv-cli.cpp | 3 ++- examples/llava/qwen2vl-cli.cpp | 20 +++++++++++++------ examples/lookahead/lookahead.cpp | 6 ++++-- examples/main/main.cpp | 6 ++++-- examples/perplexity/perplexity.cpp | 15 +++++++------- .../speculative-simple/speculative-simple.cpp | 3 ++- examples/speculative/speculative.cpp | 11 ++++++---- 9 files changed, 43 insertions(+), 26 deletions(-) diff --git a/common/common.h b/common/common.h index 4bcc4a2c44cf6..c223685f21ba8 100644 --- a/common/common.h +++ b/common/common.h @@ -607,7 +607,7 @@ struct common_batch { n_outputs++; } } - void add_text(llama_token token, llama_pos pos, std::vector seq_ids, bool logits) { + void add_text_multi_seq(llama_token token, llama_pos pos, std::vector seq_ids, bool logits) { llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); tokens.push_back({token, seq_ids[0], logits}); if (logits) { diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 40aa0876f24a7..b4b6e63c7552a 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 12f536cf5cfff..adc3a615fe868 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -101,7 +101,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index d65e88f9d12d5..1f4242580469f 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -96,16 +96,24 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - auto batch = llama_batch_get_one(&tokens[i], n_eval); + // TODO: add mrope pos ids somewhere else - pos.resize(batch.n_tokens * 4); + int n_tokens = n_eval; + pos.resize(n_tokens * 4); std::fill(pos.begin(), pos.end(), 0); - for (int j = 0; j < batch.n_tokens * 3; j ++) { - pos[j] = *st_pos_id + (j % batch.n_tokens); + for (int j = 0; j < n_tokens * 3; j ++) { + pos[j] = *st_pos_id + (j % n_tokens); } - batch.pos = pos.data(); - if (llama_decode(ctx_llama, batch)) { + llama_batch_ext_ptr batch(llama_batch_ext_init(n_eval, 1)); + for (int j = 0; j < n_eval; j++) { + llama_token token = tokens[i + j]; + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), token, pos[j], &seq_id, 1, false); + } + llama_batch_ext_set_output_last(batch.get()); + + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 1c2c3ec46c903..1e8de9673edc5 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -92,8 +92,10 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); + llama_decode_ext(ctx, batch0.get()); + llama_decode_ext(ctx, batch1.get()); for (int s = 1; s < W + G + 1; ++s) { llama_kv_self_seq_cp(ctx, 0, s, -1, -1); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fd7410a646c69..2f735c4203e7b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -548,7 +548,8 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0)); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -668,7 +669,8 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index d24fddbf450e1..956c115d40fe3 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -565,7 +565,6 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & } for (int k = 0; k < batch_size; ++k) { - const int idx = seq*n_ctx + k; const llama_pos pos = j*n_batch + k; bool output = pos >= first; batch.add_text(tokens[seq_start + k], pos, seq, output); @@ -876,7 +875,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { } for (size_t i = 0; i < hs_cur.common_prefix; ++i) { - batch.add_text(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); + batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); } llama_batch_ext_set_output_last(batch.get()); n_logits += 1; @@ -886,7 +885,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { // TODO: don't evaluate the last token of each sequence for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - batch.add_text(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); + batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); n_logits += needs_logits; } } @@ -1155,7 +1154,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) } for (size_t i = 0; i < data[i1].common_prefix; ++i) { - batch.add_text(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); + batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); } llama_batch_ext_set_output_last(batch.get()); n_logits += 1; @@ -1163,7 +1162,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) for (int s = 0; s < 2; ++s) { // TODO: end before the last token, no need to predict past the end of the sequences for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { - batch.add_text(data[i1].seq_tokens[s][i], i, { s0 + s }, true); + batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true); n_logits += 1; } } @@ -1523,7 +1522,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par for (size_t i = 0; i < cur_task.common_prefix; ++i) { //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); - batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false); + batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false); } llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix n_logits += 1; @@ -1533,7 +1532,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par // TODO: don't evaluate the last token of each sequence for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - batch.add_text(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); + batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); n_logits += needs_logits; } } @@ -1760,7 +1759,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { batch.clear(); for (int i = 0; i < batch_size; i++) { - batch.add_text(tokens[batch_start + i], j*n_batch + i, {0}, true); + batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true); } if (llama_decode_ext(ctx, batch.get())) { diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index e61e863ce02eb..2f4a85abdb90a 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -113,7 +113,8 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0)); + llama_decode_ext(ctx_tgt, batch.get()); // note: keep the last token separate! llama_token id_last = inp.back(); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 1f55db7b65f53..2d44dc82c2b30 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -45,7 +45,7 @@ int main(int argc, char ** argv) { } common_init(); -#ifdef 0 +#if 0 if (params.speculative.model.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; @@ -166,9 +166,12 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1)); - llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1)); - llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); + llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0)); + llama_decode_ext(ctx_tgt, batch0); + llama_decode_ext(ctx_tgt, batch1); + llama_decode_ext(ctx_dft, batch2); const auto t_enc_end = ggml_time_us(); From c3dd79007bdec0ed9c3097db8bab2aad5b992d73 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 23:09:27 +0100 Subject: [PATCH 17/52] fix llama_batch_ext_init_from_text --- common/common.cpp | 4 ++-- examples/cvector-generator/cvector-generator.cpp | 3 ++- examples/eval-callback/eval-callback.cpp | 2 +- examples/infill/infill.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 5 +++-- examples/llava/llava-cli.cpp | 2 +- examples/llava/minicpmv-cli.cpp | 2 +- examples/lookahead/lookahead.cpp | 4 ++-- examples/lookup/lookup.cpp | 4 ++-- examples/main/main.cpp | 5 +++-- examples/run/run.cpp | 4 ++-- examples/save-load-state/save-load-state.cpp | 2 +- examples/simple-chat/simple-chat.cpp | 8 ++++++-- examples/simple/simple.cpp | 2 +- examples/speculative-simple/speculative-simple.cpp | 2 +- examples/speculative/speculative.cpp | 6 +++--- include/llama.h | 4 +++- src/llama-batch.cpp | 6 +++++- 18 files changed, 40 insertions(+), 27 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index ec4bf699ab808..c7cf665458dc5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1014,7 +1014,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (llama_model_has_encoder(model)) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true)); llama_encode_ext(lctx, batch.get()); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == LLAMA_TOKEN_NULL) { @@ -1024,7 +1024,7 @@ struct common_init_result common_init_from_params(common_params & params) { tmp.push_back(decoder_start_token_id); } if (llama_model_has_decoder(model)) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true)); llama_decode_ext(lctx, batch.get()); } llama_kv_self_clear(lctx); diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index f5ca61c317722..13fa2c44230ea 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_self_clear(ctx); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); + llama_batch_ext_set_output_last(batch.get()); if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 7e600440d839d..47dfd94d21833 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 631d0b07d83c5..2c84ab8e75ef0 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -353,7 +353,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 730e994b2ff07..6a6ab4ab28a95 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1444,7 +1444,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true)); + llama_batch_ext_set_output_last(batch.get()); llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1462,7 +1463,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true)); llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index b4b6e63c7552a..233480354f9f2 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index adc3a615fe868..0740b4b4f962e 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 1e8de9673edc5..88d0b1606b528 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -92,8 +92,8 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index a6bf80fdf77f8..0e885fa4174a2 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -91,8 +91,8 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2f735c4203e7b..8caf1ae3b6ce3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -548,7 +548,7 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0, true)); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; @@ -669,7 +669,8 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); + llama_batch_ext_set_output_last(batch.get()); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 02cafa9da0a92..d7faa1472e022 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -946,7 +946,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); @@ -969,7 +969,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str print_word_and_concatenate_to_response(piece, response); // prepare the next batch with the sampled token - batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0)); + batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0, true)); } printf(LOG_COL_DEFAULT); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index d1cf599b1665b..6ab35133bb427 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,7 +48,7 @@ int main(int argc, char ** argv) { auto tokens = common_tokenize(ctx, params.prompt, true); // prepare the batch - llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0); + llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true); // evaluate prompt llama_decode_ext(ctx, batch); diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index cee00ea82421b..0c2d34d563d4b 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -108,8 +108,11 @@ int main(int argc, char ** argv) { } // prepare a batch for the prompt - llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); + llama_pos n_past = 0; + llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); llama_batch_ext_set_output_last(batch); + n_past += llama_batch_ext_get_n_tokens(batch); + llama_token new_token_id; while (true) { // check if we have enough space in the context to evaluate this batch @@ -147,7 +150,8 @@ int main(int argc, char ** argv) { // prepare the next batch with the sampled token llama_batch_ext_clear(batch); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); + llama_batch_ext_add_text(batch, new_token_id, n_past, &seq_id, 1, true); + n_past++; } llama_batch_ext_free(batch); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 7b3ba8d815e45..9101cc6bbb4d0 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -143,7 +143,7 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt - llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0); + llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true); llama_batch_ext_set_output_last(batch); // main loop diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 2f4a85abdb90a..61b9af2f0f7db 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0, true)); llama_decode_ext(ctx_tgt, batch.get()); // note: keep the last token separate! diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 2d44dc82c2b30..2812846d1b9f5 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -166,9 +166,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0)); - llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); + llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); llama_decode_ext(ctx_tgt, batch0); llama_decode_ext(ctx_tgt, batch1); llama_decode_ext(ctx_dft, batch2); diff --git a/include/llama.h b/include/llama.h index 4521b3a415638..5864519fd0fcc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -928,12 +928,14 @@ extern "C" { // Same with llama_batch_init, but initializes the batch with the provided text tokens // First token will be at position pos0 // The sequence ID will be fixed to seq_id + // If output_last is true, the last token will have output set // The batch has to be freed with llama_batch_ext_free() LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text( llama_token * tokens, int32_t n_tokens, int32_t pos0, - int32_t seq_id); + int32_t seq_id, + bool output_last); // Same with llama_batch_init, but initializes the batch with the provided raw embeddings // First token will be at position pos0 diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index bae8b37b3fc1f..80f1592e9d0db 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -341,11 +341,15 @@ struct llama_batch_ext * llama_batch_ext_init_from_text( llama_token * tokens, int32_t n_tokens, int32_t pos0, - int32_t seq_id) { + int32_t seq_id, + bool output_last) { llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); for (int32_t i = 0; i < n_tokens; i++) { llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); } + if (output_last) { + llama_batch_ext_set_output_last(batch); + } return batch; } From 04f8641815bee928180e300e15502581d8a1d553 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 13 Mar 2025 23:14:16 +0100 Subject: [PATCH 18/52] rm redundant llama_batch_ext_set_output_last --- examples/cvector-generator/cvector-generator.cpp | 1 - examples/llama-bench/llama-bench.cpp | 1 - examples/simple-chat/simple-chat.cpp | 1 - examples/simple/simple.cpp | 1 - 4 files changed, 4 deletions(-) diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 13fa2c44230ea..b3236ea854f93 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -344,7 +344,6 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_self_clear(ctx); llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); - llama_batch_ext_set_output_last(batch.get()); if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 6a6ab4ab28a95..bf39134d015d7 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1445,7 +1445,6 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th tokens[i] = std::rand() % n_vocab; } llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true)); - llama_batch_ext_set_output_last(batch.get()); llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 0c2d34d563d4b..dbde1ee9e88d6 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -110,7 +110,6 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt llama_pos n_past = 0; llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); - llama_batch_ext_set_output_last(batch); n_past += llama_batch_ext_get_n_tokens(batch); llama_token new_token_id; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 9101cc6bbb4d0..4aea9dbdc531f 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -144,7 +144,6 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true); - llama_batch_ext_set_output_last(batch); // main loop From 54566ad95db209d990b600b9597a2618f104dd77 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 00:21:06 +0100 Subject: [PATCH 19/52] correct comment --- include/llama.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/llama.h b/include/llama.h index 5864519fd0fcc..fb6edda8eba55 100644 --- a/include/llama.h +++ b/include/llama.h @@ -980,7 +980,7 @@ extern "C" { LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch); // Get a "view" from a number of tokens offset - // Return returned batch must be freed with llama_batch_free() + // Return returned batch must be freed with llama_batch_ext_free() LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view( struct llama_batch_ext * batch, int32_t offset, From bfdddbc150734ac69689edb5287197a6a0db3e89 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 00:22:28 +0100 Subject: [PATCH 20/52] bring back mistakenly deleted llama_batch_init/free --- src/llama-batch.cpp | 46 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 80f1592e9d0db..d55625da1c335 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -503,3 +503,49 @@ void llama_batch_ext_free(struct llama_batch_ext * batch) { } delete batch; } + +// deprecated +struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { + llama_batch batch = { + /*n_tokens =*/ 0, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; + + if (embd) { + batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + } else { + batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + } + + batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + for (int i = 0; i < n_tokens_alloc; ++i) { + batch.seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens_alloc] = nullptr; + + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + + return batch; +} + +// deprecated +void llama_batch_free(struct llama_batch batch) { + if (batch.token) free(batch.token); + if (batch.embd) free(batch.embd); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i] != nullptr; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} From 5e6a6d4e1c24f28e284ebbd6dd1618dd048ebbd4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 10:32:43 +0100 Subject: [PATCH 21/52] fix llama-run n_past --- examples/run/run.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/run/run.cpp b/examples/run/run.cpp index d7faa1472e022..39026813bf3b6 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -595,6 +595,7 @@ class LlamaData { std::vector messages; // TODO: switch to common_chat_msg std::list msg_strs; std::vector fmtted; + llama_pos n_past = 0; int init(Opt & opt) { model = initialize_model(opt); @@ -946,7 +947,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true)); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); @@ -955,6 +956,8 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str return 1; } + llama_data.n_past += llama_batch_ext_get_n_tokens(batch.get()); + // sample the next token, check is it an end of generation? new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); if (llama_vocab_is_eog(vocab, new_token_id)) { @@ -969,7 +972,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str print_word_and_concatenate_to_response(piece, response); // prepare the next batch with the sampled token - batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, 0, 0, true)); + batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, llama_data.n_past, 0, true)); } printf(LOG_COL_DEFAULT); From 32940369d390db4c052b2c4d5f9d91354a86a08a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 10:33:28 +0100 Subject: [PATCH 22/52] fix gemma3-cli --- examples/llava/gemma3-cli.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index 9aa71065249e1..e2fdfcc288668 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -93,7 +93,7 @@ static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = llama_batch_ext_clear(ctx.batch.get()); for (llama_token & t : tokens) { llama_seq_id seq_id = 0; - llama_batch_ext_add_text(ctx.batch.get(), t, 0, &seq_id, 1, false); + llama_batch_ext_add_text(ctx.batch.get(), t, ctx.n_past++, &seq_id, 1, false); } if (logits_last) { llama_batch_ext_set_output_last(ctx.batch.get()); From 07d84fa3c2b9386a9878ce22ef363ef7bb0c12c9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 10:47:08 +0100 Subject: [PATCH 23/52] fix missing n_past in various places this is actually a revert of https://github.com/ggml-org/llama.cpp/commit/cda0e4b648dde8fac162b3430b14a99597d3d74f --- examples/llama-bench/llama-bench.cpp | 16 ++++++++-------- examples/lookahead/lookahead.cpp | 4 ++-- examples/lookup/lookup.cpp | 4 ++-- examples/save-load-state/save-load-state.cpp | 4 ++-- examples/simple/simple.cpp | 2 +- examples/speculative/speculative.cpp | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index bf39134d015d7..992df2b516f16 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1427,7 +1427,7 @@ struct sql_printer : public printer { } }; -static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { +static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, 0, 0, true)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true)); llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1452,7 +1452,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th llama_synchronize(ctx); } -static void test_gen(llama_context * ctx, int n_gen, int n_threads) { +static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, 0, 0, true)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, n_past + i, 0, true)); llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; @@ -1610,13 +1610,13 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count); } //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count); } - test_gen(ctx, 1, t.n_threads); + test_gen(ctx, 1, 0, t.n_threads); } for (int i = 0; i < params.reps; i++) { @@ -1629,14 +1629,14 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_gen(ctx, t.n_gen, t.n_threads); + test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); } uint64_t t_ns = get_time_ns() - t_start; diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 88d0b1606b528..8277559689074 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -92,8 +92,8 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 0e885fa4174a2..07e57afcbab9f 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -91,8 +91,8 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 6ab35133bb427..2ff4e24c19c1e 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -133,7 +133,7 @@ int main(int argc, char ** argv) { result1 += next_token_str; llama_batch_ext_clear(batch); - llama_seq_id seq_id = 1; + llama_seq_id seq_id = 0; llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); if (llama_decode_ext(ctx2, batch)) { @@ -215,7 +215,7 @@ int main(int argc, char ** argv) { result2 += next_token_str; llama_batch_ext_clear(batch); - llama_seq_id seq_id = 1; + llama_seq_id seq_id = 1; // seq 1 instead of 0 llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); if (llama_decode_ext(ctx3, batch)) { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 4aea9dbdc531f..26009a5aec398 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -182,7 +182,7 @@ int main(int argc, char ** argv) { // prepare the next batch with the sampled token llama_batch_ext_clear(batch); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, new_token_id, 0, &seq_id, 1, true); + llama_batch_ext_add_text(batch, new_token_id, n_pos, &seq_id, 1, true); n_decode += 1; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 2812846d1b9f5..4d987332a5422 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -166,9 +166,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, 0, 0, true)); - llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); + llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); + llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); + llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); llama_decode_ext(ctx_tgt, batch0); llama_decode_ext(ctx_tgt, batch1); llama_decode_ext(ctx_dft, batch2); From ba79369615f84dd07573e9e86c3d228f0aec7c63 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 11:17:22 +0100 Subject: [PATCH 24/52] fix llama_batch_ext_init_from_embd --- examples/llava/gemma3-cli.cpp | 2 +- examples/llava/llava.cpp | 2 +- include/llama.h | 3 +++ src/llama-batch.cpp | 17 +++++++++-------- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index e2fdfcc288668..3efa604b935b6 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -148,7 +148,7 @@ static int eval_image(gemma3_context & ctx, std::string & fname) { int64_t t1 = ggml_time_ms(); eval_text(ctx, ""); llama_set_causal_attn(ctx.lctx, false); - llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, ctx.n_past, 0)); + llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0)); if (llama_decode_ext(ctx.lctx, batch_img.get())) { LOG_ERR("failed to decode image\n"); return 1; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 53ce30215508b..de967e0699680 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, 0, 0)); + llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, n_embd, 0, 0)); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/include/llama.h b/include/llama.h index fb6edda8eba55..2f58085fc8fef 100644 --- a/include/llama.h +++ b/include/llama.h @@ -938,11 +938,14 @@ extern "C" { bool output_last); // Same with llama_batch_init, but initializes the batch with the provided raw embeddings + // Size of embd should be n_tokens * n_embd + // n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd() // First token will be at position pos0 // The sequence ID will be fixed to seq_id // The batch has to be freed with llama_batch_ext_free() LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, + size_t n_tokens, size_t n_embd, int32_t pos0, int32_t seq_id); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index d55625da1c335..a7f2717f1fd34 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -353,7 +353,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_text( return batch; } -static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { +static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) { llama_batch_ext * batch = new llama_batch_ext{ /*n_tokens =*/ 0, /*max_tokens =*/ n_tokens_alloc, @@ -366,8 +366,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc /*logits =*/ nullptr, }; - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd); + if (n_embd) { + batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd); } else { batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); } @@ -391,14 +391,15 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_ struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, + size_t n_tokens, size_t n_embd, int32_t pos0, int32_t seq_id) { - struct llama_batch_ext * batch = llama_batch_ext_init_impl(0, n_embd, 1); - memcpy(batch->embd, embd, n_embd * sizeof(float)); - for (size_t i = 0; i < n_embd; i++) { - batch->pos [i] = pos0 + i; - batch->n_seq_id[i] = 1; + struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1); + memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float)); + for (size_t i = 0; i < n_tokens; i++) { + batch->pos [i] = pos0 + i; + batch->n_seq_id[i] = 1; batch->seq_id [i][0] = seq_id; } return batch; From a363251fac38be94d27f4c63c1765efa65d9d0d2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 11:25:36 +0100 Subject: [PATCH 25/52] qwen2vl: use llama_batch_ext_set_pos --- examples/llava/qwen2vl-cli.cpp | 17 +++++------------ include/llama.h | 6 ++++++ src/llama-batch.cpp | 8 ++++++++ 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 1f4242580469f..a702ab46adb86 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -66,18 +66,11 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); - // TODO: move this to llama_batch_ext API - llama_batch batch = { - int32_t(n_eval), // n_tokens - nullptr, // token - (image_embed->embed+i*n_embd), // embed - batch_mrope_pos.data(), // pos - nullptr, // n_seq_id - nullptr, // seq_id - nullptr, // logits - }; - - if (llama_decode(ctx_llama, batch)) { + float * batch_embd = image_embed->embed+i*n_embd; + llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(batch_embd, n_eval, n_embd, 0, 0)); + llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval); + + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/include/llama.h b/include/llama.h index 2f58085fc8fef..28fb82606446c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -950,6 +950,12 @@ extern "C" { int32_t pos0, int32_t seq_id); + // Set arbitrary token to the embeddings batch + // Note: this is only to be used in conjunction with llama_batch_ext_init_from_embd() + // n_pos must match the n_tokens of the batch + // Returns -1 if n_pos does not match the n_tokens of the batch + LLAMA_API int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos); + // Get the number of tokens in the batch LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index a7f2717f1fd34..f56b3b03b80d2 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -405,6 +405,14 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd( return batch; } +int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) { + if (batch->n_tokens != n_pos) { + return -1; + } + memcpy(batch->pos, pos, n_pos * sizeof(llama_pos)); + return 0; +} + int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { return batch->n_tokens; } From 8e7714fa777bf4bc54f63825827e402abdb1b642 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 11:28:15 +0100 Subject: [PATCH 26/52] fix compile --- src/llama-batch.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index f56b3b03b80d2..0455db9d0617d 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -406,7 +406,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd( } int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) { - if (batch->n_tokens != n_pos) { + if ((size_t) batch->n_tokens != n_pos) { return -1; } memcpy(batch->pos, pos, n_pos * sizeof(llama_pos)); From eaffba0f2ed6f18402f95e29de36c48018967f6f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 17:12:03 +0100 Subject: [PATCH 27/52] llama_batch_ext_ptr::from_text/embd --- .../cvector-generator/cvector-generator.cpp | 2 +- examples/eval-callback/eval-callback.cpp | 2 +- examples/infill/infill.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 4 ++-- examples/llava/llava-cli.cpp | 2 +- examples/llava/llava.cpp | 2 +- examples/llava/minicpmv-cli.cpp | 2 +- examples/llava/qwen2vl-cli.cpp | 2 +- examples/main/main.cpp | 4 ++-- examples/run/run.cpp | 2 +- .../speculative-simple/speculative-simple.cpp | 2 +- include/llama-cpp.h | 22 ++++++++++++++++++- 12 files changed, 34 insertions(+), 14 deletions(-) diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index b3236ea854f93..e0b647632445e 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -343,7 +343,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_self_clear(ctx); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 47dfd94d21833..86a36223dea9d 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 2c84ab8e75ef0..574ef644f28fe 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -353,7 +353,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&embd[i], n_eval, n_past, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 992df2b516f16..0ed841f09af33 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true); llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&token, 1, n_past + i, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&token, 1, n_past + i, 0, true); llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 233480354f9f2..ed4326f87610e 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index de967e0699680..901061ca3e2e7 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(embd, n_eval, n_embd, 0, 0)); + auto batch = llama_batch_ext_ptr::from_embd(embd, n_eval, n_embd, 0, 0); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 0740b4b4f962e..2a725d384d9ff 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&tokens[i], n_eval, *n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index a702ab46adb86..c655fd7a28021 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -67,7 +67,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); float * batch_embd = image_embed->embed+i*n_embd; - llama_batch_ext_ptr batch(llama_batch_ext_init_from_embd(batch_embd, n_eval, n_embd, 0, 0)); + auto batch = llama_batch_ext_ptr::from_embd(batch_embd, n_eval, n_embd, 0, 0); llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval); if (llama_decode_ext(ctx_llama, batch.get())) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8caf1ae3b6ce3..1ec5a51aa205f 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -548,7 +548,7 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(enc_input_buf, enc_input_size, 0, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(enc_input_buf, enc_input_size, 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; @@ -669,7 +669,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(&embd[i], n_eval, n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(&embd[i], n_eval, n_past, 0, true); llama_batch_ext_set_output_last(batch.get()); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 39026813bf3b6..aac2f39009352 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -947,7 +947,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 61b9af2f0f7db..b15593b150cba 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(inp.data(), inp.size() - 1, 0, 0, true)); + auto batch = llama_batch_ext_ptr::from_text(inp.data(), inp.size() - 1, 0, 0, true); llama_decode_ext(ctx_tgt, batch.get()); // note: keep the last token separate! diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 880a6a5fae8f5..dfced7ef9bbc9 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -32,4 +32,24 @@ typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; typedef std::unique_ptr llama_adapter_lora_ptr; -typedef std::unique_ptr llama_batch_ext_ptr; + +struct llama_batch_ext_ptr : std::unique_ptr { + llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr(batch) {} + + // convience function to create a batch from text tokens, without worrying about manually freeing it + static llama_batch_ext_ptr from_text(llama_token * tokens, + int32_t n_tokens, + int32_t pos0, + int32_t seq_id, + bool output_last) { + return llama_batch_ext_ptr(llama_batch_ext_init_from_text(tokens, n_tokens, pos0, seq_id, output_last)); + } + + static llama_batch_ext_ptr from_embd(float * embd, + size_t n_tokens, + size_t n_embd, + int32_t pos0, + int32_t seq_id) { + return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(embd, n_tokens, n_embd, pos0, seq_id)); + } +}; From 116b9a1662281c09c7e4b144e3654dbe3ce2227f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 14 Mar 2025 22:17:07 +0100 Subject: [PATCH 28/52] rename to init_from_text --- examples/cvector-generator/cvector-generator.cpp | 2 +- examples/eval-callback/eval-callback.cpp | 2 +- examples/infill/infill.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 4 ++-- examples/llava/llava-cli.cpp | 2 +- examples/llava/llava.cpp | 2 +- examples/llava/minicpmv-cli.cpp | 2 +- examples/llava/qwen2vl-cli.cpp | 2 +- examples/main/main.cpp | 4 ++-- examples/run/run.cpp | 2 +- examples/speculative-simple/speculative-simple.cpp | 2 +- include/llama-cpp.h | 7 ++++--- 12 files changed, 17 insertions(+), 16 deletions(-) diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index e0b647632445e..6b25dc1db6efe 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -343,7 +343,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_self_clear(ctx); - auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 86a36223dea9d..21ca9b4ceec61 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); - auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 574ef644f28fe..29cba998968e7 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -353,7 +353,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - auto batch = llama_batch_ext_ptr::from_text(&embd[i], n_eval, n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 0ed841f09af33..c671194c77864 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - auto batch = llama_batch_ext_ptr::from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true); llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - auto batch = llama_batch_ext_ptr::from_text(&token, 1, n_past + i, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(&token, 1, n_past + i, 0, true); llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index ed4326f87610e..1fa72a24d8a63 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 901061ca3e2e7..eda96e19f1b20 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - auto batch = llama_batch_ext_ptr::from_embd(embd, n_eval, n_embd, 0, 0); + auto batch = llama_batch_ext_ptr::init_from_embd(embd, n_eval, n_embd, 0, 0); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 2a725d384d9ff..81fbc247af292 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - auto batch = llama_batch_ext_ptr::from_text(&tokens[i], n_eval, *n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index c655fd7a28021..d4fcabb1081e9 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -67,7 +67,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); float * batch_embd = image_embed->embed+i*n_embd; - auto batch = llama_batch_ext_ptr::from_embd(batch_embd, n_eval, n_embd, 0, 0); + auto batch = llama_batch_ext_ptr::init_from_embd(batch_embd, n_eval, n_embd, 0, 0); llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval); if (llama_decode_ext(ctx_llama, batch.get())) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1ec5a51aa205f..0d264f6534d5a 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -548,7 +548,7 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - auto batch = llama_batch_ext_ptr::from_text(enc_input_buf, enc_input_size, 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(enc_input_buf, enc_input_size, 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; @@ -669,7 +669,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - auto batch = llama_batch_ext_ptr::from_text(&embd[i], n_eval, n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true); llama_batch_ext_set_output_last(batch.get()); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index aac2f39009352..91edd741ce0de 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -947,7 +947,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - auto batch = llama_batch_ext_ptr::from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index b15593b150cba..74abd98d75e68 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - auto batch = llama_batch_ext_ptr::from_text(inp.data(), inp.size() - 1, 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(inp.data(), inp.size() - 1, 0, 0, true); llama_decode_ext(ctx_tgt, batch.get()); // note: keep the last token separate! diff --git a/include/llama-cpp.h b/include/llama-cpp.h index dfced7ef9bbc9..efc5074df93e5 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -36,8 +36,8 @@ typedef std::unique_ptr llama_ad struct llama_batch_ext_ptr : std::unique_ptr { llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr(batch) {} - // convience function to create a batch from text tokens, without worrying about manually freeing it - static llama_batch_ext_ptr from_text(llama_token * tokens, + // convenience function to create a batch from text tokens, without worrying about manually freeing it + static llama_batch_ext_ptr init_from_text(llama_token * tokens, int32_t n_tokens, int32_t pos0, int32_t seq_id, @@ -45,7 +45,8 @@ struct llama_batch_ext_ptr : std::unique_ptr Date: Fri, 14 Mar 2025 22:30:29 +0100 Subject: [PATCH 29/52] fix compile --- include/llama-cpp.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/llama-cpp.h b/include/llama-cpp.h index efc5074df93e5..fee15ef9c2bae 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -34,6 +34,7 @@ typedef std::unique_ptr llama_sampler_ptr; typedef std::unique_ptr llama_adapter_lora_ptr; struct llama_batch_ext_ptr : std::unique_ptr { + llama_batch_ext_ptr() : std::unique_ptr() {} llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr(batch) {} // convenience function to create a batch from text tokens, without worrying about manually freeing it From de788e071bc397607413a3dc3f3018049c272ac5 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 17 Mar 2025 12:05:23 +0100 Subject: [PATCH 30/52] Update examples/tts/tts.cpp Co-authored-by: Georgi Gerganov --- examples/tts/tts.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 32f8c43a8d314..9f510d004c3da 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -920,7 +920,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 i_batch[i] = llama_batch_ext_get_n_tokens(batch); // push this new token for next evaluation - llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, false); + llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, true); } // all streams are finished From eab5606d7b7e7d66d70ab7bc3f8cc0d4a2633711 Mon Sep 17 00:00:00 2001 From: Xuan-Son Nguyen Date: Mon, 17 Mar 2025 12:17:14 +0100 Subject: [PATCH 31/52] Apply suggestions from code review --- examples/batched-bench/batched-bench.cpp | 2 +- examples/batched/batched.cpp | 2 +- examples/main/main.cpp | 1 - 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 8f7c2c94b8964..063b5ca8bc84a 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -147,7 +147,7 @@ int main(int argc, char ** argv) { llama_batch_ext_clear(batch); for (int j = 0; j < pl; ++j) { - llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, false); + llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, true); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 1ed189859d4d0..9f169b41b505a 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -196,7 +196,7 @@ int main(int argc, char ** argv) { i_batch[i] = llama_batch_ext_get_n_tokens(batch); // push this new token for next evaluation - llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, false); + llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, true); n_decode += 1; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 0d264f6534d5a..4a779e3601bd0 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -670,7 +670,6 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true); - llama_batch_ext_set_output_last(batch.get()); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; From 7a3c178d788e98e59f4d2fe66a23ac7f9b39ded1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Mar 2025 16:10:26 +0200 Subject: [PATCH 32/52] speculative : adapt to new llama API ggml-ci --- examples/speculative/speculative.cpp | 63 ++++++++++++++++------------ 1 file changed, 36 insertions(+), 27 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 4d987332a5422..ff5eceb643208 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -45,7 +45,6 @@ int main(int argc, char ** argv) { } common_init(); -#if 0 if (params.speculative.model.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; @@ -169,9 +168,9 @@ int main(int argc, char ** argv) { llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); - llama_decode_ext(ctx_tgt, batch0); - llama_decode_ext(ctx_tgt, batch1); - llama_decode_ext(ctx_dft, batch2); + llama_decode_ext(ctx_tgt, batch0.get()); + llama_decode_ext(ctx_tgt, batch1.get()); + llama_decode_ext(ctx_dft, batch2.get()); const auto t_enc_end = ggml_time_us(); @@ -338,7 +337,7 @@ int main(int argc, char ** argv) { if (i == s) { continue; } - if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { + if (drafts[i].active && drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) { // synchronize active status for sequences with the same drafted token drafts[i].active = drafts[i].active && accept; if (!drafts[i].active) { @@ -446,7 +445,7 @@ int main(int argc, char ** argv) { llama_batch_ext_clear(batch_dft); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch_tgt, token_id, n_past_tgt, &seq_id, 1, true); + llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true); llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); @@ -475,13 +474,19 @@ int main(int argc, char ** argv) { drafts[0].drafting = true; drafts[0].i_batch_dft = 0; - llama_batch_ext_clear(batch_tgt); - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch_tgt, drafts[0].tokens[0], n_past_tgt, &seq_id, 1, true); + struct batch_info { + llama_token id; + llama_pos pos; + std::vector seq_id; + }; + + std::vector batch_tgt_data; + + batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} }); // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { - batch_dft.n_tokens = 0; + llama_batch_ext_clear(batch_dft); for (int s = 0; s < n_seq_dft; ++s) { drafts[s].skip = false; @@ -512,11 +517,10 @@ int main(int argc, char ** argv) { llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch - for (int t = 0; t < batch_tgt.n_tokens; ++t) { - for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) { - if (batch_tgt.seq_id[t][p] == s) { - batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur; - batch_tgt.n_seq_id[t]++; + for (int t = 0; t < (int) batch_tgt_data.size(); ++t) { + for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) { + if (batch_tgt_data[t].seq_id[p] == s) { + batch_tgt_data[t].seq_id.push_back(n_seq_cur); break; } } @@ -558,32 +562,30 @@ int main(int argc, char ** argv) { drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size}); // add unique drafted tokens to the target batch - drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); + drafts[s].i_batch_tgt.push_back(batch_tgt_data.size()); - common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); + batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }}); // add the token to the batch for batched decoding with the draft model - drafts[s].i_batch_dft = batch_dft.n_tokens; - - common_batch_add(batch_dft, id, n_past_cur, { s }, true); + drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true); - if (batch_tgt.n_tokens > n_draft) { + if (batch_tgt_data.size() > (size_t) n_draft) { drafts[s].drafting = false; } } } // no sequence is drafting anymore - if (batch_dft.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch_dft) == 0) { break; } // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch_dft); + llama_decode_ext(ctx_dft, batch_dft); ++n_past_cur; ++n_drafted; - if (batch_tgt.n_tokens > n_draft) { + if (batch_tgt_data.size() > (size_t) n_draft) { break; } } @@ -595,8 +597,15 @@ int main(int argc, char ** argv) { llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1); } + llama_batch_ext_clear(batch_tgt); + for (int i = 0; i < (int) batch_tgt_data.size(); ++i) { + const auto & data = batch_tgt_data[i]; + + llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true); + } + // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); - llama_decode(ctx_tgt, batch_tgt); + llama_decode_ext(ctx_tgt, batch_tgt); ++n_past_tgt; } @@ -639,12 +648,12 @@ int main(int argc, char ** argv) { common_sampler_free(drafts[s].smpl); } - llama_batch_free(batch_dft); + llama_batch_ext_free(batch_dft); + llama_batch_ext_free(batch_tgt); llama_backend_free(); LOG("\n\n"); -#endif return 0; } From b0db7fc2c60158d54a9b8986a334b00b659c9e56 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 19 Mar 2025 10:16:55 +0200 Subject: [PATCH 33/52] android : adapt to new API --- common/common.cpp | 29 ------- common/common.h | 11 --- .../llama/src/main/cpp/llama-android.cpp | 85 +++++++------------ 3 files changed, 31 insertions(+), 94 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 92f2c57cc9d19..f8498f01d6f71 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1576,35 +1576,6 @@ std::pair common_get_hf_file(const std::string &, cons #endif // LLAMA_USE_CURL -// -// Batch utils -// - -// DEPRECATED -void common_batch_clear(struct llama_batch & batch) { - batch.n_tokens = 0; -} - -// DEPRECATED -void common_batch_add( - struct llama_batch & batch, - llama_token id, - llama_pos pos, - const std::vector & seq_ids, - bool logits) { - GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); - - batch.token [batch.n_tokens] = id; - batch.pos [batch.n_tokens] = pos; - batch.n_seq_id[batch.n_tokens] = seq_ids.size(); - for (size_t i = 0; i < seq_ids.size(); ++i) { - batch.seq_id[batch.n_tokens][i] = seq_ids[i]; - } - batch.logits [batch.n_tokens] = logits; - - batch.n_tokens++; -} - // // Token utils // diff --git a/common/common.h b/common/common.h index c223685f21ba8..5fe149ff8c991 100644 --- a/common/common.h +++ b/common/common.h @@ -569,17 +569,6 @@ void common_set_adapter_lora(struct llama_context * ctx, std::vector & seq_ids, - bool logits); - // convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions // this is meant to be temporary struct common_batch { diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 9654cd53cf8d5..9bf7db399b408 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -125,7 +125,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; - llama_context * context = llama_new_context_with_model(model, ctx_params); + llama_context * context = llama_init_from_model(model, ctx_params); if (!context) { LOGe("llama_new_context_with_model() returned null)"); @@ -175,7 +175,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto context = reinterpret_cast(context_pointer); const auto model = reinterpret_cast(model_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); const int n_ctx = llama_n_ctx(context); @@ -186,19 +186,20 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( for (nri = 0; nri < nr; nri++) { LOGi("Benchmark prompt processing (pp)"); - common_batch_clear(*batch); + llama_batch_ext_clear(batch); const int n_tokens = pp; for (i = 0; i < n_tokens; i++) { - common_batch_add(*batch, 0, i, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false); } - batch->logits[batch->n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); llama_kv_self_clear(context); const auto t_pp_start = ggml_time_us(); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during prompt processing"); + if (llama_decode_ext(context, batch) != 0) { + LOGi("llama_decode_ext() failed during prompt processing"); } const auto t_pp_end = ggml_time_us(); @@ -210,14 +211,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { - common_batch_clear(*batch); + llama_batch_ext_clear(batch); for (j = 0; j < pl; j++) { - common_batch_add(*batch, 0, i, { j }, true); + llama_seq_id seq_id = j; + llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, true); } - LOGi("llama_decode() text generation: %d", i); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during text generation"); + LOGi("llama_decode_ext() text generation: %d", i); + if (llama_decode_ext(context, batch) != 0) { + LOGi("llama_decode_ext() failed during text generation"); } } @@ -272,32 +274,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( extern "C" JNIEXPORT jlong JNICALL Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { - - // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - - llama_batch *batch = new llama_batch { - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - }; - - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); - } else { - batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - } - - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); - batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); - for (int i = 0; i < n_tokens; ++i) { - batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); - } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + llama_batch_ext * batch = llama_batch_ext_init(n_tokens, n_seq_max); return reinterpret_cast(batch); } @@ -305,9 +282,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - //llama_batch_free(*reinterpret_cast(batch_pointer)); - const auto batch = reinterpret_cast(batch_pointer); - delete batch; + llama_batch_ext_free(reinterpret_cast(batch_pointer)); } extern "C" @@ -355,7 +330,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( const auto text = env->GetStringUTFChars(jtext, 0); const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); bool parse_special = (format_chat == JNI_TRUE); const auto tokens_list = common_tokenize(context, text, true, parse_special); @@ -363,7 +338,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( auto n_ctx = llama_n_ctx(context); auto n_kv_req = tokens_list.size() + n_len; - LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); + LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", (int) n_len, (int) n_ctx, (int) n_kv_req); if (n_kv_req > n_ctx) { LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); @@ -373,23 +348,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id); } - common_batch_clear(*batch); + llama_batch_ext_clear(batch); // evaluate the initial prompt for (auto i = 0; i < tokens_list.size(); i++) { - common_batch_add(*batch, tokens_list[i], i, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, tokens_list[i], i, &seq_id, 1, false); } // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() failed"); + if (llama_decode_ext(context, batch) != 0) { + LOGe("llama_decode_ext() failed"); } env->ReleaseStringUTFChars(jtext, text); - return batch->n_tokens; + return llama_batch_ext_get_n_tokens(batch); } extern "C" @@ -404,7 +380,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); const auto sampler = reinterpret_cast(sampler_pointer); const auto model = llama_get_model(context); const auto vocab = llama_model_get_vocab(model); @@ -433,13 +409,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( new_token = env->NewStringUTF(""); } - common_batch_clear(*batch); - common_batch_add(*batch, new_token_id, n_cur, { 0 }, true); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, new_token_id, n_cur, &seq_id, 1, true); env->CallVoidMethod(intvar_ncur, la_int_var_inc); - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() returned null"); + if (llama_decode_ext(context, batch) != 0) { + LOGe("llama_decode_ext() returned null"); } return new_token; From 96ca6e8d23c046aae5df8fc2a534e15ff18398cd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 19 Mar 2025 10:48:42 +0200 Subject: [PATCH 34/52] swift : adapt to new API --- .../llama.cpp.swift/LibLlama.swift | 64 +++++++------------ 1 file changed, 24 insertions(+), 40 deletions(-) diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index f6e31abc93c09..d04c6353eec1d 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -5,35 +5,19 @@ enum LlamaError: Error { case couldNotInitializeContext } -func llama_batch_clear(_ batch: inout llama_batch) { - batch.n_tokens = 0 -} - -func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) { - batch.token [Int(batch.n_tokens)] = id - batch.pos [Int(batch.n_tokens)] = pos - batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count) - for i in 0.. - private var batch: llama_batch + private var batch: OpaquePointer private var tokens_list: [llama_token] var is_done: Bool = false /// This variable is used to store temporarily invalid cchars private var temporary_invalid_cchars: [CChar] - var n_len: Int32 = 1024 + var n_len: Int32 = 128 var n_cur: Int32 = 0 var n_decode: Int32 = 0 @@ -42,7 +26,7 @@ actor LlamaContext { self.model = model self.context = context self.tokens_list = [] - self.batch = llama_batch_init(512, 0, 1) + self.batch = llama_batch_ext_init(512, 1) self.temporary_invalid_cchars = [] let sparams = llama_sampler_chain_default_params() self.sampling = llama_sampler_chain_init(sparams) @@ -53,7 +37,7 @@ actor LlamaContext { deinit { llama_sampler_free(sampling) - llama_batch_free(batch) + llama_batch_ext_free(batch) llama_model_free(model) llama_free(context) llama_backend_free() @@ -111,7 +95,7 @@ actor LlamaContext { } func get_n_tokens() -> Int32 { - return batch.n_tokens; + return llama_batch_ext_get_n_tokens(batch) } func completion_init(text: String) { @@ -133,25 +117,25 @@ actor LlamaContext { print(String(cString: token_to_piece(token: id) + [0])) } - llama_batch_clear(&batch) + llama_batch_ext_clear(batch) for i1 in 0.. String { var new_token_id: llama_token = 0 - new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) + new_token_id = llama_sampler_sample(sampling, context, llama_batch_ext_get_n_tokens(batch) - 1) if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len { print("\n") @@ -178,13 +162,13 @@ actor LlamaContext { print(new_token_str) // tokens_list.append(new_token_id) - llama_batch_clear(&batch) - llama_batch_add(&batch, new_token_id, n_cur, [0], true) + llama_batch_ext_clear(batch) + llama_batch_ext_add_text(batch, new_token_id, n_cur, [llama_seq_id(0)], 1, true) n_decode += 1 n_cur += 1 - if llama_decode(context, batch) != 0 { + if llama_decode_ext(context, batch) != 0 { print("failed to evaluate llama!") } @@ -201,21 +185,21 @@ actor LlamaContext { for _ in 0.. Date: Wed, 19 Mar 2025 10:49:30 +0100 Subject: [PATCH 35/52] android : fix permission --- .../app/src/main/AndroidManifest.xml | 2 ++ .../java/com/example/llama/MainActivity.kt | 21 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/examples/llama.android/app/src/main/AndroidManifest.xml b/examples/llama.android/app/src/main/AndroidManifest.xml index 41a358a299154..de94ee3cd1faf 100644 --- a/examples/llama.android/app/src/main/AndroidManifest.xml +++ b/examples/llama.android/app/src/main/AndroidManifest.xml @@ -3,6 +3,8 @@ xmlns:tools="http://schemas.android.com/tools"> + + Date: Wed, 19 Mar 2025 13:50:15 +0200 Subject: [PATCH 36/52] retrieval : avoid common_batch ggml-ci --- examples/retrieval/retrieval.cpp | 59 ++++++++++++-------------------- include/llama.h | 4 +-- 2 files changed, 24 insertions(+), 39 deletions(-) diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index d43270e856554..9fe6f8b643728 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -74,55 +74,38 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(common_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { - size_t n_tokens = tokens.size(); +static void batch_add_seq(llama_batch_ext * batch, const std::vector & tokens, llama_seq_id seq_id) { + const size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - batch.add_text(tokens[i], i, seq_id, true); + llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true); } } -static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { - const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); +static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { const struct llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch), n_seq); if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { // encoder-only model - if (llama_encode_ext(ctx, batch.get()) < 0) { + if (llama_encode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to encode\n", __func__); } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { // decoder-only model - if (llama_decode_ext(ctx, batch.get()) < 0) { + if (llama_decode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to decode\n", __func__); } } - for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { - if (!batch.tokens[i].logits) { - continue; - } - - const float * embd = nullptr; - int embd_pos = 0; - - if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - // try to get token embeddings - embd = llama_get_embeddings_ith(ctx, i); - embd_pos = i; - GGML_ASSERT(embd != NULL && "failed to get token embeddings"); - } else { - // try to get sequence embeddings - supported only when pooling_type is not NONE - embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); - embd_pos = batch.tokens[i].seq_id; - GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); - } + for (int s = 0; s < n_seq; s++) { + const float * embd = llama_get_embeddings_seq(ctx, s); + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); - float * out = output + embd_pos * n_embd; + float * out = output + s * n_embd; common_embd_normalize(embd, out, n_embd, embd_norm); } } @@ -230,7 +213,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_chunks = chunks.size(); - struct common_batch batch = common_batch(n_batch, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); // allocate output const int n_embd = llama_model_n_embd(model); @@ -247,10 +230,10 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (llama_batch_ext_get_n_tokens(batch.get()) + n_toks > n_batch) { - float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); - batch.clear(); + if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) { + batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); + llama_batch_ext_clear(batch); + p += s; s = 0; } @@ -261,8 +244,7 @@ int main(int argc, char ** argv) { } // final batch - float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); + batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { @@ -271,7 +253,7 @@ int main(int argc, char ** argv) { chunks[i].tokens.clear(); } - struct common_batch query_batch = common_batch(n_batch, 1); + llama_batch_ext * query_batch = llama_batch_ext_init(n_batch, 1); // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; @@ -285,7 +267,7 @@ int main(int argc, char ** argv) { std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); - query_batch.clear(); + llama_batch_ext_clear(query_batch); // compute cosine similarities { @@ -314,6 +296,9 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); + llama_batch_ext_free(batch); + llama_batch_ext_free(query_batch); + // clean up llama_backend_free(); } diff --git a/include/llama.h b/include/llama.h index 73fecf029e4be..d6aeb510011ee 100644 --- a/include/llama.h +++ b/include/llama.h @@ -945,8 +945,8 @@ extern "C" { // The batch has to be freed with llama_batch_ext_free() LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( float * embd, - size_t n_tokens, - size_t n_embd, + size_t n_tokens, + size_t n_embd, int32_t pos0, int32_t seq_id); From 8b80d68338a904bbc680d148c84c028fabf6fc76 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 19 Mar 2025 14:29:04 +0200 Subject: [PATCH 37/52] embedding : avoid common_batch ggml-ci --- examples/embedding/embedding.cpp | 65 +++++++++++++++----------------- examples/retrieval/retrieval.cpp | 2 +- 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 5e72f0e1a160d..947bbc1741021 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -26,56 +26,52 @@ static std::vector split_lines(const std::string & s, const std::st return lines; } -static void batch_add_seq(common_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { +static void batch_add_seq(llama_batch_ext * batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); for (size_t i = 0; i < n_tokens; i++) { - batch.add_text(tokens[i], i, seq_id, true); + llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true); } } -static void batch_decode(llama_context * ctx, common_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { +static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); - const struct llama_model * model = llama_get_model(ctx); + const llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); + const int n_tokens = llama_batch_ext_get_n_tokens(batch); + // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch.get()), n_seq); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, n_tokens, n_seq); if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { // encoder-only model - if (llama_encode_ext(ctx, batch.get()) < 0) { + if (llama_encode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to encode\n", __func__); } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { // decoder-only model - if (llama_decode_ext(ctx, batch.get()) < 0) { + if (llama_decode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to decode\n", __func__); } } - for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i++) { - if (!batch.tokens[i].logits) { - continue; - } - - const float * embd = nullptr; - int embd_pos = 0; - - if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - // try to get token embeddings - embd = llama_get_embeddings_ith(ctx, i); - embd_pos = i; + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int i = 0; i < n_tokens; i++) { + const float * embd = llama_get_embeddings_ith(ctx, i); GGML_ASSERT(embd != NULL && "failed to get token embeddings"); - } else { - // try to get sequence embeddings - supported only when pooling_type is not NONE - embd = llama_get_embeddings_seq(ctx, batch.tokens[i].seq_id); - embd_pos = batch.tokens[i].seq_id; - GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + + float * out = output + i * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); } + } else { + for (int s = 0; s < n_seq; s++) { + const float * embd = llama_get_embeddings_seq(ctx, s); + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); - float * out = output + embd_pos * n_embd; - common_embd_normalize(embd, out, n_embd, embd_norm); + float * out = output + s * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); + } } } @@ -171,7 +167,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - struct common_batch batch = common_batch(n_batch, 1); + llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); // count number of embeddings int n_embd_count = 0; @@ -198,12 +194,12 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (batch.get_n_tokens() + n_toks > n_batch) { - float * out = emb + e * n_embd; - batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); - e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.get_n_tokens() : s; + if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) { + batch_decode(ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize); + llama_batch_ext_clear(batch); + + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? llama_batch_ext_get_n_tokens(batch) : s; s = 0; - batch.clear(); } // add to batch @@ -212,8 +208,7 @@ int main(int argc, char ** argv) { } // final batch - float * out = emb + e * n_embd; - batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); + batch_decode(ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize); if (params.embd_out.empty()) { LOG("\n"); @@ -318,6 +313,8 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); + llama_batch_ext_free(batch); + // clean up llama_backend_free(); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 9fe6f8b643728..6086665494654 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -82,7 +82,7 @@ static void batch_add_seq(llama_batch_ext * batch, const std::vector & } static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { - const struct llama_model * model = llama_get_model(ctx); + const llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); From 76fd7d6f5b0c74429817c2b54c1f6be47b6d88a3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 20 Mar 2025 12:21:40 +0200 Subject: [PATCH 38/52] perplexity : avoid common_batch ggml-ci --- examples/perplexity/perplexity.cpp | 91 ++++++++++++++---------------- 1 file changed, 41 insertions(+), 50 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 956c115d40fe3..15265aa9e86ae 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -363,15 +363,16 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // clear the KV cache llama_kv_self_clear(ctx); - common_batch batch(n_batch, 1); + llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1)); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - batch.clear(); + llama_batch_ext_clear(batch.get()); for (int i = 0; i < batch_size; i++) { - batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true); } //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); @@ -501,7 +502,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); GGML_ASSERT(params.n_ctx == n_seq * n_ctx); - common_batch batch(std::min(n_batch, n_ctx*n_seq), 1); + llama_batch_ext_ptr batch(llama_batch_ext_init(std::min(n_batch, n_ctx*n_seq), 1)); std::vector logits; if (num_batches > 1) { @@ -552,7 +553,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & int n_outputs = 0; - batch.clear(); + llama_batch_ext_clear(batch.get()); for (int seq = 0; seq < n_seq_batch; seq++) { int seq_start = batch_start + seq*n_ctx; @@ -567,7 +568,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & for (int k = 0; k < batch_size; ++k) { const llama_pos pos = j*n_batch + k; bool output = pos >= first; - batch.add_text(tokens[seq_start + k], pos, seq, output); + llama_batch_ext_add_text(batch.get(), tokens[seq_start + k], pos, &seq, 1, output); n_outputs += output ? 1 : 0; } @@ -649,26 +650,15 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & return {tokens, ppl, logit_history, prob_history}; } -static bool decode_helper(llama_context * ctx, common_batch & batch, std::vector & batch_logits, int n_batch, int n_vocab) { - int prev_outputs = 0; - for (int i = 0; i < (int) batch.get_n_tokens(); i += n_batch) { - const int n_tokens = std::min(n_batch, batch.get_n_tokens() - i); - - common_batch batch_view = batch.get_view(i, n_tokens); - - const int ret = llama_decode_ext(ctx, batch_view.get()); - if (ret != 0) { - LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); - return false; - } - - int n_outputs = batch_view.n_outputs; - - memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); - - prev_outputs += n_outputs; +static bool decode_helper(llama_context * ctx, llama_batch_ext_ptr & batch, std::vector & batch_logits, size_t n_outputs, int n_vocab) { + const int ret = llama_decode_ext(ctx, batch.get()); + if (ret != 0) { + LOG_ERR("failed to decode the batch, ret = %d\n", ret); + return false; } + memcpy(batch_logits.data(), llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float)); + return true; } @@ -836,14 +826,12 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { double acc = 0.0f; const int n_ctx = llama_n_ctx(ctx); - const int n_batch = params.n_batch; - const int n_vocab = llama_vocab_n_tokens(vocab); const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - common_batch batch(n_ctx, 4); + llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 4)); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -859,7 +847,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - batch.clear(); + llama_batch_ext_clear(batch.get()); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -875,7 +863,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { } for (size_t i = 0; i < hs_cur.common_prefix; ++i) { - batch.add_text_multi_seq(hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); + std::vector seq_ids = { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }; + llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false); } llama_batch_ext_set_output_last(batch.get()); n_logits += 1; @@ -885,7 +874,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { // TODO: don't evaluate the last token of each sequence for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - batch.add_text_multi_seq(hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); + llama_seq_id seq_id = s0 + s; + llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[s][i], i, &seq_id, 1, needs_logits); n_logits += needs_logits; } } @@ -907,7 +897,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { llama_kv_self_clear(ctx); // decode all tasks [i0, i1) - if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { + if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return; } @@ -1118,14 +1108,12 @@ static void winogrande_score(llama_context * ctx, const common_params & params) LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__); const int n_ctx = llama_n_ctx(ctx); - const int n_batch = params.n_batch; - const int n_vocab = llama_vocab_n_tokens(vocab); const int max_tasks_per_batch = 128; const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - common_batch batch(n_ctx, 2); + llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 2)); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -1144,7 +1132,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) size_t i1 = i0; size_t i_logits = 0; - batch.clear(); + llama_batch_ext_clear(batch.get()); while (n_cur + (int) data[i1].required_tokens <= n_ctx) { int n_logits = 0; @@ -1154,7 +1142,8 @@ static void winogrande_score(llama_context * ctx, const common_params & params) } for (size_t i = 0; i < data[i1].common_prefix; ++i) { - batch.add_text_multi_seq(data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); + std::vector seq_ids{ s0 + 0, s0 + 1 }; + llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false); } llama_batch_ext_set_output_last(batch.get()); n_logits += 1; @@ -1162,7 +1151,8 @@ static void winogrande_score(llama_context * ctx, const common_params & params) for (int s = 0; s < 2; ++s) { // TODO: end before the last token, no need to predict past the end of the sequences for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { - batch.add_text_multi_seq(data[i1].seq_tokens[s][i], i, { s0 + s }, true); + llama_seq_id seq_id = s0 + s; + llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[s][i], i, &seq_id, 1, true); n_logits += 1; } } @@ -1184,7 +1174,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) llama_kv_self_clear(ctx); // decode all tasks [i0, i1) - if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { + if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return; } @@ -1472,14 +1462,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par LOG("\ntask\tacc_norm\n"); const int n_ctx = llama_n_ctx(ctx); - const int n_batch = params.n_batch; - const int n_vocab = llama_vocab_n_tokens(vocab); const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - common_batch batch(n_ctx, max_seq); + llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, max_seq)); std::vector tok_logits(n_vocab); std::vector batch_logits(size_t(n_ctx)*n_vocab); @@ -1499,7 +1487,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - batch.clear(); + llama_batch_ext_clear(batch.get()); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -1518,11 +1506,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par if (int(batch_indeces.size()) != num_answers) { batch_indeces.resize(num_answers); } - for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s; + for (int s = 0; s < num_answers; ++s) { + batch_indeces[s] = s0 + s; + } for (size_t i = 0; i < cur_task.common_prefix; ++i) { - //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); - batch.add_text_multi_seq(cur_task.seq_tokens[0][i], i, batch_indeces, false); + llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[0][i], i, batch_indeces.data(), batch_indeces.size(), false); } llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix n_logits += 1; @@ -1532,7 +1521,8 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par // TODO: don't evaluate the last token of each sequence for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - batch.add_text_multi_seq(cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); + llama_seq_id seq_id = { s0 + s }; + llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[s][i], i, &seq_id, 1, needs_logits); n_logits += needs_logits; } } @@ -1556,7 +1546,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par llama_kv_self_clear(ctx); // decode all tasks [i0, i1) - if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { + if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return; } @@ -1743,7 +1733,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_self_clear(ctx); - common_batch batch(n_batch, 1); + llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1)); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -1757,9 +1747,10 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_vocab_bos(vocab); } - batch.clear(); + llama_batch_ext_clear(batch.get()); for (int i = 0; i < batch_size; i++) { - batch.add_text_multi_seq(tokens[batch_start + i], j*n_batch + i, {0}, true); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true); } if (llama_decode_ext(ctx, batch.get())) { From 8a23b4a54a76f4cc179990e29e61753c00da5a5c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 20 Mar 2025 16:52:24 +0200 Subject: [PATCH 39/52] server : avoid common_batch ggml-ci --- common/common.h | 64 -------- examples/server/server.cpp | 146 ++++++++++--------- examples/server/tests/unit/test_embedding.py | 18 ++- 3 files changed, 91 insertions(+), 137 deletions(-) diff --git a/common/common.h b/common/common.h index 5fe149ff8c991..197108be0ebba 100644 --- a/common/common.h +++ b/common/common.h @@ -565,70 +565,6 @@ std::pair common_get_hf_file( // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); -// -// Batch utils -// - -// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions -// this is meant to be temporary -struct common_batch { - llama_batch_ext_ptr batch; - struct batch_token { - llama_token token; - llama_seq_id seq_id; // only support single seq for now - bool logits; - }; - std::vector tokens; - int n_outputs = 0; - common_batch() = default; - common_batch(int32_t n_tokens, int32_t n_seq_max) { - batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); - tokens.reserve(n_tokens); - } - void clear() { - llama_batch_ext_clear(batch.get()); - tokens.clear(); - } - void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { - llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); - tokens.push_back({token, seq_id, logits}); - if (logits) { - n_outputs++; - } - } - void add_text_multi_seq(llama_token token, llama_pos pos, std::vector seq_ids, bool logits) { - llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); - tokens.push_back({token, seq_ids[0], logits}); - if (logits) { - n_outputs++; - } - } - void set_logits_last() { - if (!tokens.empty()) { - llama_batch_ext_set_output_last(batch.get()); - tokens.back().logits = true; - } - } - int32_t get_n_tokens() const { - return (int32_t)tokens.size(); - } - llama_batch_ext * get() { - return batch.get(); - } - common_batch get_view(int32_t offset, int32_t n_tokens) { - common_batch view; - view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); - view.tokens.reserve(n_tokens); - for (int32_t i = 0; i < n_tokens; i++) { - view.tokens.push_back(tokens[offset + i]); - if (tokens[offset + i].logits) { - view.n_outputs++; - } - } - return view; - } -}; - // // Token utils // diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 80daec9792e79..bcbaa070f905d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1224,7 +1224,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - common_batch batch_spec; + llama_batch_ext_ptr batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1248,7 +1248,7 @@ struct server_slot { int32_t n_past = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; - int32_t i_batch = -1; + int32_t i_batch = -1; // TODO: remove and use only sequence-based sampling int32_t n_predict = -1; // TODO: disambiguate from params.n_predict // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated @@ -1796,7 +1796,7 @@ struct server_context { llama_context_params cparams_dft; - common_batch batch; + llama_batch_ext_ptr batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1922,7 +1922,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1); + slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1)); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1958,7 +1958,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = common_batch(std::max(n_batch, params_base.n_parallel), 1); + batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1)); } metrics.init(); @@ -2093,7 +2093,7 @@ struct server_context { } if (slot.ctx_dft) { - slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1); + slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1)); } slot.state = SLOT_STATE_STARTED; @@ -2401,7 +2401,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, common_batch & batch) { + void send_embedding(const server_slot & slot) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2410,34 +2410,40 @@ struct server_context { const int n_embd = llama_model_n_embd(model); - std::vector embd_res(n_embd, 0.0f); + const llama_seq_id seq_id = slot.id; - for (int i = 0; i < batch.get_n_tokens(); ++i) { - auto tok = batch.tokens[i]; - if (!tok.logits || tok.seq_id != slot.id) { - continue; - } + std::vector embd_res(n_embd, 0.0f); - const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + const float * embd = llama_get_embeddings_seq(ctx, seq_id); if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); + SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id); res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; } - // normalize only when there is pooling // TODO: configurable - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, 2); - res->embedding.push_back(embd_res); - } else { - res->embedding.push_back({ embd, embd + n_embd }); - } + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + GGML_ABORT("embeddings without pooling is not supported yet"); + //for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { + // auto tok = batch.tokens[i]; + // if (!tok.logits || tok.seq_id != slot.id) { + // continue; + // } + + // const float * embd = llama_get_embeddings_ith(ctx, tok.seq_id); + // if (embd == NULL) { + // SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); + + // res->embedding.push_back(std::vector(n_embd, 0.0f)); + // continue; + // } + + // res->embedding.push_back({ embd, embd + n_embd }); + //} } SLT_DBG(slot, "%s", "sending embeddings\n"); @@ -2445,30 +2451,20 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, common_batch & batch) { + void send_rerank(const server_slot & slot) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; res->n_tokens = slot.n_prompt_tokens; - for (int i = 0; i < batch.get_n_tokens(); ++i) { - auto tok = batch.tokens[i]; - if (!tok.logits || tok.seq_id != slot.id) { - continue; - } + const llama_seq_id seq_id = slot.id; - const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); - - res->score = -1e6; - continue; - } + const float * embd = llama_get_embeddings_seq(ctx, seq_id); + if (embd == NULL) { + SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id); + res->score = -1e6; + } else { res->score = embd[0]; } @@ -2854,7 +2850,7 @@ struct server_context { } // start populating the batch for this iteration - batch.clear(); + llama_batch_ext_clear(batch.get()); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2876,9 +2872,9 @@ struct server_context { continue; } - slot.i_batch = batch.get_n_tokens(); + slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); - batch.add_text(slot.sampled, slot.n_past, slot.id, true); + llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, &slot.id, 1, true); slot.n_past += 1; @@ -2895,7 +2891,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.get_n_tokens() == 0) { + if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3061,7 +3057,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) { + if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3081,11 +3077,12 @@ struct server_context { slot.cache_tokens.resize(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); + //batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); + llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, &slot.id, 1, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3095,13 +3092,14 @@ struct server_context { slot.n_past++; } - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", + slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.get_n_tokens() > 0); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0); common_sampler_reset(slot.smpl); @@ -3111,27 +3109,28 @@ struct server_context { } // extract the logits only for the last token - batch.set_logits_last(); + //batch.set_logits_last(); + llama_batch_ext_set_output_last(batch.get()); slot.n_decoded = 0; - slot.i_batch = batch.get_n_tokens() - 1; + slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens()); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get())); } } - if (batch.get_n_tokens() >= n_batch) { + if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) { break; } } } - if (batch.get_n_tokens() == 0) { + if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens()); + SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get())); if (slot_batched) { // make sure we're in the right embedding mode @@ -3141,10 +3140,10 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { - const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); + for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i); - common_batch batch_view = batch.get_view(i, n_tokens); + llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens)); const int ret = llama_decode_ext(ctx, batch_view.get()); metrics.on_decoded(slots); @@ -3177,14 +3176,14 @@ struct server_context { if (slot.state == SLOT_STATE_DONE_PROMPT) { if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding - send_embedding(slot, batch_view); + send_embedding(slot); slot.release(); slot.i_batch = -1; continue; // continue loop of slots } if (slot.task_type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); + send_rerank(slot); slot.release(); slot.i_batch = -1; continue; // continue loop of slots @@ -3281,14 +3280,17 @@ struct server_context { } // construct the speculation batch - slot.batch_spec.clear(); - slot.batch_spec.add_text(id, slot.n_past, slot.id, true); + //slot.batch_spec.clear(); + //slot.batch_spec.add_text(id, slot.n_past, slot.id, true); + llama_batch_ext_clear(slot.batch_spec.get()); + llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, &slot.id, 1, true); for (size_t i = 0; i < draft.size(); ++i) { - slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); + //slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); + llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, &slot.id, 1, true); } - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); + SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); llama_decode_ext(ctx, slot.batch_spec.get()); @@ -4147,6 +4149,11 @@ int main(int argc, char ** argv) { return; } + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' is not yet supported. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + // for the shape of input/content, see tokenize_input_prompts() json prompt; if (body.count("input") != 0) { @@ -4241,6 +4248,11 @@ int main(int argc, char ** argv) { return; } + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' cannot be used with reranking. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + const json body = json::parse(req.body); // TODO: implement diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 8b0eb42b0926f..889a759aea934 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -88,13 +88,19 @@ def test_embedding_pooling_none(): res = server.make_request("POST", "/embeddings", data={ "input": "hello hello hello", }) - assert res.status_code == 200 - assert 'embedding' in res.body[0] - assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special - # make sure embedding vector is not normalized - for x in res.body[0]['embedding']: - assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON + # /embeddings does not support pooling type 'none' + assert res.status_code == 400 + assert "error" in res.body + + # TODO: re-enable when we figure out how to support pooling type 'none' + #assert res.status_code == 200 + #assert 'embedding' in res.body[0] + #assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special + + ## make sure embedding vector is not normalized + #for x in res.body[0]['embedding']: + # assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON def test_embedding_pooling_none_oai(): From b8b173274d50c9fe12d94c3308fa47dccbf96580 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 20 Mar 2025 18:19:55 +0200 Subject: [PATCH 40/52] server : remove old commented code [no ci] --- examples/server/server.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index bcbaa070f905d..b99059511e7e7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3081,7 +3081,6 @@ struct server_context { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - //batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, &slot.id, 1, need_embd); if (slot.params.cache_prompt) { @@ -3109,7 +3108,6 @@ struct server_context { } // extract the logits only for the last token - //batch.set_logits_last(); llama_batch_ext_set_output_last(batch.get()); slot.n_decoded = 0; @@ -3280,13 +3278,10 @@ struct server_context { } // construct the speculation batch - //slot.batch_spec.clear(); - //slot.batch_spec.add_text(id, slot.n_past, slot.id, true); llama_batch_ext_clear(slot.batch_spec.get()); llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, &slot.id, 1, true); for (size_t i = 0; i < draft.size(); ++i) { - //slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, &slot.id, 1, true); } From 30f1db9936112bc04917df1a9520421ae5f3ffc1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 20 Mar 2025 22:27:36 +0100 Subject: [PATCH 41/52] remove C API llama_batch_ext_init_from_text --- common/common.cpp | 4 +-- examples/lookahead/lookahead.cpp | 4 +-- examples/lookup/lookup.cpp | 4 +-- examples/run/run.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 32 ++++++++++---------- examples/simple-chat/simple-chat.cpp | 15 ++++----- examples/simple/simple.cpp | 14 ++++----- examples/speculative/speculative.cpp | 6 ++-- include/llama-cpp.h | 32 +++++++++++++------- include/llama.h | 24 ++++----------- src/llama-batch.cpp | 26 +++------------- 11 files changed, 73 insertions(+), 90 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index f8498f01d6f71..e3b9261fc0f8c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1016,7 +1016,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (llama_model_has_encoder(model)) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), tmp.size(), 0, 0, true)); + auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), tmp.size(), 0, 0, true); llama_encode_ext(lctx, batch.get()); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == LLAMA_TOKEN_NULL) { @@ -1026,7 +1026,7 @@ struct common_init_result common_init_from_params(common_params & params) { tmp.push_back(decoder_start_token_id); } if (llama_model_has_decoder(model)) { - llama_batch_ext_ptr batch(llama_batch_ext_init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true)); + auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true); llama_decode_ext(lctx, batch.get()); } llama_kv_self_clear(lctx); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 8277559689074..985b89e448f95 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -92,8 +92,8 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); + auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true); + auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 07e57afcbab9f..232f7816ee217 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -91,8 +91,8 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); + auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true); + auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true); llama_decode_ext(ctx, batch0.get()); llama_decode_ext(ctx, batch1.get()); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index a61982dcf64e7..9396687d6e1ea 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -1017,7 +1017,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str print_word_and_concatenate_to_response(piece, response); // prepare the next batch with the sampled token - batch.reset(llama_batch_ext_init_from_text(&new_token_id, 1, llama_data.n_past, 0, true)); + batch = llama_batch_ext_ptr::init_from_text(&new_token_id, 1, llama_data.n_past, 0, true); } printf(LOG_COL_DEFAULT); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 2ff4e24c19c1e..ad818917e1b8e 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,11 +48,11 @@ int main(int argc, char ** argv) { auto tokens = common_tokenize(ctx, params.prompt, true); // prepare the batch - llama_batch_ext * batch = llama_batch_ext_init_from_text(tokens.data(), tokens.size(), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true); // evaluate prompt - llama_decode_ext(ctx, batch); - n_past += llama_batch_ext_get_n_tokens(batch); + llama_decode_ext(ctx, batch.get()); + n_past += llama_batch_ext_get_n_tokens(batch.get()); // save state (rng, logits, embedding and kv_cache) to file { @@ -79,13 +79,13 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result0 += next_token_str; - llama_batch_ext_clear(batch); + llama_batch_ext_clear(batch.get()); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); + llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true); - if (llama_decode_ext(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_ext_free(batch); + llama_batch_ext_free(batch.get()); return 1; } n_past += 1; @@ -132,13 +132,13 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result1 += next_token_str; - llama_batch_ext_clear(batch); + llama_batch_ext_clear(batch.get()); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); + llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true); - if (llama_decode_ext(ctx2, batch)) { + if (llama_decode_ext(ctx2, batch.get())) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_ext_free(batch); + llama_batch_ext_free(batch.get()); return 1; } n_past += 1; @@ -214,13 +214,13 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result2 += next_token_str; - llama_batch_ext_clear(batch); + llama_batch_ext_clear(batch.get()); llama_seq_id seq_id = 1; // seq 1 instead of 0 - llama_batch_ext_add_text(batch, next_token, 0, &seq_id, 1, true); + llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true); - if (llama_decode_ext(ctx3, batch)) { + if (llama_decode_ext(ctx3, batch.get())) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_ext_free(batch); + llama_batch_ext_free(batch.get()); return 1; } n_past += 1; @@ -232,7 +232,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl2); llama_sampler_free(smpl3); - llama_batch_ext_free(batch); + llama_batch_ext_free(batch.get()); if (result0 != result2) { fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index dbde1ee9e88d6..4ad8a4ecd5f97 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -1,4 +1,5 @@ #include "llama.h" +#include "llama-cpp.h" #include #include #include @@ -109,21 +110,21 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt llama_pos n_past = 0; - llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); - n_past += llama_batch_ext_get_n_tokens(batch); + auto batch = llama_batch_ext_ptr::init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); + n_past += llama_batch_ext_get_n_tokens(batch.get()); llama_token new_token_id; while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); int n_ctx_used = llama_kv_self_used_cells(ctx); - if (n_ctx_used + llama_batch_ext_get_n_tokens(batch) > n_ctx) { + if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); exit(0); } - if (llama_decode_ext(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { GGML_ABORT("failed to decode\n"); } @@ -147,13 +148,13 @@ int main(int argc, char ** argv) { response += piece; // prepare the next batch with the sampled token - llama_batch_ext_clear(batch); + llama_batch_ext_clear(batch.get()); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, new_token_id, n_past, &seq_id, 1, true); + llama_batch_ext_add_text(batch.get(), new_token_id, n_past, &seq_id, 1, true); n_past++; } - llama_batch_ext_free(batch); + llama_batch_ext_free(batch.get()); return response; }; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 26009a5aec398..63d7703e04f6c 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -1,4 +1,5 @@ #include "llama.h" +#include "llama-cpp.h" #include #include #include @@ -143,7 +144,7 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt - llama_batch_ext * batch = llama_batch_ext_init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true); // main loop @@ -151,14 +152,14 @@ int main(int argc, char ** argv) { int n_decode = 0; llama_token new_token_id; - for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch) < n_prompt + n_predict; ) { + for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch.get()) < n_prompt + n_predict; ) { // evaluate the current batch with the transformer model - if (llama_decode_ext(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } - n_pos += llama_batch_ext_get_n_tokens(batch); + n_pos += llama_batch_ext_get_n_tokens(batch.get()); // sample the next token { @@ -180,9 +181,9 @@ int main(int argc, char ** argv) { fflush(stdout); // prepare the next batch with the sampled token - llama_batch_ext_clear(batch); + llama_batch_ext_clear(batch.get()); llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, new_token_id, n_pos, &seq_id, 1, true); + llama_batch_ext_add_text(batch.get(), new_token_id, n_pos, &seq_id, 1, true); n_decode += 1; } @@ -200,7 +201,6 @@ int main(int argc, char ** argv) { llama_perf_context_print(ctx); fprintf(stderr, "\n"); - llama_batch_ext_free(batch); llama_sampler_free(smpl); llama_free(ctx); llama_model_free(model); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index ff5eceb643208..561956bf669cd 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -165,9 +165,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_batch_ext_ptr batch0(llama_batch_ext_init_from_text( inp.data(), n_input - 1, 0, 0, true)); - llama_batch_ext_ptr batch1(llama_batch_ext_init_from_text(&inp.back(), 1, n_input - 1, 0, true)); - llama_batch_ext_ptr batch2(llama_batch_ext_init_from_text( inp.data(), n_input , 0, 0, true)); + auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true); + auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true); + auto batch2 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input , 0, 0, true); llama_decode_ext(ctx_tgt, batch0.get()); llama_decode_ext(ctx_tgt, batch1.get()); llama_decode_ext(ctx_dft, batch2.get()); diff --git a/include/llama-cpp.h b/include/llama-cpp.h index fee15ef9c2bae..a415fc0310a02 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -37,21 +37,31 @@ struct llama_batch_ext_ptr : std::unique_ptr() {} llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr(batch) {} - // convenience function to create a batch from text tokens, without worrying about manually freeing it + // Convenience C++ wrapper to create a batch from text tokens, without worrying about manually freeing it + // First token will be at position pos0 + // The sequence ID will be fixed to seq_id + // If output_last is true, the last token will have output set static llama_batch_ext_ptr init_from_text(llama_token * tokens, - int32_t n_tokens, - int32_t pos0, - int32_t seq_id, - bool output_last) { - return llama_batch_ext_ptr(llama_batch_ext_init_from_text(tokens, n_tokens, pos0, seq_id, output_last)); + int32_t n_tokens, + llama_pos pos0, + llama_seq_id seq_id, + bool output_last) { + llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); + for (int32_t i = 0; i < n_tokens; i++) { + llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); + } + if (output_last) { + llama_batch_ext_set_output_last(batch); + } + return llama_batch_ext_ptr(batch); } - // convenience function to create a batch from text embeddings, without worrying about manually freeing it + // Convenience C++ wrapper to create a batch from text embeddings, without worrying about manually freeing it static llama_batch_ext_ptr init_from_embd(float * embd, - size_t n_tokens, - size_t n_embd, - int32_t pos0, - int32_t seq_id) { + size_t n_tokens, + size_t n_embd, + llama_pos pos0, + llama_seq_id seq_id) { return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(embd, n_tokens, n_embd, pos0, seq_id)); } }; diff --git a/include/llama.h b/include/llama.h index d6aeb510011ee..a63a397193016 100644 --- a/include/llama.h +++ b/include/llama.h @@ -900,7 +900,7 @@ extern "C" { // DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens), "use llama_batch_ext_init_from_text instead"); + int32_t n_tokens), "use llama_batch_ext API instead"); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids @@ -925,18 +925,6 @@ extern "C" { int32_t n_tokens, int32_t n_seq_max); - // Same with llama_batch_init, but initializes the batch with the provided text tokens - // First token will be at position pos0 - // The sequence ID will be fixed to seq_id - // If output_last is true, the last token will have output set - // The batch has to be freed with llama_batch_ext_free() - LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_text( - llama_token * tokens, - int32_t n_tokens, - int32_t pos0, - int32_t seq_id, - bool output_last); - // Same with llama_batch_init, but initializes the batch with the provided raw embeddings // Size of embd should be n_tokens * n_embd // n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd() @@ -944,11 +932,11 @@ extern "C" { // The sequence ID will be fixed to seq_id // The batch has to be freed with llama_batch_ext_free() LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( - float * embd, - size_t n_tokens, - size_t n_embd, - int32_t pos0, - int32_t seq_id); + const float * embd, + size_t n_tokens, + size_t n_embd, + llama_pos pos0, + llama_seq_id seq_id); // Set arbitrary token to the embeddings batch // Note: this is only to be used in conjunction with llama_batch_ext_init_from_embd() diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 0455db9d0617d..375724a4784fc 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -337,22 +337,6 @@ struct llama_batch llama_batch_get_one( }; } -struct llama_batch_ext * llama_batch_ext_init_from_text( - llama_token * tokens, - int32_t n_tokens, - int32_t pos0, - int32_t seq_id, - bool output_last) { - llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); - for (int32_t i = 0; i < n_tokens; i++) { - llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); - } - if (output_last) { - llama_batch_ext_set_output_last(batch); - } - return batch; -} - static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) { llama_batch_ext * batch = new llama_batch_ext{ /*n_tokens =*/ 0, @@ -390,11 +374,11 @@ struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_ } struct llama_batch_ext * llama_batch_ext_init_from_embd( - float * embd, - size_t n_tokens, - size_t n_embd, - int32_t pos0, - int32_t seq_id) { + const float * embd, + size_t n_tokens, + size_t n_embd, + llama_pos pos0, + llama_seq_id seq_id) { struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1); memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float)); for (size_t i = 0; i < n_tokens; i++) { From 2134cabf50ab327cbb9c18150cc4fbe5d2bc8d36 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 21 Mar 2025 11:48:29 +0100 Subject: [PATCH 42/52] add cpp batch.add_text wrapper --- examples/perplexity/perplexity.cpp | 22 +++++++++------------- include/llama-cpp.h | 27 +++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 15265aa9e86ae..3c8bb0cc9d59b 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -371,8 +371,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params llama_batch_ext_clear(batch.get()); for (int i = 0; i < batch_size; i++) { - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true); + batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); } //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); @@ -568,7 +567,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & for (int k = 0; k < batch_size; ++k) { const llama_pos pos = j*n_batch + k; bool output = pos >= first; - llama_batch_ext_add_text(batch.get(), tokens[seq_start + k], pos, &seq, 1, output); + batch.add_text(tokens[seq_start + k], pos, seq, output); n_outputs += output ? 1 : 0; } @@ -864,7 +863,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { for (size_t i = 0; i < hs_cur.common_prefix; ++i) { std::vector seq_ids = { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }; - llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false); + batch.add_text(hs_cur.seq_tokens[0][i], i, seq_ids, false); } llama_batch_ext_set_output_last(batch.get()); n_logits += 1; @@ -875,7 +874,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; llama_seq_id seq_id = s0 + s; - llama_batch_ext_add_text(batch.get(), hs_cur.seq_tokens[s][i], i, &seq_id, 1, needs_logits); + batch.add_text(hs_cur.seq_tokens[s][i], i, seq_id, needs_logits); n_logits += needs_logits; } } @@ -1143,7 +1142,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) for (size_t i = 0; i < data[i1].common_prefix; ++i) { std::vector seq_ids{ s0 + 0, s0 + 1 }; - llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[0][i], i, seq_ids.data(), seq_ids.size(), false); + batch.add_text(data[i1].seq_tokens[0][i], i, seq_ids, false); } llama_batch_ext_set_output_last(batch.get()); n_logits += 1; @@ -1151,8 +1150,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) for (int s = 0; s < 2; ++s) { // TODO: end before the last token, no need to predict past the end of the sequences for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { - llama_seq_id seq_id = s0 + s; - llama_batch_ext_add_text(batch.get(), data[i1].seq_tokens[s][i], i, &seq_id, 1, true); + batch.add_text(data[i1].seq_tokens[s][i], i, s0 + s, true); n_logits += 1; } } @@ -1511,7 +1509,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par } for (size_t i = 0; i < cur_task.common_prefix; ++i) { - llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[0][i], i, batch_indeces.data(), batch_indeces.size(), false); + batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false); } llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix n_logits += 1; @@ -1521,8 +1519,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par // TODO: don't evaluate the last token of each sequence for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - llama_seq_id seq_id = { s0 + s }; - llama_batch_ext_add_text(batch.get(), cur_task.seq_tokens[s][i], i, &seq_id, 1, needs_logits); + batch.add_text(cur_task.seq_tokens[s][i], i, s0 + s, needs_logits); n_logits += needs_logits; } } @@ -1749,8 +1746,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { llama_batch_ext_clear(batch.get()); for (int i = 0; i < batch_size; i++) { - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch.get(), tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true); + batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); } if (llama_decode_ext(ctx, batch.get())) { diff --git a/include/llama-cpp.h b/include/llama-cpp.h index a415fc0310a02..940b6127b8666 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -64,4 +64,31 @@ struct llama_batch_ext_ptr : std::unique_ptr & tokens, llama_pos pos0, llama_seq_id seq_id, bool output_last) { + size_t n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + llama_batch_ext_add_text(this->get(), tokens[i], i + pos0, &seq_id, 1, false); + } + if (output_last) { + llama_batch_ext_set_output_last(this->get()); + } + } + + // Wrapper to add a single token to the batch, support multiple sequence IDs + void add_text(llama_token token, llama_pos pos, std::vector & seq_id, bool output_last) { + llama_batch_ext_add_text(this->get(), token, pos, seq_id.data(), seq_id.size(), false); + if (output_last) { + llama_batch_ext_set_output_last(this->get()); + } + } + + // Wrapper to add a single token to the batch (single sequence ID) + void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool output_last) { + llama_batch_ext_add_text(this->get(), token, pos, &seq_id, 1, false); + if (output_last) { + llama_batch_ext_set_output_last(this->get()); + } + } }; From 2cec1cff7412999c1002949fc53edc2204361575 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 21 Mar 2025 17:23:16 +0100 Subject: [PATCH 43/52] move various places to batch.add_text --- common/speculative.cpp | 8 ++--- examples/gritlm/gritlm.cpp | 25 ++++++------- examples/llava/gemma3-cli.cpp | 6 ++-- examples/llava/qwen2vl-cli.cpp | 3 +- examples/lookup/lookup.cpp | 14 ++++---- examples/parallel/parallel.cpp | 34 +++++++++--------- examples/passkey/passkey.cpp | 6 ++-- examples/save-load-state/save-load-state.cpp | 9 ++--- examples/server/server.cpp | 8 ++--- examples/simple-chat/simple-chat.cpp | 3 +- examples/simple/simple.cpp | 3 +- .../speculative-simple/speculative-simple.cpp | 13 ++++--- examples/speculative/speculative.cpp | 29 +++++++-------- examples/tts/tts.cpp | 36 +++++++++---------- include/llama-cpp.h | 22 ++++-------- 15 files changed, 91 insertions(+), 128 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 464d84ad7aa9c..2b98a978db7c8 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -149,8 +149,6 @@ llama_tokens common_speculative_gen_draft( const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); - const llama_seq_id seq_id = 0; - // reuse as much as possible from the old draft context // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt for (int i = 0; i < (int) prompt.size(); ++i) { @@ -210,7 +208,7 @@ llama_tokens common_speculative_gen_draft( for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); - llama_batch_ext_add_text(batch.get(), prompt_tgt[i], i - i_start, &seq_id, 1, false); + batch.add_text(prompt_tgt[i], i - i_start, 0, false); prompt.push_back(prompt_tgt[i]); } @@ -227,7 +225,7 @@ llama_tokens common_speculative_gen_draft( LOG_DBG("%s: n_past = %d\n", __func__, n_past); llama_batch_ext_clear(batch.get()); - llama_batch_ext_add_text(batch.get(), id_last, n_past, &seq_id, 1, true); + batch.add_text(id_last, n_past, 0, true); prompt.push_back(id_last); @@ -266,7 +264,7 @@ llama_tokens common_speculative_gen_draft( break; } - llama_batch_ext_add_text(batch.get(), id, n_past + i + 1, &seq_id, 1, true); + batch.add_text( id, n_past + i + 1, 0, true); // evaluate the drafted tokens on the draft model llama_decode_ext(ctx, batch.get()); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index a28effa44fd89..9a3ea912a1122 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -1,6 +1,7 @@ #include "arg.h" #include "common.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -13,10 +14,10 @@ static std::vector> encode(llama_context * ctx, const std::ve const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); - llama_batch_ext * batch = llama_batch_ext_init(llama_n_batch(ctx), 1); + llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1)); for (uint64_t i = 0; i < sentences.size(); i++) { - llama_batch_ext_clear(batch); + llama_batch_ext_clear(batch.get()); const std::string input_string = instruction + sentences[i]; @@ -41,8 +42,7 @@ static std::vector> encode(llama_context * ctx, const std::ve // add input to batch (this increments n_tokens) for (int32_t j = 0; j < n_toks; j++) { - const llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, inputs[j], j, &seq_id, 1 , j >= n_inst); + batch.add_text(inputs[j], j, 0, j >= n_inst); } // clear previous kv_cache values (irrelevant for embeddings) @@ -51,7 +51,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_set_causal_attn(ctx, false); // run model - llama_decode_ext(ctx, batch); + llama_decode_ext(ctx, batch.get()); // get embedding dimensions uint64_t n_embd = llama_model_n_embd(model); @@ -90,8 +90,6 @@ static std::vector> encode(llama_context * ctx, const std::ve #endif } - llama_batch_ext_free(batch); - return result; } @@ -107,26 +105,25 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); - llama_batch_ext * bat = llama_batch_ext_init(llama_n_batch(ctx), 1); + llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1)); std::vector inputs = common_tokenize(vocab, prompt, false, true); int32_t i_current_token = 0; while (true) { - llama_batch_ext_clear(bat); + llama_batch_ext_clear(batch.get()); { const int32_t n_inputs = inputs.size(); for (int32_t i = 0; i < n_inputs; i++) { - const llama_seq_id seq_id = 0; - llama_batch_ext_add_text(bat, inputs[i], i_current_token++, &seq_id, 1, i == n_inputs - 1); + batch.add_text(inputs[i], i_current_token++, 0, i == n_inputs - 1); } } inputs.clear(); - llama_decode_ext(ctx, bat); + llama_decode_ext(ctx, batch.get()); - llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(bat) - 1); + llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1); if (token == eos_token) { break; @@ -147,8 +144,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std std::printf("\n"); } - llama_batch_ext_free(bat); - return result; } diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index 3efa604b935b6..28d5a13a8abc1 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -92,8 +92,7 @@ static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); llama_batch_ext_clear(ctx.batch.get()); for (llama_token & t : tokens) { - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(ctx.batch.get(), t, ctx.n_past++, &seq_id, 1, false); + ctx.batch.add_text(t, ctx.n_past++, 0, false); } if (logits_last) { llama_batch_ext_set_output_last(ctx.batch.get()); @@ -180,8 +179,7 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_ // eval the token llama_batch_ext_clear(ctx.batch.get()); - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(ctx.batch.get(), token_id, ctx.n_past++, &seq_id, 1, true); + ctx.batch.add_text(token_id, ctx.n_past++, 0, true); if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { LOG_ERR("failed to decode token\n"); return 1; diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index d4fcabb1081e9..490ee7fdc6ee1 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -101,8 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector #include @@ -110,7 +111,7 @@ int main(int argc, char ** argv){ std::vector draft; - llama_batch_ext * batch_tgt = llama_batch_ext_init(params.n_ctx, 1); + llama_batch_ext_ptr batch_tgt(llama_batch_ext_init(params.n_ctx, 1)); // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); @@ -196,9 +197,8 @@ int main(int argc, char ** argv){ // clean the cache of draft tokens that weren't accepted llama_kv_self_seq_rm(ctx, 0, n_past, -1); - const llama_seq_id seq_id = 0; - llama_batch_ext_clear(batch_tgt); - llama_batch_ext_add_text(batch_tgt, draft[0], n_past, &seq_id, 1, true); + llama_batch_ext_clear(batch_tgt.get()); + batch_tgt.add_text(draft[0], n_past, 0, true); // Draft already contains a single token sampled from the model: GGML_ASSERT(draft.size() == 1); @@ -208,13 +208,13 @@ int main(int argc, char ** argv){ common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); for (size_t i = 1; i < draft.size(); ++i) { - llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); + batch_tgt.add_text(draft[i], n_past + i, 0, true); } t_draft_us += ggml_time_us() - t_start_draft_us; n_drafted += draft.size() - 1; - llama_decode_ext(ctx, batch_tgt); + llama_decode_ext(ctx, batch_tgt.get()); ++n_past; draft.erase(draft.begin()); @@ -246,8 +246,6 @@ int main(int argc, char ** argv){ common_sampler_free(smpl); - llama_batch_ext_free(batch_tgt); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 1d5f59f7d2124..c9e80daea68f1 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -6,6 +6,7 @@ #include "sampling.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -174,7 +175,7 @@ int main(int argc, char ** argv) { // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time - llama_batch_ext * batch = llama_batch_ext_init(n_ctx, 1); + llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 1)); int32_t n_total_prompt = 0; int32_t n_total_gen = 0; @@ -192,11 +193,10 @@ int main(int argc, char ** argv) { LOG_INF("%s: Evaluating the system prompt ...\n", __func__); for (int32_t i = 0; i < n_tokens_system; ++i) { - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, tokens_system[i], i, &seq_id, 1, false); + batch.add_text(tokens_system[i], i, 0, false); } - if (llama_decode_ext(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -217,7 +217,7 @@ int main(int argc, char ** argv) { common_kv_cache_dump_view_seqs(kvc_view, 40); } - llama_batch_ext_clear(batch); + llama_batch_ext_clear(batch.get()); // decode any currently ongoing sequences for (auto & client : clients) { @@ -225,15 +225,15 @@ int main(int argc, char ** argv) { continue; } - client.i_batch = llama_batch_ext_get_n_tokens(batch); + client.i_batch = llama_batch_ext_get_n_tokens(batch.get()); llama_seq_id seq_id = client.id + 1; - llama_batch_ext_add_text(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, &seq_id, 1, true); + batch.add_text(client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, seq_id, true); client.n_decoded += 1; } - if (llama_batch_ext_get_n_tokens(batch) == 0) { + if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { llama_kv_self_seq_rm(ctx, i, -1, -1); @@ -245,7 +245,7 @@ int main(int argc, char ** argv) { } // insert new sequences for decoding - if (cont_batching || llama_batch_ext_get_n_tokens(batch) == 0) { + if (cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) { for (auto & client : clients) { if (client.seq_id == -1 && g_seq_id < n_seq) { client.seq_id = g_seq_id; @@ -265,17 +265,17 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < tokens_prompt.size(); ++i) { llama_seq_id seq_id = client.id + 1; - llama_batch_ext_add_text(batch, tokens_prompt[i], i + n_tokens_system, &seq_id, 1, false); + batch.add_text(tokens_prompt[i], i + n_tokens_system, seq_id, false); } // extract the logits only for the last token - if (llama_batch_ext_get_n_tokens(batch) > 0) { - llama_batch_ext_set_output_last(batch); + if (llama_batch_ext_get_n_tokens(batch.get()) > 0) { + llama_batch_ext_set_output_last(batch.get()); } client.n_prompt = tokens_prompt.size(); client.n_decoded = 0; - client.i_batch = llama_batch_ext_get_n_tokens(batch) - 1; + client.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1; LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); @@ -289,14 +289,14 @@ int main(int argc, char ** argv) { } } - if (llama_batch_ext_get_n_tokens(batch) == 0) { + if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { break; } // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch); + int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch.get()); for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) { // experiment: process in powers of 2 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { @@ -307,7 +307,7 @@ int main(int argc, char ** argv) { const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i)); - llama_batch_ext * batch_view = llama_batch_ext_get_view(batch, i, n_tokens); + llama_batch_ext * batch_view = llama_batch_ext_get_view(batch.get(), i, n_tokens); const int ret = llama_decode_ext(ctx, batch_view); llama_batch_ext_free(batch_view); if (ret != 0) { @@ -413,8 +413,6 @@ int main(int argc, char ** argv) { // TODO: print sampling/grammar timings for all clients llama_perf_context_print(ctx); - llama_batch_ext_free(batch); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 88e6ccdde6424..42674525064fd 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -144,8 +144,7 @@ int main(int argc, char ** argv) { llama_batch_ext_clear(batch.get()); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); + batch.add_text(tokens_list[i + j], n_past++, 0, false); } if (i + n_batch >= n_tokens_all) { @@ -179,8 +178,7 @@ int main(int argc, char ** argv) { llama_batch_ext_clear(batch.get()); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch.get(), tokens_list[i + j], n_past++, &seq_id, 1, false); + batch.add_text(tokens_list[i + j], n_past++, 0, false); } if (i + n_batch >= n_tokens_all) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index ad818917e1b8e..2f7804ade596c 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -80,8 +80,7 @@ int main(int argc, char ** argv) { result0 += next_token_str; llama_batch_ext_clear(batch.get()); - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true); + batch.add_text(next_token, 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); @@ -133,8 +132,7 @@ int main(int argc, char ** argv) { result1 += next_token_str; llama_batch_ext_clear(batch.get()); - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true); + batch.add_text(next_token, 0, 0, true); if (llama_decode_ext(ctx2, batch.get())) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); @@ -215,8 +213,7 @@ int main(int argc, char ** argv) { result2 += next_token_str; llama_batch_ext_clear(batch.get()); - llama_seq_id seq_id = 1; // seq 1 instead of 0 - llama_batch_ext_add_text(batch.get(), next_token, 0, &seq_id, 1, true); + batch.add_text(next_token, 0, 1, true); if (llama_decode_ext(ctx3, batch.get())) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b99059511e7e7..2dbdb2079b147 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2874,7 +2874,7 @@ struct server_context { slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); - llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, &slot.id, 1, true); + batch.add_text(slot.sampled, slot.n_past, slot.id, true); slot.n_past += 1; @@ -3081,7 +3081,7 @@ struct server_context { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, &slot.id, 1, need_embd); + batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3279,10 +3279,10 @@ struct server_context { // construct the speculation batch llama_batch_ext_clear(slot.batch_spec.get()); - llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, &slot.id, 1, true); + slot.batch_spec.add_text(id, slot.n_past, slot.id, true); for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, &slot.id, 1, true); + slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); } SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 4ad8a4ecd5f97..51c5c837dc624 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -149,8 +149,7 @@ int main(int argc, char ** argv) { // prepare the next batch with the sampled token llama_batch_ext_clear(batch.get()); - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch.get(), new_token_id, n_past, &seq_id, 1, true); + batch.add_text(new_token_id, n_past, 0, true); n_past++; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 63d7703e04f6c..bcd119ad00247 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -182,8 +182,7 @@ int main(int argc, char ** argv) { // prepare the next batch with the sampled token llama_batch_ext_clear(batch.get()); - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch.get(), new_token_id, n_pos, &seq_id, 1, true); + batch.add_text(new_token_id, n_pos, 0, true); n_decode += 1; } diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 74abd98d75e68..d1c4e4a5d7b45 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -4,6 +4,7 @@ #include "speculative.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -133,7 +134,7 @@ int main(int argc, char ** argv) { struct common_speculative * spec = common_speculative_init(ctx_dft); - llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), 1); + llama_batch_ext_ptr batch_tgt(llama_batch_ext_init(llama_n_batch(ctx_tgt), 1)); const auto t_enc_end = ggml_time_us(); @@ -152,9 +153,8 @@ int main(int argc, char ** argv) { //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); // always have a token to evaluate from before - id_last - llama_batch_ext_clear(batch_tgt); - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch_tgt, id_last, n_past++, &seq_id, 1, true); + llama_batch_ext_clear(batch_tgt.get()); + batch_tgt.add_text(id_last, n_past++, 0, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { @@ -164,12 +164,12 @@ int main(int argc, char ** argv) { } for (size_t i = 0; i < draft.size(); ++i) { - llama_batch_ext_add_text(batch_tgt, draft[i], n_past + i, &seq_id, 1, true); + batch_tgt.add_text(draft[i], n_past + i, 0, true); } //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); - llama_decode_ext(ctx_tgt, batch_tgt); + llama_decode_ext(ctx_tgt, batch_tgt.get()); } // sample from the full target batch and return the accepted tokens based on the target sampler @@ -255,7 +255,6 @@ int main(int argc, char ** argv) { common_sampler_free(smpl); common_speculative_free(spec); - llama_batch_ext_free(batch_tgt); llama_backend_free(); LOG("\n\n"); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 9915045564e3f..8f4500e8ca68d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -3,6 +3,7 @@ #include "sampling.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -201,8 +202,8 @@ int main(int argc, char ** argv) { drafts[s].smpl = common_sampler_init(model_dft, params.sampling); } - llama_batch_ext * batch_dft = llama_batch_ext_init(llama_n_batch(ctx_dft), 1); - llama_batch_ext * batch_tgt = llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft); + llama_batch_ext_ptr batch_dft(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)); + llama_batch_ext_ptr batch_tgt(llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft)); const auto t_dec_start = ggml_time_us(); @@ -443,13 +444,12 @@ int main(int argc, char ** argv) { drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); - llama_batch_ext_clear(batch_dft); - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch_dft, token_id, n_past_dft, &seq_id, 1, true); + llama_batch_ext_clear(batch_dft.get()); + batch_dft.add_text(token_id, n_past_dft, 0, true); llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - llama_decode_ext(ctx_dft, batch_dft); + llama_decode_ext(ctx_dft, batch_dft.get()); ++n_past_dft; } @@ -486,7 +486,7 @@ int main(int argc, char ** argv) { // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { - llama_batch_ext_clear(batch_dft); + llama_batch_ext_clear(batch_dft.get()); for (int s = 0; s < n_seq_dft; ++s) { drafts[s].skip = false; @@ -567,7 +567,7 @@ int main(int argc, char ** argv) { batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }}); // add the token to the batch for batched decoding with the draft model - drafts[s].i_batch_dft = llama_batch_ext_add_text(batch_dft, id, n_past_cur, &s, 1, true); + drafts[s].i_batch_dft = batch_dft.add_text(id, n_past_cur, s, true); if (batch_tgt_data.size() > (size_t) n_draft) { drafts[s].drafting = false; @@ -576,12 +576,12 @@ int main(int argc, char ** argv) { } // no sequence is drafting anymore - if (llama_batch_ext_get_n_tokens(batch_dft) == 0) { + if (llama_batch_ext_get_n_tokens(batch_dft.get()) == 0) { break; } // evaluate the drafted tokens on the draft model - llama_decode_ext(ctx_dft, batch_dft); + llama_decode_ext(ctx_dft, batch_dft.get()); ++n_past_cur; ++n_drafted; @@ -597,15 +597,15 @@ int main(int argc, char ** argv) { llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1); } - llama_batch_ext_clear(batch_tgt); + llama_batch_ext_clear(batch_tgt.get()); for (int i = 0; i < (int) batch_tgt_data.size(); ++i) { const auto & data = batch_tgt_data[i]; - llama_batch_ext_add_text(batch_tgt, data.id, data.pos, data.seq_id.data(), data.seq_id.size(), true); + batch_tgt.add_text(data.id, data.pos, data.seq_id, true); } // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); - llama_decode_ext(ctx_tgt, batch_tgt); + llama_decode_ext(ctx_tgt, batch_tgt.get()); ++n_past_tgt; } @@ -648,9 +648,6 @@ int main(int argc, char ** argv) { common_sampler_free(drafts[s].smpl); } - llama_batch_ext_free(batch_dft); - llama_batch_ext_free(batch_tgt); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index bfa4b6d574945..45107e2d82cfe 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -826,7 +826,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // create a llama_batch // we use this object to submit token data for decoding - llama_batch_ext * batch = llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel); + llama_batch_ext_ptr batch(llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel)); std::vector seq_ids(n_parallel, 0); for (int32_t i = 0; i < n_parallel; ++i) { @@ -835,14 +835,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // evaluate the initial prompt for (size_t i = 0; i < prompt_inp.size(); ++i) { - llama_batch_ext_add_text(batch, prompt_inp[i], i, seq_ids.data(), seq_ids.size(), false); + batch.add_text(prompt_inp[i], i, seq_ids, false); } - GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) prompt_inp.size()); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) == (int) prompt_inp.size()); // llama_decode will output logits only for the last token of the prompt - llama_batch_ext_set_output_last(batch); + llama_batch_ext_set_output_last(batch.get()); - if (llama_decode_ext(ctx_ttc, batch) != 0) { + if (llama_decode_ext(ctx_ttc, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -861,16 +861,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // remember the batch index of the last token for each parallel sequence // we need this to determine which logits to sample from - std::vector i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); + std::vector i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch.get()) - 1); - int n_past = llama_batch_ext_get_n_tokens(batch); + int n_past = llama_batch_ext_get_n_tokens(batch.get()); int n_decode = 0; bool next_token_uses_guide_token = true; while (n_decode <= n_predict) { // prepare the next batch - llama_batch_ext_clear(batch); + llama_batch_ext_clear(batch.get()); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -926,14 +926,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 //LOG_CNT("%d", i); } - i_batch[i] = llama_batch_ext_get_n_tokens(batch); + i_batch[i] = llama_batch_ext_get_n_tokens(batch.get()); // push this new token for next evaluation - llama_batch_ext_add_text(batch, new_token_id, n_past, &i, 1, true); + batch.add_text(new_token_id, n_past, i, true); } // all streams are finished - if (llama_batch_ext_get_n_tokens(batch) == 0) { + if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { break; } @@ -941,14 +941,12 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 n_past += 1; // evaluate the current batch with the transformer model - if (llama_decode_ext(ctx_ttc, batch)) { + if (llama_decode_ext(ctx_ttc, batch.get())) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } } - llama_batch_ext_free(batch); - LOG("\n"); LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); } @@ -1016,15 +1014,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 const int n_codes = codes.size(); - llama_batch_ext * batch = llama_batch_ext_init(n_codes, 1); + llama_batch_ext_ptr batch(llama_batch_ext_init(n_codes, 1)); for (size_t i = 0; i < codes.size(); ++i) { - llama_seq_id seq_id = 0; - llama_batch_ext_add_text(batch, codes[i], i, &seq_id, 1, true); // TODO: all logits? + batch.add_text(codes[i], i, 0, true); // TODO: all logits? } - GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == n_codes); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) == n_codes); - if (llama_decode_ext(ctx_cts, batch) != 0) { + if (llama_decode_ext(ctx_cts, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -1088,7 +1085,6 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 retval = ENOENT; } - llama_batch_ext_free(batch); llama_backend_free(); return retval; diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 940b6127b8666..1c4f4e859b77c 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -5,6 +5,7 @@ #endif #include +#include #include "llama.h" @@ -65,30 +66,21 @@ struct llama_batch_ext_ptr : std::unique_ptr & tokens, llama_pos pos0, llama_seq_id seq_id, bool output_last) { - size_t n_tokens = tokens.size(); - for (size_t i = 0; i < n_tokens; i++) { - llama_batch_ext_add_text(this->get(), tokens[i], i + pos0, &seq_id, 1, false); - } - if (output_last) { - llama_batch_ext_set_output_last(this->get()); - } - } - // Wrapper to add a single token to the batch, support multiple sequence IDs - void add_text(llama_token token, llama_pos pos, std::vector & seq_id, bool output_last) { - llama_batch_ext_add_text(this->get(), token, pos, seq_id.data(), seq_id.size(), false); + int32_t add_text(llama_token token, llama_pos pos, const std::vector & seq_id, bool output_last) { + int32_t output_id = llama_batch_ext_add_text(this->get(), token, pos, seq_id.data(), seq_id.size(), false); if (output_last) { llama_batch_ext_set_output_last(this->get()); } + return output_id; } // Wrapper to add a single token to the batch (single sequence ID) - void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool output_last) { - llama_batch_ext_add_text(this->get(), token, pos, &seq_id, 1, false); + int32_t add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool output_last) { + int32_t output_id = llama_batch_ext_add_text(this->get(), token, pos, &seq_id, 1, false); if (output_last) { llama_batch_ext_set_output_last(this->get()); } + return output_id; } }; From 3802ff2a9b8df77d28797d53eb3d9dca24f960d1 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 21 Mar 2025 17:33:48 +0100 Subject: [PATCH 44/52] add batch.clear() and batch.n_tokens() --- common/speculative.cpp | 8 ++--- examples/gritlm/gritlm.cpp | 6 ++-- examples/llava/gemma3-cli.cpp | 4 +-- examples/lookup/lookup.cpp | 2 +- examples/parallel/parallel.cpp | 16 ++++----- examples/passkey/passkey.cpp | 8 ++--- examples/perplexity/perplexity.cpp | 12 +++---- examples/run/run.cpp | 4 +-- examples/save-load-state/save-load-state.cpp | 8 ++--- examples/server/server.cpp | 34 +++++++++---------- examples/simple-chat/simple-chat.cpp | 6 ++-- examples/simple/simple.cpp | 6 ++-- .../speculative-simple/speculative-simple.cpp | 2 +- examples/speculative/speculative.cpp | 8 ++--- examples/tts/tts.cpp | 14 ++++---- include/llama-cpp.h | 8 +++++ 16 files changed, 77 insertions(+), 69 deletions(-) diff --git a/common/speculative.cpp b/common/speculative.cpp index 2b98a978db7c8..a798fcb67f2a7 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -204,7 +204,7 @@ llama_tokens common_speculative_gen_draft( } // prepare a batch to evaluate any new tokens in the prompt - llama_batch_ext_clear(batch.get()); + batch.clear(); for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); @@ -214,7 +214,7 @@ llama_tokens common_speculative_gen_draft( } // we should rarely end-up here during normal decoding - if (llama_batch_ext_get_n_tokens(batch.get()) > 0) { + if (batch.n_tokens() > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode_ext(ctx, batch.get()); @@ -224,7 +224,7 @@ llama_tokens common_speculative_gen_draft( LOG_DBG("%s: n_past = %d\n", __func__, n_past); - llama_batch_ext_clear(batch.get()); + batch.clear(); batch.add_text(id_last, n_past, 0, true); prompt.push_back(id_last); @@ -237,7 +237,7 @@ llama_tokens common_speculative_gen_draft( // sample n_draft tokens from the draft model for (int i = 0; i < params.n_draft; ++i) { - llama_batch_ext_clear(batch.get()); + batch.clear(); common_sampler_sample(smpl, ctx, 0, true); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 9a3ea912a1122..deda96099b212 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -17,7 +17,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1)); for (uint64_t i = 0; i < sentences.size(); i++) { - llama_batch_ext_clear(batch.get()); + batch.clear(); const std::string input_string = instruction + sentences[i]; @@ -111,7 +111,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std int32_t i_current_token = 0; while (true) { - llama_batch_ext_clear(batch.get()); + batch.clear(); { const int32_t n_inputs = inputs.size(); @@ -123,7 +123,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_decode_ext(ctx, batch.get()); - llama_token token = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1); + llama_token token = llama_sampler_sample(smpl, ctx, batch.n_tokens() - 1); if (token == eos_token) { break; diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index 28d5a13a8abc1..2ae5e665e942b 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -90,7 +90,7 @@ struct gemma3_context { static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); - llama_batch_ext_clear(ctx.batch.get()); + ctx.batch.clear(); for (llama_token & t : tokens) { ctx.batch.add_text(t, ctx.n_past++, 0, false); } @@ -178,7 +178,7 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_ fflush(stdout); // eval the token - llama_batch_ext_clear(ctx.batch.get()); + ctx.batch.clear(); ctx.batch.add_text(token_id, ctx.n_past++, 0, true); if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { LOG_ERR("failed to decode token\n"); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 5fd43f8dd4bea..b812e6151d793 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -197,7 +197,7 @@ int main(int argc, char ** argv){ // clean the cache of draft tokens that weren't accepted llama_kv_self_seq_rm(ctx, 0, n_past, -1); - llama_batch_ext_clear(batch_tgt.get()); + batch_tgt.clear(); batch_tgt.add_text(draft[0], n_past, 0, true); // Draft already contains a single token sampled from the model: diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index c9e80daea68f1..8f54875700ee4 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -217,7 +217,7 @@ int main(int argc, char ** argv) { common_kv_cache_dump_view_seqs(kvc_view, 40); } - llama_batch_ext_clear(batch.get()); + batch.clear(); // decode any currently ongoing sequences for (auto & client : clients) { @@ -225,7 +225,7 @@ int main(int argc, char ** argv) { continue; } - client.i_batch = llama_batch_ext_get_n_tokens(batch.get()); + client.i_batch = batch.n_tokens(); llama_seq_id seq_id = client.id + 1; batch.add_text(client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, seq_id, true); @@ -233,7 +233,7 @@ int main(int argc, char ** argv) { client.n_decoded += 1; } - if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (batch.n_tokens() == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { llama_kv_self_seq_rm(ctx, i, -1, -1); @@ -245,7 +245,7 @@ int main(int argc, char ** argv) { } // insert new sequences for decoding - if (cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (cont_batching || batch.n_tokens() == 0) { for (auto & client : clients) { if (client.seq_id == -1 && g_seq_id < n_seq) { client.seq_id = g_seq_id; @@ -269,13 +269,13 @@ int main(int argc, char ** argv) { } // extract the logits only for the last token - if (llama_batch_ext_get_n_tokens(batch.get()) > 0) { + if (batch.n_tokens() > 0) { llama_batch_ext_set_output_last(batch.get()); } client.n_prompt = tokens_prompt.size(); client.n_decoded = 0; - client.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1; + client.i_batch = batch.n_tokens() - 1; LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); @@ -289,14 +289,14 @@ int main(int argc, char ** argv) { } } - if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (batch.n_tokens() == 0) { break; } // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - int32_t n_tokens_in_batch = llama_batch_ext_get_n_tokens(batch.get()); + int32_t n_tokens_in_batch = batch.n_tokens(); for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) { // experiment: process in powers of 2 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 42674525064fd..de093e5e0d35e 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -141,7 +141,7 @@ int main(int argc, char ** argv) { n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; } - llama_batch_ext_clear(batch.get()); + batch.clear(); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { batch.add_text(tokens_list[i + j], n_past++, 0, false); @@ -175,7 +175,7 @@ int main(int argc, char ** argv) { n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; - llama_batch_ext_clear(batch.get()); + batch.clear(); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { batch.add_text(tokens_list[i + j], n_past++, 0, false); @@ -224,7 +224,7 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - const llama_token new_token_id = llama_sampler_sample(smpl, ctx, llama_batch_ext_get_n_tokens(batch.get()) - 1); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens() - 1); // is it an end of generation? if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { @@ -238,7 +238,7 @@ int main(int argc, char ** argv) { n_decode += 1; // prepare the next batch - llama_batch_ext_clear(batch.get()); + batch.clear(); // push this new token for next evaluation llama_seq_id seq_id = 0; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 3c8bb0cc9d59b..ac04ba355e6d9 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -369,7 +369,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - llama_batch_ext_clear(batch.get()); + batch.clear(); for (int i = 0; i < batch_size; i++) { batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); } @@ -552,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & int n_outputs = 0; - llama_batch_ext_clear(batch.get()); + batch.clear(); for (int seq = 0; seq < n_seq_batch; seq++) { int seq_start = batch_start + seq*n_ctx; @@ -846,7 +846,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - llama_batch_ext_clear(batch.get()); + batch.clear(); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -1131,7 +1131,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) size_t i1 = i0; size_t i_logits = 0; - llama_batch_ext_clear(batch.get()); + batch.clear(); while (n_cur + (int) data[i1].required_tokens <= n_ctx) { int n_logits = 0; @@ -1485,7 +1485,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - llama_batch_ext_clear(batch.get()); + batch.clear(); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -1744,7 +1744,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_vocab_bos(vocab); } - llama_batch_ext_clear(batch.get()); + batch.clear(); for (int i = 0; i < batch_size; i++) { batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); } diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 9396687d6e1ea..68526519baacb 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -954,7 +954,7 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) { const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx_used = llama_kv_self_used_cells(ctx.get()); - if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) { + if (n_ctx_used + batch.n_tokens() > n_ctx) { printf(LOG_COL_DEFAULT "\n"); printe("context size exceeded\n"); return 1; @@ -1001,7 +1001,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str return 1; } - llama_data.n_past += llama_batch_ext_get_n_tokens(batch.get()); + llama_data.n_past += batch.n_tokens(); // sample the next token, check is it an end of generation? new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 2f7804ade596c..9ffe6780c5503 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -52,7 +52,7 @@ int main(int argc, char ** argv) { // evaluate prompt llama_decode_ext(ctx, batch.get()); - n_past += llama_batch_ext_get_n_tokens(batch.get()); + n_past += batch.n_tokens(); // save state (rng, logits, embedding and kv_cache) to file { @@ -79,7 +79,7 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result0 += next_token_str; - llama_batch_ext_clear(batch.get()); + batch.clear(); batch.add_text(next_token, 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { @@ -131,7 +131,7 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result1 += next_token_str; - llama_batch_ext_clear(batch.get()); + batch.clear(); batch.add_text(next_token, 0, 0, true); if (llama_decode_ext(ctx2, batch.get())) { @@ -212,7 +212,7 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result2 += next_token_str; - llama_batch_ext_clear(batch.get()); + batch.clear(); batch.add_text(next_token, 0, 1, true); if (llama_decode_ext(ctx3, batch.get())) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2dbdb2079b147..a3097d87eff36 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2428,7 +2428,7 @@ struct server_context { res->embedding.push_back(embd_res); } else { GGML_ABORT("embeddings without pooling is not supported yet"); - //for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { + //for (int i = 0; i < batch.n_tokens(); ++i) { // auto tok = batch.tokens[i]; // if (!tok.logits || tok.seq_id != slot.id) { // continue; @@ -2850,7 +2850,7 @@ struct server_context { } // start populating the batch for this iteration - llama_batch_ext_clear(batch.get()); + batch.clear(); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2872,7 +2872,7 @@ struct server_context { continue; } - slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); + slot.i_batch = batch.n_tokens(); batch.add_text(slot.sampled, slot.n_past, slot.id, true); @@ -2891,7 +2891,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (params_base.cont_batching || batch.n_tokens() == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3057,7 +3057,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { + if (batch.n_tokens() + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3077,7 +3077,7 @@ struct server_context { slot.cache_tokens.resize(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; @@ -3092,13 +3092,13 @@ struct server_context { } SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", - slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + slot.n_past, batch.n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0); + GGML_ASSERT(batch.n_tokens() > 0); common_sampler_reset(slot.smpl); @@ -3111,24 +3111,24 @@ struct server_context { llama_batch_ext_set_output_last(batch.get()); slot.n_decoded = 0; - slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1; + slot.i_batch = batch.n_tokens() - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get())); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens()); } } - if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) { + if (batch.n_tokens() >= n_batch) { break; } } } - if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (batch.n_tokens() == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get())); + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens()); if (slot_batched) { // make sure we're in the right embedding mode @@ -3138,8 +3138,8 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) { - const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i); + for (int32_t i = 0; i < batch.n_tokens(); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens() - i); llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens)); @@ -3278,14 +3278,14 @@ struct server_context { } // construct the speculation batch - llama_batch_ext_clear(slot.batch_spec.get()); + slot.batch_spec.clear(); slot.batch_spec.add_text(id, slot.n_past, slot.id, true); for (size_t i = 0; i < draft.size(); ++i) { slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); } - SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens()); llama_decode_ext(ctx, slot.batch_spec.get()); diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 51c5c837dc624..44824199c4fb9 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -111,14 +111,14 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt llama_pos n_past = 0; auto batch = llama_batch_ext_ptr::init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); - n_past += llama_batch_ext_get_n_tokens(batch.get()); + n_past += batch.n_tokens(); llama_token new_token_id; while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); int n_ctx_used = llama_kv_self_used_cells(ctx); - if (n_ctx_used + llama_batch_ext_get_n_tokens(batch.get()) > n_ctx) { + if (n_ctx_used + batch.n_tokens() > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); exit(0); @@ -148,7 +148,7 @@ int main(int argc, char ** argv) { response += piece; // prepare the next batch with the sampled token - llama_batch_ext_clear(batch.get()); + batch.clear(); batch.add_text(new_token_id, n_past, 0, true); n_past++; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index bcd119ad00247..d9a6a63396817 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -152,14 +152,14 @@ int main(int argc, char ** argv) { int n_decode = 0; llama_token new_token_id; - for (int n_pos = 0; n_pos + llama_batch_ext_get_n_tokens(batch.get()) < n_prompt + n_predict; ) { + for (int n_pos = 0; n_pos + batch.n_tokens() < n_prompt + n_predict; ) { // evaluate the current batch with the transformer model if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } - n_pos += llama_batch_ext_get_n_tokens(batch.get()); + n_pos += batch.n_tokens(); // sample the next token { @@ -181,7 +181,7 @@ int main(int argc, char ** argv) { fflush(stdout); // prepare the next batch with the sampled token - llama_batch_ext_clear(batch.get()); + batch.clear(); batch.add_text(new_token_id, n_pos, 0, true); n_decode += 1; diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index d1c4e4a5d7b45..1d112c7fa0463 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); // always have a token to evaluate from before - id_last - llama_batch_ext_clear(batch_tgt.get()); + batch_tgt.clear(); batch_tgt.add_text(id_last, n_past++, 0, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 8f4500e8ca68d..c7baac692a3b7 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -444,7 +444,7 @@ int main(int argc, char ** argv) { drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); - llama_batch_ext_clear(batch_dft.get()); + batch_dft.clear(); batch_dft.add_text(token_id, n_past_dft, 0, true); llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); @@ -486,7 +486,7 @@ int main(int argc, char ** argv) { // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { - llama_batch_ext_clear(batch_dft.get()); + batch_dft.clear(); for (int s = 0; s < n_seq_dft; ++s) { drafts[s].skip = false; @@ -576,7 +576,7 @@ int main(int argc, char ** argv) { } // no sequence is drafting anymore - if (llama_batch_ext_get_n_tokens(batch_dft.get()) == 0) { + if (batch_dft.n_tokens() == 0) { break; } @@ -597,7 +597,7 @@ int main(int argc, char ** argv) { llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1); } - llama_batch_ext_clear(batch_tgt.get()); + batch_tgt.clear(); for (int i = 0; i < (int) batch_tgt_data.size(); ++i) { const auto & data = batch_tgt_data[i]; diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 45107e2d82cfe..ca220d36ca35c 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -837,7 +837,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 for (size_t i = 0; i < prompt_inp.size(); ++i) { batch.add_text(prompt_inp[i], i, seq_ids, false); } - GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) == (int) prompt_inp.size()); + GGML_ASSERT(batch.n_tokens() == (int) prompt_inp.size()); // llama_decode will output logits only for the last token of the prompt llama_batch_ext_set_output_last(batch.get()); @@ -861,16 +861,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // remember the batch index of the last token for each parallel sequence // we need this to determine which logits to sample from - std::vector i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch.get()) - 1); + std::vector i_batch(n_parallel, batch.n_tokens() - 1); - int n_past = llama_batch_ext_get_n_tokens(batch.get()); + int n_past = batch.n_tokens(); int n_decode = 0; bool next_token_uses_guide_token = true; while (n_decode <= n_predict) { // prepare the next batch - llama_batch_ext_clear(batch.get()); + batch.clear(); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -926,14 +926,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 //LOG_CNT("%d", i); } - i_batch[i] = llama_batch_ext_get_n_tokens(batch.get()); + i_batch[i] = batch.n_tokens(); // push this new token for next evaluation batch.add_text(new_token_id, n_past, i, true); } // all streams are finished - if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { + if (batch.n_tokens() == 0) { break; } @@ -1019,7 +1019,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 for (size_t i = 0; i < codes.size(); ++i) { batch.add_text(codes[i], i, 0, true); // TODO: all logits? } - GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) == n_codes); + GGML_ASSERT(batch.n_tokens() == n_codes); if (llama_decode_ext(ctx_cts, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 1c4f4e859b77c..676493ad27fb6 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -83,4 +83,12 @@ struct llama_batch_ext_ptr : std::unique_ptrget()); + } + + int32_t n_tokens() const { + return llama_batch_ext_get_n_tokens(this->get()); + } }; From a9efdbbce56e95579b255d0052b289a76c1e7d9f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 23 Mar 2025 18:22:58 +0100 Subject: [PATCH 45/52] qwen2vl: fix mrope position --- examples/llava/qwen2vl-cli.cpp | 12 ++++++------ src/llama-batch.cpp | 5 +++-- src/llama-graph.cpp | 4 +++- src/llama-graph.h | 2 ++ 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 490ee7fdc6ee1..17c673024dca6 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -68,7 +68,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla float * batch_embd = image_embed->embed+i*n_embd; auto batch = llama_batch_ext_ptr::init_from_embd(batch_embd, n_eval, n_embd, 0, 0); - llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval); + llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval * 4); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); @@ -91,18 +91,18 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector #include @@ -356,7 +357,7 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); } - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc * MAX_POS_PER_TOKEN); batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); for (int i = 0; i < n_tokens_alloc; ++i) { @@ -390,7 +391,7 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd( } int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) { - if ((size_t) batch->n_tokens != n_pos) { + if ((size_t) batch->n_tokens * MAX_POS_PER_TOKEN < n_pos) { return -1; } memcpy(batch->pos, pos, n_pos * sizeof(llama_pos)); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0bd40174438cc..23f86a5e9dea5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -603,7 +603,9 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : } int64_t llm_graph_context::n_pos_per_token() const { - return arch == LLM_ARCH_QWEN2VL ? 4 : 1; + constexpr int64_t n_pos_per_token_qwen2vl = 4; + static_assert(n_pos_per_token_qwen2vl <= MAX_POS_PER_TOKEN); + return arch == LLM_ARCH_QWEN2VL ? n_pos_per_token_qwen2vl : 1; } void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { diff --git a/src/llama-graph.h b/src/llama-graph.h index bdf19ed015e35..0f80f567c312f 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -10,6 +10,8 @@ #include #include +#define MAX_POS_PER_TOKEN 4 + struct ggml_cgraph; struct ggml_context; struct ggml_tensor; From d18a79ed07c61e21474db1097ccf929b98f044e7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 25 Mar 2025 16:35:31 +0100 Subject: [PATCH 46/52] llama_batch_ext_init with ctx --- common/common.cpp | 4 +- common/speculative.cpp | 2 +- examples/batched-bench/batched-bench.cpp | 2 +- examples/batched/batched.cpp | 2 +- .../cvector-generator/cvector-generator.cpp | 2 +- examples/embedding/embedding.cpp | 2 +- examples/eval-callback/eval-callback.cpp | 2 +- examples/gritlm/gritlm.cpp | 4 +- examples/imatrix/imatrix.cpp | 2 +- examples/infill/infill.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 4 +- .../llama/src/main/cpp/llama-android.cpp | 5 ++- .../java/android/llama/cpp/LLamaAndroid.kt | 4 +- .../llama.cpp.swift/LibLlama.swift | 2 +- examples/llava/gemma3-cli.cpp | 5 ++- examples/llava/llava-cli.cpp | 2 +- examples/llava/llava.cpp | 2 +- examples/llava/minicpmv-cli.cpp | 2 +- examples/llava/qwen2vl-cli.cpp | 9 ++-- examples/lookahead/lookahead.cpp | 6 +-- examples/lookup/lookup.cpp | 6 +-- examples/main/main.cpp | 4 +- examples/parallel/parallel.cpp | 2 +- examples/passkey/passkey.cpp | 2 +- examples/perplexity/perplexity.cpp | 12 +++--- examples/retrieval/retrieval.cpp | 4 +- examples/run/run.cpp | 4 +- examples/save-load-state/save-load-state.cpp | 2 +- examples/server/server.cpp | 6 +-- examples/simple-chat/simple-chat.cpp | 2 +- examples/simple/simple.cpp | 2 +- .../speculative-simple/speculative-simple.cpp | 4 +- examples/speculative/speculative.cpp | 10 ++--- examples/tts/tts.cpp | 4 +- include/llama-cpp.h | 41 +++++++++++++------ include/llama.h | 16 +++----- src/llama-batch.cpp | 26 +++++------- src/llama-graph.cpp | 12 ++---- src/llama-graph.h | 3 +- src/llama-model.cpp | 8 ++++ src/llama-model.h | 2 + 41 files changed, 124 insertions(+), 113 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index e3b9261fc0f8c..736d8899a5eee 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1016,7 +1016,7 @@ struct common_init_result common_init_from_params(common_params & params) { } if (llama_model_has_encoder(model)) { - auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), tmp.size(), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), tmp.size(), 0, 0, true); llama_encode_ext(lctx, batch.get()); llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == LLAMA_TOKEN_NULL) { @@ -1026,7 +1026,7 @@ struct common_init_result common_init_from_params(common_params & params) { tmp.push_back(decoder_start_token_id); } if (llama_model_has_decoder(model)) { - auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true); llama_decode_ext(lctx, batch.get()); } llama_kv_self_clear(lctx); diff --git a/common/speculative.cpp b/common/speculative.cpp index a798fcb67f2a7..f16b2c6e36d6c 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -23,7 +23,7 @@ struct common_speculative * common_speculative_init( auto * result = new common_speculative { /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, - /* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)), + /* .batch = */ llama_batch_ext_ptr(ctx_dft), /* .prompt = */ {}, }; diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 063b5ca8bc84a..331ec88852733 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -59,7 +59,7 @@ int main(int argc, char ** argv) { const int32_t n_kv_max = llama_n_ctx(ctx); - llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1); + llama_batch_ext * batch = llama_batch_ext_init(ctx); // decode in batches of ctx_params.n_batch tokens auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 9f169b41b505a..204544f6f38cd 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -102,7 +102,7 @@ int main(int argc, char ** argv) { // create a llama_batch // we use this object to submit token data for decoding - llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel); + llama_batch_ext * batch = llama_batch_ext_init(ctx); std::vector seq_ids(n_parallel, 0); for (int32_t i = 0; i < n_parallel; ++i) { diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 6b25dc1db6efe..5b7a42025d8b4 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -343,7 +343,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_self_clear(ctx); - auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 947bbc1741021..49c4700238522 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -167,7 +167,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); + llama_batch_ext * batch = llama_batch_ext_init(ctx); // count number of embeddings int n_embd_count = 0; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 21ca9b4ceec61..4253b5ca4193e 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); - auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index deda96099b212..e8753f2163e01 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -14,7 +14,7 @@ static std::vector> encode(llama_context * ctx, const std::ve const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); - llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1)); + llama_batch_ext_ptr batch(ctx); for (uint64_t i = 0; i < sentences.size(); i++) { batch.clear(); @@ -105,7 +105,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); - llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1)); + llama_batch_ext_ptr batch(ctx); std::vector inputs = common_tokenize(vocab, prompt, false, true); int32_t i_current_token = 0; diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 2461afcdce565..55cecad84d97a 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -497,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_self_clear(ctx); - llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); + llama_batch_ext * batch = llama_batch_ext_init(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 29cba998968e7..9d5c129581c6f 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -353,7 +353,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, &embd[i], n_eval, n_past, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index c671194c77864..e956baf15d263 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1444,7 +1444,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), n_tokens, n_past + n_processed, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), n_tokens, n_past + n_processed, 0, true); llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } @@ -1462,7 +1462,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - auto batch = llama_batch_ext_ptr::init_from_text(&token, 1, n_past + i, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, &token, 1, n_past + i, 0, true); llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 9bf7db399b408..c3c94b11bd3fc 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -273,8 +273,9 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( extern "C" JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { - llama_batch_ext * batch = llama_batch_ext_init(n_tokens, n_seq_max); +Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jlong context_pointer) { + const auto context = reinterpret_cast(context_pointer); + llama_batch_ext * batch = llama_batch_ext_init(context); return reinterpret_cast(batch); } diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index b964d93e37819..f58f7431a3ca6 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -45,7 +45,7 @@ class LLamaAndroid { private external fun free_context(context: Long) private external fun backend_init(numa: Boolean) private external fun backend_free() - private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long + private external fun new_batch(context: Long): Long private external fun free_batch(batch: Long) private external fun new_sampler(): Long private external fun free_sampler(sampler: Long) @@ -102,7 +102,7 @@ class LLamaAndroid { val context = new_context(model) if (context == 0L) throw IllegalStateException("new_context() failed") - val batch = new_batch(512, 0, 1) + val batch = new_batch(context) if (batch == 0L) throw IllegalStateException("new_batch() failed") val sampler = new_sampler() diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index d04c6353eec1d..c4ddaf9bc51c3 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -26,7 +26,7 @@ actor LlamaContext { self.model = model self.context = context self.tokens_list = [] - self.batch = llama_batch_ext_init(512, 1) + self.batch = llama_batch_ext_init(context) self.temporary_invalid_cchars = [] let sparams = llama_sampler_chain_default_params() self.sampling = llama_sampler_chain_init(sparams) diff --git a/examples/llava/gemma3-cli.cpp b/examples/llava/gemma3-cli.cpp index 2ae5e665e942b..56c08b27e6994 100644 --- a/examples/llava/gemma3-cli.cpp +++ b/examples/llava/gemma3-cli.cpp @@ -74,7 +74,7 @@ struct gemma3_context { lctx = llama_init.context.get(); vocab = llama_model_get_vocab(model); n_threads = params.cpuparams.n_threads; - batch.reset(llama_batch_ext_init(params.n_batch, 1)); + batch.reset(llama_batch_ext_init(lctx)); init_clip_model(params); } @@ -147,7 +147,8 @@ static int eval_image(gemma3_context & ctx, std::string & fname) { int64_t t1 = ggml_time_ms(); eval_text(ctx, ""); llama_set_causal_attn(ctx.lctx, false); - llama_batch_ext_ptr batch_img(llama_batch_ext_init_from_embd(image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0)); + llama_batch_ext_ptr batch_img = llama_batch_ext_ptr::init_from_embd( + ctx.lctx, image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0); if (llama_decode_ext(ctx.lctx, batch_img.get())) { LOG_ERR("failed to decode image\n"); return 1; diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 1fa72a24d8a63..884547fcb831a 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx_llama, &tokens[i], n_eval, *n_past, 0, true); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index eda96e19f1b20..f88e4e7a800b9 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -448,7 +448,7 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - auto batch = llama_batch_ext_ptr::init_from_embd(embd, n_eval, n_embd, 0, 0); + auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, embd, n_eval, n_embd, 0, 0); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 81fbc247af292..7aadca9489ab5 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -101,7 +101,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - auto batch = llama_batch_ext_ptr::init_from_text(&tokens[i], n_eval, *n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx_llama, &tokens[i], n_eval, *n_past, 0, true); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 17c673024dca6..9ac4bd086f40c 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -67,8 +67,8 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); float * batch_embd = image_embed->embed+i*n_embd; - auto batch = llama_batch_ext_ptr::init_from_embd(batch_embd, n_eval, n_embd, 0, 0); - llama_batch_ext_set_pos(batch.get(), batch_mrope_pos.data(), n_eval * 4); + const llama_pos * pos = batch_mrope_pos.data(); + auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, batch_embd, n_eval, n_embd, pos, 0); if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); @@ -97,12 +97,11 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector draft; - llama_batch_ext_ptr batch_tgt(llama_batch_ext_init(params.n_ctx, 1)); + llama_batch_ext_ptr batch_tgt(ctx); // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4a779e3601bd0..84bfefba2a13c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -548,7 +548,7 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - auto batch = llama_batch_ext_ptr::init_from_text(enc_input_buf, enc_input_size, 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, enc_input_buf, enc_input_size, 0, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; @@ -669,7 +669,7 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, &embd[i], n_eval, n_past, 0, true); if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 8f54875700ee4..74d823e9a2eee 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -175,7 +175,7 @@ int main(int argc, char ** argv) { // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time - llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 1)); + llama_batch_ext_ptr batch(ctx); int32_t n_total_prompt = 0; int32_t n_total_gen = 0; diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index de093e5e0d35e..94ede1c5a91c2 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -123,7 +123,7 @@ int main(int argc, char ** argv) { LOG_INF("prompt tokens: %d\n", n_tokens_all); //LOG_INF("prompt: %s\n", params.prompt.c_str()); - llama_batch_ext_ptr batch(llama_batch_ext_init(params.n_batch, 1)); + llama_batch_ext_ptr batch(ctx); int n_past = 0; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index ac04ba355e6d9..d0fbc3f571734 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -363,7 +363,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // clear the KV cache llama_kv_self_clear(ctx); - llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1)); + llama_batch_ext_ptr batch(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -501,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); GGML_ASSERT(params.n_ctx == n_seq * n_ctx); - llama_batch_ext_ptr batch(llama_batch_ext_init(std::min(n_batch, n_ctx*n_seq), 1)); + llama_batch_ext_ptr batch(ctx); std::vector logits; if (num_batches > 1) { @@ -830,7 +830,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 4)); + llama_batch_ext_ptr batch(ctx); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -1112,7 +1112,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) const int max_tasks_per_batch = 128; const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, 2)); + llama_batch_ext_ptr batch(ctx); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -1465,7 +1465,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch_ext_ptr batch(llama_batch_ext_init(n_ctx, max_seq)); + llama_batch_ext_ptr batch(ctx); std::vector tok_logits(n_vocab); std::vector batch_logits(size_t(n_ctx)*n_vocab); @@ -1730,7 +1730,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_self_clear(ctx); - llama_batch_ext_ptr batch(llama_batch_ext_init(n_batch, 1)); + llama_batch_ext_ptr batch(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 6086665494654..aa242e96f94b7 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -213,7 +213,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_chunks = chunks.size(); - llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1); + llama_batch_ext * batch = llama_batch_ext_init(ctx); // allocate output const int n_embd = llama_model_n_embd(model); @@ -253,7 +253,7 @@ int main(int argc, char ** argv) { chunks[i].tokens.clear(); } - llama_batch_ext * query_batch = llama_batch_ext_init(n_batch, 1); + llama_batch_ext * query_batch = llama_batch_ext_init(ctx); // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 68526519baacb..876a5a4c0d254 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -992,7 +992,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), llama_data.n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(llama_data.context.get(), tokens.data(), tokens.size(), llama_data.n_past, 0, true); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); @@ -1017,7 +1017,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str print_word_and_concatenate_to_response(piece, response); // prepare the next batch with the sampled token - batch = llama_batch_ext_ptr::init_from_text(&new_token_id, 1, llama_data.n_past, 0, true); + batch.clear(); } printf(LOG_COL_DEFAULT); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 9ffe6780c5503..03b1e7ccfabe4 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,7 +48,7 @@ int main(int argc, char ** argv) { auto tokens = common_tokenize(ctx, params.prompt, true); // prepare the batch - auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true); // evaluate prompt llama_decode_ext(ctx, batch.get()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index adeeca479563b..daf11ce2f5e05 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1927,7 +1927,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1)); + slot.batch_spec.clear(); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1963,7 +1963,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1)); + batch.clear(); } metrics.init(); @@ -2098,7 +2098,7 @@ struct server_context { } if (slot.ctx_dft) { - slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1)); + slot.batch_spec.clear(); } slot.state = SLOT_STATE_STARTED; diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 44824199c4fb9..1425d2b114438 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -110,7 +110,7 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt llama_pos n_past = 0; - auto batch = llama_batch_ext_ptr::init_from_text(prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); n_past += batch.n_tokens(); llama_token new_token_id; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index d9a6a63396817..90bae0aa8614b 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -144,7 +144,7 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt - auto batch = llama_batch_ext_ptr::init_from_text(prompt_tokens.data(), prompt_tokens.size(), 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, prompt_tokens.data(), prompt_tokens.size(), 0, 0, true); // main loop diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 1d112c7fa0463..5981d85304f33 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -114,7 +114,7 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - auto batch = llama_batch_ext_ptr::init_from_text(inp.data(), inp.size() - 1, 0, 0, true); + auto batch = llama_batch_ext_ptr::init_from_text(ctx_tgt, inp.data(), inp.size() - 1, 0, 0, true); llama_decode_ext(ctx_tgt, batch.get()); // note: keep the last token separate! @@ -134,7 +134,7 @@ int main(int argc, char ** argv) { struct common_speculative * spec = common_speculative_init(ctx_dft); - llama_batch_ext_ptr batch_tgt(llama_batch_ext_init(llama_n_batch(ctx_tgt), 1)); + llama_batch_ext_ptr batch_tgt(ctx_tgt); const auto t_enc_end = ggml_time_us(); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index c7baac692a3b7..d61a173408a15 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -166,9 +166,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - auto batch0 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input - 1, 0, 0, true); - auto batch1 = llama_batch_ext_ptr::init_from_text(&inp.back(), 1, n_input - 1, 0, true); - auto batch2 = llama_batch_ext_ptr::init_from_text( inp.data(), n_input , 0, 0, true); + auto batch0 = llama_batch_ext_ptr::init_from_text(ctx_tgt, inp.data(), n_input - 1, 0, 0, true); + auto batch1 = llama_batch_ext_ptr::init_from_text(ctx_tgt, &inp.back(), 1, n_input - 1, 0, true); + auto batch2 = llama_batch_ext_ptr::init_from_text(ctx_dft, inp.data(), n_input , 0, 0, true); llama_decode_ext(ctx_tgt, batch0.get()); llama_decode_ext(ctx_tgt, batch1.get()); llama_decode_ext(ctx_dft, batch2.get()); @@ -202,8 +202,8 @@ int main(int argc, char ** argv) { drafts[s].smpl = common_sampler_init(model_dft, params.sampling); } - llama_batch_ext_ptr batch_dft(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)); - llama_batch_ext_ptr batch_tgt(llama_batch_ext_init(llama_n_batch(ctx_tgt), n_seq_dft)); + llama_batch_ext_ptr batch_dft(ctx_dft); + llama_batch_ext_ptr batch_tgt(ctx_tgt); const auto t_dec_start = ggml_time_us(); diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index ca220d36ca35c..dac097859fd41 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -826,7 +826,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // create a llama_batch // we use this object to submit token data for decoding - llama_batch_ext_ptr batch(llama_batch_ext_init(std::max(prompt_inp.size(), (size_t) n_parallel), n_parallel)); + llama_batch_ext_ptr batch(ctx_ttc); std::vector seq_ids(n_parallel, 0); for (int32_t i = 0; i < n_parallel; ++i) { @@ -1014,7 +1014,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 const int n_codes = codes.size(); - llama_batch_ext_ptr batch(llama_batch_ext_init(n_codes, 1)); + llama_batch_ext_ptr batch(ctx_cts); for (size_t i = 0; i < codes.size(); ++i) { batch.add_text(codes[i], i, 0, true); // TODO: all logits? diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 676493ad27fb6..cb1a6e2fedeac 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -36,18 +36,20 @@ typedef std::unique_ptr llama_ad struct llama_batch_ext_ptr : std::unique_ptr { llama_batch_ext_ptr() : std::unique_ptr() {} + llama_batch_ext_ptr(struct llama_context * ctx) : std::unique_ptr(llama_batch_ext_init(ctx)) {} llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr(batch) {} // Convenience C++ wrapper to create a batch from text tokens, without worrying about manually freeing it // First token will be at position pos0 // The sequence ID will be fixed to seq_id // If output_last is true, the last token will have output set - static llama_batch_ext_ptr init_from_text(llama_token * tokens, - int32_t n_tokens, - llama_pos pos0, - llama_seq_id seq_id, - bool output_last) { - llama_batch_ext * batch = llama_batch_ext_init(n_tokens, 1); + static llama_batch_ext_ptr init_from_text(struct llama_context * ctx, + llama_token * tokens, + int32_t n_tokens, + llama_pos pos0, + llama_seq_id seq_id, + bool output_last) { + llama_batch_ext * batch = llama_batch_ext_init(ctx); for (int32_t i = 0; i < n_tokens; i++) { llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); } @@ -58,12 +60,27 @@ struct llama_batch_ext_ptr : std::unique_ptr pos(n_tokens); + for (size_t i = 0; i < n_tokens; i++) { + pos[i] = pos0 + i; + } + return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(ctx, embd, n_tokens, n_embd, pos.data(), seq_id)); } // Wrapper to add a single token to the batch, support multiple sequence IDs diff --git a/include/llama.h b/include/llama.h index 214345efee7c4..dfc5968e4dc29 100644 --- a/include/llama.h +++ b/include/llama.h @@ -483,6 +483,7 @@ extern "C" { LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); + LLAMA_API uint32_t llama_n_pos_per_token(const struct llama_model * model); LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); @@ -922,29 +923,22 @@ extern "C" { // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids // The batch has to be freed with llama_batch_ext_free() - LLAMA_API struct llama_batch_ext * llama_batch_ext_init( - int32_t n_tokens, - int32_t n_seq_max); + LLAMA_API struct llama_batch_ext * llama_batch_ext_init(struct llama_context * ctx); // Same with llama_batch_init, but initializes the batch with the provided raw embeddings // Size of embd should be n_tokens * n_embd + // Size of pos should be n_tokens * n_pos_per_token // n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd() - // First token will be at position pos0 // The sequence ID will be fixed to seq_id // The batch has to be freed with llama_batch_ext_free() LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( + struct llama_context * ctx, const float * embd, size_t n_tokens, size_t n_embd, - llama_pos pos0, + const llama_pos * pos, llama_seq_id seq_id); - // Set arbitrary token to the embeddings batch - // Note: this is only to be used in conjunction with llama_batch_ext_init_from_embd() - // n_pos must match the n_tokens of the batch - // Returns -1 if n_pos does not match the n_tokens of the batch - LLAMA_API int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos); - // Get the number of tokens in the batch LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index c6ea43ae0e1af..67bc0ef2161ed 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -370,34 +370,28 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc return batch; } -struct llama_batch_ext * llama_batch_ext_init(int32_t n_tokens_alloc, int32_t n_seq_max) { - return llama_batch_ext_init_impl(n_tokens_alloc, 0, n_seq_max); +struct llama_batch_ext * llama_batch_ext_init(struct llama_context * ctx) { + return llama_batch_ext_init_impl(llama_n_batch(ctx), 0, llama_n_seq_max(ctx)); } struct llama_batch_ext * llama_batch_ext_init_from_embd( - const float * embd, - size_t n_tokens, - size_t n_embd, - llama_pos pos0, - llama_seq_id seq_id) { + struct llama_context * ctx, + const float * embd, + size_t n_tokens, + size_t n_embd, + const llama_pos * pos, + llama_seq_id seq_id) { + auto model = llama_get_model(ctx); struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1); memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float)); + memcpy(batch->pos, pos, n_tokens * llama_n_pos_per_token(model) * sizeof(llama_pos)); for (size_t i = 0; i < n_tokens; i++) { - batch->pos [i] = pos0 + i; batch->n_seq_id[i] = 1; batch->seq_id [i][0] = seq_id; } return batch; } -int32_t llama_batch_ext_set_pos(struct llama_batch_ext * batch, llama_pos * pos, size_t n_pos) { - if ((size_t) batch->n_tokens * MAX_POS_PER_TOKEN < n_pos) { - return -1; - } - memcpy(batch->pos, pos, n_pos * sizeof(llama_pos)); - return 0; -} - int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { return batch->n_tokens; } diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 23f86a5e9dea5..477120ce8ac5c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2,6 +2,7 @@ #include "llama-impl.h" #include "llama-batch.h" +#include "llama-model.h" #include "llama-cparams.h" #include "llama-kv-cache.h" @@ -565,6 +566,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : hparams (params.hparams), cparams (params.cparams), ubatch (params.ubatch), + n_pos_per_token (llama_n_pos_per_token(params.arch)), n_embd (hparams.n_embd), n_layer (hparams.n_layer), n_rot (hparams.n_rot), @@ -602,12 +604,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : res (std::make_unique()) { } -int64_t llm_graph_context::n_pos_per_token() const { - constexpr int64_t n_pos_per_token_qwen2vl = 4; - static_assert(n_pos_per_token_qwen2vl <= MAX_POS_PER_TOKEN); - return arch == LLM_ARCH_QWEN2VL ? n_pos_per_token_qwen2vl : 1; -} - void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { if (cb_func) { cb_func(ubatch, cur, name, il); @@ -1005,11 +1001,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { } ggml_tensor * llm_graph_context::build_inp_pos() const { - auto inp = std::make_unique(n_pos_per_token()); + auto inp = std::make_unique(n_pos_per_token); auto & cur = inp->pos; - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token()); + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token); ggml_set_input(cur); res->add_input(std::move(inp)); diff --git a/src/llama-graph.h b/src/llama-graph.h index 0f80f567c312f..adcca07fe81c7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -357,6 +357,7 @@ struct llm_graph_context { const llama_cparams & cparams; const llama_ubatch & ubatch; + const int64_t n_pos_per_token; const int64_t n_embd; const int64_t n_layer; const int64_t n_rot; @@ -404,8 +405,6 @@ struct llm_graph_context { llm_graph_context(const llm_graph_params & params); - int64_t n_pos_per_token() const; - void cb(ggml_tensor * cur, const char * name, int il) const; // diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0ae754154b069..4992d869612b6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -12057,6 +12057,14 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } +uint32_t llama_n_pos_per_token(llm_arch arch) { + return arch == LLM_ARCH_QWEN2VL ? 4 : 1; +} + +uint32_t llama_n_pos_per_token(const struct llama_model * model) { + return llama_n_pos_per_token(model->arch); +} + float llama_model_rope_freq_scale_train(const llama_model * model) { return model->hparams.rope_freq_scale_train; } diff --git a/src/llama-model.h b/src/llama-model.h index a9da1215abbfd..7320201de7a11 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -401,3 +401,5 @@ const char * llm_type_name(llm_type type); // For internal test use // TODO: remove const std::vector> & llama_internal_get_tensor_map(const llama_model * model); + +uint32_t llama_n_pos_per_token(llm_arch arch); From c4fea7fe65927227ec2601ee3de2e8a7f0a8b7f8 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 25 Mar 2025 18:35:47 +0100 Subject: [PATCH 47/52] fix qwzn2vl mrope position input --- examples/llava/qwen2vl-cli.cpp | 18 ++++---- include/llama.h | 1 + src/llama-batch.cpp | 76 +++++++++++++++++++--------------- src/llama-batch.h | 3 +- src/llama-model.cpp | 5 +++ 5 files changed, 60 insertions(+), 43 deletions(-) diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 9ac4bd086f40c..327831c2434f7 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -66,8 +66,17 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); + // tranpose from layout 0123012301230123 to 0000111122223333 + // TODO @ngxson : this is a low-effort solution, generated with the help of LLM; we should improve this in the future + std::vector batch_mrope_pos_T(n_eval * 4); + for (int r = 0; r < 4; r++) { + for (int c = 0; c < n_eval; c++) { + batch_mrope_pos_T[c*4 + r] = batch_mrope_pos[r*n_eval + c]; + } + } + float * batch_embd = image_embed->embed+i*n_embd; - const llama_pos * pos = batch_mrope_pos.data(); + const llama_pos * pos = batch_mrope_pos_T.data(); auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, batch_embd, n_eval, n_embd, pos, 0); if (llama_decode_ext(ctx_llama, batch.get())) { @@ -90,13 +99,6 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vectorn_tokens > 0); if (!in_batch.pos) { @@ -338,17 +339,18 @@ struct llama_batch llama_batch_get_one( }; } -static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max) { +static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max, int32_t n_pos_per_token) { llama_batch_ext * batch = new llama_batch_ext{ - /*n_tokens =*/ 0, - /*max_tokens =*/ n_tokens_alloc, - /*is_view =*/ false, - /*tokens =*/ nullptr, - /*embd =*/ nullptr, - /*pos =*/ nullptr, - /*n_seq_id =*/ nullptr, - /*seq_id =*/ nullptr, - /*logits =*/ nullptr, + /*n_tokens =*/ 0, + /*max_tokens =*/ n_tokens_alloc, + /*n_pos_per_token =*/ n_pos_per_token, + /*is_view =*/ false, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, }; if (n_embd) { @@ -371,7 +373,8 @@ static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc } struct llama_batch_ext * llama_batch_ext_init(struct llama_context * ctx) { - return llama_batch_ext_init_impl(llama_n_batch(ctx), 0, llama_n_seq_max(ctx)); + int32_t n_pos_per_token = llama_n_pos_per_token(llama_get_model(ctx)); + return llama_batch_ext_init_impl(llama_n_batch(ctx), 0, llama_n_seq_max(ctx), n_pos_per_token); } struct llama_batch_ext * llama_batch_ext_init_from_embd( @@ -381,10 +384,10 @@ struct llama_batch_ext * llama_batch_ext_init_from_embd( size_t n_embd, const llama_pos * pos, llama_seq_id seq_id) { - auto model = llama_get_model(ctx); - struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1); + int32_t n_pos_per_token = llama_n_pos_per_token(llama_get_model(ctx)); + struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1, n_pos_per_token); memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float)); - memcpy(batch->pos, pos, n_tokens * llama_n_pos_per_token(model) * sizeof(llama_pos)); + memcpy(batch->pos, pos, n_tokens * n_pos_per_token * sizeof(llama_pos)); for (size_t i = 0; i < n_tokens; i++) { batch->n_seq_id[i] = 1; batch->seq_id [i][0] = seq_id; @@ -411,12 +414,16 @@ int32_t llama_batch_ext_add_text( } const int32_t output_id = batch->n_tokens; batch->token [output_id] = token; - batch->pos [output_id] = pos; + batch->n_seq_id[output_id] = n_seq_ids; + batch->logits [output_id] = output; + for (int32_t i = 0; i < batch->n_pos_per_token; i++) { + // TODO: this is only used by qwen2vl for now, and text tokens only have 3 pos, the last is set to 0; we should improve this code in the future + batch->pos[output_id * batch->n_pos_per_token + i] = i < 3 ? pos : 0; + } batch->n_seq_id[output_id] = n_seq_ids; for (size_t j = 0; j < n_seq_ids; j++) { batch->seq_id[batch->n_tokens][j] = seq_ids[j]; } - batch->logits [output_id] = output; batch->n_tokens++; return output_id; } @@ -461,15 +468,16 @@ struct llama_batch_ext * llama_batch_ext_get_view( return nullptr; // not yet supported } llama_batch_ext * batch_view = new llama_batch_ext{ - /*n_tokens =*/ n_tokens, - /*max_tokens =*/ n_tokens, - /*is_view =*/ true, - /*tokens =*/ batch->token + offset, - /*embd =*/ nullptr, - /*pos =*/ batch->pos + offset, - /*n_seq_id =*/ batch->n_seq_id + offset, - /*seq_id =*/ batch->seq_id + offset, - /*logits =*/ batch->logits + offset, + /*n_tokens =*/ n_tokens, + /*max_tokens =*/ n_tokens, + /*n_pos_per_token =*/ batch->n_pos_per_token, + /*is_view =*/ true, + /*tokens =*/ batch->token + offset, + /*embd =*/ nullptr, + /*pos =*/ batch->pos + offset * batch->n_pos_per_token, + /*n_seq_id =*/ batch->n_seq_id + offset, + /*seq_id =*/ batch->seq_id + offset, + /*logits =*/ batch->logits + offset, }; return batch_view; } diff --git a/src/llama-batch.h b/src/llama-batch.h index 1b3413ac24428..6671cdcd76df1 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -21,11 +21,12 @@ struct llama_batch_ext { int32_t n_tokens; int32_t max_tokens; + int32_t n_pos_per_token = 1; bool is_view; llama_token * token; float * embd; - llama_pos * pos; + llama_pos * pos; // if multi pos per token: 000011112222... int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 4992d869612b6..54697c3bb908c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6075,6 +6075,11 @@ struct llm_build_qwen2vl : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); + // TODO @ngxson : transpose layout 0000111122223333 to 0123012301230123, we should improve this in the future + inp_pos = ggml_reshape_2d(ctx0, inp_pos, n_tokens, n_pos_per_token); + inp_pos = ggml_cont(ctx0, ggml_transpose(ctx0, inp_pos)); + inp_pos = ggml_reshape_1d(ctx0, inp_pos, n_pos_per_token * n_tokens); + auto * inp_attn = build_attn_inp_kv_unified(); int sections[4]; From 42062cc2c7106681aa7fe95e6fbda9e1369b2124 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 25 Mar 2025 18:39:19 +0100 Subject: [PATCH 48/52] fix build --- examples/parallel/parallel.cpp | 2 -- examples/server/server.cpp | 8 +------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 74d823e9a2eee..ee713ab416645 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -158,8 +158,6 @@ int main(int argc, char ** argv) { LOG_INF("\n\n"); - const int n_ctx = llama_n_ctx(ctx); - std::vector clients(n_clients); for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index daf11ce2f5e05..1def160356c08 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1959,13 +1959,7 @@ struct server_context { // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) - { - const int32_t n_batch = llama_n_batch(ctx); - - // only a single seq_id per token is needed - batch.clear(); - } - + batch.clear(); metrics.init(); } From 56e82d024435e8672432ee72eecfe632c15a5d65 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 25 Mar 2025 19:16:19 +0100 Subject: [PATCH 49/52] fix server --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1def160356c08..876367c6d97ec 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1959,7 +1959,7 @@ struct server_context { // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) - batch.clear(); + batch = llama_batch_ext_ptr(ctx); metrics.init(); } From 50fb3963cfdf7734baad18502028a6893fc6c0eb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 25 Mar 2025 19:25:30 +0100 Subject: [PATCH 50/52] server: fix batch_spec --- examples/server/server.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 876367c6d97ec..36ff66c2c3f0f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1925,6 +1925,7 @@ struct server_context { slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; + slot.batch_spec = llama_batch_ext_ptr(ctx); if (model_dft) { slot.batch_spec.clear(); From 8ec0ff9b7f8ab2935e64862e94507233a5614f7f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 27 Mar 2025 11:26:16 +0100 Subject: [PATCH 51/52] fix embeddings and retrieval --- examples/embedding/embedding.cpp | 23 +++++++---------------- examples/retrieval/retrieval.cpp | 30 ++++++++++-------------------- include/llama-cpp.h | 12 ++++++++++++ 3 files changed, 29 insertions(+), 36 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 49c4700238522..91b4579f2c747 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -26,13 +26,6 @@ static std::vector split_lines(const std::string & s, const std::st return lines; } -static void batch_add_seq(llama_batch_ext * batch, const std::vector & tokens, llama_seq_id seq_id) { - size_t n_tokens = tokens.size(); - for (size_t i = 0; i < n_tokens; i++) { - llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true); - } -} - static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const llama_model * model = llama_get_model(ctx); @@ -167,7 +160,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - llama_batch_ext * batch = llama_batch_ext_init(ctx); + llama_batch_ext_ptr batch(ctx); // count number of embeddings int n_embd_count = 0; @@ -194,21 +187,21 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) { - batch_decode(ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize); - llama_batch_ext_clear(batch); + if (batch.n_tokens() + n_toks > n_batch) { + batch_decode(ctx, batch.get(), emb + e * n_embd, s, n_embd, params.embd_normalize); + batch.clear(); - e += pooling_type == LLAMA_POOLING_TYPE_NONE ? llama_batch_ext_get_n_tokens(batch) : s; + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens() : s; s = 0; } // add to batch - batch_add_seq(batch, inp, s); + batch.add_seq(inp, 0, s, true); s += 1; } // final batch - batch_decode(ctx, batch, emb + e * n_embd, s, n_embd, params.embd_normalize); + batch_decode(ctx, batch.get(), emb + e * n_embd, s, n_embd, params.embd_normalize); if (params.embd_out.empty()) { LOG("\n"); @@ -313,8 +306,6 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - llama_batch_ext_free(batch); - // clean up llama_backend_free(); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index aa242e96f94b7..00617e059f412 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -74,13 +74,6 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(llama_batch_ext * batch, const std::vector & tokens, llama_seq_id seq_id) { - const size_t n_tokens = tokens.size(); - for (size_t i = 0; i < n_tokens; i++) { - llama_batch_ext_add_text(batch, tokens[i], i, &seq_id, 1, true); - } -} - static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { const llama_model * model = llama_get_model(ctx); @@ -213,7 +206,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_chunks = chunks.size(); - llama_batch_ext * batch = llama_batch_ext_init(ctx); + llama_batch_ext_ptr batch(ctx); // allocate output const int n_embd = llama_model_n_embd(model); @@ -230,21 +223,21 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (llama_batch_ext_get_n_tokens(batch) + n_toks > n_batch) { - batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); - llama_batch_ext_clear(batch); + if (batch.n_tokens() + n_toks > n_batch) { + batch_decode(ctx, batch.get(), emb + p * n_embd, s, n_embd); + batch.clear(); p += s; s = 0; } // add to batch - batch_add_seq(batch, inp, s); + batch.add_seq(inp, 0, s, true); s += 1; } // final batch - batch_decode(ctx, batch, emb + p * n_embd, s, n_embd); + batch_decode(ctx, batch.get(), emb + p * n_embd, s, n_embd); // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { @@ -253,7 +246,7 @@ int main(int argc, char ** argv) { chunks[i].tokens.clear(); } - llama_batch_ext * query_batch = llama_batch_ext_init(ctx); + llama_batch_ext_ptr query_batch(ctx); // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; @@ -262,12 +255,12 @@ int main(int argc, char ** argv) { std::getline(std::cin, query); std::vector query_tokens = common_tokenize(ctx, query, true); - batch_add_seq(query_batch, query_tokens, 0); + batch.add_seq(query_tokens, 0, 0, true); std::vector query_emb(n_embd, 0); - batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); + batch_decode(ctx, query_batch.get(), query_emb.data(), 1, n_embd); - llama_batch_ext_clear(query_batch); + query_batch.clear(); // compute cosine similarities { @@ -296,9 +289,6 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - llama_batch_ext_free(batch); - llama_batch_ext_free(query_batch); - // clean up llama_backend_free(); } diff --git a/include/llama-cpp.h b/include/llama-cpp.h index cb1a6e2fedeac..e752ae6385c72 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -101,6 +101,18 @@ struct llama_batch_ext_ptr : std::unique_ptr & tokens, llama_pos pos0, llama_seq_id seq_id, bool output_last) { + int32_t output_id = -1; + for (size_t i = 0; i < tokens.size(); i++) { + output_id = llama_batch_ext_add_text(this->get(), tokens[i], pos0 + i, &seq_id, 1, false); + } + if (output_last) { + llama_batch_ext_set_output_last(this->get()); + } + return output_id; + } + void clear() { llama_batch_ext_clear(this->get()); } From c1f4a78f0e0ba72672f07a1cb3354de8a27b18f0 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 27 Mar 2025 11:32:03 +0100 Subject: [PATCH 52/52] correct output_id for llama-cpp header --- include/llama-cpp.h | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/include/llama-cpp.h b/include/llama-cpp.h index e752ae6385c72..5a4d316cbda4b 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -54,6 +54,7 @@ struct llama_batch_ext_ptr : std::unique_ptr & seq_id, bool output_last) { - int32_t output_id = llama_batch_ext_add_text(this->get(), token, pos, seq_id.data(), seq_id.size(), false); + int32_t output_id = -1; + llama_batch_ext_add_text(this->get(), token, pos, seq_id.data(), seq_id.size(), false); if (output_last) { - llama_batch_ext_set_output_last(this->get()); + output_id = llama_batch_ext_set_output_last(this->get()); } return output_id; } // Wrapper to add a single token to the batch (single sequence ID) int32_t add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool output_last) { - int32_t output_id = llama_batch_ext_add_text(this->get(), token, pos, &seq_id, 1, false); + int32_t output_id = -1; + llama_batch_ext_add_text(this->get(), token, pos, &seq_id, 1, false); if (output_last) { - llama_batch_ext_set_output_last(this->get()); + output_id = llama_batch_ext_set_output_last(this->get()); } return output_id; } @@ -105,10 +108,10 @@ struct llama_batch_ext_ptr : std::unique_ptr & tokens, llama_pos pos0, llama_seq_id seq_id, bool output_last) { int32_t output_id = -1; for (size_t i = 0; i < tokens.size(); i++) { - output_id = llama_batch_ext_add_text(this->get(), tokens[i], pos0 + i, &seq_id, 1, false); + llama_batch_ext_add_text(this->get(), tokens[i], pos0 + i, &seq_id, 1, false); } if (output_last) { - llama_batch_ext_set_output_last(this->get()); + output_id = llama_batch_ext_set_output_last(this->get()); } return output_id; }