Skip to content

Commit

Permalink
All sorts of cleanup, moved various is_*_[solve/solution/etc] routine…
Browse files Browse the repository at this point in the history
…s into host utils
  • Loading branch information
weinbe2 committed Dec 7, 2023
1 parent 30eb5ee commit e161cdc
Show file tree
Hide file tree
Showing 13 changed files with 170 additions and 190 deletions.
11 changes: 0 additions & 11 deletions tests/eigensolve_test_gtest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,6 @@ class EigensolveTest : public ::testing::TestWithParam<test_t>
EigensolveTest() : param(GetParam()) { }
};

bool is_chiral(QudaDslashType type)
{
switch (type) {
case QUDA_DOMAIN_WALL_DSLASH:
case QUDA_DOMAIN_WALL_4D_DSLASH:
case QUDA_MOBIUS_DWF_DSLASH:
case QUDA_MOBIUS_DWF_EOFA_DSLASH: return true;
default: return false;
}
}

bool skip_test(test_t param)
{
// dwf-style solves must use a normal solver
Expand Down
2 changes: 1 addition & 1 deletion tests/host_reference/dslash_reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ std::array<double, 2> verifyStaggeredInversion(quda::ColorSpinorField &tmp, quda
stag_mat(ref, fat_link, long_link, out, mass, dagger, dslash_type);

// correct for the massRescale function inside invertQuda
if (dslash_type == QUDA_LAPLACE_DSLASH)
if (is_laplace(dslash_type))
ax(0.5 / kappa, ref.data(), ref.Length(), ref.Precision());
} else if (inv_param.solution_type == QUDA_MATPC_SOLUTION) {
QudaParity parity = QUDA_INVALID_PARITY;
Expand Down
72 changes: 2 additions & 70 deletions tests/invert_test_gtest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,79 +16,11 @@ class InvertTest : public ::testing::TestWithParam<test_t>
InvertTest() : param(GetParam()) { }
};

bool is_normal_residual(QudaInverterType type)
{
switch (type) {
case QUDA_CGNR_INVERTER:
case QUDA_CA_CGNR_INVERTER: return true;
default: return false;
}
}

bool is_preconditioned_solve(QudaSolveType type)
{
switch (type) {
case QUDA_DIRECT_PC_SOLVE:
case QUDA_NORMOP_PC_SOLVE: return true;
default: return false;
}
}

bool is_full_solution(QudaSolutionType type)
{
switch (type) {
case QUDA_MAT_SOLUTION:
case QUDA_MATDAG_MAT_SOLUTION: return true;
default: return false;
}
}

bool is_normal_solve(test_t param)
{
auto inv_type = ::testing::get<0>(param);
auto solve_type = ::testing::get<2>(param);

switch (solve_type) {
case QUDA_NORMOP_SOLVE:
case QUDA_NORMOP_PC_SOLVE: return true;
default:
switch (inv_type) {
case QUDA_CGNR_INVERTER:
case QUDA_CGNE_INVERTER:
case QUDA_CA_CGNR_INVERTER:
case QUDA_CA_CGNE_INVERTER: return true;
default: return false;
}
}
}

bool is_chiral(QudaDslashType type)
{
switch (type) {
case QUDA_DOMAIN_WALL_DSLASH:
case QUDA_DOMAIN_WALL_4D_DSLASH:
case QUDA_MOBIUS_DWF_DSLASH:
case QUDA_MOBIUS_DWF_EOFA_DSLASH: return true;
default: return false;
}
}

bool support_solution_accumulator_pipeline(QudaInverterType type)
{
switch (type) {
case QUDA_CG_INVERTER:
case QUDA_CA_CG_INVERTER:
case QUDA_CGNR_INVERTER:
case QUDA_CGNE_INVERTER:
case QUDA_PCG_INVERTER: return true;
default: return false;
}
}

bool skip_test(test_t param)
{
auto inverter_type = ::testing::get<0>(param);
auto solution_type = ::testing::get<1>(param);
auto solve_type = ::testing::get<2>(param);
auto prec_sloppy = ::testing::get<3>(param);
auto multishift = ::testing::get<4>(param);
auto solution_accumulator_pipeline = ::testing::get<5>(param);
Expand All @@ -102,7 +34,7 @@ bool skip_test(test_t param)
if (prec_sloppy < prec_precondition) return true; // sloppy precision >= preconditioner precision

// dwf-style solves must use a normal solver
if (is_chiral(dslash_type) && !is_normal_solve(param)) return true;
if (is_chiral(dslash_type) && !is_normal_solve(inverter_type, solve_type)) return true;
// FIXME this needs to be added to dslash_reference.cpp
if (is_chiral(dslash_type) && multishift > 1) return true;
// FIXME this needs to be added to dslash_reference.cpp
Expand Down
10 changes: 5 additions & 5 deletions tests/staggered_dslash_ctest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class StaggeredDslashTest : public ::testing::TestWithParam<::testing::tuple<int
|| (QUDA_RECONSTRUCT & getReconstructNibble(recon)) == 0)
return true;

if (dslash_type == QUDA_LAPLACE_DSLASH && (::testing::get<0>(GetParam()) == 0 || ::testing::get<0>(GetParam()) == 1))
if (is_laplace(dslash_type) && (::testing::get<0>(GetParam()) == 0 || ::testing::get<0>(GetParam()) == 1))
return true;

const std::array<bool, 16> partition_enabled {true, true, true, false, true, false, false, false,
Expand Down Expand Up @@ -123,12 +123,12 @@ int main(int argc, char **argv)

// Only these fermions are supported in this file
if (is_laplace_enabled) {
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH && dslash_type != QUDA_LAPLACE_DSLASH)
if (!is_staggered(dslash_type) && !is_laplace(dslash_type))
errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
} else {
if (dslash_type == QUDA_LAPLACE_DSLASH)
if (is_laplace(dslash_type))
errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_LAPLACE=ON");
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH)
if (!is_staggered(dslash_type))
errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
}

