Skip to content

Commit

Permalink
Implement math roots functions
Browse files Browse the repository at this point in the history
  • Loading branch information
miscco committed Jan 14, 2025
1 parent 3ee6a0c commit 12c7fbb
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 18 deletions.
27 changes: 27 additions & 0 deletions libcudacxx/include/cuda/std/__cccl/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,20 @@
# undef _CCCL_BUILTIN_BSWAP128
#endif // _CCCL_CUDA_COMPILER(NVCC)

#if _CCCL_CHECK_BUILTIN(builtin_cbrt) || _CCCL_COMPILER(GCC)
# define _CCCL_BUILTIN_CBRTF(...) __builtin_cbrtf(__VA_ARGS__)
# define _CCCL_BUILTIN_CBRT(...) __builtin_cbrt(__VA_ARGS__)
# define _CCCL_BUILTIN_CBRTL(...) __builtin_cbrtl(__VA_ARGS__)
#endif // _CCCL_CHECK_BUILTIN(builtin_cbrt)

// Below 11.7 nvcc treats the builtin as a host only function
// clang-cuda fails with fatal error: error in backend: Undefined external symbol "cbrt"
#if _CCCL_CUDACC_BELOW(11, 7) || _CCCL_CUDA_COMPILER(CLANG)
# undef _CCCL_BUILTIN_CBRTF
# undef _CCCL_BUILTIN_CBRT
# undef _CCCL_BUILTIN_CBRTL
#endif // _CCCL_CUDACC_BELOW(11, 7) || _CCCL_CUDA_COMPILER(CLANG)

#if _CCCL_CHECK_BUILTIN(builtin_ceil) || _CCCL_COMPILER(GCC)
# define _CCCL_BUILTIN_CEILF(...) __builtin_ceilf(__VA_ARGS__)
# define _CCCL_BUILTIN_CEIL(...) __builtin_ceil(__VA_ARGS__)
Expand Down Expand Up @@ -576,6 +590,19 @@
# undef _CCCL_BUILTIN_SIGNBIT
#endif // _CCCL_CUDACC_BELOW(11, 7)

#if _CCCL_CHECK_BUILTIN(builtin_sqrt) || _CCCL_COMPILER(GCC)
# define _CCCL_BUILTIN_SQRTF(...) __builtin_sqrtf(__VA_ARGS__)
# define _CCCL_BUILTIN_SQRT(...) __builtin_sqrt(__VA_ARGS__)
# define _CCCL_BUILTIN_SQRTL(...) __builtin_sqrtl(__VA_ARGS__)
#endif // _CCCL_CHECK_BUILTIN(builtin_sqrt)

// Below 11.7 nvcc treats the builtin as a host only function
#if _CCCL_CUDACC_BELOW(11, 7)
# undef _CCCL_BUILTIN_SQRTF
# undef _CCCL_BUILTIN_SQRT
# undef _CCCL_BUILTIN_SQRTL
#endif // _CCCL_CUDACC_BELOW(11, 7)

#if _CCCL_CHECK_BUILTIN(builtin_trunc) || _CCCL_COMPILER(GCC)
# define _CCCL_BUILTIN_TRUNCF(...) __builtin_truncf(__VA_ARGS__)
# define _CCCL_BUILTIN_TRUNC(...) __builtin_trunc(__VA_ARGS__)
Expand Down
5 changes: 0 additions & 5 deletions libcudacxx/include/cuda/std/__cmath/nvbf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ _LIBCUDACXX_HIDE_FROM_ABI __nv_bfloat16 atan2(__nv_bfloat16 __x, __nv_bfloat16 _
return __float2bfloat16(::atan2f(__bfloat162float(__x), __bfloat162float(__y)));
}

_LIBCUDACXX_HIDE_FROM_ABI __nv_bfloat16 sqrt(__nv_bfloat16 __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2bfloat16(::sqrtf(__bfloat162float(__x)));))
}

// floating point helper
_LIBCUDACXX_HIDE_FROM_ABI __nv_bfloat16 __constexpr_copysign(__nv_bfloat16 __x, __nv_bfloat16 __y) noexcept
{
Expand Down
5 changes: 0 additions & 5 deletions libcudacxx/include/cuda/std/__cmath/nvfp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,6 @@ _LIBCUDACXX_HIDE_FROM_ABI __half atan2(__half __x, __half __y)
return __float2half(::atan2f(__half2float(__x), __half2float(__y)));
}

