Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: update minimum version of Enzyme to 0.13 #166

Merged
merged 7 commits into from
Sep 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
Expand All @@ -44,6 +45,7 @@ LuxLibAppleAccelerateExt = "AppleAccelerate"
LuxLibBLISBLASExt = "BLISBLAS"
LuxLibCUDAExt = "CUDA"
LuxLibMKLExt = "MKL"
LuxLibEnzymeExt = "Enzyme"
LuxLibReverseDiffExt = "ReverseDiff"
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
LuxLibTrackerExt = "Tracker"
Expand All @@ -59,7 +61,8 @@ ChainRulesCore = "1.24"
Compat = "4.15.0"
CpuId = "0.3"
DispatchDoctor = "0.4.12"
EnzymeCore = "0.7.7"
Enzyme = "0.13.1"
EnzymeCore = "0.8"
FastClosures = "0.3.2"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
Expand Down
8 changes: 8 additions & 0 deletions ext/LuxLibEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module LuxLibEnzymeExt

using LuxLib: Utils
using Static: True

Utils.is_extension_loaded(::Val{:Enzyme}) = True()

end
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 3 additions & 3 deletions ext/LuxLibReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ Utils.remove_tracking(x::TrackedArray) = ReverseDiff.value(x)
Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = ReverseDiff.value.(x)
Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T)

Utils.within_gradient(::TrackedReal) = True()
Utils.within_gradient(::TrackedArray) = True()
Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True()
Utils.within_autodiff(::TrackedReal) = True()
Utils.within_autodiff(::TrackedArray) = True()
Utils.within_autodiff(::AbstractArray{<:TrackedReal}) = True()

# Traits extensions
Traits.is_tracked(::Type{<:TrackedReal}) = True()
Expand Down
6 changes: 3 additions & 3 deletions ext/LuxLibTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ Utils.remove_tracking(x::TrackedArray) = Tracker.data(x)
Utils.remove_tracking(x::AbstractArray{<:TrackedReal}) = Tracker.data.(x)
Utils.remove_tracking(::Type{<:TrackedReal{T}}) where {T} = Utils.remove_tracking(T)

Utils.within_gradient(::TrackedReal) = True()
Utils.within_gradient(::TrackedArray) = True()
Utils.within_gradient(::AbstractArray{<:TrackedReal}) = True()
Utils.within_autodiff(::TrackedReal) = True()
Utils.within_autodiff(::TrackedArray) = True()
Utils.within_autodiff(::AbstractArray{<:TrackedReal}) = True()

# Traits extensions
Traits.is_tracked(::Type{<:TrackedReal}) = True()
Expand Down
5 changes: 3 additions & 2 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,18 +213,19 @@ end
# Enzyme works for all of these except `gelu`.
# See https://github.com/EnzymeAD/Enzyme.jl/issues/1671
function EnzymeRules.augmented_primal(
cfg::EnzymeRules.ConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)},
cfg::EnzymeRules.RevConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)},
::Type{<:EnzymeCore.Active}, x::EnzymeCore.Active{<:Number})
primal = EnzymeRules.needs_primal(cfg) ? func.val(x.val) : nothing
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(
::EnzymeRules.ConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)},
::EnzymeRules.RevConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)},
dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number})
return (dret.val * ∇gelu(x.val),)
end

# FIXME: ForwardRules changed in EnzymeCore 0.8
function EnzymeRules.forward(
::EnzymeCore.Const{typeof(gelu)}, ::Type{<:EnzymeCore.Duplicated},
x::EnzymeCore.Duplicated{<:Number})
Expand Down
4 changes: 2 additions & 2 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ end
for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!)
@eval begin
function EnzymeRules.augmented_primal(
cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))},
cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))},
::Type{RT}, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT}
Expand All @@ -155,7 +155,7 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!)
end

