Skip to content

Commit 4eb4ba0

Browse files
committed
fix: openapi revamp: tell the LLM the names of the tools it should use
Signed-off-by: Grant Linville <[email protected]>
1 parent eaaf0cd commit 4eb4ba0

File tree

4 files changed

+28
-13
lines changed

4 files changed

+28
-13
lines changed

pkg/engine/openapi.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func (e *Engine) runOpenAPIRevamp(tool types.Tool, input string) (*Return, error
4343
return nil, fmt.Errorf("failed to load OpenAPI file %s: %w", source, err)
4444
}
4545

46-
opList, err := openapi.List(t, filter)
46+
opList, err := openapi.List(t, filter, "getSchema"+strings.TrimPrefix(tool.Name, "listOperations"))
4747
if err != nil {
4848
return nil, fmt.Errorf("failed to list operations: %w", err)
4949
}
@@ -66,7 +66,7 @@ func (e *Engine) runOpenAPIRevamp(tool types.Tool, input string) (*Return, error
6666
} else if !match {
6767
// Report to the LLM that the operation was not found
6868
return &Return{
69-
Result: ptr(fmt.Sprintf("operation %s not found", operation)),
69+
Result: ptr(fmt.Sprintf("ERROR: operation %s not found", operation)),
7070
}, nil
7171
}
7272
}
@@ -85,14 +85,14 @@ func (e *Engine) runOpenAPIRevamp(tool types.Tool, input string) (*Return, error
8585
defaultHost = u.Scheme + "://" + u.Hostname()
8686
}
8787

88-
schema, _, found, err := openapi.GetSchema(operation, defaultHost, t)
88+
schema, _, found, err := openapi.GetSchema(operation, defaultHost, "runOperation"+strings.TrimPrefix(tool.Name, "getSchema"), t)
8989
if err != nil {
9090
return nil, fmt.Errorf("failed to get schema: %w", err)
9191
}
9292
if !found {
9393
// Report to the LLM that the operation was not found
9494
return &Return{
95-
Result: ptr(fmt.Sprintf("operation %s not found", operation)),
95+
Result: ptr(fmt.Sprintf("ERROR: operation %s not found", operation)),
9696
}, nil
9797
}
9898

@@ -115,7 +115,7 @@ func (e *Engine) runOpenAPIRevamp(tool types.Tool, input string) (*Return, error
115115
} else if !match {
116116
// Report to the LLM that the operation was not found
117117
return &Return{
118-
Result: ptr(fmt.Sprintf("operation %s not found", operation)),
118+
Result: ptr(fmt.Sprintf("ERROR: operation %s not found", operation)),
119119
}, nil
120120
}
121121
}
@@ -134,13 +134,13 @@ func (e *Engine) runOpenAPIRevamp(tool types.Tool, input string) (*Return, error
134134
defaultHost = u.Scheme + "://" + u.Hostname()
135135
}
136136

137-
result, found, err := openapi.Run(operation, defaultHost, args, t, e.Env)
137+
result, found, err := openapi.Run(operation, defaultHost, args, tool.Name, t, e.Env)
138138
if err != nil {
139139
return nil, fmt.Errorf("failed to run operation %s: %w", operation, err)
140140
} else if !found {
141141
// Report to the LLM that the operation was not found
142142
return &Return{
143-
Result: ptr(fmt.Sprintf("operation %s not found", operation)),
143+
Result: ptr(fmt.Sprintf("ERROR: operation %s not found", operation)),
144144
}, nil
145145
}
146146

pkg/openapi/getschema.go

+12-2
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,14 @@ func GetSupportedSecurityTypes() []string {
4242
return supportedSecurityTypes
4343
}
4444

45+
type argsWithMessage struct {
46+
Arguments *openapi3.Schema
47+
Message string `json:"message"`
48+
}
49+
4550
// GetSchema returns the JSONSchema and OperationInfo for a particular OpenAPI operation.
4651
// Return values in order: JSONSchema (string), OperationInfo, found (bool), error.
47-
func GetSchema(operationID, defaultHost string, t *openapi3.T) (string, OperationInfo, bool, error) {
52+
func GetSchema(operationID, defaultHost, runToolName string, t *openapi3.T) (string, OperationInfo, bool, error) {
4853
arguments := &openapi3.Schema{
4954
Type: &openapi3.Types{"object"},
5055
Properties: openapi3.Schemas{},
@@ -227,7 +232,12 @@ func GetSchema(operationID, defaultHost string, t *openapi3.T) (string, Operatio
227232
}
228233
}
229234

230-
argumentsJSON, err := json.MarshalIndent(arguments, "", " ")
235+
withMessage := argsWithMessage{
236+
Arguments: arguments,
237+
Message: fmt.Sprintf("You can use the %s tool to run this operation.", runToolName),
238+
}
239+
240+
argumentsJSON, err := json.MarshalIndent(withMessage, "", " ")
231241
if err != nil {
232242
return "", OperationInfo{}, false, err
233243
}

pkg/openapi/list.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
package openapi
22

33
import (
4+
"fmt"
45
"path/filepath"
56
"strings"
67

78
"github.com/getkin/kin-openapi/openapi3"
89
)
910

1011
type OperationList struct {
12+
Message string `json:"message"`
1113
Operations map[string]Operation `json:"operations"`
1214
}
1315

@@ -21,7 +23,7 @@ const (
2123
NoFilter = "<none>"
2224
)
2325

24-
func List(t *openapi3.T, filter string) (OperationList, error) {
26+
func List(t *openapi3.T, filter, getSchemaToolName string) (OperationList, error) {
2527
operations := make(map[string]Operation)
2628
for _, pathItem := range t.Paths.Map() {
2729
for _, operation := range pathItem.Operations() {
@@ -51,7 +53,10 @@ func List(t *openapi3.T, filter string) (OperationList, error) {
5153
}
5254
}
5355

54-
return OperationList{Operations: operations}, nil
56+
return OperationList{
57+
Message: fmt.Sprintf("You can get the schema for these operations using the %s tool.", getSchemaToolName),
58+
Operations: operations,
59+
}, nil
5560
}
5661

5762
func MatchFilters(filters []string, operationID string) (bool, error) {

pkg/openapi/run.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import (
2020

2121
const RunTool = "run"
2222

23-
func Run(operationID, defaultHost, args string, t *openapi3.T, envs []string) (string, bool, error) {
23+
func Run(operationID, defaultHost, args, runToolName string, t *openapi3.T, envs []string) (string, bool, error) {
2424
envMap := make(map[string]string, len(envs))
2525
for _, e := range envs {
2626
k, v, _ := strings.Cut(e, "=")
@@ -30,7 +30,7 @@ func Run(operationID, defaultHost, args string, t *openapi3.T, envs []string) (s
3030
if args == "" {
3131
args = "{}"
3232
}
33-
schemaJSON, opInfo, found, err := GetSchema(operationID, defaultHost, t)
33+
schemaJSON, opInfo, found, err := GetSchema(operationID, defaultHost, runToolName, t)
3434
if err != nil || !found {
3535
return "", false, err
3636
}

0 commit comments

Comments
 (0)