From 74dc729c0be0edfe629c5e4542d055e28a2b852d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 11 Dec 2024 14:38:57 +0100 Subject: [PATCH 01/19] server : fix logprobs, make it openai-compatible --- examples/server/server.cpp | 151 +++++++++++------- .../server/tests/unit/test_chat_completion.py | 62 ++++++- examples/server/tests/unit/test_completion.py | 43 ++++- examples/server/utils.hpp | 30 ++++ 4 files changed, 217 insertions(+), 69 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8cb992470a302..2c94318b4e27c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -342,6 +342,11 @@ struct server_task { } } + if (params.sampling.n_probs > 0 && params.cache_prompt) { + SRV_WRN("cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n", params.sampling.n_probs); + params.cache_prompt = false; + } + std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; params.oaicompat_model = json_value(data, "model", model_name); @@ -416,6 +421,7 @@ inline std::string stop_type_to_str(stop_type type) { struct completion_token_output { llama_token tok; + float prob; std::string text_to_send; struct token_prob { llama_token tok; @@ -427,9 +433,13 @@ struct completion_token_output { json to_json() const { json probs_for_token = json::array(); for (const auto & p : probs) { + std::string tok_str(p.tok_str); + tok_str.resize(validate_utf8(tok_str)); probs_for_token.push_back(json { - {"tok_str", p.tok_str}, - {"prob", p.prob}, + {"id", p.tok}, + {"token", tok_str}, + {"bytes", str_to_bytes(p.tok_str)}, + {"logprob", p.prob}, }); } return probs_for_token; @@ -437,15 +447,27 @@ struct completion_token_output { static json probs_vector_to_json(const std::vector & probs) { json out = json::array(); - for (const auto & prob : probs) { - const std::string tok_str = prob.text_to_send; + for (const auto & it : probs) { + std::string tok_str(it.text_to_send); + tok_str.resize(validate_utf8(tok_str)); out.push_back(json { - {"content", tok_str}, - {"probs", prob.to_json()}, + {"id", it.tok}, + {"token", tok_str}, + {"logprob", it.prob}, + {"bytes", str_to_bytes(it.text_to_send)}, + {"top_logprobs", it.to_json()}, }); } return out; } + + static std::vector str_to_bytes(const std::string & str) { + std::vector bytes; + for (unsigned char c : str) { + bytes.push_back(c); + } + return bytes; + } }; struct server_task_result_cmpl_final : server_task_result { @@ -506,7 +528,7 @@ struct server_task_result_cmpl_final : server_task_result { {"tokens_cached", n_tokens_cached}, {"timings", timings.to_json()}, }; - if (!probs_output.empty()) { + if (!stream && !probs_output.empty()) { res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); } return res; @@ -518,19 +540,25 @@ struct server_task_result_cmpl_final : server_task_result { finish_reason = "stop"; } - json choices = json::array({json{ + json choice = json{ {"finish_reason", finish_reason}, {"index", 0}, {"message", json{ {"content", content}, {"role", "assistant"} } - }}}); + }}; + + if (!stream && probs_output.size() > 0) { + choice["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json(probs_output)}, + }; + } std::time_t t = std::time(0); json res = json { - {"choices", choices}, + {"choices", json::array({choice})}, {"created", t}, {"model", oaicompat_model}, {"object", "chat.completion"}, @@ -560,12 +588,14 @@ struct server_task_result_cmpl_final : server_task_result { finish_reason = "stop"; } - json choices = json::array({json{{"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}}}); + json choice = json{ + {"finish_reason", finish_reason}, + {"index", 0}, + {"delta", json::object()} + }; json ret = json { - {"choices", choices}, + {"choices", json::array({choice})}, {"created", t}, {"id", oaicompat_cmpl_id}, {"model", oaicompat_model}, @@ -592,7 +622,7 @@ struct server_task_result_cmpl_partial : server_task_result { int32_t n_decoded; int32_t n_prompt_tokens; - std::vector probs_output; + completion_token_output prob_output; result_timings timings; // OAI-compat fields @@ -628,8 +658,8 @@ struct server_task_result_cmpl_partial : server_task_result { if (timings.prompt_n > 0) { res.push_back({"timings", timings.to_json()}); } - if (!probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); + if (!prob_output.probs.empty()) { + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}); } return res; } @@ -681,6 +711,14 @@ struct server_task_result_cmpl_partial : server_task_result { }}); } + GGML_ASSERT(choices.size() >= 1); + + if (prob_output.probs.size() > 0) { + choices[0]["logprobs"] = json{ + {"content", completion_token_output::probs_vector_to_json({prob_output})}, + }; + } + json ret = json { {"choices", choices}, {"created", t}, @@ -951,7 +989,6 @@ struct server_slot { // stats size_t n_sent_text = 0; // number of sent text character - size_t n_sent_token_probs = 0; int64_t t_start_process_prompt; int64_t t_start_generation; @@ -973,7 +1010,6 @@ struct server_slot { stopping_word = ""; n_past = 0; n_sent_text = 0; - n_sent_token_probs = 0; task_type = SERVER_TASK_TYPE_COMPLETION; generated_token_probs.clear(); @@ -1713,7 +1749,7 @@ struct server_context { bool process_token(completion_token_output & result, server_slot & slot) { // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = common_token_to_piece(ctx, result.tok, params_base.special); + const std::string token_str = result.text_to_send; slot.sampled = result.tok; // search stop word and delete it @@ -1721,26 +1757,7 @@ struct server_context { slot.has_next_token = true; // check if there is incomplete UTF-8 character at the end - bool incomplete = false; - for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { - unsigned char c = slot.generated_text[slot.generated_text.size() - i]; - if ((c & 0xC0) == 0x80) { - // continuation byte: 10xxxxxx - continue; - } - if ((c & 0xE0) == 0xC0) { - // 2-byte character: 110xxxxx ... - incomplete = i < 2; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character: 1110xxxx ... - incomplete = i < 3; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character: 11110xxx ... - incomplete = i < 4; - } - // else 1-byte character or invalid byte - break; - } + bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); if (!incomplete) { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); @@ -1869,6 +1886,29 @@ struct server_context { return slot.has_next_token; // continue } + void populate_token_probs(const server_slot & slot, completion_token_output & result) { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; + + // set prob for the sampled token + for (size_t i = 0; i < max_probs; ++i) { + if (result.tok == cur_p->data[i].id) { + result.prob = cur_p->data[i].p; + break; + } + } + + // set probs for the top n tokens + for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) { + auto tok_id = cur_p->data[i].id; + result.probs.push_back({ + tok_id, + tokens_to_output_formatted_string(ctx, tok_id), + cur_p->data[i].p, + }); + } + } + void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { send_error(task.id, error, type); } @@ -1906,17 +1946,7 @@ struct server_context { // populate res.probs_output if (slot.params.sampling.n_probs > 0) { - const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false); - - const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size()); - const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size()); - - std::vector probs_output; - if (probs_pos < probs_stop_pos) { - res->probs_output = std::vector( - slot.generated_token_probs.begin() + probs_pos, - slot.generated_token_probs.begin() + probs_stop_pos); - } + res->prob_output = tkn; // copy the token probs } // populate timings if this is final response or timings_per_token is enabled @@ -2747,17 +2777,12 @@ struct server_context { slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3; completion_token_output result; - result.tok = id; + result.tok = id; + result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.prob = 1.0f; // set later - const auto * cur_p = common_sampler_get_candidates(slot.smpl); - - for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) { - auto tok_id = cur_p->data[i].id; - result.probs.push_back({ - tok_id, - tokens_to_output_formatted_string(ctx, tok_id), - i >= cur_p->size ? 0.0f : cur_p->data[i].p, - }); + if (slot.params.sampling.n_probs > 0) { + populate_token_probs(slot, result); } if (!process_token(result, slot)) { @@ -2841,7 +2866,9 @@ struct server_context { for (size_t i = 0; i < ids.size(); ++i) { completion_token_output result; - result.tok = ids[i]; + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); + result.prob = 1.0f; // set later if (!process_token(result, slot)) { // release slot because of stop condition diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 6573cc17f7b87..299472fa46162 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -92,7 +92,6 @@ def test_chat_completion_with_openai_library(): seed=42, temperature=0.8, ) - print(res) assert res.choices[0].finish_reason == "length" assert res.choices[0].message.content is not None assert match_regex("(Suddenly)+", res.choices[0].message.content) @@ -163,3 +162,64 @@ def test_chat_completion_with_timings_per_token(): assert "predicted_per_second" in data["timings"] assert "predicted_n" in data["timings"] assert data["timings"]["predicted_n"] <= 10 + + +def test_logprobs(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=5, + logprobs=True, + top_logprobs=10, + ) + output_text = res.choices[0].message.content + aggregated_text = '' + assert res.choices[0].logprobs is not None + assert res.choices[0].logprobs.content is not None + for token in res.choices[0].logprobs.content: + aggregated_text += token.token + assert 0.0 <= token.logprob <= 1.0 + assert token.bytes is not None and len(token.bytes) > 0 + assert len(token.top_logprobs) > 0 + assert aggregated_text == output_text + + +def test_logprobs_stream(): + global server + server.start() + client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") + res = client.chat.completions.create( + model="gpt-3.5-turbo-instruct", + temperature=0.0, + messages=[ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + max_tokens=5, + logprobs=True, + top_logprobs=10, + stream=True, + ) + output_text = '' + aggregated_text = '' + for data in res: + choice = data.choices[0] + if choice.finish_reason is None: + if choice.delta.content: + output_text += choice.delta.content + assert choice.logprobs is not None + assert choice.logprobs.content is not None + for token in choice.logprobs.content: + aggregated_text += token.token + assert 0.0 <= token.logprob <= 1.0 + assert token.bytes is not None and len(token.bytes) > 0 + assert token.top_logprobs is not None + assert len(token.top_logprobs) > 0 + assert aggregated_text == output_text diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 7f4f9cd038be4..4c89ee3ee0c0c 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -260,9 +260,40 @@ def test_n_probs(): assert "completion_probabilities" in res.body assert len(res.body["completion_probabilities"]) == 5 for tok in res.body["completion_probabilities"]: - assert "probs" in tok - assert len(tok["probs"]) == 10 - for prob in tok["probs"]: - assert "prob" in prob - assert "tok_str" in prob - assert 0.0 <= prob["prob"] <= 1.0 + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0 + assert "bytes" in tok and len(tok["bytes"]) > 0 + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0 + assert "bytes" in prob and len(prob["bytes"]) > 0 + + +def test_n_probs_stream(): + global server + server.start() + res = server.make_stream_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "stream": True, + }) + for data in res: + if data["stop"] == False: + assert "completion_probabilities" in data + assert len(data["completion_probabilities"]) == 1 + for tok in data["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0 + assert "bytes" in tok and len(tok["bytes"]) > 0 + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0 + assert "bytes" in prob and len(prob["bytes"]) > 0 diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 8f545aea52dc4..3750cf758c866 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -170,6 +170,36 @@ static std::vector tokenize_input_prompts(llama_context * ctx, con return result; } +// return the last index of character that can form a valid string +// if the last character is potentially cut in half, return the index before the cut +// if validate_utf8(text) == text.size(), then the whole text is valid utf8 +static size_t validate_utf8(const std::string& text) { + size_t len = text.size(); + if (len == 0) return 0; + + // Check the last few bytes to see if a multi-byte character is cut off + for (size_t i = 1; i <= 4 && i <= len; ++i) { + unsigned char c = text[len - i]; + // Check for start of a multi-byte sequence from the end + if ((c & 0xE0) == 0xC0) { + // 2-byte character start: 110xxxxx + // Needs at least 2 bytes + if (i < 2) return len - i; + } else if ((c & 0xF0) == 0xE0) { + // 3-byte character start: 1110xxxx + // Needs at least 3 bytes + if (i < 3) return len - i; + } else if ((c & 0xF8) == 0xF0) { + // 4-byte character start: 11110xxx + // Needs at least 4 bytes + if (i < 4) return len - i; + } + } + + // If no cut-off multi-byte character is found, return full length + return len; +} + // // template utils // From 7828013689dfaf8e80d6dfe36b080e5ca80206c9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 11 Dec 2024 14:47:49 +0100 Subject: [PATCH 02/19] update docs --- examples/server/README.md | 39 ++++++++++++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 7 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 6294f541fc7d7..4636a5f42d3c1 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -343,6 +343,10 @@ node index.js ### POST `/completion`: Given a `prompt`, it returns the predicted completion. +> [!IMPORTANT] +> +> This endpoint is **not** OAI-compatible + *Options:* `prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, if `cache_prompt` is `true`, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. A `BOS` token is inserted at the start, if all of the following conditions are true: @@ -448,27 +452,48 @@ These words will not be included in the completion, so make sure to add them to - Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. -- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has the following structure: +- `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements: ```json { - "content": "", - "probs": [ + "content": "", + ... + "completion_probabilities": [ { + "id": , "prob": float, - "tok_str": "" + "token": "", + "bytes": [int, int, ...], + "top_logprobs": [ + { + "id": , + "prob": float, + "token": "", + "bytes": [int, int, ...], + }, + { + "id": , + "prob": float, + "token": "", + "bytes": [int, int, ...], + }, + ... + ] }, { + "id": , "prob": float, - "tok_str": "" + "token": "", + "bytes": [int, int, ...], + "top_logprobs": [ + ... + ] }, ... ] }, ``` -Notice that each `probs` is an array of length `n_probs`. - - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.). From 01afafef93ad32b8be48987a0b86649bf176e39f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 11:16:12 +0100 Subject: [PATCH 03/19] add std::log --- examples/server/server.cpp | 14 +++++++------- examples/server/tests/unit/test_chat_completion.py | 4 ++-- examples/server/tests/unit/test_completion.py | 8 ++++---- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2c94318b4e27c..5a3f5d889ccd5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -342,11 +342,6 @@ struct server_task { } } - if (params.sampling.n_probs > 0 && params.cache_prompt) { - SRV_WRN("cache_prompt is not compatible with n_probs > 0 (current value = %d), disabling cache_prompt.\n", params.sampling.n_probs); - params.cache_prompt = false; - } - std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias; params.oaicompat_model = json_value(data, "model", model_name); @@ -439,7 +434,7 @@ struct completion_token_output { {"id", p.tok}, {"token", tok_str}, {"bytes", str_to_bytes(p.tok_str)}, - {"logprob", p.prob}, + {"logprob", logarithm(p.prob)}, }); } return probs_for_token; @@ -453,7 +448,7 @@ struct completion_token_output { out.push_back(json { {"id", it.tok}, {"token", tok_str}, - {"logprob", it.prob}, + {"logprob", logarithm(it.prob)}, {"bytes", str_to_bytes(it.text_to_send)}, {"top_logprobs", it.to_json()}, }); @@ -461,6 +456,11 @@ struct completion_token_output { return out; } + static float logarithm(float x) { + // nlohmann::json converts -inf to null, so we need to prevent that + return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); + } + static std::vector str_to_bytes(const std::string & str) { std::vector bytes; for (unsigned char c : str) { diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 299472fa46162..ce94398d69a0f 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -185,7 +185,7 @@ def test_logprobs(): assert res.choices[0].logprobs.content is not None for token in res.choices[0].logprobs.content: aggregated_text += token.token - assert 0.0 <= token.logprob <= 1.0 + assert token.logprob <= 0.0 assert token.bytes is not None and len(token.bytes) > 0 assert len(token.top_logprobs) > 0 assert aggregated_text == output_text @@ -218,7 +218,7 @@ def test_logprobs_stream(): assert choice.logprobs.content is not None for token in choice.logprobs.content: aggregated_text += token.token - assert 0.0 <= token.logprob <= 1.0 + assert token.logprob <= 0.0 assert token.bytes is not None and len(token.bytes) > 0 assert token.top_logprobs is not None assert len(token.top_logprobs) > 0 diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 4c89ee3ee0c0c..9e91c5da24c95 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -262,13 +262,13 @@ def test_n_probs(): for tok in res.body["completion_probabilities"]: assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str - assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0 + assert "logprob" in tok and tok["logprob"] <= 0.0 assert "bytes" in tok and len(tok["bytes"]) > 0 assert len(tok["top_logprobs"]) == 10 for prob in tok["top_logprobs"]: assert "id" in prob and prob["id"] > 0 assert "token" in prob and type(prob["token"]) == str - assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0 + assert "logprob" in prob and prob["logprob"] <= 0.0 assert "bytes" in prob and len(prob["bytes"]) > 0 @@ -289,11 +289,11 @@ def test_n_probs_stream(): for tok in data["completion_probabilities"]: assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str - assert "logprob" in tok and 0.0 <= tok["logprob"] <= 1.0 + assert "logprob" in tok and tok["logprob"] <= 0.0 assert "bytes" in tok and len(tok["bytes"]) > 0 assert len(tok["top_logprobs"]) == 10 for prob in tok["top_logprobs"]: assert "id" in prob and prob["id"] > 0 assert "token" in prob and type(prob["token"]) == str - assert "logprob" in prob and 0.0 <= prob["logprob"] <= 1.0 + assert "logprob" in prob and prob["logprob"] <= 0.0 assert "bytes" in prob and len(prob["bytes"]) > 0 From cc90cdbc33fe769b6e99a09597947befd43e2903 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 13:44:30 +0100 Subject: [PATCH 04/19] return pre-sampling p --- examples/server/server.cpp | 32 +++++++++++++------------------- examples/server/utils.hpp | 30 ++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5a3f5d889ccd5..d50ae2dc7ab75 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1886,25 +1886,17 @@ struct server_context { return slot.has_next_token; // continue } - void populate_token_probs(const server_slot & slot, completion_token_output & result) { - const auto * cur_p = common_sampler_get_candidates(slot.smpl); - const size_t max_probs = cur_p->size; - - // set prob for the sampled token - for (size_t i = 0; i < max_probs; ++i) { - if (result.tok == cur_p->data[i].id) { - result.prob = cur_p->data[i].p; - break; - } - } + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) { + std::vector cur = get_token_probabilities(ctx, idx); + int n_vocab = llama_n_vocab(llama_get_model(ctx)); - // set probs for the top n tokens - for (size_t i = 0; i < std::min(max_probs, (size_t) slot.params.sampling.n_probs); ++i) { - auto tok_id = cur_p->data[i].id; + // only take at most n_probs tokens + const int n_probs = slot.params.sampling.n_probs; + for (int i = 0; i < std::min(n_probs, n_vocab); i++) { result.probs.push_back({ - tok_id, - tokens_to_output_formatted_string(ctx, tok_id), - cur_p->data[i].p, + cur[i].id, + common_detokenize(ctx, {cur[i].id}, special), + cur[i].p }); } } @@ -2758,7 +2750,9 @@ struct server_context { continue; // continue loop of slots } - llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i); + const int tok_idx = slot.i_batch - i; + + llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); slot.i_batch = -1; @@ -2782,7 +2776,7 @@ struct server_context { result.prob = 1.0f; // set later if (slot.params.sampling.n_probs > 0) { - populate_token_probs(slot, result); + populate_token_probs(slot, result, params_base.special, tok_idx); } if (!process_token(result, slot)) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 3750cf758c866..60c7656d32d2a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -694,3 +694,33 @@ static json format_logit_bias(const std::vector & logit_bias) static std::string safe_json_to_str(json data) { return data.dump(-1, ' ', false, json::error_handler_t::replace); } + +static std::vector get_token_probabilities(llama_context * ctx, int idx) { + std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + // apply softmax + float max_l = cur[0].logit; + float cum_sum = 0.0f; + for (size_t i = 0; i < cur.size(); ++i) { + float p = expf(cur[i].logit - max_l); + cur[i].p = p; + cum_sum += p; + } + for (size_t i = 0; i < cur.size(); ++i) { + cur[i].p /= cum_sum; + } + + // sort tokens by probability + std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.p > b.p; + }); + + return cur; +} From 29c1495afa84012132531218efe225d2c78b2803 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 13:47:43 +0100 Subject: [PATCH 05/19] sort before apply softmax --- examples/server/utils.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 60c7656d32d2a..38c00b7db90b2 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -705,6 +705,11 @@ static std::vector get_token_probabilities(llama_context * ctx cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } + // sort tokens by logits + std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + // apply softmax float max_l = cur[0].logit; float cum_sum = 0.0f; @@ -717,10 +722,5 @@ static std::vector get_token_probabilities(llama_context * ctx cur[i].p /= cum_sum; } - // sort tokens by probability - std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { - return a.p > b.p; - }); - return cur; } From 396ade0b02342509c50930f0eda70c0808e10683 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 13:50:42 +0100 Subject: [PATCH 06/19] add comment --- examples/server/server.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d50ae2dc7ab75..1a296aa00dac3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2864,6 +2864,8 @@ struct server_context { result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); result.prob = 1.0f; // set later + // TODO: set result.probs + if (!process_token(result, slot)) { // release slot because of stop condition slot.release(); From 22b72c8574b5ea709fd559411a447c8b3033bd02 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 14:05:29 +0100 Subject: [PATCH 07/19] fix test --- examples/server/tests/unit/test_chat_completion.py | 4 ++-- examples/server/tests/unit/test_completion.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index ce94398d69a0f..0fa1a17c1f50a 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -186,7 +186,7 @@ def test_logprobs(): for token in res.choices[0].logprobs.content: aggregated_text += token.token assert token.logprob <= 0.0 - assert token.bytes is not None and len(token.bytes) > 0 + assert token.bytes is not None assert len(token.top_logprobs) > 0 assert aggregated_text == output_text @@ -219,7 +219,7 @@ def test_logprobs_stream(): for token in choice.logprobs.content: aggregated_text += token.token assert token.logprob <= 0.0 - assert token.bytes is not None and len(token.bytes) > 0 + assert token.bytes is not None assert token.top_logprobs is not None assert len(token.top_logprobs) > 0 assert aggregated_text == output_text diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 9e91c5da24c95..7b33ec531d38f 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -263,13 +263,13 @@ def test_n_probs(): assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str assert "logprob" in tok and tok["logprob"] <= 0.0 - assert "bytes" in tok and len(tok["bytes"]) > 0 + assert "bytes" in tok and type(tok["bytes"]) == list assert len(tok["top_logprobs"]) == 10 for prob in tok["top_logprobs"]: assert "id" in prob and prob["id"] > 0 assert "token" in prob and type(prob["token"]) == str assert "logprob" in prob and prob["logprob"] <= 0.0 - assert "bytes" in prob and len(prob["bytes"]) > 0 + assert "bytes" in prob and type(prob["bytes"]) == list def test_n_probs_stream(): @@ -290,10 +290,10 @@ def test_n_probs_stream(): assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str assert "logprob" in tok and tok["logprob"] <= 0.0 - assert "bytes" in tok and len(tok["bytes"]) > 0 + assert "bytes" in tok and type(tok["bytes"]) == list assert len(tok["top_logprobs"]) == 10 for prob in tok["top_logprobs"]: assert "id" in prob and prob["id"] > 0 assert "token" in prob and type(prob["token"]) == str assert "logprob" in prob and prob["logprob"] <= 0.0 - assert "bytes" in prob and len(prob["bytes"]) > 0 + assert "bytes" in prob and type(prob["bytes"]) == list From ed7f2d5756933edff163691b2a455e8847ac7651 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 12 Dec 2024 14:05:38 +0100 Subject: [PATCH 08/19] set p for sampled token --- examples/server/server.cpp | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1a296aa00dac3..95bd531b3878b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1889,15 +1889,26 @@ struct server_context { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) { std::vector cur = get_token_probabilities(ctx, idx); int n_vocab = llama_n_vocab(llama_get_model(ctx)); + size_t n_probs = slot.params.sampling.n_probs; - // only take at most n_probs tokens - const int n_probs = slot.params.sampling.n_probs; - for (int i = 0; i < std::min(n_probs, n_vocab); i++) { + bool found_sampled_tok = false; + result.probs.reserve(n_probs); + for (int i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + found_sampled_tok = true; + result.prob = cur[i].p; + } + // set probability for top n_probs tokens result.probs.push_back({ cur[i].id, common_detokenize(ctx, {cur[i].id}, special), cur[i].p }); + // break if we have all the necessary data + if (result.probs.size() == n_probs && found_sampled_tok) { + break; + } } } From 06bb38e75dbd3a706e99edcccb14d04318b4a91b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 13 Dec 2024 08:50:47 +0100 Subject: [PATCH 09/19] update docs --- examples/server/README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 4636a5f42d3c1..e04530713dc8d 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -461,19 +461,19 @@ These words will not be included in the completion, so make sure to add them to "completion_probabilities": [ { "id": , - "prob": float, + "logprob": float, "token": "", "bytes": [int, int, ...], "top_logprobs": [ { "id": , - "prob": float, + "logprob": float, "token": "", "bytes": [int, int, ...], }, { "id": , - "prob": float, + "logprob": float, "token": "", "bytes": [int, int, ...], }, @@ -482,7 +482,7 @@ These words will not be included in the completion, so make sure to add them to }, { "id": , - "prob": float, + "logprob": float, "token": "", "bytes": [int, int, ...], "top_logprobs": [ From 196e237e097c244e4ec14c7ff4f82a7dd7f0475b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 13 Dec 2024 14:33:44 +0100 Subject: [PATCH 10/19] add --multi-token-probs --- common/arg.cpp | 10 ++++++++++ common/common.h | 1 + examples/server/server.cpp | 4 ++++ examples/server/tests/unit/test_chat_completion.py | 2 ++ examples/server/tests/unit/test_completion.py | 2 ++ examples/server/tests/utils.py | 3 +++ 6 files changed, 22 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index 49af31682510d..622f24fb47c69 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1057,6 +1057,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(value)); } ).set_sparam()); + add_opt(common_arg( + {"-mtp", "--multi-token-probs"}, + string_format( + "allow getting probabilities for multiple tokens. note: this will slow down the generation speed (default: %s)", + params.sampling.multi_token_probs ? "enabled" : "disabled" + ), + [](common_params & params) { + params.sampling.multi_token_probs = true; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MULTI_TOKEN_PROBS")); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", diff --git a/common/common.h b/common/common.h index 95d20401d2a9a..5fcb8e506fe2d 100644 --- a/common/common.h +++ b/common/common.h @@ -134,6 +134,7 @@ struct common_params_sampling { bool ignore_eos = false; bool no_perf = false; // disable performance metrics bool timing_per_token = false; + bool multi_token_probs = false; // output probabilities for multiple tokens (when n_probs > 0) std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 95bd531b3878b..8f5778052211c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -239,6 +239,10 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); + if (!params_base.sampling.multi_token_probs && params.n_predict > 1 && params.sampling.n_probs > 0) { + throw std::runtime_error("For performance reason, n_probs with n_predict > 1 is not allowed. To enable this, start the server with --multi-token-probs"); + } + if (params.sampling.dry_base < 1.0f) { params.sampling.dry_base = defaults.sampling.dry_base; } diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 0fa1a17c1f50a..37ac11006e2bc 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -166,6 +166,7 @@ def test_chat_completion_with_timings_per_token(): def test_logprobs(): global server + server.multi_token_probs = True server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") res = client.chat.completions.create( @@ -193,6 +194,7 @@ def test_logprobs(): def test_logprobs_stream(): global server + server.multi_token_probs = True server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") res = client.chat.completions.create( diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 7b33ec531d38f..ee9b9f46663d3 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -249,6 +249,7 @@ def check_slots_status(): def test_n_probs(): global server + server.multi_token_probs = True server.start() res = server.make_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", @@ -274,6 +275,7 @@ def test_n_probs(): def test_n_probs_stream(): global server + server.multi_token_probs = True server.start() res = server.make_stream_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index d988ccf5e3061..5221e0829a819 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -73,6 +73,7 @@ class ServerProcess: draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None + multi_token_probs: bool | None = None # session variables process: subprocess.Popen | None = None @@ -161,6 +162,8 @@ def start(self, timeout_seconds: int = 10) -> None: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") + if self.multi_token_probs: + server_args.append("--multi-token-probs") args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") From 630ddcc570a4cbffce08828b04807ca1ad3e2f96 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 13 Dec 2024 14:35:51 +0100 Subject: [PATCH 11/19] update docs --- examples/server/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/README.md b/examples/server/README.md index 0803c01ba01f2..b3ffb6c89a32e 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -140,6 +140,7 @@ The project is under active development, and we are [looking for feedback and co | `-sp, --special` | special tokens output enabled (default: false) | | `--no-warmup` | skip warming up the model with an empty run | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | +| `-mtp, --multi-token-probs` | allow getting probabilities for multiple tokens. note: this will slow down the generation speed (default: disabled)
(env: LLAMA_ARG_MULTI_TOKEN_PROBS) | | `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | From ecadd37c63381766875adacb687b3856f27fa913 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 18 Dec 2024 14:11:04 +0100 Subject: [PATCH 12/19] add `post_sampling_probs` option --- examples/server/README.md | 84 +++++++------ examples/server/server.cpp | 116 ++++++++++++------ examples/server/tests/unit/test_completion.py | 27 ++++ examples/server/tests/unit/test_embedding.py | 3 + 4 files changed, 151 insertions(+), 79 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index fa6df1ce444de..e4384513557d7 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -449,52 +449,56 @@ These words will not be included in the completion, so make sure to add them to `timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false` +`post_sampling_probs`: Returns the probabilities of top `n_probs` tokens after applying sampling chain. + **Response format** - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. - `completion_probabilities`: An array of token probabilities for each completion. The array's length is `n_predict`. Each item in the array has a nested array `top_logprobs`. It contains at **maximum** `n_probs` elements: - -```json -{ - "content": "", - "tokens": [ generated token ids if requested ], - ... - "probs": [ - { - "id": , - "logprob": float, - "token": "", - "bytes": [int, int, ...], - "top_logprobs": [ - { - "id": , - "logprob": float, - "token": "", - "bytes": [int, int, ...], - }, - { - "id": , - "logprob": float, - "token": "", - "bytes": [int, int, ...], - }, - ... - ] - }, - { - "id": , - "logprob": float, - "token": "", - "bytes": [int, int, ...], - "top_logprobs": [ - ... - ] - }, + ```json + { + "content": "", + "tokens": [ generated token ids if requested ], ... - ] -}, -``` + "probs": [ + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + "top_logprobs": [ + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + }, + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + }, + ... + ] + }, + { + "id": , + "logprob": float, + "token": "", + "bytes": [int, int, ...], + "top_logprobs": [ + ... + ] + }, + ... + ] + }, + ``` + Please note that if `post_sampling_probs` is set to `true`: + - `logprob` will be replace with `prob`, with the value between 0.0 and 1.0 + - Returned number of probabilities may be less than `n_probs` - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. - `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 854dbda1cfd6d..93196adcdc5b9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -93,6 +93,7 @@ struct slot_params { std::vector antiprompt; bool timings_per_token = false; + bool post_sampling_probs = false; bool ignore_eos = false; struct common_params_sampling sampling; @@ -151,6 +152,7 @@ struct slot_params { {"speculative.n_min", speculative.n_min}, {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, + {"post_sampling_probs", post_sampling_probs}, }; } }; @@ -231,6 +233,7 @@ struct server_task { params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); @@ -449,7 +452,7 @@ struct completion_token_output { }; std::vector probs; - json to_json() const { + json to_json(bool post_sampling_probs) const { json probs_for_token = json::array(); for (const auto & p : probs) { std::string tok_str(p.tok_str); @@ -458,13 +461,16 @@ struct completion_token_output { {"id", p.tok}, {"token", tok_str}, {"bytes", str_to_bytes(p.tok_str)}, - {"logprob", logarithm(p.prob)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? p.prob : logarithm(p.prob) + }, }); } return probs_for_token; } - static json probs_vector_to_json(const std::vector & probs) { + static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { json out = json::array(); for (const auto & it : probs) { std::string tok_str(it.text_to_send); @@ -472,9 +478,12 @@ struct completion_token_output { out.push_back(json { {"id", it.tok}, {"token", tok_str}, - {"logprob", logarithm(it.prob)}, {"bytes", str_to_bytes(it.text_to_send)}, - {"top_logprobs", it.to_json()}, + {"top_logprobs", it.to_json(post_sampling_probs)}, + { + post_sampling_probs ? "prob" : "logprob", + post_sampling_probs ? it.prob : logarithm(it.prob) + }, }); } return out; @@ -512,6 +521,7 @@ struct server_task_result_cmpl_final : server_task_result { std::string stopping_word; stop_type stop = STOP_TYPE_NONE; + bool post_sampling_probs; std::vector probs_output; slot_params generation_params; @@ -557,7 +567,7 @@ struct server_task_result_cmpl_final : server_task_result { {"timings", timings.to_json()}, }; if (!stream && !probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output); + res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); } return res; } @@ -579,7 +589,7 @@ struct server_task_result_cmpl_final : server_task_result { if (!stream && probs_output.size() > 0) { choice["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output)}, + {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, }; } @@ -652,6 +662,7 @@ struct server_task_result_cmpl_partial : server_task_result { int32_t n_decoded; int32_t n_prompt_tokens; + bool post_sampling_probs; completion_token_output prob_output; result_timings timings; @@ -690,7 +701,7 @@ struct server_task_result_cmpl_partial : server_task_result { res.push_back({"timings", timings.to_json()}); } if (!prob_output.probs.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}); + res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); } return res; } @@ -746,7 +757,7 @@ struct server_task_result_cmpl_partial : server_task_result { if (prob_output.probs.size() > 0) { choices[0]["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json({prob_output})}, + {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, }; } @@ -1944,28 +1955,53 @@ struct server_context { return slot.has_next_token; // continue } - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool special, int idx) { - std::vector cur = get_token_probabilities(ctx, idx); - int n_vocab = llama_n_vocab(llama_get_model(ctx)); + void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { size_t n_probs = slot.params.sampling.n_probs; - - bool found_sampled_tok = false; - result.probs.reserve(n_probs); - for (int i = 0; i < n_vocab; i++) { - // set probability for sampled token - if (cur[i].id == result.tok) { - found_sampled_tok = true; - result.prob = cur[i].p; + int n_vocab = llama_n_vocab(llama_get_model(ctx)); + if (post_sampling) { + std::vector cur = get_token_probabilities(ctx, idx); + + bool found_sampled_tok = false; + result.probs.reserve(n_probs); + for (int i = 0; i < n_vocab; i++) { + // set probability for sampled token + if (cur[i].id == result.tok) { + found_sampled_tok = true; + result.prob = cur[i].p; + } + // set probability for top n_probs tokens + result.probs.push_back({ + cur[i].id, + common_detokenize(ctx, {cur[i].id}, special), + cur[i].p + }); + // break if we have all the necessary data + if (result.probs.size() == n_probs && found_sampled_tok) { + break; + } } - // set probability for top n_probs tokens - result.probs.push_back({ - cur[i].id, - common_detokenize(ctx, {cur[i].id}, special), - cur[i].p - }); - // break if we have all the necessary data - if (result.probs.size() == n_probs && found_sampled_tok) { - break; + } else { + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; + + bool found_sampled_tok = false; + result.probs.reserve(max_probs); + for (size_t i = 0; i < max_probs; i++) { + // set probability for sampled token + if (cur_p->data[i].id == result.tok) { + found_sampled_tok = true; + result.prob = cur_p->data[i].p; + } + // set probability for top n_probs tokens + result.probs.push_back({ + cur_p->data[i].id, + common_detokenize(ctx, {cur_p->data[i].id}, special), + cur_p->data[i].p + }); + // break if we have all the necessary data + if (result.probs.size() == n_probs && found_sampled_tok) { + break; + } } } } @@ -1997,8 +2033,9 @@ struct server_context { res->content = tkn.text_to_send; res->tokens = { tkn.tok }; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->post_sampling_probs = slot.params.post_sampling_probs; res->verbose = slot.params.verbose; res->oaicompat = slot.params.oaicompat; @@ -2030,13 +2067,14 @@ struct server_context { res->timings = slot.get_timings(); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.n_prompt_tokens; - res->n_tokens_cached = slot.n_past; - res->has_new_line = slot.has_new_line; - res->stopping_word = slot.stopping_word; - res->stop = slot.stop; + res->truncated = slot.truncated; + res->n_decoded = slot.n_decoded; + res->n_prompt_tokens = slot.n_prompt_tokens; + res->n_tokens_cached = slot.n_past; + res->has_new_line = slot.has_new_line; + res->stopping_word = slot.stopping_word; + res->stop = slot.stop; + res->post_sampling_probs = slot.params.post_sampling_probs; res->verbose = slot.params.verbose; res->stream = slot.params.stream; @@ -2859,7 +2897,7 @@ struct server_context { result.prob = 1.0f; // set later if (slot.params.sampling.n_probs > 0) { - populate_token_probs(slot, result, params_base.special, tok_idx); + populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx); } if (!process_token(result, slot)) { diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index c26f982d80c6d..78aaed0522d57 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -309,3 +309,30 @@ def test_n_probs_stream(): assert "token" in prob and type(prob["token"]) == str assert "logprob" in prob and prob["logprob"] <= 0.0 assert "bytes" in prob and type(prob["bytes"]) == list + + +def test_n_probs_post_sampling(): + global server + server.multi_token_probs = True + server.start() + res = server.make_request("POST", "/completion", data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "post_sampling_probs": True, + }) + assert res.status_code == 200 + assert "completion_probabilities" in res.body + assert len(res.body["completion_probabilities"]) == 5 + for tok in res.body["completion_probabilities"]: + assert "id" in tok and tok["id"] > 0 + assert "token" in tok and type(tok["token"]) == str + assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0 + assert "bytes" in tok and type(tok["bytes"]) == list + assert len(tok["top_logprobs"]) == 10 + for prob in tok["top_logprobs"]: + assert "id" in prob and prob["id"] > 0 + assert "token" in prob and type(prob["token"]) == str + assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 + assert "bytes" in prob and type(prob["bytes"]) == list diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index e32d745829605..43e372fc70d71 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -50,6 +50,8 @@ def test_embedding_multiple(): @pytest.mark.parametrize( "input,is_multi_prompt", [ + # do not crash on empty input + ("", False), # single prompt ("string", False), ([12, 34, 56], False), @@ -103,6 +105,7 @@ def test_embedding_pooling_none_oai(): # /v1/embeddings does not support pooling type 'none' assert res.status_code == 400 + assert "error" in res.body def test_embedding_openai_library_single(): From 75fe7751e50440f6a59226db987724281766fad9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 18 Dec 2024 14:32:32 +0100 Subject: [PATCH 13/19] update docs [no ci] --- examples/server/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index e4384513557d7..480e40d30428b 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -504,8 +504,8 @@ These words will not be included in the completion, so make sure to add them to - `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request. - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model`. These options may differ from the original ones in some way (e.g. bad values filtered out, strings converted to tokens, etc.). -- `model`: The path to the model loaded with `-m` -- `prompt`: The provided `prompt` +- `model`: The model alias (for model path, please use `/props` endpoint) +- `prompt`: The processed `prompt` (special tokens may be added) - `stop_type`: Indicating whether the completion has stopped. Possible values are: - `none`: Generating (not stopped) - `eos`: Stopped because it encountered the EOS token From 8734df73d9a470181ba82b5932b2980e35972fb9 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 18 Dec 2024 17:15:15 +0100 Subject: [PATCH 14/19] remove --multi-token-probs --- common/arg.cpp | 10 ---------- common/common.h | 1 - examples/server/README.md | 1 - examples/server/server.cpp | 4 ---- examples/server/tests/unit/test_chat_completion.py | 2 -- examples/server/tests/unit/test_completion.py | 3 --- examples/server/tests/utils.py | 3 --- 7 files changed, 24 deletions(-) diff --git a/common/arg.cpp b/common/arg.cpp index e3f546b76668d..3d55289c33192 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1085,16 +1085,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(value)); } ).set_sparam()); - add_opt(common_arg( - {"-mtp", "--multi-token-probs"}, - string_format( - "allow getting probabilities for multiple tokens. note: this will slow down the generation speed (default: %s)", - params.sampling.multi_token_probs ? "enabled" : "disabled" - ), - [](common_params & params) { - params.sampling.multi_token_probs = true; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MULTI_TOKEN_PROBS")); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", diff --git a/common/common.h b/common/common.h index 9ec4c9f4bc10a..ec0e49f6f1806 100644 --- a/common/common.h +++ b/common/common.h @@ -134,7 +134,6 @@ struct common_params_sampling { bool ignore_eos = false; bool no_perf = false; // disable performance metrics bool timing_per_token = false; - bool multi_token_probs = false; // output probabilities for multiple tokens (when n_probs > 0) std::vector dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY diff --git a/examples/server/README.md b/examples/server/README.md index 480e40d30428b..73e394cfb9b1a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -139,7 +139,6 @@ The project is under active development, and we are [looking for feedback and co | `-sp, --special` | special tokens output enabled (default: false) | | `--no-warmup` | skip warming up the model with an empty run | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | -| `-mtp, --multi-token-probs` | allow getting probabilities for multiple tokens. note: this will slow down the generation speed (default: disabled)
(env: LLAMA_ARG_MULTI_TOKEN_PROBS) | | `--pooling {none,mean,cls,last,rank}` | pooling type for embeddings, use model default if unspecified
(env: LLAMA_ARG_POOLING) | | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 93196adcdc5b9..1b20c8e59d6ca 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -243,10 +243,6 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); - if (!params_base.sampling.multi_token_probs && params.n_predict > 1 && params.sampling.n_probs > 0) { - throw std::runtime_error("For performance reason, n_probs with n_predict > 1 is not allowed. To enable this, start the server with --multi-token-probs"); - } - // TODO: add more sanity checks for the input parameters if (params.sampling.penalty_last_n < -1) { diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 37ac11006e2bc..0fa1a17c1f50a 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -166,7 +166,6 @@ def test_chat_completion_with_timings_per_token(): def test_logprobs(): global server - server.multi_token_probs = True server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") res = client.chat.completions.create( @@ -194,7 +193,6 @@ def test_logprobs(): def test_logprobs_stream(): global server - server.multi_token_probs = True server.start() client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}") res = client.chat.completions.create( diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 78aaed0522d57..f583737ca87d3 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -259,7 +259,6 @@ def check_slots_status(): def test_n_probs(): global server - server.multi_token_probs = True server.start() res = server.make_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", @@ -285,7 +284,6 @@ def test_n_probs(): def test_n_probs_stream(): global server - server.multi_token_probs = True server.start() res = server.make_stream_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", @@ -313,7 +311,6 @@ def test_n_probs_stream(): def test_n_probs_post_sampling(): global server - server.multi_token_probs = True server.start() res = server.make_request("POST", "/completion", data={ "prompt": "I believe the meaning of life is", diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 12310fb224838..277125e88b534 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -74,7 +74,6 @@ class ServerProcess: draft_min: int | None = None draft_max: int | None = None no_webui: bool | None = None - multi_token_probs: bool | None = None # session variables process: subprocess.Popen | None = None @@ -165,8 +164,6 @@ def start(self, timeout_seconds: int = 10) -> None: server_args.extend(["--draft-min", self.draft_min]) if self.no_webui: server_args.append("--no-webui") - if self.multi_token_probs: - server_args.append("--multi-token-probs") args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") From fd4cf34b004fa630e2c5186ee62e51c56c208cfa Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 18 Dec 2024 17:27:29 +0100 Subject: [PATCH 15/19] "top_probs" with "post_sampling_probs" --- examples/server/README.md | 7 ++++++- examples/server/server.cpp | 5 ++++- examples/server/tests/unit/test_completion.py | 4 ++-- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 73e394cfb9b1a..647fa49abdf54 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -497,7 +497,12 @@ These words will not be included in the completion, so make sure to add them to ``` Please note that if `post_sampling_probs` is set to `true`: - `logprob` will be replace with `prob`, with the value between 0.0 and 1.0 - - Returned number of probabilities may be less than `n_probs` + - `top_logprobs` will be replace with `top_probs`. Each element inside contains: + - `id`: token ID + - `token`: token in string + - `bytes`: token in bytes + - `prob`: token probability, with the value between 0.0 and 1.0 + - Number of elements in `top_probs` may be less than `n_probs` - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. - `tokens`: Same as `content` but represented as raw token ids. Only populated if `"return_tokens": true` or `"stream": true` in the request. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1b20c8e59d6ca..a5ac8db76a7c3 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -475,7 +475,10 @@ struct completion_token_output { {"id", it.tok}, {"token", tok_str}, {"bytes", str_to_bytes(it.text_to_send)}, - {"top_logprobs", it.to_json(post_sampling_probs)}, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + it.to_json(post_sampling_probs) + }, { post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? it.prob : logarithm(it.prob) diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index f583737ca87d3..24342b3bbde3a 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -327,8 +327,8 @@ def test_n_probs_post_sampling(): assert "token" in tok and type(tok["token"]) == str assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0 assert "bytes" in tok and type(tok["bytes"]) == list - assert len(tok["top_logprobs"]) == 10 - for prob in tok["top_logprobs"]: + assert len(tok["top_probs"]) == 10 + for prob in tok["top_probs"]: assert "id" in prob and prob["id"] > 0 assert "token" in prob and type(prob["token"]) == str assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 From d2463dc8df1a3da01ce9ee1606999d2891dd1047 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 19 Dec 2024 11:46:34 +0100 Subject: [PATCH 16/19] resolve review comments --- examples/server/README.md | 4 +- examples/server/server.cpp | 61 ++++++++++--------- examples/server/tests/unit/test_completion.py | 4 +- 3 files changed, 36 insertions(+), 33 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 647fa49abdf54..6d64656926250 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -496,8 +496,8 @@ These words will not be included in the completion, so make sure to add them to }, ``` Please note that if `post_sampling_probs` is set to `true`: - - `logprob` will be replace with `prob`, with the value between 0.0 and 1.0 - - `top_logprobs` will be replace with `top_probs`. Each element inside contains: + - `logprob` will be replaced with `prob`, with the value between 0.0 and 1.0 + - `top_logprobs` will be replaced with `top_probs`. Each element contains: - `id`: token ID - `token`: token in string - `bytes`: token in bytes diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a5ac8db76a7c3..dfa38db063e24 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -443,7 +443,7 @@ struct completion_token_output { std::string text_to_send; struct token_prob { llama_token tok; - std::string tok_str; + std::string txt; float prob; }; std::vector probs; @@ -451,12 +451,12 @@ struct completion_token_output { json to_json(bool post_sampling_probs) const { json probs_for_token = json::array(); for (const auto & p : probs) { - std::string tok_str(p.tok_str); - tok_str.resize(validate_utf8(tok_str)); + std::string txt(p.txt); + txt.resize(validate_utf8(txt)); probs_for_token.push_back(json { {"id", p.tok}, - {"token", tok_str}, - {"bytes", str_to_bytes(p.tok_str)}, + {"token", txt}, + {"bytes", str_to_bytes(p.txt)}, { post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob) @@ -468,20 +468,20 @@ struct completion_token_output { static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { json out = json::array(); - for (const auto & it : probs) { - std::string tok_str(it.text_to_send); - tok_str.resize(validate_utf8(tok_str)); + for (const auto & p : probs) { + std::string txt(p.text_to_send); + txt.resize(validate_utf8(txt)); out.push_back(json { - {"id", it.tok}, - {"token", tok_str}, - {"bytes", str_to_bytes(it.text_to_send)}, + {"id", p.tok}, + {"token", txt}, + {"bytes", str_to_bytes(p.text_to_send)}, { post_sampling_probs ? "top_probs" : "top_logprobs", - it.to_json(post_sampling_probs) + p.to_json(post_sampling_probs) }, { post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? it.prob : logarithm(it.prob) + post_sampling_probs ? p.prob : logarithm(p.prob) }, }); } @@ -1958,21 +1958,23 @@ struct server_context { size_t n_probs = slot.params.sampling.n_probs; int n_vocab = llama_n_vocab(llama_get_model(ctx)); if (post_sampling) { - std::vector cur = get_token_probabilities(ctx, idx); + // TODO: optimize this with min-p optimization + const auto * cur_p = common_sampler_get_candidates(slot.smpl); + const size_t max_probs = cur_p->size; bool found_sampled_tok = false; - result.probs.reserve(n_probs); - for (int i = 0; i < n_vocab; i++) { + result.probs.reserve(max_probs); + for (size_t i = 0; i < max_probs; i++) { // set probability for sampled token - if (cur[i].id == result.tok) { + if (cur_p->data[i].id == result.tok) { found_sampled_tok = true; - result.prob = cur[i].p; + result.prob = cur_p->data[i].p; } // set probability for top n_probs tokens result.probs.push_back({ - cur[i].id, - common_detokenize(ctx, {cur[i].id}, special), - cur[i].p + cur_p->data[i].id, + common_detokenize(ctx, {cur_p->data[i].id}, special), + cur_p->data[i].p }); // break if we have all the necessary data if (result.probs.size() == n_probs && found_sampled_tok) { @@ -1980,22 +1982,21 @@ struct server_context { } } } else { - const auto * cur_p = common_sampler_get_candidates(slot.smpl); - const size_t max_probs = cur_p->size; + std::vector cur = get_token_probabilities(ctx, idx); bool found_sampled_tok = false; - result.probs.reserve(max_probs); - for (size_t i = 0; i < max_probs; i++) { + result.probs.reserve(n_probs); + for (int i = 0; i < n_vocab; i++) { // set probability for sampled token - if (cur_p->data[i].id == result.tok) { + if (cur[i].id == result.tok) { found_sampled_tok = true; - result.prob = cur_p->data[i].p; + result.prob = cur[i].p; } // set probability for top n_probs tokens result.probs.push_back({ - cur_p->data[i].id, - common_detokenize(ctx, {cur_p->data[i].id}, special), - cur_p->data[i].p + cur[i].id, + common_detokenize(ctx, {cur[i].id}, special), + cur[i].p }); // break if we have all the necessary data if (result.probs.size() == n_probs && found_sampled_tok) { diff --git a/examples/server/tests/unit/test_completion.py b/examples/server/tests/unit/test_completion.py index 24342b3bbde3a..b88d45f18547f 100644 --- a/examples/server/tests/unit/test_completion.py +++ b/examples/server/tests/unit/test_completion.py @@ -325,7 +325,7 @@ def test_n_probs_post_sampling(): for tok in res.body["completion_probabilities"]: assert "id" in tok and tok["id"] > 0 assert "token" in tok and type(tok["token"]) == str - assert "prob" in tok and 0.0 <= tok["prob"] <= 1.0 + assert "prob" in tok and 0.0 < tok["prob"] <= 1.0 assert "bytes" in tok and type(tok["bytes"]) == list assert len(tok["top_probs"]) == 10 for prob in tok["top_probs"]: @@ -333,3 +333,5 @@ def test_n_probs_post_sampling(): assert "token" in prob and type(prob["token"]) == str assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0 assert "bytes" in prob and type(prob["bytes"]) == list + # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs + assert any(prob["prob"] == 1.0 for prob in tok["top_probs"]) From 65ef1c8dc9538957c1963e5629869ee189dd00cb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 19 Dec 2024 11:51:50 +0100 Subject: [PATCH 17/19] rename struct token_prob to prob_info --- examples/server/server.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index dfa38db063e24..aed1be62e85ce 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -441,12 +441,12 @@ struct completion_token_output { llama_token tok; float prob; std::string text_to_send; - struct token_prob { + struct prob_info { llama_token tok; std::string txt; float prob; }; - std::vector probs; + std::vector probs; json to_json(bool post_sampling_probs) const { json probs_for_token = json::array(); From a217382b25713b9299ba7a9a11cbd77aad859100 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 19 Dec 2024 12:02:11 +0100 Subject: [PATCH 18/19] correct comment placement --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index aed1be62e85ce..9b338e1d9af66 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1958,7 +1958,6 @@ struct server_context { size_t n_probs = slot.params.sampling.n_probs; int n_vocab = llama_n_vocab(llama_get_model(ctx)); if (post_sampling) { - // TODO: optimize this with min-p optimization const auto * cur_p = common_sampler_get_candidates(slot.smpl); const size_t max_probs = cur_p->size; @@ -1982,6 +1981,7 @@ struct server_context { } } } else { + // TODO: optimize this with min-p optimization std::vector cur = get_token_probabilities(ctx, idx); bool found_sampled_tok = false; From 5b966df17736037890cb481175d853a5418c2d2b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 19 Dec 2024 14:39:36 +0100 Subject: [PATCH 19/19] fix setting prob for sampled token --- examples/server/server.cpp | 45 ++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 9b338e1d9af66..fa3682a920649 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -475,14 +475,14 @@ struct completion_token_output { {"id", p.tok}, {"token", txt}, {"bytes", str_to_bytes(p.text_to_send)}, - { - post_sampling_probs ? "top_probs" : "top_logprobs", - p.to_json(post_sampling_probs) - }, { post_sampling_probs ? "prob" : "logprob", post_sampling_probs ? p.prob : logarithm(p.prob) }, + { + post_sampling_probs ? "top_probs" : "top_logprobs", + p.to_json(post_sampling_probs) + }, }); } return out; @@ -1956,52 +1956,49 @@ struct server_context { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { size_t n_probs = slot.params.sampling.n_probs; - int n_vocab = llama_n_vocab(llama_get_model(ctx)); + size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); if (post_sampling) { const auto * cur_p = common_sampler_get_candidates(slot.smpl); const size_t max_probs = cur_p->size; - bool found_sampled_tok = false; - result.probs.reserve(max_probs); + // set probability for sampled token for (size_t i = 0; i < max_probs; i++) { - // set probability for sampled token if (cur_p->data[i].id == result.tok) { - found_sampled_tok = true; result.prob = cur_p->data[i].p; + break; } - // set probability for top n_probs tokens + } + + // set probability for top n_probs tokens + result.probs.reserve(max_probs); + for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { result.probs.push_back({ cur_p->data[i].id, common_detokenize(ctx, {cur_p->data[i].id}, special), cur_p->data[i].p }); - // break if we have all the necessary data - if (result.probs.size() == n_probs && found_sampled_tok) { - break; - } } } else { // TODO: optimize this with min-p optimization std::vector cur = get_token_probabilities(ctx, idx); - bool found_sampled_tok = false; - result.probs.reserve(n_probs); - for (int i = 0; i < n_vocab; i++) { + // set probability for sampled token + for (size_t i = 0; i < n_vocab; i++) { // set probability for sampled token if (cur[i].id == result.tok) { - found_sampled_tok = true; result.prob = cur[i].p; + break; } - // set probability for top n_probs tokens + } + + // set probability for top n_probs tokens + result.probs.reserve(n_probs); + for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { result.probs.push_back({ cur[i].id, common_detokenize(ctx, {cur[i].id}, special), cur[i].p }); - // break if we have all the necessary data - if (result.probs.size() == n_probs && found_sampled_tok) { - break; - } } } } @@ -2894,7 +2891,7 @@ struct server_context { completion_token_output result; result.tok = id; result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special); - result.prob = 1.0f; // set later + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs if (slot.params.sampling.n_probs > 0) { populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);