Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7282131
Make to_sparse_semi_structured_cutlass_sm9x ABI stable
jerryzh168 Jan 26, 2026
6904b15
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 27, 2026
91d707e
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 27, 2026
19300fa
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 27, 2026
864de15
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 27, 2026
f75c166
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 27, 2026
c8bb8e5
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
1c86c1b
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
92f81f6
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
22ea328
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
950e5bc
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
56b67ff
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
b9aeb0e
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
840860d
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
9c8cf49
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
a6cd36d
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
e5e8cff
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
232e364
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
e47b77d
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
0c155e6
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
e7441f9
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
9fa4849
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
cfb35fe
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
b6d8f09
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
3a88f76
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
4422455
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
4e90861
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 28, 2026
80ac545
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 28, 2026
0dd28f8
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 29, 2026
60850d2
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 29, 2026
ef5f9ee
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 29, 2026
4611c8e
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 29, 2026
68faa64
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 30, 2026
66e8ab8
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 30, 2026
a5c3ff1
Update base for Update on "Make to_sparse_semi_structured_cutlass_sm9…
jerryzh168 Jan 30, 2026
82a9537
Update on "Make to_sparse_semi_structured_cutlass_sm9x ABI stable"
jerryzh168 Jan 30, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/1xH100_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
runs-on: linux.aws.h100
torch-spec: '--pre torch torchvision torchaudio mslk --index-url https://download.pytorch.org/whl/nightly/cu126'
gpu-arch-type: "cuda"
gpu-arch-version: "12.4"
gpu-arch-version: "12.6"
permissions:
id-token: write
contents: read
Expand All @@ -46,10 +46,10 @@ jobs:
pip install uv
pip install ${{ matrix.torch-spec }}
uv pip install -r dev-requirements.txt
pip install . --no-build-isolation
pip install . --no-build-isolation -vv
pytest test/integration --verbose -s
pytest test/dtypes/test_affine_quantized_float.py --verbose -s
python test/quantization/quantize_/workflows/float8/test_float8_tensor.py
pytest test/quantization/quantize_/workflows/float8 -v
python test/quantization/quantize_/workflows/int4/test_int4_tensor.py
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py
./test/float8/test_everything_single_gpu.sh
Expand Down
127 changes: 52 additions & 75 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,6 @@ def read_version(file_path="version.txt"):
return file.readline().strip()


# Check if torch version is at least 2.10.0 (for stable ABI support)
# util copied from torchao/utils.py 1
def _parse_version(version_string):
"""
Parse version string representing pre-release with -1

Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0]
"""
# Check for pre-release indicators
is_prerelease = bool(re.search(r"(git|dev)", version_string))
match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string)
if match:
major, minor, patch = map(int, match.groups())
if is_prerelease:
patch = -1
return [major, minor, patch]
else:
raise ValueError(f"Invalid version string format: {version_string}")


def _is_fbcode():
return not hasattr(torch.version, "git_version")


def _torch_version_at_least(min_version):
if _is_fbcode():
return True

# Parser for local identifiers
return _parse_version(torch.__version__) >= _parse_version(min_version)


SPINQUANT_REL_PATH = Path("torchao") / "prototype" / "spinquant"
HADAMARD_JSON = "_hadamard_matrices.json"
HADAMARD_PKL = "_hadamard_matrices.pkl"
Expand Down Expand Up @@ -182,6 +150,38 @@ def use_debug_mode():
)


# Check if torch version is at least 2.10.0 (for stable ABI support)
# util copied from torchao/utils.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not import the util

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the code that is installing torchao, so it is assuming torchao is not installed yet, you mean just import the file? is that too confusing?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to keep setup.py as simple as possible. Importing torchao/utils.py might have side effects outside the scope of this file

def _parse_version(version_string):
"""
Parse version string representing pre-release with -1

Examples: "2.5.0.dev20240708+cu121" -> [2, 5, -1], "2.5.0" -> [2, 5, 0]
"""
# Check for pre-release indicators
is_prerelease = bool(re.search(r"(git|dev)", version_string))
match = re.match(r"(\d+)\.(\d+)\.(\d+)", version_string)
if match:
major, minor, patch = map(int, match.groups())
if is_prerelease:
patch = -1
return [major, minor, patch]
else:
raise ValueError(f"Invalid version string format: {version_string}")


def _is_fbcode():
return not hasattr(torch.version, "git_version")


def _torch_version_at_least(min_version):
if _is_fbcode():
return True

# Parser for local identifiers
return _parse_version(torch.__version__) >= _parse_version(min_version)


def detect_hipify_v2():
try:
from torch.utils.hipify import __version__
Expand Down Expand Up @@ -649,7 +649,6 @@ def get_extensions():

