From 1d967f0fc7f6dd0556fa8552c8f48e9c0f12e9cc Mon Sep 17 00:00:00 2001 From: iyamazaki Date: Tue, 17 Dec 2024 01:32:47 -0700 Subject: [PATCH] Tacho : rval on device Signed-off-by: iyamazaki --- .../src/impl/Tacho_NumericTools_LevelSet.hpp | 48 +++++++++++++------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_LevelSet.hpp b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_LevelSet.hpp index 25068de037d7..813d6c386d70 100644 --- a/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_LevelSet.hpp +++ b/packages/shylu/shylu_node/tacho/src/impl/Tacho_NumericTools_LevelSet.hpp @@ -2190,8 +2190,13 @@ class NumericToolsLevelSet : public NumericToolsBase { // const ordinal_type team_size_factor[2] = { 16, 16 }, vector_size_factor[2] = { 32, 32}; const ordinal_type team_size_factor[2] = {64, 64}, vector_size_factor[2] = {8, 4}; const ordinal_type team_size_update[2] = {16, 8}, vector_size_update[2] = {32, 32}; + // returned value from team Chol + colind_view d_rval("rval",1); + auto h_rval = Kokkos::create_mirror_view(host_memory_space(), d_rval); { typedef TeamFunctor_FactorizeChol functor_type; + functor_type functor(_info, _factorize_mode, _level_sids, _buf, d_rval.data()); + #if defined(TACHO_TEST_LEVELSET_TOOLS_KERNEL_OVERHEAD) typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> team_policy_factorize; @@ -2204,11 +2209,8 @@ class NumericToolsLevelSet : public NumericToolsBase { typedef Kokkos::TeamPolicy, exec_space, typename functor_type::UpdateTag> team_policy_update; #endif - - int rval = 0; team_policy_factor policy_factor(1, 1, 1); team_policy_update policy_update(1, 1, 1); - functor_type functor(_info, _factorize_mode, _level_sids, _buf, &rval); // get max vector size const ordinal_type vmax = policy_factor.vector_length_max(); @@ -2259,10 +2261,14 @@ class NumericToolsLevelSet : public NumericToolsBase { Kokkos::fence(); time_device += tick.seconds(); tick.reset(); } - Kokkos::fence(); + Kokkos::deep_copy(h_rval, d_rval); + int rval = h_rval(0); if (rval != 0) { TACHO_TEST_FOR_EXCEPTION(rval, std::runtime_error, "POTRF (team) returns non-zero error code."); } + //if (_status != 0) { + // TACHO_TEST_FOR_EXCEPTION(rval, std::runtime_error, "POTRF (device) returns non-zero error code."); + //} Kokkos::parallel_for("update factor", policy_update, functor); if (verbose) { @@ -3926,8 +3932,13 @@ class NumericToolsLevelSet : public NumericToolsBase { const ordinal_type team_size_factor[2] = {64, 64}, vector_size_factor[2] = {8, 4}; #endif const ordinal_type team_size_update[2] = {16, 8}, vector_size_update[2] = {32, 32}; + // returned value from team LDL + colind_view d_rval("rval",1); + auto h_rval = Kokkos::create_mirror_view(host_memory_space(), d_rval); { typedef TeamFunctor_FactorizeLDL functor_type; + functor_type functor(_info, _factorize_mode, _level_sids, _piv, _diag, _buf, d_rval.data()); + #if defined(TACHO_TEST_LEVELSET_TOOLS_KERNEL_OVERHEAD) typedef Kokkos::TeamPolicy, exec_space, typename functor_type::DummyTag> team_policy_factorize; @@ -3940,12 +3951,10 @@ class NumericToolsLevelSet : public NumericToolsBase { typedef Kokkos::TeamPolicy, exec_space, typename functor_type::UpdateTag> team_policy_update; #endif - int rval = 0; - team_policy_factor policy_factor(1, 1, 1); - team_policy_update policy_update(1, 1, 1); - functor_type functor(_info, _factorize_mode, _level_sids, _piv, _diag, _buf, &rval); // get max vector length + team_policy_factor policy_factor(1, 1, 1); + team_policy_update policy_update(1, 1, 1); const ordinal_type vmax = policy_factor.vector_length_max(); { for (ordinal_type lvl = (_team_serial_level_cut - 1); lvl >= 0; --lvl) { @@ -3994,10 +4003,14 @@ class NumericToolsLevelSet : public NumericToolsBase { Kokkos::fence(); time_device += tick.seconds(); tick.reset(); } - Kokkos::fence(); + Kokkos::deep_copy(h_rval, d_rval); + int rval = h_rval(0); if (rval != 0) { TACHO_TEST_FOR_EXCEPTION(rval, std::runtime_error, "SYTRF (team) returns non-zero error code."); } + //if (_status != 0) { + // TACHO_TEST_FOR_EXCEPTION(rval, std::runtime_error, "SYTRF (device) returns non-zero error code."); + //} Kokkos::parallel_for("update factor", policy_update, functor); if (verbose) { @@ -4271,6 +4284,10 @@ class NumericToolsLevelSet : public NumericToolsBase { const ordinal_type team_size_factor[2] = {64, 64}, vector_size_factor[2] = {8, 4}; #endif const ordinal_type team_size_update[2] = {16, 8}, vector_size_update[2] = {32, 32}; + + // returned value from team LU + colind_view d_rval("rval",1); + auto h_rval = Kokkos::create_mirror_view(host_memory_space(), d_rval); { typedef TeamFunctor_FactorizeLU functor_type; #if defined(TACHO_TEST_LEVELSET_TOOLS_KERNEL_OVERHEAD) @@ -4285,10 +4302,9 @@ class NumericToolsLevelSet : public NumericToolsBase { typedef Kokkos::TeamPolicy, exec_space, typename functor_type::UpdateTag> team_policy_update; #endif - int rval = 0; team_policy_factor policy_factor(1, 1, 1); team_policy_update policy_update(1, 1, 1); - functor_type functor(_info, _factorize_mode, _level_sids, _piv, _buf, &rval); + functor_type functor(_info, _factorize_mode, _level_sids, _piv, _buf, d_rval.data()); if (pivot_tol > 0.0) { functor.setDiagPertubationTol(pivot_tol); } @@ -4342,13 +4358,15 @@ class NumericToolsLevelSet : public NumericToolsBase { Kokkos::fence(); time_device += tick.seconds(); tick.reset(); } - Kokkos::fence(); + Kokkos::deep_copy(h_rval, d_rval); + int rval = h_rval(0); if (rval != 0) { TACHO_TEST_FOR_EXCEPTION(rval, std::runtime_error, "GETRF (team) returns non-zero error code."); } - if (_status != 0) { - TACHO_TEST_FOR_EXCEPTION(rval, std::runtime_error, "GETRF (device) returns non-zero error code."); - } + //if (_status != 0) { + // TACHO_TEST_FOR_EXCEPTION(rval, std::runtime_error, "GETRF (device) returns non-zero error code."); + //} + Kokkos::parallel_for("update factor", policy_update, functor); if (verbose) { Kokkos::fence(); time_update += tick.seconds();