Skip to content

Commit 819ec5c

Browse files
committed
[Fix] Refactor for kernel performance, reduce call redundancy.
1 parent d58a503 commit 819ec5c

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

transformer_engine/common/util/rocm_cast_kernels.cuh

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -467,18 +467,7 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
467467
workspace->data.dtype = DType::kFloat32;
468468
return;
469469
}
470-
}
471-
472-
if (output && output->data.dptr) {
473-
if constexpr (IS_DACT) {
474-
NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output.");
475-
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
476-
} else {
477-
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
478-
}
479-
}
480470

481-
if constexpr (IS_DBIAS) {
482471
const void *ptr_to_reduce = nullptr;
483472
DType dtype_to_reduce;
484473

@@ -490,9 +479,15 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
490479
// The values to reduce are the result of the dAct function.
491480
NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT.");
492481
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, workspace, stream);
482+
if (output && output->data.dptr) {
483+
CastVectorizedUnaryKernelLauncher<transformer_engine::Empty, nullptr>(*workspace, noop, output, stream);
484+
}
493485
ptr_to_reduce = workspace->data.dptr;
494486
dtype_to_reduce = workspace->data.dtype;
495487
} else {
488+
if (output && output->data.dptr) {
489+
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
490+
}
496491
// The values to reduce are just the input values.
497492
ptr_to_reduce = input.data.dptr;
498493
dtype_to_reduce = input.data.dtype;
@@ -509,6 +504,15 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
509504
dbias, rows, cols, stream, workspace);
510505
);
511506
);
507+
} else {
508+
if (output && output->data.dptr) {
509+
if constexpr (IS_DACT) {
510+
NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output.");
511+
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
512+
} else {
513+
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
514+
}
515+
}
512516
}
513517
break;
514518
}

0 commit comments

Comments
 (0)