Skip to content

Commit

Permalink
implement fixes and suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Sep 12, 2024
1 parent c159984 commit c1c174e
Show file tree
Hide file tree
Showing 6 changed files with 483 additions and 504 deletions.
4 changes: 2 additions & 2 deletions src/auxiliary/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ using ..TensorKit: OrthogonalFactorizationAlgorithm,

# TODO: define for CuMatrix if we support this
function one!(A::StridedMatrix)
A[:] .= 0
A[diagind(A)] .= 1
length(A) > 0 || return A
copyto!(A, LinearAlgebra.I)
return A
end

Expand Down
2 changes: 1 addition & 1 deletion src/tensors/abstracttensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ spacetype(::Type{<:AbstractTensorMap{<:Any,S}}) where {S} = S
Return the type of sector `I` of a tensor.
"""
sectortype(::Type{<:AbstractTensorMap{<:Any,S}}) where {S} = sectortype(S)
sectortype(::Type{TT}) where {TT<:AbstractTensorMap} = sectortype(spacetype(TT))

function InnerProductStyle(::Type{TT}) where {TT<:AbstractTensorMap}
return InnerProductStyle(spacetype(TT))
Expand Down
19 changes: 10 additions & 9 deletions src/tensors/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,31 @@
Specific subtype of [`AbstractTensorMap`](@ref) that is a lazy wrapper for representing the
adjoint of an instance of [`AbstractTensorMap`](@ref).
"""
struct AdjointTensorMap{T,S,N₁,N₂,TT<:AbstractTensorMap{T,S,N₁,N₂}} <:
struct AdjointTensorMap{T,S,N₁,N₂,TT<:AbstractTensorMap{T,S,N₂,N₁}} <:
AbstractTensorMap{T,S,N₁,N₂}
parent::TT
end
Base.parent(t::AdjointTensorMap) = t.parent

# Constructor: construct from taking adjoint of a tensor
Base.adjoint(t::AdjointTensorMap) = t.parent
Base.adjoint(t::AdjointTensorMap) = parent(t)
Base.adjoint(t::AbstractTensorMap) = AdjointTensorMap(t)

