Skip to content

Commit

Permalink
Implement submdspan_mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
miscco committed Jan 10, 2025
1 parent 08ee8b7 commit a503b53
Show file tree
Hide file tree
Showing 6 changed files with 957 additions and 13 deletions.
15 changes: 2 additions & 13 deletions libcudacxx/include/cuda/std/__mdspan/submdspan_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,28 +84,17 @@ struct full_extent_t
};
_CCCL_GLOBAL_CONSTANT full_extent_t full_extent{};

// [mdspan.submdspan.submdspan.mapping.result]
template <class _LayoutMapping>
struct submdspan_mapping_result
{
static_assert(true, // __is_layout_mapping<_LayoutMapping>,
"[mdspan.submdspan.submdspan.mapping.result] shall meet the layout mapping requirements");

_CCCL_NO_UNIQUE_ADDRESS _LayoutMapping mapping{};
size_t offset{};
};

// [mdspan.submdspan.helpers]
_CCCL_TEMPLATE(class _Tp)
_CCCL_REQUIRES((!__integral_constant_like<_Tp>) )
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __de_ice(_Tp __val)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __de_ice(_Tp __val) noexcept
{
return __val;
}

_CCCL_TEMPLATE(class _Tp)
_CCCL_REQUIRES(__integral_constant_like<_Tp>)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __de_ice(_Tp)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto __de_ice(_Tp) noexcept
{
return _Tp::value;
}
Expand Down
332 changes: 332 additions & 0 deletions libcudacxx/include/cuda/std/__mdspan/submdspan_mapping.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,332 @@
//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___MDSPAN_SUBMDSPAN_MAPPING_H
#define _LIBCUDACXX___MDSPAN_SUBMDSPAN_MAPPING_H

#include <cuda/std/detail/__config>

#include "cuda/std/__cccl/unreachable.h"

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__fwd/mdspan.h>
#include <cuda/std/__mdspan/concepts.h>
#include <cuda/std/__mdspan/extents.h>
#include <cuda/std/__mdspan/layout_left.h>
#include <cuda/std/__mdspan/layout_right.h>
#include <cuda/std/__mdspan/layout_stride.h>
#include <cuda/std/__mdspan/mdspan.h>
#include <cuda/std/__mdspan/submdspan_extents.h>
#include <cuda/std/__mdspan/submdspan_helper.h>
#include <cuda/std/__type_traits/make_unsigned.h>
#include <cuda/std/__type_traits/remove_const.h>
#include <cuda/std/__type_traits/type_list.h>
#include <cuda/std/__utility/integer_sequence.h>
#include <cuda/std/array>

#if _CCCL_STD_VER >= 2014

_LIBCUDACXX_BEGIN_NAMESPACE_STD

// [mdspan.sub.map]

// [mdspan.submdspan.submdspan.mapping.result]
template <class _LayoutMapping>
struct submdspan_mapping_result
{
static_assert(true, // __is_layout_mapping<_LayoutMapping>,
"[mdspan.submdspan.submdspan.mapping.result] shall meet the layout mapping requirements");

_CCCL_NO_UNIQUE_ADDRESS _LayoutMapping mapping{};
size_t offset{};
};

// [mdspan.sub.map.common]
_CCCL_TEMPLATE(size_t _SliceIndex, class _LayoutMapping, class... _SliceSpecifiers)
_CCCL_REQUIRES(__is_strided_slice<__get_slice_type<_SliceIndex, _SliceSpecifiers...>>)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
__get_submdspan_strides(const _LayoutMapping& __mapping, _SliceSpecifiers... __slices) noexcept
{
using _SliceType = __get_slice_type<_SliceIndex, _SliceSpecifiers...>;
_SliceType& __slice = _CUDA_VSTD::__get_slice_at<_SliceIndex>(__slices...);

using __unsigned_stride = make_unsigned_t<typename _SliceType::stride_type>;
using __unsigned_extent = make_unsigned_t<typename _SliceType::extent_type>;
return __mapping.stride(_SliceIndex)
* (static_cast<__unsigned_stride>(__slice.stride) < static_cast<__unsigned_extent>(__slice.extent)
? _CUDA_VSTD::__de_ice(__slice.stride)
: 1);
}

