From c6c99004cc0fe502c3fb9bf603b7c9802aebb025 Mon Sep 17 00:00:00 2001 From: "Balos, Cody, J" Date: Mon, 13 May 2024 23:05:17 -0700 Subject: [PATCH] add ARKStepCreateAdjointSolver routine --- include/arkode/arkode.h | 4 ++ include/arkode/arkode_arkstep.h | 6 +++ include/sunadjoint/sunadjoint_solver.h | 6 +-- src/arkode/arkode_arkstep.c | 45 +++++++++++++++++++ src/arkode/arkode_impl.h | 5 +++ src/arkode/arkode_io.c | 16 +++++++ src/sunadjoint/sunadjoint_solver.c | 1 + .../arkode/C_serial/ark_test_sunadjoint.c | 16 +++---- 8 files changed, 85 insertions(+), 14 deletions(-) diff --git a/include/arkode/arkode.h b/include/arkode/arkode.h index 14027693af..810e21d61c 100644 --- a/include/arkode/arkode.h +++ b/include/arkode/arkode.h @@ -31,6 +31,7 @@ #include #include #include +#include #ifdef __cplusplus /* wrapper to enable C++ usage */ extern "C" { @@ -275,6 +276,9 @@ SUNDIALS_EXPORT int ARKodeSetPostprocessStageFn(void* arkode_mem, ARKPostProcessFn ProcessStage); SUNDIALS_EXPORT int ARKodeSetStagePredictFn(void* arkode_mem, ARKStagePredictFn PredictStage); +SUNDIALS_EXPORT +int ARKodeSetCheckpointScheme(void* arkode_mem, + SUNAdjointCheckpointScheme checkpoint_scheme); /* Integrate the ODE over an interval in t */ SUNDIALS_EXPORT int ARKodeEvolve(void* arkode_mem, sunrealtype tout, diff --git a/include/arkode/arkode_arkstep.h b/include/arkode/arkode_arkstep.h index aa6cf26611..744966b3ad 100644 --- a/include/arkode/arkode_arkstep.h +++ b/include/arkode/arkode_arkstep.h @@ -24,6 +24,7 @@ #include #include #include +#include #ifdef __cplusplus /* wrapper to enable C++ usage */ extern "C" { @@ -435,6 +436,11 @@ SUNDIALS_EXPORT int ARKStepCreateMRIStepInnerStepper(void* arkode_mem, SUNDIALS_EXPORT int ARKStepCreateSUNStepper(void* arkode_mem, SUNStepper* stepper); +/* Adjoint solver functions */ +SUNDIALS_EXPORT +int ARKStepCreateAdjointSolver(void* arkode_mem, sunindextype num_cost, + N_Vector sf, SUNAdjointSolver* adj_solver_ptr); + /* Relaxation functions */ SUNDIALS_DEPRECATED_EXPORT_MSG("use ARKodeSetRelaxFn instead") int ARKStepSetRelaxFn(void* arkode_mem, ARKRelaxFn rfn, ARKRelaxJacFn rjac); diff --git a/include/sunadjoint/sunadjoint_solver.h b/include/sunadjoint/sunadjoint_solver.h index 45d41f44d5..e085eb4434 100644 --- a/include/sunadjoint/sunadjoint_solver.h +++ b/include/sunadjoint/sunadjoint_solver.h @@ -39,8 +39,8 @@ typedef struct SUNAdjointSolver_* SUNAdjointSolver; extern "C" { #endif -// IDEA: In lieu of Stepper_ID each package that supports adjoint can have a function that creates the adjoint solver. -// E.g., SUNAdjointSolver ARKStepCreateAdjointSolver(); +// TODO(CJB): I think this should be a private function that is only used +// within the package CreateAdjointSolver routines. SUNDIALS_EXPORT SUNErrCode SUNAdjointSolver_Create(SUNStepper stepper, sunindextype num_cost_fns, N_Vector sf, @@ -52,7 +52,7 @@ SUNErrCode SUNAdjointSolver_Create(SUNStepper stepper, Solves the adjoint system. :param adj_solver: The adjoint solver object. - :param tf: The final output time from the forward integration. + :param tf: The final output time from the forward integration. This is the "starting" time for adjoint solver's backwards integration. :param tout: The time at which the adjoint solution is desired. :param sens: The vector of sensitivity solutions dg/dy0 and dg/dp. diff --git a/src/arkode/arkode_arkstep.c b/src/arkode/arkode_arkstep.c index e99f4fc6d8..1660c0fd9e 100644 --- a/src/arkode/arkode_arkstep.c +++ b/src/arkode/arkode_arkstep.c @@ -3247,6 +3247,51 @@ int arkStep_SUNStepperReset(SUNStepper stepper, sunrealtype tR, N_Vector yR) return (ARKodeReset(arkode_mem, tR, yR)); } +/*--------------------------------------------------------------- + Utility routines for interfacing with SUNAdjointSolver + ---------------------------------------------------------------*/ + +int arkStep_fe_Adj(sunrealtype t, N_Vector y, N_Vector ydot, void* user_data) +{ + return 0; +} + +int arkStep_fi_Adj(sunrealtype t, N_Vector y, N_Vector ydot, void* user_data) +{ + return 0; +} + +int ARKStepCreateAdjointSolver(void* arkode_mem, sunindextype num_cost, + N_Vector sf, SUNAdjointSolver* adj_solver_ptr) +{ + ARKodeMem ark_mem; + ARKodeARKStepMem step_mem; + int retval = arkStep_AccessARKODEStepMem(arkode_mem, + "ARKStepCreateAdjointSolver", + &ark_mem, &step_mem); + if (retval) + { + arkProcessError(NULL, ARK_ILL_INPUT, __LINE__, __func__, __FILE__, + "The ARKStep memory pointer is NULL"); + return ARK_ILL_INPUT; + } + + // TODO(CJB): should we reinit to tretlast or tcur? tcur could be past the time the + // user asked for in the forward integration if they do not use tstop mode. + ARKodeResize(arkode_mem, sf, -ONE, ark_mem->tretlast, NULL, NULL); + ARKRhsFn fe_adj = step_mem->fe ? arkStep_fe_Adj : NULL; + ARKRhsFn fi_adj = step_mem->fi ? arkStep_fi_Adj: NULL; + ARKStepReInit(arkode_mem, fe_adj, fi_adj, ark_mem->tretlast, sf); + + // SUNAdjointSolver will own the SUNStepper and destroy it + SUNStepper stepper; + ARKStepCreateSUNStepper(arkode_mem, &stepper); + SUNAdjointSolver_Create(stepper, num_cost, sf, ark_mem->checkpoint_scheme, + ark_mem->sunctx, adj_solver_ptr); + + return ARK_SUCCESS; +} + /*--------------------------------------------------------------- Utility routines for interfacing with MRIStep ---------------------------------------------------------------*/ diff --git a/src/arkode/arkode_impl.h b/src/arkode/arkode_impl.h index dda9d0bd19..fe65c6b4bd 100644 --- a/src/arkode/arkode_impl.h +++ b/src/arkode/arkode_impl.h @@ -27,6 +27,8 @@ #include #include #include +#include +#include #include "arkode_adapt_impl.h" #include "arkode_relaxation_impl.h" @@ -488,6 +490,9 @@ struct ARKodeMemRec sunbooleantype use_compensated_sums; + /* Adjoint solver data */ + SUNAdjointCheckpointScheme checkpoint_scheme; + /* XBraid interface variables */ sunbooleantype force_pass; /* when true the step attempt loop will ignore the return value (kflag) from arkCheckTemporalError diff --git a/src/arkode/arkode_io.c b/src/arkode/arkode_io.c index 98dc8b2cb1..1f96d8439c 100644 --- a/src/arkode/arkode_io.c +++ b/src/arkode/arkode_io.c @@ -2219,6 +2219,22 @@ int ARKodeSetMaxConvFails(void* arkode_mem, int maxncf) return (ARK_SUCCESS); } +int ARKodeSetCheckpointScheme(void* arkode_mem, SUNAdjointCheckpointScheme checkpoint_scheme) +{ + ARKodeMem ark_mem; + if (arkode_mem == NULL) + { + arkProcessError(NULL, ARK_MEM_NULL, __LINE__, __func__, __FILE__, + MSG_ARK_NO_MEM); + return (ARK_MEM_NULL); + } + ark_mem = (ARKodeMem)arkode_mem; + + ark_mem->checkpoint_scheme = checkpoint_scheme; + + return (ARK_SUCCESS); +} + /*=============================================================== ARKODE optional output utility functions ===============================================================*/ diff --git a/src/sunadjoint/sunadjoint_solver.c b/src/sunadjoint/sunadjoint_solver.c index 01ffe46651..e38369eacf 100644 --- a/src/sunadjoint/sunadjoint_solver.c +++ b/src/sunadjoint/sunadjoint_solver.c @@ -82,6 +82,7 @@ SUNErrCode SUNAdjointSolver_Destroy(SUNAdjointSolver* adj_solver_ptr) { SUNAdjointSolver adj_solver = *adj_solver_ptr; // SUNAdjointCheckpointScheme_Destroy(adj_solver->checkpoint_scheme); + SUNStepper_Destroy(&adj_solver->stepper); free(adj_solver); *adj_solver_ptr = NULL; return SUN_SUCCESS; diff --git a/test/unit_tests/arkode/C_serial/ark_test_sunadjoint.c b/test/unit_tests/arkode/C_serial/ark_test_sunadjoint.c index 14b93fa06d..077255df87 100644 --- a/test/unit_tests/arkode/C_serial/ark_test_sunadjoint.c +++ b/test/unit_tests/arkode/C_serial/ark_test_sunadjoint.c @@ -107,15 +107,9 @@ int adjoint_solution(SUNContext sunctx, void* arkode_mem, // TODO(CJB): Load sf with the sensitivity terminal conditions N_VConst(0.0, sf); - // TODO(CJB): this block of code needs to be less complicated, should wrap it up in something like ARKStepCreateAdjointSolver() + SUNAdjointSolver adj_solver; + ARKStepCreateAdjointSolver(arkode_mem, num_cost, sf, &adj_solver); // lotka_volterra_adjoint is J*lambda - user will provide J, we internally will create RHS that is J*lambda - ARKodeResize(arkode_mem, sf, 1.0, tf, NULL, NULL); - ARKStepReInit(arkode_mem, lotka_volterra_adjoint, NULL, tf, sf); - SUNStepper stepper = NULL; - ARKStepCreateSUNStepper(arkode_mem, &stepper); - SUNAdjointSolver adj_solver = NULL; - SUNAdjointSolver_Create(stepper, num_cost, sf, checkpoint_scheme, sunctx, - &adj_solver); // SUNAdjointSolver_SetJacFn(adj_solver, ); // SUNAdjointSolver_SetJacPFn(adj_solver, ); @@ -127,8 +121,9 @@ int adjoint_solution(SUNContext sunctx, void* arkode_mem, N_VPrint(sf); N_VDestroy(sf); - SUNStepper_Destroy(&stepper); SUNAdjointSolver_Destroy(&adj_solver); + + return 0; } int main(int argc, char* argv[]) @@ -154,8 +149,7 @@ int main(int argc, char* argv[]) // Enable checkpointing during the forward solution SUNAdjointCheckpointScheme checkpoint_scheme = NULL; - // SUNAdjointCheckpointScheme_NewEmpty(sunctx, &checkpoint_scheme); - // ARKodeSetCheckpointScheme(arkode_mem, checkpoint_scheme); + ARKodeSetCheckpointScheme(arkode_mem, checkpoint_scheme); // // Compute the forward solution