Skip to content

Commit

Permalink
MueLu CoalesceDropFactory_kokkos: Correctly handle "filtered matrix: …
Browse files Browse the repository at this point in the history
…Dirichlet threshold"

Signed-off-by: Christian Glusa <[email protected]>
  • Loading branch information
cgcgcg committed Jan 13, 2025
1 parent 499d67c commit 56c26cf
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ std::tuple<GlobalOrdinal, typename MueLu::LWGraph_kokkos<LocalOrdinal, GlobalOrd

const bool useRootStencil = pL.get<bool>("filtered matrix: use root stencil");
const bool useSpreadLumping = pL.get<bool>("filtered matrix: use spread lumping");

const Scalar filteringDirichletThreshold = as<SC>(pL.get<double>("filtered matrix: Dirichlet threshold"));
TEUCHOS_ASSERT(!useRootStencil);
TEUCHOS_ASSERT(!useSpreadLumping);

Expand Down Expand Up @@ -692,18 +694,18 @@ std::tuple<GlobalOrdinal, typename MueLu::LWGraph_kokkos<LocalOrdinal, GlobalOrd

if (lumping) {
if (reuseGraph) {
auto fillFunctor = MatrixConstruction::PointwiseFillReuseFunctor<local_matrix_type, local_graph_type, true>(lclA, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::PointwiseFillReuseFunctor<local_matrix_type, local_graph_type, true>(lclA, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_lumped_reuse", range, fillFunctor);
} else {
auto fillFunctor = MatrixConstruction::PointwiseFillNoReuseFunctor<local_matrix_type, true>(lclA, results, lclFilteredA);
auto fillFunctor = MatrixConstruction::PointwiseFillNoReuseFunctor<local_matrix_type, true>(lclA, results, lclFilteredA, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_lumped_noreuse", range, fillFunctor);
}
} else {
if (reuseGraph) {
auto fillFunctor = MatrixConstruction::PointwiseFillReuseFunctor<local_matrix_type, local_graph_type, false>(lclA, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::PointwiseFillReuseFunctor<local_matrix_type, local_graph_type, false>(lclA, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_unlumped_reuse", range, fillFunctor);
} else {
auto fillFunctor = MatrixConstruction::PointwiseFillNoReuseFunctor<local_matrix_type, false>(lclA, results, lclFilteredA);
auto fillFunctor = MatrixConstruction::PointwiseFillNoReuseFunctor<local_matrix_type, false>(lclA, results, lclFilteredA, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_unlumped_noreuse", range, fillFunctor);
}
}
Expand Down Expand Up @@ -854,6 +856,9 @@ std::tuple<GlobalOrdinal, typename MueLu::LWGraph_kokkos<LocalOrdinal, GlobalOrd

const bool useRootStencil = pL.get<bool>("filtered matrix: use root stencil");
const bool useSpreadLumping = pL.get<bool>("filtered matrix: use spread lumping");

const Scalar filteringDirichletThreshold = as<SC>(pL.get<double>("filtered matrix: Dirichlet threshold"));

TEUCHOS_ASSERT(!useRootStencil);
TEUCHOS_ASSERT(!useSpreadLumping);

Expand Down Expand Up @@ -1095,18 +1100,18 @@ std::tuple<GlobalOrdinal, typename MueLu::LWGraph_kokkos<LocalOrdinal, GlobalOrd

if (lumping) {
if (reuseGraph) {
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, true, true>(lclA, blkPartSize, colTranslation, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, true, true>(lclA, blkPartSize, colTranslation, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_lumped_reuse", range, fillFunctor);
} else {
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, true, false>(lclA, blkPartSize, colTranslation, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, true, false>(lclA, blkPartSize, colTranslation, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_lumped_noreuse", range, fillFunctor);
}
} else {
if (reuseGraph) {
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, false, true>(lclA, blkSize, colTranslation, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, false, true>(lclA, blkSize, colTranslation, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_unlumped_reuse", range, fillFunctor);
} else {
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, false, false>(lclA, blkSize, colTranslation, results, lclFilteredA, lclGraph);
auto fillFunctor = MatrixConstruction::VectorFillFunctor<local_matrix_type, false, false>(lclA, blkSize, colTranslation, results, lclFilteredA, lclGraph, filteringDirichletThreshold);
Kokkos::parallel_for("MueLu::CoalesceDrop::Fill_unlumped_noreuse", range, fillFunctor);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,17 @@ class PointwiseFillReuseFunctor {
results_view results;
local_matrix_type filteredA;
local_graph_type graph;
scalar_type dirichletThreshold;
const scalar_type zero = ATS::zero();
const scalar_type one = ATS::one();

public:
PointwiseFillReuseFunctor(local_matrix_type& A_, results_view& results_, local_matrix_type& filteredA_, local_graph_type& graph_)
PointwiseFillReuseFunctor(local_matrix_type& A_, results_view& results_, local_matrix_type& filteredA_, local_graph_type& graph_, scalar_type dirichletThreshold_)
: A(A_)
, results(results_)
, filteredA(filteredA_)
, graph(graph_) {}
, graph(graph_)
, dirichletThreshold(dirichletThreshold_) {}

KOKKOS_INLINE_FUNCTION
void operator()(const local_ordinal_type rlid) const {
Expand Down Expand Up @@ -300,6 +303,8 @@ class PointwiseFillReuseFunctor {
}
if constexpr (lumping) {
rowFilteredA.value(diagOffset) += diagCorrection;
if ((dirichletThreshold >= 0.0) && (ATS::real(rowFilteredA.value(diagOffset)) <= dirichletThreshold))
rowFilteredA.value(diagOffset) = one;
}
}
};
Expand All @@ -323,13 +328,16 @@ class PointwiseFillNoReuseFunctor {
local_matrix_type A;
results_view results;
local_matrix_type filteredA;
scalar_type dirichletThreshold;
const scalar_type zero = ATS::zero();
const scalar_type one = ATS::one();

public:
PointwiseFillNoReuseFunctor(local_matrix_type& A_, results_view& results_, local_matrix_type& filteredA_)
PointwiseFillNoReuseFunctor(local_matrix_type& A_, results_view& results_, local_matrix_type& filteredA_, scalar_type dirichletThreshold_)
: A(A_)
, results(results_)
, filteredA(filteredA_) {}
, filteredA(filteredA_)
, dirichletThreshold(dirichletThreshold_) {}

KOKKOS_INLINE_FUNCTION
void operator()(const local_ordinal_type rlid) const {
Expand All @@ -356,6 +364,8 @@ class PointwiseFillNoReuseFunctor {
}
if constexpr (lumping) {
rowFilteredA.value(diagOffset) += diagCorrection;
if ((dirichletThreshold >= 0.0) && (ATS::real(rowFilteredA.value(diagOffset)) <= dirichletThreshold))
rowFilteredA.value(diagOffset) = one;
}
}
};
Expand Down Expand Up @@ -727,16 +737,19 @@ class VectorFillFunctor {
results_view results;
local_matrix_type filteredA;
local_graph_type graph;
scalar_type dirichletThreshold;
const scalar_type zero = ATS::zero();
const scalar_type one = ATS::one();

public:
VectorFillFunctor(local_matrix_type& A_, local_ordinal_type blockSize_, block_indices_view_type ghosted_point_to_block_, results_view& results_, local_matrix_type& filteredA_, local_graph_type& graph_)
VectorFillFunctor(local_matrix_type& A_, local_ordinal_type blockSize_, block_indices_view_type ghosted_point_to_block_, results_view& results_, local_matrix_type& filteredA_, local_graph_type& graph_, scalar_type dirichletThreshold_)
: A(A_)
, blockSize(blockSize_)
, ghosted_point_to_block(ghosted_point_to_block_)
, results(results_)
, filteredA(filteredA_)
, graph(graph_) {}
, graph(graph_)
, dirichletThreshold(dirichletThreshold_) {}

KOKKOS_INLINE_FUNCTION
void operator()(const local_ordinal_type brlid) const {
Expand Down Expand Up @@ -773,6 +786,8 @@ class VectorFillFunctor {
}
if constexpr (lumping) {
rowFilteredA.value(diagOffset) += diagCorrection;
if ((dirichletThreshold >= 0.0) && (ATS::real(rowFilteredA.value(diagOffset)) <= dirichletThreshold))
rowFilteredA.value(diagOffset) = one;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,11 @@ void ParameterListInterpreter<Scalar, LocalOrdinal, GlobalOrdinal, Node>::
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: use lumping", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: reuse graph", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: reuse eigenvalue", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: use root stencil", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: Dirichlet threshold", double, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: use spread lumping", bool, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: spread lumping diag dom growth factor", double, dropParams);
MUELU_TEST_AND_SET_PARAM_2LIST(paramList, defaultList, "filtered matrix: spread lumping diag dom cap", double, dropParams);
}

if (!amalgFact.is_null())
Expand Down Expand Up @@ -1313,7 +1318,6 @@ void ParameterListInterpreter<Scalar, LocalOrdinal, GlobalOrdinal, Node>::

// Matrix analysis
if (MUELU_TEST_PARAM_2LIST(paramList, defaultList, "matrix: compute analysis", bool, true)) {

RCP<Factory> matrixAnalysisFact = rcp(new MatrixAnalysisFactory());

if (!RAP.is_null())
Expand Down

0 comments on commit 56c26cf

Please sign in to comment.