From 5d6e6e89ca61dfbd5dbcef3f672bcd0d97fb1460 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 10 Dec 2024 20:28:42 -0800 Subject: [PATCH] Use batching where possible in dirac_quda.h --- include/dirac_quda.h | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/include/dirac_quda.h b/include/dirac_quda.h index a14daef015..d6b123603a 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -2347,8 +2347,7 @@ namespace quda { void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->M(out, in); - for (auto i = 0u; i < in.size(); i++) - if (shift != 0.0) blas::axpy(shift, in[i], out[i]); + if (shift != 0.0) blas::axpy(shift, in, out); } int getStencilSteps() const override { return dirac->getStencilSteps(); } @@ -2369,8 +2368,7 @@ namespace quda { void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->MdagM(out, in); - for (auto i = 0u; i < in.size(); i++) - if (shift != 0.0) blas::axpy(shift, in[i], out[i]); + if (shift != 0.0) blas::axpy(shift, in, out); } int getStencilSteps() const override @@ -2421,8 +2419,7 @@ namespace quda { void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->MMdag(out, in); - for (auto i = 0u; i < in.size(); i++) - if (shift != 0.0) blas::axpy(shift, in[i], out[i]); + if (shift != 0.0) blas::axpy(shift, in, out); } int getStencilSteps() const override @@ -2448,8 +2445,7 @@ namespace quda { void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->Mdag(out, in); - for (auto i = 0u; i < in.size(); i++) - if (shift != 0.0) blas::axpy(shift, in[i], out[i]); + if (shift != 0.0) blas::axpy(shift, in, out); } int getStencilSteps() const override { return dirac->getStencilSteps(); } @@ -2496,7 +2492,7 @@ namespace quda { @param vec[in,out] vector to which gamma5 is applied in place */ - void applyGamma5(ColorSpinorField &vec) const + void applyGamma5(cvector_ref &vec) const { auto dirac_type = dirac->getDiracType(); auto pc_type = dirac->getMatPCType(); @@ -2573,10 +2569,8 @@ namespace quda { void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->M(out, in); - for (auto i = 0u; i < in.size(); i++) { - if (shift != 0.0) blas::axpy(shift, in[i], out[i]); - applyGamma5(out[i]); - } + if (shift != 0.0) blas::axpy(shift, in, out); + applyGamma5(out); } int getStencilSteps() const override { return dirac->getStencilSteps(); }