Skip to content

Commit

Permalink
Add pragma unroll
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyscientist committed Dec 3, 2024
1 parent 0a3f608 commit 22393c5
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions include/kernels/dslash_staggered.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ namespace quda
if (doHalo<kernel_type>(d) && ghost) {
const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dim, d, 1);
const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg));
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
Vector in
= arg.halo.Ghost(d, 1, ghost_idx + (src_idx + s) * arg.nFace * arg.dc.ghostFaceCB[d], their_spinor_parity);
Expand All @@ -111,6 +112,7 @@ namespace quda
} else if (doBulk<kernel_type>() && !ghost) {
const int fwd_idx = linkIndexP1(coord, arg.dim, d);
const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg));
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
Vector in = arg.in[src_idx + s](fwd_idx, their_spinor_parity);
out[s] = mv_add(U, in, out[s]);
Expand All @@ -124,6 +126,7 @@ namespace quda
if (doHalo<kernel_type>(d) && ghost) {
const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dim, d, arg.nFace);
const Link L = arg.L(d, coord.x_cb, parity);
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
const Vector in
= arg.halo.Ghost(d, 1, ghost_idx + (src_idx + s) * arg.nFace * arg.dc.ghostFaceCB[d], their_spinor_parity);
Expand All @@ -132,6 +135,7 @@ namespace quda
} else if (doBulk<kernel_type>() && !ghost) {
const int fwd3_idx = linkIndexP3(coord, arg.dim, d);
const Link L = arg.L(d, coord.x_cb, parity);
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
const Vector in = arg.in[src_idx + s](fwd3_idx, their_spinor_parity);
out[s] = mv_add(L, in, out[s]);
Expand All @@ -148,6 +152,7 @@ namespace quda
const int ghost_idx = arg.improved ? ghostFaceIndexStaggered<0>(coord, arg.dim, d, 3) : ghost_idx2;
const Link U = arg.improved ? arg.U.Ghost(d, ghost_idx2, 1 - parity) :
arg.U.Ghost(d, ghost_idx2, 1 - parity, StaggeredPhase(coord, d, -1, arg));
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
Vector in
= arg.halo.Ghost(d, 0, ghost_idx + (src_idx + s) * arg.nFace * arg.dc.ghostFaceCB[d], their_spinor_parity);
Expand All @@ -158,6 +163,7 @@ namespace quda
const int gauge_idx = back_idx;
const Link U = arg.improved ? arg.U(d, gauge_idx, 1 - parity) :
arg.U(d, gauge_idx, 1 - parity, StaggeredPhase(coord, d, -1, arg));
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
Vector in = arg.in[src_idx + s](back_idx, their_spinor_parity);
out[s] = mv_add(conj(U), -in, out[s]);
Expand All @@ -172,6 +178,7 @@ namespace quda
// when updating replace arg.nFace with 1 here
const int ghost_idx = ghostFaceIndexStaggered<0>(coord, arg.dim, d, 1);
const Link L = arg.L.Ghost(d, ghost_idx, 1 - parity);
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
const Vector in
= arg.halo.Ghost(d, 0, ghost_idx + (src_idx + s) * arg.nFace * arg.dc.ghostFaceCB[d], their_spinor_parity);
Expand All @@ -181,6 +188,7 @@ namespace quda
const int back3_idx = linkIndexM3(coord, arg.dim, d);
const int gauge_idx = back3_idx;
const Link L = arg.L(d, gauge_idx, 1 - parity);
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
const Vector in = arg.in[src_idx + s](back3_idx, their_spinor_parity);
out[s] = mv_add(conj(L), -in, out[s]);
Expand Down Expand Up @@ -215,20 +223,24 @@ namespace quda
array<Vector, n_src_tile> out;
applyStaggered<nParity, mykernel_type, n_src_tile>(out, arg, coord, parity, idx, thread_dim, active, src_idx);

#pragma unroll
for (auto s = 0; s < n_src_tile; s++) out[s] *= arg.dagger_scale;

if (xpay && mykernel_type == INTERIOR_KERNEL) {
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
Vector x = arg.x[src_idx + s](coord.x_cb, my_spinor_parity);
out[s] = arg.a * x - out[s];
}
} else if (mykernel_type != INTERIOR_KERNEL) {
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) {
Vector x = arg.out[src_idx + s](coord.x_cb, my_spinor_parity);
out[s] = x + (xpay ? -out[s] : out[s]);
}
}
if (mykernel_type != EXTERIOR_KERNEL_ALL || active) {
#pragma unroll
for (auto s = 0; s < n_src_tile; s++) { arg.out[src_idx + s](coord.x_cb, my_spinor_parity) = out[s]; }
}
}
Expand Down

0 comments on commit 22393c5

Please sign in to comment.