1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.opencastproject.security.jwt;
23
24 import org.opencastproject.security.api.Organization;
25 import org.opencastproject.security.api.SecurityService;
26 import org.opencastproject.security.api.UserDirectoryService;
27 import org.opencastproject.security.impl.jpa.JpaOrganization;
28 import org.opencastproject.security.impl.jpa.JpaRole;
29 import org.opencastproject.security.impl.jpa.JpaUserReference;
30 import org.opencastproject.security.util.SecurityUtil;
31 import org.opencastproject.userdirectory.api.UserReferenceProvider;
32
33 import com.google.common.cache.Cache;
34 import com.google.common.cache.CacheBuilder;
35 import com.nimbusds.jose.JOSEException;
36 import com.nimbusds.jwt.SignedJWT;
37
38 import org.apache.commons.lang3.StringUtils;
39 import org.osgi.service.component.annotations.Reference;
40 import org.slf4j.Logger;
41 import org.slf4j.LoggerFactory;
42 import org.springframework.beans.factory.InitializingBean;
43 import org.springframework.context.expression.MapAccessor;
44 import org.springframework.expression.Expression;
45 import org.springframework.expression.ExpressionParser;
46 import org.springframework.expression.spel.standard.SpelExpressionParser;
47 import org.springframework.expression.spel.support.StandardEvaluationContext;
48 import org.springframework.security.core.userdetails.UserDetailsService;
49 import org.springframework.security.core.userdetails.UsernameNotFoundException;
50 import org.springframework.util.Assert;
51
52 import java.text.ParseException;
53 import java.util.ArrayList;
54 import java.util.Date;
55 import java.util.HashSet;
56 import java.util.List;
57 import java.util.Objects;
58 import java.util.Set;
59 import java.util.concurrent.TimeUnit;
60 import java.util.function.Consumer;
61
62
63
64
65 public class DynamicLoginHandler implements InitializingBean, JWTLoginHandler {
66
67
68 private static final Logger logger = LoggerFactory.getLogger(DynamicLoginHandler.class);
69
70
71 private UserDetailsService userDetailsService = null;
72
73
74 private UserDirectoryService userDirectoryService = null;
75
76
77 private UserReferenceProvider userReferenceProvider = null;
78
79
80 private SecurityService securityService = null;
81
82
83 private String jwksUrl = null;
84
85
86 private int jwksTimeToLive = 1000 * 60 * 60;
87
88
89 private int jwksRefreshTimeout = 1000 * 60;
90
91
92 private String secret = null;
93
94
95 private List<String> expectedAlgorithms = null;
96
97
98 private List<String> claimConstraints = null;
99
100
101 private String usernameMapping = null;
102
103
104 private String nameMapping = null;
105
106
107 private String emailMapping = null;
108
109
110 private boolean ocStandardRoleMappings = true;
111
112
113 private List<String> roleMappings = null;
114
115
116 private JWKSetProvider jwkProvider;
117
118
119 private int jwtCacheSize = 500;
120
121
122 private int jwtCacheExpiresIn = 60;
123
124
125 private Cache<String, CachedJWT> cache;
126
127 @Override
128 public void afterPropertiesSet() {
129 Assert.notNull(userDetailsService, "A UserDetailsService must be set");
130 Assert.notNull(userDirectoryService, "A UserDirectoryService must be set");
131 Assert.notNull(userReferenceProvider, "A UserReferenceProvider must be set");
132 Assert.notNull(securityService, "A SecurityService must be set");
133 Assert.isTrue(!(StringUtils.isNotBlank(jwksUrl) && StringUtils.isNotBlank(secret)),
134 "A JWKS URL and a secret cannot be set at the same time");
135 Assert.notEmpty(expectedAlgorithms, "Expected algorithms must be set");
136 Assert.notNull(claimConstraints, "Claim constraints must be set");
137 Assert.notNull(usernameMapping, "User name mapping must be set");
138 Assert.notNull(nameMapping, "Name mapping must be set");
139 Assert.notNull(emailMapping, "Email mapping must be set");
140 Assert.isTrue(roleMappings != null || ocStandardRoleMappings,
141 "Role mappings must be set if ocStandardRoleMappings is false");
142
143 if (StringUtils.isBlank(jwksUrl) && StringUtils.isBlank(secret)) {
144 logger.info("JWT login handler disabled as neither 'jwksUrl' nor 'secret' are set");
145 }
146
147 if (jwksUrl != null) {
148 jwkProvider = new JWKSetProvider(jwksUrl, jwksTimeToLive, jwksRefreshTimeout);
149 }
150 userReferenceProvider.setRoleProvider(new JWTRoleProvider(securityService, userReferenceProvider));
151 cache = CacheBuilder.newBuilder()
152 .maximumSize(jwtCacheSize)
153 .expireAfterWrite(jwtCacheExpiresIn, TimeUnit.MINUTES)
154 .build();
155 }
156
157 @Override
158 public String handleToken(String token) {
159 if (jwkProvider == null && secret == null) {
160 logger.debug("neither jwksURL nor secret set: ignoring JWT");
161 return null;
162 }
163
164 try {
165 String signature = extractSignature(token);
166 CachedJWT cachedJwt = cache.getIfPresent(signature);
167
168 if (cachedJwt == null) {
169 logger.debug("JWT cache miss for signature {}, validating token", abbreviateSignature(signature));
170
171
172 SignedJWT jwt = decodeAndValidate(token);
173 String username = extractUsername(jwt);
174 boolean invalidateUserDirectory = true;
175
176 try {
177 if (userDetailsService.loadUserByUsername(username) != null) {
178 invalidateUserDirectory = existingUserLogin(username, jwt);
179 }
180 } catch (UsernameNotFoundException e) {
181 newUserLogin(username, jwt);
182 }
183
184 if (invalidateUserDirectory) {
185 logger.debug("Invalidating user directory caches for JWT user '{}'", username);
186 userDirectoryService.invalidate(username);
187 } else {
188 logger.debug("Skipping user directory invalidation for JWT user '{}' since effective user data is unchanged",
189 username);
190 }
191 cache.put(jwt.getSignature().toString(), new CachedJWT(jwt, username));
192 logger.debug("Cached validated JWT signature {} for user '{}'", abbreviateSignature(signature), username);
193 return username;
194 } else {
195
196 if (cachedJwt.hasExpired()) {
197 cache.invalidate(signature);
198 throw new JOSEException("JWT token is not valid anymore");
199 }
200 logger.debug("Using validated JWT from cache for user '{}' and signature {}", cachedJwt.getUsername(),
201 abbreviateSignature(signature));
202 return cachedJwt.getUsername();
203 }
204 } catch (ParseException | JOSEException exception) {
205 logger.debug(exception.getMessage());
206 }
207
208 return null;
209 }
210
211
212
213
214
215
216
217
218 private SignedJWT decodeAndValidate(String token) throws ParseException, JOSEException {
219 SignedJWT jwt;
220
221 if (jwksUrl != null) {
222 jwt = JWTVerifier.verify(token, jwkProvider, claimConstraints);
223 } else {
224 jwt = JWTVerifier.verify(token, secret, claimConstraints);
225 }
226
227 if (!expectedAlgorithms.contains(jwt.getHeader().getAlgorithm().getName())) {
228 throw new JOSEException(
229 "JWT token was signed with an unexpected algorithm '" + jwt.getHeader().getAlgorithm() + "'"
230 );
231 }
232
233 return jwt;
234 }
235
236
237
238
239
240
241
242 private String extractSignature(String token) throws JOSEException {
243 String[] parts = token.split("\\.");
244 if (parts.length != 3) {
245 throw new JOSEException("Given token is not in a valid JWT format");
246 }
247 return parts[2];
248 }
249
250
251
252
253
254
255
256 private String extractUsername(SignedJWT jwt) throws ParseException {
257 String username = evaluateMapping(jwt, usernameMapping);
258 Assert.isTrue(StringUtils.isNotBlank(username), "Extracted username is blank");
259 return username;
260 }
261
262
263
264
265
266
267
268 private String extractName(SignedJWT jwt) throws ParseException {
269 String name = evaluateMapping(jwt, nameMapping);
270 Assert.isTrue(StringUtils.isNotBlank(name), "Extracted name is blank");
271 return name;
272 }
273
274
275
276
277
278
279
280 private String extractEmail(SignedJWT jwt) throws ParseException {
281 String email = evaluateMapping(jwt, emailMapping);
282 Assert.isTrue(StringUtils.isNotBlank(email), "Extracted email is blank");
283 return email;
284 }
285
286
287
288
289
290
291
292 private Set<JpaRole> extractRoles(SignedJWT jwt) throws ParseException {
293 JpaOrganization organization = fromOrganization(securityService.getOrganization());
294 Set<JpaRole> roles = new HashSet<>();
295 Consumer<String> addRole = (String role) -> {
296 if (StringUtils.isNotBlank(role)) {
297 roles.add(new JpaRole(role, organization));
298 }
299 };
300
301
302 if (ocStandardRoleMappings) {
303
304 try {
305 var rolesClaim = jwt.getJWTClaimsSet().getStringArrayClaim("roles");
306 if (rolesClaim != null) {
307 for (String r : rolesClaim) {
308 addRole.accept(r);
309 }
310 }
311 } catch (ParseException e) {
312 logger.debug("claim 'roles' is not an array of strings, ignoring");
313 }
314
315
316 try {
317 var ocClaim = jwt.getJWTClaimsSet().getJSONObjectClaim("oc");
318 if (ocClaim != null) {
319 for (var entry : ocClaim.entrySet()) {
320 var key = entry.getKey();
321 var parts = key.split(":", 2);
322 if (parts.length != 2) {
323 logger.debug("key in 'oc' claim does not start with 'x:' -> ignoring");
324 continue;
325 }
326 var type = parts[0];
327 var id = parts[1];
328
329 try {
330 for (var actionObj : (List<?>) entry.getValue()) {
331 var action = (String) actionObj;
332 if (action.isBlank()) {
333 continue;
334 }
335
336 if (type.equals("e")) {
337 addRole.accept(SecurityUtil.getEpisodeRoleId(id, action));
338 } else {
339 logger.debug("in 'oc' claim: granting access to item type '{}' is not yet supported", type);
340 }
341 }
342 } catch (ClassCastException e) {
343 logger.debug("value in 'oc' claim is not a string array -> ignoring");
344 continue;
345 }
346 }
347 }
348 } catch (ParseException e) {
349 logger.debug("claim 'oc' is not an array of strings, ignoring");
350 }
351 }
352
353 for (String mapping : (roleMappings == null ? new ArrayList<String>() : roleMappings)) {
354 ExpressionParser parser = new SpelExpressionParser();
355 Expression exp = parser.parseExpression(mapping);
356 StandardEvaluationContext ctx = new StandardEvaluationContext();
357 ctx.addPropertyAccessor(new MapAccessor());
358 Object value = exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims());
359 if (value != null) {
360
361 if (value instanceof String) {
362 addRole.accept((String) value);
363 } else if (value.getClass().isArray()) {
364 for (var role : (Object[]) value) {
365 addRole.accept((String) role);
366 }
367 } else {
368 for (var role : (List<?>) value) {
369 addRole.accept((String) role);
370 }
371 }
372 }
373 }
374 Assert.notEmpty(roles, "No roles could be extracted");
375 return roles;
376 }
377
378
379
380
381
382
383
384
385
386 private String evaluateMapping(SignedJWT jwt, String mapping) throws ParseException {
387 ExpressionParser parser = new SpelExpressionParser();
388 Expression exp = parser.parseExpression(mapping);
389 StandardEvaluationContext ctx = new StandardEvaluationContext();
390 ctx.addPropertyAccessor(new MapAccessor());
391 return exp.getValue(ctx, jwt.getJWTClaimsSet().getClaims(), String.class);
392 }
393
394
395
396
397
398
399
400 public void newUserLogin(String username, SignedJWT jwt) throws ParseException {
401
402 JpaUserReference userReference = new JpaUserReference(username, extractName(jwt), extractEmail(jwt), MECH_JWT,
403 new Date(), fromOrganization(securityService.getOrganization()), extractRoles(jwt));
404
405 logger.debug("JWT user '{}' logged in for the first time, creating a new user reference", username);
406 userReferenceProvider.addUserReference(userReference, MECH_JWT);
407 }
408
409
410
411
412
413
414
415 public boolean existingUserLogin(String username, SignedJWT jwt) throws ParseException {
416 Organization organization = securityService.getOrganization();
417 String name = extractName(jwt);
418 String email = extractEmail(jwt);
419 Set<JpaRole> roles = extractRoles(jwt);
420
421
422 JpaUserReference userReference = userReferenceProvider.findUserReference(username, organization.getId());
423 if (userReference == null) {
424 throw new UsernameNotFoundException("User reference '" + username + "' was not found");
425 }
426
427 boolean changed = userReferenceChanged(userReference, name, email, roles);
428 String changedFields = changed ? describeUserReferenceChanges(userReference, name, email, roles) : "";
429
430 userReference.setName(name);
431 userReference.setEmail(email);
432 userReference.setLastLogin(new Date());
433 userReference.setRoles(roles);
434
435 if (changed) {
436 logger.debug("JWT user '{}' logged in with updated user data ({}), refreshing merged user caches", username,
437 changedFields);
438 } else {
439 logger.debug("JWT user '{}' logged in with unchanged name, email, and roles", username);
440 }
441 userReferenceProvider.updateUserReference(userReference);
442 return changed;
443 }
444
445 private boolean userReferenceChanged(JpaUserReference userReference, String name, String email, Set<JpaRole> roles) {
446 return !Objects.equals(userReference.getName(), name)
447 || !Objects.equals(userReference.getEmail(), email)
448 || !Objects.equals(userReference.getRoles(), roles);
449 }
450
451 private String describeUserReferenceChanges(JpaUserReference userReference, String name, String email,
452 Set<JpaRole> roles) {
453 List<String> changedFields = new ArrayList<>();
454 if (!Objects.equals(userReference.getName(), name)) {
455 changedFields.add("name");
456 }
457 if (!Objects.equals(userReference.getEmail(), email)) {
458 changedFields.add("email");
459 }
460 if (!Objects.equals(userReference.getRoles(), roles)) {
461 changedFields.add("roles");
462 }
463 return String.join(", ", changedFields);
464 }
465
466 private String abbreviateSignature(String signature) {
467 if (signature == null || signature.length() <= 12) {
468 return signature;
469 }
470 return signature.substring(0, 12) + "...";
471 }
472
473
474
475
476
477
478
479 private JpaOrganization fromOrganization(Organization org) {
480 if (org instanceof JpaOrganization) {
481 return (JpaOrganization) org;
482 }
483
484 return new JpaOrganization(org.getId(), org.getName(), org.getServers(), org.getAdminRole(), org.getAnonymousRole(),
485 org.getProperties());
486 }
487
488
489
490
491
492
493 @Reference
494 public void setUserDetailsService(UserDetailsService userDetailsService) {
495 this.userDetailsService = userDetailsService;
496 }
497
498
499
500
501
502
503 @Reference
504 public void setUserDirectoryService(UserDirectoryService userDirectoryService) {
505 this.userDirectoryService = userDirectoryService;
506 }
507
508
509
510
511
512
513 @Reference
514 public void setSecurityService(SecurityService securityService) {
515 this.securityService = securityService;
516 }
517
518
519
520
521
522
523 @Reference
524 public void setUserReferenceProvider(UserReferenceProvider userReferenceProvider) {
525 this.userReferenceProvider = userReferenceProvider;
526 }
527
528
529
530
531
532
533 public void setJwksUrl(String jwksUrl) {
534 this.jwksUrl = jwksUrl;
535 }
536
537
538
539
540
541
542 public void setJwksTimeToLive(int jwksTimeToLive) {
543 this.jwksTimeToLive = jwksTimeToLive;
544 }
545
546
547
548
549
550
551 public void setSecret(String secret) {
552 this.secret = secret;
553 }
554
555
556
557
558
559
560 public void setExpectedAlgorithms(List<String> expectedAlgorithms) {
561 this.expectedAlgorithms = expectedAlgorithms;
562 }
563
564
565
566
567
568
569 public void setClaimConstraints(List<String> claimConstraints) {
570 this.claimConstraints = claimConstraints;
571 }
572
573
574
575
576
577 public void setUsernameMapping(String usernameMapping) {
578 this.usernameMapping = usernameMapping;
579 }
580
581
582
583
584
585
586 public void setNameMapping(String nameMapping) {
587 this.nameMapping = nameMapping;
588 }
589
590
591
592
593
594 public void setEmailMapping(String emailMapping) {
595 this.emailMapping = emailMapping;
596 }
597
598 public void setOcStandardRoleMappings(boolean ocStandardRoleMappings) {
599 this.ocStandardRoleMappings = ocStandardRoleMappings;
600 }
601
602
603
604
605
606
607 public void setRoleMappings(List<String> roleMappings) {
608 this.roleMappings = roleMappings;
609 }
610
611
612
613
614
615
616 public void setJwtCacheSize(int jwtCacheSize) {
617 this.jwtCacheSize = jwtCacheSize;
618 }
619
620
621
622
623
624
625 public void setJwtCacheExpiresIn(int jwtCacheExpiresIn) {
626 this.jwtCacheExpiresIn = jwtCacheExpiresIn;
627 }
628
629 }