diff --git a/pkg/credentials/noop.go b/pkg/credentials/noop.go index 5f3cc5ad..3a13b907 100644 --- a/pkg/credentials/noop.go +++ b/pkg/credentials/noop.go @@ -12,6 +12,10 @@ func (s NoopStore) Add(context.Context, Credential) error { return nil } +func (s NoopStore) Refresh(context.Context, Credential) error { + return nil +} + func (s NoopStore) Remove(context.Context, string) error { return nil } diff --git a/pkg/credentials/store.go b/pkg/credentials/store.go index 749aba3a..2414e1e8 100644 --- a/pkg/credentials/store.go +++ b/pkg/credentials/store.go @@ -5,6 +5,7 @@ import ( "fmt" "path/filepath" "regexp" + "slices" "strings" "github.com/docker/cli/cli/config/credentials" @@ -26,6 +27,7 @@ type CredentialBuilder interface { type CredentialStore interface { Get(ctx context.Context, toolName string) (*Credential, bool, error) Add(ctx context.Context, cred Credential) error + Refresh(ctx context.Context, cred Credential) error Remove(ctx context.Context, toolName string) error List(ctx context.Context) ([]Credential, error) } @@ -95,6 +97,8 @@ func (s Store) Get(ctx context.Context, toolName string) (*Credential, bool, err return &cred, true, nil } +// Add adds a new credential to the credential store. +// Any context set on the credential object will be overwritten with the first context of the credential store. func (s Store) Add(ctx context.Context, cred Credential) error { first := first(s.credCtxs) if first == AllCredentialContexts { @@ -113,6 +117,23 @@ func (s Store) Add(ctx context.Context, cred Credential) error { return store.Store(auth) } +// Refresh updates an existing credential in the credential store. +func (s Store) Refresh(ctx context.Context, cred Credential) error { + if !slices.Contains(s.credCtxs, cred.Context) { + return fmt.Errorf("context %q not in list of valid contexts for this credential store", cred.Context) + } + + store, err := s.getStore(ctx) + if err != nil { + return err + } + auth, err := cred.toDockerAuthConfig() + if err != nil { + return err + } + return store.Store(auth) +} + func (s Store) Remove(ctx context.Context, toolName string) error { first := first(s.credCtxs) if len(s.credCtxs) > 1 || first == AllCredentialContexts { diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index e6318c15..7ac9fae0 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -854,8 +854,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env } var ( - c *credentials.Credential - exists bool + c *credentials.Credential + resultCredential credentials.Credential + exists bool + refresh bool ) rm := runtimeWithLogger(callCtx, monitor, r.runtimeManager) @@ -886,6 +888,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env if !exists || c.IsExpired() { // If the existing credential is expired, we need to provide it to the cred tool through the environment. if exists && c.IsExpired() { + refresh = true credJSON, err := json.Marshal(c) if err != nil { return nil, fmt.Errorf("failed to marshal credential: %w", err) @@ -916,39 +919,56 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env continue } - if err := json.Unmarshal([]byte(*res.Result), &c); err != nil { + if err := json.Unmarshal([]byte(*res.Result), &resultCredential); err != nil { return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err) } - c.ToolName = credName - c.Type = credentials.CredentialTypeTool + resultCredential.ToolName = credName + resultCredential.Type = credentials.CredentialTypeTool + + if refresh { + // If this is a credential refresh, we need to make sure we use the same context. + resultCredential.Context = c.Context + } else { + // If it is a new credential, let the credential store determine the context. + resultCredential.Context = "" + } isEmpty := true - for _, v := range c.Env { + for _, v := range resultCredential.Env { if v != "" { isEmpty = false break } } - if !c.Ephemeral { + if !resultCredential.Ephemeral { // Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty. if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" { if isEmpty { log.Warnf("Not saving empty credential for tool %s", toolName) - } else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil { - return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err) + } else { + if refresh { + err = r.credStore.Refresh(callCtx.Ctx, resultCredential) + } else { + err = r.credStore.Add(callCtx.Ctx, resultCredential) + } + if err != nil { + return nil, fmt.Errorf("failed to save credential for tool %s: %w", toolName, err) + } } } else { log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName) } } + } else { + resultCredential = *c } - if c.ExpiresAt != nil && (nearestExpiration == nil || nearestExpiration.After(*c.ExpiresAt)) { - nearestExpiration = c.ExpiresAt + if resultCredential.ExpiresAt != nil && (nearestExpiration == nil || nearestExpiration.After(*resultCredential.ExpiresAt)) { + nearestExpiration = resultCredential.ExpiresAt } - for k, v := range c.Env { + for k, v := range resultCredential.Env { env = append(env, fmt.Sprintf("%s=%s", k, v)) } }