From d34c40f6b6371bf1a41e08e047d7138484931421 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 17 Jan 2025 03:12:12 -0800 Subject: [PATCH] [mosaic_gpu] Added a serialization pass The pass adds versioning to the Mosaic GPU IR in the lowered custom calls and can apply forward/backward migration rules. Currently, no rules are necessary since we are at version 1. PiperOrigin-RevId: 716596848 --- jax/_src/lib/mosaic_gpu.py | 7 +- .../mosaic_gpu/pallas_call_registration.py | 2 +- jax/experimental/mosaic/gpu/core.py | 57 +++++++++--- jaxlib/mlir/_mlir_libs/BUILD.bazel | 1 + jaxlib/mosaic/gpu/BUILD | 4 + jaxlib/mosaic/gpu/custom_call.cc | 10 ++- jaxlib/mosaic/gpu/integrations/c/passes.cc | 2 + jaxlib/mosaic/gpu/mosaic_gpu_ext.cc | 5 ++ jaxlib/mosaic/gpu/serde.cc | 68 +++++++++++++++ jaxlib/mosaic/gpu/serde.h | 86 +++++++++++++++++++ 10 files changed, 225 insertions(+), 17 deletions(-) create mode 100644 jaxlib/mosaic/gpu/serde.cc create mode 100644 jaxlib/mosaic/gpu/serde.h diff --git a/jax/_src/lib/mosaic_gpu.py b/jax/_src/lib/mosaic_gpu.py index 494112093029..233b51db4d6f 100644 --- a/jax/_src/lib/mosaic_gpu.py +++ b/jax/_src/lib/mosaic_gpu.py @@ -17,7 +17,12 @@ try: try: from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error - except ImportError: + except ImportError as e: + print(e) from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error except ImportError as e: + print("="*128) + print(e) raise ModuleNotFoundError("Failed to import the Mosaic GPU bindings") from e +else: + _mosaic_gpu_ext.register_passes() diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 18d8baf6e95e..19c8b0ad5b30 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -73,7 +73,7 @@ def pallas_call_lowering( outs = mosaic_core._mosaic_gpu_lowering_rule( ctx.replace(avals_out=new_avals_out), *args, - module=module.operation.get_asm(binary=True, enable_debug_info=True), + module=module, out_types=lowering_result.out_structs, input_output_aliases=input_output_aliases, ) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py index ab70c19b9e96..e50a4672b938 100644 --- a/jax/experimental/mosaic/gpu/core.py +++ b/jax/experimental/mosaic/gpu/core.py @@ -28,7 +28,9 @@ import jax from jax._src.interpreters import mlir +from jax._src.lib import mosaic_gpu_dialect as dialect from jaxlib.mlir import ir +from jaxlib.mlir import passmanager from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import func from jaxlib.mlir.dialects import gpu @@ -37,8 +39,6 @@ from jaxlib.mlir.dialects import nvvm import numpy as np -from jax._src.lib import mosaic_gpu_dialect as dialect # noqa: F401 - if dialect is not None: from . import dialect_lowering from . import layout_inference @@ -63,6 +63,9 @@ PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") +# This tracks the latest Mosaic GPU IR version with a monthly delay. +FWD_COMPAT_IR_VERSION = 1 + c = utils.c # This is too common to fully qualify. @@ -88,7 +91,7 @@ @mosaic_gpu_p.def_abstract_eval def _mosaic_gpu_abstract_eval(*_, module, out_types): - del module # Unused. + del module # Unused. return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] # TODO(apaszke): Implement a proper system for managing kernel lifetimes @@ -103,22 +106,28 @@ def _mosaic_gpu_lowering_rule( input_output_aliases: tuple[tuple[int, int], ...] = (), ): assert len(out_types) == len(ctx.avals_out) - kernel_id = hashlib.sha256(module).digest() + module = _run_serde_pass( + module, + serialize=True, + ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None, + ) + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + kernel_id = hashlib.sha256(module_asm).digest() # Note that this is technically only a half measure. Someone might load a # compiled module with a hash collision from disk. But that's so unlikely with # SHA256 that it shouldn't be a problem. if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None: - if kernel_text != module: + if kernel_text != module_asm: raise RuntimeError("Hash collision!") else: - KNOWN_KERNELS[kernel_id] = module + KNOWN_KERNELS[kernel_id] = module_asm op = mlir.custom_call( "mosaic_gpu", result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], operands=args, operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], - backend_config=kernel_id + module, + backend_config=kernel_id + module_asm, operand_output_aliases=dict(input_output_aliases), ) return op.results @@ -425,6 +434,30 @@ def main(token_ptr, buffers): return module, out_shape, unwrap_output_tuple, launch_ctx, scratch_arr +def _run_serde_pass( + module: ir.Module, *, serialize: bool, ir_version: int | None = None +) -> ir.Module: + module = ir.Module.parse( + module.operation.get_asm(binary=True, enable_debug_info=True), + context=module.context, + ) + pipeline = passmanager.PassManager.parse( + "builtin.module(mosaic_gpu-serde{serialize=" + + str(serialize).lower() + + (f" target-version={ir_version}" if ir_version is not None else "") + + "})", + module.context, + ) + allow_unregistered_dialects = module.context.allow_unregistered_dialects + module.context.allow_unregistered_dialects = True + try: + pipeline.run(module.operation) + module.operation.verify() + finally: + module.context.allow_unregistered_dialects = allow_unregistered_dialects + return module + + def _initialize_scratch( launch_ctx : launch_context.LaunchContext, scratch_arr: ir.Value, @@ -472,6 +505,7 @@ def as_gpu_kernel( cluster: tuple[int, int, int] = (1, 1, 1), module_name: str = "unknown", kernel_name: str | None = None, + ir_version: int | None = None, thread_semantics: ThreadSemantics = ThreadSemantics.Lane, ): if isinstance(in_shape, list): @@ -504,13 +538,8 @@ def _check_args(*args): f" {arg_treedef}, ({args=})" ) - module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) - def bind(*args): - return mosaic_gpu_p.bind( - *args, - out_types=out_shape, - module=module_asm, - ) + def bind(*args) -> Any: + return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape) if prof_spec is not None: @jax.jit diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 812d73e3343e..b82cf31d7a72 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -382,6 +382,7 @@ cc_library( deps = [ "//jaxlib/mosaic:tpu_dialect_capi_objects", "//jaxlib/mosaic/dialect/gpu:gpu_dialect_capi_objects", + "//jaxlib/mosaic/gpu:mlir_capi_objects", "@llvm-project//mlir:CAPIArithObjects", "@llvm-project//mlir:CAPIGPUObjects", "@llvm-project//mlir:CAPIIRObjects", diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index f72776ca3e7e..444586f19ea7 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -44,13 +44,16 @@ cc_library( srcs = [ "launch_lowering.cc", "passes.cc", + "serde.cc", ], hdrs = [ "launch_lowering.h", "passes.h", + "serde.h", ], deps = [ "//jaxlib:pass_boilerplate", + "//jaxlib/mosaic:serde", "@llvm-project//llvm:Support", "@llvm-project//mlir:DataLayoutInterfaces", "@llvm-project//mlir:FuncDialect", @@ -185,6 +188,7 @@ nanobind_extension( "//conditions:default": [], }), deps = [ + ":mlir_capi", "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", "@com_google_absl//absl/cleanup", diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 383f68ddd3d1..34685368663b 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -86,6 +86,7 @@ limitations under the License. #include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" #include "jaxlib/mosaic/gpu/passes.h" +#include "jaxlib/mosaic/gpu/serde.h" #include "jaxlib/mosaic/gpu/target.h" #include "xla/service/custom_call_status.h" #include "xla/service/custom_call_target_registry.h" @@ -455,12 +456,19 @@ absl::StatusOr> GetHostAndInitFuncNames( absl::StatusOr CompileAndInit(const char* module) { mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); + context.allowUnregisteredDialects(true); InitContext(&context); mlir::ParserConfig parse_config(&context); auto module_op = mlir::parseSourceString(module, parse_config); if (!module_op) { - return absl::InternalError("Failed to parse module"); + return absl::InternalError("Failed to parse Mosaic GPU module"); + } + auto manager = mlir::PassManager::on(module_op->getContext()); + manager.addPass(mosaic::gpu::createSerdePass( + mosaic::gpu::SerdePassOptions{.serialize = false})); + if (manager.run(module_op.get()).failed()) { + return absl::InternalError("Failed to deserialize Mosaic GPU module"); } auto maybe_engine = Compile(*module_op); if (!maybe_engine.ok()) { diff --git a/jaxlib/mosaic/gpu/integrations/c/passes.cc b/jaxlib/mosaic/gpu/integrations/c/passes.cc index 065d11fd33e1..524b4443cd8e 100644 --- a/jaxlib/mosaic/gpu/integrations/c/passes.cc +++ b/jaxlib/mosaic/gpu/integrations/c/passes.cc @@ -16,11 +16,13 @@ limitations under the License. #include "jaxlib/mosaic/gpu/integrations/c/passes.h" #include "jaxlib/mosaic/gpu/launch_lowering.h" +#include "jaxlib/mosaic/gpu/serde.h" extern "C" { void mlirMosaicGpuRegisterPasses() { mosaic::gpu::registerGpuLaunchLoweringPass(); + mosaic::gpu::registerSerdePass(); } } diff --git a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc index 8889a4983765..e5c85ac5801f 100644 --- a/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc +++ b/jaxlib/mosaic/gpu/mosaic_gpu_ext.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -28,6 +29,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" +#include "jaxlib/mosaic/gpu/integrations/c/passes.h" #include "xla/ffi/api/c_api.h" #include "xla/ffi/api/ffi.h" @@ -194,6 +196,9 @@ void callback_complete(CUcontext context, uint32_t streamId, } NB_MODULE(_mosaic_gpu_ext, m) { + m.def("register_passes", []() { + mlirMosaicGpuRegisterPasses(); + }); m.def("registrations", []() { return nb::make_tuple( nb::make_tuple("mgpu_event_record", EncapsulateFunction(EventRecord)), diff --git a/jaxlib/mosaic/gpu/serde.cc b/jaxlib/mosaic/gpu/serde.cc new file mode 100644 index 000000000000..f4cf846acc11 --- /dev/null +++ b/jaxlib/mosaic/gpu/serde.cc @@ -0,0 +1,68 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/mosaic/gpu/serde.h" + +#include "llvm/include/llvm/ADT/StringMap.h" +#include "llvm/include/llvm/ADT/StringRef.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/Support/LLVM.h" +#include "jaxlib/mosaic/serde.h" + +namespace mosaic::gpu { + +namespace { + +constexpr llvm::StringRef kMangledDialect = "stable_mosaic_gpu."; +constexpr llvm::StringRef kVersionAttrName = "stable_mosaic_gpu.version"; +// When this is bumped, we should file a TODO to update the forward-compatible +// version in Mosaic GPU lowering in a month! +constexpr int kVersion = 1; + +using SerdeRuleType = jaxlib::mosaic::SerdeRuleType; + +const llvm::StringMap& upgrade_rules() { + static auto rules = new llvm::StringMap{}; + return *rules; +} + +const llvm::StringMap& downgrade_rules() { + static auto rules = new llvm::StringMap{}; + return *rules; +} + +} // namespace + +void SerdePass::runOnOperation() { + mlir::ModuleOp module = getOperation(); + if (!serialize.hasValue()) { + module.emitError("serialize option must be specified"); + return signalPassFailure(); + } + int serialize_version = -1; + if (serialize) { + serialize_version = target_version.hasValue() ? target_version : kVersion; + } + if (mlir::failed(jaxlib::mosaic::RunSerde( + module, upgrade_rules(), downgrade_rules(), serialize, + {.dialect_prefix = kMangledDialect, + .highest_version = kVersion, + .version_attr_name = kVersionAttrName, + .serialize_version = serialize_version}))) { + signalPassFailure(); + } +} + +} // namespace mosaic::gpu diff --git a/jaxlib/mosaic/gpu/serde.h b/jaxlib/mosaic/gpu/serde.h new file mode 100644 index 000000000000..6187d72b4cd5 --- /dev/null +++ b/jaxlib/mosaic/gpu/serde.h @@ -0,0 +1,86 @@ +/* Copyright 2025 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_GPU_SERDE_H_ +#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_GPU_SERDE_H_ + +#include +#include + +#include "llvm/include/llvm/ADT/StringRef.h" +#include "llvm/include/llvm/Support/CommandLine.h" +#include "mlir/include/mlir/IR/BuiltinOps.h" +#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/include/mlir/Pass/Pass.h" +#include "mlir/include/mlir/Pass/PassRegistry.h" +#include "jaxlib/pass_boilerplate.h" + +namespace mosaic::gpu { + +struct SerdePassOptions { + bool serialize; + int target_version; +}; + +struct SerdePass : public jaxlib::mlir::Pass { + using jaxlib::mlir::Pass::Pass; + + static constexpr llvm::StringLiteral kArgumentName = "mosaic_gpu-serde"; + static constexpr llvm::StringLiteral kPassName = "MosaicGPUSerdePass"; + + SerdePass() = default; + + explicit SerdePass(SerdePassOptions options) { + serialize = options.serialize; + target_version = options.target_version; + } + + SerdePass(const SerdePass &other) { + serialize = other.serialize; + target_version = other.target_version; + } + + SerdePass &operator=(const SerdePass &other) { + serialize = other.serialize; + target_version = other.target_version; + return *this; + } + + void runOnOperation(); + + protected: + ::mlir::Pass::Option serialize{*this, "serialize", llvm::cl::desc("")}; + ::mlir::Pass::Option target_version{*this, "target-version", + llvm::cl::desc("")}; +}; + +inline std::unique_ptr<::mlir::Pass> createSerdePass() { + return std::make_unique(); +} + +inline std::unique_ptr<::mlir::Pass> createSerdePass( + SerdePassOptions options) { + return std::make_unique(std::move(options)); +} + +inline void registerSerdePass() { + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return createSerdePass(); + }); +} + +} // namespace mosaic::gpu + +#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_GPU_SERDE_H_