Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into jit-interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
pacheco committed Jan 8, 2025
2 parents 0a4475f + 067b633 commit d3554f1
Show file tree
Hide file tree
Showing 39 changed files with 932 additions and 577 deletions.
49 changes: 30 additions & 19 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::parsed::{
TraitDeclaration, TraitImplementation, TypeDeclaration,
};

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
pub enum StatementIdentifier {
/// Either an intermediate column or a definition.
Definition(String),
Expand Down Expand Up @@ -685,7 +685,7 @@ impl DegreeRange {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct Symbol {
pub id: u64,
pub source: SourceRef,
Expand Down Expand Up @@ -745,7 +745,7 @@ impl Symbol {
/// The "kind" of a symbol. In the future, this will be mostly
/// replaced by its type.
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema,
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum SymbolKind {
/// Fixed, witness or intermediate polynomial
Expand Down Expand Up @@ -815,7 +815,7 @@ impl Children<Expression> for NamedType {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PublicDeclaration {
pub id: u64,
pub source: SourceRef,
Expand All @@ -835,7 +835,9 @@ impl PublicDeclaration {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct SelectedExpressions<T> {
pub selector: AlgebraicExpression<T>,
pub expressions: Vec<AlgebraicExpression<T>>,
Expand All @@ -861,7 +863,7 @@ impl<T> Children<AlgebraicExpression<T>> for SelectedExpressions<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PolynomialIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -878,7 +880,7 @@ impl<T> Children<AlgebraicExpression<T>> for PolynomialIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct LookupIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -900,7 +902,7 @@ impl<T> Children<AlgebraicExpression<T>> for LookupIdentity<T> {
///
/// This identity is used as a replacement for a lookup identity which has been turned into challenge-based polynomial identities.
/// This is ignored by the backend.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PhantomLookupIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand Down Expand Up @@ -929,7 +931,7 @@ impl<T> Children<AlgebraicExpression<T>> for PhantomLookupIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PermutationIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -951,7 +953,7 @@ impl<T> Children<AlgebraicExpression<T>> for PermutationIdentity<T> {
///
/// This identity is used as a replacement for a permutation identity which has been turned into challenge-based polynomial identities.
/// This is ignored by the backend.
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PhantomPermutationIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -969,7 +971,7 @@ impl<T> Children<AlgebraicExpression<T>> for PhantomPermutationIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct ConnectIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand All @@ -987,7 +989,9 @@ impl<T> Children<AlgebraicExpression<T>> for ConnectIdentity<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, PartialOrd, Ord)]
#[derive(
Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, PartialOrd, Ord, Hash,
)]
pub struct ExpressionList<T>(pub Vec<AlgebraicExpression<T>>);

impl<T> Children<AlgebraicExpression<T>> for ExpressionList<T> {
Expand All @@ -999,7 +1003,7 @@ impl<T> Children<AlgebraicExpression<T>> for ExpressionList<T> {
}
}

#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize, JsonSchema, Hash)]
pub struct PhantomBusInteractionIdentity<T> {
// The ID is globally unique among identities.
pub id: u64,
Expand Down Expand Up @@ -1034,6 +1038,7 @@ impl<T> Children<AlgebraicExpression<T>> for PhantomBusInteractionIdentity<T> {
Serialize,
Deserialize,
JsonSchema,
Hash,
derive_more::Display,
derive_more::From,
derive_more::TryInto,
Expand Down Expand Up @@ -1235,7 +1240,9 @@ impl Hash for AlgebraicReference {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum AlgebraicExpression<T> {
Reference(AlgebraicReference),
PublicReference(String),
Expand All @@ -1245,7 +1252,9 @@ pub enum AlgebraicExpression<T> {
UnaryOperation(AlgebraicUnaryOperation<T>),
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct AlgebraicBinaryOperation<T> {
pub left: Box<AlgebraicExpression<T>>,
pub op: AlgebraicBinaryOperator,
Expand All @@ -1271,7 +1280,9 @@ impl<T> From<AlgebraicBinaryOperation<T>> for AlgebraicExpression<T> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct AlgebraicUnaryOperation<T> {
pub op: AlgebraicUnaryOperator,
pub expr: Box<AlgebraicExpression<T>>,
Expand Down Expand Up @@ -1468,7 +1479,7 @@ impl<T> AlgebraicExpression<T> {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct Challenge {
/// Challenge ID
Expand All @@ -1477,7 +1488,7 @@ pub struct Challenge {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum AlgebraicBinaryOperator {
Add,
Expand Down Expand Up @@ -1515,7 +1526,7 @@ impl TryFrom<BinaryOperator> for AlgebraicBinaryOperator {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum AlgebraicUnaryOperator {
Minus,
Expand Down
4 changes: 3 additions & 1 deletion ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ impl<R> Children<Expression<R>> for EnumVariant<Expression<R>> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct TraitImplementation<Expr> {
pub name: SymbolPath,
pub source_ref: SourceRef,
Expand Down
10 changes: 4 additions & 6 deletions backend/src/halo2/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use powdr_ast::analyzed::{
AlgebraicExpression, AlgebraicReferenceThin, Identity, PolynomialIdentity, PolynomialType,
SelectedExpressions,
};
use powdr_executor_utils::expression_evaluator::{ExpressionEvaluator, GlobalValues, TraceValues};
use powdr_executor_utils::expression_evaluator::{ExpressionEvaluator, TerminalAccess};
use powdr_number::FieldElement;

const FIRST_STEP_NAME: &str = "__first_step";
Expand Down Expand Up @@ -553,8 +553,8 @@ impl<'a, F: PrimeField<Repr = [u8; 32]>> Data<'a, '_, '_, F> {
fn evaluator<T: FieldElement>(
&self,
intermediate_definitions: &'a BTreeMap<AlgebraicReferenceThin, AlgebraicExpression<T>>,
) -> ExpressionEvaluator<'a, T, Expression<F>, &Self, &Self> {
ExpressionEvaluator::new_with_custom_expr(self, self, intermediate_definitions, |n| {
) -> ExpressionEvaluator<'a, T, Expression<F>, &Self> {
ExpressionEvaluator::new_with_custom_expr(self, intermediate_definitions, |n| {
Expression::Constant(convert_field(*n))
})
}
Expand All @@ -564,7 +564,7 @@ impl<'a, F: PrimeField<Repr = [u8; 32]>> Data<'a, '_, '_, F> {
}
}

impl<F: Field> TraceValues<Expression<F>> for &Data<'_, '_, '_, F> {
impl<F: Field> TerminalAccess<Expression<F>> for &Data<'_, '_, '_, F> {
fn get(&self, poly_ref: &powdr_ast::analyzed::AlgebraicReference) -> Expression<F> {
let rotation = match poly_ref.next {
false => Rotation::cur(),
Expand All @@ -578,9 +578,7 @@ impl<F: Field> TraceValues<Expression<F>> for &Data<'_, '_, '_, F> {
panic!("Unknown reference: {}", poly_ref.name)
}
}
}

impl<F: Field> GlobalValues<Expression<F>> for &Data<'_, '_, '_, F> {
fn get_public(&self, _public: &str) -> Expression<F> {
unimplemented!()
}
Expand Down
2 changes: 1 addition & 1 deletion backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub const DEFAULT_ESTARK_OPTIONS: &str = "stark_gl";
impl BackendType {
pub fn factory<T: FieldElement>(&self) -> Box<dyn BackendFactory<T>> {
match self {
BackendType::Mock => Box::new(mock::MockBackendFactory::new()),
BackendType::Mock => Box::new(mock::MockBackendFactory),
#[cfg(feature = "halo2")]
BackendType::Halo2 => Box::new(halo2::Halo2ProverFactory),
#[cfg(feature = "halo2")]
Expand Down
28 changes: 4 additions & 24 deletions backend/src/mock/connection_constraint_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use std::ops::ControlFlow;

use itertools::Itertools;
use powdr_ast::analyzed::AlgebraicExpression;
use powdr_ast::analyzed::AlgebraicReference;
use powdr_ast::analyzed::Analyzed;
use powdr_ast::analyzed::{
Identity, LookupIdentity, PermutationIdentity, PhantomLookupIdentity,
Expand All @@ -15,8 +14,7 @@ use powdr_ast::parsed::visitor::ExpressionVisitable;
use powdr_ast::parsed::visitor::VisitOrder;
use powdr_backend_utils::referenced_namespaces_algebraic_expression;
use powdr_executor_utils::expression_evaluator::ExpressionEvaluator;
use powdr_executor_utils::expression_evaluator::OwnedGlobalValues;
use powdr_executor_utils::expression_evaluator::TraceValues;
use powdr_executor_utils::expression_evaluator::TerminalAccess;
use powdr_number::FieldElement;
use rayon::iter::IntoParallelIterator;
use rayon::iter::ParallelIterator;
Expand Down Expand Up @@ -152,25 +150,17 @@ impl<F: FieldElement> Connection<F> {
pub struct ConnectionConstraintChecker<'a, F: FieldElement> {
connections: &'a [Connection<F>],
machines: BTreeMap<String, Machine<'a, F>>,
global_values: OwnedGlobalValues<F>,
}

impl<'a, F: FieldElement> ConnectionConstraintChecker<'a, F> {
/// Creates a new connection constraint checker.
pub fn new(
connections: &'a [Connection<F>],
machines: BTreeMap<String, Machine<'a, F>>,
challenges: &'a BTreeMap<u64, F>,
) -> Self {
let global_values = OwnedGlobalValues {
// TODO: Support publics.
public_values: BTreeMap::new(),
challenge_values: challenges.clone(),
};
Self {
connections,
machines,
global_values,
}
}
}
Expand Down Expand Up @@ -276,8 +266,7 @@ impl<'a, F: FieldElement> ConnectionConstraintChecker<'a, F> {
.into_par_iter()
.filter_map(|row| {
let mut evaluator = ExpressionEvaluator::new(
machine.trace_values.row(row),
&self.global_values,
machine.values.row(row),
&machine.intermediate_definitions,
);
let result = evaluator.evaluate(&selected_expressions.selector);
Expand All @@ -300,9 +289,7 @@ impl<'a, F: FieldElement> ConnectionConstraintChecker<'a, F> {
None => {
let empty_variables = EmptyVariables {};
let empty_definitions = BTreeMap::new();
let empty_globals = OwnedGlobalValues::default();
let mut evaluator =
ExpressionEvaluator::new(empty_variables, &empty_globals, &empty_definitions);
let mut evaluator = ExpressionEvaluator::new(empty_variables, &empty_definitions);
let selector_value: F = evaluator.evaluate(&selected_expressions.selector);

match selector_value.to_degree() {
Expand Down Expand Up @@ -339,14 +326,7 @@ impl<'a, F: FieldElement> ConnectionConstraintChecker<'a, F> {

struct EmptyVariables;

impl<T> TraceValues<T> for EmptyVariables
where
T: FieldElement,
{
fn get(&self, _reference: &AlgebraicReference) -> T {
panic!()
}
}
impl<T: FieldElement> TerminalAccess<T> for EmptyVariables {}

/// Converts a slice to a multi-set, represented as a map from elements to their count.
fn to_multi_set<T: Ord>(a: &[T]) -> BTreeMap<&T, usize> {
Expand Down
10 changes: 6 additions & 4 deletions backend/src/mock/machine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use itertools::Itertools;
use powdr_ast::analyzed::{AlgebraicExpression, AlgebraicReferenceThin, Analyzed};
use powdr_backend_utils::{machine_fixed_columns, machine_witness_columns};
use powdr_executor::constant_evaluator::VariablySizedColumn;
use powdr_executor_utils::{expression_evaluator::OwnedTraceValues, WitgenCallback};
use powdr_executor_utils::{expression_evaluator::OwnedTerminalValues, WitgenCallback};
use powdr_number::{DegreeType, FieldElement};

/// A collection of columns with self-contained constraints.
pub struct Machine<'a, F> {
pub machine_name: String,
pub size: usize,
pub trace_values: OwnedTraceValues<F>,
pub values: OwnedTerminalValues<F>,
pub pil: &'a Analyzed<F>,
pub intermediate_definitions: BTreeMap<AlgebraicReferenceThin, AlgebraicExpression<F>>,
}
Expand Down Expand Up @@ -55,12 +55,14 @@ impl<'a, F: FieldElement> Machine<'a, F> {

let intermediate_definitions = pil.intermediate_definitions();

let trace_values = OwnedTraceValues::new(pil, witness, fixed);
// TODO: Supports publics.
let values =
OwnedTerminalValues::new(pil, witness, fixed).with_challenges(challenges.clone());

Some(Self {
machine_name,
size,
trace_values,
values,
pil,
intermediate_definitions,
})
Expand Down
Loading

0 comments on commit d3554f1

Please sign in to comment.