Skip to content

Commit 79bcf52

Browse files
fix padding, GQA
1 parent 8d56dad commit 79bcf52

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

ggml-cuda.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7552,9 +7552,10 @@ static __global__ void flash_attn_ext_f16(
75527552
__builtin_assume(tid < nthreads);
75537553
constexpr int D_padded = D + 8; // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts.
75547554

7555-
const float * Q_f = (const float *) (Q + nb02*blockIdx.y + ncols*nb01*blockIdx.x);
7556-
const half * K_h = (const half *) (K + nb12*blockIdx.y);
7557-
const half * V_h = (const half *) (V + nb12*blockIdx.y); // K and V have same shape
7555+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
7556+
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + ncols*nb01*blockIdx.x);
7557+
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
7558+
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
75587559
const half2 * mask2 = (half2 *) mask + ncols*ne11*blockIdx.x/2;
75597560

75607561
const int stride_Q = nb01 / sizeof(float);

llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9166,7 +9166,7 @@ static int llama_decode_internal(
91669166
// a heuristic, to avoid attending the full cache if it is not yet utilized
91679167
// after enough generations, the benefit from this heuristic disappears
91689168
// if we start defragmenting the cache, the benefit from this will be more important
9169-
kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
9169+
kv_self.n = std::min(kv_self.size, std::max(256u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 256)));
91709170
//kv_self.n = llama_kv_cache_cell_max(kv_self);
91719171
}
91729172
}
@@ -13083,7 +13083,7 @@ struct llama_context * llama_new_context_with_model(
1308313083
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
1308413084

1308513085
// this is necessary due to kv_self.n being padded later during inference
13086-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 32);
13086+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
1308713087

1308813088
// with causal attention, the batch size is limited by the context size
1308913089
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;

0 commit comments

Comments
 (0)