function EnzymeRules.reverse(
cfg::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(func))},
cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))},
::Type{RT}, cache, C::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
A::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}},
B::EnzymeCore.Annotation{<:AbstractArray{<:Any, 3}}) where {RT}
Expand Down
25 changes: 14 additions & 11 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ CRC.@non_differentiable safe_minimum(::Any...)
macro enzyme_alternative(f₁, f₂)
return esc(quote
function EnzymeRules.augmented_primal(
::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))},
::EnzymeRules.RevConfig, ::EnzymeCore.Const{typeof($(f₁))},
::Type{RT}, args...) where {RT}
fwd, rev = EnzymeCore.autodiff_thunk(
EnzymeCore.ReverseSplitWithPrimal, EnzymeCore.Const{typeof($(f₂))},
Expand All @@ -245,11 +245,12 @@ macro enzyme_alternative(f₁, f₂)
end

function EnzymeRules.reverse(
::EnzymeRules.ConfigWidth, ::EnzymeCore.Const{typeof($(f₁))},
::EnzymeRules.RevConfig, ::EnzymeCore.Const{typeof($(f₁))},
::Type{RT}, (tape, rev), args...) where {RT}
return only(rev(EnzymeCore.Const($(f₂)), args..., tape))
end

# FIXME: ForwardRules changed in EnzymeCore 0.8
function EnzymeRules.forward(
::EnzymeCore.Const{typeof($(f₁))}, ::Type{RT}, args...) where {RT}
EnzymeCore.autodiff(EnzymeCore.Forward, EnzymeCore.Const($(f₂)), RT, args...)
Expand All @@ -269,20 +270,23 @@ end
return
end

within_gradient_vararg(args...) = unrolled_any(within_gradient, args)
within_autodiff_vararg(args...) = unrolled_any(within_autodiff, args)

within_gradient(_) = False()
within_gradient(::ForwardDiff.Dual) = True()
within_gradient(::AbstractArray{<:ForwardDiff.Dual}) = True()
function within_autodiff(_)
is_extension_loaded(Val(:Enzyme)) && return static(EnzymeCore.within_autodiff())
return False()
end
within_autodiff(::ForwardDiff.Dual) = True()
within_autodiff(::AbstractArray{<:ForwardDiff.Dual}) = True()

CRC.rrule(::typeof(within_gradient), x) = True(), _ -> (∂∅, ∂∅)
CRC.rrule(::typeof(within_autodiff), x) = True(), _ -> (∂∅, ∂∅)

static_training_mode(::Nothing, args...) = within_gradient_vararg(args...)
static_training_mode(::Nothing, args...) = within_autodiff_vararg(args...)

function static_training_mode(
training::Union{Bool, Val{true}, Val{false}, StaticBool}, args...)
return static_training_mode_check(
training, static(training), within_gradient_vararg(args...))
training, static(training), within_autodiff_vararg(args...))
end

function CRC.rrule(::typeof(static_training_mode), ::Nothing, args...)
Expand All @@ -304,8 +308,7 @@ function static_training_mode_check(training, ::True, ::False)
`Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. \
Reliance on this behavior is discouraged, and is not guaranteed by Semantic \
Versioning, and might be removed without a deprecation cycle. It is recommended \
to fix this issue in your code. \n\n\
If you are using Enzyme.jl, then you can ignore this warning." maxlog=1
to fix this issue in your code." maxlog=1
return True()
end

Expand Down
6 changes: 3 additions & 3 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ BLISBLAS = "0.1"
BenchmarkTools = "1.5"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.16"
Enzyme = "0.12.26"
EnzymeCore = "0.7.7"
Enzyme = "0.13.1"
EnzymeCore = "0.8"
ExplicitImports = "1.9.0"
ForwardDiff = "0.10.36"
Hwloc = "3.2"
InteractiveUtils = "<0.0.1, 1"
JLArrays = "0.1.5"
LuxTestUtils = "1.2"
LuxTestUtils = "1.2.1"
MKL = "0.7"
MLDataDevices = "1.0.0"
NNlib = "0.9.21"
Expand Down
Loading