Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

fix: rollback custom gelu implementation #168

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.0"
version = "1.3.1"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
38 changes: 0 additions & 38 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ CRC.@non_differentiable select_fastest_activation(::Any...)
module SLEEFActivations

using ChainRulesCore: ChainRulesCore
using EnzymeCore: EnzymeCore, EnzymeRules
using NNlib: NNlib
using SLEEFPirates: SLEEFPirates

Expand All @@ -164,32 +163,16 @@ const CRC = ChainRulesCore
sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x)
softplus(x::Number) = SLEEFPirates.softplus(x)
logsigmoid(x::Number) = -softplus(-x)
gelu(x::Number) = SLEEFPirates.gelu(x)
swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x))
lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x))
tanh(x::Number) = SLEEFPirates.tanh(x)
tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x)

const gelu_λ = √(2 / π)
const gelu_2λ = √(8 / π)

function ∇gelu(x::Number)
α = oftype(x, 0.044715)
α2 = oftype(x, 0.08943)
λλ = oftype(x, gelu_2λ)
x2 = Base.FastMath.mul_fast(x, x)
t = muladd(x2, α, one(x))
Ω = sigmoid_fast(λλ * x * t)
dσ = conj(Ω * (1 - Ω))
return muladd(dσ * λλ * muladd(x2, α2, t), x, Ω)
end

for (f, dfdx) in [
#! format: off
(:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))),
(:softplus, :(sigmoid_fast(x))),
(:logsigmoid, :(sigmoid_fast(-x))),
(:gelu, :(∇gelu(x))),
(:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))),
(:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))),
(:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))),
Expand All @@ -210,26 +193,6 @@ for (f, dfdx) in [
end
end

# Enzyme works for all of these except `gelu`.
# See https://github.com/EnzymeAD/Enzyme.jl/issues/1671
function EnzymeRules.augmented_primal(
cfg::EnzymeRules.RevConfigWidth{1}, func::EnzymeCore.Const{typeof(gelu)},
::Type{<:EnzymeCore.Active}, x::EnzymeCore.Active{<:Number})
primal = EnzymeRules.needs_primal(cfg) ? func.val(x.val) : nothing
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(
::EnzymeRules.RevConfigWidth{1}, ::EnzymeCore.Const{typeof(gelu)},
dret::EnzymeCore.Active, ::Nothing, x::EnzymeCore.Active{<:Number})
return (dret.val * ∇gelu(x.val),)
end

function EnzymeRules.forward(::EnzymeRules.FwdConfig, ::EnzymeCore.Const{typeof(gelu)},
::Type{<:EnzymeCore.Duplicated}, x::EnzymeCore.Duplicated{<:Number})
return EnzymeCore.Duplicated(gelu(x.val), x.dval * ∇gelu(x.val))
end

fast_act(f::F, ::Type{T}) where {F, T} = f
fast_act(f::F, ::Type{Float32}) where {F} = fast_act(f)

Expand All @@ -238,7 +201,6 @@ for (fbase, ffast) in [
(NNlib.sigmoid_fast, sigmoid_fast),
(NNlib.softplus, softplus),
(NNlib.logsigmoid, logsigmoid),
(NNlib.gelu, gelu),
(NNlib.swish, swish),
(NNlib.lisht, lisht),
(Base.tanh, tanh),
Expand Down
Loading