From 573f36224cd3e5f24bcef3e84beaa66a9e71114a Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 13 Oct 2023 23:49:36 +0200 Subject: [PATCH 1/2] workarounds for numba ssa bug --- sklearn_numba_dpex/common/topk.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sklearn_numba_dpex/common/topk.py b/sklearn_numba_dpex/common/topk.py index 46790ee..6d24fe7 100644 --- a/sklearn_numba_dpex/common/topk.py +++ b/sklearn_numba_dpex/common/topk.py @@ -792,6 +792,11 @@ def compute_radixes( # fmt: on # If `col_idx` is outside the bounds of the input, ignore this location. is_in_bounds = col_idx < n_cols + + mask_for_desired_value_ = uint_type(0) # workaround + item_lexicographically_mapped = uint_type(0) # workaround + desired_masked_value_ = uint_type(0) # workaround + radix_position_ = uint_type(0) # workaround if is_in_bounds: item = array_in_uint[row_idx, col_idx] @@ -1030,6 +1035,7 @@ def check_radix_histogram( # NB: `numba_dpex` seem to produce inefficient (branching) code for `break`, # use `if/else` instead desired_mask_value_search = True + count = np.int64(0) # workaround for _ in range(radix_size): if desired_mask_value_search: count = counts[row_idx, current_count_idx] From c1a15d7eb89e013763ebba58a529f5acddef0c36 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 16 Oct 2023 18:53:10 +0200 Subject: [PATCH 2/2] comments --- sklearn_numba_dpex/common/topk.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/sklearn_numba_dpex/common/topk.py b/sklearn_numba_dpex/common/topk.py index 6d24fe7..207ebe8 100644 --- a/sklearn_numba_dpex/common/topk.py +++ b/sklearn_numba_dpex/common/topk.py @@ -793,10 +793,11 @@ def compute_radixes( # If `col_idx` is outside the bounds of the input, ignore this location. is_in_bounds = col_idx < n_cols - mask_for_desired_value_ = uint_type(0) # workaround - item_lexicographically_mapped = uint_type(0) # workaround - desired_masked_value_ = uint_type(0) # workaround - radix_position_ = uint_type(0) # workaround + # Workaround for https://github.com/numba/numba/issues/9242 + mask_for_desired_value_ = uint_type(0) + item_lexicographically_mapped = uint_type(0) + desired_masked_value_ = uint_type(0) + radix_position_ = uint_type(0) if is_in_bounds: item = array_in_uint[row_idx, col_idx] @@ -1035,7 +1036,7 @@ def check_radix_histogram( # NB: `numba_dpex` seem to produce inefficient (branching) code for `break`, # use `if/else` instead desired_mask_value_search = True - count = np.int64(0) # workaround + count = np.int64(0) # Workaround for https://github.com/numba/numba/issues/9242 for _ in range(radix_size): if desired_mask_value_search: count = counts[row_idx, current_count_idx]