From 9a0709ff93f764b1c02f869453cf58d3c5ba8d93 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Thu, 29 Aug 2024 19:31:50 -0400 Subject: [PATCH 1/4] enhance: avoid context limit Signed-off-by: Grant Linville --- pkg/openai/client.go | 47 ++++++++++++++++++++++++++++++++++++++++++++ pkg/openai/count.go | 38 ++++++++++++++++++++++++++--------- 2 files changed, 76 insertions(+), 9 deletions(-) diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 42a1a39e..b69950a1 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -2,8 +2,10 @@ package openai import ( "context" + "errors" "io" "log/slog" + "math" "os" "slices" "sort" @@ -24,6 +26,7 @@ import ( const ( DefaultModel = openai.GPT4o BuiltinCredName = "sys.openai" + TooLongMessage = "Error: tool call output is too long" ) var ( @@ -317,6 +320,14 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques } if messageRequest.Chat { + // Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it. + lastMessage := msgs[len(msgs)-1] + if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(math.Round(float64(getBudget(messageRequest.MaxTokens))*0.8)) { + // We need to update it in the msgs slice for right now and in the messageRequest for future calls. + msgs[len(msgs)-1].Content = TooLongMessage + messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage) + } + msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs) } @@ -383,6 +394,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return nil, err } else if !ok { response, err = c.call(ctx, request, id, status) + + // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass. + var apiError *openai.APIError + if err != nil && errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat { + // Decrease maxTokens by 10% to make garbage collection more aggressive. + // The retry loop will further decrease maxTokens if needed. + maxTokens := decreaseTenPercent(messageRequest.MaxTokens) + response, err = c.contextLimitRetryLoop(ctx, request, id, maxTokens, status) + } + if err != nil { return nil, err } @@ -421,6 +442,32 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return &result, nil } +func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, maxTokens int, status chan<- types.CompletionStatus) ([]openai.ChatCompletionStreamResponse, error) { + var ( + response []openai.ChatCompletionStreamResponse + err error + ) + + for range 10 { // maximum 10 tries + // Try to drop older messages again, with a decreased max tokens. + request.Messages = dropMessagesOverCount(maxTokens, request.Messages) + response, err = c.call(ctx, request, id, status) + if err == nil { + break + } + + var apiError *openai.APIError + if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" { + // Decrease maxTokens and try again + maxTokens = decreaseTenPercent(maxTokens) + continue + } + return nil, err + } + + return response, nil +} + func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionStreamResponse) types.CompletionMessage { msg.Usage.CompletionTokens = types.FirstSet(msg.Usage.CompletionTokens, response.Usage.CompletionTokens) msg.Usage.PromptTokens = types.FirstSet(msg.Usage.PromptTokens, response.Usage.PromptTokens) diff --git a/pkg/openai/count.go b/pkg/openai/count.go index 47c5c9bd..f46790b6 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -1,20 +1,32 @@ package openai -import openai "github.com/gptscript-ai/chat-completion-client" +import ( + "math" + + openai "github.com/gptscript-ai/chat-completion-client" +) + +const DefaultMaxTokens = 128_000 + +func decreaseTenPercent(maxTokens int) int { + maxTokens = getBudget(maxTokens) + return int(math.Round(float64(maxTokens) * 0.9)) +} + +func getBudget(maxTokens int) int { + if maxTokens == 0 { + return DefaultMaxTokens + } + return maxTokens +} func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage) { var ( lastSystem int withinBudget int - budget = maxTokens + budget = getBudget(maxTokens) ) - if maxTokens == 0 { - budget = 300_000 - } else { - budget *= 3 - } - for i, msg := range msgs { if msg.Role == openai.ChatMessageRoleSystem { budget -= countMessage(msg) @@ -33,7 +45,15 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) ( } } - if withinBudget == len(msgs)-1 { + // OpenAI gets upset if there is a tool message without a tool call preceding it. + // Check the oldest message within budget, and if it is a tool message, just drop it. + // We do this in a loop because it is possible for multiple tool messages to be in a row, + // due to parallel tool calls. + for withinBudget < len(msgs) && msgs[withinBudget].Role == openai.ChatMessageRoleTool { + withinBudget++ + } + + if withinBudget >= len(msgs)-1 { // We are going to drop all non system messages, which seems useless, so just return them // all and let it fail return msgs From 92876e8c062e5edd606f63bed27a16ac6d34de74 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Thu, 29 Aug 2024 19:37:07 -0400 Subject: [PATCH 2/4] fix retry loop Signed-off-by: Grant Linville --- pkg/openai/client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/openai/client.go b/pkg/openai/client.go index b69950a1..a5f77001 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -453,7 +453,7 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC request.Messages = dropMessagesOverCount(maxTokens, request.Messages) response, err = c.call(ctx, request, id, status) if err == nil { - break + return response, nil } var apiError *openai.APIError @@ -465,7 +465,7 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC return nil, err } - return response, nil + return nil, err } func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionStreamResponse) types.CompletionMessage { From 30f24a43b072641621625c79c222e1559126cc82 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Thu, 29 Aug 2024 19:39:01 -0400 Subject: [PATCH 3/4] fix Signed-off-by: Grant Linville --- pkg/openai/count.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/openai/count.go b/pkg/openai/count.go index f46790b6..1bae22dc 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -53,7 +53,7 @@ func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) ( withinBudget++ } - if withinBudget >= len(msgs)-1 { + if withinBudget == len(msgs)-1 { // We are going to drop all non system messages, which seems useless, so just return them // all and let it fail return msgs From 53b96712a418e7637b1a14acafbf93d1cf204ad8 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Fri, 30 Aug 2024 09:27:20 -0400 Subject: [PATCH 4/4] PR feedback Signed-off-by: Grant Linville --- pkg/openai/client.go | 5 ++--- pkg/openai/count.go | 4 +--- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pkg/openai/client.go b/pkg/openai/client.go index a5f77001..61a7ec77 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -5,7 +5,6 @@ import ( "errors" "io" "log/slog" - "math" "os" "slices" "sort" @@ -322,7 +321,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques if messageRequest.Chat { // Check the last message. If it is from a tool call, and if it takes up more than 80% of the budget on its own, reject it. lastMessage := msgs[len(msgs)-1] - if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(math.Round(float64(getBudget(messageRequest.MaxTokens))*0.8)) { + if lastMessage.Role == string(types.CompletionMessageRoleTypeTool) && countMessage(lastMessage) > int(float64(getBudget(messageRequest.MaxTokens))*0.8) { // We need to update it in the msgs slice for right now and in the messageRequest for future calls. msgs[len(msgs)-1].Content = TooLongMessage messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage) @@ -397,7 +396,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass. var apiError *openai.APIError - if err != nil && errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat { + if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat { // Decrease maxTokens by 10% to make garbage collection more aggressive. // The retry loop will further decrease maxTokens if needed. maxTokens := decreaseTenPercent(messageRequest.MaxTokens) diff --git a/pkg/openai/count.go b/pkg/openai/count.go index 1bae22dc..ffd902e5 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -1,8 +1,6 @@ package openai import ( - "math" - openai "github.com/gptscript-ai/chat-completion-client" ) @@ -10,7 +8,7 @@ const DefaultMaxTokens = 128_000 func decreaseTenPercent(maxTokens int) int { maxTokens = getBudget(maxTokens) - return int(math.Round(float64(maxTokens) * 0.9)) + return int(float64(maxTokens) * 0.9) } func getBudget(maxTokens int) int {