Skip to content

Commit deb1c61

Browse files
chore: add type field to tools
This allows the tools field to reference tools, context tools, or agent tools without using the context:, agent: fields.
1 parent a7509b0 commit deb1c61

File tree

7 files changed

+161
-41
lines changed

7 files changed

+161
-41
lines changed

pkg/parser/parser.go

+2
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
150150
tool.Parameters.Credentials = append(tool.Parameters.Credentials, value)
151151
case "sharecredentials", "sharecreds", "sharecredential", "sharecred", "sharedcredentials", "sharedcreds", "sharedcredential", "sharedcred":
152152
tool.Parameters.ExportCredentials = append(tool.Parameters.ExportCredentials, value)
153+
case "type":
154+
tool.Type = types.ToolType(strings.ToLower(value))
153155
default:
154156
return false, nil
155157
}

pkg/runner/runner.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
332332
}
333333

334334
func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ *State, _ error) {
335-
toolRefs, err := callCtx.Program.GetContextToolRefs(callCtx.Tool.ID)
335+
toolRefs, err := callCtx.Tool.GetContextTools(*callCtx.Program)
336336
if err != nil {
337337
return nil, nil, err
338338
}

pkg/tests/runner_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -995,3 +995,8 @@ func TestMissingTool(t *testing.T) {
995995
r.AssertResponded(t)
996996
autogold.Expect("TEST RESULT CALL: 2").Equal(t, resp)
997997
}
998+
999+
func TestToolRefAll(t *testing.T) {
1000+
r := tester.NewRunner(t)
1001+
r.RunDefault()
1002+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
`{
2+
"role": "assistant",
3+
"content": [
4+
{
5+
"text": "TEST RESULT CALL: 1"
6+
}
7+
],
8+
"usage": {}
9+
}`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
`{
2+
"model": "gpt-4o",
3+
"tools": [
4+
{
5+
"function": {
6+
"toolID": "testdata/TestToolRefAll/test.gpt:tool",
7+
"name": "tool",
8+
"parameters": {
9+
"properties": {
10+
"toolArg": {
11+
"description": "stuff",
12+
"type": "string"
13+
}
14+
},
15+
"type": "object"
16+
}
17+
}
18+
},
19+
{
20+
"function": {
21+
"toolID": "testdata/TestToolRefAll/test.gpt:none",
22+
"name": "none",
23+
"parameters": {
24+
"properties": {
25+
"noneArg": {
26+
"description": "stuff",
27+
"type": "string"
28+
}
29+
},
30+
"type": "object"
31+
}
32+
}
33+
},
34+
{
35+
"function": {
36+
"toolID": "testdata/TestToolRefAll/test.gpt:agentAssistant",
37+
"name": "agent",
38+
"parameters": {
39+
"properties": {
40+
"defaultPromptParameter": {
41+
"description": "Prompt to send to the tool. This may be an instruction or question.",
42+
"type": "string"
43+
}
44+
},
45+
"type": "object"
46+
}
47+
}
48+
}
49+
],
50+
"messages": [
51+
{
52+
"role": "system",
53+
"content": [
54+
{
55+
"text": "\nContext Body\nMain tool"
56+
}
57+
],
58+
"usage": {}
59+
}
60+
]
61+
}`

pkg/types/tool.go

+78-40
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ var (
2626
DefaultFiles = []string{"agent.gpt", "tool.gpt"}
2727
)
2828

29+
type ToolType string
30+
31+
const (
32+
ToolTypeContext = ToolType("context")
33+
ToolTypeAgent = ToolType("agent")
34+
ToolTypeOutput = ToolType("output")
35+
ToolTypeInput = ToolType("input")
36+
ToolTypeAssistant = ToolType("assistant")
37+
ToolTypeTool = ToolType("tool")
38+
ToolTypeCredential = ToolType("credential")
39+
ToolTypeProvider = ToolType("provider")
40+
ToolTypeDefault = ToolType("")
41+
)
42+
2943
type ErrToolNotFound struct {
3044
ToolName string
3145
}
@@ -77,28 +91,6 @@ type ToolReference struct {
7791
ToolID string `json:"toolID,omitempty"`
7892
}
7993

80-
func (p Program) GetContextToolRefs(toolID string) ([]ToolReference, error) {
81-
return p.ToolSet[toolID].GetContextTools(p)
82-
}
83-
84-
func (p Program) GetCompletionTools() (result []CompletionTool, err error) {
85-
return Tool{
86-
ToolDef: ToolDef{
87-
Parameters: Parameters{
88-
Tools: []string{"main"},
89-
},
90-
},
91-
ToolMapping: map[string][]ToolReference{
92-
"main": {
93-
{
94-
Reference: "main",
95-
ToolID: p.EntryToolID,
96-
},
97-
},
98-
},
99-
}.GetCompletionTools(p)
100-
}
101-
10294
func (p Program) TopLevelTools() (result []Tool) {
10395
for _, tool := range p.ToolSet[p.EntryToolID].LocalTools {
10496
if target, ok := p.ToolSet[tool]; ok {
@@ -145,6 +137,7 @@ type Parameters struct {
145137
OutputFilters []string `json:"outputFilters,omitempty"`
146138
ExportOutputFilters []string `json:"exportOutputFilters,omitempty"`
147139
Blocking bool `json:"-"`
140+
Type ToolType `json:"type,omitempty"`
148141
}
149142

150143
func (p Parameters) ToolRefNames() []string {
@@ -347,6 +340,13 @@ func (t Tool) GetAgents(prg Program) (result []ToolReference, _ error) {
347340
return nil, err
348341
}
349342

343+
genericToolRefs, err := t.getCompletionToolRefs(prg, nil, ToolTypeAgent)
344+
if err != nil {
345+
return nil, err
346+
}
347+
348+
toolRefs = append(toolRefs, genericToolRefs...)
349+
350350
// Agent Tool refs must be named
351351
for i, toolRef := range toolRefs {
352352
if toolRef.Named != "" {
@@ -358,7 +358,9 @@ func (t Tool) GetAgents(prg Program) (result []ToolReference, _ error) {
358358
name = toolRef.Reference
359359
}
360360
normed := ToolNormalizer(name)
361-
normed = strings.TrimSuffix(strings.TrimSuffix(normed, "Agent"), "Assistant")
361+
if trimmed := strings.TrimSuffix(strings.TrimSuffix(normed, "Agent"), "Assistant"); trimmed != "" {
362+
normed = trimmed
363+
}
362364
toolRefs[i].Named = normed
363365
}
364366

@@ -404,6 +406,9 @@ func (t ToolDef) String() string {
404406
if t.Parameters.Description != "" {
405407
_, _ = fmt.Fprintf(buf, "Description: %s\n", t.Parameters.Description)
406408
}
409+
if t.Parameters.Type != ToolTypeDefault {
410+
_, _ = fmt.Fprintf(buf, "Type: %s\n", strings.ToUpper(string(t.Type[0]))+string(t.Type[1:]))
411+
}
407412
if len(t.Parameters.Agents) != 0 {
408413
_, _ = fmt.Fprintf(buf, "Agents: %s\n", strings.Join(t.Parameters.Agents, ", "))
409414
}
@@ -486,7 +491,7 @@ func (t ToolDef) String() string {
486491
return buf.String()
487492
}
488493

489-
func (t Tool) GetExportedContext(prg Program) ([]ToolReference, error) {
494+
func (t Tool) getExportedContext(prg Program) ([]ToolReference, error) {
490495
result := &toolRefSet{}
491496

492497
exportRefs, err := t.GetToolRefsFromNames(t.ExportContext)
@@ -498,13 +503,13 @@ func (t Tool) GetExportedContext(prg Program) ([]ToolReference, error) {
498503
result.Add(exportRef)
499504

500505
tool := prg.ToolSet[exportRef.ToolID]
501-
result.AddAll(tool.GetExportedContext(prg))
506+
result.AddAll(tool.getExportedContext(prg))
502507
}
503508

504509
return result.List()
505510
}
506511

507-
func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) {
512+
func (t Tool) getExportedTools(prg Program) ([]ToolReference, error) {
508513
result := &toolRefSet{}
509514

510515
exportRefs, err := t.GetToolRefsFromNames(t.Export)
@@ -514,7 +519,7 @@ func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) {
514519

515520
for _, exportRef := range exportRefs {
516521
result.Add(exportRef)
517-
result.AddAll(prg.ToolSet[exportRef.ToolID].GetExportedTools(prg))
522+
result.AddAll(prg.ToolSet[exportRef.ToolID].getExportedTools(prg))
518523
}
519524

520525
return result.List()
@@ -524,14 +529,23 @@ func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) {
524529
// contexts that are exported by the context tools. This will recurse all exports.
525530
func (t Tool) GetContextTools(prg Program) ([]ToolReference, error) {
526531
result := &toolRefSet{}
532+
result.AddAll(t.getDirectContextToolRefs(prg))
533+
result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeContext))
534+
return result.List()
535+
}
536+
537+
// GetContextTools returns all tools that are in the context of the tool including all the
538+
// contexts that are exported by the context tools. This will recurse all exports.
539+
func (t Tool) getDirectContextToolRefs(prg Program) ([]ToolReference, error) {
540+
result := &toolRefSet{}
527541

528542
contextRefs, err := t.GetToolRefsFromNames(t.Context)
529543
if err != nil {
530544
return nil, err
531545
}
532546

533547
for _, contextRef := range contextRefs {
534-
result.AddAll(prg.ToolSet[contextRef.ToolID].GetExportedContext(prg))
548+
result.AddAll(prg.ToolSet[contextRef.ToolID].getExportedContext(prg))
535549
result.Add(contextRef)
536550
}
537551

@@ -550,7 +564,9 @@ func (t Tool) GetOutputFilterTools(program Program) ([]ToolReference, error) {
550564
result.Add(outputFilterRef)
551565
}
552566

553-
contextRefs, err := t.GetContextTools(program)
567+
result.AddAll(t.getCompletionToolRefs(program, nil, ToolTypeOutput))
568+
569+
contextRefs, err := t.getDirectContextToolRefs(program)
554570
if err != nil {
555571
return nil, err
556572
}
@@ -575,7 +591,9 @@ func (t Tool) GetInputFilterTools(program Program) ([]ToolReference, error) {
575591
result.Add(inputFilterRef)
576592
}
577593

578-
contextRefs, err := t.GetContextTools(program)
594+
result.AddAll(t.getCompletionToolRefs(program, nil, ToolTypeInput))
595+
596+
contextRefs, err := t.getDirectContextToolRefs(program)
579597
if err != nil {
580598
return nil, err
581599
}
@@ -602,11 +620,28 @@ func (t Tool) GetNextAgentGroup(prg Program, agentGroup []ToolReference, toolID
602620
return agentGroup, nil
603621
}
604622

623+
func filterRefs(prg Program, refs []ToolReference, types ...ToolType) (result []ToolReference) {
624+
for _, ref := range refs {
625+
if slices.Contains(types, prg.ToolSet[ref.ToolID].Type) {
626+
result = append(result, ref)
627+
}
628+
}
629+
return
630+
}
631+
605632
func (t Tool) GetCompletionTools(prg Program, agentGroup ...ToolReference) (result []CompletionTool, err error) {
606-
refs, err := t.getCompletionToolRefs(prg, agentGroup)
633+
toolSet := &toolRefSet{}
634+
toolSet.AddAll(t.getCompletionToolRefs(prg, agentGroup, ToolTypeDefault, ToolTypeTool))
635+
636+
if err := t.addAgents(prg, toolSet); err != nil {
637+
return nil, err
638+
}
639+
640+
refs, err := toolSet.List()
607641
if err != nil {
608642
return nil, err
609643
}
644+
610645
return toolRefsToCompletionTools(refs, prg), nil
611646
}
612647

@@ -638,26 +673,30 @@ func (t Tool) addReferencedTools(prg Program, result *toolRefSet) error {
638673
result.Add(subToolRef)
639674

640675
// Get all tools exports
641-
result.AddAll(prg.ToolSet[subToolRef.ToolID].GetExportedTools(prg))
676+
result.AddAll(prg.ToolSet[subToolRef.ToolID].getExportedTools(prg))
642677
}
643678

644679
return nil
645680
}
646681

647682
func (t Tool) addContextExportedTools(prg Program, result *toolRefSet) error {
648-
contextTools, err := t.GetContextTools(prg)
683+
contextTools, err := t.getDirectContextToolRefs(prg)
649684
if err != nil {
650685
return err
651686
}
652687

653688
for _, contextTool := range contextTools {
654-
result.AddAll(prg.ToolSet[contextTool.ToolID].GetExportedTools(prg))
689+
result.AddAll(prg.ToolSet[contextTool.ToolID].getExportedTools(prg))
655690
}
656691

657692
return nil
658693
}
659694

660-
func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([]ToolReference, error) {
695+
func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference, types ...ToolType) ([]ToolReference, error) {
696+
if len(types) == 0 {
697+
types = []ToolType{ToolTypeDefault, ToolTypeTool}
698+
}
699+
661700
result := toolRefSet{}
662701

663702
if t.Chat {
@@ -677,18 +716,17 @@ func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([]
677716
return nil, err
678717
}
679718

680-
if err := t.addAgents(prg, &result); err != nil {
681-
return nil, err
682-
}
683-
684-
return result.List()
719+
refs, err := result.List()
720+
return filterRefs(prg, refs, types...), err
685721
}
686722

687723
func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]ToolReference, error) {
688724
result := toolRefSet{}
689725

690726
result.AddAll(t.GetToolRefsFromNames(t.Credentials))
691727

728+
result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeCredential))
729+
692730
toolRefs, err := t.getCompletionToolRefs(prg, agentGroup)
693731
if err != nil {
694732
return nil, err

pkg/types/tool_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ func TestToolDef_String(t *testing.T) {
3333
ExportInputFilters: []string{"SharedFilter1", "SharedFilter2"},
3434
OutputFilters: []string{"Filter1", "Filter2"},
3535
ExportOutputFilters: []string{"SharedFilter1", "SharedFilter2"},
36+
ExportCredentials: []string{"ExportCredential1", "ExportCredential2"},
37+
Type: ToolTypeContext,
3638
},
3739
Instructions: "This is a sample instruction",
3840
}
@@ -41,6 +43,7 @@ func TestToolDef_String(t *testing.T) {
4143
Global Tools: GlobalTool1, GlobalTool2
4244
Name: Tool Sample
4345
Description: This is a sample tool
46+
Type: Context
4447
Agents: Agent1, Agent2
4548
Tools: Tool1, Tool2
4649
Share Tools: Export1, Export2
@@ -60,6 +63,8 @@ Parameter: arg2: desc2
6063
Internal Prompt: true
6164
Credential: Credential1
6265
Credential: Credential2
66+
Share Credential: ExportCredential1
67+
Share Credential: ExportCredential2
6368
Chat: true
6469
6570
This is a sample instruction

0 commit comments

Comments
 (0)