Skip to content

Commit

Permalink
only depend on data for TensorMap
Browse files Browse the repository at this point in the history
  • Loading branch information
Jutho committed Sep 12, 2024
1 parent 0cd3fd1 commit c159984
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
15 changes: 14 additions & 1 deletion src/tensors/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,26 @@ LinearAlgebra.isdiag(t::AbstractTensorMap) = all(LinearAlgebra.isdiag, values(bl
# Wrapping the blocks in a StridedView enables multithreading if JULIA_NUM_THREADS > 1
# TODO: reconsider this strategy, consider spawning different threads for different blocks

# Copy, adjoint! and fill:
# Copy, adjoint and fill:
function Base.copy!(tdst::AbstractTensorMap, tsrc::AbstractTensorMap)
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
for c in blocksectors(tdst)
copy!(StridedView(block(tdst, c)), StridedView(block(tsrc, c)))
end
return tdst
end
function Base.copy!(tdst::TensorMap, tsrc::TensorMap)
space(tdst) == space(tsrc) || throw(SpaceMismatch("$(space(tdst))$(space(tsrc))"))
copy!(tdst.data, tsrc.data)
return tdst
end
function Base.fill!(t::AbstractTensorMap, value::Number)
for (c, b) in blocks(t)
fill!(b, value)
end
return t
end
function Base.fill!(t::TensorMap, value::Number)
fill!(t.data, value)
return t
end
Expand Down
4 changes: 2 additions & 2 deletions src/tensors/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,8 @@ function Base.convert(TT::Type{TensorMap{T,S,N₁,N₂,A}},
if typeof(t) === TT
return t
else
data = convert(A, t.data)
return TensorMap(data, space(t))
tnew = TT(undef, space(t))
return copy!(tnew, t)
end
end

Expand Down
27 changes: 27 additions & 0 deletions src/tensors/vectorinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ function VectorInterface.zerovector(t::AbstractTensorMap, ::Type{S}) where {S<:N
return zerovector!(similar(t, S))
end
function VectorInterface.zerovector!(t::AbstractTensorMap)
for (c, b) in blocks(t)
zerovector!(b)
end
return t
end
function VectorInterface.zerovector!(t::TensorMap)
zerovector!(t.data)
return t
end
Expand All @@ -20,6 +26,12 @@ function VectorInterface.scale(t::AbstractTensorMap, α::Number)
return scale!(similar(t, T), t, α)
end
function VectorInterface.scale!(t::AbstractTensorMap, α::Number)
for (c, b) in blocks(t)
scale!(b, α)
end
return t
end
function VectorInterface.scale!(t::TensorMap, α::Number)
scale!(t.data, α)
return t
end
Expand All @@ -29,6 +41,13 @@ function VectorInterface.scale!!(t::AbstractTensorMap, α::Number)
return T <: scalartype(t) ? scale!(t, α) : scale(t, α)
end
function VectorInterface.scale!(ty::AbstractTensorMap, tx::AbstractTensorMap, α::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
for ((cy, by), (cx, bx)) in zip(blocks(ty), blocks(tx))
scale!(by, bx, α)
end
return ty
end
function VectorInterface.scale!(ty::TensorMap, tx::TensorMap, α::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
scale!(ty.data, tx.data, α)
return ty
Expand All @@ -54,6 +73,14 @@ end
function VectorInterface.add!(ty::AbstractTensorMap, tx::AbstractTensorMap,
α::Number, β::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
for ((cy, by), (cx, bx)) in zip(blocks(ty), blocks(tx))
add!(StridedView(by), StridedView(bx), α, β)
end
return ty
end
function VectorInterface.add!(ty::TensorMap, tx::TensorMap,
α::Number, β::Number)
space(ty) == space(tx) || throw(SpaceMismatch("$(space(ty))$(space(tx))"))
add!(ty.data, tx.data, α, β)
return ty
end
Expand Down

0 comments on commit c159984

Please sign in to comment.