Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zygote AD failure workarounds & test cleanup #414

Merged
merged 10 commits into from
Dec 18, 2021
100 changes: 56 additions & 44 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
@inferred f(args...)
@inferred Zygote._pullback(ctx, f, args...)
out, pb = Zygote._pullback(ctx, f, args...)
@test_throws ErrorException @inferred pb(out)
@inferred pb(out)
end

function test_ADs(
Expand Down Expand Up @@ -224,65 +224,77 @@ end

function test_AD(AD::Symbol, kernelfunction, args=nothing, dims=[3, 3])
@testset "$(AD)" begin
# Test kappa function
k = if args === nothing
kernelfunction()
else
kernelfunction(args)
end
rng = MersenneTwister(42)

if k isa SimpleKernel
for d in log.([eps(), rand(rng)])
compare_gradient(AD, [d]) do x
kappa(k, exp(x[1]))
@testset "kappa function" begin
for d in log.([eps(), rand(rng)])
compare_gradient(AD, [d]) do x
kappa(k, exp(x[1]))
end
end
end
end
# Testing kernel evaluations
x = rand(rng, dims[1])
y = rand(rng, dims[1])
compare_gradient(AD, x) do x
k(x, y)
end
compare_gradient(AD, y) do y
k(x, y)
end
if !(args === nothing)
compare_gradient(AD, args) do p
kernelfunction(p)(x, y)
end
end
# Testing kernel matrices
A = rand(rng, dims...)
B = rand(rng, dims...)
for dim in 1:2
compare_gradient(AD, A) do a
testfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testfunction(k, a, B, dim)

@testset "kernel evaluations" begin
x = rand(rng, dims[1])
y = rand(rng, dims[1])
@testset "first argument" begin
compare_gradient(AD, x) do x
k(x, y)
end
end
compare_gradient(AD, B) do b
testfunction(k, A, b, dim)
@testset "second argument" begin
compare_gradient(AD, y) do y
k(x, y)
end
end
if !(args === nothing)
compare_gradient(AD, args) do p
testfunction(kernelfunction(p), A, dim)
@testset "hyperparameters" begin
compare_gradient(AD, args) do p
kernelfunction(p)(x, y)
end
end
end
end

compare_gradient(AD, A) do a
testdiagfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testdiagfunction(k, a, B, dim)
end
compare_gradient(AD, B) do b
testdiagfunction(k, A, b, dim)
end
if args !== nothing
compare_gradient(AD, args) do p
testdiagfunction(kernelfunction(p), A, dim)
Comment on lines -274 to -285
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was doing exactly the same for testdiagfunction as the code above for testfunction, so I've unified it with a for loop over the two functions.

@testset "kernel matrices" begin
A = rand(rng, dims...)
B = rand(rng, dims...)
for dim in 1:2
compare_gradient(AD, A) do a
testfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testfunction(k, a, B, dim)
end
compare_gradient(AD, B) do b
testfunction(k, A, b, dim)
end
if !(args === nothing)
compare_gradient(AD, args) do p
testfunction(kernelfunction(p), A, dim)
end
end

compare_gradient(AD, A) do a
testdiagfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testdiagfunction(k, a, B, dim)
end
compare_gradient(AD, B) do b
testdiagfunction(k, A, b, dim)
end
if args !== nothing
compare_gradient(AD, args) do p
testdiagfunction(kernelfunction(p), A, dim)
end
end
end
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rossviljoen doing the changes above I spotted that you overloaded test_AD() for MOKernel - this worries me: it seems like all the "fallback" AD tests above are not run. It also seems to suggest that we can't substitute an MOKernel where we need an arbitrary Kernel (which would I believe violate Liskov substitution principle). Could you comment pls? thanks:)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that was me?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, apologies! Now I'm not even sure anymore why I assumed it was you 😔 @willtebbutt / @thomasgudjonwright ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm don't think it was me. A lot of this code was @theogf I think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did the original compare gradient but someone else extended it to MOKernel

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't me! Could always look back at the commit history :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should've checked git blame in the first place... it was #263 (@david-vicente).

This to me seems like a case where overloading the same method isn't really the right thing to do (because it suggests they do the same thing - it's a kernel AD check in either case - but they don't actually seem to check the same things)... now that i've got y'all in a thread, what is your opinion on these ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separated out into #416

Expand Down
4 changes: 3 additions & 1 deletion test/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
@test repr(tp ∘ tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
test_ADs(
x -> SEKernel() ∘ (ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))),
randn(rng, 4),
randn(rng, 4);
ADs=[:ForwardDiff, :ReverseDiff]
st-- marked this conversation as resolved.
Show resolved Hide resolved
)
@test_broken "test_AD of chain transform is currently broken in Zygote"
end