-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explain the Abstract Primals Problem #343
Comments
I wrote some tidier comments here: JuliaDiff/ChainRules.jl#337 (comment) . This takes The "big argument" about FillArray is here: FluxML/Zygote.jl#863 and actually quite civil! Tl;dr is that I don't find time-complexity arguments all that compelling, but do think preserving (some) structural constraints is a good idea. |
This is also discussed in some detail here. |
planHaving read the discussions in JuliaDiff/ChainRules.jl#337, #347, and JuliaDiff/ChainRules.jl#232, it seems that:
Is that a fair conclusion? It appears that while the automatic decisions on opt-in/out via asideOne idea that has been mentioned but not discussed extensively is the second point in JuliaDiff/ChainRules.jl#232 (comment), i.e. how to prevent abstractly typed rules from overriding what would have been an efficient AD pullback from transforming a specialised forward pass. Could we require the signature of the Doing this would draw a parallel to the normal dispatch. I suppose we win in cases where AD is more efficient at transforming the specialised forward pass than the fallback Overall I still think the plan above is better. But thought I'd bring it up in case people have opinions. |
I'm still very uneasy about the idea of having to opt-out of rules -- my experience has generally been that, if I've written a specialised method for some type, I want AD to have a crack at it. We are all in agreement that if you define a type, don't implement a specialised method of a particular function for it, then you want to hit the generic projected fallback. That feels like a good tradeoff. I'm finding it hard to know how to make progress on the in-between ground though. Maybe we need sketches of code in both cases or something? |
If I understand correctly, in the case where we define a specialised method for some type, we want to do one of the two:
I don't have enough experience to guess which option is better generally. On the other hand, opt-out seems to require less effort than opt-in, so I have a slight preference for that. What are some examples in which 1) is better? I thought it would be better for julia> n = 5;
julia> d = Diagonal(rand(n));
julia> m = rand(n,n);
julia> gradient(d, m -> sum(d*m), d, m)
ERROR: MethodError: objects of type Diagonal{Float64, Vector{Float64}} are not callable
Use square brackets [] for indexing an Array.
Stacktrace:
[1] macro expansion
@ ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context, ::Diagonal{Float64, Vector{Float64}}, ::var"#5#6", ::Diagonal{Float64, Vector{Float64}}, ::Matrix{Float64})
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface2.jl:9
[3] _pullback(::Diagonal{Float64, Vector{Float64}}, ::Function, ::Diagonal{Float64, Vector{Float64}}, ::Matrix{Float64})
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:34
[4] pullback(::Diagonal{Float64, Vector{Float64}}, ::Function, ::Diagonal{Float64, Vector{Float64}}, ::Vararg{Any, N} where N)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:40
[5] gradient(::Diagonal{Float64, Vector{Float64}}, ::Function, ::Vararg{Any, N} where N)
@ Zygote ~/JuliaEnvs/Zygote.jl/src/compiler/interface.jl:58
[6] top-level scope
@ REPL[8]:1 EDIT:
Oh that's embarrassing... Anyway, it seems that while the rrule scales terribly, the AD pullback is only faster above n~1000 Using the fallback rule julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=5; d=Diagonal(rand(n)); m=rand(n,n))
69.041 ns (1 allocation: 288 bytes)
julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=100; d=Diagonal(rand(n)); m=rand(n,n))
10.750 μs (2 allocations: 78.20 KiB)
julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=1000; d=Diagonal(rand(n)); m=rand(n,n))
1.895 ms (2 allocations: 7.63 MiB) while AD (commenting out the rule) gets julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=5; d=Diagonal(rand(n)); m=rand(n,n))
52.701 μs (487 allocations: 23.45 KiB)
julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=100; d=Diagonal(rand(n)); m=rand(n,n))
69.860 μs (489 allocations: 179.30 KiB)
julia> @btime Zygote.pullback((d,m) -> d*m, d, m) setup=(n=1000; d=Diagonal(rand(n)); m=rand(n,n))
2.914 ms (489 allocations: 15.28 MiB) |
I think you miss some brackets:
Arguably this is mathematically wrong (as it is nonzero off-diagonal) not just computationally inefficient (n^2 not n). But this is an easy rule to add, partly because ChainRules depends on LinearAlgebra. The specialised method for the forward pass is this, not AD friendly:
I think the crucial asymmetry is that, if you are defining a specialised method for Whereas if I want the generic |
I think we all agree that in this case (where no As far as I understand, the question is whether in the case where the primal |
Ok. So I agree you can imagine some automated rule that says "skip the abstract rule if there is a more specific primal". But I think there are two problems with that idea. One is that there are a great many more specific primals, |
These are good points. I can imagine that for some
Preferring fallback rules over ADI suppose an argument for this option is that we want to prefer things working out of the box (even if less efficiently) over AD running into problems and throwing up a stacktrace. Making sure that things are efficient (or even tractable) should be a secondary concern that would require some action: opting out of rules, or writing a custom rule. In this case:
Opt outs here are one liners. Preferring AD over fallback rulesIf we decide to prefer ADing over fallback rules, then:
How would this opt in mechanism look like? One is to define a more specific rule. This could be quite repetitive and might result in a large number of rules. Some refactoring would be needed for each function, to define some This sounds like it would be more code than opting-out. @willtebbutt did I get this right? The information that we are missing is what the fraction of 1-4 types of methods are, and how important they are. Getting accurate number on this is hard, guesstimating them could be possible (but not by me). Did I miss any arguments in the above? |
I don't know that there's a useful summary count, but I do think discussing particular examples is good, else it's easy to talk past each other. (Although I'm not sure I want to inflict the
Going complex -> real, or full -> diagonal, is a projection, not a change of basis. I believe this is correct for many examples, such as the Diagonal * Matrix above. I hope it will be true in general, i.e. that an abstract rule for |
I think that's a fair summary. I'm becoming more convinced by the idea that making the opt-out mechanism work, provided that we deal with the projections / representational changes for the sake of correctness properly, is the way forward. I don't particularly like it, but I think I agree it's the lesser of two evils.
@wesselb and I have been discussing this a bit, and I suspect that you can side-step this issue by being careful about how you define the tangent space of any given variable. Taking the |
There is a problem that comes up to do with abstract primals.
Most commonly in the case of AbstractArrays.
We don't have a good explination of what the problem is anywhere, it is scattered across various issues and PRs over various issues.
@willtebbutt has spent a bunch of time thinking about it.
I propose that we should open a docs PR that clearly explains it, with examples etc.
As part of the design docs section.
Once we have that PR open, we can talk more about solving it.
@mcabbott I were discussing this on slack. (they might post there note on this later)
The below is roughly extracted from that It kind of discusses a lot of the problem, thought it isn't super clear.
Since the whole discussion exists because we don't have a clear eplination of what the problem is.
Rough ugly notes:
The problem is that we want to define rules not just for
Arrays
but also forStaticArray
s.Which nearest common super type is
AbstractArray
.But if you do that then people say this is bad because it will mean
Diagonal
will takeO(N^2)
rather thanO(N)
.One could say that the user op'ed into this, since they used a
AbstractMatrix
and so it is on the method author (in this case the rule author) to provide an optimized method if appropriate.But there is a greater problem.
IIRC some operations on FillArray give the wrong answer, not just the wrong time complexity, if you treat it as a generic Array.
Mike and Will T had a big argument about it.
However:
If you only define rules on fundermental array types like Array and maybe StaticArray and GPUArrays, you get the correct time complexity, and you avoid defining things so generally that they break weirder array types.
And the AD will decompose the wrapper arrays correctly and get what is inside.
FillArrays work correctly if you treat them like structs. (As do all wrapper arrays).
@willtebbutt has spent ages thinking about this.
o understand the problem though I think we still don't have a great solution.
I am hoping Will will write something that we can put in the docs. (We have a few nice writeups like that on there)
I think we might actually need to formally introduce the idea of a fundermental array type as a trait maybe into ChainRules.
Or maybe just add a StaticArrays dep (it's basically a stdlib at this point) and then we can have a union for them
Now defining things only on concrete types seems unidiomatic.
Julia code normally is define on the general case then multiple dispatch is used to provide specific optizations for the specific case
However, AD has 2 ways to achieve functionality. Rule and Generation. And we want to endure overall most specific gets hit.
Julia should work with the most specific function takes precedence over the less specific function right?
The most specific function is the one that is most specific and customised for that type.
With AD the most specific functionality is always available: it is to let the AD run, which generates code for the exact input.
So when a rule is applied it is actually getting in the way of the most specific type.
E.g. when a rule is defined for AbstractMatrix that prevents the AD from generating the more specific functionality for Symmeteic{Diagonal{...}} .
When I say wrapper arrays I mean anything that has a parent array. (According to the parent function)
Though it really is more general: it is anything which has the method underconsideration defined without resorting to ccall.
Since as long as it doesn't resort to ccall the AD will be able to generate a pullback for the method.
That generated pullback will call something that we will have an optimised rule for. Might not even be the same function but it will be something, and so we are solid.
Remember AD systems do generally generate code that is optimal if they don't error, except if there is specific domain expertise the rule author is applying.
So we only need the rule to catch the errorring case.
Julia's specialisation rules do apply to AD rules and to code generated by AD. But the AD doesn't get to generate it's code if it hits a rule.
And the code will be more specialised than something hitting abstract matrix. This is the thing where the AD on Diagonal would get to break things down according to the primal method definition and would end up hitting a rule for Vector, rather than for matrix.
Where as the compiler's specialisation of a rrule for AbstractMatrix's is not so specialised and ends up looking vat a bunch of zero elements.
The AD would have done the better thing if the rule hadn't have been defined.
Because AD systems are good at findings derivatives.
Optimal as in identical to the code someone would write for this concrete input by hand.
AD doesn't always do it, because sometimes there is domain knowledge to apply. But in simple cases like decomposing function on wrapper arrays, it does.
As I said: if you have domain knowledge then you can do better.
But still generally that domain knowledge will be able to be applied to a "fundermental" array type (,like Array, and maybe StaticArray, GPUArray) and it woll still end up benifitting the wrapper arrays type.
and the generated code the comes for the pullback between the wrapper type and the type that has the domain knowledge rule for will be optimal (in the sense of being basically identical to what a human would do to do this).
The text was updated successfully, but these errors were encountered: