Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stack overflow from recursive function in DynamicExpressions.jl #428

Open
MilesCranmer opened this issue Dec 19, 2024 · 6 comments
Open
Labels
enhancement (error messages) The error was produced that should be improved upon high priority

Comments

@MilesCranmer
Copy link

Hey all,

Thanks for working on this! I'm trying out Mooncake 0.4.65 on DynamicExpressions.jl 1.8.0 (via DifferentiationInterface 0.6.27) and ran into a stack overflow from this example:

import Mooncake
using DynamicExpressions
using DifferentiationInterface

# Build up expression:
operators = OperatorEnum(; binary_operators=(+, -, *, /), unary_operators=(cos, sin))
variable_names = ["x1", "x2", "x3"]
x1, x2, x3 = map(i -> Expression(Node{Float64}(; feature=i); operators, variable_names), 1:3)
f = x1 + cos(x2 - 0.2)

eval_sum = let f = f
    X -> sum(f(X)[1])
end
backend = AutoMooncake(; config=nothing)

# Example data
X = randn(3, 100)
dX = gradient(f, backend, X)

This hits the following error:

ERROR: LoadError: MooncakeRuleCompilationError: an error occured while Mooncake was compiling a rule to differentiate something. If the `caused by` error message below does not make it clear to you how the problem can be fixed, please open an issue at github.com/compintell/Mooncake.jl describing your problem.
To replicate this error run the following:

Mooncake.build_rrule(Mooncake.MooncakeInterpreter(), Tuple{Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, Matrix{Float64}}; debug_mode=false)

Note that you may need to `using` some additional packages if not all of the names printed in the above signature are available currently in your environment.

Stacktrace:
 [1] build_rrule(interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
   @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1074
 [2] build_rrule
   @ ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1017 [inlined]
 [3] prepare_pullback_cache(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
   @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interface.jl:191
 [4] prepare_pullback_cache
   @ ~/.julia/packages/Mooncake/N9iX9/src/interface.jl:185 [inlined]
 [5] prepare_pullback(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64}, ::Tuple{Bool})
   @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/gjT8p/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:10
 [6] prepare_gradient(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64})
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/gjT8p/src/first_order/gradient.jl:70
 [7] gradient(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64})
   @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/gjT8p/src/fallbacks/no_prep.jl:48
 [8] top-level scope
   @ ~/PermaDocuments/SymbolicRegressionMonorepo/DynamicExpressions.jl/test/test_mooncake.jl:29
in expression starting at /Users/mcranmer/PermaDocuments/SymbolicRegressionMonorepo/DynamicExpressions.jl/test/test_mooncake.jl:29

