Skip to content

[claude] Make rowwise_scaled_linear_sparse_cutlass ABI stable#3725

Merged
andrewor14 merged 1 commit intomainfrom
sparse-kernels-abi
Jan 29, 2026
Merged

[claude] Make rowwise_scaled_linear_sparse_cutlass ABI stable#3725
andrewor14 merged 1 commit intomainfrom
sparse-kernels-abi

Conversation

@andrewor14
Copy link
Copy Markdown
Contributor

@andrewor14 andrewor14 commented Jan 26, 2026

Test Plan:

pytest test/test_ops_rowwise_scaled_linear_sparse_cutlass.py

(needs #3768)

Prompt:

Make everything under this directory ABI stable:
ao/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass

Read these for instructions:
pytorch/docs/source/notes/libtorch_stable_abi.md
cppdocs/_sources/stable.rst.txt

Use these files for an example:
Before: flash-attention/hopper/flash_api.cpp
After: flash-attention/hopper/flash_api_stable.cpp

Additional instructions:
Replace TORCH_CHECK with STD_TORCH_CHECK without changing the error message
Replace c10::cuda::CUDAGuard with DeviceGuard
When calling aoti_torch_get_current_cuda_stream, get the device index from a tensor, not from the general torch::stable::accelerator::getCurrentDeviceIndex()
Where possible, use get_device_index instead of get_device
Before checking if CUDA_VERSION is defined, remember to include cuda.h

Don’t do these things:
Don’t define a dummy _C module that can be accessed from python
Don’t declare aoti_torch_get_current_cuda_stream, just include it from torch/csrc/inductor/aoti_torch/c/shim.h and add -DUSE_CUDA to both cxx and nvcc in setup.py if needed
Also add -DTORCH_TARGET_VERSION=0x020a000000000000 to both cxx and nvcc in setup.py
Don’t box kernels manually, just use TORCH_BOX

Follow-up prompts:

Seems like STD_CUDA_KERNEL_LAUNCH_CHECK is undefined, is it missing an import?
Revert all checks that check if the layout is strided, where the original code looks like this:

TORCH_CHECK(Wq.layout() == at::Layout::Strided, …)

Then, rewrite these checks using the following instead to make it ABI stable:

using torch::headeronly::Layout
int32_t wlayout;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(Wq.get(), wlayout));
STD_TORCH_CHECK(torch::stable::detail::to<Layout>(
    torch::stable::detail::from(wlayout)) == Layout::Strided, …
)

Do not change the error message.
Can you save "torch::stable::detail::to<Layout>(torch::stable::detail::from(xq_layout))" to a variable to reduce duplicate code? Same for the other checks

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 26, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3725

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 26, 2026
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jan 26, 2026
std::call_once(device_flags[device_index], initDeviceProperty, device_index);
return &device_properties[device_index];
}
} // anonymous namespace
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown

@mikaylagawarecki mikaylagawarecki Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this looks right, iiuc it is almost an exact copy of what's in flash_ap_stable.cpp, and only change is how the device_index is acquired on line 83 and I think that is an improvement :)

a small nit is maybe use t.get_device_index() instead of t.get_device()

@andrewor14 andrewor14 changed the title [draft][claude] Make rowwise_scaled_linear_sparse_cutlass ABI stable [claude] Make rowwise_scaled_linear_sparse_cutlass ABI stable Jan 27, 2026
@janeyx99
Copy link
Copy Markdown
Contributor

LGTM, my main qs for these sort of PRs are "what's the test plan?". Does it pass existing tests?

@andrewor14 andrewor14 marked this pull request as ready for review January 28, 2026 16:13
@andrewor14 andrewor14 force-pushed the sparse-kernels-abi branch 2 times, most recently from 47e109e to 10eb8b3 Compare January 28, 2026 16:51
@andrewor14 andrewor14 requested a review from jerryzh168 January 28, 2026 17:25
@jerryzh168
Copy link
Copy Markdown
Contributor

LGTM, my main qs for these sort of PRs are "what's the test plan?". Does it pass existing tests?

we added 2.10 in CI recently, probably can run the related test with that

@andrewor14 andrewor14 force-pushed the sparse-kernels-abi branch 2 times, most recently from f37a4db to 89a0123 Compare January 28, 2026 17:41
extensions_cuda_dir,
"to_sparse_semi_structured_cutlass_sm9x",
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
),
Copy link
Copy Markdown
Contributor Author

@andrewor14 andrewor14 Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 I also separated rowwise_scaled_linear_sparse_cutlass and to_sparse_semi_structured_cutlass_sm9x here to make my PR pass tests.

I'm thinking once I land this you can move these few lines to stable_cutlass_90a_sources? Then eventually we can delete cutlass_90a_sources. What do you think

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I can rebase if you can get the CI to pass

