Skip to content

Commit

Permalink
runtimes/cuda,zml: setup xla_gpu_cuda_data_dir via environment variable
Browse files Browse the repository at this point in the history
This is the way it should be [1]. Thankfully, other plugins don't use it.

1. openxla/xla#21428
  • Loading branch information
steeve committed Jan 15, 2025
1 parent 5309a25 commit c540ea7
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 16 deletions.
2 changes: 2 additions & 0 deletions runtimes/cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ zig_library(
"//runtimes:cuda.enabled": [
":libpjrt_cuda",
"//async",
"//stdx",
"@rules_zig//zig/runfiles",
],
"//conditions:default": [":empty"],
}),
Expand Down
30 changes: 28 additions & 2 deletions runtimes/cuda/cuda.zig
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
const builtin = @import("builtin");
const std = @import("std");

const asynk = @import("async");
const pjrt = @import("pjrt");
const bazel_builtin = @import("bazel_builtin");
const builtin = @import("builtin");
const c = @import("c");
const pjrt = @import("pjrt");
const runfiles = @import("runfiles");
const stdx = @import("stdx");

pub fn isEnabled() bool {
return @hasDecl(c, "ZML_RUNTIME_CUDA");
Expand All @@ -12,6 +17,23 @@ fn hasNvidiaDevice() bool {
return true;
}

fn setupXlaGpuCudaDirFlag() !void {
var arena = std.heap.ArenaAllocator.init(std.heap.c_allocator);
defer arena.deinit();

var r_ = try runfiles.Runfiles.create(.{ .allocator = arena.allocator() }) orelse {
stdx.debug.panic("Unable to find CUDA directory", .{});
};

const source_repo = bazel_builtin.current_repository;
const r = r_.withSourceRepo(source_repo);
const cuda_data_dir = (try r.rlocationAlloc(arena.allocator(), "libpjrt_cuda/sandbox")).?;
const xla_flags = std.process.getEnvVarOwned(arena.allocator(), "XLA_FLAGS") catch "";
const new_xla_flagsZ = try std.fmt.allocPrintZ(arena.allocator(), "--xla_gpu_cuda_data_dir={s} {s}", .{ cuda_data_dir, xla_flags });

_ = c.setenv("XLA_FLAGS", new_xla_flagsZ, 1);
}

pub fn load() !*const pjrt.Api {
if (comptime !isEnabled()) {
return error.Unavailable;
Expand All @@ -23,5 +45,9 @@ pub fn load() !*const pjrt.Api {
return error.Unavailable;
}

// CUDA path has to be set _before_ loading the PJRT plugin.
// See https://github.com/openxla/xla/issues/21428
try setupXlaGpuCudaDirFlag();

return try pjrt.Api.loadFrom("libpjrt_cuda.so");
}
17 changes: 3 additions & 14 deletions zml/module.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ const std = @import("std");

const asynk = @import("async");
const dialect = @import("mlir/dialects");
const runfiles = @import("runfiles");
const stdx = @import("stdx");
const xla_pb = @import("//xla:xla_proto");

Expand Down Expand Up @@ -902,28 +901,18 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
}
}
switch (platform.target) {
.cuda => cuda_dir: {
// NVIDIA recommends to disable Triton GEMM on JAX:
.cuda => {
// NVIDIA recommends these settings
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
// setFlag(&options, "xla_gpu_enable_cudnn_fmha", true);
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
// setFlag(&options, "xla_gpu_enable_custom_fusions", true);
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
// setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
log.warn("Bazel runfile not found !", .{});
break :cuda_dir;
};
defer r_.deinit(arena);
const source_repo = @import("bazel_builtin").current_repository;
const r = r_.withSourceRepo(source_repo);
const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?;
log.info("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir});
setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir);
},
.rocm => {
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
Expand Down

0 comments on commit c540ea7

Please sign in to comment.