Skip to content

Commit

Permalink
use syntactic equality
Browse files Browse the repository at this point in the history
  • Loading branch information
Schaeff committed Nov 17, 2024
1 parent 58ddeaf commit d0a95ff
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 34 deletions.
10 changes: 7 additions & 3 deletions ast/src/analyzed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ pub enum SymbolKind {
Other(),
}

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
pub enum FunctionValueDefinition {
Array(ArrayExpression<Reference>),
Expression(TypedExpression),
Expand Down Expand Up @@ -1152,7 +1152,9 @@ impl<T> SelectedExpressions<T> {
pub type Expression = parsed::Expression<Reference>;
pub type TypedExpression = crate::parsed::TypedExpression<Reference, u64>;

#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash, PartialOrd, Ord,
)]
pub enum Reference {
LocalVar(u64, String),
Poly(PolynomialReference),
Expand Down Expand Up @@ -1567,7 +1569,9 @@ impl<T> From<T> for AlgebraicExpression<T> {
/// Reference to a symbol with optional type arguments.
/// Named `PolynomialReference` for historical reasons, it can reference
/// any symbol.
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, PartialOrd, Ord, Hash,
)]
pub struct PolynomialReference {
/// Absolute name of the symbol.
pub name: String,
Expand Down
102 changes: 76 additions & 26 deletions ast/src/parsed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ impl Children<Expression> for PilStatement {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum TypeDeclaration<E = u64> {
Enum(EnumDeclaration<E>),
Struct(StructDeclaration<E>),
Expand Down Expand Up @@ -289,7 +291,9 @@ impl<R> Children<Expression<R>> for TypeDeclaration<u64> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct StructDeclaration<E = u64> {
pub name: String,
pub type_vars: TypeBounds,
Expand Down Expand Up @@ -326,7 +330,9 @@ impl<R> Children<Expression<R>> for StructDeclaration<u64> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct EnumDeclaration<E = u64> {
pub name: String,
pub type_vars: TypeBounds,
Expand All @@ -351,7 +357,9 @@ impl<R> Children<Expression<R>> for EnumDeclaration<Expression<R>> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct EnumVariant<E = u64> {
pub name: String,
pub fields: Option<Vec<Type<E>>>,
Expand Down Expand Up @@ -452,7 +460,9 @@ impl<R> Children<Expression<R>> for TraitImplementation<Expression<R>> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct NamedExpression<Expr> {
pub name: String,
pub body: Expr,
Expand All @@ -476,7 +486,9 @@ impl<R> Children<Expression<R>> for NamedExpression<Arc<Expression<R>>> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct TraitDeclaration<E = u64> {
pub name: String,
pub type_vars: Vec<String>,
Expand All @@ -502,7 +514,9 @@ impl<R> Children<Expression<R>> for TraitDeclaration<Expression<R>> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct NamedType<E = u64> {
pub name: String,
pub ty: Type<E>,
Expand Down Expand Up @@ -548,7 +562,9 @@ impl<T> Children<Expression<T>> for SelectedExpressions<Expression<T>> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum Expression<Ref = NamespacedPolynomialReference> {
Reference(SourceRef, Ref),
PublicReference(SourceRef, String),
Expand Down Expand Up @@ -641,7 +657,9 @@ impl_source_reference!(
Expression
);

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct UnaryOperation<E = Expression<NamespacedPolynomialReference>> {
pub op: UnaryOperator,
pub expr: Box<E>,
Expand All @@ -663,7 +681,9 @@ impl<E> Children<E> for UnaryOperation<E> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct BinaryOperation<E = Expression<NamespacedPolynomialReference>> {
pub left: Box<E>,
pub op: BinaryOperator,
Expand All @@ -686,7 +706,9 @@ impl<E> Children<E> for BinaryOperation<E> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct Number {
#[schemars(skip)]
pub value: BigUint,
Expand Down Expand Up @@ -739,7 +761,9 @@ impl<Ref> Expression<Ref> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct MatchExpression<E = Expression<NamespacedPolynomialReference>> {
pub scrutinee: Box<E>,
pub arms: Vec<MatchArm<E>>,
Expand All @@ -766,7 +790,9 @@ impl<E> Children<E> for MatchExpression<E> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct BlockExpression<E> {
pub statements: Vec<StatementInsideBlock<E>>,
pub expr: Option<Box<E>>,
Expand Down Expand Up @@ -959,7 +985,9 @@ impl NamespacedPolynomialReference {
}
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct LambdaExpression<E = Expression<NamespacedPolynomialReference>> {
pub kind: FunctionKind,
pub params: Vec<Pattern>,
Expand All @@ -985,15 +1013,17 @@ impl<E> Children<E> for LambdaExpression<E> {
}

#[derive(
Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema,
Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum FunctionKind {
Pure,
Constr,
Query,
}

#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct ArrayLiteral<E = Expression<NamespacedPolynomialReference>> {
pub items: Vec<E>,
}
Expand Down Expand Up @@ -1156,7 +1186,9 @@ impl BinaryOperator {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct IndexAccess<E = Expression<NamespacedPolynomialReference>> {
pub array: Box<E>,
pub index: Box<E>,
Expand All @@ -1178,7 +1210,9 @@ impl<E> Children<E> for IndexAccess<E> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct FunctionCall<E = Expression<NamespacedPolynomialReference>> {
pub function: Box<E>,
pub arguments: Vec<E>,
Expand All @@ -1200,7 +1234,9 @@ impl<E> Children<E> for FunctionCall<E> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct MatchArm<E = Expression<NamespacedPolynomialReference>> {
pub pattern: Pattern,
pub value: E,
Expand All @@ -1216,7 +1252,9 @@ impl<E> Children<E> for MatchArm<E> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct IfExpression<E = Expression<NamespacedPolynomialReference>> {
pub condition: Box<E>,
pub body: Box<E>,
Expand Down Expand Up @@ -1249,7 +1287,9 @@ impl<E> Children<E> for IfExpression<E> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct StructExpression<Ref = NamespacedPolynomialReference> {
pub name: Ref,
pub fields: Vec<NamedExpression<Box<Expression<Ref>>>>,
Expand All @@ -1271,7 +1311,9 @@ impl<R> Children<Expression<R>> for StructExpression<R> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum StatementInsideBlock<E = Expression<NamespacedPolynomialReference>> {
// TODO add a source ref here.
LetStatement(LetStatementInsideBlock<E>),
Expand All @@ -1294,7 +1336,9 @@ impl<E> Children<E> for StatementInsideBlock<E> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct LetStatementInsideBlock<E = Expression<NamespacedPolynomialReference>> {
pub pattern: Pattern,
pub ty: Option<Type<u64>>,
Expand Down Expand Up @@ -1352,7 +1396,9 @@ impl Children<Expression> for FunctionDefinition {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum ArrayExpression<Ref = NamespacedPolynomialReference> {
Value(Vec<Expression<Ref>>),
RepeatedValue(Vec<Expression<Ref>>),
Expand Down Expand Up @@ -1511,7 +1557,9 @@ impl<Ref> Children<Expression<Ref>> for ArrayExpression<Ref> {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub enum Pattern {
CatchAll(SourceRef), // "_", matches a single value
Ellipsis(SourceRef), // "..", matches a series of values, only valid inside array patterns
Expand Down Expand Up @@ -1604,7 +1652,9 @@ impl SourceReference for Pattern {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct TypedExpression<Ref = NamespacedPolynomialReference, E = Expression<Ref>> {
pub e: Expression<Ref>,
pub type_scheme: Option<TypeScheme<E>>,
Expand Down
6 changes: 4 additions & 2 deletions ast/src/parsed/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,9 @@ impl From<FunctionType> for Type {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema, Hash,
)]
pub struct TypeScheme<E = u64> {
/// Type variables and their trait bounds.
pub vars: TypeBounds,
Expand Down Expand Up @@ -481,7 +483,7 @@ impl From<Type> for TypeScheme {
}

#[derive(
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Default, Serialize, Deserialize, JsonSchema,
Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Default, Serialize, Deserialize, JsonSchema, Hash,
)]
// TODO bounds should be SymbolPaths in the future.
pub struct TypeBounds(Vec<(String, BTreeSet<String>)>);
Expand Down
5 changes: 5 additions & 0 deletions parser-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

use std::{
fmt::{self, Debug, Formatter},
hash::Hash,
sync::Arc,
};

Expand Down Expand Up @@ -38,6 +39,10 @@ impl PartialOrd for SourceRef {
}
}

impl Hash for SourceRef {
fn hash<H: std::hash::Hasher>(&self, _: &mut H) {}
}

impl SourceRef {
pub fn unknown() -> Self {
Default::default()
Expand Down
6 changes: 3 additions & 3 deletions pilopt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,18 +218,18 @@ fn constant_value(function: &FunctionValueDefinition) -> Option<BigUint> {
}

/// Deduplicate fixed columns of the same namespace which share the same value.
/// This uses the `Display` implementation of the function value, so `|i| i` is different from `|j| j`
/// This compares the function values, so `|i| i` is different from `|j| j`
/// This is enough for use cases where exactly the same function is inserted many times
/// This only replaces the references inside expressions and does not clean up the now unreachable fixed column definitions
fn deduplicate_fixed_columns<T: FieldElement>(pil_file: &mut Analyzed<T>) {
// build a map of `poly_id` to the `(name, poly_id)` they can be replaced by
let replacement_map: BTreeMap<PolyID, (String, PolyID)> = pil_file
.constant_polys_in_source_order()
// group symbols by common namespace and displayed value
// group symbols by common namespace and function value
.into_group_map_by(|(symbol, value)| {
(
symbol.absolute_name.split("::").next().unwrap(),
value.as_ref().unwrap().to_string(),
value.as_ref().unwrap(),
)
})
.values()
Expand Down

0 comments on commit d0a95ff

Please sign in to comment.