Skip to content

Commit

Permalink
Use batching where possible in dirac_quda.h
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyscientist committed Dec 11, 2024
1 parent 49a0211 commit 5d6e6e8
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions include/dirac_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -2347,8 +2347,7 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->M(out, in);
for (auto i = 0u; i < in.size(); i++)
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
if (shift != 0.0) blas::axpy(shift, in, out);
}

int getStencilSteps() const override { return dirac->getStencilSteps(); }
Expand All @@ -2369,8 +2368,7 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->MdagM(out, in);
for (auto i = 0u; i < in.size(); i++)
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
if (shift != 0.0) blas::axpy(shift, in, out);
}

int getStencilSteps() const override
Expand Down Expand Up @@ -2421,8 +2419,7 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->MMdag(out, in);
for (auto i = 0u; i < in.size(); i++)
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
if (shift != 0.0) blas::axpy(shift, in, out);
}

int getStencilSteps() const override
Expand All @@ -2448,8 +2445,7 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->Mdag(out, in);
for (auto i = 0u; i < in.size(); i++)
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
if (shift != 0.0) blas::axpy(shift, in, out);
}

int getStencilSteps() const override { return dirac->getStencilSteps(); }
Expand Down Expand Up @@ -2496,7 +2492,7 @@ namespace quda {
@param vec[in,out] vector to which gamma5 is applied in place
*/
void applyGamma5(ColorSpinorField &vec) const
void applyGamma5(cvector_ref<ColorSpinorField> &vec) const
{
auto dirac_type = dirac->getDiracType();
auto pc_type = dirac->getMatPCType();
Expand Down Expand Up @@ -2573,10 +2569,8 @@ namespace quda {
void operator()(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in) const override
{
dirac->M(out, in);
for (auto i = 0u; i < in.size(); i++) {
if (shift != 0.0) blas::axpy(shift, in[i], out[i]);
applyGamma5(out[i]);
}
if (shift != 0.0) blas::axpy(shift, in, out);
applyGamma5(out);
}

int getStencilSteps() const override { return dirac->getStencilSteps(); }
Expand Down

0 comments on commit 5d6e6e8

Please sign in to comment.