Skip to content

feat: add input filters #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ func appendInputAsEnv(env []string, input string) []string {
dec := json.NewDecoder(bytes.NewReader([]byte(input)))
dec.UseNumber()

env = appendEnv(env, "GPTSCRIPT_INPUT", input)

if err := json.Unmarshal([]byte(input), &data); err != nil {
// ignore invalid JSON
return env
Expand All @@ -206,7 +208,6 @@ func appendInputAsEnv(env []string, input string) []string {
}
}

env = appendEnv(env, "GPTSCRIPT_INPUT", input)
return env
}

Expand Down
3 changes: 2 additions & 1 deletion pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ const (
ProviderToolCategory ToolCategory = "provider"
CredentialToolCategory ToolCategory = "credential"
ContextToolCategory ToolCategory = "context"
InputToolCategory ToolCategory = "input"
NoCategory ToolCategory = ""
)

Expand Down Expand Up @@ -180,7 +181,7 @@ func NewContext(ctx context.Context, prg *types.Program, input string) Context {
return callCtx
}

func (c *Context) SubCall(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) {
func (c *Context) SubCallContext(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) {
tool, ok := c.Program.ToolSet[toolID]
if !ok {
return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID)
Expand Down
9 changes: 7 additions & 2 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package openai

import (
"context"
"fmt"
"io"
"log/slog"
"os"
Expand All @@ -16,6 +15,7 @@ import (
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/hash"
"github.com/gptscript-ai/gptscript/pkg/mvl"
"github.com/gptscript-ai/gptscript/pkg/prompt"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
Expand All @@ -29,6 +29,7 @@ const (
var (
key = os.Getenv("OPENAI_API_KEY")
url = os.Getenv("OPENAI_BASE_URL")
log = mvl.Package()
)

type InvalidAuthError struct{}
Expand Down Expand Up @@ -305,7 +306,11 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
}

if len(msgs) == 0 {
return nil, fmt.Errorf("invalid request, no messages to send to LLM")
log.Errorf("invalid request, no messages to send to LLM")
return &types.CompletionMessage{
Role: types.CompletionMessageRoleTypeAssistant,
Content: types.Text(""),
}, nil
}

request := openai.ChatCompletionRequest{
Expand Down
11 changes: 9 additions & 2 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
tool.Parameters.Export = append(tool.Parameters.Export, csv(value)...)
case "tool", "tools":
tool.Parameters.Tools = append(tool.Parameters.Tools, csv(value)...)
case "inputfilter", "inputfilters":
tool.Parameters.InputFilters = append(tool.Parameters.InputFilters, csv(value)...)
case "shareinputfilter", "shareinputfilters":
tool.Parameters.ExportInputFilters = append(tool.Parameters.ExportInputFilters, csv(value)...)
case "agent", "agents":
tool.Parameters.Agents = append(tool.Parameters.Agents, csv(value)...)
case "globaltool", "globaltools":
Expand Down Expand Up @@ -183,10 +187,13 @@ type context struct {

func (c *context) finish(tools *[]Node) {
c.tool.Instructions = strings.TrimSpace(strings.Join(c.instructions, ""))
if c.tool.Instructions != "" || c.tool.Parameters.Name != "" ||
len(c.tool.Export) > 0 || len(c.tool.Tools) > 0 ||
if c.tool.Instructions != "" ||
c.tool.Parameters.Name != "" ||
len(c.tool.Export) > 0 ||
len(c.tool.Tools) > 0 ||
c.tool.GlobalModelName != "" ||
len(c.tool.GlobalTools) > 0 ||
len(c.tool.ExportInputFilters) > 0 ||
c.tool.Chat {
*tools = append(*tools, Node{
ToolNode: &ToolNode{
Expand Down
24 changes: 24 additions & 0 deletions pkg/parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,27 @@ name: bad
},
}}).Equal(t, out)
}

func TestParseInput(t *testing.T) {
input := `
input filters: input
share input filters: shared
`
out, err := Parse(strings.NewReader(input))
require.NoError(t, err)
autogold.Expect(Document{Nodes: []Node{
{ToolNode: &ToolNode{
Tool: types.Tool{
ToolDef: types.ToolDef{
Parameters: types.Parameters{
InputFilters: []string{
"input",
},
ExportInputFilters: []string{"shared"},
},
},
Source: types.ToolSource{LineNo: 1},
},
}},
}}).Equal(t, out)
}
27 changes: 27 additions & 0 deletions pkg/runner/input.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package runner

import (
"fmt"

"github.com/gptscript-ai/gptscript/pkg/engine"
)

func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []string, input string) (string, error) {
inputToolRefs, err := callCtx.Tool.GetInputFilterTools(*callCtx.Program)
if err != nil {
return "", err
}

for _, inputToolRef := range inputToolRefs {
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, input, "", engine.InputToolCategory)
if err != nil {
return "", err
}
if res.Result == nil {
return "", fmt.Errorf("invalid state: input tool [%s] can not result in a chat continuation", inputToolRef.Reference)
}
input = *res.Result
}

return input, nil
}
34 changes: 23 additions & 11 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,11 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
Content: input,
})

input, err := r.handleInput(callCtx, monitor, env, input)
if err != nil {
return nil, err
}

if len(callCtx.Tool.Credentials) > 0 {
var err error
env, err = r.handleCredentials(callCtx, monitor, env)
Expand All @@ -417,7 +422,6 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
}

var (
err error
newState *State
)
callCtx.InputContext, newState, err = r.getContext(callCtx, state, monitor, env, input)
Expand Down Expand Up @@ -446,7 +450,10 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
}

if !authResp.Accept {
msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message)
msg := authResp.Message
if msg == "" {
msg = "Tool call request has been denied"
}
return &State{
Continuation: &engine.Return{
Result: &msg,
Expand Down Expand Up @@ -631,8 +638,12 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
}

if state.ResumeInput != nil {
input, err := r.handleInput(callCtx, monitor, env, *state.ResumeInput)
if err != nil {
return state, err
}
engineResults = append(engineResults, engine.CallResult{
User: *state.ResumeInput,
User: input,
})
}

Expand Down Expand Up @@ -689,16 +700,22 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
}

func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string, toolCategory engine.ToolCategory) (*State, error) {
callCtx, err := parentContext.SubCall(ctx, input, toolID, callID, toolCategory)
callCtx, err := parentContext.SubCallContext(ctx, input, toolID, callID, toolCategory)
if err != nil {
return nil, err
}

if toolCategory == engine.ContextToolCategory && callCtx.Tool.IsNoop() {
return &State{
Result: new(string),
}, nil
}

return r.call(callCtx, monitor, env, input)
}

