Skip to content

Commit 06ead3d

Browse files
committed
mamba : multiple sequences, but one at a time
This is a step towards making this Mamba implementation usable with the server example (the way the system prompt is kept when clearing the client slots will need to be changed before this can work, though). The KV cache size for this kind of model is tied to the maximum number of sequences kept at any single time. For now, this number is obtained from n_parallel (plus one, to have an extra sequence to dedicate to the system prompt), but there might be a better way to do this which won't also make the main example use 2 cells even if only 1 is really used. (for this specific case, --parallel 0 helps) Simultaneous sequence processing will probably require changes to ggml_ssm_scan, and possibly a new operator for the conv step. * mamba : support llama_kv_cache_seq_cp This (mis)uses the logic around K shifts, because tokens in a state can't be shifted anyway, and because inp_K_shift has the right shape and type. Using ggml_get_rows is a nice way to do copies, but copy chains can't work. Fortunately, copy chains don't really seem to be used in the examples. Each KV cell is dedicated to the sequence ID corresponding to its own index. * mamba : use a state mask It's cleaner than the previous heuristic of checking for the pos of the first token in the batch. inp_KQ_mask could not be re-used for this, because it has the wrong shape and because it seems more suited to the next step of simultaneous sequence processing (helping with the problem of remembering which token belongs to which sequence(s)/state(s)). * llama : replace the usage of n_ctx with kv_self.size in many places * mamba : use n_tokens directly instead of n_tok
1 parent 5d8d127 commit 06ead3d

File tree

4 files changed

+255
-89
lines changed

4 files changed

+255
-89
lines changed

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1243,6 +1243,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
12431243

12441244
cparams.n_ctx = params.n_ctx;
12451245
cparams.n_batch = params.n_batch;
1246+
cparams.n_parallel = params.n_parallel;
12461247
cparams.n_threads = params.n_threads;
12471248
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
12481249
cparams.mul_mat_q = params.mul_mat_q;

ggml.c

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5948,15 +5948,15 @@ struct ggml_tensor * ggml_ssm_scan(
59485948
GGML_ASSERT(ggml_is_matrix(s)); // the ssm_state should be 2D
59495949

59505950
{
5951-
const int64_t d_state = s->ne[0];
5952-
const int64_t d_inner = s->ne[1];
5953-
const int64_t n_tok = x->ne[1];
5951+
const int64_t d_state = s->ne[0];
5952+
const int64_t d_inner = s->ne[1];
5953+
const int64_t n_tokens = x->ne[1];
59545954

59555955
GGML_ASSERT(x->ne[0] == d_inner);
59565956
GGML_ASSERT(A->ne[0] == d_state);
59575957
GGML_ASSERT(A->ne[1] == d_inner);
59585958
GGML_ASSERT(B->ne[0] == d_state);
5959-
GGML_ASSERT(B->ne[1] == n_tok);
5959+
GGML_ASSERT(B->ne[1] == n_tokens);
59605960
}
59615961

59625962
bool is_node = false;
@@ -14237,12 +14237,12 @@ static void ggml_compute_forward_ssm_scan_f32(
1423714237

1423814238
// first batch
1423914239
{
14240-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tok}
14240+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1])); // {d_state, d_inner, n_tokens}
1424114241
float * s = (float *) ((char *) src0->data + ir0*(src0->nb[1])); // {d_state, d_inner}
14242-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tok}
14243-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tok}
14242+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0])); // {d_inner, n_tokens}
14243+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0])); // {d_inner, n_tokens}
1424414244
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14245-
float * B = (float *) ((char *) src4->data); // {d_state, n_tok}
14245+
float * B = (float *) ((char *) src4->data); // {d_state, n_tokens}
1424614246
// d_inner
1424714247
for (int i1 = 0; i1 < ir; ++i1) {
1424814248
float dt_soft_plus = log1pf(expf(dt[i1]));
@@ -14258,12 +14258,12 @@ static void ggml_compute_forward_ssm_scan_f32(
1425814258

1425914259
// compute state for rest of tokens, previous state comes from dest
1426014260
for (int i2 = 1; i2 < n_t; ++i2) {
14261-
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tok}
14262-
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tok}
14263-
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tok}
14264-
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tok}
14261+
float * pdst = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + i2 *( dst->nb[2])); // {d_state, d_inner, n_tokens}
14262+
float * s = (float *) ((char *) dst->data + ir0*( dst->nb[1]) + (i2-1)*( dst->nb[2])); // {d_state, d_inner, n_tokens}
14263+
float * x = (float *) ((char *) src1->data + ir0*(src1->nb[0]) + i2 *(src1->nb[1])); // {d_inner, n_tokens}
14264+
float * dt = (float *) ((char *) src2->data + ir0*(src2->nb[0]) + i2 *(src2->nb[1])); // {d_inner, n_tokens}
1426514265
float * A = (float *) ((char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
14266-
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tok}
14266+
float * B = (float *) ((char *) src4->data + i2*(src4->nb[1])); // {d_state, n_tokens}
1426714267
// d_inner
1426814268
for (int i1 = 0; i1 < ir; ++i1) {
1426914269
float dt_soft_plus = log1pf(expf(dt[i1]));

0 commit comments

Comments
 (0)