Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: add a preference to disable loop vectorization
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 18, 2024
1 parent 08f8448 commit 1d21e54
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -78,6 +79,7 @@ MLDataDevices = "1.2"
Markdown = "1.10"
NNlib = "0.9.24"
Octavian = "0.3.28"
Preferences = "1.4.3"
Polyester = "0.7.15"
Random = "1.10"
Reexport = "1"
Expand Down
11 changes: 8 additions & 3 deletions src/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Static: True, False, static
using StaticArraysCore: StaticArray

using ..LuxLib: Numeric
using ..LuxLibPreferences: DISABLE_LOOP_VECTORIZATION
using ..Utils: NotaNumber, only_derivative, unrolled_any, unrolled_map

function fast_scalar_indexing(::T) where {T <: AbstractArray}
Expand Down Expand Up @@ -130,9 +131,13 @@ end

CRC.@non_differentiable explicit_blas_loaded()

function use_octavian()
return is_extension_loaded(Val(:Octavian)) & is_x86_64() &
(INTEL_HARDWARE | AMD_RYZEN_HARDWARE)
@static if DISABLE_LOOP_VECTORIZATION
use_octavian() = False()
else
function use_octavian()
return is_extension_loaded(Val(:Octavian)) & is_x86_64() &
(INTEL_HARDWARE | AMD_RYZEN_HARDWARE)
end
end

CRC.@non_differentiable use_octavian()
Expand Down
20 changes: 18 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
module LuxLibPreferences

using Preferences: load_preference

using ..LuxLib: LuxLib

const DISABLE_LOOP_VECTORIZATION = load_preference(
LuxLib, "disable_loop_vectorization", false)

end

module Utils

using ChainRulesCore: ChainRulesCore
Expand All @@ -12,6 +23,7 @@ using Static: Static, StaticBool, False, True, static
using StaticArraysCore: SVector, SMatrix

using ..LuxLib: Optional, ∂∅
using ..LuxLibPreferences: DISABLE_LOOP_VECTORIZATION

const CRC = ChainRulesCore
const KA = KernelAbstractions
Expand Down Expand Up @@ -325,8 +337,12 @@ end

CRC.@non_differentiable static_training_mode_check(::Any...)

@inline function can_loopvec_args(args...)
return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...)
@static if DISABLE_LOOP_VECTORIZATION
@inline can_loopvec_args(args...) = false
else
@inline function can_loopvec_args(args...)
return can_loopvec_args_check(is_extension_loaded(Val(:LoopVectorization)), args...)
end
end

@inline can_loopvec_args_check(::False, args...) = false
Expand Down

0 comments on commit 1d21e54

Please sign in to comment.