Skip to content

Commit 7b1ff55

Browse files
committed
mamba : dedicate an input tensor for state copy indices
This is cleaner and makes it easier to adapt when/if token positions (and by extension, inp_K_shift) are no longer integers.
1 parent 2e1ddf4 commit 7b1ff55

File tree

1 file changed

+91
-31
lines changed

1 file changed

+91
-31
lines changed

llama.cpp

Lines changed: 91 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1743,6 +1743,7 @@ struct llama_layer {
17431743
struct llama_kv_cell {
17441744
llama_pos pos = -1;
17451745
llama_pos delta = 0;
1746+
int32_t src = 0; // used by recurrent state models to copy states
17461747

17471748
std::set<llama_seq_id> seq_id;
17481749

@@ -1763,6 +1764,7 @@ struct llama_kv_cell {
17631764
struct llama_kv_cache {
17641765
bool has_shift = false;
17651766
bool do_defrag = false;
1767+
bool do_copy = false;
17661768
// with Mamba, a cell can hold the state for more than one past token
17671769
bool unlimited = false;
17681770

@@ -2001,7 +2003,8 @@ struct llama_context {
20012003
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
20022004
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
20032005
struct ggml_tensor * inp_cls; // I32 [n_batch]
2004-
struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
2006+
struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2007+
struct ggml_tensor * inp_s_mask; // F32 [kv_size]
20052008
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
20062009

20072010
#ifdef GGML_USE_MPI
@@ -2043,9 +2046,9 @@ static bool llama_kv_cache_init(
20432046

20442047
if (cache.unlimited) {
20452048
for (uint32_t i = 0; i < cache.size; ++i) {
2046-
cache.cells[i].delta = i;
2049+
cache.cells[i].src = i;
20472050
}
2048-
} // else, delta is already initialized to zero
2051+
}
20492052

20502053
#ifdef GGML_USE_CLBLAST
20512054
offload = false;
@@ -2296,19 +2299,20 @@ static void llama_kv_cache_seq_cp(
22962299

22972300
if (cache.unlimited) {
22982301
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2299-
seq_id_src = cache.cells[seq_id_src].delta;
2302+
seq_id_src = cache.cells[seq_id_src].src;
23002303
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
23012304
// intent to "copy from"
23022305
// supports copy chains thanks to taking the source of the source
2303-
cache.cells[seq_id_dst].delta = seq_id_src;
2306+
cache.cells[seq_id_dst].src = seq_id_src;
23042307

2305-
// prevent the destination from getting cleared if the source is not empty
2308+
// preserve the "keep or clear" status of the copied sequence
23062309
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
23072310
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2311+
} else {
2312+
cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
23082313
}
2309-
// repurposed as a "need copy" flag
2310-
// (shifting can't be done anyway for this kind of KV cache)
2311-
cache.has_shift = true;
2314+
2315+
cache.do_copy = true;
23122316

23132317
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
23142318
}
@@ -5335,21 +5339,7 @@ struct llm_build_context {
53355339
struct ggml_cgraph * build_k_shift() {
53365340
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
53375341

5338-
// TODO: do this in a another graph with a dedicated input tensor
5339-
if (kv_self.unlimited) {
5340-
for (int il = 0; il < n_layer; ++il) {
5341-
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
5342-
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
5343-
5344-
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift);
5345-
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift);
5346-
5347-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5348-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
5349-
}
5350-
5351-
return gf;
5352-
}
5342+
GGML_ASSERT(kv_self.size == n_ctx);
53535343

53545344
for (int il = 0; il < n_layer; ++il) {
53555345
struct ggml_tensor * tmp =
@@ -5369,6 +5359,25 @@ struct llm_build_context {
53695359
return gf;
53705360
}
53715361

5362+
struct ggml_cgraph * build_s_copy() {
5363+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5364+
5365+
for (int il = 0; il < n_layer; ++il) {
5366+
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
5367+
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
5368+
5369+
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
5370+
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
5371+
5372+
// TODO: name the intermediate tensors with cb()
5373+
5374+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5375+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
5376+
}
5377+
5378+
return gf;
5379+
}
5380+
53725381
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
53735382
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
53745383

@@ -7985,6 +7994,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
79857994
return result;
79867995
}
79877996

7997+
static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
7998+
llama_batch dummy;
7999+
dummy.n_tokens = 0;
8000+
8001+
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
8002+
8003+
struct llm_build_context llm(lctx, dummy, cb, false);
8004+
8005+
llm.init();
8006+
8007+
struct ggml_cgraph * result = llm.build_s_copy();
8008+
8009+
llm.free();
8010+
8011+
return result;
8012+
}
8013+
79888014
static struct ggml_cgraph * llama_build_graph(
79898015
llama_context & lctx,
79908016
const llama_batch & batch,
@@ -8120,6 +8146,18 @@ static void llama_set_k_shift(llama_context & lctx) {
81208146
}
81218147
}
81228148

8149+
static void llama_set_s_copy(llama_context & lctx) {
8150+
const int64_t kv_size = lctx.kv_self.size;
8151+
8152+
assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
8153+
8154+
int32_t * data = (int32_t *) lctx.inp_s_copy->data;
8155+
8156+
for (int i = 0; i < kv_size; ++i) {
8157+
data[i] = lctx.kv_self.cells[i].src;
8158+
}
8159+
}
8160+
81238161
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
81248162
//
81258163
// set input data
@@ -8234,17 +8272,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
82348272
}
82358273

82368274
if (kv_self.unlimited) {
8237-
const int64_t n_kv = kv_self.n;
8275+
const int64_t n_kv = kv_self.n;
82388276

82398277
{
82408278
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
82418279
float * data = (float *) lctx.inp_s_mask->data;
82428280

82438281
// states which are not affected by the current batch are left untouched
82448282
for (int i = 0; i < n_kv; ++i) {
8245-
llama_seq_id seq_id = i + lctx.kv_self.head;
8246-
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
8247-
bool has_self_seq = kv_cell.has_seq_id(seq_id);
8283+
llama_seq_id seq_id = i + lctx.kv_self.head;
8284+
llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
8285+
bool has_self_seq = kv_cell.has_seq_id(seq_id);
82488286

82498287
data[i] = (float) has_self_seq;
82508288

@@ -8731,7 +8769,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
87318769

87328770
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
87338771
// apply K-shift if needed
8734-
if ((lctx.kv_self.unlimited || lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) && lctx.kv_self.has_shift) {
8772+
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
87358773
llama_set_k_shift(lctx);
87368774

87378775
{
@@ -8746,7 +8784,27 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
87468784
kv_self.has_shift = false;
87478785

87488786
for (uint32_t i = 0; i < kv_self.size; ++i) {
8749-
kv_self.cells[i].delta = kv_self.unlimited ? i : 0;
8787+
kv_self.cells[i].delta = 0;
8788+
}
8789+
}
8790+
}
8791+
8792+
if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) {
8793+
llama_set_s_copy(lctx);
8794+
8795+
{
8796+
ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
8797+
8798+
llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
8799+
}
8800+
8801+
{
8802+
auto & kv_self = lctx.kv_self;
8803+
8804+
kv_self.do_copy = false;
8805+
8806+
for (uint32_t i = 0; i < kv_self.size; ++i) {
8807+
kv_self.cells[i].src = i;
87508808
}
87518809
}
87528810
}
@@ -12458,7 +12516,7 @@ struct llama_context * llama_new_context_with_model(
1245812516
// graph inputs
1245912517
{
1246012518
ggml_init_params init_params = {
12461-
/* .mem_size */ ggml_tensor_overhead()*(8 + 2*(ctx->kv_self.unlimited)),
12519+
/* .mem_size */ ggml_tensor_overhead()*(8 + 3*(ctx->kv_self.unlimited)),
1246212520
/* .mem_buffer */ nullptr,
1246312521
/* .no_alloc */ true,
1246412522
};
@@ -12473,6 +12531,7 @@ struct llama_context * llama_new_context_with_model(
1247312531
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
1247412532
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
1247512533
if (ctx->kv_self.unlimited) {
12534+
ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
1247612535
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
1247712536
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
1247812537
}
@@ -12486,6 +12545,7 @@ struct llama_context * llama_new_context_with_model(
1248612545
ggml_set_name(ctx->inp_mean, "inp_mean");
1248712546
ggml_set_name(ctx->inp_cls, "inp_cls");
1248812547
if (ctx->kv_self.unlimited) {
12548+
ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
1248912549
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
1249012550
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
1249112551
}

0 commit comments

Comments
 (0)