Skip to content

Commit

Permalink
small improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Dec 11, 2024
1 parent 187efb1 commit 993577d
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 14 deletions.
7 changes: 4 additions & 3 deletions src/operators/infinitepepo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@ end
Base.size(T::InfinitePEPO) = size(T.A)
Base.size(T::InfinitePEPO, i) = size(T.A, i)
Base.length(T::InfinitePEPO) = length(T.A)
Base.eltype(T::InfinitePEPO) = eltype(T.A[1])
VectorInterface.scalartype(T::InfinitePEPO) = scalartype(T.A)
Base.eltype(T::InfinitePEPO) = eltype(typeof(T))
Base.eltype(::Type{<:InfinitePEPO{T}}) where {T} = T
VectorInterface.scalartype(::Type{T}) where {T<:InfinitePEPO} = scalartype(eltype(T))

## Copy
Base.copy(T::InfinitePEPO) = InfinitePEPO(copy(T.A))
Base.similar(T::InfinitePEPO, args...) = InfinitePEPO(similar.(T.A, args...))
Base.similar(T::InfinitePEPO, args...) = InfinitePEPO(similar(T.A, args...))
Base.repeat(T::InfinitePEPO, counts...) = InfinitePEPO(repeat(T.A, counts...))

Base.getindex(T::InfinitePEPO, args...) = Base.getindex(T.A, args...)
Expand Down
7 changes: 4 additions & 3 deletions src/states/infinitepeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,13 @@ end
Base.size(T::InfinitePEPS) = size(T.A)
Base.size(T::InfinitePEPS, i) = size(T.A, i)
Base.length(T::InfinitePEPS) = length(T.A)
Base.eltype(T::InfinitePEPS) = eltype(T.A[1])
VectorInterface.scalartype(T::InfinitePEPS) = scalartype(T.A)
Base.eltype(T::InfinitePEPS) = eltype(typeof(T))
Base.eltype(::Type{<:InfinitePEPS{T}}) where {T} = T
VectorInterface.scalartype(::Type{T}) where {T<:InfinitePEPS} = scalartype(eltype(T))

## Copy
Base.copy(T::InfinitePEPS) = InfinitePEPS(copy(T.A))
Base.similar(T::InfinitePEPS, args...) = InfinitePEPS(similar.(T.A, args...))
Base.similar(T::InfinitePEPS, args...) = InfinitePEPS(similar(T.A, args...))
Base.repeat(T::InfinitePEPS, counts...) = InfinitePEPS(repeat(T.A, counts...))

Base.getindex(T::InfinitePEPS, args...) = Base.getindex(T.A, args...)
Expand Down
9 changes: 4 additions & 5 deletions src/states/infiniteweightpeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,17 @@ end
Base.size(W::SUWeight) = size(W.data)
Base.size(W::SUWeight, i) = size(W.data, i)
Base.length(W::SUWeight) = length(W.data)
Base.eltype(W::SUWeight) = eltype(W.data[1])
Base.eltype(W::SUWeight) = eltype(typeof(W))
Base.eltype(::Type{SUWeight{E}}) where {E} = E
VectorInterface.scalartype(::Type{T}) where {T<:SUWeight} = scalartype(eltype(T))

Base.getindex(W::SUWeight, args...) = Base.getindex(W.data, args...)
Base.setindex!(W::SUWeight, args...) = (Base.setindex!(W.data, args...); W)
Base.axes(W::SUWeight, args...) = axes(W.data, args...)

function compare_weights(wts1::SUWeight, wts2::SUWeight)
@assert size(wts1) == size(wts2)
wtdiff = sum(
_singular_value_distance((wt1, wt2)) for (wt1, wt2) in zip(wts1.data, wts2.data)
)
return wtdiff / length(wts1)
return sum(_singular_value_distance, zip(wts1.data, wts2.data)) / length(wts1)
end

"""
Expand Down
6 changes: 3 additions & 3 deletions src/utility/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function _elementwise_mult(a::AbstractTensorMap, b::AbstractTensorMap)
return dst
end

_safe_pow(a, pow, tol) = (pow < 0 && abs(a) < tol) ? zero(a) : a .^ pow
_safe_pow(a, pow, tol) = (pow < 0 && abs(a) < tol) ? zero(a) : a^pow

"""
sdiag_pow(S::AbstractTensorMap, pow::Real; tol::Real=eps(scalartype(S))^(3 / 4))
Expand Down Expand Up @@ -60,9 +60,9 @@ function ChainRulesCore.rrule(
)
tol *= norm(S, Inf)
spow = sdiag_pow(S, pow; tol)
spow_minus1_conj = sdiag_pow(S', pow - 1; tol)
spow_minus1_conj = scale!(sdiag_pow(S', pow - 1; tol), pow)
function sdiag_pow_pullback(c̄)
return (ChainRulesCore.NoTangent(), pow * _elementwise_mult(c̄, spow_minus1_conj))
return (ChainRulesCore.NoTangent(), _elementwise_mult(c̄, spow_minus1_conj))
end
return spow, sdiag_pow_pullback
end
Expand Down

0 comments on commit 993577d

Please sign in to comment.