Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: enzyme reverse bias needs a check on Const
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 16, 2024
1 parent 7ba127a commit ac37989
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.2"
version = "1.2.3"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
10 changes: 5 additions & 5 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,17 @@ for (f, dfdx) in [
(:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω)))))
#! format: on
]
@eval CRC.@scalar_rule($f(x), $dfdx)
@eval CRC.@scalar_rule($f(x), $(dfdx))

∇f = Symbol(:∇broadcasted_, f)
@eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f),
x::Union{Numeric, Broadcast.Broadcasted})
Ω = $f.(x)
function $∇f(dΩ)
∂x = CRC.InplaceableThunk(dx -> @.(dx+=* $dfdx), CRC.@thunk @.(dΩ*$dfdx))
Ω = $(f).(x)
function $(∇f)(dΩ)
∂x = CRC.InplaceableThunk(dx -> @.(dx+=* $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx)))
return CRC.NoTangent(), CRC.NoTangent(), ∂x
end
return Ω, $∇f
return Ω, $(∇f)
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!)
end

dCs = C.dval
dAs = (typeof(A) <: EnzymeCore.Const) ? dCs : A.dval
dBs = (typeof(B) <: EnzymeCore.Const) ? dCs : B.dval
dAs = A isa EnzymeCore.Const ? dCs : A.dval
dBs = B isa EnzymeCore.Const ? dCs : B.dval

if EnzymeRules.width(cfg) == 1
dCs = (dCs,)
Expand Down
6 changes: 3 additions & 3 deletions src/impl/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,9 @@ function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(matmuladd!)},
end

∂Cs = C.dval
∂As = (typeof(A) <: EnzymeCore.Const) ? ∂Cs : A.dval
∂Bs = (typeof(B) <: EnzymeCore.Const) ? ∂Cs : B.dval
∂bs = bias.dval
∂As = A isa EnzymeCore.Const ? ∂Cs : A.dval
∂Bs = A isa EnzymeCore.Const ? ∂Cs : B.dval
∂bs = bias isa EnzymeCore.Const ? ∂Cs : bias.dval

if EnzymeRules.width(cfg) == 1
∂Cs = (∂Cs,)
Expand Down

0 comments on commit ac37989

Please sign in to comment.