From a71f40a69becb71d077cf97a9f60c397d6beeeea Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 20 Nov 2024 14:33:56 -0500 Subject: [PATCH] docs: add 3rd order AD example using Reactant --- docs/make.jl | 2 + docs/src/.vitepress/config.mts | 4 ++ docs/src/manual/nested_autodiff.md | 14 ++--- docs/src/manual/nested_autodiff_reactant.md | 66 +++++++++++++++++++++ 4 files changed, 79 insertions(+), 7 deletions(-) create mode 100644 docs/src/manual/nested_autodiff_reactant.md diff --git a/docs/make.jl b/docs/make.jl index 8d407f3d2..eaa360eaf 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -47,6 +47,8 @@ pages = [ "manual/distributed_utils.md", "manual/nested_autodiff.md", "manual/compiling_lux_models.md", + "manual/exporting_to_jax.md", + "manual/nested_autodiff_reactant.md" ], "API Reference" => [ "Lux" => [ diff --git a/docs/src/.vitepress/config.mts b/docs/src/.vitepress/config.mts index 3019b10a4..2b8abfe1e 100644 --- a/docs/src/.vitepress/config.mts +++ b/docs/src/.vitepress/config.mts @@ -314,6 +314,10 @@ export default defineConfig({ text: "Exporting Lux Models to Jax", link: "/manual/exporting_to_jax", }, + { + text: "Nested AutoDiff", + link: "/manual/nested_autodiff_reactant", + } ], }, { diff --git a/docs/src/manual/nested_autodiff.md b/docs/src/manual/nested_autodiff.md index 19270b5be..d92b9f852 100644 --- a/docs/src/manual/nested_autodiff.md +++ b/docs/src/manual/nested_autodiff.md @@ -1,16 +1,16 @@ # [Nested Automatic Differentiation](@id nested_autodiff) -!!! note - - This is a relatively new feature in Lux, so there might be some rough edges. If you - encounter any issues, please let us know by opening an issue on the - [GitHub repository](https://github.com/LuxDL/Lux.jl). - In this manual, we will explore how to use automatic differentiation (AD) inside your layers or loss functions and have Lux automatically switch the AD backend with a faster one when needed. -!!! tip +!!! tip "Reactant Support" + + Reactant + Lux natively supports Nested AD (even higher dimensions). If you are using + Reactant, please see the [Nested AD with Reactant](@ref nested_autodiff_reactant) + manual. + +!!! tip "Disabling Nested AD Switching" Don't wan't Lux to do this switching for you? You can disable it by setting the `automatic_nested_ad_switching` Preference to `false`. diff --git a/docs/src/manual/nested_autodiff_reactant.md b/docs/src/manual/nested_autodiff_reactant.md new file mode 100644 index 000000000..297a5d5b6 --- /dev/null +++ b/docs/src/manual/nested_autodiff_reactant.md @@ -0,0 +1,66 @@ +# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant) + +We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614). + +```@example nested_ad_reactant +using Reactant, Enzyme, Lux, Random, LinearAlgebra + +const xdev = reactant_device() +const cdev = cpu_device() + +# XXX: We need to be able to compile this with a for-loop else tracing time will scale +# proportionally to the number of elements in the input. +function ∇potential(potential, x) + dxs = onehot(x) + ∇p = similar(x) + for i in eachindex(dxs) + dxᵢ = dxs[i] + res = only(Enzyme.autodiff( + Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ) + )) + @allowscalar ∇p[i] = res[i] + end + return ∇p +end + +function ∇²potential(potential, x) + dxs = onehot(x) + ∇²p = similar(x) + for i in eachindex(dxs) + dxᵢ = dxs[i] + res = only(Enzyme.autodiff( + Enzyme.set_abi(Forward, Reactant.ReactantABI), + ∇potential, Const(potential), Duplicated(x, dxᵢ) + )) + @allowscalar ∇²p[i] = res[i] + end + return ∇²p +end + +struct PotentialNet{P} <: Lux.AbstractLuxWrapperLayer{:potential} + potential::P +end + +function (potential::PotentialNet)(x, ps, st) + pnet = StatefulLuxLayer{true}(potential.potential, ps, st) + return ∇²potential(pnet, x), pnet.st +end + +model = PotentialNet(Dense(5 => 5, gelu)) +ps, st = Lux.setup(Random.default_rng(), model) |> xdev + +x_ra = randn(Float32, 5, 3) |> xdev + +model_compiled = @compile model(x_ra, ps, st) +model_compiled(x_ra, ps, st) + +sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st))) + +function enzyme_gradient(model, x, ps, st) + return Enzyme.gradient( + Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st) + ) +end + +@jit enzyme_gradient(model, x_ra, ps, st) +```