Skip to content

Commit

Permalink
[xla:cpu] Add support for pthreadpool_parallelize_1d
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 708458932
  • Loading branch information
ezhulenev authored and Google-ML-Automation committed Dec 21, 2024
1 parent e8e098f commit 6b8b5a1
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 9 deletions.
8 changes: 4 additions & 4 deletions xla/backends/cpu/runtime/xnnpack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ xla_cc_test(
deps = [
":parallel_loop_runner",
"//xla/tsl/concurrency:async_value",
"//xla/tsl/platform:env",
"//xla/tsl/platform:test",
"//xla/tsl/platform:test_benchmark",
"//xla/tsl/platform:test_main",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/cleanup",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:env",
"@tsl//tsl/platform:test",
"@tsl//tsl/platform:test_benchmark",
"@tsl//tsl/platform:test_main",
],
)

Expand Down
29 changes: 29 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,35 @@ static Task3DTile2DIndex Delinearize(size_t task_index, size_t range_i,
// (2) If done event is not available, we have to overwrite it with a new one
// that will be set to concrete state after the task is executed.

void ParallelLoopRunner::Parallelize(size_t range, Task1D task) {
DCHECK(done_event_) << "Parallel loop runner is in moved-from state";
DCHECK_GT(range, 0) << "Expected at least one task";

// Fast path for the degenerate parallel loop with single task.
if (ABSL_PREDICT_TRUE(range == 1)) {
// Execute task in the caller thread if done event is already available.
if (ABSL_PREDICT_TRUE(done_event_.IsConcrete())) {
task(0);
return;
}

// Schedule task when done event becomes available.
ScheduleOne([task = std::move(task)] { task(0); });
return;
}

// Schedule `parallel_config.num_parallel_tasks` into the underlying thread
// pool when done event becomes available.
auto parallel_config = ComputeParallelTaskConfig(range);
auto parallel_task = [parallel_config,
task = std::move(task)](size_t parallel_task_index) {
auto [begin, end] = parallel_config.ParallelTaskRange(parallel_task_index);
for (size_t i = begin; i < end; ++i) task(i);
};

ScheduleAll(parallel_config.num_parallel_tasks, std::move(parallel_task));
}

void ParallelLoopRunner::Parallelize(size_t range, size_t tile,
Task1DTile1D task) {
DCHECK(done_event_) << "Parallel loop runner is in moved-from state";
Expand Down
8 changes: 8 additions & 0 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ParallelLoopRunner {
static tsl::AsyncValueRef<tsl::Chain> TakeDoneEvent(
ParallelLoopRunner&& runner);

using Task1D = std::function<void(size_t offset)>;

using Task1DTile1D = std::function<void(size_t offset, size_t extent)>;

using Task2DTile1D =
Expand All @@ -61,6 +63,12 @@ class ParallelLoopRunner {
std::function<void(size_t offset_i, size_t offset_j, size_t offset_k,
size_t extent_j, size_t extent_k)>;

// This function implements a parallel version of a following loop:
//
// for (size_t i = 0; i < range; i++)
// task(i);
void Parallelize(size_t range, Task1D task);

// This function implements a parallel version of a following loop:
//
// for (size_t i = 0; i < range; i += tile)
Expand Down
32 changes: 28 additions & 4 deletions xla/backends/cpu/runtime/xnnpack/parallel_loop_runner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,41 @@ limitations under the License.
#include "absl/cleanup/cleanup.h"
#include "absl/types/span.h"
#include "xla/tsl/concurrency/async_value_ref.h"
#include "tsl/platform/env.h"
#include "tsl/platform/test.h"
#include "tsl/platform/test_benchmark.h"
#include "tsl/platform/threadpool.h"
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/test.h"
#include "xla/tsl/platform/test_benchmark.h"
#include "xla/tsl/platform/threadpool.h"

#define EIGEN_USE_THREADS
#include "unsupported/Eigen/CXX11/Tensor"

namespace xla::cpu {
namespace {

TEST(ParallelLoopRunnerTest, Parallelize1D) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
threads.NumThreads());
ParallelLoopRunner runner(&device);

constexpr int32_t d0 = 128;

auto* data = new int32_t[d0]();
auto cleanup = absl::Cleanup([&]() { delete[] data; });

auto increment = [&](size_t offset) { data[offset] += 1; };

runner.Parallelize(d0, increment);
runner.Parallelize(d0, increment);
runner.Parallelize(d0, increment);
runner.Parallelize(d0, increment);
runner.Parallelize(d0, increment);

tsl::BlockUntilReady(ParallelLoopRunner::TakeDoneEvent(std::move(runner)));
ASSERT_TRUE(absl::c_all_of(absl::MakeSpan(&data[0], d0),
[](int32_t value) { return value == 5; }));
}

TEST(ParallelLoopRunnerTest, Parallelize1DTile1D) {
tsl::thread::ThreadPool threads(tsl::Env::Default(), "test", 8);
Eigen::ThreadPoolDevice device(threads.AsEigenThreadPool(),
Expand Down
18 changes: 17 additions & 1 deletion xla/backends/cpu/runtime/xnnpack/xnn_threadpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,22 @@ static size_t GetThreadsCount(pthreadpool_t threadpool) { // NOLINT
return Cast(threadpool)->runner()->num_threads();
}

static void Parallelize1D( // NOLINT
pthreadpool_t threadpool, pthreadpool_task_1d_t function, void* context,
size_t range, uint32_t flags) {
if (ABSL_PREDICT_FALSE(threadpool == nullptr)) {
for (size_t i = 0; i < range; ++i) {
function(context, i);
}
return;
}

ParallelLoopRunner::Task1D task = [function, context](size_t offset) {
(*function)(context, offset);
};
Cast(threadpool)->runner()->Parallelize(range, task);
}

static void Parallelize1DTile1D( // NOLINT
pthreadpool_t threadpool, pthreadpool_task_1d_tile_1d_t function,
void* context, size_t range, size_t tile, uint32_t flags) {
Expand Down Expand Up @@ -243,7 +259,7 @@ extern "C" void pthreadpool_parallelize_1d(pthreadpool_t threadpool,
pthreadpool_task_1d_t function,
void* context, size_t range,
uint32_t flags) {
LOG(FATAL) << "Not implemented";
xla::cpu::Parallelize1D(threadpool, function, context, range, flags);
}

extern "C" void pthreadpool_parallelize_1d_with_thread(
Expand Down

0 comments on commit 6b8b5a1

Please sign in to comment.