-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add DenseSparsityDetector * Doctest * Add local warning * More warnings * Tests with more shapes * Fix matrices * Doc * Coverage
- Loading branch information
Showing
6 changed files
with
257 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
""" | ||
DenseSparsityDetector | ||
Sparsity pattern detector satisfying the [detection API](https://sciml.github.io/ADTypes.jl/stable/#Sparse-AD) of [ADTypes.jl](https://github.com/SciML/ADTypes.jl). | ||
The nonzeros in a Jacobian or Hessian are detected by computing the relevant matrix with _dense_ AD, and thresholding the entries with a given tolerance (which can be numerically inaccurate). | ||
!!! warning | ||
This detector can be very slow, and should only be used if its output can be exploited multiple times to compute many sparse matrices. | ||
!!! danger | ||
In general, the sparsity pattern you obtain can depend on the provided input `x`. If you want to reuse the pattern, make sure that it is input-agnostic. | ||
# Fields | ||
- `backend::AbstractADType` is the dense AD backend used under the hood | ||
- `atol::Float64` is the minimum magnitude of a matrix entry to be considered nonzero | ||
# Constructor | ||
DenseSparsityDetector(backend; atol, method=:iterative) | ||
The keyword argument `method::Symbol` can be either: | ||
- `:iterative`: compute the matrix in a sequence of matrix-vector products (memory-efficient) | ||
- `:direct`: compute the matrix all at once (memory-hungry but sometimes faster). | ||
Note that the constructor is type-unstable because `method` ends up being a type parameter of the `DenseSparsityDetector` object (this is not part of the API and might change). | ||
# Examples | ||
```jldoctest detector | ||
using ADTypes, DifferentiationInterface, SparseArrays | ||
import ForwardDiff | ||
detector = DenseSparsityDetector(AutoForwardDiff(); atol=1e-5, method=:direct) | ||
ADTypes.jacobian_sparsity(diff, rand(5), detector) | ||
# output | ||
4×5 SparseMatrixCSC{Bool, Int64} with 8 stored entries: | ||
1 1 ⋅ ⋅ ⋅ | ||
⋅ 1 1 ⋅ ⋅ | ||
⋅ ⋅ 1 1 ⋅ | ||
⋅ ⋅ ⋅ 1 1 | ||
``` | ||
Sometimes the sparsity pattern is input-dependent: | ||
```jldoctest detector | ||
ADTypes.jacobian_sparsity(x -> [prod(x)], rand(2), detector) | ||
# output | ||
1×2 SparseMatrixCSC{Bool, Int64} with 2 stored entries: | ||
1 1 | ||
``` | ||
```jldoctest detector | ||
ADTypes.jacobian_sparsity(x -> [prod(x)], [0, 1], detector) | ||
# output | ||
1×2 SparseMatrixCSC{Bool, Int64} with 1 stored entry: | ||
1 ⋅ | ||
``` | ||
""" | ||
struct DenseSparsityDetector{method,B} <: ADTypes.AbstractSparsityDetector | ||
backend::B | ||
atol::Float64 | ||
end | ||
|
||
function Base.show(io::IO, detector::DenseSparsityDetector{method}) where {method} | ||
@compat (; backend, atol) = detector | ||
return print(io, "DenseSparsityDetector{:$method}($backend; atol=$atol)") | ||
end | ||
|
||
function DenseSparsityDetector( | ||
backend::AbstractADType; atol::Float64, method::Symbol=:iterative | ||
) | ||
if !(method in (:iterative, :direct)) | ||
throw( | ||
ArgumentError("The keyword `method` must be either `:iterative` or `:direct`.") | ||
) | ||
end | ||
return DenseSparsityDetector{method,typeof(backend)}(backend, atol) | ||
end | ||
|
||
## Direct | ||
|
||
function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:direct}) | ||
@compat (; backend, atol) = detector | ||
J = jacobian(f, backend, x) | ||
return sparse(abs.(J) .> atol) | ||
end | ||
|
||
function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:direct}) | ||
@compat (; backend, atol) = detector | ||
J = jacobian(f!, y, backend, x) | ||
return sparse(abs.(J) .> atol) | ||
end | ||
|
||
function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:direct}) | ||
@compat (; backend, atol) = detector | ||
H = hessian(f, backend, x) | ||
return sparse(abs.(H) .> atol) | ||
end | ||
|
||
## Iterative | ||
|
||
function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:iterative}) | ||
@compat (; backend, atol) = detector | ||
y = f(x) | ||
n, m = length(x), length(y) | ||
I, J = Int[], Int[] | ||
if pushforward_performance(backend) isa PushforwardFast | ||
p = similar(y) | ||
extras = prepare_pushforward_same_point( | ||
f, backend, x, basis(backend, x, first(CartesianIndices(x))) | ||
) | ||
for (kj, j) in enumerate(CartesianIndices(x)) | ||
pushforward!(f, p, backend, x, basis(backend, x, j), extras) | ||
for ki in LinearIndices(p) | ||
if abs(p[ki]) > atol | ||
push!(I, ki) | ||
push!(J, kj) | ||
end | ||
end | ||
end | ||
else | ||
p = similar(x) | ||
extras = prepare_pullback_same_point( | ||
f, backend, x, basis(backend, y, first(CartesianIndices(y))) | ||
) | ||
for (ki, i) in enumerate(CartesianIndices(y)) | ||
pullback!(f, p, backend, x, basis(backend, y, i), extras) | ||
for kj in LinearIndices(p) | ||
if abs(p[kj]) > atol | ||
push!(I, ki) | ||
push!(J, kj) | ||
end | ||
end | ||
end | ||
end | ||
return sparse(I, J, ones(Bool, length(I)), m, n) | ||
end | ||
|
||
function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:iterative}) | ||
@compat (; backend, atol) = detector | ||
n, m = length(x), length(y) | ||
I, J = Int[], Int[] | ||
if pushforward_performance(backend) isa PushforwardFast | ||
p = similar(y) | ||
extras = prepare_pushforward_same_point( | ||
f!, y, backend, x, basis(backend, x, first(CartesianIndices(x))) | ||
) | ||
for (kj, j) in enumerate(CartesianIndices(x)) | ||
pushforward!(f!, y, p, backend, x, basis(backend, x, j), extras) | ||
for ki in LinearIndices(p) | ||
if abs(p[ki]) > atol | ||
push!(I, ki) | ||
push!(J, kj) | ||
end | ||
end | ||
end | ||
else | ||
p = similar(x) | ||
extras = prepare_pullback_same_point( | ||
f!, y, backend, x, basis(backend, y, first(CartesianIndices(y))) | ||
) | ||
for (ki, i) in enumerate(CartesianIndices(y)) | ||
pullback!(f!, y, p, backend, x, basis(backend, y, i), extras) | ||
for kj in LinearIndices(p) | ||
if abs(p[kj]) > atol | ||
push!(I, ki) | ||
push!(J, kj) | ||
end | ||
end | ||
end | ||
end | ||
return sparse(I, J, ones(Bool, length(I)), m, n) | ||
end | ||
|
||
function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:iterative}) | ||
@compat (; backend, atol) = detector | ||
n = length(x) | ||
I, J = Int[], Int[] | ||
p = similar(x) | ||
extras = prepare_hvp_same_point( | ||
f, backend, x, basis(backend, x, first(CartesianIndices(x))) | ||
) | ||
for (kj, j) in enumerate(CartesianIndices(x)) | ||
hvp!(f, p, backend, x, basis(backend, x, j), extras) | ||
for ki in LinearIndices(p) | ||
if abs(p[ki]) > atol | ||
push!(I, ki) | ||
push!(J, kj) | ||
end | ||
end | ||
end | ||
return sparse(I, J, ones(Bool, length(I)), n, n) | ||
end |
43 changes: 43 additions & 0 deletions
43
DifferentiationInterface/test/Double/Enzyme-ForwardDiff/detector.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
using ADTypes: jacobian_sparsity, hessian_sparsity | ||
using DifferentiationInterface | ||
using ForwardDiff: ForwardDiff | ||
using Enzyme: Enzyme | ||
using LinearAlgebra | ||
using SparseArrays | ||
using StableRNGs | ||
using Test | ||
|
||
rng = StableRNG(63) | ||
|
||
const Jc = sprand(rng, Bool, 10, 20, 0.3) | ||
const Hc = sparse(Symmetric(sprand(rng, Bool, 20, 20, 0.3))) | ||
|
||
f(x::AbstractVector) = Jc * x | ||
f(x::AbstractMatrix) = reshape(f(vec(x)), (5, 2)) | ||
|
||
function f!(y, x) | ||
y .= f(x) | ||
return nothing | ||
end | ||
|
||
g(x::AbstractVector) = dot(x, Hc, x) | ||
g(x::AbstractMatrix) = g(vec(x)) | ||
|
||
@testset verbose = true "$(typeof(backend))" for backend in [ | ||
AutoEnzyme(; mode=Enzyme.Reverse), AutoForwardDiff() | ||
] | ||
@test_throws ArgumentError DenseSparsityDetector(backend; atol=1e-5, method=:random) | ||
@testset "$method" for method in (:iterative, :direct) | ||
detector = DenseSparsityDetector(backend; atol=1e-5, method) | ||
string(detector) | ||
for (x, y) in ((rand(20), zeros(10)), (rand(2, 10), zeros(5, 2))) | ||
@test Jc == jacobian_sparsity(f, x, detector) | ||
@test Jc == jacobian_sparsity(f!, copy(y), x, detector) | ||
end | ||
if backend isa AutoForwardDiff | ||
for x in (rand(20), rand(2, 10)) | ||
@test Hc == hessian_sparsity(g, x, detector) | ||
end | ||
end | ||
end | ||
end |