From b8567cef25616b855583ac94e3d6ed489a543cc6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 2 May 2025 16:47:24 -0600 Subject: [PATCH 01/15] feat: First pass at llama_kv_cache_hybrid This implementation covers both `llama_memory_i` and `llama_kv_cache` interfaces, but they could very well not be correct. Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 225 +++++++++++++++++++++++++++++++++++++++++ src/llama-kv-cache.h | 74 ++++++++++++++ 2 files changed, 299 insertions(+) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 265db2527c7ca..662e17d4e14dc 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -2392,6 +2392,231 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return true; } +// +// llama_kv_cache_hybrid +// +llama_kv_cache_hybrid::llama_kv_cache_hybrid( + const llama_hparams & hparams, + const std::vector & children) : + m_hparams(hparams), + m_layer_cache_map( + [](const std::vector& caches) -> std::unordered_map { + std::unordered_map map; + for (const auto & cache : caches) { + for (size_t layer_id : cache.layer_ids) { + map[layer_id] = cache.child; + } + } + + return map; + }(children) + ), + m_children( + [](std::vector caches) -> std::set { + // Sort the caches by the lowest layer ID so the order is repeatable + for (auto & cache : caches) { + GGML_ASSERT(cache.layer_ids.size() > 0); + std::sort(cache.layer_ids.begin(), cache.layer_ids.end()); + } + std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) { + return a.layer_ids[0] < b.layer_ids[0]; + }); + std::set unique_caches; + for (const auto & cache : caches) { + unique_caches.insert(cache.child); + } + return unique_caches; + }(children) + ), + m_has_recurrent( + [](const std::vector& caches) -> bool { + for (const auto & cache : caches) { + if (dynamic_cast(cache.child)) { + return true; + } + } + return false; + }(children) + ) +{ + // Ensure at least one child + GGML_ASSERT(m_children.size() > 0); + + // Ensure layers are not overlapping and are concurrent + std::set seen_layers; + size_t max_layer = 0; + for (const auto & cache : children) { + for (const auto & layer_id : cache.layer_ids) { + GGML_ASSERT(seen_layers.find(layer_id) == seen_layers.end()); + seen_layers.insert(layer_id); + if (layer_id > max_layer) { + max_layer = layer_id; + } + } + } + GGML_ASSERT(max_layer == seen_layers.size()); +} + +void llama_kv_cache_hybrid::clear() { + for (const auto & cache : m_children) { + cache->clear(); + } +} + +bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // TODO: Will it cause problems if some caches are able to remove the seq + // but others aren't? + bool removed = true; + for (const auto & cache : m_children) { + removed = cache->seq_rm(seq_id, p0, p1) && removed; + } + return removed; +} + +void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + for (const auto & cache : m_children) { + cache->seq_cp(seq_id_src, seq_id_dst, p0, p1); + } +} + +void llama_kv_cache_hybrid::seq_keep(llama_seq_id seq_id) { + for (const auto & cache : m_children) { + cache->seq_keep(seq_id); + } +} + +void llama_kv_cache_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + for (const auto & cache : m_children) { + cache->seq_add(seq_id, p0, p1, delta); + } +} + +void llama_kv_cache_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + for (const auto & cache : m_children) { + cache->seq_div(seq_id, p0, p1, d); + } +} + +llama_pos llama_kv_cache_hybrid::seq_pos_max(llama_seq_id seq_id) const { + llama_pos max_pos = 0; + for (const auto & cache : m_children) { + max_pos = std::max(max_pos, cache->seq_pos_max(seq_id)); + } + return max_pos; +} + +void llama_kv_cache_hybrid::restore() { + for (const auto & cache : m_children) { + cache->restore(); + } +} + +void llama_kv_cache_hybrid::commit() { + for (const auto & cache : m_children) { + cache->commit(); + } +} + +bool llama_kv_cache_hybrid::update(llama_context & ctx) { + bool updated = false; + for (const auto & cache : m_children) { + updated = cache->update(ctx) || updated; + } + return updated; +} + +void llama_kv_cache_hybrid::defrag_sched(float thold) { + for (const auto & cache : m_children) { + cache->defrag_sched(thold); + } +} + +void llama_kv_cache_hybrid::set_full() { + for (const auto & cache : m_children) { + cache->set_full(); + } +} + +llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) { + // If any of the caches are recurrent, require simple split + return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all); +} + +llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { + if (m_has_recurrent) { + return sbatch.split_simple(n_ubatch); + } + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + return sbatch.split_seq(n_ubatch); + } + return sbatch.split_equal(n_ubatch); +} + +bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) { + bool found = true; + for (const auto & cache : m_children) { + found = cache->find_slot(batch) && found; + } + return found; +} + +int32_t llama_kv_cache_hybrid::get_n_tokens() const { + // The number of tokens should be the same across all child caches + int32_t n_tokens = -1; + for (const auto & cache : m_children) { + const auto cache_n_tokens = cache->get_n_tokens(); + GGML_ASSERT(n_tokens == -1 || cache_n_tokens == n_tokens); + n_tokens = cache_n_tokens; + } + return n_tokens; +} + +int32_t llama_kv_cache_hybrid::get_used_cells() const { + // TODO: Is this correct? + // Return the largetst number of used cells + int32_t used_cells = -1; + for (const auto & cache : m_children) { + used_cells = std::max(used_cells, cache->get_used_cells()); + } + return used_cells; +} + +llama_pos llama_kv_cache_hybrid::get_pos_max() const { + llama_pos pos_max = -1; + for (const auto & cache : m_children) { + pos_max = std::max(pos_max, cache->get_pos_max()); + } + return pos_max; +} + +bool llama_kv_cache_hybrid::get_can_shift() const { + // TODO: Is this correct? + // If any children can shift, return true + for (const auto & cache : m_children) { + if (cache->get_can_shift()) { + return true; + } + } + return false; +} + +void llama_kv_cache_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + // Write each cache state in order. Note that order is guaranteed at + // initialization by using an ordered set sorted by lowest layer ID + for (const auto & cache : m_children) { + cache->state_write(io, seq_id); + } +} + +void llama_kv_cache_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + // Read each cache state in order. Note that order is guaranteed at + // initialization by using an ordered set sorted by lowest layer ID + for (const auto & cache : m_children) { + cache->state_read(io, seq_id); + } +} + // // kv cache view // diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index e83e12c09f2b1..82f696dd22be2 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -9,6 +9,7 @@ #include #include +#include struct llama_cparams; struct llama_hparams; @@ -389,6 +390,79 @@ class llama_kv_cache_recurrent : public llama_kv_cache { bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; +// +// llama_kv_cache_hybrid +// + +class llama_kv_cache_hybrid : public llama_kv_cache { +public: + + struct child_cache { + llama_kv_cache * child; + std::vector layer_ids; + }; + + llama_kv_cache_hybrid( + const llama_hparams & hparams, + const std::vector & children); + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + void restore() override; + void commit() override; + + bool update(llama_context & ctx) override; + + void defrag_sched(float thold) override; + + void set_full() override; + + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + + // updates the cache head + // Note: On success, it's important that cache.head points + // to the first cell of the slot. + bool find_slot(const llama_ubatch & batch) override; + + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; + + // TODO: better data structures to reduce the cost of this operation + llama_pos get_pos_max() const override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + +private: + + const llama_hparams & m_hparams; + const std::unordered_map m_layer_cache_map; + const std::set m_children; // Ordered for state IO + const bool m_has_recurrent; +}; + // // kv cache view From 9eca84e1011ac8a060a9464c401ad6a4c80eb977 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:33:22 -0600 Subject: [PATCH 02/15] fix: Fix confusion on simple vs equal splitting Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 662e17d4e14dc..12ba32201f898 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -2538,19 +2538,19 @@ void llama_kv_cache_hybrid::set_full() { } llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) { - // If any of the caches are recurrent, require simple split - return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all); + // If any of the caches are recurrent, require equal split + return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all); } llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - if (m_has_recurrent) { - return sbatch.split_simple(n_ubatch); - } if (embd_pooled) { // Pooled embeddings cannot be split across ubatches (yet) return sbatch.split_seq(n_ubatch); } - return sbatch.split_equal(n_ubatch); + if (m_has_recurrent) { + return sbatch.split_equal(n_ubatch); + } + return sbatch.split_simple(n_ubatch); } bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) { From f1ceed6ad18695f01d54f773d83ca77fc99f861b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 16:30:13 -0600 Subject: [PATCH 03/15] fix: Split up seq_rm interface into immutable can_seq_rm and mutating seq_rm This allows the hybrid cache to check first before mutating any of the children. Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 84 +++++++++++++++++++++++++++++++----------- src/llama-kv-cache.h | 11 ++++++ 2 files changed, 74 insertions(+), 21 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 12ba32201f898..03db981f3adbf 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -450,6 +450,11 @@ void llama_kv_cache_unified::set_full() { head = 0; } +bool llama_kv_cache_unified::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + // Unified attention cache can always do a sequence removal + return true; +} + llama_sbatch llama_kv_cache_unified::sbatch_init( const llama_batch & batch, bool logits_all) { @@ -1488,39 +1493,33 @@ void llama_kv_cache_recurrent::clear() { } bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - uint32_t new_head = size; + if (!can_seq_rm(seq_id, p0, p1)) { + // could be fatal + return false; + } + uint32_t new_head = size; if (p0 < 0) { p0 = 0; } - if (p1 < 0) { p1 = std::numeric_limits::max(); } - // models like Mamba or RWKV can't have a state partially erased - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { const kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } + // already validated in can_seq_rm + GGML_ASSERT(!((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos))); // invalidate tails which will be cleared if (p0 <= cell.pos && cell.pos < p1) { tail_id = -1; } } } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } + // already validated in can_seq_rm + GGML_ASSERT(!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max()))); } for (uint32_t i = 0; i < size; ++i) { @@ -1722,6 +1721,35 @@ void llama_kv_cache_recurrent::set_full() { head = 0; } +bool llama_kv_cache_recurrent::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + const int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + } + // seq_id is negative, then the range should include everything or nothing + } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + return true; +} + llama_sbatch llama_kv_cache_recurrent::sbatch_init( const llama_batch & batch, bool logits_all) { @@ -2464,13 +2492,18 @@ void llama_kv_cache_hybrid::clear() { } bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - // TODO: Will it cause problems if some caches are able to remove the seq - // but others aren't? - bool removed = true; + // First check if we can do this removal. This checks all children so that + // no mutation happens before we know if it's possible + if (!can_seq_rm(seq_id, p0, p1)) { + return false; + } + + // Do the removal from each child which should never fail for (const auto & cache : m_children) { - removed = cache->seq_rm(seq_id, p0, p1) && removed; + const bool failed = cache->seq_rm(seq_id, p0, p1); + GGML_ASSERT(!failed); } - return removed; + return true; } void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { @@ -2537,6 +2570,15 @@ void llama_kv_cache_hybrid::set_full() { } } +bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + for (const auto & cache : m_children) { + if (!cache->can_seq_rm(seq_id, p0, p1)) { + return false; + } + } + return true; +} + llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) { // If any of the caches are recurrent, require equal split return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all); @@ -2574,7 +2616,7 @@ int32_t llama_kv_cache_hybrid::get_n_tokens() const { int32_t llama_kv_cache_hybrid::get_used_cells() const { // TODO: Is this correct? - // Return the largetst number of used cells + // Return the largest number of used cells int32_t used_cells = -1; for (const auto & cache : m_children) { used_cells = std::max(used_cells, cache->get_used_cells()); diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 82f696dd22be2..ba8a55d5b9a4f 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -37,6 +37,11 @@ struct llama_kv_cache : public llama_memory_i { // simulate full cache, used for allocating worst-case compute buffers virtual void set_full() = 0; + // sometimes it is useful to check whether a cache can remove a sequence + // before attempting to mutate the cache (eg a hybrid cache with multiple + // children to keep in sync) + virtual bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const = 0; + // // batch processing // @@ -150,6 +155,8 @@ class llama_kv_cache_unified : public llama_kv_cache { void set_full() override; + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; @@ -318,6 +325,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void set_full() override; + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; @@ -433,6 +442,8 @@ class llama_kv_cache_hybrid : public llama_kv_cache { void set_full() override; + bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override; + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; From a99cbd361dad4682cdfc9a5968ddb3ac4391bed4 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 10:28:52 -0600 Subject: [PATCH 04/15] fix: Mark unused params correctly Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 03db981f3adbf..44d3e23ef74c2 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -451,6 +451,9 @@ void llama_kv_cache_unified::set_full() { } bool llama_kv_cache_unified::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const { + GGML_UNUSED(seq_id); + GGML_UNUSED(p0); + GGML_UNUSED(p1); // Unified attention cache can always do a sequence removal return true; } From c678901d2f682a7d4d16d87bad4bd7550dce7df7 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 10:42:23 -0600 Subject: [PATCH 05/15] fix: Give ownership of child caches to the hybrid cache The parent should fully own the lifecycle of the children which is managed by the m_children member holding unique_ptrs. These need to be initialized correctly, so the constructor now takes the input vector of child_cache by value instead of reference so that the child pointers can be transferred to the parent cache. The expectation is that the vector of child_cache instances will be instantiated in-place with move semantics. Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 20 ++++++++++---------- src/llama-kv-cache.h | 11 +++++++---- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 44d3e23ef74c2..819f9aa60374b 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -2427,15 +2427,15 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // llama_kv_cache_hybrid // llama_kv_cache_hybrid::llama_kv_cache_hybrid( - const llama_hparams & hparams, - const std::vector & children) : + const llama_hparams & hparams, + std::vector children) : m_hparams(hparams), m_layer_cache_map( [](const std::vector& caches) -> std::unordered_map { std::unordered_map map; for (const auto & cache : caches) { for (size_t layer_id : cache.layer_ids) { - map[layer_id] = cache.child; + map[layer_id] = cache.child.get(); } } @@ -2443,7 +2443,7 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid( }(children) ), m_children( - [](std::vector caches) -> std::set { + [](std::vector& caches) -> std::set> { // Sort the caches by the lowest layer ID so the order is repeatable for (auto & cache : caches) { GGML_ASSERT(cache.layer_ids.size() > 0); @@ -2452,22 +2452,22 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid( std::sort(caches.begin(), caches.end(), [](const child_cache & a, const child_cache & b) { return a.layer_ids[0] < b.layer_ids[0]; }); - std::set unique_caches; - for (const auto & cache : caches) { - unique_caches.insert(cache.child); + std::set> unique_caches; + for (auto & cache : caches) { + unique_caches.emplace(cache.child.release()); } return unique_caches; }(children) ), m_has_recurrent( - [](const std::vector& caches) -> bool { + [](const std::set> & caches) -> bool { for (const auto & cache : caches) { - if (dynamic_cast(cache.child)) { + if (dynamic_cast(cache.get())) { return true; } } return false; - }(children) + }(m_children) ) { // Ensure at least one child diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index ba8a55d5b9a4f..d9caaf3f7ed8c 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -407,13 +407,16 @@ class llama_kv_cache_hybrid : public llama_kv_cache { public: struct child_cache { - llama_kv_cache * child; - std::vector layer_ids; + std::unique_ptr child; + std::vector layer_ids; + + child_cache(std::unique_ptr child_, std::vector layer_ids_) + : child(std::move(child_)), layer_ids(std::move(layer_ids_)) {} }; llama_kv_cache_hybrid( const llama_hparams & hparams, - const std::vector & children); + std::vector children); // // llama_memory_i @@ -470,7 +473,7 @@ class llama_kv_cache_hybrid : public llama_kv_cache { const llama_hparams & m_hparams; const std::unordered_map m_layer_cache_map; - const std::set m_children; // Ordered for state IO + const std::set> m_children; // Ordered for state IO const bool m_has_recurrent; }; From cd0dc98f0d864efb1b1de5e77558a76d839f174a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:08:33 -0600 Subject: [PATCH 06/15] feat: Add c++ side constants for attention layer indices hparam Branch: GraniteFour --- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index abf436adac416..690b15d88d126 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -144,6 +144,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 41a023da3da6e..adba61c44e35d 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -148,6 +148,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_LAYER_INDICES, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, From a00f3f68303113315f1ded063fb7d66042c61355 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:21:29 -0600 Subject: [PATCH 07/15] feat: Add llama_model_is_hybrid API call Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart --- include/llama.h | 3 +++ src/llama-arch.cpp | 22 ++++++++++++++++++++++ src/llama-arch.h | 3 +++ src/llama-model.cpp | 13 +++++-------- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/include/llama.h b/include/llama.h index 99e5fba244fcc..2670df03b8146 100644 --- a/include/llama.h +++ b/include/llama.h @@ -552,6 +552,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 690b15d88d126..291e5da230ad2 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1745,3 +1745,25 @@ llm_arch llm_arch_from_string(const std::string & name) { const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { return LLM_TENSOR_INFOS.at(tensor); } + +bool llm_arch_is_recurrent(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_MAMBA: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + return true; + default: + return false; + } +} + +bool llm_arch_is_hybrid(const llm_arch & arch) { + // TODO: There are currently no hybrid models! Once there are, this will be + // the place to identify them + switch (arch) { + default: + return false; + } +} diff --git a/src/llama-arch.h b/src/llama-arch.h index adba61c44e35d..1bfd9780ac962 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -436,3 +436,6 @@ const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); + +bool llm_arch_is_recurrent(const llm_arch& arch); +bool llm_arch_is_hybrid(const llm_arch& arch); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7fd094b63f269..15a3b23a03c08 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13616,14 +13616,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) { } bool llama_model_is_recurrent(const llama_model * model) { - switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - case LLM_ARCH_RWKV6QWEN2: return true; - case LLM_ARCH_RWKV7: return true; - case LLM_ARCH_ARWKV7: return true; - default: return false; - } + return llm_arch_is_recurrent(model->arch); +} + +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { From 8a13b03d6a9d0c4aad2a023786d60eff8a4124d1 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:22:18 -0600 Subject: [PATCH 08/15] feat: Auto-fill hparams.recurrent_layer_arr based on whether the model is recurrent Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 15a3b23a03c08..a03d51db0ccf5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -467,6 +467,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill( + hparams.recurrent_layer_arr.begin(), + hparams.recurrent_layer_arr.end(), + llm_arch_is_recurrent(ml.get_arch())); ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); From 5a60db59aec34c3c969a9013b40b58c62a205569 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 11:20:31 -0600 Subject: [PATCH 09/15] feat: Zero-out recurrent / non-recurrent layers in the single-type caches This is a bit of an inversion of concerns, so we could conceivably make the interface to this more opaque to the other cache types by providing something like a layer mask, but since these cache implementations already have access to the hparams, it seems minimally invasive to just check the new recurrent_layer function. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 819f9aa60374b..c1b25c55cf114 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -100,8 +100,11 @@ llama_kv_cache_unified::llama_kv_cache_unified( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + // any recurrent layers in the model will not use this cache + const uint32_t tensor_dim = hparams.recurrent_layer(i) ? 0 : kv_size; + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*tensor_dim); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*tensor_dim); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); k_l.push_back(k); @@ -1447,8 +1450,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + // any non-recurrent layers in the model will not use this cache + const uint32_t tensor_dim = hparams.recurrent_layer(i) ? kv_size : 0; + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*tensor_dim); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*tensor_dim); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); k_l.push_back(k); From bb7d4bd85b539a3f5179812f2d30e244f9a344df Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:04:36 -0600 Subject: [PATCH 10/15] feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-hparams.cpp | 14 ++++++++++++-- src/llama-hparams.h | 10 ++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 90dfe7a7fcc00..d57c5defe7157 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -49,7 +49,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } -uint32_t llama_hparams::n_embd_k_s() const { +uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // for RWKV models return token_shift_count * n_embd; @@ -60,7 +63,10 @@ uint32_t llama_hparams::n_embd_k_s() const { return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } -uint32_t llama_hparams::n_embd_v_s() const { +uint32_t llama_hparams::n_embd_v_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; @@ -70,6 +76,10 @@ uint32_t llama_hparams::n_embd_v_s() const { return ssm_d_state * ssm_d_inner; } +bool llama_hparams::recurrent_layer(uint32_t il) const { + return recurrent_layer_arr[il]; +} + bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 7ee6a5b75ad1e..8c35407480a79 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -102,6 +102,9 @@ struct llama_hparams { uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + // for hybrid state space models + std::array recurrent_layer_arr; + bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -149,10 +152,13 @@ struct llama_hparams { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size - uint32_t n_embd_k_s() const; + uint32_t n_embd_k_s(uint32_t il = 0) const; // dimension of the recurrent state embeddings - uint32_t n_embd_v_s() const; + uint32_t n_embd_v_s(uint32_t il = 0) const; + + // whether or not the given layer is recurrent (for hybrid models) + bool recurrent_layer(uint32_t il) const; bool is_swa(uint32_t il) const; }; From cbf6b102c7d66e3d9699ed2c4b006b026a952454 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 11:22:30 -0600 Subject: [PATCH 11/15] feat!: Instantiate hybrid cache for hybrid models There is a small breaking change here that extends the create_memory method signature to include the hparams. Currently, this member is only used inside llama_context and is not part of an interface that's expected to be extended by classes derived from llama_model, so I don't think this should actually break any downstream use cases. Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-context.cpp | 2 +- src/llama-model.cpp | 98 ++++++++++++++++++++++++++++++++----------- src/llama-model.h | 5 ++- 3 files changed, 78 insertions(+), 27 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a3b84a6a82e74..73f21bda6d49c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -180,7 +180,7 @@ llama_context::llama_context( /*.type_v =*/ params.type_v, }; - memory.reset(model.create_memory(params_mem, cparams)); + memory.reset(model.create_memory(params_mem, cparams, hparams)); } // init backends diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a03d51db0ccf5..dd3713a844edd 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13040,10 +13040,15 @@ struct llm_build_bailingmoe : public llm_graph_context { } }; -llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { +llama_memory_i * llama_model::create_memory( + const llama_memory_params & params, + llama_cparams & cparams, + const llama_hparams & hparams) const { llama_memory_i * res; switch (arch) { + // Models that need specific instantiation should be handled in the + // switch statement case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: @@ -13051,35 +13056,78 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; - case LLM_ARCH_MAMBA: - case LLM_ARCH_RWKV6: - case LLM_ARCH_RWKV6QWEN2: - case LLM_ARCH_RWKV7: - case LLM_ARCH_ARWKV7: - { - res = new llama_kv_cache_recurrent( - *this, - GGML_TYPE_F32, - GGML_TYPE_F32, - cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max)); - } break; + // Models that need standard caching should rely on recurrent/hybrid + // checks default: { - const auto padding = llama_kv_cache_unified::get_padding(cparams); + if (llm_arch_is_hybrid(arch)) { + // make vectors of recurrent and non-recurrent layer indices + std::vector recurrent_layers; + std::vector unified_layers; + for (auto il = 0u; il < hparams.n_layer; ++il) { + if (hparams.recurrent_layer(il)) { + recurrent_layers.push_back(il); + } else { + unified_layers.push_back(il); + } + } + + const auto padding = llama_kv_cache_unified::get_padding(cparams); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + // initialize the children + std::vector children; + children.emplace_back( + std::unique_ptr( + new llama_kv_cache_recurrent( + *this, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max)) + ), + std::move(recurrent_layers) + ); + children.emplace_back( + std::unique_ptr( + new llama_kv_cache_unified( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + padding) + ), + std::move(unified_layers) + ); + + // initialize the hybrid cache with both children + res = new llama_kv_cache_hybrid(hparams, std::move(children)); + } else if (llm_arch_is_recurrent(arch)) { + res = new llama_kv_cache_recurrent( + *this, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max)); + } else { + const auto padding = llama_kv_cache_unified::get_padding(cparams); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - res = new llama_kv_cache_unified( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - cparams.n_ctx, - padding); + res = new llama_kv_cache_unified( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + padding); + } } } diff --git a/src/llama-model.h b/src/llama-model.h index 6bdec263b709b..ed79b249cc6f9 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -402,7 +402,10 @@ struct llama_model { // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; + llama_memory_i * create_memory( + const llama_memory_params & params, + llama_cparams & cparams, + const llama_hparams & hparams) const; // TODO: move this to new llm_arch_model_i interface llm_graph_result_ptr build_graph( From 220456d3691901a24ff6ab1dd1bb47e23c089561 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 16 May 2025 15:25:15 -0600 Subject: [PATCH 12/15] fix: Remove unnecessary hparams argument to create_memory It was already available as a member! :facepalm: Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-context.cpp | 2 +- src/llama-model.cpp | 3 +-- src/llama-model.h | 5 +---- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 73f21bda6d49c..a3b84a6a82e74 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -180,7 +180,7 @@ llama_context::llama_context( /*.type_v =*/ params.type_v, }; - memory.reset(model.create_memory(params_mem, cparams, hparams)); + memory.reset(model.create_memory(params_mem, cparams)); } // init backends diff --git a/src/llama-model.cpp b/src/llama-model.cpp index dd3713a844edd..07378d764f15b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13042,8 +13042,7 @@ struct llm_build_bailingmoe : public llm_graph_context { llama_memory_i * llama_model::create_memory( const llama_memory_params & params, - llama_cparams & cparams, - const llama_hparams & hparams) const { + llama_cparams & cparams) const { llama_memory_i * res; switch (arch) { diff --git a/src/llama-model.h b/src/llama-model.h index ed79b249cc6f9..6bdec263b709b 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -402,10 +402,7 @@ struct llama_model { // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory( - const llama_memory_params & params, - llama_cparams & cparams, - const llama_hparams & hparams) const; + llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; // TODO: move this to new llm_arch_model_i interface llm_graph_result_ptr build_graph( From f857dc04d6092bd0b0e297d4550ca9960a7440dc Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 16:51:39 -0600 Subject: [PATCH 13/15] feat: Add a templated helper to the hybrid cache to retrieve a child This will be the public interface used by functions that need to access one specific type of child. It's a bit brittle since the rest of the hybrid class intentionally avoids expecting there to be exactly one unified child and one recurrent child, but the idea is that this should only be used from a context where that's known to be true. Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.h | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index d9caaf3f7ed8c..9abeb0e1702b0 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -418,6 +418,21 @@ class llama_kv_cache_hybrid : public llama_kv_cache { const llama_hparams & hparams, std::vector children); + // getters for specific child cache type + // NOTE: This will fail if there are multiple of the given type + template + const child_t * get_child_cache() const { + const child_t * child = nullptr; + for (const auto & child_cache : m_children) { + const child_t * child_cast = dynamic_cast(child_cache.get()); + if (child_cast) { + GGML_ASSERT(!child); + child = child_cast; + } + } + return child; + } + // // llama_memory_i // From 0cc96c93141e4821c33abd9162a3ad5b2be6197e Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 13 May 2025 16:49:58 -0600 Subject: [PATCH 14/15] fix: Fix off-by-one error for assertion check on layers in hybrid cache constructor Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-kv-cache.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c1b25c55cf114..832ede29383d5 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -2491,7 +2491,8 @@ llama_kv_cache_hybrid::llama_kv_cache_hybrid( } } } - GGML_ASSERT(max_layer == seen_layers.size()); + LLAMA_LOG_DEBUG("max_layer=%zu, seen_layers.size()=%zu\n", max_layer, seen_layers.size()); + GGML_ASSERT(max_layer + 1 == seen_layers.size()); } void llama_kv_cache_hybrid::clear() { From f46a72758df4c923b7463fe718c46a749c63a4eb Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 16 May 2025 15:30:22 -0600 Subject: [PATCH 15/15] tests: Add initial unit tests for kv caches So far this only tests constructor logic (and barely that) Branch: HybridCache Signed-off-by: Gabe Goodhart --- tests/CMakeLists.txt | 1 + tests/test-memory.cpp | 133 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 tests/test-memory.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 083347d188880..9e4ac8342b24e 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -145,6 +145,7 @@ endif() llama_build_and_test(test-log.cpp) llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-regex-partial.cpp) +llama_build_and_test(test-memory.cpp) # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135) if (NOT WIN32) diff --git a/tests/test-memory.cpp b/tests/test-memory.cpp new file mode 100644 index 0000000000000..0c08fb8e06a07 --- /dev/null +++ b/tests/test-memory.cpp @@ -0,0 +1,133 @@ +#include "../src/llama-arch.h" +#include "../src/llama-hparams.h" +#include "../src/llama-impl.h" +#include "../src/llama-kv-cache.h" +#include "../src/llama-model.h" + +#include "llama.h" + +#include +#include +#include + +/*- Helpers ------------------------------------------------------------------*/ + +static std::shared_ptr _make_model() { + llama_model_params params; + params.tensor_buft_overrides = nullptr; + std::shared_ptr model(new llama_model(params)); + model->hparams = llama_hparams(); + model->arch = LLM_ARCH_LLAMA; + return model; +} + +struct log_scope { + const char * name; + explicit log_scope(const char * name) : name(name) { + LLAMA_LOG_INFO("--------\n"); + LLAMA_LOG_INFO("START: %s\n", name); + } + ~log_scope() { + LLAMA_LOG_INFO("END: %s\n", name); + LLAMA_LOG_INFO("--------\n"); + } +}; + +#define LOG_SCOPE() log_scope __log_scope(__func__) + +/*- Unified Cache ------------------------------------------------------------*/ + +/* Test that the unified cache can be constructed and destructed safely */ +static void test_llama_kv_cache_unified_constructor() { + LOG_SCOPE(); + auto model = _make_model(); + llama_kv_cache_unified cache( + /* model */ *model, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 10, + /* padding */ 10 + ); +} + +/*- Recurrent Cache ----------------------------------------------------------*/ + +/* Test that the recurrent cache can be constructed and destructed safely */ +static void test_llama_kv_cache_recurrent_constructor() { + LOG_SCOPE(); + auto model = _make_model(); + llama_kv_cache_recurrent cache( + /* model */ *model, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* offload */ false, + /* kv_size */ 10 + ); +} + +/*- Hybrid Cache -------------------------------------------------------------*/ + +/* Test that the hybrid cache can be constructed and destructed safely */ +static void test_llama_kv_cache_hybrid_constructor() { + LOG_SCOPE(); + auto model = _make_model(); + model->hparams.n_layer = 4; + model->hparams.n_embd_head_k = 4; + model->hparams.n_embd_head_v = 4; + auto& recurrent_layer_arr = model->hparams.recurrent_layer_arr; + recurrent_layer_arr[0] = 1; + recurrent_layer_arr[1] = 0; + recurrent_layer_arr[2] = 1; + recurrent_layer_arr[3] = 0; + auto& n_head_kv_arr = model->hparams.n_head_kv_arr; + n_head_kv_arr[0] = 16; + n_head_kv_arr[1] = 8; + n_head_kv_arr[2] = 16; + n_head_kv_arr[3] = 8; + + std::unique_ptr u_cache( + new llama_kv_cache_unified( + /* model */ *model, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* v_trans */ false, + /* offload */ false, + /* kv_size */ 20, + /* padding */ 2 + ) + ); + auto * u_cache_ptr = u_cache.get(); + std::unique_ptr r_cache ( + new llama_kv_cache_recurrent( + /* model */ *model, + /* type_k */ GGML_TYPE_F32, + /* type_v */ GGML_TYPE_F16, + /* offload */ false, + /* kv_size */ 10 + ) + ); + auto * r_cache_ptr = r_cache.get(); + + std::vector children; + children.emplace_back(std::move(u_cache), std::vector{1, 3}); + children.emplace_back(std::move(r_cache), std::vector{0, 2}); + + llama_kv_cache_hybrid cache(model->hparams, std::move(children)); + + GGML_ASSERT(cache.get_child_cache() == u_cache_ptr); + GGML_ASSERT(cache.get_child_cache() == r_cache_ptr); +} + +/*- Main ---------------------------------------------------------------------*/ + +int main() { + // Unified Cache Tests + test_llama_kv_cache_unified_constructor(); + // Recurrent Cache Tests + test_llama_kv_cache_recurrent_constructor(); + // Hybrid Cache Tests + test_llama_kv_cache_hybrid_constructor(); + return 0; +}