Skip to content

Commit 260cdb2

Browse files
committed
llama-bench : add -fa,--flash-attn arg
1 parent 87968de commit 260cdb2

File tree

1 file changed

+27
-3
lines changed

1 file changed

+27
-3
lines changed

examples/llama-bench/llama-bench.cpp

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ struct cmd_params {
174174
std::vector<llama_split_mode> split_mode;
175175
std::vector<int> main_gpu;
176176
std::vector<bool> no_kv_offload;
177+
std::vector<bool> flash_attn;
177178
std::vector<std::vector<float>> tensor_split;
178179
std::vector<bool> use_mmap;
179180
std::vector<bool> embeddings;
@@ -195,6 +196,7 @@ static const cmd_params cmd_params_defaults = {
195196
/* split_mode */ {LLAMA_SPLIT_MODE_LAYER},
196197
/* main_gpu */ {0},
197198
/* no_kv_offload */ {false},
199+
/* flash_attn */ {false},
198200
/* tensor_split */ {std::vector<float>(llama_max_devices(), 0.0f)},
199201
/* use_mmap */ {true},
200202
/* embeddings */ {false},
@@ -220,6 +222,7 @@ static void print_usage(int /* argc */, char ** argv) {
220222
printf(" -sm, --split-mode <none|layer|row> (default: %s)\n", join(transform_to_str(cmd_params_defaults.split_mode, split_mode_str), ",").c_str());
221223
printf(" -mg, --main-gpu <i> (default: %s)\n", join(cmd_params_defaults.main_gpu, ",").c_str());
222224
printf(" -nkvo, --no-kv-offload <0|1> (default: %s)\n", join(cmd_params_defaults.no_kv_offload, ",").c_str());
225+
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
223226
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
224227
printf(" -embd, --embeddings <0|1> (default: %s)\n", join(cmd_params_defaults.embeddings, ",").c_str());
225228
printf(" -ts, --tensor-split <ts0/ts1/..> (default: 0)\n");
@@ -393,6 +396,13 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
393396
}
394397
auto p = split<bool>(argv[i], split_delim);
395398
params.no_kv_offload.insert(params.no_kv_offload.end(), p.begin(), p.end());
399+
} else if (arg == "-fa" || arg == "--flash-attn") {
400+
if (++i >= argc) {
401+
invalid_param = true;
402+
break;
403+
}
404+
auto p = split<bool>(argv[i], split_delim);
405+
params.flash_attn.insert(params.flash_attn.end(), p.begin(), p.end());
396406
} else if (arg == "-mmp" || arg == "--mmap") {
397407
if (++i >= argc) {
398408
invalid_param = true;
@@ -477,6 +487,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
477487
if (params.split_mode.empty()) { params.split_mode = cmd_params_defaults.split_mode; }
478488
if (params.main_gpu.empty()) { params.main_gpu = cmd_params_defaults.main_gpu; }
479489
if (params.no_kv_offload.empty()){ params.no_kv_offload = cmd_params_defaults.no_kv_offload; }
490+
if (params.flash_attn.empty()) { params.flash_attn = cmd_params_defaults.flash_attn; }
480491
if (params.tensor_split.empty()) { params.tensor_split = cmd_params_defaults.tensor_split; }
481492
if (params.use_mmap.empty()) { params.use_mmap = cmd_params_defaults.use_mmap; }
482493
if (params.embeddings.empty()) { params.embeddings = cmd_params_defaults.embeddings; }
@@ -498,6 +509,7 @@ struct cmd_params_instance {
498509
llama_split_mode split_mode;
499510
int main_gpu;
500511
bool no_kv_offload;
512+
bool flash_attn;
501513
std::vector<float> tensor_split;
502514
bool use_mmap;
503515
bool embeddings;
@@ -532,6 +544,7 @@ struct cmd_params_instance {
532544
cparams.type_k = type_k;
533545
cparams.type_v = type_v;
534546
cparams.offload_kqv = !no_kv_offload;
547+
cparams.flash_attn = flash_attn;
535548
cparams.embeddings = embeddings;
536549

537550
return cparams;
@@ -554,6 +567,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
554567
for (const auto & tk : params.type_k)
555568
for (const auto & tv : params.type_v)
556569
for (const auto & nkvo : params.no_kv_offload)
570+
for (const auto & fa : params.flash_attn)
557571
for (const auto & nt : params.n_threads) {
558572
for (const auto & n_prompt : params.n_prompt) {
559573
if (n_prompt == 0) {
@@ -572,6 +586,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
572586
/* .split_mode = */ sm,
573587
/* .main_gpu = */ mg,
574588
/* .no_kv_offload= */ nkvo,
589+
/* .flash_attn = */ fa,
575590
/* .tensor_split = */ ts,
576591
/* .use_mmap = */ mmp,
577592
/* .embeddings = */ embd,
@@ -596,6 +611,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
596611
/* .split_mode = */ sm,
597612
/* .main_gpu = */ mg,
598613
/* .no_kv_offload= */ nkvo,
614+
/* .flash_attn = */ fa,
599615
/* .tensor_split = */ ts,
600616
/* .use_mmap = */ mmp,
601617
/* .embeddings = */ embd,
@@ -633,6 +649,7 @@ struct test {
633649
llama_split_mode split_mode;
634650
int main_gpu;
635651
bool no_kv_offload;
652+
bool flash_attn;
636653
std::vector<float> tensor_split;
637654
bool use_mmap;
638655
bool embeddings;
@@ -657,6 +674,7 @@ struct test {
657674
split_mode = inst.split_mode;
658675
main_gpu = inst.main_gpu;
659676
no_kv_offload = inst.no_kv_offload;
677+
flash_attn = inst.flash_attn;
660678
tensor_split = inst.tensor_split;
661679
use_mmap = inst.use_mmap;
662680
embeddings = inst.embeddings;
@@ -731,7 +749,7 @@ struct test {
731749
"n_batch", "n_ubatch",
732750
"n_threads", "type_k", "type_v",
733751
"n_gpu_layers", "split_mode",
734-
"main_gpu", "no_kv_offload",
752+
"main_gpu", "no_kv_offload", "flash_attn",
735753
"tensor_split", "use_mmap", "embeddings",
736754
"n_prompt", "n_gen", "test_time",
737755
"avg_ns", "stddev_ns",
@@ -753,7 +771,7 @@ struct test {
753771
}
754772
if (field == "cuda" || field == "opencl" || field == "vulkan" || field == "kompute" || field == "metal" ||
755773
field == "gpu_blas" || field == "blas" || field == "sycl" ||field == "f16_kv" || field == "no_kv_offload" ||
756-
field == "use_mmap" || field == "embeddings") {
774+
field == "flash_attn" || field == "use_mmap" || field == "embeddings") {
757775
return BOOL;
758776
}
759777
if (field == "avg_ts" || field == "stddev_ts") {
@@ -787,7 +805,7 @@ struct test {
787805
std::to_string(n_batch), std::to_string(n_ubatch),
788806
std::to_string(n_threads), ggml_type_name(type_k), ggml_type_name(type_v),
789807
std::to_string(n_gpu_layers), split_mode_str(split_mode),
790-
std::to_string(main_gpu), std::to_string(no_kv_offload),
808+
std::to_string(main_gpu), std::to_string(no_kv_offload), std::to_string(flash_attn),
791809
tensor_split_str, std::to_string(use_mmap), std::to_string(embeddings),
792810
std::to_string(n_prompt), std::to_string(n_gen), test_time,
793811
std::to_string(avg_ns()), std::to_string(stdev_ns()),
@@ -955,6 +973,9 @@ struct markdown_printer : public printer {
955973
if (field == "no_kv_offload") {
956974
return "nkvo";
957975
}
976+
if (field == "flash_attn") {
977+
return "fa";
978+
}
958979
if (field == "use_mmap") {
959980
return "mmap";
960981
}
@@ -1001,6 +1022,9 @@ struct markdown_printer : public printer {
10011022
if (params.no_kv_offload.size() > 1 || params.no_kv_offload != cmd_params_defaults.no_kv_offload) {
10021023
fields.emplace_back("no_kv_offload");
10031024
}
1025+
if (params.flash_attn.size() > 1 || params.flash_attn != cmd_params_defaults.flash_attn) {
1026+
fields.emplace_back("flash_attn");
1027+
}
10041028
if (params.tensor_split.size() > 1 || params.tensor_split != cmd_params_defaults.tensor_split) {
10051029
fields.emplace_back("tensor_split");
10061030
}

0 commit comments

Comments
 (0)