@@ -1743,6 +1743,7 @@ struct llama_layer {
1743
1743
struct llama_kv_cell {
1744
1744
llama_pos pos = -1;
1745
1745
llama_pos delta = 0;
1746
+ int32_t src = 0; // used by recurrent state models to copy states
1746
1747
1747
1748
std::set<llama_seq_id> seq_id;
1748
1749
@@ -1763,6 +1764,7 @@ struct llama_kv_cell {
1763
1764
struct llama_kv_cache {
1764
1765
bool has_shift = false;
1765
1766
bool do_defrag = false;
1767
+ bool do_copy = false;
1766
1768
// with Mamba, a cell can hold the state for more than one past token
1767
1769
bool unlimited = false;
1768
1770
@@ -2001,7 +2003,8 @@ struct llama_context {
2001
2003
struct ggml_tensor * inp_K_shift; // I32 [kv_size]
2002
2004
struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch]
2003
2005
struct ggml_tensor * inp_cls; // I32 [n_batch]
2004
- struct ggml_tensor * inp_s_mask; // F32 [kv_size] (only used by constant state models like Mamba)
2006
+ struct ggml_tensor * inp_s_copy; // I32 [kv_size]
2007
+ struct ggml_tensor * inp_s_mask; // F32 [kv_size]
2005
2008
struct ggml_tensor * inp_s_seq; // I32 [kv_size, n_batch]
2006
2009
2007
2010
#ifdef GGML_USE_MPI
@@ -2043,9 +2046,9 @@ static bool llama_kv_cache_init(
2043
2046
2044
2047
if (cache.unlimited) {
2045
2048
for (uint32_t i = 0; i < cache.size; ++i) {
2046
- cache.cells[i].delta = i;
2049
+ cache.cells[i].src = i;
2047
2050
}
2048
- } // else, delta is already initialized to zero
2051
+ }
2049
2052
2050
2053
#ifdef GGML_USE_CLBLAST
2051
2054
offload = false;
@@ -2296,19 +2299,20 @@ static void llama_kv_cache_seq_cp(
2296
2299
2297
2300
if (cache.unlimited) {
2298
2301
if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) {
2299
- seq_id_src = cache.cells[seq_id_src].delta ;
2302
+ seq_id_src = cache.cells[seq_id_src].src ;
2300
2303
GGML_ASSERT((uint32_t) seq_id_src < cache.size);
2301
2304
// intent to "copy from"
2302
2305
// supports copy chains thanks to taking the source of the source
2303
- cache.cells[seq_id_dst].delta = seq_id_src;
2306
+ cache.cells[seq_id_dst].src = seq_id_src;
2304
2307
2305
- // prevent the destination from getting cleared if the source is not empty
2308
+ // preserve the "keep or clear" status of the copied sequence
2306
2309
if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) {
2307
2310
cache.cells[seq_id_dst].seq_id.insert(seq_id_dst);
2311
+ } else {
2312
+ cache.cells[seq_id_dst].seq_id.erase(seq_id_dst);
2308
2313
}
2309
- // repurposed as a "need copy" flag
2310
- // (shifting can't be done anyway for this kind of KV cache)
2311
- cache.has_shift = true;
2314
+
2315
+ cache.do_copy = true;
2312
2316
2313
2317
cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos;
2314
2318
}
@@ -5335,21 +5339,7 @@ struct llm_build_context {
5335
5339
struct ggml_cgraph * build_k_shift() {
5336
5340
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5337
5341
5338
- // TODO: do this in a another graph with a dedicated input tensor
5339
- if (kv_self.unlimited) {
5340
- for (int il = 0; il < n_layer; ++il) {
5341
- ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
5342
- ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
5343
-
5344
- conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_K_shift);
5345
- ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_K_shift);
5346
-
5347
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5348
- ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
5349
- }
5350
-
5351
- return gf;
5352
- }
5342
+ GGML_ASSERT(kv_self.size == n_ctx);
5353
5343
5354
5344
for (int il = 0; il < n_layer; ++il) {
5355
5345
struct ggml_tensor * tmp =
@@ -5369,6 +5359,25 @@ struct llm_build_context {
5369
5359
return gf;
5370
5360
}
5371
5361
5362
+ struct ggml_cgraph * build_s_copy() {
5363
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5364
+
5365
+ for (int il = 0; il < n_layer; ++il) {
5366
+ ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, kv_self.size);
5367
+ ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], n_embd_v_gqa, kv_self.size);
5368
+
5369
+ conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
5370
+ ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
5371
+
5372
+ // TODO: name the intermediate tensors with cb()
5373
+
5374
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il]));
5375
+ ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il]));
5376
+ }
5377
+
5378
+ return gf;
5379
+ }
5380
+
5372
5381
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) {
5373
5382
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
5374
5383
@@ -7985,6 +7994,23 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
7985
7994
return result;
7986
7995
}
7987
7996
7997
+ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
7998
+ llama_batch dummy;
7999
+ dummy.n_tokens = 0;
8000
+
8001
+ llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
8002
+
8003
+ struct llm_build_context llm(lctx, dummy, cb, false);
8004
+
8005
+ llm.init();
8006
+
8007
+ struct ggml_cgraph * result = llm.build_s_copy();
8008
+
8009
+ llm.free();
8010
+
8011
+ return result;
8012
+ }
8013
+
7988
8014
static struct ggml_cgraph * llama_build_graph(
7989
8015
llama_context & lctx,
7990
8016
const llama_batch & batch,
@@ -8120,6 +8146,18 @@ static void llama_set_k_shift(llama_context & lctx) {
8120
8146
}
8121
8147
}
8122
8148
8149
+ static void llama_set_s_copy(llama_context & lctx) {
8150
+ const int64_t kv_size = lctx.kv_self.size;
8151
+
8152
+ assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
8153
+
8154
+ int32_t * data = (int32_t *) lctx.inp_s_copy->data;
8155
+
8156
+ for (int i = 0; i < kv_size; ++i) {
8157
+ data[i] = lctx.kv_self.cells[i].src;
8158
+ }
8159
+ }
8160
+
8123
8161
static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8124
8162
//
8125
8163
// set input data
@@ -8234,17 +8272,17 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
8234
8272
}
8235
8273
8236
8274
if (kv_self.unlimited) {
8237
- const int64_t n_kv = kv_self.n;
8275
+ const int64_t n_kv = kv_self.n;
8238
8276
8239
8277
{
8240
8278
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer));
8241
8279
float * data = (float *) lctx.inp_s_mask->data;
8242
8280
8243
8281
// states which are not affected by the current batch are left untouched
8244
8282
for (int i = 0; i < n_kv; ++i) {
8245
- llama_seq_id seq_id = i + lctx.kv_self.head;
8246
- llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
8247
- bool has_self_seq = kv_cell.has_seq_id(seq_id);
8283
+ llama_seq_id seq_id = i + lctx.kv_self.head;
8284
+ llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id];
8285
+ bool has_self_seq = kv_cell.has_seq_id(seq_id);
8248
8286
8249
8287
data[i] = (float) has_self_seq;
8250
8288
@@ -8731,7 +8769,7 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
8731
8769
8732
8770
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
8733
8771
// apply K-shift if needed
8734
- if (( lctx.kv_self.unlimited || lctx. model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) && lctx.kv_self.has_shift) {
8772
+ if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
8735
8773
llama_set_k_shift(lctx);
8736
8774
8737
8775
{
@@ -8746,7 +8784,27 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
8746
8784
kv_self.has_shift = false;
8747
8785
8748
8786
for (uint32_t i = 0; i < kv_self.size; ++i) {
8749
- kv_self.cells[i].delta = kv_self.unlimited ? i : 0;
8787
+ kv_self.cells[i].delta = 0;
8788
+ }
8789
+ }
8790
+ }
8791
+
8792
+ if (lctx.kv_self.unlimited && lctx.kv_self.do_copy) {
8793
+ llama_set_s_copy(lctx);
8794
+
8795
+ {
8796
+ ggml_cgraph * gf = llama_build_graph_s_copy(lctx);
8797
+
8798
+ llama_graph_compute(lctx, gf, lctx.cparams.n_threads);
8799
+ }
8800
+
8801
+ {
8802
+ auto & kv_self = lctx.kv_self;
8803
+
8804
+ kv_self.do_copy = false;
8805
+
8806
+ for (uint32_t i = 0; i < kv_self.size; ++i) {
8807
+ kv_self.cells[i].src = i;
8750
8808
}
8751
8809
}
8752
8810
}
@@ -12458,7 +12516,7 @@ struct llama_context * llama_new_context_with_model(
12458
12516
// graph inputs
12459
12517
{
12460
12518
ggml_init_params init_params = {
12461
- /* .mem_size */ ggml_tensor_overhead()*(8 + 2 *(ctx->kv_self.unlimited)),
12519
+ /* .mem_size */ ggml_tensor_overhead()*(8 + 3 *(ctx->kv_self.unlimited)),
12462
12520
/* .mem_buffer */ nullptr,
12463
12521
/* .no_alloc */ true,
12464
12522
};
@@ -12473,6 +12531,7 @@ struct llama_context * llama_new_context_with_model(
12473
12531
ctx->inp_mean = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
12474
12532
ctx->inp_cls = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
12475
12533
if (ctx->kv_self.unlimited) {
12534
+ ctx->inp_s_copy = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, kv_size);
12476
12535
ctx->inp_s_mask = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, kv_size);
12477
12536
ctx->inp_s_seq = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_I32, kv_size, cparams.n_batch);
12478
12537
}
@@ -12486,6 +12545,7 @@ struct llama_context * llama_new_context_with_model(
12486
12545
ggml_set_name(ctx->inp_mean, "inp_mean");
12487
12546
ggml_set_name(ctx->inp_cls, "inp_cls");
12488
12547
if (ctx->kv_self.unlimited) {
12548
+ ggml_set_name(ctx->inp_s_copy, "inp_s_copy");
12489
12549
ggml_set_name(ctx->inp_s_mask, "inp_s_mask");
12490
12550
ggml_set_name(ctx->inp_s_seq, "inp_s_seq");
12491
12551
}
0 commit comments