diff --git a/Project.toml b/Project.toml index 76b516fd..59f30dfe 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.11.1" +version = "0.11.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,8 +11,8 @@ Richardson = "708f8203-808e-40c0-ba2d-98a6953ed40d" [compat] ChainRulesCore = "0.9" -julia = "1" Richardson = "1.2" +julia = "1" [extras] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/to_vec.jl b/src/to_vec.jl index 4f74822f..ae311186 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -139,3 +139,26 @@ function to_vec(d::Dict) end return d_vec, Dict_from_vec end + + +# ChainRulesCore Differentials +function FiniteDifferences.to_vec(x::Composite{P}) where{P} + x_canon = canonicalize(x) # to be safe, fill in every field and put in primal order. + x_inner = ChainRulesCore.backing(x_canon) + x_vec, back_inner = FiniteDifferences.to_vec(x_inner) + function Composite_from_vec(y_vec) + y_back = back_inner(y_vec) + return Composite{P, typeof(y_back)}(y_back) + end + return x_vec, Composite_from_vec +end + + +function FiniteDifferences.to_vec(x::AbstractZero) + function AbstractZero_from_vec(z) + length(z) == 1 || throw(DimensionMismatch("tried to go back to $x from $z")) + iszero(first(z)) || throw(DomainError(first(z))) + return x + end + return [false], AbstractZero_from_vec +end diff --git a/test/to_vec.jl b/test/to_vec.jl index b105881f..5fbf2541 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -18,6 +18,14 @@ struct FillVector <: AbstractVector{Float64} len::Int end +# For testing Composite{ThreeFields} +struct ThreeFields + a + b + c +end + + Base.size(x::FillVector) = (x.len,) Base.getindex(x::FillVector, n::Int) = x.x @@ -105,6 +113,43 @@ end end end + @testset "ChainRulesCore Differentials" begin + @testset "Composite{Tuple}" begin + @testset "basic" begin + x_tup = (1.0, 2.0, 3.0) + x_comp = Composite{typeof(x_tup)}(x_tup...) + test_to_vec(x_comp) + end + + @testset "nested" begin + x_inner = (2, 3) + x_outer = (1, x_inner) + x_comp = Composite{typeof(x_outer)}(1, Composite{typeof(x_inner)}(2, 3)) + test_to_vec(x_comp) + end + end + + @testset "Composite Struct" begin + @testset "NamedTuple basic" begin + nt = (; a=1.0, b=20.0) + comp = Composite{typeof(nt)}(; nt...) + test_to_vec(comp) + end + + @testset "Struct" begin + test_to_vec(Composite{ThreeFields}(; a=10.0, b=20.0, c=30.0)) + test_to_vec(Composite{ThreeFields}(; a=10.0, b=20.0,)) + test_to_vec(Composite{ThreeFields}(; a=10.0, c=30.0)) + test_to_vec(Composite{ThreeFields}(; c=30.0, a=10.0, b=20.0)) + end + end + + @testset "AbstractZero" begin + test_to_vec(Zero()) + test_to_vec(DoesNotExist()) + end + end + @testset "FillVector" begin x = FillVector(5.0, 10) x_vec, from_vec = to_vec(x)