Skip to content

Commit

Permalink
[mosaic_gpu] Added a serialization pass
Browse files Browse the repository at this point in the history
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
  • Loading branch information
superbobry authored and Google-ML-Automation committed Jan 17, 2025
1 parent af66719 commit d34c40f
Show file tree
Hide file tree
Showing 10 changed files with 225 additions and 17 deletions.
7 changes: 6 additions & 1 deletion jax/_src/lib/mosaic_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
try:
try:
from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error
except ImportError:
except ImportError as e:
print(e)
from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error
except ImportError as e:
print("="*128)
print(e)
raise ModuleNotFoundError("Failed to import the Mosaic GPU bindings") from e
else:
_mosaic_gpu_ext.register_passes()
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def pallas_call_lowering(
outs = mosaic_core._mosaic_gpu_lowering_rule(
ctx.replace(avals_out=new_avals_out),
*args,
module=module.operation.get_asm(binary=True, enable_debug_info=True),
module=module,
out_types=lowering_result.out_structs,
input_output_aliases=input_output_aliases,
)
Expand Down
57 changes: 43 additions & 14 deletions jax/experimental/mosaic/gpu/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@

import jax
from jax._src.interpreters import mlir
from jax._src.lib import mosaic_gpu_dialect as dialect
from jaxlib.mlir import ir
from jaxlib.mlir import passmanager
from jaxlib.mlir.dialects import builtin
from jaxlib.mlir.dialects import func
from jaxlib.mlir.dialects import gpu
Expand All @@ -37,8 +39,6 @@
from jaxlib.mlir.dialects import nvvm
import numpy as np

from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401

if dialect is not None:
from . import dialect_lowering
from . import layout_inference
Expand All @@ -63,6 +63,9 @@
PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas")
NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm")

# This tracks the latest Mosaic GPU IR version with a monthly delay.
FWD_COMPAT_IR_VERSION = 1

c = utils.c # This is too common to fully qualify.


Expand All @@ -88,7 +91,7 @@

@mosaic_gpu_p.def_abstract_eval
def _mosaic_gpu_abstract_eval(*_, module, out_types):
del module # Unused.
del module # Unused.
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]

# TODO(apaszke): Implement a proper system for managing kernel lifetimes
Expand All @@ -103,22 +106,28 @@ def _mosaic_gpu_lowering_rule(
input_output_aliases: tuple[tuple[int, int], ...] = (),
):
assert len(out_types) == len(ctx.avals_out)
kernel_id = hashlib.sha256(module).digest()
module = _run_serde_pass(
module,
serialize=True,
ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None,
)
module_asm = module.operation.get_asm(binary=True, enable_debug_info=True)
kernel_id = hashlib.sha256(module_asm).digest()
# Note that this is technically only a half measure. Someone might load a
# compiled module with a hash collision from disk. But that's so unlikely with
# SHA256 that it shouldn't be a problem.
if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None:
if kernel_text != module:
if kernel_text != module_asm:
raise RuntimeError("Hash collision!")
else:
KNOWN_KERNELS[kernel_id] = module
KNOWN_KERNELS[kernel_id] = module_asm
op = mlir.custom_call(
"mosaic_gpu",
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands=args,
operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in],
result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out],
backend_config=kernel_id + module,
backend_config=kernel_id + module_asm,
operand_output_aliases=dict(input_output_aliases),
)
return op.results
Expand Down Expand Up @@ -425,6 +434,30 @@ def main(token_ptr, buffers):
return module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr


def _run_serde_pass(
module: ir.Module, *, serialize: bool, ir_version: int | None = None
) -> ir.Module:
module = ir.Module.parse(
module.operation.get_asm(binary=True, enable_debug_info=True),
context=module.context,
)
pipeline = passmanager.PassManager.parse(
"builtin.module(mosaic_gpu-serde{serialize="
+ str(serialize).lower()
+ (f" target-version={ir_version}" if ir_version is not None else "")
+ "})",
module.context,
)
allow_unregistered_dialects = module.context.allow_unregistered_dialects
module.context.allow_unregistered_dialects = True
try:
pipeline.run(module.operation)
module.operation.verify()
finally:
module.context.allow_unregistered_dialects = allow_unregistered_dialects
return module


def _initialize_scratch(
launch_ctx : launch_context.LaunchContext,
scratch_arr: ir.Value,
Expand Down Expand Up @@ -472,6 +505,7 @@ def as_gpu_kernel(
cluster: tuple[int, int, int] = (1, 1, 1),
module_name: str = "unknown",
kernel_name: str | None = None,
ir_version: int | None = None,
thread_semantics: ThreadSemantics = ThreadSemantics.Lane,
):
if isinstance(in_shape, list):
Expand Down Expand Up @@ -504,13 +538,8 @@ def _check_args(*args):
f" {arg_treedef}, ({args=})"
)

module_asm = module.operation.get_asm(binary=True, enable_debug_info=True)
def bind(*args):
return mosaic_gpu_p.bind(
*args,
out_types=out_shape,
module=module_asm,
)
def bind(*args) -> Any:
return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape)

