-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mul/ewise rules for basic arithmetic semiring #26
Conversation
This comment has been minimized.
This comment has been minimized.
I apologize for the messy PR, the only important parts are in the tests and chainrules folders. I'm primarily interested in your thoughts about the Everything works fine for dense. For sparse inputs though there's two problems:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't manage to finish it all today, but here's a few comments
test/runtests.jl
Outdated
@@ -14,4 +14,5 @@ println("Testing SuiteSparseGraphBLAS.jl") | |||
@testset "SuiteSparseGraphBLAS" begin | |||
include_test("gbarray.jl") | |||
include_test("operations.jl") | |||
include_test("testrules.jl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually the structure of test
folder mirrors the src/
folder, which makes it easier to find things when the package grows.
I think the underlying issue is the same (and also the same one as in the elementwise rules). What it comes down is an instance of the "array dilemma", discussed in great detail over many issues and PRs. See JuliaDiff/ChainRulesCore.jl#347 (and related issues) for a discussion, but I warn you, it is a rabbit hole ;) Essentially what it comes down to is whether you think of the input, say Primal computation will be fast because if you interpret it as an array, the If you interpret it as a struct, meaning that the zeros are structural, it doesn't make sense to compute the tangents to all the zeros, and you can compute the backward pass efficiently. Since Long story short, we are treating them as structs now in order to not completely kill efficiency. We should probably treat them as structs here as well. Aside: projection, merged recently, was a way to make sure rules with abstractly typed arguments still return the correct tangent type. The classic example is In this case, as you point out, masking is all we need to do, since we are writing dedicated rules for |
…teSparseGraphBLAS.jl into arithmeticchains
I removed some I'm still testing to get feedback on these and avoid a monster PR.
Notes:
test_*rule
withcheck_inferred=false
. Issue Output eltype inference is not type stable #25 will fix.mul
there's a deeper issue. I'm 85-90% sure the rules are correct, but the patterns are not the same as for FiniteDifferences, and occasionally there's different values.