diff --git a/libcudacxx/include/cuda/std/__type_traits/is_floating_point.h b/libcudacxx/include/cuda/std/__type_traits/is_floating_point.h index 913bacdb2a6..6baa8b115e1 100644 --- a/libcudacxx/include/cuda/std/__type_traits/is_floating_point.h +++ b/libcudacxx/include/cuda/std/__type_traits/is_floating_point.h @@ -21,6 +21,7 @@ #endif // no system header #include +#include #include _LIBCUDACXX_BEGIN_NAMESPACE_STD @@ -37,6 +38,16 @@ struct __cccl_is_floating_point : public true_type template <> struct __cccl_is_floating_point : public true_type {}; +#ifdef _LIBCUDACXX_HAS_NVFP16 +template <> +struct __cccl_is_floating_point<__half> : public true_type +{}; +#endif // _LIBCUDACXX_HAS_NVFP16 +#ifdef _LIBCUDACXX_HAS_NVBF16 +template <> +struct __cccl_is_floating_point<__nv_bfloat16> : public true_type +{}; +#endif // _LIBCUDACXX_HAS_NVBF16 template struct _CCCL_TYPE_VISIBILITY_DEFAULT is_floating_point : public __cccl_is_floating_point> diff --git a/libcudacxx/test/libcudacxx/std/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp b/libcudacxx/test/libcudacxx/std/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp index 1baff1fe485..3c0e72aae93 100644 --- a/libcudacxx/test/libcudacxx/std/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp +++ b/libcudacxx/test/libcudacxx/std/utilities/meta/meta.unary/meta.unary.cat/is_floating_point.pass.cpp @@ -82,6 +82,12 @@ int main(int, char**) test_is_floating_point(); test_is_floating_point(); test_is_floating_point(); +#ifdef _LIBCUDACXX_HAS_NVFP16 + test_is_floating_point<__half>(); +#endif // _LIBCUDACXX_HAS_NVFP16 +#ifdef _LIBCUDACXX_HAS_NVBF16 + test_is_floating_point<__nv_bfloat16>(); +#endif // _LIBCUDACXX_HAS_NVBF16 test_is_not_floating_point(); test_is_not_floating_point();