Expand All @@ -146,7 +146,7 @@ int main(int argc, char **argv)
eps_naik = 0.0; // to avoid potential headaches
}

if (dslash_type == QUDA_LAPLACE_DSLASH && dtest_type != dslash_test_type::Mat)
if (is_laplace(dslash_type) && dtest_type != dslash_test_type::Mat)
errorQuda("Test type %s is not supported for the Laplace operator", get_string(dtest_type_map, dtest_type).c_str());

int test_rc = RUN_ALL_TESTS();
Expand Down
8 changes: 4 additions & 4 deletions tests/staggered_dslash_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ int main(int argc, char **argv)

// Only these fermions are supported in this file
if (is_laplace_enabled) {
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH && dslash_type != QUDA_LAPLACE_DSLASH)
if (!is_staggered(dslash_type) && !is_laplace(dslash_type))
errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
} else {
if (dslash_type == QUDA_LAPLACE_DSLASH)
if (is_laplace(dslash_type))
errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_LAPLACE=ON");
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH)
if (!is_staggered(dslash_type))
errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
}

Expand All @@ -109,7 +109,7 @@ int main(int argc, char **argv)
eps_naik = 0.0; // to avoid potential headaches
}

if (dslash_type == QUDA_LAPLACE_DSLASH && dtest_type != dslash_test_type::Mat)
if (is_laplace(dslash_type) && dtest_type != dslash_test_type::Mat)
errorQuda("Test type %s is not supported for the Laplace operator", get_string(dtest_type_map, dtest_type).c_str());

