Skip to content

Commit 08965f0

Browse files
authored
[CUBLAS] Set fp32 compute and scale dtypes in fp16 matmul (#16892)
This commit replaces fp16 compute dtype and scale dtype by fp32 in cublas matmul.
1 parent 3680a0d commit 08965f0

File tree

1 file changed

+0
-6
lines changed

1 file changed

+0
-6
lines changed

src/runtime/contrib/cublas/cublas.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
150150
cudaDataType_t c_type = CUDA_R_32F;
151151
float one_fp32 = 1.0;
152152
float zero_fp32 = 0.0;
153-
auto one_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(1.0);
154-
auto zero_fp16 = __truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(0.0);
155153
int32_t one_i32 = 1;
156154
int32_t zero_i32 = 0;
157155
void* alpha = &one_fp32;
@@ -168,10 +166,6 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
168166

169167
if (TypeMatch(C->dtype, kDLFloat, 16)) {
170168
c_type = CUDA_R_16F;
171-
compute_type = CUBLAS_COMPUTE_16F;
172-
scale_type = CUDA_R_16F;
173-
alpha = &one_fp16;
174-
beta = &zero_fp16;
175169
} else if (TypeMatch(C->dtype, kDLInt, 32)) {
176170
c_type = CUDA_R_32I;
177171
compute_type = CUBLAS_COMPUTE_32I;

0 commit comments

Comments
 (0)