diff --git a/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl new file mode 100644 index 00000000..da202178 --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl @@ -0,0 +1,28 @@ +module TensorKitChainRulesCoreExt + +using TensorOperations +using VectorInterface +using TensorKit +using ChainRulesCore +using LinearAlgebra +using TupleTools + +import TensorOperations as TO +using TensorOperations: Backend, promote_contract +using VectorInterface: promote_scale, promote_add + +ext = @static if isdefined(Base, :get_extension) + Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt) +else + TensorOperations.TensorOperationsChainRulesCoreExt +end +const _conj = ext._conj +const trivtuple = ext.trivtuple + +include("utility.jl") +include("constructors.jl") +include("linalg.jl") +include("tensoroperations.jl") +include("factorizations.jl") + +end diff --git a/ext/TensorKitChainRulesCoreExt/constructors.jl b/ext/TensorKitChainRulesCoreExt/constructors.jl new file mode 100644 index 00000000..38ed3a6b --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/constructors.jl @@ -0,0 +1,49 @@ +@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom) +@non_differentiable TensorKit.id(args...) +@non_differentiable TensorKit.isomorphism(args...) +@non_differentiable TensorKit.isometry(args...) +@non_differentiable TensorKit.unitary(args...) + +function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...) + function TensorMap_pullback(Δt) + ∂d = convert(Array, unthunk(Δt)) + return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))... + end + return TensorMap(d, args...; kwargs...), TensorMap_pullback +end + +function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) + copy_pullback(Δt) = NoTangent(), Δt + return copy(t), copy_pullback +end + +function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array}, + t::AbstractTensorMap) + A = convert(T, t) + function convert_pullback(ΔA) + # use constructor to (unconditionally) project back onto symmetric subspace + ∂t = TensorMap(unthunk(ΔA), codomain(t), domain(t); tol=Inf) + return NoTangent(), NoTangent(), ∂t + end + return A, convert_pullback +end + +function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) + out = convert(Dict, t) + function convert_pullback(c′) + c = unthunk(c′) + if haskey(c, :data) # :data is the only thing for which this dual makes sense + dual = copy(out) + dual[:data] = c[:data] + return (NoTangent(), NoTangent(), convert(TensorMap, dual)) + else + # instead of zero(t) you can also return ZeroTangent(), which is type unstable + return (NoTangent(), NoTangent(), zero(t)) + end + end + return out, convert_pullback +end +function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap}, + t::Dict{Symbol,Any}) + return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v)) +end diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt/factorizations.jl similarity index 54% rename from ext/TensorKitChainRulesCoreExt.jl rename to ext/TensorKitChainRulesCoreExt/factorizations.jl index 16b51df6..b254b9a8 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt/factorizations.jl @@ -1,185 +1,5 @@ -module TensorKitChainRulesCoreExt - -using TensorOperations -using VectorInterface -using TensorKit -using ChainRulesCore -using LinearAlgebra -using TupleTools - -import TensorOperations as TO -using TensorOperations: Backend, promote_contract -using VectorInterface: promote_scale, promote_add - -ext = @static if isdefined(Base, :get_extension) - Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt) -else - TensorOperations.TensorOperationsChainRulesCoreExt -end -const _conj = ext._conj -const trivtuple = ext.trivtuple - -# Utility -# ------- - -function _repartition(p::IndexTuple, N₁::Int) - length(p) >= N₁ || - throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) - return p[1:N₁], p[(N₁ + 1):end] -end -_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁) -function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} - return _repartition(p, N₁) -end -function _repartition(p::Union{IndexTuple,Index2Tuple}, - ::AbstractTensorMap{<:Any,N₁}) where {N₁} - return _repartition(p, N₁) -end - -TensorKit.block(t::ZeroTangent, c::Sector) = t - -# Constructors -# ------------ - -@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom) -@non_differentiable TensorKit.id(args...) -@non_differentiable TensorKit.isomorphism(args...) -@non_differentiable TensorKit.isometry(args...) -@non_differentiable TensorKit.unitary(args...) - -function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...) - function TensorMap_pullback(Δt) - ∂d = convert(Array, Δt) - return NoTangent(), ∂d, fill(NoTangent(), length(args))... - end - return TensorMap(d, args...), TensorMap_pullback -end - -function ChainRulesCore.rrule(::typeof(convert), T::Type{<:Array}, t::AbstractTensorMap) - A = convert(T, t) - function convert_pullback(ΔA) - ∂t = TensorMap(ΔA, codomain(t), domain(t)) - return NoTangent(), NoTangent(), ∂t - end - return A, convert_pullback -end - -function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) - copy_pullback(Δt) = NoTangent(), Δt - return copy(t), copy_pullback -end - -ChainRulesCore.ProjectTo(::T) where {T<:AbstractTensorMap} = ProjectTo{T}() -function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{S,N1,N2}, - T2<:AbstractTensorMap{S,N1,N2}} - T1 === T2 && return x - y = similar(x, scalartype(T1)) - for (c, b) in blocks(y) - p = ProjectTo(b) - b .= p(block(x, c)) - end - return y -end - -# Base Linear Algebra -# ------------------- - -function ChainRulesCore.rrule(::typeof(+), a::AbstractTensorMap, b::AbstractTensorMap) - plus_pullback(Δc) = NoTangent(), Δc, Δc - return a + b, plus_pullback -end - -ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap) = -a, Δc -> (NoTangent(), -Δc) -function ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap, b::AbstractTensorMap) - minus_pullback(Δc) = NoTangent(), Δc, -Δc - return a - b, minus_pullback -end - -function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::AbstractTensorMap) - times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(a' * Δc) - return a * b, times_pullback -end - -function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::Number) - times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(dot(a, Δc)) - return a * b, times_pullback -end - -function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap) - times_pullback(Δc) = NoTangent(), @thunk(dot(b, Δc)), @thunk(a' * Δc) - return a * b, times_pullback -end - -function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTensorMap) - C = A ⊗ B - projectA = ProjectTo(A) - projectB = ProjectTo(B) - function otimes_pullback(ΔC_) - # TODO: this rule is probably better written in terms of inner products, - # using planarcontract and adjoint tensormaps would remove the twists. - ΔC = unthunk(ΔC_) - pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...), - ((codomainind(B) .+ numout(A))..., - (domainind(B) .+ (numin(A) + numout(A)))...)) - dA_ = @thunk begin - ipA = (codomainind(A), domainind(A)) - pB = (allind(B), ()) - dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B))) - tB = twist(B, filter(x -> isdual(space(B, x)), allind(B))) - dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, tB, pB, :C) - return projectA(dA) - end - dB_ = @thunk begin - ipB = (codomainind(B), domainind(B)) - pA = ((), allind(A)) - dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A))) - tA = twist(A, filter(x -> isdual(space(A, x)), allind(A))) - dB = tensorcontract!(dB, ipB, tA, pA, :C, ΔC, pΔC, :N) - return projectB(dB) - end - return NoTangent(), dA_, dB_ - end - return C, otimes_pullback -end - -function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple; - copy::Bool=false) - function permute_pullback(Δtdst) - invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) - return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent() - end - return permute(tsrc, p; copy=true), permute_pullback -end - -# LinearAlgebra -# ------------- - -function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) - tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) - return tr(A), tr_pullback -end - -function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap) - adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint)) - return adjoint(A), adjoint_pullback -end - -function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap) - dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd) - return dot(a, b), dot_pullback -end - -function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) - p == 2 || error("currently only implemented for p = 2") - n = norm(a, p) - function norm_pullback(Δn) - return NoTangent(), a * (Δn' + Δn) / 2 / hypot(n, eps(one(n))), NoTangent() - end - return n, norm_pullback -end - -# Factorizations -# -------------- +# Factorizations rules +# -------------------- function ChainRulesCore.rrule(::typeof(TensorKit.tsvd!), t::AbstractTensorMap; trunc::TensorKit.TruncationScheme=TensorKit.NoTruncation(), p::Real=2, @@ -669,172 +489,3 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ldiv!(LowerTriangular(L11)', ΔA1) return ΔA end - -# Convert rrules -#---------------- -function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) - out = convert(Dict, t) - function convert_pullback(c) - if haskey(c, :data) # :data is the only thing for which this dual makes sense - dual = copy(out) - dual[:data] = c[:data] - return (NoTangent(), NoTangent(), convert(TensorMap, dual)) - else - # instead of zero(t) you can also return ZeroTangent(), which is type unstable - return (NoTangent(), NoTangent(), zero(t)) - end - end - return out, convert_pullback -end -function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap}, - t::Dict{Symbol,Any}) - return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v)) -end - -function ChainRulesCore.rrule(::typeof(TO.tensorcontract!), - C::AbstractTensorMap{S}, pC::Index2Tuple, - A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol, - B::AbstractTensorMap{S}, pB::Index2Tuple, conjB::Symbol, - α::Number, β::Number, - backend::Backend...) where {S} - C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...) - - projectA = ProjectTo(A) - projectB = ProjectTo(B) - projectC = ProjectTo(C) - projectα = ProjectTo(α) - projectβ = ProjectTo(β) - - function pullback(ΔC′) - ΔC = unthunk(ΔC′) - ipC = invperm(linearize(pC)) - pΔC = (TupleTools.getindices(ipC, trivtuple(TO.numout(pA))), - TupleTools.getindices(ipC, TO.numout(pA) .+ trivtuple(TO.numin(pB)))) - - dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk begin - ipA = (invperm(linearize(pA)), ()) - conjΔC = conjA == :C ? :C : :N - conjB′ = conjA == :C ? conjB : _conj(conjB) - _dA = zerovector(A, - promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))) - tB = twist(B, - TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]), - filter(x -> isdual(space(B, x)), pB[2]))) - _dA = tensorcontract!(_dA, ipA, - ΔC, pΔC, conjΔC, - tB, reverse(pB), conjB′, - conjA == :C ? α : conj(α), Zero(), backend...) - return projectA(_dA) - end - dB = @thunk begin - ipB = (invperm(linearize(pB)), ()) - conjΔC = conjB == :C ? :C : :N - conjA′ = conjB == :C ? conjA : _conj(conjA) - _dB = zerovector(B, - promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))) - tA = twist(A, - TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]), - filter(x -> !isdual(space(A, x)), pA[2]))) - _dB = tensorcontract!(_dB, ipB, - tA, reverse(pA), conjA′, - ΔC, pΔC, conjΔC, - conjB == :C ? α : conj(α), Zero(), backend...) - return projectB(_dB) - end - dα = @thunk begin - # TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB - AB = tensorcontract(pC, A, pA, conjA, B, pB, conjB) - return projectα(inner(AB, ΔC)) - end - dβ = @thunk projectβ(inner(C, ΔC)) - dbackend = map(x -> NoTangent(), backend) - return NoTangent(), dC, NoTangent(), - dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ, - dbackend... - end - return C′, pullback -end - -function ChainRulesCore.rrule(::typeof(TO.tensoradd!), - C::AbstractTensorMap{S}, pC::Index2Tuple, - A::AbstractTensorMap{S}, conjA::Symbol, - α::Number, β::Number, backend::Backend...) where {S} - C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...) - - projectA = ProjectTo(A) - projectC = ProjectTo(C) - projectα = ProjectTo(α) - projectβ = ProjectTo(β) - - function pullback(ΔC′) - ΔC = unthunk(ΔC′) - dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk begin - ipC = invperm(linearize(pC)) - _dA = zerovector(A, promote_add(ΔC, α)) - _dA = tensoradd!(_dA, (ipC, ()), ΔC, conjA, conjA == :N ? conj(α) : α, Zero(), - backend...) - return projectA(_dA) - end - dα = @thunk begin - # TODO: this is an inner product implemented as a contraction - # for non-symmetric tensors this might be more efficient like this, - # but for symmetric tensors an intermediate object will anyways be created - # and then it might be more efficient to use an addition and inner product - tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) - _dα = tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)), - _conj(conjA), tΔC, - (trivtuple(TO.numind(pC)), - ()), :N, One(), backend...)) - return projectα(_dα) - end - dβ = @thunk projectβ(inner(C, ΔC)) - dbackend = map(x -> NoTangent(), backend) - return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend... - end - - return C′, pullback -end - -function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap{S}, - pC::Index2Tuple, A::AbstractTensorMap{S}, - pA::Index2Tuple, conjA::Symbol, α::Number, β::Number, - backend::Backend...) where {S} - C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...) - - projectA = ProjectTo(A) - projectC = ProjectTo(C) - projectα = ProjectTo(α) - projectβ = ProjectTo(β) - - function pullback(ΔC′) - ΔC = unthunk(ΔC′) - dC = @thunk projectC(scale(ΔC, conj(β))) - dA = @thunk begin - ipC = invperm((linearize(pC)..., pA[1]..., pA[2]...)) - E = one!(TO.tensoralloc_add(scalartype(A), pA, A, conjA)) - twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) - _dA = zerovector(A, promote_scale(ΔC, α)) - _dA = tensorproduct!(_dA, (ipC, ()), ΔC, - (trivtuple(TO.numind(pC)), ()), conjA, E, - ((), trivtuple(TO.numind(pA))), conjA, - conjA == :N ? conj(α) : α, Zero(), backend...) - return projectA(_dA) - end - dα = @thunk begin - # TODO: this result might be easier to compute as: - # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α - At = tensortrace(pC, A, pA, conjA) - return projectα(inner(At, ΔC)) - end - dβ = @thunk projectβ(inner(C, ΔC)) - dbackend = map(x -> NoTangent(), backend) - return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ, - dbackend... - end - - return C′, pullback -end - -end diff --git a/ext/TensorKitChainRulesCoreExt/linalg.jl b/ext/TensorKitChainRulesCoreExt/linalg.jl new file mode 100644 index 00000000..a7bd64fa --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/linalg.jl @@ -0,0 +1,92 @@ +# Linear Algebra chainrules +# ------------------------- +function ChainRulesCore.rrule(::typeof(+), a::AbstractTensorMap, b::AbstractTensorMap) + plus_pullback(Δc) = NoTangent(), Δc, Δc + return a + b, plus_pullback +end + +ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap) = -a, Δc -> (NoTangent(), -Δc) +function ChainRulesCore.rrule(::typeof(-), a::AbstractTensorMap, b::AbstractTensorMap) + minus_pullback(Δc) = NoTangent(), Δc, -Δc + return a - b, minus_pullback +end + +function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::AbstractTensorMap) + times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(a' * Δc) + return a * b, times_pullback +end + +function ChainRulesCore.rrule(::typeof(*), a::AbstractTensorMap, b::Number) + times_pullback(Δc) = NoTangent(), @thunk(Δc * b'), @thunk(dot(a, Δc)) + return a * b, times_pullback +end + +function ChainRulesCore.rrule(::typeof(*), a::Number, b::AbstractTensorMap) + times_pullback(Δc) = NoTangent(), @thunk(dot(b, Δc)), @thunk(a' * Δc) + return a * b, times_pullback +end + +function ChainRulesCore.rrule(::typeof(⊗), A::AbstractTensorMap, B::AbstractTensorMap) + C = A ⊗ B + projectA = ProjectTo(A) + projectB = ProjectTo(B) + function otimes_pullback(ΔC_) + # TODO: this rule is probably better written in terms of inner products, + # using planarcontract and adjoint tensormaps would remove the twists. + ΔC = unthunk(ΔC_) + pΔC = ((codomainind(A)..., (domainind(A) .+ numout(B))...), + ((codomainind(B) .+ numout(A))..., + (domainind(B) .+ (numin(A) + numout(A)))...)) + dA_ = @thunk begin + ipA = (codomainind(A), domainind(A)) + pB = (allind(B), ()) + dA = zerovector(A, promote_contract(scalartype(ΔC), scalartype(B))) + tB = twist(B, filter(x -> isdual(space(B, x)), allind(B))) + dA = tensorcontract!(dA, ipA, ΔC, pΔC, :N, tB, pB, :C) + return projectA(dA) + end + dB_ = @thunk begin + ipB = (codomainind(B), domainind(B)) + pA = ((), allind(A)) + dB = zerovector(B, promote_contract(scalartype(ΔC), scalartype(A))) + tA = twist(A, filter(x -> isdual(space(A, x)), allind(A))) + dB = tensorcontract!(dB, ipB, tA, pA, :C, ΔC, pΔC, :N) + return projectB(dB) + end + return NoTangent(), dA_, dB_ + end + return C, otimes_pullback +end + +function ChainRulesCore.rrule(::typeof(permute), tsrc::AbstractTensorMap, p::Index2Tuple; + copy::Bool=false) + function permute_pullback(Δtdst) + invp = TensorKit._canonicalize(TupleTools.invperm(linearize(p)), tsrc) + return NoTangent(), permute(unthunk(Δtdst), invp; copy=true), NoTangent() + end + return permute(tsrc, p; copy=true), permute_pullback +end + +function ChainRulesCore.rrule(::typeof(tr), A::AbstractTensorMap) + tr_pullback(Δtr) = NoTangent(), Δtr * id(domain(A)) + return tr(A), tr_pullback +end + +function ChainRulesCore.rrule(::typeof(adjoint), A::AbstractTensorMap) + adjoint_pullback(Δadjoint) = NoTangent(), adjoint(unthunk(Δadjoint)) + return adjoint(A), adjoint_pullback +end + +function ChainRulesCore.rrule(::typeof(dot), a::AbstractTensorMap, b::AbstractTensorMap) + dot_pullback(Δd) = NoTangent(), @thunk(b * Δd'), @thunk(a * Δd) + return dot(a, b), dot_pullback +end + +function ChainRulesCore.rrule(::typeof(norm), a::AbstractTensorMap, p::Real=2) + p == 2 || error("currently only implemented for p = 2") + n = norm(a, p) + function norm_pullback(Δn) + return NoTangent(), a * (Δn' + Δn) / 2 / hypot(n, eps(one(n))), NoTangent() + end + return n, norm_pullback +end diff --git a/ext/TensorKitChainRulesCoreExt/tensoroperations.jl b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl new file mode 100644 index 00000000..72c25a5a --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/tensoroperations.jl @@ -0,0 +1,145 @@ +function ChainRulesCore.rrule(::typeof(TO.tensorcontract!), + C::AbstractTensorMap{S}, pC::Index2Tuple, + A::AbstractTensorMap{S}, pA::Index2Tuple, conjA::Symbol, + B::AbstractTensorMap{S}, pB::Index2Tuple, conjB::Symbol, + α::Number, β::Number, + backend::Backend...) where {S} + C′ = tensorcontract!(copy(C), pC, A, pA, conjA, B, pB, conjB, α, β, backend...) + + projectA = ProjectTo(A) + projectB = ProjectTo(B) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function pullback(ΔC′) + ΔC = unthunk(ΔC′) + ipC = invperm(linearize(pC)) + pΔC = (TupleTools.getindices(ipC, trivtuple(TO.numout(pA))), + TupleTools.getindices(ipC, TO.numout(pA) .+ trivtuple(TO.numin(pB)))) + + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipA = (invperm(linearize(pA)), ()) + conjΔC = conjA == :C ? :C : :N + conjB′ = conjA == :C ? conjB : _conj(conjB) + _dA = zerovector(A, + promote_contract(scalartype(ΔC), scalartype(B), scalartype(α))) + tB = twist(B, + TupleTools.vcat(filter(x -> !isdual(space(B, x)), pB[1]), + filter(x -> isdual(space(B, x)), pB[2]))) + _dA = tensorcontract!(_dA, ipA, + ΔC, pΔC, conjΔC, + tB, reverse(pB), conjB′, + conjA == :C ? α : conj(α), Zero(), backend...) + return projectA(_dA) + end + dB = @thunk begin + ipB = (invperm(linearize(pB)), ()) + conjΔC = conjB == :C ? :C : :N + conjA′ = conjB == :C ? conjA : _conj(conjA) + _dB = zerovector(B, + promote_contract(scalartype(ΔC), scalartype(A), scalartype(α))) + tA = twist(A, + TupleTools.vcat(filter(x -> isdual(space(A, x)), pA[1]), + filter(x -> !isdual(space(A, x)), pA[2]))) + _dB = tensorcontract!(_dB, ipB, + tA, reverse(pA), conjA′, + ΔC, pΔC, conjΔC, + conjB == :C ? α : conj(α), Zero(), backend...) + return projectB(_dB) + end + dα = @thunk begin + # TODO: this result should be AB = (C′ - βC) / α as C′ = βC + αAB + AB = tensorcontract(pC, A, pA, conjA, B, pB, conjB) + return projectα(inner(AB, ΔC)) + end + dβ = @thunk projectβ(inner(C, ΔC)) + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, NoTangent(), + dA, NoTangent(), NoTangent(), dB, NoTangent(), NoTangent(), dα, dβ, + dbackend... + end + return C′, pullback +end + +function ChainRulesCore.rrule(::typeof(TO.tensoradd!), + C::AbstractTensorMap{S}, pC::Index2Tuple, + A::AbstractTensorMap{S}, conjA::Symbol, + α::Number, β::Number, backend::Backend...) where {S} + C′ = tensoradd!(copy(C), pC, A, conjA, α, β, backend...) + + projectA = ProjectTo(A) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function pullback(ΔC′) + ΔC = unthunk(ΔC′) + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipC = invperm(linearize(pC)) + _dA = zerovector(A, promote_add(ΔC, α)) + _dA = tensoradd!(_dA, (ipC, ()), ΔC, conjA, conjA == :N ? conj(α) : α, Zero(), + backend...) + return projectA(_dA) + end + dα = @thunk begin + # TODO: this is an inner product implemented as a contraction + # for non-symmetric tensors this might be more efficient like this, + # but for symmetric tensors an intermediate object will anyways be created + # and then it might be more efficient to use an addition and inner product + tΔC = twist(ΔC, filter(x -> isdual(space(ΔC, x)), allind(ΔC))) + _dα = tensorscalar(tensorcontract(((), ()), A, ((), linearize(pC)), + _conj(conjA), tΔC, + (trivtuple(TO.numind(pC)), + ()), :N, One(), backend...)) + return projectα(_dα) + end + dβ = @thunk projectβ(inner(C, ΔC)) + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, NoTangent(), dA, NoTangent(), dα, dβ, dbackend... + end + + return C′, pullback +end + +function ChainRulesCore.rrule(::typeof(tensortrace!), C::AbstractTensorMap{S}, + pC::Index2Tuple, A::AbstractTensorMap{S}, + pA::Index2Tuple, conjA::Symbol, α::Number, β::Number, + backend::Backend...) where {S} + C′ = tensortrace!(copy(C), pC, A, pA, conjA, α, β, backend...) + + projectA = ProjectTo(A) + projectC = ProjectTo(C) + projectα = ProjectTo(α) + projectβ = ProjectTo(β) + + function pullback(ΔC′) + ΔC = unthunk(ΔC′) + dC = @thunk projectC(scale(ΔC, conj(β))) + dA = @thunk begin + ipC = invperm((linearize(pC)..., pA[1]..., pA[2]...)) + E = one!(TO.tensoralloc_add(scalartype(A), pA, A, conjA)) + twist!(E, filter(x -> !isdual(space(E, x)), codomainind(E))) + _dA = zerovector(A, promote_scale(ΔC, α)) + _dA = tensorproduct!(_dA, (ipC, ()), ΔC, + (trivtuple(TO.numind(pC)), ()), conjA, E, + ((), trivtuple(TO.numind(pA))), conjA, + conjA == :N ? conj(α) : α, Zero(), backend...) + return projectA(_dA) + end + dα = @thunk begin + # TODO: this result might be easier to compute as: + # C′ = βC + α * trace(A) ⟹ At = (C′ - βC) / α + At = tensortrace(pC, A, pA, conjA) + return projectα(inner(At, ΔC)) + end + dβ = @thunk projectβ(inner(C, ΔC)) + dbackend = map(x -> NoTangent(), backend) + return NoTangent(), dC, NoTangent(), dA, NoTangent(), NoTangent(), dα, dβ, + dbackend... + end + + return C′, pullback +end diff --git a/ext/TensorKitChainRulesCoreExt/utility.jl b/ext/TensorKitChainRulesCoreExt/utility.jl new file mode 100644 index 00000000..170ed09f --- /dev/null +++ b/ext/TensorKitChainRulesCoreExt/utility.jl @@ -0,0 +1,29 @@ +# Utility +# ------- +function _repartition(p::IndexTuple, N₁::Int) + length(p) >= N₁ || + throw(ArgumentError("cannot repartition $(typeof(p)) to $N₁, $(length(p) - N₁)")) + return p[1:N₁], p[(N₁ + 1):end] +end +_repartition(p::Index2Tuple, N₁::Int) = _repartition(linearize(p), N₁) +function _repartition(p::Union{IndexTuple,Index2Tuple}, ::Index2Tuple{N₁}) where {N₁} + return _repartition(p, N₁) +end +function _repartition(p::Union{IndexTuple,Index2Tuple}, + ::AbstractTensorMap{<:Any,N₁}) where {N₁} + return _repartition(p, N₁) +end + +TensorKit.block(t::ZeroTangent, c::Sector) = t + +ChainRulesCore.ProjectTo(::T) where {T<:AbstractTensorMap} = ProjectTo{T}() +function (::ProjectTo{T1})(x::T2) where {S,N1,N2,T1<:AbstractTensorMap{S,N1,N2}, + T2<:AbstractTensorMap{S,N1,N2}} + T1 === T2 && return x + y = similar(x, scalartype(T1)) + for (c, b) in blocks(y) + p = ProjectTo(b) + b .= p(block(x, c)) + end + return y +end diff --git a/src/auxiliary/auxiliary.jl b/src/auxiliary/auxiliary.jl index e8c9445b..3cf0da46 100644 --- a/src/auxiliary/auxiliary.jl +++ b/src/auxiliary/auxiliary.jl @@ -49,3 +49,13 @@ end else using Base: @constprop end + +""" + _interleave(a::NTuple{N}, b::NTuple{N}) -> NTuple{2N} + +Interleave two tuples of the same length. +""" +_interleave(::Tuple{}, ::Tuple{}) = () +function _interleave(a::NTuple{N}, b::NTuple{N}) where {N} + return (a[1], b[1], _interleave(tail(a), tail(b))...) +end diff --git a/src/fusiontrees/fusiontrees.jl b/src/fusiontrees/fusiontrees.jl index 4e168b86..f6ebe84f 100644 --- a/src/fusiontrees/fusiontrees.jl +++ b/src/fusiontrees/fusiontrees.jl @@ -207,6 +207,20 @@ function Base.convert(A::Type{<:AbstractArray}, f::FusionTree{I,N}) where {I,N} Ctail, ((1,), Base.tail(trivialtuple)), :N, true, false) end +# TODO: is this piracy? +function Base.convert(A::Type{<:AbstractArray}, + (f₁, f₂)::Tuple{FusionTree{I},FusionTree{I}}) where {I} + F₁ = convert(A, f₁) + F₂ = convert(A, f₂) + sz1 = size(F₁) + sz2 = size(F₂) + d1 = TupleTools.front(sz1) + d2 = TupleTools.front(sz2) + + return reshape(reshape(F₁, TupleTools.prod(d1), sz1[end]) * + reshape(F₂, TupleTools.prod(d2), sz2[end])', (d1..., d2...)) +end + # Show methods function Base.show(io::IO, t::FusionTree{I,N,M,K,Nothing}) where {I<:Sector,N,M,K} return print(IOContext(io, :typeinfo => I), "FusionTree{", type_repr(I), "}(", diff --git a/src/tensors/abstracttensor.jl b/src/tensors/abstracttensor.jl index 55195e4f..38997c94 100644 --- a/src/tensors/abstracttensor.jl +++ b/src/tensors/abstracttensor.jl @@ -267,14 +267,7 @@ function Base.convert(::Type{Array}, t::AbstractTensorMap{S,N₁,N₂}) where {S dom = domain(t) local A for (f₁, f₂) in fusiontrees(t) - F₁ = convert(Array, f₁) - F₂ = convert(Array, f₂) - sz1 = size(F₁) - sz2 = size(F₂) - d1 = TupleTools.front(sz1) - d2 = TupleTools.front(sz2) - F = reshape(reshape(F₁, TupleTools.prod(d1), sz1[end]) * - reshape(F₂, TupleTools.prod(d2), sz2[end])', (d1..., d2...)) + F = convert(Array, (f₁, f₂)) if !(@isdefined A) if eltype(F) <: Complex T = complex(float(scalartype(t))) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 233b82dd..85e46858 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -350,50 +350,53 @@ function TensorMap(data::DenseArray, codom::ProductSpace{S,N₁}, dom::ProductSp size(data) == (dims(codom)..., dims(dom)...)) throw(DimensionMismatch()) end + if sectortype(S) === Trivial data2 = reshape(data, (d1, d2)) A = typeof(data2) return TensorMap{S,N₁,N₂,Trivial,A,Nothing,Nothing}(data2, codom, dom) - else - t = TensorMap(zeros, eltype(data), codom, dom) - ta = convert(Array, t) - l = length(ta) - dimt = dim(t) - basis = zeros(eltype(ta), (l, dimt)) - qdims = zeros(real(eltype(ta)), (dimt,)) - i = 1 - for (c, b) in blocks(t) - for k in 1:length(b) - b[k] = 1 - copy!(view(basis, :, i), reshape(convert(Array, t), (l,))) - qdims[i] = dim(c) - b[k] = 0 - i += 1 - end - end - rhs = reshape(data, (l,)) - if FusionStyle(sectortype(t)) isa UniqueFusion - lhs = basis' * rhs - else - lhs = Diagonal(qdims) \ (basis' * rhs) - end - if norm(basis * lhs - rhs) > tol - throw(ArgumentError("Data has non-zero elements at incompatible positions")) - end - if eltype(lhs) != scalartype(t) - t2 = TensorMap(zeros, promote_type(eltype(lhs), scalartype(t)), codom, dom) - else - t2 = t - end - i = 1 - for (c, b) in blocks(t2) - for k in 1:length(b) - b[k] = lhs[i] - i += 1 - end + end + + t = TensorMap(undef, eltype(data), codom, dom) + project_symmetric!(t, data) + + if !isapprox(data, convert(Array, t); atol=tol) + throw(ArgumentError("Data has non-zero elements at incompatible positions")) + end + + return t +end + +""" + project_symmetric!(t::TensorMap, data::DenseArray) -> TensorMap + +Project the data from a dense array `data` into the tensor map `t`. This function discards +any data that does not fit the symmetry structure of `t`. +""" +function project_symmetric!(t::TensorMap, data::DenseArray) + if sectortype(t) === Trivial + copy!(t.data, reshape(data, size(t.data))) + return t + end + + for (f₁, f₂) in fusiontrees(t) + F = convert(Array, (f₁, f₂)) + b = zeros(eltype(data), dims(codomain(t), f₁.uncoupled)..., + dims(domain(t), f₂.uncoupled)...) + szbF = _interleave(size(b), size(F)) + dataslice = sreshape(StridedView(data)[axes(codomain(t), f₁.uncoupled)..., + axes(domain(t), f₂.uncoupled)...], szbF) + # project (can this be done in one go?) + d = inv(dim(f₁.coupled)) + for k in eachindex(b) + b[k] = 1 + projector = _kron(b, F) # probably possible to re-use memory + t[f₁, f₂][k] = dot(projector, dataslice) * d + b[k] = 0 end - return t2 end + + return t end # Efficient copy constructors diff --git a/test/ad.jl b/test/ad.jl index 01b1eb34..020a49e1 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -15,6 +15,7 @@ end function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::AbstractTensorMap) return TensorMap(randn, scalartype(x), space(x)) end +ChainRulesTestUtils.rand_tangent(::AbstractRNG, ::VectorSpace) = NoTangent() function ChainRulesTestUtils.test_approx(actual::AbstractTensorMap, expected::AbstractTensorMap, msg=""; kwargs...) for (c, b) in blocks(actual) @@ -132,6 +133,22 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), @timedtestset "Automatic Differentiation with spacetype $(TensorKit.type_repr(eltype(V)))" verbose = true for V in Vlist + @timedtestset "Basic utility" begin + T1 = TensorMap(randn, Float64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) + T2 = TensorMap(randn, ComplexF64, V[1] ⊗ V[2] ← V[3] ⊗ V[4]) + + P1 = ProjectTo(T1) + @test P1(T1) == T1 + @test P1(T2) == real(T2) + + test_rrule(copy, T1) + test_rrule(copy, T2) + + test_rrule(convert, Array, T1) + test_rrule(TensorMap, convert(Array, T1), codomain(T1), domain(T1); + fkwargs=(; tol=Inf)) + end + @timedtestset "Basic Linear Algebra with scalartype $T" for T in (Float64, ComplexF64) A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) B = TensorMap(randn, T, space(A)) @@ -163,6 +180,9 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'), A = TensorMap(randn, T, V[1] ⊗ V[2] ← V[3] ⊗ V[4] ⊗ V[5]) test_rrule(LinearAlgebra.adjoint, A) test_rrule(LinearAlgebra.norm, A, 2) + + B = TensorMap(randn, T, space(A)) + test_rrule(LinearAlgebra.dot, A, B) end @timedtestset "TensorOperations with scalartype $T" for T in (Float64, ComplexF64)