diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 97b130952d..4fe682efcd 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -79,15 +79,15 @@ struct SimpleTensor { std::vector shape; DType dtype; - SimpleTensor(void *dptr, const std::vector &shape, DType dtype) - : dptr(dptr), shape(shape), dtype(dtype) {} + SimpleTensor(void *dptr, std::vector shape, DType dtype) + : dptr{dptr}, shape{std::move(shape)}, dtype{dtype} {} SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT : dptr(tensor.data_ptr), shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim), dtype(static_cast(tensor.dtype)) {} - SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {} + SimpleTensor() : SimpleTensor(nullptr, std::vector{0}, DType::kFloat32) {} operator NVTEBasicTensor() const { return {dptr, static_cast(dtype), @@ -104,7 +104,8 @@ struct SimpleTensor { void clear() { dptr = nullptr; - shape.resize(0); + shape.resize(1); + shape[0] = 0; dtype = DType::kFloat32; } }; @@ -125,11 +126,11 @@ struct Tensor { Tensor() : data(), columnwise_data(), - amax(nullptr, {1}, DType::kFloat32), - columnwise_amax(nullptr, {1}, DType::kFloat32), - scale(nullptr, {1}, DType::kFloat32), - scale_inv(nullptr, {1}, DType::kFloat32), - columnwise_scale_inv(nullptr, {1}, DType::kFloat32), + amax(), + columnwise_amax(), + scale(), + scale_inv(), + columnwise_scale_inv(), scaling_mode(NVTE_DELAYED_TENSOR_SCALING), nvte_tensor(0) {} @@ -154,11 +155,10 @@ struct Tensor { return acc; } - bool has_data() const noexcept { return data.dptr != nullptr; } + bool has_data() const noexcept { return data.dptr != nullptr && data.numel() != 0; } - // Check for size (not just pointer) for 0-dim or no token cases. bool has_columnwise_data() const noexcept { - return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0; + return columnwise_data.dptr != nullptr && columnwise_data.numel() != 0; } DType dtype() const { @@ -169,34 +169,52 @@ struct Tensor { } size_t dim() const { - if (!has_data() && has_columnwise_data()) { + // Check whether a tensor shape matches an uninitialized tensor + auto is_shape_trivial = [](const std::vector &shape) -> bool { + return shape.size() == 1 && shape[0] == 0; + }; + + // Choose data buffer based on whether it is initialized + // Note: Logically each tensor format interprets its data + // differently, but for simplicity we assume they all use row-wise + // and column-wise data similarly. + bool use_columnwise_shape = false; + if (data.dptr != nullptr) { + use_columnwise_shape = false; + } else if (columnwise_data.dptr != nullptr) { + use_columnwise_shape = true; + } else if (!is_shape_trivial(data.shape)) { + use_columnwise_shape = false; + } else if (!is_shape_trivial(columnwise_data.shape)) { + use_columnwise_shape = true; + } + + // Infer number of dims based on data + if (use_columnwise_shape) { return columnwise_data.shape.size(); - } else { - return data.shape.size(); } + return data.shape.size(); } std::vector shape() const { - /* Note: We sometimes experience spurious compiler errors - * (-Wstringop-overflow) from this function. It appears that GCC - * has some bugs with std::vector (see - * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). - */ + // Check whether a tensor shape matches an uninitialized tensor + auto is_shape_trivial = [](const std::vector &shape) -> bool { + return shape.size() == 1 && shape.front() == 0; + }; + + // Each tensor format interprets its data differently switch (scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: case NVTE_NVFP4_1D_SCALING: { // Choose data buffer based on whether it is initialized - // Note: Uninitialized buffers currently have shape=[]. - // However, this is logically incorrect. 0-D tensors have 1 - // entry, and uninitialized tensors should have shape=[0]. bool use_columnwise_shape = false; if (data.dptr != nullptr) { use_columnwise_shape = false; } else if (columnwise_data.dptr != nullptr) { use_columnwise_shape = true; - } else if (data.shape.size() != 0) { + } else if (!is_shape_trivial(data.shape)) { use_columnwise_shape = false; - } else if (columnwise_data.shape.size() != 0) { + } else if (!is_shape_trivial(columnwise_data.shape)) { use_columnwise_shape = true; } @@ -215,38 +233,56 @@ struct Tensor { } return data.shape; } - case NVTE_MXFP8_1D_SCALING: - if (!has_data() && has_columnwise_data()) { + case NVTE_MXFP8_1D_SCALING: { + // Choose data buffer based on whether it is initialized + bool use_columnwise_shape = false; + if (data.dptr != nullptr) { + use_columnwise_shape = false; + } else if (columnwise_data.dptr != nullptr) { + use_columnwise_shape = true; + } else if (!is_shape_trivial(data.shape)) { + use_columnwise_shape = false; + } else if (!is_shape_trivial(columnwise_data.shape)) { + use_columnwise_shape = true; + } + + // Infer shape based on data + if (use_columnwise_shape) { return columnwise_data.shape; - } else { - return data.shape; } - break; + return data.shape; + } case NVTE_BLOCK_SCALING_1D: case NVTE_BLOCK_SCALING_2D: { - if (!has_data() && has_columnwise_data()) { - std::vector shape; - size_t ndim = columnwise_data.shape.size(); - shape.reserve(ndim); - for (size_t i = 0; i + 1 < ndim; ++i) { - shape.push_back(columnwise_data.shape[i + 1]); - } - if (ndim > 0) { - shape.push_back(columnwise_data.shape[0]); + // Choose data buffer based on whether it is initialized + bool use_columnwise_shape = false; + if (data.dptr != nullptr) { + use_columnwise_shape = false; + } else if (columnwise_data.dptr != nullptr) { + use_columnwise_shape = true; + } else if (!is_shape_trivial(data.shape)) { + use_columnwise_shape = false; + } else if (!is_shape_trivial(columnwise_data.shape)) { + use_columnwise_shape = true; + } + + // Infer shape based on data + if (use_columnwise_shape) { + // Column-wise data is transposed + std::vector ret; + if (!columnwise_data.shape.empty()) { + ret.reserve(columnwise_data.shape.size()); + for (size_t i = 1; i < columnwise_data.shape.size(); i++) { + ret.push_back(columnwise_data.shape[i]); + } + ret.push_back(columnwise_data.shape.front()); } - return shape; - } else { - // NOTE: We may have removed the data pointer from - // data by setting usage. In that case, we return - // the non-null shape. It is our best guess at the most - // recent shape. - return data.shape; + return ret; } - break; + return data.shape; } default: NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\""); - return {}; } } diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 1a901ab82d..62c3a47c86 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -467,15 +467,22 @@ class TensorWrapper { */ TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr, float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr, - const NVTEShape scale_inv_shape = defaultShape, + NVTEShape scale_inv_shape = defaultShape, const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) { tensor_ = nvte_create_tensor(scaling_mode); NVTEBasicTensor data = {dptr, static_cast(dtype), shape}; nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data); - NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape}; + NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, + amax_dptr != nullptr ? defaultShape : emptyShape}; nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax); - NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape}; + NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, + scale_dptr != nullptr ? defaultShape : emptyShape}; nvte_set_tensor_param(&tensor_, kNVTEScale, &scale); + if (scale_inv_dptr == nullptr && scale_inv_shape.ndim == defaultShape.ndim && + scale_inv_shape.ndim == 1 && scale_inv_shape.data[0] == defaultShape.data[0]) { + // Scale-inv pointer has not been provided and shape matches default + scale_inv_shape = emptyShape; + } NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape}; nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv); } @@ -626,7 +633,8 @@ class TensorWrapper { */ const NVTEShape shape() const noexcept { if (tensor_ == nullptr) { - return nvte_make_shape(nullptr, 0); + const size_t zero = 0; + return nvte_make_shape(&zero, 1); } return nvte_tensor_shape(tensor_); } @@ -637,7 +645,8 @@ class TensorWrapper { */ const NVTEShape columnwise_shape() const noexcept { if (tensor_ == nullptr) { - return nvte_make_shape(nullptr, 0); + const size_t zero = 0; + return nvte_make_shape(&zero, 1); } return nvte_tensor_columnwise_shape(tensor_); } @@ -761,7 +770,8 @@ class TensorWrapper { */ const NVTEShape scale_inv_shape() const noexcept { if (tensor_ == nullptr) { - return nvte_make_shape(nullptr, 0); + const size_t zero = 0; + return nvte_make_shape(&zero, 1); } return nvte_tensor_scale_inv_shape(tensor_); } @@ -780,6 +790,7 @@ class TensorWrapper { static constexpr size_t defaultData = 1; static constexpr NVTEShape defaultShape = { {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; private: NVTEShape convertShape(const NVTEShape &s) { return s; } diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 337b165080..70e814e806 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -127,7 +127,13 @@ void TeNormalizationPlan::_build() { template std::vector TeNormalizationPlan::getWorkspaceShape() const { - return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)}; + size_t workspace_size = _launch_params.getTotalWorkspaceBytes(_is_layernorm); + if (workspace_size == 0) { + // Workspace size must not be zero since that corresponds to a + // workspace size query + workspace_size = 1; + } + return {workspace_size}; } template @@ -405,7 +411,13 @@ void CudnnNormalizationPlan::_build() { } std::vector CudnnNormalizationPlan::getWorkspaceShape() const { - return {static_cast(_graph.get_workspace_size())}; + size_t workspace_size = _graph.get_workspace_size(); + if (workspace_size == 0) { + // Workspace size must not be zero since that corresponds to a + // workspace size query + workspace_size = 1; + } + return {workspace_size}; } void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr, diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 5785fd2233..b83ae25f25 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -51,7 +51,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size "RSigma must be 1D tensor with shape (x.shape[0],)."); NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); - if (!workspace->data.shape.empty()) { + if (workspace->data.numel() != 0) { CheckInputTensor(x, "x"); CheckInputTensor(gamma, "gamma"); CheckInputTensor(beta, "beta"); @@ -94,7 +94,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, gamma_in_weight_dtype); - if (workspace->data.shape.empty()) { + if (workspace->data.numel() == 0) { workspace->data.shape = plan->getWorkspaceShape(); workspace->data.dtype = DType::kByte; return; @@ -146,7 +146,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te NVTE_CHECK(dbeta->data.shape == gamma.data.shape); NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); - if (!workspace->data.shape.empty()) { + if (workspace->data.numel() != 0) { CheckInputTensor(dz, "dz"); CheckInputTensor(x, "x"); CheckInputTensor(mu, "mu"); @@ -179,7 +179,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, gamma_in_weight_dtype); - if (workspace->data.shape.empty()) { + if (workspace->data.numel() == 0) { workspace->data.shape = plan->getWorkspaceShape(); workspace->data.dtype = DType::kByte; return; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index a3b05f7a29..ea6c972bf5 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -39,7 +39,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens "RSigma must be 1D tensor with shape (x.shape[0],)."); NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); - if (!workspace->data.shape.empty()) { + if (workspace->data.numel() != 0) { CheckInputTensor(x, "x"); CheckInputTensor(gamma, "gamma"); @@ -79,7 +79,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, gamma_in_weight_dtype); - if (workspace->data.shape.empty()) { + if (workspace->data.numel() == 0) { workspace->data.shape = plan->getWorkspaceShape(); workspace->data.dtype = DType::kByte; return; @@ -125,7 +125,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const NVTE_CHECK(dgamma->data.shape == gamma.data.shape); NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - if (!workspace->data.shape.empty()) { + if (workspace->data.numel() != 0) { CheckInputTensor(dz, "dz"); CheckInputTensor(x, "x"); CheckInputTensor(rsigma, "rsigma"); @@ -156,7 +156,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, gamma_in_weight_dtype); - if (workspace->data.shape.empty()) { + if (workspace->data.numel() == 0) { workspace->data.shape = plan->getWorkspaceShape(); workspace->data.dtype = DType::kByte; return; @@ -191,7 +191,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const NVTE_CHECK(dgamma->data.shape == gamma.data.shape); NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); - if (!workspace->data.shape.empty()) { + if (workspace->data.numel() != 0) { CheckInputTensor(dz, "dz"); CheckInputTensor(x, "x"); CheckInputTensor(add, "add"); @@ -222,7 +222,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, gamma_in_weight_dtype); - if (workspace->data.shape.empty()) { + if (workspace->data.numel() == 0) { workspace->data.shape = plan->getWorkspaceShape(); workspace->data.dtype = DType::kByte; return; diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 06735e3104..72cb3e1a71 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -332,68 +332,120 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ } // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { - NVTE_CHECK( - input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); - NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()), - "Input tensor has invalid dtype (", to_string(input->dtype()), ")."); - - // Do nothing if tensor is empty - if (input->data.numel() == 0) { - return; - } + // Check scaling mode + const auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, + "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); + // Check tensors CheckInputTensor(*input, "scaling_factor_input"); CheckInputTensor(*output, "scaling_factor_output"); + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: + NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ", + to_string(input->dtype()), ")."); + break; + case NVTE_NVFP4_1D_SCALING: + NVTE_CHECK(is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP4, got ", + to_string(input->dtype()), ")."); + break; + default: + NVTE_ERROR("Invalid scaling mode"); + } - auto& scaling_mode = input->scaling_mode; - NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, - "Unsupported scaling mode for swizzling."); - - bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; + // Check if scaling factors are non-trivial + const bool has_rowwise_scale_inv = + (input->scale_inv.dptr != nullptr && input->scale_inv.numel() != 0); + const bool has_columnwise_scale_inv = + (input->columnwise_scale_inv.dptr != nullptr && input->columnwise_scale_inv.numel() != 0); + NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv, + "Input tensor has both row-wise and column-wise scaling factors"); + if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) { + return; + } - // 1D block scaling, row-wise or colum-wise - int m, k; - if (input->has_data()) { - m = input->scale_inv.shape[0]; - k = input->scale_inv.shape[1]; - } else { - if (nvfp4) { - m = input->columnwise_scale_inv.shape[0]; - k = input->columnwise_scale_inv.shape[1]; - } else { - m = input->columnwise_scale_inv.shape[1]; - k = input->columnwise_scale_inv.shape[0]; + // Deduce tensor dims + int m{0}, k{0}; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + if (has_rowwise_scale_inv) { + NVTE_CHECK(input->scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else if (has_columnwise_scale_inv) { + NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, + "."); + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + if (has_rowwise_scale_inv) { + NVTE_CHECK(input->scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->scale_inv.shape, "."); + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else if (has_columnwise_scale_inv) { + NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2, + "Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape, + "."); + m = input->columnwise_scale_inv.shape[0]; + k = input->columnwise_scale_inv.shape[1]; + } + break; } + default: + NVTE_ERROR("Invalid scaling mode"); } + // Check dims constexpr int SF_TILE_DIM_M = 128; constexpr int SF_TILE_DIM_K = 4; - NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); - NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - if (output->has_data()) { - NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), - output->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); + + // Check that output tensor matches input tensor + if (has_rowwise_scale_inv) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, + "Output tensor does not have row-wise scaling factors."); + NVTE_CHECK(m * k == output->scale_inv.numel(), "Expected output tensor to have ", m * k, + " row-wise scaling factors, but got shape=", output->scale_inv.shape, "."); } - if (output->has_columnwise_data()) { - NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), - output->columnwise_scale_inv.shape.end(), 1, - std::multiplies()), - "Input.columnwise_scale_inv size is not equal to " - "Output.columnwise_scale_inv size!"); + if (has_columnwise_scale_inv) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Output tensor does not have column-wise scaling factors."); + NVTE_CHECK( + m * k == output->columnwise_scale_inv.numel(), "Expected output tensor to have ", m * k, + " column-wise scaling factors, but got shape=", output->columnwise_scale_inv.shape, "."); } - int num_tiles_m = m / SF_TILE_DIM_M; - int num_tiles_k = k / SF_TILE_DIM_K; + // Choose swizzle implementation + bool rowwise_swizzle{false}, columnwise_swizzle{false}; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + rowwise_swizzle = has_rowwise_scale_inv; + columnwise_swizzle = has_columnwise_scale_inv; + break; + } + case NVTE_NVFP4_1D_SCALING: { + // NVFP4 column-wise data is transposed, so row-wise and + // column-wise scales have same swizzling format + rowwise_swizzle = true; + columnwise_swizzle = false; + break; + } + default: + NVTE_ERROR("Invalid scaling mode"); + } - // For NVFP4, the scale inverse for tranposed data needs rowwise swizzle. - const bool rowwise_swizzle = input->has_data() || nvfp4; - const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4; + const dim3 block_size(TB_DIM, TB_DIM); + const int num_tiles_m = m / SF_TILE_DIM_M; + const int num_tiles_k = k / SF_TILE_DIM_K; - dim3 block_size(TB_DIM, TB_DIM); + // Perform row-wise swizzle if (rowwise_swizzle) { int vec_load_size = (num_tiles_k - 1) % 4 + 1; /* there is no int3 and misaligned if using int4/int2 */ @@ -404,18 +456,30 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int original_M, original_K; void *input_scale_inv_ptr, *output_scale_inv_ptr; - - if (!nvfp4 || input->has_data()) { - int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; - original_M = input->flat_first_dim(); - original_K = input->flat_last_dim() / block_scale_size; - input_scale_inv_ptr = input->scale_inv.dptr; - output_scale_inv_ptr = output->scale_inv.dptr; - } else { - original_M = input->flat_last_dim(); - original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE; - input_scale_inv_ptr = input->columnwise_scale_inv.dptr; - output_scale_inv_ptr = output->columnwise_scale_inv.dptr; + switch (scaling_mode) { + case NVTE_MXFP8_1D_SCALING: { + original_M = input->flat_first_dim(); + original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + break; + } + case NVTE_NVFP4_1D_SCALING: { + if (has_rowwise_scale_inv) { + original_M = input->flat_first_dim(); + original_K = input->flat_last_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + } else if (has_columnwise_scale_inv) { + original_M = input->flat_last_dim(); + original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->columnwise_scale_inv.dptr; + output_scale_inv_ptr = output->columnwise_scale_inv.dptr; + } + break; + } + default: + NVTE_ERROR("Invalid scaling mode"); } switch (vec_load_size) { @@ -447,7 +511,10 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s NVTE_ERROR("Not valid vec_load_size."); break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } + + // Perform column-wise swizzle if (columnwise_swizzle) { int vec_load_size = (num_tiles_m - 1) % 4 + 1; if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ @@ -456,8 +523,6 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); const int original_M = input->flat_last_dim(); const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; - // NVFP4 shouldn't end up here because it only needs rowwise swizzle - NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); switch (vec_load_size) { case 4: @@ -491,9 +556,8 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s NVTE_ERROR("Not valid vec_load_size."); break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } - - NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -595,14 +659,15 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, (is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)), "Not implemented scaling mode " + to_string(scaling_mode) + "."); // We don't allow empty tensors. They should be filtered out before calling this function. - if (input[i]->data.numel() == 0) { - NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty."); - } + NVTE_CHECK(input[i]->numel() != 0, "Tensor input[", i, "] is empty."); CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); - all_has_data &= input[i]->has_data(); - all_has_columnwise_data &= input[i]->has_columnwise_data(); - all_nvfp4 &= is_nvfp4_scaling(scaling_mode); + all_has_data = + (all_has_data && input[i]->scale_inv.dptr != nullptr && input[i]->scale_inv.numel() != 0); + all_has_columnwise_data = + (all_has_columnwise_data && input[i]->columnwise_scale_inv.dptr != nullptr && + input[i]->columnwise_scale_inv.numel() != 0); + all_nvfp4 = all_nvfp4 && is_nvfp4_scaling(scaling_mode); } NVTE_CHECK(all_has_data || all_has_columnwise_data, "All tensors should have data or columnwise data."); @@ -644,18 +709,19 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - if (output[i]->has_data()) { - NVTE_CHECK( - m * k == std::accumulate(output[i]->scale_inv.shape.begin(), - output[i]->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); + if (all_has_data) { + NVTE_CHECK(output[i]->scale_inv.dptr != nullptr, "Output tensor ", i, + " does not have row-wise scaling factors."); + NVTE_CHECK(m * k == output[i]->scale_inv.numel(), "Expected output tensor ", i, " to have ", + m * k, + " column-wise scaling factors, but got shape=", output[i]->scale_inv.shape, "."); } - if (output[i]->has_columnwise_data()) { - NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), - output[i]->columnwise_scale_inv.shape.end(), 1, - std::multiplies()), - "Input.columnwise_scale_inv size is not equal to " - "Output.columnwise_scale_inv size!"); + if (all_has_columnwise_data) { + NVTE_CHECK(output[i]->columnwise_scale_inv.dptr != nullptr, "Output tensor ", i, + " does not have column-wise scaling factors."); + NVTE_CHECK(m * k == output[i]->columnwise_scale_inv.numel(), "Expected output tensor ", i, + " to have ", m * k, " column-wise scaling factors, but got shape=", + output[i]->columnwise_scale_inv.shape, "."); } int num_tiles_k = k / SF_TILE_DIM_K; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 35e8b683ad..c3a2192dc4 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -88,15 +88,30 @@ void CheckNoopTensor(const Tensor &t, const std::string &name) { void CheckScaleTensorShape(const Tensor &t, const std::string &name) { NVTE_CHECK(t.scaling_mode != NVTE_INVALID_SCALING, "Invalid scaling mode!"); if (is_tensor_scaling(t.scaling_mode)) { - // per-tensor scaling - if (t.has_data()) { - NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name, - "\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")"); - } - if (t.has_columnwise_data()) { - NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name, - "\" has invalid columnwise_scale_inv shape (expected (1), got ", - t.columnwise_scale_inv.shape, ")"); + if (is_fp8_dtype(t.dtype())) { + // FP8 tensor with tensor scaling + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid scale_inv shape (expected 1 entry, got ", t.scale_inv.shape, + ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected 1 entry, got ", + t.columnwise_scale_inv.shape, ")"); + } + } else { + // High-precision tensor + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.numel() == 0, "Tensor \"", name, + "\" has invalid scale_inv shape (expected 0 entries, got ", t.scale_inv.shape, + ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.numel() == 0, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected 0 entries, got ", + t.columnwise_scale_inv.shape, ")"); + } } } else { if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { @@ -524,7 +539,8 @@ void *nvte_tensor_columnwise_scale_inv(const NVTETensor tensor) { NVTEShape nvte_tensor_scale_inv_shape(const NVTETensor tensor) { auto *t = transformer_engine::convertNVTETensor(tensor); if (t == nullptr) { - return nvte_make_shape(nullptr, 0); + const size_t zero = 0; + return nvte_make_shape(&zero, 1); } return nvte_make_shape(t->scale_inv.shape.data(), t->scale_inv.shape.size()); } @@ -563,7 +579,8 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam param_name) { if (tensor == nullptr) { - return {nullptr, kNVTEFloat32, nvte_make_shape(nullptr, 0)}; + const size_t zero = 0; + return {nullptr, kNVTEFloat32, nvte_make_shape(&zero, 1)}; } const auto &t = *transformer_engine::convertNVTETensorCheck(tensor); switch (param_name) { diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index 6329e79ae7..a04ab902ca 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -198,8 +198,6 @@ void populate_cast_transpose_dbias_workspace_config(const Tensor &cast_output, / workspace->data.dtype); const size_t required_size = get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32); - NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", - num_rows_partial_dbias, ",", row_length, "), found ())"); NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), "; found dims=", workspace->data.shape, diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index 3c51ce3dab..75fa07a5b3 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -388,8 +388,6 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ workspace->data.dtype); const size_t required_size = get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32); - NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (", - num_rows_partial_dbias, ",", row_length, "), found ())"); NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), "; found dims=", workspace->data.shape, diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 15404ad9a6..cc17c89a82 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -43,10 +43,10 @@ bool is_low_precision(const DType type) { std::vector getGemmOutputShape(const NVTEShape& A_shape, const bool transa, const NVTEShape& B_shape, const bool transb) { // Flatten outer dims to get 2D matrices - const size_t A0 = product(A_shape, 0, A_shape.ndim - 1); - const size_t A1 = A_shape.data[A_shape.ndim - 1]; - const size_t B0 = product(B_shape, 0, B_shape.ndim - 1); - const size_t B1 = B_shape.data[B_shape.ndim - 1]; + const size_t A0 = A_shape.ndim > 0 ? product(A_shape, 0, A_shape.ndim - 1) : 1; + const size_t A1 = A_shape.ndim > 0 ? A_shape.data[A_shape.ndim - 1] : 1; + const size_t B0 = B_shape.ndim > 0 ? product(B_shape, 0, B_shape.ndim - 1) : 1; + const size_t B1 = B_shape.ndim > 0 ? B_shape.data[B_shape.ndim - 1] : 1; // Check matrix dims NVTE_CHECK((transa ? A1 : A0) == (transb ? B0 : B1), "Invalid matrix dimensions for GEMM (A=(", diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 134185ac82..ce547d302e 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -142,13 +142,20 @@ std::optional multi_tensor_swizzle_scaling_factors( auto& tensor = tensors[i]; void* scale_inv_dptr = scale_inv_dptrs[i]; void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); - // auto input_shape = nvte_shape_to_vector(tensor.shape()); + + // Empty tensors don't require scale swizzling + if (tensor.numel() == 0) { + continue; + } + + // Tensor shape NVTEShape nvte_input_shape; if (rowwise) { nvte_input_shape = tensor.shape(); } else { nvte_input_shape = tensor.get_columnwise_data().shape; } + auto input_shape = nvte_shape_to_vector(nvte_input_shape); // Reconstruct input only to avoid swizzling both directions if not needed. // Use any 8 bit type, it's irrelevant. @@ -202,14 +209,14 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp size_t data_flat_last_dim = 1; if (rowwise) { data = input.get_rowwise_data(); - for (int i = 0; i < data.shape.ndim - 1; ++i) { + for (size_t i = 0; i < data.shape.ndim - 1; ++i) { data_flat_first_dim *= data.shape.data[i]; } data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; } else { data = input.get_columnwise_data(); data_flat_first_dim = data.shape.data[0]; - for (int i = 1; i < data.shape.ndim; ++i) { + for (size_t i = 1; i < data.shape.ndim; ++i) { data_flat_last_dim *= data.shape.data[i]; } }