Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Batch proposal spend limits #3471

Draft
wants to merge 6 commits into
base: staging
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions node/bft/ledger-service/src/ledger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,29 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
tracing::info!("\n\nAdvanced to block {} at round {} - {}\n", block.height(), block.round(), block.hash());
Ok(())
}

fn compute_cost(&self, _transaction_id: N::TransactionID, transaction: Data<Transaction<N>>) -> Result<u64> {
// TODO: move to VM or ledger?
let process = self.ledger.vm().process();

// Deserialize the transaction. If the transaction exceeds the maximum size, then return an error.
let transaction = match transaction {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deserialization is extremely expensive. I would consider moving this calculation into ledger.rs:check_transaction_basic, where we already deserialize. Perhaps that function can return the compute cost? Or some other design?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did consider it and should revisit the idea but intuitively we might want to avoid coupling the cost calculation with check_transaction_basic as it's only called in propose_batch and not in process_batch_propose_from_peer, at least not directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point... That is unfortunate, because in the case of process_batch_propose_from_peer, we might be retrieving the transmission from disk, in which case we'll most certainly have to incur the deserialization cost.

Maybe if we let fn compute_cost cost take in a Transaction instead of Data<Transaction>, it can at least be made explicit. For our own proposal we call it from within check_transaction_basic, for incoming proposals we'll need to deserialize before calling it.

And if it turns out to be a bottleneck, we can always refactor the locations where we deserialize more comprehensively, and potentially create a cache for the compute_cache if needed.

Data::Object(transaction) => transaction,
Data::Buffer(bytes) => Transaction::<N>::read_le(&mut bytes.take(N::MAX_TRANSACTION_SIZE as u64))?,
};

// Collect the Optional Stack corresponding to the transaction if its an Execution.
let stack = if let Transaction::Execute(_, ref execution, _) = transaction {
// Get the root transition from the execution.
let root_transition = execution.peek()?;
// Get the stack from the process.
Some(process.read().get_stack(root_transition.program_id())?.clone())
} else {
None
};

use snarkvm::prelude::compute_cost;

compute_cost(&transaction, stack)
}
}
5 changes: 5 additions & 0 deletions node/bft/ledger-service/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,9 @@ impl<N: Network> LedgerService<N> for MockLedgerService<N> {
self.height_to_round_and_hash.lock().insert(block.height(), (block.round(), block.hash()));
Ok(())
}

/// TODO: is this reasonable?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// TODO: is this reasonable?

I guess if our mock ledger doesn't have the compute_cost function, yes its reasonable, no need to extend every test abstraction which exists.

fn compute_cost(&self, _transaction_id: N::TransactionID, _transaction: Data<Transaction<N>>) -> Result<u64> {
Ok(0)
}
}
4 changes: 4 additions & 0 deletions node/bft/ledger-service/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,4 +186,8 @@ impl<N: Network> LedgerService<N> for ProverLedgerService<N> {
fn advance_to_next_block(&self, block: &Block<N>) -> Result<()> {
bail!("Cannot advance to next block in prover - {block}")
}

fn compute_cost(&self, transaction_id: N::TransactionID, _transaction: Data<Transaction<N>>) -> Result<u64> {
bail!("Transaction '{transaction_id}' doesn't exist in prover")
vicsn marked this conversation as resolved.
Show resolved Hide resolved
}
}
2 changes: 2 additions & 0 deletions node/bft/ledger-service/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,6 @@ pub trait LedgerService<N: Network>: Debug + Send + Sync {
/// Adds the given block as the next block in the ledger.
#[cfg(feature = "ledger-write")]
fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;

fn compute_cost(&self, transaction_id: N::TransactionID, transaction: Data<Transaction<N>>) -> Result<u64>;
}
4 changes: 4 additions & 0 deletions node/bft/ledger-service/src/translucent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,8 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for TranslucentLedgerS
fn advance_to_next_block(&self, block: &Block<N>) -> Result<()> {
self.inner.advance_to_next_block(block)
}

