From 8feab682097b0949d0504ec0ee73f4637aeb1f57 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Thu, 5 Sep 2024 04:57:02 -0700 Subject: [PATCH] [Mosaic GPU] Remove the unnecessary scratch space operand And clean up the C++ dispatch code. We don't use HBM scratch anymore since we pass TMA descriptors as kernel arguments. PiperOrigin-RevId: 671327420 --- .../mosaic_gpu/pallas_call_registration.py | 1 - jax/experimental/mosaic/gpu/__init__.py | 27 +++++++------------ jaxlib/mosaic/gpu/custom_call.cc | 25 +++++++---------- 3 files changed, 19 insertions(+), 34 deletions(-) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 9f28fa7c2944..5b46caf1553a 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -70,6 +70,5 @@ def pallas_call_lowering( ctx, *args, module=module.operation.get_asm(binary=True, enable_debug_info=True), - gmem_scratch_bytes=lowering_result.gmem_scratch_bytes, out_types=lowering_result.out_structs, ) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index eb8eba9dfacf..2e2941fca5b1 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -88,14 +88,14 @@ @mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes): - del module, gmem_scratch_bytes # Unused. +def _mosaic_gpu_abstract_eval(*_, module, out_types): + del module # Unused. return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] # TODO(apaszke): Implement a proper system for managing kernel lifetimes KNOWN_KERNELS = {} -def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes): +def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types): del out_types # Unused. kernel_id = hashlib.sha256(module).digest() # Note that this is technically only a half measure. Someone might load a @@ -108,19 +108,13 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes) KNOWN_KERNELS[kernel_id] = module op = mlir.custom_call( "mosaic_gpu", - result_types=[ - *(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out), - mlir.aval_to_ir_type( - jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8) - ), - ], + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], operands=args, operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], - result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out] - + [[0]], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], backend_config=kernel_id + module, ) - return op.results[:-1] # Skip the scratch space. + return op.results mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") @@ -766,8 +760,8 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: ir.Attribute.parse("#llvm.linkage"), addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. ) - @func.FuncOp.from_py_func(ptr_ty, ptr_ty, ptr_ty) - def main(token_ptr, buffers, gmem_scratch_ptr): + @func.FuncOp.from_py_func(ptr_ty, ptr_ty) + def main(token_ptr, buffers): nonlocal gmem_scratch_bytes token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) arg_refs = [] @@ -803,7 +797,7 @@ def main(token_ptr, buffers, gmem_scratch_ptr): sym_tab.insert(global_scratch) module.operation.verify() - return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple + return module, out_shape, unwrap_output_tuple def as_gpu_kernel( @@ -822,7 +816,7 @@ def as_gpu_kernel( elif not isinstance(in_shape, tuple): in_shape = (in_shape,) - module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = ( + module, out_shape, unwrap_output_tuple = ( _lower_as_gpu_kernel( body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, module_name, prof_spec @@ -844,7 +838,6 @@ def bind(*args): *args, out_types=out_shape, module=module_asm, - gmem_scratch_bytes=gmem_scratch_bytes, ) if prof_spec is not None: diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index d9b1e0775ecc..2e5723b184a8 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -353,20 +353,16 @@ absl::StatusOr> Compile( class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - void* scratch_addr, MosaicHostFunc* host_launch) - : engine_(std::move(engine)), - ctx_(ctx), - scratch_addr_(scratch_addr), - host_launch_(host_launch) {} + MosaicHostFunc* host_launch) + : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch) {} - std::tuple GetHostLaunch() { - return std::make_tuple(ctx_, scratch_addr_, host_launch_); + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, host_launch_); } private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly - void* scratch_addr_; MosaicHostFunc* host_launch_; }; @@ -384,7 +380,7 @@ GetKernelCache() { // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CompileAndInit( +absl::StatusOr> CompileAndInit( CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; @@ -426,10 +422,8 @@ absl::StatusOr> CompileAndInit( void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; reinterpret_cast(*init)(init_args); cache->insert_or_assign( - key, - CompiledKernel(std::move(*maybe_engine), kernel_ptr, - nullptr, // TODO(apaszke): Clean this up. - reinterpret_cast(*main))); + key, CompiledKernel(std::move(*maybe_engine), kernel_ptr, + reinterpret_cast(*main))); } return cache->at(key).GetHostLaunch(); } @@ -454,9 +448,8 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, ctx_and_kernel.status().message().size()); return; } - void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers, - &std::get<1>(*ctx_and_kernel)}; - std::get<2>(*ctx_and_kernel)(args); + void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers}; + std::get<1>(*ctx_and_kernel)(args); } XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall,