Skip to content

Commit 387e226

Browse files
committed
Add CustomTokenSource for custom token validation
This commit adds CustomTokenSource which allows the user to specify a function which accepts a token validation function. This also changes the internal implementation of ReuseTokenSource to use CustomTokenSource under the hood while keeping the same external API.
1 parent 0f29369 commit 387e226

File tree

2 files changed

+111
-45
lines changed

2 files changed

+111
-45
lines changed

oauth2.go

Lines changed: 55 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,7 @@ func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
244244
if t != nil {
245245
tkr.refreshToken = t.RefreshToken
246246
}
247-
return &reuseTokenSource{
248-
t: t,
249-
new: tkr,
250-
}
247+
return ReuseTokenSource(t, tkr)
251248
}
252249

253250
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
@@ -281,24 +278,64 @@ func (tf *tokenRefresher) Token() (*Token, error) {
281278
return tk, err
282279
}
283280

284-
// reuseTokenSource is a TokenSource that holds a single token in memory
285-
// and validates its expiry before each call to retrieve it with
286-
// Token. If it's expired, it will be auto-refreshed using the
287-
// new TokenSource.
288-
type reuseTokenSource struct {
289-
new TokenSource // called when t is expired.
281+
// StaticTokenSource returns a TokenSource that always returns the same token.
282+
// Because the provided token t is never refreshed, StaticTokenSource is only
283+
// useful for tokens that never expire.
284+
func StaticTokenSource(t *Token) TokenSource {
285+
return staticTokenSource{t}
286+
}
287+
288+
// staticTokenSource is a TokenSource that always returns the same Token.
289+
type staticTokenSource struct {
290+
t *Token
291+
}
292+
293+
func (s staticTokenSource) Token() (*Token, error) {
294+
return s.t, nil
295+
}
296+
297+
// ValidFunc should return false when the passed token is invalid and true when
298+
// the token is valid. ValidFunc should NOT modify the token passed to it.
299+
type ValidFunc func(t *Token) bool
300+
301+
// CustomTokenSource returns a TokenSource which repeatedly returns the
302+
// same token as long as ValidFunc returns true, starting with t.
303+
// When ValidFunc returns false (the cached token is invalid), a new token
304+
// is obtained from src.
305+
func CustomTokenSource(t *Token, src TokenSource, validFunc ValidFunc) TokenSource {
306+
// Don't wrap a customTokenSource in itself. That would work,
307+
// but cause an unnecessary number of mutex operations.
308+
// Just build the equivalent one.
309+
if rt, ok := src.(*customTokenSource); ok {
310+
if t == nil {
311+
// Just use it directly.
312+
return rt
313+
}
314+
src = rt.new
315+
}
316+
return &customTokenSource{
317+
t: t,
318+
new: src,
319+
validFunc: validFunc,
320+
}
321+
}
322+
323+
type customTokenSource struct {
324+
validFunc ValidFunc // used for determining whether the token should be refreshed
325+
new TokenSource // called when validFunc returns invalid
290326

291327
mu sync.Mutex // guards t
292328
t *Token
293329
}
294330

295-
// Token returns the current token if it's still valid, else will
296-
// refresh the current token (using r.Context for HTTP client
297-
// information) and return the new one.
298-
func (s *reuseTokenSource) Token() (*Token, error) {
331+
// Token returns a TokenSource that will return the current token so long as
332+
// ValidFunc returns that the token is valid, otherwise it will refresh the
333+
// current token (using r.Context for HTTP client information) and return the
334+
// new one.
335+
func (s *customTokenSource) Token() (*Token, error) {
299336
s.mu.Lock()
300337
defer s.mu.Unlock()
301-
if s.t.Valid() {
338+
if s.validFunc(s.t) {
302339
return s.t, nil
303340
}
304341
t, err := s.new.Token()
@@ -309,22 +346,6 @@ func (s *reuseTokenSource) Token() (*Token, error) {
309346
return t, nil
310347
}
311348

312-
// StaticTokenSource returns a TokenSource that always returns the same token.
313-
// Because the provided token t is never refreshed, StaticTokenSource is only
314-
// useful for tokens that never expire.
315-
func StaticTokenSource(t *Token) TokenSource {
316-
return staticTokenSource{t}
317-
}
318-
319-
// staticTokenSource is a TokenSource that always returns the same Token.
320-
type staticTokenSource struct {
321-
t *Token
322-
}
323-
324-
func (s staticTokenSource) Token() (*Token, error) {
325-
return s.t, nil
326-
}
327-
328349
// HTTPClient is the context key to use with golang.org/x/net/context's
329350
// WithValue function to associate an *http.Client value with a context.
330351
var HTTPClient internal.ContextKey
@@ -364,18 +385,7 @@ func NewClient(ctx context.Context, src TokenSource) *http.Client {
364385
// means it's always safe to wrap ReuseTokenSource around any other
365386
// TokenSource without adverse effects.
366387
func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
367-
// Don't wrap a reuseTokenSource in itself. That would work,
368-
// but cause an unnecessary number of mutex operations.
369-
// Just build the equivalent one.
370-
if rt, ok := src.(*reuseTokenSource); ok {
371-
if t == nil {
372-
// Just use it directly.
373-
return rt
374-
}
375-
src = rt.new
376-
}
377-
return &reuseTokenSource{
378-
t: t,
379-
new: src,
380-
}
388+
return CustomTokenSource(t, src, func(t *Token) bool {
389+
return t.Valid()
390+
})
381391
}

oauth2_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,3 +565,59 @@ func TestConfigClientWithToken(t *testing.T) {
565565
t.Error(err)
566566
}
567567
}
568+
569+
type mockTokenSource struct {
570+
nextToken *Token
571+
}
572+
573+
func (s mockTokenSource) Token() (*Token, error) {
574+
return s.nextToken, nil
575+
}
576+
577+
func TestCustomTokenSource(t *testing.T) {
578+
foobarToken := &Token{AccessToken: "foobar"}
579+
barbazToken := &Token{AccessToken: "barbaz"}
580+
581+
testCases := []struct {
582+
name string
583+
t *Token
584+
src TokenSource
585+
validToken bool
586+
expectedToken *Token
587+
}{
588+
{
589+
name: "invalid token",
590+
t: foobarToken,
591+
src: mockTokenSource{nextToken: barbazToken},
592+
validToken: false,
593+
expectedToken: barbazToken,
594+
},
595+
{
596+
name: "valid token",
597+
t: foobarToken,
598+
src: mockTokenSource{nextToken: barbazToken},
599+
validToken: true,
600+
expectedToken: foobarToken,
601+
},
602+
}
603+
604+
for _, tt := range testCases {
605+
t.Run(tt.name, func(t *testing.T) {
606+
validFunc := func(t *Token) bool { return tt.validToken }
607+
ts := CustomTokenSource(tt.t, tt.src, validFunc)
608+
609+
// the same expected token should always be returned no matter how many iterations
610+
// we go through since the validfunc returns a constant value
611+
for i := 0; i < 3; i++ {
612+
newToken, err := ts.Token()
613+
if err != nil {
614+
t.Errorf("did not expect an error but got: %v", err)
615+
}
616+
617+
if tt.expectedToken != newToken {
618+
t.Errorf("expected token %v, but got %v", tt.expectedToken, newToken)
619+
}
620+
}
621+
})
622+
}
623+
}

0 commit comments

Comments
 (0)