View Javadoc
1   /*
2    * Licensed to The Apereo Foundation under one or more contributor license
3    * agreements. See the NOTICE file distributed with this work for additional
4    * information regarding copyright ownership.
5    *
6    *
7    * The Apereo Foundation licenses this file to you under the Educational
8    * Community License, Version 2.0 (the "License"); you may not use this file
9    * except in compliance with the License. You may obtain a copy of the License
10   * at:
11   *
12   *   http://opensource.org/licenses/ecl2.txt
13   *
14   * Unless required by applicable law or agreed to in writing, software
15   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
17   * License for the specific language governing permissions and limitations under
18   * the License.
19   *
20   */
21  
22  package org.opencastproject.security.jwt;
23  
24  import org.opencastproject.security.api.Organization;
25  import org.opencastproject.security.api.SecurityService;
26  import org.opencastproject.security.api.UserDirectoryService;
27  import org.opencastproject.security.impl.jpa.JpaOrganization;
28  import org.opencastproject.security.impl.jpa.JpaRole;
29  import org.opencastproject.security.impl.jpa.JpaUserReference;
30  import org.opencastproject.security.util.SecurityUtil;
31  import org.opencastproject.userdirectory.api.UserReferenceProvider;
32  
33  import com.google.common.cache.Cache;
34  import com.google.common.cache.CacheBuilder;
35  import com.nimbusds.jose.JOSEException;
36  import com.nimbusds.jwt.SignedJWT;
37  
38  import org.apache.commons.lang3.StringUtils;
39  import org.osgi.service.component.annotations.Reference;
40  import org.slf4j.Logger;
41  import org.slf4j.LoggerFactory;
42  import org.springframework.beans.factory.InitializingBean;
43  import org.springframework.context.expression.MapAccessor;
44  import org.springframework.expression.Expression;
45  import org.springframework.expression.ExpressionParser;
46  import org.springframework.expression.spel.standard.SpelExpressionParser;
47  import org.springframework.expression.spel.support.StandardEvaluationContext;
48  import org.springframework.security.core.userdetails.UserDetailsService;
49  import org.springframework.security.core.userdetails.UsernameNotFoundException;
50  import org.springframework.util.Assert;
51  
52  import java.text.ParseException;
53  import java.util.ArrayList;
54  import java.util.Date;
55  import java.util.HashSet;
56  import java.util.List;
57  import java.util.Set;
58  import java.util.concurrent.TimeUnit;
59  import java.util.function.Consumer;
60  
61  /**
62   * Dynamic login handler for JWTs.
63   */
64  public class DynamicLoginHandler implements InitializingBean, JWTLoginHandler {
65  
66    /** Logging facility. */
67    private static final Logger logger = LoggerFactory.getLogger(DynamicLoginHandler.class);
68  
69    /** Spring security's user details manager. */
70    private UserDetailsService userDetailsService = null;
71  
72    /** User directory service. */
73    private UserDirectoryService userDirectoryService = null;
74  
75    /** User reference provider. */
76    private UserReferenceProvider userReferenceProvider = null;
77  
78    /** Security service. */
79    private SecurityService securityService = null;
80  
81    /** JWKS URL to use for JWT validation (asymmetric algorithms). */
82    private String jwksUrl = null;
83  
84    /** The default time to live of cached JWK sets, in milliseconds. */
85    private int jwksTimeToLive = 1000 * 60 * 60;
86  
87    /** The default refresh timeout of cached JWK sets, in milliseconds. */
88    private int jwksRefreshTimeout = 1000 * 60;
89  
90    /** Secret to use for JWT validation (symmetric algorithms). */
91    private String secret = null;
92  
93    /** Allowed algorithms with which a valid JWT may be signed ('alg' claim). */
94    private List<String> expectedAlgorithms = null;
95  
96    /** Constraints that the claims of a valid JWT must fulfill. */
97    private List<String> claimConstraints = null;
98  
99    /** Mapping used to extract the username from the JWT. */
100   private String usernameMapping = null;
101 
102   /** Mapping used to extract the name from the JWT. */
103   private String nameMapping = null;
104 
105   /** Mapping used to extract the email from the JWT. */
106   private String emailMapping = null;
107 
108   /** Whether the built-in schema role mapping should be used. */
109   private boolean ocStandardRoleMappings = true;
110 
111   /** Mapping used to extract roles from the JWT. */
112   private List<String> roleMappings = null;
113 
114   /** Mapping used to extract roles from the JWT. */
115   private JWKSetProvider jwkProvider;
116 
117   /** Size of the JWT cache. */
118   private int jwtCacheSize = 500;
119 
120   /** Number of minutes validated JWTs will be cached before re-validating them. */
121   private int jwtCacheExpiresIn = 60;
122 
123   /** Cache for validated JWTs. */
124   private Cache<String, CachedJWT> cache;
125 
126   @Override
127   public void afterPropertiesSet() {
128     Assert.notNull(userDetailsService, "A UserDetailsService must be set");
129     Assert.notNull(userDirectoryService, "A UserDirectoryService must be set");
130     Assert.notNull(userReferenceProvider, "A UserReferenceProvider must be set");
131     Assert.notNull(securityService, "A SecurityService must be set");
132     Assert.isTrue(!(StringUtils.isNotBlank(jwksUrl) && StringUtils.isNotBlank(secret)),
133         "A JWKS URL and a secret cannot be set at the same time");
134     Assert.notEmpty(expectedAlgorithms, "Expected algorithms must be set");
135     Assert.notNull(claimConstraints, "Claim constraints must be set");
136     Assert.notNull(usernameMapping, "User name mapping must be set");
137     Assert.notNull(nameMapping, "Name mapping must be set");
138     Assert.notNull(emailMapping, "Email mapping must be set");
139     Assert.isTrue(roleMappings != null || ocStandardRoleMappings,
140             "Role mappings must be set if ocStandardRoleMappings is false");
141 
142     if (StringUtils.isBlank(jwksUrl) && StringUtils.isBlank(secret)) {
143       logger.info("JWT login handler disabled as neither 'jwksUrl' nor 'secret' are set");
144     }
145 
146     if (jwksUrl != null) {
147       jwkProvider = new JWKSetProvider(jwksUrl, jwksTimeToLive, jwksRefreshTimeout);
148     }
149     userReferenceProvider.setRoleProvider(new JWTRoleProvider(securityService, userReferenceProvider));
150     cache = CacheBuilder.newBuilder()
151         .maximumSize(jwtCacheSize)
152         .expireAfterWrite(jwtCacheExpiresIn, TimeUnit.MINUTES)
153         .build();
154   }
155 
156   @Override
157   public String handleToken(String token) {
158     if (jwkProvider == null && secret == null) {
159       logger.debug("neither jwksURL nor secret set: ignoring JWT");
160       return null;
161     }
162 
163     try {
164       String signature = extractSignature(token);
165       CachedJWT cachedJwt = cache.getIfPresent(signature);
166 
167       if (cachedJwt == null) {
168         // JWT hasn't been cached before, so validate all claims
169         SignedJWT jwt = decodeAndValidate(token);
170         String username = extractUsername(jwt);
171 
172         try {
173           if (userDetailsService.loadUserByUsername(username) != null) {
174             existingUserLogin(username, jwt);
175           }
176         } catch (UsernameNotFoundException e) {
177           newUserLogin(username, jwt);
178         }
179 
180         userDirectoryService.invalidate(username);
181         cache.put(jwt.getSignature().toString(), new CachedJWT(jwt, username));
182         return username;
183       } else {
184         // JWT has been cached before, so only check if it has expired
185         if (cachedJwt.hasExpired()) {
186           cache.invalidate(signature);
187           throw new JOSEException("JWT token is not valid anymore");
188         }
189         logger.debug("Using decoded and validated JWT from cache");
190         return cachedJwt.getUsername();
191       }
192     } catch (ParseException | JOSEException exception) {
193       logger.debug(exception.getMessage());
194     }
195 
196     return null;
197   }
198 
199   /**
200    * Decodes and validates a JWT.
201    *
202    * @param token The JWT string.
203    * @return The decoded JWT.
204    * @throws JOSEException If the JWT fails to be validated.
205    */
206   private SignedJWT decodeAndValidate(String token) throws ParseException, JOSEException {
207     SignedJWT jwt;
208 
209     if (jwksUrl != null) {
210       jwt = JWTVerifier.verify(token, jwkProvider, claimConstraints);
211     } else {
212       jwt = JWTVerifier.verify(token, secret, claimConstraints);
213     }
214 
215     if (!expectedAlgorithms.contains(jwt.getHeader().getAlgorithm().getName())) {
216       throw new JOSEException(
217           "JWT token was signed with an unexpected algorithm '" + jwt.getHeader().getAlgorithm() + "'"
218       );
219     }
220 
221     return jwt;
222   }
223 
224   /**
225    * Extracts the signature from a JWT.
226    *
227    * @param token The JWT string.
228    * @return The JWT's signature.
229    */
230   private String extractSignature(String token) throws JOSEException {
231     String[] parts = token.split("\\.");
232     if (parts.length != 3) {
233       throw new JOSEException("Given token is not in a valid JWT format");
234     }
235     return parts[2];
236   }
237 
238   /**
239    * Extracts the username from a decoded and validated JWT.
240    *
241    * @param jwt The decoded JWT.
242    * @return The username.
243    */
244   private String extractUsername(SignedJWT jwt) throws ParseException {
245     String username = evaluateMapping(jwt, usernameMapping);
246     Assert.isTrue(StringUtils.isNotBlank(username), "Extracted username is blank");
247     return username;
248   }
249 
250   /**
251    * Extracts the name from a decoded and validated JWT.
252    *
253    * @param jwt The decoded JWT.
254    * @return The name.
255    */
256   private String extractName(SignedJWT jwt) throws ParseException {
257     String name = evaluateMapping(jwt, nameMapping);
258     Assert.isTrue(StringUtils.isNotBlank(name), "Extracted name is blank");
259     return name;
260   }
261 
262   /**
263    * Extracts the email from a decoded and validated JWT.
264    *
265    * @param jwt The decoded JWT.
266    * @return The email.
267    */
268   private String extractEmail(SignedJWT jwt) throws ParseException {
269     String email = evaluateMapping(jwt, emailMapping);
270     Assert.isTrue(StringUtils.isNotBlank(email), "Extracted email is blank");
271     return email;
272   }
273 
274   /**
275    * Extracts the roles from a decoded and validated JWT.
276    *
277    * @param jwt The decoded JWT.
278    * @return The roles.
279    */
280   private Set<JpaRole> extractRoles(SignedJWT jwt) throws ParseException {
281     JpaOrganization organization = fromOrganization(securityService.getOrganization());
282     Set<JpaRole> roles = new HashSet<>();
283     Consumer<String> addRole = (String role) -> {
284       if (StringUtils.isNotBlank(role)) {
285         roles.add(new JpaRole(role, organization));
286       }
287     };
288 
289     // Evaluate the standard role mapping if specified
290     if (ocStandardRoleMappings) {
291       // Read `role` claim
292       try {
293         var rolesClaim = jwt.getJWTClaimsSet().getStringArrayClaim("roles");
294         if (rolesClaim != null) {
295           for (String r : rolesClaim) {
296             addRole.accept(r);
297           }
298         }
299       } catch (ParseException e) {
300         logger.debug("claim 'roles' is not an array of strings, ignoring");
301       }
302 
303       // Read `oc` claim
304       try {
305         var ocClaim = jwt.getJWTClaimsSet().getJSONObjectClaim("oc");
306         if (ocClaim != null) {
307           for (var entry : ocClaim.entrySet()) {
308             var key = entry.getKey();
309             var parts = key.split(":", 2);
310             if (parts.length != 2) {
311               logger.debug("key in 'oc' claim does not start with 'x:' -> ignoring");
312               continue;
313             }
314             var type = parts[0];
315             var id = parts[1];
316 
317             try {
318               for (var actionObj : (List<?>) entry.getValue()) {
319                 var action = (String) actionObj;
320                 if (action.isBlank()) {
321                   continue;
322                 }
323 
324                 if (type.equals("e")) {
325                   addRole.accept(SecurityUtil.getEpisodeRoleId(id, action));
326                 } else {
327                   logger.debug("in 'oc' claim: granting access to item type '{}' is not yet supported", type);
328                 }
329               }
330             } catch (ClassCastException e) {
331               logger.debug("value in 'oc' claim is not a string array -> ignoring");
332               continue;
333             }
334           }
335         }
336       } catch (ParseException e) {
337         logger.debug("claim 'oc' is not an array of strings, ignoring");
338       }
339     }
340 
341     for (String mapping : (roleMappings == null ? new ArrayList<String>() : roleMappings)) {
342       ExpressionParser parser = new SpelExpressionParser();
343       Expression exp = parser.parseExpression(mapping);
344       StandardEvaluationContext ctx = new StandardEvaluationContext();
345       ctx.addPropertyAccessor(new MapAccessor());
346       Object value = exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims());
347       if (value != null) {
348         // We allow the expression to either return a string directly, or a list/array of strings.
349         if (value instanceof String) {
350           addRole.accept((String) value);
351         } else if (value.getClass().isArray()) {
352           for (var role : (Object[]) value) {
353             addRole.accept((String) role);
354           }
355         } else {
356           for (var role : (List<?>) value) {
357             addRole.accept((String) role);
358           }
359         }
360       }
361     }
362     Assert.notEmpty(roles, "No roles could be extracted");
363     return roles;
364   }
365 
366   /**
367    * Evaluates a mapping given in SpEL on a decoded JWT.
368    *
369    * @param jwt The decoded JWT.
370    * @param mapping The mapping.
371    *
372    * @return The string evaluated from the mapping.
373    */
374   private String evaluateMapping(SignedJWT jwt, String mapping) throws ParseException {
375     ExpressionParser parser = new SpelExpressionParser();
376     Expression exp = parser.parseExpression(mapping);
377     StandardEvaluationContext ctx = new StandardEvaluationContext();
378     ctx.addPropertyAccessor(new MapAccessor());
379     return exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims(), String.class);
380   }
381 
382   /**
383    * Handles a new user login.
384    *
385    * @param username The username.
386    * @param jwt The decoded JWT.
387    */
388   public void newUserLogin(String username, SignedJWT jwt) throws ParseException {
389     // Create a new user reference
390     JpaUserReference userReference = new JpaUserReference(username, extractName(jwt), extractEmail(jwt), MECH_JWT,
391         new Date(), fromOrganization(securityService.getOrganization()), extractRoles(jwt));
392 
393     logger.debug("JWT user '{}' logged in for the first time", username);
394     userReferenceProvider.addUserReference(userReference, MECH_JWT);
395   }
396 
397   /**
398    * Handles an existing user login.
399    *
400    * @param username The username.
401    * @param jwt The decoded JWT.
402    */
403   public void existingUserLogin(String username, SignedJWT jwt) throws ParseException {
404     Organization organization = securityService.getOrganization();
405 
406     // Load the user reference
407     JpaUserReference userReference = userReferenceProvider.findUserReference(username, organization.getId());
408     if (userReference == null) {
409       throw new UsernameNotFoundException("User reference '" + username + "' was not found");
410     }
411 
412     // Update the reference
413     userReference.setName(extractName(jwt));
414     userReference.setEmail(extractEmail(jwt));
415     userReference.setLastLogin(new Date());
416     userReference.setRoles(extractRoles(jwt));
417 
418     logger.debug("JWT user '{}' logged in", username);
419     userReferenceProvider.updateUserReference(userReference);
420   }
421 
422   /**
423    * Converts a {@link Organization} object into a {@link JpaOrganization} object.
424    *
425    * @param org The {@link Organization} object.
426    * @return The corresponding {@link JpaOrganization} object.
427    */
428   private JpaOrganization fromOrganization(Organization org) {
429     if (org instanceof JpaOrganization) {
430       return (JpaOrganization) org;
431     }
432 
433     return new JpaOrganization(org.getId(), org.getName(), org.getServers(), org.getAdminRole(), org.getAnonymousRole(),
434         org.getProperties());
435   }
436 
437   /**
438    * Setter for the user details service.
439    *
440    * @param userDetailsService The user details service.
441    */
442   @Reference
443   public void setUserDetailsService(UserDetailsService userDetailsService) {
444     this.userDetailsService = userDetailsService;
445   }
446 
447   /**
448    * Setter for the user directory service.
449    *
450    * @param userDirectoryService The user directory service.
451    */
452   @Reference
453   public void setUserDirectoryService(UserDirectoryService userDirectoryService) {
454     this.userDirectoryService = userDirectoryService;
455   }
456 
457   /**
458    * Setter for the security service.
459    *
460    * @param securityService The security service.
461    */
462   @Reference
463   public void setSecurityService(SecurityService securityService) {
464     this.securityService = securityService;
465   }
466 
467   /**
468    * Setter for the user reference provider.
469    *
470    * @param userReferenceProvider The user reference provider.
471    */
472   @Reference
473   public void setUserReferenceProvider(UserReferenceProvider userReferenceProvider) {
474     this.userReferenceProvider = userReferenceProvider;
475   }
476 
477   /**
478    * Setter for the JWKS URL.
479    *
480    * @param jwksUrl The JWKS URL.
481    */
482   public void setJwksUrl(String jwksUrl) {
483     this.jwksUrl = jwksUrl;
484   }
485 
486   /**
487    * Setter for the JWKS ttl.
488    *
489    * @param jwksTimeToLive The default time to live of cached JWK sets, in milliseconds
490    */
491   public void setJwksTimeToLive(int jwksTimeToLive) {
492     this.jwksTimeToLive = jwksTimeToLive;
493   }
494 
495   /**
496    * Setter for the secret used for JWT validation.
497    *
498    * @param secret The secret.
499    */
500   public void setSecret(String secret) {
501     this.secret = secret;
502   }
503 
504   /**
505    * Setter for the expected algorithms.
506    *
507    * @param expectedAlgorithms The expected algorithms.
508    */
509   public void setExpectedAlgorithms(List<String> expectedAlgorithms) {
510     this.expectedAlgorithms = expectedAlgorithms;
511   }
512 
513   /**
514    * Setter for the claim constraints.
515    *
516    * @param claimConstraints The claim constraints.
517    */
518   public void setClaimConstraints(List<String> claimConstraints) {
519     this.claimConstraints = claimConstraints;
520   }
521 
522   /**
523    * Setter for the username mapping.
524    * @param usernameMapping The username mapping.
525    */
526   public void setUsernameMapping(String usernameMapping) {
527     this.usernameMapping = usernameMapping;
528   }
529 
530   /**
531    * Setter for the name mapping.
532    *
533    * @param nameMapping The name mapping.
534    */
535   public void setNameMapping(String nameMapping) {
536     this.nameMapping = nameMapping;
537   }
538 
539   /**
540    * Setter for the email mapping.
541    * @param emailMapping The email mapping.
542    */
543   public void setEmailMapping(String emailMapping) {
544     this.emailMapping = emailMapping;
545   }
546 
547   public void setOcStandardRoleMappings(boolean ocStandardRoleMappings) {
548     this.ocStandardRoleMappings = ocStandardRoleMappings;
549   }
550 
551   /**
552    * Setter for the role mappings.
553    *
554    * @param roleMappings The role mappings.
555    */
556   public void setRoleMappings(List<String> roleMappings) {
557     this.roleMappings = roleMappings;
558   }
559 
560   /**
561    * Setter for the JWT cache size.
562    *
563    * @param jwtCacheSize The JWT cache size.
564    */
565   public void setJwtCacheSize(int jwtCacheSize) {
566     this.jwtCacheSize = jwtCacheSize;
567   }
568 
569   /**
570    * Setter for the JWT cache expiration.
571    *
572    * @param jwtCacheExpiresIn The number of minutes after which a cached JWT expires.
573    */
574   public void setJwtCacheExpiresIn(int jwtCacheExpiresIn) {
575     this.jwtCacheExpiresIn = jwtCacheExpiresIn;
576   }
577 
578 }