Skip to content

Commit

Permalink
Tests & publish ReauthEvent
Browse files Browse the repository at this point in the history
  • Loading branch information
ggivo committed Dec 3, 2024
1 parent 91871b6 commit 16b2f1c
Show file tree
Hide file tree
Showing 12 changed files with 331 additions and 116 deletions.
34 changes: 30 additions & 4 deletions src/main/java/io/lettuce/core/BaseRedisAuthenticationHandler.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
package io.lettuce.core;

import io.lettuce.core.codec.StringCodec;
import io.lettuce.core.event.EventBus;
import io.lettuce.core.event.connection.ReauthEvent;
import io.lettuce.core.event.connection.ReauthFailedEvent;
import io.lettuce.core.protocol.AsyncCommand;
import io.lettuce.core.protocol.Endpoint;
import io.lettuce.core.protocol.RedisCommand;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
Expand All @@ -21,8 +25,11 @@ public abstract class BaseRedisAuthenticationHandler<T extends RedisChannelHandl

private final AtomicReference<Disposable> credentialsSubscription = new AtomicReference<>();

public BaseRedisAuthenticationHandler(T connection) {
protected final EventBus eventBus;

public BaseRedisAuthenticationHandler(T connection, EventBus eventBus) {
this.connection = connection;
this.eventBus = eventBus;
}

/**
Expand Down Expand Up @@ -94,11 +101,23 @@ private void reauthenticate(RedisCredentials credentials) {
authCmd = new AsyncCommand<>(commandBuilder.auth(password));
}

dispatchAuth(authCmd).exceptionally(throwable -> {
log.error("Re-authentication {} failed.", credentials.hasUsername() ? "with username" : "without username",
throwable);
dispatchAuth(authCmd).thenRun(() -> {
publishReauthEvent();
log.info("Re-authentication succeeded for endpoint {}.", getEpid());
}).exceptionally(throwable -> {
publishReauthFailedEvent(throwable);
log.error("Re-authentication failed for endpoint {}.", getEpid(), throwable);
return null;
});
;
}

private void publishReauthEvent() {
eventBus.publish(new ReauthEvent(getEpid()));
}

private void publishReauthFailedEvent(Throwable throwable) {
eventBus.publish(new ReauthFailedEvent(getEpid(), throwable));
}

protected boolean isSupportedConnection() {
Expand All @@ -114,4 +133,11 @@ private AsyncCommand<String, String, String> dispatchAuth(RedisCommand<String, S
return asyncCommand;
}

private String getEpid() {
if (connection.getChannelWriter() instanceof Endpoint) {
return ((Endpoint) connection.getChannelWriter()).getId();
}
return "unknown";
}

}
6 changes: 4 additions & 2 deletions src/main/java/io/lettuce/core/RedisAuthenticationHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
*/
package io.lettuce.core;

import io.lettuce.core.event.EventBus;
import io.lettuce.core.protocol.Endpoint;
import io.lettuce.core.protocol.ProtocolVersion;
import io.lettuce.core.pubsub.StatefulRedisPubSubConnection;
import io.netty.util.internal.logging.InternalLogger;
Expand All @@ -28,8 +30,8 @@ class RedisAuthenticationHandler extends BaseRedisAuthenticationHandler<Stateful

private static final InternalLogger logger = InternalLoggerFactory.getInstance(RedisAuthenticationHandler.class);

public RedisAuthenticationHandler(StatefulRedisConnectionImpl<?, ?> connection) {
super(connection);
public RedisAuthenticationHandler(StatefulRedisConnectionImpl<?, ?> connection, EventBus eventBus) {
super(connection, eventBus);
}

protected boolean isSupportedConnection() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ public StatefulRedisConnectionImpl(RedisChannelWriter writer, PushHandler pushHa
this.sync = newRedisSyncCommandsImpl();
this.reactive = newRedisReactiveCommandsImpl();

this.authHandler = new RedisAuthenticationHandler(this);
this.authHandler = new RedisAuthenticationHandler(this, getResources().eventBus());
}

public RedisCodec<K, V> getCodec() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import io.lettuce.core.BaseRedisAuthenticationHandler;
import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection;
import io.lettuce.core.event.EventBus;
import io.lettuce.core.protocol.ProtocolVersion;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
Expand All @@ -29,8 +30,8 @@ class RedisClusterAuthenticationHandler extends BaseRedisAuthenticationHandler<S

private static final InternalLogger logger = InternalLoggerFactory.getInstance(RedisClusterAuthenticationHandler.class);

public RedisClusterAuthenticationHandler(StatefulRedisClusterConnectionImpl<?, ?> connection) {
super(connection);
public RedisClusterAuthenticationHandler(StatefulRedisClusterConnectionImpl<?, ?> connection, EventBus eventBus) {
super(connection, eventBus);
}

protected boolean isSupportedConnection() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public StatefulRedisClusterConnectionImpl(RedisChannelWriter writer, ClusterPush
this.sync = newRedisAdvancedClusterCommandsImpl();
this.reactive = newRedisAdvancedClusterReactiveCommandsImpl();

this.authHandler = new RedisClusterAuthenticationHandler(this);
this.authHandler = new RedisClusterAuthenticationHandler(this, getResources().eventBus());
}

protected RedisAdvancedClusterReactiveCommandsImpl<K, V> newRedisAdvancedClusterReactiveCommandsImpl() {
Expand Down
22 changes: 22 additions & 0 deletions src/main/java/io/lettuce/core/event/connection/ReauthEvent.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.lettuce.core.event.connection;

import io.lettuce.core.event.Event;

/**
* Event fired on failed authentication caused either by I/O issues or during connection reauthentication.
*
* @author Ivo Gaydajiev
*/
public class ReauthEvent implements Event {

private final String epId;

public ReauthEvent(String epId) {
this.epId = epId;
}

public String getEpId() {
return epId;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package io.lettuce.core.event.connection;

import io.lettuce.core.event.Event;

import java.net.SocketAddress;

/**
* Event fired on failed authentication caused either by I/O issues or during connection reauthentication.
*
* @author Ivo Gaydajiev
*/
public class ReauthFailedEvent implements Event {

private final String epId;

private final Throwable cause;

public ReauthFailedEvent(String epId, Throwable cause) {
this.epId = epId;
this.cause = cause;
}

public String getEpId() {
return epId;
}

/**
* Returns the {@link Throwable} that describes the reauth failure cause.
*
* @return the {@link Throwable} that describes the reauth failure cause.
*/
public Throwable getCause() {
return cause;
}

}
90 changes: 14 additions & 76 deletions src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import io.lettuce.test.WithPassword;
import io.lettuce.test.condition.EnabledOnCommand;
import io.lettuce.test.settings.TestSettings;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Sinks;

import java.time.Duration;
import java.util.ArrayList;
Expand Down Expand Up @@ -86,12 +84,10 @@ void ownCredentialProvider(RedisClient client) {
// Simulate test user credential rotation, and verify that re-authentication is successful
@Test
@Inject
void renewableCredentialProvider(RedisClient client) {
void streamingCredentialProvider(RedisClient client) {

// Thread-safe list to capture intercepted commands
List<RedisCommand<?, ?, ?>> interceptedCommands = Collections.synchronizedList(new ArrayList<>());

// CommandListener to track successful commands
CommandListener commandListener = new CommandListener() {

@Override
Expand All @@ -100,66 +96,33 @@ public void commandSucceeded(CommandSucceededEvent event) {
}

};

// Add CommandListener to the client
client.addListener(commandListener);

// Configure client options
client.setOptions(
ClientOptions.builder().disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS).build());

// Connection for managing test user credential rotation
StatefulRedisConnection<String, String> adminConnection = client.connect();

String testUser = "streaming_cred_test_user";
char[] initialPassword = "token_1".toCharArray();
char[] updatedPassword = "token_2".toCharArray();

// Streaming credentials provider to simulate token emission
RenewableRedisCredentialsProvider credentialsProvider = new RenewableRedisCredentialsProvider();

// Build RedisURI with streaming credentials provider
MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider();
RedisURI uri = RedisURI.builder().withHost(TestSettings.host()).withPort(TestSettings.port())
.withClientName("streaming_cred_test").withAuthentication(credentialsProvider)
.withTimeout(Duration.ofSeconds(1)).build();

// Create test user and set initial credentials
createTestUser(adminConnection, testUser, initialPassword);
credentialsProvider.emitToken(new StaticRedisCredentials(testUser, initialPassword));

// Establish connection using the streaming credentials provider
StatefulRedisConnection<String, String> userConnection = client.connect(StringCodec.UTF8, uri);
WithPassword.run(client, () -> {

// Verify initial authentication
assertThat(userConnection.sync().aclWhoami()).isEqualTo(testUser);
credentialsProvider.emitCredentials(TestSettings.username(), TestSettings.password().toString().toCharArray());

// Update test user credentials and emit updated credentials
updateTestUser(adminConnection, testUser, updatedPassword);
credentialsProvider.emitToken(new StaticRedisCredentials(testUser, updatedPassword));
StatefulRedisConnection<String, String> adminConnection = client.connect(uri);
assertThat(adminConnection.sync().aclWhoami()).isEqualTo(TestSettings.username());

// Wait for the `AUTH` command with updated credentials
Awaitility.await().atMost(Duration.ofSeconds(1)).until(() -> interceptedCommands.stream()
.anyMatch(command -> isAuthCommandWithCredentials(command, testUser, updatedPassword)));
credentialsProvider.emitCredentials(TestSettings.aclUsername(),
TestSettings.aclPassword().toString().toCharArray());

// Verify re-authentication and connection functionality
assertThat(userConnection.sync().ping()).isEqualTo("PONG");
assertThat(userConnection.sync().aclWhoami()).isEqualTo(testUser);
Awaitility.await().atMost(Duration.ofSeconds(1))
.until(() -> interceptedCommands.stream().anyMatch(command -> isAuthCommandWithCredentials(command,
TestSettings.aclUsername(), TestSettings.aclPassword().toString().toCharArray())));

// Clean up
adminConnection.close();
userConnection.close();
}
assertThat(adminConnection.sync().aclWhoami()).isEqualTo(TestSettings.aclUsername());

private void createTestUser(StatefulRedisConnection<String, String> connection, String username, char[] password) {
AclSetuserArgs args = AclSetuserArgs.Builder.on().allCommands().allChannels().allKeys().nopass()
.addPassword(String.valueOf(password));
connection.sync().aclSetuser(username, args);
}
adminConnection.close();
});

private void updateTestUser(StatefulRedisConnection<String, String> connection, String username, char[] newPassword) {
AclSetuserArgs args = AclSetuserArgs.Builder.on().allCommands().allChannels().allKeys().nopass()
.addPassword(String.valueOf(newPassword));
connection.sync().aclSetuser(username, args);
}

private boolean isAuthCommandWithCredentials(RedisCommand<?, ?, ?> command, String username, char[] password) {
Expand All @@ -170,29 +133,4 @@ private boolean isAuthCommandWithCredentials(RedisCommand<?, ?, ?> command, Stri
return false;
}

static class RenewableRedisCredentialsProvider implements StreamingCredentialsProvider {

private final Sinks.Many<RedisCredentials> credentialsSink = Sinks.many().replay().latest();

@Override
public Mono<RedisCredentials> resolveCredentials() {

return credentialsSink.asFlux().next();
}

public Flux<RedisCredentials> credentials() {

return credentialsSink.asFlux().onBackpressureLatest(); // Provide a continuous stream of credentials
}

public void shutdown() {
credentialsSink.tryEmitComplete();
}

public void emitToken(RedisCredentials credentials) {
credentialsSink.tryEmitNext(credentials);
}

}

}
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package io.lettuce.core;

import io.lettuce.core.event.DefaultEventBus;
import io.lettuce.core.event.EventBus;
import io.lettuce.core.protocol.CommandType;
import io.lettuce.core.protocol.RedisCommand;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;
import reactor.core.scheduler.Schedulers;

import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
Expand All @@ -23,43 +26,25 @@ class BaseRedisAuthenticationHandlerTest {

private RedisChannelWriter channelWriter;

private StreamingCredentialsProvider streamingCredentialsProvider;

private Sinks.Many<RedisCredentials> sink;

@BeforeEach
void setUp() {

EventBus eventBus = new DefaultEventBus(Schedulers.immediate());
connection = mock(RedisChannelHandler.class);
channelWriter = mock(RedisChannelWriter.class);
when(connection.getChannelWriter()).thenReturn(channelWriter);
streamingCredentialsProvider = mock(StreamingCredentialsProvider.class);
sink = Sinks.many().replay().latest();
Flux<RedisCredentials> credentialsFlux = sink.asFlux();
when(streamingCredentialsProvider.credentials()).thenReturn(credentialsFlux);
handler = new BaseRedisAuthenticationHandler<RedisChannelHandler<?, ?>>(connection) {

@Override
protected boolean isSupportedConnection() {
return true;
}
handler = new BaseRedisAuthenticationHandler<RedisChannelHandler<?, ?>>(connection, eventBus) {

};
}

@SuppressWarnings("unchecked")
@Test
void subscribeWithStreamingCredentialsProviderInvokesReauth() {
MyStreamingRedisCredentialsProvider credentialsProvider = new MyStreamingRedisCredentialsProvider();

// Subscribe to the provider
handler.subscribe(streamingCredentialsProvider);
sink.tryEmitNext(RedisCredentials.just("newuser", "newpassword"));

// Ensure credentials() method was invoked
verify(streamingCredentialsProvider).credentials();

// Verify that write() is invoked once
verify(channelWriter, times(1)).write(any(RedisCommand.class));
handler.subscribe(credentialsProvider);
credentialsProvider.emitCredentials("newuser", "newpassword".toCharArray());

ArgumentCaptor<RedisCommand<Object, Object, Object>> captor = ArgumentCaptor.forClass(RedisCommand.class);
verify(channelWriter).write(captor.capture());
Expand All @@ -86,12 +71,11 @@ void shouldHandleErrorInCredentialsStream() {

@Test
void shouldNotSubscribeIfConnectionIsNotSupported() {
Sinks.Many<RedisCredentials> sink = Sinks.many().replay().latest();
Flux<RedisCredentials> credentialsFlux = sink.asFlux();
EventBus eventBus = new DefaultEventBus(Schedulers.immediate());
StreamingCredentialsProvider credentialsProvider = mock(StreamingCredentialsProvider.class);
when(credentialsProvider.credentials()).thenReturn(credentialsFlux);

BaseRedisAuthenticationHandler<?> handler = new BaseRedisAuthenticationHandler<RedisChannelHandler<?, ?>>(connection) {
BaseRedisAuthenticationHandler<?> handler = new BaseRedisAuthenticationHandler<RedisChannelHandler<?, ?>>(connection,
eventBus) {

@Override
protected boolean isSupportedConnection() {
Expand Down
Loading

0 comments on commit 16b2f1c

Please sign in to comment.