Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-mutating versions of pop, popfirst, etc. (#66) #68

Merged
merged 13 commits into from
Sep 14, 2020
101 changes: 100 additions & 1 deletion src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using Requires
using LinearAlgebra
using SparseArrays

using Base: OneTo
using Base: OneTo, @propagate_inbounds

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
Expand Down Expand Up @@ -543,6 +543,105 @@ function restructure(x::Array,y)
reshape(convert(Array,y),size(x)...)
end

"""
insert(collection, index, item)

Return a new instance of `collection` with `item` inserted into at the given `index`.
"""
Base.@propagate_inbounds function insert(collection, index, item)
@boundscheck checkbounds(collection, index)
ret = similar(collection, length(collection) + 1)
@inbounds for i in indices(ret)
Copy link
Collaborator

@chriselrod chriselrod Sep 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would split this into separate loops. The compiler may do so automatically, but I'd rather not rely on it.
Something like:

@inbounds for i in firstindex(ret):index-1
    ret[i] = collection[i]
end
@inbounds ret[index] = item
@inbounds for i in index+1:lastindex(ret)
    ret[i] = collection[i-1]
end

to make sure it SIMDs well if possible.

if i < index
ret[i] = collection[i]
elseif i == index
ret[i] = item
else
ret[i] = collection[i - 1]
end
end
return ret
end

function insert(x::Tuple, index::Integer, item)
@boundscheck if !checkindex(Bool, static_first(x):static_last(x), index)
throw(BoundsError(x, index))
end
return unsafe_insert(x, Int(index), item)
end

@inline function unsafe_insert(x::Tuple, i::Int, item)
if i === 1
return (item, x...)
else
return (first(x), unsafe_insert(Base.tail(x), i - 1, item)...)
end
end

"""
deleteat(collection, index)

Return a new instance of `collection` with the item at the given `index` removed.
"""
@propagate_inbounds function deleteat(collection::AbstractVector, index)
@boundscheck if !checkindex(Bool, eachindex(collection), index)
throw(BoundsError(collection, index))
end
return unsafe_deleteat(collection, index)
end
@propagate_inbounds function deleteat(collection::Tuple, index)
@boundscheck if !checkindex(Bool, static_first(collection):static_last(collection), index)
throw(BoundsError(collection, index))
end
return unsafe_deleteat(collection, index)
end

function unsafe_deleteat(src::AbstractVector, index::Integer)
dst = similar(src, length(src) - 1)
@inbounds for i in indices(dst)
if i < index
dst[i] = src[i]
else
dst[i] = src[i + 1]
end
end
return dst
end

@inline function unsafe_deleteat(src::AbstractVector, inds::AbstractVector)
dst = similar(src, length(src) - length(inds))
dst_index = firstindex(dst)
@inbounds for src_index in indices(src)
if !in(src_index, inds)
dst[dst_index] = src[src_index]
dst_index += one(dst_index)
end
end
return dst
end

@inline function unsafe_deleteat(src::Tuple, inds::AbstractVector)
dst = Vector{eltype(src)}(undef, length(src) - length(inds))
dst_index = firstindex(dst)
@inbounds for src_index in OneTo(length(src))
if !in(src_index, inds)
dst[dst_index] = src[src_index]
dst_index += one(dst_index)
end
end
return Tuple(dst)
end

@inline function unsafe_deleteat(x::Tuple, i::Integer)
if i === one(i)
return Base.tail(x)
elseif i == length(x)
return Base.front(x)
else
return (first(x), unsafe_deleteat(Base.tail(x), i - one(i))...)
end
end

function __init__()

@require SuiteSparse="4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
Expand Down
61 changes: 21 additions & 40 deletions src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,6 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)

# add methods to support ArrayInterface

_get(x) = x
_get(::Static{V}) where {V} = V
_get(::Type{Static{V}}) where {V} = V
_convert(::Type{T}, x) where {T} = convert(T, x)
_convert(::Type{T}, ::Val{V}) where {T,V} = Val(convert(T, V))

"""
OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T}

Expand All @@ -57,28 +51,23 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in
from other valid indices. Therefore, users should not expect the same checks are used
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
"""
struct OptionallyStaticUnitRange{T <: Integer, F <: Integer, L <: Integer} <: AbstractUnitRange{T}
struct OptionallyStaticUnitRange{F <: Integer, L <: Integer} <: AbstractUnitRange{Int}
start::F
stop::L

function OptionallyStaticUnitRange{T}(start, stop) where {T<:Real}
if _get(start) isa T
if _get(stop) isa T
return new{T,typeof(start),typeof(stop)}(start, stop)
function OptionallyStaticUnitRange(start, stop)
if eltype(start) <: Int
if eltype(stop) <: Int
return new{typeof(start),typeof(stop)}(start, stop)
else
return OptionallyStaticUnitRange{T}(start, _convert(T, stop))
return OptionallyStaticUnitRange(start, Int(stop))
end
else
return OptionallyStaticUnitRange{T}(_convert(T, start), stop)
return OptionallyStaticUnitRange(Int(start), stop)
end
end

function OptionallyStaticUnitRange(start, stop)
T = promote_type(typeof(_get(start)), typeof(_get(stop)))
return OptionallyStaticUnitRange{T}(start, stop)
end

