Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
03004ee
[Fix] Added dbias and dgelu kernels for ROCm
AllenFarcas Oct 6, 2025
fc86c36
[Fix] Added copyright
AllenFarcas Oct 6, 2025
ae3c13e
[Fix] Refactored ROCm implementation for FP8 quantization separately
AllenFarcas Oct 7, 2025
5e56c74
[Fix] Refactored ROCm implementation in ROCm specific file. Fixed ind…
AllenFarcas Oct 7, 2025
662ef81
[Fix] Adjusted indentation and moved fp8_quantize_rocm to rocm_cast_k…
AllenFarcas Oct 8, 2025
5ca2aa2
[Fix] Adjusted ifndef indentation
AllenFarcas Oct 8, 2025
191ecab
[Fix] Working version of fp8_quantize_rocm
AllenFarcas Oct 13, 2025
dd8a1bb
[Fix] Refactored fp8_quantize_rocm and implemented partial dbias redu…
AllenFarcas Oct 16, 2025
9c5cf99
[Fix] Fixed indentation
AllenFarcas Oct 16, 2025
e581af9
[Fix] Added MXFP8 scaling mode support
AllenFarcas Oct 16, 2025
d0052e5
[Fix] Fixed indentation
AllenFarcas Oct 16, 2025
5ad2960
[Fix] Added hip related alloc and free
AllenFarcas Oct 16, 2025
86385f2
[Fix] Refactored code for efficiency and performance
AllenFarcas Oct 23, 2025
a55b81e
[Fix] Moved everything new to rocm_cast_kernels
AllenFarcas Oct 23, 2025
7f9ead8
[Fix] Removing inserted space
AllenFarcas Oct 23, 2025
c9ac6cf
[Fix] Refactor to avoid two passes of the same kernel
AllenFarcas Oct 27, 2025
b893a87
[Fix] Refactor for kernel performance, reduce call redundancy.
AllenFarcas Oct 27, 2025
9803b31
[Fix] Fixed memory corruption error reading/writing partial sum works…
AllenFarcas Nov 12, 2025
18888a4
[Fix] Fixed indentation
AllenFarcas Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ list(APPEND test_cuda_sources
test_cast_transpose_dbias.cu
test_cast_transpose_dbias_dgelu.cu
test_cast_transpose_dgeglu.cu
test_cast_dbias.cu
test_cast_dbias_dgelu.cu
test_act.cu
test_normalization.cu
test_normalization_mxfp8.cu
Expand All @@ -29,9 +31,7 @@ list(APPEND test_cuda_sources
../test_common.cu)
if(USE_CUDA)
list(APPEND test_cuda_sources
test_cast_float8blockwise.cu
test_cast_dbias.cu
test_cast_dbias_dgelu.cu)
test_cast_float8blockwise.cu)
else()
list(APPEND test_cuda_sources
test_cublaslt_gemm.cu)
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_cast_dbias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@ class CastDBiasTestSuite : public ::testing::TestWithParam<std::tuple<transforme
TEST_P(CastDBiasTestSuite, TestCastDBias) {
using namespace transformer_engine;
using namespace test;
#ifndef __HIP_PLATFORM_AMD__
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
#endif

const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/operator/test_cast_dbias_dgelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,13 @@ class CastDBiasDGeluTestSuite : public ::testing::TestWithParam<std::tuple<trans
TEST_P(CastDBiasDGeluTestSuite, TestCastDBiasDgelu) {
using namespace transformer_engine;
using namespace test;

#ifndef __HIP_PLATFORM_AMD__
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
#endif

const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
Expand Down
4 changes: 4 additions & 0 deletions tests/cpp/operator/test_cast_float8blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,11 @@ class FusedCastFloat8VectorwiseTestSuite
}

TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
#ifndef __HIP_PLATFORM_AMD__
if (getDeviceComputeCapability() < hopperComputeCapability) {
GTEST_SKIP();
}
#endif

using namespace transformer_engine;
using namespace test;
Expand Down Expand Up @@ -529,9 +531,11 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
}

TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
#ifndef __HIP_PLATFORM_AMD__
if (getDeviceComputeCapability() < hopperComputeCapability) {
GTEST_SKIP();
}
#endif

using namespace transformer_engine;
using namespace test;
Expand Down
2 changes: 2 additions & 0 deletions tests/cpp/operator/test_normalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma,
return;
}

#ifndef __HIP_PLATFORM_AMD__
if (getDeviceComputeCapability() < blackwellComputeCapability && use_cudnn) {
GTEST_SKIP() << "cuDNN normalizations not supported on pre-Blackwell GPUs yet!";
}
#endif

using WeightType = InputType;
DType itype = TypeInfo<InputType>::dtype;
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -531,13 +531,13 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
const T *test_data = rowwise ? test.rowwise_cpu_dptr<T>() : test.columnwise_cpu_dptr<T>();
const T *ref_data = reinterpret_cast<const T*>(ref);
for (size_t i = 0; i < N; ++i) {
#ifndef __HIP_PLATFORM_AMD__
#ifndef __HIP_PLATFORM_AMD__
double t = static_cast<double>(test_data[i]);
double r = static_cast<double>(ref_data[i]);
#else
#else
double t = static_cast<double>(static_cast<float>(test_data[i]));
double r = static_cast<double>(static_cast<float>(ref_data[i]));
#endif
#endif
bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
/* For Float32 the floating point comparison is enough to error out */
bool assertion = mismatch && test.dtype() == DType::kFloat32;
Expand Down
26 changes: 14 additions & 12 deletions transformer_engine/common/util/cast_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,6 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
#endif //#ifndef __HIP_PLATFORM_AMD__

// Supported by the Arch < 10.0
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
Expand All @@ -1227,6 +1226,7 @@ void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}
}
#endif //#ifndef __HIP_PLATFORM_AMD__

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
Expand All @@ -1251,17 +1251,19 @@ void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *no
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");

