Skip to content

Commit b3d662a

Browse files
chore: add --default-model-provider
This change allow you to set a default model provider where before the default could only be a url/api pair for OpenAI. Setting this also implicitly disables the configured openai provider configured with the --openai* flags
1 parent 2584308 commit b3d662a

File tree

5 files changed

+90
-70
lines changed

5 files changed

+90
-70
lines changed

pkg/cli/gptscript.go

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,25 @@ type GPTScript struct {
5454
Output string `usage:"Save output to a file, or - for stdout" short:"o"`
5555
EventsStreamTo string `usage:"Stream events to this location, could be a file descriptor/handle (e.g. fd://2), filename, or named pipe (e.g. \\\\.\\pipe\\my-pipe)" name:"events-stream-to"`
5656
// Input should not be using GPTSCRIPT_INPUT env var because that is the same value that is set in tool executions
57-
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"`
58-
SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"`
59-
Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"`
60-
ListModels bool `usage:"List the models available and exit" local:"true"`
61-
ListTools bool `usage:"List built-in tools and exit" local:"true"`
62-
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"`
63-
Chdir string `usage:"Change current working directory" short:"C"`
64-
Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"`
65-
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
66-
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
67-
CredentialOverride []string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
68-
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"`
69-
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"`
70-
ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"`
71-
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`
72-
UI bool `usage:"Launch the UI" local:"true" name:"ui"`
73-
DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"`
74-
SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"`
57+
Input string `usage:"Read input from a file (\"-\" for stdin)" short:"f" env:"GPTSCRIPT_INPUT_FILE"`
58+
SubTool string `usage:"Use tool of this name, not the first tool in file" local:"true"`
59+
Assemble bool `usage:"Assemble tool to a single artifact, saved to --output" hidden:"true" local:"true"`
60+
ListModels bool `usage:"List the models available and exit" local:"true"`
61+
ListTools bool `usage:"List built-in tools and exit" local:"true"`
62+
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:0" hidden:"true"`
63+
Chdir string `usage:"Change current working directory" short:"C"`
64+
Daemon bool `usage:"Run tool as a daemon" local:"true" hidden:"true"`
65+
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
66+
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
67+
CredentialOverride []string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
68+
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state" local:"true"`
69+
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool" local:"true"`
70+
ForceSequential bool `usage:"Force parallel calls to run sequentially" local:"true"`
71+
Workspace string `usage:"Directory to use for the workspace, if specified it will not be deleted on exit"`
72+
UI bool `usage:"Launch the UI" local:"true" name:"ui"`
73+
DisableTUI bool `usage:"Don't use chat TUI but instead verbose output" local:"true" name:"disable-tui"`
74+
SaveChatStateFile string `usage:"A file to save the chat state to so that a conversation can be resumed with --chat-state" local:"true"`
75+
DefaultModelProvider string `usage:"Default LLM model provider to use, this will override OpenAI settings"`
7576

7677
readData []byte
7778
}
@@ -136,11 +137,12 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
136137
CredentialOverrides: r.CredentialOverride,
137138
Sequential: r.ForceSequential,
138139
},
139-
Quiet: r.Quiet,
140-
Env: os.Environ(),
141-
CredentialContext: r.CredentialContext,
142-
Workspace: r.Workspace,
143-
DisablePromptServer: r.UI,
140+
Quiet: r.Quiet,
141+
Env: os.Environ(),
142+
CredentialContext: r.CredentialContext,
143+
Workspace: r.Workspace,
144+
DisablePromptServer: r.UI,
145+
DefaultModelProvider: r.DefaultModelProvider,
144146
}
145147

