@@ -443,6 +443,11 @@ void llama_kv_cache_unified::set_full() {
443
443
n = size;
444
444
}
445
445
446
+ bool llama_kv_cache_unified::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
447
+ // Unified attention cache can always do a sequence removal
448
+ return true ;
449
+ }
450
+
446
451
llama_sbatch llama_kv_cache_unified::sbatch_init (
447
452
const llama_batch & batch,
448
453
bool logits_all) {
@@ -1481,39 +1486,33 @@ void llama_kv_cache_recurrent::clear() {
1481
1486
}
1482
1487
1483
1488
bool llama_kv_cache_recurrent::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
1484
- uint32_t new_head = size;
1489
+ if (!can_seq_rm (seq_id, p0, p1)) {
1490
+ // could be fatal
1491
+ return false ;
1492
+ }
1485
1493
1494
+ uint32_t new_head = size;
1486
1495
if (p0 < 0 ) {
1487
1496
p0 = 0 ;
1488
1497
}
1489
-
1490
1498
if (p1 < 0 ) {
1491
1499
p1 = std::numeric_limits<llama_pos>::max ();
1492
1500
}
1493
1501
1494
- // models like Mamba or RWKV can't have a state partially erased
1495
- if (seq_id >= (int64_t ) size) {
1496
- // could be fatal
1497
- return false ;
1498
- }
1499
1502
if (0 <= seq_id) {
1500
1503
int32_t & tail_id = cells[seq_id].tail ;
1501
1504
if (tail_id >= 0 ) {
1502
1505
const kv_cell & cell = cells[tail_id];
1503
- // partial intersection is invalid
1504
- if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1505
- return false ;
1506
- }
1506
+ // already validated in can_seq_rm
1507
+ GGML_ASSERT (!((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )));
1507
1508
// invalidate tails which will be cleared
1508
1509
if (p0 <= cell.pos && cell.pos < p1) {
1509
1510
tail_id = -1 ;
1510
1511
}
1511
1512
}
1512
1513
} else {
1513
- // seq_id is negative, then the range should include everything or nothing
1514
- if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1515
- return false ;
1516
- }
1514
+ // already validated in can_seq_rm
1515
+ GGML_ASSERT (!(p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())));
1517
1516
}
1518
1517
1519
1518
for (uint32_t i = 0 ; i < size; ++i) {
@@ -1714,6 +1713,35 @@ void llama_kv_cache_recurrent::set_full() {
1714
1713
n = size;
1715
1714
}
1716
1715
1716
+ bool llama_kv_cache_recurrent::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
1717
+ if (p0 < 0 ) {
1718
+ p0 = 0 ;
1719
+ }
1720
+
1721
+ if (p1 < 0 ) {
1722
+ p1 = std::numeric_limits<llama_pos>::max ();
1723
+ }
1724
+ // models like Mamba or RWKV can't have a state partially erased
1725
+ if (seq_id >= (int64_t ) size) {
1726
+ // could be fatal
1727
+ return false ;
1728
+ }
1729
+ if (0 <= seq_id) {
1730
+ const int32_t & tail_id = cells[seq_id].tail ;
1731
+ if (tail_id >= 0 ) {
1732
+ const kv_cell & cell = cells[tail_id];
1733
+ // partial intersection is invalid
1734
+ if ((0 < p0 && p0 <= cell.pos ) || (0 < p1 && p1 <= cell.pos )) {
1735
+ return false ;
1736
+ }
1737
+ }
1738
+ // seq_id is negative, then the range should include everything or nothing
1739
+ } else if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max ())) {
1740
+ return false ;
1741
+ }
1742
+ return true ;
1743
+ }
1744
+
1717
1745
llama_sbatch llama_kv_cache_recurrent::sbatch_init (
1718
1746
const llama_batch & batch,
1719
1747
bool logits_all) {
@@ -2456,13 +2484,18 @@ void llama_kv_cache_hybrid::clear() {
2456
2484
}
2457
2485
2458
2486
bool llama_kv_cache_hybrid::seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
2459
- // TODO: Will it cause problems if some caches are able to remove the seq
2460
- // but others aren't?
2461
- bool removed = true ;
2487
+ // First check if we can do this removal. This checks all children so that
2488
+ // no mutation happens before we know if it's possible
2489
+ if (!can_seq_rm (seq_id, p0, p1)) {
2490
+ return false ;
2491
+ }
2492
+
2493
+ // Do the removal from each child which should never fail
2462
2494
for (const auto & cache : m_children) {
2463
- removed = cache->seq_rm (seq_id, p0, p1) && removed;
2495
+ const bool failed = cache->seq_rm (seq_id, p0, p1);
2496
+ GGML_ASSERT (!failed);
2464
2497
}
2465
- return removed ;
2498
+ return true ;
2466
2499
}
2467
2500
2468
2501
void llama_kv_cache_hybrid::seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
@@ -2529,6 +2562,15 @@ void llama_kv_cache_hybrid::set_full() {
2529
2562
}
2530
2563
}
2531
2564
2565
+ bool llama_kv_cache_hybrid::can_seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) const {
2566
+ for (const auto & cache : m_children) {
2567
+ if (!cache->can_seq_rm (seq_id, p0, p1)) {
2568
+ return false ;
2569
+ }
2570
+ }
2571
+ return true ;
2572
+ }
2573
+
2532
2574
llama_sbatch llama_kv_cache_hybrid::sbatch_init (const llama_batch & batch, bool logits_all) {
2533
2575
// If any of the caches are recurrent, require equal split
2534
2576
return llama_sbatch (batch, m_hparams.n_embd , !m_has_recurrent, logits_all);
@@ -2566,7 +2608,7 @@ int32_t llama_kv_cache_hybrid::get_n_tokens() const {
2566
2608
2567
2609
int32_t llama_kv_cache_hybrid::get_used_cells () const {
2568
2610
// TODO: Is this correct?
2569
- // Return the largetst number of used cells
2611
+ // Return the largest number of used cells
2570
2612
int32_t used_cells = -1 ;
2571
2613
for (const auto & cache : m_children) {
2572
2614
used_cells = std::max (used_cells, cache->get_used_cells ());
0 commit comments