Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into feature/prolongato…
Browse files Browse the repository at this point in the history
…r-mma
  • Loading branch information
hummingtree committed Dec 5, 2024
2 parents be3d180 + a54595d commit d94156a
Show file tree
Hide file tree
Showing 95 changed files with 2,793 additions and 805 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ include/jitify_options.hpp
.tags*
autom4te.cache/*
.vscode
cmake/CPM_*.cmake
19 changes: 9 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ if(QUDA_MAX_MULTI_BLAS_N GREATER 32)
message(SEND_ERROR "Maximum QUDA_MAX_MULTI_BLAS_N is 32.")
endif()

# For now only we only support register tiles for the staggered dslash operators
set(QUDA_MAX_MULTI_RHS_TILE "1" CACHE STRING "maximum tile size for MRHS kernels (staggered only)")
if(QUDA_MAX_MULTI_RHS_TILE GREATER QUDA_MAX_MULTI_RHS)
message(SEND_ERROR "QUDA_MAX_MULTI_RHS_TILE is greater than QUDA_MAX_MULTI_RHS")
endif()

set(QUDA_PRECISION
"14"
CACHE STRING "which precisions to instantiate in QUDA (4-bit number - double, single, half, quarter)")
Expand Down Expand Up @@ -275,6 +281,7 @@ mark_as_advanced(QUDA_ALTERNATIVE_I_TO_F)

mark_as_advanced(QUDA_MAX_MULTI_BLAS_N)
mark_as_advanced(QUDA_MAX_MULTI_RHS)
mark_as_advanced(QUDA_MAX_MULTI_RHS_TILE)
mark_as_advanced(QUDA_PRECISION)
mark_as_advanced(QUDA_RECONSTRUCT)
mark_as_advanced(QUDA_CLOVER_CHOLESKY_PROMOTE)
Expand Down Expand Up @@ -420,20 +427,12 @@ if(QUDA_DOWNLOAD_EIGEN)
CPMAddPackage(
NAME Eigen
VERSION ${QUDA_EIGEN_VERSION}
URL https://gitlab.com/libeigen/eigen/-/archive/${QUDA_EIGEN_VERSION}/eigen-${QUDA_EIGEN_VERSION}.tar.bz2
URL https://gitlab.com/libeigen/eigen/-/archive/e67c494cba7180066e73b9f6234d0b2129f1cdf5.tar.bz2
URL_HASH SHA256=98d244932291506b75c4ae7459af29b1112ea3d2f04660686a925d9ef6634583
DOWNLOAD_ONLY YES
SYSTEM YES)
target_include_directories(Eigen SYSTEM INTERFACE ${Eigen_SOURCE_DIR})
install(DIRECTORY ${Eigen_SOURCE_DIR}/Eigen TYPE INCLUDE)

# Eigen 3.4 needs to be patched on Neon with nvc++
if (${CMAKE_CXX_COMPILER_ID} MATCHES "NVHPC")
set(CMAKE_PATCH_EIGEN OFF CACHE BOOL "Internal use only; do not modify")
if (NOT CMAKE_PATCH_EIGEN)
execute_process(COMMAND patch -N "${Eigen_SOURCE_DIR}/Eigen/src/Core/arch/NEON/Complex.h" "${CMAKE_SOURCE_DIR}/cmake/eigen34_neon.diff")
set(CMAKE_PATCH_EIGEN ON CACHE BOOL "Internal use only; do not modify" FORCE)
endif()
endif()
else()
# fall back to using find_package
find_package(Eigen QUIET)
Expand Down
8 changes: 0 additions & 8 deletions cmake/eigen34_neon.diff

This file was deleted.

80 changes: 80 additions & 0 deletions include/blas_3d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#pragma once

#include <quda_internal.h>
#include <color_spinor_field.h>

namespace quda
{

namespace blas3d
{

// Local enum for the 3D copy type
enum class copyType { COPY_TO_3D, COPY_FROM_3D, SWAP_3D };

/**
@brief Extract / insert / swap a timeslice between a 4-d field and a 3-d field
@param[in] slice Which slice
@param[in] type Extracting a time slice (COPY_TO_3D) or
inserting a timeslice (COPY_FROM_3D) or swapping time slices
(SWAP_3D)
@param[in,out] x 3-d field
@param[in,out] y 4-d field
*/
void copy(int slice, copyType type, ColorSpinorField &x, ColorSpinorField &y);

/**
@brief Swap the slice in two given fields
@param[in] slice The slice we wish to swap in the fields
@param[in,out] x Field whose slice we wish to swap
@param[in,out] y Field whose slice we wish to swap
*/
void swap(int slice, ColorSpinorField &x, ColorSpinorField &y);

/**
@brief Compute a set of real-valued inner products <x, y>, where each inner
product is restricted to a timeslice.
@param[out] result Vector of spatial inner products
@param[in] x Left vector field
@param[in] y Right vector field
*/
void reDotProduct(std::vector<double> &result, const ColorSpinorField &x, const ColorSpinorField &y);

