Skip to content

Commit 6db2e59

Browse files
committed
enhance: share credential
Signed-off-by: Grant Linville <[email protected]>
1 parent 9ca6e93 commit 6db2e59

File tree

6 files changed

+228
-19
lines changed

6 files changed

+228
-19
lines changed

integration/cred_test.go

+16
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,19 @@ func TestGPTScriptCredential(t *testing.T) {
1111
require.NoError(t, err)
1212
require.Contains(t, out, "CREDENTIAL")
1313
}
14+
15+
// TestCredentialScopes makes sure that environment variables set by credential tools and shared credential tools
16+
// are only available to the correct tools. See scripts/credscopes.gpt for more details.
17+
func TestCredentialScopes(t *testing.T) {
18+
out, err := RunScript("scripts/credscopes.gpt", "--sub-tool", "oneOne")
19+
require.NoError(t, err)
20+
require.Contains(t, out, "good")
21+
22+
out, err = RunScript("scripts/credscopes.gpt", "--sub-tool", "twoOne")
23+
require.NoError(t, err)
24+
require.Contains(t, out, "good")
25+
26+
out, err = RunScript("scripts/credscopes.gpt", "--sub-tool", "twoTwo")
27+
require.NoError(t, err)
28+
require.Contains(t, out, "good")
29+
}

integration/helpers.go

+4
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,7 @@ func GPTScriptExec(args ...string) (string, error) {
1414
out, err := cmd.CombinedOutput()
1515
return string(out), err
1616
}
17+
18+
func RunScript(script string, options ...string) (string, error) {
19+
return GPTScriptExec(append(options, "--quiet", script)...)
20+
}

integration/scripts/credscopes.gpt

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# This script sets up a chain of tools in a tree structure.
2+
# The root is oneOne, with children twoOne and twoTwo, with children threeOne, threeTwo, and threeThree, with only
3+
# threeTwo shared between them.
4+
# Each tool should only have access to any credentials it defines and any credentials exported/shared by its
5+
# immediate children (but not grandchildren).
6+
# This script checks to make sure that this is working properly.
7+
name: oneOne
8+
tools: twoOne, twoTwo
9+
cred: getcred with oneOne as var and 11 as val
10+
11+
#!python3
12+
13+
import os
14+
15+
oneOne = os.getenv('oneOne')
16+
twoOne = os.getenv('twoOne')
17+
twoTwo = os.getenv('twoTwo')
18+
threeOne = os.getenv('threeOne')
19+
threeTwo = os.getenv('threeTwo')
20+
threeThree = os.getenv('threeThree')
21+
22+
if oneOne != '11':
23+
print('error: oneOne is not 11')
24+
exit(1)
25+
26+
if twoOne != '21':
27+
print('error: twoOne is not 21')
28+
exit(1)
29+
30+
if twoTwo != '22':
31+
print('error: twoTwo is not 22')
32+
exit(1)
33+
34+
if threeOne is not None:
35+
print('error: threeOne is not None')
36+
exit(1)
37+
38+
if threeTwo is not None:
39+
print('error: threeTwo is not None')
40+
exit(1)
41+
42+
if threeThree is not None:
43+
print('error: threeThree is not None')
44+
exit(1)
45+
46+
print('good')
47+
48+
---
49+
name: twoOne
50+
tools: threeOne, threeTwo
51+
exportcred: getcred with twoOne as var and 21 as val
52+
53+
#!python3
54+
55+
import os
56+
57+
oneOne = os.getenv('oneOne')
58+
twoOne = os.getenv('twoOne')
59+
twoTwo = os.getenv('twoTwo')
60+
threeOne = os.getenv('threeOne')
61+
threeTwo = os.getenv('threeTwo')
62+
threeThree = os.getenv('threeThree')
63+
64+
if oneOne is not None:
65+
print('error: oneOne is not None')
66+
exit(1)
67+
68+
if twoOne is not None:
69+
print('error: twoOne is not None')
70+
exit(1)
71+
72+
if twoTwo is not None:
73+
print('error: twoTwo is not None')
74+
exit(1)
75+
76+
if threeOne != '31':
77+
print('error: threeOne is not 31')
78+
exit(1)
79+
80+
if threeTwo != '32':
81+
print('error: threeTwo is not 32')
82+
exit(1)
83+
84+
if threeThree is not None:
85+
print('error: threeThree is not None')
86+
exit(1)
87+
88+
print('good')
89+
90+
---
91+
name: twoTwo
92+
tools: threeTwo, threeThree
93+
exportcred: getcred with twoTwo as var and 22 as val
94+
95+
#!python3
96+
97+
import os
98+
99+
oneOne = os.getenv('oneOne')
100+
twoOne = os.getenv('twoOne')
101+
twoTwo = os.getenv('twoTwo')
102+
threeOne = os.getenv('threeOne')
103+
threeTwo = os.getenv('threeTwo')
104+
threeThree = os.getenv('threeThree')
105+
106+
if oneOne is not None:
107+
print('error: oneOne is not None')
108+
exit(1)
109+
110+
if twoOne is not None:
111+
print('error: twoOne is not None')
112+
exit(1)
113+
114+
if twoTwo is not None:
115+
print('error: twoTwo is not None')
116+
exit(1)
117+
118+
if threeOne is not None:
119+
print('error: threeOne is not None')
120+
exit(1)
121+
122+
if threeTwo != '32':
123+
print('error: threeTwo is not 32')
124+
exit(1)
125+
126+
if threeThree != '33':
127+
print('error: threeThree is not 33')
128+
exit(1)
129+
130+
print('good')
131+
132+
---
133+
name: threeOne
134+
exportcred: getcred with threeOne as var and 31 as val
135+
136+
---
137+
name: threeTwo
138+
exportcred: getcred with threeTwo as var and 32 as val
139+
140+
---
141+
name: threeThree
142+
exportcred: getcred with threeThree as var and 33 as val
143+
144+
---
145+
name: getcred
146+
147+
#!python3
148+
149+
import os
150+
import json
151+
152+
var = os.getenv('var')
153+
val = os.getenv('val')
154+
155+
output = {
156+
"env": {
157+
var: val
158+
}
159+
}
160+
print(json.dumps(output))

