Skip to content

Commit

Permalink
ODESolvers: CalcYs using lincomb
Browse files Browse the repository at this point in the history
  • Loading branch information
lwJi committed May 15, 2024
1 parent 98622f4 commit 1f5edf0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 72 deletions.
26 changes: 14 additions & 12 deletions ODESolvers/schedule.ccl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ if(use_subcycling_wip) {
SCHEDULE GROUP ODESolvers_SyncKs
{
} "Group of Sync Ks at RM Boundary"

SCHEDULE GROUP ODESolvers_CalcYfFromKcs1
{
} "Group of Calc Yfs from Kcs at RM Boundary for RK1"
Expand Down Expand Up @@ -49,28 +50,29 @@ if(use_subcycling_wip) {
SCHEDULE GROUP ODESolvers_SyncState
{
} "Group of Sync State"
SCHEDULE GROUP ODESolvers_CalcY2

SCHEDULE GROUP ODESolvers_SetK1
{
} "Group of Calc Y2"
SCHEDULE GROUP ODESolvers_CalcY3
} "Group of Set K1"
SCHEDULE GROUP ODESolvers_SetK2
{
} "Group of Calc Y3"
SCHEDULE GROUP ODESolvers_CalcY4
} "Group of Set K2"
SCHEDULE GROUP ODESolvers_SetK3
{
} "Group of Calc Y4"
} "Group of Set K3"

SCHEDULE ODESolvers_Solve_CalcY2 IN ODESolvers_CalcY2
SCHEDULE ODESolvers_Solve_SetK1 IN ODESolvers_SetK1
{
LANG: C
} "Calculate Y2"
SCHEDULE ODESolvers_Solve_CalcY3 IN ODESolvers_CalcY3
} "Set K1 from RHS"
SCHEDULE ODESolvers_Solve_SetK2 IN ODESolvers_SetK2
{
LANG: C
} "Calculate Y3"
SCHEDULE ODESolvers_Solve_CalcY4 IN ODESolvers_CalcY4
} "Set K2 from RHS"
SCHEDULE ODESolvers_Solve_SetK3 IN ODESolvers_SetK3
{
LANG: C
} "Calculate Y4"
} "Set K3 from RHS"
}
else {
SCHEDULE ODESolvers_Solve AT evol
Expand Down
85 changes: 25 additions & 60 deletions ODESolvers/src/odesolvers_solve_subcycling.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -79,79 +79,44 @@ SetK(const Loop::GridDescBaseDevice &grid, const Loop::GF3D2<CCTK_REAL> &K,
CCTK_ATTRIBUTE_ALWAYS_INLINE { K(p.I) = rhs(p.I); });
}

CCTK_HOST CCTK_ATTRIBUTE_ALWAYS_INLINE inline void
CalcY(const Loop::GridDescBaseDevice &grid, const Loop::GF3D2<CCTK_REAL> &Y,
const Loop::GF3D2<const CCTK_REAL> &u0,
const Loop::GF3D2<const CCTK_REAL> &rhs, const CCTK_REAL dt) {
grid.loop_int_device<0, 0, 0>(
grid.nghostzones,
[=] CCTK_DEVICE(const Loop::PointDesc &p)
CCTK_ATTRIBUTE_ALWAYS_INLINE { Y(p.I) = u0(p.I) + rhs(p.I) * dt; });
}

