Skip to content

Commit

Permalink
fix the Nx=1 case and the (Nx=1 AND Ny=1) cases for unpacking the tru…
Browse files Browse the repository at this point in the history
…ncated FFT into a full multifab. In these cases the spectral box is from Ny/2+1 since there is only 1 cell in x, and analogous for the Nx=Ny=1 cases, the spectral box here is Nz/2+1
  • Loading branch information
ajnonaka committed Jan 4, 2025
1 parent af2ac09 commit c940ce7
Showing 1 changed file with 107 additions and 43 deletions.
150 changes: 107 additions & 43 deletions src_analysis/StructFact.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,11 +307,22 @@ void StructFact::ComputeFFT(const MultiFab& variables,

BL_PROFILE_VAR("StructFact::ComputeFFT()", ComputeFFT);

bool is_flattened = false;

Box domain = variables.boxArray().minimalBox();
if (domain.bigEnd(AMREX_SPACEDIM-1) == 0) {
is_flattened = true; // flattened case
bool chopped_in_x = false;
bool chopped_in_y = false;
bool chopped_in_z = false;

// figure out which direction the spectral box will be chopped
if (domain.length(0) > 1) {
chopped_in_x = true;
} else if (domain.length(1) > 1) {
chopped_in_y = true;
#if (AMREX_SPACEDIM == 3)
} else if (domain.length(2) > 1) {
chopped_in_z = true;
#endif
} else {
Abort("Calling ComputeFFT for a MultiFab with only 1 cell");
}

// compute number of points in the domain and the square root
Expand Down Expand Up @@ -379,61 +390,114 @@ void StructFact::ComputeFFT(const MultiFab& variables,
Array4<Real> const& realpart = variables_dft_real_onegrid.array(mfi);
Array4<Real> const& imagpart = variables_dft_imag_onegrid.array(mfi);

amrex::ParallelFor(bx,
[=] AMREX_GPU_DEVICE (int i, int j, int k) noexcept
{
/*
Unpacking rules:
/*
Unpacking rules:
For domains from (0,0,0) to (Nx-1,Ny-1,Nz-1) and chopped_in_x (i.e., Nx > 1)
For domains from (0,0,0) to (Nx-1,Ny-1,Nz-1)
For any cells with i index > Nx/2, these values are complex conjugates of the corresponding
entry where (Nx-i,Ny-j,Nz-k) UNLESS that index is zero, in which case you use 0.
For any cells with i index > Nx/2, these values are complex conjugates of the corresponding
entry where (Nx-i,Ny-j,Nz-k) UNLESS that index is zero, in which case you use 0.
e.g. for an 8^3 domain, any cell with i index
e.g. for an 8^3 domain, any cell with i index
Cell (6,2,3) is complex conjugate of (2,6,5)
Cell (6,2,3) is complex conjugate of (2,6,5)
Cell (4,1,0) is complex conjugate of (4,7,0) (note that the FFT is computed for 0 <= i <= Nx/2)
Cell (4,1,0) is complex conjugate of (4,7,0) (note that the FFT is computed for 0 <= i <= Nx/2)
*/
if (i <= bx.length(0)/2) {
// copy value
realpart(i,j,k) = spectral(i,j,k).real();
imagpart(i,j,k) = spectral(i,j,k).imag();
} else {
// copy complex conjugate
int iloc = bx.length(0)-i;
int jloc, kloc;
if (is_flattened) {
The analogy extends for the chopped_in_y and z directions
*/

if (chopped_in_x) {
amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE (int i, int j, int k) noexcept
{
if (i <= bx.length(0)/2) {
// copy value
realpart(i,j,k) = spectral(i,j,k).real();
imagpart(i,j,k) = spectral(i,j,k).imag();
} else {
// copy complex conjugate
int iloc = bx.length(0)-i;
int jloc = (j == 0) ? 0 : bx.length(1)-j;
#if (AMREX_SPACEDIM == 2)
jloc = 0;
int kloc = 0;
#elif (AMREX_SPACEDIM == 3)
jloc = (j == 0) ? 0 : bx.length(1)-j;
int kloc = (k == 0) ? 0 : bx.length(2)-k;
#endif
kloc = 0;
if (unpack) {
realpart(i,j,k) = spectral(iloc,jloc,kloc).real();
imagpart(i,j,k) = -spectral(iloc,jloc,kloc).imag();
}
else {
realpart(i,j,k) = 0.0;
imagpart(i,j,k) = 0.0;
}
}

realpart(i,j,k) /= sqrtnpts;
imagpart(i,j,k) /= sqrtnpts;
});
}

if (chopped_in_y) {
amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE (int i, int j, int k) noexcept
{
if (j <= bx.length(1)/2) {
// copy value
realpart(i,j,k) = spectral(i,j,k).real();
imagpart(i,j,k) = spectral(i,j,k).imag();
} else {
jloc = (j == 0) ? 0 : bx.length(1)-j;
// copy complex conjugate
int iloc = (i == 0) ? 0 : bx.length(0)-i;
int jloc = bx.length(1)-j;
#if (AMREX_SPACEDIM == 2)
kloc = 0;
int kloc = 0;
#elif (AMREX_SPACEDIM == 3)
kloc = (k == 0) ? 0 : bx.length(2)-k;
int kloc = (k == 0) ? 0 : bx.length(2)-k;
#endif
if (unpack) {
realpart(i,j,k) = spectral(iloc,jloc,kloc).real();
imagpart(i,j,k) = -spectral(iloc,jloc,kloc).imag();
}
else {
realpart(i,j,k) = 0.0;
imagpart(i,j,k) = 0.0;
}
}

if (unpack) {
realpart(i,j,k) = spectral(iloc,jloc,kloc).real();
imagpart(i,j,k) = -spectral(iloc,jloc,kloc).imag();
}
else {
realpart(i,j,k) = 0.0;
imagpart(i,j,k) = 0.0;
realpart(i,j,k) /= sqrtnpts;
imagpart(i,j,k) /= sqrtnpts;
});
}

if (chopped_in_z) {
amrex::ParallelFor(bx, [=] AMREX_GPU_DEVICE (int i, int j, int k) noexcept
{
if (k <= bx.length(2)/2) {
// copy value
realpart(i,j,k) = spectral(i,j,k).real();
imagpart(i,j,k) = spectral(i,j,k).imag();
} else {
// copy complex conjugate
int iloc = (i == 0) ? 0 : bx.length(0)-i;
int jloc = (j == 0) ? 0 : bx.length(1)-j;
int kloc = bx.length(2)-k;

if (unpack) {
realpart(i,j,k) = spectral(iloc,jloc,kloc).real();
imagpart(i,j,k) = -spectral(iloc,jloc,kloc).imag();
}
else {
realpart(i,j,k) = 0.0;
imagpart(i,j,k) = 0.0;
}
}
}

realpart(i,j,k) /= sqrtnpts;
imagpart(i,j,k) /= sqrtnpts;
});
}
realpart(i,j,k) /= sqrtnpts;
imagpart(i,j,k) /= sqrtnpts;
});
}

} // end MFIter

variables_dft_real.ParallelCopy(variables_dft_real_onegrid,0,comp,1);
variables_dft_imag.ParallelCopy(variables_dft_imag_onegrid,0,comp,1);
Expand Down

0 comments on commit c940ce7

Please sign in to comment.