From 1107cbf6a0a55c9b9b1fb63407326da58003cb7a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 12 Oct 2020 14:16:44 +0100 Subject: [PATCH 1/2] Add to_vec Composite{Tuple} and AbstractZero --- src/to_vec.jl | 22 ++++++++++++++++++++++ test/to_vec.jl | 23 +++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/src/to_vec.jl b/src/to_vec.jl index 4f74822f..b9fe81d7 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -139,3 +139,25 @@ function to_vec(d::Dict) end return d_vec, Dict_from_vec end + + +# ChainRulesCore Differentials +function FiniteDifferences.to_vec(x::Composite{P, T}) where{P, T<:Tuple} + x_tuple = convert(Tuple, x) + x_vec, back_tuple = FiniteDifferences.to_vec(x_tuple) + function CompositeTuple_from_vec(y_vec) + y_tuple = back_tuple(y_vec) + return Composite{P, typeof(y_tuple)}(y_tuple) + end + return x_vec, CompositeTuple_from_vec +end + + +function FiniteDifferences.to_vec(x::AbstractZero) + function from_vec_AbstractZero(z) + length(z) == 1 || throw(DimensionMismatch("tried to go back to $x from $z")) + iszero(first(z)) || throw(DomainError(first(z))) + return x + end + return [false], from_vec_AbstractZero +end diff --git a/test/to_vec.jl b/test/to_vec.jl index b105881f..9b0cc2f4 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -105,6 +105,29 @@ end end end + @testset "ChainRulesCore Differentials" begin + @testset "Composite{Tuple}" begin + @testset "basic" begin + x_tup = (1.0, 2.0, 3.0) + x_comp = Composite{typeof(x_tup)}(x_tup...) + test_to_vec(x_comp) + end + + @testset "nested" begin + x_inner = (2, 3) + x_outer = (1, x_inner) + x_comp = Composite{typeof(x_outer)}(1, Composite{typeof(x_inner)}(2, 3)) + + test_to_vec(x_comp) + end + end + + @testset "AbstractZero" begin + test_to_vec(Zero()) + test_to_vec(DoesNotExist()) + end + end + @testset "FillVector" begin x = FillVector(5.0, 10) x_vec, from_vec = to_vec(x) From c8c2158833dd0058560669f13bb21460b1b4b1ae Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 12 Oct 2020 15:20:07 +0100 Subject: [PATCH 2/2] Add to_vec on struct namestuples Update src/to_vec.jl Co-authored-by: Nick Robinson Update src/to_vec.jl --- Project.toml | 4 ++-- src/to_vec.jl | 19 ++++++++++--------- test/to_vec.jl | 24 +++++++++++++++++++++++- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 76b516fd..59f30dfe 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "FiniteDifferences" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.11.1" +version = "0.11.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -11,8 +11,8 @@ Richardson = "708f8203-808e-40c0-ba2d-98a6953ed40d" [compat] ChainRulesCore = "0.9" -julia = "1" Richardson = "1.2" +julia = "1" [extras] Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" diff --git a/src/to_vec.jl b/src/to_vec.jl index b9fe81d7..ae311186 100644 --- a/src/to_vec.jl +++ b/src/to_vec.jl @@ -142,22 +142,23 @@ end # ChainRulesCore Differentials -function FiniteDifferences.to_vec(x::Composite{P, T}) where{P, T<:Tuple} - x_tuple = convert(Tuple, x) - x_vec, back_tuple = FiniteDifferences.to_vec(x_tuple) - function CompositeTuple_from_vec(y_vec) - y_tuple = back_tuple(y_vec) - return Composite{P, typeof(y_tuple)}(y_tuple) +function FiniteDifferences.to_vec(x::Composite{P}) where{P} + x_canon = canonicalize(x) # to be safe, fill in every field and put in primal order. + x_inner = ChainRulesCore.backing(x_canon) + x_vec, back_inner = FiniteDifferences.to_vec(x_inner) + function Composite_from_vec(y_vec) + y_back = back_inner(y_vec) + return Composite{P, typeof(y_back)}(y_back) end - return x_vec, CompositeTuple_from_vec + return x_vec, Composite_from_vec end function FiniteDifferences.to_vec(x::AbstractZero) - function from_vec_AbstractZero(z) + function AbstractZero_from_vec(z) length(z) == 1 || throw(DimensionMismatch("tried to go back to $x from $z")) iszero(first(z)) || throw(DomainError(first(z))) return x end - return [false], from_vec_AbstractZero + return [false], AbstractZero_from_vec end diff --git a/test/to_vec.jl b/test/to_vec.jl index 9b0cc2f4..5fbf2541 100644 --- a/test/to_vec.jl +++ b/test/to_vec.jl @@ -18,6 +18,14 @@ struct FillVector <: AbstractVector{Float64} len::Int end +# For testing Composite{ThreeFields} +struct ThreeFields + a + b + c +end + + Base.size(x::FillVector) = (x.len,) Base.getindex(x::FillVector, n::Int) = x.x @@ -117,11 +125,25 @@ end x_inner = (2, 3) x_outer = (1, x_inner) x_comp = Composite{typeof(x_outer)}(1, Composite{typeof(x_inner)}(2, 3)) - test_to_vec(x_comp) end end + @testset "Composite Struct" begin + @testset "NamedTuple basic" begin + nt = (; a=1.0, b=20.0) + comp = Composite{typeof(nt)}(; nt...) + test_to_vec(comp) + end + + @testset "Struct" begin + test_to_vec(Composite{ThreeFields}(; a=10.0, b=20.0, c=30.0)) + test_to_vec(Composite{ThreeFields}(; a=10.0, b=20.0,)) + test_to_vec(Composite{ThreeFields}(; a=10.0, c=30.0)) + test_to_vec(Composite{ThreeFields}(; c=30.0, a=10.0, b=20.0)) + end + end + @testset "AbstractZero" begin test_to_vec(Zero()) test_to_vec(DoesNotExist())