# Properties
space(t::AdjointTensorMap) = adjoint(space(t'))
dim(t::AdjointTensorMap) = dim(t')
space(t::AdjointTensorMap) = adjoint(space(parent(t)))
dim(t::AdjointTensorMap) = dim(parent(t))
storagetype(::Type{AdjointTensorMap{T,S,N₁,N₂,TT}}) where {T,S,N₁,N₂,TT} = storagetype(TT)

# Blocks and subblocks
#----------------------
blocksectors(t::AdjointTensorMap) = blocksectors(t')
block(t::AdjointTensorMap, s::Sector) = block(t', s)'
blocksectors(t::AdjointTensorMap) = blocksectors(parent(t))
block(t::AdjointTensorMap, s::Sector) = block(parent(t), s)'

function Base.getindex(t::AdjointTensorMap{T,S,N₁,N₂},

Check warning on line 29 in src/tensors/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/adjoint.jl#L29

Added line #L29 was not covered by tests
f₁::FusionTree{I,N₁}, f₂::FusionTree{I,N₂}) where {T,S,N₁,N₂,I}
parent = t'
subblock = getindex(parent, f₂, f₁)
return permutedims(conj(subblock), (domainind(parent)..., codomainind(parent)...))
tp = parent(t)
subblock = getindex(tp, f₂, f₁)
return permutedims(conj(subblock), (domainind(tp)..., codomainind(tp)...))

Check warning on line 33 in src/tensors/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/adjoint.jl#L31-L33

Added lines #L31 - L33 were not covered by tests
end
function Base.setindex!(t::AdjointTensorMap{T,S,N₁,N₂}, v,

Check warning on line 35 in src/tensors/adjoint.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/adjoint.jl#L35

Added line #L35 was not covered by tests
f₁::FusionTree{I,N₁},
Expand Down
46 changes: 12 additions & 34 deletions src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,10 @@ Construct the identity endomorphism on space `V`, i.e. return a `t::TensorMap` w
or `storagetype(t) = T` if `T` is a `DenseVector` type.
"""
id(V::TensorSpace) = id(Float64, V)
function id(::Type{A}, V::TensorSpace{S}) where {A,S}
function id(A::Type, V::TensorSpace{S}) where {S}
W = V V
if A <: Number
t = TensorMap{A}(undef, W)
elseif A <: DenseVector
T = scalartype(A)
N = length(codomain(W))
t = TensorMap{T,S,N,N,A}(undef, W)
else
throw(ArgumentError("`id` only supports Number or DenseVector subtypes as first argument"))
end
return one!(t)
return one!(tensormaptype(S, N, N, A)(undef, W))
end

"""
Expand All @@ -90,17 +82,10 @@ error will be thrown.
See also [`unitary`](@ref) when `InnerProductStyle(cod) === EuclideanProduct()`.
"""
function isomorphism(::Type{A}, V::TensorMapSpace{S,N₁,N₂}) where {A<:VecOrNumber,S,N₁,N₂}
function isomorphism(A::Type, V::TensorMapSpace{S,N₁,N₂}) where {S,N₁,N₂}
codomain(V) domain(V) ||
throw(SpaceMismatch("codomain and domain are not isomorphic: $V"))
if A <: Number
t = TensorMap{A}(undef, V)
elseif A <: DenseVector
T = scalartype(A)
t = TensorMap{T,S,N₁,N₂,A}(undef, V)
else
throw(ArgumentError("`isomorphism` only supports Number or DenseVector subtypes as first argument"))
end
t = tensormaptype(S, N₁, N₂, A)(undef, V)
for (_, b) in blocks(t)
MatrixAlgebra.one!(b)
end
Expand Down Expand Up @@ -141,18 +126,11 @@ isometric inclusion, an error will be thrown.
See also [`isomorphism`](@ref) and [`unitary`](@ref).
"""
function isometry(::Type{A}, V::TensorMapSpace{S,N₁,N₂}) where {A<:VecOrNumber,S,N₁,N₂}
function isometry(A::Type, V::TensorMapSpace{S,N₁,N₂}) where {S,N₁,N₂}
InnerProductStyle(S) === EuclideanProduct() || throw_invalid_innerproduct(:isometry)
domain(V) codomain(V) ||
throw(SpaceMismatch("$V does not allow for an isometric inclusion"))
if A <: Number
t = TensorMap{A}(undef, V)
elseif A <: DenseVector
T = scalartype(A)
t = TensorMap{T,S,N₁,N₂,A}(undef, V)
else
throw(ArgumentError("`isometry` only supports Number or DenseVector subtypes as first argument"))
end
t = tensormaptype(S, N₁, N₂, A)(undef, V)
for (_, b) in blocks(t)
MatrixAlgebra.one!(b)
end
Expand Down Expand Up @@ -485,9 +463,9 @@ function ⊗(t1::AbstractTensorMap, t2::AbstractTensorMap)
d2 = dim(cod2)
d3 = dim(dom1)
d4 = dim(dom2)
m1 = reshape(t1[], (d1, 1, d3, 1))
m2 = reshape(t2[], (1, d2, 1, d4))
m = reshape(t[], (d1, d2, d3, d4))
m1 = sreshape(t1[trivial_fusiontree(t1)...], (d1, 1, d3, 1))
m2 = sreshape(t2[trivial_fusiontree(t2)...], (1, d2, 1, d4))
m = sreshape(t[trivial_fusiontree(t)...], (d1, d2, d3, d4))

Check warning on line 468 in src/tensors/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/linalg.jl#L466-L468

Added lines #L466 - L468 were not covered by tests
m .= m1 .* m2
else
for (f1l, f1r) in fusiontrees(t1)
Expand All @@ -504,9 +482,9 @@ function ⊗(t1::AbstractTensorMap, t2::AbstractTensorMap)
d2 = dim(cod2, f2l.uncoupled)
d3 = dim(dom1, f1r.uncoupled)
d4 = dim(dom2, f2r.uncoupled)
m1 = reshape(t1[f1l, f1r], (d1, 1, d3, 1))
m2 = reshape(t2[f2l, f2r], (1, d2, 1, d4))
m = reshape(t[fl, fr], (d1, d2, d3, d4))
m1 = sreshape(t1[f1l, f1r], (d1, 1, d3, 1))
m2 = sreshape(t2[f2l, f2r], (1, d2, 1, d4))
m = sreshape(t[fl, fr], (d1, d2, d3, d4))

Check warning on line 487 in src/tensors/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/linalg.jl#L485-L487

Added lines #L485 - L487 were not covered by tests
m .+= coeff1 .* conj(coeff2) .* m1 .* m2
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ function tensormaptype(S::Type{<:IndexSpace}, N₁, N₂, TorA::Type)
elseif TorA <: DenseVector
return TensorMap{scalartype(TorA),S,N₁,N₂,TorA}
else
throw(ArgumentError("invalid type for TensorMap data: $TorA"))
throw(ArgumentError("argument $TorA should specify a scalar type (`<:Number`) or a storage type `<:DenseVector{<:Number}`"))

Check warning on line 50 in src/tensors/tensor.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/tensor.jl#L50

Added line #L50 was not covered by tests
end
end

Expand Down
Loading

0 comments on commit c1c174e

Please sign in to comment.