Skip to content

Commit c82742a

Browse files
authored
llama : add llama_beam_search() (#2267)
* Add llama_beam_search(). * Add '// Beam search' heading to llama.{h,cpp} after llama_grammar_accept_token(). * Add space around * pointers and & references. * Add spaces around comparison and assignment operators. * Prefer west const. * Use llama_ prefix for structs in global namespace. * Delete obsolete comment from an earlier revision. * Change eos to eob in llama_beam and llama_beam_view structs.
1 parent 28b2c99 commit c82742a

File tree

7 files changed

+563
-13
lines changed

7 files changed

+563
-13
lines changed

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct gpt_params {
2828
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
2929
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
3030
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
31+
int32_t n_beams = 0; // if non-zero then use beam search of given width.
3132
float rope_freq_base = 10000.0f; // RoPE base frequency
3233
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor
3334

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ else()
2525
add_subdirectory(simple)
2626
add_subdirectory(embd-input)
2727
add_subdirectory(llama-bench)
28+
add_subdirectory(beam_search)
2829
if (LLAMA_METAL)
2930
add_subdirectory(metal)
3031
endif()

examples/beam_search/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
set(TARGET beam_search)
2+
add_executable(${TARGET} beam_search.cpp)
3+
install(TARGETS ${TARGET} RUNTIME)
4+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
6+
if(TARGET BUILD_INFO)
7+
add_dependencies(${TARGET} BUILD_INFO)
8+
endif()

examples/beam_search/beam_search.cpp

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
#ifndef _GNU_SOURCE
2+
#define _GNU_SOURCE
3+
#endif
4+
5+
#include "common.h"
6+
#include "llama.h"
7+
#include "build-info.h"
8+
9+
#include <cassert>
10+
#include <cinttypes>
11+
#include <cmath>
12+
#include <cstdio>
13+
#include <cstring>
14+
#include <ctime>
15+
#include <fstream>
16+
#include <iostream>
17+
#include <string>
18+
#include <vector>
19+
20+
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
21+
#include <signal.h>
22+
#include <unistd.h>
23+
#elif defined (_WIN32)
24+
#define WIN32_LEAN_AND_MEAN
25+
#define NOMINMAX
26+
#include <windows.h>
27+
#include <signal.h>
28+
#endif
29+
30+
// Used for debugging to print out beam tokens.
31+
struct ostream_beam_view {
32+
llama_context * ctx;
33+
llama_beam_view beam_view;
34+
};
35+
std::ostream& operator<<(std::ostream& os, const ostream_beam_view & obv) {
36+
os << "p(" << obv.beam_view.p << ") eob(" << std::boolalpha << obv.beam_view.eob << ") tokens(";
37+
for (size_t i = 0 ; i < obv.beam_view.n_tokens ; ++i) {
38+
os << llama_token_to_str(obv.ctx, obv.beam_view.tokens[i]);
39+
}
40+
return os << ')';
41+
}
42+
43+
// Put here anything you want back in beam_search_callback().
44+
struct beam_search_callback_data {
45+
llama_context * ctx;
46+
std::vector<llama_token> response;
47+
};
48+
49+
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
50+
// For example, eob can be flagged due to maximum token length, stop words, etc.
51+
bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, const size_t n_tokens) {
52+
return n_tokens && tokens[n_tokens-1] == llama_token_eos(callback_data.ctx);
53+
}
54+
55+
// Function matching type llama_beam_search_callback_fn_t.
56+
// Custom callback example is called each time the beams lengths increase:
57+
// * Show progress by printing ',' following by number of convergent beam tokens if any.
58+
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
59+
// This is also called when the stop condition is met.
60+
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
61+
void beam_search_callback(void * callback_data_ptr, llama_beams_state beams_state) {
62+
auto& callback_data = *static_cast<beam_search_callback_data*>(callback_data_ptr);
63+
// Mark beams as EOS as needed.
64+
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
65+
llama_beam_view& beam_view = beams_state.beam_views[i];
66+
if (!beam_view.eob && is_at_eob(callback_data, beam_view.tokens, beam_view.n_tokens)) {
67+
beam_view.eob = true;
68+
}
69+
}
70+
printf(","); // Show progress
71+
if (const size_t n = beams_state.common_prefix_length) {
72+
callback_data.response.resize(callback_data.response.size() + n);
73+
assert(0u < beams_state.n_beams);
74+
const llama_token * tokens = beams_state.beam_views[0].tokens;
75+
std::copy(tokens, tokens + n, callback_data.response.end() - n);
76+
printf("%lu", n);
77+
}
78+
fflush(stdout);
79+
#if 1 // DEBUG: print current beams for this iteration
80+
std::cout << "\n\nCurrent beams (last_call=" << beams_state.last_call << "):\n";
81+
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
82+
std::cout << "beams["<<i<<"]: " << ostream_beam_view{callback_data.ctx,beams_state.beam_views[i]} << std::endl;
83+
}
84+
#endif
85+
}
86+
87+
int main(int argc, char ** argv)
88+
{
89+
gpt_params params;
90+
//params.n_gpu_layers = 200;
91+
92+
//---------------------------------
93+
// Print help :
94+
//---------------------------------
95+
96+
if ( argc < 2 || argv[1][0] == '-' )
97+
{
98+
printf( "Usage: %s MODEL_PATH [BEAM_WIDTH=2] [PROMPT]\n" , argv[0] );
99+
return 1 ;
100+
}
101+
102+
//---------------------------------
103+
// Load parameters :
104+
//---------------------------------
105+
106+
params.model = argv[1];
107+
108+
params.n_beams = 2 < argc ? std::stoi(argv[2]) : 2;
109+
110+
if ( argc > 3 )
111+
{
112+
params.prompt = argv[3];
113+
}
114+
115+
if ( params.prompt.empty() )
116+
{
117+
params.prompt = "### Request:\nHow many countries are there?\n\n### Response:\n";
118+
}
119+
120+
//---------------------------------
121+
// Init LLM :
122+
//---------------------------------
123+
124+
llama_backend_init(params.numa);
125+
126+
llama_model * model;
127+
llama_context * ctx;
128+
129+
std::tie(model, ctx) = llama_init_from_gpt_params( params );
130+
131+
if ( model == NULL )
132+
{
133+
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
134+
return 1;
135+
}
136+
137+
//---------------------------------
138+
// Tokenize the prompt :
139+
//---------------------------------
140+
141+
std::vector<llama_token> tokens_list = llama_tokenize(ctx, params.prompt, true);
142+
143+
const size_t max_context_size = llama_n_ctx( ctx );
144+
const size_t max_tokens_list_size = max_context_size - 4 ;
145+
146+
if (tokens_list.size() > max_tokens_list_size)
147+
{
148+
fprintf( stderr , "%s: error: prompt too long (%lu tokens, max %lu)\n" ,
149+
__func__ , tokens_list.size() , max_tokens_list_size );
150+
return 1;
151+
}
152+
153+
fprintf( stderr, "\n\n" );
154+
155+
// Print the tokens from the prompt :
156+
157+
for( auto id : tokens_list )
158+
{
159+
std::cout << llama_token_to_str(ctx, id);
160+
}
161+
std::cout << std::flush;
162+
163+
int n_past = llama_get_kv_cache_token_count(ctx);
164+
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
165+
{
166+
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
167+
return 1;
168+
}
169+
n_past += tokens_list.size();
170+
171+
beam_search_callback_data callback_data{ctx, {}};
172+
size_t const beam_width = static_cast<size_t>(params.n_beams);
173+
int const n_predict = 256;
174+
llama_beam_search(ctx, beam_search_callback, &callback_data, beam_width, n_past, n_predict, params.n_threads);
175+
176+
std::cout << "\n\n";
177+
for (llama_token const token_id : callback_data.response) {
178+
std::cout << llama_token_to_str(ctx,token_id);
179+
}
180+
std::cout << std::endl;
181+
182+
llama_free( ctx );
183+
llama_free_model( model );
184+
185+
llama_backend_free();
186+
187+
return 0;
188+
}

