Skip to content

chore: add sys.chat.current #542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ require (
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
github.com/gptscript-ai/cmd v0.0.0-20240625175447-4250b42feb7d
github.com/gptscript-ai/tui v0.0.0-20240625175717-1e6eca7a66c1
github.com/gptscript-ai/tui v0.0.0-20240627001757-8b452fa47eb5
github.com/hexops/autogold/v2 v2.2.1
github.com/hexops/valast v1.4.4
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
Expand Down Expand Up @@ -94,6 +94,7 @@ require (
github.com/pterm/pterm v0.12.79 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e // indirect
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
github.com/therootcompany/xz v1.0.1 // indirect
github.com/tidwall/match v1.1.1 // indirect
Expand Down
6 changes: 4 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ github.com/gptscript-ai/cmd v0.0.0-20240625175447-4250b42feb7d h1:sKf7T7twhGXs6A
github.com/gptscript-ai/cmd v0.0.0-20240625175447-4250b42feb7d/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
github.com/gptscript-ai/go-gptscript v0.0.0-20240625134437-4b83849794cc h1:ABV7VAK65YBkqL7VlNp5ryVXnRqkKQ+U/NZfUO3ypqA=
github.com/gptscript-ai/go-gptscript v0.0.0-20240625134437-4b83849794cc/go.mod h1:Dh6vYRAiVcyC3ElZIGzTvNF1FxtYwA07BHfSiFKQY7s=
github.com/gptscript-ai/tui v0.0.0-20240625175717-1e6eca7a66c1 h1:sx/dJ0IRh3P9Ehr1g1TQ/jEw83KISmQyjrssVgPGUbE=
github.com/gptscript-ai/tui v0.0.0-20240625175717-1e6eca7a66c1/go.mod h1:R33cfOnNaqsEn9es5jLKR39wvDyHvsIVgeTMNqtzCb8=
github.com/gptscript-ai/tui v0.0.0-20240627001757-8b452fa47eb5 h1:knDhTTJNqaZB1XMudXJuVVnTqj9USrXzNfsl1nTqKXA=
github.com/gptscript-ai/tui v0.0.0-20240627001757-8b452fa47eb5/go.mod h1:NwFdBDmGQvjLFFDnSRBRakkhw0MIO1sSdRnWNk4cCQ0=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
Expand Down Expand Up @@ -308,6 +308,8 @@ github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ=
github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e h1:H+jDTUeF+SVd4ApwnSFoew8ZwGNRfgb9EsZc7LcocAg=
github.com/sourcegraph/go-diff-patch v0.0.0-20240223163233-798fd1e94a8e/go.mod h1:VsUklG6OQo7Ctunu0gS3AtEOCEc2kMB6r5rKzxAes58=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
Expand Down
32 changes: 32 additions & 0 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ var SafeTools = map[string]struct{}{
"sys.abort": {},
"sys.chat.finish": {},
"sys.chat.history": {},
"sys.chat.current": {},
"sys.echo": {},
"sys.prompt": {},
"sys.time.now": {},
Expand Down Expand Up @@ -229,6 +230,15 @@ var tools = map[string]types.Tool{
BuiltinFunc: SysChatHistory,
},
},
"sys.chat.current": {
ToolDef: types.ToolDef{
Parameters: types.Parameters{
Description: "Retrieves the current chat dialog",
Arguments: types.ObjectSchema(),
},
BuiltinFunc: SysChatCurrent,
},
},
"sys.context": {
ToolDef: types.ToolDef{
Parameters: types.Parameters{
Expand Down Expand Up @@ -715,6 +725,28 @@ func writeHistory(ctx *engine.Context) (result []engine.ChatHistoryCall) {
return
}

func SysChatCurrent(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
engineContext, _ := engine.FromContext(ctx)

var call any
if engineContext != nil && engineContext.CurrentReturn != nil && engineContext.CurrentReturn.State != nil {
call = engine.ChatHistoryCall{
ID: engineContext.ID,
Tool: engineContext.Tool,
Completion: engineContext.CurrentReturn.State.Completion,
}
} else {
call = map[string]any{}
}

data, err := json.Marshal(call)
if err != nil {
return invalidArgument("", err), nil
}

return string(data), nil
}

func SysChatFinish(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
var params struct {
Message string `json:"return,omitempty"`
Expand Down
2 changes: 1 addition & 1 deletion pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
DefaultModel: r.DefaultModel,
TrustedRepoPrefixes: []string{"github.com/gptscript-ai"},
DisableCache: r.DisableCache,
Input: strings.Join(args[1:], " "),
Input: toolInput,
CacheDir: r.CacheDir,
SubTool: r.SubTool,
Workspace: r.Workspace,
Expand Down
37 changes: 28 additions & 9 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type CallResult struct {
type commonContext struct {
ID string `json:"id"`
Tool types.Tool `json:"tool"`
CurrentAgent types.ToolReference `json:"currentAgent,omitempty"`
AgentGroup []types.ToolReference `json:"agentGroup,omitempty"`
InputContext []InputContext `json:"inputContext"`
ToolCategory ToolCategory `json:"toolCategory,omitempty"`
Expand All @@ -73,10 +74,11 @@ type CallContext struct {

type Context struct {
commonContext
Ctx context.Context
Parent *Context
LastReturn *Return
Program *types.Program
Ctx context.Context
Parent *Context
LastReturn *Return
CurrentReturn *Return
Program *types.Program
// Input is saved only so that we can render display text, don't use otherwise
Input string
}
Expand Down Expand Up @@ -129,6 +131,18 @@ func (c *Context) ParentID() string {
return c.Parent.ID
}

func (c *Context) CurrentAgent() types.ToolReference {
for _, ref := range c.AgentGroup {
if ref.ToolID == c.Tool.ID {
return ref
}
}
if c.Parent != nil {
return c.Parent.CurrentAgent()
}
return types.ToolReference{}
}

func (c *Context) GetCallContext() *CallContext {
var toolName string
if c.Parent != nil {
Expand All @@ -143,12 +157,15 @@ func (c *Context) GetCallContext() *CallContext {
}
}

return &CallContext{
result := &CallContext{
commonContext: c.commonContext,
ParentID: c.ParentID(),
ToolName: toolName,
DisplayText: types.ToDisplayText(c.Tool, c.Input),
}

result.CurrentAgent = c.CurrentAgent()
return result
}

func (c *Context) UnmarshalJSON([]byte) error {
Expand Down Expand Up @@ -215,10 +232,11 @@ func (c *Context) SubCallContext(ctx context.Context, input, toolID, callID stri
AgentGroup: agentGroup,
ToolCategory: toolCategory,
},
Ctx: ctx,
Parent: c,
Program: c.Program,
Input: input,
Ctx: ctx,
Parent: c,
Program: c.Program,
CurrentReturn: c.CurrentReturn,
Input: input,
}, nil
}

Expand Down Expand Up @@ -270,6 +288,7 @@ func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {
MaxTokens: tool.Parameters.MaxTokens,
JSONResponse: tool.Parameters.JSONResponse,
Cache: tool.Parameters.Cache,
Chat: tool.Parameters.Chat,
Temperature: tool.Parameters.Temperature,
InternalSystemPrompt: tool.Parameters.InternalPrompt,
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
if messageRequest.Model == "" {
messageRequest.Model = c.defaultModel
}

msgs, err := toMessages(messageRequest, !c.setSeed)
if err != nil {
return nil, err
}

if messageRequest.Chat {
msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs)
}

if len(msgs) == 0 {
log.Errorf("invalid request, no messages to send to LLM")
return &types.CompletionMessage{
Expand Down
57 changes: 57 additions & 0 deletions pkg/openai/count.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package openai

import openai "github.com/gptscript-ai/chat-completion-client"

func dropMessagesOverCount(maxTokens int, msgs []openai.ChatCompletionMessage) (result []openai.ChatCompletionMessage) {
var (
lastSystem int
withinBudget int
budget = maxTokens
)

if maxTokens == 0 {
budget = 300_000
} else {
budget *= 3
}

for i, msg := range msgs {
if msg.Role == openai.ChatMessageRoleSystem {
budget -= countMessage(msg)
lastSystem = i
result = append(result, msg)
} else {
break
}
}

for i := len(msgs) - 1; i > lastSystem; i-- {
withinBudget = i
budget -= countMessage(msgs[i])
if budget <= 0 {
break
}
}

if withinBudget == len(msgs)-1 {
// We are going to drop all non system messages, which seems useless, so just return them
// all and let it fail
return msgs
}

return append(result, msgs[withinBudget:]...)
}

func countMessage(msg openai.ChatCompletionMessage) (count int) {
count += len(msg.Role)
count += len(msg.Content)
for _, content := range msg.MultiContent {
count += len(content.Text)
}
for _, tool := range msg.ToolCalls {
count += len(tool.Function.Name)
count += len(tool.Function.Arguments)
}
count += len(msg.ToolCallID)
return count / 3
}
1 change: 0 additions & 1 deletion pkg/runner/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []str
for _, outputToolRef := range outputToolRefs {
inputData, err := json.Marshal(map[string]any{
"output": output,
"chatFinish": chatFinish,
"continuation": continuation,
"chat": callCtx.Tool.Chat,
})
Expand Down
2 changes: 2 additions & 0 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,8 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
}

for {
callCtx.CurrentReturn = state.Continuation

if state.Continuation.Result != nil && len(state.Continuation.Calls) == 0 && state.SubCallID == "" && state.ResumeInput == nil {
progressClose()
monitor.Event(Event{
Expand Down
18 changes: 11 additions & 7 deletions pkg/tests/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ func TestSubChat(t *testing.T) {
],
"usage": {}
}
]
],
"chat": true
}
},
"result": "Assistant 1"
Expand Down Expand Up @@ -555,7 +556,8 @@ func TestSubChat(t *testing.T) {
],
"usage": {}
}
]
],
"chat": true
}
},
"result": "Assistant 2"
Expand Down Expand Up @@ -622,7 +624,8 @@ func TestChat(t *testing.T) {
],
"usage": {}
}
]
],
"chat": true
}
},
"result": "Assistant 1"
Expand Down Expand Up @@ -691,7 +694,8 @@ func TestChat(t *testing.T) {
],
"usage": {}
}
]
],
"chat": true
}
},
"result": "Assistant 2"
Expand Down Expand Up @@ -866,7 +870,7 @@ func TestOutput(t *testing.T) {
require.NoError(t, err)
r.AssertResponded(t)
assert.False(t, resp.Done)
autogold.Expect(`CHAT: true CONTENT: Response 1 CONTINUATION: true FINISH: false suffix
autogold.Expect(`CHAT: true CONTENT: Response 1 CONTINUATION: true suffix
`).Equal(t, resp.Content)
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1"))

Expand All @@ -877,7 +881,7 @@ func TestOutput(t *testing.T) {
require.NoError(t, err)
r.AssertResponded(t)
assert.False(t, resp.Done)
autogold.Expect(`CHAT: true CONTENT: Response 2 CONTINUATION: true FINISH: false suffix
autogold.Expect(`CHAT: true CONTENT: Response 2 CONTINUATION: true suffix
`).Equal(t, resp.Content)
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2"))

Expand All @@ -890,7 +894,7 @@ func TestOutput(t *testing.T) {
require.NoError(t, err)
r.AssertResponded(t)
assert.True(t, resp.Done)
autogold.Expect(`CHAT FINISH: CHAT: true CONTENT: Chat Done CONTINUATION: false FINISH: true suffix
autogold.Expect(`CHAT FINISH: CHAT: true CONTENT: Chat Done CONTINUATION: false suffix
`).Equal(t, resp.Content)
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step3"))
}
Expand Down
3 changes: 2 additions & 1 deletion pkg/tests/testdata/TestAgents/call1.golden
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@
],
"usage": {}
}
]
],
"chat": true
}`
3 changes: 2 additions & 1 deletion pkg/tests/testdata/TestAgents/call2.golden
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@
],
"usage": {}
}
]
],
"chat": true
}`
3 changes: 2 additions & 1 deletion pkg/tests/testdata/TestAgents/call3.golden
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,6 @@
],
"usage": {}
}
]
],
"chat": true
}`
3 changes: 2 additions & 1 deletion pkg/tests/testdata/TestAgents/call4.golden
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
],
"usage": {}
}
]
],
"chat": true
}`
Loading