Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions backends/cuda/runtime/shims/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,6 @@ AOTITorchError aoti_torch_empty_strided(

// This requires us to reserve CUDA memory and put it into a ETensor
void* ptr;
int64_t numel = 1;
for (int64_t i = 0; i < ndim; i++) {
numel *= sizes_ptr[i];
}

ET_CHECK_OK_OR_RETURN_ERROR(validate_dtype(dtype));

Expand All @@ -223,7 +219,29 @@ AOTITorchError aoti_torch_empty_strided(
InvalidArgument,
"Invalid element size for dtype: %d",
dtype);
int64_t nbytes = numel * element_size;

// Calculate storage size based on strides, matching PyTorch's behavior
// This is critical when sizes and strides don't match the expected contiguous
// layout Reference: PyTorch's computeStorageNbytes in EmptyTensor.cpp
int64_t storage_size = 1; // storage offset (0) + 1
for (int64_t i = 0; i < ndim; i++) {
if (sizes_ptr[i] == 0) {
storage_size = 0;
break;
}
// For each dimension, add stride[i] * (size[i] - 1)
// This gives us the maximum offset in that dimension
int64_t stride_i = (strides_ptr != nullptr) ? strides_ptr[i] : 0;
if (strides_ptr == nullptr) {
// Calculate contiguous stride if not provided
stride_i = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int64_t stride_i = (strides_ptr != nullptr) ? strides_ptr[i] : 0;
if (strides_ptr == nullptr) {
// Calculate contiguous stride if not provided
stride_i = 1;
int64_t stride_i = (strides_ptr != nullptr) ? strides_ptr[i] : 1;
if (strides_ptr == nullptr) {
// Calculate contiguous stride if not provided

for (int64_t j = i + 1; j < ndim; j++) {
stride_i *= sizes_ptr[j];
}
}
storage_size += stride_i * (sizes_ptr[i] - 1);
}
int64_t nbytes = storage_size * element_size;

if (device_type == static_cast<int32_t>(SupportedDevices::CUDA)) {
ET_CUDA_CHECK_OR_RETURN_ERROR(
Expand Down Expand Up @@ -259,7 +277,6 @@ AOTITorchError aoti_torch_empty_strided(

// This tensor owns the memory it allocated, set reference count to 1
memory_to_n_tensor[ptr] = 1;

return Error::Ok;
}

Expand Down
111 changes: 107 additions & 4 deletions backends/cuda/runtime/shims/tests/test_aoti_torch_empty_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,11 +509,11 @@ TEST_F(AOTITorchEmptyStridedTest, ZeroElementTensor) {
EXPECT_EQ(sizes_ptr[2], 3);
}

// Test different data types (only float32 is currently supported)
// Test different data types (currently we support bf16, fp32 and int32)
TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
std::vector<int64_t> sizes = {2, 3};

// Test float32 (dtype 6) - currently the only supported type
// Test float32 (dtype 6) - one of the supported types
Tensor* tensor_float32;
AOTITorchError error = aoti_torch_empty_strided(
sizes.size(),
Expand All @@ -527,7 +527,7 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
EXPECT_EQ(error, Error::Ok);
EXPECT_NE(tensor_float32, nullptr);

// Test unsupported data types should return error
// Test int32 (dtype 3) - one of the supported types
Tensor* tensor_int32;
error = aoti_torch_empty_strided(
sizes.size(),
Expand All @@ -538,7 +538,8 @@ TEST_F(AOTITorchEmptyStridedTest, DifferentDataTypes) {
0, // device index
&tensor_int32);

EXPECT_EQ(error, Error::InvalidArgument); // Should fail for unsupported dtype
EXPECT_EQ(error, Error::Ok);
EXPECT_NE(tensor_int32, nullptr);

// Test another unsupported data type
Tensor* tensor_float64;
Expand Down Expand Up @@ -586,3 +587,105 @@ TEST_F(AOTITorchEmptyStridedTest, MultiDimensionalTensors) {
EXPECT_EQ(tensor_5d->size(3), 4);
EXPECT_EQ(tensor_5d->size(4), 5);
}

// Test incontiguous tensor creation - transpose-like layout
TEST_F(AOTITorchEmptyStridedTest, IncontiguousTransposeLayout) {
// Create a tensor with transpose-like strides (column-major)
// For a 3x4 tensor in column-major order, strides should be [1, 3]
// This means each row step is 1, and each column step is 3
std::vector<int64_t> sizes = {3, 4};
std::vector<int64_t> strides = {1, 3}; // Column-major (incontiguous)

Tensor* tensor;
AOTITorchError error = aoti_torch_empty_strided(
sizes.size(),
sizes.data(),
strides.data(),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDevices::CUDA),
0, // device index
&tensor);

EXPECT_EQ(error, Error::Ok);
EXPECT_NE(tensor, nullptr);

// Verify tensor properties
EXPECT_EQ(tensor->dim(), 2);
EXPECT_EQ(tensor->size(0), 3);
EXPECT_EQ(tensor->size(1), 4);

// Verify the strides are what we specified
int64_t* strides_ptr;
EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok);
EXPECT_EQ(strides_ptr[0], 1); // Column-major stride for dimension 0
EXPECT_EQ(strides_ptr[1], 3); // Column-major stride for dimension 1

// Verify that memory was allocated correctly for incontiguous layout
// Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] -
// 1) + 1 = 1 * (3 - 1) + 3 * (4 - 1) + 1 = 1 * 2 + 3 * 3 + 1 = 2 + 9 + 1 = 12
// elements Total bytes = 12 * 4 = 48 bytes (for float32)
EXPECT_EQ(tensor->numel(), 12); // numel is still 3*4=12 for logical shape

// The tensor should be accessible and writable
void* data_ptr = tensor->mutable_data_ptr();
EXPECT_NE(data_ptr, nullptr);

// Verify we can use CUDA to write to the memory
std::vector<float> test_data(12, 1.0f);
cudaError_t cuda_err = cudaMemcpy(
data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice);
EXPECT_EQ(cuda_err, cudaSuccess);
}

// Test incontiguous tensor creation - expanded/broadcasted stride pattern
TEST_F(AOTITorchEmptyStridedTest, IncontiguousExpandedStrides) {
// Create a tensor with expanded strides (simulating broadcasting)
// A 2x3x4 tensor where the first dimension has stride 0 (expanded)
// This creates a tensor where the first dimension is "broadcasted"
std::vector<int64_t> sizes = {2, 3, 4};
std::vector<int64_t> strides = {0, 4, 1}; // First dimension has stride 0

Tensor* tensor;
AOTITorchError error = aoti_torch_empty_strided(
sizes.size(),
sizes.data(),
strides.data(),
static_cast<int32_t>(SupportedDTypes::FLOAT32),
static_cast<int32_t>(SupportedDevices::CUDA),
0, // device index
&tensor);

EXPECT_EQ(error, Error::Ok);
EXPECT_NE(tensor, nullptr);

// Verify tensor properties
EXPECT_EQ(tensor->dim(), 3);
EXPECT_EQ(tensor->size(0), 2);
EXPECT_EQ(tensor->size(1), 3);
EXPECT_EQ(tensor->size(2), 4);

// Verify the strides are what we specified
int64_t* strides_ptr;
EXPECT_EQ(aoti_torch_get_strides(tensor, &strides_ptr), Error::Ok);
EXPECT_EQ(strides_ptr[0], 0); // Expanded dimension stride
EXPECT_EQ(strides_ptr[1], 4);
EXPECT_EQ(strides_ptr[2], 1);

// Verify that memory was allocated correctly for this incontiguous layout
// Storage size should be: stride[0] * (size[0] - 1) + stride[1] * (size[1] -
// 1) + stride[2] * (size[2] - 1) + 1 = 0 * (2 - 1) + 4 * (3 - 1) + 1 * (4 -
// 1) + 1 = 0 + 8 + 3 + 1 = 12 elements Note: numel() returns logical number
// of elements (2*3*4=24), not storage size
EXPECT_EQ(tensor->numel(), 24); // Logical numel is 2*3*4=24

// The tensor should be accessible and writable
void* data_ptr = tensor->mutable_data_ptr();
EXPECT_NE(data_ptr, nullptr);

// Verify we can use CUDA to write to the allocated memory
// We only need to allocate 12 elements (storage size), not 24
std::vector<float> test_data(12, 2.0f);
cudaError_t cuda_err = cudaMemcpy(
data_ptr, test_data.data(), 12 * sizeof(float), cudaMemcpyHostToDevice);
EXPECT_EQ(cuda_err, cudaSuccess);
}
5 changes: 5 additions & 0 deletions extension/tensor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ target_include_directories(
)
target_compile_options(extension_tensor PUBLIC ${_common_compile_options})

# Define USE_CUDA_BACKEND when building with CUDA backend support
if(EXECUTORCH_BUILD_CUDA)
target_compile_definitions(extension_tensor PUBLIC USE_CUDA_BACKEND)
endif()

# Install libraries
install(
TARGETS extension_tensor
Expand Down
5 changes: 5 additions & 0 deletions extension/tensor/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ def define_common_targets():
for aten_mode in get_aten_mode_options():
aten_suffix = ("_aten" if aten_mode else "")

# Check if USE_CUDA_BACKEND flag is set via build config
use_cuda_backend = native.read_config("executorch", "use_cuda_backend", "false") == "true"
preprocessor_flags = ["-DUSE_CUDA_BACKEND"] if use_cuda_backend else []

runtime.cxx_library(
name = "tensor" + aten_suffix,
srcs = [
Expand All @@ -25,6 +29,7 @@ def define_common_targets():
visibility = [
"@EXECUTORCH_CLIENTS",
],
preprocessor_flags = preprocessor_flags,
deps = [
"//executorch/runtime/core/exec_aten/util:dim_order_util" + aten_suffix,
"//executorch/runtime/core/exec_aten/util:tensor_util" + aten_suffix,
Expand Down
7 changes: 6 additions & 1 deletion extension/tensor/tensor_ptr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ TensorPtr make_tensor_ptr(
});
}
}

// Skip stride calculation and incontiguous tensor check for CUDA backend since
// AOTI-CUDA handles both contiguous and incontiguous tensors. This will be
// removed after SlimTensor migration.
#ifndef USE_CUDA_BACKEND
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are absolutely right. Removed.

std::vector<executorch::aten::StridesType> computed_strides(dim);

auto error = runtime::dim_order_to_stride(
Expand All @@ -98,8 +103,8 @@ TensorPtr make_tensor_ptr(
sizes[i]);
}
}

strides = std::move(computed_strides);
#endif // USE_CUDA_BACKEND

#ifndef USE_ATEN_LIB
executorch::aten::TensorImpl tensor_impl(
Expand Down
Loading