Skip to content

Commit 8bbb029

Browse files
authored
Merge pull request #802 from thedadams/sdk-list-model-other-providers
feat: add ability to list models from other providers
2 parents 20c983e + 1124ed1 commit 8bbb029

File tree

2 files changed

+15
-30
lines changed

2 files changed

+15
-30
lines changed

pkg/sdkserver/routes.go

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -73,39 +73,13 @@ func (s *server) version(w http.ResponseWriter, r *http.Request) {
7373
// listTools will return the output of `gptscript --list-tools`
7474
func (s *server) listTools(w http.ResponseWriter, r *http.Request) {
7575
logger := gcontext.GetLogger(r.Context())
76-
var prg types.Program
77-
if r.ContentLength != 0 {
78-
reqObject := new(toolOrFileRequest)
79-
err := json.NewDecoder(r.Body).Decode(reqObject)
80-
if err != nil {
81-
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
82-
return
83-
}
84-
85-
if reqObject.Content != "" {
86-
prg, err = loader.ProgramFromSource(r.Context(), reqObject.Content, reqObject.SubTool, loader.Options{Cache: s.client.Cache})
87-
} else if reqObject.File != "" {
88-
prg, err = loader.Program(r.Context(), reqObject.File, reqObject.SubTool, loader.Options{Cache: s.client.Cache})
89-
} else {
90-
prg, err = loader.ProgramFromSource(r.Context(), reqObject.ToolDefs.String(), reqObject.SubTool, loader.Options{Cache: s.client.Cache})
91-
}
92-
if err != nil {
93-
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to load program: %w", err))
94-
return
95-
}
96-
}
97-
98-
tools := s.client.ListTools(r.Context(), prg)
76+
tools := s.client.ListTools(r.Context(), types.Program{})
9977
sort.Slice(tools, func(i, j int) bool {
10078
return tools[i].Name < tools[j].Name
10179
})
10280

10381
lines := make([]string, 0, len(tools))
10482
for _, tool := range tools {
105-
if tool.Name == "" {
106-
tool.Name = prg.Name
107-
}
108-
10983
// Don't print instructions
11084
tool.Instructions = ""
11185

@@ -118,22 +92,31 @@ func (s *server) listTools(w http.ResponseWriter, r *http.Request) {
11892
// listModels will return the output of `gptscript --list-models`
11993
func (s *server) listModels(w http.ResponseWriter, r *http.Request) {
12094
logger := gcontext.GetLogger(r.Context())
95+
client := s.client
96+
12197
var providers []string
12298
if r.ContentLength != 0 {
12399
reqObject := new(modelsRequest)
124-
if err := json.NewDecoder(r.Body).Decode(reqObject); err != nil {
100+
err := json.NewDecoder(r.Body).Decode(reqObject)
101+
if err != nil {
125102
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
126103
return
127104
}
128105

129106
providers = reqObject.Providers
107+
108+
client, err = gptscript.New(r.Context(), s.gptscriptOpts, gptscript.Options{Env: reqObject.Env, Runner: runner.Options{CredentialOverrides: reqObject.CredentialOverrides}})
109+
if err != nil {
110+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to create client: %w", err))
111+
return
112+
}
130113
}
131114

132115
if s.gptscriptOpts.DefaultModelProvider != "" {
133116
providers = append(providers, s.gptscriptOpts.DefaultModelProvider)
134117
}
135118

136-
out, err := s.client.ListModels(r.Context(), providers...)
119+
out, err := client.ListModels(r.Context(), providers...)
137120
if err != nil {
138121
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to list models: %w", err))
139122
return

pkg/sdkserver/types.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ type parseRequest struct {
100100
}
101101

102102
type modelsRequest struct {
103-
Providers []string `json:"providers"`
103+
Providers []string `json:"providers"`
104+
Env []string `json:"env"`
105+
CredentialOverrides []string `json:"credentialOverrides"`
104106
}
105107

106108
type runInfo struct {

0 commit comments

Comments
 (0)