fn compute_cost(&self, transaction_id: N::TransactionID, transaction: Data<Transaction<N>>) -> Result<u64> {
self.inner.compute_cost(transaction_id, transaction)
}
}
179 changes: 110 additions & 69 deletions node/bft/src/primary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,88 +490,100 @@ impl<N: Network> Primary<N> {

// Determined the required number of transmissions per worker.
let num_transmissions_per_worker = BatchHeader::<N>::MAX_TRANSMISSIONS_PER_BATCH / self.num_workers() as usize;

// Initialize the map of transmissions.
let mut transmissions: IndexMap<_, _> = Default::default();
// Keeps track of the number of transmissions included thus far. The
// transmissions index is only updated in batches, this counter is more granular.
let mut num_transmissions_included = 0usize;
// Track the total execution costs of the batch proposal as it is being constructed.
let mut proposal_cost = 0u64;
// Take the transmissions from the workers.
for worker in self.workers.iter() {
'outer: for worker in self.workers.iter() {
// Initialize a tracker for included transmissions for the current worker.
let mut num_transmissions_included_for_worker = 0;
// Keep draining the worker until the desired number of transmissions is reached or the worker is empty.
'outer: while num_transmissions_included_for_worker < num_transmissions_per_worker {
// Determine the number of remaining transmissions for the worker.
let num_remaining_transmissions =
num_transmissions_per_worker.saturating_sub(num_transmissions_included_for_worker);
// Drain the worker.
let mut worker_transmissions = worker.drain(num_remaining_transmissions).peekable();
// If the worker is empty, break early.
if worker_transmissions.peek().is_none() {
break 'outer;
let mut worker_transmissions = worker.transmissions().into_iter();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So correct me if I'm wrong, but the previous behaviour was to selectively drain transmissions, and if they are duplicates or are incorrect, we discard them.

Your current implementation completely drains transmissions, so besides discaring duplicates or invalid ones, once we hit the size or compute limit, remaining transmissions are discarded.

Some options:

  1. preprocess the compute spend of a transmission and to cache it in the worker's Ready queue. But this would unfortunately require a lot of code changes and wouldn't make sense everywhere.
  2. We drain transmissions one by one. If we notice post-deserialization that we're above the compute limit, we reinsert that one transmission back into the queue using worker.reinsert. This has strictly less overhead than draining all transmissions.

Copy link
Collaborator Author

@niklaslong niklaslong Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously the primary would preemptively drain the workers by batches, sized based on the remaining proposal space. This was fine as transactions were either included or dropped altogether but we need more nuance: we might choose to save a transaction for the following batch if it pushes the batch cost over the spend limit.

For the PoC, I opted to clone the transmissions (happens in worker.transmissions()), to avoid reinsertion logic initially, leaving the underlying collection untouched until the number of valid transactions is determined and draining it only at the end of the worker loop. Ideally, I'd want to peek the next item in the collection, and it might still be the most efficient. The next best thing would be to reinsert but that's still O(n) and requires more logic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Peeking one by one would work!

Or draining one by one. If we optionally have to reinsert the very last one we drained if it crosses the compute threshold, I'm not sure what would be O(n)

Copy link
Collaborator Author

@niklaslong niklaslong Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

O(n) comes from shifting all the transmissions by one in the collection when reinserting at index 0 (to preserve fairness in the ordering of transmissions). Peeking directly likely won't work because it requires either holding the lock (not a good idea) or cloning. But perhaps doing the lookup one-by-one, followed by a removal could work. Though any operation that isn't inserting or removing values from the tail of the map will incur an O(n) shift cost.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you for explaining.

I take everything back, your approach seems close to optimal (assuming we don't change or introduce additional data structures).

But perhaps doing the lookup one-by-one, followed by a removal could work.

Indeed you can clone and lookup one-by-one (to avoid cloning the entire ready queue). And then batch drain at the end like you're already doing.


// Check the transactions for inclusion in the batch proposal.
while num_transmissions_included_for_worker < num_transmissions_per_worker {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering whether it is possible and desirable to put this check in the same place as the compute cost check, because they're both indicating "the batch proposal is too full now".

The loop could then be something like While let Some(transmission) = worker.drain(1).

Up to you, just a suggestion

let Some((id, transmission)) = worker_transmissions.next() else { break };

// Check if the ledger already contains the transmission.
if self.ledger.contains_transmission(&id).unwrap_or(true) {
trace!("Proposing - Skipping transmission '{}' - Already in ledger", fmt_id(id));
continue;
}
// Iterate through the worker transmissions.
'inner: for (id, transmission) in worker_transmissions {
// Check if the ledger already contains the transmission.
if self.ledger.contains_transmission(&id).unwrap_or(true) {
trace!("Proposing - Skipping transmission '{}' - Already in ledger", fmt_id(id));
continue 'inner;
}
// Check if the storage already contain the transmission.
// Note: We do not skip if this is the first transmission in the proposal, to ensure that
// the primary does not propose a batch with no transmissions.
if !transmissions.is_empty() && self.storage.contains_transmission(id) {
trace!("Proposing - Skipping transmission '{}' - Already in storage", fmt_id(id));
continue 'inner;
}
// Check the transmission is still valid.
match (id, transmission.clone()) {
(TransmissionID::Solution(solution_id, checksum), Transmission::Solution(solution)) => {
// Ensure the checksum matches.
match solution.to_checksum::<N>() {
Ok(solution_checksum) if solution_checksum == checksum => (),
_ => {
trace!(
"Proposing - Skipping solution '{}' - Checksum mismatch",
fmt_id(solution_id)
);
continue 'inner;
}
}
// Check if the solution is still valid.
if let Err(e) = self.ledger.check_solution_basic(solution_id, solution).await {
trace!("Proposing - Skipping solution '{}' - {e}", fmt_id(solution_id));
continue 'inner;

// Check if the storage already contain the transmission.
// Note: We do not skip if this is the first transmission in the proposal, to ensure that
// the primary does not propose a batch with no transmissions.
if num_transmissions_included != 0 && self.storage.contains_transmission(id) {
trace!("Proposing - Skipping transmission '{}' - Already in storage", fmt_id(id));
continue;
}

// Check the transmission is still valid.
match (id, transmission.clone()) {
(TransmissionID::Solution(solution_id, checksum), Transmission::Solution(solution)) => {
// Ensure the checksum matches.
match solution.to_checksum::<N>() {
Ok(solution_checksum) if solution_checksum == checksum => (),
_ => {
trace!("Proposing - Skipping solution '{}' - Checksum mismatch", fmt_id(solution_id));
continue;
}
}
(
TransmissionID::Transaction(transaction_id, checksum),
Transmission::Transaction(transaction),
) => {
// Ensure the checksum matches.
match transaction.to_checksum::<N>() {
Ok(transaction_checksum) if transaction_checksum == checksum => (),
_ => {
trace!(
"Proposing - Skipping transaction '{}' - Checksum mismatch",
fmt_id(transaction_id)
);
continue 'inner;
}
// Check if the solution is still valid.
if let Err(e) = self.ledger.check_solution_basic(solution_id, solution).await {
trace!("Proposing - Skipping solution '{}' - {e}", fmt_id(solution_id));
continue;
}
}
(TransmissionID::Transaction(transaction_id, checksum), Transmission::Transaction(transaction)) => {
// Ensure the checksum matches.
match transaction.to_checksum::<N>() {
Ok(transaction_checksum) if transaction_checksum == checksum => (),
_ => {
trace!(
"Proposing - Skipping transaction '{}' - Checksum mismatch",
fmt_id(transaction_id)
);
continue;
}
// Check if the transaction is still valid.
if let Err(e) = self.ledger.check_transaction_basic(transaction_id, transaction).await {
trace!("Proposing - Skipping transaction '{}' - {e}", fmt_id(transaction_id));
continue 'inner;
}
// Check if the transaction is still valid.
// TODO: check if clone is cheap, otherwise fix.
if let Err(e) = self.ledger.check_transaction_basic(transaction_id, transaction.clone()).await {
trace!("Proposing - Skipping transaction '{}' - {e}", fmt_id(transaction_id));
continue;
}

// Ensure the transaction doesn't bring the proposal above the spend limit.
match self.ledger.compute_cost(transaction_id, transaction) {
Ok(cost) if proposal_cost + cost <= N::BATCH_SPEND_LIMIT => proposal_cost += cost,
_ => {
trace!(
"Proposing - Skipping transaction '{}' - Batch spend limit surpassed",
fmt_id(transaction_id)
);
break 'outer;
}
}
// Note: We explicitly forbid including ratifications,
// as the protocol currently does not support ratifications.
(TransmissionID::Ratification, Transmission::Ratification) => continue,
// All other combinations are clearly invalid.
_ => continue 'inner,
}
// Insert the transmission into the map.
transmissions.insert(id, transmission);
num_transmissions_included_for_worker += 1;
// Note: We explicitly forbid including ratifications,
// as the protocol currently does not support ratifications.
(TransmissionID::Ratification, Transmission::Ratification) => continue,
// All other combinations are clearly invalid.
_ => continue,
}

num_transmissions_included += 1;
num_transmissions_included_for_worker += 1;
}

// Drain the selected transactions from the worker and insert them into the batch proposal.
for (id, transmission) in worker.drain(num_transmissions_included_for_worker) {
transmissions.insert(id, transmission);
}
}

Expand Down Expand Up @@ -755,6 +767,35 @@ impl<N: Network> Primary<N> {
// Inserts the missing transmissions into the workers.
self.insert_missing_transmissions_into_workers(peer_ip, missing_transmissions.into_iter())?;

// Ensure the transaction doesn't bring the proposal above the spend limit.
let mut proposal_cost = 0u64;
for transmission_id in batch_header.transmission_ids() {
let worker_id = assign_to_worker(*transmission_id, self.num_workers())?;
let Some(worker) = self.workers.get(worker_id as usize) else {
debug!("Unable to find worker {worker_id}");
return Ok(());
};

let Some(transmission) = worker.get_transmission(*transmission_id) else {
debug!("Unable to find transmission '{}' in worker '{worker_id}", fmt_id(transmission_id));
return Ok(());
};

// If the transmission is a transaction, compute its execution cost.
if let (TransmissionID::Transaction(transaction_id, _), Transmission::Transaction(transaction)) =
(transmission_id, transmission)
{
proposal_cost += self.ledger.compute_cost(*transaction_id, transaction)?
}
}

if proposal_cost > N::BATCH_SPEND_LIMIT {
debug!(
"Batch propose from peer '{peer_ip}' exceeds the batch spend limit — cost in microcredits: '{proposal_cost}'"
);
return Ok(());
}

/* Proceeding to sign the batch. */

// Retrieve the batch ID.
Expand Down
1 change: 1 addition & 0 deletions node/bft/src/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ mod tests {
transmissions: IndexMap<TransmissionID<N>, Transmission<N>>,
) -> Result<Block<N>>;
fn advance_to_next_block(&self, block: &Block<N>) -> Result<()>;
fn compute_cost(&self, transaction_id: N::TransactionID, transaction: Data<Transaction<N>>) -> Result<u64>;
}
}

Expand Down