From 35b432c47b454d6af925497337c32bc0f3958df0 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 9 Jan 2025 17:54:55 +0000 Subject: [PATCH] Performance Robustness in Reverse Pass (#442) * Fix up zero_rdata_from_type * Stop generating hundreds of methods * Call getfield directly instead of getindex * Manually inline ad stmts for rvs-pass for call * Refactor PhiNode * Remove increment_if_ref * Remove commented-out code * Improve docstring * Improve special_functions testset display * Extend increment * Remove increment_ref usage * PiNode * Remove increment_ref and __pi_rvs * Add regression test * Remove __deref_and_zero * Docstring * Reformat * Bump patch version --- Project.toml | 2 +- src/fwds_rvs_data.jl | 11 +- src/interpreter/s2s_reverse_mode_ad.jl | 167 ++++++++++++------ src/interpreter/zero_like_rdata.jl | 2 + src/utils.jl | 38 ++-- .../special_functions/special_functions.jl | 4 +- test/interpreter/s2s_reverse_mode_ad.jl | 9 +- 7 files changed, 137 insertions(+), 96 deletions(-) diff --git a/Project.toml b/Project.toml index 333477155..08f312226 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.75" +version = "0.4.76" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/fwds_rvs_data.jl b/src/fwds_rvs_data.jl index 584227fa0..bd088a004 100644 --- a/src/fwds_rvs_data.jl +++ b/src/fwds_rvs_data.jl @@ -662,14 +662,11 @@ with. if P isa DataType names = fieldnames(P) types = fieldtypes(P) - wrapped_field_zeros = map(enumerate(tangent_field_types(P))) do (n, tt) + wrapped_field_zeros = map(enumerate(always_initialised(P))) do (n, init) fzero = :(zero_rdata_from_type($(types[n]))) - if tt <: PossiblyUninitTangent - Q = :(rdata_type(tangent_type($(fieldtype(P, n))))) - return :(PossiblyUninitTangent{$Q}($fzero)) - else - return fzero - end + init && return fzero + Q = :(rdata_type(tangent_type($(fieldtype(P, n))))) + return :(PossiblyUninitTangent{$Q}($fzero)) end wrapped_field_zeros_tuple = Expr(:call, :tuple, wrapped_field_zeros...) wrapped_expr = :(R(NamedTuple{$names}($wrapped_field_zeros_tuple))) diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index d8a55b615..0b604798e 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -405,7 +405,7 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) end if is_active(stmt.val) rdata_id = get_rev_data_id(info, stmt.val) - rvs = new_inst(Expr(:call, increment_ref!, rdata_id, Argument(2))) + rvs = increment_ref_stmts(rdata_id, Argument(2)) assert_id = ID() val = __inc(stmt.val) fwds = [ @@ -479,7 +479,13 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) val_rdata_ref_id = get_rev_data_id(info, stmt.val) output_rdata_ref_id = get_rev_data_id(info, line) fwds = PiNode(__inc(stmt.val), fcodual_type(CC.widenconst(stmt.typ))) - rvs = Expr(:call, __pi_rvs!, P, val_rdata_ref_id, output_rdata_ref_id) + + # Get the rdata from the output_rdata_ref, and set its new value to zero, and + # increment the output ref. + output_rdata_id = ID() + deref_stmts = deref_and_zero_stmts(P, output_rdata_ref_id, output_rdata_id) + inc_exprs = increment_ref_stmts(val_rdata_ref_id, output_rdata_id) + rvs = vcat(deref_stmts, inc_exprs) else # If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to # do on the reverse-pass. @@ -494,11 +500,6 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) return ad_stmt_info(line, nothing, fwds, rvs) end -@inline function __pi_rvs!(::Type{P}, val_rdata_ref::Ref, output_rdata_ref::Ref) where {P} - increment_ref!(val_rdata_ref, __deref_and_zero(P, output_rdata_ref)) - return nothing -end - # Constant GlobalRefs are handled. See const_codual. Non-constant GlobalRefs are handled by # assuming that they are constant, and creating a CoDual with the value. We then check at # run-time that the value has not changed. @@ -723,17 +724,53 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) rvs_pass = if T_pb!! <: NoPullback nothing else - Expr( - :call, - __run_rvs_pass!, - get_primal_type(info, line), - sig, - pb, - get_rev_data_id(info, line), - map(Base.Fix1(get_rev_data_id, info), args)..., + # Get the rdata which we pass into the pullback from its rdata ref. + rdata_ref_id = get_rev_data_id(info, line) + rdata_output_id = ID() + rdata_output_expr = Expr(:call, getfield, rdata_ref_id, QuoteNode(:x)) + rdata_output = (rdata_output_id, new_inst(rdata_output_expr)) + + # Zero out the value stored in this rdata ref now that we have its current + # value. The new value is rdata, so must be an instance of a bits type, so is + # safe to interpolate straight into instruction. + zero_val = zero_like_rdata_from_type(get_primal_type(info, line)) + zero_rdata_expr = Expr(:call, setfield!, rdata_ref_id, QuoteNode(:x), zero_val) + zero_rdata_ref = (ID(), new_inst(zero_rdata_expr)) + + # Run the pullback. The result is a tuple comprising `length(args)` elements. + call_pullback_id = ID() + call_pullback = (call_pullback_id, new_inst(Expr(:call, pb, rdata_output_id))) + + # For each element of the tuple returned by call_pullback, if the corresponding + # value in the primal IR is an Argument / SSA (if `get_rev_data_id` does not + # return nothing), increment the value in its rdata ref. This is equivalent to + # rdata_ref[] = increment!!(rdata_ref[], rdata_inc_resulting_from_pullback), + # but written out manually to ensure nothing fails to inline. + # If the corresponding value in the primal IR is not an Argument / SSA (e.g. it + # is a literal, a `QuoteNode`, or a `GlobalRef`), do nothing as we do not track + # gradients w.r.t. it. + tmp = map(enumerate(args)) do (n, arg) + rev_data_id = get_rev_data_id(info, arg) + + # If arg is not an SSA / Argument, then no rdata ref to inc. + rev_data_id === nothing && return nothing + + # Extract rdata from result of calling pullback. + rdata_inc_id = ID() + rdata_inc_expr = Expr(:call, getfield, call_pullback_id, n) + rdata_inc = (rdata_inc_id, new_inst(rdata_inc_expr)) + + # Construct statments to increment ref. + return vcat(rdata_inc, increment_ref_stmts(rev_data_id, rdata_inc_id)) + end + + # Concatenate all statements, and return them. + vcat( + IDInstPair[rdata_output, zero_rdata_ref, call_pullback], + reduce(vcat, filter(x -> !(x === nothing), tmp); init=IDInstPair[]), ) end - return ad_stmt_info(line, comms_id, fwds, new_inst(rvs_pass)) + return ad_stmt_info(line, comms_id, fwds, rvs_pass) elseif Meta.isexpr(stmt, :boundscheck) # For some reason the compiler cannot handle boundscheck statements when we run it @@ -782,6 +819,29 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) end end +""" + increment_ref_stmts(ref_id::ID, inc_data)::Vector{IDInstPair} + +Equivalent to `ref[] = increment!!(ref[], inc_data)`, where `ref` and `inc_data` are the +values associated to `ref_id` and `inc_data` respectively. +""" +function increment_ref_stmts(ref_id::ID, inc_data)::Vector{IDInstPair} + + # Get the value stored in the `Base.RefValue`. + ref_val_id = ID() + ref_val = (ref_val_id, new_inst(Expr(:call, getfield, ref_id, QuoteNode(:x)))) + + # Increment the value by inc_data. + new_val_id = ID() + new_val = (new_val_id, new_inst(Expr(:call, increment!!, ref_val_id, inc_data))) + + # Update the value stored in the rdata reference. + set_ref_expr = Expr(:call, setfield!, ref_id, QuoteNode(:x), new_val_id) + set_ref = (ID(), new_inst(set_ref_expr)) + + return IDInstPair[ref_val, new_val, set_ref] +end + is_active(::Union{Argument,ID}) = true is_active(::Any) = false @@ -807,33 +867,6 @@ end __get_primal(x::CoDual) = primal(x) __get_primal(x) = x -""" - __run_rvs_pass!( - P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs... - ) where {sig} - -Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`. -""" -@inline function __run_rvs_pass!( - P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs... -) where {sig} - tuple_map(increment_if_ref!, arg_rev_data_refs, pb!!(ret_rev_data_ref[])) - set_ret_ref_to_zero!!(P, ret_rev_data_ref) - return nothing -end - -@inline increment_if_ref!(ref::Ref, rvs_data) = increment_ref!(ref, rvs_data) -@inline increment_if_ref!(::Ref, ::ZeroRData) = nothing -@inline increment_if_ref!(::Nothing, ::Any) = nothing - -@inline increment_ref!(x::Ref, t) = setindex!(x, increment!!(x[], t)) -@inline increment_ref!(::Base.RefValue{NoRData}, t) = nothing - -@inline function set_ret_ref_to_zero!!(::Type{P}, r::Ref{R}) where {P,R} - return r[] = zero_like_rdata_from_type(P) -end -@inline set_ret_ref_to_zero!!(::Type{P}, r::Base.RefValue{NoRData}) where {P} = nothing - const RuleMC{A,R} = MistyClosure{OpaqueClosure{A,R}} # @@ -1437,7 +1470,7 @@ function pullback_ir( # De-reference the nth rdata. rdata_id = ID() - rdata = new_inst(Expr(:call, getindex, arg_rdata_ref_ids[n])) + rdata = new_inst(Expr(:call, getfield, arg_rdata_ref_ids[n], QuoteNode(:x))) # Get the nth lazy zero rdata. lazy_zero_rdata_id = ID() @@ -1511,11 +1544,12 @@ function conclude_rvs_block( # Create statements which extract + zero the rdata refs associated to them. rdata_ids = map(_ -> ID(), phi_ids) - deref_stmts = map(phi_ids, rdata_ids) do phi_id, deref_id + tmp = map(phi_ids, rdata_ids) do phi_id, deref_id P = get_primal_type(info, phi_id) r = get_rev_data_id(info, phi_id) - return (deref_id, new_inst(Expr(:call, __deref_and_zero, P, r))) + return deref_and_zero_stmts(P, r, deref_id) end + deref_stmts = reduce(vcat, tmp; init=IDInstPair[]) # For each predecessor, create a `BBlock` which processes its corresponding edge in # each of the `PhiNode`s. @@ -1540,14 +1574,19 @@ function __get_value(edge::ID, x::IDPhiNode) end """ - __deref_and_zero(::Type{P}, x::Ref) where {P} + deref_and_zero_stmts(P, ref_id, val_id) -Helper, used in conclude_rvs_block. +Equivalent to something like +```julia +val = ref[] +ref[] = zero_rdata_from_type(P) +``` """ -@inline function __deref_and_zero(::Type{P}, x::Ref) where {P} - t = x[] - x[] = Mooncake.zero_like_rdata_from_type(P) - return t +function deref_and_zero_stmts(P, ref_id, val_id) + val = (val_id, new_inst(Expr(:call, getfield, ref_id, QuoteNode(:x)))) + r = Mooncake.zero_like_rdata_from_type(P) + set_ref = (ID(), new_inst(Expr(:call, setfield!, ref_id, QuoteNode(:x), r))) + return IDInstPair[val, set_ref] end """ @@ -1562,10 +1601,14 @@ of some block: %6 = φ (#2 => _1, #3 => %5) %7 = φ (#2 => 5., #3 => _2) ``` -Let the tangent refs associated to `%6`, `%7`, and `_1`` be denoted `t%6`, `t%7`, and `t_1` -resp., and let `pred_id` be `#2`, then this function will produce a basic block of the form +Let the rdata refs associated to `%6`, `%7`, and `_1`` be denoted `r%6`, `r%7`, and `r_1` +resp., and let `pred_id` be `#2`, and `increment_ref!` be the following function, ```julia -increment_ref!(t_1, t%6) +increment_ref!(ref, x) = ref[] = increment!!(ref[], x) +``` +then this `rvs_phi_block` will produce a basic block of the form +```julia +increment_ref!(r_1, r%6) nothing goto #2 ``` @@ -1577,15 +1620,23 @@ on. The same ideas apply if `pred_id` were `#3`. The block would end with `#3`, and there would be two `increment_ref!` calls because both `%5` and `_2` are not constants. + +In practice, code which is equivalent to `increment_ref!` is created directly, rather than +inserting a call to a generic Julia function. This is because we need to be certain that +the getfield and setfield! calls applied to any references are visible to the SROA +optimisation pass. If we insert a call to a function like `increment_ref!`, it might not be +inlined away, making such references opaque. """ function rvs_phi_block( pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo ) @assert length(rdata_ids) == length(values) - inc_stmts = map(rdata_ids, values) do id, val - stmt = Expr(:call, increment_if_ref!, get_rev_data_id(info, val), id) - return (ID(), new_inst(stmt)) + tmp = map(rdata_ids, values) do id, val + rev_data_id = get_rev_data_id(info, val) + rev_data_id === nothing && return nothing + return increment_ref_stmts(rev_data_id, id) end + inc_stmts = reduce(vcat, filter(x -> !(x === nothing), tmp); init=IDInstPair[]) goto_stmt = (ID(), new_inst(IDGotoNode(pred_id))) return BBlock(ID(), vcat(inc_stmts, goto_stmt)) end diff --git a/src/interpreter/zero_like_rdata.jl b/src/interpreter/zero_like_rdata.jl index 882f5d9f1..6cb9e4078 100644 --- a/src/interpreter/zero_like_rdata.jl +++ b/src/interpreter/zero_like_rdata.jl @@ -11,6 +11,8 @@ error -- please open an issue in such a situation. struct ZeroRData end @inline increment!!(::ZeroRData, r::R) where {R} = r +@inline increment!!(r::R, ::ZeroRData) where {R} = r +@inline increment!!(::ZeroRData, ::ZeroRData) = ZeroRData() """ zero_like_rdata_type(::Type{P}) where {P} diff --git a/src/utils.jl b/src/utils.jl index 6914ab2ad..263e8173e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,43 +24,29 @@ the same length, while `map` will just produce a new tuple whose length is equal shorter of `x` and `y`. """ @inline @generated function tuple_map(f::F, x::Tuple) where {F} - return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), eachindex(x.parameters))...) + return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:fieldcount(x))...) end @inline @generated function tuple_map(f::F, x::Tuple, y::Tuple) where {F} if length(x.parameters) != length(y.parameters) return :(throw(ArgumentError("length(x) != length(y)"))) else - stmts = map(n -> :(f(getfield(x, $n), getfield(y, $n))), eachindex(x.parameters)) + stmts = map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:fieldcount(x)) return Expr(:call, :tuple, stmts...) end end -for N in 1:128 - @eval @inline function tuple_map(f::F, x::Tuple{Vararg{Any,$N}}) where {F} - return $(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:N)...)) - end - @eval @inline function tuple_map( - f::F, x::NamedTuple{names,<:Tuple{Vararg{Any,$N}}} - ) where {F,names} - return NamedTuple{names}( - $(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), 1:N)...)) - ) - end - @eval @inline function tuple_map(f, x::Tuple{Vararg{Any,$N}}, y::Tuple{Vararg{Any,$N}}) - return $(Expr( - :call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)... - )) - end - @eval @inline function tuple_map( - f::F, - x::NamedTuple{names,<:Tuple{Vararg{Any,$N}}}, - y::NamedTuple{names,<:Tuple{Vararg{Any,$N}}}, - ) where {F,names} - return NamedTuple{names}( - $(Expr(:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:N)...)) - ) +@generated function tuple_map(f, x::NamedTuple{names}) where {names} + getfield_exprs = map(n -> :(f(getfield(x, $n))), 1:fieldcount(x)) + return :(NamedTuple{names}($(Expr(:call, :tuple, getfield_exprs...)))) +end + +@generated function tuple_map(f, x::NamedTuple{names}, y::NamedTuple{names}) where {names} + if fieldcount(x) != fieldcount(y) + return :(throw(ArgumentError("length(x) != length(y)"))) end + getfield_exprs = map(n -> :(f(getfield(x, $n), getfield(y, $n))), 1:fieldcount(x)) + return :(NamedTuple{names}($(Expr(:call, :tuple, getfield_exprs...)))) end for N in 1:256 diff --git a/test/ext/special_functions/special_functions.jl b/test/ext/special_functions/special_functions.jl index b5a44b43e..613ceba10 100644 --- a/test/ext/special_functions/special_functions.jl +++ b/test/ext/special_functions/special_functions.jl @@ -7,7 +7,7 @@ using Mooncake.TestUtils: test_rule # Rules in this file are only lightly tester, because they are all just @from_rrule rules. @testset "special_functions" begin - @testset for (perf_flag, f, x...) in vcat( + @testset "$perf_flag, $(typeof((f, x...)))" for (perf_flag, f, x...) in vcat( map([Float64, Float32]) do P return Any[ (:stability, airyai, P(0.1)), @@ -51,7 +51,7 @@ using Mooncake.TestUtils: test_rule ) test_rule(StableRNG(123456), f, x...; perf_flag) end - @testset for (perf_flag, f, x...) in vcat( + @testset "$perf_flag, $(typeof((f, x...)))" for (perf_flag, f, x...) in vcat( map([Float64, Float32]) do P return Any[ (:none, logerf, P(0.3), P(0.5)), # first branch diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index 344f8f76a..d680bdbdf 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -12,6 +12,8 @@ struct A end f(a, x) = dot(a.data, x) +unstable_tester(x::Ref{Any}) = sin(x[]) + end @testset "s2s_reverse_mode_ad" begin @@ -106,8 +108,6 @@ end @test length(stmts.fwds) == 2 @test stmts.fwds[1][2].stmt isa Expr @test stmts.fwds[2][2].stmt isa ReturnNode - @test Meta.isexpr(only(stmts.rvs)[2].stmt, :call) - @test only(stmts.rvs)[2].stmt.args[1] == Mooncake.increment_ref! end @testset "literal" begin stmt_info = make_ad_stmts!(ReturnNode(5.0), line, info) @@ -344,4 +344,9 @@ end f() = Float64 @test length(build_rrule(Tuple{typeof(f)}).fwds_oc.oc.captures) == 2 end + @testset "all `Ref`s for rdata are eliminated in type unstable code" begin + ir = Mooncake.rvs_ir(Tuple{typeof(S2SGlobals.unstable_tester),Ref{Any}}) + stmts = Mooncake.stmt(ir.stmts) + @test !any(x -> Meta.isexpr(x, :new) && x.args[1] <: Base.RefValue, stmts) + end end