_CCCL_TEMPLATE(size_t _SliceIndex, class _LayoutMapping, class... _SliceSpecifiers)
_CCCL_REQUIRES((!__is_strided_slice<__get_slice_type<_SliceIndex, _SliceSpecifiers...>>) )
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
__get_submdspan_strides(const _LayoutMapping& __mapping, _SliceSpecifiers...) noexcept
{
return __mapping.stride(_SliceIndex);
}

template <class _LayoutMapping, class... _SliceSpecifiers, size_t... _SliceIndexes>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto __submdspan_strides(
index_sequence<_SliceIndexes...>, const _LayoutMapping& __mapping, _SliceSpecifiers... __slices) noexcept
{
using _Extents = typename _LayoutMapping::extents_type;
using _IndexType = typename _Extents::index_type;
constexpr auto __map_rank_ = _CUDA_VSTD::__map_rank<_IndexType, _SliceSpecifiers...>();
const array<_IndexType, _Extents::rank()> __arr = {
_CUDA_VSTD::__get_submdspan_strides<_SliceIndexes>(__mapping, __slices...)...};

using _SubExtent = __get_subextents_t<_Extents, _SliceSpecifiers...>;
array<_IndexType, _SubExtent::rank()> __res = {};
for (size_t __index = 0; __index != _SubExtent::rank(); ++__index)
{
if (__map_rank_[__index] != dynamic_extent)
{
__res[__map_rank_[__index]] = __arr[__index];
}
}
return __res;
}

template <class _LayoutMapping, class... _SliceSpecifiers>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
__submdspan_strides(const _LayoutMapping& __mapping, _SliceSpecifiers... __slices)
{
return _CUDA_VSTD::__submdspan_strides(_CUDA_VSTD::index_sequence_for<_SliceSpecifiers...>(), __mapping, __slices...);
}

// [mdspan.sub.map.common-8]
template <class _LayoutMapping, class... _SliceSpecifiers, size_t... _SliceIndexes>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr size_t
__submdspan_offset(index_sequence<_SliceIndexes...>, const _LayoutMapping& __mapping, _SliceSpecifiers... __slices)
{
using _Extents = typename _LayoutMapping::extents_type;
using _IndexType = typename _Extents::index_type;
// If first_<index_type, k>(slices...)
const array<_IndexType, _Extents::rank()> __offsets = {
_CUDA_VSTD::__first_extent_from_slice<_IndexType, _SliceIndexes>(__slices...)...};

using _SubExtent = __get_subextents_t<_Extents, _SliceSpecifiers...>;
for (size_t __index = 0; __index != _SubExtent::rank(); ++__index)
{
// If first_<index_type, k>(slices...) equals extents().extent(k) for any rank index k of extents()
if (__offsets[__index] == __mapping.extents().extent(__index))
{
// then let offset be a value of type size_t equal to (*this).required_span_size()
return static_cast<size_t>(__mapping.required_span_size());
}
}
// Otherwise, let offset be a value of type size_t equal to (*this)(first_<index_type, P>(slices...)...).
return static_cast<size_t>(__mapping(__offsets[_SliceIndexes]...));
}

template <class _LayoutMapping, class... _SliceSpecifiers>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr size_t
__submdspan_offset(const _LayoutMapping& __mapping, _SliceSpecifiers... __slices)
{
return _CUDA_VSTD::__submdspan_offset(_CUDA_VSTD::index_sequence_for<_SliceSpecifiers...>(), __mapping, __slices...);
}

// [mdspan.sub.map.common-9]
// [mdspan.sub.map.common-9.1]
template <class _SliceType>
_CCCL_CONCEPT __is_strided_slice_stride_of_one = _CCCL_REQUIRES_EXPR((_SliceType))(
requires(__is_strided_slice<_SliceType>),
requires(__integral_constant_like<typename _SliceType::stride_type>),
requires(_SliceType::stride_type::value == 1));

