Skip to content

Commit

Permalink
Merge origin/main
Browse files Browse the repository at this point in the history
  • Loading branch information
georgwiese committed Jan 11, 2025
2 parents 15e3de5 + 65a2fdd commit cb11fb6
Show file tree
Hide file tree
Showing 20 changed files with 715 additions and 564 deletions.
1 change: 0 additions & 1 deletion .github/workflows/pr-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ concurrency:
env:
CARGO_TERM_COLOR: always
POWDR_GENERATE_PROOFS: "true"
POWDR_JIT_OPT_LEVEL: "0"
MAX_DEGREE_LOG: "20"

jobs:
Expand Down
3 changes: 2 additions & 1 deletion ast/src/parsed/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ pub trait Children<O> {
}

pub trait AllChildren<O> {
/// Returns an iterator over all direct and indirect children of kind O in this object.
/// Returns an iterator over all direct and indirect children of kind `O` in this object.
/// If `O` and `Self` are the same type, also includes `self`.
/// Pre-order visitor.
fn all_children(&self) -> Box<dyn Iterator<Item = &O> + '_>;
}
Expand Down
260 changes: 53 additions & 207 deletions executor/src/witgen/jit/block_machine_processor.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
use std::collections::{BTreeSet, HashSet};
use std::collections::HashSet;

use bit_vec::BitVec;
use itertools::Itertools;
use powdr_ast::analyzed::{
AlgebraicReference, Identity, PolyID, PolynomialType, SelectedExpressions,
};
use powdr_ast::analyzed::AlgebraicReference;
use powdr_number::FieldElement;

use crate::witgen::{jit::effect::format_code, machines::MachineParts, FixedData};
use crate::witgen::{jit::processor::Processor, machines::MachineParts, FixedData};

use super::{
effect::Effect,
variable::{Cell, Variable},
witgen_inference::{CanProcessCall, FixedEvaluator, Value, WitgenInference},
variable::Variable,
witgen_inference::{CanProcessCall, FixedEvaluator, WitgenInference},
};

/// This is a tuning value. It is the maximum nesting depth of branches in the JIT code.
const BLOCK_MACHINE_MAX_BRANCH_DEPTH: usize = 6;

/// A processor for generating JIT code for a block machine.
pub struct BlockMachineProcessor<'a, T: FieldElement> {
fixed_data: &'a FixedData<'a, T>,
Expand Down Expand Up @@ -74,209 +75,51 @@ impl<'a, T: FieldElement> BlockMachineProcessor<'a, T> {
witgen.assign_variable(expr, self.latch_row as i32, Variable::Param(index));
}

// Solve for the block witness.
// Fails if any machine call cannot be completed.
match self.solve_block(can_process, &mut witgen, connection.right) {
Ok(()) => Ok(witgen.finish()),
Err(e) => {
log::debug!("\nCode generation failed for connection:\n {connection}");
let known_args_str = known_args
.iter()
.enumerate()
.filter_map(|(i, b)| b.then_some(connection.right.expressions[i].to_string()))
.join("\n ");
log::debug!("Known arguments:\n {known_args_str}");
log::debug!("Error:\n {e}");
log::debug!(
"The following code was generated so far:\n{}",
format_code(witgen.code())
);
Err(format!("Code generation failed: {e}\nRun with RUST_LOG=debug to see the code generated so far."))
}
}
let identities = self.row_range().flat_map(move |row| {
self.machine_parts
.identities
.iter()
.map(move |&id| (id, row))
});
let requested_known = known_args
.iter()
.enumerate()
.filter_map(|(i, is_input)| (!is_input).then_some(Variable::Param(i)));
Processor::new(
self.fixed_data,
self,
identities,
self.block_size,
true,
requested_known,
BLOCK_MACHINE_MAX_BRANCH_DEPTH,
)
.generate_code(can_process, witgen)
.map_err(|e| {
let err_str = e.to_string_with_variable_formatter(|var| match var {
Variable::Param(i) => format!("{}", &connection.right.expressions[*i]),
_ => var.to_string(),
});
log::trace!("\nCode generation failed for connection:\n {connection}");
let known_args_str = known_args
.iter()
.enumerate()
.filter_map(|(i, b)| b.then_some(connection.right.expressions[i].to_string()))
.join("\n ");
log::trace!("Known arguments:\n {known_args_str}");
log::trace!("Error:\n {err_str}");
let shortened_error = err_str
.lines()
.take(10)
.format("\n ");
format!("Code generation failed: {shortened_error}\nRun with RUST_LOG=trace to see the code generated so far.")
})
}

