Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
refactor: move SLEEFPirates to an ext
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 17, 2024
1 parent c8e5006 commit 68fc1b3
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 66 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ jobs:
blas_backend: "default"
version: "1.10"
loopvec: "false"
- os: ubuntu-latest
test_group: "other_ops"
blas_backend: "default"
version: "1.10"
loopvec: "false"
- os: macos-latest
test_group: "dense"
blas_backend: "appleaccelerate"
Expand Down
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand All @@ -37,6 +36,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

Expand All @@ -49,6 +49,7 @@ LuxLibEnzymeExt = "Enzyme"
LuxLibLoopVectorizationExt = "LoopVectorization"
LuxLibOctavianExt = ["Octavian", "LoopVectorization"]
LuxLibReverseDiffExt = "ReverseDiff"
LuxLibSLEEFPiratesExt = "SLEEFPirates"
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
LuxLibTrackerExt = "Tracker"
LuxLibcuDNNExt = ["CUDA", "cuDNN"]
Expand Down
1 change: 1 addition & 0 deletions ext/LuxLibOctavianExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LuxLibOctavianExt

using Octavian: Octavian
using Static: True

using LuxLib: LuxLib, Utils

Expand Down
58 changes: 58 additions & 0 deletions ext/LuxLibSLEEFPiratesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
module LuxLibSLEEFPiratesExt

using ChainRulesCore: ChainRulesCore
using NNlib: NNlib
using SLEEFPirates: SLEEFPirates

using LuxLib: Numeric, Impl

const CRC = ChainRulesCore

sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x)
softplus(x::Number) = SLEEFPirates.softplus(x)
logsigmoid(x::Number) = -softplus(-x)
swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x))
lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x))
tanh(x::Number) = SLEEFPirates.tanh(x)
tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x)

for (f, dfdx) in [
#! format: off
(:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))),
(:softplus, :(sigmoid_fast(x))),
(:logsigmoid, :(sigmoid_fast(-x))),
(:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))),
(:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))),
(:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))),
(:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω)))))
#! format: on
]
@eval CRC.@scalar_rule($f(x), $(dfdx))

∇f = Symbol(:∇broadcasted_, f)
@eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f),
x::Union{Numeric, Broadcast.Broadcasted})
Ω = $(f).(x)
function $(∇f)(dΩ)
∂x = CRC.InplaceableThunk(dx -> @.(dx+=* $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx)))
return CRC.NoTangent(), CRC.NoTangent(), ∂x
end
return Ω, $(∇f)
end
end

for (fbase, ffast) in [
#! format: off
(NNlib.sigmoid_fast, sigmoid_fast),
(NNlib.softplus, softplus),
(NNlib.logsigmoid, logsigmoid),
(NNlib.swish, swish),
(NNlib.lisht, lisht),
(Base.tanh, tanh),
(NNlib.tanh_fast, tanh_fast)
#! format: on
]
@eval Impl.sleefpirates_fast_act(::typeof($fbase)) = $ffast
end

end
70 changes: 5 additions & 65 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,73 +132,13 @@ end
select_fastest_activation(f::F, ::AbstractInternalArrayOpMode, ::Type{T}) where {F, T} = f

function select_fastest_activation(f::F, ::LoopedArrayOp, ::Type{T}) where {F, T}
return SLEEFActivations.fast_act(f, T)
return sleefpirates_fast_act(f, T)
end

CRC.@non_differentiable select_fastest_activation(::Any...)

# Fast activations via SLEEFPirates.jl
module SLEEFActivations

using ChainRulesCore: ChainRulesCore
using NNlib: NNlib
using SLEEFPirates: SLEEFPirates

using ....LuxLib: Numeric

const CRC = ChainRulesCore

sigmoid_fast(x::Number) = SLEEFPirates.sigmoid_fast(x)
softplus(x::Number) = SLEEFPirates.softplus(x)
logsigmoid(x::Number) = -softplus(-x)
swish(x::Number) = Base.FastMath.mul_fast(x, sigmoid_fast(x))
lisht(x::Number) = Base.FastMath.mul_fast(x, tanh_fast(x))
tanh(x::Number) = SLEEFPirates.tanh(x)
tanh_fast(x::Number) = SLEEFPirates.tanh_fast(x)

for (f, dfdx) in [
#! format: off
(:sigmoid_fast, :(conj(Base.FastMath.mul_fast(Ω, Base.FastMath.sub_fast(1, Ω))))),
(:softplus, :(sigmoid_fast(x))),
(:logsigmoid, :(sigmoid_fast(-x))),
(:swish, :(Base.FastMath.add_fast(Ω, Base.FastMath.mul_fast(sigmoid_fast(x), Base.FastMath.sub_fast(1, Ω))))),
(:lisht, :(Base.FastMath.add_fast(x, Base.FastMath.mul_fast(tanh_fast(x), Base.FastMath.sub_fast(1, Ω))))),
(:tanh, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω))))),
(:tanh_fast, :(conj(Base.FastMath.sub_fast(1, Base.FastMath.mul_fast(Ω, Ω)))))
#! format: on
]
@eval CRC.@scalar_rule($f(x), $(dfdx))

∇f = Symbol(:∇broadcasted_, f)
@eval function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof($f),
x::Union{Numeric, Broadcast.Broadcasted})
Ω = $(f).(x)
function $(∇f)(dΩ)
∂x = CRC.InplaceableThunk(dx -> @.(dx+=* $(dfdx)), CRC.@thunk @.(dΩ*$(dfdx)))
return CRC.NoTangent(), CRC.NoTangent(), ∂x
end
return Ω, $(∇f)
end
end

fast_act(f::F, ::Type{T}) where {F, T} = f
fast_act(f::F, ::Type{Float32}) where {F} = fast_act(f)
sleefpirates_fast_act(f::F, ::Type{T}) where {F} = f
sleefpirates_fast_act(f::F, ::Type{Float32}) where {F} = sleefpirates_fast_act(f)
sleefpirates_fast_act(f::F) where {F} = f

for (fbase, ffast) in [
#! format: off
(NNlib.sigmoid_fast, sigmoid_fast),
(NNlib.softplus, softplus),
(NNlib.logsigmoid, logsigmoid),
(NNlib.swish, swish),
(NNlib.lisht, lisht),
(Base.tanh, tanh),
(NNlib.tanh_fast, tanh_fast)
#! format: on
]
@eval fast_act(::typeof($fbase)) = $ffast
end
fast_act(f::F) where {F} = f

CRC.@non_differentiable fast_act(::Any...)

end
CRC.@non_differentiable sleefpirates_fast_act(::Any...)

0 comments on commit 68fc1b3

Please sign in to comment.