_LIBCUDACXX_HIDE_FROM_ABI __half sqrt(__half __x)
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2half(::sqrtf(__half2float(__x)));))
}

// floating point helper
_LIBCUDACXX_HIDE_FROM_ABI __half __constexpr_copysign(__half __x, __half __y) noexcept
{
Expand Down
171 changes: 171 additions & 0 deletions libcudacxx/include/cuda/std/__cmath/roots.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// -*- C++ -*-
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX___CMATH_ROOTS_H
#define _LIBCUDACXX___CMATH_ROOTS_H

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
# pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
# pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
# pragma system_header
#endif // no system header

#include <cuda/std/__cmath/common.h>
#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/is_integral.h>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

// sqrt

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI float sqrt(float __x) noexcept
{
#if defined(_CCCL_BUILTIN_SQRTF)
return _CCCL_BUILTIN_SQRTF(__x);
#else // ^^^ _CCCL_BUILTIN_SQRTF ^^^ // vvv !_CCCL_BUILTIN_SQRTF vvv
return ::sqrtf(__x);
#endif // !_CCCL_BUILTIN_SQRTF
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI float sqrtf(float __x) noexcept
{
#if defined(_CCCL_BUILTIN_SQRTF)
return _CCCL_BUILTIN_SQRTF(__x);
#else // ^^^ _CCCL_BUILTIN_SQRTF ^^^ // vvv !_CCCL_BUILTIN_SQRTF vvv
return ::sqrtf(__x);
#endif // !_CCCL_BUILTIN_SQRTF
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI double sqrt(double __x) noexcept
{
#if defined(_CCCL_BUILTIN_SQRT)
return _CCCL_BUILTIN_SQRT(__x);
#else // ^^^ _CCCL_BUILTIN_SQRT ^^^ // vvv !_CCCL_BUILTIN_SQRT vvv
return ::sqrt(__x);
#endif // !_CCCL_BUILTIN_SQRT
}

#if !defined(_LIBCUDACXX_HAS_NO_LONG_DOUBLE)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI long double sqrt(long double __x) noexcept
{
# if defined(_CCCL_BUILTIN_SQRTL)
return _CCCL_BUILTIN_SQRTL(__x);
# else // ^^^ _CCCL_BUILTIN_SQRTL ^^^ // vvv !_CCCL_BUILTIN_SQRTL vvv
return ::sqrtl(__x);
# endif // !_CCCL_BUILTIN_SQRTL
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI long double sqrtl(long double __x) noexcept
{
# if defined(_CCCL_BUILTIN_SQRTL)
return _CCCL_BUILTIN_SQRTL(__x);
# else // ^^^ _CCCL_BUILTIN_SQRTL ^^^ // vvv !_CCCL_BUILTIN_SQRTL vvv
return ::sqrtl(__x);
# endif // !_CCCL_BUILTIN_SQRTL
}
#endif // !_LIBCUDACXX_HAS_NO_LONG_DOUBLE

#if defined(_LIBCUDACXX_HAS_NVFP16)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI __half sqrt(__half __x) noexcept
{
NV_IF_ELSE_TARGET(NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2half(_CUDA_VSTD::sqrt(__half2float(__x)));))
}
#endif // _LIBCUDACXX_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI __nv_bfloat16 sqrt(__nv_bfloat16 __x) noexcept
{
NV_IF_ELSE_TARGET(
NV_IS_DEVICE, (return ::hsqrt(__x);), (return __float2bfloat16(_CUDA_VSTD::sqrt(__bfloat162float(__x)));))
}
#endif // _LIBCUDACXX_HAS_NVBF16

template <class _Integer, enable_if_t<_CCCL_TRAIT(is_integral, _Integer), int> = 0>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI double sqrt(_Integer __x) noexcept
{
return _CUDA_VSTD::sqrt((double) __x);
}

// cbrt

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI float cbrt(float __x) noexcept
{
#if defined(_CCCL_BUILTIN_CBRTF)
return _CCCL_BUILTIN_CBRTF(__x);
#else // ^^^ _CCCL_BUILTIN_CBRTF ^^^ // vvv !_CCCL_BUILTIN_CBRTF vvv
return ::cbrtf(__x);
#endif // !_CCCL_BUILTIN_CBRTF
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI float cbrtf(float __x) noexcept
{
#if defined(_CCCL_BUILTIN_CBRTF)
return _CCCL_BUILTIN_CBRTF(__x);
#else // ^^^ _CCCL_BUILTIN_CBRTF ^^^ // vvv !_CCCL_BUILTIN_CBRTF vvv
return ::cbrtf(__x);
#endif // !_CCCL_BUILTIN_CBRTF
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI double cbrt(double __x) noexcept
{
#if defined(_CCCL_BUILTIN_CBRT)
return _CCCL_BUILTIN_CBRT(__x);
#else // ^^^ _CCCL_BUILTIN_CBRT ^^^ // vvv !_CCCL_BUILTIN_CBRT vvv
return ::cbrt(__x);
#endif // !_CCCL_BUILTIN_CBRT
}

#if !defined(_LIBCUDACXX_HAS_NO_LONG_DOUBLE)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI long double cbrt(long double __x) noexcept
{
# if defined(_CCCL_BUILTIN_CBRTL)
return _CCCL_BUILTIN_CBRTL(__x);
# else // ^^^ _CCCL_BUILTIN_CBRTL ^^^ // vvv !_CCCL_BUILTIN_CBRTL vvv
return ::cbrtl(__x);
# endif // !_CCCL_BUILTIN_CBRTL
}

_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI long double cbrtl(long double __x) noexcept
{
# if defined(_CCCL_BUILTIN_CBRTL)
return _CCCL_BUILTIN_CBRTL(__x);
# else // ^^^ _CCCL_BUILTIN_CBRTL ^^^ // vvv !_CCCL_BUILTIN_CBRTL vvv
return ::cbrtl(__x);
# endif // !_CCCL_BUILTIN_CBRTL
}
#endif // !_LIBCUDACXX_HAS_NO_LONG_DOUBLE

#if defined(_LIBCUDACXX_HAS_NVFP16)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI __half cbrt(__half __x) noexcept
{
return __float2half(_CUDA_VSTD::cbrt(__half2float(__x)));
}
#endif // _LIBCUDACXX_HAS_NVFP16

#if defined(_LIBCUDACXX_HAS_NVBF16)
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI __nv_bfloat16 cbrt(__nv_bfloat16 __x) noexcept
{
return __float2bfloat16(_CUDA_VSTD::cbrt(__bfloat162float(__x)));
}
#endif // _LIBCUDACXX_HAS_NVBF16

template <class _Integer, enable_if_t<_CCCL_TRAIT(is_integral, _Integer), int> = 0>
_CCCL_NODISCARD _LIBCUDACXX_HIDE_FROM_ABI double cbrt(_Integer __x) noexcept
{
return _CUDA_VSTD::cbrt((double) __x);
}

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _LIBCUDACXX___CMATH_ROOTS_H
9 changes: 1 addition & 8 deletions libcudacxx/include/cuda/std/detail/libcxx/include/cmath
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ long double truncl(long double x);
#include <cuda/std/__cmath/lerp.h>
#include <cuda/std/__cmath/logarithms.h>
#include <cuda/std/__cmath/min_max.h>
#include <cuda/std/__cmath/roots.h>
#include <cuda/std/__cmath/rounding_functions.h>
#include <cuda/std/__cmath/traits.h>
#include <cuda/std/__cstdlib/abs.h>
Expand Down Expand Up @@ -371,8 +372,6 @@ using ::sinf;
using ::sinh;
using ::sinhf;

using ::sqrt;
using ::sqrtf;
using ::tan;
using ::tanf;

Expand Down Expand Up @@ -413,8 +412,6 @@ using ::sinf;
using ::sinh;
using ::sinhf;

using ::sqrt;
using ::sqrtf;
using ::tan;
using ::tanf;

Expand All @@ -427,8 +424,6 @@ using ::asinh;
using ::asinhf;
using ::atanh;
using ::atanhf;
using ::cbrt;
using ::cbrtf;

using ::copysign;
using ::copysignf;
Expand Down Expand Up @@ -476,13 +471,11 @@ using ::modfl;
using ::powl;
using ::sinhl;
using ::sinl;
using ::sqrtl;
using ::tanl;

using ::acoshl;
using ::asinhl;
using ::atanhl;
using ::cbrtl;
using ::tanhl;

using ::copysignl;
Expand Down
Loading

0 comments on commit 12c7fbb

Please sign in to comment.