diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index 283ab6a64b3..3dfc324af4c 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -63,14 +63,13 @@ void reduceProductTo(int64_t& z, int64_t& y, int64_t& x, const int64_t max) { } } -std::unique_ptr < ReductionParams > - 2dInnerReductionHeuristic( - const int64_t total_reduction_numel, - const int64_t total_iteration_numel, - const int64_t inner_most_dimension_numel, - const int64_t n_tensor_inputs, - const int64_t max_input_dtype_size, - const size_t vectorize_factor) { +std::unique_ptr inner2dReductionHeuristic( + const int64_t total_reduction_numel, + const int64_t total_iteration_numel, + const int64_t inner_most_dimension_numel, + const int64_t n_tensor_inputs, + const int64_t max_input_dtype_size, + const size_t vectorize_factor) { // Set some targets for parallelization const int64_t n_elems = total_reduction_numel * total_iteration_numel; @@ -493,7 +492,7 @@ std::unique_ptr < ReductionParams > << (rparams->unroll_factor_inner_reduction > 1) << ", " << rparams->cross_grid_inner_reduction << std::endl; } - return innerReductionHeuristic( + return inner2dReductionHeuristic( total_reduction_numel, total_iteration_numel, total_reduction_numel, @@ -506,14 +505,13 @@ std::unique_ptr < ReductionParams > return rparams; } -std::unique_ptr < ReductionParams > - 3dInnerReductionHeuristic( - const int64_t total_reduction_numel, - const int64_t total_iteration_numel, - const int64_t inner_most_dimension_numel, - const int64_t n_tensor_inputs, - const int64_t max_input_dtype_size, - const size_t vectorize_factor) { +std::unique_ptr inner3dReductionHeuristic( + const int64_t total_reduction_numel, + const int64_t total_iteration_numel, + const int64_t inner_most_dimension_numel, + const int64_t n_tensor_inputs, + const int64_t max_input_dtype_size, + const size_t vectorize_factor) { // Set some targets for parallelization const int64_t n_elems = total_reduction_numel * total_iteration_numel; @@ -936,7 +934,7 @@ std::unique_ptr < ReductionParams > << (rparams->unroll_factor_inner_reduction > 1) << ", " << rparams->cross_grid_inner_reduction << std::endl; } - return innerReductionHeuristic( + return inner2dReductionHeuristic( total_reduction_numel, total_iteration_numel, total_reduction_numel, @@ -1497,7 +1495,7 @@ std::unique_ptr reductionHeuristic( const size_t vectorize_factor) { if (fastest_dim_reduction) { if (total_reduction_numel == inner_most_dimension_numel) { - return 2dInnerReductionHeuristic( + return inner2dReductionHeuristic( total_reduction_numel, total_iteration_numel, inner_most_dimension_numel, @@ -1505,7 +1503,7 @@ std::unique_ptr reductionHeuristic( (int64_t)max_input_dtype_size, vectorize_factor); } else { - return 3dInnerReductionHeuristic( + return inner3dReductionHeuristic( total_reduction_numel, total_iteration_numel, inner_most_dimension_numel,