Skip to content

Commit 61328ae

Browse files
author
Ivan Sidorenko
committed
[CUBLAS][FP8] Enable R.matmul + R.multiply offloading
This commit enables offloading of the next pattern to cuBLAS: mm = R.linear(data, weights) scale = R.multiply(a_scale, w_scale) out = R.multiply(mm, scale) out = R.cast(out, dtype)
1 parent 28d32b5 commit 61328ae

File tree

7 files changed

+156
-12
lines changed

7 files changed

+156
-12
lines changed

python/tvm/relax/backend/contrib/cublas.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from tvm.relax.transform import PatternCheckContext
2626

2727
from ..pattern_registry import get_patterns_with_prefix, register_patterns
28-
from ..patterns import make_matmul_pattern, make_matmul_dequantize_pattern
28+
from ..patterns import (
29+
make_matmul_pattern,
30+
make_matmul_dequantize_pattern,
31+
make_matmul_multiply_pattern,
32+
)
2933
from ..utils import has_leaking_intermediate_variables
3034

3135

@@ -202,6 +206,11 @@ def _check_matmul(context: PatternCheckContext) -> bool:
202206
*make_matmul_dequantize_pattern(transposed_rhs=True),
203207
_check_matmul,
204208
),
209+
(
210+
"cublas.matmul_transposed_multiply",
211+
*make_matmul_multiply_pattern(transposed_rhs=True),
212+
_check_matmul,
213+
),
205214
]
206215
)
207216

python/tvm/relax/backend/patterns.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,44 @@ def make_matmul_dequantize_pattern(
376376
return out, annotations
377377

378378

379+
def make_matmul_multiply_pattern(
380+
transposed_rhs: bool = False,
381+
) -> Tuple[DFPattern, Mapping[str, DFPattern]]:
382+
"""
383+
Create pattern for matrix multiplication and dequantize operation.
384+
385+
Parameters
386+
----------
387+
transposed_rhs: bool
388+
Whether the right hand side of multiplication is transposed.
389+
390+
Returns
391+
-------
392+
pattern: DFPattern
393+
The resulting pattern describing a matrix multiplication.
394+
395+
annotations: Mapping[str, DFPattern]
396+
A mapping from name to sub pattern. It can be used to extract important expressions from
397+
match result, to power the partition check function and codegen.
398+
"""
399+
400+
lhs = wildcard()
401+
rhs = wildcard()
402+
scaleA = wildcard()
403+
scaleB = wildcard()
404+
annotations = {"lhs": lhs, "rhs": rhs, "scaleA": scaleA, "scaleB": scaleB}
405+
406+
if transposed_rhs:
407+
rhs = is_op("relax.permute_dims")(rhs)
408+
out = is_op("relax.matmul")(lhs, rhs)
409+
annotations["root"] = out
410+
scale = is_op("relax.multiply")(scaleA.has_shape((1,)), scaleB.has_shape((1,)))
411+
out = is_op("relax.multiply")(out, scale)
412+
out = is_op("relax.astype")(out)
413+
414+
return out, annotations
415+
416+
379417
def make_attention_rewrite_pattern(
380418
qkv_layout: str, out_layout: str, with_bias: bool, with_cast: bool, with_kv_repeat: bool = False
381419
):

src/relax/backend/contrib/cublas/codegen.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,17 @@ class CublasJSONSerializer : public JSONSerializer {
6262
inputs_tmp.insert(inputs_tmp.end(), res.begin(), res.end());
6363
}
6464

65-
ICHECK(inputs_tmp.size() <= 3);
65+
ICHECK(inputs_tmp.size() <= 4);
6666
NodeEntries inputs(inputs_tmp.size());
6767

6868
auto arg_idx = backend::ExtractArgIdx(composite_name, fn);
6969
inputs[0] = inputs_tmp[arg_idx["lhs"]->value];
7070
inputs[1] = inputs_tmp[arg_idx["rhs"]->value];
7171
if (inputs_tmp.size() == 3) {
7272
inputs[2] = inputs_tmp[arg_idx["bias"]->value];
73+
} else if (inputs_tmp.size() == 4) {
74+
inputs[2] = inputs_tmp[arg_idx["scaleA"]->value];
75+
inputs[3] = inputs_tmp[arg_idx["scaleB"]->value];
7376
}
7477

7578
auto node = std::make_shared<JSONGraphNode>(composite_name, /* name_ */

src/runtime/contrib/cublas/cublas.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ int roundoff(int v, int d) { return (v + d - 1) / d * d; }
137137

138138
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
139139
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
140-
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
141-
void* workspace_ptr, size_t workspace_size, cublasLtEpilogue_t epilogue,
140+
const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB,
141+
const DLTensor* C, bool transa, bool transb, void* workspace_ptr,
142+
size_t workspace_size, cublasLtEpilogue_t epilogue,
142143
std::optional<float> dq_scale) {
143144
ICHECK(TypeEqual(A->dtype, B->dtype));
144145
// Reversed strides indicates an in-place transpose operation.
@@ -193,6 +194,15 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
193194
&bias->data, sizeof(float*)));
194195
}
195196

