diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index a76a3556..bea8439a 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -99,6 +99,7 @@ const ( CredentialToolCategory ToolCategory = "credential" ContextToolCategory ToolCategory = "context" InputToolCategory ToolCategory = "input" + OutputToolCategory ToolCategory = "output" NoCategory ToolCategory = "" ) diff --git a/pkg/parser/parser.go b/pkg/parser/parser.go index e72dc32f..b998320b 100644 --- a/pkg/parser/parser.go +++ b/pkg/parser/parser.go @@ -109,6 +109,10 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) { tool.Parameters.InputFilters = append(tool.Parameters.InputFilters, csv(value)...) case "shareinputfilter", "shareinputfilters": tool.Parameters.ExportInputFilters = append(tool.Parameters.ExportInputFilters, csv(value)...) + case "outputfilter", "outputfilters": + tool.Parameters.OutputFilters = append(tool.Parameters.OutputFilters, csv(value)...) + case "shareoutputfilter", "shareoutputfilters": + tool.Parameters.ExportOutputFilters = append(tool.Parameters.ExportOutputFilters, csv(value)...) case "agent", "agents": tool.Parameters.Agents = append(tool.Parameters.Agents, csv(value)...) case "globaltool", "globaltools": @@ -194,6 +198,7 @@ func (c *context) finish(tools *[]Node) { c.tool.GlobalModelName != "" || len(c.tool.GlobalTools) > 0 || len(c.tool.ExportInputFilters) > 0 || + len(c.tool.ExportOutputFilters) > 0 || c.tool.Chat { *tools = append(*tools, Node{ ToolNode: &ToolNode{ diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index fb5a3ab7..9f682efa 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -215,3 +215,27 @@ share input filters: shared }}, }}).Equal(t, out) } + +func TestParseOutput(t *testing.T) { + output := ` +output filters: output +share output filters: shared +` + out, err := Parse(strings.NewReader(output)) + require.NoError(t, err) + autogold.Expect(Document{Nodes: []Node{ + {ToolNode: &ToolNode{ + Tool: types.Tool{ + ToolDef: types.ToolDef{ + Parameters: types.Parameters{ + OutputFilters: []string{ + "output", + }, + ExportOutputFilters: []string{"shared"}, + }, + }, + Source: types.ToolSource{LineNo: 1}, + }, + }}, + }}).Equal(t, out) +} diff --git a/pkg/runner/input.go b/pkg/runner/input.go index 0d8cb7f0..7d77330e 100644 --- a/pkg/runner/input.go +++ b/pkg/runner/input.go @@ -1,6 +1,7 @@ package runner import ( + "encoding/json" "fmt" "github.com/gptscript-ai/gptscript/pkg/engine" @@ -13,7 +14,13 @@ func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []stri } for _, inputToolRef := range inputToolRefs { - res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, input, "", engine.InputToolCategory) + inputData, err := json.Marshal(map[string]any{ + "input": input, + }) + if err != nil { + return "", fmt.Errorf("failed to marshal input: %w", err) + } + res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, string(inputData), "", engine.InputToolCategory) if err != nil { return "", err } diff --git a/pkg/runner/output.go b/pkg/runner/output.go new file mode 100644 index 00000000..858d106c --- /dev/null +++ b/pkg/runner/output.go @@ -0,0 +1,72 @@ +package runner + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/gptscript-ai/gptscript/pkg/engine" +) + +func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, state *State, retErr error) (*State, error) { + outputToolRefs, err := callCtx.Tool.GetOutputFilterTools(*callCtx.Program) + if err != nil { + return nil, err + } + + if len(outputToolRefs) == 0 { + return state, retErr + } + + var ( + continuation bool + chatFinish bool + output string + ) + + if errMessage := (*engine.ErrChatFinish)(nil); errors.As(retErr, &errMessage) && callCtx.Tool.Chat { + chatFinish = true + output = errMessage.Message + } else if retErr != nil { + return state, retErr + } else if state.Continuation != nil && state.Continuation.Result != nil { + continuation = true + output = *state.Continuation.Result + } else if state.Result != nil { + output = *state.Result + } else { + return state, nil + } + + for _, outputToolRef := range outputToolRefs { + inputData, err := json.Marshal(map[string]any{ + "output": output, + "chatFinish": chatFinish, + "continuation": continuation, + "chat": callCtx.Tool.Chat, + }) + if err != nil { + return nil, fmt.Errorf("marshaling input for output filter: %w", err) + } + res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, string(inputData), "", engine.OutputToolCategory) + if err != nil { + return nil, err + } + if res.Result == nil { + return nil, fmt.Errorf("invalid state: output tool [%s] can not result in a chat continuation", outputToolRef.Reference) + } + output = *res.Result + } + + if chatFinish { + return state, &engine.ErrChatFinish{ + Message: output, + } + } else if continuation { + state.Continuation.Result = &output + } else { + state.Result = &output + } + + return state, nil +} diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index 7df697ff..fb2cba0d 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -536,7 +536,11 @@ type Needed struct { Input string `json:"input,omitempty"` } -func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (*State, error) { +func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (retState *State, retErr error) { + defer func() { + retState, retErr = r.handleOutput(callCtx, monitor, env, retState, retErr) + }() + if state.StartContinuation { return nil, fmt.Errorf("invalid state, resume should not have StartContinuation set to true") } diff --git a/pkg/tests/runner_test.go b/pkg/tests/runner_test.go index 1421e849..db185d75 100644 --- a/pkg/tests/runner_test.go +++ b/pkg/tests/runner_test.go @@ -849,6 +849,52 @@ func TestInput(t *testing.T) { autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2")) } +func TestOutput(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip() + } + + r := tester.NewRunner(t) + r.RespondWith(tester.Result{ + Text: "Response 1", + }) + + prg, err := r.Load("") + require.NoError(t, err) + + resp, err := r.Chat(context.Background(), nil, prg, nil, "Input 1") + require.NoError(t, err) + r.AssertResponded(t) + assert.False(t, resp.Done) + autogold.Expect(`CHAT: true CONTENT: Response 1 CONTINUATION: true FINISH: false suffix +`).Equal(t, resp.Content) + autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1")) + + r.RespondWith(tester.Result{ + Text: "Response 2", + }) + resp, err = r.Chat(context.Background(), resp.State, prg, nil, "Input 2") + require.NoError(t, err) + r.AssertResponded(t) + assert.False(t, resp.Done) + autogold.Expect(`CHAT: true CONTENT: Response 2 CONTINUATION: true FINISH: false suffix +`).Equal(t, resp.Content) + autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2")) + + r.RespondWith(tester.Result{ + Err: &engine.ErrChatFinish{ + Message: "Chat Done", + }, + }) + resp, err = r.Chat(context.Background(), resp.State, prg, nil, "Input 3") + require.NoError(t, err) + r.AssertResponded(t) + assert.True(t, resp.Done) + autogold.Expect(`CHAT FINISH: CHAT: true CONTENT: Chat Done CONTINUATION: false FINISH: true suffix +`).Equal(t, resp.Content) + autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step3")) +} + func TestSysContext(t *testing.T) { if runtime.GOOS == "windows" { t.Skip() diff --git a/pkg/tests/testdata/TestInput/test.gpt b/pkg/tests/testdata/TestInput/test.gpt index bcb85a43..79522d90 100644 --- a/pkg/tests/testdata/TestInput/test.gpt +++ b/pkg/tests/testdata/TestInput/test.gpt @@ -7,9 +7,10 @@ Tool body --- name: taunt args: foo: this is useless +args: input: this is used #!/bin/bash -echo "No, ${GPTSCRIPT_INPUT}!" +echo "No, ${INPUT}!" --- name: exporter @@ -18,6 +19,7 @@ share input filters: taunt2 --- name: taunt2 args: foo: this is useless +args: input: this is used #!/bin/bash -echo "${GPTSCRIPT_INPUT} ha ha ha" \ No newline at end of file +echo "${INPUT} ha ha ha" \ No newline at end of file diff --git a/pkg/tests/testdata/TestOutput/call1-resp.golden b/pkg/tests/testdata/TestOutput/call1-resp.golden new file mode 100644 index 00000000..a6c5b94a --- /dev/null +++ b/pkg/tests/testdata/TestOutput/call1-resp.golden @@ -0,0 +1,9 @@ +`{ + "role": "assistant", + "content": [ + { + "text": "Response 1" + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestOutput/call1.golden b/pkg/tests/testdata/TestOutput/call1.golden new file mode 100644 index 00000000..9430afee --- /dev/null +++ b/pkg/tests/testdata/TestOutput/call1.golden @@ -0,0 +1,24 @@ +`{ + "model": "gpt-4o", + "internalSystemPrompt": false, + "messages": [ + { + "role": "system", + "content": [ + { + "text": "\nTool body" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 1" + } + ], + "usage": {} + } + ] +}` diff --git a/pkg/tests/testdata/TestOutput/call2-resp.golden b/pkg/tests/testdata/TestOutput/call2-resp.golden new file mode 100644 index 00000000..e5170fb8 --- /dev/null +++ b/pkg/tests/testdata/TestOutput/call2-resp.golden @@ -0,0 +1,9 @@ +`{ + "role": "assistant", + "content": [ + { + "text": "Response 2" + } + ], + "usage": {} +}` diff --git a/pkg/tests/testdata/TestOutput/call2.golden b/pkg/tests/testdata/TestOutput/call2.golden new file mode 100644 index 00000000..32bb7039 --- /dev/null +++ b/pkg/tests/testdata/TestOutput/call2.golden @@ -0,0 +1,42 @@ +`{ + "model": "gpt-4o", + "internalSystemPrompt": false, + "messages": [ + { + "role": "system", + "content": [ + { + "text": "\nTool body" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 1" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "Response 1" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 2" + } + ], + "usage": {} + } + ] +}` diff --git a/pkg/tests/testdata/TestOutput/call3.golden b/pkg/tests/testdata/TestOutput/call3.golden new file mode 100644 index 00000000..01aed5eb --- /dev/null +++ b/pkg/tests/testdata/TestOutput/call3.golden @@ -0,0 +1,60 @@ +`{ + "model": "gpt-4o", + "internalSystemPrompt": false, + "messages": [ + { + "role": "system", + "content": [ + { + "text": "\nTool body" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 1" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "Response 1" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 2" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "Response 2" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 3" + } + ], + "usage": {} + } + ] +}` diff --git a/pkg/tests/testdata/TestOutput/step1.golden b/pkg/tests/testdata/TestOutput/step1.golden new file mode 100644 index 00000000..46f1b8e8 --- /dev/null +++ b/pkg/tests/testdata/TestOutput/step1.golden @@ -0,0 +1,47 @@ +`{ + "done": false, + "content": "CHAT: true CONTENT: Response 1 CONTINUATION: true FINISH: false suffix\n", + "toolID": "testdata/TestOutput/test.gpt:", + "state": { + "continuation": { + "state": { + "input": "Input 1", + "completion": { + "model": "gpt-4o", + "internalSystemPrompt": false, + "messages": [ + { + "role": "system", + "content": [ + { + "text": "\nTool body" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 1" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "Response 1" + } + ], + "usage": {} + } + ] + } + }, + "result": "CHAT: true CONTENT: Response 1 CONTINUATION: true FINISH: false suffix\n" + }, + "continuationToolID": "testdata/TestOutput/test.gpt:" + } +}` diff --git a/pkg/tests/testdata/TestOutput/step2.golden b/pkg/tests/testdata/TestOutput/step2.golden new file mode 100644 index 00000000..d5fd89a0 --- /dev/null +++ b/pkg/tests/testdata/TestOutput/step2.golden @@ -0,0 +1,65 @@ +`{ + "done": false, + "content": "CHAT: true CONTENT: Response 2 CONTINUATION: true FINISH: false suffix\n", + "toolID": "testdata/TestOutput/test.gpt:", + "state": { + "continuation": { + "state": { + "input": "Input 1", + "completion": { + "model": "gpt-4o", + "internalSystemPrompt": false, + "messages": [ + { + "role": "system", + "content": [ + { + "text": "\nTool body" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 1" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "Response 1" + } + ], + "usage": {} + }, + { + "role": "user", + "content": [ + { + "text": "Input 2" + } + ], + "usage": {} + }, + { + "role": "assistant", + "content": [ + { + "text": "Response 2" + } + ], + "usage": {} + } + ] + } + }, + "result": "CHAT: true CONTENT: Response 2 CONTINUATION: true FINISH: false suffix\n" + }, + "continuationToolID": "testdata/TestOutput/test.gpt:" + } +}` diff --git a/pkg/tests/testdata/TestOutput/step3.golden b/pkg/tests/testdata/TestOutput/step3.golden new file mode 100644 index 00000000..c4e63adc --- /dev/null +++ b/pkg/tests/testdata/TestOutput/step3.golden @@ -0,0 +1,6 @@ +`{ + "done": true, + "content": "CHAT FINISH: CHAT: true CONTENT: Chat Done CONTINUATION: false FINISH: true suffix\n", + "toolID": "", + "state": null +}` diff --git a/pkg/tests/testdata/TestOutput/test.gpt b/pkg/tests/testdata/TestOutput/test.gpt new file mode 100644 index 00000000..cc35faa0 --- /dev/null +++ b/pkg/tests/testdata/TestOutput/test.gpt @@ -0,0 +1,31 @@ +output filter: prefix +context: context +chat: true + +Tool body + +--- +name: context +share output filters: suffix + +--- +name: prefix +args: chat: is it chat +args: output: the output content +args: continuation: if this is a non-terminating response +args: chatFinish: chat finish message + +#!/bin/bash + +echo CHAT: ${CHAT} +echo CONTENT: ${OUTPUT} +echo CONTINUATION: ${CONTINUATION} +echo FINISH: ${CHATFINISH} + +--- +name: suffix +args: output: the output content + +#!/bin/bash + +echo ${OUTPUT} suffix \ No newline at end of file diff --git a/pkg/types/tool.go b/pkg/types/tool.go index dd89c471..a5124796 100644 --- a/pkg/types/tool.go +++ b/pkg/types/tool.go @@ -120,28 +120,30 @@ func (p Program) SetBlocking() Program { type BuiltinFunc func(ctx context.Context, env []string, input string, progress chan<- string) (string, error) type Parameters struct { - Name string `json:"name,omitempty"` - Description string `json:"description,omitempty"` - MaxTokens int `json:"maxTokens,omitempty"` - ModelName string `json:"modelName,omitempty"` - ModelProvider bool `json:"modelProvider,omitempty"` - JSONResponse bool `json:"jsonResponse,omitempty"` - Chat bool `json:"chat,omitempty"` - Temperature *float32 `json:"temperature,omitempty"` - Cache *bool `json:"cache,omitempty"` - InternalPrompt *bool `json:"internalPrompt"` - Arguments *openapi3.Schema `json:"arguments,omitempty"` - Tools []string `json:"tools,omitempty"` - GlobalTools []string `json:"globalTools,omitempty"` - GlobalModelName string `json:"globalModelName,omitempty"` - Context []string `json:"context,omitempty"` - ExportContext []string `json:"exportContext,omitempty"` - Export []string `json:"export,omitempty"` - Agents []string `json:"agents,omitempty"` - Credentials []string `json:"credentials,omitempty"` - InputFilters []string `json:"inputFilters,omitempty"` - ExportInputFilters []string `json:"exportInputFilters,omitempty"` - Blocking bool `json:"-"` + Name string `json:"name,omitempty"` + Description string `json:"description,omitempty"` + MaxTokens int `json:"maxTokens,omitempty"` + ModelName string `json:"modelName,omitempty"` + ModelProvider bool `json:"modelProvider,omitempty"` + JSONResponse bool `json:"jsonResponse,omitempty"` + Chat bool `json:"chat,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + Cache *bool `json:"cache,omitempty"` + InternalPrompt *bool `json:"internalPrompt"` + Arguments *openapi3.Schema `json:"arguments,omitempty"` + Tools []string `json:"tools,omitempty"` + GlobalTools []string `json:"globalTools,omitempty"` + GlobalModelName string `json:"globalModelName,omitempty"` + Context []string `json:"context,omitempty"` + ExportContext []string `json:"exportContext,omitempty"` + Export []string `json:"export,omitempty"` + Agents []string `json:"agents,omitempty"` + Credentials []string `json:"credentials,omitempty"` + InputFilters []string `json:"inputFilters,omitempty"` + ExportInputFilters []string `json:"exportInputFilters,omitempty"` + OutputFilters []string `json:"outputFilters,omitempty"` + ExportOutputFilters []string `json:"exportOutputFilters,omitempty"` + Blocking bool `json:"-"` } func (p Parameters) ToolRefNames() []string { @@ -153,7 +155,9 @@ func (p Parameters) ToolRefNames() []string { p.Context, p.Credentials, p.InputFilters, - p.ExportInputFilters) + p.ExportInputFilters, + p.OutputFilters, + p.ExportOutputFilters) } type ToolDef struct { @@ -419,6 +423,12 @@ func (t ToolDef) String() string { if len(t.Parameters.ExportInputFilters) != 0 { _, _ = fmt.Fprintf(buf, "Share Input Filters: %s\n", strings.Join(t.Parameters.ExportInputFilters, ", ")) } + if len(t.Parameters.OutputFilters) != 0 { + _, _ = fmt.Fprintf(buf, "Output Filters: %s\n", strings.Join(t.Parameters.OutputFilters, ", ")) + } + if len(t.Parameters.ExportOutputFilters) != 0 { + _, _ = fmt.Fprintf(buf, "Share Output Filters: %s\n", strings.Join(t.Parameters.ExportOutputFilters, ", ")) + } if t.Parameters.MaxTokens != 0 { _, _ = fmt.Fprintf(buf, "Max Tokens: %d\n", t.Parameters.MaxTokens) } @@ -521,6 +531,31 @@ func (t Tool) GetContextTools(prg Program) ([]ToolReference, error) { return result.List() } +func (t Tool) GetOutputFilterTools(program Program) ([]ToolReference, error) { + result := &toolRefSet{} + + outputFilterRefs, err := t.GetToolRefsFromNames(t.OutputFilters) + if err != nil { + return nil, err + } + + for _, outputFilterRef := range outputFilterRefs { + result.Add(outputFilterRef) + } + + contextRefs, err := t.GetContextTools(program) + if err != nil { + return nil, err + } + + for _, contextRef := range contextRefs { + contextTool := program.ToolSet[contextRef.ToolID] + result.AddAll(contextTool.GetToolRefsFromNames(contextTool.ExportOutputFilters)) + } + + return result.List() +} + func (t Tool) GetInputFilterTools(program Program) ([]ToolReference, error) { result := &toolRefSet{} diff --git a/pkg/types/tool_test.go b/pkg/types/tool_test.go index 6e3d98d3..43af6cee 100644 --- a/pkg/types/tool_test.go +++ b/pkg/types/tool_test.go @@ -9,28 +9,30 @@ import ( func TestToolDef_String(t *testing.T) { tool := ToolDef{ Parameters: Parameters{ - Name: "Tool Sample", - Description: "This is a sample tool", - MaxTokens: 1024, - ModelName: "ModelSample", - ModelProvider: true, - JSONResponse: true, - Chat: true, - Temperature: float32Ptr(0.8), - Cache: boolPtr(true), - InternalPrompt: boolPtr(true), - Arguments: ObjectSchema("arg1", "desc1", "arg2", "desc2"), - Tools: []string{"Tool1", "Tool2"}, - GlobalTools: []string{"GlobalTool1", "GlobalTool2"}, - GlobalModelName: "GlobalModelSample", - Context: []string{"Context1", "Context2"}, - ExportContext: []string{"ExportContext1", "ExportContext2"}, - Export: []string{"Export1", "Export2"}, - Agents: []string{"Agent1", "Agent2"}, - Credentials: []string{"Credential1", "Credential2"}, - Blocking: true, - InputFilters: []string{"Filter1", "Filter2"}, - ExportInputFilters: []string{"SharedFilter1", "SharedFilter2"}, + Name: "Tool Sample", + Description: "This is a sample tool", + MaxTokens: 1024, + ModelName: "ModelSample", + ModelProvider: true, + JSONResponse: true, + Chat: true, + Temperature: float32Ptr(0.8), + Cache: boolPtr(true), + InternalPrompt: boolPtr(true), + Arguments: ObjectSchema("arg1", "desc1", "arg2", "desc2"), + Tools: []string{"Tool1", "Tool2"}, + GlobalTools: []string{"GlobalTool1", "GlobalTool2"}, + GlobalModelName: "GlobalModelSample", + Context: []string{"Context1", "Context2"}, + ExportContext: []string{"ExportContext1", "ExportContext2"}, + Export: []string{"Export1", "Export2"}, + Agents: []string{"Agent1", "Agent2"}, + Credentials: []string{"Credential1", "Credential2"}, + Blocking: true, + InputFilters: []string{"Filter1", "Filter2"}, + ExportInputFilters: []string{"SharedFilter1", "SharedFilter2"}, + OutputFilters: []string{"Filter1", "Filter2"}, + ExportOutputFilters: []string{"SharedFilter1", "SharedFilter2"}, }, Instructions: "This is a sample instruction", } @@ -46,6 +48,8 @@ Context: Context1, Context2 Share Context: ExportContext1, ExportContext2 Input Filters: Filter1, Filter2 Share Input Filters: SharedFilter1, SharedFilter2 +Output Filters: Filter1, Filter2 +Share Output Filters: SharedFilter1, SharedFilter2 Max Tokens: 1024 Model: ModelSample Model Provider: true