Skip to content

Commit 9c67c27

Browse files
ggerganovJohannesGaesslerphymbert
authored
ggml : add Flash Attention (#5021)
* ggml : add ggml_flash_attn_ext API * ggml : fix GQA support in ggml_flash_attn_ext * ggml : online attention (CPU) * metal : initial implementation * metal : f16 precision * metal : reduce branches * metal : specialize for head size * wip : 8 rows per simd group * wip : 4 rows per simd group * wip : template for rows per warp * metal : parallelize across KV size * metal : parallel reduce across heads * metal : efficient flash_attn_f16 implementation * metal : avoid redundant loads of the attention * metal : scale and mask in matrix form * metal : fix comment * llama : avoid ggml_cast, use F32 query * metal : add parallel reduce version (disabled) * metal : move output into local memory + optimize - the result from each simdgroup now stays in the registers - significantly reduced SRAM usage - more efficient skipping of -INF blocks - avoid simdgroup barrier in hot loop - add comments * metal : add tests, fix scaling, support C > 32 * metal : improve precision * ggml : fix f16 mad * metal : minor * metal : support Q > 8 * tests : add ATTN tests * metal : disable buffer allocation logs * tests : more * metal : faster inner loop for C == 32 * metal : fix array initialization * tests : ifdef * ggml : switch to padded F16 mask for ggml_soft_max, ggml_flash_attn_ext * ggml : fix ggml_soft_max mask requirement * cuda : fix soft_max to use correct mask size * cuda : add flash_attn kernel (wip) * metal : optimize softmax for C > 32 * metal : optimize softmax * tests : minor fix * cuda : avoid zeroing fragments * tests : update dims * cuda : fix __hisinf() result check * cuda : avoid warp_reduce for smax * cuda : use int instead of int64_t Noticeably improves performance (thanks to Johannes) * cuda : make loops use the same loop values Thanks Johannes again for the tip * cuda : unroll some of the loops * cuda : avoid __hisinf branches * cuda : use half2 in softmax * cuda : switch to 1 warp for bs > 16 * cuda : speed-up reduce part of the kernel * cuda : unroll Q*K^T loop * cuda : fix -INF block check * cuda : simplify softmax * cuda : fix matrix names * cuda : minor * llama : adapt to F16 KQ_pos * llama : adapt new models to F16 KQ_mask * ggml : fix F16 store (ARM NEON) * llama : fix type of KQ_mask and KQ_pos * ggml : fix CPU soft_max * tests : add hs=256 * cuda : fix build * metal : improve perf via smaller int registers * cuda : adapt soft_max to F16 mask and pos * CUDA: faster FlashAttention, kernel for bs == 1 * 16 cols for Phi-2 * no vec for hs, no hs==256 ncols==32 for Volta * adjust kernel selection logic * 4 warps, 256 stride for all D * no ncols == 64 * Multiple parallel blocks for batch size 1 * fix compile warnings * fix excessive KQ_b loads * fix cmake build * fix KV cache padding, NaN from INFINITY (#6438) * llama : flash_attn cparam + fix defrag * server: support flash_attn param * server: bench: enable flash_attn param * CUDA: refactor host code, dyn. par. blocks * fix flash_attn_vec_f16 race condition * flush softmax exp below threshold to 0 * store temp KQ in registers * Calculate KQ as FP32 if KQV has GGML_PREC_F32 * Add __hgt2_mask implementation for CUDA 11 * fix KQ FP32 precision fpr parallel_blocks > 1 * llama-bench : add -fa,--flash-attn arg * metal : add BS=1 kernel for flash attention (#6508) * metal : add BS=1 kernel for flash attention (wip) * metal : support more than 1 warps * metal : opts * metal : opt * metal : switch to parallel reduce * metal : reduce registers * metal : simplify * metal : initial FA vec kernel * metal : use F32 attention accumulators * batched-bench : add fattn arg * llama : simplify llama_build_kv_store ggml-ci * llama : adapt build_olmo to changes * ggml : fix arm fp16 store on windows * metal : clean-up * metal : clean-up kernel code * metal : minor * tests : remove benchmarks ggml-ci * ggml : fix avx512 const correctness ggml-ci * ggml : fix soft_max with bias on CPU ggml-ci * common : print --flash-attn in help * ggml : fix num dimensions in ggml_flash_attn_ext * llama : force disable flash attention for incompatible models * ggml : ggml_soft_max support F16/F32 mask/pos ggml-ci * cuda : uint -> uint32_t * cuda : "constexpr dim3" -> "const dim3" ggml-ci * cuda : try to fix __hgt2_mask ggml-ci * ggml : add TODO's for F16/F32 mask/pos support in other backends * llama : replace bool need_kq_pos with use_alibi * llama : prep ALiBi support for BERT models ggml-ci * llama : fix n_batch requirements ggml-ci * cont * server : add help for --flash-attn arg * llama : disable FA for AMD * tests : remove TMP_ATTN_BENCH ggml-ci * llama : support save/load state with FA enabled ggml-ci * ci : add CUDA save-load-state tests ggml-ci * llama : llama_kv_cache_clear zeroes data + fix save-load seq ggml-ci * llama : fix copy-paste errors, add TODO * llama : disallow incompatible states * llama : update llama_state_get_size after v_trans field * metal : remove tmp log * llama : add static reminder for llama_state_get_size * metal : fix max nsg ggml-ci * ci : fix arg order ggml-ci --------- Co-authored-by: Johannes Gäßler <[email protected]> Co-authored-by: Pierrick HYMBERT <[email protected]>
1 parent 952d03d commit 9c67c27

22 files changed

+2917
-453
lines changed

ci/run.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,8 @@ function gg_run_open_llama_3b_v2 {
336336

337337
(time ./bin/imatrix --model ${model_f16} -f ${wiki_test_60} -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
338338

339-
(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
339+
(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
340+
(time ./bin/save-load-state -fa --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
340341

341342
function check_ppl {
342343
qnt="$1"
@@ -517,7 +518,10 @@ function gg_run_open_llama_7b_v2 {
517518

518519
(time ./bin/imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
519520

520-
(time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
521+
(time ./bin/save-load-state -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
522+
(time ./bin/save-load-state -fa -ngl 10 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
523+
(time ./bin/save-load-state -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
524+
(time ./bin/save-load-state -fa -ngl 99 --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
521525

522526
function check_ppl {
523527
qnt="$1"

common/common.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
947947
params.cont_batching = true;
948948
return true;
949949
}
950+
if (arg == "-fa" || arg == "--flash-attn") {
951+
params.flash_attn = true;
952+
return true;
953+
}
950954
if (arg == "--color") {
951955
params.use_color = true;
952956
return true;
@@ -1513,6 +1517,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
15131517
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
15141518
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
15151519
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
1520+
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
15161521
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
15171522
printf(" --image IMAGE_FILE path to an image file. use with multimodal models. Specify multiple times for batching\n");
15181523
if (llama_supports_mlock()) {
@@ -1885,6 +1890,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
18851890
cparams.cb_eval = params.cb_eval;
18861891
cparams.cb_eval_user_data = params.cb_eval_user_data;
18871892
cparams.offload_kqv = !params.no_kv_offload;
1893+
cparams.flash_attn = params.flash_attn;
18881894

18891895
cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
18901896
cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
@@ -2707,6 +2713,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
27072713
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
27082714
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
27092715
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
2716+
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
27102717
fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
27112718

27122719
const std::vector<float> tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices());

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ struct gpt_params {
150150
bool multiline_input = false; // reverse the usage of `\`
151151
bool simple_io = false; // improves compatibility with subprocesses and limited consoles
152152
bool cont_batching = true; // insert new sequences for decoding on-the-fly
153+
bool flash_attn = false; // flash attention
153154

154155
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
155156
bool ignore_eos = false; // ignore generated EOS tokens

examples/batched-bench/batched-bench.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ int main(int argc, char ** argv) {
3232
gpt_params params;
3333

3434
if (argc == 1 || argv[1][0] == '-') {
35-
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
35+
printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [FATTN] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
3636
printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
3737
printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
3838
return 1 ;
@@ -41,6 +41,7 @@ int main(int argc, char ** argv) {
4141
int n_kv_max = 2048;
4242
int n_batch = 2048;
4343
int n_ubatch = 512;
44+
bool flash_attn = false;
4445
int is_pp_shared = 0;
4546
int n_gpu_layers = 0;
4647

@@ -66,23 +67,27 @@ int main(int argc, char ** argv) {
6667
}
6768

6869
if (argc >= 6) {
69-
is_pp_shared = std::atoi(argv[5]);
70+
flash_attn = std::atoi(argv[5]);
7071
}
7172

7273
if (argc >= 7) {
73-
n_gpu_layers = std::atoi(argv[6]);
74+
is_pp_shared = std::atoi(argv[6]);
7475
}
7576

7677
if (argc >= 8) {
77-
n_pp = parse_list(argv[7]);
78+
n_gpu_layers = std::atoi(argv[7]);
7879
}
7980

8081
if (argc >= 9) {
81-
n_tg = parse_list(argv[8]);
82+
n_pp = parse_list(argv[8]);
8283
}
8384

8485
if (argc >= 10) {
85-
n_pl = parse_list(argv[9]);
86+
n_tg = parse_list(argv[9]);
87+
}
88+
89+
if (argc >= 11) {
90+
n_pl = parse_list(argv[10]);
8691
}
8792

8893
// init LLM
@@ -108,10 +113,11 @@ int main(int argc, char ** argv) {
108113

109114
llama_context_params ctx_params = llama_context_default_params();
110115

111-
ctx_params.seed = 1234;
112-
ctx_params.n_ctx = n_kv_max;
113-
ctx_params.n_batch = n_batch;
114-
ctx_params.n_ubatch = n_ubatch;
116+
ctx_params.seed = 1234;
117+
ctx_params.n_ctx = n_kv_max;
118+
ctx_params.n_batch = n_batch;
119+
ctx_params.n_ubatch = n_ubatch;
120+
ctx_params.flash_attn = flash_attn;
115121

116122
ctx_params.n_threads = params.n_threads;
117123
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
@@ -169,7 +175,7 @@ int main(int argc, char ** argv) {
169175
}
170176

171177
LOG_TEE("\n");
172-
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
178+
LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, flash_attn = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, flash_attn, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
173179
LOG_TEE("\n");
174180

175181
LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");

examples/llama-bench/llama-bench.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ struct cmd_params {
174174
std::vector<llama_split_mode> split_mode;
175175
std::vector<int> main_gpu;
176176
std::vector<bool> no_kv_offload;
177+
std::vector<bool> flash_attn;
177178
std::vector<std::vector<float>> tensor_split;
178179
std::vector<bool> use_mmap;
179180
std::vector<bool> embeddings;
@@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = {
195196
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
196197
/* main_gpu */ {0},
197198
/* no_kv_offload */ {false},
199+
/* flash_attn */ {false},
198200
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
199201
/* use_mmap */ {true},
200202
/* embeddings */ {false},
@@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) {
220222
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
221223
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
222224
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
225+
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
223226
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
224227
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
225228
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
@@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
393396
}
394397
auto p = split<bool>(argv[i], split_delim);
395398
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
399+
} else if (arg == "-fa" || arg == "--flash-attn") {
400+
if (++i >= argc) {
401+
invalid_param = true;
402+
break;
403+
}
404+
auto p = split<bool>(argv[i], split_delim);
405+
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
396406
} else if (arg == "-mmp" || arg == "--mmap") {
397407
if (++i >= argc) {
398408
invalid_param = true;
@@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
477487
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
478488
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
479489
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
490+
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
480491
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
481492
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
482493
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -498,6 +509,7 @@ struct cmd_params_instance {
498509
llama_split_mode split_mode;
499510
int main_gpu;
500511
bool no_kv_offload;
512+
bool flash_attn;
501513
std::vector<float> tensor_split;
502514
bool use_mmap;
503515
bool embeddings;
@@ -532,6 +544,7 @@ struct cmd_params_instance {
532544
cparams.type_k = type_k;
533545
cparams.type_v = type_v;
534546
cparams.offload_kqv = !no_kv_offload;
547+
cparams.flash_attn = flash_attn;
535548
cparams.embeddings = embeddings;
536549

537550
return cparams;
@@ -554,6 +567,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
554567
for (const auto & tk : params.type_k)
555568
for (const auto & tv : params.type_v)
556569
for (const auto & nkvo : params.no_kv_offload)
570+
for (const auto & fa : params.flash_attn)
557571
for (const auto & nt : params.n_threads) {
558572
for (const auto & n_prompt : params.n_prompt) {
559573
if (n_prompt == 0) {
@@ -572,6 +586,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
572586
/* .split_mode = */ sm,
573587
/* .main_gpu = */ mg,
574588
/* .no_kv_offload= */ nkvo,
589+
/* .flash_attn = */ fa,
575590
/* .tensor_split = */ ts,
576591
/* .use_mmap = */ mmp,
577592
/* .embeddings = */ embd,
@@ -596,6 +611,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
596611
/* .split_mode = */ sm,
597612
/* .main_gpu = */ mg,
598613
/* .no_kv_offload= */ nkvo,
614+
/* .flash_attn = */ fa,
599615
/* .tensor_split = */ ts,
600616
/* .use_mmap = */ mmp,
601617
/* .embeddings = */ embd,
@@ -633,6 +649,7 @@ struct test {
633649
llama_split_mode split_mode;
634650
int main_gpu;
635651
bool no_kv_offload;
652+
bool flash_attn;
636653
std::vector<float> tensor_split;
637654
bool use_mmap;
638655
bool embeddings;
@@ -657,6 +674,7 @@ struct test {
657674
split_mode = inst.split_mode;
658675
main_gpu = inst.main_gpu;
659676
no_kv_offload = inst.no_kv_offload;
677+
flash_attn = inst.flash_attn;
660678
tensor_split = inst.tensor_split;
661679
use_mmap = inst.use_mmap;
662680
embeddings = inst.embeddings;
@@ -731,7 +749,7 @@ struct test {
731749
"n_batch", "n_ubatch",
732750
"n_threads", "type_k", "type_v",
733751
"n_gpu_layers", "split_mode",
734-
"main_gpu", "no_kv_offload",
752+
"main_gpu", "no_kv_offload", "flash_attn",
735753
"tensor_split", "use_mmap", "embeddings",
736754
"n_prompt", "n_gen", "test_time",
737755
"avg_ns", "stddev_ns",
@@ -753,7 +771,7 @@ struct test {
753771
}
754772
if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
755773
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
756-
field == "use_mmap" || field == "embeddings") {
774+
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
757775
return BOOL;
758776
}
759777
if (field == "avg_ts" || field == "stddev_ts") {
@@ -787,7 +805,7 @@ struct test {
787805
std::to_string(n_batch), std::to_string(n_ubatch),
788806
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
789807
std::to_string(n_gpu_layers), split_mode_str(split_mode),
790-
std::to_string(main_gpu), std::to_string(no_kv_offload),
808+
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
791809
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
792810
std::to_string(n_prompt), std::to_string(n_gen), test_time,
793811
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -955,6 +973,9 @@ struct markdown_printer : public printer {
955973
if (field == "no_kv_offload") {
956974
return "nkvo";
957975
}
976+
if (field == "flash_attn") {
977+
return "fa";
978+
}
958979
if (field == "use_mmap") {
959980
return "mmap";
960981
}
@@ -1001,6 +1022,9 @@ struct markdown_printer : public printer {
10011022
if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) {
10021023
fields.emplace_back("no_kv_offload");
10031024
}
1025+
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
1026+
fields.emplace_back("flash_attn");
1027+
}
10041028
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
10051029
fields.emplace_back("tensor_split");
10061030
}

examples/server/bench/bench.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ def start_server_background(args):
268268
server_args.extend(['--defrag-thold', "0.1"])
269269
server_args.append('--cont-batching')
270270
server_args.append('--metrics')
271+
server_args.append('--flash-attn')
271272
server_args.extend(['--log-format', "text"])
272273
args = [str(arg) for arg in [server_path, *server_args]]
273274
print(f"bench: starting server with: {' '.join(args)}")

examples/server/server.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2377,6 +2377,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params, co
23772377
printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
23782378
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
23792379
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n");
2380+
printf(" -fa, --flash-attn enable Flash Attention (default: %s)\n", params.flash_attn ? "enabled" : "disabled");
23802381
printf(" -spf FNAME, --system-prompt-file FNAME\n");
23812382
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
23822383
printf(" -ctk TYPE, --cache-type-k TYPE\n");
@@ -2742,6 +2743,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
27422743
params.embedding = true;
27432744
} else if (arg == "-cb" || arg == "--cont-batching") {
27442745
params.cont_batching = true;
2746+
} else if (arg == "-fa" || arg == "--flash-attn") {
2747+
params.flash_attn = true;
27452748
} else if (arg == "-np" || arg == "--parallel") {
27462749
if (++i >= argc) {
27472750
invalid_param = true;

ggml-cuda.cu

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "ggml-cuda/cpy.cuh"
1515
#include "ggml-cuda/diagmask.cuh"
1616
#include "ggml-cuda/dmmv.cuh"
17+
#include "ggml-cuda/fattn.cuh"
1718
#include "ggml-cuda/getrows.cuh"
1819
#include "ggml-cuda/im2col.cuh"
1920
#include "ggml-cuda/mmq.cuh"
@@ -140,6 +141,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
140141
info.devices[id].cc = 100*prop.major + 10*prop.minor;
141142
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
142143
info.devices[id].smpb = prop.sharedMemPerBlock;
144+
info.devices[id].nsm = prop.multiProcessorCount;
143145
}
144146

145147
for (int id = 0; id < info.device_count; ++id) {
@@ -2290,6 +2292,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22902292
case GGML_OP_ARGSORT:
22912293
ggml_cuda_op_argsort(ctx, dst);
22922294
break;
2295+
case GGML_OP_FLASH_ATTN_EXT:
2296+
ggml_cuda_flash_attn_ext(ctx, dst);
2297+
break;
22932298
default:
22942299
return false;
22952300
}
@@ -2564,6 +2569,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
25642569
case GGML_OP_ARANGE:
25652570
case GGML_OP_TIMESTEP_EMBEDDING:
25662571
case GGML_OP_LEAKY_RELU:
2572+
case GGML_OP_FLASH_ATTN_EXT:
25672573
return true;
25682574
default:
25692575
return false;

0 commit comments

Comments
 (0)