Skip to content

Commit 20f384d

Browse files
authored
enhance: allocate a random token for each daemon tool (#972)
Signed-off-by: Grant Linville <[email protected]>
1 parent 9b2832d commit 20f384d

File tree

2 files changed

+50
-11
lines changed

2 files changed

+50
-11
lines changed

pkg/engine/daemon.go

+41-9
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"sync"
1212
"time"
1313

14+
cryptorand "crypto/rand"
15+
1416
"github.com/gptscript-ai/gptscript/pkg/system"
1517
"github.com/gptscript-ai/gptscript/pkg/types"
1618
)
@@ -19,6 +21,7 @@ var ports Ports
1921

2022
type Ports struct {
2123
daemonPorts map[string]int64
24+
daemonTokens map[string]string
2225
daemonsRunning map[string]func()
2326
daemonLock sync.Mutex
2427

@@ -119,18 +122,46 @@ func getPath(instructions string) (string, string) {
119122
return strings.TrimSpace(rest), strings.TrimSpace(value)
120123
}
121124

122-
func (e *Engine) startDaemon(tool types.Tool) (string, error) {
125+
func getDaemonToken(toolID string) (string, error) {
126+
token, ok := ports.daemonTokens[toolID]
127+
if !ok {
128+
// Generate a new token.
129+
tokenBytes := make([]byte, 50)
130+
count, err := cryptorand.Read(tokenBytes)
131+
if err != nil {
132+
return "", fmt.Errorf("failed to generate daemon token: %w", err)
133+
} else if count != len(tokenBytes) {
134+
return "", fmt.Errorf("failed to generate daemon token")
135+
}
136+
137+
token = fmt.Sprintf("%x", tokenBytes)
138+
139+
if ports.daemonTokens == nil {
140+
ports.daemonTokens = map[string]string{}
141+
}
142+
ports.daemonTokens[toolID] = token
143+
}
144+
145+
return token, nil
146+
}
147+
148+
func (e *Engine) startDaemon(tool types.Tool) (string, string, error) {
123149
ports.daemonLock.Lock()
124150
defer ports.daemonLock.Unlock()
125151

126152
instructions := strings.TrimPrefix(tool.Instructions, types.DaemonPrefix)
127153
instructions, path := getPath(instructions)
128154
tool.Instructions = types.CommandPrefix + instructions
129155

156+
token, err := getDaemonToken(tool.ID)
157+
if err != nil {
158+
return "", "", err
159+
}
160+
130161
port, ok := ports.daemonPorts[tool.ID]
131162
url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
132163
if ok && ports.daemonsRunning[url] != nil {
133-
return url, nil
164+
return url, token, nil
134165
}
135166

136167
if ports.daemonCtx == nil {
@@ -149,18 +180,19 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
149180
cmd, stop, err := e.newCommand(ctx, []string{
150181
fmt.Sprintf("PORT=%d", port),
151182
fmt.Sprintf("GPTSCRIPT_PORT=%d", port),
183+
fmt.Sprintf("GPTSCRIPT_DAEMON_TOKEN=%s", token),
152184
},
153185
tool,
154186
"{}",
155187
false,
156188
)
157189
if err != nil {
158-
return url, err
190+
return url, "", err
159191
}
160192

161193
r, w, err := os.Pipe()
162194
if err != nil {
163-
return "", err
195+
return "", "", err
164196
}
165197

166198
// Loop back to gptscript to help with process supervision
@@ -178,7 +210,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
178210
log.Infof("launched [%s][%s] port [%d] %v", tool.Name, tool.ID, port, cmd.Args)
179211
if err := cmd.Start(); err != nil {
180212
stop()
181-
return url, err
213+
return url, "", err
182214
}
183215

184216
if ports.daemonPorts == nil {
@@ -217,20 +249,20 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
217249
_, _ = io.ReadAll(resp.Body)
218250
_ = resp.Body.Close()
219251
}()
220-
return url, nil
252+
return url, token, nil
221253
}
222254
select {
223255
case <-killedCtx.Done():
224-
return url, fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx))
256+
return url, "", fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx))
225257
case <-time.After(time.Second):
226258
}
227259
}
228260

229-
return url, fmt.Errorf("timeout waiting for 200 response from GET %s", url)
261+
return url, "", fmt.Errorf("timeout waiting for 200 response from GET %s", url)
230262
}
231263

232264
func (e *Engine) runDaemon(ctx Context, tool types.Tool, input string) (cmdRet *Return, cmdErr error) {
233-
url, err := e.startDaemon(tool)
265+
url, _, err := e.startDaemon(tool)
234266
if err != nil {
235267
return nil, err
236268
}

pkg/engine/http.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Re
3939
return nil, err
4040
}
4141

42-
var requestedEnvVars map[string]struct{}
42+
var (
43+
requestedEnvVars map[string]struct{}
44+
daemonToken string
45+
)
4346
if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) {
4447
referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix)
4548
referencedToolRefs, ok := tool.ToolMapping[referencedToolName]
@@ -50,7 +53,7 @@ func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Re
5053
if !ok {
5154
return nil, fmt.Errorf("failed to find tool [%s] for [%s]", referencedToolName, parsed.Hostname())
5255
}
53-
toolURL, err = e.startDaemon(referencedTool)
56+
toolURL, daemonToken, err = e.startDaemon(referencedTool)
5457
if err != nil {
5558
return nil, err
5659
}
@@ -85,6 +88,10 @@ func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Re
8588
return nil, err
8689
}
8790

91+
if daemonToken != "" {
92+
req.Header.Add("X-GPTScript-Daemon-Token", daemonToken)
93+
}
94+
8895
for _, k := range slices.Sorted(maps.Keys(envMap)) {
8996
if _, ok := requestedEnvVars[k]; ok || strings.HasPrefix(k, "GPTSCRIPT_WORKSPACE_") {
9097
req.Header.Add("X-GPTScript-Env", k+"="+envMap[k])

0 commit comments

Comments
 (0)