Skip to content

Commit

Permalink
Fix slow path reconnect process to call update on the connection map
Browse files Browse the repository at this point in the history
Signed-off-by: GilboaAWS <[email protected]>
  • Loading branch information
GilboaAWS committed Jan 16, 2025
1 parent 029aafc commit 66a3e39
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ pub(crate) struct RefreshConnectionStates<Connection> {
}

impl<Connection> RefreshConnectionStates<Connection> {
// Clears all ongoing refresh connection tasks and resets associated state tracking.
//
// - This method removes all entries in the `refresh_address_in_progress` map.
// - The `Drop` trait is responsible for notifying the associated notifiers and aborting any unfinished refresh tasks.
// - Additionally, this method clears `refresh_addresses_started` and `refresh_addresses_done`
// to ensure no stale data remains in the refresh state tracking.
pub(crate) fn clear_refresh_state(&mut self) {
debug!(
"clear_refresh_state: removing all in-progress refresh connection tasks for addresses: {:?}",
Expand All @@ -294,6 +300,35 @@ impl<Connection> RefreshConnectionStates<Connection> {
self.refresh_addresses_started.clear();
self.refresh_addresses_done.clear();
}

// Collects the notifiers for the given addresses and returns them as a vector.
//
// This function retrieves the notifiers for the provided addresses from the `refresh_address_in_progress`
// map and returns them, so they can be awaited outside of the lock.
//
// # Arguments
// * `addresses` - A list of addresses for which notifiers are required.
//
// # Returns
// A vector of `futures::future::Notified` that can be awaited.
pub(crate) fn collect_refresh_notifiers(
&self,
addresses: &HashSet<String>,
) -> Vec<Arc<Notify>> {
addresses
.iter()
.filter_map(|address| {
self.refresh_address_in_progress
.get(address)
.and_then(|refresh_state| match &refresh_state.status {
RefreshTaskStatus::Reconnecting(notifier) => {
Some(notifier.get_notifier().clone())
}
_ => None,
})
})
.collect()
}
}

impl<Connection> Default for RefreshConnectionStates<Connection> {
Expand Down
125 changes: 77 additions & 48 deletions glide-core/redis-rs/redis/src/cluster_async/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1364,23 +1364,59 @@ where

if !addrs_to_refresh.is_empty() {
// don't try existing nodes since we know a. it does not exist. b. exist but its connection is closed
Self::refresh_connections(
Self::refresh_and_update_connections(
inner.clone(),
addrs_to_refresh,
RefreshConnectionType::OnlyUserConnection,
true,
RefreshConnectionType::AllConnections,
false,
)
.await;
}
}

async fn refresh_connections(
// Creates refresh tasks, await on the tasks' notifier and the update the connection_container.
// Awaiting on the notifier guaranties at least one reconnect attempt on each address.
async fn refresh_and_update_connections(
inner: Arc<InnerCore<C>>,
addresses: HashSet<String>,
conn_type: RefreshConnectionType,
check_existing_conn: bool,
) {
info!("Started refreshing connections to {:?}", addresses);
trace!("refresh_and_update_connections: calling trigger_refresh_connection_tasks");
Self::trigger_refresh_connection_tasks(
inner.clone(),
addresses.clone(),
conn_type,
check_existing_conn,
)
.await;

trace!("refresh_and_update_connections: Await on all tasks' refresh notifier");
// Await on all tasks' refresh notifier if exists
let refresh_task_notifiers = inner
.clone()
.conn_lock
.read()
.expect(MUTEX_READ_ERR)
.refresh_conn_state
.collect_refresh_notifiers(&addresses);
let futures: Vec<_> = refresh_task_notifiers
.iter()
.map(|notify| notify.notified())
.collect();
futures::future::join_all(futures).await;

// Update the connections in the connection_container
Self::update_refreshed_connection(inner);
}

async fn trigger_refresh_connection_tasks(
inner: Arc<InnerCore<C>>,
addresses: HashSet<String>,
conn_type: RefreshConnectionType,
check_existing_conn: bool,
) {
debug!("Triggering refresh connections tasks to {:?} ", addresses);

for address in addresses {
if inner
Expand Down Expand Up @@ -1854,7 +1890,7 @@ where

if !addrs_to_refresh.is_empty() {
// immediately trigger connection reestablishment
Self::refresh_connections(
Self::refresh_and_update_connections(
inner.clone(),
addrs_to_refresh,
RefreshConnectionType::AllConnections,
Expand Down Expand Up @@ -1890,7 +1926,8 @@ where
}

if !failed_connections.is_empty() {
Self::refresh_connections(
trace!("check_for_topology_diff: calling trigger_refresh_connection_tasks");
Self::trigger_refresh_connection_tasks(
inner,
failed_connections.into_iter().collect::<HashSet<String>>(),
RefreshConnectionType::OnlyManagementConnection,
Expand Down Expand Up @@ -2416,7 +2453,7 @@ where
ConnectionCheck::Found((address, connection)) => (address, connection.await),
ConnectionCheck::OnlyAddress(address) => {
// No connection for this address in the conn_map
Self::refresh_connections(
Self::trigger_refresh_connection_tasks(
core.clone(),
HashSet::from_iter(once(address.clone())),
RefreshConnectionType::AllConnections,
Expand Down Expand Up @@ -2518,6 +2555,7 @@ where
}

fn poll_recover(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), RedisError>> {
trace!("entered poll_recovere");
let recover_future = match &mut self.state {
ConnectionState::PollComplete => return Poll::Ready(Ok(())),
ConnectionState::Recover(future) => future,
Expand Down Expand Up @@ -2585,10 +2623,10 @@ where
Self::try_request(info, core).await
}

fn update_refreshed_connection(&mut self) {
fn update_refreshed_connection(inner: Arc<InnerCore<C>>) {
trace!("update_refreshed_connection started");
loop {
let connections_container = self.inner.conn_lock.read().expect(MUTEX_READ_ERR);
let connections_container = inner.conn_lock.read().expect(MUTEX_READ_ERR);

// Check if both sets are empty
if connections_container
Expand Down Expand Up @@ -2617,21 +2655,18 @@ where
.cloned()
.collect();

let current_existing_addresses_in_slot_map =
connections_container.slot_map.all_node_addresses();

drop(connections_container);

// Process refresh_addresses_started
for address in addresses_to_remove {
self.inner
inner
.conn_lock
.write()
.expect(MUTEX_READ_ERR)
.refresh_conn_state
.refresh_addresses_started
.remove(&address);
self.inner
inner
.conn_lock
.write()
.expect(MUTEX_READ_ERR)
Expand All @@ -2640,40 +2675,35 @@ where

// Process refresh_addresses_done
for address in addresses_done {
// Check if this address appears in the current topology
if current_existing_addresses_in_slot_map.contains(&address) {
// Check if the address exists in refresh_addresses_done
let mut conn_lock_write = self.inner.conn_lock.write().expect(MUTEX_READ_ERR);
if let Some(conn_option) = conn_lock_write
.refresh_conn_state
.refresh_addresses_done
.get_mut(&address)
{
// Match the content of the Option
match conn_option.take() {
Some(conn) => {
debug!(
"update_refreshed_connection: found refreshed connection for address {}",
address
);
// Move the node_conn to the function
conn_lock_write
.replace_or_add_connection_for_address(address.clone(), conn);
}
None => {
debug!(
"update_refreshed_connection: task completed, but no connection for address {}",
address
);
}
// Check if the address exists in refresh_addresses_done
let mut conn_lock_write = inner.conn_lock.write().expect(MUTEX_READ_ERR);
if let Some(conn_option) = conn_lock_write
.refresh_conn_state
.refresh_addresses_done
.get_mut(&address)
{
// Match the content of the Option
match conn_option.take() {
Some(conn) => {
debug!(
"update_refreshed_connection: found refreshed connection for address {}",
address
);
// Move the node_conn to the function
conn_lock_write
.replace_or_add_connection_for_address(address.clone(), conn);
}
None => {
debug!(
"update_refreshed_connection: task completed, but no connection for address {}",
address
);
}
}
} else {
debug!("update_refreshed_connection: address {:?} doesn't appear in addresses in slot_map: {:?}", address, current_existing_addresses_in_slot_map);
}

// Remove this address from refresh_addresses_done
self.inner
inner
.conn_lock
.write()
.expect(MUTEX_READ_ERR)
Expand All @@ -2682,8 +2712,7 @@ where
.remove(&address);

// Remove this entry from refresh_address_in_progress
if let Some(_) = self
.inner
if let Some(_) = inner
.conn_lock
.write()
.expect(MUTEX_READ_ERR)
Expand Down Expand Up @@ -2933,7 +2962,7 @@ where
// In case of active poll_recovery, the <RecoverSlots / Reconnect(reconnect_to_initial_nodes)> work should
// take care of the refreshed_connection, add them if still relevant, and kill the refresh_tasks of
// non-relevant addresses.
self.update_refreshed_connection();
ClusterConnInner::update_refreshed_connection(self.inner.clone());

match ready!(self.poll_complete(cx)) {
PollFlushAction::None => return Poll::Ready(Ok(())),
Expand All @@ -2947,7 +2976,7 @@ where
}
PollFlushAction::Reconnect(addresses) => {
self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin(
ClusterConnInner::refresh_connections(
ClusterConnInner::trigger_refresh_connection_tasks(
self.inner.clone(),
addresses,
RefreshConnectionType::OnlyUserConnection,
Expand Down

0 comments on commit 66a3e39

Please sign in to comment.