template <int RKSTAGES>
CCTK_HOST CCTK_ATTRIBUTE_ALWAYS_INLINE inline void
CalcYs(CCTK_ARGUMENTS, vector<int> &Ys, vector<int> &u0s, vector<int> &rhss,
const array<vector<int>, RKSTAGES> &kss, const CCTK_REAL dt,
const CCTK_INT stage) {
SetK(CCTK_ARGUMENTS, vector<int> &rhss,
const array<vector<int>, RKSTAGES> &kss, const CCTK_INT stage) {
assert(stage > 0 && stage <= 4);
const CCTK_REAL dt_stage = (stage == 4) ? dt : dt * CCTK_REAL(0.5);
const Loop::GridDescBaseDevice grid(cctkGH);
const int tl = 0;
// TODO: we need different centering types of flag for refinement boundary,
// maybe make it a group
for (size_t i = 0; i < Ys.size(); ++i) {
const int nvars = CCTK_NumVarsInGroupI(Ys[i]);
for (size_t i = 0; i < rhss.size(); ++i) {
const int nvars = CCTK_NumVarsInGroupI(rhss[i]);
const Loop::GF3D2layout layout(cctkGH,
Subcycling::get_group_indextype(Ys[i]));

const int Y_0 = CCTK_FirstVarIndexI(Ys[i]);
const int u0_0 = CCTK_FirstVarIndexI(u0s[i]);
Subcycling::get_group_indextype(rhss[i]));
const int rhs_0 = CCTK_FirstVarIndexI(rhss[i]);
const int K_0 = CCTK_FirstVarIndexI(kss[stage - 2][i]);
const int K_0 = CCTK_FirstVarIndexI(kss[stage - 1][i]);
for (int vi = 0; vi < nvars; vi++) {
const Loop::GF3D2<CCTK_REAL> K(
layout,
static_cast<CCTK_REAL *>(CCTK_VarDataPtrI(cctkGH, tl, K_0 + vi)));
const Loop::GF3D2<CCTK_REAL> Y(
layout,
static_cast<CCTK_REAL *>(CCTK_VarDataPtrI(cctkGH, tl, Y_0 + vi)));
const Loop::GF3D2<const CCTK_REAL> u0(
layout,
static_cast<CCTK_REAL *>(CCTK_VarDataPtrI(cctkGH, tl, u0_0 + vi)));
const Loop::GF3D2<const CCTK_REAL> rhs(
layout,
static_cast<CCTK_REAL *>(CCTK_VarDataPtrI(cctkGH, tl, rhs_0 + vi)));

switch (RKSTAGES) {
case 4: {
SetK(grid, K, rhs);
CalcY(grid, Y, u0, rhs, dt_stage);
break;
}
default: {
CCTK_ERROR("Unsupported RK stages with subcycling");
break;
}
}
SetK(grid, K, rhs);
}
}
}

extern "C" void ODESolvers_Solve_CalcY2(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTS_ODESolvers_Solve_CalcY2;
CalcYs<rkstages>(CCTK_PASS_CTOC, VarGroups, OldGroups, RhsGroups, KsGroups,
CCTK_DELTA_TIME, 2);
extern "C" void ODESolvers_Solve_SetK1(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTS_ODESolvers_Solve_SetK1;
SetK<rkstages>(CCTK_PASS_CTOC, RhsGroups, KsGroups, 1);
}
extern "C" void ODESolvers_Solve_CalcY3(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTS_ODESolvers_Solve_CalcY3;
CalcYs<rkstages>(CCTK_PASS_CTOC, VarGroups, OldGroups, RhsGroups, KsGroups,
CCTK_DELTA_TIME, 3);
extern "C" void ODESolvers_Solve_SetK2(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTS_ODESolvers_Solve_SetK2;
SetK<rkstages>(CCTK_PASS_CTOC, RhsGroups, KsGroups, 2);
}
extern "C" void ODESolvers_Solve_CalcY4(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTS_ODESolvers_Solve_CalcY4;
CalcYs<rkstages>(CCTK_PASS_CTOC, VarGroups, OldGroups, RhsGroups, KsGroups,
CCTK_DELTA_TIME, 4);
extern "C" void ODESolvers_Solve_SetK3(CCTK_ARGUMENTS) {
DECLARE_CCTK_ARGUMENTS_ODESolvers_Solve_SetK3;
SetK<rkstages>(CCTK_PASS_CTOC, RhsGroups, KsGroups, 3);
}

extern "C" void ODESolvers_Solve_Subcycling(CCTK_ARGUMENTS) {
Expand Down Expand Up @@ -297,11 +262,11 @@ extern "C" void ODESolvers_Solve_Subcycling(CCTK_ARGUMENTS) {
// rhs.check_valid(make_valid_int(),
// "ODESolvers after calling ODESolvers_RHS");
// var = Y2 = y0 + h/2 k1
CallScheduleGroup(cctkGH, "ODESolvers_CalcY2");
CallScheduleGroup(cctkGH, "ODESolvers_SetK1");
// statecomp_t::lincomb(ks[0], 0, make_array(CCTK_REAL(1)),
// make_array(&rhs), make_valid_int());
// statecomp_t::lincomb(var, 1, make_array(dt / 2), make_array(&rhs),
// make_valid_int());
statecomp_t::lincomb(var, 1, make_array(dt / 2), make_array(&rhs),
make_valid_int());
CallScheduleGroup(cctkGH, "ODESolvers_SyncState");
// var.check_valid(make_valid_int(),
// "ODESolvers after defining new state vector");
Expand All @@ -318,11 +283,11 @@ extern "C" void ODESolvers_Solve_Subcycling(CCTK_ARGUMENTS) {
// rhs.check_valid(make_valid_int(),
// "ODESolvers after calling ODESolvers_RHS");
// var = Y3 = y0 + h/2 k2
CallScheduleGroup(cctkGH, "ODESolvers_CalcY3");
CallScheduleGroup(cctkGH, "ODESolvers_SetK2");
// statecomp_t::lincomb(ks[1], 0, make_array(CCTK_REAL(1)),
// make_array(&rhs), make_valid_int());
// statecomp_t::lincomb(var, 0, make_array(CCTK_REAL(1), dt / 2),
// make_array(&old, &rhs), make_valid_int());
statecomp_t::lincomb(var, 0, make_array(CCTK_REAL(1), dt / 2),
make_array(&old, &rhs), make_valid_int());
CallScheduleGroup(cctkGH, "ODESolvers_SyncState");
// var.check_valid(make_valid_int(),
// "ODESolvers after defining new state vector");
Expand All @@ -339,11 +304,11 @@ extern "C" void ODESolvers_Solve_Subcycling(CCTK_ARGUMENTS) {
// rhs.check_valid(make_valid_int(),
// "ODESolvers after calling ODESolvers_RHS");
// var = Y4 = y0 + h k3
CallScheduleGroup(cctkGH, "ODESolvers_CalcY4");
CallScheduleGroup(cctkGH, "ODESolvers_SetK3");
// statecomp_t::lincomb(ks[2], 0, make_array(CCTK_REAL(1)),
// make_array(&rhs), make_valid_int());
// statecomp_t::lincomb(var, 0, make_array(CCTK_REAL(1), dt),
// make_array(&old, &rhs), make_valid_int());
statecomp_t::lincomb(var, 0, make_array(CCTK_REAL(1), dt),
make_array(&old, &rhs), make_valid_int());
CallScheduleGroup(cctkGH, "ODESolvers_SyncState");
// var.check_valid(make_valid_int(),
// "ODESolvers after defining new state vector");
Expand Down

0 comments on commit 1f5edf0

Please sign in to comment.