26
26
DefaultFiles = []string {"agent.gpt" , "tool.gpt" }
27
27
)
28
28
29
+ type ToolType string
30
+
31
+ const (
32
+ ToolTypeContext = ToolType ("context" )
33
+ ToolTypeAgent = ToolType ("agent" )
34
+ ToolTypeOutput = ToolType ("output" )
35
+ ToolTypeInput = ToolType ("input" )
36
+ ToolTypeAssistant = ToolType ("assistant" )
37
+ ToolTypeTool = ToolType ("tool" )
38
+ ToolTypeCredential = ToolType ("credential" )
39
+ ToolTypeProvider = ToolType ("provider" )
40
+ ToolTypeDefault = ToolType ("" )
41
+ )
42
+
29
43
type ErrToolNotFound struct {
30
44
ToolName string
31
45
}
@@ -77,28 +91,6 @@ type ToolReference struct {
77
91
ToolID string `json:"toolID,omitempty"`
78
92
}
79
93
80
- func (p Program ) GetContextToolRefs (toolID string ) ([]ToolReference , error ) {
81
- return p .ToolSet [toolID ].GetContextTools (p )
82
- }
83
-
84
- func (p Program ) GetCompletionTools () (result []CompletionTool , err error ) {
85
- return Tool {
86
- ToolDef : ToolDef {
87
- Parameters : Parameters {
88
- Tools : []string {"main" },
89
- },
90
- },
91
- ToolMapping : map [string ][]ToolReference {
92
- "main" : {
93
- {
94
- Reference : "main" ,
95
- ToolID : p .EntryToolID ,
96
- },
97
- },
98
- },
99
- }.GetCompletionTools (p )
100
- }
101
-
102
94
func (p Program ) TopLevelTools () (result []Tool ) {
103
95
for _ , tool := range p .ToolSet [p .EntryToolID ].LocalTools {
104
96
if target , ok := p .ToolSet [tool ]; ok {
@@ -145,6 +137,7 @@ type Parameters struct {
145
137
OutputFilters []string `json:"outputFilters,omitempty"`
146
138
ExportOutputFilters []string `json:"exportOutputFilters,omitempty"`
147
139
Blocking bool `json:"-"`
140
+ Type ToolType `json:"type,omitempty"`
148
141
}
149
142
150
143
func (p Parameters ) ToolRefNames () []string {
@@ -347,6 +340,13 @@ func (t Tool) GetAgents(prg Program) (result []ToolReference, _ error) {
347
340
return nil , err
348
341
}
349
342
343
+ genericToolRefs , err := t .getCompletionToolRefs (prg , nil , ToolTypeAgent )
344
+ if err != nil {
345
+ return nil , err
346
+ }
347
+
348
+ toolRefs = append (toolRefs , genericToolRefs ... )
349
+
350
350
// Agent Tool refs must be named
351
351
for i , toolRef := range toolRefs {
352
352
if toolRef .Named != "" {
@@ -358,7 +358,9 @@ func (t Tool) GetAgents(prg Program) (result []ToolReference, _ error) {
358
358
name = toolRef .Reference
359
359
}
360
360
normed := ToolNormalizer (name )
361
- normed = strings .TrimSuffix (strings .TrimSuffix (normed , "Agent" ), "Assistant" )
361
+ if trimmed := strings .TrimSuffix (strings .TrimSuffix (normed , "Agent" ), "Assistant" ); trimmed != "" {
362
+ normed = trimmed
363
+ }
362
364
toolRefs [i ].Named = normed
363
365
}
364
366
@@ -404,6 +406,9 @@ func (t ToolDef) String() string {
404
406
if t .Parameters .Description != "" {
405
407
_ , _ = fmt .Fprintf (buf , "Description: %s\n " , t .Parameters .Description )
406
408
}
409
+ if t .Parameters .Type != ToolTypeDefault {
410
+ _ , _ = fmt .Fprintf (buf , "Type: %s\n " , strings .ToUpper (string (t .Type [0 ]))+ string (t .Type [1 :]))
411
+ }
407
412
if len (t .Parameters .Agents ) != 0 {
408
413
_ , _ = fmt .Fprintf (buf , "Agents: %s\n " , strings .Join (t .Parameters .Agents , ", " ))
409
414
}
@@ -486,7 +491,7 @@ func (t ToolDef) String() string {
486
491
return buf .String ()
487
492
}
488
493
489
- func (t Tool ) GetExportedContext (prg Program ) ([]ToolReference , error ) {
494
+ func (t Tool ) getExportedContext (prg Program ) ([]ToolReference , error ) {
490
495
result := & toolRefSet {}
491
496
492
497
exportRefs , err := t .GetToolRefsFromNames (t .ExportContext )
@@ -498,13 +503,13 @@ func (t Tool) GetExportedContext(prg Program) ([]ToolReference, error) {
498
503
result .Add (exportRef )
499
504
500
505
tool := prg .ToolSet [exportRef .ToolID ]
501
- result .AddAll (tool .GetExportedContext (prg ))
506
+ result .AddAll (tool .getExportedContext (prg ))
502
507
}
503
508
504
509
return result .List ()
505
510
}
506
511
507
- func (t Tool ) GetExportedTools (prg Program ) ([]ToolReference , error ) {
512
+ func (t Tool ) getExportedTools (prg Program ) ([]ToolReference , error ) {
508
513
result := & toolRefSet {}
509
514
510
515
exportRefs , err := t .GetToolRefsFromNames (t .Export )
@@ -514,7 +519,7 @@ func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) {
514
519
515
520
for _ , exportRef := range exportRefs {
516
521
result .Add (exportRef )
517
- result .AddAll (prg .ToolSet [exportRef .ToolID ].GetExportedTools (prg ))
522
+ result .AddAll (prg .ToolSet [exportRef .ToolID ].getExportedTools (prg ))
518
523
}
519
524
520
525
return result .List ()
@@ -524,14 +529,23 @@ func (t Tool) GetExportedTools(prg Program) ([]ToolReference, error) {
524
529
// contexts that are exported by the context tools. This will recurse all exports.
525
530
func (t Tool ) GetContextTools (prg Program ) ([]ToolReference , error ) {
526
531
result := & toolRefSet {}
532
+ result .AddAll (t .getDirectContextToolRefs (prg ))
533
+ result .AddAll (t .getCompletionToolRefs (prg , nil , ToolTypeContext ))
534
+ return result .List ()
535
+ }
536
+
537
+ // GetContextTools returns all tools that are in the context of the tool including all the
538
+ // contexts that are exported by the context tools. This will recurse all exports.
539
+ func (t Tool ) getDirectContextToolRefs (prg Program ) ([]ToolReference , error ) {
540
+ result := & toolRefSet {}
527
541
528
542
contextRefs , err := t .GetToolRefsFromNames (t .Context )
529
543
if err != nil {
530
544
return nil , err
531
545
}
532
546
533
547
for _ , contextRef := range contextRefs {
534
- result .AddAll (prg .ToolSet [contextRef .ToolID ].GetExportedContext (prg ))
548
+ result .AddAll (prg .ToolSet [contextRef .ToolID ].getExportedContext (prg ))
535
549
result .Add (contextRef )
536
550
}
537
551
@@ -550,7 +564,9 @@ func (t Tool) GetOutputFilterTools(program Program) ([]ToolReference, error) {
550
564
result .Add (outputFilterRef )
551
565
}
552
566
553
- contextRefs , err := t .GetContextTools (program )
567
+ result .AddAll (t .getCompletionToolRefs (program , nil , ToolTypeOutput ))
568
+
569
+ contextRefs , err := t .getDirectContextToolRefs (program )
554
570
if err != nil {
555
571
return nil , err
556
572
}
@@ -575,7 +591,9 @@ func (t Tool) GetInputFilterTools(program Program) ([]ToolReference, error) {
575
591
result .Add (inputFilterRef )
576
592
}
577
593
578
- contextRefs , err := t .GetContextTools (program )
594
+ result .AddAll (t .getCompletionToolRefs (program , nil , ToolTypeInput ))
595
+
596
+ contextRefs , err := t .getDirectContextToolRefs (program )
579
597
if err != nil {
580
598
return nil , err
581
599
}
@@ -602,11 +620,28 @@ func (t Tool) GetNextAgentGroup(prg Program, agentGroup []ToolReference, toolID
602
620
return agentGroup , nil
603
621
}
604
622
623
+ func filterRefs (prg Program , refs []ToolReference , types ... ToolType ) (result []ToolReference ) {
624
+ for _ , ref := range refs {
625
+ if slices .Contains (types , prg .ToolSet [ref .ToolID ].Type ) {
626
+ result = append (result , ref )
627
+ }
628
+ }
629
+ return
630
+ }
631
+
605
632
func (t Tool ) GetCompletionTools (prg Program , agentGroup ... ToolReference ) (result []CompletionTool , err error ) {
606
- refs , err := t .getCompletionToolRefs (prg , agentGroup )
633
+ toolSet := & toolRefSet {}
634
+ toolSet .AddAll (t .getCompletionToolRefs (prg , agentGroup , ToolTypeDefault , ToolTypeTool ))
635
+
636
+ if err := t .addAgents (prg , toolSet ); err != nil {
637
+ return nil , err
638
+ }
639
+
640
+ refs , err := toolSet .List ()
607
641
if err != nil {
608
642
return nil , err
609
643
}
644
+
610
645
return toolRefsToCompletionTools (refs , prg ), nil
611
646
}
612
647
@@ -638,26 +673,30 @@ func (t Tool) addReferencedTools(prg Program, result *toolRefSet) error {
638
673
result .Add (subToolRef )
639
674
640
675
// Get all tools exports
641
- result .AddAll (prg .ToolSet [subToolRef .ToolID ].GetExportedTools (prg ))
676
+ result .AddAll (prg .ToolSet [subToolRef .ToolID ].getExportedTools (prg ))
642
677
}
643
678
644
679
return nil
645
680
}
646
681
647
682
func (t Tool ) addContextExportedTools (prg Program , result * toolRefSet ) error {
648
- contextTools , err := t .GetContextTools (prg )
683
+ contextTools , err := t .getDirectContextToolRefs (prg )
649
684
if err != nil {
650
685
return err
651
686
}
652
687
653
688
for _ , contextTool := range contextTools {
654
- result .AddAll (prg .ToolSet [contextTool .ToolID ].GetExportedTools (prg ))
689
+ result .AddAll (prg .ToolSet [contextTool .ToolID ].getExportedTools (prg ))
655
690
}
656
691
657
692
return nil
658
693
}
659
694
660
- func (t Tool ) getCompletionToolRefs (prg Program , agentGroup []ToolReference ) ([]ToolReference , error ) {
695
+ func (t Tool ) getCompletionToolRefs (prg Program , agentGroup []ToolReference , types ... ToolType ) ([]ToolReference , error ) {
696
+ if len (types ) == 0 {
697
+ types = []ToolType {ToolTypeDefault , ToolTypeTool }
698
+ }
699
+
661
700
result := toolRefSet {}
662
701
663
702
if t .Chat {
@@ -677,18 +716,17 @@ func (t Tool) getCompletionToolRefs(prg Program, agentGroup []ToolReference) ([]
677
716
return nil , err
678
717
}
679
718
680
- if err := t .addAgents (prg , & result ); err != nil {
681
- return nil , err
682
- }
683
-
684
- return result .List ()
719
+ refs , err := result .List ()
720
+ return filterRefs (prg , refs , types ... ), err
685
721
}
686
722
687
723
func (t Tool ) GetCredentialTools (prg Program , agentGroup []ToolReference ) ([]ToolReference , error ) {
688
724
result := toolRefSet {}
689
725
690
726
result .AddAll (t .GetToolRefsFromNames (t .Credentials ))
691
727
728
+ result .AddAll (t .getCompletionToolRefs (prg , nil , ToolTypeCredential ))
729
+
692
730
toolRefs , err := t .getCompletionToolRefs (prg , agentGroup )
693
731
if err != nil {
694
732
return nil , err
0 commit comments