Skip to content

Commit

Permalink
Implement submdspan_extents
Browse files Browse the repository at this point in the history
  • Loading branch information
miscco committed Jan 10, 2025
1 parent f371e69 commit 08ee8b7
Show file tree
Hide file tree
Showing 7 changed files with 743 additions and 0 deletions.
87 changes: 87 additions & 0 deletions libcudacxx/include/cuda/std/__mdspan/concepts.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,23 @@
#endif // no system header

#include <cuda/std/__concepts/concept_macros.h>
#include <cuda/std/__concepts/convertible_to.h>
#include <cuda/std/__concepts/copyable.h>
#include <cuda/std/__concepts/equality_comparable.h>
#include <cuda/std/__concepts/same_as.h>
#include <cuda/std/__tuple_dir/tuple_element.h>
#include <cuda/std/__tuple_dir/tuple_like.h>
#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/__type_traits/is_integral.h>
#include <cuda/std/__type_traits/is_nothrow_move_assignable.h>
#include <cuda/std/__type_traits/is_nothrow_move_constructible.h>
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/__type_traits/is_signed.h>
#include <cuda/std/__type_traits/is_swappable.h>
#include <cuda/std/__type_traits/is_unsigned.h>
#include <cuda/std/__type_traits/remove_const.h>
#include <cuda/std/__type_traits/remove_cvref.h>
#include <cuda/std/span>

#if _CCCL_STD_VER >= 2014

Expand All @@ -53,6 +67,50 @@ template <class _Layout, class _Mapping>
_CCCL_INLINE_VAR constexpr bool __is_mapping_of =
_CCCL_TRAIT(is_same, typename _Layout::template mapping<typename _Mapping::extents_type>, _Mapping);

// [mdspan.layout.reqmts]/1
# if _CCCL_STD_VER >= 2020
template <class _Mapping>
concept __layout_mapping_req_type =
copyable<_Mapping> && equality_comparable<_Mapping> && //
is_nothrow_move_constructible_v<_Mapping> && is_move_assignable_v<_Mapping> && is_nothrow_swappable_v<_Mapping>;
# else // ^^^ _CCCL_STD_VER >= 2020 ^^^ / vvv _CCCL_STD_VER <= 2017 vvv
template <class _Mapping>
_CCCL_CONCEPT_FRAGMENT(
__layout_mapping_req_type_,
requires()( //
requires(copyable<_Mapping>),
requires(equality_comparable<_Mapping>),
requires(_CCCL_TRAIT(is_nothrow_move_constructible, _Mapping)),
requires(_CCCL_TRAIT(is_move_assignable, _Mapping)),
requires(_CCCL_TRAIT(is_nothrow_swappable, _Mapping))));

template <class _Mapping>
_CCCL_CONCEPT __layout_mapping_req_type = _CCCL_FRAGMENT(__layout_mapping_req_type_, _Mapping);
# endif // _CCCL_STD_VER <= 2017

// [mdspan.layout.reqmts]/2-4
# if _CCCL_STD_VER >= 2020
template <class _Mapping>
concept __layout_mapping_req_types = requires {
requires __is_extents_v<typename _Mapping::extents_type>;
requires same_as<typename _Mapping::index_type, typename _Mapping::extents_type::index_type>;
requires same_as<typename _Mapping::rank_type, typename _Mapping::extents_type::rank_type>;
requires __is_mapping_of<typename _Mapping::layout_type, _Mapping>;
};
# else // ^^^ _CCCL_STD_VER >= 2020 ^^^ / vvv _CCCL_STD_VER <= 2017 vvv
template <class _Mapping>
_CCCL_CONCEPT_FRAGMENT(
__layout_mapping_req_types_,
requires()( //
requires(__is_extents_v<typename _Mapping::extents_type>),
requires(same_as<typename _Mapping::index_type, typename _Mapping::extents_type::index_type>),
requires(same_as<typename _Mapping::rank_type, typename _Mapping::extents_type::rank_type>),
requires(__is_mapping_of<typename _Mapping::layout_type, _Mapping>)));

template <class _Mapping>
_CCCL_CONCEPT __layout_mapping_req_types = _CCCL_FRAGMENT(__layout_mapping_req_types_, _Mapping);
# endif // _CCCL_STD_VER <= 2017

// [mdspan.layout.stride.expo]/4
# if _CCCL_STD_VER >= 2020
template <class _Mapping>
Expand Down Expand Up @@ -96,6 +154,35 @@ _CCCL_CONCEPT __layout_mapping_alike = _CCCL_FRAGMENT(__layout_mapping_alike_, _

} // namespace __mdspan_detail

# if _CCCL_STD_VER >= 2020

template <class _Tp, class _IndexType>
concept __index_pair_like =
__pair_like<_Tp> //
&& convertible_to<tuple_element_t<0, _Tp>, _IndexType> //
&& convertible_to<tuple_element_t<1, _Tp>, _IndexType>;

# else // ^^^ _CCCL_STD_VER >= 2020 ^^^ / vvv _CCCL_STD_VER <= 2017 vvv

template <class _Tp, class _IndexType>
_CCCL_CONCEPT_FRAGMENT(
__index_pair_like_,
requires()( //
requires(__pair_like<_Tp>),
requires(convertible_to<tuple_element_t<0, _Tp>, _IndexType>),
requires(convertible_to<tuple_element_t<1, _Tp>, _IndexType>) //
));
template <class _Tp, class _IndexType>
_CCCL_CONCEPT __index_pair_like = _CCCL_FRAGMENT(__index_pair_like_, _Tp, _IndexType);

# endif // _CCCL_STD_VER <= 2017

// [mdspan.submdspan.strided.slice]/3

template <class _Tp>
_CCCL_CONCEPT __index_like =
_CCCL_TRAIT(is_signed, _Tp) || _CCCL_TRAIT(is_unsigned, _Tp) || __integral_constant_like<_Tp>;

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _CCCL_STD_VER >= 2014
Expand Down
187 changes: 187 additions & 0 deletions libcudacxx/include/cuda/std/__mdspan/submdspan_extents.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___MDSPAN_SUBMDSPAN_EXTENTS_H
#define _LIBCUDACXX___MDSPAN_SUBMDSPAN_EXTENTS_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__concepts/concept_macros.h>
#include <cuda/std/__concepts/convertible_to.h>
#include <cuda/std/__fwd/mdspan.h>
#include <cuda/std/__mdspan/concepts.h>
#include <cuda/std/__mdspan/extents.h>
#include <cuda/std/__mdspan/submdspan_helper.h>
#include <cuda/std/__tuple_dir/tuple_like.h>
#include <cuda/std/__tuple_dir/tuple_size.h>
#include <cuda/std/__type_traits/is_integral.h>
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/__type_traits/is_signed.h>
#include <cuda/std/__type_traits/is_unsigned.h>
#include <cuda/std/__utility/integer_sequence.h>
#include <cuda/std/array>
#include <cuda/std/tuple>

#if _CCCL_STD_VER >= 2014

_LIBCUDACXX_BEGIN_NAMESPACE_STD

// [mdspan.sub.extents]
// [mdspan.sub.extents-4.2.2]
template <class _Extents, class _SliceType>
_CCCL_CONCEPT __subextents_is_index_pair = _CCCL_REQUIRES_EXPR((_Extents, _SliceType))(
requires(__index_pair_like<_SliceType, typename _Extents::index_type>),
requires(__integral_constant_like<tuple_element_t<0, _SliceType>>),
requires(__integral_constant_like<tuple_element_t<1, _SliceType>>));

// [mdspan.sub.extents-4.2.3]
template <class _Extents, class _SliceType>
_CCCL_CONCEPT __subextents_is_strided_slice_zero_extent = _CCCL_REQUIRES_EXPR((_Extents, _SliceType))(
requires(__is_strided_slice<_SliceType>),
requires(__integral_constant_like<typename _SliceType::extent_type>),
requires(typename _SliceType::extent_type() == 0));

// [mdspan.sub.extents-4.2.4]
template <class _Extents, class _SliceType>
_CCCL_CONCEPT __subextents_is_strided_slice = _CCCL_REQUIRES_EXPR((_Extents, _SliceType))(
requires(__is_strided_slice<_SliceType>),
requires(__integral_constant_like<typename _SliceType::extent_type>),
requires(__integral_constant_like<typename _SliceType::stride_type>));

