View Javadoc
1   /*
2    * Licensed to The Apereo Foundation under one or more contributor license
3    * agreements. See the NOTICE file distributed with this work for additional
4    * information regarding copyright ownership.
5    *
6    *
7    * The Apereo Foundation licenses this file to you under the Educational
8    * Community License, Version 2.0 (the "License"); you may not use this file
9    * except in compliance with the License. You may obtain a copy of the License
10   * at:
11   *
12   *   http://opensource.org/licenses/ecl2.txt
13   *
14   * Unless required by applicable law or agreed to in writing, software
15   * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
16   * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  See the
17   * License for the specific language governing permissions and limitations under
18   * the License.
19   *
20   */
21  
22  package org.opencastproject.security.jwt;
23  
24  import static com.google.common.base.Preconditions.checkArgument;
25  import static com.google.common.base.Strings.isNullOrEmpty;
26  
27  import com.auth0.jwk.Jwk;
28  import com.auth0.jwk.JwkException;
29  import com.auth0.jwk.SigningKeyNotFoundException;
30  import com.auth0.jwk.UrlJwkProvider;
31  import com.auth0.jwt.algorithms.Algorithm;
32  import com.auth0.jwt.interfaces.DecodedJWT;
33  import com.google.common.cache.Cache;
34  import com.google.common.cache.CacheBuilder;
35  
36  import org.slf4j.Logger;
37  import org.slf4j.LoggerFactory;
38  
39  import java.net.MalformedURLException;
40  import java.net.URI;
41  import java.net.URISyntaxException;
42  import java.net.URL;
43  import java.util.ArrayList;
44  import java.util.Collections;
45  import java.util.List;
46  import java.util.concurrent.ExecutionException;
47  import java.util.concurrent.TimeUnit;
48  
49  /**
50   * JWK provider that caches previously fetched JWKs in memory using a Google Guava cache.
51   */
52  public class GuavaCachedUrlJwkProvider extends UrlJwkProvider {
53  
54    /** Logging facility. */
55    private static final Logger logger = LoggerFactory.getLogger(GuavaCachedUrlJwkProvider.class);
56  
57    /** JWK cache. */
58    private final Cache<String, List<Jwk>> cache;
59  
60    /** Key for the JWK in the cache. */
61    private static final String KEY = "GET_ALL";
62  
63    /**
64     * Cons a new cached provider from a JWKs URL and a TTL.
65     *
66     * @param jwksUrl The URL where JWKs are published.
67     * @param expiresIn The amount of time the fetched JWKs will live in the cache.
68     * @param expiresUnit The unit of the expiresIn parameter.
69     */
70    public GuavaCachedUrlJwkProvider(String jwksUrl, long expiresIn, TimeUnit expiresUnit) {
71      super(urlFromString(jwksUrl));
72      this.cache = CacheBuilder.newBuilder().maximumSize(1).expireAfterWrite(expiresIn, expiresUnit).build();
73    }
74  
75    /**
76     * Converts a URL string into a {@link URL}.
77     *
78     * @param url The URL string.
79     * @return The {@link URL}.
80     */
81    private static URL urlFromString(String url) {
82      checkArgument(!isNullOrEmpty(url), "A URL is required");
83      try {
84        final URI uri = new URI(url).normalize();
85        return uri.toURL();
86      } catch (MalformedURLException | URISyntaxException e) {
87        throw new IllegalArgumentException("Invalid JWKS URI", e);
88      }
89    }
90  
91    @Override
92    public List<Jwk> getAll() {
93      return getAll(false);
94    }
95  
96    /**
97     * Getter for all JWKs.
98     *
99     * @param forceFetch Whether to force a re-fetch.
100    * @return The JWKs.
101    */
102   public List<Jwk> getAll(boolean forceFetch) {
103     try {
104       if (forceFetch) {
105         cache.invalidate(KEY);
106       }
107 
108       List<Jwk> jwks = cache.getIfPresent(KEY);
109       if (jwks == null) {
110         logger.debug("JWKS cache miss");
111         jwks = cache.get(KEY, super::getAll);
112       } else {
113         logger.debug("JWKS cache hit");
114       }
115 
116       return jwks;
117     } catch (ExecutionException e) {
118       logger.error("Error while loading from JWKS cache: " + e.getMessage());
119       return Collections.emptyList();
120     }
121   }
122 
123   /**
124    * Getter for all algorithms corresponding to the fetched JWKs.
125    *
126    * @param jwt The decoded JWT.
127    * @param forceFetch Whether to force a re-fetch.
128    * @return The algorithms.
129    * @throws JwkException If the algorithms cannot be constructed from the JWKs.
130    */
131   public List<Algorithm> getAlgorithms(DecodedJWT jwt, boolean forceFetch) throws JwkException {
132     List<Algorithm> algorithms = new ArrayList<>();
133 
134     for (Jwk jwk : getAll(forceFetch)) {
135       if (jwt.getKeyId() == null && jwt.getAlgorithm().equals(jwk.getAlgorithm())
136           || jwt.getKeyId() != null && jwt.getKeyId().equals(jwk.getId())) {
137         algorithms.add(AlgorithmBuilder.buildAlgorithm(jwk));
138       }
139     }
140 
141     if (algorithms.isEmpty()) {
142       throw new SigningKeyNotFoundException("No key found", null);
143     }
144 
145     return algorithms;
146   }
147 
148 }