diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index d64e737306..365e629a5d 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -3838,7 +3838,6 @@ int computeGaugeForceQuda(void* mom, void* siteLink, int*** input_path_buf, int gParamMom.create = qudaGaugeParam->overwrite_mom ? QUDA_ZERO_FIELD_CREATE : QUDA_COPY_FIELD_CREATE; gParamMom.field = &cpuMom; gParamMom.reconstruct = QUDA_RECONSTRUCT_10; - gParamMom.link_type = QUDA_ASQTAD_MOM_LINKS; gParamMom.setPrecision(qudaGaugeParam->cuda_prec, true); GaugeField cudaMom = qudaGaugeParam->use_resident_mom ? momResident.create_alias() : GaugeField(gParamMom); @@ -4498,6 +4497,7 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double checkGaugeParam(gauge_param); if (!gaugePrecise) errorQuda("No resident gauge field"); + if (!cloverPrecise) errorQuda("No resident clover field"); GaugeFieldParam fParam(*gauge_param, h_mom, QUDA_ASQTAD_MOM_LINKS); // create the host momentum field @@ -4506,13 +4506,15 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double // create the device momentum field fParam.location = QUDA_CUDA_FIELD_LOCATION; - fParam.create = QUDA_COPY_FIELD_CREATE; + fParam.create = gauge_param->overwrite_mom ? QUDA_ZERO_FIELD_CREATE : QUDA_COPY_FIELD_CREATE; fParam.field = &cpuMom; + fParam.reconstruct = QUDA_RECONSTRUCT_10; fParam.setPrecision(gauge_param->cuda_prec, true); if (gauge_param->use_resident_mom && !momResident.Length()) errorQuda("No resident momentum field to use"); GaugeField cudaMom = gauge_param->use_resident_mom ? momResident.create_alias() : GaugeField(fParam); - + if (gauge_param->use_resident_mom && gauge_param->overwrite_mom) cudaMom.zero(); + // create the device force field fParam.link_type = QUDA_GENERAL_LINKS; fParam.create = QUDA_ZERO_FIELD_CREATE; @@ -4610,27 +4612,29 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef using namespace quda; auto profile = pushProfile(profileTMCloverForce, inv_param->secs, inv_param->gflops); + checkGaugeParam(gauge_param); + if (!gaugePrecise) errorQuda("No resident gauge field"); + if (!cloverPrecise) errorQuda("No resident clover field"); + if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaInvertParam(inv_param); if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaGaugeParam(gauge_param); double kappa = inv_param->kappa; double k_csw_ov_8 = kappa * inv_param->clover_csw / 8.0; - checkGaugeParam(gauge_param); - if (!gaugePrecise) errorQuda("No resident gauge field"); - if (!cloverPrecise) errorQuda("No resident clover field"); - GaugeFieldParam gParamMom(*gauge_param, h_mom, QUDA_ASQTAD_MOM_LINKS); - GaugeField cpuMom(gParamMom); + GaugeField cpuMom = !gauge_param->use_resident_mom ? GaugeField(gParamMom) : GaugeField(); //create the device momentum field gParamMom.location = QUDA_CUDA_FIELD_LOCATION; - gParamMom.create = QUDA_COPY_FIELD_CREATE; + gParamMom.create = gauge_param->overwrite_mom ? QUDA_ZERO_FIELD_CREATE : QUDA_COPY_FIELD_CREATE; gParamMom.field = &cpuMom; + gParamMom.reconstruct = QUDA_RECONSTRUCT_10; gParamMom.setPrecision(gauge_param->cuda_prec, true); if (gauge_param->use_resident_mom && !momResident.Length()) errorQuda("No resident momentum field to use"); GaugeField gpuMom = gauge_param->use_resident_mom ? momResident.create_alias() : GaugeField(gParamMom); + if (gauge_param->use_resident_mom && gauge_param->overwrite_mom) gpuMom.zero(); // create the device force field gParamMom.link_type = QUDA_GENERAL_LINKS; diff --git a/tests/TMCloverForce_test.cpp b/tests/TMCloverForce_test.cpp index 5b2774d62f..e793d51901 100644 --- a/tests/TMCloverForce_test.cpp +++ b/tests/TMCloverForce_test.cpp @@ -78,10 +78,10 @@ std::tuple clover_force_test(test_t param) bool detratio = ::testing::get<0>(param); int nvector = ::testing::get<1>(param); - std::vector out_nvector(nvector * Nsrc); - std::vector> in(Nsrc, std::vector(nvector)); - std::vector out_nvector0(nvector * Nsrc); - std::vector> in0(Nsrc, std::vector(nvector)); + std::vector out_nvector(nvector); + std::vector in(nvector); + std::vector out_nvector0(nvector); + std::vector in0(nvector); quda::ColorSpinorParam cs_param; constructWilsonTestSpinorParam(&cs_param, &inv_param, &gauge_param); @@ -91,15 +91,13 @@ std::tuple clover_force_test(test_t param) inv_param.num_offset = nvector; for (int i = 0; i < nvector; i++) { // Allocate memory and set pointers - for (int n = 0; n < Nsrc; n++) { - out_nvector[n * nvector + i] = quda::ColorSpinorField(cs_param); - spinorNoise(out_nvector[n * nvector + i], rng, QUDA_NOISE_GAUSS); - in[n][i] = out_nvector[n * nvector + i].data(); - - out_nvector0[n * nvector + i] = quda::ColorSpinorField(cs_param); - spinorNoise(out_nvector0[n * nvector + i], rng, QUDA_NOISE_GAUSS); - in0[n][i] = out_nvector0[n * nvector + i].data(); - } + out_nvector[i] = quda::ColorSpinorField(cs_param); + spinorNoise(out_nvector[i], rng, QUDA_NOISE_GAUSS); + in[i] = out_nvector[i].data(); + + out_nvector0[i] = quda::ColorSpinorField(cs_param); + spinorNoise(out_nvector0[i], rng, QUDA_NOISE_GAUSS); + in0[i] = out_nvector0[i].data(); } std::vector coeff(nvector); @@ -108,16 +106,16 @@ std::tuple clover_force_test(test_t param) coeff[i] += coeff[i] * (i + 1) / 10.0; } gauge_param.gauge_order = QUDA_MILC_GAUGE_ORDER; + gauge_param.overwrite_mom = 1; if (getTuning() == QUDA_TUNE_YES) - computeTMCloverForceQuda(mom.data(), in[0].data(), in0[0].data(), coeff.data(), nvector, &gauge_param, &inv_param, + computeTMCloverForceQuda(mom.data(), in.data(), in0.data(), coeff.data(), nvector, &gauge_param, &inv_param, detratio); // Multiple execution to exclude warmup time in the first run double time_sec = 0.0; double gflops = 0.0; for (int i = 0; i < niter; i++) { - mom = mom_ref; // restore initial momentum for correctness - computeTMCloverForceQuda(mom.data(), in[0].data(), in0[0].data(), coeff.data(), nvector, &gauge_param, &inv_param, + computeTMCloverForceQuda(mom.data(), in.data(), in0.data(), coeff.data(), nvector, &gauge_param, &inv_param, detratio); time_sec += inv_param.secs; gflops += inv_param.gflops; @@ -127,7 +125,8 @@ std::tuple clover_force_test(test_t param) std::array u = {gauge.data(0), gauge.data(1), gauge.data(2), gauge.data(3)}; if (verify_results) { gauge_param.gauge_order = QUDA_QDP_GAUGE_ORDER; - TMCloverForce_reference(mom_ref.data(), in[0].data(), in0[0].data(), coeff.data(), nvector, u, clover, clover_inv, + mom_ref.zero(); + TMCloverForce_reference(mom_ref.data(), in.data(), in0.data(), coeff.data(), nvector, u, clover, clover_inv, &gauge_param, &inv_param, detratio); *check_out = compare_floats(mom.data(), mom_ref.data(), 4 * V * mom_site_size, getTolerance(cuda_prec), gauge_param.cpu_prec); @@ -218,7 +217,7 @@ int main(int argc, char **argv) test_rc = RUN_ALL_TESTS(); } else { - clover_force_test({detratio, nvector}); + clover_force_test({detratio, Nsrc}); } destroy(); diff --git a/tests/utils/command_line_params.cpp b/tests/utils/command_line_params.cpp index a483407528..ee70377dab 100644 --- a/tests/utils/command_line_params.cpp +++ b/tests/utils/command_line_params.cpp @@ -302,8 +302,8 @@ bool smear_delete_two_link = true; bool enable_testing = false; -int nvector = 1; bool detratio = false; + namespace { CLI::TransformPairs ca_basis_map {{"power", QUDA_POWER_BASIS}, {"chebyshev", QUDA_CHEBYSHEV_BASIS}}; @@ -1122,7 +1122,5 @@ void add_quark_smear_option_group(std::shared_ptr quda_app) void add_clover_force_option_group(std::shared_ptr quda_app) { auto opgroup = quda_app->add_option_group("Clover force", "Options controlling clover force testing"); - opgroup->add_option("--nvector", nvector, "Test multiple quark fields. Default is 1"); opgroup->add_option("--detratio", detratio, "Test a ratio of determinants. Default is false"); - } diff --git a/tests/utils/command_line_params.h b/tests/utils/command_line_params.h index 749c066555..fb52e5764d 100644 --- a/tests/utils/command_line_params.h +++ b/tests/utils/command_line_params.h @@ -430,5 +430,4 @@ extern std::array grid_partition; extern bool enable_testing; -extern int nvector; -extern bool detratio; \ No newline at end of file +extern bool detratio;