Skip to content

Commit

Permalink
feat: emit batch_norm ops from stablehlo (#1142)
Browse files Browse the repository at this point in the history
* feat: emit batch_norm ops from stablehlo

* refactor: only implement inference path for now

* test: batchnorm layers
  • Loading branch information
avik-pal authored Jan 17, 2025
1 parent c3f8b02 commit 30e7b01
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
LossFunctions = "0.11.1, 1"
LuxCore = "1.2"
LuxLib = "1.3.7"
LuxLib = "1.5.0"
MLDataDevices = "1.6.6"
MLUtils = "0.4.4"
MPI = "0.20.19"
Expand Down
1 change: 1 addition & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Utils.contiguous(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_a

Utils.eltype(::Type{<:TracedRArray{T, N}}) where {T, N} = T
Utils.eltype(::Type{<:TracedRNumber{T}}) where {T} = T
Utils.eltype(x::Reactant.AnyTracedRArray) = Reactant.unwrapped_eltype(x)

function Utils.promote_to(::Type{T}, x::Number) where {T <: Number}
x isa Reactant.TracedType && return x
Expand Down
21 changes: 12 additions & 9 deletions lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.1"
version = "1.5.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -19,8 +19,8 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Expand All @@ -32,23 +32,29 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[sources]
LuxCore = {path = "../LuxCore"}
MLDataDevices = {path = "../MLDataDevices"}

[extensions]
LuxLibAppleAccelerateExt = "AppleAccelerate"
LuxLibBLISBLASExt = "BLISBLAS"
LuxLibCUDAExt = "CUDA"
LuxLibMKLExt = "MKL"
LuxLibEnzymeExt = "Enzyme"
LuxLibLoopVectorizationExt = "LoopVectorization"
LuxLibMKLExt = "MKL"
LuxLibOctavianExt = ["Octavian", "LoopVectorization"]
LuxLibReactantExt = "Reactant"
LuxLibReverseDiffExt = "ReverseDiff"
LuxLibSLEEFPiratesExt = "SLEEFPirates"
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
Expand Down Expand Up @@ -79,9 +85,10 @@ MLDataDevices = "1.6"
Markdown = "1.10"
NNlib = "0.9.26"
Octavian = "0.3.28"
Preferences = "1.4.3"
Polyester = "0.7.15"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.13"
Reexport = "1"
ReverseDiff = "1.15"
SLEEFPirates = "0.6.43"
Expand All @@ -91,7 +98,3 @@ Statistics = "1.10"
Tracker = "0.2.36"
cuDNN = "1.3"
julia = "1.10"

[sources]
LuxCore = { path = "../LuxCore" }
MLDataDevices = { path = "../MLDataDevices" }
106 changes: 106 additions & 0 deletions lib/LuxLib/ext/LuxLibReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
module LuxLibReactantExt

using Reactant: Reactant, MLIR, Ops, TracedUtils, TracedRArray, AnyTracedRArray,
AnyTracedRVector, TracedRNumber
using Static: False

using LuxLib: LuxLib, Impl, Optional, Utils

# Most of the NN code gen happens in Reactant.jl via an extension on NNlib, however,
# NNlib doesn't have certain ops implemented. In those cases we can emit more optimized
# StableHLO
function Impl.batchnorm(
x::AnyTracedRArray{T},
γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector},
::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector},
::False, act::F, momentum, ϵ
) where {T, F}
x = TracedUtils.materialize_traced_array(x)

