From bfc2f834d9c03d81efce0bad64c3ce144a480537 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 21:05:21 -0400 Subject: [PATCH] fix: broken enzyme tests --- Project.toml | 2 +- ext/LuxLibEnzymeExt.jl | 2 +- test/common_ops/dense_tests.jl | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 27b771f8..d1a88266 100644 --- a/Project.toml +++ b/Project.toml @@ -73,7 +73,7 @@ LuxCore = "1" MKL = "0.7" MLDataDevices = "1.1.1" Markdown = "1.10" -NNlib = "0.9.21" +NNlib = "0.9.24" Octavian = "0.3.28" Polyester = "0.7.15" Random = "1.10" diff --git a/ext/LuxLibEnzymeExt.jl b/ext/LuxLibEnzymeExt.jl index 14855718..958075c4 100644 --- a/ext/LuxLibEnzymeExt.jl +++ b/ext/LuxLibEnzymeExt.jl @@ -5,4 +5,4 @@ using Static: True Utils.is_extension_loaded(::Val{:Enzyme}) = True() -end \ No newline at end of file +end diff --git a/test/common_ops/dense_tests.jl b/test/common_ops/dense_tests.jl index e438647c..99d1810c 100644 --- a/test/common_ops/dense_tests.jl +++ b/test/common_ops/dense_tests.jl @@ -169,14 +169,13 @@ end end @testitem "Enzyme.Forward patch: dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin - using LuxLib, Random, LuxTestUtils, Enzyme + using LuxLib, Random, ForwardDiff, Enzyme x = rand(Float32, 2, 2) f(x) = sum(abs2, LuxLib.Impl.matmul(x, x)) - # Just test that we don't crash - @test length(Enzyme.gradient(Forward, f, x)) == 4 + @test only(Enzyme.gradient(Forward, f, x)) ≈ ForwardDiff.gradient(f, x) end @testitem "Enzyme rules for fused dense" tags=[:dense] setup=[SharedTestSetup] skip=:(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin