Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 84 additions & 48 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ struct SimpleTensor {
std::vector<size_t> shape;
DType dtype;

SimpleTensor(void *dptr, const std::vector<size_t> &shape, DType dtype)
: dptr(dptr), shape(shape), dtype(dtype) {}
SimpleTensor(void *dptr, std::vector<size_t> shape, DType dtype)
: dptr{dptr}, shape{std::move(shape)}, dtype{dtype} {}

SimpleTensor(const NVTEBasicTensor &tensor) // NOLINT
: dptr(tensor.data_ptr),
shape(tensor.shape.data, tensor.shape.data + tensor.shape.ndim),
dtype(static_cast<DType>(tensor.dtype)) {}

SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
SimpleTensor() : SimpleTensor(nullptr, std::vector<size_t>{0}, DType::kFloat32) {}

operator NVTEBasicTensor() const {
return {dptr, static_cast<NVTEDType>(dtype),
Expand All @@ -104,7 +104,8 @@ struct SimpleTensor {

void clear() {
dptr = nullptr;
shape.resize(0);
shape.resize(1);
shape[0] = 0;
dtype = DType::kFloat32;
}
};
Expand All @@ -125,11 +126,11 @@ struct Tensor {
Tensor()
: data(),
columnwise_data(),
amax(nullptr, {1}, DType::kFloat32),
columnwise_amax(nullptr, {1}, DType::kFloat32),
scale(nullptr, {1}, DType::kFloat32),
scale_inv(nullptr, {1}, DType::kFloat32),
columnwise_scale_inv(nullptr, {1}, DType::kFloat32),
amax(),
columnwise_amax(),
scale(),
scale_inv(),
columnwise_scale_inv(),
scaling_mode(NVTE_DELAYED_TENSOR_SCALING),
nvte_tensor(0) {}

Expand All @@ -154,11 +155,10 @@ struct Tensor {
return acc;
}

bool has_data() const noexcept { return data.dptr != nullptr; }
bool has_data() const noexcept { return data.dptr != nullptr && data.numel() != 0; }

// Check for size (not just pointer) for 0-dim or no token cases.
bool has_columnwise_data() const noexcept {
return columnwise_data.dptr != nullptr || columnwise_data.shape.size() != 0;
return columnwise_data.dptr != nullptr && columnwise_data.numel() != 0;
}

DType dtype() const {
Expand All @@ -169,34 +169,52 @@ struct Tensor {
}

size_t dim() const {
if (!has_data() && has_columnwise_data()) {
// Check whether a tensor shape matches an uninitialized tensor
auto is_shape_trivial = [](const std::vector<size_t> &shape) -> bool {
return shape.size() == 1 && shape[0] == 0;
};

// Choose data buffer based on whether it is initialized
// Note: Logically each tensor format interprets its data
// differently, but for simplicity we assume they all use row-wise
// and column-wise data similarly.
bool use_columnwise_shape = false;
if (data.dptr != nullptr) {
use_columnwise_shape = false;
} else if (columnwise_data.dptr != nullptr) {
use_columnwise_shape = true;
} else if (!is_shape_trivial(data.shape)) {
use_columnwise_shape = false;
} else if (!is_shape_trivial(columnwise_data.shape)) {
use_columnwise_shape = true;
}

// Infer number of dims based on data
if (use_columnwise_shape) {
return columnwise_data.shape.size();
} else {
return data.shape.size();
}
return data.shape.size();
}

std::vector<size_t> shape() const {
/* Note: We sometimes experience spurious compiler errors
* (-Wstringop-overflow) from this function. It appears that GCC
* has some bugs with std::vector (see
* https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569).
*/
// Check whether a tensor shape matches an uninitialized tensor
auto is_shape_trivial = [](const std::vector<size_t> &shape) -> bool {
return shape.size() == 1 && shape.front() == 0;
};

// Each tensor format interprets its data differently
switch (scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING:
case NVTE_NVFP4_1D_SCALING: {
// Choose data buffer based on whether it is initialized
// Note: Uninitialized buffers currently have shape=[].
// However, this is logically incorrect. 0-D tensors have 1
// entry, and uninitialized tensors should have shape=[0].
bool use_columnwise_shape = false;
if (data.dptr != nullptr) {
use_columnwise_shape = false;
} else if (columnwise_data.dptr != nullptr) {
use_columnwise_shape = true;
} else if (data.shape.size() != 0) {
} else if (!is_shape_trivial(data.shape)) {
use_columnwise_shape = false;
} else if (columnwise_data.shape.size() != 0) {
} else if (!is_shape_trivial(columnwise_data.shape)) {
use_columnwise_shape = true;
}

Expand All @@ -215,38 +233,56 @@ struct Tensor {
}
return data.shape;
}
case NVTE_MXFP8_1D_SCALING:
if (!has_data() && has_columnwise_data()) {
case NVTE_MXFP8_1D_SCALING: {
// Choose data buffer based on whether it is initialized
bool use_columnwise_shape = false;
if (data.dptr != nullptr) {
use_columnwise_shape = false;
} else if (columnwise_data.dptr != nullptr) {
use_columnwise_shape = true;
} else if (!is_shape_trivial(data.shape)) {
use_columnwise_shape = false;
} else if (!is_shape_trivial(columnwise_data.shape)) {
use_columnwise_shape = true;
}

// Infer shape based on data
if (use_columnwise_shape) {
return columnwise_data.shape;
} else {
return data.shape;
}
break;
return data.shape;
}
case NVTE_BLOCK_SCALING_1D:
case NVTE_BLOCK_SCALING_2D: {
if (!has_data() && has_columnwise_data()) {
std::vector<size_t> shape;
size_t ndim = columnwise_data.shape.size();
shape.reserve(ndim);
for (size_t i = 0; i + 1 < ndim; ++i) {
shape.push_back(columnwise_data.shape[i + 1]);
}
if (ndim > 0) {
shape.push_back(columnwise_data.shape[0]);
// Choose data buffer based on whether it is initialized
bool use_columnwise_shape = false;
if (data.dptr != nullptr) {
use_columnwise_shape = false;
} else if (columnwise_data.dptr != nullptr) {
use_columnwise_shape = true;
} else if (!is_shape_trivial(data.shape)) {
use_columnwise_shape = false;
} else if (!is_shape_trivial(columnwise_data.shape)) {
use_columnwise_shape = true;
}

// Infer shape based on data
if (use_columnwise_shape) {
// Column-wise data is transposed
std::vector<size_t> ret;
if (!columnwise_data.shape.empty()) {
ret.reserve(columnwise_data.shape.size());
for (size_t i = 1; i < columnwise_data.shape.size(); i++) {
ret.push_back(columnwise_data.shape[i]);
}
ret.push_back(columnwise_data.shape.front());
}
return shape;
} else {
// NOTE: We may have removed the data pointer from
// data by setting usage. In that case, we return
// the non-null shape. It is our best guess at the most
// recent shape.
return data.shape;
return ret;
}
break;
return data.shape;
}
default:
NVTE_ERROR("Cannot parse tensor shape with scaling mode \"", to_string(scaling_mode), "\"");
return {};
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,15 +467,22 @@ class TensorWrapper {
*/
TensorWrapper(void *dptr, const NVTEShape &shape, const DType dtype, float *amax_dptr = nullptr,
float *scale_dptr = nullptr, float *scale_inv_dptr = nullptr,
const NVTEShape scale_inv_shape = defaultShape,
NVTEShape scale_inv_shape = defaultShape,
const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) {
tensor_ = nvte_create_tensor(scaling_mode);
NVTEBasicTensor data = {dptr, static_cast<NVTEDType>(dtype), shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseData, &data);
NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32, defaultShape};
NVTEBasicTensor amax = {amax_dptr, kNVTEFloat32,
amax_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param(&tensor_, kNVTEAmax, &amax);
NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32, defaultShape};
NVTEBasicTensor scale = {scale_dptr, kNVTEFloat32,
scale_dptr != nullptr ? defaultShape : emptyShape};
nvte_set_tensor_param(&tensor_, kNVTEScale, &scale);
if (scale_inv_dptr == nullptr && scale_inv_shape.ndim == defaultShape.ndim &&
scale_inv_shape.ndim == 1 && scale_inv_shape.data[0] == defaultShape.data[0]) {
// Scale-inv pointer has not been provided and shape matches default
scale_inv_shape = emptyShape;
}
NVTEBasicTensor scale_inv = {scale_inv_dptr, kNVTEFloat32, scale_inv_shape};
nvte_set_tensor_param(&tensor_, kNVTERowwiseScaleInv, &scale_inv);
}
Expand Down Expand Up @@ -626,7 +633,8 @@ class TensorWrapper {
*/
const NVTEShape shape() const noexcept {
if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
const size_t zero = 0;
return nvte_make_shape(&zero, 1);
}
return nvte_tensor_shape(tensor_);
}
Expand All @@ -637,7 +645,8 @@ class TensorWrapper {
*/
const NVTEShape columnwise_shape() const noexcept {
if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
const size_t zero = 0;
return nvte_make_shape(&zero, 1);
}
return nvte_tensor_columnwise_shape(tensor_);
}
Expand Down Expand Up @@ -761,7 +770,8 @@ class TensorWrapper {
*/
const NVTEShape scale_inv_shape() const noexcept {
if (tensor_ == nullptr) {
return nvte_make_shape(nullptr, 0);
const size_t zero = 0;
return nvte_make_shape(&zero, 1);
}
return nvte_tensor_scale_inv_shape(tensor_);
}
Expand All @@ -780,6 +790,7 @@ class TensorWrapper {
static constexpr size_t defaultData = 1;
static constexpr NVTEShape defaultShape = {
{defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};
static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1};

private:
NVTEShape convertShape(const NVTEShape &s) { return s; }
Expand Down
16 changes: 14 additions & 2 deletions transformer_engine/common/normalization/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,13 @@ void TeNormalizationPlan<KernelParamsType>::_build() {

template <typename KernelParamsType>
std::vector<size_t> TeNormalizationPlan<KernelParamsType>::getWorkspaceShape() const {
return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)};
size_t workspace_size = _launch_params.getTotalWorkspaceBytes(_is_layernorm);
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
}

template <typename KernelParamsType>
Expand Down Expand Up @@ -405,7 +411,13 @@ void CudnnNormalizationPlan::_build() {
}

std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
return {static_cast<size_t>(_graph.get_workspace_size())};
size_t workspace_size = _graph.get_workspace_size();
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
}

void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void* beta_dptr,
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/normalization/layernorm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
"RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");

if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
Expand Down Expand Up @@ -94,7 +94,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);

if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
Expand Down Expand Up @@ -146,7 +146,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);

if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
Expand Down Expand Up @@ -179,7 +179,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);

if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
"RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");

if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");

Expand Down Expand Up @@ -79,7 +79,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);

if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
Expand Down Expand Up @@ -125,7 +125,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);

if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
Expand Down Expand Up @@ -156,7 +156,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);

if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
Expand Down Expand Up @@ -191,7 +191,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);

if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(add, "add");
Expand Down Expand Up @@ -222,7 +222,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);

if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
Expand Down
Loading
Loading