diff --git a/include/kernels/dslash_staggered.cuh b/include/kernels/dslash_staggered.cuh index 3dfaf01be0..ec09ab1622 100644 --- a/include/kernels/dslash_staggered.cuh +++ b/include/kernels/dslash_staggered.cuh @@ -103,6 +103,7 @@ namespace quda if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dim, d, 1); const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg)); +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector in = arg.halo.Ghost(d, 1, ghost_idx + (src_idx + s) * arg.nFace * arg.dc.ghostFaceCB[d], their_spinor_parity); @@ -111,6 +112,7 @@ namespace quda } else if (doBulk() && !ghost) { const int fwd_idx = linkIndexP1(coord, arg.dim, d); const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg)); +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector in = arg.in[src_idx + s](fwd_idx, their_spinor_parity); out[s] = mv_add(U, in, out[s]); @@ -124,6 +126,7 @@ namespace quda if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dim, d, arg.nFace); const Link L = arg.L(d, coord.x_cb, parity); +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { const Vector in = arg.halo.Ghost(d, 1, ghost_idx + (src_idx + s) * arg.nFace * arg.dc.ghostFaceCB[d], their_spinor_parity); @@ -132,6 +135,7 @@ namespace quda } else if (doBulk() && !ghost) { const int fwd3_idx = linkIndexP3(coord, arg.dim, d); const Link L = arg.L(d, coord.x_cb, parity); +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { const Vector in = arg.in[src_idx + s](fwd3_idx, their_spinor_parity); out[s] = mv_add(L, in, out[s]); @@ -148,6 +152,7 @@ namespace quda const int ghost_idx = arg.improved ? ghostFaceIndexStaggered<0>(coord, arg.dim, d, 3) : ghost_idx2; const Link U = arg.improved ? arg.U.Ghost(d, ghost_idx2, 1 - parity) : arg.U.Ghost(d, ghost_idx2, 1 - parity, StaggeredPhase(coord, d, -1, arg)); +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector in = arg.halo.Ghost(d, 0, ghost_idx + (src_idx + s) * arg.nFace * arg.dc.ghostFaceCB[d], their_spinor_parity); @@ -158,6 +163,7 @@ namespace quda const int gauge_idx = back_idx; const Link U = arg.improved ? arg.U(d, gauge_idx, 1 - parity) : arg.U(d, gauge_idx, 1 - parity, StaggeredPhase(coord, d, -1, arg)); +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector in = arg.in[src_idx + s](back_idx, their_spinor_parity); out[s] = mv_add(conj(U), -in, out[s]); @@ -172,6 +178,7 @@ namespace quda // when updating replace arg.nFace with 1 here const int ghost_idx = ghostFaceIndexStaggered<0>(coord, arg.dim, d, 1); const Link L = arg.L.Ghost(d, ghost_idx, 1 - parity); +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { const Vector in = arg.halo.Ghost(d, 0, ghost_idx + (src_idx + s) * arg.nFace * arg.dc.ghostFaceCB[d], their_spinor_parity); @@ -181,6 +188,7 @@ namespace quda const int back3_idx = linkIndexM3(coord, arg.dim, d); const int gauge_idx = back3_idx; const Link L = arg.L(d, gauge_idx, 1 - parity); +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { const Vector in = arg.in[src_idx + s](back3_idx, their_spinor_parity); out[s] = mv_add(conj(L), -in, out[s]); @@ -215,20 +223,24 @@ namespace quda array out; applyStaggered(out, arg, coord, parity, idx, thread_dim, active, src_idx); +#pragma unroll for (auto s = 0; s < n_src_tile; s++) out[s] *= arg.dagger_scale; if (xpay && mykernel_type == INTERIOR_KERNEL) { +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector x = arg.x[src_idx + s](coord.x_cb, my_spinor_parity); out[s] = arg.a * x - out[s]; } } else if (mykernel_type != INTERIOR_KERNEL) { +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector x = arg.out[src_idx + s](coord.x_cb, my_spinor_parity); out[s] = x + (xpay ? -out[s] : out[s]); } } if (mykernel_type != EXTERIOR_KERNEL_ALL || active) { +#pragma unroll for (auto s = 0; s < n_src_tile; s++) { arg.out[src_idx + s](coord.x_cb, my_spinor_parity) = out[s]; } } }