197+
if (scaleA != nullptr && scaleB != nullptr) {
198+
auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
199+
auto scaleB_data = static_cast<char*>(scaleB->data) + scaleB->byte_offset;
200+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
201+
&scaleA_data, sizeof(float*)));
202+
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
203+
&scaleB_data, sizeof(float*)));
204+
}
205+
196206
if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
197207
CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
198208
&epilogue, sizeof(epilogue)));

src/runtime/contrib/cublas/cublas_json_runtime.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,15 @@ class CublasJSONRuntime : public JSONRuntimeBase {
9797
return dl_tensors[eid];
9898
};
9999

100-
auto get_inputs = [=](const JSONGraphNode& node, bool has_bias) {
101-
const DLTensor* bias = nullptr;
100+
auto get_inputs = [=](const JSONGraphNode& node, bool has_bias, bool has_scale) {
101+
const DLTensor *bias = nullptr, *scaleA = nullptr, *scaleB = nullptr;
102102
if (has_bias) {
103103
bias = get_input(node, 2);
104+
} else if (has_scale) {
105+
scaleA = get_input(node, 2);
106+
scaleB = get_input(node, 3);
104107
}
105-
return std::make_tuple(get_input(node, 0), get_input(node, 1), bias);
108+
return std::make_tuple(get_input(node, 0), get_input(node, 1), bias, scaleA, scaleB);
106109
};
107110

108111
for (size_t i = 0; i < nodes_.size(); ++i) {
@@ -127,15 +130,17 @@ class CublasJSONRuntime : public JSONRuntimeBase {
127130
epilogue = CUBLASLT_EPILOGUE_BIAS;
128131
}
129132

130-
auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT);
133+
bool has_scale = op_name.find("multiply") != std::string::npos;
134+
auto [a_ptr, b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr] =
135+
get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT, has_scale);
131136

132137
std::optional<float> dq_scale = std::nullopt;
133138
if (op_name.find("dequantize") != std::string::npos) {
134139
dq_scale = std::stof(node.GetAttr<std::vector<std::string>>("dq_scale")[0]);
135140
}
136141

137142
tvm::contrib::CallCublasLt(entry_ptr->handle, stream, entry_ptr->matmul_pref_desc, a_ptr,
138-
b_ptr, bias_ptr, out_ptr, transa, transb,
143+
b_ptr, bias_ptr, scaleA_ptr, scaleB_ptr, out_ptr, transa, transb,
139144
entry_ptr->workspace_ptr, entry_ptr->workspace_size, epilogue,
140145
dq_scale);
141146
}

src/runtime/contrib/cublas/cublas_utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
123123
/*! \brief Execute matrix multiply followed by the specified epilogue, using cuBLASLt. */
124124
void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
125125
cublasLtMatmulPreference_t matmul_pref_desc, const DLTensor* A, const DLTensor* B,
126-
const DLTensor* bias, const DLTensor* C, bool transa, bool transb,
127-
void* workspace_ptr, size_t workspace_size,
128-
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT,
126+
const DLTensor* bias, const DLTensor* scaleA, const DLTensor* scaleB,
127+
const DLTensor* C, bool transa, bool transb, void* workspace_ptr,
128+
size_t workspace_size, cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_DEFAULT,
129129
std::optional<float> dq_scale = std::nullopt);
130130