γ = if γ === nothing
Ops.constant(fill(T(1), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(γ)
end
β = if β === nothing
Ops.constant(fill(T(0), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(β)
end

if=== nothing && rσ² === nothing
μ, σ² = Impl.mean_var(
x; dims=Utils.unsafe_known(Impl.batchnorm_reduce_dims(x)), corrected=false
)
μ = TracedUtils.materialize_traced_array(vec(μ))
σ² = TracedUtils.materialize_traced_array(vec(σ²))
else
@assert!== nothing && rσ² !== nothing
μ = TracedUtils.materialize_traced_array(rμ)
σ² = TracedUtils.materialize_traced_array(rσ²)
end

res = MLIR.IR.result(
MLIR.Dialects.stablehlo.batch_norm_inference(
TracedUtils.get_mlir_data(x),
TracedUtils.get_mlir_data(γ),
TracedUtils.get_mlir_data(β),
TracedUtils.get_mlir_data(μ),
TracedUtils.get_mlir_data(σ²);
epsilon=Float32(ϵ),
feature_index=Int64(ndims(x) - 2)
),
1
)

return act.(TracedRArray{T, ndims(x)}((), res, size(x))), rμ, rσ²
end

# The following code is commented out since we don't have Batchnorm Op Adjoint registered
# for EnzymeJAX yet
#=
function Impl.batchnorm(
x::AnyTracedRArray{T},
γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector},
rμ::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector},
training::StaticBool, act::F, momentum, ϵ
) where {T, F}
x = TracedUtils.materialize_traced_array(x)
γ = if γ === nothing
Ops.constant(fill(T(1), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(γ)
end
β = if β === nothing
Ops.constant(fill(T(0), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(β)
end
op = MLIR.Dialects.stablehlo.batch_norm_training(
TracedUtils.get_mlir_data(x),
TracedUtils.get_mlir_data(γ),
TracedUtils.get_mlir_data(β);
epsilon=Float32(ϵ),
feature_index=Int64(ndims(x) - 2)
)
res = act.(TracedRArray{T, ndims(x)}((), MLIR.IR.result(op, 1), size(x)))
μ = TracedRArray{T, 1}((), MLIR.IR.result(op, 2), size(x, ndims(x) - 1))
σ² = TracedRArray{T, 1}((), MLIR.IR.result(op, 3), size(x, ndims(x) - 1))
if rμ === nothing && rσ² === nothing
return res, nothing, nothing
else
@assert rμ !== nothing && rσ² !== nothing
m = T(Impl.accum_size(x, Impl.batchnorm_reduce_dims(x)))
rμ, rσ² = Impl.update_running_statistics(
rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))
)
return res, rμ, rσ²
end
end
=#

end
67 changes: 67 additions & 0 deletions test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,70 @@ end
end
end
end

@testitem "BatchNorm Layer" tags=[:reactant] setup=[
SharedTestSetup, SharedReactantLayersTestSetup] skip=:(Sys.iswindows()) begin
using Reactant, Lux, Random

@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

dev = reactant_device(; force=true)

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

@testset for track_stats in (true, false), affine in (true, false),
act in (identity, tanh)

model = Chain(
Dense(2 => 3, tanh),
BatchNorm(3, act; track_stats, affine, init_bias=rand32, init_scale=rand32),
Dense(3 => 2)
)

x = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), model)

x_ra = x |> dev
ps_ra = ps |> dev
st_ra = st |> dev

y, st2 = model(x, ps, st)
y_ra, st2_ra = @jit model(x_ra, ps_ra, st_ra)

@test yy_ra rtol=1e-3 atol=1e-3
if track_stats
@test st2.layer_2.running_meanst2_ra.layer_2.running_mean rtol=1e-3 atol=1e-3
@test st2.layer_2.running_varst2_ra.layer_2.running_var rtol=1e-3 atol=1e-3
end

# TODO: Check for stablehlo.batch_norm_training once we emit it in LuxLib

@testset "gradient" begin
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
@test ∂x_ra∂x atol=1e-2 rtol=1e-2
@test check_approx(∂ps_ra, ∂ps; atol=1e-2, rtol=1e-2)
end

y2, st3 = model(x, ps, Lux.testmode(st2))
y2_ra, st3_ra = @jit model(x_ra, ps_ra, Lux.testmode(st2_ra))

@test y2y2_ra rtol=1e-3 atol=1e-3
if track_stats
@test st3.layer_2.running_meanst3_ra.layer_2.running_mean rtol=1e-3 atol=1e-3
@test st3.layer_2.running_varst3_ra.layer_2.running_var rtol=1e-3 atol=1e-3
end

hlo = @code_hlo model(x_ra, ps_ra, Lux.testmode(st_ra))
@test contains(repr(hlo), "stablehlo.batch_norm_inference")
end
end
end

2 comments on commit 30e7b01

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=lib/LuxLib

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/123232

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a LuxLib-v1.5.0 -m "<description of version>" 30e7b01aa583b0a6e5862abb783e721e0da370ba
git push origin LuxLib-v1.5.0

Please sign in to comment.