Skip to content

Commit 9c62d55

Browse files
committed
fix: Fix confusion on simple vs equal splitting
Branch: GraniteFour Signed-off-by: Gabe Goodhart <[email protected]>
1 parent 7778b5b commit 9c62d55

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/llama-kv-cache.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2530,19 +2530,19 @@ void llama_kv_cache_hybrid::set_full() {
25302530
}
25312531

25322532
llama_sbatch llama_kv_cache_hybrid::sbatch_init(const llama_batch & batch, bool logits_all) {
2533-
// If any of the caches are recurrent, require simple split
2534-
return llama_sbatch(batch, m_hparams.n_embd, m_has_recurrent, logits_all);
2533+
// If any of the caches are recurrent, require equal split
2534+
return llama_sbatch(batch, m_hparams.n_embd, !m_has_recurrent, logits_all);
25352535
}
25362536

25372537
llama_ubatch llama_kv_cache_hybrid::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const {
2538-
if (m_has_recurrent) {
2539-
return sbatch.split_simple(n_ubatch);
2540-
}
25412538
if (embd_pooled) {
25422539
// Pooled embeddings cannot be split across ubatches (yet)
25432540
return sbatch.split_seq(n_ubatch);
25442541
}
2545-
return sbatch.split_equal(n_ubatch);
2542+
if (m_has_recurrent) {
2543+
return sbatch.split_equal(n_ubatch);
2544+
}
2545+
return sbatch.split_simple(n_ubatch);
25462546
}
25472547

25482548
bool llama_kv_cache_hybrid::find_slot(const llama_ubatch & batch) {

0 commit comments

Comments
 (0)