diff --git a/jwt.go b/jwt.go index ec7d664..2c43ec8 100644 --- a/jwt.go +++ b/jwt.go @@ -118,6 +118,16 @@ var ErrJWTMissing = echo.NewHTTPError(http.StatusUnauthorized, "missing or malfo // ErrJWTInvalid denotes an error raised when JWT token value is invalid or expired var ErrJWTInvalid = echo.NewHTTPError(http.StatusUnauthorized, "invalid or expired jwt") +// TokenError is used to return error with error occurred JWT token when processing JWT token +type TokenError struct { + Token *jwt.Token + Err error +} + +func (e *TokenError) Error() string { return e.Err.Error() } + +func (e *TokenError) Unwrap() error { return e.Err } + // JWT returns a JSON Web Token (JWT) auth middleware. // // For valid token, it sets the user in context and calls next handler. @@ -233,9 +243,12 @@ func (config Config) ToMiddleware() (echo.MiddlewareFunc, error) { }, nil } +// defaultKeyFunc creates JWTGo implementation for KeyFunc. +// +// error returns TokenError. func (config Config) defaultKeyFunc(token *jwt.Token) (interface{}, error) { if token.Method.Alg() != config.SigningMethod { - return nil, fmt.Errorf("unexpected jwt signing method=%v", token.Header["alg"]) + return nil, &TokenError{Token: token, Err: fmt.Errorf("unexpected jwt signing method=%v", token.Header["alg"])} } if len(config.SigningKeys) == 0 { return config.SigningKey, nil @@ -246,17 +259,19 @@ func (config Config) defaultKeyFunc(token *jwt.Token) (interface{}, error) { return key, nil } } - return nil, fmt.Errorf("unexpected jwt key id=%v", token.Header["kid"]) + return nil, &TokenError{Token: token, Err: fmt.Errorf("unexpected jwt key id=%v", token.Header["kid"])} } -// defaultParseTokenFunc creates JWTGo implementation for ParseTokenFunc +// defaultParseTokenFunc creates JWTGo implementation for ParseTokenFunc. +// +// error returns TokenError. func (config Config) defaultParseTokenFunc(c echo.Context, auth string) (interface{}, error) { token, err := jwt.ParseWithClaims(auth, config.NewClaimsFunc(c), config.KeyFunc) if err != nil { - return nil, err + return nil, &TokenError{Token: token, Err: err} } if !token.Valid { - return nil, errors.New("invalid token") + return nil, &TokenError{Token: token, Err: errors.New("invalid token")} } return token, nil }