Skip to content

Commit

Permalink
[Mosaic GPU] Remove the unnecessary scratch space operand
Browse files Browse the repository at this point in the history
And clean up the C++ dispatch code. We don't use HBM scratch anymore
since we pass TMA descriptors as kernel arguments.

PiperOrigin-RevId: 671327420
  • Loading branch information
apaszke authored and jax authors committed Sep 5, 2024
1 parent f3b91b2 commit 8feab68
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 34 deletions.
1 change: 0 additions & 1 deletion jax/_src/pallas/mosaic_gpu/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,5 @@ def pallas_call_lowering(
ctx,
*args,
module=module.operation.get_asm(binary=True, enable_debug_info=True),
gmem_scratch_bytes=lowering_result.gmem_scratch_bytes,
out_types=lowering_result.out_structs,
)
27 changes: 10 additions & 17 deletions jax/experimental/mosaic/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@


@mosaic_gpu_p.def_abstract_eval
def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes):
del module, gmem_scratch_bytes # Unused.
def _mosaic_gpu_abstract_eval(*_, module, out_types):
del module # Unused.
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]

# TODO(apaszke): Implement a proper system for managing kernel lifetimes
KNOWN_KERNELS = {}

def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes):
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types):
del out_types # Unused.
kernel_id = hashlib.sha256(module).digest()
# Note that this is technically only a half measure. Someone might load a
Expand All @@ -108,19 +108,13 @@ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes)
KNOWN_KERNELS[kernel_id] = module
op = mlir.custom_call(
"mosaic_gpu",
result_types=[
*(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out),
mlir.aval_to_ir_type(
jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8)
),
],
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands=args,
operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in],
result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out]
+ [[0]],
result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out],
backend_config=kernel_id + module,
)
return op.results[:-1] # Skip the scratch space.
return op.results

mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")

Expand Down Expand Up @@ -766,8 +760,8 @@ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
ir.Attribute.parse("#llvm.linkage<external>"),
addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory.
)
@func.FuncOp.from_py_func(ptr_ty, ptr_ty, ptr_ty)
def main(token_ptr, buffers, gmem_scratch_ptr):
@func.FuncOp.from_py_func(ptr_ty, ptr_ty)
def main(token_ptr, buffers):
nonlocal gmem_scratch_bytes
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
arg_refs = []
Expand Down Expand Up @@ -803,7 +797,7 @@ def main(token_ptr, buffers, gmem_scratch_ptr):
sym_tab.insert(global_scratch)
module.operation.verify()

return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple
return module, out_shape, unwrap_output_tuple


def as_gpu_kernel(
Expand All @@ -822,7 +816,7 @@ def as_gpu_kernel(
elif not isinstance(in_shape, tuple):
in_shape = (in_shape,)

module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = (
module, out_shape, unwrap_output_tuple = (
_lower_as_gpu_kernel(
body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape,
module_name, prof_spec
Expand All @@ -844,7 +838,6 @@ def bind(*args):
*args,
out_types=out_shape,
module=module_asm,
gmem_scratch_bytes=gmem_scratch_bytes,
)

if prof_spec is not None:
Expand Down
25 changes: 9 additions & 16 deletions jaxlib/mosaic/gpu/custom_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,20 +353,16 @@ absl::StatusOr<std::unique_ptr<mlir::ExecutionEngine>> Compile(
class CompiledKernel {
public:
CompiledKernel(std::unique_ptr<mlir::ExecutionEngine> engine, void* ctx,
void* scratch_addr, MosaicHostFunc* host_launch)
: engine_(std::move(engine)),
ctx_(ctx),
scratch_addr_(scratch_addr),
host_launch_(host_launch) {}
MosaicHostFunc* host_launch)
: engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch) {}

std::tuple<void*, void*, MosaicHostFunc*> GetHostLaunch() {
return std::make_tuple(ctx_, scratch_addr_, host_launch_);
std::tuple<void*, MosaicHostFunc*> GetHostLaunch() {
return std::make_tuple(ctx_, host_launch_);
}

private:
std::unique_ptr<mlir::ExecutionEngine> engine_;
void* ctx_; // TODO(apaszke): Destroy this properly
void* scratch_addr_;
MosaicHostFunc* host_launch_;
};

Expand All @@ -384,7 +380,7 @@ GetKernelCache() {
// Each compiled kernel has a unique init func, and each kernel is used from
// a single HLO module. So it should be safe to not include the CUDA context
// in the key.
absl::StatusOr<std::tuple<void*, void*, MosaicHostFunc*>> CompileAndInit(
absl::StatusOr<std::tuple<void*, MosaicHostFunc*>> CompileAndInit(
CacheKey key, const char* module) {
auto cache_and_mutex = GetKernelCache();
auto* cache = cache_and_mutex.first;
Expand Down Expand Up @@ -426,10 +422,8 @@ absl::StatusOr<std::tuple<void*, void*, MosaicHostFunc*>> CompileAndInit(
void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr};
reinterpret_cast<MosaicInitFunc*>(*init)(init_args);
cache->insert_or_assign(
key,
CompiledKernel(std::move(*maybe_engine), kernel_ptr,
nullptr, // TODO(apaszke): Clean this up.
reinterpret_cast<MosaicHostFunc*>(*main)));
key, CompiledKernel(std::move(*maybe_engine), kernel_ptr,
reinterpret_cast<MosaicHostFunc*>(*main)));
}
return cache->at(key).GetHostLaunch();
}
Expand All @@ -454,9 +448,8 @@ void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque,
ctx_and_kernel.status().message().size());
return;
}
void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers,
&std::get<1>(*ctx_and_kernel)};
std::get<2>(*ctx_and_kernel)(args);
void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers};
std::get<1>(*ctx_and_kernel)(args);
}

XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall,
Expand Down

0 comments on commit 8feab68

Please sign in to comment.