fn row_range(&self) -> std::ops::Range<i32> {
// We iterate over all rows of the block +/- one row, so that we can also solve for non-rectangular blocks.
-1..self.block_size as i32
}

/// Repeatedly processes all identities on all rows, until no progress is made.
/// Fails iff there are incomplete machine calls in the latch row.
fn solve_block<CanProcess: CanProcessCall<T> + Clone>(
&self,
can_process: CanProcess,
witgen: &mut WitgenInference<'a, T, &Self>,
connection_rhs: &SelectedExpressions<T>,
) -> Result<(), String> {
let mut complete = HashSet::new();
for iteration in 0.. {
let mut progress = false;

for row in self.row_range() {
for id in &self.machine_parts.identities {
if !complete.contains(&(id.id(), row)) {
let result = witgen.process_identity(can_process.clone(), id, row);
if result.complete {
complete.insert((id.id(), row));
}
progress |= result.progress;
}
}
}
if !progress {
log::trace!(
"Finishing block machine witgen code generation after {iteration} iterations"
);
break;
}
}

for (index, expr) in connection_rhs.expressions.iter().enumerate() {
if !witgen.is_known(&Variable::Param(index)) {
return Err(format!(
"Unable to derive algorithm to compute output value \"{expr}\""
));
}
}

if let Err(e) = self.check_block_shape(witgen) {
// Fail hard, as this should never happen for a correctly detected block machine.
log::debug!(
"The following code was generated so far:\n{}",
format_code(witgen.code())
);
panic!("{e}");
}
self.check_incomplete_machine_calls(&complete)?;

Ok(())
}

/// After solving, the known values should be such that we can stack different blocks.
fn check_block_shape(&self, witgen: &mut WitgenInference<'a, T, &Self>) -> Result<(), String> {
let known_columns = witgen
.known_variables()
.iter()
.filter_map(|var| match var {
Variable::Cell(cell) => Some(cell.id),
_ => None,
})
.collect::<BTreeSet<_>>();

let can_stack = known_columns.iter().all(|column_id| {
// Increase the range by 1, because in row <block_size>,
// we might have processed an identity with next references.
let row_range = self.row_range();
let values = (row_range.start..(row_range.end + 1))
.map(|row| {
witgen.value(&Variable::Cell(Cell {
id: *column_id,
row_offset: row,
// Dummy value, the column name is ignored in the implementation
// of Cell::eq, etc.
column_name: "".to_string(),
}))
})
.collect::<Vec<_>>();

// Two values that refer to the same row (modulo block size) are compatible if:
// - One of them is unknown, or
// - Both are concrete and equal
let is_compatible = |v1: Value<T>, v2: Value<T>| match (v1, v2) {
(Value::Unknown, _) | (_, Value::Unknown) => true,
(Value::Concrete(a), Value::Concrete(b)) => a == b,
_ => false,
};
// A column is stackable if all rows equal to each other modulo
// the block size are compatible.
let stackable = (0..(values.len() - self.block_size))
.all(|i| is_compatible(values[i], values[i + self.block_size]));

if !stackable {
let column_name = self.fixed_data.column_name(&PolyID {
id: *column_id,
ptype: PolynomialType::Committed,
});
let block_list = values.iter().skip(1).take(self.block_size).join(", ");
let column_str = format!(
"... {} | {} | {} ...",
values[0],
block_list,
values[self.block_size + 1]
);
log::error!("Column {column_name} is not stackable:\n{column_str}");
}

stackable
});

match can_stack {
true => Ok(()),
false => Err("Block machine shape does not allow stacking".to_string()),
}
}

