Skip to content

Commit

Permalink
get both current and cleared version in same query
Browse files Browse the repository at this point in the history
  • Loading branch information
somtochiama committed Dec 20, 2024
1 parent 65d39cd commit 8adfca6
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 103 deletions.
52 changes: 50 additions & 2 deletions crates/corro-agent/src/agent/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::{
transport::Transport,
};
use corro_tests::*;
use corro_types::broadcast::Timestamp;
use corro_types::{broadcast::Timestamp, config::FollowFrom};
use corro_types::change::Change;
use corro_types::{
actor::ActorId,
Expand All @@ -51,7 +51,7 @@ use corro_types::{
};

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn insert_rows_and_gossip() -> eyre::Result<()> {
pub async fn insert_rows_and_gossip() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
let ta1 = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
Expand Down Expand Up @@ -1023,6 +1023,54 @@ async fn process_failed_changes() -> eyre::Result<()> {
Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn follow_basic() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();

let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
let main = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;

// setup the schema, for both nodes
let (status_code, _body) = api_v1_db_schema(
Extension(main.agent.clone()),
axum::Json(vec![corro_tests::TEST_SCHEMA.into()]),
)
.await;

assert_eq!(status_code, StatusCode::OK);

// make about 50 transactions to ta1
insert_rows(main.agent.clone(), 1, 20).await;
// clear some rows
insert_rows(main.agent.clone(), 10, 30).await;

let follower = launch_test_agent(|conf| conf.follow(main.agent.gossip_addr(), FollowFrom::Latest, None).build(), tripwire.clone()).await?;
let (status_code, _body) = api_v1_db_schema(
Extension(follower.agent.clone()),
axum::Json(vec![corro_tests::TEST_SCHEMA.into()]),
)
.await;
assert_eq!(status_code, StatusCode::OK);

sleep(Duration::from_secs(3)).await;
check_bookie_versions(
follower.clone(),
main.agent.actor_id(),
vec![Version(1)..=Version(9)],
vec![],
vec![],
vec![Version(10)..=Version(20)],
)
.await?;


tripwire_tx.send(()).await.ok();
tripwire_worker.await;
wait_for_all_pending_handles().await;

Ok(())
}

#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_process_multiple_changes() -> eyre::Result<()> {
_ = tracing_subscriber::fmt::try_init();
Expand Down
134 changes: 33 additions & 101 deletions crates/corro-agent/src/api/peer/follow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use rusqlite::{params_from_iter, OptionalExtension, Row, ToSql};
use speedy::{Readable, Writable};
use tokio::{sync::mpsc, task::block_in_place};
use tokio_util::codec::{Encoder, FramedRead, LengthDelimitedCodec};
use tracing::{debug, error, trace};
use tracing::{debug, error, info, trace};

use super::{encode_write_bipayload_msg, BiPayloadSendError};

Expand Down Expand Up @@ -148,7 +148,7 @@ pub async fn serve_follow(
});

let actor_id = agent.actor_id();
let from_ts: Timestamp = {
let mut from_ts: Timestamp = {
let conn = agent.pool().read().await?;
conn.query_row(
"SELECT MIN(ts) FROM __corro_bookkeeping WHERE db_version >= ? and (? or actor_id = ?)",
Expand All @@ -166,12 +166,16 @@ pub async fn serve_follow(

block_in_place(|| {
let (extra_where_clause, query_params): (_, Vec<&dyn ToSql>) = if local_only {
("AND actor_id = ?", vec![&last_db_version, &actor_id])
("AND actor_id = ?", vec![&last_db_version, &from_ts, &actor_id])
} else {
("", vec![&last_db_version])
("", vec![&last_db_version, &from_ts])
};

let mut bk_prepped = conn.prepare_cached(&format!("SELECT actor_id, start_version, db_version, last_seq, ts FROM __corro_bookkeeping WHERE db_version IS NOT NULL AND db_version > ? {extra_where_clause} ORDER BY db_version ASC"))?;

let mut bk_prepped = conn.prepare_cached(&format!("SELECT actor_id, start_version, end_version, db_version, last_seq, ts
FROM __corro_bookkeeping WHERE (db_version IS NOT NULL AND db_version > ?)
OR (db_version IS NULL and ts > ?) {extra_where_clause}
ORDER BY db_version ASC"))?;

let map = |row: &Row| {
Ok((
Expand All @@ -180,25 +184,43 @@ pub async fn serve_follow(
row.get(2)?,
row.get(3)?,
row.get(4)?,
row.get(5)?,
))
};

// implicit read transaction
let bk_rows = bk_prepped.query_map(params_from_iter(query_params), map)?;

for bk_res in bk_rows {
let (actor_id, version, db_version, last_seq, ts): (
let (actor_id, start_version, end_version, db_version, last_seq, ts): (
ActorId,
Version,
CrsqlDbVersion,
CrsqlSeq,
Option<Version>,
Option<CrsqlDbVersion>,
Option<CrsqlSeq>,
Timestamp,
) = bk_res?;

debug!("sending changes for: {actor_id} v{version} (db_version: {db_version})");
debug!("sending changes for: {actor_id} v{start_version} (db_version: {db_version:?})");

if let Some(end) = end_version {
tx.blocking_send(FollowMessage::V1(FollowMessageV1::Change(ChangeV1 {
actor_id,
changeset: Changeset::Empty {
versions: start_version..=end,
ts: Some(ts),
},
})))
.map_err(|_| FollowError::ChannelClosed)?;

from_ts = ts;
continue;
}

let last_seq = last_seq.unwrap();
let db_version:CrsqlDbVersion = db_version.unwrap();
let mut prepped = conn.prepare_cached(
"SELECT \"table\", pk, cid, val, col_version, db_version, seq, site_id, cl FROM crsql_changes WHERE db_version = ? ORDER BY db_version ASC, seq ASC",
"SELECT \"table\", pk, cid, val, col_version, db_version, seq, site_id, cl FROM crsql_changes WHERE db_version = ? ORDER BY db_version ASC, seq ASC, ts ASC",
)?;
// implicit read transaction
let rows = prepped.query_map([db_version], row_to_change)?;
Expand All @@ -210,7 +232,7 @@ pub async fn serve_follow(
tx.blocking_send(FollowMessage::V1(FollowMessageV1::Change(ChangeV1 {
actor_id,
changeset: Changeset::Full {
version,
version: start_version,
changes,
seqs,
last_seq,
Expand All @@ -223,52 +245,6 @@ pub async fn serve_follow(
last_db_version = db_version; // record last db version processed for next go around
}

// we do this everytime so we can pick up new actor_ids
let actor_ids = {
if local_only {
vec![actor_id]
} else {
conn.prepare_cached("SELECT DISTINCT actor_id FROM __corro_bookkeeping")?
.query_map([], |row| Ok(row.get(0)?))
.and_then(|rows| rows.collect::<rusqlite::Result<Vec<_>>>())?
}
};

info!("sending cleared version since from - {from_ts} for {} actors", last_empty_ts.len());
for id in actor_ids {
if !last_empty_ts.contains_key(&id) {
last_empty_ts.insert(actor_id, from_ts);
}
}

for (actor_id, empty_ts) in last_empty_ts.clone() {
let mut empty_prepped = conn.prepare_cached(
"SELECT start_version, end_version, ts FROM __corro_bookkeeping WHERE db_version IS NULL AND ts > ? AND actor_id = ? ORDER BY ts ASC",
)?;

let empty_rows = empty_prepped.query_map((empty_ts, actor_id), |row| {
Ok(Changeset::Empty {
versions: row.get(0)?..=row.get(1)?,
ts: row.get(2)?,
})
})?;

let mut last_ts: Option<Timestamp> = None;
for row in empty_rows {
let changeset = row?;
last_ts = changeset.ts();
debug!("sending cleared versions for {actor_id}, versions - {:?}", changeset.versions());
tx.blocking_send(FollowMessage::V1(FollowMessageV1::Change(ChangeV1 {
actor_id,
changeset,
})))
.map_err(|_| FollowError::ChannelClosed)?;
}

if let Some(ts) = last_ts {
last_empty_ts.insert(actor_id, ts);
}
}
Ok::<_, FollowError>(())
})?;

Expand Down Expand Up @@ -386,49 +362,5 @@ pub async fn follow(

#[cfg(test)]
mod tests {
// use corro_types::config::FollowFrom;
// use corro_tests::launch_test_agent;
// use axum::Extension;
// use hyper::StatusCode;
// use tripwire::Tripwire;

// use crate::{
// api::{
// public::{api_v1_db_schema, api_v1_transactions},
// },
// };
// use corro_tests::*;

use super::*;

// #[tokio::test(flavor = "multi_thread", worker_threads = 1)]
// async fn test_follow() -> eyre::Result<()> {

// _ = tracing_subscriber::fmt::try_init();

// let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
// let main = launch_test_agent(|conf| conf.build(), tripwire.clone()).await?;
// let follower = launch_test_agent(|conf| conf.follow(main.agent.gossip_addr(), FollowFrom::Latest, None).build(), tripwire.clone()).await?;

// // setup the schema, for both nodes
// let (status_code, _body) = api_v1_db_schema(
// Extension(main.agent.clone()),
// axum::Json(vec![corro_tests::TEST_SCHEMA.into()]),
// )
// .await;

// assert_eq!(status_code, StatusCode::OK);

// let (status_code, _body) = api_v1_db_schema(
// Extension(follower.agent.clone()),
// axum::Json(vec![corro_tests::TEST_SCHEMA.into()]),
// )
// .await;
// assert_eq!(status_code, StatusCode::OK);

// // make about 50 transactions to ta1
// insert_rows(follower.agent.clone(), 1, 50).await;

// Ok(())
// }
}

0 comments on commit 8adfca6

Please sign in to comment.