Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ChainRulesCoreExt into separate files #133

Merged
merged 13 commits into from
Jul 2, 2024
28 changes: 28 additions & 0 deletions ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
module TensorKitChainRulesCoreExt

using TensorOperations
using VectorInterface
using TensorKit
using ChainRulesCore
using LinearAlgebra
using TupleTools

import TensorOperations as TO
using TensorOperations: Backend, promote_contract
using VectorInterface: promote_scale, promote_add

ext = @static if isdefined(Base, :get_extension)
Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt)
else
TensorOperations.TensorOperationsChainRulesCoreExt
end
const _conj = ext._conj
const trivtuple = ext.trivtuple

include("utility.jl")
include("constructors.jl")
include("linalg.jl")
include("tensoroperations.jl")
include("factorizations.jl")

end
49 changes: 49 additions & 0 deletions ext/TensorKitChainRulesCoreExt/constructors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom)
@non_differentiable TensorKit.id(args...)
@non_differentiable TensorKit.isomorphism(args...)
@non_differentiable TensorKit.isometry(args...)
@non_differentiable TensorKit.unitary(args...)

function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...)
function TensorMap_pullback(Δt)
∂d = convert(Array, unthunk(Δt))
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))...
end
return TensorMap(d, args...), TensorMap_pullback
end

function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap)
copy_pullback(Δt) = NoTangent(), Δt
return copy(t), copy_pullback
end

# this rule does not work for generic symmetries, as we currently have no way to
# project back onto the symmetric subspace
lkdvos marked this conversation as resolved.
Show resolved Hide resolved
function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array},
t::TrivialTensorMap)
A = convert(T, t)
function convert_pullback(ΔA)
∂t = TensorMap(unthunk(ΔA), codomain(t), domain(t))
return NoTangent(), NoTangent(), ∂t
end
return A, convert_pullback
end

function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap)
out = convert(Dict, t)
function convert_pullback(c)
if haskey(c, :data) # :data is the only thing for which this dual makes sense
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about dual = typeof(out)(:data => c[:data]) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that works, because then the spaces are missing in the dictionary-to-tensormap converter. Maybe the comment is a bit misleading -- all fields in the dictionary are required, but Zygote tends to drop fields that do not contribute (i.e. codomain and domain). Thus, this uses the dictionary from the forwards pass with the data from the backwards pass.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I wanted to avoid copying the data from out, but probably copy(out) is a shallow copy that does not duplicate out[:data]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, it should really only duplicate the pointer, which I think is acceptable. On top of that, I am not sure this rule is ever used/useful anyways. I think I copied this at some point from maarten, but presumably that was also only introduced for testing purposes, as it also suffers from the weird interplay of "inner product on the parameters" for non-abelian symmetries.

dual = copy(out)
dual[:data] = c[:data]
return (NoTangent(), NoTangent(), convert(TensorMap, dual))

Check warning on line 38 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L32-L38

Added lines #L32 - L38 were not covered by tests
else
# instead of zero(t) you can also return ZeroTangent(), which is type unstable
return (NoTangent(), NoTangent(), zero(t))

Check warning on line 41 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L41

Added line #L41 was not covered by tests
end
end
return out, convert_pullback

Check warning on line 44 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L44

Added line #L44 was not covered by tests
end
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap},

Check warning on line 46 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L46

Added line #L46 was not covered by tests
t::Dict{Symbol,Any})
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v))

Check warning on line 48 in ext/TensorKitChainRulesCoreExt/constructors.jl

View check run for this annotation

Codecov / codecov/patch

ext/TensorKitChainRulesCoreExt/constructors.jl#L48

Added line #L48 was not covered by tests
end
Loading