Skip to content

Commit e6dc215

Browse files
committed
fix: Split up seq_rm interface into immutable can_seq_rm and mutating seq_rm
This allows the hybrid cache to check first before mutating any of the children. Branch: HybridCache Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 9c62d55 commit e6dc215

File tree

2 files changed

+74
-21
lines changed

2 files changed

+74
-21
lines changed

src/llama-kv-cache.cpp

Lines changed: 63 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,11 @@ void llama_kv_cache_unified::set_full() {
443443
n = size;
444444
}
445445

446+
bool llama_kv_cache_unified::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
447+
// Unified attention cache can always do a sequence removal
448+
return true;
449+
}
450+
446451
llama_sbatch llama_kv_cache_unified::sbatch_init(
447452
const llama_batch & batch,
448453
bool logits_all) {
@@ -1481,39 +1486,33 @@ void llama_kv_cache_recurrent::clear() {
14811486
}
14821487

14831488
bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1484-
uint32_t new_head = size;
1489+
if (!can_seq_rm(seq_id, p0, p1)) {
1490+
// could be fatal
1491+
return false;
1492+
}
14851493

1494+
uint32_t new_head = size;
14861495
if (p0 < 0) {
14871496
p0 = 0;
14881497
}
1489-
14901498
if (p1 < 0) {
14911499
p1 = std::numeric_limits<llama_pos>::max();
14921500
}
14931501

1494-
// models like Mamba or RWKV can't have a state partially erased
1495-
if (seq_id >= (int64_t) size) {
1496-
// could be fatal
1497-
return false;
1498-
}
14991502
if (0 <= seq_id) {
15001503
int32_t & tail_id = cells[seq_id].tail;
15011504
if (tail_id >= 0) {
15021505
const kv_cell & cell = cells[tail_id];
1503-
// partial intersection is invalid
1504-
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1505-
return false;
1506-
}
1506+
// already validated in can_seq_rm
1507+
GGML_ASSERT(!((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)));
15071508
// invalidate tails which will be cleared
15081509
if (p0 <= cell.pos && cell.pos < p1) {
15091510
tail_id = -1;
15101511
}
15111512
}
15121513
} else {
1513-
// seq_id is negative, then the range should include everything or nothing
1514-
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1515-
return false;
1516-
}
1514+
// already validated in can_seq_rm
1515+
GGML_ASSERT(!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())));
15171516
}
15181517

15191518
for (uint32_t i = 0; i < size; ++i) {
@@ -1714,6 +1713,35 @@ void llama_kv_cache_recurrent::set_full() {
17141713
n = size;
17151714
}
17161715

1716+
bool llama_kv_cache_recurrent::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1717+
if (p0 < 0) {
1718+
p0 = 0;
1719+
}
1720+
1721+
if (p1 < 0) {
1722+
p1 = std::numeric_limits<llama_pos>::max();
1723+
}
1724+
// models like Mamba or RWKV can't have a state partially erased
1725+
if (seq_id >= (int64_t) size) {
1726+
// could be fatal
1727+
return false;
1728+
}
1729+
if (0 <= seq_id) {
1730+
const int32_t & tail_id = cells[seq_id].tail;
1731+
if (tail_id >= 0) {
1732+
const kv_cell & cell = cells[tail_id];
1733+
// partial intersection is invalid
1734+
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
1735+
return false;
1736+
}
1737+
}
1738+
// seq_id is negative, then the range should include everything or nothing
1739+
} else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
1740+
return false;
1741+
}
1742+
return true;
1743+
}
1744+
17171745
llama_sbatch llama_kv_cache_recurrent::sbatch_init(
17181746
const llama_batch & batch,
17191747
bool logits_all) {
@@ -2456,13 +2484,18 @@ void llama_kv_cache_hybrid::clear() {
24562484
}
24572485

24582486
bool llama_kv_cache_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2459-
// TODO: Will it cause problems if some caches are able to remove the seq
2460-
// but others aren't?
2461-
bool removed = true;
2487+
// First check if we can do this removal. This checks all children so that
2488+
// no mutation happens before we know if it's possible
2489+
if (!can_seq_rm(seq_id, p0, p1)) {
2490+
return false;
2491+
}
2492+
2493+
// Do the removal from each child which should never fail
24622494
for (const auto & cache : m_children) {
2463-
removed = cache->seq_rm(seq_id, p0, p1) && removed;
2495+
const bool failed = cache->seq_rm(seq_id, p0, p1);
2496+
GGML_ASSERT(!failed);
24642497
}
2465-
return removed;
2498+
return true;
24662499
}
24672500

24682501
void llama_kv_cache_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
@@ -2529,6 +2562,15 @@ void llama_kv_cache_hybrid::set_full() {
25292562
}
25302563
}
25312564

2565+
bool llama_kv_cache_hybrid::can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2566+
for (const auto & cache : m_children) {
2567+
if (!cache->can_seq_rm(seq_id, p0, p1)) {
2568+
return false;
2569+
}
2570+
}
2571+
return true;
2572+
}
2573+
25322574
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
25332575
// If any of the caches are recurrent, require equal split
25342576
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);
@@ -2566,7 +2608,7 @@ int32_t llama_kv_cache_hybrid::get_n_tokens() const {
25662608

25672609
int32_t llama_kv_cache_hybrid::get_used_cells() const {
25682610
// TODO: Is this correct?
2569-
// Return the largetst number of used cells
2611+
// Return the largest number of used cells
25702612
int32_t used_cells = -1;
25712613
for (const auto & cache : m_children) {
25722614
used_cells = std::max(used_cells, cache->get_used_cells());

src/llama-kv-cache.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ struct llama_kv_cache : public llama_memory_i {
3737
// simulate full cache, used for allocating worst-case compute buffers
3838
virtual void set_full() = 0;
3939

40+
// sometimes it is useful to check whether a cache can remove a sequence
41+
// before attempting to mutate the cache (eg a hybrid cache with multiple
42+
// children to keep in sync)
43+
virtual bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const = 0;
44+
4045
//
4146
// batch processing
4247
//
@@ -150,6 +155,8 @@ class llama_kv_cache_unified : public llama_kv_cache {
150155

151156
void set_full() override;
152157

158+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
159+
153160
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
154161

155162
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
@@ -321,6 +328,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache {
321328

322329
void set_full() override;
323330

331+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
332+
324333
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
325334

326335
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;
@@ -439,6 +448,8 @@ class llama_kv_cache_hybrid : public llama_kv_cache {
439448

440449
void set_full() override;
441450

451+
bool can_seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) const override;
452+
442453
llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override;
443454

444455
llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override;

0 commit comments

Comments
 (0)