Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

docs: add 3rd order AD example using Reactant #1097

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ pages = [
"manual/distributed_utils.md",
"manual/nested_autodiff.md",
"manual/compiling_lux_models.md",
"manual/exporting_to_jax.md",
"manual/nested_autodiff_reactant.md"
],
"API Reference" => [
"Lux" => [
Expand Down
4 changes: 4 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ export default defineConfig({
text: "Exporting Lux Models to Jax",
link: "/manual/exporting_to_jax",
},
{
text: "Nested AutoDiff",
link: "/manual/nested_autodiff_reactant",
}
],
},
{
Expand Down
14 changes: 7 additions & 7 deletions docs/src/manual/nested_autodiff.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# [Nested Automatic Differentiation](@id nested_autodiff)

!!! note

This is a relatively new feature in Lux, so there might be some rough edges. If you
encounter any issues, please let us know by opening an issue on the
[GitHub repository](https://github.com/LuxDL/Lux.jl).

In this manual, we will explore how to use automatic differentiation (AD) inside your layers
or loss functions and have Lux automatically switch the AD backend with a faster one when
needed.

!!! tip
!!! tip "Reactant Support"

Reactant + Lux natively supports Nested AD (even higher dimensions). If you are using
Reactant, please see the [Nested AD with Reactant](@ref nested_autodiff_reactant)
manual.

!!! tip "Disabling Nested AD Switching"

Don't wan't Lux to do this switching for you? You can disable it by setting the
`automatic_nested_ad_switching` Preference to `false`.
Expand Down
66 changes: 66 additions & 0 deletions docs/src/manual/nested_autodiff_reactant.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# [Nested AutoDiff with Reactant](@id nested_autodiff_reactant)

We will be using the example from [issue 614](https://github.com/LuxDL/Lux.jl/issues/614).

```@example nested_ad_reactant
using Reactant, Enzyme, Lux, Random, LinearAlgebra

const xdev = reactant_device()
const cdev = cpu_device()

# XXX: We need to be able to compile this with a for-loop else tracing time will scale
# proportionally to the number of elements in the input.
function ∇potential(potential, x)
dxs = onehot(x)
∇p = similar(x)
for i in eachindex(dxs)
dxᵢ = dxs[i]
res = only(Enzyme.autodiff(
Enzyme.set_abi(Forward, Reactant.ReactantABI), potential, Duplicated(x, dxᵢ)
))
@allowscalar ∇p[i] = res[i]
end
return ∇p
end

function ∇²potential(potential, x)
dxs = onehot(x)
∇²p = similar(x)
for i in eachindex(dxs)
dxᵢ = dxs[i]
res = only(Enzyme.autodiff(
Enzyme.set_abi(Forward, Reactant.ReactantABI),
∇potential, Const(potential), Duplicated(x, dxᵢ)
))
@allowscalar ∇²p[i] = res[i]
end
return ∇²p
end

struct PotentialNet{P} <: Lux.AbstractLuxWrapperLayer{:potential}
potential::P
end

function (potential::PotentialNet)(x, ps, st)
pnet = StatefulLuxLayer{true}(potential.potential, ps, st)
return ∇²potential(pnet, x), pnet.st
end

model = PotentialNet(Dense(5 => 5, gelu))
ps, st = Lux.setup(Random.default_rng(), model) |> xdev

x_ra = randn(Float32, 5, 3) |> xdev

model_compiled = @compile model(x_ra, ps, st)
model_compiled(x_ra, ps, st)

sumabs2first(model, x, ps, st) = sum(abs2, first(model(x, ps, st)))

function enzyme_gradient(model, x, ps, st)
return Enzyme.gradient(
Enzyme.Reverse, Const(sumabs2first), Const(model), Const(x), ps, Const(st)
)
end

@jit enzyme_gradient(model, x_ra, ps, st)
```
Loading