1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package org.opencastproject.security.jwt;
23
24 import static com.google.common.base.Preconditions.checkArgument;
25 import static com.google.common.base.Strings.isNullOrEmpty;
26
27 import com.auth0.jwk.Jwk;
28 import com.auth0.jwk.JwkException;
29 import com.auth0.jwk.SigningKeyNotFoundException;
30 import com.auth0.jwk.UrlJwkProvider;
31 import com.auth0.jwt.algorithms.Algorithm;
32 import com.auth0.jwt.interfaces.DecodedJWT;
33 import com.google.common.cache.Cache;
34 import com.google.common.cache.CacheBuilder;
35
36 import org.slf4j.Logger;
37 import org.slf4j.LoggerFactory;
38
39 import java.net.MalformedURLException;
40 import java.net.URI;
41 import java.net.URISyntaxException;
42 import java.net.URL;
43 import java.util.ArrayList;
44 import java.util.Collections;
45 import java.util.List;
46 import java.util.concurrent.ExecutionException;
47 import java.util.concurrent.TimeUnit;
48
49
50
51
52 public class GuavaCachedUrlJwkProvider extends UrlJwkProvider {
53
54
55 private static final Logger logger = LoggerFactory.getLogger(GuavaCachedUrlJwkProvider.class);
56
57
58 private final Cache<String, List<Jwk>> cache;
59
60
61 private static final String KEY = "GET_ALL";
62
63
64
65
66
67
68
69
70 public GuavaCachedUrlJwkProvider(String jwksUrl, long expiresIn, TimeUnit expiresUnit) {
71 super(urlFromString(jwksUrl));
72 this.cache = CacheBuilder.newBuilder().maximumSize(1).expireAfterWrite(expiresIn, expiresUnit).build();
73 }
74
75
76
77
78
79
80
81 private static URL urlFromString(String url) {
82 checkArgument(!isNullOrEmpty(url), "A URL is required");
83 try {
84 final URI uri = new URI(url).normalize();
85 return uri.toURL();
86 } catch (MalformedURLException | URISyntaxException e) {
87 throw new IllegalArgumentException("Invalid JWKS URI", e);
88 }
89 }
90
91 @Override
92 public List<Jwk> getAll() {
93 return getAll(false);
94 }
95
96
97
98
99
100
101
102 public List<Jwk> getAll(boolean forceFetch) {
103 try {
104 if (forceFetch) {
105 cache.invalidate(KEY);
106 }
107
108 List<Jwk> jwks = cache.getIfPresent(KEY);
109 if (jwks == null) {
110 logger.debug("JWKS cache miss");
111 jwks = cache.get(KEY, super::getAll);
112 } else {
113 logger.debug("JWKS cache hit");
114 }
115
116 return jwks;
117 } catch (ExecutionException e) {
118 logger.error("Error while loading from JWKS cache: " + e.getMessage());
119 return Collections.emptyList();
120 }
121 }
122
123
124
125
126
127
128
129
130
131 public List<Algorithm> getAlgorithms(DecodedJWT jwt, boolean forceFetch) throws JwkException {
132 List<Algorithm> algorithms = new ArrayList<>();
133
134 for (Jwk jwk : getAll(forceFetch)) {
135 if (jwt.getKeyId() == null && jwt.getAlgorithm().equals(jwk.getAlgorithm())
136 || jwt.getKeyId() != null && jwt.getKeyId().equals(jwk.getId())) {
137 algorithms.add(AlgorithmBuilder.buildAlgorithm(jwk));
138 }
139 }
140
141 if (algorithms.isEmpty()) {
142 throw new SigningKeyNotFoundException("No key found", null);
143 }
144
145 return algorithms;
146 }
147
148 }