Skip to content

Commit

Permalink
Use a no-expiration-value cache for each cached JWTDecoder
Browse files Browse the repository at this point in the history
  • Loading branch information
enricovianello committed Mar 10, 2024
1 parent 702c695 commit e0d36df
Show file tree
Hide file tree
Showing 8 changed files with 137 additions and 11 deletions.
1 change: 0 additions & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
<owner-config.version>1.0.5.1</owner-config.version>

<spring-security-oauth2.version>2.3.3.RELEASE</spring-security-oauth2.version>
<nimbus-jose-jwt.version>6.0.2</nimbus-jose-jwt.version>
<mock-server.version>5.5.1</mock-server.version>
<bouncycastle.version>1.76</bouncycastle.version>
<voms-api-java.version>3.3.2</voms-api-java.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ public boolean isEnforceAudienceChecks() {
@Min(value = 1, message = "The refresh period must be a positive integer")
int refreshPeriodMinutes = 60;

@Min(value = 1, message = "The refresh timeout must be a positive integer")
int refreshTimeoutSeconds = 30;

public List<AuthorizationServer> getIssuers() {
return issuers;
}
Expand All @@ -112,6 +115,14 @@ public void setRefreshPeriodMinutes(int refreshPeriodMinutes) {
this.refreshPeriodMinutes = refreshPeriodMinutes;
}

public int getRefreshTimeoutSeconds() {
return refreshTimeoutSeconds;
}

public void setRefreshTimeoutSeconds(int refreshTimeoutSeconds) {
this.refreshTimeoutSeconds = refreshTimeoutSeconds;
}

public void setEnableOidc(boolean enableOidc) {
this.enableOidc = enableOidc;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,25 @@
import static java.lang.String.format;

import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;

import org.italiangrid.storm.webdav.config.OAuthProperties;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;

import com.nimbusds.jose.RemoteKeySourceException;

@Service
public class DefaultOidcConfigurationFetcher implements OidcConfigurationFetcher {

Expand All @@ -39,13 +46,17 @@ public class DefaultOidcConfigurationFetcher implements OidcConfigurationFetcher
public static final String NO_JWKS_URI_ERROR_TEMPLATE =
"No jwks_uri found in metadata for issuer '%s'";

private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");

public static final Logger LOG = LoggerFactory.getLogger(DefaultOidcConfigurationFetcher.class);

final RestTemplateBuilder restBuilder;
final OAuthProperties oAuthProperties;

@Autowired
public DefaultOidcConfigurationFetcher(RestTemplateBuilder restBuilder) {
public DefaultOidcConfigurationFetcher(RestTemplateBuilder restBuilder,
OAuthProperties oAuthProperties) {
this.restBuilder = restBuilder;
this.oAuthProperties = oAuthProperties;
}

private void metadataChecks(String issuer, Map<String, Object> oidcConfiguration) {
Expand Down Expand Up @@ -95,4 +106,32 @@ public Map<String, Object> loadConfigurationForIssuer(String issuer) {
}
}

@Override
public String loadJWKSourceForURL(URI uri) throws RemoteKeySourceException {

LOG.debug("Fetching JWK from {}", uri);

final Duration TIMEOUT = Duration.ofSeconds(oAuthProperties.getRefreshTimeoutSeconds());
RestTemplate rest = restBuilder.setConnectTimeout(TIMEOUT).setReadTimeout(TIMEOUT).build();

HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
ResponseEntity<String> response = null;
try {
RequestEntity<Void> request = RequestEntity.get(uri).headers(headers).build();
response = rest.exchange(request, String.class);
if (response.getStatusCodeValue() != 200) {
throw new RuntimeException(format("Received status code: %s", response.getStatusCodeValue()));
}
} catch (RuntimeException e) {
final String errorMsg = format("Unable to get JWK from '%s'", uri);
if (LOG.isDebugEnabled()) {
LOG.error("{}: {}", errorMsg, e.getMessage());
}
throw new RemoteKeySourceException(errorMsg, e);
}

return response.getBody();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/**
* Copyright (c) Istituto Nazionale di Fisica Nucleare, 2014-2023.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.italiangrid.storm.webdav.oauth.utils;

import java.util.concurrent.Callable;

import org.springframework.cache.support.AbstractValueAdaptingCache;
import org.springframework.lang.Nullable;

public class NoExpirationStringCache extends AbstractValueAdaptingCache {

private static final String NAME = "NoExpirationCache";
private final String value;

public NoExpirationStringCache(String value) {
super(false);
this.value = value;
}

@Override
public String getName() {
return NAME;
}

@Override
public Object getNativeCache() {
return this;
}

@Override
@Nullable
protected Object lookup(Object key) {
return value;
}

@Override
public void put(Object key, Object value) {
return;
}

@Override
public void evict(Object key) {
return;
}

@Override
public void clear() {
return;
}

@SuppressWarnings("unchecked")
@Override
public <T> T get(Object key, Callable<T> valueLoader) {
return (T) fromStoreValue(value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
*/
package org.italiangrid.storm.webdav.oauth.utils;

import java.net.URI;
import java.util.Map;

import com.nimbusds.jose.RemoteKeySourceException;

public interface OidcConfigurationFetcher {

Map<String, Object> loadConfigurationForIssuer(String issuer);

String loadJWKSourceForURL(URI uri) throws RemoteKeySourceException;

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.italiangrid.storm.webdav.oauth.utils;

import java.net.URI;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
Expand All @@ -28,8 +29,8 @@
import org.italiangrid.storm.webdav.oauth.validator.WlcgProfileValidator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.cache.Cache;
import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.jwt.Jwt;
Expand All @@ -53,7 +54,6 @@ public class TrustedJwtDecoderCacheLoader extends CacheLoader<String, JwtDecoder
private final ExecutorService executor;
private final OAuthProperties oauthProperties;

@Autowired
public TrustedJwtDecoderCacheLoader(ServiceConfigurationProperties properties,
OAuthProperties oauthProperties, RestTemplateBuilder builder,
OidcConfigurationFetcher fetcher, ExecutorService executor) {
Expand All @@ -76,9 +76,14 @@ public JwtDecoder load(String issuer) throws Exception {
.orElseThrow(unknownTokenIssuer(issuer));

Map<String, Object> oidcConfiguration = fetcher.loadConfigurationForIssuer(issuer);
URI jwksUri = URI.create(oidcConfiguration.get("jwks_uri").toString());
Cache noExpirationCache =
new NoExpirationStringCache(fetcher.loadJWKSourceForURL(jwksUri).toString());

NimbusJwtDecoder decoder =
NimbusJwtDecoder.withJwkSetUri((oidcConfiguration.get("jwks_uri").toString())).build();
NimbusJwtDecoder.withJwkSetUri((oidcConfiguration.get("jwks_uri").toString()))
.cache(noExpirationCache)
.build();

OAuth2TokenValidator<Jwt> jwtValidator = JwtValidators.createDefaultWithIssuer(issuer);
OAuth2TokenValidator<Jwt> wlcgProfileValidator = new WlcgProfileValidator();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,9 @@ CloseableHttpClient transferClient(ThirdPartyCopyProperties props,
ClientRegistrationRepository clientRegistrationRepository(
OAuth2ClientProperties clientProperties, OAuthProperties props, ExecutorService executor) {


ClientRegistrationCacheLoader loader =
new ClientRegistrationCacheLoader(clientProperties, props, executor);


LoadingCache<String, ClientRegistration> clients = CacheBuilder.newBuilder()
.refreshAfterWrite(props.getRefreshPeriodMinutes(), TimeUnit.MINUTES)
.build(loader);
Expand Down Expand Up @@ -346,7 +344,6 @@ ClientRegistrationRepository clientRegistrationRepository(
JwtDecoder jwtDecoder(OAuthProperties props, ServiceConfigurationProperties sProps,
RestTemplateBuilder builder, OidcConfigurationFetcher fetcher, ExecutorService executor) {


TrustedJwtDecoderCacheLoader loader =
new TrustedJwtDecoderCacheLoader(sProps, props, builder, fetcher, executor);

Expand Down
3 changes: 2 additions & 1 deletion src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ tpc:
enable-expect-continue-threshold: ${STORM_WEBDAV_TPC_ENABLE_EXPECT_CONTINUE_THRESHOLD:1048576}

oauth:
refresh-period-minutes: 60
refresh-period-minutes: ${STORM_WEBDAV_OAUTH_REFRESH_PERIOD_MINUTES:60}
refresh-timeout-seconds: ${STORM_WEBDAV_OAUTH_REFRESH_TIMEOUT_SECONDS:30}
issuers:

storm:
Expand Down

0 comments on commit e0d36df

Please sign in to comment.