diff --git a/ast/src/parsed/visitor.rs b/ast/src/parsed/visitor.rs index cec85e55cc..f1d8386d0f 100644 --- a/ast/src/parsed/visitor.rs +++ b/ast/src/parsed/visitor.rs @@ -16,7 +16,8 @@ pub trait Children { } pub trait AllChildren { - /// Returns an iterator over all direct and indirect children of kind O in this object. + /// Returns an iterator over all direct and indirect children of kind `O` in this object. + /// If `O` and `Self` are the same type, also includes `self`. /// Pre-order visitor. fn all_children(&self) -> Box + '_>; } diff --git a/executor/src/witgen/jit/effect.rs b/executor/src/witgen/jit/effect.rs index b69225c8fd..f193e5aff1 100644 --- a/executor/src/witgen/jit/effect.rs +++ b/executor/src/witgen/jit/effect.rs @@ -28,25 +28,24 @@ pub enum Effect { } impl Effect { - pub fn referenced_variables(&self) -> Box + '_> { - match self { - Effect::Assignment(v, expr) => { - Box::new(iter::once(v).chain(expr.referenced_symbols()).unique()) - } + pub fn referenced_variables(&self) -> impl Iterator { + let iter: Box> = match self { + Effect::Assignment(v, expr) => Box::new(iter::once(v).chain(expr.referenced_symbols())), Effect::RangeConstraint(v, _) => Box::new(iter::once(v)), - Effect::Assertion(Assertion { lhs, rhs, .. }) => Box::new( - lhs.referenced_symbols() - .chain(rhs.referenced_symbols()) - .unique(), - ), - Effect::MachineCall(_, _, args) => Box::new(args.iter().unique()), - Effect::Branch(branch_condition, vec, vec1) => Box::new( - iter::once(&branch_condition.variable) - .chain(vec.iter().flat_map(|effect| effect.referenced_variables())) - .chain(vec1.iter().flat_map(|effect| effect.referenced_variables())) - .unique(), + Effect::Assertion(Assertion { lhs, rhs, .. }) => { + Box::new(lhs.referenced_symbols().chain(rhs.referenced_symbols())) + } + Effect::MachineCall(_, _, args) => Box::new(args.iter()), + Effect::Branch(branch_condition, first, second) => Box::new( + iter::once(&branch_condition.variable).chain( + [first, second] + .into_iter() + .flatten() + .flat_map(|effect| effect.referenced_variables()), + ), ), - } + }; + iter.unique() } } diff --git a/executor/src/witgen/jit/symbolic_expression.rs b/executor/src/witgen/jit/symbolic_expression.rs index 9a85e8cd34..1301d7ad85 100644 --- a/executor/src/witgen/jit/symbolic_expression.rs +++ b/executor/src/witgen/jit/symbolic_expression.rs @@ -1,6 +1,7 @@ +use auto_enums::auto_enum; use itertools::Itertools; use num_traits::Zero; -use powdr_ast::parsed::visitor::Children; +use powdr_ast::parsed::visitor::AllChildren; use powdr_number::FieldElement; use std::hash::Hash; use std::{ @@ -48,22 +49,26 @@ pub enum UnaryOperator { Neg, } -impl Children> for SymbolicExpression { - fn children(&self) -> Box> + '_> { +impl SymbolicExpression { + /// Returns all direct children of this expression. + /// Does specifically not implement the `Children` trait, because it does not go + /// well with recursive types. + #[auto_enum(Iterator)] + fn children(&self) -> impl Iterator> { match self { SymbolicExpression::BinaryOperation(lhs, _, rhs, _) => { - Box::new(iter::once(lhs.as_ref()).chain(iter::once(rhs.as_ref()))) - } - SymbolicExpression::UnaryOperation(_, expr, _) => Box::new(iter::once(expr.as_ref())), - SymbolicExpression::BitOperation(expr, _, _, _) => Box::new(iter::once(expr.as_ref())), - SymbolicExpression::Concrete(_) | SymbolicExpression::Symbol(..) => { - Box::new(iter::empty()) + [lhs.as_ref(), rhs.as_ref()].into_iter() } + SymbolicExpression::UnaryOperation(_, expr, _) + | SymbolicExpression::BitOperation(expr, _, _, _) => iter::once(expr.as_ref()), + SymbolicExpression::Concrete(_) | SymbolicExpression::Symbol(..) => iter::empty(), } } +} - fn children_mut(&mut self) -> Box> + '_> { - unimplemented!() +impl AllChildren> for SymbolicExpression { + fn all_children(&self) -> Box> + '_> { + Box::new(iter::once(self).chain(self.children().flat_map(|e| e.all_children()))) } } @@ -112,15 +117,13 @@ impl SymbolicExpression { } impl SymbolicExpression { - pub fn referenced_symbols(&self) -> Box + '_> { - match self { - SymbolicExpression::Symbol(s, _) => Box::new(iter::once(s)), - _ => Box::new( - self.children() - .flat_map(|c| c.referenced_symbols()) - .unique(), - ), - } + pub fn referenced_symbols(&self) -> impl Iterator { + self.all_children() + .flat_map(|e| match e { + SymbolicExpression::Symbol(s, _) => Some(s), + _ => None, + }) + .unique() } }