template <class _LayoutMapping, class _SliceType>
_LIBCUDACXX_HIDE_FROM_ABI constexpr bool __is_unit_stride_slice()
{
// [mdspan.sub.map.common-9.1]
if constexpr (__is_strided_slice_stride_of_one<_SliceType>)
{
return true;
}
// [mdspan.sub.map.common-9.2]
else if constexpr (__index_pair_like<_SliceType, typename _LayoutMapping::index_type>)
{
return true;
}
// [mdspan.sub.map.common-9.3]
else if constexpr (_CCCL_TRAIT(is_convertible, _SliceType, full_extent_t))
{
return true;
}
else
{
return false;
}
_CCCL_UNREACHABLE();
}

// [mdspan.sub.map.left]
template <class _LayoutMapping, class _SubExtents, class... _SliceSpecifiers>
_LIBCUDACXX_HIDE_FROM_ABI constexpr bool __can_layout_left()
{
// [mdspan.sub.map.left-1.2]
if constexpr (_SubExtents::rank() == 0)
{
return true;
}
// [mdspan.sub.map.left-1.3.1]
// Note we can simplify metaprogramming here a bit because unit-stride slice is true if that condition holds
else if constexpr (_CCCL_FOLD_AND(_CCCL_TRAIT(is_convertible, _SliceSpecifiers, full_extent_t)))
{
return true;
}
else
{
// [mdspan.sub.map.left-1.3.2]
return _CUDA_VSTD::__is_unit_stride_slice<_LayoutMapping,
__type_index_c<_SubExtents::rank() - 1, _SliceSpecifiers...>>();
}
_CCCL_UNREACHABLE();
}

template <class _Extents, class... _SliceSpecifiers>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
__submdspan_mapping_impl(const typename layout_left::mapping<_Extents>& __mapping, _SliceSpecifiers... __slices)
{
// [mdspan.sub.map.left-1.1]
if constexpr (_Extents::rank() == 0)
{
return submdspan_mapping_result{__mapping, 0};
}

// [mdspan.sub.map.left-1.2]
// [mdspan.sub.map.left-1.3]
using _SubExtents = __get_subextents_t<_Extents, _SliceSpecifiers...>;
const auto __sub_ext = _CUDA_VSTD::submdspan_extents(__mapping.extents(), __slices...);
const auto __offset = _CUDA_VSTD::__submdspan_offset(__mapping, __slices...);
if constexpr (_CUDA_VSTD::
__can_layout_left<typename layout_left::mapping<_Extents>, _SubExtents, _SliceSpecifiers...>())
{
return submdspan_mapping_result<layout_left::mapping<_SubExtents>>{layout_left::mapping{__sub_ext}, __offset};
}
// [mdspan.sub.map.left-1.4]
// TODO: Implement padded layouts
else
{
// [mdspan.sub.map.left-1.5]
const auto __sub_strides = _CUDA_VSTD::__submdspan_strides(__mapping, __slices...);
return submdspan_mapping_result<layout_stride::mapping<_SubExtents>>{
layout_stride::mapping{__sub_ext, __sub_strides}, __offset};
}
_CCCL_UNREACHABLE();
}

template <class _LayoutMapping, class _SubExtents, class... _SliceSpecifiers>
_LIBCUDACXX_HIDE_FROM_ABI constexpr bool __can_layout_right()
{
// [mdspan.sub.map.right-1.2]
if constexpr (_SubExtents::rank() == 0)
{
return true;
}
// [mdspan.sub.map.right-1.3.1]
// Note we can simplify metaprogramming here a bit because unit-stride slice is true if that condition holds
else if constexpr (_CCCL_FOLD_AND(_CCCL_TRAIT(is_convertible, _SliceSpecifiers, full_extent_t)))
{
return true;
}
else
{
// [mdspan.sub.map.right-1.3.2]
return _CUDA_VSTD::__is_unit_stride_slice<_LayoutMapping,
__type_index_c<_SubExtents::rank() - 1, _SliceSpecifiers...>>();
}
_CCCL_UNREACHABLE();
}

