Skip to content

Commit

Permalink
Fix caching of remote JWK sets
Browse files Browse the repository at this point in the history
  • Loading branch information
enricovianello committed Mar 19, 2024
1 parent e8dac03 commit bb5aa0a
Show file tree
Hide file tree
Showing 10 changed files with 582 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ public boolean isEnforceAudienceChecks() {
@Min(value = 1, message = "The refresh period must be a positive integer")
int refreshPeriodMinutes = 60;

@Min(value = 1, message = "The refresh timeout must be a positive integer")
int refreshTimeoutSeconds = 30;

public List<AuthorizationServer> getIssuers() {
return issuers;
}
Expand All @@ -112,6 +115,14 @@ public void setRefreshPeriodMinutes(int refreshPeriodMinutes) {
this.refreshPeriodMinutes = refreshPeriodMinutes;
}

public int getRefreshTimeoutSeconds() {
return refreshTimeoutSeconds;
}

public void setRefreshTimeoutSeconds(int refreshTimeoutSeconds) {
this.refreshTimeoutSeconds = refreshTimeoutSeconds;
}

public void setEnableOidc(boolean enableOidc) {
this.enableOidc = enableOidc;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,45 @@
import static java.lang.String.format;

import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;

import org.italiangrid.storm.webdav.config.OAuthProperties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;

import com.nimbusds.jose.RemoteKeySourceException;

@Service
public class DefaultOidcConfigurationFetcher implements OidcConfigurationFetcher {

public static final String WELL_KNOWN_FRAGMENT = "/.well-known/openid-configuration";
public static final String ISSUER_MISMATCH_ERROR_TEMPLATE =
"Issuer in metadata '%s' does not match with requested issuer '%s'";
public static final String NO_JWKS_URI_ERROR_TEMPLATE =
public static final String NO_JWKS_URI_ERROR_TEMPLATE =
"No jwks_uri found in metadata for issuer '%s'";

private static final MediaType APPLICATION_JWK_SET_JSON =
new MediaType("application", "jwk-set+json");

public static final Logger LOG = LoggerFactory.getLogger(DefaultOidcConfigurationFetcher.class);

final RestTemplateBuilder restBuilder;
final RestTemplate restTemplate;

@Autowired
public DefaultOidcConfigurationFetcher(RestTemplateBuilder restBuilder) {
this.restBuilder = restBuilder;
public DefaultOidcConfigurationFetcher(RestTemplateBuilder restBuilder,
OAuthProperties oAuthProperties) {
final Duration TIMEOUT = Duration.ofSeconds(oAuthProperties.getRefreshTimeoutSeconds());
this.restTemplate = restBuilder.setConnectTimeout(TIMEOUT).setReadTimeout(TIMEOUT).build();
}

private void metadataChecks(String issuer, Map<String, Object> oidcConfiguration) {
Expand All @@ -59,40 +70,63 @@ private void metadataChecks(String issuer, Map<String, Object> oidcConfiguration
throw new OidcConfigurationResolutionError(
format(ISSUER_MISMATCH_ERROR_TEMPLATE, metadataIssuer, issuer));
}

if (!oidcConfiguration.containsKey("jwks_uri")) {
throw new OidcConfigurationResolutionError(format(NO_JWKS_URI_ERROR_TEMPLATE,issuer));
throw new OidcConfigurationResolutionError(format(NO_JWKS_URI_ERROR_TEMPLATE, issuer));
}
}

@Override
public Map<String, Object> loadConfigurationForIssuer(String issuer) {
LOG.debug("Fetching OpenID configuration for {}", issuer);

ParameterizedTypeReference<Map<String, Object>> typeReference =
new ParameterizedTypeReference<Map<String, Object>>() {};

RestTemplate rest = restBuilder.build();

URI uri = UriComponentsBuilder.fromUriString(issuer + WELL_KNOWN_FRAGMENT).build().toUri();

ResponseEntity<Map<String, Object>> response = null;
try {

RequestEntity<Void> request = RequestEntity.get(uri).build();
Map<String, Object> conf = rest.exchange(request, typeReference).getBody();
metadataChecks(issuer, conf);
return conf;
response = restTemplate.exchange(request, typeReference);
if (response.getStatusCodeValue() != 200) {
throw new RuntimeException(
format("Received status code: %s", response.getStatusCodeValue()));
}
metadataChecks(issuer, response.getBody());
return response.getBody();
} catch (RuntimeException e) {
final String errorMsg =
format("Unable to resolve OpenID configuration for issuer '%s' from '%s': %s", issuer,
uri, e.getMessage());

final String errorMsg = format("Unable to resolve OpenID configuration from '%s'", uri);
if (LOG.isDebugEnabled()) {
LOG.error(errorMsg, e);
LOG.error("{}: {}", errorMsg, e.getMessage());
}

throw new OidcConfigurationResolutionError(errorMsg, e);
}
}

@Override
public String loadJWKSourceForURL(URI uri) throws RemoteKeySourceException {

LOG.debug("Fetching JWK from {}", uri);

HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
ResponseEntity<String> response = null;
try {
RequestEntity<Void> request = RequestEntity.get(uri).headers(headers).build();
response = restTemplate.exchange(request, String.class);
if (response.getStatusCodeValue() != 200) {
throw new RuntimeException(
format("Received status code: %s", response.getStatusCodeValue()));
}
} catch (RuntimeException e) {
final String errorMsg = format("Unable to get JWK from '%s'", uri);
if (LOG.isDebugEnabled()) {
LOG.error("{}: {}", errorMsg, e.getMessage());
}
throw new RemoteKeySourceException(errorMsg, e);
}

return response.getBody();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* Copyright (c) Istituto Nazionale di Fisica Nucleare, 2014-2023.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.italiangrid.storm.webdav.oauth.utils;

import java.util.concurrent.Callable;

import org.springframework.cache.support.AbstractValueAdaptingCache;
import org.springframework.lang.Nullable;

public class NoExpirationStringCache extends AbstractValueAdaptingCache {

private static final String NAME = "NoExpirationCache";
private final String value;

public NoExpirationStringCache(String value) {
super(false);
this.value = value;
}

@Override
public String getName() {
return NAME;
}

@Override
public Object getNativeCache() {
return this;
}

@Override
@Nullable
protected Object lookup(Object key) {
return value;
}

@Override
public void put(Object key, Object value) {
return;
}

@Override
public void evict(Object key) {
return;
}

@Override
public void clear() {
return;
}

@SuppressWarnings("unchecked")
@Override
public <T> T get(Object key, Callable<T> valueLoader) {
return (T) fromStoreValue(value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
*/
package org.italiangrid.storm.webdav.oauth.utils;

import java.net.URI;
import java.util.Map;

import com.nimbusds.jose.RemoteKeySourceException;

public interface OidcConfigurationFetcher {

Map<String, Object> loadConfigurationForIssuer(String issuer);

String loadJWKSourceForURL(URI uri) throws RemoteKeySourceException;

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.italiangrid.storm.webdav.oauth.utils;

import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
Expand All @@ -28,8 +29,8 @@
import org.italiangrid.storm.webdav.oauth.validator.WlcgProfileValidator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.cache.Cache;
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.jwt.Jwt;
Expand All @@ -53,7 +54,6 @@ public class TrustedJwtDecoderCacheLoader extends CacheLoader<String, JwtDecoder
private final ExecutorService executor;
private final OAuthProperties oauthProperties;

@Autowired
public TrustedJwtDecoderCacheLoader(ServiceConfigurationProperties properties,
OAuthProperties oauthProperties, RestTemplateBuilder builder,
OidcConfigurationFetcher fetcher, ExecutorService executor) {
Expand All @@ -76,9 +76,14 @@ public JwtDecoder load(String issuer) throws Exception {
.orElseThrow(unknownTokenIssuer(issuer));

Map<String, Object> oidcConfiguration = fetcher.loadConfigurationForIssuer(issuer);
URI jwksUri = URI.create(oidcConfiguration.get("jwks_uri").toString());
Cache noExpirationCache =
new NoExpirationStringCache(fetcher.loadJWKSourceForURL(jwksUri).toString());

NimbusJwtDecoder decoder =
NimbusJwtDecoder.withJwkSetUri((oidcConfiguration.get("jwks_uri").toString())).build();
NimbusJwtDecoder.withJwkSetUri((oidcConfiguration.get("jwks_uri").toString()))
.cache(noExpirationCache)
.build();

OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefaultWithIssuer(issuer);
OAuth2TokenValidator<Jwt> wlcgProfileValidator = new WlcgProfileValidator();
Expand Down
3 changes: 2 additions & 1 deletion src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ tpc:
enable-expect-continue-threshold: ${STORM_WEBDAV_TPC_ENABLE_EXPECT_CONTINUE_THRESHOLD:1048576}

oauth:
refresh-period-minutes: 60
refresh-period-minutes: ${STORM_WEBDAV_OAUTH_REFRESH_PERIOD_MINUTES:60}
refresh-timeout-seconds: ${STORM_WEBDAV_OAUTH_REFRESH_TIMEOUT_SECONDS:30}
issuers:

storm:
Expand Down
Loading

0 comments on commit bb5aa0a

Please sign in to comment.