@@ -467,18 +467,7 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
467467 workspace->data .dtype = DType::kFloat32 ;
468468 return ;
469469 }
470- }
471-
472- if (output && output->data .dptr ) {
473- if constexpr (IS_DACT) {
474- NVTE_CHECK (act_input, " Gradient tensor must be provided for DACT output." );
475- CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
476- } else {
477- CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
478- }
479- }
480470
481- if constexpr (IS_DBIAS) {
482471 const void *ptr_to_reduce = nullptr ;
483472 DType dtype_to_reduce;
484473
@@ -490,9 +479,15 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
490479 // The values to reduce are the result of the dAct function.
491480 NVTE_CHECK (act_input, " Gradient tensor must be provided for DBias + DACT." );
492481 CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, workspace, stream);
482+ if (output && output->data .dptr ) {
483+ CastVectorizedUnaryKernelLauncher<transformer_engine::Empty, nullptr >(*workspace, noop, output, stream);
484+ }
493485 ptr_to_reduce = workspace->data .dptr ;
494486 dtype_to_reduce = workspace->data .dtype ;
495487 } else {
488+ if (output && output->data .dptr ) {
489+ CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
490+ }
496491 // The values to reduce are just the input values.
497492 ptr_to_reduce = input.data .dptr ;
498493 dtype_to_reduce = input.data .dtype ;
@@ -509,6 +504,15 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
509504 dbias, rows, cols, stream, workspace);
510505 );
511506 );
507+ } else {
508+ if (output && output->data .dptr ) {
509+ if constexpr (IS_DACT) {
510+ NVTE_CHECK (act_input, " Gradient tensor must be provided for DACT output." );
511+ CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
512+ } else {
513+ CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
514+ }
515+ }
512516 }
513517 break ;
514518 }
0 commit comments