diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index 58de592b..3a0ecba6 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -11,6 +11,8 @@ import ( "sync" "time" + cryptorand "crypto/rand" + "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" ) @@ -19,6 +21,7 @@ var ports Ports type Ports struct { daemonPorts map[string]int64 + daemonTokens map[string]string daemonsRunning map[string]func() daemonLock sync.Mutex @@ -119,7 +122,30 @@ func getPath(instructions string) (string, string) { return strings.TrimSpace(rest), strings.TrimSpace(value) } -func (e *Engine) startDaemon(tool types.Tool) (string, error) { +func getDaemonToken(toolID string) (string, error) { + token, ok := ports.daemonTokens[toolID] + if !ok { + // Generate a new token. + tokenBytes := make([]byte, 50) + count, err := cryptorand.Read(tokenBytes) + if err != nil { + return "", fmt.Errorf("failed to generate daemon token: %w", err) + } else if count != len(tokenBytes) { + return "", fmt.Errorf("failed to generate daemon token") + } + + token = fmt.Sprintf("%x", tokenBytes) + + if ports.daemonTokens == nil { + ports.daemonTokens = map[string]string{} + } + ports.daemonTokens[toolID] = token + } + + return token, nil +} + +func (e *Engine) startDaemon(tool types.Tool) (string, string, error) { ports.daemonLock.Lock() defer ports.daemonLock.Unlock() @@ -127,10 +153,15 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { instructions, path := getPath(instructions) tool.Instructions = types.CommandPrefix + instructions + token, err := getDaemonToken(tool.ID) + if err != nil { + return "", "", err + } + port, ok := ports.daemonPorts[tool.ID] url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path) if ok && ports.daemonsRunning[url] != nil { - return url, nil + return url, token, nil } if ports.daemonCtx == nil { @@ -149,18 +180,19 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { cmd, stop, err := e.newCommand(ctx, []string{ fmt.Sprintf("PORT=%d", port), fmt.Sprintf("GPTSCRIPT_PORT=%d", port), + fmt.Sprintf("GPTSCRIPT_DAEMON_TOKEN=%s", token), }, tool, "{}", false, ) if err != nil { - return url, err + return url, "", err } r, w, err := os.Pipe() if err != nil { - return "", err + return "", "", err } // Loop back to gptscript to help with process supervision @@ -178,7 +210,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { log.Infof("launched [%s][%s] port [%d] %v", tool.Name, tool.ID, port, cmd.Args) if err := cmd.Start(); err != nil { stop() - return url, err + return url, "", err } if ports.daemonPorts == nil { @@ -217,20 +249,20 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { _, _ = io.ReadAll(resp.Body) _ = resp.Body.Close() }() - return url, nil + return url, token, nil } select { case <-killedCtx.Done(): - return url, fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx)) + return url, "", fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx)) case <-time.After(time.Second): } } - return url, fmt.Errorf("timeout waiting for 200 response from GET %s", url) + return url, "", fmt.Errorf("timeout waiting for 200 response from GET %s", url) } func (e *Engine) runDaemon(ctx Context, tool types.Tool, input string) (cmdRet *Return, cmdErr error) { - url, err := e.startDaemon(tool) + url, _, err := e.startDaemon(tool) if err != nil { return nil, err } diff --git a/pkg/engine/http.go b/pkg/engine/http.go index 49738b1a..a9a635e8 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -39,7 +39,10 @@ func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Re return nil, err } - var requestedEnvVars map[string]struct{} + var ( + requestedEnvVars map[string]struct{} + daemonToken string + ) if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) { referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix) referencedToolRefs, ok := tool.ToolMapping[referencedToolName] @@ -50,7 +53,7 @@ func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Re if !ok { return nil, fmt.Errorf("failed to find tool [%s] for [%s]", referencedToolName, parsed.Hostname()) } - toolURL, err = e.startDaemon(referencedTool) + toolURL, daemonToken, err = e.startDaemon(referencedTool) if err != nil { return nil, err } @@ -85,6 +88,10 @@ func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Re return nil, err } + if daemonToken != "" { + req.Header.Add("X-GPTScript-Daemon-Token", daemonToken) + } + for _, k := range slices.Sorted(maps.Keys(envMap)) { if _, ok := requestedEnvVars[k]; ok || strings.HasPrefix(k, "GPTSCRIPT_WORKSPACE_") { req.Header.Add("X-GPTScript-Env", k+"="+envMap[k])