Skip to content

Commit 567e729

Browse files
committed
Add NimbusJwtEncoder Builders
Closes gh-16267 Signed-off-by: Suraj Bhadrike <[email protected]>
1 parent 226e81d commit 567e729

File tree

2 files changed

+492
-3
lines changed

2 files changed

+492
-3
lines changed

oauth2/oauth2-jose/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwtEncoder.java

Lines changed: 300 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,39 @@
1818

1919
import java.net.URI;
2020
import java.net.URL;
21+
import java.security.KeyPair;
22+
import java.security.interfaces.ECPrivateKey;
23+
import java.security.interfaces.ECPublicKey;
24+
import java.security.interfaces.RSAPrivateKey;
2125
import java.time.Instant;
2226
import java.util.ArrayList;
2327
import java.util.Date;
2428
import java.util.HashMap;
2529
import java.util.List;
2630
import java.util.Map;
2731
import java.util.Set;
32+
import java.util.UUID;
2833
import java.util.concurrent.ConcurrentHashMap;
2934

35+
import javax.crypto.SecretKey;
36+
3037
import com.nimbusds.jose.JOSEException;
3138
import com.nimbusds.jose.JOSEObjectType;
3239
import com.nimbusds.jose.JWSAlgorithm;
3340
import com.nimbusds.jose.JWSHeader;
3441
import com.nimbusds.jose.JWSSigner;
3542
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
43+
import com.nimbusds.jose.jwk.Curve;
44+
import com.nimbusds.jose.jwk.ECKey;
3645
import com.nimbusds.jose.jwk.JWK;
3746
import com.nimbusds.jose.jwk.JWKMatcher;
3847
import com.nimbusds.jose.jwk.JWKSelector;
48+
import com.nimbusds.jose.jwk.JWKSet;
3949
import com.nimbusds.jose.jwk.KeyType;
4050
import com.nimbusds.jose.jwk.KeyUse;
51+
import com.nimbusds.jose.jwk.OctetSequenceKey;
52+
import com.nimbusds.jose.jwk.RSAKey;
53+
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
4154
import com.nimbusds.jose.jwk.source.JWKSource;
4255
import com.nimbusds.jose.proc.SecurityContext;
4356
import com.nimbusds.jose.produce.JWSSignerFactory;
@@ -47,6 +60,7 @@
4760
import com.nimbusds.jwt.SignedJWT;
4861

4962
import org.springframework.core.convert.converter.Converter;
63+
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
5064
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
5165
import org.springframework.util.Assert;
5266
import org.springframework.util.CollectionUtils;
@@ -83,6 +97,8 @@ public final class NimbusJwtEncoder implements JwtEncoder {
8397

8498
private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
8599

100+
private JwsHeader jwsHeader;
101+
86102
private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
87103

88104
private final JWKSource<SecurityContext> jwkSource;
@@ -119,14 +135,16 @@ public void setJwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
119135
this.jwkSelector = jwkSelector;
120136
}
121137

138+
public void setJwsHeader(JwsHeader jwsHeader) {
139+
this.jwsHeader = jwsHeader;
140+
}
141+
122142
@Override
123143
public Jwt encode(JwtEncoderParameters parameters) throws JwtEncodingException {
124144
Assert.notNull(parameters, "parameters cannot be null");
125145

126146
JwsHeader headers = parameters.getJwsHeader();
127-
if (headers == null) {
128-
headers = DEFAULT_JWS_HEADER;
129-
}
147+
headers = (headers != null) ? headers : (this.jwsHeader != null) ? this.jwsHeader : DEFAULT_JWS_HEADER;
130148
JwtClaimsSet claims = parameters.getClaims();
131149

132150
JWK jwk = selectJwk(headers);
@@ -369,4 +387,283 @@ private static URI convertAsURI(String header, URL url) {
369387
}
370388
}
371389