template <class _Extents, class... _SliceSpecifiers>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
__submdspan_mapping_impl(const typename layout_right::mapping<_Extents>& __mapping, _SliceSpecifiers... __slices)
{
// [mdspan.sub.map.right-1.1]
if constexpr (_Extents::rank() == 0)
{
return submdspan_mapping_result{__mapping, 0};
}
else
{
// [mdspan.sub.map.right-1.2]
// [mdspan.sub.map.right-1.3]
using _SubExtents = __get_subextents_t<_Extents, _SliceSpecifiers...>;
const auto __sub_ext = _CUDA_VSTD::submdspan_extents(__mapping.extents(), __slices...);
const auto __offset = _CUDA_VSTD::__submdspan_offset(__mapping, __slices...);
if constexpr (_CUDA_VSTD::
__can_layout_right<typename layout_left::mapping<_Extents>, _SubExtents, _SliceSpecifiers...>())
{
return submdspan_mapping_result<layout_right::mapping<_SubExtents>>{layout_right::mapping{__sub_ext}, __offset};
}
// [mdspan.sub.map.right-1.4]
// TODO: Implement padded layouts
else
{
// [mdspan.sub.map.right-1.5]
const auto __sub_strides = _CUDA_VSTD::__submdspan_strides(__mapping, __slices...);
return submdspan_mapping_result<layout_stride::mapping<_SubExtents>>{
layout_stride::mapping{__sub_ext, __sub_strides}, __offset};
}
}
_CCCL_UNREACHABLE();
}

template <class _Extents, class... _SliceSpecifiers>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
__submdspan_mapping_impl(const typename layout_stride::mapping<_Extents>& __mapping, _SliceSpecifiers... __slices)
{
// [mdspan.sub.map.stride-1.1]
if constexpr (_Extents::rank() == 0)
{
return submdspan_mapping_result{__mapping, 0};
}
else
{
// [mdspan.sub.map.stride-1.2]
using _SubExtents = __get_subextents_t<_Extents, _SliceSpecifiers...>;
const auto __sub_ext = _CUDA_VSTD::submdspan_extents(__mapping.extents(), __slices...);
const auto __offset = _CUDA_VSTD::__submdspan_offset(__mapping, __slices...);
const auto __sub_strides = _CUDA_VSTD::__submdspan_strides(__mapping, __slices...);
return submdspan_mapping_result{layout_stride::mapping{__sub_ext, __sub_strides}, __offset};
}
}

template <class _LayoutMapping, class... _SliceSpecifiers>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
submdspan_mapping(const _LayoutMapping& __mapping, _SliceSpecifiers... __slices)
{
return _CUDA_VSTD::__submdspan_mapping_impl(__mapping, __slices...);
}

_CCCL_TEMPLATE(class _Tp, class _Extents, class _Layout, class _Accessor, class... _SliceSpecifiers)
_CCCL_REQUIRES(true)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI constexpr auto
submdspan(const mdspan<_Tp, _Extents, _Layout, _Accessor>& __src, _SliceSpecifiers... __slices)
{
auto __sub_map_result = _CUDA_VSTD::submdspan_mapping(__src.mapping(), __slices...);
return mdspan(__src.accessor().offset(__src.data_handle(), __sub_map_result.offset),
__sub_map_result.mapping,
typename _Accessor::offset_policy(__src.accessor()));
}

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _CCCL_STD_VER >= 2014

#endif // _LIBCUDACXX___MDSPAN_SUBMDSPAN_MAPPING_H
1 change: 1 addition & 0 deletions libcudacxx/include/cuda/std/mdspan
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ _CCCL_PUSH_MACROS
#include <cuda/std/__mdspan/mdspan.h>
#include <cuda/std/__mdspan/submdspan_extents.h>
#include <cuda/std/__mdspan/submdspan_helper.h>
#include <cuda/std/__mdspan/submdspan_mapping.h>
#include <cuda/std/version>

_CCCL_POP_MACROS
Expand Down
Loading

0 comments on commit a503b53

Please sign in to comment.