@andrewor14 andrewor14 force-pushed the sparse-kernels-abi branch 2 times, most recently from 2da941b to c625428 Compare January 28, 2026 19:05
setup.py Outdated
stable_cutlass_90a_sources is not None
and len(stable_cutlass_90a_sources) > 0
and build_for_sm90a
and _parse_version(torch.__version__) >= [2, 10, 0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to count for fbcode as well

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

" : Expected Xq argument to be strided, got layout ",
Xq.layout());
TORCH_CHECK(X_scale.dim() == Xq.dim() - 1, OPERATOR_NAME,
STD_TORCH_CHECK(Xq.is_contiguous() || Xq.stride(-1) == 1, OPERATOR_NAME,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janeyx99 this doesn't seem right, is there a good replacement for tensor.layout()?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would is_contiguous be too strict for the kernel? If not, is_contiguous is even more strict than layout is strided.

2.10 doesn't have Tensor.layout yet but we can add for 2.11. In the meantime, a not so nice workaround would be to move this check to python :/

#define CUTLASS_STATUS_CHECK(status, message_prefix) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, message_prefix, \
STD_TORCH_CHECK(status == cutlass::Status::kSuccess, message_prefix, \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you'll need to change this locally in rowwise_scaled_linear_sparse_cutlass.cuh file I think

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm what do you mean? I already replaced all TORCH_CHECK in the .cuh file, but this one was failing compilation because it was inlined (and TORCH_CHECK was not replaced)

Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I followed the same fix in my PR first but found that [torchao/csrc/cuda/cutlass_extensions/common.h](https://github.com/pytorch/ao/pull/3725/files/6589ec0d78011214d24c850b5579ad5c16d5c35f#diff-d7b8367f1b64ac9e770e0f286c78b8bd83a515e113c89496fe62ef023c1e7057) is used by cu files both in your PR and my PR, and stable API include files (imports) are needed to use STD_TORCH_CHECK I think, so we can't modify this in the common.h (since the other half of the files are not using stable ABI in both our PRs), so we just have to do this locally https://github.com/pytorch/ao/pull/3727/changes#diff-818fa0804aed61661b3208bb032fdc1692086e8ce1e9ec0d191261e0fb8b93dbR37-R43 instead of in common.h

this is something that can be simplified if we do the stable ABI change in the same PR, but it's not too complicated to workaround

@andrewor14 andrewor14 force-pushed the sparse-kernels-abi branch 3 times, most recently from e985ce8 to 460ce6d Compare January 29, 2026 19:44
@andrewor14 andrewor14 requested a review from jerryzh168 January 29, 2026 22:27
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add fbcode check, before merging

**Prompt:**

```
Make everything under this directory ABI stable:
ao/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass

Read these for instructions:
pytorch/docs/source/notes/libtorch_stable_abi.md
cppdocs/_sources/stable.rst.txt

Use these files for an example:
Before: flash-attention/hopper/flash_api.cpp
After: flash-attention/hopper/flash_api_stable.cpp

Additional instructions:
Replace TORCH_CHECK with STD_TORCH_CHECK without changing the error message
Replace c10::cuda::CUDAGuard with DeviceGuard
When calling aoti_torch_get_current_cuda_stream, get the device index from a tensor, not from the general torch::stable::accelerator::getCurrentDeviceIndex()
Where possible, use get_device_index instead of get_device
Before checking if CUDA_VERSION is defined, remember to include cuda.h

Don’t do these things:
Don’t define a dummy _C module that can be accessed from python
Don’t declare aoti_torch_get_current_cuda_stream, just include it from torch/csrc/inductor/aoti_torch/c/shim.h and add -DUSE_CUDA to both cxx and nvcc in setup.py if needed
Also add -DTORCH_TARGET_VERSION=0x020a000000000000 to both cxx and nvcc in setup.py
Don’t box kernels manually, just use TORCH_BOX
```

**Follow-up prompts:**

```
Seems like STD_CUDA_KERNEL_LAUNCH_CHECK is undefined, is it missing an import?
```

```
Revert all checks that check if the layout is strided, where the original code looks like this:

TORCH_CHECK(Wq.layout() == at::Layout::Strided, …)

Then, rewrite these checks using the following instead to make it ABI stable:

using torch::headeronly::Layout
int32_t wlayout;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(Wq.get(), wlayout));
STD_TORCH_CHECK(torch::stable::detail::to<Layout>(
    torch::stable::detail::from(wlayout)) == Layout::Strided, …
)

Do not change the error message.
```

```
Can you save "torch::stable::detail::to<Layout>(torch::stable::detail::from(xq_layout))" to a variable to reduce duplicate code? Same for the other checks
```
@andrewor14 andrewor14 merged commit df0fde3 into main Jan 29, 2026
22 of 27 checks passed
device_properties[device_index] = device_prop;
}

cudaDeviceProp* get_device_prop() {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants