Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: broken enzyme tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 22, 2024
1 parent 597e0a9 commit bfc2f83
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ LuxCore = "1"
MKL = "0.7"
MLDataDevices = "1.1.1"
Markdown = "1.10"
NNlib = "0.9.21"
NNlib = "0.9.24"
Octavian = "0.3.28"
Polyester = "0.7.15"
Random = "1.10"
Expand Down
2 changes: 1 addition & 1 deletion ext/LuxLibEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ using Static: True

Utils.is_extension_loaded(::Val{:Enzyme}) = True()

end
end
5 changes: 2 additions & 3 deletions test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,13 @@ end
end

@testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
using LuxLib, Random, LuxTestUtils, Enzyme
using LuxLib, Random, ForwardDiff, Enzyme

x = rand(Float32, 2, 2)

f(x) = sum(abs2, LuxLib.Impl.matmul(x, x))

# Just test that we don't crash
@test length(Enzyme.gradient(Forward, f, x)) == 4
@test only(Enzyme.gradient(Forward, f, x)) ForwardDiff.gradient(f, x)
end

@testitem "Enzyme rules for fused dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin
Expand Down

0 comments on commit bfc2f83

Please sign in to comment.