Skip to content

Commit 7d5ef84

Browse files
authored
[CUDA] Various int8 fix (cublas, cutlass, etc) (#10596)
* [CUTLASS] avoid tile size 256 for int8 + align1 case * allow selecting int8 dense strategy for vulkan * fixed cublas batch matmul for int8 * fixed int8 dense tensorcore strategy * add cutlass conv align1 + int8 case * support int8 mixed precision cublas bmm * black
1 parent 4d88a45 commit 7d5ef84

File tree

7 files changed

+69
-35
lines changed

7 files changed

+69
-35
lines changed

python/tvm/contrib/cutlass/gen_conv2d.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .conv2d_profiler import Conv2dProfilerEmitter
2323
from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP
2424
from .library import (
25+
DataType,
2526
EpilogueFunctor,
2627
SwizzlingFunctor,
2728
TensorDescription,
@@ -133,6 +134,10 @@ def enumerate_conv2d_operators(
133134
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
134135
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)
135136

137+
if element_c == DataType.s32 and A.alignment == 1:
138+
tile.threadblock_shape[0] = min(tile.threadblock_shape[0], 128)
139+
tile.threadblock_shape[1] = min(tile.threadblock_shape[1], 128)
140+
136141
op = Conv2dOperation(
137142
conv_kind,
138143
IteratorAlgorithm.Optimized,

python/tvm/contrib/cutlass/gen_gemm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .gemm_profiler import GemmProfilerEmitter
2121
from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP
2222
from .library import (
23+
DataType,
2324
EpilogueFunctor,
2425
SwizzlingFunctor,
2526
TensorDescription,
@@ -87,6 +88,14 @@ def enumerate_gemm_operators(
8788
B = TensorDescription(element_b, LayoutType.ColumnMajor, alignment)
8889
C = TensorDescription(element_c, LayoutType.RowMajor, alignment)
8990

91+
if element_c == DataType.s32 and A.alignment == 1:
92+
tile_description.threadblock_shape[0] = min(
93+
tile_description.threadblock_shape[0], 128
94+
)
95+
tile_description.threadblock_shape[1] = min(
96+
tile_description.threadblock_shape[1], 128
97+
)
98+
9099
op = GemmOperation(
91100
tile_description.minimum_compute_capability,
92101
tile_description,

python/tvm/relay/op/strategy/cuda.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
836836
b, i = get_const_tuple(data.shape)
837837
o, _ = get_const_tuple(weights.shape)
838838
if (
839-
target.kind.name == "cuda"
839+
target.kind.name in ["cuda", "vulkan"]
840840
and data.dtype == "int8"
841841
and weights.dtype == "int8"
842842
and out_type.dtype == "int32"
@@ -860,36 +860,28 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
860860
name="dense_large_batch.gpu",
861861
plevel=5,
862862
)
863-
if target.kind.name == "cuda":
864-
if nvcc.have_tensorcore(target=target):
865-
if (
866-
(
867-
data.dtype in ["float16", "int8", "uint8"]
868-
and (
869-
(i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
870-
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
871-
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
872-
)
873-
)
874-
or (
875-
data.dtype in ["int4", "uint4"]
876-
and i % 32 == 0
877-
and b % 8 == 0
878-
and o % 8 == 0
879-
)
880-
or (
881-
data.dtype in ["int1", "uint1"]
882-
and i % 128 == 0
883-
and b % 8 == 0
884-
and o % 8 == 0
885-
)
886-
):
887-
strategy.add_implementation(
888-
wrap_compute_dense(topi.cuda.dense_tensorcore),
889-
wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
890-
name="dense_tensorcore.cuda",
891-
plevel=20,
863+
864+
if target.kind.name == "cuda":
865+
if nvcc.have_tensorcore(target=target):
866+
if (
867+
(
868+
data.dtype in ["float16", "int8", "uint8"]
869+
and (
870+
(i % 16 == 0 and b % 16 == 0 and o % 16 == 0)
871+
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0)
872+
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0)
892873
)
874+
)
875+
or (data.dtype in ["int4", "uint4"] and i % 32 == 0 and b % 8 == 0 and o % 8 == 0)
876+
or (data.dtype in ["int1", "uint1"] and i % 128 == 0 and b % 8 == 0 and o % 8 == 0)
877+
):
878+
strategy.add_implementation(
879+
wrap_compute_dense(topi.cuda.dense_tensorcore),
880+
wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
881+
name="dense_tensorcore.cuda",
882+
plevel=20,
883+
)
884+
893885
if target.kind.name == "cuda" and "cublas" in target.libs:
894886
strategy.add_implementation(
895887
wrap_compute_dense(topi.cuda.dense_cublas),
@@ -927,7 +919,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
927919
)
928920
if target.kind.name == "cuda" and "cublas" in target.libs:
929921
strategy.add_implementation(
930-
wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
922+
wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas, need_out_dtype=True),
931923
wrap_topi_schedule(topi.generic.schedule_extern),
932924
name="batch_matmul_cublas.cuda",
933925
plevel=30,

python/tvm/topi/cuda/batch_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def batch_matmul_cublas(
229229
b, k, n = get_const_tuple(y.shape)
230230
if all([isinstance(s, int) for s in [b, m, n, k]]):
231231
cfg.add_flop(b * m * k * n * 2)
232-
return cublas.batch_matmul(x, y, transa=transpose_a, transb=transpose_b)
232+
return cublas.batch_matmul(x, y, transa=transpose_a, transb=transpose_b, dtype=out_dtype)
233233

234234

235235
@autotvm.register_topi_schedule("batch_matmul_cublas.cuda")

src/runtime/contrib/cublas/cublas.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl)
290290
transa = IsInPlaceTransposed(A) ? !transa : transa;
291291
transb = IsInPlaceTransposed(B) ? !transb : transb;
292292

293-
ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, false)) << "Unsupported data type";
293+
ICHECK(CheckMixPrecisionType(A->dtype, C->dtype, true)) << "Unsupported data type";
294294
ICHECK(!TypeMatch(A->dtype, kDLInt, 8) || ColumnStride(A) % 4 == 0)
295295
<< "leading dimension must divide 4 for int8 gemm";
296296
ICHECK(!TypeMatch(B->dtype, kDLInt, 8) || ColumnStride(B) % 4 == 0)

tests/python/contrib/test_cublas.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,14 @@ def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5):
120120

121121
dev = tvm.cuda(0)
122122
f = tvm.build(s, [A, B, C], "cuda")
123-
a = tvm.nd.array(np.random.uniform(size=Ashape).astype(A.dtype), dev)
124-
b = tvm.nd.array(np.random.uniform(size=Bshape).astype(B.dtype), dev)
123+
124+
if "int" in in_dtype:
125+
a = tvm.nd.array(np.random.uniform(1, 10, size=Ashape).astype(in_dtype), dev)
126+
b = tvm.nd.array(np.random.uniform(1, 10, size=Bshape).astype(in_dtype), dev)
127+
else:
128+
a = tvm.nd.array(np.random.uniform(size=Ashape).astype(A.dtype), dev)
129+
b = tvm.nd.array(np.random.uniform(size=Bshape).astype(B.dtype), dev)
130+
125131
c = tvm.nd.array(np.zeros(Cshape, dtype=C.dtype), dev)
126132
f(a, b, c)
127133
tvm.testing.assert_allclose(
@@ -161,6 +167,8 @@ def test_batch_matmul():
161167
(16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2
162168
)
163169

170+
verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "int8", "int32")
171+
164172

165173
if __name__ == "__main__":
166174
test_matmul_add()

tests/python/contrib/test_cutlass.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,26 @@ def test_conv2d():
725725
ref_target="llvm",
726726
)
727727

728+
# align1 + int8 case
729+
d_shape = (16, 3, 32, 32)
730+
w_shape = (32, 3, 3, 3)
731+
mod_nchw = get_conv2d_nchw(
732+
d_shape, w_shape, padding, out_dtype="int32", data_dtype="uint8", weight_dtype="int8"
733+
)
734+
735+
verify_conv2d(
736+
mod_nchw,
737+
mod_nchw,
738+
d_shape,
739+
w_shape,
740+
sm=80,
741+
atol=1e-5,
742+
rtol=1e-5,
743+
ref_target="llvm",
744+
data_dtype="uint8",
745+
weight_dtype="int8",
746+
)
747+
728748

729749
def test_conv2d_fusion():
730750
d_shape = (16, 16, 32, 32)

0 commit comments

Comments
 (0)