/**
@brief Compute a set of complex-valued inner products <x, y>, where each inner
product is restricted to a timeslice.
@param[out] result Vector of spatial inner products
@param[in] x Left vector field
@param[in] y Right vector field
*/
void cDotProduct(std::vector<Complex> &result, const ColorSpinorField &a, const ColorSpinorField &b);

/**
@brief Timeslice real-valued scaling of the field
@param[in] a Vector of scale factors (length = local temporary extent)
@param[in] x Field we we wish to scale
*/
void ax(std::vector<double> &a, ColorSpinorField &x);

/**
@brief Timeslice real-valued axpby computation
@param[in] a Vector of scale factors (length = local temporary extent)
@param[in] x Input field
@param[in] b Vector of scale factors (length = local temporary extent)
@param[in,out] y Field we are updating
*/
void axpby(const std::vector<double> &a, const ColorSpinorField &x, const std::vector<double> &b,
ColorSpinorField &y);

/**
@brief Timeslice complex-valued axpby computation
@param[in] a Vector of scale factors (length = local temporary extent)
@param[in] x Input field
@param[in] b Vector of scale factors (length = local temporary extent)
@param[in,out] y Field we are updating
*/
void caxpby(const std::vector<Complex> &a, const ColorSpinorField &x, const std::vector<Complex> &b,
ColorSpinorField &y);

} // namespace blas3d
} // namespace quda
34 changes: 34 additions & 0 deletions include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,8 @@ namespace quda
//! for deflation etc.
if (is_composite) printfQuda("Number of elements = %d\n", composite_dim);
}

void change_dim(int D, int d) { x[D] = d; }
};

struct DslashConstant;
Expand Down Expand Up @@ -1054,6 +1056,38 @@ namespace quda
*/
void spinorDistanceReweight(ColorSpinorField &src, double alpha0, int t0);

/**
@brief Helper function for determining if the spin of the fields is the same.
@param[in] a Input field
@param[in] b Input field
@return If spin is unique return the number of spins
*/
inline int Spin_(const char *func, const char *file, int line, const ColorSpinorField &a, const ColorSpinorField &b)
{
int nSpin = 0;
if (a.Nspin() == b.Nspin())
nSpin = a.Nspin();
else
errorQuda("Spin %d %d do not match (%s:%d in %s())", a.Nspin(), b.Nspin(), file, line, func);
return nSpin;
}

/**
@brief Helper function for determining if the spin of the fields is the same.
@param[in] a Input field
@param[in] b Input field
@param[in] args List of additional fields to check spin on
@return If spins is unique return the number of spins
*/
template <typename... Args>
inline int Spin_(const char *func, const char *file, int line, const ColorSpinorField &a, const ColorSpinorField &b,
const Args &...args)
{
return Spin_(func, file, line, a, b) & Spin_(func, file, line, a, args...);
}

#define checkSpin(...) Spin_(__func__, __FILE__, __LINE__, __VA_ARGS__)

/**
@brief Helper function for determining if the preconditioning
type of the fields is the same.
Expand Down
9 changes: 6 additions & 3 deletions include/color_spinor_field_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,16 +971,19 @@ namespace quda
norm_t *norm = nullptr;
int norm_offset = 0;
if constexpr (fixed) {
if constexpr (block_float) {
if constexpr (fixed && block_float && nColor == 3 && nSpin == 1 && nVec == 1) {
norm = v.norm;
norm_offset = parity * v.norm_offset + 4 * x_cb + 3;
} else if constexpr (block_float) {
norm = v.norm;
norm_offset = v.norm_offset;
norm_offset = parity * v.norm_offset + x_cb;
} else {
scale = v.scale;
scale_inv = v.scale_inv;
}
}
return fieldorder_wrapper<Float, storeFloat, block_float, norm_t>(
v.v, accessor.index(parity, x_cb, s, c, n, volumeCB), scale, scale_inv, norm, parity * norm_offset + x_cb);
v.v, accessor.index(parity, x_cb, s, c, n, volumeCB), scale, scale_inv, norm, norm_offset);
}

/** Returns the number of field colors */
Expand Down
22 changes: 14 additions & 8 deletions include/comm_key.h
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
#pragma once

#include <array.h>

