-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor ChainRulesCoreExt into separate files #133
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
fcad6a9
Refactor ChainRulesCoreExt into separate files
lkdvos 8a821e0
Add rrule test `dot`
lkdvos 1148fe9
Add tests `ProjectTo`
lkdvos c15b7ab
Add tests rrules converters
lkdvos 76078de
Restrict convert rrule to trivialtensormap
lkdvos 7ada03f
Add kwargs to rrule
lkdvos e0ce389
Add unthunk
lkdvos c98624f
Refactor array -> tensormap conversion
lkdvos d06cca1
Reenable constructor ad tests
lkdvos 1827d6a
Refactor _interleave
lkdvos 5003daa
add converter for fusiontreepair to array
lkdvos 655d9e0
Refactor `project_symmetric!` into separate function
lkdvos 99f71c0
Add rule for (not) generating tangents for vector spaces
lkdvos File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
28 changes: 28 additions & 0 deletions
28
ext/TensorKitChainRulesCoreExt/TensorKitChainRulesCoreExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
module TensorKitChainRulesCoreExt | ||
|
||
using TensorOperations | ||
using VectorInterface | ||
using TensorKit | ||
using ChainRulesCore | ||
using LinearAlgebra | ||
using TupleTools | ||
|
||
import TensorOperations as TO | ||
using TensorOperations: Backend, promote_contract | ||
using VectorInterface: promote_scale, promote_add | ||
|
||
ext = @static if isdefined(Base, :get_extension) | ||
Base.get_extension(TensorOperations, :TensorOperationsChainRulesCoreExt) | ||
else | ||
TensorOperations.TensorOperationsChainRulesCoreExt | ||
end | ||
const _conj = ext._conj | ||
const trivtuple = ext.trivtuple | ||
|
||
include("utility.jl") | ||
include("constructors.jl") | ||
include("linalg.jl") | ||
include("tensoroperations.jl") | ||
include("factorizations.jl") | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
@non_differentiable TensorKit.TensorMap(f::Function, storagetype, cod, dom) | ||
@non_differentiable TensorKit.id(args...) | ||
@non_differentiable TensorKit.isomorphism(args...) | ||
@non_differentiable TensorKit.isometry(args...) | ||
@non_differentiable TensorKit.unitary(args...) | ||
|
||
function ChainRulesCore.rrule(::Type{<:TensorMap}, d::DenseArray, args...; kwargs...) | ||
function TensorMap_pullback(Δt) | ||
∂d = convert(Array, unthunk(Δt)) | ||
return NoTangent(), ∂d, ntuple(_ -> NoTangent(), length(args))... | ||
end | ||
return TensorMap(d, args...; kwargs...), TensorMap_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(Base.copy), t::AbstractTensorMap) | ||
copy_pullback(Δt) = NoTangent(), Δt | ||
return copy(t), copy_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(Base.convert), T::Type{<:Array}, | ||
t::AbstractTensorMap) | ||
A = convert(T, t) | ||
function convert_pullback(ΔA) | ||
# use constructor to (unconditionally) project back onto symmetric subspace | ||
∂t = TensorMap(unthunk(ΔA), codomain(t), domain(t); tol=Inf) | ||
return NoTangent(), NoTangent(), ∂t | ||
end | ||
return A, convert_pullback | ||
end | ||
|
||
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{Dict}, t::AbstractTensorMap) | ||
out = convert(Dict, t) | ||
function convert_pullback(c′) | ||
c = unthunk(c′) | ||
if haskey(c, :data) # :data is the only thing for which this dual makes sense | ||
dual = copy(out) | ||
dual[:data] = c[:data] | ||
return (NoTangent(), NoTangent(), convert(TensorMap, dual)) | ||
else | ||
# instead of zero(t) you can also return ZeroTangent(), which is type unstable | ||
return (NoTangent(), NoTangent(), zero(t)) | ||
end | ||
end | ||
return out, convert_pullback | ||
end | ||
function ChainRulesCore.rrule(::typeof(Base.convert), ::Type{TensorMap}, | ||
t::Dict{Symbol,Any}) | ||
return convert(TensorMap, t), v -> (NoTangent(), NoTangent(), convert(Dict, v)) | ||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about
dual = typeof(out)(:data => c[:data])
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that works, because then the spaces are missing in the dictionary-to-tensormap converter. Maybe the comment is a bit misleading -- all fields in the dictionary are required, but Zygote tends to drop fields that do not contribute (i.e. codomain and domain). Thus, this uses the dictionary from the forwards pass with the data from the backwards pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I wanted to avoid copying the
data
fromout
, but probablycopy(out)
is a shallow copy that does not duplicateout[:data]
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, it should really only duplicate the pointer, which I think is acceptable. On top of that, I am not sure this rule is ever used/useful anyways. I think I copied this at some point from maarten, but presumably that was also only introduced for testing purposes, as it also suffers from the weird interplay of "inner product on the parameters" for non-abelian symmetries.