DynamicLoginHandler.java
/*
* Licensed to The Apereo Foundation under one or more contributor license
* agreements. See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
*
*
* The Apereo Foundation licenses this file to you under the Educational
* Community License, Version 2.0 (the "License"); you may not use this file
* except in compliance with the License. You may obtain a copy of the License
* at:
*
* http://opensource.org/licenses/ecl2.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*
*/
package org.opencastproject.security.jwt;
import org.opencastproject.security.api.Organization;
import org.opencastproject.security.api.SecurityService;
import org.opencastproject.security.api.UserDirectoryService;
import org.opencastproject.security.impl.jpa.JpaOrganization;
import org.opencastproject.security.impl.jpa.JpaRole;
import org.opencastproject.security.impl.jpa.JpaUserReference;
import org.opencastproject.security.util.SecurityUtil;
import org.opencastproject.userdirectory.api.UserReferenceProvider;
import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jwt.SignedJWT;
import org.apache.commons.lang3.StringUtils;
import org.osgi.service.component.annotations.Reference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.expression.MapAccessor;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.util.Assert;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
/**
* Dynamic login handler for JWTs.
*/
public class DynamicLoginHandler implements InitializingBean, JWTLoginHandler {
/** Logging facility. */
private static final Logger logger = LoggerFactory.getLogger(DynamicLoginHandler.class);
/** Spring security's user details manager. */
private UserDetailsService userDetailsService = null;
/** User directory service. */
private UserDirectoryService userDirectoryService = null;
/** User reference provider. */
private UserReferenceProvider userReferenceProvider = null;
/** Security service. */
private SecurityService securityService = null;
/** JWKS URL to use for JWT validation (asymmetric algorithms). */
private String jwksUrl = null;
/** The default time to live of cached JWK sets, in milliseconds. */
private int jwksTimeToLive = 1000 * 60 * 60;
/** The default refresh timeout of cached JWK sets, in milliseconds. */
private int jwksRefreshTimeout = 1000 * 60;
/** Secret to use for JWT validation (symmetric algorithms). */
private String secret = null;
/** Allowed algorithms with which a valid JWT may be signed ('alg' claim). */
private List<String> expectedAlgorithms = null;
/** Constraints that the claims of a valid JWT must fulfill. */
private List<String> claimConstraints = null;
/** Mapping used to extract the username from the JWT. */
private String usernameMapping = null;
/** Mapping used to extract the name from the JWT. */
private String nameMapping = null;
/** Mapping used to extract the email from the JWT. */
private String emailMapping = null;
/** Whether the built-in schema role mapping should be used. */
private boolean ocStandardRoleMappings = true;
/** Mapping used to extract roles from the JWT. */
private List<String> roleMappings = null;
/** Mapping used to extract roles from the JWT. */
private JWKSetProvider jwkProvider;
/** Size of the JWT cache. */
private int jwtCacheSize = 500;
/** Number of minutes validated JWTs will be cached before re-validating them. */
private int jwtCacheExpiresIn = 60;
/** Cache for validated JWTs. */
private Cache<String, CachedJWT> cache;
@Override
public void afterPropertiesSet() {
Assert.notNull(userDetailsService, "A UserDetailsService must be set");
Assert.notNull(userDirectoryService, "A UserDirectoryService must be set");
Assert.notNull(userReferenceProvider, "A UserReferenceProvider must be set");
Assert.notNull(securityService, "A SecurityService must be set");
Assert.isTrue(!(StringUtils.isNotBlank(jwksUrl) && StringUtils.isNotBlank(secret)),
"A JWKS URL and a secret cannot be set at the same time");
Assert.notEmpty(expectedAlgorithms, "Expected algorithms must be set");
Assert.notNull(claimConstraints, "Claim constraints must be set");
Assert.notNull(usernameMapping, "User name mapping must be set");
Assert.notNull(nameMapping, "Name mapping must be set");
Assert.notNull(emailMapping, "Email mapping must be set");
Assert.isTrue(roleMappings != null || ocStandardRoleMappings,
"Role mappings must be set if ocStandardRoleMappings is false");
if (StringUtils.isBlank(jwksUrl) && StringUtils.isBlank(secret)) {
logger.info("JWT login handler disabled as neither 'jwksUrl' nor 'secret' are set");
}
if (jwksUrl != null) {
jwkProvider = new JWKSetProvider(jwksUrl, jwksTimeToLive, jwksRefreshTimeout);
}
userReferenceProvider.setRoleProvider(new JWTRoleProvider(securityService, userReferenceProvider));
cache = CacheBuilder.newBuilder()
.maximumSize(jwtCacheSize)
.expireAfterWrite(jwtCacheExpiresIn, TimeUnit.MINUTES)
.build();
}
@Override
public String handleToken(String token) {
if (jwkProvider == null && secret == null) {
logger.debug("neither jwksURL nor secret set: ignoring JWT");
return null;
}
try {
String signature = extractSignature(token);
CachedJWT cachedJwt = cache.getIfPresent(signature);
if (cachedJwt == null) {
// JWT hasn't been cached before, so validate all claims
SignedJWT jwt = decodeAndValidate(token);
String username = extractUsername(jwt);
try {
if (userDetailsService.loadUserByUsername(username) != null) {
existingUserLogin(username, jwt);
}
} catch (UsernameNotFoundException e) {
newUserLogin(username, jwt);
}
userDirectoryService.invalidate(username);
cache.put(jwt.getSignature().toString(), new CachedJWT(jwt, username));
return username;
} else {
// JWT has been cached before, so only check if it has expired
if (cachedJwt.hasExpired()) {
cache.invalidate(signature);
throw new JOSEException("JWT token is not valid anymore");
}
logger.debug("Using decoded and validated JWT from cache");
return cachedJwt.getUsername();
}
} catch (ParseException | JOSEException exception) {
logger.debug(exception.getMessage());
}
return null;
}
/**
* Decodes and validates a JWT.
*
* @param token The JWT string.
* @return The decoded JWT.
* @throws JOSEException If the JWT fails to be validated.
*/
private SignedJWT decodeAndValidate(String token) throws ParseException, JOSEException {
SignedJWT jwt;
if (jwksUrl != null) {
jwt = JWTVerifier.verify(token, jwkProvider, claimConstraints);
} else {
jwt = JWTVerifier.verify(token, secret, claimConstraints);
}
if (!expectedAlgorithms.contains(jwt.getHeader().getAlgorithm().getName())) {
throw new JOSEException(
"JWT token was signed with an unexpected algorithm '" + jwt.getHeader().getAlgorithm() + "'"
);
}
return jwt;
}
/**
* Extracts the signature from a JWT.
*
* @param token The JWT string.
* @return The JWT's signature.
*/
private String extractSignature(String token) throws JOSEException {
String[] parts = token.split("\\.");
if (parts.length != 3) {
throw new JOSEException("Given token is not in a valid JWT format");
}
return parts[2];
}
/**
* Extracts the username from a decoded and validated JWT.
*
* @param jwt The decoded JWT.
* @return The username.
*/
private String extractUsername(SignedJWT jwt) throws ParseException {
String username = evaluateMapping(jwt, usernameMapping);
Assert.isTrue(StringUtils.isNotBlank(username), "Extracted username is blank");
return username;
}
/**
* Extracts the name from a decoded and validated JWT.
*
* @param jwt The decoded JWT.
* @return The name.
*/
private String extractName(SignedJWT jwt) throws ParseException {
String name = evaluateMapping(jwt, nameMapping);
Assert.isTrue(StringUtils.isNotBlank(name), "Extracted name is blank");
return name;
}
/**
* Extracts the email from a decoded and validated JWT.
*
* @param jwt The decoded JWT.
* @return The email.
*/
private String extractEmail(SignedJWT jwt) throws ParseException {
String email = evaluateMapping(jwt, emailMapping);
Assert.isTrue(StringUtils.isNotBlank(email), "Extracted email is blank");
return email;
}
/**
* Extracts the roles from a decoded and validated JWT.
*
* @param jwt The decoded JWT.
* @return The roles.
*/
private Set<JpaRole> extractRoles(SignedJWT jwt) throws ParseException {
JpaOrganization organization = fromOrganization(securityService.getOrganization());
Set<JpaRole> roles = new HashSet<>();
Consumer<String> addRole = (String role) -> {
if (StringUtils.isNotBlank(role)) {
roles.add(new JpaRole(role, organization));
}
};
// Evaluate the standard role mapping if specified
if (ocStandardRoleMappings) {
// Read `role` claim
try {
var rolesClaim = jwt.getJWTClaimsSet().getStringArrayClaim("roles");
if (rolesClaim != null) {
for (String r : rolesClaim) {
addRole.accept(r);
}
}
} catch (ParseException e) {
logger.debug("claim 'roles' is not an array of strings, ignoring");
}
// Read `oc` claim
try {
var ocClaim = jwt.getJWTClaimsSet().getJSONObjectClaim("oc");
if (ocClaim != null) {
for (var entry : ocClaim.entrySet()) {
var key = entry.getKey();
var parts = key.split(":", 2);
if (parts.length != 2) {
logger.debug("key in 'oc' claim does not start with 'x:' -> ignoring");
continue;
}
var type = parts[0];
var id = parts[1];
try {
for (var actionObj : (List<?>) entry.getValue()) {
var action = (String) actionObj;
if (action.isBlank()) {
continue;
}
if (type.equals("e")) {
addRole.accept(SecurityUtil.getEpisodeRoleId(id, action));
} else {
logger.debug("in 'oc' claim: granting access to item type '{}' is not yet supported", type);
}
}
} catch (ClassCastException e) {
logger.debug("value in 'oc' claim is not a string array -> ignoring");
continue;
}
}
}
} catch (ParseException e) {
logger.debug("claim 'oc' is not an array of strings, ignoring");
}
}
for (String mapping : (roleMappings == null ? new ArrayList<String>() : roleMappings)) {
ExpressionParser parser = new SpelExpressionParser();
Expression exp = parser.parseExpression(mapping);
StandardEvaluationContext ctx = new StandardEvaluationContext();
ctx.addPropertyAccessor(new MapAccessor());
Object value = exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims());
if (value != null) {
// We allow the expression to either return a string directly, or a list/array of strings.
if (value instanceof String) {
addRole.accept((String) value);
} else if (value.getClass().isArray()) {
for (var role : (Object[]) value) {
addRole.accept((String) role);
}
} else {
for (var role : (List<?>) value) {
addRole.accept((String) role);
}
}
}
}
Assert.notEmpty(roles, "No roles could be extracted");
return roles;
}
/**
* Evaluates a mapping given in SpEL on a decoded JWT.
*
* @param jwt The decoded JWT.
* @param mapping The mapping.
*
* @return The string evaluated from the mapping.
*/
private String evaluateMapping(SignedJWT jwt, String mapping) throws ParseException {
ExpressionParser parser = new SpelExpressionParser();
Expression exp = parser.parseExpression(mapping);
StandardEvaluationContext ctx = new StandardEvaluationContext();
ctx.addPropertyAccessor(new MapAccessor());
return exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims(), String.class);
}
/**
* Handles a new user login.
*
* @param username The username.
* @param jwt The decoded JWT.
*/
public void newUserLogin(String username, SignedJWT jwt) throws ParseException {
// Create a new user reference
JpaUserReference userReference = new JpaUserReference(username, extractName(jwt), extractEmail(jwt), MECH_JWT,
new Date(), fromOrganization(securityService.getOrganization()), extractRoles(jwt));
logger.debug("JWT user '{}' logged in for the first time", username);
userReferenceProvider.addUserReference(userReference, MECH_JWT);
}
/**
* Handles an existing user login.
*
* @param username The username.
* @param jwt The decoded JWT.
*/
public void existingUserLogin(String username, SignedJWT jwt) throws ParseException {
Organization organization = securityService.getOrganization();
// Load the user reference
JpaUserReference userReference = userReferenceProvider.findUserReference(username, organization.getId());
if (userReference == null) {
throw new UsernameNotFoundException("User reference '" + username + "' was not found");
}
// Update the reference
userReference.setName(extractName(jwt));
userReference.setEmail(extractEmail(jwt));
userReference.setLastLogin(new Date());
userReference.setRoles(extractRoles(jwt));
logger.debug("JWT user '{}' logged in", username);
userReferenceProvider.updateUserReference(userReference);
}
/**
* Converts a {@link Organization} object into a {@link JpaOrganization} object.
*
* @param org The {@link Organization} object.
* @return The corresponding {@link JpaOrganization} object.
*/
private JpaOrganization fromOrganization(Organization org) {
if (org instanceof JpaOrganization) {
return (JpaOrganization) org;
}
return new JpaOrganization(org.getId(), org.getName(), org.getServers(), org.getAdminRole(), org.getAnonymousRole(),
org.getProperties());
}
/**
* Setter for the user details service.
*
* @param userDetailsService The user details service.
*/
@Reference
public void setUserDetailsService(UserDetailsService userDetailsService) {
this.userDetailsService = userDetailsService;
}
/**
* Setter for the user directory service.
*
* @param userDirectoryService The user directory service.
*/
@Reference
public void setUserDirectoryService(UserDirectoryService userDirectoryService) {
this.userDirectoryService = userDirectoryService;
}
/**
* Setter for the security service.
*
* @param securityService The security service.
*/
@Reference
public void setSecurityService(SecurityService securityService) {
this.securityService = securityService;
}
/**
* Setter for the user reference provider.
*
* @param userReferenceProvider The user reference provider.
*/
@Reference
public void setUserReferenceProvider(UserReferenceProvider userReferenceProvider) {
this.userReferenceProvider = userReferenceProvider;
}
/**
* Setter for the JWKS URL.
*
* @param jwksUrl The JWKS URL.
*/
public void setJwksUrl(String jwksUrl) {
this.jwksUrl = jwksUrl;
}
/**
* Setter for the JWKS ttl.
*
* @param jwksTimeToLive The default time to live of cached JWK sets, in milliseconds
*/
public void setJwksTimeToLive(int jwksTimeToLive) {
this.jwksTimeToLive = jwksTimeToLive;
}
/**
* Setter for the secret used for JWT validation.
*
* @param secret The secret.
*/
public void setSecret(String secret) {
this.secret = secret;
}
/**
* Setter for the expected algorithms.
*
* @param expectedAlgorithms The expected algorithms.
*/
public void setExpectedAlgorithms(List<String> expectedAlgorithms) {
this.expectedAlgorithms = expectedAlgorithms;
}
/**
* Setter for the claim constraints.
*
* @param claimConstraints The claim constraints.
*/
public void setClaimConstraints(List<String> claimConstraints) {
this.claimConstraints = claimConstraints;
}
/**
* Setter for the username mapping.
* @param usernameMapping The username mapping.
*/
public void setUsernameMapping(String usernameMapping) {
this.usernameMapping = usernameMapping;
}
/**
* Setter for the name mapping.
*
* @param nameMapping The name mapping.
*/
public void setNameMapping(String nameMapping) {
this.nameMapping = nameMapping;
}
/**
* Setter for the email mapping.
* @param emailMapping The email mapping.
*/
public void setEmailMapping(String emailMapping) {
this.emailMapping = emailMapping;
}
public void setOcStandardRoleMappings(boolean ocStandardRoleMappings) {
this.ocStandardRoleMappings = ocStandardRoleMappings;
}
/**
* Setter for the role mappings.
*
* @param roleMappings The role mappings.
*/
public void setRoleMappings(List<String> roleMappings) {
this.roleMappings = roleMappings;
}
/**
* Setter for the JWT cache size.
*
* @param jwtCacheSize The JWT cache size.
*/
public void setJwtCacheSize(int jwtCacheSize) {
this.jwtCacheSize = jwtCacheSize;
}
/**
* Setter for the JWT cache expiration.
*
* @param jwtCacheExpiresIn The number of minutes after which a cached JWT expires.
*/
public void setJwtCacheExpiresIn(int jwtCacheExpiresIn) {
this.jwtCacheExpiresIn = jwtCacheExpiresIn;
}
}