namespace quda
{

struct CommKey {

static constexpr int n_dim = 4;
array<int, n_dim> key = {0, 0, 0, 0};

int array[n_dim] = {0, 0, 0, 0};
constexpr int product() { return key[0] * key[1] * key[2] * key[3]; }

constexpr inline int product() { return array[0] * array[1] * array[2] * array[3]; }
constexpr int &operator[](int d) { return key[d]; }

constexpr inline int &operator[](int d) { return array[d]; }
constexpr const int &operator[](int d) const { return key[d]; }

constexpr inline const int &operator[](int d) const { return array[d]; }
constexpr auto data() { return key.data; }

constexpr inline int *data() { return array; }
constexpr auto data() const { return key.data; }

constexpr inline const int *data() const { return array; }
constexpr bool is_valid() const { return (key[0] > 0) && (key[1] > 0) && (key[2] > 0) && (key[3] > 0); }

constexpr inline bool is_valid() const
bool operator==(const CommKey &other) const
{
return (array[0] > 0) && (array[1] > 0) && (array[2] > 0) && (array[3] > 0);
bool is_same = true;
for (auto i = 0; i < n_dim; i++)
if (key[i] != other.key[i]) is_same = false;
return is_same;
}
};

Expand Down
16 changes: 15 additions & 1 deletion include/comm_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,26 @@ namespace quda
int comm_dim(int dim);

/**
Return the coording of this process in the dimension dim
Return the global number of processes in the dimension dim
@param dim Dimension which we are querying
@return Length of process dimensions
*/
int comm_dim_global(int dim);

/**
Return the coordinate of this process in the dimension dim
@param dim Dimension which we are querying
@return Coordinate of this process
*/
int comm_coord(int dim);

/**
Return the global coordinates of this process in the dimension dim
@param dim Dimension which we are querying
@return Coordinate of this process
*/
int comm_coord_global(int dim);

/**
* Declare a message handle for sending `nbytes` to the `rank` with `tag`.
*/
Expand Down
38 changes: 16 additions & 22 deletions include/complex_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,23 +360,19 @@ struct complex
typedef ValueType value_type;

// Constructors
__host__ __device__ inline complex<ValueType>(const ValueType &re = ValueType(), const ValueType &im = ValueType())
__host__ __device__ inline complex(const ValueType &re = ValueType(), const ValueType &im = ValueType())
{
real(re);
imag(im);
}

template <class X>
__host__ __device__
inline complex<ValueType>(const complex<X> & z)
template <class X> __host__ __device__ inline complex(const complex<X> &z)
{
real(z.real());
imag(z.imag());
}

template <class X>
__host__ __device__
inline complex<ValueType>(const std::complex<X> & z)
template <class X> __host__ __device__ inline complex(const std::complex<X> &z)
{
real(z.real());
imag(z.imag());
Expand Down Expand Up @@ -436,12 +432,11 @@ struct complex
template <> struct complex<float> : public float2 {
public:
typedef float value_type;
complex<float>() = default;
constexpr complex<float>(const float &re, const float &im = float()) : float2 {re, im} { }
complex() = default;
constexpr complex(const float &re, const float &im = float()) : float2 {re, im} { }

template <typename X>
constexpr complex<float>(const std::complex<X> &z) :
float2 {static_cast<float>(z.real()), static_cast<float>(z.imag())}
constexpr complex(const std::complex<X> &z) : float2 {static_cast<float>(z.real()), static_cast<float>(z.imag())}
{
}

Expand Down Expand Up @@ -500,16 +495,15 @@ template <> struct complex<float> : public float2 {
template <> struct complex<double> : public double2 {
public:
typedef double value_type;
complex<double>() = default;
constexpr complex<double>(const double &re, const double &im = double()) : double2 {re, im} { }
complex() = default;
constexpr complex(const double &re, const double &im = double()) : double2 {re, im} { }

template <typename X>
constexpr complex<double>(const std::complex<X> &z) :
double2 {static_cast<double>(z.real()), static_cast<double>(z.imag())}
constexpr complex(const std::complex<X> &z) : double2 {static_cast<double>(z.real()), static_cast<double>(z.imag())}
{
}

template <typename T> __host__ __device__ inline complex<double> &operator=(const complex<T> &z)
template <typename T> __host__ __device__ inline complex &operator=(const complex<T> &z)
{
real(z.real());
imag(z.imag());
Expand Down Expand Up @@ -572,9 +566,9 @@ template <> struct complex<int8_t> : public char2 {
public:
typedef int8_t value_type;

complex<int8_t>() = default;
complex() = default;

constexpr complex<int8_t>(const int8_t &re, const int8_t &im = int8_t()) : char2 {re, im} { }
constexpr complex(const int8_t &re, const int8_t &im = int8_t()) : char2 {re, im} { }

__host__ __device__ inline complex<int8_t> &operator+=(const complex<int8_t> &z)
{
Expand Down Expand Up @@ -608,9 +602,9 @@ struct complex <short> : public short2
public:
typedef short value_type;

complex<short>() = default;
complex() = default;

constexpr complex<short>(const short &re, const short &im = short()) : short2 {re, im} { }
constexpr complex(const short &re, const short &im = short()) : short2 {re, im} { }

__host__ __device__ inline complex<short> &operator+=(const complex<short> &z)
{
Expand Down Expand Up @@ -644,9 +638,9 @@ struct complex <int> : public int2
public:
typedef int value_type;

complex<int>() = default;
complex() = default;

constexpr complex<int>(const int &re, const int &im = int()) : int2 {re, im} { }
constexpr complex(const int &re, const int &im = int()) : int2 {re, im} { }

__host__ __device__ inline complex<int> &operator+=(const complex<int> &z)
{
Expand Down
Loading

0 comments on commit d94156a

Please sign in to comment.