caused by: StackOverflowError:
Stacktrace:
     [1] _stable_typeof
       @ ./operators.jl:929 [inlined]
     [2] Base.Fix1(f::typeof(Mooncake.tangent_field_type), x::Type)
       @ Base ./operators.jl:1123
     [3] tangent_field_types
       @ ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:445 [inlined]
     [4] #s11#44
       @ ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:437 [inlined]
     [5] var"#s11#44"(P::Any, ::Any, ::Any)
       @ Mooncake ./none:0
     [6] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
       @ Core ./boot.jl:707
     [7] tangent_field_type(::Type{Node{Float64}}, n::Int64)
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:461
     [8] Fix1
       @ ./operators.jl:1127 [inlined]
     [9] tuple_map(f::Base.Fix1{typeof(Mooncake.tangent_field_type), Type{Node{Float64}}}, x::NTuple{7, Int64})
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/utils.jl:41
--- the above 7 lines are repeated 5435 more times ---
 [38055] tangent_field_types
       @ ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:445 [inlined]
 [38056] #s11#44
       @ ~/.julia/packages/Mooncake/N9iX9/src/tangents.jl:437 [inlined]
 [38057] var"#s11#44"(P::Any, ::Any, ::Any)
       @ Mooncake ./none:0
 [38058] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
       @ Core ./boot.jl:707
 [38059] #s11#110
       @ ~/.julia/packages/Mooncake/N9iX9/src/fwds_rvs_data.jl:560 [inlined]
 [38060] var"#s11#110"(P::Any, ::Any, ::Any)
       @ Mooncake ./none:0
 [38061] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
       @ Core ./boot.jl:707
 [38062] lazy_zero_rdata_type(::Type{Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}})
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/fwds_rvs_data.jl:751
 [38063] call_composed
       @ ./operators.jl:1053 [inlined]
 [38064] (::ComposedFunction{typeof(Mooncake.lazy_zero_rdata_type), typeof(Mooncake._type)})(x::Type; kw::@Kwargs{})
       @ Base ./operators.jl:1050
 [38065] iterate
       @ ./generator.jl:48 [inlined]
 [38066] _collect(c::Vector{Any}, itr::Base.Generator{Vector{Any}, ComposedFunction{typeof(Mooncake.lazy_zero_rdata_type), typeof(Mooncake._type)}}, ::Base.EltypeUnknown, isz::Base.HasShape{1})
       @ Base ./array.jl:811
 [38067] collect_similar
       @ ./array.jl:720 [inlined]
 [38068] map
       @ ./abstractarray.jl:3371 [inlined]
 [38069] Mooncake.ADInfo(interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, ir::Mooncake.BBCode, debug_mode::Bool)
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:175
 [38070] generate_ir(interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, sig_or_mi::Type; debug_mode::Bool, do_inline::Bool)
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1104
 [38071] generate_ir
       @ ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1087 [inlined]
 [38072] build_rrule(interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1050
 [38073] build_rrule
       @ ~/.julia/packages/Mooncake/N9iX9/src/interpreter/s2s_reverse_mode_ad.jl:1017 [inlined]
 [38074] prepare_pullback_cache(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::Vararg{Any}; kwargs::@Kwargs{debug_mode::Bool, silence_debug_messages::Bool})
       @ Mooncake ~/.julia/packages/Mooncake/N9iX9/src/interface.jl:191
 [38075] prepare_pullback_cache
       @ ~/.julia/packages/Mooncake/N9iX9/src/interface.jl:185 [inlined]
 [38076] prepare_pullback(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64}, ::Tuple{Bool})
       @ DifferentiationInterfaceMooncakeExt ~/.julia/packages/DifferentiationInterface/gjT8p/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:10
 [38077] prepare_gradient(::Expression{Float64, Node{Float64}, @NamedTuple{operators::OperatorEnum{Tuple{typeof(+), typeof(-), typeof(*), typeof(/)}, Tuple{typeof(cos), typeof(sin)}}, variable_names::Vector{String}}}, ::AutoMooncake{Nothing}, ::Matrix{Float64})
       @ DifferentiationInterface ~/.julia/packages/DifferentiationInterface/gjT8p/src/first_order/gradient.jl:70

I'm assuming this is from the recursive evaluation of DynamicExpressions.jl, which gets branched from here: https://github.com/SymbolicML/DynamicExpressions.jl/blob/dde92915df3ed275989e53d3691fd7f9280d9b14/src/Evaluate.jl#L242-L269. I think it should be possible to make this work since Enzyme.jl can now differentiate it. Zygote.jl can't because there is array mutation, however.

@willtebbutt
Copy link
Member

willtebbutt commented Dec 19, 2024

Hi Miles. Thanks for trying out Mooncake!

The only source of potential stack overflows I'm currently aware of in Mooncake if you ask for the tangent_type of a type whose name appears in its own definition. For example, something like

julia> using Mooncake

julia> struct Foo
           x::Union{Foo, Nothing}
       end

julia> tangent_type(Foo)
ERROR: StackOverflowError:
Stacktrace:
 [1] tangent_type(::Type{Foo})
   @ Mooncake ./none:0
 [2] macro expansion
   @ ./none:0 [inlined]
 [3] tangent_type(::Type{Union{Nothing, Foo}})
   @ Mooncake ./none:0

