Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

fix: missing enzyme rules for matmuladd! (CUDA support) #159

Merged
merged 4 commits into from
Sep 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.2.1"
version = "1.2.2"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
72 changes: 72 additions & 0 deletions src/impl/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,78 @@ function CRC.rrule(
end

# EnzymeRules
function EnzymeRules.augmented_primal(cfg, ::EnzymeCore.Const{typeof(matmuladd!)},
::Type{EnzymeCore.Const{Nothing}}, C::EnzymeCore.Annotation{<:AbstractMatrix},
opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode},
A::EnzymeCore.Annotation{<:AbstractMatrix},
B::EnzymeCore.Annotation{<:AbstractMatrix},
bias::EnzymeCore.Annotation{<:AbstractVector})
A_cache = EnzymeRules.overwritten(cfg)[4] && !(B isa EnzymeCore.Const) &&
!(C isa EnzymeCore.Const) ? copy(A.val) : nothing
B_cache = EnzymeRules.overwritten(cfg)[5] && !(A isa EnzymeCore.Const) &&
!(C isa EnzymeCore.Const) ? copy(B.val) : nothing

if !(C isa EnzymeCore.DuplicatedNoNeed || C isa EnzymeCore.BatchDuplicatedNoNeed)
matmuladd!(C.val, A.val, B.val, bias.val)
end

return EnzymeRules.AugmentedReturn(nothing, nothing, (A_cache, B_cache))
end

function EnzymeRules.reverse(cfg, ::EnzymeCore.Const{typeof(matmuladd!)},
::Type{EnzymeCore.Const{Nothing}}, (A_cache, B_cache),
C::EnzymeCore.Annotation{<:AbstractMatrix},
opmode::EnzymeCore.Const{<:AbstractInternalArrayOpMode},
A::EnzymeCore.Annotation{<:AbstractMatrix},
B::EnzymeCore.Annotation{<:AbstractMatrix},
bias::EnzymeCore.Annotation{<:AbstractVector})
if !(C isa EnzymeCore.Const) && !(B isa EnzymeCore.Const)
if !EnzymeRules.overwritten(cfg)[4]
A_cache = A.val
end
end

if !(C isa EnzymeCore.Const) && !(A isa EnzymeCore.Const)
if !EnzymeRules.overwritten(cfg)[5]
B_cache = B.val
end
end

∂Cs = C.dval
∂As = (typeof(A) <: EnzymeCore.Const) ? ∂Cs : A.dval
∂Bs = (typeof(B) <: EnzymeCore.Const) ? ∂Cs : B.dval
∂bs = bias.dval

if EnzymeRules.width(cfg) == 1
∂Cs = (∂Cs,)
∂As = (∂As,)
∂Bs = (∂Bs,)
∂bs = (∂bs,)
end

for (∂C, ∂A, ∂B, ∂b) in zip(∂Cs, ∂As, ∂Bs, ∂bs)
if !(C isa EnzymeCore.Const) && ∂C !== C.val
if !(bias isa EnzymeCore.Const) && ∂b !== bias.val
sum!(∂b, ∂C)
end

if !(A isa EnzymeCore.Const) && ∂A !== A.val
# TODO: we don't use our faster matmul here since we lack the 5 arg version
mul!(∂A, ∂C, B_cache', true, true)
end

if !(B isa EnzymeCore.Const) && ∂B !== B.val
# TODO: we don't use our faster matmul here since we lack the 5 arg version
mul!(∂B, A_cache', ∂C, true, true)
end

∂C .= 0
end
end

return ntuple(Returns(nothing), 5)
end

@enzyme_alternative matmul_octavian! matmul_linalg_default!
@enzyme_alternative serial_matmul_loopvec! matmul_linalg_default!
@enzyme_alternative matmul_loopvec! matmul_linalg_default!
Expand Down
18 changes: 18 additions & 0 deletions test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,12 @@ end
return
end

function matmuladd!(C, A, B, bias)
op = LuxLib.internal_operation_mode((C, A, B, bias))
LuxLib.Impl.matmuladd!(C, op, A, B, bias)
return
end

rng = StableRNG(1234)

ALL_ACTS = [identity, tanh, tanh_fast, sigmoid, sigmoid_fast,
Expand Down Expand Up @@ -218,6 +224,18 @@ end
if hasbias
@test db≈db_zyg atol=1e-3 rtol=1e-3
end

act === identity || !hasbias || continue
avik-pal marked this conversation as resolved.
Show resolved Hide resolved

Enzyme.autodiff(Reverse, matmuladd!, Duplicated(y, copy(dy)),
Duplicated(weight, dweight), Duplicated(x, dx), b_enz)

_, pb_f = Zygote.pullback(matmuladd, weight, x, b)
avik-pal marked this conversation as resolved.
Show resolved Hide resolved
dweight_zyg, dx_zyg, db_zyg = pb_f(dy)

@test dweight≈dweight_zyg atol=1e-3 rtol=1e-3
@test dx≈dx_zyg atol=1e-3 rtol=1e-3
@test db≈db_zyg atol=1e-3 rtol=1e-3
end
end
end
Loading