use_cutlass = False
cutlass_90a_sources = None
stable_cutlass_90a_sources = None
build_for_sm90a = False
build_for_sm100a = False
if use_cuda and not IS_WINDOWS:
Expand Down Expand Up @@ -681,34 +680,31 @@ def get_extensions():
)

build_for_sm90a, build_for_sm100a = get_cutlass_build_flags()
# Define sm90a sources

# Define sm90a sources that use stable ABI (requires torch >= 2.10.0)
cutlass_90a_sources = [
# TODO: move this to stable_cutlass_90a_sources in #3727
os.path.join(
extensions_cuda_dir,
"to_sparse_semi_structured_cutlass_sm9x",
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
),
]

stable_cutlass_90a_sources = [
os.path.join(
extensions_cuda_dir,
"rowwise_scaled_linear_sparse_cutlass",
"rowwise_scaled_linear_sparse_cutlass_f8f8.cu",
),
]
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
stable_cutlass_90a_sources.append(
cutlass_90a_sources.append(
os.path.join(
extensions_cuda_dir,
"rowwise_scaled_linear_sparse_cutlass",
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
)
)

# Always remove sm90a sources from main sources
sources = [s for s in sources if s not in cutlass_90a_sources]
sources = [s for s in sources if s not in stable_cutlass_90a_sources]

else:
# Remove CUTLASS-based kernels from the sources list. An
Expand Down Expand Up @@ -770,17 +766,30 @@ def get_extensions():
)

# Only build the cutlass_90a extension if sm90a is in the architecture flags
# TODO: delete this after #3726 and #3727
# and if torch version >= 2.10
if (
cutlass_90a_sources is not None
and len(cutlass_90a_sources) > 0
and build_for_sm90a
and _torch_version_at_least("2.10.0")
):
cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
# Only use sm90a architecture for these sources, ignoring other flags
cutlass_90a_extra_compile_args["nvcc"].append(
"-gencode=arch=compute_90a,code=sm_90a"
cutlass_90a_extra_compile_args["nvcc"].extend(
[
"-DUSE_CUDA",
"-gencode=arch=compute_90a,code=sm_90a",
"-DTORCH_TARGET_VERSION=0x020a000000000000",
]
)
# Add compile flags for stable ABI support (requires torch >= 2.10)
cutlass_90a_extra_compile_args["cxx"].extend(
[
"-DUSE_CUDA",
"-DTORCH_TARGET_VERSION=0x020a000000000000",
]
)
# stable ABI cutlass_90a module
ext_modules.append(
extension(
"torchao._C_cutlass_90a",
Expand All @@ -791,38 +800,6 @@ def get_extensions():
)
)

# Only build the cutlass_90a extension if sm90a is in the architecture flags
# and if torch version >= 2.10
if (
stable_cutlass_90a_sources is not None
and len(stable_cutlass_90a_sources) > 0
and build_for_sm90a
and _torch_version_at_least("2.10.0")
):
stable_cutlass_90a_extra_compile_args = copy.deepcopy(extra_compile_args)
# Only use sm90a architecture for these sources, ignoring other flags
stable_cutlass_90a_extra_compile_args["nvcc"].append(
"-gencode=arch=compute_90a,code=sm_90a"
)
# Add -DUSE_CUDA for stable ABI support (using features in torch 2.10)
stable_cutlass_90a_extra_compile_args["cxx"].append("-DUSE_CUDA")
stable_cutlass_90a_extra_compile_args["cxx"].append(
"-DTORCH_TARGET_VERSION=0x020a000000000000"
)
stable_cutlass_90a_extra_compile_args["nvcc"].append("-DUSE_CUDA")
stable_cutlass_90a_extra_compile_args["nvcc"].append(
"-DTORCH_TARGET_VERSION=0x020a000000000000"
)
ext_modules.append(
extension(
"torchao._C_cutlass_90a_stable",
stable_cutlass_90a_sources,
py_limited_api=True,
extra_compile_args=stable_cutlass_90a_extra_compile_args,
extra_link_args=extra_link_args,
)
)

# Build CMakeLists from /torchao/csrc/cpu - additional options become available : TORCHAO_BUILD_CPU_AARCH64, TORCHAO_BUILD_KLEIDIAI, TORCHAO_BUILD_MPS_OPS, TORCHAO_PARALLEL_BACKEND
if build_macos_arm_auto or os.getenv("BUILD_TORCHAO_EXPERIMENTAL") == "1":
build_options = BuildOptions()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,20 @@
)
from torchao.quantization.utils import compute_error
from torchao.sparsity import apply_fake_sparsity
from torchao.utils import is_sm_at_least_90
from torchao.utils import (
is_sm_at_least_90,
torch_version_at_least,
)

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)