It looks to me like your stack overflow is happening during a tangent_type call so, before I dig into your issue a bit further, could you confirm if e.g. your Expression or Node types have this property?

@MilesCranmer
Copy link
Author

Thanks! Yes the Node is recursive: https://ai.damtp.cam.ac.uk/dynamicexpressions/dev/api/#Nodes. It’s a binary tree structure. Any workarounds?

@willtebbutt
Copy link
Member

Ah, damn. Sadly there is not a work around at the minute.

Short explanation: Mooncake derives a "tangent type" for each Julia type it encounters -- it does this recursively. For a structs "primal" type, it produces something of the form Tangent{NameTuple{fieldnames, tangent_types_of_fields}}, where tangent_types_of_fields is a Tuple containing the result of tangent_type for each of the fields of the original struct. This is where the problem arises: this winds up being recursive if the name of the type appears in the type.

Enzyme circumvents this problem entirely by using the primal type as its own tangent type.

In the short term I can probably improve the error message to make it so that future users do not have to open an issue about this. In the medium term we should be able to add a macro that makes it a one-line fix to make this work correctly.

I'm going to label this issue as a "should have given a better error" issue for now.

@willtebbutt willtebbutt added the enhancement (error messages) The error was produced that should be improved upon label Dec 19, 2024
@yebai
Copy link
Contributor

yebai commented Dec 19, 2024

In the medium term we should be able to add a macro that makes it a one-line fix to make this work correctly.

@willtebbutt let's help implement this (or a simpiler version if it involves lots of work), so @MilesCranmer can run Mooncake with DynamicExpressions.jl

@MilesCranmer
Copy link
Author

MilesCranmer commented Jan 13, 2025

Short explanation: Mooncake derives a "tangent type" for each Julia type it encounters -- it does this recursively. For a structs "primal" type, it produces something of the form Tangent{NameTuple{fieldnames, tangent_types_of_fields}}, where tangent_types_of_fields is a Tuple containing the result of tangent_type for each of the fields of the original struct. This is where the problem arises: this winds up being recursive if the name of the type appears in the type.

P.S., I wonder if another, more structural solution, could make sense here? Perhaps you could be to have a special type for this scenario, that is lazily-expanded –

# (Existing)
struct Tangent{Tfields<:NamedTuple}
    fields::Tfields
end

# (Add)
struct SelfTangent{T} end

where T is the type itself. Then, you could keep around a stack of parent types when generating the Tangent types. And if any of them match, then return a SelfTangent instead. e.g., in

@generated function build_tangent(::Type{P}, fields::Vararg{Any,N}) where {P,N}
    tangent_values_exprs = map(enumerate(tangent_field_types(P))) do (n, tt)
        tt <: PossiblyUninitTangent && return n <= N ? :($tt(fields[$n])) : :($tt())
        return :(fields[$n])
    end
    tuple_expr = Expr(:tuple, tangent_values_exprs...)
    return Expr(:call, tangent_type(P), Expr(:call, NamedTuple{fieldnames(P)}, tuple_expr))
end

you could add a third argument parents that simply records the stack of parents, and then trigger on any self-reference:

- @generated function build_tangent(::Type{P}, fields::Vararg{Any,N}) where {P,N}
+ @generated function build_tangent(::Type{P}, fields::Vararg{Any,N}, parents::PARENTS=()) where {P,N,PARENTS <: Tuple}
+     P in PARENTS.types && return :($(SelfTangent){$P})

and whenever descending into a SelfTangent, you would implicitly already know the types of its fields. Does this kinda make sense? I have no idea how hard it would be to get this working though.


Perhaps yet another option could be to take a trait-based approach to the dispatch of the tangent_type function.

@yebai
Copy link
Contributor

yebai commented Jan 14, 2025

Thanks, @MilesCranmer, for the suggestion. @willtebbutt is currently away, so it might take a while to get back to you.

As a manual solution, does #434 offer enough information to implement a customised tangent type?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement (error messages) The error was produced that should be improved upon high priority
Projects
None yet
Development

No branches or pull requests

3 participants