pkg/parser/parser.go

+2
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
148148
}
149149
case "credentials", "creds", "credential", "cred":
150150
tool.Parameters.Credentials = append(tool.Parameters.Credentials, value)
151+
case "exportcredentials", "exportcreds", "exportcredential", "exportcred", "sharecredentials", "sharecreds", "sharecredential", "sharecred":
152+
tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, value)
151153
default:
152154
return false, nil
153155
}

pkg/runner/runner.go

+22-19
Original file line numberDiff line numberDiff line change
@@ -419,9 +419,13 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
419419
return nil, err
420420
}
421421

422-
if len(callCtx.Tool.Credentials) > 0 {
422+
credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
423+
if err != nil {
424+
return nil, err
425+
}
426+
if len(credTools) > 0 {
423427
var err error
424-
env, err = r.handleCredentials(callCtx, monitor, env)
428+
env, err = r.handleCredentials(callCtx, monitor, env, credTools)
425429
if err != nil {
426430
return nil, err
427431
}
@@ -552,9 +556,13 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
552556
progress, progressClose := streamProgress(&callCtx, monitor)
553557
defer progressClose()
554558

555-
if len(callCtx.Tool.Credentials) > 0 {
559+
credTools, err := callCtx.Tool.GetCredentialTools(*callCtx.Program, callCtx.AgentGroup)
560+
if err != nil {
561+
return nil, err
562+
}
563+
if len(credTools) > 0 {
556564
var err error
557-
env, err = r.handleCredentials(callCtx, monitor, env)
565+
env, err = r.handleCredentials(callCtx, monitor, env, credTools)
558566
if err != nil {
559567
return nil, err
560568
}
@@ -828,7 +836,7 @@ func getEventContent(content string, callCtx engine.Context) string {
828836
return content
829837
}
830838

831-
func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string) ([]string, error) {
839+
func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string, credToolRefs []types.ToolReference) ([]string, error) {
832840
// Since credential tools (usually) prompt the user, we want to only run one at a time.
833841
r.credMutex.Lock()
834842
defer r.credMutex.Unlock()
@@ -845,10 +853,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
845853
}
846854
}
847855

848-
for _, credToolName := range callCtx.Tool.Credentials {
849-
toolName, credentialAlias, args, err := types.ParseCredentialArgs(credToolName, callCtx.Input)
856+
for _, ref := range credToolRefs {
857+
toolName, credentialAlias, args, err := types.ParseCredentialArgs(ref.Reference, callCtx.Input)
850858
if err != nil {
851-
return nil, fmt.Errorf("failed to parse credential tool %q: %w", credToolName, err)
859+
return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err)
852860
}
853861

854862
credName := toolName
@@ -895,11 +903,6 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
895903
// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
896904
// and save it in the store.
897905
if !exists || c.IsExpired() {
898-
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
899-
if !ok || len(credToolRefs) != 1 {
900-
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
901-
}
902-
903906
// If the existing credential is expired, we need to provide it to the cred tool through the environment.
904907
if exists && c.IsExpired() {
905908
credJSON, err := json.Marshal(c)
@@ -914,22 +917,22 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
914917
if args != nil {
915918
inputBytes, err := json.Marshal(args)
916919
if err != nil {
917-
return nil, fmt.Errorf("failed to marshal args for tool %s: %w", credToolName, err)
920+
return nil, fmt.Errorf("failed to marshal args for tool %s: %w", ref.Reference, err)
918921
}
919922
input = string(inputBytes)
920923
}
921924

922-
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, credToolRefs[0].ToolID, input, "", engine.CredentialToolCategory)
925+
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, ref.ToolID, input, "", engine.CredentialToolCategory)
923926
if err != nil {
924-
return nil, fmt.Errorf("failed to run credential tool %s: %w", credToolName, err)
927+
return nil, fmt.Errorf("failed to run credential tool %s: %w", ref.Reference, err)
925928
}
926929

927930
if res.Result == nil {
928-
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName)
931+
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference)
929932
}
930933

931934
if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
932-
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err)
935+
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err)
933936
}
934937
c.ToolName = credName
935938
c.Type = credentials.CredentialTypeTool
@@ -943,7 +946,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
943946
}
944947

