@@ -457,6 +457,18 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
457457 const size_t rows = input.flat_first_dim ();
458458 const size_t cols = input.flat_last_dim ();
459459
460+ if constexpr (IS_DBIAS) {
461+ NVTE_CHECK (dbias, " DBias tensor must be provided when IS_DBIAS is true." );
462+ NVTE_CHECK (workspace, " Workspace must be provided when IS_DBIAS is true." );
463+ if (workspace->data .dptr == nullptr ||
464+ workspace->data .dtype != DType::kFloat32 ||
465+ workspace->data .shape != std::vector<size_t >{rows, cols}) {
466+ workspace->data .shape = {rows, cols};
467+ workspace->data .dtype = DType::kFloat32 ;
468+ return ;
469+ }
470+ }
471+
460472 if (output && output->data .dptr ) {
461473 if constexpr (IS_DACT) {
462474 NVTE_CHECK (act_input, " Gradient tensor must be provided for DACT output." );
@@ -470,17 +482,6 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
470482 const void *ptr_to_reduce = nullptr ;
471483 DType dtype_to_reduce;
472484
473- NVTE_CHECK (dbias, " DBias tensor must be provided when IS_DBIAS is true." );
474- NVTE_CHECK (workspace, " Workspace must be provided when IS_DBIAS is true." );
475-
476- if (workspace->data .dptr == nullptr ||
477- workspace->data .dtype != DType::kFloat32 ||
478- workspace->data .shape != std::vector<size_t >{rows, cols}) {
479- workspace->data .shape = {rows, cols};
480- workspace->data .dtype = DType::kFloat32 ;
481- return ;
482- }
483-
484485 workspace->amax = {};
485486 workspace->scale = {};
486487 workspace->scale_inv = {};
@@ -508,7 +509,7 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
508509 dbias, rows, cols, stream, workspace);
509510 );
510511 );
511- }
512+ }
512513 break ;
513514 }
514515 case NVTE_MXFP8_1D_SCALING: {
0 commit comments