View Javadoc
1   /*
2    * Licensed to The Apereo Foundation under one or more contributor license
3    * agreements. See the NOTICE file distributed with this work for additional
4    * information regarding copyright ownership.
5    *
6    *
7    * The Apereo Foundation licenses this file to you under the Educational
8    * Community License, Version 2.0 (the "License"); you may not use this file
9    * except in compliance with the License. You may obtain a copy of the License
10   * at:
11   *
12   *   http://opensource.org/licenses/ecl2.txt
13   *
14   * Unless required by applicable law or agreed to in writing, software
15   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
17   * License for the specific language governing permissions and limitations under
18   * the License.
19   *
20   */
21  
22  package org.opencastproject.security.jwt;
23  
24  import org.opencastproject.security.api.Organization;
25  import org.opencastproject.security.api.SecurityService;
26  import org.opencastproject.security.api.UserDirectoryService;
27  import org.opencastproject.security.impl.jpa.JpaOrganization;
28  import org.opencastproject.security.impl.jpa.JpaRole;
29  import org.opencastproject.security.impl.jpa.JpaUserReference;
30  import org.opencastproject.security.util.SecurityUtil;
31  import org.opencastproject.userdirectory.api.UserReferenceProvider;
32  
33  import com.google.common.cache.Cache;
34  import com.google.common.cache.CacheBuilder;
35  import com.nimbusds.jose.JOSEException;
36  import com.nimbusds.jwt.SignedJWT;
37  
38  import org.apache.commons.lang3.StringUtils;
39  import org.osgi.service.component.annotations.Reference;
40  import org.slf4j.Logger;
41  import org.slf4j.LoggerFactory;
42  import org.springframework.beans.factory.InitializingBean;
43  import org.springframework.context.expression.MapAccessor;
44  import org.springframework.expression.Expression;
45  import org.springframework.expression.ExpressionParser;
46  import org.springframework.expression.spel.standard.SpelExpressionParser;
47  import org.springframework.expression.spel.support.StandardEvaluationContext;
48  import org.springframework.security.core.userdetails.UserDetailsService;
49  import org.springframework.security.core.userdetails.UsernameNotFoundException;
50  import org.springframework.util.Assert;
51  
52  import java.text.ParseException;
53  import java.util.ArrayList;
54  import java.util.Date;
55  import java.util.HashSet;
56  import java.util.List;
57  import java.util.Objects;
58  import java.util.Set;
59  import java.util.concurrent.TimeUnit;
60  import java.util.function.Consumer;
61  
62  /**
63   * Dynamic login handler for JWTs.
64   */
65  public class DynamicLoginHandler implements InitializingBean, JWTLoginHandler {
66  
67    /** Logging facility. */
68    private static final Logger logger = LoggerFactory.getLogger(DynamicLoginHandler.class);
69  
70    /** Spring security's user details manager. */
71    private UserDetailsService userDetailsService = null;
72  
73    /** User directory service. */
74    private UserDirectoryService userDirectoryService = null;
75  
76    /** User reference provider. */
77    private UserReferenceProvider userReferenceProvider = null;
78  
79    /** Security service. */
80    private SecurityService securityService = null;
81  
82    /** JWKS URL to use for JWT validation (asymmetric algorithms). */
83    private String jwksUrl = null;
84  
85    /** The default time to live of cached JWK sets, in milliseconds. */
86    private int jwksTimeToLive = 1000 * 60 * 60;
87  
88    /** The default refresh timeout of cached JWK sets, in milliseconds. */
89    private int jwksRefreshTimeout = 1000 * 60;
90  
91    /** Secret to use for JWT validation (symmetric algorithms). */
92    private String secret = null;
93  
94    /** Allowed algorithms with which a valid JWT may be signed ('alg' claim). */
95    private List<String> expectedAlgorithms = null;
96  
97    /** Constraints that the claims of a valid JWT must fulfill. */
98    private List<String> claimConstraints = null;
99  
100   /** Mapping used to extract the username from the JWT. */
101   private String usernameMapping = null;
102 
103   /** Mapping used to extract the name from the JWT. */
104   private String nameMapping = null;
105 
106   /** Mapping used to extract the email from the JWT. */
107   private String emailMapping = null;
108 
109   /** Whether the built-in schema role mapping should be used. */
110   private boolean ocStandardRoleMappings = true;
111 
112   /** Mapping used to extract roles from the JWT. */
113   private List<String> roleMappings = null;
114 
115   /** Mapping used to extract roles from the JWT. */
116   private JWKSetProvider jwkProvider;
117 
118   /** Size of the JWT cache. */
119   private int jwtCacheSize = 500;
120 
121   /** Number of minutes validated JWTs will be cached before re-validating them. */
122   private int jwtCacheExpiresIn = 60;
123 
124   /** Cache for validated JWTs. */
125   private Cache<String, CachedJWT> cache;
126 
127   @Override
128   public void afterPropertiesSet() {
129     Assert.notNull(userDetailsService, "A UserDetailsService must be set");
130     Assert.notNull(userDirectoryService, "A UserDirectoryService must be set");
131     Assert.notNull(userReferenceProvider, "A UserReferenceProvider must be set");
132     Assert.notNull(securityService, "A SecurityService must be set");
133     Assert.isTrue(!(StringUtils.isNotBlank(jwksUrl) && StringUtils.isNotBlank(secret)),
134         "A JWKS URL and a secret cannot be set at the same time");
135     Assert.notEmpty(expectedAlgorithms, "Expected algorithms must be set");
136     Assert.notNull(claimConstraints, "Claim constraints must be set");
137     Assert.notNull(usernameMapping, "User name mapping must be set");
138     Assert.notNull(nameMapping, "Name mapping must be set");
139     Assert.notNull(emailMapping, "Email mapping must be set");
140     Assert.isTrue(roleMappings != null || ocStandardRoleMappings,
141             "Role mappings must be set if ocStandardRoleMappings is false");
142 
143     if (StringUtils.isBlank(jwksUrl) && StringUtils.isBlank(secret)) {
144       logger.info("JWT login handler disabled as neither 'jwksUrl' nor 'secret' are set");
145     }
146 
147     if (jwksUrl != null) {
148       jwkProvider = new JWKSetProvider(jwksUrl, jwksTimeToLive, jwksRefreshTimeout);
149     }
150     userReferenceProvider.setRoleProvider(new JWTRoleProvider(securityService, userReferenceProvider));
151     cache = CacheBuilder.newBuilder()
152         .maximumSize(jwtCacheSize)
153         .expireAfterWrite(jwtCacheExpiresIn, TimeUnit.MINUTES)
154         .build();
155   }
156 
157   @Override
158   public String handleToken(String token) {
159     if (jwkProvider == null && secret == null) {
160       logger.debug("neither jwksURL nor secret set: ignoring JWT");
161       return null;
162     }
163 
164     try {
165       String signature = extractSignature(token);
166       CachedJWT cachedJwt = cache.getIfPresent(signature);
167 
168       if (cachedJwt == null) {
169         logger.debug("JWT cache miss for signature {}, validating token", abbreviateSignature(signature));
170 
171         // JWT hasn't been cached before, so validate all claims
172         SignedJWT jwt = decodeAndValidate(token);
173         String username = extractUsername(jwt);
174         boolean invalidateUserDirectory = true;
175 
176         try {
177           if (userDetailsService.loadUserByUsername(username) != null) {
178             invalidateUserDirectory = existingUserLogin(username, jwt);
179           }
180         } catch (UsernameNotFoundException e) {
181           newUserLogin(username, jwt);
182         }
183 
184         if (invalidateUserDirectory) {
185           logger.debug("Invalidating user directory caches for JWT user '{}'", username);
186           userDirectoryService.invalidate(username);
187         } else {
188           logger.debug("Skipping user directory invalidation for JWT user '{}' since effective user data is unchanged",
189               username);
190         }
191         cache.put(jwt.getSignature().toString(), new CachedJWT(jwt, username));
192         logger.debug("Cached validated JWT signature {} for user '{}'", abbreviateSignature(signature), username);
193         return username;
194       } else {
195         // JWT has been cached before, so only check if it has expired
196         if (cachedJwt.hasExpired()) {
197           cache.invalidate(signature);
198           throw new JOSEException("JWT token is not valid anymore");
199         }
200         logger.debug("Using validated JWT from cache for user '{}' and signature {}", cachedJwt.getUsername(),
201             abbreviateSignature(signature));
202         return cachedJwt.getUsername();
203       }
204     } catch (ParseException | JOSEException exception) {
205       logger.debug(exception.getMessage());
206     }
207 
208     return null;
209   }
210 
211   /**
212    * Decodes and validates a JWT.
213    *
214    * @param token The JWT string.
215    * @return The decoded JWT.
216    * @throws JOSEException If the JWT fails to be validated.
217    */
218   private SignedJWT decodeAndValidate(String token) throws ParseException, JOSEException {
219     SignedJWT jwt;
220 
221     if (jwksUrl != null) {
222       jwt = JWTVerifier.verify(token, jwkProvider, claimConstraints);
223     } else {
224       jwt = JWTVerifier.verify(token, secret, claimConstraints);
225     }
226 
227     if (!expectedAlgorithms.contains(jwt.getHeader().getAlgorithm().getName())) {
228       throw new JOSEException(
229           "JWT token was signed with an unexpected algorithm '" + jwt.getHeader().getAlgorithm() + "'"
230       );
231     }
232 
233     return jwt;
234   }
235 
236   /**
237    * Extracts the signature from a JWT.
238    *
239    * @param token The JWT string.
240    * @return The JWT's signature.
241    */
242   private String extractSignature(String token) throws JOSEException {
243     String[] parts = token.split("\\.");
244     if (parts.length != 3) {
245       throw new JOSEException("Given token is not in a valid JWT format");
246     }
247     return parts[2];
248   }
249 
250   /**
251    * Extracts the username from a decoded and validated JWT.
252    *
253    * @param jwt The decoded JWT.
254    * @return The username.
255    */
256   private String extractUsername(SignedJWT jwt) throws ParseException {
257     String username = evaluateMapping(jwt, usernameMapping);
258     Assert.isTrue(StringUtils.isNotBlank(username), "Extracted username is blank");
259     return username;
260   }
261 
262   /**
263    * Extracts the name from a decoded and validated JWT.
264    *
265    * @param jwt The decoded JWT.
266    * @return The name.
267    */
268   private String extractName(SignedJWT jwt) throws ParseException {
269     String name = evaluateMapping(jwt, nameMapping);
270     Assert.isTrue(StringUtils.isNotBlank(name), "Extracted name is blank");
271     return name;
272   }
273 
274   /**
275    * Extracts the email from a decoded and validated JWT.
276    *
277    * @param jwt The decoded JWT.
278    * @return The email.
279    */
280   private String extractEmail(SignedJWT jwt) throws ParseException {
281     String email = evaluateMapping(jwt, emailMapping);
282     Assert.isTrue(StringUtils.isNotBlank(email), "Extracted email is blank");
283     return email;
284   }
285 
286   /**
287    * Extracts the roles from a decoded and validated JWT.
288    *
289    * @param jwt The decoded JWT.
290    * @return The roles.
291    */
292   private Set<JpaRole> extractRoles(SignedJWT jwt) throws ParseException {
293     JpaOrganization organization = fromOrganization(securityService.getOrganization());
294     Set<JpaRole> roles = new HashSet<>();
295     Consumer<String> addRole = (String role) -> {
296       if (StringUtils.isNotBlank(role)) {
297         roles.add(new JpaRole(role, organization));
298       }
299     };
300 
301     // Evaluate the standard role mapping if specified
302     if (ocStandardRoleMappings) {
303       // Read `role` claim
304       try {
305         var rolesClaim = jwt.getJWTClaimsSet().getStringArrayClaim("roles");
306         if (rolesClaim != null) {
307           for (String r : rolesClaim) {
308             addRole.accept(r);
309           }
310         }
311       } catch (ParseException e) {
312         logger.debug("claim 'roles' is not an array of strings, ignoring");
313       }
314 
315       // Read `oc` claim
316       try {
317         var ocClaim = jwt.getJWTClaimsSet().getJSONObjectClaim("oc");
318         if (ocClaim != null) {
319           for (var entry : ocClaim.entrySet()) {
320             var key = entry.getKey();
321             var parts = key.split(":", 2);
322             if (parts.length != 2) {
323               logger.debug("key in 'oc' claim does not start with 'x:' -> ignoring");
324               continue;
325             }
326             var type = parts[0];
327             var id = parts[1];
328 
329             try {
330               for (var actionObj : (List<?>) entry.getValue()) {
331                 var action = (String) actionObj;
332                 if (action.isBlank()) {
333                   continue;
334                 }
335 
336                 if (type.equals("e")) {
337                   addRole.accept(SecurityUtil.getEpisodeRoleId(id, action));
338                 } else {
339                   logger.debug("in 'oc' claim: granting access to item type '{}' is not yet supported", type);
340                 }
341               }
342             } catch (ClassCastException e) {
343               logger.debug("value in 'oc' claim is not a string array -> ignoring");
344               continue;
345             }
346           }
347         }
348       } catch (ParseException e) {
349         logger.debug("claim 'oc' is not an array of strings, ignoring");
350       }
351     }
352 
353     for (String mapping : (roleMappings == null ? new ArrayList<String>() : roleMappings)) {
354       ExpressionParser parser = new SpelExpressionParser();
355       Expression exp = parser.parseExpression(mapping);
356       StandardEvaluationContext ctx = new StandardEvaluationContext();
357       ctx.addPropertyAccessor(new MapAccessor());
358       Object value = exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims());
359       if (value != null) {
360         // We allow the expression to either return a string directly, or a list/array of strings.
361         if (value instanceof String) {
362           addRole.accept((String) value);
363         } else if (value.getClass().isArray()) {
364           for (var role : (Object[]) value) {
365             addRole.accept((String) role);
366           }
367         } else {
368           for (var role : (List<?>) value) {
369             addRole.accept((String) role);
370           }
371         }
372       }
373     }
374     Assert.notEmpty(roles, "No roles could be extracted");
375     return roles;
376   }
377 
378   /**
379    * Evaluates a mapping given in SpEL on a decoded JWT.
380    *
381    * @param jwt The decoded JWT.
382    * @param mapping The mapping.
383    *
384    * @return The string evaluated from the mapping.
385    */
386   private String evaluateMapping(SignedJWT jwt, String mapping) throws ParseException {
387     ExpressionParser parser = new SpelExpressionParser();
388     Expression exp = parser.parseExpression(mapping);
389     StandardEvaluationContext ctx = new StandardEvaluationContext();
390     ctx.addPropertyAccessor(new MapAccessor());
391     return exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims(), String.class);
392   }
393 
394   /**
395    * Handles a new user login.
396    *
397    * @param username The username.
398    * @param jwt The decoded JWT.
399    */
400   public void newUserLogin(String username, SignedJWT jwt) throws ParseException {
401     // Create a new user reference
402     JpaUserReference userReference = new JpaUserReference(username, extractName(jwt), extractEmail(jwt), MECH_JWT,
403         new Date(), fromOrganization(securityService.getOrganization()), extractRoles(jwt));
404 
405     logger.debug("JWT user '{}' logged in for the first time, creating a new user reference", username);
406     userReferenceProvider.addUserReference(userReference, MECH_JWT);
407   }
408 
409   /**
410    * Handles an existing user login.
411    *
412    * @param username The username.
413    * @param jwt The decoded JWT.
414    */
415   public boolean existingUserLogin(String username, SignedJWT jwt) throws ParseException {
416     Organization organization = securityService.getOrganization();
417     String name = extractName(jwt);
418     String email = extractEmail(jwt);
419     Set<JpaRole> roles = extractRoles(jwt);
420 
421     // Load the user reference
422     JpaUserReference userReference = userReferenceProvider.findUserReference(username, organization.getId());
423     if (userReference == null) {
424       throw new UsernameNotFoundException("User reference '" + username + "' was not found");
425     }
426 
427     boolean changed = userReferenceChanged(userReference, name, email, roles);
428     String changedFields = changed ? describeUserReferenceChanges(userReference, name, email, roles) : "";
429     // Update the reference
430     userReference.setName(name);
431     userReference.setEmail(email);
432     userReference.setLastLogin(new Date());
433     userReference.setRoles(roles);
434 
435     if (changed) {
436       logger.debug("JWT user '{}' logged in with updated user data ({}), refreshing merged user caches", username,
437           changedFields);
438     } else {
439       logger.debug("JWT user '{}' logged in with unchanged name, email, and roles", username);
440     }
441     userReferenceProvider.updateUserReference(userReference);
442     return changed;
443   }
444 
445   private boolean userReferenceChanged(JpaUserReference userReference, String name, String email, Set<JpaRole> roles) {
446     return !Objects.equals(userReference.getName(), name)
447         || !Objects.equals(userReference.getEmail(), email)
448         || !Objects.equals(userReference.getRoles(), roles);
449   }
450 
451   private String describeUserReferenceChanges(JpaUserReference userReference, String name, String email,
452       Set<JpaRole> roles) {
453     List<String> changedFields = new ArrayList<>();
454     if (!Objects.equals(userReference.getName(), name)) {
455       changedFields.add("name");
456     }
457     if (!Objects.equals(userReference.getEmail(), email)) {
458       changedFields.add("email");
459     }
460     if (!Objects.equals(userReference.getRoles(), roles)) {
461       changedFields.add("roles");
462     }
463     return String.join(", ", changedFields);
464   }
465 
466   private String abbreviateSignature(String signature) {
467     if (signature == null || signature.length() <= 12) {
468       return signature;
469     }
470     return signature.substring(0, 12) + "...";
471   }
472 
473   /**
474    * Converts a {@link Organization} object into a {@link JpaOrganization} object.
475    *
476    * @param org The {@link Organization} object.
477    * @return The corresponding {@link JpaOrganization} object.
478    */
479   private JpaOrganization fromOrganization(Organization org) {
480     if (org instanceof JpaOrganization) {
481       return (JpaOrganization) org;
482     }
483 
484     return new JpaOrganization(org.getId(), org.getName(), org.getServers(), org.getAdminRole(), org.getAnonymousRole(),
485         org.getProperties());
486   }
487 
488   /**
489    * Setter for the user details service.
490    *
491    * @param userDetailsService The user details service.
492    */
493   @Reference
494   public void setUserDetailsService(UserDetailsService userDetailsService) {
495     this.userDetailsService = userDetailsService;
496   }
497 
498   /**
499    * Setter for the user directory service.
500    *
501    * @param userDirectoryService The user directory service.
502    */
503   @Reference
504   public void setUserDirectoryService(UserDirectoryService userDirectoryService) {
505     this.userDirectoryService = userDirectoryService;
506   }
507 
508   /**
509    * Setter for the security service.
510    *
511    * @param securityService The security service.
512    */
513   @Reference
514   public void setSecurityService(SecurityService securityService) {
515     this.securityService = securityService;
516   }
517 
518   /**
519    * Setter for the user reference provider.
520    *
521    * @param userReferenceProvider The user reference provider.
522    */
523   @Reference
524   public void setUserReferenceProvider(UserReferenceProvider userReferenceProvider) {
525     this.userReferenceProvider = userReferenceProvider;
526   }
527 
528   /**
529    * Setter for the JWKS URL.
530    *
531    * @param jwksUrl The JWKS URL.
532    */
533   public void setJwksUrl(String jwksUrl) {
534     this.jwksUrl = jwksUrl;
535   }
536 
537   /**
538    * Setter for the JWKS ttl.
539    *
540    * @param jwksTimeToLive The default time to live of cached JWK sets, in milliseconds
541    */
542   public void setJwksTimeToLive(int jwksTimeToLive) {
543     this.jwksTimeToLive = jwksTimeToLive;
544   }
545 
546   /**
547    * Setter for the secret used for JWT validation.
548    *
549    * @param secret The secret.
550    */
551   public void setSecret(String secret) {
552     this.secret = secret;
553   }
554 
555   /**
556    * Setter for the expected algorithms.
557    *
558    * @param expectedAlgorithms The expected algorithms.
559    */
560   public void setExpectedAlgorithms(List<String> expectedAlgorithms) {
561     this.expectedAlgorithms = expectedAlgorithms;
562   }
563 
564   /**
565    * Setter for the claim constraints.
566    *
567    * @param claimConstraints The claim constraints.
568    */
569   public void setClaimConstraints(List<String> claimConstraints) {
570     this.claimConstraints = claimConstraints;
571   }
572 
573   /**
574    * Setter for the username mapping.
575    * @param usernameMapping The username mapping.
576    */
577   public void setUsernameMapping(String usernameMapping) {
578     this.usernameMapping = usernameMapping;
579   }
580 
581   /**
582    * Setter for the name mapping.
583    *
584    * @param nameMapping The name mapping.
585    */
586   public void setNameMapping(String nameMapping) {
587     this.nameMapping = nameMapping;
588   }
589 
590   /**
591    * Setter for the email mapping.
592    * @param emailMapping The email mapping.
593    */
594   public void setEmailMapping(String emailMapping) {
595     this.emailMapping = emailMapping;
596   }
597 
598   public void setOcStandardRoleMappings(boolean ocStandardRoleMappings) {
599     this.ocStandardRoleMappings = ocStandardRoleMappings;
600   }
601 
602   /**
603    * Setter for the role mappings.
604    *
605    * @param roleMappings The role mappings.
606    */
607   public void setRoleMappings(List<String> roleMappings) {
608     this.roleMappings = roleMappings;
609   }
610 
611   /**
612    * Setter for the JWT cache size.
613    *
614    * @param jwtCacheSize The JWT cache size.
615    */
616   public void setJwtCacheSize(int jwtCacheSize) {
617     this.jwtCacheSize = jwtCacheSize;
618   }
619 
620   /**
621    * Setter for the JWT cache expiration.
622    *
623    * @param jwtCacheExpiresIn The number of minutes after which a cached JWT expires.
624    */
625   public void setJwtCacheExpiresIn(int jwtCacheExpiresIn) {
626     this.jwtCacheExpiresIn = jwtCacheExpiresIn;
627   }
628 
629 }