#ifndef __HIP_PLATFORM_AMD__
// Supported by the Arch >= 10.0
if (is_supported_by_CC_100()) {
fp8_quantize_arch_ge_100<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
dbias, workspace, stream);
} else {
#endif //#ifndef __HIP_PLATFORM_AMD__
// Supported by the Arch < 10.0
fp8_quantize_arch_l_100<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
dbias, workspace, stream);
#ifndef __HIP_PLATFORM_AMD__
}
// NVIDIA
// Supported by the Arch >= 10.0
if (is_supported_by_CC_100()) {
fp8_quantize_arch_ge_100<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
dbias, workspace, stream);
} else { // Supported by the Arch < 10.0
fp8_quantize_arch_l_100<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
dbias, workspace, stream);
}
#else
// AMD
fp8_quantize_rocm<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output,
dbias, workspace, stream);
#endif //#ifndef __HIP_PLATFORM_AMD__
}

Expand Down
141 changes: 141 additions & 0 deletions transformer_engine/common/util/rocm_cast_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -385,4 +385,145 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
}
}

// Forward declaration of functions defined in `cast_kernels.cuh`
template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
cudaStream_t stream);

template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output,
cudaStream_t stream);

template <typename ParamOP, float (*OP)(float, const ParamOP &)>
void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output,
cudaStream_t stream);

constexpr size_t TILE_DIM = 32;
template <typename DTypeReduce>
__global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_output, int rows, int cols) {
__shared__ float tile[TILE_DIM][TILE_DIM];

int tile_start_col = blockIdx.x * TILE_DIM;
int tile_start_row = blockIdx.y * TILE_DIM;
int thread_col_in_tile = threadIdx.x;
int thread_row_in_tile = threadIdx.y;

int global_col = tile_start_col + thread_col_in_tile;
int global_row = tile_start_row + thread_row_in_tile;

if (global_row < rows && global_col < cols) {
tile[thread_row_in_tile][thread_col_in_tile] = static_cast<float>(input[global_row * cols + global_col]);
} else {
tile[thread_row_in_tile][thread_col_in_tile] = 0.0f;
}
__syncthreads();

for (int stride = TILE_DIM / 2; stride > 0; stride /= 2) {
if (thread_row_in_tile < stride) {
tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile];
}
__syncthreads();
}

if (thread_row_in_tile == 0 && global_col < cols) {
partial_output[blockIdx.y * cols + global_col] = tile[0][thread_col_in_tile];
}
}

template <typename DTypeReduce, typename DBiasTypeOut>
void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows,
const size_t cols, cudaStream_t stream, Tensor* partial_sum_workspace) {
dim3 block_dim_partial(TILE_DIM, TILE_DIM);
dim3 grid_dim_partial(DIVUP(cols, TILE_DIM), DIVUP(rows, TILE_DIM));

const size_t partial_rows = grid_dim_partial.y;
float* partial_workspace = reinterpret_cast<float*>(partial_sum_workspace->data.dptr);

partial_reduce_kernel<DTypeReduce><<<grid_dim_partial, block_dim_partial, 0, stream>>>(
workspace_ptr,
partial_workspace,
rows, cols);

reduce_dbias<DBiasTypeOut>(partial_workspace, dbias, partial_rows, cols, stream);
}

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tensor *noop,
Tensor *output, Tensor *dbias, Tensor *workspace,
cudaStream_t stream) {
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();

if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias, "DBias tensor must be provided when IS_DBIAS is true.");
NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true.");
if (workspace->data.dptr == nullptr ||
workspace->data.dtype != DType::kFloat32 ||
workspace->data.shape != std::vector<size_t>{rows, cols}) {
workspace->data.shape = {rows, cols};
workspace->data.dtype = DType::kFloat32;
return;
}

const void *ptr_to_reduce = nullptr;
DType dtype_to_reduce;

workspace->amax = {};
workspace->scale = {};
workspace->scale_inv = {};

if constexpr (IS_DACT) {
// The values to reduce are the result of the dAct function.
NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT.");
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, workspace, stream);
if (output && output->data.dptr) {
CastVectorizedUnaryKernelLauncher<transformer_engine::Empty, nullptr>(*workspace, noop, output, stream);
}
ptr_to_reduce = workspace->data.dptr;
dtype_to_reduce = workspace->data.dtype;
} else {
if (output && output->data.dptr) {
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
}
// The values to reduce are just the input values.
ptr_to_reduce = input.data.dptr;
dtype_to_reduce = input.data.dtype;
}

NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias tensor.");

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
dbias->data.dtype, DBiasTypeOut,
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
dtype_to_reduce, DTypeReduce,
reduce_dbias_rocm<DTypeReduce, DBiasTypeOut>(
reinterpret_cast<const DTypeReduce *>(ptr_to_reduce),
dbias, rows, cols, stream, workspace);
);
);
} else {
if (output && output->data.dptr) {
if constexpr (IS_DACT) {
NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output.");
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(input, act_input, output, stream);
} else {
CastVectorizedUnaryKernelLauncher<ParamOP, OP>(input, noop, output, stream);
}
}
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
mxfp8_quantize<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP>(input, act_input, noop, output, dbias,
workspace, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}
}

} // namespace transformer_engine