View Javadoc
1   /*
2    * Licensed to The Apereo Foundation under one or more contributor license
3    * agreements. See the NOTICE file distributed with this work for additional
4    * information regarding copyright ownership.
5    *
6    *
7    * The Apereo Foundation licenses this file to you under the Educational
8    * Community License, Version 2.0 (the "License"); you may not use this file
9    * except in compliance with the License. You may obtain a copy of the License
10   * at:
11   *
12   *   http://opensource.org/licenses/ecl2.txt
13   *
14   * Unless required by applicable law or agreed to in writing, software
15   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
17   * License for the specific language governing permissions and limitations under
18   * the License.
19   *
20   */
21  
22  package org.opencastproject.security.jwt;
23  
24  import org.opencastproject.security.api.Organization;
25  import org.opencastproject.security.api.SecurityService;
26  import org.opencastproject.security.api.UserDirectoryService;
27  import org.opencastproject.security.impl.jpa.JpaOrganization;
28  import org.opencastproject.security.impl.jpa.JpaRole;
29  import org.opencastproject.security.impl.jpa.JpaUserReference;
30  import org.opencastproject.security.util.SecurityUtil;
31  import org.opencastproject.userdirectory.api.UserReferenceProvider;
32  
33  import com.google.common.cache.Cache;
34  import com.google.common.cache.CacheBuilder;
35  import com.nimbusds.jose.JOSEException;
36  import com.nimbusds.jwt.SignedJWT;
37  
38  import org.apache.commons.lang3.StringUtils;
39  import org.osgi.service.component.annotations.Reference;
40  import org.slf4j.Logger;
41  import org.slf4j.LoggerFactory;
42  import org.springframework.beans.factory.InitializingBean;
43  import org.springframework.context.expression.MapAccessor;
44  import org.springframework.expression.Expression;
45  import org.springframework.expression.ExpressionParser;
46  import org.springframework.expression.spel.standard.SpelExpressionParser;
47  import org.springframework.expression.spel.support.StandardEvaluationContext;
48  import org.springframework.security.core.userdetails.UserDetailsService;
49  import org.springframework.security.core.userdetails.UsernameNotFoundException;
50  import org.springframework.util.Assert;
51  
52  import java.text.ParseException;
53  import java.util.ArrayList;
54  import java.util.Date;
55  import java.util.HashSet;
56  import java.util.List;
57  import java.util.Set;
58  import java.util.concurrent.TimeUnit;
59  import java.util.function.Consumer;
60  
61  /**
62   * Dynamic login handler for JWTs.
63   */
64  public class DynamicLoginHandler implements InitializingBean, JWTLoginHandler {
65  
66    /** Logging facility. */
67    private static final Logger logger = LoggerFactory.getLogger(DynamicLoginHandler.class);
68  
69    /** Spring security's user details manager. */
70    private UserDetailsService userDetailsService = null;
71  
72    /** User directory service. */
73    private UserDirectoryService userDirectoryService = null;
74  
75    /** User reference provider. */
76    private UserReferenceProvider userReferenceProvider = null;
77  
78    /** Security service. */
79    private SecurityService securityService = null;
80  
81    /** JWKS URL to use for JWT validation (asymmetric algorithms). */
82    private String jwksUrl = null;
83  
84    /** The default time to live of cached JWK sets, in milliseconds. */
85    private int jwksTimeToLive = 1000 * 60 * 60;
86  
87    /** The default refresh timeout of cached JWK sets, in milliseconds. */
88    private int jwksRefreshTimeout = 1000 * 60;
89  
90    /** Secret to use for JWT validation (symmetric algorithms). */
91    private String secret = null;
92  
93    /** Allowed algorithms with which a valid JWT may be signed ('alg' claim). */
94    private List<String> expectedAlgorithms = null;
95  
96    /** Constraints that the claims of a valid JWT must fulfill. */
97    private List<String> claimConstraints = null;
98  
99    /** Mapping used to extract the username from the JWT. */
100   private String usernameMapping = null;
101 
102   /** Mapping used to extract the name from the JWT. */
103   private String nameMapping = null;
104 
105   /** Mapping used to extract the email from the JWT. */
106   private String emailMapping = null;
107 
108   /** Whether the built-in schema role mapping should be used. */
109   private boolean ocStandardRoleMappings = true;
110 
111   /** Mapping used to extract roles from the JWT. */
112   private List<String> roleMappings = null;
113 
114   /** Mapping used to extract roles from the JWT. */
115   private JWKSetProvider jwkProvider;
116 
117   /** Size of the JWT cache. */
118   private int jwtCacheSize = 500;
119 
120   /** Number of minutes validated JWTs will be cached before re-validating them. */
121   private int jwtCacheExpiresIn = 60;
122 
123   /** Cache for validated JWTs. */
124   private Cache<String, CachedJWT> cache;
125 
126   @Override
127   public void afterPropertiesSet() {
128     Assert.notNull(userDetailsService, "A UserDetailsService must be set");
129     Assert.notNull(userDirectoryService, "A UserDirectoryService must be set");
130     Assert.notNull(userReferenceProvider, "A UserReferenceProvider must be set");
131     Assert.notNull(securityService, "A SecurityService must be set");
132     Assert.isTrue(StringUtils.isNotBlank(jwksUrl) ^ StringUtils.isNotBlank(secret),
133         "Either a JWKS URL or a secret must be set");
134     Assert.notEmpty(expectedAlgorithms, "Expected algorithms must be set");
135     Assert.notEmpty(claimConstraints, "Claim constraints must be set");
136     Assert.notNull(usernameMapping, "User name mapping must be set");
137     Assert.notNull(nameMapping, "Name mapping must be set");
138     Assert.notNull(emailMapping, "Email mapping must be set");
139     Assert.isTrue(roleMappings != null || ocStandardRoleMappings,
140             "Role mappings must be set if ocStandardRoleMappings is false");
141 
142     if (jwksUrl != null) {
143       jwkProvider = new JWKSetProvider(jwksUrl, jwksTimeToLive, jwksRefreshTimeout);
144     }
145     userReferenceProvider.setRoleProvider(new JWTRoleProvider(securityService, userReferenceProvider));
146     cache = CacheBuilder.newBuilder()
147         .maximumSize(jwtCacheSize)
148         .expireAfterWrite(jwtCacheExpiresIn, TimeUnit.MINUTES)
149         .build();
150   }
151 
152   @Override
153   public String handleToken(String token) {
154     try {
155       String signature = extractSignature(token);
156       CachedJWT cachedJwt = cache.getIfPresent(signature);
157 
158       if (cachedJwt == null) {
159         // JWT hasn't been cached before, so validate all claims
160         SignedJWT jwt = decodeAndValidate(token);
161         String username = extractUsername(jwt);
162 
163         try {
164           if (userDetailsService.loadUserByUsername(username) != null) {
165             existingUserLogin(username, jwt);
166           }
167         } catch (UsernameNotFoundException e) {
168           newUserLogin(username, jwt);
169         }
170 
171         userDirectoryService.invalidate(username);
172         cache.put(jwt.getSignature().toString(), new CachedJWT(jwt, username));
173         return username;
174       } else {
175         // JWT has been cached before, so only check if it has expired
176         if (cachedJwt.hasExpired()) {
177           cache.invalidate(signature);
178           throw new JOSEException("JWT token is not valid anymore");
179         }
180         logger.debug("Using decoded and validated JWT from cache");
181         return cachedJwt.getUsername();
182       }
183     } catch (ParseException | JOSEException exception) {
184       logger.debug(exception.getMessage());
185     }
186 
187     return null;
188   }
189 
190   /**
191    * Decodes and validates a JWT.
192    *
193    * @param token The JWT string.
194    * @return The decoded JWT.
195    * @throws JOSEException If the JWT fails to be validated.
196    */
197   private SignedJWT decodeAndValidate(String token) throws ParseException, JOSEException {
198     SignedJWT jwt;
199 
200     if (jwksUrl != null) {
201       jwt = JWTVerifier.verify(token, jwkProvider, claimConstraints);
202     } else {
203       jwt = JWTVerifier.verify(token, secret, claimConstraints);
204     }
205 
206     if (!expectedAlgorithms.contains(jwt.getHeader().getAlgorithm().getName())) {
207       throw new JOSEException(
208           "JWT token was signed with an unexpected algorithm '" + jwt.getHeader().getAlgorithm() + "'"
209       );
210     }
211 
212     return jwt;
213   }
214 
215   /**
216    * Extracts the signature from a JWT.
217    *
218    * @param token The JWT string.
219    * @return The JWT's signature.
220    */
221   private String extractSignature(String token) throws JOSEException {
222     String[] parts = token.split("\\.");
223     if (parts.length != 3) {
224       throw new JOSEException("Given token is not in a valid JWT format");
225     }
226     return parts[2];
227   }
228 
229   /**
230    * Extracts the username from a decoded and validated JWT.
231    *
232    * @param jwt The decoded JWT.
233    * @return The username.
234    */
235   private String extractUsername(SignedJWT jwt) throws ParseException {
236     String username = evaluateMapping(jwt, usernameMapping);
237     Assert.isTrue(StringUtils.isNotBlank(username), "Extracted username is blank");
238     return username;
239   }
240 
241   /**
242    * Extracts the name from a decoded and validated JWT.
243    *
244    * @param jwt The decoded JWT.
245    * @return The name.
246    */
247   private String extractName(SignedJWT jwt) throws ParseException {
248     String name = evaluateMapping(jwt, nameMapping);
249     Assert.isTrue(StringUtils.isNotBlank(name), "Extracted name is blank");
250     return name;
251   }
252 
253   /**
254    * Extracts the email from a decoded and validated JWT.
255    *
256    * @param jwt The decoded JWT.
257    * @return The email.
258    */
259   private String extractEmail(SignedJWT jwt) throws ParseException {
260     String email = evaluateMapping(jwt, emailMapping);
261     Assert.isTrue(StringUtils.isNotBlank(email), "Extracted email is blank");
262     return email;
263   }
264 
265   /**
266    * Extracts the roles from a decoded and validated JWT.
267    *
268    * @param jwt The decoded JWT.
269    * @return The roles.
270    */
271   private Set<JpaRole> extractRoles(SignedJWT jwt) throws ParseException {
272     JpaOrganization organization = fromOrganization(securityService.getOrganization());
273     Set<JpaRole> roles = new HashSet<>();
274     Consumer<String> addRole = (String role) -> {
275       if (StringUtils.isNotBlank(role)) {
276         roles.add(new JpaRole(role, organization));
277       }
278     };
279 
280     // Evaluate the standard role mapping if specified
281     if (ocStandardRoleMappings) {
282       // Read `role` claim
283       try {
284         var rolesClaim = jwt.getJWTClaimsSet().getStringArrayClaim("roles");
285         if (rolesClaim != null) {
286           for (String r : rolesClaim) {
287             addRole.accept(r);
288           }
289         }
290       } catch (ParseException e) {
291         logger.debug("claim 'roles' is not an array of strings, ignoring");
292       }
293 
294       // Read `oc` claim
295       try {
296         var ocClaim = jwt.getJWTClaimsSet().getJSONObjectClaim("oc");
297         if (ocClaim != null) {
298           for (var entry : ocClaim.entrySet()) {
299             var key = entry.getKey();
300             var parts = key.split(":", 2);
301             if (parts.length != 2) {
302               logger.debug("key in 'oc' claim does not start with 'x:' -> ignoring");
303               continue;
304             }
305             var type = parts[0];
306             var id = parts[1];
307 
308             try {
309               for (var actionObj : (List<?>) entry.getValue()) {
310                 var action = (String) actionObj;
311                 if (action.isBlank()) {
312                   continue;
313                 }
314 
315                 if (type.equals("e")) {
316                   addRole.accept(SecurityUtil.getEpisodeRoleId(id, action));
317                 } else {
318                   logger.debug("in 'oc' claim: granting access to item type '{}' is not yet supported", type);
319                 }
320               }
321             } catch (ClassCastException e) {
322               logger.debug("value in 'oc' claim is not a string array -> ignoring");
323               continue;
324             }
325           }
326         }
327       } catch (ParseException e) {
328         logger.debug("claim 'oc' is not an array of strings, ignoring");
329       }
330     }
331 
332     for (String mapping : (roleMappings == null ? new ArrayList<String>() : roleMappings)) {
333       ExpressionParser parser = new SpelExpressionParser();
334       Expression exp = parser.parseExpression(mapping);
335       StandardEvaluationContext ctx = new StandardEvaluationContext();
336       ctx.addPropertyAccessor(new MapAccessor());
337       Object value = exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims());
338       if (value != null) {
339         // We allow the expression to either return a string directly, or a list/array of strings.
340         if (value instanceof String) {
341           addRole.accept((String) value);
342         } else if (value.getClass().isArray()) {
343           for (var role : (Object[]) value) {
344             addRole.accept((String) role);
345           }
346         } else {
347           for (var role : (List<?>) value) {
348             addRole.accept((String) role);
349           }
350         }
351       }
352     }
353     Assert.notEmpty(roles, "No roles could be extracted");
354     return roles;
355   }
356 
357   /**
358    * Evaluates a mapping given in SpEL on a decoded JWT.
359    *
360    * @param jwt The decoded JWT.
361    * @param mapping The mapping.
362    *
363    * @return The string evaluated from the mapping.
364    */
365   private String evaluateMapping(SignedJWT jwt, String mapping) throws ParseException {
366     ExpressionParser parser = new SpelExpressionParser();
367     Expression exp = parser.parseExpression(mapping);
368     StandardEvaluationContext ctx = new StandardEvaluationContext();
369     ctx.addPropertyAccessor(new MapAccessor());
370     return exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims(), String.class);
371   }
372 
373   /**
374    * Handles a new user login.
375    *
376    * @param username The username.
377    * @param jwt The decoded JWT.
378    */
379   public void newUserLogin(String username, SignedJWT jwt) throws ParseException {
380     // Create a new user reference
381     JpaUserReference userReference = new JpaUserReference(username, extractName(jwt), extractEmail(jwt), MECH_JWT,
382         new Date(), fromOrganization(securityService.getOrganization()), extractRoles(jwt));
383 
384     logger.debug("JWT user '{}' logged in for the first time", username);
385     userReferenceProvider.addUserReference(userReference, MECH_JWT);
386   }
387 
388   /**
389    * Handles an existing user login.
390    *
391    * @param username The username.
392    * @param jwt The decoded JWT.
393    */
394   public void existingUserLogin(String username, SignedJWT jwt) throws ParseException {
395     Organization organization = securityService.getOrganization();
396 
397     // Load the user reference
398     JpaUserReference userReference = userReferenceProvider.findUserReference(username, organization.getId());
399     if (userReference == null) {
400       throw new UsernameNotFoundException("User reference '" + username + "' was not found");
401     }
402 
403     // Update the reference
404     userReference.setName(extractName(jwt));
405     userReference.setEmail(extractEmail(jwt));
406     userReference.setLastLogin(new Date());
407     userReference.setRoles(extractRoles(jwt));
408 
409     logger.debug("JWT user '{}' logged in", username);
410     userReferenceProvider.updateUserReference(userReference);
411   }
412 
413   /**
414    * Converts a {@link Organization} object into a {@link JpaOrganization} object.
415    *
416    * @param org The {@link Organization} object.
417    * @return The corresponding {@link JpaOrganization} object.
418    */
419   private JpaOrganization fromOrganization(Organization org) {
420     if (org instanceof JpaOrganization) {
421       return (JpaOrganization) org;
422     }
423 
424     return new JpaOrganization(org.getId(), org.getName(), org.getServers(), org.getAdminRole(), org.getAnonymousRole(),
425         org.getProperties());
426   }
427 
428   /**
429    * Setter for the user details service.
430    *
431    * @param userDetailsService The user details service.
432    */
433   @Reference
434   public void setUserDetailsService(UserDetailsService userDetailsService) {
435     this.userDetailsService = userDetailsService;
436   }
437 
438   /**
439    * Setter for the user directory service.
440    *
441    * @param userDirectoryService The user directory service.
442    */
443   @Reference
444   public void setUserDirectoryService(UserDirectoryService userDirectoryService) {
445     this.userDirectoryService = userDirectoryService;
446   }
447 
448   /**
449    * Setter for the security service.
450    *
451    * @param securityService The security service.
452    */
453   @Reference
454   public void setSecurityService(SecurityService securityService) {
455     this.securityService = securityService;
456   }
457 
458   /**
459    * Setter for the user reference provider.
460    *
461    * @param userReferenceProvider The user reference provider.
462    */
463   @Reference
464   public void setUserReferenceProvider(UserReferenceProvider userReferenceProvider) {
465     this.userReferenceProvider = userReferenceProvider;
466   }
467 
468   /**
469    * Setter for the JWKS URL.
470    *
471    * @param jwksUrl The JWKS URL.
472    */
473   public void setJwksUrl(String jwksUrl) {
474     this.jwksUrl = jwksUrl;
475   }
476 
477   /**
478    * Setter for the JWKS ttl.
479    *
480    * @param jwksTimeToLive The default time to live of cached JWK sets, in milliseconds
481    */
482   public void setJwksTimeToLive(int jwksTimeToLive) {
483     this.jwksTimeToLive = jwksTimeToLive;
484   }
485 
486   /**
487    * Setter for the secret used for JWT validation.
488    *
489    * @param secret The secret.
490    */
491   public void setSecret(String secret) {
492     this.secret = secret;
493   }
494 
495   /**
496    * Setter for the expected algorithms.
497    *
498    * @param expectedAlgorithms The expected algorithms.
499    */
500   public void setExpectedAlgorithms(List<String> expectedAlgorithms) {
501     this.expectedAlgorithms = expectedAlgorithms;
502   }
503 
504   /**
505    * Setter for the claim constraints.
506    *
507    * @param claimConstraints The claim constraints.
508    */
509   public void setClaimConstraints(List<String> claimConstraints) {
510     this.claimConstraints = claimConstraints;
511   }
512 
513   /**
514    * Setter for the username mapping.
515    * @param usernameMapping The username mapping.
516    */
517   public void setUsernameMapping(String usernameMapping) {
518     this.usernameMapping = usernameMapping;
519   }
520 
521   /**
522    * Setter for the name mapping.
523    *
524    * @param nameMapping The name mapping.
525    */
526   public void setNameMapping(String nameMapping) {
527     this.nameMapping = nameMapping;
528   }
529 
530   /**
531    * Setter for the email mapping.
532    * @param emailMapping The email mapping.
533    */
534   public void setEmailMapping(String emailMapping) {
535     this.emailMapping = emailMapping;
536   }
537 
538   public void setOcStandardRoleMappings(boolean ocStandardRoleMappings) {
539     this.ocStandardRoleMappings = ocStandardRoleMappings;
540   }
541 
542   /**
543    * Setter for the role mappings.
544    *
545    * @param roleMappings The role mappings.
546    */
547   public void setRoleMappings(List<String> roleMappings) {
548     this.roleMappings = roleMappings;
549   }
550 
551   /**
552    * Setter for the JWT cache size.
553    *
554    * @param jwtCacheSize The JWT cache size.
555    */
556   public void setJwtCacheSize(int jwtCacheSize) {
557     this.jwtCacheSize = jwtCacheSize;
558   }
559 
560   /**
561    * Setter for the JWT cache expiration.
562    *
563    * @param jwtCacheExpiresIn The number of minutes after which a cached JWT expires.
564    */
565   public void setJwtCacheExpiresIn(int jwtCacheExpiresIn) {
566     this.jwtCacheExpiresIn = jwtCacheExpiresIn;
567   }
568 
569 }