[claude] Make rowwise_scaled_linear_sparse_cutlass ABI stable#3725
[claude] Make rowwise_scaled_linear_sparse_cutlass ABI stable#3725andrewor14 merged 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3725
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh
Outdated
Show resolved
Hide resolved
| std::call_once(device_flags[device_index], initDeviceProperty, device_index); | ||
| return &device_properties[device_index]; | ||
| } | ||
| } // anonymous namespace |
There was a problem hiding this comment.
Yes this looks right, iiuc it is almost an exact copy of what's in flash_ap_stable.cpp, and only change is how the device_index is acquired on line 83 and I think that is an improvement :)
a small nit is maybe use t.get_device_index() instead of t.get_device()
436d664 to
6a0887c
Compare
|
LGTM, my main qs for these sort of PRs are "what's the test plan?". Does it pass existing tests? |
47e109e to
10eb8b3
Compare
10eb8b3 to
95a5cb7
Compare
we added 2.10 in CI recently, probably can run the related test with that |
f37a4db to
89a0123
Compare
| extensions_cuda_dir, | ||
| "to_sparse_semi_structured_cutlass_sm9x", | ||
| "to_sparse_semi_structured_cutlass_sm9x_f8.cu", | ||
| ), |
There was a problem hiding this comment.
@jerryzh168 I also separated rowwise_scaled_linear_sparse_cutlass and to_sparse_semi_structured_cutlass_sm9x here to make my PR pass tests.
I'm thinking once I land this you can move these few lines to stable_cutlass_90a_sources? Then eventually we can delete cutlass_90a_sources. What do you think
There was a problem hiding this comment.
yeah I can rebase if you can get the CI to pass
2da941b to
c625428
Compare
setup.py
Outdated
| stable_cutlass_90a_sources is not None | ||
| and len(stable_cutlass_90a_sources) > 0 | ||
| and build_for_sm90a | ||
| and _parse_version(torch.__version__) >= [2, 10, 0] |
There was a problem hiding this comment.
do we need to count for fbcode as well
There was a problem hiding this comment.
I was copy pasting all these functions https://github.com/pytorch/ao/pull/3727/changes#diff-60f61ab7a8d1910d86d9fda2261620314edcae5894d5aaa236b821c7256badd7R155-R182 seems much more complicated than before
c625428 to
b680848
Compare
torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass/rowwise_scaled_linear_sparse_cutlass.cuh
Show resolved
Hide resolved
| " : Expected Xq argument to be strided, got layout ", | ||
| Xq.layout()); | ||
| TORCH_CHECK(X_scale.dim() == Xq.dim() - 1, OPERATOR_NAME, | ||
| STD_TORCH_CHECK(Xq.is_contiguous() || Xq.stride(-1) == 1, OPERATOR_NAME, |
There was a problem hiding this comment.
@janeyx99 this doesn't seem right, is there a good replacement for tensor.layout()?
There was a problem hiding this comment.
Would is_contiguous be too strict for the kernel? If not, is_contiguous is even more strict than layout is strided.
2.10 doesn't have Tensor.layout yet but we can add for 2.11. In the meantime, a not so nice workaround would be to move this check to python :/
b680848 to
6589ec0
Compare
| #define CUTLASS_STATUS_CHECK(status, message_prefix) \ | ||
| { \ | ||
| TORCH_CHECK(status == cutlass::Status::kSuccess, message_prefix, \ | ||
| STD_TORCH_CHECK(status == cutlass::Status::kSuccess, message_prefix, \ |
There was a problem hiding this comment.
you'll need to change this locally in rowwise_scaled_linear_sparse_cutlass.cuh file I think
There was a problem hiding this comment.
Hmm what do you mean? I already replaced all TORCH_CHECK in the .cuh file, but this one was failing compilation because it was inlined (and TORCH_CHECK was not replaced)
There was a problem hiding this comment.
oh, I followed the same fix in my PR first but found that [torchao/csrc/cuda/cutlass_extensions/common.h](https://github.com/pytorch/ao/pull/3725/files/6589ec0d78011214d24c850b5579ad5c16d5c35f#diff-d7b8367f1b64ac9e770e0f286c78b8bd83a515e113c89496fe62ef023c1e7057) is used by cu files both in your PR and my PR, and stable API include files (imports) are needed to use STD_TORCH_CHECK I think, so we can't modify this in the common.h (since the other half of the files are not using stable ABI in both our PRs), so we just have to do this locally https://github.com/pytorch/ao/pull/3727/changes#diff-818fa0804aed61661b3208bb032fdc1692086e8ce1e9ec0d191261e0fb8b93dbR37-R43 instead of in common.h
this is something that can be simplified if we do the stable ABI change in the same PR, but it's not too complicated to workaround
e985ce8 to
460ce6d
Compare
jerryzh168
left a comment
There was a problem hiding this comment.
please add fbcode check, before merging
**Prompt:**
```
Make everything under this directory ABI stable:
ao/torchao/csrc/cuda/rowwise_scaled_linear_sparse_cutlass
Read these for instructions:
pytorch/docs/source/notes/libtorch_stable_abi.md
cppdocs/_sources/stable.rst.txt
Use these files for an example:
Before: flash-attention/hopper/flash_api.cpp
After: flash-attention/hopper/flash_api_stable.cpp
Additional instructions:
Replace TORCH_CHECK with STD_TORCH_CHECK without changing the error message
Replace c10::cuda::CUDAGuard with DeviceGuard
When calling aoti_torch_get_current_cuda_stream, get the device index from a tensor, not from the general torch::stable::accelerator::getCurrentDeviceIndex()
Where possible, use get_device_index instead of get_device
Before checking if CUDA_VERSION is defined, remember to include cuda.h
Don’t do these things:
Don’t define a dummy _C module that can be accessed from python
Don’t declare aoti_torch_get_current_cuda_stream, just include it from torch/csrc/inductor/aoti_torch/c/shim.h and add -DUSE_CUDA to both cxx and nvcc in setup.py if needed
Also add -DTORCH_TARGET_VERSION=0x020a000000000000 to both cxx and nvcc in setup.py
Don’t box kernels manually, just use TORCH_BOX
```
**Follow-up prompts:**
```
Seems like STD_CUDA_KERNEL_LAUNCH_CHECK is undefined, is it missing an import?
```
```
Revert all checks that check if the layout is strided, where the original code looks like this:
TORCH_CHECK(Wq.layout() == at::Layout::Strided, …)
Then, rewrite these checks using the following instead to make it ABI stable:
using torch::headeronly::Layout
int32_t wlayout;
TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(Wq.get(), wlayout));
STD_TORCH_CHECK(torch::stable::detail::to<Layout>(
torch::stable::detail::from(wlayout)) == Layout::Strided, …
)
Do not change the error message.
```
```
Can you save "torch::stable::detail::to<Layout>(torch::stable::detail::from(xq_layout))" to a variable to reduce duplicate code? Same for the other checks
```
460ce6d to
d92ae19
Compare
| device_properties[device_index] = device_prop; | ||
| } | ||
|
|
||
| cudaDeviceProp* get_device_prop() { |
Test Plan:
(needs #3768)
Prompt:
Follow-up prompts: