From 24c7636a4eb55cc8f4fe8e56faf234e8fdd2c53c Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Sun, 10 Nov 2024 07:56:39 -0500 Subject: [PATCH] Add and test 2-arg `complex` --- src/tensors/tensor.jl | 3 +++ test/tensors.jl | 17 +++++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/tensors/tensor.jl b/src/tensors/tensor.jl index 928ad99a..60195b91 100644 --- a/src/tensors/tensor.jl +++ b/src/tensors/tensor.jl @@ -392,6 +392,9 @@ function Base.complex(t::AbstractTensorMap) return copy!(similar(t, complex(scalartype(t))), t) end end +function Base.complex(r::AbstractTensorMap{<:Real}, i::AbstractTensorMap{<:Real}) + return add(r, i, im * one(scalartype(i))) +end # Conversion between TensorMap and Dict, for read and write purpose #------------------------------------------------------------------ diff --git a/test/tensors.jl b/test/tensors.jl index 485eb882..99d73b7c 100644 --- a/test/tensors.jl +++ b/test/tensors.jl @@ -185,8 +185,21 @@ for V in spacelist W = V1 ⊗ V2 for T in (Float64, ComplexF64, ComplexF32) t = @constinferred randn(T, W, W) - @test real(convert(Array, t)) == convert(Array, @constinferred real(t)) - @test imag(convert(Array, t)) == convert(Array, @constinferred imag(t)) + + tr = @constinferred real(t) + @test scalartype(tr) <: Real + @test real(convert(Array, t)) == convert(Array, tr) + + ti = @constinferred imag(t) + @test scalartype(ti) <: Real + @test imag(convert(Array, t)) == convert(Array, ti) + + tc = @inferred complex(t) + @test scalartype(tc) <: Complex + @test complex(convert(Array, t)) == convert(Array, tc) + + tc2 = @inferred complex(tr, ti) + @test tc2 ≈ tc end end end