1
0
mirror of synced 2026-05-22 21:33:16 +00:00

Enable null-safety in spring-security-oauth2-authorization-server

Closes gh-18937
This commit is contained in:
Joe Grandja
2026-03-19 13:56:17 -04:00
parent fe24bd3d0c
commit 1db0d4f83d
166 changed files with 1861 additions and 858 deletions
@@ -1,5 +1,6 @@
plugins { plugins {
id 'compile-warnings-error' id 'compile-warnings-error'
id 'security-nullability'
} }
apply plugin: 'io.spring.convention.spring-module' apply plugin: 'io.spring.convention.spring-module'
@@ -23,7 +23,8 @@ import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
@@ -90,8 +91,7 @@ public final class InMemoryOAuth2AuthorizationConsentService implements OAuth2Au
} }
@Override @Override
@Nullable public @Nullable OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) {
public OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) {
Assert.hasText(registeredClientId, "registeredClientId cannot be empty"); Assert.hasText(registeredClientId, "registeredClientId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty");
int id = getId(registeredClientId, principalName); int id = getId(registeredClientId, principalName);
@@ -23,7 +23,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2DeviceCode; import org.springframework.security.oauth2.core.OAuth2DeviceCode;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@@ -125,17 +126,15 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
} }
} }
@Nullable
@Override @Override
public OAuth2Authorization findById(String id) { public @Nullable OAuth2Authorization findById(String id) {
Assert.hasText(id, "id cannot be empty"); Assert.hasText(id, "id cannot be empty");
OAuth2Authorization authorization = this.authorizations.get(id); OAuth2Authorization authorization = this.authorizations.get(id);
return (authorization != null) ? authorization : this.initializedAuthorizations.get(id); return (authorization != null) ? authorization : this.initializedAuthorizations.get(id);
} }
@Nullable
@Override @Override
public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) { public @Nullable OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) {
Assert.hasText(token, "token cannot be empty"); Assert.hasText(token, "token cannot be empty");
for (OAuth2Authorization authorization : this.authorizations.values()) { for (OAuth2Authorization authorization : this.authorizations.values()) {
if (hasToken(authorization, token, tokenType)) { if (hasToken(authorization, token, tokenType)) {
@@ -25,6 +25,8 @@ import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.function.Function; import java.util.function.Function;
import org.jspecify.annotations.Nullable;
import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar; import org.springframework.aot.hint.RuntimeHintsRegistrar;
import org.springframework.context.annotation.ImportRuntimeHints; import org.springframework.context.annotation.ImportRuntimeHints;
@@ -35,7 +37,6 @@ import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter; import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SqlParameterValue; import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.lang.Nullable;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -162,8 +163,7 @@ public class JdbcOAuth2AuthorizationConsentService implements OAuth2Authorizatio
} }
@Override @Override
@Nullable public @Nullable OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) {
public OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) {
Assert.hasText(registeredClientId, "registeredClientId cannot be empty"); Assert.hasText(registeredClientId, "registeredClientId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty");
SqlParameterValue[] parameters = new SqlParameterValue[] { SqlParameterValue[] parameters = new SqlParameterValue[] {
@@ -281,7 +281,7 @@ public class JdbcOAuth2AuthorizationConsentService implements OAuth2Authorizatio
static class JdbcOAuth2AuthorizationConsentServiceRuntimeHintsRegistrar implements RuntimeHintsRegistrar { static class JdbcOAuth2AuthorizationConsentServiceRuntimeHintsRegistrar implements RuntimeHintsRegistrar {
@Override @Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) { public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
hints.resources() hints.resources()
.registerResource(new ClassPathResource( .registerResource(new ClassPathResource(
"org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql")); "org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql"));
@@ -17,6 +17,7 @@
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.DatabaseMetaData; import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement; import java.sql.PreparedStatement;
import java.sql.ResultSet; import java.sql.ResultSet;
@@ -36,6 +37,7 @@ import java.util.function.Function;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.jspecify.annotations.Nullable;
import tools.jackson.databind.JacksonModule; import tools.jackson.databind.JacksonModule;
import tools.jackson.databind.json.JsonMapper; import tools.jackson.databind.json.JsonMapper;
@@ -54,7 +56,6 @@ import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.jdbc.support.lob.DefaultLobHandler; import org.springframework.jdbc.support.lob.DefaultLobHandler;
import org.springframework.jdbc.support.lob.LobCreator; import org.springframework.jdbc.support.lob.LobCreator;
import org.springframework.jdbc.support.lob.LobHandler; import org.springframework.jdbc.support.lob.LobHandler;
import org.springframework.lang.Nullable;
import org.springframework.security.jackson.SecurityJacksonModules; import org.springframework.security.jackson.SecurityJacksonModules;
import org.springframework.security.jackson2.SecurityJackson2Modules; import org.springframework.security.jackson2.SecurityJackson2Modules;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -210,7 +211,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
private static final String REMOVE_AUTHORIZATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; private static final String REMOVE_AUTHORIZATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
private static Map<String, ColumnMetadata> columnMetadataMap; private static final Map<String, ColumnMetadata> columnMetadataMap = new HashMap<>();
private final JdbcOperations jdbcOperations; private final JdbcOperations jdbcOperations;
@@ -292,18 +293,16 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
this.jdbcOperations.update(REMOVE_AUTHORIZATION_SQL, pss); this.jdbcOperations.update(REMOVE_AUTHORIZATION_SQL, pss);
} }
@Nullable
@Override @Override
public OAuth2Authorization findById(String id) { public @Nullable OAuth2Authorization findById(String id) {
Assert.hasText(id, "id cannot be empty"); Assert.hasText(id, "id cannot be empty");
List<SqlParameterValue> parameters = new ArrayList<>(); List<SqlParameterValue> parameters = new ArrayList<>();
parameters.add(new SqlParameterValue(Types.VARCHAR, id)); parameters.add(new SqlParameterValue(Types.VARCHAR, id));
return findBy(PK_FILTER, parameters); return findBy(PK_FILTER, parameters);
} }
@Nullable
@Override @Override
public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) { public @Nullable OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) {
Assert.hasText(token, "token cannot be empty"); Assert.hasText(token, "token cannot be empty");
List<SqlParameterValue> parameters = new ArrayList<>(); List<SqlParameterValue> parameters = new ArrayList<>();
if (tokenType == null) { if (tokenType == null) {
@@ -347,7 +346,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
return null; return null;
} }
private OAuth2Authorization findBy(String filter, List<SqlParameterValue> parameters) { private @Nullable OAuth2Authorization findBy(String filter, List<SqlParameterValue> parameters) {
try (LobCreator lobCreator = getLobHandler().getLobCreator()) { try (LobCreator lobCreator = getLobHandler().getLobCreator()) {
PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
parameters.toArray()); parameters.toArray());
@@ -399,7 +398,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
} }
private static void initColumnMetadata(JdbcOperations jdbcOperations) { private static void initColumnMetadata(JdbcOperations jdbcOperations) {
columnMetadataMap = new HashMap<>(); columnMetadataMap.clear();
ColumnMetadata columnMetadata; ColumnMetadata columnMetadata;
columnMetadata = getColumnMetadata(jdbcOperations, "attributes", Types.BLOB); columnMetadata = getColumnMetadata(jdbcOperations, "attributes", Types.BLOB);
@@ -432,32 +431,37 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
private static ColumnMetadata getColumnMetadata(JdbcOperations jdbcOperations, String columnName, private static ColumnMetadata getColumnMetadata(JdbcOperations jdbcOperations, String columnName,
int defaultDataType) { int defaultDataType) {
Integer dataType = jdbcOperations.execute((ConnectionCallback<Integer>) (conn) -> { @Nullable Integer dataType = jdbcOperations.execute(new ConnectionCallback<@Nullable Integer>() {
DatabaseMetaData databaseMetaData = conn.getMetaData(); @Override
ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName); public @Nullable Integer doInConnection(Connection conn) throws SQLException {
if (rs.next()) { DatabaseMetaData databaseMetaData = conn.getMetaData();
return rs.getInt("DATA_TYPE"); ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName);
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
// NOTE: (Applies to HSQL)
// When a database object is created with one of the CREATE statements or
// renamed with the ALTER statement,
// if the name is enclosed in double quotes, the exact name is used as the
// case-normal form.
// But if it is not enclosed in double quotes,
// the name is converted to uppercase and this uppercase version is stored
// in
// the database as the case-normal form.
rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(Locale.ENGLISH),
columnName.toUpperCase(Locale.ENGLISH));
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
return null;
} }
// NOTE: (Applies to HSQL)
// When a database object is created with one of the CREATE statements or
// renamed with the ALTER statement,
// if the name is enclosed in double quotes, the exact name is used as the
// case-normal form.
// But if it is not enclosed in double quotes,
// the name is converted to uppercase and this uppercase version is stored in
// the database as the case-normal form.
rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(Locale.ENGLISH),
columnName.toUpperCase(Locale.ENGLISH));
if (rs.next()) {
return rs.getInt("DATA_TYPE");
}
return null;
}); });
return new ColumnMetadata(columnName, (dataType != null) ? dataType : defaultDataType); return new ColumnMetadata(columnName, (dataType != null) ? dataType : defaultDataType);
} }
private static SqlParameterValue mapToSqlParameter(String columnName, String value) { private static SqlParameterValue mapToSqlParameter(String columnName, @Nullable String value) {
ColumnMetadata columnMetadata = columnMetadataMap.get(columnName); ColumnMetadata columnMetadata = columnMetadataMap.get(columnName);
Assert.notNull(columnMetadata, "Column metadata not found for column '" + columnName + "'");
return (Types.BLOB == columnMetadata.getDataType() && StringUtils.hasText(value)) return (Types.BLOB == columnMetadata.getDataType() && StringUtils.hasText(value))
? new SqlParameterValue(Types.BLOB, value.getBytes(StandardCharsets.UTF_8)) ? new SqlParameterValue(Types.BLOB, value.getBytes(StandardCharsets.UTF_8))
: new SqlParameterValue(columnMetadata.getDataType(), value); : new SqlParameterValue(columnMetadata.getDataType(), value);
@@ -610,6 +614,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
.equalsIgnoreCase(rs.getString("access_token_type"))) { .equalsIgnoreCase(rs.getString("access_token_type"))) {
tokenType = OAuth2AccessToken.TokenType.DPOP; tokenType = OAuth2AccessToken.TokenType.DPOP;
} }
Assert.notNull(tokenType, "access_token_type must be BEARER or DPOP");
Set<String> scopes = Collections.emptySet(); Set<String> scopes = Collections.emptySet();
String accessTokenScopes = rs.getString("access_token_scopes"); String accessTokenScopes = rs.getString("access_token_scopes");
@@ -627,8 +632,13 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant(); tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant();
Map<String, Object> oidcTokenMetadata = parseMap(getLobValue(rs, OIDC_ID_TOKEN_METADATA)); Map<String, Object> oidcTokenMetadata = parseMap(getLobValue(rs, OIDC_ID_TOKEN_METADATA));
OidcIdToken oidcToken = new OidcIdToken(oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt, @SuppressWarnings("unchecked")
(Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME)); Map<String, Object> idTokenClaims = (Map<String, Object>) oidcTokenMetadata
.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME);
if (idTokenClaims == null) {
idTokenClaims = Collections.emptyMap();
}
OidcIdToken oidcToken = new OidcIdToken(oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt, idTokenClaims);
builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata)); builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
} }
@@ -670,9 +680,10 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
return builder.build(); return builder.build();
} }
private String getLobValue(ResultSet rs, String columnName) throws SQLException { private @Nullable String getLobValue(ResultSet rs, String columnName) throws SQLException {
String columnValue = null; String columnValue = null;
ColumnMetadata columnMetadata = columnMetadataMap.get(columnName); ColumnMetadata columnMetadata = columnMetadataMap.get(columnName);
Assert.notNull(columnMetadata, "Column metadata not found for column '" + columnName + "'");
if (Types.BLOB == columnMetadata.getDataType()) { if (Types.BLOB == columnMetadata.getDataType()) {
byte[] columnValueBytes = this.lobHandler.getBlobAsBytes(rs, columnName); byte[] columnValueBytes = this.lobHandler.getBlobAsBytes(rs, columnName);
if (columnValueBytes != null) { if (columnValueBytes != null) {
@@ -701,7 +712,10 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
return this.lobHandler; return this.lobHandler;
} }
private Map<String, Object> parseMap(String data) { private Map<String, Object> parseMap(@Nullable String data) {
if (!StringUtils.hasText(data)) {
return Collections.emptyMap();
}
try { try {
return readValue(data); return readValue(data);
} }
@@ -849,7 +863,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
} }
private <T extends OAuth2Token> List<SqlParameterValue> toSqlParameterList(String tokenColumnName, private <T extends OAuth2Token> List<SqlParameterValue> toSqlParameterList(String tokenColumnName,
String tokenMetadataColumnName, OAuth2Authorization.Token<T> token) { String tokenMetadataColumnName, OAuth2Authorization.@Nullable Token<T> token) {
List<SqlParameterValue> parameters = new ArrayList<>(); List<SqlParameterValue> parameters = new ArrayList<>();
String tokenValue = null; String tokenValue = null;
@@ -933,7 +947,8 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
} }
@Override @Override
protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException { protected void doSetValue(PreparedStatement ps, int parameterPosition, @Nullable Object argValue)
throws SQLException {
if (argValue instanceof SqlParameterValue paramValue) { if (argValue instanceof SqlParameterValue paramValue) {
if (paramValue.getSqlType() == Types.BLOB) { if (paramValue.getSqlType() == Types.BLOB) {
if (paramValue.getValue() != null) { if (paramValue.getValue() != null) {
@@ -983,7 +998,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
static class JdbcOAuth2AuthorizationServiceRuntimeHintsRegistrar implements RuntimeHintsRegistrar { static class JdbcOAuth2AuthorizationServiceRuntimeHintsRegistrar implements RuntimeHintsRegistrar {
@Override @Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) { public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
hints.resources() hints.resources()
.registerResource(new ClassPathResource( .registerResource(new ClassPathResource(
"org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql")); "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql"));
@@ -28,7 +28,8 @@ import java.util.Set;
import java.util.UUID; import java.util.UUID;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken; import org.springframework.security.oauth2.core.OAuth2RefreshToken;
@@ -58,19 +59,19 @@ public class OAuth2Authorization implements Serializable {
@Serial @Serial
private static final long serialVersionUID = 880363144799377926L; private static final long serialVersionUID = 880363144799377926L;
private String id; private @Nullable String id;
private String registeredClientId; private @Nullable String registeredClientId;
private String principalName; private @Nullable String principalName;
private AuthorizationGrantType authorizationGrantType; private @Nullable AuthorizationGrantType authorizationGrantType;
private Set<String> authorizedScopes; private @Nullable Set<String> authorizedScopes;
private Map<Class<? extends OAuth2Token>, Token<?>> tokens; private @Nullable Map<Class<? extends OAuth2Token>, Token<?>> tokens;
private Map<String, Object> attributes; private @Nullable Map<String, Object> attributes;
protected OAuth2Authorization() { protected OAuth2Authorization() {
} }
@@ -80,6 +81,7 @@ public class OAuth2Authorization implements Serializable {
* @return the identifier for the authorization * @return the identifier for the authorization
*/ */
public String getId() { public String getId() {
Assert.notNull(this.id, "id cannot be null");
return this.id; return this.id;
} }
@@ -88,6 +90,7 @@ public class OAuth2Authorization implements Serializable {
* @return the {@link RegisteredClient#getId()} * @return the {@link RegisteredClient#getId()}
*/ */
public String getRegisteredClientId() { public String getRegisteredClientId() {
Assert.notNull(this.registeredClientId, "registeredClientId cannot be null");
return this.registeredClientId; return this.registeredClientId;
} }
@@ -96,6 +99,7 @@ public class OAuth2Authorization implements Serializable {
* @return the {@code Principal} name of the resource owner (or client) * @return the {@code Principal} name of the resource owner (or client)
*/ */
public String getPrincipalName() { public String getPrincipalName() {
Assert.notNull(this.principalName, "principalName cannot be null");
return this.principalName; return this.principalName;
} }
@@ -105,6 +109,7 @@ public class OAuth2Authorization implements Serializable {
* @return the {@link AuthorizationGrantType} used for the authorization * @return the {@link AuthorizationGrantType} used for the authorization
*/ */
public AuthorizationGrantType getAuthorizationGrantType() { public AuthorizationGrantType getAuthorizationGrantType() {
Assert.notNull(this.authorizationGrantType, "authorizationGrantType cannot be null");
return this.authorizationGrantType; return this.authorizationGrantType;
} }
@@ -113,14 +118,16 @@ public class OAuth2Authorization implements Serializable {
* @return the {@code Set} of authorized scope(s) * @return the {@code Set} of authorized scope(s)
*/ */
public Set<String> getAuthorizedScopes() { public Set<String> getAuthorizedScopes() {
Assert.notNull(this.authorizedScopes, "authorizedScopes cannot be null");
return this.authorizedScopes; return this.authorizedScopes;
} }
/** /**
* Returns the {@link Token} of type {@link OAuth2AccessToken}. * Returns the {@link Token} of type {@link OAuth2AccessToken}.
* @return the {@link Token} of type {@link OAuth2AccessToken} * @return the {@link Token} of type {@link OAuth2AccessToken}, or {@code null} if not
* available
*/ */
public Token<OAuth2AccessToken> getAccessToken() { public @Nullable Token<OAuth2AccessToken> getAccessToken() {
return getToken(OAuth2AccessToken.class); return getToken(OAuth2AccessToken.class);
} }
@@ -129,8 +136,7 @@ public class OAuth2Authorization implements Serializable {
* @return the {@link Token} of type {@link OAuth2RefreshToken}, or {@code null} if * @return the {@link Token} of type {@link OAuth2RefreshToken}, or {@code null} if
* not available * not available
*/ */
@Nullable public @Nullable Token<OAuth2RefreshToken> getRefreshToken() {
public Token<OAuth2RefreshToken> getRefreshToken() {
return getToken(OAuth2RefreshToken.class); return getToken(OAuth2RefreshToken.class);
} }
@@ -140,10 +146,10 @@ public class OAuth2Authorization implements Serializable {
* @param <T> the type of the token * @param <T> the type of the token
* @return the {@link Token}, or {@code null} if not available * @return the {@link Token}, or {@code null} if not available
*/ */
@Nullable
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T extends OAuth2Token> Token<T> getToken(Class<T> tokenType) { public <T extends OAuth2Token> @Nullable Token<T> getToken(Class<T> tokenType) {
Assert.notNull(tokenType, "tokenType cannot be null"); Assert.notNull(tokenType, "tokenType cannot be null");
Assert.notNull(this.tokens, "tokens cannot be null");
Token<?> token = this.tokens.get(tokenType); Token<?> token = this.tokens.get(tokenType);
return (token != null) ? (Token<T>) token : null; return (token != null) ? (Token<T>) token : null;
} }
@@ -154,10 +160,10 @@ public class OAuth2Authorization implements Serializable {
* @param <T> the type of the token * @param <T> the type of the token
* @return the {@link Token}, or {@code null} if not available * @return the {@link Token}, or {@code null} if not available
*/ */
@Nullable
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T extends OAuth2Token> Token<T> getToken(String tokenValue) { public <T extends OAuth2Token> @Nullable Token<T> getToken(String tokenValue) {
Assert.hasText(tokenValue, "tokenValue cannot be empty"); Assert.hasText(tokenValue, "tokenValue cannot be empty");
Assert.notNull(this.tokens, "tokens cannot be null");
for (Token<?> token : this.tokens.values()) { for (Token<?> token : this.tokens.values()) {
if (token.getToken().getTokenValue().equals(tokenValue)) { if (token.getToken().getTokenValue().equals(tokenValue)) {
return (Token<T>) token; return (Token<T>) token;
@@ -171,6 +177,7 @@ public class OAuth2Authorization implements Serializable {
* @return a {@code Map} of the attribute(s) * @return a {@code Map} of the attribute(s)
*/ */
public Map<String, Object> getAttributes() { public Map<String, Object> getAttributes() {
Assert.notNull(this.attributes, "attributes cannot be null");
return this.attributes; return this.attributes;
} }
@@ -181,10 +188,10 @@ public class OAuth2Authorization implements Serializable {
* @return the value of an attribute associated to the authorization, or {@code null} * @return the value of an attribute associated to the authorization, or {@code null}
* if not available * if not available
*/ */
@Nullable
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T> T getAttribute(String name) { public <T> @Nullable T getAttribute(String name) {
Assert.hasText(name, "name cannot be empty"); Assert.hasText(name, "name cannot be empty");
Assert.notNull(this.attributes, "attributes cannot be null");
return (T) this.attributes.get(name); return (T) this.attributes.get(name);
} }
@@ -230,6 +237,7 @@ public class OAuth2Authorization implements Serializable {
*/ */
public static Builder from(OAuth2Authorization authorization) { public static Builder from(OAuth2Authorization authorization) {
Assert.notNull(authorization, "authorization cannot be null"); Assert.notNull(authorization, "authorization cannot be null");
Assert.notNull(authorization.tokens, "tokens cannot be null");
return new Builder(authorization.getRegisteredClientId()).id(authorization.getId()) return new Builder(authorization.getRegisteredClientId()).id(authorization.getId())
.principalName(authorization.getPrincipalName()) .principalName(authorization.getPrincipalName())
.authorizationGrantType(authorization.getAuthorizationGrantType()) .authorizationGrantType(authorization.getAuthorizationGrantType())
@@ -324,8 +332,7 @@ public class OAuth2Authorization implements Serializable {
* Returns the claims associated to the token. * Returns the claims associated to the token.
* @return a {@code Map} of the claims, or {@code null} if not available * @return a {@code Map} of the claims, or {@code null} if not available
*/ */
@Nullable public @Nullable Map<String, Object> getClaims() {
public Map<String, Object> getClaims() {
return getMetadata(CLAIMS_METADATA_NAME); return getMetadata(CLAIMS_METADATA_NAME);
} }
@@ -335,9 +342,8 @@ public class OAuth2Authorization implements Serializable {
* @param <V> the value type of the metadata * @param <V> the value type of the metadata
* @return the value of the metadata, or {@code null} if not available * @return the value of the metadata, or {@code null} if not available
*/ */
@Nullable
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <V> V getMetadata(String name) { public <V> @Nullable V getMetadata(String name) {
Assert.hasText(name, "name cannot be empty"); Assert.hasText(name, "name cannot be empty");
return (V) this.metadata.get(name); return (V) this.metadata.get(name);
} }
@@ -380,15 +386,15 @@ public class OAuth2Authorization implements Serializable {
*/ */
public static class Builder { public static class Builder {
private String id; private @Nullable String id;
private final String registeredClientId; private final String registeredClientId;
private String principalName; private @Nullable String principalName;
private AuthorizationGrantType authorizationGrantType; private @Nullable AuthorizationGrantType authorizationGrantType;
private Set<String> authorizedScopes; private @Nullable Set<String> authorizedScopes;
private Map<Class<? extends OAuth2Token>, Token<?>> tokens = new HashMap<>(); private Map<Class<? extends OAuth2Token>, Token<?>> tokens = new HashMap<>();
@@ -503,8 +509,10 @@ public class OAuth2Authorization implements Serializable {
token(token, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)); token(token, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) { if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
Token<?> accessToken = this.tokens.get(OAuth2AccessToken.class); Token<?> accessToken = this.tokens.get(OAuth2AccessToken.class);
token(accessToken.getToken(), if (accessToken != null) {
(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true)); token(accessToken.getToken(),
(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
}
Token<?> authorizationCode = this.tokens.get(OAuth2AuthorizationCode.class); Token<?> authorizationCode = this.tokens.get(OAuth2AuthorizationCode.class);
if (authorizationCode != null && !authorizationCode.isInvalidated()) { if (authorizationCode != null && !authorizationCode.isInvalidated()) {
@@ -24,7 +24,6 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.NonNull;
import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority; import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -99,8 +98,9 @@ public final class OAuth2AuthorizationConsent implements Serializable {
public Set<String> getScopes() { public Set<String> getScopes() {
Set<String> authorities = new HashSet<>(); Set<String> authorities = new HashSet<>();
for (GrantedAuthority authority : getAuthorities()) { for (GrantedAuthority authority : getAuthorities()) {
if (authority.getAuthority().startsWith(AUTHORITIES_SCOPE_PREFIX)) { String authorityValue = authority.getAuthority();
authorities.add(authority.getAuthority().substring(AUTHORITIES_SCOPE_PREFIX.length())); if (authorityValue != null && authorityValue.startsWith(AUTHORITIES_SCOPE_PREFIX)) {
authorities.add(authorityValue.substring(AUTHORITIES_SCOPE_PREFIX.length()));
} }
} }
return authorities; return authorities;
@@ -146,7 +146,7 @@ public final class OAuth2AuthorizationConsent implements Serializable {
* @param principalName the {@code Principal} name * @param principalName the {@code Principal} name
* @return the {@link Builder} * @return the {@link Builder}
*/ */
public static Builder withId(@NonNull String registeredClientId, @NonNull String principalName) { public static Builder withId(String registeredClientId, String principalName) {
Assert.hasText(registeredClientId, "registeredClientId cannot be empty"); Assert.hasText(registeredClientId, "registeredClientId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty");
return new Builder(registeredClientId, principalName); return new Builder(registeredClientId, principalName);
@@ -18,7 +18,8 @@ package org.springframework.security.oauth2.server.authorization;
import java.security.Principal; import java.security.Principal;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
/** /**
@@ -50,7 +51,6 @@ public interface OAuth2AuthorizationConsentService {
* @param principalName the name of the {@link Principal} * @param principalName the name of the {@link Principal}
* @return the {@link OAuth2AuthorizationConsent} if found, otherwise {@code null} * @return the {@link OAuth2AuthorizationConsent} if found, otherwise {@code null}
*/ */
@Nullable @Nullable OAuth2AuthorizationConsent findById(String registeredClientId, String principalName);
OAuth2AuthorizationConsent findById(String registeredClientId, String principalName);
} }
@@ -19,8 +19,11 @@ package org.springframework.security.oauth2.server.authorization;
import java.net.URL; import java.net.URL;
import java.util.List; import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.ClaimAccessor; import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms;
import org.springframework.util.Assert;
/** /**
* A {@link ClaimAccessor} for the "claims" an Authorization Server describes about its * A {@link ClaimAccessor} for the "claims" an Authorization Server describes about its
@@ -57,7 +60,9 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the {@code URL} the Authorization Server asserts as its Issuer Identifier * @return the {@code URL} the Authorization Server asserts as its Issuer Identifier
*/ */
default URL getIssuer() { default URL getIssuer() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.ISSUER); URL issuer = getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.ISSUER);
Assert.notNull(issuer, "issuer cannot be null");
return issuer;
} }
/** /**
@@ -66,7 +71,9 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the {@code URL} of the OAuth 2.0 Authorization Endpoint * @return the {@code URL} of the OAuth 2.0 Authorization Endpoint
*/ */
default URL getAuthorizationEndpoint() { default URL getAuthorizationEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.AUTHORIZATION_ENDPOINT); URL authorizationEndpoint = getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.AUTHORIZATION_ENDPOINT);
Assert.notNull(authorizationEndpoint, "authorizationEndpoint cannot be null");
return authorizationEndpoint;
} }
/** /**
@@ -74,7 +81,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (pushed_authorization_request_endpoint)}. * {@code (pushed_authorization_request_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Pushed Authorization Request Endpoint * @return the {@code URL} of the OAuth 2.0 Pushed Authorization Request Endpoint
*/ */
default URL getPushedAuthorizationRequestEndpoint() { default @Nullable URL getPushedAuthorizationRequestEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.PUSHED_AUTHORIZATION_REQUEST_ENDPOINT); return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.PUSHED_AUTHORIZATION_REQUEST_ENDPOINT);
} }
@@ -83,7 +90,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (device_authorization_endpoint)}. * {@code (device_authorization_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Device Authorization Endpoint * @return the {@code URL} of the OAuth 2.0 Device Authorization Endpoint
*/ */
default URL getDeviceAuthorizationEndpoint() { default @Nullable URL getDeviceAuthorizationEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.DEVICE_AUTHORIZATION_ENDPOINT); return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.DEVICE_AUTHORIZATION_ENDPOINT);
} }
@@ -92,7 +99,9 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the {@code URL} of the OAuth 2.0 Token Endpoint * @return the {@code URL} of the OAuth 2.0 Token Endpoint
*/ */
default URL getTokenEndpoint() { default URL getTokenEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.TOKEN_ENDPOINT); URL tokenEndpoint = getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.TOKEN_ENDPOINT);
Assert.notNull(tokenEndpoint, "tokenEndpoint cannot be null");
return tokenEndpoint;
} }
/** /**
@@ -100,7 +109,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (token_endpoint_auth_methods_supported)}. * {@code (token_endpoint_auth_methods_supported)}.
* @return the client authentication methods supported by the OAuth 2.0 Token Endpoint * @return the client authentication methods supported by the OAuth 2.0 Token Endpoint
*/ */
default List<String> getTokenEndpointAuthenticationMethods() { default @Nullable List<String> getTokenEndpointAuthenticationMethods() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED); return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED);
} }
@@ -108,7 +117,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* Returns the {@code URL} of the JSON Web Key Set {@code (jwks_uri)}. * Returns the {@code URL} of the JSON Web Key Set {@code (jwks_uri)}.
* @return the {@code URL} of the JSON Web Key Set * @return the {@code URL} of the JSON Web Key Set
*/ */
default URL getJwkSetUrl() { default @Nullable URL getJwkSetUrl() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.JWKS_URI); return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.JWKS_URI);
} }
@@ -116,7 +125,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* Returns the OAuth 2.0 {@code scope} values supported {@code (scopes_supported)}. * Returns the OAuth 2.0 {@code scope} values supported {@code (scopes_supported)}.
* @return the OAuth 2.0 {@code scope} values supported * @return the OAuth 2.0 {@code scope} values supported
*/ */
default List<String> getScopes() { default @Nullable List<String> getScopes() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.SCOPES_SUPPORTED); return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.SCOPES_SUPPORTED);
} }
@@ -126,7 +135,10 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the OAuth 2.0 {@code response_type} values supported * @return the OAuth 2.0 {@code response_type} values supported
*/ */
default List<String> getResponseTypes() { default List<String> getResponseTypes() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.RESPONSE_TYPES_SUPPORTED); List<String> responseTypes = getClaimAsStringList(
OAuth2AuthorizationServerMetadataClaimNames.RESPONSE_TYPES_SUPPORTED);
Assert.notNull(responseTypes, "responseTypes cannot be null");
return responseTypes;
} }
/** /**
@@ -134,7 +146,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (grant_types_supported)}. * {@code (grant_types_supported)}.
* @return the OAuth 2.0 {@code grant_type} values supported * @return the OAuth 2.0 {@code grant_type} values supported
*/ */
default List<String> getGrantTypes() { default @Nullable List<String> getGrantTypes() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.GRANT_TYPES_SUPPORTED); return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.GRANT_TYPES_SUPPORTED);
} }
@@ -143,7 +155,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (revocation_endpoint)}. * {@code (revocation_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Token Revocation Endpoint * @return the {@code URL} of the OAuth 2.0 Token Revocation Endpoint
*/ */
default URL getTokenRevocationEndpoint() { default @Nullable URL getTokenRevocationEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.REVOCATION_ENDPOINT); return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.REVOCATION_ENDPOINT);
} }
@@ -153,7 +165,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the client authentication methods supported by the OAuth 2.0 Token * @return the client authentication methods supported by the OAuth 2.0 Token
* Revocation Endpoint * Revocation Endpoint
*/ */
default List<String> getTokenRevocationEndpointAuthenticationMethods() { default @Nullable List<String> getTokenRevocationEndpointAuthenticationMethods() {
return getClaimAsStringList( return getClaimAsStringList(
OAuth2AuthorizationServerMetadataClaimNames.REVOCATION_ENDPOINT_AUTH_METHODS_SUPPORTED); OAuth2AuthorizationServerMetadataClaimNames.REVOCATION_ENDPOINT_AUTH_METHODS_SUPPORTED);
} }
@@ -163,7 +175,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (introspection_endpoint)}. * {@code (introspection_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Token Introspection Endpoint * @return the {@code URL} of the OAuth 2.0 Token Introspection Endpoint
*/ */
default URL getTokenIntrospectionEndpoint() { default @Nullable URL getTokenIntrospectionEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.INTROSPECTION_ENDPOINT); return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.INTROSPECTION_ENDPOINT);
} }
@@ -173,7 +185,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the client authentication methods supported by the OAuth 2.0 Token * @return the client authentication methods supported by the OAuth 2.0 Token
* Introspection Endpoint * Introspection Endpoint
*/ */
default List<String> getTokenIntrospectionEndpointAuthenticationMethods() { default @Nullable List<String> getTokenIntrospectionEndpointAuthenticationMethods() {
return getClaimAsStringList( return getClaimAsStringList(
OAuth2AuthorizationServerMetadataClaimNames.INTROSPECTION_ENDPOINT_AUTH_METHODS_SUPPORTED); OAuth2AuthorizationServerMetadataClaimNames.INTROSPECTION_ENDPOINT_AUTH_METHODS_SUPPORTED);
} }
@@ -183,7 +195,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* {@code (registration_endpoint)}. * {@code (registration_endpoint)}.
* @return the {@code URL} of the OAuth 2.0 Dynamic Client Registration Endpoint * @return the {@code URL} of the OAuth 2.0 Dynamic Client Registration Endpoint
*/ */
default URL getClientRegistrationEndpoint() { default @Nullable URL getClientRegistrationEndpoint() {
return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.REGISTRATION_ENDPOINT); return getClaimAsURL(OAuth2AuthorizationServerMetadataClaimNames.REGISTRATION_ENDPOINT);
} }
@@ -192,7 +204,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* supported {@code (code_challenge_methods_supported)}. * supported {@code (code_challenge_methods_supported)}.
* @return the {@code code_challenge_method} values supported * @return the {@code code_challenge_method} values supported
*/ */
default List<String> getCodeChallengeMethods() { default @Nullable List<String> getCodeChallengeMethods() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.CODE_CHALLENGE_METHODS_SUPPORTED); return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.CODE_CHALLENGE_METHODS_SUPPORTED);
} }
@@ -213,7 +225,7 @@ public interface OAuth2AuthorizationServerMetadataClaimAccessor extends ClaimAcc
* @return the {@link JwsAlgorithms JSON Web Signature (JWS) algorithms} supported for * @return the {@link JwsAlgorithms JSON Web Signature (JWS) algorithms} supported for
* DPoP Proof JWTs * DPoP Proof JWTs
*/ */
default List<String> getDPoPSigningAlgorithms() { default @Nullable List<String> getDPoPSigningAlgorithms() {
return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.DPOP_SIGNING_ALG_VALUES_SUPPORTED); return getClaimAsStringList(OAuth2AuthorizationServerMetadataClaimNames.DPOP_SIGNING_ALG_VALUES_SUPPORTED);
} }
@@ -16,7 +16,7 @@
package org.springframework.security.oauth2.server.authorization; package org.springframework.security.oauth2.server.authorization;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
/** /**
* Implementations of this interface are responsible for the management of * Implementations of this interface are responsible for the management of
@@ -47,8 +47,7 @@ public interface OAuth2AuthorizationService {
* @param id the authorization identifier * @param id the authorization identifier
* @return the {@link OAuth2Authorization} if found, otherwise {@code null} * @return the {@link OAuth2Authorization} if found, otherwise {@code null}
*/ */
@Nullable @Nullable OAuth2Authorization findById(String id);
OAuth2Authorization findById(String id);
/** /**
* Returns the {@link OAuth2Authorization} containing the provided {@code token}, or * Returns the {@link OAuth2Authorization} containing the provided {@code token}, or
@@ -57,7 +56,6 @@ public interface OAuth2AuthorizationService {
* @param tokenType the {@link OAuth2TokenType token type} * @param tokenType the {@link OAuth2TokenType token type}
* @return the {@link OAuth2Authorization} if found, otherwise {@code null} * @return the {@link OAuth2Authorization} if found, otherwise {@code null}
*/ */
@Nullable @Nullable OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType);
OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType);
} }
@@ -20,7 +20,10 @@ import java.net.URL;
import java.time.Instant; import java.time.Instant;
import java.util.List; import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.ClaimAccessor; import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.util.Assert;
/** /**
* A {@link ClaimAccessor} for the claims that are contained in the OAuth 2.0 Client * A {@link ClaimAccessor} for the claims that are contained in the OAuth 2.0 Client
@@ -41,7 +44,9 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* @return the Client Identifier * @return the Client Identifier
*/ */
default String getClientId() { default String getClientId() {
return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_ID); String clientId = getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_ID);
Assert.notNull(clientId, "clientId cannot be null");
return clientId;
} }
/** /**
@@ -49,7 +54,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (client_id_issued_at)}. * {@code (client_id_issued_at)}.
* @return the time at which the Client Identifier was issued * @return the time at which the Client Identifier was issued
*/ */
default Instant getClientIdIssuedAt() { default @Nullable Instant getClientIdIssuedAt() {
return getClaimAsInstant(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT); return getClaimAsInstant(OAuth2ClientMetadataClaimNames.CLIENT_ID_ISSUED_AT);
} }
@@ -57,7 +62,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* Returns the Client Secret {@code (client_secret)}. * Returns the Client Secret {@code (client_secret)}.
* @return the Client Secret * @return the Client Secret
*/ */
default String getClientSecret() { default @Nullable String getClientSecret() {
return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_SECRET); return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_SECRET);
} }
@@ -66,7 +71,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (client_secret_expires_at)}. * {@code (client_secret_expires_at)}.
* @return the time at which the {@code client_secret} will expire * @return the time at which the {@code client_secret} will expire
*/ */
default Instant getClientSecretExpiresAt() { default @Nullable Instant getClientSecretExpiresAt() {
return getClaimAsInstant(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT); return getClaimAsInstant(OAuth2ClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT);
} }
@@ -75,7 +80,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (client_name)}. * {@code (client_name)}.
* @return the name of the Client to be presented to the End-User * @return the name of the Client to be presented to the End-User
*/ */
default String getClientName() { default @Nullable String getClientName() {
return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_NAME); return getClaimAsString(OAuth2ClientMetadataClaimNames.CLIENT_NAME);
} }
@@ -84,7 +89,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (redirect_uris)}. * {@code (redirect_uris)}.
* @return the redirection {@code URI} values used by the Client * @return the redirection {@code URI} values used by the Client
*/ */
default List<String> getRedirectUris() { default @Nullable List<String> getRedirectUris() {
return getClaimAsStringList(OAuth2ClientMetadataClaimNames.REDIRECT_URIS); return getClaimAsStringList(OAuth2ClientMetadataClaimNames.REDIRECT_URIS);
} }
@@ -93,7 +98,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* {@code (token_endpoint_auth_method)}. * {@code (token_endpoint_auth_method)}.
* @return the authentication method used by the Client for the Token Endpoint * @return the authentication method used by the Client for the Token Endpoint
*/ */
default String getTokenEndpointAuthenticationMethod() { default @Nullable String getTokenEndpointAuthenticationMethod() {
return getClaimAsString(OAuth2ClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD); return getClaimAsString(OAuth2ClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD);
} }
@@ -103,7 +108,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* @return the OAuth 2.0 {@code grant_type} values that the Client will restrict * @return the OAuth 2.0 {@code grant_type} values that the Client will restrict
* itself to using * itself to using
*/ */
default List<String> getGrantTypes() { default @Nullable List<String> getGrantTypes() {
return getClaimAsStringList(OAuth2ClientMetadataClaimNames.GRANT_TYPES); return getClaimAsStringList(OAuth2ClientMetadataClaimNames.GRANT_TYPES);
} }
@@ -113,7 +118,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* @return the OAuth 2.0 {@code response_type} values that the Client will restrict * @return the OAuth 2.0 {@code response_type} values that the Client will restrict
* itself to using * itself to using
*/ */
default List<String> getResponseTypes() { default @Nullable List<String> getResponseTypes() {
return getClaimAsStringList(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES); return getClaimAsStringList(OAuth2ClientMetadataClaimNames.RESPONSE_TYPES);
} }
@@ -123,7 +128,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* @return the OAuth 2.0 {@code scope} values that the Client will restrict itself to * @return the OAuth 2.0 {@code scope} values that the Client will restrict itself to
* using * using
*/ */
default List<String> getScopes() { default @Nullable List<String> getScopes() {
return getClaimAsStringList(OAuth2ClientMetadataClaimNames.SCOPE); return getClaimAsStringList(OAuth2ClientMetadataClaimNames.SCOPE);
} }
@@ -131,7 +136,7 @@ public interface OAuth2ClientMetadataClaimAccessor extends ClaimAccessor {
* Returns the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}. * Returns the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}.
* @return the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)} * @return the {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
*/ */
default URL getJwkSetUrl() { default @Nullable URL getJwkSetUrl() {
return getClaimAsURL(OAuth2ClientMetadataClaimNames.JWKS_URI); return getClaimAsURL(OAuth2ClientMetadataClaimNames.JWKS_URI);
} }
@@ -20,6 +20,8 @@ import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import org.jspecify.annotations.Nullable;
import org.springframework.aot.generate.GenerationContext; import org.springframework.aot.generate.GenerationContext;
import org.springframework.aot.hint.BindingReflectionHintsRegistrar; import org.springframework.aot.hint.BindingReflectionHintsRegistrar;
import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.MemberCategory;
@@ -84,7 +86,7 @@ class OAuth2AuthorizationServerBeanRegistrationAotProcessor implements BeanRegis
private boolean jacksonContributed; private boolean jacksonContributed;
@Override @Override
public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) { public @Nullable BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) {
boolean isJdbcBasedOAuth2AuthorizationService = JdbcOAuth2AuthorizationService.class boolean isJdbcBasedOAuth2AuthorizationService = JdbcOAuth2AuthorizationService.class
.isAssignableFrom(registeredBean.getBeanClass()); .isAssignableFrom(registeredBean.getBeanClass());
@@ -16,6 +16,8 @@
package org.springframework.security.oauth2.server.authorization.aot.hint; package org.springframework.security.oauth2.server.authorization.aot.hint;
import org.jspecify.annotations.Nullable;
import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar; import org.springframework.aot.hint.RuntimeHintsRegistrar;
@@ -35,7 +37,7 @@ import org.springframework.security.oauth2.server.authorization.web.OAuth2Author
class OAuth2AuthorizationServerRuntimeHints implements RuntimeHintsRegistrar { class OAuth2AuthorizationServerRuntimeHints implements RuntimeHintsRegistrar {
@Override @Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) { public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
hints.reflection() hints.reflection()
.registerType(OAuth2AuthorizationCodeRequestAuthenticationProvider.class, .registerType(OAuth2AuthorizationCodeRequestAuthenticationProvider.class,
MemberCategory.INVOKE_DECLARED_METHODS); MemberCategory.INVOKE_DECLARED_METHODS);
@@ -0,0 +1,24 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Spring Framework AOT {@link org.springframework.aot.hint.RuntimeHints} for GraalVM
* native images for the authorization server module.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.aot.hint;
import org.jspecify.annotations.NullMarked;
@@ -23,7 +23,8 @@ import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -48,9 +49,9 @@ abstract class AbstractOAuth2AuthorizationCodeRequestAuthenticationToken extends
private final Authentication principal; private final Authentication principal;
private final String redirectUri; private final @Nullable String redirectUri;
private final String state; private final @Nullable String state;
private final Set<String> scopes; private final Set<String> scopes;
@@ -103,8 +104,7 @@ abstract class AbstractOAuth2AuthorizationCodeRequestAuthenticationToken extends
* Returns the redirect uri. * Returns the redirect uri.
* @return the redirect uri * @return the redirect uri
*/ */
@Nullable public @Nullable String getRedirectUri() {
public String getRedirectUri() {
return this.redirectUri; return this.redirectUri;
} }
@@ -112,8 +112,7 @@ abstract class AbstractOAuth2AuthorizationCodeRequestAuthenticationToken extends
* Returns the state. * Returns the state.
* @return the state * @return the state
*/ */
@Nullable public @Nullable String getState() {
public String getState() {
return this.state; return this.state;
} }
@@ -20,6 +20,7 @@ import java.time.Instant;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
@@ -92,7 +93,7 @@ public final class ClientSecretAuthenticationProvider implements AuthenticationP
} }
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication; OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
// @formatter:off // @formatter:off
@@ -105,7 +106,7 @@ public final class ClientSecretAuthenticationProvider implements AuthenticationP
String clientId = clientAuthentication.getPrincipal().toString(); String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) { if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); throw invalidClientException(OAuth2ParameterNames.CLIENT_ID);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -114,26 +115,27 @@ public final class ClientSecretAuthenticationProvider implements AuthenticationP
if (!registeredClient.getClientAuthenticationMethods() if (!registeredClient.getClientAuthenticationMethods()
.contains(clientAuthentication.getClientAuthenticationMethod())) { .contains(clientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient("authentication_method"); throw invalidClientException("authentication_method");
} }
if (clientAuthentication.getCredentials() == null) { Object credentials = clientAuthentication.getCredentials();
throwInvalidClient("credentials"); if (credentials == null) {
throw invalidClientException("credentials");
} }
String clientSecret = clientAuthentication.getCredentials().toString(); String clientSecret = credentials.toString();
if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) { if (!this.passwordEncoder.matches(clientSecret, registeredClient.getClientSecret())) {
if (this.logger.isDebugEnabled()) { if (this.logger.isDebugEnabled()) {
this.logger.debug(LogMessage.format( this.logger.debug(LogMessage.format(
"Invalid request: client_secret does not match" + " for registered client '%s'", "Invalid request: client_secret does not match" + " for registered client '%s'",
registeredClient.getId())); registeredClient.getId()));
} }
throwInvalidClient(OAuth2ParameterNames.CLIENT_SECRET); throw invalidClientException(OAuth2ParameterNames.CLIENT_SECRET);
} }
if (registeredClient.getClientSecretExpiresAt() != null if (registeredClient.getClientSecretExpiresAt() != null
&& Instant.now().isAfter(registeredClient.getClientSecretExpiresAt())) { && Instant.now().isAfter(registeredClient.getClientSecretExpiresAt())) {
throwInvalidClient("client_secret_expires_at"); throw invalidClientException("client_secret_expires_at");
} }
if (this.passwordEncoder.upgradeEncoding(registeredClient.getClientSecret())) { if (this.passwordEncoder.upgradeEncoding(registeredClient.getClientSecret())) {
@@ -164,10 +166,10 @@ public final class ClientSecretAuthenticationProvider implements AuthenticationP
return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication); return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
} }
private static void throwInvalidClient(String parameterName) { private static OAuth2AuthenticationException invalidClientException(String parameterName) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI); "Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error); return new OAuth2AuthenticationException(error);
} }
} }
@@ -24,6 +24,7 @@ import java.util.Map;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -65,7 +66,7 @@ final class CodeVerifierAuthenticator {
void authenticateRequired(OAuth2ClientAuthenticationToken clientAuthentication, RegisteredClient registeredClient) { void authenticateRequired(OAuth2ClientAuthenticationToken clientAuthentication, RegisteredClient registeredClient) {
if (!authenticate(clientAuthentication, registeredClient)) { if (!authenticate(clientAuthentication, registeredClient)) {
throwInvalidGrant(PkceParameterNames.CODE_VERIFIER); throw invalidGrantException(PkceParameterNames.CODE_VERIFIER);
} }
} }
@@ -82,10 +83,11 @@ final class CodeVerifierAuthenticator {
return false; return false;
} }
OAuth2Authorization authorization = this.authorizationService String code = (String) parameters.get(OAuth2ParameterNames.CODE);
.findByToken((String) parameters.get(OAuth2ParameterNames.CODE), AUTHORIZATION_CODE_TOKEN_TYPE); Assert.hasText(code, "code cannot be empty");
OAuth2Authorization authorization = this.authorizationService.findByToken(code, AUTHORIZATION_CODE_TOKEN_TYPE);
if (authorization == null) { if (authorization == null) {
throwInvalidGrant(OAuth2ParameterNames.CODE); throw invalidGrantException(OAuth2ParameterNames.CODE);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -94,6 +96,7 @@ final class CodeVerifierAuthenticator {
OAuth2AuthorizationRequest authorizationRequest = authorization OAuth2AuthorizationRequest authorizationRequest = authorization
.getAttribute(OAuth2AuthorizationRequest.class.getName()); .getAttribute(OAuth2AuthorizationRequest.class.getName());
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
String codeChallenge = (String) authorizationRequest.getAdditionalParameters() String codeChallenge = (String) authorizationRequest.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE); .get(PkceParameterNames.CODE_CHALLENGE);
@@ -105,7 +108,7 @@ final class CodeVerifierAuthenticator {
"Invalid request: code_challenge is required" + " for registered client '%s'", "Invalid request: code_challenge is required" + " for registered client '%s'",
registeredClient.getId())); registeredClient.getId()));
} }
throwInvalidGrant(PkceParameterNames.CODE_CHALLENGE); throw invalidGrantException(PkceParameterNames.CODE_CHALLENGE);
} }
else { else {
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -119,6 +122,7 @@ final class CodeVerifierAuthenticator {
this.logger.trace("Validated code verifier parameters"); this.logger.trace("Validated code verifier parameters");
} }
Assert.hasText(codeChallenge, "codeChallenge cannot be empty");
String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters() String codeChallengeMethod = (String) authorizationRequest.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD); .get(PkceParameterNames.CODE_CHALLENGE_METHOD);
if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) { if (!codeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
@@ -127,7 +131,7 @@ final class CodeVerifierAuthenticator {
"Invalid request: code_verifier is missing or invalid" + " for registered client '%s'", "Invalid request: code_verifier is missing or invalid" + " for registered client '%s'",
registeredClient.getId())); registeredClient.getId()));
} }
throwInvalidGrant(PkceParameterNames.CODE_VERIFIER); throw invalidGrantException(PkceParameterNames.CODE_VERIFIER);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -143,12 +147,13 @@ final class CodeVerifierAuthenticator {
return false; return false;
} }
if (!StringUtils.hasText((String) parameters.get(OAuth2ParameterNames.CODE))) { if (!StringUtils.hasText((String) parameters.get(OAuth2ParameterNames.CODE))) {
throwInvalidGrant(OAuth2ParameterNames.CODE); throw invalidGrantException(OAuth2ParameterNames.CODE);
} }
return true; return true;
} }
private boolean codeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) { private boolean codeVerifierValid(@Nullable String codeVerifier, String codeChallenge,
@Nullable String codeChallengeMethod) {
if (!StringUtils.hasText(codeVerifier)) { if (!StringUtils.hasText(codeVerifier)) {
return false; return false;
} }
@@ -169,10 +174,10 @@ final class CodeVerifierAuthenticator {
return false; return false;
} }
private static void throwInvalidGrant(String parameterName) { private static OAuth2AuthenticationException invalidGrantException(String parameterName) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT,
"Client authentication failed: " + parameterName, null); "Client authentication failed: " + parameterName, null);
throw new OAuth2AuthenticationException(error); return new OAuth2AuthenticationException(error);
} }
} }
@@ -16,6 +16,8 @@
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
@@ -24,6 +26,7 @@ import org.springframework.security.oauth2.jwt.DPoPProofJwtDecoderFactory;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder; import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtDecoderFactory; import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
/** /**
@@ -42,14 +45,17 @@ final class DPoPProofVerifier {
private DPoPProofVerifier() { private DPoPProofVerifier() {
} }
static Jwt verifyIfAvailable(OAuth2AuthorizationGrantAuthenticationToken authorizationGrantAuthentication) { static @Nullable Jwt verifyIfAvailable(
OAuth2AuthorizationGrantAuthenticationToken authorizationGrantAuthentication) {
String dPoPProof = (String) authorizationGrantAuthentication.getAdditionalParameters().get("dpop_proof"); String dPoPProof = (String) authorizationGrantAuthentication.getAdditionalParameters().get("dpop_proof");
if (!StringUtils.hasText(dPoPProof)) { if (!StringUtils.hasText(dPoPProof)) {
return null; return null;
} }
String method = (String) authorizationGrantAuthentication.getAdditionalParameters().get("dpop_method"); String method = (String) authorizationGrantAuthentication.getAdditionalParameters().get("dpop_method");
Assert.hasText(method, "dpop_method cannot be empty");
String targetUri = (String) authorizationGrantAuthentication.getAdditionalParameters().get("dpop_target_uri"); String targetUri = (String) authorizationGrantAuthentication.getAdditionalParameters().get("dpop_target_uri");
Assert.hasText(targetUri, "dpop_target_uri cannot be empty");
Jwt dPoPProofJwt; Jwt dPoPProofJwt;
try { try {
@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@@ -82,7 +83,7 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
} }
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication; OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
if (!JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod())) { if (!JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod())) {
@@ -92,7 +93,7 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
String clientId = clientAuthentication.getPrincipal().toString(); String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) { if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); throw invalidClientException(OAuth2ParameterNames.CLIENT_ID);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -102,21 +103,22 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
// @formatter:off // @formatter:off
if (!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT) && if (!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT) &&
!registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT)) { !registeredClient.getClientAuthenticationMethods().contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT)) {
throwInvalidClient("authentication_method"); throw invalidClientException("authentication_method");
} }
// @formatter:on // @formatter:on
if (clientAuthentication.getCredentials() == null) { Object credentials = clientAuthentication.getCredentials();
throwInvalidClient("credentials"); if (credentials == null) {
throw invalidClientException("credentials");
} }
Jwt jwtAssertion = null; Jwt jwtAssertion = null;
JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(registeredClient); JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(registeredClient);
try { try {
jwtAssertion = jwtDecoder.decode(clientAuthentication.getCredentials().toString()); jwtAssertion = jwtDecoder.decode(credentials.toString());
} }
catch (JwtException ex) { catch (JwtException ex) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION, ex); throw invalidClientException(OAuth2ParameterNames.CLIENT_ASSERTION, ex);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -159,14 +161,15 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
this.jwtDecoderFactory = jwtDecoderFactory; this.jwtDecoderFactory = jwtDecoderFactory;
} }
private static void throwInvalidClient(String parameterName) { private static OAuth2AuthenticationException invalidClientException(String parameterName) {
throwInvalidClient(parameterName, null); return invalidClientException(parameterName, null);
} }
private static void throwInvalidClient(String parameterName, Throwable cause) { private static OAuth2AuthenticationException invalidClientException(String parameterName,
@Nullable Throwable cause) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI); "Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error, error.toString(), cause); return new OAuth2AuthenticationException(error, error.toString(), cause);
} }
} }
@@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AccessTokenResponseAuthenticationSuccessHandler; import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AccessTokenResponseAuthenticationSuccessHandler;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -47,9 +48,8 @@ public final class OAuth2AccessTokenAuthenticationContext implements OAuth2Authe
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Nullable
@Override @Override
public <V> V get(Object key) { public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null; return hasKey(key) ? (V) this.context.get(key) : null;
} }
@@ -65,7 +65,9 @@ public final class OAuth2AccessTokenAuthenticationContext implements OAuth2Authe
* @return the {@link OAuth2AccessTokenResponse.Builder} * @return the {@link OAuth2AccessTokenResponse.Builder}
*/ */
public OAuth2AccessTokenResponse.Builder getAccessTokenResponse() { public OAuth2AccessTokenResponse.Builder getAccessTokenResponse() {
return get(OAuth2AccessTokenResponse.Builder.class); OAuth2AccessTokenResponse.Builder accessTokenResponse = get(OAuth2AccessTokenResponse.Builder.class);
Assert.notNull(accessTokenResponse, "accessTokenResponse cannot be null");
return accessTokenResponse;
} }
/** /**
@@ -20,7 +20,8 @@ import java.io.Serial;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -52,7 +53,7 @@ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthentication
private final OAuth2AccessToken accessToken; private final OAuth2AccessToken accessToken;
private final OAuth2RefreshToken refreshToken; private final @Nullable OAuth2RefreshToken refreshToken;
private final Map<String, Object> additionalParameters; private final Map<String, Object> additionalParameters;
@@ -135,8 +136,7 @@ public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthentication
* Returns the {@link OAuth2RefreshToken refresh token}. * Returns the {@link OAuth2RefreshToken refresh token}.
* @return the {@link OAuth2RefreshToken} or {@code null} if not available * @return the {@link OAuth2RefreshToken} or {@code null} if not available
*/ */
@Nullable public @Nullable OAuth2RefreshToken getRefreshToken() {
public OAuth2RefreshToken getRefreshToken() {
return this.refreshToken; return this.refreshToken;
} }
@@ -20,6 +20,8 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.context.Context; import org.springframework.security.oauth2.server.authorization.context.Context;
@@ -42,7 +44,9 @@ public interface OAuth2AuthenticationContext extends Context {
*/ */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
default <T extends Authentication> T getAuthentication() { default <T extends Authentication> T getAuthentication() {
return (T) get(Authentication.class); Authentication authentication = get(Authentication.class);
Assert.notNull(authentication, "authentication cannot be null");
return (T) authentication;
} }
/** /**
@@ -85,7 +89,7 @@ public interface OAuth2AuthenticationContext extends Context {
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected <V> V get(Object key) { protected <V> @Nullable V get(Object key) {
return (V) getContext().get(key); return (V) getContext().get(key);
} }
@@ -43,8 +43,9 @@ final class OAuth2AuthenticationProviderUtils {
static OAuth2ClientAuthenticationToken getAuthenticatedClientElseThrowInvalidClient(Authentication authentication) { static OAuth2ClientAuthenticationToken getAuthenticatedClientElseThrowInvalidClient(Authentication authentication) {
OAuth2ClientAuthenticationToken clientPrincipal = null; OAuth2ClientAuthenticationToken clientPrincipal = null;
if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication.getPrincipal().getClass())) { Object principal = authentication.getPrincipal();
clientPrincipal = (OAuth2ClientAuthenticationToken) authentication.getPrincipal(); if (principal != null && OAuth2ClientAuthenticationToken.class.isAssignableFrom(principal.getClass())) {
clientPrincipal = (OAuth2ClientAuthenticationToken) principal;
} }
if (clientPrincipal != null && clientPrincipal.isAuthenticated()) { if (clientPrincipal != null && clientPrincipal.isAuthenticated()) {
return clientPrincipal; return clientPrincipal;
@@ -30,6 +30,7 @@ import java.util.Map;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
@@ -96,7 +97,7 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator; private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
private SessionRegistry sessionRegistry; private @Nullable SessionRegistry sessionRegistry;
/** /**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the
@@ -119,6 +120,7 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(authorizationCodeAuthentication); .getAuthenticatedClientElseThrowInvalidClient(authorizationCodeAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client"); this.logger.trace("Retrieved registered client");
@@ -136,9 +138,11 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization
.getToken(OAuth2AuthorizationCode.class); .getToken(OAuth2AuthorizationCode.class);
Assert.notNull(authorizationCode, "authorizationCode cannot be null");
OAuth2AuthorizationRequest authorizationRequest = authorization OAuth2AuthorizationRequest authorizationRequest = authorization
.getAttribute(OAuth2AuthorizationRequest.class.getName()); .getAttribute(OAuth2AuthorizationRequest.class.getName());
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) { if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
if (!authorizationCode.isInvalidated()) { if (!authorizationCode.isInvalidated()) {
@@ -193,6 +197,7 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
} }
Authentication principal = authorization.getAttribute(Principal.class.getName()); Authentication principal = authorization.getAttribute(Principal.class.getName());
Assert.notNull(principal, "principal cannot be null");
// @formatter:off // @formatter:off
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder() DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
@@ -331,10 +336,14 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
this.sessionRegistry = sessionRegistry; this.sessionRegistry = sessionRegistry;
} }
private SessionInformation getSessionInformation(Authentication principal) { private @Nullable SessionInformation getSessionInformation(Authentication principal) {
SessionInformation sessionInformation = null; SessionInformation sessionInformation = null;
if (this.sessionRegistry != null) { if (this.sessionRegistry != null) {
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(principal.getPrincipal(), false); Object sessionPrincipal = principal.getPrincipal();
if (sessionPrincipal == null) {
return null;
}
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(sessionPrincipal, false);
if (!CollectionUtils.isEmpty(sessions)) { if (!CollectionUtils.isEmpty(sessions)) {
sessionInformation = sessions.get(0); sessionInformation = sessions.get(0);
if (sessions.size() > 1) { if (sessions.size() > 1) {
@@ -18,7 +18,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Map; import java.util.Map;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -38,7 +39,7 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends OAuth2Authorizat
private final String code; private final String code;
private final String redirectUri; private final @Nullable String redirectUri;
/** /**
* Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided
@@ -68,8 +69,7 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends OAuth2Authorizat
* Returns the redirect uri. * Returns the redirect uri.
* @return the redirect uri * @return the redirect uri
*/ */
@Nullable public @Nullable String getRedirectUri() {
public String getRedirectUri() {
return this.redirectUri; return this.redirectUri;
} }
@@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.time.Instant; import java.time.Instant;
import java.util.Base64; import java.util.Base64;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.crypto.keygen.Base64StringKeyGenerator; import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
import org.springframework.security.crypto.keygen.StringKeyGenerator; import org.springframework.security.crypto.keygen.StringKeyGenerator;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -42,10 +43,9 @@ final class OAuth2AuthorizationCodeGenerator implements OAuth2TokenGenerator<OAu
private final StringKeyGenerator authorizationCodeGenerator = new Base64StringKeyGenerator( private final StringKeyGenerator authorizationCodeGenerator = new Base64StringKeyGenerator(
Base64.getUrlEncoder().withoutPadding(), 96); Base64.getUrlEncoder().withoutPadding(), 96);
@Nullable
@Override @Override
public OAuth2AuthorizationCode generate(OAuth2TokenContext context) { public @Nullable OAuth2AuthorizationCode generate(OAuth2TokenContext context) {
if (context.getTokenType() == null || !OAuth2ParameterNames.CODE.equals(context.getTokenType().getValue())) { if (!OAuth2ParameterNames.CODE.equals(context.getTokenType().getValue())) {
return null; return null;
} }
Instant issuedAt = Instant.now(); Instant issuedAt = Instant.now();
@@ -22,7 +22,8 @@ import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Predicate; import java.util.function.Predicate;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -50,9 +51,8 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationContext implement
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Nullable
@Override @Override
public <V> V get(Object key) { public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null; return hasKey(key) ? (V) this.context.get(key) : null;
} }
@@ -67,15 +67,16 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationContext implement
* @return the {@link RegisteredClient} * @return the {@link RegisteredClient}
*/ */
public RegisteredClient getRegisteredClient() { public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class); RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
} }
/** /**
* Returns the {@link OAuth2AuthorizationRequest authorization request}. * Returns the {@link OAuth2AuthorizationRequest authorization request}.
* @return the {@link OAuth2AuthorizationRequest} * @return the {@link OAuth2AuthorizationRequest}
*/ */
@Nullable public @Nullable OAuth2AuthorizationRequest getAuthorizationRequest() {
public OAuth2AuthorizationRequest getAuthorizationRequest() {
return get(OAuth2AuthorizationRequest.class); return get(OAuth2AuthorizationRequest.class);
} }
@@ -83,8 +84,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationContext implement
* Returns the {@link OAuth2AuthorizationConsent authorization consent}. * Returns the {@link OAuth2AuthorizationConsent authorization consent}.
* @return the {@link OAuth2AuthorizationConsent} * @return the {@link OAuth2AuthorizationConsent}
*/ */
@Nullable public @Nullable OAuth2AuthorizationConsent getAuthorizationConsent() {
public OAuth2AuthorizationConsent getAuthorizationConsent() {
return get(OAuth2AuthorizationConsent.class); return get(OAuth2AuthorizationConsent.class);
} }
@@ -16,7 +16,8 @@
package org.springframework.security.oauth2.server.authorization.authentication; package org.springframework.security.oauth2.server.authorization.authentication;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException; import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
@@ -33,7 +34,7 @@ import org.springframework.security.oauth2.core.OAuth2Error;
*/ */
public class OAuth2AuthorizationCodeRequestAuthenticationException extends OAuth2AuthenticationException { public class OAuth2AuthorizationCodeRequestAuthenticationException extends OAuth2AuthenticationException {
private final OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication; private final @Nullable OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication;
/** /**
* Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationException} using * Constructs an {@code OAuth2AuthorizationCodeRequestAuthenticationException} using
@@ -67,8 +68,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationException extends OAuth
* (or Consent), or {@code null} if not available. * (or Consent), or {@code null} if not available.
* @return the {@link OAuth2AuthorizationCodeRequestAuthenticationToken} * @return the {@link OAuth2AuthorizationCodeRequestAuthenticationToken}
*/ */
@Nullable public @Nullable OAuth2AuthorizationCodeRequestAuthenticationToken getAuthorizationCodeRequestAuthentication() {
public OAuth2AuthorizationCodeRequestAuthenticationToken getAuthorizationCodeRequestAuthentication() {
return this.authorizationCodeRequestAuthentication; return this.authorizationCodeRequestAuthentication;
} }
@@ -30,6 +30,7 @@ import java.util.function.Predicate;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
@@ -129,19 +130,19 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters() String requestUri = (String) authorizationCodeRequestAuthentication.getAdditionalParameters()
.get(OAuth2ParameterNames.REQUEST_URI); .get(OAuth2ParameterNames.REQUEST_URI);
if (StringUtils.hasText(requestUri)) { if (StringUtils.hasText(requestUri)) {
OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = null; OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri;
try { try {
pushedAuthorizationRequestUri = OAuth2PushedAuthorizationRequestUri.parse(requestUri); pushedAuthorizationRequestUri = OAuth2PushedAuthorizationRequestUri.parse(requestUri);
} }
catch (Exception ex) { catch (Exception ex) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI,
authorizationCodeRequestAuthentication, null); authorizationCodeRequestAuthentication, null);
} }
pushedAuthorization = this.authorizationService.findByToken(pushedAuthorizationRequestUri.getState(), pushedAuthorization = this.authorizationService.findByToken(pushedAuthorizationRequestUri.getState(),
STATE_TOKEN_TYPE); STATE_TOKEN_TYPE);
if (pushedAuthorization == null) { if (pushedAuthorization == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI,
authorizationCodeRequestAuthentication, null); authorizationCodeRequestAuthentication, null);
} }
@@ -151,9 +152,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
OAuth2AuthorizationRequest authorizationRequest = pushedAuthorization OAuth2AuthorizationRequest authorizationRequest = pushedAuthorization
.getAttribute(OAuth2AuthorizationRequest.class.getName()); .getAttribute(OAuth2AuthorizationRequest.class.getName());
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
if (!authorizationCodeRequestAuthentication.getClientId().equals(authorizationRequest.getClientId())) { if (!authorizationCodeRequestAuthentication.getClientId().equals(authorizationRequest.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
authorizationCodeRequestAuthentication, null); authorizationCodeRequestAuthentication, null);
} }
@@ -165,7 +167,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
.warn(LogMessage.format("Removed expired pushed authorization request for client id '%s'", .warn(LogMessage.format("Removed expired pushed authorization request for client id '%s'",
authorizationRequest.getClientId())); authorizationRequest.getClientId()));
} }
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REQUEST_URI,
authorizationCodeRequestAuthentication, null); authorizationCodeRequestAuthentication, null);
} }
@@ -179,7 +181,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
RegisteredClient registeredClient = this.registeredClientRepository RegisteredClient registeredClient = this.registeredClientRepository
.findByClientId(authorizationCodeRequestAuthentication.getClientId()); .findByClientId(authorizationCodeRequestAuthentication.getClientId());
if (registeredClient == null) { if (registeredClient == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
authorizationCodeRequestAuthentication, null); authorizationCodeRequestAuthentication, null);
} }
@@ -233,11 +235,12 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
if (!isPrincipalAuthenticated(principal)) { if (!isPrincipalAuthenticated(principal)) {
if (promptValues.contains(OidcPrompt.NONE)) { if (promptValues.contains(OidcPrompt.NONE)) {
throwError("login_required", "prompt", authorizationCodeRequestAuthentication, registeredClient); throw createException("login_required", "prompt", authorizationCodeRequestAuthentication,
registeredClient);
} }
else { else {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "principal", authorizationCodeRequestAuthentication, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, "principal",
registeredClient); authorizationCodeRequestAuthentication, registeredClient);
} }
} }
@@ -260,7 +263,8 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) { if (this.authorizationConsentRequired.test(authenticationContextBuilder.build())) {
if (promptValues.contains(OidcPrompt.NONE)) { if (promptValues.contains(OidcPrompt.NONE)) {
// Return an error instead of displaying the consent page // Return an error instead of displaying the consent page
throwError("consent_required", "prompt", authorizationCodeRequestAuthentication, registeredClient); throw createException("consent_required", "prompt", authorizationCodeRequestAuthentication,
registeredClient);
} }
String state = DEFAULT_STATE_GENERATOR.generateKey(); String state = DEFAULT_STATE_GENERATOR.generateKey();
@@ -416,15 +420,17 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
if (!authenticationContext.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()) { if (!authenticationContext.getRegisteredClient().getClientSettings().isRequireAuthorizationConsent()) {
return false; return false;
} }
OAuth2AuthorizationRequest authorizationRequest = authenticationContext.getAuthorizationRequest();
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
// 'openid' scope does not require consent // 'openid' scope does not require consent
if (authenticationContext.getAuthorizationRequest().getScopes().contains(OidcScopes.OPENID) if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)
&& authenticationContext.getAuthorizationRequest().getScopes().size() == 1) { && authorizationRequest.getScopes().size() == 1) {
return false; return false;
} }
if (authenticationContext.getAuthorizationConsent() != null && authenticationContext.getAuthorizationConsent() if (authenticationContext.getAuthorizationConsent() != null && authenticationContext.getAuthorizationConsent()
.getScopes() .getScopes()
.containsAll(authenticationContext.getAuthorizationRequest().getScopes())) { .containsAll(authorizationRequest.getScopes())) {
return false; return false;
} }
@@ -442,7 +448,8 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
private static OAuth2TokenContext createAuthorizationCodeTokenContext( private static OAuth2TokenContext createAuthorizationCodeTokenContext(
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication, OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient, OAuth2Authorization authorization, Set<String> authorizedScopes) { RegisteredClient registeredClient, @Nullable OAuth2Authorization authorization,
Set<String> authorizedScopes) {
// @formatter:off // @formatter:off
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder() DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
@@ -467,23 +474,27 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
&& principal.isAuthenticated(); && principal.isAuthenticated();
} }
private static void throwError(String errorCode, String parameterName, private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
String parameterName,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication, OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient) { @Nullable RegisteredClient registeredClient) {
throwError(errorCode, parameterName, ERROR_URI, authorizationCodeRequestAuthentication, registeredClient, null); return createException(errorCode, parameterName, ERROR_URI, authorizationCodeRequestAuthentication,
registeredClient, null);
} }
private static void throwError(String errorCode, String parameterName, String errorUri, private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
String parameterName, String errorUri,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication, OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) { @Nullable RegisteredClient registeredClient, @Nullable OAuth2AuthorizationRequest authorizationRequest) {
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri); OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
throwError(error, parameterName, authorizationCodeRequestAuthentication, registeredClient, return createException(error, parameterName, authorizationCodeRequestAuthentication, registeredClient,
authorizationRequest); authorizationRequest);
} }
private static void throwError(OAuth2Error error, String parameterName, private static OAuth2AuthorizationCodeRequestAuthenticationException createException(OAuth2Error error,
String parameterName,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication, OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) { @Nullable RegisteredClient registeredClient, @Nullable OAuth2AuthorizationRequest authorizationRequest) {
String redirectUri = resolveRedirectUri(authorizationCodeRequestAuthentication, authorizationRequest, String redirectUri = resolveRedirectUri(authorizationCodeRequestAuthentication, authorizationRequest,
registeredClient); registeredClient);
@@ -500,13 +511,13 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
authorizationCodeRequestAuthentication.getState(), authorizationCodeRequestAuthentication.getScopes(), authorizationCodeRequestAuthentication.getState(), authorizationCodeRequestAuthentication.getScopes(),
authorizationCodeRequestAuthentication.getAdditionalParameters()); authorizationCodeRequestAuthentication.getAdditionalParameters());
throw new OAuth2AuthorizationCodeRequestAuthenticationException(error, return new OAuth2AuthorizationCodeRequestAuthenticationException(error,
authorizationCodeRequestAuthenticationResult); authorizationCodeRequestAuthenticationResult);
} }
private static String resolveRedirectUri( private static @Nullable String resolveRedirectUri(
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication, OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
OAuth2AuthorizationRequest authorizationRequest, RegisteredClient registeredClient) { @Nullable OAuth2AuthorizationRequest authorizationRequest, @Nullable RegisteredClient registeredClient) {
if (authorizationCodeRequestAuthentication != null if (authorizationCodeRequestAuthentication != null
&& StringUtils.hasText(authorizationCodeRequestAuthentication.getRedirectUri())) { && StringUtils.hasText(authorizationCodeRequestAuthentication.getRedirectUri())) {
@@ -20,7 +20,8 @@ import java.io.Serial;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -40,7 +41,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
@Serial @Serial
private static final long serialVersionUID = -1946164725241393094L; private static final long serialVersionUID = -1946164725241393094L;
private final OAuth2AuthorizationCode authorizationCode; private final @Nullable OAuth2AuthorizationCode authorizationCode;
private boolean validated; private boolean validated;
@@ -86,8 +87,7 @@ public class OAuth2AuthorizationCodeRequestAuthenticationToken
* Returns the {@link OAuth2AuthorizationCode}. * Returns the {@link OAuth2AuthorizationCode}.
* @return the {@link OAuth2AuthorizationCode} * @return the {@link OAuth2AuthorizationCode}
*/ */
@Nullable public @Nullable OAuth2AuthorizationCode getAuthorizationCode() {
public OAuth2AuthorizationCode getAuthorizationCode() {
return this.authorizationCode; return this.authorizationCode;
} }
@@ -23,6 +23,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@@ -104,7 +105,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
"Invalid request: requested grant_type is not allowed for registered client '%s'", "Invalid request: requested grant_type is not allowed for registered client '%s'",
registeredClient.getId())); registeredClient.getId()));
} }
throwError(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID, throw createException(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT, OAuth2ParameterNames.CLIENT_ID,
authorizationCodeRequestAuthentication, registeredClient); authorizationCodeRequestAuthentication, registeredClient);
} }
} }
@@ -130,7 +131,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
LOGGER.debug(LogMessage.format("Invalid request: redirect_uri is missing or contains a fragment" LOGGER.debug(LogMessage.format("Invalid request: redirect_uri is missing or contains a fragment"
+ " for registered client '%s'", registeredClient.getId())); + " for registered client '%s'", registeredClient.getId()));
} }
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
authorizationCodeRequestAuthentication, registeredClient); authorizationCodeRequestAuthentication, registeredClient);
} }
@@ -140,7 +141,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
// When comparing client redirect URIs against pre-registered URIs, // When comparing client redirect URIs against pre-registered URIs,
// authorization servers MUST utilize exact string matching. // authorization servers MUST utilize exact string matching.
if (!registeredClient.getRedirectUris().contains(requestedRedirectUri)) { if (!registeredClient.getRedirectUris().contains(requestedRedirectUri)) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
authorizationCodeRequestAuthentication, registeredClient); authorizationCodeRequestAuthentication, registeredClient);
} }
} }
@@ -166,7 +167,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
"Invalid request: redirect_uri does not match for registered client '%s'", "Invalid request: redirect_uri does not match for registered client '%s'",
registeredClient.getId())); registeredClient.getId()));
} }
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
authorizationCodeRequestAuthentication, registeredClient); authorizationCodeRequestAuthentication, registeredClient);
} }
} }
@@ -178,7 +179,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID) if (authorizationCodeRequestAuthentication.getScopes().contains(OidcScopes.OPENID)
|| registeredClient.getRedirectUris().size() != 1) { || registeredClient.getRedirectUris().size() != 1) {
// redirect_uri is REQUIRED for OpenID Connect // redirect_uri is REQUIRED for OpenID Connect
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI,
authorizationCodeRequestAuthentication, registeredClient); authorizationCodeRequestAuthentication, registeredClient);
} }
} }
@@ -197,7 +198,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
LogMessage.format("Invalid request: requested scope is not allowed for registered client '%s'", LogMessage.format("Invalid request: requested scope is not allowed for registered client '%s'",
registeredClient.getId())); registeredClient.getId()));
} }
throwError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE, throw createException(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE,
authorizationCodeRequestAuthentication, registeredClient); authorizationCodeRequestAuthentication, registeredClient);
} }
} }
@@ -215,12 +216,12 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
String codeChallengeMethod = (String) authorizationCodeRequestAuthentication.getAdditionalParameters() String codeChallengeMethod = (String) authorizationCodeRequestAuthentication.getAdditionalParameters()
.get(PkceParameterNames.CODE_CHALLENGE_METHOD); .get(PkceParameterNames.CODE_CHALLENGE_METHOD);
if (!StringUtils.hasText(codeChallengeMethod) || !"S256".equals(codeChallengeMethod)) { if (!StringUtils.hasText(codeChallengeMethod) || !"S256".equals(codeChallengeMethod)) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD,
authorizationCodeRequestAuthentication, registeredClient); PKCE_ERROR_URI, authorizationCodeRequestAuthentication, registeredClient);
} }
} }
else if (registeredClient.getClientSettings().isRequireProofKey()) { else if (registeredClient.getClientSettings().isRequireProofKey()) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI,
authorizationCodeRequestAuthentication, registeredClient); authorizationCodeRequestAuthentication, registeredClient);
} }
} }
@@ -239,15 +240,15 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
if (promptValues.contains(OidcPrompt.NONE)) { if (promptValues.contains(OidcPrompt.NONE)) {
if (promptValues.contains(OidcPrompt.LOGIN) || promptValues.contains(OidcPrompt.CONSENT) if (promptValues.contains(OidcPrompt.LOGIN) || promptValues.contains(OidcPrompt.CONSENT)
|| promptValues.contains(OidcPrompt.SELECT_ACCOUNT)) { || promptValues.contains(OidcPrompt.SELECT_ACCOUNT)) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, "prompt", authorizationCodeRequestAuthentication, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, "prompt",
registeredClient); authorizationCodeRequestAuthentication, registeredClient);
} }
} }
} }
} }
} }
private static boolean isLoopbackAddress(String host) { private static boolean isLoopbackAddress(@Nullable String host) {
if (!StringUtils.hasText(host)) { if (!StringUtils.hasText(host)) {
return false; return false;
} }
@@ -273,20 +274,24 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
} }
} }
private static void throwError(String errorCode, String parameterName, private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
String parameterName,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication, OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient) { RegisteredClient registeredClient) {
throwError(errorCode, parameterName, ERROR_URI, authorizationCodeRequestAuthentication, registeredClient); return createException(errorCode, parameterName, ERROR_URI, authorizationCodeRequestAuthentication,
registeredClient);
} }
private static void throwError(String errorCode, String parameterName, String errorUri, private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
String parameterName, String errorUri,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication, OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient) { RegisteredClient registeredClient) {
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri); OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
throwError(error, parameterName, authorizationCodeRequestAuthentication, registeredClient); return createException(error, parameterName, authorizationCodeRequestAuthentication, registeredClient);
} }
private static void throwError(OAuth2Error error, String parameterName, private static OAuth2AuthorizationCodeRequestAuthenticationException createException(OAuth2Error error,
String parameterName,
OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication, OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
RegisteredClient registeredClient) { RegisteredClient registeredClient) {
@@ -306,7 +311,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationValidator
authorizationCodeRequestAuthentication.getAdditionalParameters()); authorizationCodeRequestAuthentication.getAdditionalParameters());
authorizationCodeRequestAuthenticationResult.setAuthenticated(true); authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
throw new OAuth2AuthorizationCodeRequestAuthenticationException(error, return new OAuth2AuthorizationCodeRequestAuthenticationException(error,
authorizationCodeRequestAuthenticationResult); authorizationCodeRequestAuthenticationResult);
} }
@@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
@@ -50,9 +51,8 @@ public final class OAuth2AuthorizationConsentAuthenticationContext implements OA
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Nullable
@Override @Override
public <V> V get(Object key) { public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null; return hasKey(key) ? (V) this.context.get(key) : null;
} }
@@ -68,7 +68,9 @@ public final class OAuth2AuthorizationConsentAuthenticationContext implements OA
* @return the {@link OAuth2AuthorizationConsent.Builder} * @return the {@link OAuth2AuthorizationConsent.Builder}
*/ */
public OAuth2AuthorizationConsent.Builder getAuthorizationConsent() { public OAuth2AuthorizationConsent.Builder getAuthorizationConsent() {
return get(OAuth2AuthorizationConsent.Builder.class); OAuth2AuthorizationConsent.Builder authorizationConsentBuilder = get(OAuth2AuthorizationConsent.Builder.class);
Assert.notNull(authorizationConsentBuilder, "authorizationConsentBuilder cannot be null");
return authorizationConsentBuilder;
} }
/** /**
@@ -76,7 +78,9 @@ public final class OAuth2AuthorizationConsentAuthenticationContext implements OA
* @return the {@link RegisteredClient} * @return the {@link RegisteredClient}
*/ */
public RegisteredClient getRegisteredClient() { public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class); RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
} }
/** /**
@@ -84,14 +88,16 @@ public final class OAuth2AuthorizationConsentAuthenticationContext implements OA
* @return the {@link OAuth2Authorization} * @return the {@link OAuth2Authorization}
*/ */
public OAuth2Authorization getAuthorization() { public OAuth2Authorization getAuthorization() {
return get(OAuth2Authorization.class); OAuth2Authorization authorization = get(OAuth2Authorization.class);
Assert.notNull(authorization, "authorization cannot be null");
return authorization;
} }
/** /**
* Returns the {@link OAuth2AuthorizationRequest authorization request}. * Returns the {@link OAuth2AuthorizationRequest authorization request}.
* @return the {@link OAuth2AuthorizationRequest} * @return the {@link OAuth2AuthorizationRequest}
*/ */
public OAuth2AuthorizationRequest getAuthorizationRequest() { public @Nullable OAuth2AuthorizationRequest getAuthorizationRequest() {
return get(OAuth2AuthorizationRequest.class); return get(OAuth2AuthorizationRequest.class);
} }
@@ -23,6 +23,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
@@ -79,7 +80,7 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
private OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator = new OAuth2AuthorizationCodeGenerator(); private OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator = new OAuth2AuthorizationCodeGenerator();
private Consumer<OAuth2AuthorizationConsentAuthenticationContext> authorizationConsentCustomizer; private @Nullable Consumer<OAuth2AuthorizationConsentAuthenticationContext> authorizationConsentCustomizer;
/** /**
* Constructs an {@code OAuth2AuthorizationConsentAuthenticationProvider} using the * Constructs an {@code OAuth2AuthorizationConsentAuthenticationProvider} using the
@@ -100,7 +101,7 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
} }
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
if (authentication instanceof OAuth2DeviceAuthorizationConsentAuthenticationToken) { if (authentication instanceof OAuth2DeviceAuthorizationConsentAuthenticationToken) {
// This is NOT an OAuth 2.0 Authorization Consent for the Authorization Code // This is NOT an OAuth 2.0 Authorization Consent for the Authorization Code
// Grant, // Grant,
@@ -114,8 +115,8 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
OAuth2Authorization authorization = this.authorizationService OAuth2Authorization authorization = this.authorizationService
.findByToken(authorizationConsentAuthentication.getState(), STATE_TOKEN_TYPE); .findByToken(authorizationConsentAuthentication.getState(), STATE_TOKEN_TYPE);
if (authorization == null) { if (authorization == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE, authorizationConsentAuthentication, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE,
null, null); authorizationConsentAuthentication, null, null);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -125,14 +126,18 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
// The 'in-flight' authorization must be associated to the current principal // The 'in-flight' authorization must be associated to the current principal
Authentication principal = (Authentication) authorizationConsentAuthentication.getPrincipal(); Authentication principal = (Authentication) authorizationConsentAuthentication.getPrincipal();
if (!isPrincipalAuthenticated(principal) || !principal.getName().equals(authorization.getPrincipalName())) { if (!isPrincipalAuthenticated(principal) || !principal.getName().equals(authorization.getPrincipalName())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE, authorizationConsentAuthentication, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE,
null, null); authorizationConsentAuthentication, null, null);
} }
RegisteredClient registeredClient = this.registeredClientRepository RegisteredClient registeredClient = this.registeredClientRepository
.findByClientId(authorizationConsentAuthentication.getClientId()); .findByClientId(authorizationConsentAuthentication.getClientId());
if (registeredClient == null || !registeredClient.getId().equals(authorization.getRegisteredClientId())) { if (registeredClient == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID, throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
authorizationConsentAuthentication, null, null);
}
if (!registeredClient.getId().equals(authorization.getRegisteredClientId())) {
throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID,
authorizationConsentAuthentication, registeredClient, null); authorizationConsentAuthentication, registeredClient, null);
} }
@@ -142,11 +147,12 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
OAuth2AuthorizationRequest authorizationRequest = authorization OAuth2AuthorizationRequest authorizationRequest = authorization
.getAttribute(OAuth2AuthorizationRequest.class.getName()); .getAttribute(OAuth2AuthorizationRequest.class.getName());
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
Set<String> requestedScopes = authorizationRequest.getScopes(); Set<String> requestedScopes = authorizationRequest.getScopes();
Set<String> authorizedScopes = new HashSet<>(authorizationConsentAuthentication.getScopes()); Set<String> authorizedScopes = new HashSet<>(authorizationConsentAuthentication.getScopes());
if (!requestedScopes.containsAll(authorizedScopes)) { if (!requestedScopes.containsAll(authorizedScopes)) {
throwError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE, authorizationConsentAuthentication, throw createException(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE,
registeredClient, authorizationRequest); authorizationConsentAuthentication, registeredClient, authorizationRequest);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -215,12 +221,12 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Removed authorization"); this.logger.trace("Removed authorization");
} }
throwError(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID, throw createException(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID,
authorizationConsentAuthentication, registeredClient, authorizationRequest); authorizationConsentAuthentication, registeredClient, authorizationRequest);
} }
OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build(); OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build();
if (!authorizationConsent.equals(currentAuthorizationConsent)) { if (currentAuthorizationConsent == null || !authorizationConsent.equals(currentAuthorizationConsent)) {
this.authorizationConsentService.save(authorizationConsent); this.authorizationConsentService.save(authorizationConsent);
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Saved authorization consent"); this.logger.trace("Saved authorization consent");
@@ -334,16 +340,17 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
&& principal.isAuthenticated(); && principal.isAuthenticated();
} }
private static void throwError(String errorCode, String parameterName, private static OAuth2AuthorizationCodeRequestAuthenticationException createException(String errorCode,
OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthentication, String parameterName, OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthentication,
RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) { @Nullable RegisteredClient registeredClient, @Nullable OAuth2AuthorizationRequest authorizationRequest) {
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, ERROR_URI); OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, ERROR_URI);
throwError(error, parameterName, authorizationConsentAuthentication, registeredClient, authorizationRequest); return createException(error, parameterName, authorizationConsentAuthentication, registeredClient,
authorizationRequest);
} }
private static void throwError(OAuth2Error error, String parameterName, private static OAuth2AuthorizationCodeRequestAuthenticationException createException(OAuth2Error error,
OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthentication, String parameterName, OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthentication,
RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) { @Nullable RegisteredClient registeredClient, @Nullable OAuth2AuthorizationRequest authorizationRequest) {
String redirectUri = resolveRedirectUri(authorizationRequest, registeredClient); String redirectUri = resolveRedirectUri(authorizationRequest, registeredClient);
if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST) if (error.getErrorCode().equals(OAuth2ErrorCodes.INVALID_REQUEST)
@@ -363,12 +370,12 @@ public final class OAuth2AuthorizationConsentAuthenticationProvider implements A
(Authentication) authorizationConsentAuthentication.getPrincipal(), redirectUri, state, requestedScopes, (Authentication) authorizationConsentAuthentication.getPrincipal(), redirectUri, state, requestedScopes,
null); null);
throw new OAuth2AuthorizationCodeRequestAuthenticationException(error, return new OAuth2AuthorizationCodeRequestAuthenticationException(error,
authorizationCodeRequestAuthenticationResult); authorizationCodeRequestAuthenticationResult);
} }
private static String resolveRedirectUri(OAuth2AuthorizationRequest authorizationRequest, private static @Nullable String resolveRedirectUri(@Nullable OAuth2AuthorizationRequest authorizationRequest,
RegisteredClient registeredClient) { @Nullable RegisteredClient registeredClient) {
if (authorizationRequest != null && StringUtils.hasText(authorizationRequest.getRedirectUri())) { if (authorizationRequest != null && StringUtils.hasText(authorizationRequest.getRedirectUri())) {
return authorizationRequest.getRedirectUri(); return authorizationRequest.getRedirectUri();
} }
@@ -23,7 +23,8 @@ import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -21,7 +21,8 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -45,9 +46,8 @@ public final class OAuth2ClientAuthenticationContext implements OAuth2Authentica
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Nullable
@Override @Override
public <V> V get(Object key) { public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null; return hasKey(key) ? (V) this.context.get(key) : null;
} }
@@ -62,7 +62,9 @@ public final class OAuth2ClientAuthenticationContext implements OAuth2Authentica
* @return the {@link RegisteredClient} * @return the {@link RegisteredClient}
*/ */
public RegisteredClient getRegisteredClient() { public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class); RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
} }
/** /**
@@ -20,7 +20,8 @@ import java.io.Serial;
import java.util.Collections; import java.util.Collections;
import java.util.Map; import java.util.Map;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.Transient; import org.springframework.security.core.Transient;
@@ -49,11 +50,11 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
private final String clientId; private final String clientId;
private final RegisteredClient registeredClient; private final @Nullable RegisteredClient registeredClient;
private final ClientAuthenticationMethod clientAuthenticationMethod; private final ClientAuthenticationMethod clientAuthenticationMethod;
private final Object credentials; private final @Nullable Object credentials;
private final Map<String, Object> additionalParameters; private final Map<String, Object> additionalParameters;
@@ -103,9 +104,8 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
return this.clientId; return this.clientId;
} }
@Nullable
@Override @Override
public Object getCredentials() { public @Nullable Object getCredentials() {
return this.credentials; return this.credentials;
} }
@@ -115,8 +115,7 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
* @return the authenticated {@link RegisteredClient}, or {@code null} if not * @return the authenticated {@link RegisteredClient}, or {@code null} if not
* authenticated * authenticated
*/ */
@Nullable public @Nullable RegisteredClient getRegisteredClient() {
public RegisteredClient getRegisteredClient() {
return this.registeredClient; return this.registeredClient;
} }
@@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -45,9 +46,8 @@ public final class OAuth2ClientCredentialsAuthenticationContext implements OAuth
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Nullable
@Override @Override
public <V> V get(Object key) { public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null; return hasKey(key) ? (V) this.context.get(key) : null;
} }
@@ -62,7 +62,9 @@ public final class OAuth2ClientCredentialsAuthenticationContext implements OAuth
* @return the {@link RegisteredClient} * @return the {@link RegisteredClient}
*/ */
public RegisteredClient getRegisteredClient() { public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class); RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
} }
/** /**
@@ -95,6 +95,7 @@ public final class OAuth2ClientCredentialsAuthenticationProvider implements Auth
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(clientCredentialsAuthentication); .getAuthenticatedClientElseThrowInvalidClient(clientCredentialsAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client"); this.logger.trace("Retrieved registered client");
@@ -21,7 +21,8 @@ import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -21,10 +21,12 @@ import java.net.URISyntaxException;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.Set; import java.util.Set;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
@@ -138,6 +140,7 @@ public final class OAuth2ClientRegistrationAuthenticationProvider implements Aut
} }
OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken(); OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken();
Assert.notNull(authorizedAccessToken, "accessToken cannot be null");
if (!authorizedAccessToken.isActive()) { if (!authorizedAccessToken.isActive()) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
} }
@@ -199,9 +202,10 @@ public final class OAuth2ClientRegistrationAuthenticationProvider implements Aut
private OAuth2ClientRegistrationAuthenticationToken registerClient( private OAuth2ClientRegistrationAuthenticationToken registerClient(
OAuth2ClientRegistrationAuthenticationToken clientRegistrationAuthentication, OAuth2ClientRegistrationAuthenticationToken clientRegistrationAuthentication,
OAuth2Authorization authorization) { @Nullable OAuth2Authorization authorization) {
if (!isValidRedirectUris(clientRegistrationAuthentication.getClientRegistration().getRedirectUris())) { List<String> redirectUris = clientRegistrationAuthentication.getClientRegistration().getRedirectUris();
if (!isValidRedirectUris((redirectUris != null) ? redirectUris : Collections.emptyList())) {
throwInvalidClientRegistration(OAuth2ErrorCodes.INVALID_REDIRECT_URI, throwInvalidClientRegistration(OAuth2ErrorCodes.INVALID_REDIRECT_URI,
OAuth2ClientMetadataClaimNames.REDIRECT_URIS); OAuth2ClientMetadataClaimNames.REDIRECT_URIS);
} }
@@ -236,8 +240,10 @@ public final class OAuth2ClientRegistrationAuthenticationProvider implements Aut
if (authorization != null) { if (authorization != null) {
// Invalidate the "initial" access token as it can only be used once // Invalidate the "initial" access token as it can only be used once
OAuth2Authorization.Token<OAuth2AccessToken> accessToken = authorization.getAccessToken();
Assert.notNull(accessToken, "accessToken cannot be null");
OAuth2Authorization.Builder builder = OAuth2Authorization.from(authorization) OAuth2Authorization.Builder builder = OAuth2Authorization.from(authorization)
.invalidate(authorization.getAccessToken().getToken()); .invalidate(accessToken.getToken());
if (authorization.getRefreshToken() != null) { if (authorization.getRefreshToken() != null) {
builder.invalidate(authorization.getRefreshToken().getToken()); builder.invalidate(authorization.getRefreshToken().getToken());
} }
@@ -265,8 +271,9 @@ public final class OAuth2ClientRegistrationAuthenticationProvider implements Aut
private static void checkScope(OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken, private static void checkScope(OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken,
Set<String> requiredScope) { Set<String> requiredScope) {
Collection<String> authorizedScope = Collections.emptySet(); Collection<String> authorizedScope = Collections.emptySet();
if (authorizedAccessToken.getClaims().containsKey(OAuth2ParameterNames.SCOPE)) { Map<String, Object> claims = authorizedAccessToken.getClaims();
authorizedScope = (Collection<String>) authorizedAccessToken.getClaims().get(OAuth2ParameterNames.SCOPE); if (claims != null && claims.containsKey(OAuth2ParameterNames.SCOPE)) {
authorizedScope = (Collection<String>) claims.get(OAuth2ParameterNames.SCOPE);
} }
if (!authorizedScope.containsAll(requiredScope)) { if (!authorizedScope.containsAll(requiredScope)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
@@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.io.Serial; import java.io.Serial;
import java.util.Collections; import java.util.Collections;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration; import org.springframework.security.oauth2.server.authorization.OAuth2ClientRegistration;
@@ -40,8 +41,7 @@ public class OAuth2ClientRegistrationAuthenticationToken extends AbstractAuthent
@Serial @Serial
private static final long serialVersionUID = 7135429161909989115L; private static final long serialVersionUID = 7135429161909989115L;
@Nullable private final @Nullable Authentication principal;
private final Authentication principal;
private final OAuth2ClientRegistration clientRegistration; private final OAuth2ClientRegistration clientRegistration;
@@ -62,9 +62,8 @@ public class OAuth2ClientRegistrationAuthenticationToken extends AbstractAuthent
} }
} }
@Nullable
@Override @Override
public Object getPrincipal() { public @Nullable Object getPrincipal() {
return this.principal; return this.principal;
} }
@@ -23,6 +23,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
@@ -72,7 +73,7 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
private final OAuth2AuthorizationConsentService authorizationConsentService; private final OAuth2AuthorizationConsentService authorizationConsentService;
private Consumer<OAuth2AuthorizationConsentAuthenticationContext> authorizationConsentCustomizer; private @Nullable Consumer<OAuth2AuthorizationConsentAuthenticationContext> authorizationConsentCustomizer;
/** /**
* Constructs an {@code OAuth2DeviceAuthorizationConsentAuthenticationProvider} using * Constructs an {@code OAuth2DeviceAuthorizationConsentAuthenticationProvider} using
@@ -99,7 +100,7 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
OAuth2Authorization authorization = this.authorizationService OAuth2Authorization authorization = this.authorizationService
.findByToken(deviceAuthorizationConsentAuthentication.getState(), STATE_TOKEN_TYPE); .findByToken(deviceAuthorizationConsentAuthentication.getState(), STATE_TOKEN_TYPE);
if (authorization == null) { if (authorization == null) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE); throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -109,13 +110,13 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
// The authorization must be associated to the current principal // The authorization must be associated to the current principal
Authentication principal = (Authentication) deviceAuthorizationConsentAuthentication.getPrincipal(); Authentication principal = (Authentication) deviceAuthorizationConsentAuthentication.getPrincipal();
if (!isPrincipalAuthenticated(principal) || !principal.getName().equals(authorization.getPrincipalName())) { if (!isPrincipalAuthenticated(principal) || !principal.getName().equals(authorization.getPrincipalName())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE); throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE);
} }
RegisteredClient registeredClient = this.registeredClientRepository RegisteredClient registeredClient = this.registeredClientRepository
.findByClientId(deviceAuthorizationConsentAuthentication.getClientId()); .findByClientId(deviceAuthorizationConsentAuthentication.getClientId());
if (registeredClient == null || !registeredClient.getId().equals(authorization.getRegisteredClientId())) { if (registeredClient == null || !registeredClient.getId().equals(authorization.getRegisteredClientId())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -123,9 +124,10 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
} }
Set<String> requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE); Set<String> requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE);
Assert.notNull(requestedScopes, "requestedScopes cannot be null");
Set<String> authorizedScopes = new HashSet<>(deviceAuthorizationConsentAuthentication.getScopes()); Set<String> authorizedScopes = new HashSet<>(deviceAuthorizationConsentAuthentication.getScopes());
if (!requestedScopes.containsAll(authorizedScopes)) { if (!requestedScopes.containsAll(authorizedScopes)) {
throwError(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE); throw createException(OAuth2ErrorCodes.INVALID_SCOPE, OAuth2ParameterNames.SCOPE);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -177,7 +179,9 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
authorizationConsentBuilder.authorities(authorities::addAll); authorizationConsentBuilder.authorities(authorities::addAll);
OAuth2Authorization.Token<OAuth2DeviceCode> deviceCodeToken = authorization.getToken(OAuth2DeviceCode.class); OAuth2Authorization.Token<OAuth2DeviceCode> deviceCodeToken = authorization.getToken(OAuth2DeviceCode.class);
Assert.notNull(deviceCodeToken, "deviceCode cannot be null");
OAuth2Authorization.Token<OAuth2UserCode> userCodeToken = authorization.getToken(OAuth2UserCode.class); OAuth2Authorization.Token<OAuth2UserCode> userCodeToken = authorization.getToken(OAuth2UserCode.class);
Assert.notNull(userCodeToken, "userCode cannot be null");
if (authorities.isEmpty()) { if (authorities.isEmpty()) {
// Authorization consent denied (or revoked) // Authorization consent denied (or revoked)
@@ -196,11 +200,11 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Invalidated device code and user code because authorization consent was denied"); this.logger.trace("Invalidated device code and user code because authorization consent was denied");
} }
throwError(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID); throw createException(OAuth2ErrorCodes.ACCESS_DENIED, OAuth2ParameterNames.CLIENT_ID);
} }
OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build(); OAuth2AuthorizationConsent authorizationConsent = authorizationConsentBuilder.build();
if (!authorizationConsent.equals(currentAuthorizationConsent)) { if (currentAuthorizationConsent == null || !authorizationConsent.equals(currentAuthorizationConsent)) {
this.authorizationConsentService.save(authorizationConsent); this.authorizationConsentService.save(authorizationConsent);
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Saved authorization consent"); this.logger.trace("Saved authorization consent");
@@ -263,9 +267,9 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
&& principal.isAuthenticated(); && principal.isAuthenticated();
} }
private static void throwError(String errorCode, String parameterName) { private static OAuth2AuthenticationException createException(String errorCode, String parameterName) {
OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, ERROR_URI); OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error); return new OAuth2AuthenticationException(error);
} }
} }
@@ -22,7 +22,8 @@ import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -43,7 +44,7 @@ public class OAuth2DeviceAuthorizationConsentAuthenticationToken extends OAuth2A
private final String userCode; private final String userCode;
private final Set<String> requestedScopes; private final @Nullable Set<String> requestedScopes;
/** /**
* Constructs an {@code OAuth2DeviceAuthorizationConsentAuthenticationToken} using the * Constructs an {@code OAuth2DeviceAuthorizationConsentAuthenticationToken} using the
@@ -98,9 +99,9 @@ public class OAuth2DeviceAuthorizationConsentAuthenticationToken extends OAuth2A
/** /**
* Returns the requested scopes. * Returns the requested scopes.
* @return the requested scopes * @return the requested scopes, or {@code null} if not available
*/ */
public Set<String> getRequestedScopes() { public @Nullable Set<String> getRequestedScopes() {
return this.requestedScopes; return this.requestedScopes;
} }
@@ -23,9 +23,9 @@ import java.util.Set;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.log.LogMessage; import org.springframework.core.log.LogMessage;
import org.springframework.lang.Nullable;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.AuthenticationException;
@@ -101,6 +101,7 @@ public final class OAuth2DeviceAuthorizationRequestAuthenticationProvider implem
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(deviceAuthorizationRequestAuthentication); .getAuthenticatedClientElseThrowInvalidClient(deviceAuthorizationRequestAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client"); this.logger.trace("Retrieved registered client");
@@ -224,9 +225,8 @@ public final class OAuth2DeviceAuthorizationRequestAuthenticationProvider implem
private final StringKeyGenerator deviceCodeGenerator = new Base64StringKeyGenerator( private final StringKeyGenerator deviceCodeGenerator = new Base64StringKeyGenerator(
Base64.getUrlEncoder().withoutPadding(), 96); Base64.getUrlEncoder().withoutPadding(), 96);
@Nullable
@Override @Override
public OAuth2DeviceCode generate(OAuth2TokenContext context) { public @Nullable OAuth2DeviceCode generate(OAuth2TokenContext context) {
if (context.getTokenType() == null if (context.getTokenType() == null
|| !OAuth2ParameterNames.DEVICE_CODE.equals(context.getTokenType().getValue())) { || !OAuth2ParameterNames.DEVICE_CODE.equals(context.getTokenType().getValue())) {
return null; return null;
@@ -268,9 +268,8 @@ public final class OAuth2DeviceAuthorizationRequestAuthenticationProvider implem
private final StringKeyGenerator userCodeGenerator = new UserCodeStringKeyGenerator(); private final StringKeyGenerator userCodeGenerator = new UserCodeStringKeyGenerator();
@Nullable
@Override @Override
public OAuth2UserCode generate(OAuth2TokenContext context) { public @Nullable OAuth2UserCode generate(OAuth2TokenContext context) {
if (context.getTokenType() == null if (context.getTokenType() == null
|| !OAuth2ParameterNames.USER_CODE.equals(context.getTokenType().getValue())) { || !OAuth2ParameterNames.USER_CODE.equals(context.getTokenType().getValue())) {
return null; return null;
@@ -23,7 +23,8 @@ import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2DeviceCode; import org.springframework.security.oauth2.core.OAuth2DeviceCode;
@@ -47,13 +48,13 @@ public class OAuth2DeviceAuthorizationRequestAuthenticationToken extends Abstrac
private final Authentication clientPrincipal; private final Authentication clientPrincipal;
private final String authorizationUri; private final @Nullable String authorizationUri;
private final Set<String> scopes; private final Set<String> scopes;
private final OAuth2DeviceCode deviceCode; private final @Nullable OAuth2DeviceCode deviceCode;
private final OAuth2UserCode userCode; private final @Nullable OAuth2UserCode userCode;
private final Map<String, Object> additionalParameters; private final Map<String, Object> additionalParameters;
@@ -116,7 +117,7 @@ public class OAuth2DeviceAuthorizationRequestAuthenticationToken extends Abstrac
* Returns the authorization {@code URI}. * Returns the authorization {@code URI}.
* @return the authorization {@code URI} * @return the authorization {@code URI}
*/ */
public String getAuthorizationUri() { public @Nullable String getAuthorizationUri() {
return this.authorizationUri; return this.authorizationUri;
} }
@@ -132,7 +133,7 @@ public class OAuth2DeviceAuthorizationRequestAuthenticationToken extends Abstrac
* Returns the device code. * Returns the device code.
* @return the device code * @return the device code
*/ */
public OAuth2DeviceCode getDeviceCode() { public @Nullable OAuth2DeviceCode getDeviceCode() {
return this.deviceCode; return this.deviceCode;
} }
@@ -140,7 +141,7 @@ public class OAuth2DeviceAuthorizationRequestAuthenticationToken extends Abstrac
* Returns the user code. * Returns the user code.
* @return the user code * @return the user code
*/ */
public OAuth2UserCode getUserCode() { public @Nullable OAuth2UserCode getUserCode() {
return this.userCode; return this.userCode;
} }
@@ -104,6 +104,7 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(deviceCodeAuthentication); .getAuthenticatedClientElseThrowInvalidClient(deviceCodeAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client"); this.logger.trace("Retrieved registered client");
@@ -119,8 +120,8 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
this.logger.trace("Retrieved authorization with device code"); this.logger.trace("Retrieved authorization with device code");
} }
OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode = authorization.getToken(OAuth2DeviceCode.class); OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode = authorization.getToken(OAuth2DeviceCode.class);
Assert.notNull(deviceCode, "deviceCode cannot be null");
if (!registeredClient.getId().equals(authorization.getRegisteredClientId())) { if (!registeredClient.getId().equals(authorization.getRegisteredClientId())) {
if (!deviceCode.isInvalidated()) { if (!deviceCode.isInvalidated()) {
@@ -158,6 +159,9 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
throw new OAuth2AuthenticationException(error); throw new OAuth2AuthenticationException(error);
} }
OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
Assert.notNull(userCode, "userCode cannot be null");
// authorization_pending // authorization_pending
// The authorization request is still pending as the end user hasn't // The authorization request is still pending as the end user hasn't
// yet completed the user-interaction steps (Section 3.3). The // yet completed the user-interaction steps (Section 3.3). The
@@ -193,10 +197,13 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
this.logger.trace("Validated device token request parameters"); this.logger.trace("Validated device token request parameters");
} }
Authentication principal = authorization.getAttribute(Principal.class.getName());
Assert.notNull(principal, "principal cannot be null");
// @formatter:off // @formatter:off
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder() DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
.registeredClient(registeredClient) .registeredClient(registeredClient)
.principal(authorization.getAttribute(Principal.class.getName())) .principal(principal)
.authorizationServerContext(AuthorizationServerContextHolder.getContext()) .authorizationServerContext(AuthorizationServerContextHolder.getContext())
.authorization(authorization) .authorization(authorization)
.authorizedScopes(authorization.getAuthorizedScopes()) .authorizedScopes(authorization.getAuthorizedScopes())
@@ -18,7 +18,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.util.Map; import java.util.Map;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames; import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
@@ -48,9 +49,8 @@ public final class OAuth2DeviceVerificationAuthenticationContext implements OAut
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Nullable
@Override @Override
public <V> V get(Object key) { public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null; return hasKey(key) ? (V) this.context.get(key) : null;
} }
@@ -65,7 +65,9 @@ public final class OAuth2DeviceVerificationAuthenticationContext implements OAut
* @return the {@link RegisteredClient} * @return the {@link RegisteredClient}
*/ */
public RegisteredClient getRegisteredClient() { public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class); RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
} }
/** /**
@@ -73,15 +75,16 @@ public final class OAuth2DeviceVerificationAuthenticationContext implements OAut
* @return the {@link OAuth2Authorization} * @return the {@link OAuth2Authorization}
*/ */
public OAuth2Authorization getAuthorization() { public OAuth2Authorization getAuthorization() {
return get(OAuth2Authorization.class); OAuth2Authorization authorization = get(OAuth2Authorization.class);
Assert.notNull(authorization, "authorization cannot be null");
return authorization;
} }
/** /**
* Returns the {@link OAuth2AuthorizationConsent authorization consent}. * Returns the {@link OAuth2AuthorizationConsent authorization consent}.
* @return the {@link OAuth2AuthorizationConsent}, or {@code null} if not available * @return the {@link OAuth2AuthorizationConsent}, or {@code null} if not available
*/ */
@Nullable public @Nullable OAuth2AuthorizationConsent getAuthorizationConsent() {
public OAuth2AuthorizationConsent getAuthorizationConsent() {
return get(OAuth2AuthorizationConsent.class); return get(OAuth2AuthorizationConsent.class);
} }
@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.security.Principal; import java.security.Principal;
import java.util.Base64; import java.util.Base64;
import java.util.Collections;
import java.util.Set; import java.util.Set;
import java.util.function.Predicate; import java.util.function.Predicate;
@@ -115,6 +116,7 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
} }
OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class); OAuth2Authorization.Token<OAuth2UserCode> userCode = authorization.getToken(OAuth2UserCode.class);
Assert.notNull(userCode, "userCode cannot be null");
if (!userCode.isActive()) { if (!userCode.isActive()) {
if (!userCode.isInvalidated()) { if (!userCode.isInvalidated()) {
authorization = OAuth2Authorization.from(authorization).invalidate(userCode.getToken()).build(); authorization = OAuth2Authorization.from(authorization).invalidate(userCode.getToken()).build();
@@ -137,12 +139,16 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
RegisteredClient registeredClient = this.registeredClientRepository RegisteredClient registeredClient = this.registeredClientRepository
.findById(authorization.getRegisteredClientId()); .findById(authorization.getRegisteredClientId());
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client"); this.logger.trace("Retrieved registered client");
} }
Set<String> requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE); Set<String> requestedScopes = authorization.getAttribute(OAuth2ParameterNames.SCOPE);
if (requestedScopes == null) {
requestedScopes = Collections.emptySet();
}
OAuth2DeviceVerificationAuthenticationContext.Builder authenticationContextBuilder = OAuth2DeviceVerificationAuthenticationContext OAuth2DeviceVerificationAuthenticationContext.Builder authenticationContextBuilder = OAuth2DeviceVerificationAuthenticationContext
.with(deviceVerificationAuthentication) .with(deviceVerificationAuthentication)
@@ -174,7 +180,7 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
} }
Set<String> currentAuthorizedScopes = (currentAuthorizationConsent != null) Set<String> currentAuthorizedScopes = (currentAuthorizationConsent != null)
? currentAuthorizationConsent.getScopes() : null; ? currentAuthorizationConsent.getScopes() : Collections.emptySet();
AuthorizationServerSettings authorizationServerSettings = AuthorizationServerContextHolder.getContext() AuthorizationServerSettings authorizationServerSettings = AuthorizationServerContextHolder.getContext()
.getAuthorizationServerSettings(); .getAuthorizationServerSettings();
@@ -21,7 +21,8 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -46,7 +47,7 @@ public class OAuth2DeviceVerificationAuthenticationToken extends AbstractAuthent
private final Map<String, Object> additionalParameters; private final Map<String, Object> additionalParameters;
private final String clientId; private final @Nullable String clientId;
/** /**
* Constructs an {@code OAuth2DeviceVerificationAuthenticationToken} using the * Constructs an {@code OAuth2DeviceVerificationAuthenticationToken} using the
@@ -114,9 +115,9 @@ public class OAuth2DeviceVerificationAuthenticationToken extends AbstractAuthent
/** /**
* Returns the client identifier. * Returns the client identifier.
* @return the client identifier * @return the client identifier, or {@code null} if not set
*/ */
public String getClientId() { public @Nullable String getClientId() {
return this.clientId; return this.clientId;
} }
@@ -74,6 +74,7 @@ public final class OAuth2PushedAuthorizationRequestAuthenticationProvider implem
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(pushedAuthorizationRequestAuthentication); .getAuthenticatedClientElseThrowInvalidClient(pushedAuthorizationRequestAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client"); this.logger.trace("Retrieved registered client");
@@ -21,7 +21,8 @@ import java.time.Instant;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -39,9 +40,9 @@ public class OAuth2PushedAuthorizationRequestAuthenticationToken
@Serial @Serial
private static final long serialVersionUID = 7330534287786569644L; private static final long serialVersionUID = 7330534287786569644L;
private final String requestUri; private final @Nullable String requestUri;
private final Instant requestUriExpiresAt; private final @Nullable Instant requestUriExpiresAt;
/** /**
* Constructs an {@code OAuth2PushedAuthorizationRequestAuthenticationToken} using the * Constructs an {@code OAuth2PushedAuthorizationRequestAuthenticationToken} using the
@@ -91,8 +92,7 @@ public class OAuth2PushedAuthorizationRequestAuthenticationToken
* Returns the {@code request_uri} corresponding to the authorization request posted. * Returns the {@code request_uri} corresponding to the authorization request posted.
* @return the {@code request_uri} corresponding to the authorization request posted * @return the {@code request_uri} corresponding to the authorization request posted
*/ */
@Nullable public @Nullable String getRequestUri() {
public String getRequestUri() {
return this.requestUri; return this.requestUri;
} }
@@ -102,8 +102,7 @@ public class OAuth2PushedAuthorizationRequestAuthenticationToken
* @return the expiration time on or after which the {@code request_uri} MUST NOT be * @return the expiration time on or after which the {@code request_uri} MUST NOT be
* accepted * accepted
*/ */
@Nullable public @Nullable Instant getRequestUriExpiresAt() {
public Instant getRequestUriExpiresAt() {
return this.requestUriExpiresAt; return this.requestUriExpiresAt;
} }
@@ -38,11 +38,11 @@ final class OAuth2PushedAuthorizationRequestUri {
private static final StringKeyGenerator DEFAULT_STATE_GENERATOR = new Base64StringKeyGenerator( private static final StringKeyGenerator DEFAULT_STATE_GENERATOR = new Base64StringKeyGenerator(
Base64.getUrlEncoder()); Base64.getUrlEncoder());
private String requestUri; private final String requestUri;
private String state; private final String state;
private Instant expiresAt; private final Instant expiresAt;
static OAuth2PushedAuthorizationRequestUri create() { static OAuth2PushedAuthorizationRequestUri create() {
return create(Instant.now().plusSeconds(300)); return create(Instant.now().plusSeconds(300));
@@ -50,23 +50,17 @@ final class OAuth2PushedAuthorizationRequestUri {
static OAuth2PushedAuthorizationRequestUri create(Instant expiresAt) { static OAuth2PushedAuthorizationRequestUri create(Instant expiresAt) {
String state = DEFAULT_STATE_GENERATOR.generateKey(); String state = DEFAULT_STATE_GENERATOR.generateKey();
OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = new OAuth2PushedAuthorizationRequestUri(); String requestUri = REQUEST_URI_PREFIX + state + REQUEST_URI_DELIMITER + expiresAt.toEpochMilli();
pushedAuthorizationRequestUri.requestUri = REQUEST_URI_PREFIX + state + REQUEST_URI_DELIMITER state = state + REQUEST_URI_DELIMITER + expiresAt.toEpochMilli();
+ expiresAt.toEpochMilli(); return new OAuth2PushedAuthorizationRequestUri(requestUri, state, expiresAt);
pushedAuthorizationRequestUri.state = state + REQUEST_URI_DELIMITER + expiresAt.toEpochMilli();
pushedAuthorizationRequestUri.expiresAt = expiresAt;
return pushedAuthorizationRequestUri;
} }
static OAuth2PushedAuthorizationRequestUri parse(String requestUri) { static OAuth2PushedAuthorizationRequestUri parse(String requestUri) {
int stateStartIndex = REQUEST_URI_PREFIX.length(); int stateStartIndex = REQUEST_URI_PREFIX.length();
int expiresAtStartIndex = requestUri.indexOf(REQUEST_URI_DELIMITER) + REQUEST_URI_DELIMITER.length(); int expiresAtStartIndex = requestUri.indexOf(REQUEST_URI_DELIMITER) + REQUEST_URI_DELIMITER.length();
OAuth2PushedAuthorizationRequestUri pushedAuthorizationRequestUri = new OAuth2PushedAuthorizationRequestUri(); String state = requestUri.substring(stateStartIndex);
pushedAuthorizationRequestUri.requestUri = requestUri; Instant expiresAt = Instant.ofEpochMilli(Long.parseLong(requestUri.substring(expiresAtStartIndex)));
pushedAuthorizationRequestUri.state = requestUri.substring(stateStartIndex); return new OAuth2PushedAuthorizationRequestUri(requestUri, state, expiresAt);
pushedAuthorizationRequestUri.expiresAt = Instant
.ofEpochMilli(Long.parseLong(requestUri.substring(expiresAtStartIndex)));
return pushedAuthorizationRequestUri;
} }
String getRequestUri() { String getRequestUri() {
@@ -81,7 +75,10 @@ final class OAuth2PushedAuthorizationRequestUri {
return this.expiresAt; return this.expiresAt;
} }
private OAuth2PushedAuthorizationRequestUri() { private OAuth2PushedAuthorizationRequestUri(String requestUri, String state, Instant expiresAt) {
this.requestUri = requestUri;
this.state = state;
this.expiresAt = expiresAt;
} }
} }
@@ -105,6 +105,7 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(refreshTokenAuthentication); .getAuthenticatedClientElseThrowInvalidClient(refreshTokenAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client"); this.logger.trace("Retrieved registered client");
@@ -137,6 +138,7 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
} }
OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken(); OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken();
Assert.notNull(refreshToken, "refreshToken cannot be null");
if (!refreshToken.isActive()) { if (!refreshToken.isActive()) {
// As per https://tools.ietf.org/html/rfc6749#section-5.2 // As per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization code, // invalid_grant: The provided authorization grant (e.g., authorization code,
@@ -168,7 +170,10 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
&& clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) { && clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {
// For public clients, verify the DPoP Proof public key is same as (current) // For public clients, verify the DPoP Proof public key is same as (current)
// access token public key binding // access token public key binding
Map<String, Object> accessTokenClaims = authorization.getAccessToken().getClaims(); OAuth2Authorization.Token<OAuth2AccessToken> accessToken = authorization.getAccessToken();
Assert.notNull(accessToken, "accessToken cannot be null");
Map<String, Object> accessTokenClaims = (accessToken.getClaims() != null) ? accessToken.getClaims()
: Collections.emptyMap();
verifyDPoPProofPublicKey(dPoPProof, () -> accessTokenClaims); verifyDPoPProofPublicKey(dPoPProof, () -> accessTokenClaims);
} }
@@ -180,10 +185,12 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
scopes = authorizedScopes; scopes = authorizedScopes;
} }
Authentication principal = authorization.getAttribute(Principal.class.getName());
Assert.notNull(principal, "principal cannot be null");
// @formatter:off // @formatter:off
DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder() DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
.registeredClient(registeredClient) .registeredClient(registeredClient)
.principal(authorization.getAttribute(Principal.class.getName())) .principal(principal)
.authorizationServerContext(AuthorizationServerContextHolder.getContext()) .authorizationServerContext(AuthorizationServerContextHolder.getContext())
.authorization(authorization) .authorization(authorization)
.authorizedScopes(scopes) .authorizedScopes(scopes)
@@ -21,7 +21,8 @@ import java.util.HashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -48,11 +48,15 @@ public final class OAuth2TokenExchangeActor implements ClaimAccessor {
} }
public String getIssuer() { public String getIssuer() {
return getClaimAsString(OAuth2TokenClaimNames.ISS); String issuer = getClaimAsString(OAuth2TokenClaimNames.ISS);
Assert.notNull(issuer, "issuer cannot be null");
return issuer;
} }
public String getSubject() { public String getSubject() {
return getClaimAsString(OAuth2TokenClaimNames.SUB); String subject = getClaimAsString(OAuth2TokenClaimNames.SUB);
Assert.notNull(subject, "subject cannot be null");
return subject;
} }
@Override @Override
@@ -28,6 +28,7 @@ import java.util.Set;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@@ -106,6 +107,7 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(tokenExchangeAuthentication); .getAuthenticatedClientElseThrowInvalidClient(tokenExchangeAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client"); this.logger.trace("Retrieved registered client");
@@ -133,6 +135,7 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
OAuth2Authorization.Token<OAuth2Token> subjectToken = subjectAuthorization OAuth2Authorization.Token<OAuth2Token> subjectToken = subjectAuthorization
.getToken(tokenExchangeAuthentication.getSubjectToken()); .getToken(tokenExchangeAuthentication.getSubjectToken());
Assert.notNull(subjectToken, "subjectToken cannot be null");
if (!subjectToken.isActive()) { if (!subjectToken.isActive()) {
// As per https://tools.ietf.org/html/rfc6749#section-5.2 // As per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization code, // invalid_grant: The provided authorization grant (e.g., authorization code,
@@ -175,6 +178,7 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
OAuth2Authorization.Token<OAuth2Token> actorToken = actorAuthorization OAuth2Authorization.Token<OAuth2Token> actorToken = actorAuthorization
.getToken(tokenExchangeAuthentication.getActorToken()); .getToken(tokenExchangeAuthentication.getActorToken());
Assert.notNull(actorToken, "actorToken cannot be null");
if (!actorToken.isActive()) { if (!actorToken.isActive()) {
// As per https://tools.ietf.org/html/rfc6749#section-5.2 // As per https://tools.ietf.org/html/rfc6749#section-5.2
// invalid_grant: The provided authorization grant (e.g., authorization // invalid_grant: The provided authorization grant (e.g., authorization
@@ -184,8 +188,11 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
} }
if (!isValidTokenType(tokenExchangeAuthentication.getActorTokenType(), actorToken)) { String actorTokenType = tokenExchangeAuthentication.getActorTokenType();
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST); if (actorTokenType != null) {
if (!isValidTokenType(actorTokenType, actorToken)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
}
} }
if (authorizedActorClaims != null) { if (authorizedActorClaims != null) {
@@ -288,7 +295,7 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
return new LinkedHashSet<>(requestedScopes); return new LinkedHashSet<>(requestedScopes);
} }
private static void validateClaims(Map<String, Object> expectedClaims, Map<String, Object> actualClaims, private static void validateClaims(Map<String, Object> expectedClaims, @Nullable Map<String, Object> actualClaims,
String... claimNames) { String... claimNames) {
if (actualClaims == null) { if (actualClaims == null) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
@@ -302,8 +309,9 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
} }
private static Authentication getPrincipal(OAuth2Authorization subjectAuthorization, private static Authentication getPrincipal(OAuth2Authorization subjectAuthorization,
OAuth2Authorization actorAuthorization) { @Nullable OAuth2Authorization actorAuthorization) {
Authentication subjectPrincipal = subjectAuthorization.getAttribute(Principal.class.getName()); Authentication subjectPrincipal = subjectAuthorization.getAttribute(Principal.class.getName());
Assert.notNull(subjectPrincipal, "subject principal cannot be null");
if (actorAuthorization == null) { if (actorAuthorization == null) {
if (subjectPrincipal instanceof OAuth2TokenExchangeCompositeAuthenticationToken compositeAuthenticationToken) { if (subjectPrincipal instanceof OAuth2TokenExchangeCompositeAuthenticationToken compositeAuthenticationToken) {
return compositeAuthenticationToken.getSubject(); return compositeAuthenticationToken.getSubject();
@@ -312,8 +320,11 @@ public final class OAuth2TokenExchangeAuthenticationProvider implements Authenti
} }
// Capture claims for current actor's access token // Capture claims for current actor's access token
OAuth2TokenExchangeActor currentActor = new OAuth2TokenExchangeActor( OAuth2Authorization.Token<OAuth2AccessToken> actorAccessToken = actorAuthorization.getAccessToken();
actorAuthorization.getAccessToken().getClaims()); Assert.notNull(actorAccessToken, "actor access token cannot be null");
Map<String, Object> actorAccessTokenClaims = actorAccessToken.getClaims();
Assert.notNull(actorAccessTokenClaims, "actor access token claims cannot be null");
OAuth2TokenExchangeActor currentActor = new OAuth2TokenExchangeActor(actorAccessTokenClaims);
List<OAuth2TokenExchangeActor> actorPrincipals = new LinkedList<>(); List<OAuth2TokenExchangeActor> actorPrincipals = new LinkedList<>();
actorPrincipals.add(currentActor); actorPrincipals.add(currentActor);
@@ -22,7 +22,8 @@ import java.util.LinkedHashSet;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -43,9 +44,9 @@ public class OAuth2TokenExchangeAuthenticationToken extends OAuth2AuthorizationG
private final String subjectTokenType; private final String subjectTokenType;
private final String actorToken; private final @Nullable String actorToken;
private final String actorTokenType; private final @Nullable String actorTokenType;
private final Set<String> resources; private final Set<String> resources;
@@ -113,17 +114,17 @@ public class OAuth2TokenExchangeAuthenticationToken extends OAuth2AuthorizationG
/** /**
* Returns the actor token. * Returns the actor token.
* @return the actor token * @return the actor token, or {@code null} if not provided
*/ */
public String getActorToken() { public @Nullable String getActorToken() {
return this.actorToken; return this.actorToken;
} }
/** /**
* Returns the actor token type. * Returns the actor token type.
* @return the actor token type * @return the actor token type, or {@code null} if not provided
*/ */
public String getActorTokenType() { public @Nullable String getActorTokenType() {
return this.actorTokenType; return this.actorTokenType;
} }
@@ -21,6 +21,8 @@ import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Objects; import java.util.Objects;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -51,12 +53,12 @@ public class OAuth2TokenExchangeCompositeAuthenticationToken extends AbstractAut
} }
@Override @Override
public Object getPrincipal() { public @Nullable Object getPrincipal() {
return this.subject.getPrincipal(); return this.subject.getPrincipal();
} }
@Override @Override
public Object getCredentials() { public @Nullable Object getCredentials() {
return null; return null;
} }
@@ -102,6 +102,7 @@ public final class OAuth2TokenIntrospectionAuthenticationProvider implements Aut
OAuth2Authorization.Token<OAuth2Token> authorizedToken = authorization OAuth2Authorization.Token<OAuth2Token> authorizedToken = authorization
.getToken(tokenIntrospectionAuthentication.getToken()); .getToken(tokenIntrospectionAuthentication.getToken());
Assert.notNull(authorizedToken, "authorizedToken cannot be null");
if (!authorizedToken.isActive()) { if (!authorizedToken.isActive()) {
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Did not introspect token since not active"); this.logger.trace("Did not introspect token since not active");
@@ -112,6 +113,7 @@ public final class OAuth2TokenIntrospectionAuthenticationProvider implements Aut
RegisteredClient authorizedClient = this.registeredClientRepository RegisteredClient authorizedClient = this.registeredClientRepository
.findById(authorization.getRegisteredClientId()); .findById(authorization.getRegisteredClientId());
Assert.notNull(authorizedClient, "authorizedClient cannot be null");
OAuth2TokenIntrospection tokenClaims = withActiveTokenClaims(authorizedToken, authorizedClient); OAuth2TokenIntrospection tokenClaims = withActiveTokenClaims(authorizedToken, authorizedClient);
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -21,7 +21,8 @@ import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenIntrospection; import org.springframework.security.oauth2.server.authorization.OAuth2TokenIntrospection;
@@ -46,7 +47,7 @@ public class OAuth2TokenIntrospectionAuthenticationToken extends AbstractAuthent
private final Authentication clientPrincipal; private final Authentication clientPrincipal;
private final String tokenTypeHint; private final @Nullable String tokenTypeHint;
private final Map<String, Object> additionalParameters; private final Map<String, Object> additionalParameters;
@@ -118,8 +119,7 @@ public class OAuth2TokenIntrospectionAuthenticationToken extends AbstractAuthent
* Returns the token type hint. * Returns the token type hint.
* @return the token type hint * @return the token type hint
*/ */
@Nullable public @Nullable String getTokenTypeHint() {
public String getTokenTypeHint() {
return this.tokenTypeHint; return this.tokenTypeHint;
} }
@@ -64,6 +64,7 @@ public final class OAuth2TokenRevocationAuthenticationProvider implements Authen
OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils OAuth2ClientAuthenticationToken clientPrincipal = OAuth2AuthenticationProviderUtils
.getAuthenticatedClientElseThrowInvalidClient(tokenRevocationAuthentication); .getAuthenticatedClientElseThrowInvalidClient(tokenRevocationAuthentication);
RegisteredClient registeredClient = clientPrincipal.getRegisteredClient(); RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
Assert.notNull(registeredClient, "registeredClient cannot be null");
OAuth2Authorization authorization = this.authorizationService OAuth2Authorization authorization = this.authorizationService
.findByToken(tokenRevocationAuthentication.getToken(), null); .findByToken(tokenRevocationAuthentication.getToken(), null);
@@ -80,6 +81,7 @@ public final class OAuth2TokenRevocationAuthenticationProvider implements Authen
} }
OAuth2Authorization.Token<OAuth2Token> token = authorization.getToken(tokenRevocationAuthentication.getToken()); OAuth2Authorization.Token<OAuth2Token> token = authorization.getToken(tokenRevocationAuthentication.getToken());
Assert.notNull(token, "token cannot be null");
authorization = OAuth2Authorization.from(authorization).invalidate(token.getToken()).build(); authorization = OAuth2Authorization.from(authorization).invalidate(token.getToken()).build();
this.authorizationService.save(authorization); this.authorizationService.save(authorization);
@@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import java.io.Serial; import java.io.Serial;
import java.util.Collections; import java.util.Collections;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.OAuth2Token; import org.springframework.security.oauth2.core.OAuth2Token;
@@ -43,7 +44,7 @@ public class OAuth2TokenRevocationAuthenticationToken extends AbstractAuthentica
private final Authentication clientPrincipal; private final Authentication clientPrincipal;
private final String tokenTypeHint; private final @Nullable String tokenTypeHint;
/** /**
* Constructs an {@code OAuth2TokenRevocationAuthenticationToken} using the provided * Constructs an {@code OAuth2TokenRevocationAuthenticationToken} using the provided
@@ -100,8 +101,7 @@ public class OAuth2TokenRevocationAuthenticationToken extends AbstractAuthentica
* Returns the token type hint. * Returns the token type hint.
* @return the token type hint * @return the token type hint
*/ */
@Nullable public @Nullable String getTokenTypeHint() {
public String getTokenTypeHint() {
return this.tokenTypeHint; return this.tokenTypeHint;
} }
@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@@ -70,7 +71,7 @@ public final class PublicClientAuthenticationProvider implements AuthenticationP
} }
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication; OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
if (!ClientAuthenticationMethod.NONE.equals(clientAuthentication.getClientAuthenticationMethod())) { if (!ClientAuthenticationMethod.NONE.equals(clientAuthentication.getClientAuthenticationMethod())) {
@@ -80,7 +81,7 @@ public final class PublicClientAuthenticationProvider implements AuthenticationP
String clientId = clientAuthentication.getPrincipal().toString(); String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) { if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); throw invalidClient(OAuth2ParameterNames.CLIENT_ID);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -89,7 +90,7 @@ public final class PublicClientAuthenticationProvider implements AuthenticationP
if (!registeredClient.getClientAuthenticationMethods() if (!registeredClient.getClientAuthenticationMethods()
.contains(clientAuthentication.getClientAuthenticationMethod())) { .contains(clientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient("authentication_method"); throw invalidClient("authentication_method");
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -112,10 +113,10 @@ public final class PublicClientAuthenticationProvider implements AuthenticationP
return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication); return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
} }
private static void throwInvalidClient(String parameterName) { private static OAuth2AuthenticationException invalidClient(String parameterName) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI); "Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error); return new OAuth2AuthenticationException(error);
} }
} }
@@ -21,6 +21,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@@ -79,7 +80,7 @@ public final class X509ClientCertificateAuthenticationProvider implements Authen
} }
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication; OAuth2ClientAuthenticationToken clientAuthentication = (OAuth2ClientAuthenticationToken) authentication;
if (!ClientAuthenticationMethod.TLS_CLIENT_AUTH.equals(clientAuthentication.getClientAuthenticationMethod()) if (!ClientAuthenticationMethod.TLS_CLIENT_AUTH.equals(clientAuthentication.getClientAuthenticationMethod())
@@ -91,7 +92,7 @@ public final class X509ClientCertificateAuthenticationProvider implements Authen
String clientId = clientAuthentication.getPrincipal().toString(); String clientId = clientAuthentication.getPrincipal().toString();
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId); RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) { if (registeredClient == null) {
throwInvalidClient(OAuth2ParameterNames.CLIENT_ID); throw invalidClient(OAuth2ParameterNames.CLIENT_ID);
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -100,11 +101,11 @@ public final class X509ClientCertificateAuthenticationProvider implements Authen
if (!registeredClient.getClientAuthenticationMethods() if (!registeredClient.getClientAuthenticationMethods()
.contains(clientAuthentication.getClientAuthenticationMethod())) { .contains(clientAuthentication.getClientAuthenticationMethod())) {
throwInvalidClient("authentication_method"); throw invalidClient("authentication_method");
} }
if (!(clientAuthentication.getCredentials() instanceof X509Certificate[])) { if (!(clientAuthentication.getCredentials() instanceof X509Certificate[])) {
throwInvalidClient("credentials"); throw invalidClient("credentials");
} }
OAuth2ClientAuthenticationContext authenticationContext = OAuth2ClientAuthenticationContext OAuth2ClientAuthenticationContext authenticationContext = OAuth2ClientAuthenticationContext
@@ -170,22 +171,23 @@ public final class X509ClientCertificateAuthenticationProvider implements Authen
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication(); OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication();
RegisteredClient registeredClient = clientAuthenticationContext.getRegisteredClient(); RegisteredClient registeredClient = clientAuthenticationContext.getRegisteredClient();
X509Certificate[] clientCertificateChain = (X509Certificate[]) clientAuthentication.getCredentials(); X509Certificate[] clientCertificateChain = (X509Certificate[]) clientAuthentication.getCredentials();
Assert.notEmpty(clientCertificateChain, "clientCertificateChain cannot be empty");
X509Certificate clientCertificate = clientCertificateChain[0]; X509Certificate clientCertificate = clientCertificateChain[0];
String expectedSubjectDN = registeredClient.getClientSettings().getX509CertificateSubjectDN(); String expectedSubjectDN = registeredClient.getClientSettings().getX509CertificateSubjectDN();
if (!StringUtils.hasText(expectedSubjectDN) if (!StringUtils.hasText(expectedSubjectDN)
|| !clientCertificate.getSubjectX500Principal().getName().equals(expectedSubjectDN)) { || !clientCertificate.getSubjectX500Principal().getName().equals(expectedSubjectDN)) {
throwInvalidClient("x509_certificate_subject_dn"); throw invalidClient("x509_certificate_subject_dn");
} }
} }
private static void throwInvalidClient(String parameterName) { private static OAuth2AuthenticationException invalidClient(String parameterName) {
throwInvalidClient(parameterName, null); return invalidClient(parameterName, null);
} }
private static void throwInvalidClient(String parameterName, Throwable cause) { private static OAuth2AuthenticationException invalidClient(String parameterName, @Nullable Throwable cause) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI); "Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error, error.toString(), cause); return new OAuth2AuthenticationException(error, error.toString(), cause);
} }
} }
@@ -37,6 +37,7 @@ import javax.security.auth.x500.X500Principal;
import com.nimbusds.jose.jwk.JWK; import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher; import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSet; import com.nimbusds.jose.jwk.JWKSet;
import org.jspecify.annotations.Nullable;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
@@ -48,6 +49,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes; import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
import org.springframework.web.client.RestOperations; import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
@@ -74,12 +76,13 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication(); OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication();
RegisteredClient registeredClient = clientAuthenticationContext.getRegisteredClient(); RegisteredClient registeredClient = clientAuthenticationContext.getRegisteredClient();
X509Certificate[] clientCertificateChain = (X509Certificate[]) clientAuthentication.getCredentials(); X509Certificate[] clientCertificateChain = (X509Certificate[]) clientAuthentication.getCredentials();
Assert.notEmpty(clientCertificateChain, "clientCertificateChain cannot be empty");
X509Certificate clientCertificate = clientCertificateChain[0]; X509Certificate clientCertificate = clientCertificateChain[0];
X500Principal issuer = clientCertificate.getIssuerX500Principal(); X500Principal issuer = clientCertificate.getIssuerX500Principal();
X500Principal subject = clientCertificate.getSubjectX500Principal(); X500Principal subject = clientCertificate.getSubjectX500Principal();
if (issuer == null || !issuer.equals(subject)) { if (issuer == null || !issuer.equals(subject)) {
throwInvalidClient("x509_certificate_issuer"); throw invalidClient("x509_certificate_issuer");
} }
JWKSet jwkSet = this.jwkSetSupplier.apply(registeredClient); JWKSet jwkSet = this.jwkSetSupplier.apply(registeredClient);
@@ -95,18 +98,18 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
} }
if (!publicKeyMatches) { if (!publicKeyMatches) {
throwInvalidClient("x509_certificate"); throw invalidClient("x509_certificate");
} }
} }
private static void throwInvalidClient(String parameterName) { private static OAuth2AuthenticationException invalidClient(String parameterName) {
throwInvalidClient(parameterName, null); return invalidClient(parameterName, null);
} }
private static void throwInvalidClient(String parameterName, Throwable cause) { private static OAuth2AuthenticationException invalidClient(String parameterName, @Nullable Throwable cause) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT, OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI); "Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error, error.toString(), cause); return new OAuth2AuthenticationException(error, error.toString(), cause);
} }
private static final class JwkSetSupplier implements Function<RegisteredClient, JWKSet> { private static final class JwkSetSupplier implements Function<RegisteredClient, JWKSet> {
@@ -128,7 +131,7 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
public JWKSet apply(RegisteredClient registeredClient) { public JWKSet apply(RegisteredClient registeredClient) {
Supplier<JWKSet> jwkSetSupplier = this.jwkSets.computeIfAbsent(registeredClient.getId(), (key) -> { Supplier<JWKSet> jwkSetSupplier = this.jwkSets.computeIfAbsent(registeredClient.getId(), (key) -> {
if (!StringUtils.hasText(registeredClient.getClientSettings().getJwkSetUrl())) { if (!StringUtils.hasText(registeredClient.getClientSettings().getJwkSetUrl())) {
throwInvalidClient("client_jwk_set_url"); throw invalidClient("client_jwk_set_url");
} }
return new JwkSetHolder(registeredClient.getClientSettings().getJwkSetUrl()); return new JwkSetHolder(registeredClient.getClientSettings().getJwkSetUrl());
}); });
@@ -136,34 +139,36 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
} }
private JWKSet retrieve(String jwkSetUrl) { private JWKSet retrieve(String jwkSetUrl) {
URI jwkSetUri = null; final URI jwkSetUri;
try { try {
jwkSetUri = new URI(jwkSetUrl); jwkSetUri = new URI(jwkSetUrl);
} }
catch (URISyntaxException ex) { catch (URISyntaxException ex) {
throwInvalidClient("jwk_set_uri", ex); throw invalidClient("jwk_set_uri", ex);
} }
HttpHeaders headers = new HttpHeaders(); HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON)); headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, jwkSetUri); RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, jwkSetUri);
ResponseEntity<String> response = null; final ResponseEntity<String> response;
try { try {
response = this.restOperations.exchange(request, String.class); response = this.restOperations.exchange(request, String.class);
} }
catch (Exception ex) { catch (Exception ex) {
throwInvalidClient("jwk_set_response_error", ex); throw invalidClient("jwk_set_response_error", ex);
} }
if (response.getStatusCode().value() != 200) { if (response.getStatusCode().value() != 200) {
throwInvalidClient("jwk_set_response_status"); throw invalidClient("jwk_set_response_status");
} }
JWKSet jwkSet = null; final JWKSet jwkSet;
try { try {
jwkSet = JWKSet.parse(response.getBody()); String body = response.getBody();
Assert.notNull(body, "response body cannot be null");
jwkSet = JWKSet.parse(body);
} }
catch (ParseException ex) { catch (ParseException ex) {
throwInvalidClient("jwk_set_response_body", ex); throw invalidClient("jwk_set_response_body", ex);
} }
return jwkSet; return jwkSet;
@@ -177,9 +182,9 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
private final String jwkSetUrl; private final String jwkSetUrl;
private JWKSet jwkSet; private @Nullable JWKSet jwkSet;
private Instant lastUpdatedAt; private @Nullable Instant lastUpdatedAt;
private JwkSetHolder(String jwkSetUrl) { private JwkSetHolder(String jwkSetUrl) {
this.jwkSetUrl = jwkSetUrl; this.jwkSetUrl = jwkSetUrl;
@@ -204,6 +209,7 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
} }
try { try {
Assert.notNull(this.jwkSet, "jwkSet cannot be null");
return this.jwkSet; return this.jwkSet;
} }
finally { finally {
@@ -213,7 +219,7 @@ final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAu
private boolean shouldRefresh() { private boolean shouldRefresh() {
// Refresh every 5 minutes // Refresh every 5 minutes
return (this.jwkSet == null return (this.jwkSet == null || this.lastUpdatedAt == null
|| this.clock.instant().isAfter(this.lastUpdatedAt.plus(5, ChronoUnit.MINUTES))); || this.clock.instant().isAfter(this.lastUpdatedAt.plus(5, ChronoUnit.MINUTES)));
} }
@@ -0,0 +1,25 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* {@link org.springframework.security.authentication.AuthenticationProvider}
* implementations and related types for OAuth2 and OpenID Connect 1.0 flows handled by
* the authorization server.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.authentication;
import org.jspecify.annotations.NullMarked;
@@ -21,7 +21,8 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
@@ -83,16 +84,14 @@ public final class InMemoryRegisteredClientRepository implements RegisteredClien
this.clientIdRegistrationMap.put(registeredClient.getClientId(), registeredClient); this.clientIdRegistrationMap.put(registeredClient.getClientId(), registeredClient);
} }
@Nullable
@Override @Override
public RegisteredClient findById(String id) { public @Nullable RegisteredClient findById(String id) {
Assert.hasText(id, "id cannot be empty"); Assert.hasText(id, "id cannot be empty");
return this.idRegistrationMap.get(id); return this.idRegistrationMap.get(id);
} }
@Nullable
@Override @Override
public RegisteredClient findByClientId(String clientId) { public @Nullable RegisteredClient findByClientId(String clientId) {
Assert.hasText(clientId, "clientId cannot be empty"); Assert.hasText(clientId, "clientId cannot be empty");
return this.clientIdRegistrationMap.get(clientId); return this.clientIdRegistrationMap.get(clientId);
} }
@@ -31,6 +31,7 @@ import java.util.function.Function;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.Module; import com.fasterxml.jackson.databind.Module;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.jspecify.annotations.Nullable;
import tools.jackson.databind.JacksonModule; import tools.jackson.databind.JacksonModule;
import tools.jackson.databind.json.JsonMapper; import tools.jackson.databind.json.JsonMapper;
@@ -190,18 +191,18 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
} }
@Override @Override
public RegisteredClient findById(String id) { public @Nullable RegisteredClient findById(String id) {
Assert.hasText(id, "id cannot be empty"); Assert.hasText(id, "id cannot be empty");
return findBy("id = ?", id); return findBy("id = ?", id);
} }
@Override @Override
public RegisteredClient findByClientId(String clientId) { public @Nullable RegisteredClient findByClientId(String clientId) {
Assert.hasText(clientId, "clientId cannot be empty"); Assert.hasText(clientId, "clientId cannot be empty");
return findBy("client_id = ?", clientId); return findBy("client_id = ?", clientId);
} }
private RegisteredClient findBy(String filter, Object... args) { private @Nullable RegisteredClient findBy(String filter, Object... args) {
List<RegisteredClient> result = this.jdbcOperations.query(LOAD_REGISTERED_CLIENT_SQL + filter, List<RegisteredClient> result = this.jdbcOperations.query(LOAD_REGISTERED_CLIENT_SQL + filter,
this.registeredClientRowMapper, args); this.registeredClientRowMapper, args);
return !result.isEmpty() ? result.get(0) : null; return !result.isEmpty() ? result.get(0) : null;
@@ -334,10 +335,15 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
// @formatter:off // @formatter:off
RegisteredClient.Builder builder = RegisteredClient.withId(rs.getString("id")) RegisteredClient.Builder builder = RegisteredClient.withId(rs.getString("id"))
.clientId(rs.getString("client_id")) .clientId(rs.getString("client_id"));
.clientIdIssuedAt((clientIdIssuedAt != null) ? clientIdIssuedAt.toInstant() : null) if (clientIdIssuedAt != null) {
.clientSecret(rs.getString("client_secret")) builder.clientIdIssuedAt(clientIdIssuedAt.toInstant());
.clientSecretExpiresAt((clientSecretExpiresAt != null) ? clientSecretExpiresAt.toInstant() : null) }
builder.clientSecret(rs.getString("client_secret"));
if (clientSecretExpiresAt != null) {
builder.clientSecretExpiresAt(clientSecretExpiresAt.toInstant());
}
builder
.clientName(rs.getString("client_name")) .clientName(rs.getString("client_name"))
.clientAuthenticationMethods((authenticationMethods) -> .clientAuthenticationMethods((authenticationMethods) ->
clientAuthenticationMethods.forEach((authenticationMethod) -> clientAuthenticationMethods.forEach((authenticationMethod) ->
@@ -558,7 +564,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
static class JdbcRegisteredClientRepositoryRuntimeHintsRegistrar implements RuntimeHintsRegistrar { static class JdbcRegisteredClientRepositoryRuntimeHintsRegistrar implements RuntimeHintsRegistrar {
@Override @Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) { public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) {
hints.resources() hints.resources()
.registerResource(new ClassPathResource( .registerResource(new ClassPathResource(
"org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql")); "org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql"));
@@ -27,7 +27,8 @@ import java.util.Objects;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.server.authorization.settings.ClientSettings; import org.springframework.security.oauth2.server.authorization.settings.ClientSettings;
@@ -50,31 +51,31 @@ public class RegisteredClient implements Serializable {
@Serial @Serial
private static final long serialVersionUID = -717282636175335081L; private static final long serialVersionUID = -717282636175335081L;
private String id; private @Nullable String id;
private String clientId; private @Nullable String clientId;
private Instant clientIdIssuedAt; private @Nullable Instant clientIdIssuedAt;
private String clientSecret; private @Nullable String clientSecret;
private Instant clientSecretExpiresAt; private @Nullable Instant clientSecretExpiresAt;
private String clientName; private @Nullable String clientName;
private Set<ClientAuthenticationMethod> clientAuthenticationMethods; private @Nullable Set<ClientAuthenticationMethod> clientAuthenticationMethods;
private Set<AuthorizationGrantType> authorizationGrantTypes; private @Nullable Set<AuthorizationGrantType> authorizationGrantTypes;
private Set<String> redirectUris; private @Nullable Set<String> redirectUris;
private Set<String> postLogoutRedirectUris; private @Nullable Set<String> postLogoutRedirectUris;
private Set<String> scopes; private @Nullable Set<String> scopes;
private ClientSettings clientSettings; private @Nullable ClientSettings clientSettings;
private TokenSettings tokenSettings; private @Nullable TokenSettings tokenSettings;
protected RegisteredClient() { protected RegisteredClient() {
} }
@@ -84,6 +85,7 @@ public class RegisteredClient implements Serializable {
* @return the identifier for the registration * @return the identifier for the registration
*/ */
public String getId() { public String getId() {
Assert.notNull(this.id, "id cannot be null");
return this.id; return this.id;
} }
@@ -92,6 +94,7 @@ public class RegisteredClient implements Serializable {
* @return the client identifier * @return the client identifier
*/ */
public String getClientId() { public String getClientId() {
Assert.notNull(this.clientId, "clientId cannot be null");
return this.clientId; return this.clientId;
} }
@@ -99,8 +102,7 @@ public class RegisteredClient implements Serializable {
* Returns the time at which the client identifier was issued. * Returns the time at which the client identifier was issued.
* @return the time at which the client identifier was issued * @return the time at which the client identifier was issued
*/ */
@Nullable public @Nullable Instant getClientIdIssuedAt() {
public Instant getClientIdIssuedAt() {
return this.clientIdIssuedAt; return this.clientIdIssuedAt;
} }
@@ -108,8 +110,7 @@ public class RegisteredClient implements Serializable {
* Returns the client secret or {@code null} if not available. * Returns the client secret or {@code null} if not available.
* @return the client secret or {@code null} if not available * @return the client secret or {@code null} if not available
*/ */
@Nullable public @Nullable String getClientSecret() {
public String getClientSecret() {
return this.clientSecret; return this.clientSecret;
} }
@@ -119,8 +120,7 @@ public class RegisteredClient implements Serializable {
* @return the time at which the client secret expires or {@code null} if it does not * @return the time at which the client secret expires or {@code null} if it does not
* expire * expire
*/ */
@Nullable public @Nullable Instant getClientSecretExpiresAt() {
public Instant getClientSecretExpiresAt() {
return this.clientSecretExpiresAt; return this.clientSecretExpiresAt;
} }
@@ -129,6 +129,7 @@ public class RegisteredClient implements Serializable {
* @return the client name * @return the client name
*/ */
public String getClientName() { public String getClientName() {
Assert.notNull(this.clientName, "clientName cannot be null");
return this.clientName; return this.clientName;
} }
@@ -139,6 +140,7 @@ public class RegisteredClient implements Serializable {
* method(s)} * method(s)}
*/ */
public Set<ClientAuthenticationMethod> getClientAuthenticationMethods() { public Set<ClientAuthenticationMethod> getClientAuthenticationMethods() {
Assert.notNull(this.clientAuthenticationMethods, "clientAuthenticationMethods cannot be null");
return this.clientAuthenticationMethods; return this.clientAuthenticationMethods;
} }
@@ -149,6 +151,7 @@ public class RegisteredClient implements Serializable {
* type(s)} * type(s)}
*/ */
public Set<AuthorizationGrantType> getAuthorizationGrantTypes() { public Set<AuthorizationGrantType> getAuthorizationGrantTypes() {
Assert.notNull(this.authorizationGrantTypes, "authorizationGrantTypes cannot be null");
return this.authorizationGrantTypes; return this.authorizationGrantTypes;
} }
@@ -157,6 +160,7 @@ public class RegisteredClient implements Serializable {
* @return the {@code Set} of redirect URI(s) * @return the {@code Set} of redirect URI(s)
*/ */
public Set<String> getRedirectUris() { public Set<String> getRedirectUris() {
Assert.notNull(this.redirectUris, "redirectUris cannot be null");
return this.redirectUris; return this.redirectUris;
} }
@@ -167,6 +171,7 @@ public class RegisteredClient implements Serializable {
* @return the {@code Set} of post logout redirect URI(s) * @return the {@code Set} of post logout redirect URI(s)
*/ */
public Set<String> getPostLogoutRedirectUris() { public Set<String> getPostLogoutRedirectUris() {
Assert.notNull(this.postLogoutRedirectUris, "postLogoutRedirectUris cannot be null");
return this.postLogoutRedirectUris; return this.postLogoutRedirectUris;
} }
@@ -175,6 +180,7 @@ public class RegisteredClient implements Serializable {
* @return the {@code Set} of scope(s) * @return the {@code Set} of scope(s)
*/ */
public Set<String> getScopes() { public Set<String> getScopes() {
Assert.notNull(this.scopes, "scopes cannot be null");
return this.scopes; return this.scopes;
} }
@@ -183,6 +189,7 @@ public class RegisteredClient implements Serializable {
* @return the {@link ClientSettings} * @return the {@link ClientSettings}
*/ */
public ClientSettings getClientSettings() { public ClientSettings getClientSettings() {
Assert.notNull(this.clientSettings, "clientSettings cannot be null");
return this.clientSettings; return this.clientSettings;
} }
@@ -191,6 +198,7 @@ public class RegisteredClient implements Serializable {
* @return the {@link TokenSettings} * @return the {@link TokenSettings}
*/ */
public TokenSettings getTokenSettings() { public TokenSettings getTokenSettings() {
Assert.notNull(this.tokenSettings, "tokenSettings cannot be null");
return this.tokenSettings; return this.tokenSettings;
} }
@@ -261,17 +269,17 @@ public class RegisteredClient implements Serializable {
*/ */
public static class Builder { public static class Builder {
private String id; private @Nullable String id;
private String clientId; private @Nullable String clientId;
private Instant clientIdIssuedAt; private @Nullable Instant clientIdIssuedAt;
private String clientSecret; private @Nullable String clientSecret;
private Instant clientSecretExpiresAt; private @Nullable Instant clientSecretExpiresAt;
private String clientName; private @Nullable String clientName;
private final Set<ClientAuthenticationMethod> clientAuthenticationMethods = new HashSet<>(); private final Set<ClientAuthenticationMethod> clientAuthenticationMethods = new HashSet<>();
@@ -283,9 +291,9 @@ public class RegisteredClient implements Serializable {
private final Set<String> scopes = new HashSet<>(); private final Set<String> scopes = new HashSet<>();
private ClientSettings clientSettings; private @Nullable ClientSettings clientSettings;
private TokenSettings tokenSettings; private @Nullable TokenSettings tokenSettings;
protected Builder(String id) { protected Builder(String id) {
this.id = id; this.id = id;
@@ -16,7 +16,7 @@
package org.springframework.security.oauth2.server.authorization.client; package org.springframework.security.oauth2.server.authorization.client;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
/** /**
* A repository for OAuth 2.0 {@link RegisteredClient}(s). * A repository for OAuth 2.0 {@link RegisteredClient}(s).
@@ -45,8 +45,7 @@ public interface RegisteredClientRepository {
* @param id the registration identifier * @param id the registration identifier
* @return the {@link RegisteredClient} if found, otherwise {@code null} * @return the {@link RegisteredClient} if found, otherwise {@code null}
*/ */
@Nullable @Nullable RegisteredClient findById(String id);
RegisteredClient findById(String id);
/** /**
* Returns the registered client identified by the provided {@code clientId}, or * Returns the registered client identified by the provided {@code clientId}, or
@@ -54,7 +53,6 @@ public interface RegisteredClientRepository {
* @param clientId the client identifier * @param clientId the client identifier
* @return the {@link RegisteredClient} if found, otherwise {@code null} * @return the {@link RegisteredClient} if found, otherwise {@code null}
*/ */
@Nullable @Nullable RegisteredClient findByClientId(String clientId);
RegisteredClient findByClientId(String clientId);
} }
@@ -0,0 +1,25 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Client registration persistence for the authorization server, including
* {@link org.springframework.security.oauth2.server.authorization.client.RegisteredClient}
* and repository abstractions.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.client;
import org.jspecify.annotations.NullMarked;
@@ -16,7 +16,8 @@
package org.springframework.security.oauth2.server.authorization.context; package org.springframework.security.oauth2.server.authorization.context;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
@@ -34,8 +35,7 @@ public interface Context {
* @return the value of the attribute associated to the key, or {@code null} if not * @return the value of the attribute associated to the key, or {@code null} if not
* available * available
*/ */
@Nullable <V> @Nullable V get(Object key);
<V> V get(Object key);
/** /**
* Returns the value of the attribute associated to the key. * Returns the value of the attribute associated to the key.
@@ -44,8 +44,7 @@ public interface Context {
* @return the value of the attribute associated to the key, or {@code null} if not * @return the value of the attribute associated to the key, or {@code null} if not
* available or not of the specified type * available or not of the specified type
*/ */
@Nullable default <V> @Nullable V get(Class<V> key) {
default <V> V get(Class<V> key) {
Assert.notNull(key, "key cannot be null"); Assert.notNull(key, "key cannot be null");
V value = get((Object) key); V value = get((Object) key);
return key.isInstance(value) ? value : null; return key.isInstance(value) ? value : null;
@@ -0,0 +1,24 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Context types that carry authorization server request state and attributes during
* protocol processing.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.context;
import org.jspecify.annotations.NullMarked;
@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.converter;
import java.time.Instant; import java.time.Instant;
import java.util.Base64; import java.util.Base64;
import java.util.List;
import java.util.UUID; import java.util.UUID;
import java.util.function.Consumer; import java.util.function.Consumer;
@@ -58,9 +59,11 @@ public final class OAuth2ClientRegistrationRegisteredClientConverter
// @formatter:off // @formatter:off
RegisteredClient.Builder builder = RegisteredClient.withId(UUID.randomUUID().toString()) RegisteredClient.Builder builder = RegisteredClient.withId(UUID.randomUUID().toString())
.clientId(CLIENT_ID_GENERATOR.generateKey()) .clientId(CLIENT_ID_GENERATOR.generateKey())
.clientIdIssuedAt(Instant.now()) .clientIdIssuedAt(Instant.now());
.clientName(clientRegistration.getClientName()); String clientName = clientRegistration.getClientName();
if (clientName != null) {
builder.clientName(clientName);
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) { if (ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) {
builder builder
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST) .clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST)
@@ -80,9 +83,10 @@ public final class OAuth2ClientRegistrationRegisteredClientConverter
redirectUris.addAll(clientRegistration.getRedirectUris())); redirectUris.addAll(clientRegistration.getRedirectUris()));
} }
if (!CollectionUtils.isEmpty(clientRegistration.getGrantTypes())) { List<String> grantTypes = clientRegistration.getGrantTypes();
if (!CollectionUtils.isEmpty(grantTypes)) {
builder.authorizationGrantTypes((authorizationGrantTypes) -> builder.authorizationGrantTypes((authorizationGrantTypes) ->
clientRegistration.getGrantTypes().forEach((grantType) -> grantTypes.forEach((grantType) ->
authorizationGrantTypes.add(new AuthorizationGrantType(grantType)))); authorizationGrantTypes.add(new AuthorizationGrantType(grantType))));
} }
else { else {
@@ -16,6 +16,8 @@
package org.springframework.security.oauth2.server.authorization.converter; package org.springframework.security.oauth2.server.authorization.converter;
import java.time.Instant;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
@@ -39,8 +41,11 @@ public final class RegisteredClientOAuth2ClientRegistrationConverter
// @formatter:off // @formatter:off
OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder() OAuth2ClientRegistration.Builder builder = OAuth2ClientRegistration.builder()
.clientId(registeredClient.getClientId()) .clientId(registeredClient.getClientId())
.clientIdIssuedAt(registeredClient.getClientIdIssuedAt())
.clientName(registeredClient.getClientName()); .clientName(registeredClient.getClientName());
Instant clientIdIssuedAt = registeredClient.getClientIdIssuedAt();
if (clientIdIssuedAt != null) {
builder.clientIdIssuedAt(clientIdIssuedAt);
}
builder builder
.tokenEndpointAuthenticationMethod(registeredClient.getClientAuthenticationMethods().iterator().next().getValue()); .tokenEndpointAuthenticationMethod(registeredClient.getClientAuthenticationMethods().iterator().next().getValue());
@@ -0,0 +1,24 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* {@link org.springframework.core.convert.converter.Converter} implementations for
* authorization server domain types.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.converter;
import org.jspecify.annotations.NullMarked;
@@ -16,6 +16,8 @@
package org.springframework.security.oauth2.server.authorization.http.converter; package org.springframework.security.oauth2.server.authorization.http.converter;
import org.jspecify.annotations.Nullable;
import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.GenericHttpMessageConverter;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.json.GsonHttpMessageConverter; import org.springframework.http.converter.json.GsonHttpMessageConverter;
@@ -54,7 +56,7 @@ final class HttpMessageConverters {
} }
@SuppressWarnings("removal") @SuppressWarnings("removal")
static GenericHttpMessageConverter<Object> getJsonMessageConverter() { static @Nullable GenericHttpMessageConverter<Object> getJsonMessageConverter() {
if (jacksonPresent) { if (jacksonPresent) {
return new GenericHttpMessageConverterAdapter<>(new JacksonJsonHttpMessageConverter()); return new GenericHttpMessageConverterAdapter<>(new JacksonJsonHttpMessageConverter());
} }
@@ -53,8 +53,7 @@ public class OAuth2AuthorizationServerMetadataHttpMessageConverter
private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() { private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() {
}; };
private final GenericHttpMessageConverter<Object> jsonMessageConverter = HttpMessageConverters private final GenericHttpMessageConverter<Object> jsonMessageConverter;
.getJsonMessageConverter();
private Converter<Map<String, Object>, OAuth2AuthorizationServerMetadata> authorizationServerMetadataConverter = new OAuth2AuthorizationServerMetadataConverter(); private Converter<Map<String, Object>, OAuth2AuthorizationServerMetadata> authorizationServerMetadataConverter = new OAuth2AuthorizationServerMetadataConverter();
@@ -62,6 +61,9 @@ public class OAuth2AuthorizationServerMetadataHttpMessageConverter
public OAuth2AuthorizationServerMetadataHttpMessageConverter() { public OAuth2AuthorizationServerMetadataHttpMessageConverter() {
super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json")); super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json"));
GenericHttpMessageConverter<Object> converter = HttpMessageConverters.getJsonMessageConverter();
Assert.notNull(converter, "Unable to locate a supported JSON message converter");
this.jsonMessageConverter = converter;
} }
@Override @Override
@@ -26,6 +26,8 @@ import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.jspecify.annotations.Nullable;
import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.convert.TypeDescriptor; import org.springframework.core.convert.TypeDescriptor;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
@@ -60,8 +62,7 @@ public class OAuth2ClientRegistrationHttpMessageConverter
private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() { private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() {
}; };
private final GenericHttpMessageConverter<Object> jsonMessageConverter = HttpMessageConverters private final GenericHttpMessageConverter<Object> jsonMessageConverter;
.getJsonMessageConverter();
private Converter<Map<String, Object>, OAuth2ClientRegistration> clientRegistrationConverter = new MapOAuth2ClientRegistrationConverter(); private Converter<Map<String, Object>, OAuth2ClientRegistration> clientRegistrationConverter = new MapOAuth2ClientRegistrationConverter();
@@ -69,6 +70,9 @@ public class OAuth2ClientRegistrationHttpMessageConverter
public OAuth2ClientRegistrationHttpMessageConverter() { public OAuth2ClientRegistrationHttpMessageConverter() {
super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json")); super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json"));
GenericHttpMessageConverter<Object> converter = HttpMessageConverters.getJsonMessageConverter();
Assert.notNull(converter, "Unable to locate a supported JSON message converter");
this.jsonMessageConverter = converter;
} }
@Override @Override
@@ -187,7 +191,7 @@ public class OAuth2ClientRegistrationHttpMessageConverter
return (source) -> CLAIM_CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, targetDescriptor); return (source) -> CLAIM_CONVERSION_SERVICE.convert(source, OBJECT_TYPE_DESCRIPTOR, targetDescriptor);
} }
private static Instant convertClientSecretExpiresAt(Object clientSecretExpiresAt) { private static @Nullable Instant convertClientSecretExpiresAt(Object clientSecretExpiresAt) {
if (clientSecretExpiresAt != null && String.valueOf(clientSecretExpiresAt).equals("0")) { if (clientSecretExpiresAt != null && String.valueOf(clientSecretExpiresAt).equals("0")) {
// 0 indicates that client_secret_expires_at does not expire // 0 indicates that client_secret_expires_at does not expire
return null; return null;
@@ -61,8 +61,7 @@ public class OAuth2TokenIntrospectionHttpMessageConverter
private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() { private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() {
}; };
private final GenericHttpMessageConverter<Object> jsonMessageConverter = HttpMessageConverters private final GenericHttpMessageConverter<Object> jsonMessageConverter;
.getJsonMessageConverter();
private Converter<Map<String, Object>, OAuth2TokenIntrospection> tokenIntrospectionConverter = new MapOAuth2TokenIntrospectionConverter(); private Converter<Map<String, Object>, OAuth2TokenIntrospection> tokenIntrospectionConverter = new MapOAuth2TokenIntrospectionConverter();
@@ -70,6 +69,9 @@ public class OAuth2TokenIntrospectionHttpMessageConverter
public OAuth2TokenIntrospectionHttpMessageConverter() { public OAuth2TokenIntrospectionHttpMessageConverter() {
super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json")); super(MediaType.APPLICATION_JSON, new MediaType("application", "*+json"));
GenericHttpMessageConverter<Object> converter = HttpMessageConverters.getJsonMessageConverter();
Assert.notNull(converter, "Unable to locate a supported JSON message converter");
this.jsonMessageConverter = converter;
} }
@Override @Override
@@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* HTTP message converters for OAuth2 Authorization Server protocol representations.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.http.converter;
import org.jspecify.annotations.NullMarked;
@@ -19,6 +19,7 @@ package org.springframework.security.oauth2.server.authorization.jackson;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.jspecify.annotations.Nullable;
import tools.jackson.core.type.TypeReference; import tools.jackson.core.type.TypeReference;
import tools.jackson.databind.DeserializationContext; import tools.jackson.databind.DeserializationContext;
import tools.jackson.databind.JsonNode; import tools.jackson.databind.JsonNode;
@@ -37,7 +38,7 @@ abstract class JsonNodeUtils {
static final TypeReference<Map<String, Object>> STRING_OBJECT_MAP = new TypeReference<>() { static final TypeReference<Map<String, Object>> STRING_OBJECT_MAP = new TypeReference<>() {
}; };
static String findStringValue(JsonNode jsonNode, String fieldName) { static @Nullable String findStringValue(@Nullable JsonNode jsonNode, String fieldName) {
if (jsonNode == null) { if (jsonNode == null) {
return null; return null;
} }
@@ -45,7 +46,7 @@ abstract class JsonNodeUtils {
return (value != null && value.isString()) ? value.stringValue() : null; return (value != null && value.isString()) ? value.stringValue() : null;
} }
static <T> T findValue(JsonNode jsonNode, String fieldName, TypeReference<T> valueTypeReference, static <T> @Nullable T findValue(@Nullable JsonNode jsonNode, String fieldName, TypeReference<T> valueTypeReference,
DeserializationContext context) { DeserializationContext context) {
if (jsonNode == null) { if (jsonNode == null) {
return null; return null;
@@ -55,7 +56,7 @@ abstract class JsonNodeUtils {
? context.readTreeAsValue(value, context.getTypeFactory().constructType(valueTypeReference)) : null; ? context.readTreeAsValue(value, context.getTypeFactory().constructType(valueTypeReference)) : null;
} }
static JsonNode findObjectNode(JsonNode jsonNode, String fieldName) { static @Nullable JsonNode findObjectNode(@Nullable JsonNode jsonNode, String fieldName) {
if (jsonNode == null) { if (jsonNode == null) {
return null; return null;
} }
@@ -16,6 +16,10 @@
package org.springframework.security.oauth2.server.authorization.jackson; package org.springframework.security.oauth2.server.authorization.jackson;
import java.util.Collections;
import java.util.Map;
import org.jspecify.annotations.Nullable;
import tools.jackson.core.JsonParser; import tools.jackson.core.JsonParser;
import tools.jackson.databind.DeserializationContext; import tools.jackson.databind.DeserializationContext;
import tools.jackson.databind.JsonNode; import tools.jackson.databind.JsonNode;
@@ -25,6 +29,7 @@ import tools.jackson.databind.exc.InvalidFormatException;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest.Builder; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest.Builder;
import org.springframework.util.Assert;
/** /**
* A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}. * A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}.
@@ -45,16 +50,27 @@ final class OAuth2AuthorizationRequestDeserializer extends ValueDeserializer<OAu
private OAuth2AuthorizationRequest deserialize(JsonParser parser, DeserializationContext context, JsonNode root) { private OAuth2AuthorizationRequest deserialize(JsonParser parser, DeserializationContext context, JsonNode root) {
AuthorizationGrantType authorizationGrantType = convertAuthorizationGrantType( AuthorizationGrantType authorizationGrantType = convertAuthorizationGrantType(
JsonNodeUtils.findObjectNode(root, "authorizationGrantType")); JsonNodeUtils.findObjectNode(root, "authorizationGrantType"));
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
Builder builder = getBuilder(parser, authorizationGrantType); Builder builder = getBuilder(parser, authorizationGrantType);
builder.authorizationUri(JsonNodeUtils.findStringValue(root, "authorizationUri")); String authorizationUri = JsonNodeUtils.findStringValue(root, "authorizationUri");
builder.clientId(JsonNodeUtils.findStringValue(root, "clientId")); Assert.notNull(authorizationUri, "authorizationUri cannot be null");
builder.authorizationUri(authorizationUri);
String clientId = JsonNodeUtils.findStringValue(root, "clientId");
Assert.notNull(clientId, "clientId cannot be null");
builder.clientId(clientId);
builder.redirectUri(JsonNodeUtils.findStringValue(root, "redirectUri")); builder.redirectUri(JsonNodeUtils.findStringValue(root, "redirectUri"));
builder.scopes(JsonNodeUtils.findValue(root, "scopes", JsonNodeUtils.STRING_SET, context)); builder.scopes(JsonNodeUtils.findValue(root, "scopes", JsonNodeUtils.STRING_SET, context));
builder.state(JsonNodeUtils.findStringValue(root, "state")); builder.state(JsonNodeUtils.findStringValue(root, "state"));
builder.additionalParameters( Map<String, Object> additionalParameters = JsonNodeUtils.findValue(root, "additionalParameters",
JsonNodeUtils.findValue(root, "additionalParameters", JsonNodeUtils.STRING_OBJECT_MAP, context)); JsonNodeUtils.STRING_OBJECT_MAP, context);
builder.authorizationRequestUri(JsonNodeUtils.findStringValue(root, "authorizationRequestUri")); builder.additionalParameters((additionalParameters != null) ? additionalParameters : Collections.emptyMap());
builder.attributes(JsonNodeUtils.findValue(root, "attributes", JsonNodeUtils.STRING_OBJECT_MAP, context)); String authorizationRequestUri = JsonNodeUtils.findStringValue(root, "authorizationRequestUri");
if (authorizationRequestUri != null) {
builder.authorizationRequestUri(authorizationRequestUri);
}
Map<String, Object> attributes = JsonNodeUtils.findValue(root, "attributes", JsonNodeUtils.STRING_OBJECT_MAP,
context);
builder.attributes((attributes != null) ? attributes : Collections.emptyMap());
return builder.build(); return builder.build();
} }
@@ -66,7 +82,10 @@ final class OAuth2AuthorizationRequestDeserializer extends ValueDeserializer<OAu
AuthorizationGrantType.class); AuthorizationGrantType.class);
} }
private static AuthorizationGrantType convertAuthorizationGrantType(JsonNode jsonNode) { private static @Nullable AuthorizationGrantType convertAuthorizationGrantType(@Nullable JsonNode jsonNode) {
if (jsonNode == null) {
return null;
}
String value = JsonNodeUtils.findStringValue(jsonNode, "value"); String value = JsonNodeUtils.findStringValue(jsonNode, "value");
if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equalsIgnoreCase(value)) { if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equalsIgnoreCase(value)) {
return AuthorizationGrantType.AUTHORIZATION_CODE; return AuthorizationGrantType.AUTHORIZATION_CODE;
@@ -0,0 +1,23 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Jackson 3 ({@code tools.jackson}) serialization support for authorization server types.
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.jackson;
import org.jspecify.annotations.NullMarked;
@@ -22,6 +22,7 @@ import java.util.Set;
import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.jspecify.annotations.Nullable;
/** /**
* Utility class for {@code JsonNode}. * Utility class for {@code JsonNode}.
@@ -41,7 +42,7 @@ abstract class JsonNodeUtils {
static final TypeReference<Map<String, Object>> STRING_OBJECT_MAP = new TypeReference<>() { static final TypeReference<Map<String, Object>> STRING_OBJECT_MAP = new TypeReference<>() {
}; };
static String findStringValue(JsonNode jsonNode, String fieldName) { static @Nullable String findStringValue(@Nullable JsonNode jsonNode, String fieldName) {
if (jsonNode == null) { if (jsonNode == null) {
return null; return null;
} }
@@ -49,7 +50,7 @@ abstract class JsonNodeUtils {
return (value != null && value.isTextual()) ? value.asText() : null; return (value != null && value.isTextual()) ? value.asText() : null;
} }
static <T> T findValue(JsonNode jsonNode, String fieldName, TypeReference<T> valueTypeReference, static <T> @Nullable T findValue(@Nullable JsonNode jsonNode, String fieldName, TypeReference<T> valueTypeReference,
ObjectMapper mapper) { ObjectMapper mapper) {
if (jsonNode == null) { if (jsonNode == null) {
return null; return null;
@@ -58,7 +59,7 @@ abstract class JsonNodeUtils {
return (value != null && value.isContainerNode()) ? mapper.convertValue(value, valueTypeReference) : null; return (value != null && value.isContainerNode()) ? mapper.convertValue(value, valueTypeReference) : null;
} }
static JsonNode findObjectNode(JsonNode jsonNode, String fieldName) { static @Nullable JsonNode findObjectNode(@Nullable JsonNode jsonNode, String fieldName) {
if (jsonNode == null) { if (jsonNode == null) {
return null; return null;
} }
@@ -17,6 +17,7 @@
package org.springframework.security.oauth2.server.authorization.jackson2; package org.springframework.security.oauth2.server.authorization.jackson2;
import java.io.IOException; import java.io.IOException;
import java.util.Map;
import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.JsonParser;
@@ -24,10 +25,12 @@ import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest.Builder; import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest.Builder;
import org.springframework.util.Assert;
/** /**
* A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}. * A {@code JsonDeserializer} for {@link OAuth2AuthorizationRequest}.
@@ -57,27 +60,42 @@ final class OAuth2AuthorizationRequestDeserializer extends JsonDeserializer<OAut
AuthorizationGrantType authorizationGrantType = convertAuthorizationGrantType( AuthorizationGrantType authorizationGrantType = convertAuthorizationGrantType(
JsonNodeUtils.findObjectNode(root, "authorizationGrantType")); JsonNodeUtils.findObjectNode(root, "authorizationGrantType"));
Builder builder = getBuilder(parser, authorizationGrantType); Builder builder = getBuilder(parser, authorizationGrantType);
builder.authorizationUri(JsonNodeUtils.findStringValue(root, "authorizationUri")); String authorizationUri = JsonNodeUtils.findStringValue(root, "authorizationUri");
builder.clientId(JsonNodeUtils.findStringValue(root, "clientId")); Assert.notNull(authorizationUri, "authorizationUri cannot be null");
builder.authorizationUri(authorizationUri);
String clientId = JsonNodeUtils.findStringValue(root, "clientId");
Assert.notNull(clientId, "clientId cannot be null");
builder.clientId(clientId);
builder.redirectUri(JsonNodeUtils.findStringValue(root, "redirectUri")); builder.redirectUri(JsonNodeUtils.findStringValue(root, "redirectUri"));
builder.scopes(JsonNodeUtils.findValue(root, "scopes", JsonNodeUtils.STRING_SET, mapper)); builder.scopes(JsonNodeUtils.findValue(root, "scopes", JsonNodeUtils.STRING_SET, mapper));
builder.state(JsonNodeUtils.findStringValue(root, "state")); builder.state(JsonNodeUtils.findStringValue(root, "state"));
builder.additionalParameters( Map<String, Object> additionalParameters = JsonNodeUtils.findValue(root, "additionalParameters",
JsonNodeUtils.findValue(root, "additionalParameters", JsonNodeUtils.STRING_OBJECT_MAP, mapper)); JsonNodeUtils.STRING_OBJECT_MAP, mapper);
builder.authorizationRequestUri(JsonNodeUtils.findStringValue(root, "authorizationRequestUri")); if (additionalParameters != null) {
builder.attributes(JsonNodeUtils.findValue(root, "attributes", JsonNodeUtils.STRING_OBJECT_MAP, mapper)); builder.additionalParameters(additionalParameters);
}
String authorizationRequestUri = JsonNodeUtils.findStringValue(root, "authorizationRequestUri");
if (authorizationRequestUri != null) {
builder.authorizationRequestUri(authorizationRequestUri);
}
Map<String, Object> attributes = JsonNodeUtils.findValue(root, "attributes", JsonNodeUtils.STRING_OBJECT_MAP,
mapper);
if (attributes != null) {
builder.attributes(attributes);
}
return builder.build(); return builder.build();
} }
private Builder getBuilder(JsonParser parser, AuthorizationGrantType authorizationGrantType) private Builder getBuilder(JsonParser parser, @Nullable AuthorizationGrantType authorizationGrantType)
throws JsonParseException { throws JsonParseException {
if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(authorizationGrantType)) { if (authorizationGrantType != null
&& authorizationGrantType.equals(AuthorizationGrantType.AUTHORIZATION_CODE)) {
return OAuth2AuthorizationRequest.authorizationCode(); return OAuth2AuthorizationRequest.authorizationCode();
} }
throw new JsonParseException(parser, "Invalid authorizationGrantType"); throw new JsonParseException(parser, "Invalid authorizationGrantType");
} }
private static AuthorizationGrantType convertAuthorizationGrantType(JsonNode jsonNode) { private static @Nullable AuthorizationGrantType convertAuthorizationGrantType(@Nullable JsonNode jsonNode) {
String value = JsonNodeUtils.findStringValue(jsonNode, "value"); String value = JsonNodeUtils.findStringValue(jsonNode, "value");
if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equalsIgnoreCase(value)) { if (AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equalsIgnoreCase(value)) {
return AuthorizationGrantType.AUTHORIZATION_CODE; return AuthorizationGrantType.AUTHORIZATION_CODE;
@@ -0,0 +1,24 @@
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Jackson 2 ({@code com.fasterxml.jackson}) serialization support for authorization
* server types (deprecated in favor of {@code jackson}).
*/
@NullMarked
package org.springframework.security.oauth2.server.authorization.jackson2;
import org.jspecify.annotations.NullMarked;
@@ -19,6 +19,8 @@ package org.springframework.security.oauth2.server.authorization.oidc;
import java.net.URL; import java.net.URL;
import java.util.List; import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.ClaimAccessor; import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod; import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcIdToken;
@@ -53,7 +55,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* to after a logout has been performed. * to after a logout has been performed.
* @return the post logout redirection {@code URI} values used by the Client * @return the post logout redirection {@code URI} values used by the Client
*/ */
default List<String> getPostLogoutRedirectUris() { default @Nullable List<String> getPostLogoutRedirectUris() {
return getClaimAsStringList(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS); return getClaimAsStringList(OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS);
} }
@@ -66,7 +68,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* @return the {@link JwsAlgorithm JWS} algorithm that must be used for signing the * @return the {@link JwsAlgorithm JWS} algorithm that must be used for signing the
* {@link Jwt JWT} used to authenticate the Client at the Token Endpoint * {@link Jwt JWT} used to authenticate the Client at the Token Endpoint
*/ */
default String getTokenEndpointAuthenticationSigningAlgorithm() { default @Nullable String getTokenEndpointAuthenticationSigningAlgorithm() {
return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG); return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG);
} }
@@ -77,7 +79,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* @return the {@link SignatureAlgorithm JWS} algorithm required for signing the * @return the {@link SignatureAlgorithm JWS} algorithm required for signing the
* {@link OidcIdToken ID Token} issued to the Client * {@link OidcIdToken ID Token} issued to the Client
*/ */
default String getIdTokenSignedResponseAlgorithm() { default @Nullable String getIdTokenSignedResponseAlgorithm() {
return getClaimAsString(OidcClientMetadataClaimNames.ID_TOKEN_SIGNED_RESPONSE_ALG); return getClaimAsString(OidcClientMetadataClaimNames.ID_TOKEN_SIGNED_RESPONSE_ALG);
} }
@@ -87,7 +89,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* @return the Registration Access Token that can be used at the Client Configuration * @return the Registration Access Token that can be used at the Client Configuration
* Endpoint * Endpoint
*/ */
default String getRegistrationAccessToken() { default @Nullable String getRegistrationAccessToken() {
return getClaimAsString(OidcClientMetadataClaimNames.REGISTRATION_ACCESS_TOKEN); return getClaimAsString(OidcClientMetadataClaimNames.REGISTRATION_ACCESS_TOKEN);
} }
@@ -97,7 +99,7 @@ public interface OidcClientMetadataClaimAccessor extends OAuth2ClientMetadataCla
* @return the {@code URL} of the Client Configuration Endpoint where the Registration * @return the {@code URL} of the Client Configuration Endpoint where the Registration
* Access Token can be used * Access Token can be used
*/ */
default URL getRegistrationClientUrl() { default @Nullable URL getRegistrationClientUrl() {
return getClaimAsURL(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI); return getClaimAsURL(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI);
} }
@@ -19,11 +19,14 @@ package org.springframework.security.oauth2.server.authorization.oidc;
import java.net.URL; import java.net.URL;
import java.util.List; import java.util.List;
import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.ClaimAccessor; import org.springframework.security.oauth2.core.ClaimAccessor;
import org.springframework.security.oauth2.core.oidc.OidcIdToken; import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm; import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationServerMetadataClaimAccessor; import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationServerMetadataClaimAccessor;
import org.springframework.util.Assert;
/** /**
* A {@link ClaimAccessor} for the "claims" that can be returned in the OpenID Provider * A {@link ClaimAccessor} for the "claims" that can be returned in the OpenID Provider
@@ -47,7 +50,9 @@ public interface OidcProviderMetadataClaimAccessor extends OAuth2AuthorizationSe
* @return the Subject Identifier types supported * @return the Subject Identifier types supported
*/ */
default List<String> getSubjectTypes() { default List<String> getSubjectTypes() {
return getClaimAsStringList(OidcProviderMetadataClaimNames.SUBJECT_TYPES_SUPPORTED); List<String> subjectTypes = getClaimAsStringList(OidcProviderMetadataClaimNames.SUBJECT_TYPES_SUPPORTED);
Assert.notNull(subjectTypes, "subjectTypes cannot be null");
return subjectTypes;
} }
/** /**
@@ -58,7 +63,10 @@ public interface OidcProviderMetadataClaimAccessor extends OAuth2AuthorizationSe
* {@link OidcIdToken ID Token} * {@link OidcIdToken ID Token}
*/ */
default List<String> getIdTokenSigningAlgorithms() { default List<String> getIdTokenSigningAlgorithms() {
return getClaimAsStringList(OidcProviderMetadataClaimNames.ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED); List<String> idTokenSigningAlgorithms = getClaimAsStringList(
OidcProviderMetadataClaimNames.ID_TOKEN_SIGNING_ALG_VALUES_SUPPORTED);
Assert.notNull(idTokenSigningAlgorithms, "idTokenSigningAlgorithms cannot be null");
return idTokenSigningAlgorithms;
} }
/** /**
@@ -66,7 +74,7 @@ public interface OidcProviderMetadataClaimAccessor extends OAuth2AuthorizationSe
* {@code (userinfo_endpoint)}. * {@code (userinfo_endpoint)}.
* @return the {@code URL} of the OpenID Connect 1.0 UserInfo Endpoint * @return the {@code URL} of the OpenID Connect 1.0 UserInfo Endpoint
*/ */
default URL getUserInfoEndpoint() { default @Nullable URL getUserInfoEndpoint() {
return getClaimAsURL(OidcProviderMetadataClaimNames.USER_INFO_ENDPOINT); return getClaimAsURL(OidcProviderMetadataClaimNames.USER_INFO_ENDPOINT);
} }
@@ -76,7 +84,9 @@ public interface OidcProviderMetadataClaimAccessor extends OAuth2AuthorizationSe
* @return the {@code URL} of the OpenID Connect 1.0 End Session Endpoint * @return the {@code URL} of the OpenID Connect 1.0 End Session Endpoint
*/ */
default URL getEndSessionEndpoint() { default URL getEndSessionEndpoint() {
return getClaimAsURL(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT); URL endSessionEndpoint = getClaimAsURL(OidcProviderMetadataClaimNames.END_SESSION_ENDPOINT);
Assert.notNull(endSessionEndpoint, "endSessionEndpoint cannot be null");
return endSessionEndpoint;
} }
} }
@@ -18,10 +18,12 @@ package org.springframework.security.oauth2.server.authorization.oidc.authentica
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Map;
import java.util.Set; import java.util.Set;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
@@ -99,7 +101,7 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth
} }
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = (OidcClientRegistrationAuthenticationToken) authentication; OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = (OidcClientRegistrationAuthenticationToken) authentication;
if (!StringUtils.hasText(clientRegistrationAuthentication.getClientId())) { if (!StringUtils.hasText(clientRegistrationAuthentication.getClientId())) {
@@ -132,6 +134,7 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth
} }
OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken(); OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken();
Assert.notNull(authorizedAccessToken, "authorizedAccessToken cannot be null");
if (!authorizedAccessToken.isActive()) { if (!authorizedAccessToken.isActive()) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
} }
@@ -149,8 +152,9 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth
OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication, OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication,
OAuth2Authorization authorization) { OAuth2Authorization authorization) {
RegisteredClient registeredClient = this.registeredClientRepository String clientId = clientRegistrationAuthentication.getClientId();
.findByClientId(clientRegistrationAuthentication.getClientId()); Assert.hasText(clientId, "clientId cannot be empty");
RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
if (registeredClient == null) { if (registeredClient == null) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
} }
@@ -176,9 +180,11 @@ public final class OidcClientConfigurationAuthenticationProvider implements Auth
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static void checkScope(OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken, private static void checkScope(OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken,
Set<String> requiredScope) { Set<String> requiredScope) {
Map<String, Object> claims = authorizedAccessToken.getClaims();
Assert.notNull(claims, "claims cannot be null");
Collection<String> authorizedScope = Collections.emptySet(); Collection<String> authorizedScope = Collections.emptySet();
if (authorizedAccessToken.getClaims().containsKey(OAuth2ParameterNames.SCOPE)) { if (claims.containsKey(OAuth2ParameterNames.SCOPE)) {
authorizedScope = (Collection<String>) authorizedAccessToken.getClaims().get(OAuth2ParameterNames.SCOPE); authorizedScope = (Collection<String>) claims.get(OAuth2ParameterNames.SCOPE);
} }
if (!authorizedScope.containsAll(requiredScope)) { if (!authorizedScope.containsAll(requiredScope)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
@@ -27,6 +27,7 @@ import java.util.Set;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.core.convert.converter.Converter; import org.springframework.core.convert.converter.Converter;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
@@ -124,7 +125,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
} }
@Override @Override
public Authentication authenticate(Authentication authentication) throws AuthenticationException { public @Nullable Authentication authenticate(Authentication authentication) throws AuthenticationException {
OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = (OidcClientRegistrationAuthenticationToken) authentication; OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication = (OidcClientRegistrationAuthenticationToken) authentication;
if (clientRegistrationAuthentication.getClientRegistration() == null) { if (clientRegistrationAuthentication.getClientRegistration() == null) {
@@ -157,6 +158,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
} }
OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken(); OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken();
Assert.notNull(authorizedAccessToken, "authorizedAccessToken cannot be null");
if (!authorizedAccessToken.isActive()) { if (!authorizedAccessToken.isActive()) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
} }
@@ -210,18 +212,24 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication, OidcClientRegistrationAuthenticationToken clientRegistrationAuthentication,
OAuth2Authorization authorization) { OAuth2Authorization authorization) {
if (!isValidRedirectUris(clientRegistrationAuthentication.getClientRegistration().getRedirectUris())) { OidcClientRegistration clientRegistrationRequest = clientRegistrationAuthentication.getClientRegistration();
Assert.notNull(clientRegistrationRequest, "clientRegistration cannot be null");
List<String> redirectUris = (clientRegistrationRequest.getRedirectUris() != null)
? clientRegistrationRequest.getRedirectUris() : Collections.emptyList();
if (!isValidRedirectUris(redirectUris)) {
throwInvalidClientRegistration(OAuth2ErrorCodes.INVALID_REDIRECT_URI, throwInvalidClientRegistration(OAuth2ErrorCodes.INVALID_REDIRECT_URI,
OidcClientMetadataClaimNames.REDIRECT_URIS); OidcClientMetadataClaimNames.REDIRECT_URIS);
} }
if (!isValidRedirectUris( List<String> postLogoutRedirectUris = (clientRegistrationRequest.getPostLogoutRedirectUris() != null)
clientRegistrationAuthentication.getClientRegistration().getPostLogoutRedirectUris())) { ? clientRegistrationRequest.getPostLogoutRedirectUris() : Collections.emptyList();
if (!isValidRedirectUris(postLogoutRedirectUris)) {
throwInvalidClientRegistration("invalid_client_metadata", throwInvalidClientRegistration("invalid_client_metadata",
OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS); OidcClientMetadataClaimNames.POST_LOGOUT_REDIRECT_URIS);
} }
if (!isValidTokenEndpointAuthenticationMethod(clientRegistrationAuthentication.getClientRegistration())) { if (!isValidTokenEndpointAuthenticationMethod(clientRegistrationRequest)) {
throwInvalidClientRegistration("invalid_client_metadata", throwInvalidClientRegistration("invalid_client_metadata",
OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD); OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD);
} }
@@ -230,8 +238,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
this.logger.trace("Validated client registration request parameters"); this.logger.trace("Validated client registration request parameters");
} }
RegisteredClient registeredClient = this.registeredClientConverter RegisteredClient registeredClient = this.registeredClientConverter.convert(clientRegistrationRequest);
.convert(clientRegistrationAuthentication.getClientRegistration());
if (StringUtils.hasText(registeredClient.getClientSecret())) { if (StringUtils.hasText(registeredClient.getClientSecret())) {
// Encode the client secret // Encode the client secret
@@ -240,8 +247,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
.build(); .build();
this.registeredClientRepository.save(updatedRegisteredClient); this.registeredClientRepository.save(updatedRegisteredClient);
if (ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue() if (ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()
.equals(clientRegistrationAuthentication.getClientRegistration() .equals(clientRegistrationRequest.getTokenEndpointAuthenticationMethod())) {
.getTokenEndpointAuthenticationMethod())) {
// gh-1344 Return the hashed client_secret // gh-1344 Return the hashed client_secret
registeredClient = updatedRegisteredClient; registeredClient = updatedRegisteredClient;
} }
@@ -257,8 +263,10 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
OAuth2Authorization registeredClientAuthorization = registerAccessToken(registeredClient); OAuth2Authorization registeredClientAuthorization = registerAccessToken(registeredClient);
// Invalidate the "initial" access token as it can only be used once // Invalidate the "initial" access token as it can only be used once
OAuth2Authorization.Token<OAuth2AccessToken> initialAccessToken = authorization.getAccessToken();
Assert.notNull(initialAccessToken, "initialAccessToken cannot be null");
OAuth2Authorization.Builder builder = OAuth2Authorization.from(authorization) OAuth2Authorization.Builder builder = OAuth2Authorization.from(authorization)
.invalidate(authorization.getAccessToken().getToken()); .invalidate(initialAccessToken.getToken());
if (authorization.getRefreshToken() != null) { if (authorization.getRefreshToken() != null) {
builder.invalidate(authorization.getRefreshToken().getToken()); builder.invalidate(authorization.getRefreshToken().getToken());
} }
@@ -271,8 +279,11 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
Map<String, Object> clientRegistrationClaims = this.clientRegistrationConverter.convert(registeredClient) Map<String, Object> clientRegistrationClaims = this.clientRegistrationConverter.convert(registeredClient)
.getClaims(); .getClaims();
OAuth2Authorization.Token<OAuth2AccessToken> registrationAccessToken = registeredClientAuthorization
.getAccessToken();
Assert.notNull(registrationAccessToken, "registrationAccessToken cannot be null");
OidcClientRegistration clientRegistration = OidcClientRegistration.withClaims(clientRegistrationClaims) OidcClientRegistration clientRegistration = OidcClientRegistration.withClaims(clientRegistrationClaims)
.registrationAccessToken(registeredClientAuthorization.getAccessToken().getToken().getTokenValue()) .registrationAccessToken(registrationAccessToken.getToken().getTokenValue())
.build(); .build();
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -338,9 +349,11 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
private static void checkScope(OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken, private static void checkScope(OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken,
Set<String> requiredScope) { Set<String> requiredScope) {
Map<String, Object> claims = authorizedAccessToken.getClaims();
Assert.notNull(claims, "claims cannot be null");
Collection<String> authorizedScope = Collections.emptySet(); Collection<String> authorizedScope = Collections.emptySet();
if (authorizedAccessToken.getClaims().containsKey(OAuth2ParameterNames.SCOPE)) { if (claims.containsKey(OAuth2ParameterNames.SCOPE)) {
authorizedScope = (Collection<String>) authorizedAccessToken.getClaims().get(OAuth2ParameterNames.SCOPE); authorizedScope = (Collection<String>) claims.get(OAuth2ParameterNames.SCOPE);
} }
if (!authorizedScope.containsAll(requiredScope)) { if (!authorizedScope.containsAll(requiredScope)) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INSUFFICIENT_SCOPE);
@@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.oidc.authentica
import java.io.Serial; import java.io.Serial;
import java.util.Collections; import java.util.Collections;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration; import org.springframework.security.oauth2.server.authorization.oidc.OidcClientRegistration;
@@ -44,9 +45,9 @@ public class OidcClientRegistrationAuthenticationToken extends AbstractAuthentic
private final Authentication principal; private final Authentication principal;
private final OidcClientRegistration clientRegistration; private final @Nullable OidcClientRegistration clientRegistration;
private final String clientId; private final @Nullable String clientId;
/** /**
* Constructs an {@code OidcClientRegistrationAuthenticationToken} using the provided * Constructs an {@code OidcClientRegistrationAuthenticationToken} using the provided
@@ -95,7 +96,7 @@ public class OidcClientRegistrationAuthenticationToken extends AbstractAuthentic
* Returns the client registration. * Returns the client registration.
* @return the client registration * @return the client registration
*/ */
public OidcClientRegistration getClientRegistration() { public @Nullable OidcClientRegistration getClientRegistration() {
return this.clientRegistration; return this.clientRegistration;
} }
@@ -103,8 +104,7 @@ public class OidcClientRegistrationAuthenticationToken extends AbstractAuthentic
* Returns the client identifier. * Returns the client identifier.
* @return the client identifier * @return the client identifier
*/ */
@Nullable public @Nullable String getClientId() {
public String getClientId() {
return this.clientId; return this.clientId;
} }
@@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationContext; import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationContext;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.Assert; import org.springframework.util.Assert;
@@ -46,9 +47,8 @@ public final class OidcLogoutAuthenticationContext implements OAuth2Authenticati
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Nullable
@Override @Override
public <V> V get(Object key) { public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null; return hasKey(key) ? (V) this.context.get(key) : null;
} }
@@ -63,7 +63,9 @@ public final class OidcLogoutAuthenticationContext implements OAuth2Authenticati
* @return the {@link RegisteredClient} * @return the {@link RegisteredClient}
*/ */
public RegisteredClient getRegisteredClient() { public RegisteredClient getRegisteredClient() {
return get(RegisteredClient.class); RegisteredClient registeredClient = get(RegisteredClient.class);
Assert.notNull(registeredClient, "registeredClient cannot be null");
return registeredClient;
} }
/** /**
@@ -26,6 +26,7 @@ import java.util.function.Consumer;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AuthenticationProvider; import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@@ -99,7 +100,7 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
OAuth2Authorization authorization = this.authorizationService OAuth2Authorization authorization = this.authorizationService
.findByToken(oidcLogoutAuthentication.getIdTokenHint(), ID_TOKEN_TOKEN_TYPE); .findByToken(oidcLogoutAuthentication.getIdTokenHint(), ID_TOKEN_TOKEN_TYPE);
if (authorization == null) { if (authorization == null) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint"); throw createException(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint");
} }
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
@@ -107,13 +108,15 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
} }
OAuth2Authorization.Token<OidcIdToken> authorizedIdToken = authorization.getToken(OidcIdToken.class); OAuth2Authorization.Token<OidcIdToken> authorizedIdToken = authorization.getToken(OidcIdToken.class);
Assert.notNull(authorizedIdToken, "authorizedIdToken cannot be null");
if (authorizedIdToken.isInvalidated() || authorizedIdToken.isBeforeUse()) { if (authorizedIdToken.isInvalidated() || authorizedIdToken.isBeforeUse()) {
// Expired ID Token should be accepted // Expired ID Token should be accepted
throwError(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint"); throw createException(OAuth2ErrorCodes.INVALID_TOKEN, "id_token_hint");
} }
RegisteredClient registeredClient = this.registeredClientRepository RegisteredClient registeredClient = this.registeredClientRepository
.findById(authorization.getRegisteredClientId()); .findById(authorization.getRegisteredClientId());
Assert.notNull(registeredClient, "registeredClient cannot be null");
if (this.logger.isTraceEnabled()) { if (this.logger.isTraceEnabled()) {
this.logger.trace("Retrieved registered client"); this.logger.trace("Retrieved registered client");
@@ -124,11 +127,11 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
// Validate client identity // Validate client identity
List<String> audClaim = idToken.getAudience(); List<String> audClaim = idToken.getAudience();
if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) { if (CollectionUtils.isEmpty(audClaim) || !audClaim.contains(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD); throw createException(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.AUD);
} }
if (StringUtils.hasText(oidcLogoutAuthentication.getClientId()) if (StringUtils.hasText(oidcLogoutAuthentication.getClientId())
&& !oidcLogoutAuthentication.getClientId().equals(registeredClient.getClientId())) { && !oidcLogoutAuthentication.getClientId().equals(registeredClient.getClientId())) {
throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID); throw createException(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
} }
OidcLogoutAuthenticationContext context = OidcLogoutAuthenticationContext.with(oidcLogoutAuthentication) OidcLogoutAuthenticationContext context = OidcLogoutAuthenticationContext.with(oidcLogoutAuthentication)
@@ -144,9 +147,10 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
if (oidcLogoutAuthentication.isPrincipalAuthenticated()) { if (oidcLogoutAuthentication.isPrincipalAuthenticated()) {
Authentication currentUserPrincipal = (Authentication) oidcLogoutAuthentication.getPrincipal(); Authentication currentUserPrincipal = (Authentication) oidcLogoutAuthentication.getPrincipal();
Authentication authorizedUserPrincipal = authorization.getAttribute(Principal.class.getName()); Authentication authorizedUserPrincipal = authorization.getAttribute(Principal.class.getName());
Assert.notNull(authorizedUserPrincipal, "authorizedUserPrincipal cannot be null");
if (!StringUtils.hasText(idToken.getSubject()) if (!StringUtils.hasText(idToken.getSubject())
|| !currentUserPrincipal.getName().equals(authorizedUserPrincipal.getName())) { || !currentUserPrincipal.getName().equals(authorizedUserPrincipal.getName())) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.SUB); throw createException(OAuth2ErrorCodes.INVALID_TOKEN, IdTokenClaimNames.SUB);
} }
// Check for active session // Check for active session
@@ -166,7 +170,7 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
String sidClaim = idToken.getClaim("sid"); String sidClaim = idToken.getClaim("sid");
if (!StringUtils.hasText(sidClaim) || !sidClaim.equals(sessionIdHash)) { if (!StringUtils.hasText(sidClaim) || !sidClaim.equals(sessionIdHash)) {
throwError(OAuth2ErrorCodes.INVALID_TOKEN, "sid"); throw createException(OAuth2ErrorCodes.INVALID_TOKEN, "sid");
} }
} }
} }
@@ -205,8 +209,10 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
this.authenticationValidator = authenticationValidator; this.authenticationValidator = authenticationValidator;
} }
private SessionInformation findSessionInformation(Authentication principal, String sessionId) { private @Nullable SessionInformation findSessionInformation(Authentication principal, String sessionId) {
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(principal.getPrincipal(), true); Object sessionPrincipal = principal.getPrincipal();
Assert.notNull(sessionPrincipal, "sessionPrincipal cannot be null");
List<SessionInformation> sessions = this.sessionRegistry.getAllSessions(sessionPrincipal, true);
SessionInformation sessionInformation = null; SessionInformation sessionInformation = null;
if (!CollectionUtils.isEmpty(sessions)) { if (!CollectionUtils.isEmpty(sessions)) {
for (SessionInformation session : sessions) { for (SessionInformation session : sessions) {
@@ -219,10 +225,10 @@ public final class OidcLogoutAuthenticationProvider implements AuthenticationPro
return sessionInformation; return sessionInformation;
} }
private static void throwError(String errorCode, String parameterName) { private static OAuth2AuthenticationException createException(String errorCode, String parameterName) {
OAuth2Error error = new OAuth2Error(errorCode, "OpenID Connect 1.0 Logout Request Parameter: " + parameterName, OAuth2Error error = new OAuth2Error(errorCode, "OpenID Connect 1.0 Logout Request Parameter: " + parameterName,
"https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ValidationAndErrorHandling"); "https://openid.net/specs/openid-connect-rpinitiated-1_0.html#ValidationAndErrorHandling");
throw new OAuth2AuthenticationException(error); return new OAuth2AuthenticationException(error);
} }
private static String createHash(String value) throws NoSuchAlgorithmException { private static String createHash(String value) throws NoSuchAlgorithmException {
@@ -19,7 +19,8 @@ package org.springframework.security.oauth2.server.authorization.oidc.authentica
import java.io.Serial; import java.io.Serial;
import java.util.Collections; import java.util.Collections;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
@@ -42,17 +43,17 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
private final String idTokenHint; private final String idTokenHint;
private final OidcIdToken idToken; private final @Nullable OidcIdToken idToken;
private final Authentication principal; private final Authentication principal;
private final String sessionId; private final @Nullable String sessionId;
private final String clientId; private final @Nullable String clientId;
private final String postLogoutRedirectUri; private final @Nullable String postLogoutRedirectUri;
private final String state; private final @Nullable String state;
/** /**
* Constructs an {@code OidcLogoutAuthenticationToken} using the provided parameters. * Constructs an {@code OidcLogoutAuthenticationToken} using the provided parameters.
@@ -147,8 +148,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* Returns the ID Token previously issued by the Provider to the Client. * Returns the ID Token previously issued by the Provider to the Client.
* @return the ID Token previously issued by the Provider to the Client * @return the ID Token previously issued by the Provider to the Client
*/ */
@Nullable public @Nullable OidcIdToken getIdToken() {
public OidcIdToken getIdToken() {
return this.idToken; return this.idToken;
} }
@@ -156,8 +156,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* Returns the End-User's current authenticated session identifier with the Provider. * Returns the End-User's current authenticated session identifier with the Provider.
* @return the End-User's current authenticated session identifier with the Provider * @return the End-User's current authenticated session identifier with the Provider
*/ */
@Nullable public @Nullable String getSessionId() {
public String getSessionId() {
return this.sessionId; return this.sessionId;
} }
@@ -165,8 +164,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* Returns the client identifier the ID Token was issued to. * Returns the client identifier the ID Token was issued to.
* @return the client identifier * @return the client identifier
*/ */
@Nullable public @Nullable String getClientId() {
public String getClientId() {
return this.clientId; return this.clientId;
} }
@@ -176,8 +174,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* @return the URI which the Client is requesting that the End-User's User Agent be * @return the URI which the Client is requesting that the End-User's User Agent be
* redirected to after a logout has been performed * redirected to after a logout has been performed
*/ */
@Nullable public @Nullable String getPostLogoutRedirectUri() {
public String getPostLogoutRedirectUri() {
return this.postLogoutRedirectUri; return this.postLogoutRedirectUri;
} }
@@ -187,8 +184,7 @@ public class OidcLogoutAuthenticationToken extends AbstractAuthenticationToken {
* @return the opaque value used by the Client to maintain state between the logout * @return the opaque value used by the Client to maintain state between the logout
* request and the callback to the {@link #getPostLogoutRedirectUri()} * request and the callback to the {@link #getPostLogoutRedirectUri()}
*/ */
@Nullable public @Nullable String getState() {
public String getState() {
return this.state; return this.state;
} }
@@ -21,7 +21,8 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Function; import java.util.function.Function;
import org.springframework.lang.Nullable; import org.jspecify.annotations.Nullable;
import org.springframework.security.oauth2.core.OAuth2AccessToken; import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
import org.springframework.security.oauth2.server.authorization.OAuth2Authorization; import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
@@ -48,9 +49,8 @@ public final class OidcUserInfoAuthenticationContext implements OAuth2Authentica
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Nullable
@Override @Override
public <V> V get(Object key) { public <V> @Nullable V get(Object key) {
return hasKey(key) ? (V) this.context.get(key) : null; return hasKey(key) ? (V) this.context.get(key) : null;
} }
@@ -65,7 +65,9 @@ public final class OidcUserInfoAuthenticationContext implements OAuth2Authentica
* @return the {@link OAuth2AccessToken} * @return the {@link OAuth2AccessToken}
*/ */
public OAuth2AccessToken getAccessToken() { public OAuth2AccessToken getAccessToken() {
return get(OAuth2AccessToken.class); OAuth2AccessToken accessToken = get(OAuth2AccessToken.class);
Assert.notNull(accessToken, "accessToken cannot be null");
return accessToken;
} }
/** /**
@@ -73,7 +75,9 @@ public final class OidcUserInfoAuthenticationContext implements OAuth2Authentica
* @return the {@link OAuth2Authorization} * @return the {@link OAuth2Authorization}
*/ */
public OAuth2Authorization getAuthorization() { public OAuth2Authorization getAuthorization() {
return get(OAuth2Authorization.class); OAuth2Authorization authorization = get(OAuth2Authorization.class);
Assert.notNull(authorization, "authorization cannot be null");
return authorization;
} }
/** /**
@@ -98,6 +98,7 @@ public final class OidcUserInfoAuthenticationProvider implements AuthenticationP
} }
OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken(); OAuth2Authorization.Token<OAuth2AccessToken> authorizedAccessToken = authorization.getAccessToken();
Assert.notNull(authorizedAccessToken, "authorizedAccessToken cannot be null");
if (!authorizedAccessToken.isActive()) { if (!authorizedAccessToken.isActive()) {
throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN); throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_TOKEN);
} }
@@ -191,7 +192,9 @@ public final class OidcUserInfoAuthenticationProvider implements AuthenticationP
@Override @Override
public OidcUserInfo apply(OidcUserInfoAuthenticationContext authenticationContext) { public OidcUserInfo apply(OidcUserInfoAuthenticationContext authenticationContext) {
OAuth2Authorization authorization = authenticationContext.getAuthorization(); OAuth2Authorization authorization = authenticationContext.getAuthorization();
OidcIdToken idToken = authorization.getToken(OidcIdToken.class).getToken(); OAuth2Authorization.Token<OidcIdToken> authorizedIdToken = authorization.getToken(OidcIdToken.class);
Assert.notNull(authorizedIdToken, "authorizedIdToken cannot be null");
OidcIdToken idToken = authorizedIdToken.getToken();
OAuth2AccessToken accessToken = authenticationContext.getAccessToken(); OAuth2AccessToken accessToken = authenticationContext.getAccessToken();
Map<String, Object> scopeRequestedClaims = getClaimsRequestedByScope(idToken.getClaims(), Map<String, Object> scopeRequestedClaims = getClaimsRequestedByScope(idToken.getClaims(),
accessToken.getScopes()); accessToken.getScopes());
@@ -19,6 +19,8 @@ package org.springframework.security.oauth2.server.authorization.oidc.authentica
import java.io.Serial; import java.io.Serial;
import java.util.Collections; import java.util.Collections;
import org.jspecify.annotations.Nullable;
import org.springframework.security.authentication.AbstractAuthenticationToken; import org.springframework.security.authentication.AbstractAuthenticationToken;
import org.springframework.security.core.Authentication; import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.core.oidc.OidcUserInfo; import org.springframework.security.oauth2.core.oidc.OidcUserInfo;
@@ -40,7 +42,7 @@ public class OidcUserInfoAuthenticationToken extends AbstractAuthenticationToken
private final Authentication principal; private final Authentication principal;
private final OidcUserInfo userInfo; private final @Nullable OidcUserInfo userInfo;
/** /**
* Constructs an {@code OidcUserInfoAuthenticationToken} using the provided * Constructs an {@code OidcUserInfoAuthenticationToken} using the provided
@@ -82,9 +84,9 @@ public class OidcUserInfoAuthenticationToken extends AbstractAuthenticationToken
/** /**
* Returns the UserInfo claims. * Returns the UserInfo claims.
* @return the UserInfo claims * @return the UserInfo claims, or {@code null} if not provided
*/ */
public OidcUserInfo getUserInfo() { public @Nullable OidcUserInfo getUserInfo() {
return this.userInfo; return this.userInfo;
} }

Some files were not shown because too many files have changed in this diff Show More