Skip to content

Commit 63f9cd6

Browse files
authored
[Relax] Alloc BYOC workspace with R.builtin.alloc_tensor (#17110)
* [Relax] Alloc BYOC workspace with R.builtin.alloc_tensor This makes the allocation go through memory planning and make it compatible with cuda graph. * lint * lint
1 parent 02fe0c5 commit 63f9cd6

File tree

5 files changed

+26
-24
lines changed

5 files changed

+26
-24
lines changed

python/tvm/relax/testing/matmul.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ def get_relax_matmul_module(
2525
x_shape,
2626
y_shape,
2727
in_dtype,
28-
out_dtype,
28+
out_dtype=None,
2929
transposed_y=False,
3030
bias_shape=None,
3131
activation=None,
3232
residual_bin_op=None,
3333
residual_activation=None,
3434
):
3535
"""Create a matmul op followd by epilogue operations."""
36+
out_dtype = out_dtype if out_dtype is not None else in_dtype
3637
with IRBuilder() as builder:
3738
with relax_builder.function():
3839
R.func_name("main")

src/relax/op/op_common.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,9 @@ Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm d
558558
StringImm storage_scope = StringImm("global"));
559559
Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype);
560560

561+
Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_index,
562+
StringImm storage_scope = StringImm("global"));
563+
561564
/**
562565
* \brief Return the argument of the call.
563566
* Note: If this is a call_tir, return the arguments passed to the TIR func

src/relax/transform/allocate_workspace.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,7 @@ class WorkspaceProvider : ExprMutator {
144144
if (!workspace_var_main_.defined()) {
145145
auto shape = ShapeExpr({Integer(max_workspace_size_)});
146146
auto ty = DataTypeImm(DataType::UInt(8));
147-
auto storage = MakeVMAllocStorage(shape, PrimValue::Int64(0), ty);
148-
auto workspace = MakeVMAllocTensor(storage, PrimValue::Int64(0), shape, ty);
147+
auto workspace = MakeAllocTensor(shape, ty, PrimValue::Int64(0));
149148
workspace_var_main_ = builder_->Emit(workspace, "workspace_main");
150149
}
151150
for (const auto& binding : block_node->bindings) {

tests/python/relax/test_codegen_cutlass.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def build_cutlass(mod, assert_all_bindings_fused=True, num_final_bindings=1):
104104
mod = partition_for_cutlass(mod)
105105

106106
if assert_all_bindings_fused:
107-
assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings
107+
assert (
108+
len(mod["main"].body.blocks[0].bindings) == num_final_bindings
109+
), "Not all bindings are fused. " + str(mod["main"])
108110

109111
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})
110112
mod = codegen_pass(mod)
@@ -714,7 +716,7 @@ def test_attention_offload(attention_size, attention_dtype):
714716
v_shape = (b, s_kv, n, h_v)
715717

716718
mod = get_relax_attention_module(q_shape, k_shape, v_shape, dtype=attention_dtype)
717-
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
719+
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2)
718720

719721
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
720722

@@ -751,7 +753,7 @@ def test_attention_bias_offload(attention_bias_size):
751753
mod = get_relax_attention_module(
752754
q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32"
753755
)
754-
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3)
756+
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=2)
755757

756758
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
757759

@@ -786,9 +788,9 @@ def test_attention_scale_offload(attention_scale_size, attention_scale):
786788
q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape, qk_scale=attention_scale
787789
)
788790
if bias is None:
789-
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
791+
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2)
790792
else:
791-
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3)
793+
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=2)
792794
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
793795

794796

@@ -829,9 +831,9 @@ def test_attention_causal_offload(attention_causal_size, attention_causal):
829831
)
830832

831833
if bias is None:
832-
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
834+
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2)
833835
else:
834-
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3)
836+
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=2)
835837
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
836838

837839

@@ -932,9 +934,9 @@ def test_stacked_attention_split_offload(stacked_attention_size):
932934
)
933935

934936
if bias is None:
935-
out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3)
937+
out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=2)
936938
else:
937-
out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3)
939+
out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=2)
938940
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
939941

940942

@@ -950,9 +952,9 @@ def test_stacked_attention_strided_slice_offload(stacked_attention_size):
950952
qkv, b, s, n, h, h_v, "strided_slice", bias, scale, single_shape=single_shape
951953
)
952954
if bias is None:
953-
out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3)
955+
out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=2)
954956
else:
955-
out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3)
957+
out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=2)
956958
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
957959

958960

@@ -1311,9 +1313,8 @@ def main(
13111313
R.func_attr({"num_input": 4})
13121314
cls = Expected
13131315
with R.dataflow():
1314-
lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8"))
1315-
workspace_main = R.vm.alloc_tensor(
1316-
lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
1316+
workspace_main = R.builtin.alloc_tensor(
1317+
R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
13171318
)
13181319
lv_1 = R.reshape(bias, R.shape([128, 16, 8]))
13191320
lv1 = R.reshape(lv_1, R.shape([4, 32, 16, 8]))
@@ -2419,7 +2420,7 @@ def test_sliding_window():
24192420
1, 64, 64, 16, 8, 8, "none", "none", causal, "float16", window_size=window_size
24202421
)
24212422

2422-
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
2423+
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=2)
24232424

24242425
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
24252426

tests/python/relax/test_transform_allocate_workspace.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,8 @@ def entry_a(
126126
) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
127127
cls = Expected
128128
with R.dataflow():
129-
lv: R.Object = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8"))
130-
workspace_main: R.Tensor((65536,), dtype="uint8") = R.vm.alloc_tensor(
131-
lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
129+
workspace_main: R.Tensor((65536,), dtype="uint8") = R.builtin.alloc_tensor(
130+
R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
132131
)
133132
gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass1(
134133
q, k, v, workspace_main
@@ -144,9 +143,8 @@ def entry_b(
144143
) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
145144
cls = Expected
146145
with R.dataflow():
147-
lv: R.Object = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8"))
148-
workspace_main: R.Tensor((65536,), dtype="uint8") = R.vm.alloc_tensor(
149-
lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
146+
workspace_main: R.Tensor((65536,), dtype="uint8") = R.builtin.alloc_tensor(
147+
R.shape([65536]), R.dtype("uint8"), R.prim_value(0)
150148
)
151149
gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass1(
152150
q, k, v, workspace_main

0 commit comments

Comments
 (0)