Skip to content

Commit b44890d

Browse files
authored
model : disable SWA for Phi models (#13676)
* model : disable SWA for Phi models ggml-ci * model : update warning message * model : print warning only if n_swa > 0 * model : fix typo
1 parent 3398305 commit b44890d

File tree

2 files changed

+30
-42
lines changed

2 files changed

+30
-42
lines changed

src/llama-graph.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1236,8 +1236,7 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
12361236
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
12371237

12381238
{
1239-
GGML_ASSERT(hparams.n_swa_pattern == 1 && "Use llama_kv_cache_unified_iswa for SWA");
1240-
GGML_ASSERT(hparams.n_swa == 0 && "Use llama_kv_cache_unified_iswa for SWA");
1239+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
12411240

12421241
const auto n_kv = kv_self->get_n();
12431242

@@ -1312,8 +1311,8 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
13121311
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13131312
}
13141313

1315-
if (hparams.n_swa_pattern > 1) {
1316-
GGML_ASSERT(hparams.n_swa > 0 && "Use llama_kv_cache_unified for non-SWA");
1314+
{
1315+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
13171316

13181317
const auto n_kv = kv_self->get_kv_swa()->get_n();
13191318

src/llama-model.cpp

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -853,43 +853,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
853853
default: type = LLM_TYPE_UNKNOWN;
854854
}
855855

856-
// for backward compatibility ; see: https://github.com/ggerganov/llama.cpp/pull/8931
857-
if ((hparams.n_layer == 32 || hparams.n_layer == 40) && hparams.n_ctx_train == 4096) {
858-
// default value for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct
859-
LLAMA_LOG_WARN("%s: assuming n_swa = 2047 for Phi-3-mini-4k-instruct and Phi-3-medium-4k-instruct\n", __func__);
856+
const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
860857

861-
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
862-
863-
hparams.n_swa = 2047;
864-
} else if (hparams.n_layer == 32 && hparams.n_head_kv(0) == 32 && hparams.n_ctx_train == 131072) {
865-
// default value for Phi-3-mini-128k-instruct
866-
LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-mini-128k-instruct\n", __func__);
867-
868-
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
869-
870-
hparams.n_swa = hparams.n_ctx_train;
871-
hparams.n_swa_pattern = 1;
872-
} else if (hparams.n_layer == 40 && hparams.n_ctx_train == 131072) {
873-
// default value for Phi-3-medium-128k-instruct
874-
LLAMA_LOG_WARN("%s: assuming no SWA for Phi-3-medium-128k-instruct\n", __func__);
858+
if (found_swa && hparams.n_swa > 0) {
859+
LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n",
860+
__func__, "https://github.com/ggml-org/llama.cpp/pull/13676");
875861

862+
// TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern`
876863
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
877864

878-
hparams.n_swa = hparams.n_ctx_train;
879-
hparams.n_swa_pattern = 1;
880-
}
881-
882-
bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
883-
if (!found_swa && hparams.n_swa == 0) {
884-
throw std::runtime_error("invalid value for sliding_window");
885-
}
886-
887-
if (hparams.n_swa > hparams.n_ctx_train) {
888-
LLAMA_LOG_WARN("%s: unexpected n_swa: %d >= %d, disabling SWA\n", __func__, hparams.n_swa, hparams.n_ctx_train);
889-
890-
hparams.swa_type = LLAMA_SWA_TYPE_NONE;
891-
892-
hparams.n_swa = hparams.n_ctx_train;
865+
hparams.n_swa = 0;
893866
hparams.n_swa_pattern = 1;
894867
}
895868
} break;
@@ -7368,8 +7341,9 @@ struct llm_build_phi2 : public llm_graph_context {
73687341
}
73697342
};
73707343

7371-
struct llm_build_phi3_iswa : public llm_graph_context {
7372-
llm_build_phi3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7344+
template<bool iswa>
7345+
struct llm_build_phi3 : public llm_graph_context {
7346+
llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
73737347
const int64_t n_embd_head = hparams.n_embd_head_v;
73747348
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
73757349

@@ -7383,7 +7357,14 @@ struct llm_build_phi3_iswa : public llm_graph_context {
73837357
// inp_pos - contains the positions
73847358
ggml_tensor * inp_pos = build_inp_pos();
73857359

7386-
auto * inp_attn = build_attn_inp_kv_unified_iswa();
7360+
using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
7361+
inp_attn_type * inp_attn = nullptr;
7362+
7363+
if constexpr (iswa) {
7364+
inp_attn = build_attn_inp_kv_unified_iswa();
7365+
} else {
7366+
inp_attn = build_attn_inp_kv_unified();
7367+
}
73877368

73887369
for (int il = 0; il < n_layer; ++il) {
73897370
auto * residual = inpL;
@@ -13232,7 +13213,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1323213213

1323313214
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
1323413215

13235-
if (hparams.n_swa > 0) {
13216+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13217+
GGML_ASSERT(hparams.n_swa_pattern != 1);
13218+
1323613219
res = new llama_kv_cache_unified_iswa(
1323713220
*this,
1323813221
params.type_k,
@@ -13245,6 +13228,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1324513228
cparams.n_batch,
1324613229
padding);
1324713230
} else {
13231+
GGML_ASSERT(hparams.n_swa_pattern == 1);
13232+
1324813233
res = new llama_kv_cache_unified(
1324913234
*this,
1325013235
nullptr,
@@ -13353,7 +13338,11 @@ llm_graph_result_ptr llama_model::build_graph(
1335313338
case LLM_ARCH_PHI3:
1335413339
case LLM_ARCH_PHIMOE:
1335513340
{
13356-
llm = std::make_unique<llm_build_phi3_iswa>(*this, params, gf);
13341+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
13342+
llm = std::make_unique<llm_build_phi3<true>> (*this, params, gf);
13343+
} else {
13344+
llm = std::make_unique<llm_build_phi3<false>>(*this, params, gf);
13345+
}
1335713346
} break;
1335813347
case LLM_ARCH_PLAMO:
1335913348
{

0 commit comments

Comments
 (0)