examples/server/server.cpp

Lines changed: 77 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,6 +1209,62 @@ static void log_server_request(const Request &req, const Response &res)
12091209
});
12101210
}
12111211

1212+
bool is_at_eob(llama_server_context & server_context, const llama_token * tokens, const size_t n_tokens) {
1213+
return n_tokens && tokens[n_tokens-1] == llama_token_eos(server_context.ctx);
1214+
}
1215+
1216+
// Function matching type llama_beam_search_callback_fn_t.
1217+
// Custom callback example is called each time the beams lengths increase:
1218+
// * Show progress by printing ',' following by number of convergent beam tokens if any.
1219+
// * When all beams converge to a common prefix, they are made available in beams_state.beams[0].
1220+
// This is also called when the stop condition is met.
1221+
// Collect tokens into std::vector<llama_token> response which is pointed to by callback_data.
1222+
void beam_search_callback(void * callback_data, llama_beams_state beams_state) {
1223+
auto & llama = *static_cast<llama_server_context*>(callback_data);
1224+
// Mark beams as EOS as needed.
1225+
for (size_t i = 0 ; i < beams_state.n_beams ; ++i) {
1226+
llama_beam_view& beam_view = beams_state.beam_views[i];
1227+
if (!beam_view.eob && is_at_eob(llama, beam_view.tokens, beam_view.n_tokens)) {
1228+
beam_view.eob = true;
1229+
}
1230+
}
1231+
printf(","); // Show progress
1232+
if (const size_t n = beams_state.common_prefix_length) {
1233+
llama.generated_token_probs.resize(llama.generated_token_probs.size() + n);
1234+
assert(0u < beams_state.n_beams);
1235+
const llama_token * tokens = beams_state.beam_views[0].tokens;
1236+
const auto map = [](llama_token tok) { return completion_token_output{{},tok}; };
1237+
std::transform(tokens, tokens + n, llama.generated_token_probs.end() - n, map);
1238+
printf("%lu", n);
1239+
}
1240+
fflush(stdout);
1241+
#if 0 // DEBUG: print current beams for this iteration
1242+
std::cout << "\n\nCurrent beams:\n";
1243+
for (size_t i=0 ; i < beams_state.n_beams ; ++i) {
1244+
std::cout << "beams["<<i<<"]: " << ostream_beam_view{state.ctx,beams_state.beam_views[i]} << std::endl;
1245+
}
1246+
#endif
1247+
}
1248+
1249+
struct token_translator {
1250+
llama_context * ctx;
1251+
std::string operator()(llama_token tok) const { return llama_token_to_str(ctx, tok); }
1252+
std::string operator()(completion_token_output cto) const { return (*this)(cto.tok); }
1253+
};
1254+
1255+
void append_to_generated_text_from_generated_token_probs(llama_server_context & llama) {
1256+
auto & gtps = llama.generated_token_probs;
1257+
auto translator = token_translator{llama.ctx};
1258+
auto add_strlen = [=](size_t sum, const completion_token_output & cto) { return sum + translator(cto).size(); };
1259+
const size_t len = std::accumulate(gtps.begin(), gtps.end(), size_t(0), add_strlen);
1260+
if (llama.generated_text.capacity() < llama.generated_text.size() + len) {
1261+
llama.generated_text.reserve(llama.generated_text.size() + len);
1262+
}
1263+
for (const completion_token_output & cto : gtps) {
1264+
llama.generated_text += translator(cto);
1265+
}
1266+
}
1267+
12121268
int main(int argc, char **argv)
12131269
{
12141270
// own arguments required by this example
@@ -1291,22 +1347,30 @@ int main(int argc, char **argv)
12911347
llama.beginCompletion();
12921348

12931349
if (!llama.stream) {
1294-
size_t stop_pos = std::string::npos;
1350+
if (llama.params.n_beams) {
1351+
// Fill llama.generated_token_probs vector with final beam.
1352+
llama_beam_search(llama.ctx, beam_search_callback, &llama, llama.params.n_beams,
1353+
llama.n_past, llama.n_remain, llama.params.n_threads);
1354+
// Translate llama.generated_token_probs to llama.generated_text.
1355+
append_to_generated_text_from_generated_token_probs(llama);
1356+
} else {
1357+
size_t stop_pos = std::string::npos;
12951358

1296-
while (llama.has_next_token) {
1297-
const completion_token_output token_with_probs = llama.doCompletion();
1298-
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
1359+
while (llama.has_next_token) {
1360+
const completion_token_output token_with_probs = llama.doCompletion();
1361+
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok);
12991362

1300-
stop_pos = llama.findStoppingStrings(llama.generated_text,
1301-
token_text.size(), STOP_FULL);
1302-
}
1363+
stop_pos = llama.findStoppingStrings(llama.generated_text,
1364+
token_text.size(), STOP_FULL);
1365+
}
13031366

1304-
if (stop_pos == std::string::npos) {
1305-
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
1306-
}
1307-
if (stop_pos != std::string::npos) {
1308-
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
1309-
llama.generated_text.end());
1367+
if (stop_pos == std::string::npos) {
1368+
stop_pos = llama.findStoppingStrings(llama.generated_text, 0, STOP_PARTIAL);
1369+
}
1370+
if (stop_pos != std::string::npos) {
1371+
llama.generated_text.erase(llama.generated_text.begin() + stop_pos,
1372+
llama.generated_text.end());
1373+
}
13101374
}
13111375

13121376
const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);

0 commit comments

Comments
 (0)