func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State, toolCategory engine.ToolCategory) (*State, error) {
callCtx, err := parentContext.SubCall(ctx, "", toolID, callID, toolCategory)
callCtx, err := parentContext.SubCallContext(ctx, "", toolID, callID, toolCategory)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -882,12 +899,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
input = string(inputBytes)
}

subCtx, err := callCtx.SubCall(callCtx.Ctx, input, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
if err != nil {
return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err)
}

res, err := r.call(subCtx, monitor, env, input)
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, credToolRefs[0].ToolID, input, "", engine.CredentialToolCategory)
if err != nil {
return nil, fmt.Errorf("failed to run credential tool %s: %w", credToolName, err)
}
Expand Down
25 changes: 25 additions & 0 deletions pkg/tests/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -822,3 +822,28 @@ func TestAgents(t *testing.T) {
autogold.Expect("TEST RESULT CALL: 4").Equal(t, resp.Content)
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1"))
}

func TestInput(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip()
}

r := tester.NewRunner(t)

prg, err := r.Load("")
require.NoError(t, err)

resp, err := r.Chat(context.Background(), nil, prg, nil, "You're stupid")
require.NoError(t, err)
r.AssertResponded(t)
assert.False(t, resp.Done)
autogold.Expect("TEST RESULT CALL: 1").Equal(t, resp.Content)
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1"))

resp, err = r.Chat(context.Background(), resp.State, prg, nil, "You're ugly")
require.NoError(t, err)
r.AssertResponded(t)
assert.False(t, resp.Done)
autogold.Expect("TEST RESULT CALL: 2").Equal(t, resp.Content)
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2"))
}
9 changes: 9 additions & 0 deletions pkg/tests/testdata/TestInput/call1-resp.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
`{
"role": "assistant",
"content": [
{
"text": "TEST RESULT CALL: 1"
}
],
"usage": {}
}`
24 changes: 24 additions & 0 deletions pkg/tests/testdata/TestInput/call1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
`{
"model": "gpt-4o",
"internalSystemPrompt": false,
"messages": [
{
"role": "system",
"content": [
{
"text": "\nTool body"
}
],
"usage": {}
},
{
"role": "user",
"content": [
{
"text": "No, You're stupid!\n ha ha ha\n"
}
],
"usage": {}
}
]
}`
9 changes: 9 additions & 0 deletions pkg/tests/testdata/TestInput/call2-resp.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
`{
"role": "assistant",
"content": [
{
"text": "TEST RESULT CALL: 2"
}
],
"usage": {}
}`
42 changes: 42 additions & 0 deletions pkg/tests/testdata/TestInput/call2.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
`{
"model": "gpt-4o",
"internalSystemPrompt": false,
"messages": [
{
"role": "system",
"content": [
{
"text": "\nTool body"
}
],
"usage": {}
},
{
"role": "user",
"content": [
{
"text": "No, You're stupid!\n ha ha ha\n"
}
],
"usage": {}
},
{
"role": "assistant",
"content": [
{
"text": "TEST RESULT CALL: 1"
}
],
"usage": {}
},
{
"role": "user",
"content": [
{
"text": "No, You're ugly!\n ha ha ha\n"
}
],
"usage": {}
}
]
}`
Loading
Loading