Skip to content

Commit

Permalink
Set default batch size to 1 (#340)
Browse files Browse the repository at this point in the history
* Default batch size to 1

* Add test

* Remove ambiguity

* Default Enzyme batch size to 8

* Actuallly 16

* Avoid ambiguity
  • Loading branch information
gdalle authored Jun 27, 2024
1 parent ff529cb commit 102fa86
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 35 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.5.6"
version = "0.5.7"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ reverse_mode(::AnyAutoEnzyme{Nothing}) = Reverse

DI.check_available(::AutoEnzyme) = true

# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
DI.pick_batchsize(::AnyAutoEnzyme, dimension::Integer) = min(dimension, 16)

# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type
function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
b = zero(a)
Expand Down
4 changes: 4 additions & 0 deletions DifferentiationInterface/src/misc/from_primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ abstract type FromPrimitive <: AbstractADType end
check_available(fromprim::FromPrimitive) = check_available(fromprim.backend)
twoarg_support(fromprim::FromPrimitive) = twoarg_support(fromprim.backend)

function pick_batchsize(fromprim::FromPrimitive, dimension::Integer)
return pick_batchsize(fromprim.backend, dimension)
end

## Forward

struct AutoForwardFromPrimitive{B} <: FromPrimitive
Expand Down
6 changes: 3 additions & 3 deletions DifferentiationInterface/src/utils/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
pick_batchsize(backend::AbstractADType, dimension::Integer)
Pick a reasonable batch size for batched derivative evaluation with a given total `dimension`.
Returns `1` for backends which have not overloaded it.
"""
function pick_batchsize(::AbstractADType, dimension::Integer)
return min(dimension, 8)
end
pick_batchsize(::AbstractADType, dimension::Integer) = 1

"""
Batch{B,T}
Expand Down
22 changes: 0 additions & 22 deletions DifferentiationInterface/test/Internals/autosparse.jl

This file was deleted.

36 changes: 36 additions & 0 deletions DifferentiationInterface/test/Internals/backends.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
using ADTypes
using DifferentiationInterface
import DifferentiationInterface as DI
using Test

@testset "SecondOrder" begin
backend = SecondOrder(AutoForwardDiff(), AutoZygote())
@test ADTypes.mode(backend) isa ADTypes.ForwardMode
@test DifferentiationInterface.outer(backend) isa AutoForwardDiff
@test DifferentiationInterface.inner(backend) isa AutoZygote
end

@testset "Sparse" begin
for backend in [AutoForwardDiff(), AutoZygote()]
sparse_backend = AutoSparse(backend)
@test ADTypes.mode(sparse_backend) == ADTypes.mode(backend)
@test check_available(sparse_backend) == check_available(backend)
@test DI.twoarg_support(sparse_backend) == DI.twoarg_support(backend)
@test DI.pushforward_performance(sparse_backend) ==
DI.pushforward_performance(backend)
@test DI.pullback_performance(sparse_backend) == DI.pullback_performance(backend)
end

for backend in [
SecondOrder(AutoForwardDiff(), AutoZygote()),
SecondOrder(AutoZygote(), AutoForwardDiff()),
]
sparse_backend = AutoSparse(backend)
@test ADTypes.mode(sparse_backend) == ADTypes.mode(backend)
@test DI.hvp_mode(sparse_backend) == DI.hvp_mode(backend)
end
end

@testset "Batch size" begin
@test DI.pick_batchsize(AutoZygote(), 2) == 1
end
9 changes: 0 additions & 9 deletions DifferentiationInterface/test/Internals/second_order.jl

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ for backend in vcat(fromprimitive_backends)
@test check_available(backend)
@test check_twoarg(backend)
@test check_hessian(backend)
@test DifferentiationInterface.pick_batchsize(backend, 100) == 5
end

## Dense backends
Expand Down

0 comments on commit 102fa86

Please sign in to comment.