int test_rc = RUN_ALL_TESTS();
Expand Down
8 changes: 5 additions & 3 deletions tests/staggered_dslash_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ struct StaggeredDslashTestWrapper {
cpuLong = GaugeField(cpuLongParam);

// Override link reconstruct as appropriate for staggered or asqtad
if (dslash_type == QUDA_STAGGERED_DSLASH || dslash_type == QUDA_ASQTAD_DSLASH) {
if (is_staggered(dslash_type)) {
if (link_recon == QUDA_RECONSTRUCT_12) link_recon = QUDA_RECONSTRUCT_13;
if (link_recon == QUDA_RECONSTRUCT_8) link_recon = QUDA_RECONSTRUCT_9;
}
Expand Down Expand Up @@ -342,19 +342,21 @@ struct StaggeredDslashTestWrapper {

host_timer.start();

if (dslash_type == QUDA_LAPLACE_DSLASH) {
if (is_laplace(dslash_type)) {
switch (dtest_type) {
case dslash_test_type::Mat: dirac->M(cudaSpinorOut, cudaSpinor); break;
default: errorQuda("Test type %d not defined on Laplace operator", static_cast<int>(dtest_type));
}
} else {
} else if (is_staggered(dslash_type)) {
switch (dtest_type) {
case dslash_test_type::Dslash: dirac->Dslash(cudaSpinorOut, cudaSpinor, parity); break;
case dslash_test_type::MatPC: dirac->M(cudaSpinorOut, cudaSpinor); break;
case dslash_test_type::Mat: dirac->M(cudaSpinorOut, cudaSpinor); break;
case dslash_test_type::MatDagMat: dirac->MdagM(cudaSpinorOut, cudaSpinor); break;
default: errorQuda("Test type %d not defined on staggered dslash", static_cast<int>(dtest_type));
}
} else {
errorQuda("Invalid dslash type %d", dslash_type);
}

host_timer.stop();
Expand Down
6 changes: 3 additions & 3 deletions tests/staggered_eigensolve_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ int main(int argc, char **argv)

// Only these fermions are supported in this file
if (is_laplace_enabled) {
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH && dslash_type != QUDA_LAPLACE_DSLASH)
if (!is_staggered(dslash_type) && !is_laplace(dslash_type))
errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
} else {
if (dslash_type == QUDA_LAPLACE_DSLASH)
if (is_laplace(dslash_type))
errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_LAPLACE=ON");
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH)
if (!is_staggered(dslash_type))
errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
}

Expand Down
4 changes: 2 additions & 2 deletions tests/staggered_eigensolve_test_gtest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ bool skip_test(test_t test_param)
// matpc

// this is only legal for the staggered and asqtad op
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH)
if (!is_staggered(dslash_type))
return true;

// we can only compute the real part for Lanczos, and real or magnitude for Arnoldi
Expand All @@ -55,7 +55,7 @@ bool skip_test(test_t test_param)
// matdag_mat

// this is only legal for the staggered and asqtad op
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH)
if (!is_staggered(dslash_type))
return true;

switch (eig_type) {
Expand Down
8 changes: 4 additions & 4 deletions tests/staggered_invert_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,18 +466,18 @@ int main(int argc, char **argv)

// Only these fermions are supported in this file
if (is_laplace_enabled) {
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH && dslash_type != QUDA_LAPLACE_DSLASH)
if (!is_staggered(dslash_type) && !is_laplace(dslash_type))
errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
} else {
if (dslash_type == QUDA_LAPLACE_DSLASH)
if (is_laplace(dslash_type))
errorQuda("The Laplace dslash is not enabled, cmake configure with -DQUDA_LAPLACE=ON");
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH)
if (!is_staggered(dslash_type))
errorQuda("dslash_type %s not supported", get_dslash_str(dslash_type));
}

// Need to add support for LAPLACE MG?
if (inv_multigrid) {
if (dslash_type != QUDA_STAGGERED_DSLASH && dslash_type != QUDA_ASQTAD_DSLASH) {
if (!is_staggered(dslash_type)) {
errorQuda("dslash_type %s not supported for multigrid preconditioner\n", get_dslash_str(dslash_type));
}
}
Expand Down
Loading

0 comments on commit e161cdc

Please sign in to comment.