Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: add a preference to disable loop vectorization
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 18, 2024
1 parent 08f8448 commit 8d7c497
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 6 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -78,6 +79,7 @@ MLDataDevices = "1.2"
Markdown = "1.10"
NNlib = "0.9.24"
Octavian = "0.3.28"
Preferences = "1.4.3"
Polyester = "0.7.15"
Random = "1.10"
Reexport = "1"
Expand Down
3 changes: 3 additions & 0 deletions src/LuxLib.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module LuxLib

using Compat: @compat
using Preferences: @load_preference
using Reexport: @reexport
using Static: Static, known

Expand All @@ -15,6 +16,8 @@ const Numeric = Union{AbstractArray{<:T}, T} where {T <: Number}
const ∂∅ = NoTangent()
const CRC = ChainRulesCore

const DISABLE_LOOP_VECTORIZATION = @load_preference("disable_loop_vectorization", false)

include("utils.jl")
include("traits.jl")
include("impl/Impl.jl")
Expand Down
11 changes: 8 additions & 3 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ using ChainRulesCore: ChainRulesCore
using Hwloc: Hwloc
using Static: static, False, True

using ..LuxLib: DISABLE_LOOP_VECTORIZATION
using ..Utils: is_extension_loaded, safe_minimum

const CRC = ChainRulesCore
Expand Down Expand Up @@ -130,9 +131,13 @@ end

CRC.@non_differentiable explicit_blas_loaded()

function use_octavian()
return is_extension_loaded(Val(:Octavian)) & is_x86_64() &
(INTEL_HARDWARE | AMD_RYZEN_HARDWARE)
@static if DISABLE_LOOP_VECTORIZATION
use_octavian() = False()
else
function use_octavian()
return is_extension_loaded(Val(:Octavian)) & is_x86_64() &
(INTEL_HARDWARE | AMD_RYZEN_HARDWARE)
end
end

CRC.@non_differentiable use_octavian()
Expand Down
10 changes: 7 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using NNlib: NNlib
using Static: Static, StaticBool, False, True, static
using StaticArraysCore: SVector, SMatrix

using ..LuxLib: Optional, ∂∅
using ..LuxLib: Optional, ∂∅, DISABLE_LOOP_VECTORIZATION

const CRC = ChainRulesCore
const KA = KernelAbstractions
Expand Down Expand Up @@ -325,8 +325,12 @@ end

CRC.@non_differentiable static_training_mode_check(::Any...)

@inline function can_loopvec_args(args...)
return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...)
@static if DISABLE_LOOP_VECTORIZATION
@inline can_loopvec_args(args...) = false
else
@inline function can_loopvec_args(args...)
return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...)
end
end

@inline can_loopvec_args_check(::False, args...) = false
Expand Down

0 comments on commit 8d7c497

Please sign in to comment.