@@ -1162,26 +1162,8 @@ struct server_tokens {
1162
1162
tokens[pos] = id;
1163
1163
}
1164
1164
1165
- // if end_pos == -1, we count all positions
1166
- size_t n_kv_tokens (llama_pos end_pos = -1 ) const {
1167
- if (end_pos == -1 ) {
1168
- return n_kv;
1169
- } else {
1170
- size_t res = 0 ;
1171
- for (llama_pos i = 0 ; i < end_pos;) {
1172
- auto & t = tokens[i];
1173
- if (t == LLAMA_TOKEN_NULL) {
1174
- auto & chunk = find_chunk (i);
1175
- auto img_tokens = mtmd_input_chunk_get_tokens_image (chunk.get ());
1176
- res += mtmd_image_tokens_get_n_tokens (img_tokens);
1177
- i += mtmd_image_tokens_get_n_pos (img_tokens);
1178
- } else {
1179
- res++;
1180
- i++;
1181
- }
1182
- }
1183
- return res;
1184
- }
1165
+ size_t n_kv_tokens () const {
1166
+ return n_kv;
1185
1167
}
1186
1168
1187
1169
llama_pos n_pos () const {
@@ -1239,9 +1221,10 @@ struct server_tokens {
1239
1221
return common_detokenize (ctx, text_tokens, special);
1240
1222
}
1241
1223
1242
- // returns the position of the first token that is different
1243
- size_t get_common_prefix (const server_tokens & b) const {
1224
+ // returns pair of <position, n_kv_tokens>
1225
+ std::pair<llama_pos, size_t > get_common_prefix (const server_tokens & b) const {
1244
1226
size_t max_idx = std::min (tokens.size (), b.tokens .size ());
1227
+ size_t n_tok = 0 ;
1245
1228
for (size_t i = 0 ; i < max_idx; ++i) {
1246
1229
auto & ai = tokens[i];
1247
1230
auto & bi = b.tokens [i];
@@ -1260,17 +1243,19 @@ struct server_tokens {
1260
1243
if (ai_id == bi_id && a_pos == b_pos) {
1261
1244
GGML_ASSERT (a_pos > 0 && " Invalid image token" ); // should never happen
1262
1245
i += a_pos - 1 ; // will be +1 by the for loop
1246
+ n_tok += mtmd_image_tokens_get_n_tokens (a_img);
1263
1247
continue ;
1264
1248
} else {
1265
- return i ;
1249
+ return {i, n_tok} ;
1266
1250
}
1267
1251
} else if (ai == bi) {
1252
+ n_tok++;
1268
1253
continue ;
1269
1254
} else {
1270
- return i ;
1255
+ return {i, n_tok} ;
1271
1256
}
1272
1257
}
1273
- return max_idx; // all tokens are equal
1258
+ return { max_idx, n_tok} ; // all tokens are equal
1274
1259
}
1275
1260
1276
1261
// make sure all text tokens are within the vocab range
0 commit comments