Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
refactor: move LV and octavian behind an extension
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 17, 2024
1 parent fa06a05 commit c75af09
Show file tree
Hide file tree
Showing 19 changed files with 199 additions and 133 deletions.
22 changes: 20 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ concurrency:

jobs:
ci:
name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.blas_backend }}
name: Julia ${{ matrix.version }} - ${{ matrix.test_group }} - ${{ matrix.os }} - ${{ matrix.blas_backend }} - ${{ matrix.loopvec }}
if: ${{ !contains(github.event.head_commit.message, '[skip tests]') }}
runs-on: ${{ matrix.os }}
strategy:
Expand All @@ -43,27 +43,44 @@ jobs:
- "others"
blas_backend:
- "default"
loopvec:
- "true"
include:
- os: ubuntu-latest
test_group: "dense"
blas_backend: "blis"
version: "1.10"
loopvec: "true"
- os: ubuntu-latest
test_group: "dense"
blas_backend: "mkl"
version: "1.10"
loopvec: "true"
- os: ubuntu-latest
test_group: "dense"
blas_backend: "default"
version: "1.10"
loopvec: "false"
- os: ubuntu-latest
test_group: "batched_ops"
blas_backend: "default"
version: "1.10"
loopvec: "false"
- os: macos-latest
test_group: "dense"
blas_backend: "appleaccelerate"
version: "1.10"
loopvec: "true"
- os: macos-latest
test_group: "all"
blas_backend: "default"
version: "1.10"
loopvec: "true"
- os: windows-latest
test_group: "all"
blas_backend: "default"
version: "1.10"
loopvec: "true"
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand All @@ -84,6 +101,7 @@ jobs:
env:
LUXLIB_TEST_GROUP: ${{ matrix.test_group }}
LUXLIB_BLAS_BACKEND: ${{ matrix.blas_backend }}
LUXLIB_LOAD_LOOPVEC: ${{ matrix.loopvec }}
- uses: julia-actions/julia-processcoverage@v1
with:
directories: src,ext
Expand Down Expand Up @@ -129,7 +147,7 @@ jobs:
repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
path: downstream
- name: Load this and run the downstream tests
shell: julia --code-coverage=user --color=yes --project=downstream {0}
shell: julia --code-coverage=user --color=true --project=downstream {0}
run: |
using Pkg
try
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Hwloc = "0e44f5e4-bd66-52a0-8798-143a42290a1d"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -36,6 +34,8 @@ BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
Expand All @@ -46,6 +46,8 @@ LuxLibBLISBLASExt = "BLISBLAS"
LuxLibCUDAExt = "CUDA"
LuxLibMKLExt = "MKL"
LuxLibEnzymeExt = "Enzyme"
LuxLibLoopVectorizationExt = "LoopVectorization"
LuxLibOctavianExt = ["Octavian", "LoopVectorization"]
LuxLibReverseDiffExt = "ReverseDiff"
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
LuxLibTrackerExt = "Tracker"
Expand Down
2 changes: 2 additions & 0 deletions benchmarks/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand Down
1 change: 1 addition & 0 deletions benchmarks/runbenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Pkg
using BenchmarkTools
using InteractiveUtils
using LinearAlgebra
using Octavian, LoopVectorization

const SUITE = BenchmarkGroup()
BenchmarkTools.DEFAULT_PARAMETERS.seconds = 5
Expand Down
49 changes: 49 additions & 0 deletions ext/LuxLibLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module LuxLibLoopVectorizationExt

using LoopVectorization: LoopVectorization, @tturbo, @turbo, indices
using Static: True

using LuxLib: LuxLib, Utils

Utils.is_extension_loaded(::Val{:LoopVectorization}) = True()

Utils.can_loopvec_args_check(::True, args...) = LoopVectorization.check_args(args...)

# matmul
for serial in (true, false)
opname = serial ? :serial_matmul_loopvec! : :matmul_loopvec!
@eval @inline function LuxLib.Impl.$(opname)(
C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number)
if !iszero(β) # Secial case this because Base.FastMath.mul_fast(NaN, false) = NaN
@turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1)
Cⱼₖ = zero(eltype(C))
for I in indices((A, B), (2, 1))
Cⱼₖ += A[J, I] * B[I, K]
end
C[J, K] = α * Cⱼₖ + β * C[J, K]
end
else
@turbo thread=$(!serial) for K in indices((C, B), 2), J in indices((C, A), 1)
Cⱼₖ = zero(eltype(C))
for I in indices((A, B), (2, 1))
Cⱼₖ += A[J, I] * B[I, K]
end
C[J, K] = α * Cⱼₖ
end
end
end
end

@inline function LuxLib.Impl.matmuladd_loopvec!(
C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, bias::AbstractVector)
@tturbo for K in indices((C, B), 2), J in indices((C, A), 1)
Cⱼₖ = zero(eltype(C))
for I in indices((A, B), (2, 1))
Cⱼₖ += A[J, I] * B[I, K]
end
C[J, K] = bias[J] + Cⱼₖ
end
return
end