@unittest.skipIf(
not torch_version_at_least("2.10.0"),
"Need torch >= 2.10.0 for availability of ABI kernels",
)
class TestSparse2x4Float8Tensor(common_utils.TestCase):
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
Expand Down
61 changes: 59 additions & 2 deletions torchao/csrc/cuda/cutlass_extensions/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,74 @@
// LICENSE file in the root directory of this source tree.
#pragma once

#include <deque>
#include <mutex>

#include <cuda_runtime.h>
#include <cutlass/cutlass.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>
#include <torch/csrc/stable/tensor.h>
#include <torch/headeronly/util/Exception.h>

#define CUTLASS_STATUS_CHECK(status, message_prefix) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, message_prefix, \
" : Got CUTLASS error: ", cutlassGetStatusString(status)); \
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, message_prefix, \
" : Got CUTLASS error: ", cutlassGetStatusString(status)); \
}

namespace torchao {

namespace {

std::deque<std::once_flag> device_flags;
std::vector<cudaDeviceProp> device_properties;

inline void initDevicePropertiesVectors() {
static bool init_flag [[maybe_unused]] = []() {
int device_count;
cudaError_t err = cudaGetDeviceCount(&device_count);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDeviceCount failed: " +
std::string(cudaGetErrorString(err)));
}
device_flags.resize(device_count);
device_properties.resize(device_count);
return true;
}();
}

inline void initDeviceProperty(int device_index) {
cudaDeviceProp device_prop{};
cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " +
std::string(cudaGetErrorString(err)));
}
device_properties[device_index] = device_prop;
}

inline cudaDeviceProp* get_device_prop() {
initDevicePropertiesVectors();
int device_index;
cudaError_t err = cudaGetDevice(&device_index);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDevice failed: " +
std::string(cudaGetErrorString(err)));
}

std::call_once(device_flags[device_index], initDeviceProperty, device_index);
return &device_properties[device_index];
}

inline cudaStream_t get_current_cuda_stream(const torch::stable::Tensor& t) {
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(
static_cast<int32_t>(t.get_device_index()), &stream_ptr));
return static_cast<cudaStream_t>(stream_ptr);
}

} // anonymous namespace

template <typename Kernel>
struct enable_2x_kernel_for_sm80_or_later : Kernel {
template <typename... Args>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,6 @@

#include "cutlass_extensions/common.h"

// Redefine CUTLASS_STATUS_CHECK to use STD_TORCH_CHECK for ABI stability
// TODO: just edit the original one in torchao/csrc/cuda/cutlass_extensions/common.h after #3727
#undef CUTLASS_STATUS_CHECK
#define CUTLASS_STATUS_CHECK(status, message_prefix) \
{ \
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, message_prefix, \
" : Got CUTLASS error: ", cutlassGetStatusString(status)); \
}

#endif

#define OPERATOR_NAME "rowwise_scaled_linear_sparse_cutlass"
Expand All @@ -66,53 +57,6 @@ namespace {
inline tsa::DeviceGuard make_device_guard(const Tensor& t) {
return tsa::DeviceGuard(static_cast<tsa::DeviceIndex>(t.get_device_index()));
}

std::deque<std::once_flag> device_flags;
std::vector<cudaDeviceProp> device_properties;

void initVectors() {
static bool init_flag [[maybe_unused]] = []() {
int device_count;
cudaError_t err = cudaGetDeviceCount(&device_count);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " +
std::string(cudaGetErrorString(err)));
}
device_flags.resize(device_count);
device_properties.resize(device_count);
return true;
}();
}

void initDeviceProperty(int device_index) {
cudaDeviceProp device_prop{};
cudaError_t err = cudaGetDeviceProperties(&device_prop, device_index);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDeviceProperties failed: " +
std::string(cudaGetErrorString(err)));
}
device_properties[device_index] = device_prop;
}

cudaDeviceProp* get_device_prop() {
initVectors();
int device_index;
cudaError_t err = cudaGetDevice(&device_index);
if (err != cudaSuccess) {
STD_TORCH_CHECK(false, "cudaGetDevice failed: " +
std::string(cudaGetErrorString(err)));
}

std::call_once(device_flags[device_index], initDeviceProperty, device_index);
return &device_properties[device_index];
}

inline cudaStream_t get_current_cuda_stream(const Tensor& t) {
void* stream_ptr = nullptr;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(
static_cast<int32_t>(t.get_device_index()), &stream_ptr));
return static_cast<cudaStream_t>(stream_ptr);
}
} // anonymous namespace

template<
Expand Down
Loading
Loading