/// If any machine call could not be completed, that's bad because machine calls typically have side effects.
/// So, the underlying lookup / permutation / bus argument likely does not hold.
/// This function checks that all machine calls are complete, at least for a window of <block_size> rows.
fn check_incomplete_machine_calls(&self, complete: &HashSet<(u64, i32)>) -> Result<(), String> {
let machine_calls = self
.machine_parts
.identities
.iter()
.filter(|id| is_machine_call(id));

let incomplete_machine_calls = machine_calls
.flat_map(|call| {
let complete_rows = self
.row_range()
.filter(|row| complete.contains(&(call.id(), *row)))
.collect::<Vec<_>>();
// Because we process rows -1..block_size+1, it is fine to have two incomplete machine calls,
// as long as <block_size> consecutive rows are complete.
if complete_rows.len() >= self.block_size {
let (min, max) = complete_rows.iter().minmax().into_option().unwrap();
let is_consecutive = max - min == complete_rows.len() as i32 - 1;
if is_consecutive {
return vec![];
}
}
self.row_range()
.filter(|row| !complete.contains(&(call.id(), *row)))
.map(|row| (call, row))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();

if !incomplete_machine_calls.is_empty() {
Err(format!(
"Incomplete machine calls:\n {}",
incomplete_machine_calls
.iter()
.map(|(identity, row)| format!("{identity} (row {row})"))
.join("\n ")
))
} else {
Ok(())
}
}
}

fn is_machine_call<T>(identity: &Identity<T>) -> bool {
match identity {
Identity::Lookup(_)
| Identity::Permutation(_)
| Identity::PhantomLookup(_)
| Identity::PhantomPermutation(_) => true,
// TODO(bus_interaction): Bus interactions are currently ignored,
// so processing them does not succeed. We currently assume that for
// every bus interaction, there is an equivalent (phantom) lookup or
// permutation constraint.
// Returning false here to give JITing a chance to succeed.
Identity::PhantomBusInteraction(_) => false,
Identity::Polynomial(_) | Identity::Connect(_) => false,
}
}

impl<T: FieldElement> FixedEvaluator<T> for &BlockMachineProcessor<'_, T> {
Expand Down Expand Up @@ -315,7 +158,10 @@ mod test {
use crate::witgen::{
data_structures::mutable_state::MutableState,
global_constraints,
jit::{effect::Effect, test_util::read_pil},
jit::{
effect::{format_code, Effect},
test_util::read_pil,
},
machines::{machine_extractor::MachineExtractor, KnownMachine, Machine},
FixedData,
};
Expand Down Expand Up @@ -400,11 +246,11 @@ params[2] = Add::c[0];"
.err()
.unwrap();
assert!(err_str
.contains("Unable to derive algorithm to compute output value \"Unconstrained::c\""));
.contains("The following variables or values are still missing: Unconstrained::c"));
}

#[test]
#[should_panic = "Block machine shape does not allow stacking"]
#[should_panic = "Column NotStackable::a is not stackable in a 1-row block"]
fn not_stackable() {
let input = "
namespace Main(256);
Expand Down
5 changes: 3 additions & 2 deletions executor/src/witgen/jit/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ pub fn compile_effects<T: FieldElement>(

record_start("JIT-compilation");
let start = std::time::Instant::now();
log::trace!("Calling cargo...");
let r = powdr_jit_compiler::call_cargo(&code);
let opt_level = 0;
log::trace!("Compiling the following code using optimization level {opt_level}:\n{code}");
let r = powdr_jit_compiler::call_cargo(&code, Some(opt_level));
log::trace!("Done compiling, took {:.2}s", start.elapsed().as_secs_f32());
record_end("JIT-compilation");
let lib_path = r.map_err(|e| format!("Failed to compile generated code: {e}"))?;
Expand Down
33 changes: 16 additions & 17 deletions executor/src/witgen/jit/effect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,24 @@ pub enum Effect<T: FieldElement, V> {
}

impl<T: FieldElement, V: Hash + Eq> Effect<T, V> {
pub fn referenced_variables(&self) -> Box<dyn Iterator<Item = &V> + '_> {
match self {
Effect::Assignment(v, expr) => {
Box::new(iter::once(v).chain(expr.referenced_symbols()).unique())
}
pub fn referenced_variables(&self) -> impl Iterator<Item = &V> {
let iter: Box<dyn Iterator<Item = &V>> = match self {
Effect::Assignment(v, expr) => Box::new(iter::once(v).chain(expr.referenced_symbols())),
Effect::RangeConstraint(v, _) => Box::new(iter::once(v)),
Effect::Assertion(Assertion { lhs, rhs, .. }) => Box::new(
lhs.referenced_symbols()
.chain(rhs.referenced_symbols())
.unique(),
),
Effect::MachineCall(_, _, args) => Box::new(args.iter().unique()),
Effect::Branch(branch_condition, vec, vec1) => Box::new(
iter::once(&branch_condition.variable)
.chain(vec.iter().flat_map(|effect| effect.referenced_variables()))
.chain(vec1.iter().flat_map(|effect| effect.referenced_variables()))
.unique(),
Effect::Assertion(Assertion { lhs, rhs, .. }) => {
Box::new(lhs.referenced_symbols().chain(rhs.referenced_symbols()))
}
Effect::MachineCall(_, _, args) => Box::new(args.iter()),
Effect::Branch(branch_condition, first, second) => Box::new(
iter::once(&branch_condition.variable).chain(
[first, second]
.into_iter()
.flatten()
.flat_map(|effect| effect.referenced_variables()),
),
),
}
};
iter.unique()
}
}

Expand Down
Loading

0 comments on commit cb11fb6

Please sign in to comment.