945948
// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
946-
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" {
949+
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
947950
if isEmpty {
948951
log.Warnf("Not saving empty credential for tool %s", toolName)
949952
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {

pkg/types/tool.go

+24
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ type Parameters struct {
139139
Export []string `json:"export,omitempty"`
140140
Agents []string `json:"agents,omitempty"`
141141
Credentials []string `json:"credentials,omitempty"`
142+
ExportCredentials []string `json:"exportCredentials,omitempty"`
142143
InputFilters []string `json:"inputFilters,omitempty"`
143144
ExportInputFilters []string `json:"exportInputFilters,omitempty"`
144145
OutputFilters []string `json:"outputFilters,omitempty"`
@@ -154,6 +155,7 @@ func (p Parameters) ToolRefNames() []string {
154155
p.ExportContext,
155156
p.Context,
156157
p.Credentials,
158+
p.ExportCredentials,
157159
p.InputFilters,
158160
p.ExportInputFilters,
159161
p.OutputFilters,
@@ -466,6 +468,11 @@ func (t ToolDef) String() string {
466468
_, _ = fmt.Fprintf(buf, "Credential: %s\n", cred)
467469
}
468470
}
471+
if len(t.Parameters.ExportCredentials) > 0 {
472+
for _, exportCred := range t.Parameters.ExportCredentials {
473+
_, _ = fmt.Fprintf(buf, "Share Credential: %s\n", exportCred)
474+
}
475+
}
469476
if t.Parameters.Chat {
470477
_, _ = fmt.Fprintf(buf, "Chat: true\n")
471478
}
@@ -675,6 +682,23 @@ func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([]
675682
return result.List()
676683
}
677684

685+
func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]ToolReference, error) {
686+
result := toolRefSet{}
687+
688+
result.AddAll(t.GetToolRefsFromNames(t.Credentials))
689+
690+
toolRefs, err := t.getCompletionToolRefs(prg, agentGroup)
691+
if err != nil {
692+
return nil, err
693+
}
694+
for _, toolRef := range toolRefs {
695+
referencedTool := prg.ToolSet[toolRef.ToolID]
696+
result.AddAll(referencedTool.GetToolRefsFromNames(referencedTool.ExportCredentials))
697+
}
698+
699+
return result.List()
700+
}
701+
678702
func toolRefsToCompletionTools(completionTools []ToolReference, prg Program) (result []CompletionTool) {
679703
toolNames := map[string]struct{}{}
680704

0 commit comments

Comments
 (0)