Skip to content

Commit 6c6a45b

Browse files
committed
feat(jwt): make KeyFunc public in JWT middleware
It allows a user-defined function to supply the key for a token verification.
1 parent d9e2354 commit 6c6a45b

File tree

2 files changed

+64
-23
lines changed

2 files changed

+64
-23
lines changed

middleware/jwt.go

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,12 @@ type (
2929
// ErrorHandlerWithContext is almost identical to ErrorHandler, but it's passed the current context.
3030
ErrorHandlerWithContext JWTErrorHandlerWithContext
3131

32-
// Signing key to validate token. Used as fallback if SigningKeys has length 0.
33-
// Required. This or SigningKeys.
32+
// Signing key to validate token. Used as fallback if KeyFunc is nil or SigningKeys has length 0.
33+
// Required. This or SigningKeys or KeyFunc.
3434
SigningKey interface{}
3535

36-
// Map of signing keys to validate token with kid field usage.
37-
// Required. This or SigningKey.
36+
// Map of signing keys to validate token with kid field usage. Used as fallback if KeyFunc is nil.
37+
// Required. This or SigningKey or KeyFunc.
3838
SigningKeys map[string]interface{}
3939

4040
// Signing method, used to check token signing method.
@@ -64,7 +64,12 @@ type (
6464
// Optional. Default value "Bearer".
6565
AuthScheme string
6666

67-
keyFunc jwt.Keyfunc
67+
// KeyFunc defines a function to supply the key for a token verification.
68+
// When a user-defined KeyFunc is provided, SigningKey, SigningKeys, and SigningMethod are ignored.
69+
// Required. This or SigningKey or SigningKeys.
70+
// Default to an internal implementation verifying the signing algorithm and selecting the proper key.
71+
// See: `jwt.Keyfunc`
72+
KeyFunc jwt.Keyfunc
6873
}
6974

7075
// JWTSuccessHandler defines a function which is executed for a valid token.
@@ -99,6 +104,7 @@ var (
99104
TokenLookup: "header:" + echo.HeaderAuthorization,
100105
AuthScheme: "Bearer",
101106
Claims: jwt.MapClaims{},
107+
KeyFunc: nil,
102108
}
103109
)
104110

@@ -123,7 +129,7 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
123129
if config.Skipper == nil {
124130
config.Skipper = DefaultJWTConfig.Skipper
125131
}
126-
if config.SigningKey == nil && len(config.SigningKeys) == 0 {
132+
if config.SigningKey == nil && len(config.SigningKeys) == 0 && config.KeyFunc == nil {
127133
panic("echo: jwt middleware requires signing key")
128134
}
129135
if config.SigningMethod == "" {
@@ -141,21 +147,8 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
141147
if config.AuthScheme == "" {
142148
config.AuthScheme = DefaultJWTConfig.AuthScheme
143149
}
144-
config.keyFunc = func(t *jwt.Token) (interface{}, error) {
145-
// Check the signing method
146-
if t.Method.Alg() != config.SigningMethod {
147-
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
148-
}
149-
if len(config.SigningKeys) > 0 {
150-
if kid, ok := t.Header["kid"].(string); ok {
151-
if key, ok := config.SigningKeys[kid]; ok {
152-
return key, nil
153-
}
154-
}
155-
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
156-
}
157-
158-
return config.SigningKey, nil
150+
if config.KeyFunc == nil {
151+
config.KeyFunc = config.defaultKeyFunc
159152
}
160153

161154
// Initialize
@@ -196,11 +189,11 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
196189
token := new(jwt.Token)
197190
// Issue #647, #656
198191
if _, ok := config.Claims.(jwt.MapClaims); ok {
199-
token, err = jwt.Parse(auth, config.keyFunc)
192+
token, err = jwt.Parse(auth, config.KeyFunc)
200193
} else {
201194
t := reflect.ValueOf(config.Claims).Type().Elem()
202195
claims := reflect.New(t).Interface().(jwt.Claims)
203-
token, err = jwt.ParseWithClaims(auth, claims, config.keyFunc)
196+
token, err = jwt.ParseWithClaims(auth, claims, config.KeyFunc)
204197
}
205198
if err == nil && token.Valid {
206199
// Store user information from token into context.
@@ -225,6 +218,24 @@ func JWTWithConfig(config JWTConfig) echo.MiddlewareFunc {
225218
}
226219
}
227220

221+
// defaultKeyFunc returns a signing key of the given token.
222+
func (config *JWTConfig) defaultKeyFunc(t *jwt.Token) (interface{}, error) {
223+
// Check the signing method
224+
if t.Method.Alg() != config.SigningMethod {
225+
return nil, fmt.Errorf("unexpected jwt signing method=%v", t.Header["alg"])
226+
}
227+
if len(config.SigningKeys) > 0 {
228+
if kid, ok := t.Header["kid"].(string); ok {
229+
if key, ok := config.SigningKeys[kid]; ok {
230+
return key, nil
231+
}
232+
}
233+
return nil, fmt.Errorf("unexpected jwt key id=%v", t.Header["kid"])
234+
}
235+
236+
return config.SigningKey, nil
237+
}
238+
228239
// jwtFromHeader returns a `jwtExtractor` that extracts token from the request header.
229240
func jwtFromHeader(header string, authScheme string) jwtExtractor {
230241
return func(c echo.Context) (string, error) {

middleware/jwt_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package middleware
22

33
import (
4+
"errors"
45
"net/http"
56
"net/http/httptest"
67
"net/url"
@@ -220,6 +221,35 @@ func TestJWT(t *testing.T) {
220221
expErrCode: http.StatusBadRequest,
221222
info: "Empty form field",
222223
},
224+
{
225+
hdrAuth: validAuth,
226+
config: JWTConfig{
227+
KeyFunc: func(*jwt.Token) (interface{}, error) {
228+
return validKey, nil
229+
},
230+
},
231+
info: "Valid JWT with a valid key using a user-defined KeyFunc",
232+
},
233+
{
234+
hdrAuth: validAuth,
235+
config: JWTConfig{
236+
KeyFunc: func(*jwt.Token) (interface{}, error) {
237+
return invalidKey, nil
238+
},
239+
},
240+
expErrCode: http.StatusUnauthorized,
241+
info: "Valid JWT with an invalid key using a user-defined KeyFunc",
242+
},
243+
{
244+
hdrAuth: validAuth,
245+
config: JWTConfig{
246+
KeyFunc: func(*jwt.Token) (interface{}, error) {
247+
return nil, errors.New("faulty KeyFunc")
248+
},
249+
},
250+
expErrCode: http.StatusUnauthorized,
251+
info: "Token verification does not pass using a user-defined KeyFunc",
252+
},
223253
} {
224254
if tc.reqURL == "" {
225255
tc.reqURL = "/"

0 commit comments

Comments
 (0)