View Javadoc
1   /*
2    * Licensed to The Apereo Foundation under one or more contributor license
3    * agreements. See the NOTICE file distributed with this work for additional
4    * information regarding copyright ownership.
5    *
6    *
7    * The Apereo Foundation licenses this file to you under the Educational
8    * Community License, Version 2.0 (the "License"); you may not use this file
9    * except in compliance with the License. You may obtain a copy of the License
10   * at:
11   *
12   *   http://opensource.org/licenses/ecl2.txt
13   *
14   * Unless required by applicable law or agreed to in writing, software
15   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
17   * License for the specific language governing permissions and limitations under
18   * the License.
19   *
20   */
21  
22  package org.opencastproject.security.jwt;
23  
24  import org.opencastproject.security.api.Organization;
25  import org.opencastproject.security.api.SecurityService;
26  import org.opencastproject.security.api.UserDirectoryService;
27  import org.opencastproject.security.impl.jpa.JpaOrganization;
28  import org.opencastproject.security.impl.jpa.JpaRole;
29  import org.opencastproject.security.impl.jpa.JpaUserReference;
30  import org.opencastproject.security.util.SecurityUtil;
31  import org.opencastproject.userdirectory.api.UserReferenceProvider;
32  
33  import com.auth0.jwk.JwkException;
34  import com.auth0.jwt.exceptions.JWTDecodeException;
35  import com.auth0.jwt.exceptions.JWTVerificationException;
36  import com.auth0.jwt.interfaces.DecodedJWT;
37  import com.google.common.cache.Cache;
38  import com.google.common.cache.CacheBuilder;
39  
40  import org.apache.commons.lang3.StringUtils;
41  import org.osgi.service.component.annotations.Reference;
42  import org.slf4j.Logger;
43  import org.slf4j.LoggerFactory;
44  import org.springframework.beans.factory.InitializingBean;
45  import org.springframework.expression.Expression;
46  import org.springframework.expression.ExpressionParser;
47  import org.springframework.expression.spel.standard.SpelExpressionParser;
48  import org.springframework.security.core.userdetails.UserDetailsService;
49  import org.springframework.security.core.userdetails.UsernameNotFoundException;
50  import org.springframework.util.Assert;
51  
52  
53  import java.util.ArrayList;
54  import java.util.Date;
55  import java.util.HashSet;
56  import java.util.List;
57  import java.util.Set;
58  import java.util.concurrent.TimeUnit;
59  import java.util.function.Consumer;
60  
61  /**
62   * Dynamic login handler for JWTs.
63   */
64  public class DynamicLoginHandler implements InitializingBean, JWTLoginHandler {
65  
66    /** Logging facility. */
67    private static final Logger logger = LoggerFactory.getLogger(DynamicLoginHandler.class);
68  
69    /** Spring security's user details manager. */
70    private UserDetailsService userDetailsService = null;
71  
72    /** User directory service. */
73    private UserDirectoryService userDirectoryService = null;
74  
75    /** User reference provider. */
76    private UserReferenceProvider userReferenceProvider = null;
77  
78    /** Security service. */
79    private SecurityService securityService = null;
80  
81    /** JWKS URL to use for JWT validation (asymmetric algorithms). */
82    private String jwksUrl = null;
83  
84    /** Number of minutes fetched JWKs will be cached. */
85    private int jwksCacheExpiresIn = 60 * 24;
86  
87    /** Secret to use for JWT validation (symmetric algorithms). */
88    private String secret = null;
89  
90    /** Allowed algorithms with which a valid JWT may be signed ('alg' claim). */
91    private List<String> expectedAlgorithms = null;
92  
93    /** Constraints that the claims of a valid JWT must fulfill. */
94    private List<String> claimConstraints = null;
95  
96    /** Mapping used to extract the username from the JWT. */
97    private String usernameMapping = null;
98  
99    /** Mapping used to extract the name from the JWT. */
100   private String nameMapping = null;
101 
102   /** Mapping used to extract the email from the JWT. */
103   private String emailMapping = null;
104 
105   /** Whether the built-in schema role mapping should be used. */
106   private boolean ocStandardRoleMappings = true;
107 
108   /** Mapping used to extract roles from the JWT. */
109   private List<String> roleMappings = null;
110 
111   /** Mapping used to extract roles from the JWT. */
112   private GuavaCachedUrlJwkProvider jwkProvider;
113 
114   /** Size of the JWT cache. */
115   private int jwtCacheSize = 500;
116 
117   /** Number of minutes validated JWTs will be cached before re-validating them. */
118   private int jwtCacheExpiresIn = 60;
119 
120   /** Cache for validated JWTs. */
121   private Cache<String, CachedJWT> cache;
122 
123   @Override
124   public void afterPropertiesSet() {
125     Assert.notNull(userDetailsService, "A UserDetailsService must be set");
126     Assert.notNull(userDirectoryService, "A UserDirectoryService must be set");
127     Assert.notNull(userReferenceProvider, "A UserReferenceProvider must be set");
128     Assert.notNull(securityService, "A SecurityService must be set");
129     Assert.isTrue(StringUtils.isNotBlank(jwksUrl) ^ StringUtils.isNotBlank(secret),
130         "Either a JWKS URL or a secret must be set");
131     Assert.notEmpty(expectedAlgorithms, "Expected algorithms must be set");
132     Assert.notEmpty(claimConstraints, "Claim constraints must be set");
133     Assert.notNull(usernameMapping, "User name mapping must be set");
134     Assert.notNull(nameMapping, "Name mapping must be set");
135     Assert.notNull(emailMapping, "Email mapping must be set");
136     Assert.isTrue(roleMappings != null || ocStandardRoleMappings,
137             "Role mappings must be set if ocStandardRoleMappings is false");
138 
139     if (jwksUrl != null) {
140       jwkProvider = new GuavaCachedUrlJwkProvider(jwksUrl, jwksCacheExpiresIn, TimeUnit.MINUTES);
141     }
142     userReferenceProvider.setRoleProvider(new JWTRoleProvider(securityService, userReferenceProvider));
143     cache = CacheBuilder.newBuilder()
144         .maximumSize(jwtCacheSize)
145         .expireAfterWrite(jwtCacheExpiresIn, TimeUnit.MINUTES)
146         .build();
147   }
148 
149   @Override
150   public String handleToken(String token) {
151     try {
152       String signature = extractSignature(token);
153       CachedJWT cachedJwt = cache.getIfPresent(signature);
154 
155       if (cachedJwt == null) {
156         // JWT hasn't been cached before, so validate all claims
157         DecodedJWT jwt = decodeAndValidate(token);
158         String username = extractUsername(jwt);
159 
160         try {
161           if (userDetailsService.loadUserByUsername(username) != null) {
162             existingUserLogin(username, jwt);
163           }
164         } catch (UsernameNotFoundException e) {
165           newUserLogin(username, jwt);
166         }
167 
168         userDirectoryService.invalidate(username);
169         cache.put(jwt.getSignature(), new CachedJWT(jwt, username));
170         return username;
171       } else {
172         // JWT has been cached before, so only check if it has expired
173         if (cachedJwt.hasExpired()) {
174           cache.invalidate(signature);
175           throw new JWTVerificationException("JWT token is not valid anymore");
176         }
177         logger.debug("Using decoded and validated JWT from cache");
178         return cachedJwt.getUsername();
179       }
180     } catch (JWTVerificationException | JwkException exception) {
181       logger.debug(exception.getMessage());
182     }
183 
184     return null;
185   }
186 
187   /**
188    * Decodes and validates a JWT.
189    *
190    * @param token The JWT string.
191    * @return The decoded JWT.
192    * @throws JwkException If the JWT fails to be validated.
193    */
194   private DecodedJWT decodeAndValidate(String token) throws JwkException {
195     DecodedJWT jwt;
196 
197     if (jwksUrl != null) {
198       jwt = JWTVerifier.verify(token, jwkProvider, claimConstraints);
199     } else {
200       jwt = JWTVerifier.verify(token, secret, claimConstraints);
201     }
202 
203     if (!expectedAlgorithms.contains(jwt.getAlgorithm())) {
204       throw new JWTVerificationException(
205           "JWT token was signed with an unexpected algorithm '" + jwt.getAlgorithm() + "'"
206       );
207     }
208 
209     return jwt;
210   }
211 
212   /**
213    * Extracts the signature from a JWT.
214    *
215    * @param token The JWT string.
216    * @return The JWT's signature.
217    */
218   private String extractSignature(String token) {
219     String[] parts = token.split("\\.");
220     if (parts.length != 3) {
221       throw new JWTDecodeException("Given token is not in a valid JWT format");
222     }
223     return parts[2];
224   }
225 
226   /**
227    * Extracts the username from a decoded and validated JWT.
228    *
229    * @param jwt The decoded JWT.
230    * @return The username.
231    */
232   private String extractUsername(DecodedJWT jwt) {
233     String username = evaluateMapping(jwt, usernameMapping);
234     Assert.isTrue(StringUtils.isNotBlank(username), "Extracted username is blank");
235     return username;
236   }
237 
238   /**
239    * Extracts the name from a decoded and validated JWT.
240    *
241    * @param jwt The decoded JWT.
242    * @return The name.
243    */
244   private String extractName(DecodedJWT jwt) {
245     String name = evaluateMapping(jwt, nameMapping);
246     Assert.isTrue(StringUtils.isNotBlank(name), "Extracted name is blank");
247     return name;
248   }
249 
250   /**
251    * Extracts the email from a decoded and validated JWT.
252    *
253    * @param jwt The decoded JWT.
254    * @return The email.
255    */
256   private String extractEmail(DecodedJWT jwt) {
257     String email = evaluateMapping(jwt, emailMapping);
258     Assert.isTrue(StringUtils.isNotBlank(email), "Extracted email is blank");
259     return email;
260   }
261 
262   /**
263    * Extracts the roles from a decoded and validated JWT.
264    *
265    * @param jwt The decoded JWT.
266    * @return The roles.
267    */
268   private Set<JpaRole> extractRoles(DecodedJWT jwt) {
269     JpaOrganization organization = fromOrganization(securityService.getOrganization());
270     Set<JpaRole> roles = new HashSet<>();
271     Consumer<String> addRole = (String role) -> {
272       if (StringUtils.isNotBlank(role)) {
273         roles.add(new JpaRole(role, organization));
274       }
275     };
276 
277     // Evaluate the standard role mapping if specified
278     if (ocStandardRoleMappings) {
279       // Read `role` claim
280       try {
281         var rolesClaim = jwt.getClaim("roles");
282         if (rolesClaim != null && !rolesClaim.isNull()) {
283           for (String r : rolesClaim.asArray(String.class)) {
284             addRole.accept(r);
285           }
286         }
287       } catch (JWTDecodeException e) {
288         logger.debug("claim 'roles' is not an array of strings, ignoring");
289       }
290 
291       // Read `oc` claim
292       try {
293         var ocClaim = jwt.getClaim("oc");
294         if (ocClaim != null && !ocClaim.isNull()) {
295           for (var entry : ocClaim.asMap().entrySet()) {
296             var key = entry.getKey();
297             var parts = key.split(":", 2);
298             if (parts.length != 2) {
299               logger.debug("key in 'oc' claim does not start with 'x:' -> ignoring");
300               continue;
301             }
302             var type = parts[0];
303             var id = parts[1];
304 
305             try {
306               for (var actionObj : (List<?>) entry.getValue()) {
307                 var action = (String) actionObj;
308                 if (action.isBlank()) {
309                   continue;
310                 }
311 
312                 if (type.equals("e")) {
313                   addRole.accept(SecurityUtil.getEpisodeRoleId(id, action));
314                 } else {
315                   logger.debug("in 'oc' claim: granting access to item type '{}' is not yet supported", type);
316                 }
317               }
318             } catch (ClassCastException e) {
319               logger.debug("value in 'oc' claim is not a string array -> ignoring");
320               continue;
321             }
322           }
323         }
324       } catch (JWTDecodeException e) {
325         logger.debug("claim 'oc' is not an array of strings, ignoring");
326       }
327     }
328 
329     for (String mapping : (roleMappings == null ? new ArrayList<String>() : roleMappings)) {
330       ExpressionParser parser = new SpelExpressionParser();
331       Expression exp = parser.parseExpression(mapping);
332       Object value = exp.getValue(jwt.getClaims());
333       if (value != null) {
334         // We allow the expression to either return a string directly, or a list/array of strings.
335         if (value instanceof String) {
336           addRole.accept((String) value);
337         } else if (value.getClass().isArray()) {
338           for (var role : (Object[]) value) {
339             addRole.accept((String) role);
340           }
341         } else {
342           for (var role : (List<?>) value) {
343             addRole.accept((String) role);
344           }
345         }
346       }
347     }
348     Assert.notEmpty(roles, "No roles could be extracted");
349     return roles;
350   }
351 
352   /**
353    * Evaluates a mapping given in SpEL on a decoded JWT.
354    *
355    * @param jwt The decoded JWT.
356    * @param mapping The mapping.
357    *
358    * @return The string evaluated from the mapping.
359    */
360   private String evaluateMapping(DecodedJWT jwt, String mapping) {
361     ExpressionParser parser = new SpelExpressionParser();
362     Expression exp = parser.parseExpression(mapping);
363     return exp.getValue(jwt.getClaims(), String.class);
364   }
365 
366   /**
367    * Handles a new user login.
368    *
369    * @param username The username.
370    * @param jwt The decoded JWT.
371    */
372   public void newUserLogin(String username, DecodedJWT jwt) {
373     // Create a new user reference
374     JpaUserReference userReference = new JpaUserReference(username, extractName(jwt), extractEmail(jwt), MECH_JWT,
375         new Date(), fromOrganization(securityService.getOrganization()), extractRoles(jwt));
376 
377     logger.debug("JWT user '{}' logged in for the first time", username);
378     userReferenceProvider.addUserReference(userReference, MECH_JWT);
379   }
380 
381   /**
382    * Handles an existing user login.
383    *
384    * @param username The username.
385    * @param jwt The decoded JWT.
386    */
387   public void existingUserLogin(String username, DecodedJWT jwt) {
388     Organization organization = securityService.getOrganization();
389 
390     // Load the user reference
391     JpaUserReference userReference = userReferenceProvider.findUserReference(username, organization.getId());
392     if (userReference == null) {
393       throw new UsernameNotFoundException("User reference '" + username + "' was not found");
394     }
395 
396     // Update the reference
397     userReference.setName(extractName(jwt));
398     userReference.setEmail(extractEmail(jwt));
399     userReference.setLastLogin(new Date());
400     userReference.setRoles(extractRoles(jwt));
401 
402     logger.debug("JWT user '{}' logged in", username);
403     userReferenceProvider.updateUserReference(userReference);
404   }
405 
406   /**
407    * Converts a {@link Organization} object into a {@link JpaOrganization} object.
408    *
409    * @param org The {@link Organization} object.
410    * @return The corresponding {@link JpaOrganization} object.
411    */
412   private JpaOrganization fromOrganization(Organization org) {
413     if (org instanceof JpaOrganization) {
414       return (JpaOrganization) org;
415     }
416 
417     return new JpaOrganization(org.getId(), org.getName(), org.getServers(), org.getAdminRole(), org.getAnonymousRole(),
418         org.getProperties());
419   }
420 
421   /**
422    * Setter for the user details service.
423    *
424    * @param userDetailsService The user details service.
425    */
426   @Reference
427   public void setUserDetailsService(UserDetailsService userDetailsService) {
428     this.userDetailsService = userDetailsService;
429   }
430 
431   /**
432    * Setter for the user directory service.
433    *
434    * @param userDirectoryService The user directory service.
435    */
436   @Reference
437   public void setUserDirectoryService(UserDirectoryService userDirectoryService) {
438     this.userDirectoryService = userDirectoryService;
439   }
440 
441   /**
442    * Setter for the security service.
443    *
444    * @param securityService The security service.
445    */
446   @Reference
447   public void setSecurityService(SecurityService securityService) {
448     this.securityService = securityService;
449   }
450 
451   /**
452    * Setter for the user reference provider.
453    *
454    * @param userReferenceProvider The user reference provider.
455    */
456   @Reference
457   public void setUserReferenceProvider(UserReferenceProvider userReferenceProvider) {
458     this.userReferenceProvider = userReferenceProvider;
459   }
460 
461   /**
462    * Setter for the JWKS URL.
463    *
464    * @param jwksUrl The JWKS URL.
465    */
466   public void setJwksUrl(String jwksUrl) {
467     this.jwksUrl = jwksUrl;
468   }
469 
470   /**
471    * Setter for the JWKS cache expiration.
472    *
473    * @param jwksCacheExpiresIn The number of minutes after which a cached JWKS expires.
474    */
475   public void setJwksCacheExpiresIn(int jwksCacheExpiresIn) {
476     this.jwksCacheExpiresIn = jwksCacheExpiresIn;
477   }
478 
479   /**
480    * Setter for the secret used for JWT validation.
481    *
482    * @param secret The secret.
483    */
484   public void setSecret(String secret) {
485     this.secret = secret;
486   }
487 
488   /**
489    * Setter for the expected algorithms.
490    *
491    * @param expectedAlgorithms The expected algorithms.
492    */
493   public void setExpectedAlgorithms(List<String> expectedAlgorithms) {
494     this.expectedAlgorithms = expectedAlgorithms;
495   }
496 
497   /**
498    * Setter for the claim constraints.
499    *
500    * @param claimConstraints The claim constraints.
501    */
502   public void setClaimConstraints(List<String> claimConstraints) {
503     this.claimConstraints = claimConstraints;
504   }
505 
506   /**
507    * Setter for the username mapping.
508    * @param usernameMapping The username mapping.
509    */
510   public void setUsernameMapping(String usernameMapping) {
511     this.usernameMapping = usernameMapping;
512   }
513 
514   /**
515    * Setter for the name mapping.
516    *
517    * @param nameMapping The name mapping.
518    */
519   public void setNameMapping(String nameMapping) {
520     this.nameMapping = nameMapping;
521   }
522 
523   /**
524    * Setter for the email mapping.
525    * @param emailMapping The email mapping.
526    */
527   public void setEmailMapping(String emailMapping) {
528     this.emailMapping = emailMapping;
529   }
530 
531   public void setOcStandardRoleMappings(boolean ocStandardRoleMappings) {
532     this.ocStandardRoleMappings = ocStandardRoleMappings;
533   }
534 
535   /**
536    * Setter for the role mappings.
537    *
538    * @param roleMappings The role mappings.
539    */
540   public void setRoleMappings(List<String> roleMappings) {
541     this.roleMappings = roleMappings;
542   }
543 
544   /**
545    * Setter for the JWT cache size.
546    *
547    * @param jwtCacheSize The JWT cache size.
548    */
549   public void setJwtCacheSize(int jwtCacheSize) {
550     this.jwtCacheSize = jwtCacheSize;
551   }
552 
553   /**
554    * Setter for the JWT cache expiration.
555    *
556    * @param jwtCacheExpiresIn The number of minutes after which a cached JWT expires.
557    */
558   public void setJwtCacheExpiresIn(int jwtCacheExpiresIn) {
559     this.jwtCacheExpiresIn = jwtCacheExpiresIn;
560   }
561 
562 }