-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote AD failure workarounds & test cleanup #414
Changes from 3 commits
7ccfbef
313535e
11b2bcd
feb5b5c
919e445
7d38977
a4ad008
51c9250
821503c
8b1cddc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -100,7 +100,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) | |
@inferred f(args...) | ||
@inferred Zygote._pullback(ctx, f, args...) | ||
out, pb = Zygote._pullback(ctx, f, args...) | ||
@test_throws ErrorException @inferred pb(out) | ||
@inferred pb(out) | ||
end | ||
|
||
function test_ADs( | ||
|
@@ -224,65 +224,77 @@ end | |
|
||
function test_AD(AD::Symbol, kernelfunction, args=nothing, dims=[3, 3]) | ||
@testset "$(AD)" begin | ||
# Test kappa function | ||
k = if args === nothing | ||
kernelfunction() | ||
else | ||
kernelfunction(args) | ||
end | ||
rng = MersenneTwister(42) | ||
|
||
if k isa SimpleKernel | ||
for d in log.([eps(), rand(rng)]) | ||
compare_gradient(AD, [d]) do x | ||
kappa(k, exp(x[1])) | ||
@testset "kappa function" begin | ||
for d in log.([eps(), rand(rng)]) | ||
compare_gradient(AD, [d]) do x | ||
kappa(k, exp(x[1])) | ||
end | ||
end | ||
end | ||
end | ||
# Testing kernel evaluations | ||
x = rand(rng, dims[1]) | ||
y = rand(rng, dims[1]) | ||
compare_gradient(AD, x) do x | ||
k(x, y) | ||
end | ||
compare_gradient(AD, y) do y | ||
k(x, y) | ||
end | ||
if !(args === nothing) | ||
compare_gradient(AD, args) do p | ||
kernelfunction(p)(x, y) | ||
end | ||
end | ||
# Testing kernel matrices | ||
A = rand(rng, dims...) | ||
B = rand(rng, dims...) | ||
for dim in 1:2 | ||
compare_gradient(AD, A) do a | ||
testfunction(k, a, dim) | ||
end | ||
compare_gradient(AD, A) do a | ||
testfunction(k, a, B, dim) | ||
|
||
@testset "kernel evaluations" begin | ||
x = rand(rng, dims[1]) | ||
y = rand(rng, dims[1]) | ||
@testset "first argument" begin | ||
compare_gradient(AD, x) do x | ||
k(x, y) | ||
end | ||
end | ||
compare_gradient(AD, B) do b | ||
testfunction(k, A, b, dim) | ||
@testset "second argument" begin | ||
compare_gradient(AD, y) do y | ||
k(x, y) | ||
end | ||
end | ||
if !(args === nothing) | ||
compare_gradient(AD, args) do p | ||
testfunction(kernelfunction(p), A, dim) | ||
@testset "hyperparameters" begin | ||
compare_gradient(AD, args) do p | ||
kernelfunction(p)(x, y) | ||
end | ||
end | ||
end | ||
end | ||
|
||
compare_gradient(AD, A) do a | ||
testdiagfunction(k, a, dim) | ||
end | ||
compare_gradient(AD, A) do a | ||
testdiagfunction(k, a, B, dim) | ||
end | ||
compare_gradient(AD, B) do b | ||
testdiagfunction(k, A, b, dim) | ||
end | ||
if args !== nothing | ||
compare_gradient(AD, args) do p | ||
testdiagfunction(kernelfunction(p), A, dim) | ||
@testset "kernel matrices" begin | ||
A = rand(rng, dims...) | ||
B = rand(rng, dims...) | ||
for dim in 1:2 | ||
compare_gradient(AD, A) do a | ||
testfunction(k, a, dim) | ||
end | ||
compare_gradient(AD, A) do a | ||
testfunction(k, a, B, dim) | ||
end | ||
compare_gradient(AD, B) do b | ||
testfunction(k, A, b, dim) | ||
end | ||
if !(args === nothing) | ||
compare_gradient(AD, args) do p | ||
testfunction(kernelfunction(p), A, dim) | ||
end | ||
end | ||
|
||
compare_gradient(AD, A) do a | ||
testdiagfunction(k, a, dim) | ||
end | ||
compare_gradient(AD, A) do a | ||
testdiagfunction(k, a, B, dim) | ||
end | ||
compare_gradient(AD, B) do b | ||
testdiagfunction(k, A, b, dim) | ||
end | ||
if args !== nothing | ||
compare_gradient(AD, args) do p | ||
testdiagfunction(kernelfunction(p), A, dim) | ||
end | ||
end | ||
end | ||
end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rossviljoen doing the changes above I spotted that you overloaded test_AD() for MOKernel - this worries me: it seems like all the "fallback" AD tests above are not run. It also seems to suggest that we can't substitute an MOKernel where we need an arbitrary Kernel (which would I believe violate Liskov substitution principle). Could you comment pls? thanks:) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think that was me? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, apologies! Now I'm not even sure anymore why I assumed it was you 😔 @willtebbutt / @thomasgudjonwright ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm don't think it was me. A lot of this code was @theogf I think? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did the original compare gradient but someone else extended it to MOKernel There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wasn't me! Could always look back at the commit history :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should've checked git blame in the first place... it was #263 (@david-vicente). This to me seems like a case where overloading the same method isn't really the right thing to do (because it suggests they do the same thing - it's a kernel AD check in either case - but they don't actually seem to check the same things)... now that i've got y'all in a thread, what is your opinion on these ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Separated out into #416 |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was doing exactly the same for
testdiagfunction
as the code above fortestfunction
, so I've unified it with a for loop over the two functions.