Skip to content

Commit 678d7b1

Browse files
committed
no more scan loop in n_kv_tokens()
1 parent 536bea4 commit 678d7b1

File tree

2 files changed

+16
-29
lines changed

2 files changed

+16
-29
lines changed

tools/server/server.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2099,10 +2099,11 @@ struct server_context {
20992099
}
21002100

21012101
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
2102-
int cur_lcs_len = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
2102+
auto common_pos = slot.cache_tokens.get_common_prefix(task.prompt_tokens);
2103+
int cur_lcs_len = common_pos.first; // position, not tokens
21032104

21042105
// fraction of the common subsequence length compared to the current slot's prompt length
2105-
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.n_kv_tokens());
2106+
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.n_pos());
21062107

21072108
// select the current slot if the criteria match
21082109
if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
@@ -3094,8 +3095,9 @@ struct server_context {
30943095

30953096
if (slot.params.cache_prompt) {
30963097
// reuse any previously computed tokens that are common with the new prompt
3097-
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
3098-
slot.n_kv_tokens = slot.cache_tokens.n_kv_tokens(slot.n_past);
3098+
auto common_pos = slot.cache_tokens.get_common_prefix(prompt_tokens);
3099+
slot.n_past = common_pos.first;
3100+
slot.n_kv_tokens = common_pos.second;
30993101

31003102
// reuse chunks from the cached prompt by shifting their KV cache in the new position
31013103
if (params_base.n_cache_reuse > 0) {

tools/server/utils.hpp

Lines changed: 10 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,26 +1162,8 @@ struct server_tokens {
11621162
tokens[pos] = id;
11631163
}
11641164

1165-
// if end_pos == -1, we count all positions
1166-
size_t n_kv_tokens(llama_pos end_pos = -1) const {
1167-
if (end_pos == -1) {
1168-
return n_kv;
1169-
} else {
1170-
size_t res = 0;
1171-
for (llama_pos i = 0; i < end_pos;) {
1172-
auto & t = tokens[i];
1173-
if (t == LLAMA_TOKEN_NULL) {
1174-
auto & chunk = find_chunk(i);
1175-
auto img_tokens = mtmd_input_chunk_get_tokens_image(chunk.get());
1176-
res += mtmd_image_tokens_get_n_tokens(img_tokens);
1177-
i += mtmd_image_tokens_get_n_pos(img_tokens);
1178-
} else {
1179-
res++;
1180-
i++;
1181-
}
1182-
}
1183-
return res;
1184-
}
1165+
size_t n_kv_tokens() const {
1166+
return n_kv;
11851167
}
11861168

11871169
llama_pos n_pos() const {
@@ -1239,9 +1221,10 @@ struct server_tokens {
12391221
return common_detokenize(ctx, text_tokens, special);
12401222
}
12411223

1242-
// returns the position of the first token that is different
1243-
size_t get_common_prefix(const server_tokens & b) const {
1224+
// returns pair of <position, n_kv_tokens>
1225+
std::pair<llama_pos, size_t> get_common_prefix(const server_tokens & b) const {
12441226
size_t max_idx = std::min(tokens.size(), b.tokens.size());
1227+
size_t n_tok = 0;
12451228
for (size_t i = 0; i < max_idx; ++i) {
12461229
auto & ai = tokens[i];
12471230
auto & bi = b.tokens[i];
@@ -1260,17 +1243,19 @@ struct server_tokens {
12601243
if (ai_id == bi_id && a_pos == b_pos) {
12611244
GGML_ASSERT(a_pos > 0 && "Invalid image token"); // should never happen
12621245
i += a_pos - 1; // will be +1 by the for loop
1246+
n_tok += mtmd_image_tokens_get_n_tokens(a_img);
12631247
continue;
12641248
} else {
1265-
return i;
1249+
return {i, n_tok};
12661250
}
12671251
} else if (ai == bi) {
1252+
n_tok++;
12681253
continue;
12691254
} else {
1270-
return i;
1255+
return {i, n_tok};
12711256
}
12721257
}
1273-
return max_idx; // all tokens are equal
1258+
return {max_idx, n_tok}; // all tokens are equal
12741259
}
12751260

12761261
// make sure all text tokens are within the vocab range

0 commit comments

Comments
 (0)