end
15 changes: 15 additions & 0 deletions ext/LuxLibOctavianExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module LuxLibOctavianExt

using Octavian: Octavian

using LuxLib: LuxLib, Utils

Utils.is_extension_loaded(::Val{:Octavian}) = True()

@inline function LuxLib.Impl.matmul_octavian!(
C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix, α::Number, β::Number)
Octavian.matmul!(C, A, B, α, β)
return
end

end
5 changes: 1 addition & 4 deletions src/impl/Impl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ using ForwardDiff: ForwardDiff

using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index

using LoopVectorization: LoopVectorization, @turbo, @tturbo, indices
using Octavian: Octavian
using Polyester: @batch

using LinearAlgebra: LinearAlgebra, mul!
Expand All @@ -31,15 +29,14 @@ using ..Utils: Utils, NotaNumber, batchview, concrete_bias_act_output_eltype, co
copy_drop_gradients, eltype_mismatch, expand_batchdim,
maybe_reduce_BLAS_threads, ofeltype_array, only_derivative, remove_tracking,
reset_BLAS_threads, run_ka_kernel, safe_eltype, safe_vec, safe_warning,
unsafe_known, unrolled_mapreduce, @enzyme_alternative
unsafe_known, unrolled_mapreduce, can_loopvec_args, @enzyme_alternative
using ..Traits: activation_intermediate_not_needed, activation_has_rrule, is_mutable_array,
fuse_cpu_activation
using ..System: explicit_blas_loaded, use_octavian, fits_in_l1cache, fits_in_l2cache,
fits_in_l3cache

const CRC = ChainRulesCore
const KA = KernelAbstractions
const LV = LoopVectorization

include("activation.jl")
include("batched_mul.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/impl/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,11 @@ end
@inbounds function ∇activation(::LoopedArrayOp, Δ, out, act::F, x) where {F}
y = similar(out)
if x isa NotaNumber
@simd ivdep for i in indices((Δ, out))
@simd ivdep for i in eachindex(Δ, out)
@inbounds y[i] = only_derivative(out[i], act, x) * Δ[i]
end
else
@simd ivdep for i in indices((Δ, out, x))
@simd ivdep for i in eachindex(Δ, out, x)
@inbounds y[i] = only_derivative(out[i], act, x[i]) * Δ[i]
end
end
Expand Down
24 changes: 12 additions & 12 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,28 @@ end

function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp,
x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT}
if !LV.check_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) ||
unsafe_known(explicit_blas_loaded())
NNlib.batched_mul!(z, x, y)
if can_loopvec_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) &&
!unsafe_known(explicit_blas_loaded())
batched_matmul_loopvec_impl!(z, x, y)
return
end
batched_matmul_loopvec_impl!(z, x, y)
NNlib.batched_mul!(z, x, y)
return
end

function batched_matmul_loopvec_impl!(
z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3},
y::AbstractArray{yT, 3}, α::Number=true, β::Number=false) where {zT, xT, yT}
if size(x, 3) == size(y, 3)
@batch for L in indices((z, x, y), 3)
@batch for L in axes(z, 3)
serial_matmul_loopvec!(batchview(z, L), batchview(x, L), batchview(y, L), α, β)
end
elseif size(x, 3) == 1
@batch for L in indices((z, y), 3)
@batch for L in axes(z, 3)
serial_matmul_loopvec!(batchview(z, L), batchview(x, 1), batchview(y, L), α, β)
end
else # has to be size(y, 3) == 1
@batch for L in indices((z, x), 3)
@batch for L in axes(z, 3)
serial_matmul_loopvec!(batchview(z, L), batchview(x, L), batchview(y, 1), α, β)
end
end
Expand All @@ -96,15 +96,15 @@ function fallback_batched_matmul!(
throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul."))
end
if size(x, 3) == size(y, 3)
Threads.@threads for L in indices((x, y), 3)
Threads.@threads for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, L), batchview(y, L))
end
elseif size(x, 3) == 1
Threads.@threads for L in indices((x, y), 3)
Threads.@threads for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, 1), batchview(y, L))
end
else # has to be size(y, 3) == 1
Threads.@threads for L in indices((x, y), 3)
Threads.@threads for L in axes(z, 3)
mul!(batchview(z, L), batchview(x, L), batchview(y, 1))
end
end
Expand Down Expand Up @@ -192,7 +192,7 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!)
if size(dA, 3) == 1 && size(B.val, 3) != 1
B′ = NNlib.batched_adjoint(B.val)
dA′ = batchview(dA, 1)
for L in indices(B′, 3)
for L in axes(B′, 3)
mul!(dA′, batchview(dC, L),
batchview(B′, L), true, true)
end
Expand All @@ -205,7 +205,7 @@ for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!)
if size(dB, 3) == 1 && size(A.val, 3) != 1
A′ = NNlib.batched_adjoint(A.val)
dB′ = batchview(dB, 1)
for L in indices(A′, 3)
for L in axes(A′, 3)
mul!(dB′, batchview(A′, L),
batchview(dC, L), true, true)
end
Expand Down
Loading

0 comments on commit c75af09

Please sign in to comment.