Skip to content

Commit

Permalink
Only download host mom field if we need to for clover force. Cleanup …
Browse files Browse the repository at this point in the history
…of clover force test code
  • Loading branch information
maddyscientist committed Dec 8, 2023
1 parent ddd0fa9 commit ad814bb
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 32 deletions.
22 changes: 13 additions & 9 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3838,7 +3838,6 @@ int computeGaugeForceQuda(void* mom, void* siteLink, int*** input_path_buf, int
gParamMom.create = qudaGaugeParam->overwrite_mom ? QUDA_ZERO_FIELD_CREATE : QUDA_COPY_FIELD_CREATE;
gParamMom.field = &cpuMom;
gParamMom.reconstruct = QUDA_RECONSTRUCT_10;
gParamMom.link_type = QUDA_ASQTAD_MOM_LINKS;
gParamMom.setPrecision(qudaGaugeParam->cuda_prec, true);

GaugeField cudaMom = qudaGaugeParam->use_resident_mom ? momResident.create_alias() : GaugeField(gParamMom);
Expand Down Expand Up @@ -4498,6 +4497,7 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double

checkGaugeParam(gauge_param);
if (!gaugePrecise) errorQuda("No resident gauge field");
if (!cloverPrecise) errorQuda("No resident clover field");

GaugeFieldParam fParam(*gauge_param, h_mom, QUDA_ASQTAD_MOM_LINKS);
// create the host momentum field
Expand All @@ -4506,13 +4506,15 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double

// create the device momentum field
fParam.location = QUDA_CUDA_FIELD_LOCATION;
fParam.create = QUDA_COPY_FIELD_CREATE;
fParam.create = gauge_param->overwrite_mom ? QUDA_ZERO_FIELD_CREATE : QUDA_COPY_FIELD_CREATE;
fParam.field = &cpuMom;
fParam.reconstruct = QUDA_RECONSTRUCT_10;
fParam.setPrecision(gauge_param->cuda_prec, true);

if (gauge_param->use_resident_mom && !momResident.Length()) errorQuda("No resident momentum field to use");
GaugeField cudaMom = gauge_param->use_resident_mom ? momResident.create_alias() : GaugeField(fParam);

if (gauge_param->use_resident_mom && gauge_param->overwrite_mom) cudaMom.zero();

// create the device force field
fParam.link_type = QUDA_GENERAL_LINKS;
fParam.create = QUDA_ZERO_FIELD_CREATE;
Expand Down Expand Up @@ -4610,27 +4612,29 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef
using namespace quda;
auto profile = pushProfile(profileTMCloverForce, inv_param->secs, inv_param->gflops);

checkGaugeParam(gauge_param);
if (!gaugePrecise) errorQuda("No resident gauge field");
if (!cloverPrecise) errorQuda("No resident clover field");

if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaInvertParam(inv_param);
if (getVerbosity() >= QUDA_DEBUG_VERBOSE) printQudaGaugeParam(gauge_param);

double kappa = inv_param->kappa;
double k_csw_ov_8 = kappa * inv_param->clover_csw / 8.0;

checkGaugeParam(gauge_param);
if (!gaugePrecise) errorQuda("No resident gauge field");
if (!cloverPrecise) errorQuda("No resident clover field");

GaugeFieldParam gParamMom(*gauge_param, h_mom, QUDA_ASQTAD_MOM_LINKS);
GaugeField cpuMom(gParamMom);
GaugeField cpuMom = !gauge_param->use_resident_mom ? GaugeField(gParamMom) : GaugeField();

//create the device momentum field
gParamMom.location = QUDA_CUDA_FIELD_LOCATION;
gParamMom.create = QUDA_COPY_FIELD_CREATE;
gParamMom.create = gauge_param->overwrite_mom ? QUDA_ZERO_FIELD_CREATE : QUDA_COPY_FIELD_CREATE;
gParamMom.field = &cpuMom;
gParamMom.reconstruct = QUDA_RECONSTRUCT_10;
gParamMom.setPrecision(gauge_param->cuda_prec, true);

if (gauge_param->use_resident_mom && !momResident.Length()) errorQuda("No resident momentum field to use");
GaugeField gpuMom = gauge_param->use_resident_mom ? momResident.create_alias() : GaugeField(gParamMom);
if (gauge_param->use_resident_mom && gauge_param->overwrite_mom) gpuMom.zero();

// create the device force field
gParamMom.link_type = QUDA_GENERAL_LINKS;
Expand Down
35 changes: 17 additions & 18 deletions tests/TMCloverForce_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ std::tuple<int, double> clover_force_test(test_t param)
bool detratio = ::testing::get<0>(param);
int nvector = ::testing::get<1>(param);

std::vector<quda::ColorSpinorField> out_nvector(nvector * Nsrc);
std::vector<std::vector<void *>> in(Nsrc, std::vector<void *>(nvector));
std::vector<quda::ColorSpinorField> out_nvector0(nvector * Nsrc);
std::vector<std::vector<void *>> in0(Nsrc, std::vector<void *>(nvector));
std::vector<quda::ColorSpinorField> out_nvector(nvector);
std::vector<void *> in(nvector);
std::vector<quda::ColorSpinorField> out_nvector0(nvector);
std::vector<void *> in0(nvector);

quda::ColorSpinorParam cs_param;
constructWilsonTestSpinorParam(&cs_param, &inv_param, &gauge_param);
Expand All @@ -91,15 +91,13 @@ std::tuple<int, double> clover_force_test(test_t param)
inv_param.num_offset = nvector;
for (int i = 0; i < nvector; i++) {
// Allocate memory and set pointers
for (int n = 0; n < Nsrc; n++) {
out_nvector[n * nvector + i] = quda::ColorSpinorField(cs_param);
spinorNoise(out_nvector[n * nvector + i], rng, QUDA_NOISE_GAUSS);
in[n][i] = out_nvector[n * nvector + i].data();

out_nvector0[n * nvector + i] = quda::ColorSpinorField(cs_param);
spinorNoise(out_nvector0[n * nvector + i], rng, QUDA_NOISE_GAUSS);
in0[n][i] = out_nvector0[n * nvector + i].data();
}
out_nvector[i] = quda::ColorSpinorField(cs_param);
spinorNoise(out_nvector[i], rng, QUDA_NOISE_GAUSS);
in[i] = out_nvector[i].data();

out_nvector0[i] = quda::ColorSpinorField(cs_param);
spinorNoise(out_nvector0[i], rng, QUDA_NOISE_GAUSS);
in0[i] = out_nvector0[i].data();
}

std::vector<double> coeff(nvector);
Expand All @@ -108,16 +106,16 @@ std::tuple<int, double> clover_force_test(test_t param)
coeff[i] += coeff[i] * (i + 1) / 10.0;
}
gauge_param.gauge_order = QUDA_MILC_GAUGE_ORDER;
gauge_param.overwrite_mom = 1;
if (getTuning() == QUDA_TUNE_YES)
computeTMCloverForceQuda(mom.data(), in[0].data(), in0[0].data(), coeff.data(), nvector, &gauge_param, &inv_param,
computeTMCloverForceQuda(mom.data(), in.data(), in0.data(), coeff.data(), nvector, &gauge_param, &inv_param,
detratio);

// Multiple execution to exclude warmup time in the first run
double time_sec = 0.0;
double gflops = 0.0;
for (int i = 0; i < niter; i++) {
mom = mom_ref; // restore initial momentum for correctness
computeTMCloverForceQuda(mom.data(), in[0].data(), in0[0].data(), coeff.data(), nvector, &gauge_param, &inv_param,
computeTMCloverForceQuda(mom.data(), in.data(), in0.data(), coeff.data(), nvector, &gauge_param, &inv_param,
detratio);
time_sec += inv_param.secs;
gflops += inv_param.gflops;
Expand All @@ -127,7 +125,8 @@ std::tuple<int, double> clover_force_test(test_t param)
std::array<void *, 4> u = {gauge.data(0), gauge.data(1), gauge.data(2), gauge.data(3)};
if (verify_results) {
gauge_param.gauge_order = QUDA_QDP_GAUGE_ORDER;
TMCloverForce_reference(mom_ref.data(), in[0].data(), in0[0].data(), coeff.data(), nvector, u, clover, clover_inv,
mom_ref.zero();
TMCloverForce_reference(mom_ref.data(), in.data(), in0.data(), coeff.data(), nvector, u, clover, clover_inv,
&gauge_param, &inv_param, detratio);
*check_out
= compare_floats(mom.data(), mom_ref.data(), 4 * V * mom_site_size, getTolerance(cuda_prec), gauge_param.cpu_prec);
Expand Down Expand Up @@ -218,7 +217,7 @@ int main(int argc, char **argv)

test_rc = RUN_ALL_TESTS();
} else {
clover_force_test({detratio, nvector});
clover_force_test({detratio, Nsrc});
}

destroy();
Expand Down
4 changes: 1 addition & 3 deletions tests/utils/command_line_params.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ bool smear_delete_two_link = true;

bool enable_testing = false;

int nvector = 1;
bool detratio = false;

namespace
{
CLI::TransformPairs<QudaCABasis> ca_basis_map {{"power", QUDA_POWER_BASIS}, {"chebyshev", QUDA_CHEBYSHEV_BASIS}};
Expand Down Expand Up @@ -1122,7 +1122,5 @@ void add_quark_smear_option_group(std::shared_ptr<QUDAApp> quda_app)
void add_clover_force_option_group(std::shared_ptr<QUDAApp> quda_app)
{
auto opgroup = quda_app->add_option_group("Clover force", "Options controlling clover force testing");
opgroup->add_option("--nvector", nvector, "Test multiple quark fields. Default is 1");
opgroup->add_option("--detratio", detratio, "Test a ratio of determinants. Default is false");

}
3 changes: 1 addition & 2 deletions tests/utils/command_line_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,5 +430,4 @@ extern std::array<int, 4> grid_partition;

extern bool enable_testing;

extern int nvector;
extern bool detratio;
extern bool detratio;

0 comments on commit ad814bb

Please sign in to comment.