Skip to content

Commit

Permalink
replace use of std math in kernels with quda versions
Browse files Browse the repository at this point in the history
allow building without cufft
change some casting
  • Loading branch information
jcosborn committed Dec 11, 2024
1 parent a54595d commit 40df91b
Show file tree
Hide file tree
Showing 29 changed files with 246 additions and 215 deletions.
2 changes: 1 addition & 1 deletion include/clover_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ namespace quda {

template <typename T = void *> auto data(bool inverse = false) const
{
return inverse ? reinterpret_cast<T>(cloverInv.data()) : reinterpret_cast<T>(clover.data());
return inverse ? static_cast<T>(cloverInv.data()) : static_cast<T>(clover.data());
}

/**
Expand Down
4 changes: 2 additions & 2 deletions include/communicator_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ namespace quda
inline void check_displacement(const int displacement[], int ndim)
{
for (int i = 0; i < ndim; i++) {
if (abs(displacement[i]) > max_displacement) {
if (std::abs(displacement[i]) > max_displacement) {
errorQuda("Requested displacement[%d] = %d is greater than maximum allowed", i, displacement[i]);
}
}
Expand Down Expand Up @@ -232,7 +232,7 @@ namespace quda
disable_peer_to_peer_bidir = true;
}

enable_peer_to_peer = abs(enable_peer_to_peer);
enable_peer_to_peer = std::abs(enable_peer_to_peer);

} else { // !enable_peer_to_peer_env
if (getVerbosity() > QUDA_SILENT && rank == 0)
Expand Down
156 changes: 31 additions & 125 deletions include/complex_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <cstdint>
#include <type_traits>
#include <quda_arch.h> // for double2 / float2
#include <math_helper.h>

namespace quda {
namespace gauge {
Expand All @@ -41,93 +42,10 @@ namespace quda {
// doesn't try to call the complex sqrt, but the standard sqrt
namespace quda
{
template <typename ValueType>
__host__ __device__
inline ValueType cos(ValueType x){
return std::cos(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType sin(ValueType x){
return std::sin(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType tan(ValueType x){
return std::tan(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType acos(ValueType x){
return std::acos(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType asin(ValueType x){
return std::asin(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType atan(ValueType x){
return std::atan(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType atan2(ValueType x,ValueType y){
return std::atan2(x,y);
}
template <typename ValueType>
__host__ __device__
inline ValueType cosh(ValueType x){
return std::cosh(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType sinh(ValueType x){
return std::sinh(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType tanh(ValueType x){
return std::tanh(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType exp(ValueType x){
return std::exp(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType log(ValueType x){
return std::log(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType log10(ValueType x){
return std::log10(x);
}
template <typename ValueType, typename ExponentType>
__host__ __device__
inline ValueType pow(ValueType x, ExponentType e){
return std::pow(x,static_cast<ValueType>(e));
}
template <typename ValueType>
__host__ __device__
inline ValueType sqrt(ValueType x){
return std::sqrt(x);
}
template <typename ValueType>
__host__ __device__
inline ValueType abs(ValueType x){
return std::abs(x);
}

__host__ __device__ inline float conj(float x) { return x; }
__host__ __device__ inline double conj(double x) { return x; }

template <typename ValueType> struct complex;
//template <> struct complex<float>;
//template <> struct complex<double>;


/// Returns the magnitude of z.
Expand Down Expand Up @@ -738,7 +656,7 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());
__host__ __device__
inline complex<float> operator/(const complex<float>& lhs, const complex<float>& rhs){

float s = fabsf(rhs.real()) + fabsf(rhs.imag());
float s = abs(rhs.real()) + abs(rhs.imag());
float oos = 1.0f / s;
float ars = lhs.real() * oos;
float ais = lhs.imag() * oos;
Expand All @@ -754,7 +672,7 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());
__host__ __device__
inline complex<double> operator/(const complex<double>& lhs, const complex<double>& rhs){

double s = fabs(rhs.real()) + fabs(rhs.imag());
double s = abs(rhs.real()) + abs(rhs.imag());
double oos = 1.0 / s;
double ars = lhs.real() * oos;
double ais = lhs.imag() * oos;
Expand Down Expand Up @@ -859,29 +777,17 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());
template <typename ValueType>
__host__ __device__
inline ValueType abs(const complex<ValueType>& z){
return ::hypot(z.real(),z.imag());
}
template <>
__host__ __device__
inline float abs(const complex<float>& z){
return ::hypotf(z.real(),z.imag());
}
template<>
__host__ __device__
inline double abs(const complex<double>& z){
return ::hypot(z.real(),z.imag());
return hypot(z.real(), z.imag());
}
template <> __host__ __device__ inline float abs(const complex<float> &z) { return hypot(z.real(), z.imag()); }
template <> __host__ __device__ inline double abs(const complex<double> &z) { return hypot(z.real(), z.imag()); }

template <typename ValueType>
__host__ __device__
inline ValueType arg(const complex<ValueType>& z){
return atan2(z.imag(),z.real());
}
template<>
__host__ __device__
inline float arg(const complex<float>& z){
return atan2f(z.imag(),z.real());
}
template <> __host__ __device__ inline float arg(const complex<float> &z) { return atan2(z.imag(), z.real()); }
template<>
__host__ __device__
inline double arg(const complex<double>& z){
Expand All @@ -897,19 +803,19 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());
template <typename ValueType>
__host__ __device__
inline complex<ValueType> polar(const ValueType & m, const ValueType & theta){
return complex<ValueType>(m * ::cos(theta),m * ::sin(theta));
return complex<ValueType>(m * cos(theta), m * sin(theta));
}

template <>
__host__ __device__
inline complex<float> polar(const float & magnitude, const float & angle){
return complex<float>(magnitude * ::cosf(angle),magnitude * ::sinf(angle));
return complex<float>(magnitude * cos(angle), magnitude * sin(angle));
}

template <>
__host__ __device__
inline complex<double> polar(const double & magnitude, const double & angle){
return complex<double>(magnitude * ::cos(angle),magnitude * ::sin(angle));
return complex<double>(magnitude * cos(angle), magnitude * sin(angle));
}

// Transcendental functions implementation
Expand All @@ -918,56 +824,56 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());
inline complex<ValueType> cos(const complex<ValueType>& z){
const ValueType re = z.real();
const ValueType im = z.imag();
return complex<ValueType>(::cos(re) * ::cosh(im), -::sin(re) * ::sinh(im));
return complex<ValueType>(cos(re) * cosh(im), -sin(re) * sinh(im));
}

template <>
__host__ __device__
inline complex<float> cos(const complex<float>& z){
const float re = z.real();
const float im = z.imag();
return complex<float>(cosf(re) * coshf(im), -sinf(re) * sinhf(im));
return complex<float>(cos(re) * cosh(im), -sin(re) * sinh(im));
}

template <typename ValueType>
__host__ __device__
inline complex<ValueType> cosh(const complex<ValueType>& z){
const ValueType re = z.real();
const ValueType im = z.imag();
return complex<ValueType>(::cosh(re) * ::cos(im), ::sinh(re) * ::sin(im));
return complex<ValueType>(cosh(re) * cos(im), sinh(re) * sin(im));
}

template <>
__host__ __device__
inline complex<float> cosh(const complex<float>& z){
const float re = z.real();
const float im = z.imag();
return complex<float>(::coshf(re) * ::cosf(im), ::sinhf(re) * ::sinf(im));
return complex<float>(cosh(re) * cos(im), sinh(re) * sin(im));
}


template <typename ValueType>
__host__ __device__
inline complex<ValueType> exp(const complex<ValueType>& z){
return polar(::exp(z.real()),z.imag());
return polar(exp(z.real()), z.imag());
}

template <>
__host__ __device__
inline complex<float> exp(const complex<float>& z){
return polar(::expf(z.real()),z.imag());
return polar(exp(z.real()), z.imag());
}

template <typename ValueType>
__host__ __device__
inline complex<ValueType> log(const complex<ValueType>& z){
return complex<ValueType>(::log(abs(z)),arg(z));
return complex<ValueType>(log(abs(z)), arg(z));
}

template <>
__host__ __device__
inline complex<float> log(const complex<float>& z){
return complex<float>(::logf(abs(z)),arg(z));
return complex<float>(log(abs(z)), arg(z));
}


Expand All @@ -977,7 +883,7 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());
// Using the explicit literal prevents compile time warnings in
// devices that don't support doubles
return log(z)/ValueType(2.30258509299404568402);
// return log(z)/ValueType(::log(10.0));
// return log(z)/ValueType(log(10.0));
}

template <typename ValueType>
Expand All @@ -995,13 +901,13 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());
template <typename ValueType>
__host__ __device__
inline complex<ValueType> pow(const ValueType & x, const complex<ValueType> & exponent){
return exp(::log(x)*exponent);
return exp(log(x) * exponent);
}

template <>
__host__ __device__
inline complex<float> pow(const float & x, const complex<float> & exponent){
return exp(::logf(x)*exponent);
return exp(log(x) * exponent);
}

template <typename ValueType>
Expand All @@ -1015,43 +921,43 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());
inline complex<ValueType> sin(const complex<ValueType>& z){
const ValueType re = z.real();
const ValueType im = z.imag();
return complex<ValueType>(::sin(re) * ::cosh(im), ::cos(re) * ::sinh(im));
return complex<ValueType>(sin(re) * cosh(im), cos(re) * sinh(im));
}

template <>
__host__ __device__
inline complex<float> sin(const complex<float>& z){
const float re = z.real();
const float im = z.imag();
return complex<float>(::sinf(re) * ::coshf(im), ::cosf(re) * ::sinhf(im));
return complex<float>(sin(re) * cosh(im), cos(re) * sinh(im));
}

template <typename ValueType>
__host__ __device__
inline complex<ValueType> sinh(const complex<ValueType>& z){
const ValueType re = z.real();
const ValueType im = z.imag();
return complex<ValueType>(::sinh(re) * ::cos(im), ::cosh(re) * ::sin(im));
return complex<ValueType>(sinh(re) * cos(im), cosh(re) * sin(im));
}

template <>
__host__ __device__
inline complex<float> sinh(const complex<float>& z){
const float re = z.real();
const float im = z.imag();
return complex<float>(::sinhf(re) * ::cosf(im), ::coshf(re) * ::sinf(im));
return complex<float>(sinh(re) * cos(im), cosh(re) * sin(im));
}

template <typename ValueType>
__host__ __device__
inline complex<ValueType> sqrt(const complex<ValueType>& z){
return polar(::sqrt(abs(z)),arg(z)/ValueType(2));
return polar(sqrt(abs(z)), arg(z) / ValueType(2));
}

template <>
__host__ __device__
inline complex<float> sqrt(const complex<float>& z){
return polar(::sqrtf(abs(z)),arg(z)/float(2));
return polar(sqrt(abs(z)), arg(z) / float(2));
}

template <typename ValueType>
Expand Down Expand Up @@ -1131,11 +1037,11 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());

ValueType d = ValueType(1.0) - z.real();
d = imag2 + d * d;
complex<ValueType> ret(ValueType(0.25) * (::log(n) - ::log(d)),0);
complex<ValueType> ret(ValueType(0.25) * (log(n) - log(d)), 0);

d = ValueType(1.0) - z.real() * z.real() - imag2;

ret.imag(ValueType(0.5) * ::atan2(ValueType(2.0) * z.imag(), d));
ret.imag(ValueType(0.5) * atan2(ValueType(2.0) * z.imag(), d));
return ret;
//return (log(ValueType(1)+z)-log(ValueType(1)-z))/ValueType(2);
}
Expand All @@ -1149,11 +1055,11 @@ lhs.real()*rhs.imag()+lhs.imag()*rhs.real());

float d = float(1.0) - z.real();
d = imag2 + d * d;
complex<float> ret(float(0.25) * (::logf(n) - ::logf(d)),0);
complex<float> ret(float(0.25) * (log(n) - log(d)), 0);

d = float(1.0) - z.real() * z.real() - imag2;

ret.imag(float(0.5) * ::atan2f(float(2.0) * z.imag(), d));
ret.imag(float(0.5) * atan2(float(2.0) * z.imag(), d));
return ret;
//return (log(ValueType(1)+z)-log(ValueType(1)-z))/ValueType(2);

Expand Down
1 change: 1 addition & 0 deletions include/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <type_traits>
#include <target_device.h>
#include <register_traits.h>
#include <math_helper.cuh>

namespace quda
{
Expand Down
4 changes: 2 additions & 2 deletions include/device_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace quda
{
size_t bytes = _size * sizeof(real);
if (bytes > 0) {
_device_data = reinterpret_cast<real *>(pool_device_malloc(bytes));
_device_data = static_cast<real *>(pool_device_malloc(bytes));
qudaMemcpy(_device_data, host_vector.data(), bytes, qudaMemcpyHostToDevice);
}
}
Expand Down Expand Up @@ -62,7 +62,7 @@ namespace quda
_size = size_;
size_t bytes = _size * sizeof(real);
if (bytes > 0) {
_device_data = reinterpret_cast<real *>(pool_device_malloc(bytes));
_device_data = static_cast<real *>(pool_device_malloc(bytes));
qudaMemset(_device_data, 0, bytes);
}
}
Expand Down
Loading

0 comments on commit 40df91b

Please sign in to comment.