diff --git a/pkg/cli/credential.go b/pkg/cli/credential.go index cb000125..733590c4 100644 --- a/pkg/cli/credential.go +++ b/pkg/cli/credential.go @@ -45,7 +45,7 @@ func (c *Credential) Run(cmd *cobra.Command, _ []string) error { ctx := c.root.CredentialContext if c.AllContexts { - ctx = "*" + ctx = credentials.AllCredentialContexts } opts, err := c.root.NewGPTScriptOpts() diff --git a/pkg/credentials/store.go b/pkg/credentials/store.go index 3940184b..c8558f3a 100644 --- a/pkg/credentials/store.go +++ b/pkg/credentials/store.go @@ -11,6 +11,11 @@ import ( "github.com/gptscript-ai/gptscript/pkg/config" ) +const ( + DefaultCredentialContext = "default" + AllCredentialContexts = "*" +) + type CredentialBuilder interface { EnsureCredentialHelpers(ctx context.Context) error } @@ -105,7 +110,7 @@ func (s Store) List(ctx context.Context) ([]Credential, error) { if err != nil { return nil, err } - if s.credCtx == "*" || c.Context == s.credCtx { + if s.credCtx == AllCredentialContexts || c.Context == s.credCtx { creds = append(creds, c) } } @@ -139,7 +144,7 @@ func validateCredentialCtx(ctx string) error { return fmt.Errorf("credential context cannot be empty") } - if ctx == "*" { // this represents "all contexts" and is allowed + if ctx == AllCredentialContexts { return nil } diff --git a/pkg/gptscript/gptscript.go b/pkg/gptscript/gptscript.go index 755fe632..abae80ac 100644 --- a/pkg/gptscript/gptscript.go +++ b/pkg/gptscript/gptscript.go @@ -75,7 +75,7 @@ func Complete(opts ...Options) Options { result.Env = os.Environ() } if result.CredentialContext == "" { - result.CredentialContext = "default" + result.CredentialContext = credentials.DefaultCredentialContext } return result diff --git a/pkg/sdkserver/credentials.go b/pkg/sdkserver/credentials.go new file mode 100644 index 00000000..adbaacdc --- /dev/null +++ b/pkg/sdkserver/credentials.go @@ -0,0 +1,176 @@ +package sdkserver + +import ( + "encoding/json" + "fmt" + "net/http" + + "github.com/gptscript-ai/gptscript/pkg/config" + gcontext "github.com/gptscript-ai/gptscript/pkg/context" + "github.com/gptscript-ai/gptscript/pkg/credentials" + "github.com/gptscript-ai/gptscript/pkg/repos/runtimes" +) + +func (s *server) initializeCredentialStore(ctx string) (credentials.CredentialStore, error) { + cfg, err := config.ReadCLIConfig(s.gptscriptOpts.OpenAI.ConfigFile) + if err != nil { + return nil, fmt.Errorf("failed to read CLI config: %w", err) + } + + // TODO - are we sure we want to always use runtimes.Default here? + store, err := credentials.NewStore(cfg, runtimes.Default(s.gptscriptOpts.Cache.CacheDir), ctx, s.gptscriptOpts.Cache.CacheDir) + if err != nil { + return nil, fmt.Errorf("failed to initialize credential store: %w", err) + } + + return store, nil +} + +func (s *server) listCredentials(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + req := new(credentialsRequest) + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + if req.AllContexts { + req.Context = credentials.AllCredentialContexts + } else if req.Context == "" { + req.Context = credentials.DefaultCredentialContext + } + + store, err := s.initializeCredentialStore(req.Context) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, err) + return + } + + creds, err := store.List(r.Context()) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to list credentials: %w", err)) + return + } + + // Remove the environment variable values (which are secrets) and refresh tokens from the response. + for i := range creds { + for k := range creds[i].Env { + creds[i].Env[k] = "" + } + creds[i].RefreshToken = "" + } + + writeResponse(logger, w, map[string]any{"stdout": creds}) +} + +func (s *server) createCredential(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + req := new(credentialsRequest) + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + cred := new(credentials.Credential) + if err := json.Unmarshal([]byte(req.Content), cred); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid credential: %w", err)) + return + } + + if cred.Context == "" { + cred.Context = credentials.DefaultCredentialContext + } + + store, err := s.initializeCredentialStore(cred.Context) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, err) + return + } + + if err := store.Add(r.Context(), *cred); err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to create credential: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": "Credential created successfully"}) +} + +func (s *server) revealCredential(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + req := new(credentialsRequest) + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + return + } + + if req.Name == "" { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("missing credential name")) + return + } + + if req.AllContexts || req.Context == credentials.AllCredentialContexts { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("allContexts is not supported for credential retrieval; please specify the specific context that the credential is in")) + return + } else if req.Context == "" { + req.Context = credentials.DefaultCredentialContext + } + + store, err := s.initializeCredentialStore(req.Context) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, err) + return + } + + cred, ok, err := store.Get(r.Context(), req.Name) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to get credential: %w", err)) + return + } else if !ok { + writeError(logger, w, http.StatusNotFound, fmt.Errorf("credential not found")) + return + } + + writeResponse(logger, w, map[string]any{"stdout": cred}) +} + +func (s *server) deleteCredential(w http.ResponseWriter, r *http.Request) { + logger := gcontext.GetLogger(r.Context()) + req := new(credentialsRequest) + if err := json.NewDecoder(r.Body).Decode(req); err != nil { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("invalid request body: %w", err)) + } + + if req.Name == "" { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("missing credential name")) + return + } + + if req.AllContexts || req.Context == credentials.AllCredentialContexts { + writeError(logger, w, http.StatusBadRequest, fmt.Errorf("allContexts is not supported for credential deletion; please specify the specific context that the credential is in")) + return + } else if req.Context == "" { + req.Context = credentials.DefaultCredentialContext + } + + store, err := s.initializeCredentialStore(req.Context) + if err != nil { + writeError(logger, w, http.StatusInternalServerError, err) + return + } + + // Check to see if a cred exists so we can return a 404 if it doesn't. + if _, ok, err := store.Get(r.Context(), req.Name); err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to get credential: %w", err)) + return + } else if !ok { + writeError(logger, w, http.StatusNotFound, fmt.Errorf("credential not found")) + return + } + + if err := store.Remove(r.Context(), req.Name); err != nil { + writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("failed to delete credential: %w", err)) + return + } + + writeResponse(logger, w, map[string]any{"stdout": "Credential deleted successfully"}) +} diff --git a/pkg/sdkserver/routes.go b/pkg/sdkserver/routes.go index 6cb1e620..c180097e 100644 --- a/pkg/sdkserver/routes.go +++ b/pkg/sdkserver/routes.go @@ -58,6 +58,11 @@ func (s *server) addRoutes(mux *http.ServeMux) { mux.HandleFunc("POST /confirm/{id}", s.confirm) mux.HandleFunc("POST /prompt/{id}", s.prompt) mux.HandleFunc("POST /prompt-response/{id}", s.promptResponse) + + mux.HandleFunc("POST /credentials", s.listCredentials) + mux.HandleFunc("POST /credentials/create", s.createCredential) + mux.HandleFunc("POST /credentials/reveal", s.revealCredential) + mux.HandleFunc("POST /credentials/delete", s.deleteCredential) } // health just provides an endpoint for checking whether the server is running and accessible. diff --git a/pkg/sdkserver/types.go b/pkg/sdkserver/types.go index 2889626b..7ed7da78 100644 --- a/pkg/sdkserver/types.go +++ b/pkg/sdkserver/types.go @@ -252,3 +252,10 @@ type prompt struct { Type runner.EventType `json:"type,omitempty"` Time time.Time `json:"time,omitempty"` } + +type credentialsRequest struct { + content `json:",inline"` + AllContexts bool `json:"allContexts"` + Context string `json:"context"` + Name string `json:"name"` +}