Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions transformer_engine/common/cast/core/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@ inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_bl

inline bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim();
constexpr size_t TMA_bytes = 16;
const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
const size_t alignment_requirement = (TMA_GMEM_ALIGNMENT * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0;
}

__device__ __forceinline__ unsigned char*
align_smem_ptr_per_TMA_requirements(unsigned char* p) {
size_t addr = reinterpret_cast<size_t>(p);
addr = (addr + TMA_SHMEM_ALIGNMENT - 1) & ~(TMA_SHMEM_ALIGNMENT - 1);
return reinterpret_cast<unsigned char*>(addr);
}

namespace kernel {

constexpr size_t THREADS_PER_BLOCK = 256;
Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "./specialized/gated_mxfp8_rowwise_swiglu.cuh"

namespace transformer_engine {
namespace dispatch {
Expand Down Expand Up @@ -696,6 +697,18 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
scaling_type = ScalingType::BIDIMENSIONAL;
}

// Optimized BWD/FWD SwiGLU MXFP8 Rowwise kernels for BF16/FP16 inputs
if constexpr (!std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
if constexpr ((!IS_BWD && (ActOP == &silu<fp32, fp32>))
|| (IS_BWD && (ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>))) {
if (((gated_input.dtype() == DType::kFloat16) || (gated_input.dtype() == DType::kBFloat16))
&& (scaling_type == ScalingType::ROWWISE)) {
quantize_gated_rowwise<IS_BWD, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
return;
}
}
}

const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
Expand Down
Loading
Loading