Skip to content

Commit d58a503

Browse files
committed
[Fix] Refactor to avoid two passes of the same kernel
1 parent 973678a commit d58a503

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

transformer_engine/common/util/rocm_cast_kernels.cuh

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,18 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
457457
const size_t rows = input.flat_first_dim();
458458
const size_t cols = input.flat_last_dim();
459459

460+
if constexpr (IS_DBIAS) {
461+
NVTE_CHECK(dbias, "DBias tensor must be provided when IS_DBIAS is true.");
462+
NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true.");
463+
if (workspace->data.dptr == nullptr ||
464+
workspace->data.dtype != DType::kFloat32 ||
465+
workspace->data.shape != std::vector<size_t>{rows, cols}) {
466+
workspace->data.shape = {rows, cols};
467+
workspace->data.dtype = DType::kFloat32;
468+
return;
469+
}
470+
}
471+
460472
if (output && output->data.dptr) {
461473
if constexpr (IS_DACT) {
462474
NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output.");
@@ -470,17 +482,6 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
470482
const void *ptr_to_reduce = nullptr;
471483
DType dtype_to_reduce;
472484

473-
NVTE_CHECK(dbias, "DBias tensor must be provided when IS_DBIAS is true.");
474-
NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true.");
475-
476-
if (workspace->data.dptr == nullptr ||
477-
workspace->data.dtype != DType::kFloat32 ||
478-
workspace->data.shape != std::vector<size_t>{rows, cols}) {
479-
workspace->data.shape = {rows, cols};
480-
workspace->data.dtype = DType::kFloat32;
481-
return;
482-
}
483-
484485
workspace->amax = {};
485486
workspace->scale = {};
486487
workspace->scale_inv = {};
@@ -508,7 +509,7 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso
508509
dbias, rows, cols, stream, workspace);
509510
);
510511
);
511-
}
512+
}
512513
break;
513514
}
514515
case NVTE_MXFP8_1D_SCALING: {

0 commit comments

Comments
 (0)