390+
/**
391+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
392+
* {@link SecretKey}.
393+
* @param secretKey the {@link SecretKey} to use for signing JWTs
394+
* @return a {@link SecretKeyJwtEncoderBuilder} for further configuration
395+
* @since // TODO: Update version
396+
*/
397+
public static SecretKeyJwtEncoderBuilder withSecretKey(SecretKey secretKey) {
398+
Assert.notNull(secretKey, "secretKey cannot be null");
399+
return new SecretKeyJwtEncoderBuilder(secretKey);
400+
}
401+
402+
/**
403+
* Creates a builder for constructing a {@link NimbusJwtEncoder} using the provided
404+
* {@link KeyPair}. The key pair must contain either an {@link RSAKey} or an
405+
* {@link ECKey}.
406+
* @param keyPair the {@link KeyPair} to use for signing JWTs
407+
* @return a {@link KeyPairJwtEncoderBuilder} for further configuration
408+
* @since // TODO: Update version
409+
*/
410+
public static KeyPairJwtEncoderBuilder withKeyPair(KeyPair keyPair) {
411+
Assert.notNull(keyPair, "keyPair cannot be null");
412+
Assert.notNull(keyPair.getPrivate(), "keyPair must contain a private key");
413+
Assert.notNull(keyPair.getPublic(), "keyPair must contain a public key");
414+
Assert.isTrue(
415+
keyPair.getPrivate() instanceof java.security.interfaces.RSAKey
416+
|| keyPair.getPrivate() instanceof java.security.interfaces.ECKey,
417+
"keyPair must be an RSAKey or an ECKey");
418+
return new KeyPairJwtEncoderBuilder(keyPair);
419+
}
420+
421+
/**
422+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
423+
* {@link SecretKey}.
424+
*
425+
* @since // TODO: Update version
426+
*/
427+
public static final class SecretKeyJwtEncoderBuilder {
428+
429+
private final SecretKey secretKey;
430+
431+
private String keyId;
432+
433+
private JWSAlgorithm jwsAlgorithm = JWSAlgorithm.HS256;
434+
435+
private Converter<List<JWK>, JWK> jwkSelector;
436+
437+
private SecretKeyJwtEncoderBuilder(SecretKey secretKey) {
438+
this.secretKey = secretKey;
439+
}
440+
441+
/**
442+
* Sets the JWS algorithm to use for signing. Defaults to
443+
* {@link JWSAlgorithm#HS256}. Must be an HMAC-based algorithm (HS256, HS384, or
444+
* HS512).
445+
* @param macAlgorithm the {@link MacAlgorithm} to use
446+
* @return this builder instance for method chaining
447+
*/
448+
public SecretKeyJwtEncoderBuilder macAlgorithm(MacAlgorithm macAlgorithm) {
449+
Assert.notNull(macAlgorithm, "macAlgorithm cannot be null");
450+
this.jwsAlgorithm = JWSAlgorithm.parse(macAlgorithm.getName());
451+
return this;
452+
}
453+
454+
/**
455+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
456+
* header.
457+
* @param keyId the key identifier
458+
* @return this builder instance for method chaining
459+
*/
460+
public SecretKeyJwtEncoderBuilder keyId(String keyId) {
461+
this.keyId = keyId;
462+
return this;
463+
}
464+
465+
/**
466+
* Configures the {@link Converter} used to select the JWK when multiple keys
467+
* match the header criteria. This is generally not needed for single-key setups
468+
* but is provided for consistency.
469+
* @param jwkSelector the {@link Converter} to select a {@link JWK}
470+
* @return this builder instance for method chaining
471+
* @since // TODO: Update version
472+
*/
473+
public SecretKeyJwtEncoderBuilder jwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
474+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
475+
this.jwkSelector = jwkSelector;
476+
return this;
477+
}
478+
479+
/**
480+
* Builds the {@link NimbusJwtEncoder} instance.
481+
* @return the configured {@link NimbusJwtEncoder}
482+
* @throws IllegalStateException if the configured JWS algorithm is not compatible
483+
* with a {@link SecretKey}.
484+
*/
485+
public NimbusJwtEncoder build() {
486+
Assert.state(JWSAlgorithm.Family.HMAC_SHA.contains(this.jwsAlgorithm),
487+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with a SecretKey. "
488+
+ "Please use one of the HS256, HS384, or HS512 algorithms.");
489+
this.jwsAlgorithm = (this.jwsAlgorithm != null) ? this.jwsAlgorithm : JWSAlgorithm.HS256;
490+
491+
OctetSequenceKey.Builder builder = new OctetSequenceKey.Builder(this.secretKey).keyUse(KeyUse.SIGNATURE)
492+
.algorithm(this.jwsAlgorithm);
493+
494+
if (StringUtils.hasText(this.keyId)) {
495+
builder.keyID(this.keyId);
496+
}
497+
498+
OctetSequenceKey jwk = builder.build();
499+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
500+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
501+
encoder.setJwsHeader(JwsHeader.with(MacAlgorithm.from(this.jwsAlgorithm.getName())).build());
502+
if (this.jwkSelector != null) {
503+
encoder.setJwkSelector(this.jwkSelector);
504+
}
505+
else {
506+
encoder.setJwkSelector((jwkSet) -> jwkSet.stream().findFirst().orElse(null));
507+
}
508+
return encoder;
509+
}
510+
511+
}
512+
513+
/**
514+
* A builder for creating {@link NimbusJwtEncoder} instances configured with a
515+
* {@link KeyPair}.
516+
*
517+
* @since // TODO: Update version
518+
*/
519+
public static final class KeyPairJwtEncoderBuilder {
520+
521+
private final KeyPair keyPair;
522+
523+
private String keyId;
524+
525+
private JWSAlgorithm jwsAlgorithm;
526+
527+
private Converter<List<JWK>, JWK> jwkSelector;
528+
529+
private JwsHeader jwsHeader;
530+
531+
private KeyPairJwtEncoderBuilder(KeyPair keyPair) {
532+
this.keyPair = keyPair;
533+
}
534+
535+
/**
536+
* Sets the JWS algorithm to use for signing. Must be compatible with the key type
537+
* (RSA or EC). If not set, a default algorithm will be chosen based on the key
538+
* type (e.g., RS256 for RSA, ES256 for EC).
539+
* @param signatureAlgorithm the {@link SignatureAlgorithm} to use
540+
* @return this builder instance for method chaining
541+
*/
542+
public KeyPairJwtEncoderBuilder signatureAlgorithm(SignatureAlgorithm signatureAlgorithm) {
543+
Assert.notNull(signatureAlgorithm, "signatureAlgorithm cannot be null");
544+
this.jwsAlgorithm = JWSAlgorithm.parse(signatureAlgorithm.getName());
545+
return this;
546+
}
547+
548+
/**
549+
* Sets the key ID ({@code kid}) to be included in the JWK and potentially the JWS
550+
* header.
551+
* @param keyId the key identifier
552+
* @return this builder instance for method chaining
553+
*/
554+
public KeyPairJwtEncoderBuilder keyId(String keyId) {
555+
this.keyId = keyId;
556+
return this;
557+
}
558+
559+
/**
560+
* Configures the {@link Converter} used to select the JWK when multiple keys
561+
* match the header criteria. This is generally not needed for single-key setups
562+
* but is provided for consistency.
563+
* @param jwkSelector the {@link Converter} to select a {@link JWK}
564+
* @return this builder instance for method chaining
565+
* @since // TODO: Update version
566+
*/
567+
public KeyPairJwtEncoderBuilder jwkSelector(Converter<List<JWK>, JWK> jwkSelector) {
568+
Assert.notNull(jwkSelector, "jwkSelector cannot be null");
569+
this.jwkSelector = jwkSelector;
570+
return this;
571+
}
572+
573+
/**
574+
* Builds the {@link NimbusJwtEncoder} instance.
575+
* @return the configured {@link NimbusJwtEncoder}
576+
* @throws IllegalStateException if the key type is unsupported or the configured
577+
* JWS algorithm is not compatible with the key type.
578+
* @throws JwtEncodingException if the key is invalid (e.g., EC key with unknown
579+
* curve)
580+
*/
581+
public NimbusJwtEncoder build() {
582+
this.keyId = (this.keyId != null) ? this.keyId : UUID.randomUUID().toString();
583+
JWK jwk = buildJwk();
584+
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
585+
NimbusJwtEncoder encoder = new NimbusJwtEncoder(jwkSource);
586+
if (this.jwsHeader != null) {
587+
encoder.setJwsHeader(this.jwsHeader);
588+
}
589+
if (this.jwkSelector != null) {
590+
encoder.setJwkSelector(this.jwkSelector);
591+
}
592+
else {
593+
encoder.setJwkSelector((jwkSet) -> jwkSet.stream().findFirst().orElse(null));
594+
}
595+
return encoder;
596+
}
597+
598+
private JWK buildJwk() {
599+
if (this.keyPair.getPrivate() instanceof RSAPrivateKey) {
600+
RSAKey rsaKey = buildRsaJwk();
601+
this.jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
602+
.keyId(rsaKey.getKeyID())
603+
.build();
604+
return rsaKey;
605+
}
606+
else if (this.keyPair.getPrivate() instanceof ECPrivateKey) {
607+
ECKey ecKey = buildEcJwk();
608+
this.jwsHeader = JwsHeader.with(SignatureAlgorithm.from(this.jwsAlgorithm.getName()))
609+
.keyId(ecKey.getKeyID())
610+
.build();
611+
return ecKey;
612+
}
613+
else {
614+
throw new IllegalStateException(
615+
"Unsupported key pair type: " + this.keyPair.getPrivate().getClass().getName());
616+
}
617+
}
618+
619+
private RSAKey buildRsaJwk() {
620+
if (this.jwsAlgorithm == null) {
621+
this.jwsAlgorithm = JWSAlgorithm.RS256;
622+
}
623+
Assert.state(JWSAlgorithm.Family.RSA.contains(this.jwsAlgorithm),
624+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an RSAKey. "
625+
+ "Please use one of the RS256, RS384, RS512, PS256, PS384, or PS512 algorithms.");
626+
627+
RSAKey.Builder builder = new RSAKey.Builder(
628+
(java.security.interfaces.RSAPublicKey) this.keyPair.getPublic())
629+
.privateKey(this.keyPair.getPrivate())
630+
.keyUse(KeyUse.SIGNATURE)
631+
.algorithm(this.jwsAlgorithm);
632+
633+
if (StringUtils.hasText(this.keyId)) {
634+
builder.keyID(this.keyId);
635+
}
636+
return builder.build();
637+
}
638+
639+
private com.nimbusds.jose.jwk.ECKey buildEcJwk() {
640+
if (this.jwsAlgorithm == null) {
641+
this.jwsAlgorithm = JWSAlgorithm.ES256;
642+
}
643+
Assert.state(JWSAlgorithm.Family.EC.contains(this.jwsAlgorithm),
644+
() -> "The algorithm '" + this.jwsAlgorithm + "' is not compatible with an ECKey. "
645+
+ "Please use one of the ES256, ES384, or ES512 algorithms.");
646+
647+
ECPublicKey publicKey = (ECPublicKey) this.keyPair.getPublic();
648+
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
649+
if (curve == null) {
650+
throw new JwtEncodingException("Unable to determine Curve for EC public key.");
651+
}
652+
653+
com.nimbusds.jose.jwk.ECKey.Builder builder = new com.nimbusds.jose.jwk.ECKey.Builder(curve, publicKey)
654+
.privateKey(this.keyPair.getPrivate())
655+
.keyUse(KeyUse.SIGNATURE)
656+
.keyID(this.keyId)
657+
.algorithm(this.jwsAlgorithm);
658+
659+
try {
660+
return builder.build();
661+
}
662+
catch (IllegalStateException ex) {
663+
throw new IllegalArgumentException("Failed to build ECKey: " + ex.getMessage(), ex);
664+
}
665+
}
666+
667+
}
668+
372669
}

0 commit comments

Comments
 (0)