-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add recursive map generalizing the make_zero mechanism #1852
base: main
Are you sure you want to change the base?
Conversation
2161e03
to
545bf9b
Compare
4fbdc47
to
74b212f
Compare
74b212f
to
3c6591e
Compare
Alright, I could take some feedback/discussion on this now.
TLDR: Should I rewrite @gdalle Promised to tag you when this was ready for review, but note that this PR only deals with the low-level, non-public guts of the implementation. I'll do the vector space wrapper in a separate PR as soon as this is merged (hopefully that won't be long, I really need that QuadGK rule for my research 📐) |
src/make_zero.jl
Outdated
return seen[prev] | ||
xs::NTuple{N,T}, | ||
::Val{copy_if_inactive}=Val(false), | ||
isleaftype::L=Returns(false), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering whether this is necessary or if the leaf types could just be hardcoded to Union{_RealOrComplexFloat,Array{<:_RealOrComplexFloat}}
. I'll make a prototype of the vector space wrapper and the updated QuadGK rules to see if customizable leaf types comes in handy.
Update for anyone who's following: I've implemented the VectorSpace wrapper, which prompted me to adjust the recursive_map implementation a bit, all for the better. It's looking good and will make writing custom higher-order rules as well as the DI wrappers a lot nicer for arbitrary types. However, it dawned on me that you probably want |
awesome, sorry I haven't had a chance to review let [just a bunch of schenanigans atm], I'll try to take a closer look next week and ping me if not |
No worries! I restored the draft label when I realized there was a bit more to do and will remove it again once I think this is ready for review. No need to look at it until then, the current state here on github doesn't reflect what I'm working with locally anyway. |
src/make_zero.jl
Outdated
isleaftype::L=Returns(false), | ||
) where {T,F,N,L,copy_if_inactive} | ||
x1 = first(xs) | ||
if guaranteed_const_nongen(T, nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm, this is only for make_zero, and not for add/etc?
Because this case here already feels specific to the context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's going to look a bit different once I push the next update (hopefully tomorrow), but no, after some experimenting it seemed best to me to always skip guaranteed inactive subtrees and restrict recursive_map
to applying f
to the differentiable values only. I tried doing the opposite initially, leaving it as part of the isleaftype
filter and handling the possible deepcopy within the mapped function f
, but it made things a lot more complicated. I think the main issue was that the whole mechanism with seen
and keeping track of object identity then becomes the purview of the mapped function f
instead of recursive_map
itself, increasing boilerplate and complicating the contract between recursive_map and its callers. I couldn't think of a use case within Enzyme where you're interested in mapping over the guaranteed inactive parts anyway, and not recursing through inactive subtrees saves you from having to deal with deconstruction/reconstruction of a few specialized types (deepcopy
has a lot more methods than recursive_map
). So I went with this solution instead.
Of course, adding a skip_guaranteed_const
flag would be straightforward (or combining it with copy_if_inactive
into a single inactive_mode
parameter). Do you think this is warranted?
c2f05d4
to
72bda99
Compare
At long last, I think this one's ready for you to take a look. Hit me with any questions and concerns, from major design issues to bikeshedding over names. I put both the implementation and tests in their own modules because they define a lot of helpers and I didn't want to pollute other modules' namespaces. |
Codecov ReportAttention: Patch coverage is
❗ Your organization needs to install the Codecov GitHub app to enable full functionality. Additional details and impacted files@@ Coverage Diff @@
## main #1852 +/- ##
==========================================
+ Coverage 67.50% 75.21% +7.70%
==========================================
Files 31 56 +25
Lines 12668 16618 +3950
==========================================
+ Hits 8552 12499 +3947
- Misses 4116 4119 +3 ☔ View full report in Codecov by Sentry. |
I'm still knee deep in 1.11 land and don't have cycles to review this immediately. @vchuravy can you take a look? |
1.11 efforts deeply appreciated! Don't rush this. I'll keep using 1.10 and a local fork for my own needs, and occiasionally push small changes here as my tinkering surfaces new concerns/opportunities. |
@danielwe is this the one now to review or is there a different related PR I should review first? (and also would you mind rebasing) |
This is the one wrt. recursive maps and all that, but I've been continually refining stuff locally, so I need to both rebase and push the latest changes, hang on! Will remove draft status when ready for review. |
05efebf
to
95598e9
Compare
Finally wrapped up and rebased this! I'll come back later and write a little blurb, but the code should be ready for review as-is. |
...as well as recursive_add, recursive_accumulate!, and accumulate_into!
95598e9
to
16d5b65
Compare
Nice!
|
""" | ||
function make_zero! end | ||
|
||
""" | ||
make_zero(prev::T) | ||
isvectortype(::Type{T})::Bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a new API? It will need a version bump for EnzymeCore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should discuss. The reason I put these helpers in EnzymeCore rather than keeping them internal was that the StaticArrays extension needed to add a method, so I figured there's a chance others might have to do the same for their custom types. However, subtyping DenseArray
(and AbstractFloat
if that ever becomes relevant) should almost always be sufficient. Either way, the point is only to make these extensible in package extensions. I don't think anyone should ever have to call them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, the vector space wrapper functionality I've built on top of this (which will be a separate PR) will probably involve a new type in EnzymeCore, so if that gets accepted there will have to be a new EnzymeCore release anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, absolutely. I think this is the right place to add them, and EnzymeCore is basically meant for people to be able to extend things without having to bite the load time bullet that is Enzyme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw. names are infinitely bikesheddable, both in this case and elsewhere in the PR. My mindset working on this PR is to enable consistent treatment of arbitrary objects as vectors in a space spanned by the scalar (float) values reachable from the object, hence all the vector/scalar terminology, but I don't know if this works well or if it's confusing, especially as part of the public API.
Blurb time! Let's start with a quick overview of the API. I also wrote exhaustive docstrings in the code (just my way of clarifying my thinking, and I hope it's useful for code review too), but this should be a more conversational introduction highlighting the main points. Out-of-place usage (out1::T, out2::T, ...) = recursive_map([seen::IdDict,] f, Val(Nout), (in1::T, in2::T, ...), [Val(copy_if_inactive), [isinactivetype]]) This generalizes Note how this supports mapped functions Partially-in-place usage (new_out1::T, new_out2::T, ...) = recursive_map([seen::IdDict,] f!!, (out1::T, out2::T, ...), (in1::T, in2::T, ...), [Val(copy_if_inactive), [isinactivetype]]) This form has bangbang-like semantics where mutable storage within the To use this form, the mapped function The design of the whole API clicked for me the moment I realized that if you get this version right, both the in-place and out-of-place versions are just special cases, so you only need a single core of recursive functions with a f ewdifferent entry points for the various use cases. In-place usage If you pass types where all differentiable values live in mutable storage, partially-in-place already implements in-place behavior for you. The in-place function recursive_map!([seen::IdDict,] f!!, (out1::T, out2::T, ...), (in1::T, in2::T, ...), [Val(copy_if_inactive), [isinactivetype]])::Nothing
Optional arguments
There were three concerns:
Rather than a proliferation of optional arguments and internal voodoo, this called for a dedicated abstraction. Hence the
This functionality is the reason behind the method overwriting warning pointed out by @vchuravy above:
This warning comes from the test suite, not the package code. The test suite overwrites
The optional argument A word on generated functions Unrolling loops over tuples and struct fields is crucial for type stability and performance. I tried to only use Generality/GPU compatibility When the recursion hits a In the common case where the GPU array eltype is scalar (i.e., float), the array is considered a leaf and dispatched directly to the mapped function |
...and a nonsense default argument
I did some more profiling, and this is not true anymore, possibly due to the |
src/analyses/activity.jl
Outdated
@@ -427,6 +427,11 @@ Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_n | |||
return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState | |||
end | |||
|
|||
Base.@assume_effects :removable :foldable :nothrow @inline function guaranteed_nonactive_nongen(::Type{T}, world)::Bool where {T} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would a default arg value world=nothing
be acceptable here and in guaranteed_const_nongen
?
end | ||
recursive_map!(accumulate_into!!, (into, from), (into, from)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wsmoses @vchuravy This line in accumulate_into!
is the only reason multiple outputs are supported in recursive_map(!)
. The reverse rule for deepcopy
is the only place this is used.
We could change this to
recursive_map!(accumulate_into_alt!!, into, (into, from))
make_zero!(from)
to completely remove the need for multiple outputs. That would simplify the implementation of recursive_map!
and make its signature/usage more intuitive.
However, this implementation of accumulate_into!
recurses twice through from
instead of once, allocating two IdDict
s instead of one, so the deepcopy
rule will perform somewhat worse.
Other uses of recursive_map
may see slightly improved performance because the indirection of the ubiquitous tuples in the current implementation seems to add some extra runtime dispatch when objects have abstractly typed fields/elements. In type-stable cases, I don't think there will be any difference.
What's your opinion on this tradeoff?
This is to explore functionality for realizing JuliaMath/QuadGK.jl#120. The current draft cuts time and allocations in half for the MWE in that PR compared to the
make_zero
hack from the comments. Not sure if modifying the existingrecursive_*
functions like this is appropriate or whether it would be better to implement a separatedeep_recursive_accumulate
.This probably breaks some existing uses of
recursive_accumulate
, like the Holomorphic derivative code, becauserecursive_accumulate
now traverses most/all of the structure on its own and will double-accumulate when combined with the iteration over theseen
IdDicts. Curious to see the total impact on the test suite.This doesn't yet have any concept of
seen
and will thus double-accumulate if the structure has internal aliasing. That obviously needs to be fixed. Perhaps we can factor out and share the recursion code frommake_zero
.A bit of a tangent, but perhaps a final version of this PR should include migrating
ClosureVector
to Enzyme from the QuadGK ext as suggested in JuliaMath/QuadGK.jl#110 (comment). Looks like that's the most relevant application of fully recursive accumulation at the moment.Let me also throw out another suggestion: what if we implement a recursive generalization of broadcasting with an arbitrary number of arguments, i.e.,
recursive_broadcast!(f, a, b, c, ...)
as a recursive generalization ofa .= f.(b, c, ...)
, free of intermediate allocations whenever possible (and similarly an out-of-placerecursive_broadcast(f, a, b, c...)
generalizingf.(a, b, c...)
that only materializes/allocates once if possible). That would enable more optimized custom rules with Duplicated args, such as having the QuadGK rule call the in-place versionquadgk!(f!, result, segs...)
. Not sure if it would be hard to correctly handle aliasing without being overly defensive, or if that could mostly be taken care of by proper reuse of the existing broadcasting functionality.