From 1538324c43add56def0caa78cd965e6e021f3230 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 16 Sep 2024 11:34:11 -0400 Subject: [PATCH] fix: enzyme reverse bias needs a check on Const --- Project.toml | 2 +- src/impl/activation.jl | 10 +++++----- src/impl/batched_mul.jl | 4 ++-- src/impl/matmul.jl | 6 +++--- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index ff5f055c..390cec9d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.2.2" +version = "1.2.3" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/impl/activation.jl b/src/impl/activation.jl index de2cfc7e..604b0614 100644 --- a/src/impl/activation.jl +++ b/src/impl/activation.jl @@ -196,17 +196,17 @@ for (f, dfdx) in [ (:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))) #! format: on ] - @eval CRC.@scalar_rule($f(x), $dfdx) + @eval CRC.@scalar_rule($f(x), $(dfdx)) ∇f = Symbol(:∇broadcasted_, f) @eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f), x::Union{Numeric, Broadcast.Broadcasted}) - Ω = $f.(x) - function $∇f(dΩ) - ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $dfdx), CRC.@thunk @.(dΩ*$dfdx)) + Ω = $(f).(x) + function $(∇f)(dΩ) + ∂x = CRC.InplaceableThunk(dx -> @.(dx+=dΩ * $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx))) return CRC.NoTangent(), CRC.NoTangent(), ∂x end - return Ω, $∇f + return Ω, $(∇f) end end diff --git a/src/impl/batched_mul.jl b/src/impl/batched_mul.jl index de760581..c5e3fdf3 100644 --- a/src/impl/batched_mul.jl +++ b/src/impl/batched_mul.jl @@ -137,8 +137,8 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!) end dCs = C.dval - dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval - dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval + dAs = A isa EnzymeCore.Const ? dCs : A.dval + dBs = B isa EnzymeCore.Const ? dCs : B.dval if EnzymeRules.width(cfg) == 1 dCs = (dCs,) diff --git a/src/impl/matmul.jl b/src/impl/matmul.jl index b7eaf7bd..59767c58 100644 --- a/src/impl/matmul.jl +++ b/src/impl/matmul.jl @@ -270,9 +270,9 @@ function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(matmuladd!)}, end ∂Cs = C.dval - ∂As = (typeof(A) <: EnzymeCore.Const) ? ∂Cs : A.dval - ∂Bs = (typeof(B) <: EnzymeCore.Const) ? ∂Cs : B.dval - ∂bs = bias.dval + ∂As = A isa EnzymeCore.Const ? ∂Cs : A.dval + ∂Bs = B isa EnzymeCore.Const ? ∂Cs : B.dval + ∂bs = bias isa EnzymeCore.Const ? ∂Cs : bias.dval if EnzymeRules.width(cfg) == 1 ∂Cs = (∂Cs,)