function OptionallyStaticUnitRange(x::AbstractRange)
function OptionallyStaticUnitRange(x::AbstractRange)
if step(x) == 1
fst = static_first(x)
lst = static_last(x)
Expand All @@ -94,12 +83,12 @@ Base.:(:)(::Static{L}, U::Integer) where {L} = OptionallyStaticUnitRange(Static(
Base.:(:)(::Static{L}, ::Static{U}) where {L,U} = OptionallyStaticUnitRange(Static(L), Static(U))

Base.first(r::OptionallyStaticUnitRange) = r.start
Base.step(r::OptionallyStaticUnitRange{T}) where {T} = oneunit(T)
Base.step(::OptionallyStaticUnitRange) = Static(1)
Base.last(r::OptionallyStaticUnitRange) = r.stop

known_first(::Type{<:OptionallyStaticUnitRange{<:Any,Static{F}}}) where {F} = F
known_step(::Type{<:OptionallyStaticUnitRange{T}}) where {T} = one(T)
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,<:Any,Static{L}}}) where {L} = L
known_first(::Type{<:OptionallyStaticUnitRange{Static{F}}}) where {F} = F
known_step(::Type{<:OptionallyStaticUnitRange}) = 1
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,Static{L}}}) where {L} = L

function Base.isempty(r::OptionallyStaticUnitRange)
if known_first(r) === oneunit(eltype(r))
Expand All @@ -112,10 +101,8 @@ end
unsafe_isempty_one_to(lst) = lst <= zero(lst)
unsafe_isempty_unit_range(fst, lst) = fst > lst

unsafe_isempty_unit_range(fst::T, lst::T) where {T} = Integer(lst - fst + one(T))

unsafe_length_one_to(lst::T) where {T<:Int} = T(lst)
unsafe_length_one_to(lst::T) where {T} = Integer(lst - zero(lst))
unsafe_length_one_to(lst::Int) = lst
unsafe_length_one_to(::Static{L}) where {L} = lst

Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
if known_first(r) === oneunit(r)
Expand Down Expand Up @@ -144,15 +131,15 @@ end
@inline _try_static(::Static{M}, ::Static{N}) where {M, N} = @assert false "Unequal Indices: Static{$M}() != Static{$N}()"
function _try_static(::Static{N}, x) where {N}
@assert N == x "Unequal Indices: Static{$N}() != x == $x"
Static{N}()
return Static{N}()
end
function _try_static(x, ::Static{N}) where {N}
@assert N == x "Unequal Indices: x == $x != Static{$N}()"
Static{N}()
return Static{N}()
end
function _try_static(x, y)
@assert x == y "Unequal Indicess: x == $x != $y == y"
x
return x
end

###
Expand All @@ -172,24 +159,19 @@ end
end
end

function Base.length(r::OptionallyStaticUnitRange{T}) where {T}
function Base.length(r::OptionallyStaticUnitRange)
if isempty(r)
return zero(T)
return 0
else
if known_one(r) === one(T)
if known_first(r) === 0
return unsafe_length_one_to(last(r))
else
return unsafe_length_unit_range(first(r), last(r))
end
end
end

function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{Int,Int64,Int128}}
return Base.checked_add(Base.checked_sub(lst, fst), one(T))
end
function unsafe_length_unit_range(fst::T, lst::T) where {T<:Union{UInt,UInt64,UInt128}}
return Base.checked_add(lst - fst, one(T))
end
unsafe_length_unit_range(start::Integer, stop::Integer) = Int(start - stop + 1)

"""
indices(x[, d])
Expand Down Expand Up @@ -231,4 +213,3 @@ end
lst = _try_static(static_last(x), static_last(y))
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
end

3 changes: 3 additions & 0 deletions src/static.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,13 @@ end
Base.promote_rule(::Type{<:Static}, ::Type{<:Static}) = Int
Base.:(%)(::Static{N}, ::Type{Integer}) where {N} = N

Base.eltype(::Type{T}) where {T<:Static} = Int
Base.iszero(::Static{0}) = true
Base.iszero(::Static) = false
Base.isone(::Static{1}) = true
Base.isone(::Static) = false
Base.zero(::Type{T}) where {T<:Static} = Static{0}()
Copy link
Collaborator

@chriselrod chriselrod Sep 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we even want to define aliases like const Zero = Static{0}?

Base.one(::Type{T}) where {T<:Static} = Static{1}()

for T = [:Real, :Rational, :Integer]
@eval begin
Expand Down
22 changes: 22 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ end
@testset "Static" begin
@test iszero(Static(0))
@test !iszero(Static(1))
@test @inferred(one(Static)) === Static(1)
@test @inferred(zero(Static)) === Static(0)
@test eltype(one(Static)) <: Int
# test for ambiguities and correctness
for i ∈ [Static(0), Static(1), Static(2), 3]
for j ∈ [Static(0), Static(1), Static(2), 3]
Expand All @@ -271,3 +274,22 @@ end
end
end

@testset "insert/deleteat" begin
@test @inferred(ArrayInterface.insert([1,2,3], 2, -2)) == [1, -2, 2, 3]
@test @inferred(ArrayInterface.deleteat([1, 2, 3], 2)) == [1, 3]

@test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 2])) == [3]
@test @inferred(ArrayInterface.deleteat([1, 2, 3], [1, 3])) == [2]
@test @inferred(ArrayInterface.deleteat([1, 2, 3], [2, 3])) == [1]


@test @inferred(ArrayInterface.insert((1,2,3), 1, -2)) == (-2, 1, 2, 3)
@test @inferred(ArrayInterface.insert((1,2,3), 2, -2)) == (1, -2, 2, 3)
@test @inferred(ArrayInterface.insert((1,2,3), 3, -2)) == (1, 2, -2, 3)

@test @inferred(ArrayInterface.deleteat((1, 2, 3), 1)) == (2, 3)
@test @inferred(ArrayInterface.deleteat((1, 2, 3), 2)) == (1, 3)
@test @inferred(ArrayInterface.deleteat((1, 2, 3), 3)) == (1, 2)
@test ArrayInterface.deleteat((1, 2, 3), [1, 2]) == (3,)
end