struct __get_subextent
{
template <class _Extents, size_t _SliceIndex, class _SliceType>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI static constexpr size_t __get_static_subextents() noexcept
{
// [mdspan.sub.extents-4.2.1]
if constexpr (convertible_to<_SliceType, full_extent_t>)
{
return _Extents::static_extent(_SliceIndex);
}
// [mdspan.sub.extents-4.2.2]
else if constexpr (__subextents_is_index_pair<_Extents, _SliceType>)
{
return _CUDA_VSTD::__de_ice(tuple_element_t<1, _SliceType>())
- _CUDA_VSTD::__de_ice(tuple_element_t<0, _SliceType>());
}
// [mdspan.sub.extents-4.2.3]
else if constexpr (__subextents_is_strided_slice_zero_extent<_Extents, _SliceType>)
{
return 0;
}
// [mdspan.sub.extents-4.2.4]
else if constexpr (__subextents_is_strided_slice<_Extents, _SliceType>)
{
return 1
+ (_CUDA_VSTD::__de_ice(_SliceType::extent_type()) - 1) / _CUDA_VSTD::__de_ice(_SliceType::stride_type());
}
// [mdspan.sub.extents-4.2.5]
else
{
return dynamic_extent;
}
_CCCL_UNREACHABLE();
}

template <size_t _SliceIndex, class _Extent, class... _Slices>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI static constexpr typename _Extent::index_type
__get_dynamic_subextents(const _Extent& __src, _Slices... __slices) noexcept
{
using _SliceType = __get_slice_type<_SliceIndex, _Slices...>;
// [mdspan.sub.extents-5.1]
if constexpr (__is_strided_slice<_SliceType>)
{
_SliceType& __slice = _CUDA_VSTD::__get_slice_at<_SliceIndex>(__slices...);
return __slice.extent == 0
? 0
: 1 + (_CUDA_VSTD::__de_ice(__slice.extent) - 1) / _CUDA_VSTD::__de_ice(__slice.stride);
}
// [mdspan.sub.extents-5.2]
else
{
return _CUDA_VSTD::__last_extent_from_slice<_SliceIndex>(__src, __slices...)
- _CUDA_VSTD::__first_extent_from_slice<typename _Extent::index_type, _SliceIndex>(__slices...);
}
_CCCL_UNREACHABLE();
}

template <class _IndexType, class... _Slices>
static constexpr auto __map_rank = _CUDA_VSTD::__map_rank<_IndexType, _Slices...>();

template <class _Extent, class... _Slices, size_t... _SliceIndices, size_t _SliceIndex, size_t... _Remaining>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
__impl(index_sequence<_SliceIndices...>,
index_sequence<_SliceIndex, _Remaining...>,
const _Extent& __src,
_Slices... __slices) noexcept
{
using _IndexType = typename _Extent::index_type;
using _SliceType = __get_slice_type<_SliceIndex, _Slices...>;
if constexpr (convertible_to<_SliceType, typename _Extent::index_type>)
{
return __impl(index_sequence<_SliceIndices...>{}, index_sequence<_Remaining...>{}, __src, __slices...);
}
else
{
return __impl(
index_sequence<_SliceIndices..., _SliceIndex>{}, index_sequence<_Remaining...>{}, __src, __slices...);
}
_CCCL_UNREACHABLE();
}

template <class _Extent, class... _Slices, size_t... _SliceIndices>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
__impl(index_sequence<_SliceIndices...>, index_sequence<>, const _Extent& __src, _Slices... __slices) noexcept
{
using _IndexType = typename _Extent::index_type;
using _SubExtents =
extents<_IndexType,
__get_static_subextents<_Extent, _SliceIndices, __get_slice_type<_SliceIndices, _Slices...>>()...>;
return _SubExtents{__get_dynamic_subextents<_SliceIndices>(__src, __slices...)...};
}

template <class _Extent, class... _Slices>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto operator()(const _Extent& __src, _Slices... __slices) noexcept
{
return __impl(index_sequence<>{}, _CUDA_VSTD::index_sequence_for<_Slices...>(), __src, __slices...);
}
};

template <class _IndexType, class _SliceType>
_CCCL_INLINE_VAR constexpr bool __is_valid_subextents =
convertible_to<_SliceType, _IndexType> || __index_pair_like<_SliceType, _IndexType>
|| _CCCL_TRAIT(is_convertible, _SliceType, full_extent_t) || __is_strided_slice<_SliceType>;

_CCCL_TEMPLATE(class _Extents, class... _Slices)
_CCCL_REQUIRES((_Extents::rank() == sizeof...(_Slices)))
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto submdspan_extents(const _Extents& __src, _Slices... __slices)
{
static_assert(_CCCL_FOLD_AND((__is_valid_subextents<typename _Extents::index_type, _Slices>) ),
"[mdspan.sub.extents] For each rank index k of src.extents(), exactly one of the following is true:");
return __get_subextent{}(__src, __slices...);
}

template <class _Extents, class... _Slices>
using __get_subextents_t =
decltype(_CUDA_VSTD::submdspan_extents(_CUDA_VSTD::declval<_Extents>(), _CUDA_VSTD::declval<_Slices>()...));

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _CCCL_STD_VER >= 2014

#endif // _LIBCUDACXX___MDSPAN_SUBMDSPAN_EXTENTS_H
Loading

0 comments on commit 08ee8b7

Please sign in to comment.