@@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
836836 b , i = get_const_tuple (data .shape )
837837 o , _ = get_const_tuple (weights .shape )
838838 if (
839- target .kind .name == "cuda"
839+ target .kind .name in [ "cuda" , "vulkan" ]
840840 and data .dtype == "int8"
841841 and weights .dtype == "int8"
842842 and out_type .dtype == "int32"
@@ -860,36 +860,28 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
860860 name = "dense_large_batch.gpu" ,
861861 plevel = 5 ,
862862 )
863- if target .kind .name == "cuda" :
864- if nvcc .have_tensorcore (target = target ):
865- if (
866- (
867- data .dtype in ["float16" , "int8" , "uint8" ]
868- and (
869- (i % 16 == 0 and b % 16 == 0 and o % 16 == 0 )
870- or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0 )
871- or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0 )
872- )
873- )
874- or (
875- data .dtype in ["int4" , "uint4" ]
876- and i % 32 == 0
877- and b % 8 == 0
878- and o % 8 == 0
879- )
880- or (
881- data .dtype in ["int1" , "uint1" ]
882- and i % 128 == 0
883- and b % 8 == 0
884- and o % 8 == 0
885- )
886- ):
887- strategy .add_implementation (
888- wrap_compute_dense (topi .cuda .dense_tensorcore ),
889- wrap_topi_schedule (topi .cuda .schedule_dense_tensorcore ),
890- name = "dense_tensorcore.cuda" ,
891- plevel = 20 ,
863+
864+ if target .kind .name == "cuda" :
865+ if nvcc .have_tensorcore (target = target ):
866+ if (
867+ (
868+ data .dtype in ["float16" , "int8" , "uint8" ]
869+ and (
870+ (i % 16 == 0 and b % 16 == 0 and o % 16 == 0 )
871+ or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0 )
872+ or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0 )
892873 )
874+ )
875+ or (data .dtype in ["int4" , "uint4" ] and i % 32 == 0 and b % 8 == 0 and o % 8 == 0 )
876+ or (data .dtype in ["int1" , "uint1" ] and i % 128 == 0 and b % 8 == 0 and o % 8 == 0 )
877+ ):
878+ strategy .add_implementation (
879+ wrap_compute_dense (topi .cuda .dense_tensorcore ),
880+ wrap_topi_schedule (topi .cuda .schedule_dense_tensorcore ),
881+ name = "dense_tensorcore.cuda" ,
882+ plevel = 20 ,
883+ )
884+
893885 if target .kind .name == "cuda" and "cublas" in target .libs :
894886 strategy .add_implementation (
895887 wrap_compute_dense (topi .cuda .dense_cublas ),
@@ -927,7 +919,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
927919 )
928920 if target .kind .name == "cuda" and "cublas" in target .libs :
929921 strategy .add_implementation (
930- wrap_compute_batch_matmul (topi .cuda .batch_matmul_cublas ),
922+ wrap_compute_batch_matmul (topi .cuda .batch_matmul_cublas , need_out_dtype = True ),
931923 wrap_topi_schedule (topi .generic .schedule_extern ),
932924 name = "batch_matmul_cublas.cuda" ,
933925 plevel = 30 ,
0 commit comments