146148
if r.Confirm {

pkg/gptscript/gptscript.go

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,16 @@ type GPTScript struct {
4040
}
4141

4242
type Options struct {
43-
Cache cache.Options
44-
OpenAI openai.Options
45-
Monitor monitor.Options
46-
Runner runner.Options
47-
CredentialContext string
48-
Quiet *bool
49-
Workspace string
50-
DisablePromptServer bool
51-
Env []string
43+
Cache cache.Options
44+
OpenAI openai.Options
45+
Monitor monitor.Options
46+
Runner runner.Options
47+
DefaultModelProvider string
48+
CredentialContext string
49+
Quiet *bool
50+
Workspace string
51+
DisablePromptServer bool
52+
Env []string
5253
}
5354

5455
func Complete(opts ...Options) Options {
@@ -64,6 +65,7 @@ func Complete(opts ...Options) Options {
6465
result.Workspace = types.FirstSet(opt.Workspace, result.Workspace)
6566
result.Env = append(result.Env, opt.Env...)
6667
result.DisablePromptServer = types.FirstSet(opt.DisablePromptServer, result.DisablePromptServer)
68+
result.DefaultModelProvider = types.FirstSet(opt.DefaultModelProvider, result.DefaultModelProvider)
6769
}
6870

6971
if result.Quiet == nil {
@@ -106,16 +108,18 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
106108
return nil, err
107109
}
108110

109-
oaiClient, err := openai.NewClient(ctx, credStore, opts.OpenAI, openai.Options{
110-
Cache: cacheClient,
111-
SetSeed: true,
112-
})
113-
if err != nil {
114-
return nil, err
115-
}
111+
if opts.DefaultModelProvider == "" {
112+
oaiClient, err := openai.NewClient(ctx, credStore, opts.OpenAI, openai.Options{
113+
Cache: cacheClient,
114+
SetSeed: true,
115+
})
116+
if err != nil {
117+
return nil, err
118+
}
116119

117-
if err := registry.AddClient(oaiClient); err != nil {
118-
return nil, err
120+
if err := registry.AddClient(oaiClient); err != nil {
121+
return nil, err
122+
}
119123
}
120124

121125
if opts.Runner.MonitorFactory == nil {
@@ -143,7 +147,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
143147

144148
fullEnv := append(opts.Env, extraEnv...)
145149

146-
remoteClient := remote.New(runner, fullEnv, cacheClient, credStore)
150+
remoteClient := remote.New(runner, fullEnv, cacheClient, credStore, opts.DefaultModelProvider)
147151
if err := registry.AddClient(remoteClient); err != nil {
148152
closeServer()
149153
return nil, err

pkg/remote/remote.go

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,23 @@ import (
2222
)
2323

2424
type Client struct {
25-
clientsLock sync.Mutex
26-
cache *cache.Client
27-
clients map[string]*openai.Client
28-
models map[string]*openai.Client
29-
runner *runner.Runner
30-
envs []string
31-
credStore credentials.CredentialStore
25+
clientsLock sync.Mutex
26+
cache *cache.Client
27+
clients map[string]*openai.Client
28+
models map[string]*openai.Client
29+
runner *runner.Runner
30+
envs []string
31+
credStore credentials.CredentialStore
32+
defaultProvider string
3233
}
3334

34-
func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credentials.CredentialStore) *Client {
35+
func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credentials.CredentialStore, defaultProvider string) *Client {
3536
return &Client{
36-
cache: cache,
37-
runner: r,
38-
envs: envs,
39-
credStore: credStore,
37+
cache: cache,
38+
runner: r,
39+
envs: envs,
40+
credStore: credStore,
41+
defaultProvider: defaultProvider,
4042
}
4143
}
4244

@@ -73,13 +75,23 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
7375
return
7476
}
7577

76-
func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
77-
toolName, modelNameSuffix := types.SplitToolRef(modelName)
78-
if modelNameSuffix == "" {
78+
func (c *Client) parseModel(modelString string) (modelName, providerName string) {
79+
toolName, subTool := types.SplitToolRef(modelString)
80+
if subTool == "" {
81+
// This is just a plain model string "gpt4o"
82+
return toolName, c.defaultProvider
83+
}
84+
// This is a provider string "modelName from provider"
85+
return subTool, toolName
86+
}
87+
88+
func (c *Client) Supports(ctx context.Context, modelString string) (bool, error) {
89+
_, providerName := c.parseModel(modelString)
90+
if providerName == "" {
7991
return false, nil
8092
}
8193

82-
client, err := c.load(ctx, toolName)
94+
client, err := c.load(ctx, providerName)
8395
if err != nil {
8496
return false, err
8597
}
@@ -91,7 +103,7 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
91103
c.models = map[string]*openai.Client{}
92104
}
93105

94-
c.models[modelName] = client
106+
c.models[modelString] = client
95107
return true, nil
96108
}
97109

pkg/sdkserver/routes.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) {
204204
CredentialOverrides: reqObject.CredentialOverrides,
205205
Sequential: reqObject.ForceSequential,
206206
},
207+
DefaultModelProvider: reqObject.DefaultModelProvider,
207208
}
208209

209210
if reqObject.Confirm {

pkg/sdkserver/types.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,18 @@ type toolOrFileRequest struct {
5252
cacheOptions `json:",inline"`
5353
openAIOptions `json:",inline"`
5454

55-
ToolDefs toolDefs `json:"toolDefs,inline"`
56-
SubTool string `json:"subTool"`
57-
Input string `json:"input"`
58-
ChatState string `json:"chatState"`
59-
Workspace string `json:"workspace"`
60-
Env []string `json:"env"`
61-
CredentialContext string `json:"credentialContext"`
62-
CredentialOverrides []string `json:"credentialOverrides"`
63-
Confirm bool `json:"confirm"`
64-
Location string `json:"location,omitempty"`
65-
ForceSequential bool `json:"forceSequential"`
55+
ToolDefs toolDefs `json:"toolDefs,inline"`
56+
SubTool string `json:"subTool"`
57+
Input string `json:"input"`
58+
ChatState string `json:"chatState"`
59+
Workspace string `json:"workspace"`
60+
Env []string `json:"env"`
61+
CredentialContext string `json:"credentialContext"`
62+
CredentialOverrides []string `json:"credentialOverrides"`
63+
Confirm bool `json:"confirm"`
64+
Location string `json:"location,omitempty"`
65+
ForceSequential bool `json:"forceSequential"`
66+
DefaultModelProvider string `json:"defaultModelProvider,omitempty"`
6667
}
6768

6869
type content struct {

0 commit comments

Comments
 (0)