if prof_spec is not None:
@jax.jit
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mlir/_mlir_libs/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ cc_library(
deps = [
"//jaxlib/mosaic:tpu_dialect_capi_objects",
"//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects",
"//jaxlib/mosaic/gpu:mlir_capi_objects",
"@llvm-project//mlir:CAPIArithObjects",
"@llvm-project//mlir:CAPIGPUObjects",
"@llvm-project//mlir:CAPIIRObjects",
Expand Down
4 changes: 4 additions & 0 deletions jaxlib/mosaic/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,16 @@ cc_library(
srcs = [
"launch_lowering.cc",
"passes.cc",
"serde.cc",
],
hdrs = [
"launch_lowering.h",
"passes.h",
"serde.h",
],
deps = [
"//jaxlib:pass_boilerplate",
"//jaxlib/mosaic:serde",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:DataLayoutInterfaces",
"@llvm-project//mlir:FuncDialect",
Expand Down Expand Up @@ -185,6 +188,7 @@ nanobind_extension(
"//conditions:default": [],
}),
deps = [
":mlir_capi",
"//jaxlib:kernel_nanobind_helpers",
"//jaxlib/cuda:cuda_vendor",
"@com_google_absl//absl/cleanup",
Expand Down
10 changes: 9 additions & 1 deletion jaxlib/mosaic/gpu/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ limitations under the License.
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h"
#include "jaxlib/mosaic/gpu/launch_lowering.h"
#include "jaxlib/mosaic/gpu/passes.h"
#include "jaxlib/mosaic/gpu/serde.h"
#include "jaxlib/mosaic/gpu/target.h"
#include "xla/service/custom_call_status.h"
#include "xla/service/custom_call_target_registry.h"
Expand Down Expand Up @@ -455,12 +456,19 @@ absl::StatusOr<std::pair<std::string, std::string>> GetHostAndInitFuncNames(

absl::StatusOr<CompiledKernel> CompileAndInit(const char* module) {
mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED);
context.allowUnregisteredDialects(true);
InitContext(&context);
mlir::ParserConfig parse_config(&context);
auto module_op =
mlir::parseSourceString<mlir::ModuleOp>(module, parse_config);
if (!module_op) {
return absl::InternalError("Failed to parse module");
return absl::InternalError("Failed to parse Mosaic GPU module");
}
auto manager = mlir::PassManager::on<mlir::ModuleOp>(module_op->getContext());
manager.addPass(mosaic::gpu::createSerdePass(
mosaic::gpu::SerdePassOptions{.serialize = false}));
if (manager.run(module_op.get()).failed()) {
return absl::InternalError("Failed to deserialize Mosaic GPU module");
}
auto maybe_engine = Compile(*module_op);
if (!maybe_engine.ok()) {
Expand Down
2 changes: 2 additions & 0 deletions jaxlib/mosaic/gpu/integrations/c/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ limitations under the License.
#include "jaxlib/mosaic/gpu/integrations/c/passes.h"

#include "jaxlib/mosaic/gpu/launch_lowering.h"
#include "jaxlib/mosaic/gpu/serde.h"

extern "C" {

void mlirMosaicGpuRegisterPasses() {
mosaic::gpu::registerGpuLaunchLoweringPass();
mosaic::gpu::registerSerdePass();
}

}
5 changes: 5 additions & 0 deletions jaxlib/mosaic/gpu/mosaic_gpu_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License.

#include <cstddef>
#include <cstdint>
#include <memory>
#include <new>
#include <stdexcept>
#include <string>
Expand All @@ -28,6 +29,7 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_nanobind_helpers.h"
#include "jaxlib/mosaic/gpu/integrations/c/passes.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"

Expand Down Expand Up @@ -194,6 +196,9 @@ void callback_complete(CUcontext context, uint32_t streamId,
}

NB_MODULE(_mosaic_gpu_ext, m) {
m.def("register_passes", []() {
mlirMosaicGpuRegisterPasses();
});
m.def("registrations", []() {
return nb::make_tuple(
nb::make_tuple("mgpu_event_record", EncapsulateFunction(EventRecord)),
Expand Down
68 changes: 68 additions & 0 deletions jaxlib/mosaic/gpu/serde.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "jaxlib/mosaic/gpu/serde.h"

#include "llvm/include/llvm/ADT/StringMap.h"
#include "llvm/include/llvm/ADT/StringRef.h"
#include "mlir/include/mlir/IR/BuiltinOps.h"
#include "mlir/include/mlir/Support/LLVM.h"
#include "jaxlib/mosaic/serde.h"

namespace mosaic::gpu {

namespace {

constexpr llvm::StringRef kMangledDialect = "stable_mosaic_gpu.";
constexpr llvm::StringRef kVersionAttrName = "stable_mosaic_gpu.version";
// When this is bumped, we should file a TODO to update the forward-compatible
// version in Mosaic GPU lowering in a month!
constexpr int kVersion = 1;

using SerdeRuleType = jaxlib::mosaic::SerdeRuleType;

const llvm::StringMap<SerdeRuleType>& upgrade_rules() {
static auto rules = new llvm::StringMap<SerdeRuleType>{};
return *rules;
}

const llvm::StringMap<SerdeRuleType>& downgrade_rules() {
static auto rules = new llvm::StringMap<SerdeRuleType>{};
return *rules;
}

} // namespace

void SerdePass::runOnOperation() {
mlir::ModuleOp module = getOperation();
if (!serialize.hasValue()) {
module.emitError("serialize option must be specified");
return signalPassFailure();
}
int serialize_version = -1;
if (serialize) {
serialize_version = target_version.hasValue() ? target_version : kVersion;
}
if (mlir::failed(jaxlib::mosaic::RunSerde(
module, upgrade_rules(), downgrade_rules(), serialize,
{.dialect_prefix = kMangledDialect,
.highest_version = kVersion,
.version_attr_name = kVersionAttrName,
.serialize_version = serialize_version}))) {
signalPassFailure();
}
}

} // namespace mosaic::gpu
Loading

0 comments on commit d34c40f

Please sign in to comment.