131131
} // namespace contrib

tests/python/relax/test_codegen_cublas.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,40 @@ def get_relax_matmul_dequantize_module(
134134
return tvm.IRModule({"main": func})
135135

136136

137+
def get_relax_matmul_multiply_module(
138+
x_shape,
139+
y_shape,
140+
z_shape,
141+
in_dtype,
142+
acc_dtype,
143+
out_dtype,
144+
transposed_y=False,
145+
):
146+
"""Create a matmul op followd by multiply operations."""
147+
with IRBuilder() as builder:
148+
with relax_builder.function():
149+
R.func_name("main")
150+
x = R.arg("x", R.Tensor(x_shape, in_dtype))
151+
y = R.arg("y", R.Tensor(y_shape, in_dtype))
152+
scaleA = R.arg("scaleA", R.Tensor(z_shape, acc_dtype))
153+
scaleB = R.arg("scaleB", R.Tensor(z_shape, acc_dtype))
154+
155+
with R.dataflow() as frame:
156+
if transposed_y:
157+
axes = list(range(len(y_shape) - 2)) + [-1, -2]
158+
y = R.emit(R.permute_dims(y, axes=axes))
159+
result = R.emit(R.matmul(x, y, out_dtype=acc_dtype))
160+
z = R.emit(R.multiply(scaleA, scaleB))
161+
result = R.emit(R.multiply(result, z))
162+
if acc_dtype != out_dtype:
163+
result = R.emit(R.astype(result, out_dtype))
164+
R.output(result)
165+
R.func_ret_value(frame.output_vars[0])
166+
167+
func = builder.get()
168+
return tvm.IRModule({"main": func})
169+
170+
137171
@pytest.mark.parametrize(
138172
"x_shape, y_shape, transpose_y, epilogue",
139173
[
@@ -327,6 +361,36 @@ def test_matmul_fp8_dequantize_offload():
327361
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
328362

329363

364+
@tvm.testing.requires_cuda_compute_version(9)
365+
@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
366+
def test_matmul_fp8_multiply_offload():
367+
x_shape = (10, 32)
368+
y_shape = (64, 32)
369+
z_shape = (1,)
370+
in_dtype, acc_dtype = ("e4m3_float8", "float32")
371+
372+
mod = get_relax_matmul_multiply_module(
373+
x_shape,
374+
y_shape,
375+
z_shape,
376+
in_dtype,
377+
acc_dtype,
378+
"float16",
379+
transposed_y=True,
380+
)
381+
382+
numpytype = "float8_e4m3fn"
383+
x = np.random.uniform(low=0, high=5, size=x_shape).astype(numpytype)
384+
y = np.random.uniform(low=0, high=5, size=y_shape).astype(numpytype)
385+
scaleA = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
386+
scaleB = np.random.uniform(low=0, high=5, size=z_shape).astype(acc_dtype)
387+
args = (x, y, scaleA, scaleB)
388+
389+
out = get_result_with_relax_cublas_offload(mod, args)
390+
ref = build_and_run(mod, args, "llvm", legalize=True)
391+
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)
392+
393+
330394
@pytest.mark.parametrize(
331395
"M, N, K, out_dtype, transposed_y, partition_done",
332396
[
@@ -371,6 +435,21 @@ def test_cublas_partition_fp8_matmul_dequantize(M, N, K, scale, zp, num_bindings
371435
assert len(mod["main"].body.blocks[0].bindings) == num_bindings
372436

373437

438+
def test_cublas_partition_fp8_matmul_multiply():
439+
M, N, K = (32, 64, 128)
440+
mod = get_relax_matmul_multiply_module(
441+
(M, K),
442+
(N, K),
443+
(1,),
444+
"e4m3_float8",
445+
"float32",
446+
"float16",
447+
transposed_y=True,
448+
)
449+
mod = partition_for_cublas(mod)
450+
assert len(mod["main"].body.blocks[0].bindings) == 1
451+
452+
374453
def test_cublas_partition_matmul_without_bias():
375454
# cuBLAS does not handle 2D bias (residual input)
376455
mod = get_relax_matmul_module((16, 32), (32, 32), "float16", "float16", bias_shape=(16, 32))

0 commit comments

Comments
 (0)