Skip to content

Commit

Permalink
Specialize is_floating_point for half and bfloat
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jan 14, 2025
1 parent d5d3aa6 commit 6b3fa17
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
11 changes: 11 additions & 0 deletions libcudacxx/include/cuda/std/__type_traits/is_floating_point.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#endif // no system header

#include <cuda/std/__type_traits/integral_constant.h>
#include <cuda/std/__type_traits/is_extended_floating_point.h>
#include <cuda/std/__type_traits/remove_cv.h>

_LIBCUDACXX_BEGIN_NAMESPACE_STD
Expand All @@ -37,6 +38,16 @@ struct __cccl_is_floating_point<double> : public true_type
template <>
struct __cccl_is_floating_point<long double> : public true_type
{};
#ifdef _LIBCUDACXX_HAS_NVFP16
template <>
struct __cccl_is_floating_point<__half> : public true_type
{};
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
template <>
struct __cccl_is_floating_point<__nv_bfloat16> : public true_type
{};
#endif // _LIBCUDACXX_HAS_NVBF16

template <class _Tp>
struct _CCCL_TYPE_VISIBILITY_DEFAULT is_floating_point : public __cccl_is_floating_point<remove_cv_t<_Tp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ int main(int, char**)
test_is_floating_point<float>();
test_is_floating_point<double>();
test_is_floating_point<long double>();
#ifdef _LIBCUDACXX_HAS_NVFP16
test_is_floating_point<__half>();
#endif // _LIBCUDACXX_HAS_NVFP16
#ifdef _LIBCUDACXX_HAS_NVBF16
test_is_floating_point<__nv_bfloat16>();
#endif // _LIBCUDACXX_HAS_NVBF16

test_is_not_floating_point<short>();
test_is_not_floating_point<unsigned short>();
Expand Down

0 comments on commit 6b3fa17

Please sign in to comment.