diff --git a/Project.toml b/Project.toml index 36d6396f..fa4c34df 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ ParallelStencil_MetalExt = "Metal" AMDGPU = "0.6, 0.7, 0.8, 0.9, 1" CUDA = "3.12, 4, 5" CellArrays = "0.3" -Enzyme = "0.11, 0.12, 0.13" +Enzyme = "0.12, 0.13" MacroTools = "0.5" Metal = "1.2" Polyester = "0.7" diff --git a/src/AD.jl b/src/AD.jl index 90640f05..98b1421f 100644 --- a/src/AD.jl +++ b/src/AD.jl @@ -7,8 +7,8 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the import ParallelStencil.AD # Functions -- `autodiff_deferred!`: wraps function `autodiff_deferred`. -- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`. +- `autodiff_deferred!`: wraps function `autodiff_deferred`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const. +- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const. # Examples const USE_GPU = true @@ -43,9 +43,6 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the main() -!!! note "Enzyme runtime activity default" - If ParallelStencil is initialized with Threads, then `Enzyme.API.runtimeActivity!(true)` is called to ensure correct behavior of Enzyme. If you want to disable this behavior, then call `Enzyme.API.runtimeActivity!(false)` after loading ParallelStencil. - To see a description of a function type `?`. """ module AD diff --git a/src/ParallelKernel/EnzymeExt/AD.jl b/src/ParallelKernel/EnzymeExt/AD.jl index 7c1a9664..d3ef4a86 100644 --- a/src/ParallelKernel/EnzymeExt/AD.jl +++ b/src/ParallelKernel/EnzymeExt/AD.jl @@ -7,11 +7,8 @@ Provides GPU-compatible wrappers for automatic differentiation functions of the import ParallelKernel.AD # Functions -- `autodiff_deferred!`: wraps function `autodiff_deferred`. -- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`. - -!!! note "Enzyme runtime activity default" - If ParallelKernel is initialized with Threads, then `Enzyme.API.runtimeActivity!(true)` is called to ensure correct behavior of Enzyme. If you want to disable this behavior, then call `Enzyme.API.runtimeActivity!(false)` after loading ParallelStencil. +- `autodiff_deferred!`: wraps function `autodiff_deferred`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const. +- `autodiff_deferred_thunk!`: wraps function `autodiff_deferred_thunk`, promoting all arguments that are not Enzyme.Annotations to Enzyme.Const. To see a description of a function type `?`. """ diff --git a/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl b/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl index b500f01e..f086bfe8 100644 --- a/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl +++ b/src/ParallelKernel/EnzymeExt/autodiff_gpu.jl @@ -2,16 +2,17 @@ import ParallelStencil import ParallelStencil: PKG_THREADS, PKG_POLYESTER import Enzyme +# NOTE: package specific initialization of Enzyme could be done as follows (not needed in the currently supported versions of Enzyme) # function ParallelStencil.ParallelKernel.AD.init_AD(package::Symbol) # if iscpu(package) # Enzyme.API.runtimeActivity!(true) # NOTE: this is currently required for Enzyme to work correctly with threads # end # end -# ParallelStencil injects a configuration parameter at the end, for Enzyme we need to wrap that parameter as a Annotation -# for all purposes this ought to be Const. This is not ideal since we might accidentially wrap other parameters the user -# provided as well. This is needed to support @parallel autodiff_deferred(...) - function promote_to_const(args::Vararg{Any,N}) where N +# NOTE: @parallel injects four parameters at the end, which need to be wrapped as Annotations. The current solution is to wrap all +# arguments which are not already Annotations (all the other arguments must be Annotations). Should this change, then one could +# explicitly wrap just the injected parameters. +function promote_to_const(args::Vararg{Any,N}) where N ntuple(Val(N)) do i @inline if !(args[i] isa Enzyme.Annotation || diff --git a/test/ParallelKernel/test_parallel.jl b/test/ParallelKernel/test_parallel.jl index 878cbf8b..f2f6e6da 100644 --- a/test/ParallelKernel/test_parallel.jl +++ b/test/ParallelKernel/test_parallel.jl @@ -133,8 +133,8 @@ eval(:( end return end - @parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, Const(f!), Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) - Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!),Const, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a)) + @parallel configcall=f!(A, B, a) AD.autodiff_deferred!(Enzyme.Reverse, f!, Const, DuplicatedNoNeed(A, Ā), DuplicatedNoNeed(B, B̄), Const(a)) # NOTE: f! is automatically promoted to Const. + Enzyme.autodiff_deferred(Enzyme.Reverse, Const(g!), Const, DuplicatedNoNeed(A_ref, Ā_ref), DuplicatedNoNeed(B_ref, B̄_ref), Const(a)) @test Array(Ā) ≈ Ā_ref @test Array(B̄) ≈ B̄_ref end