Skip to content

Commit c2d142d

Browse files
committed
build(pypi): add cu128 windows build
2 parents c8474a0 + 85a0fea commit c2d142d

File tree

5 files changed

+85
-14
lines changed

5 files changed

+85
-14
lines changed

.github/workflows/cuda12.8-whl-release.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,48 @@ jobs:
5353
retention-days: 1
5454
name: linux-${{ matrix.pyver }}
5555

56+
windows-build:
57+
strategy:
58+
matrix:
59+
pyver: ['3.9', '3.10', '3.11', '3.12', '3.13']
60+
runs-on: windows-latest
61+
steps:
62+
- name: Set git for windows
63+
run: |
64+
git config --global core.longpaths true
65+
- name: Checkout repository
66+
uses: actions/checkout@v3
67+
- name: Set up python
68+
uses: actions/setup-python@v4
69+
with:
70+
python-version: ${{ matrix.pyver }}
71+
- name: Install python packages
72+
run: |
73+
pip install build change-wheel-version
74+
- name: Setup CUDA Toolkit
75+
id: cuda-toolkit
76+
shell: pwsh
77+
run: ./builder/windows/setup_cuda.ps1
78+
env:
79+
INPUT_CUDA_VERSION: '12.8.1'
80+
- name: Build wheel
81+
run: |
82+
python -m build --wheel -o build/wheel
83+
Get-ChildItem -Path "build" -Filter "*.whl" | ForEach-Object { change_wheel_version $_.FullName --local-version cu128 --delete-old-wheel }
84+
- name: Upload Artifacts
85+
uses: actions/upload-artifact@v4
86+
with:
87+
if-no-files-found: error
88+
path: build/wheel/*
89+
retention-days: 1
90+
name: windows-${{ matrix.pyver }}
91+
5692
publish:
5793
runs-on: ubuntu-latest
5894
environment: 'prod'
5995
needs:
6096
- linux-build
97+
- windows-build
6198
steps:
6299
- name: Checkout repository
63100
uses: actions/checkout@v3

.github/workflows/linux-x64-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
strategy:
3030
fail-fast: false
3131
matrix:
32-
cudaver: [11.8, 12.4]
32+
cudaver: [11.8, 12.4, 12.8]
3333
name: cuda-${{ matrix.cudaver }}
3434
runs-on: ubuntu-latest
3535
steps:

.github/workflows/windows-x64-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
strategy:
3030
fail-fast: false
3131
matrix:
32-
cudaver: [11.8.0, 12.1.0]
32+
cudaver: [11.8.0, 12.5.0, 12.8.1]
3333
name: cuda-${{ matrix.cudaver }}
3434
runs-on: windows-latest
3535
steps:

builder/windows/setup_cuda.ps1

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ if ($CUDA_VERSION_FULL -eq "12.1.0") {
2626
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe"
2727
} elseif ($CUDA_VERSION_FULL -eq "12.5.0") {
2828
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.5.0/local_installers/cuda_12.5.0_555.85_windows.exe"
29+
} elseif ($CUDA_VERSION_FULL -eq "12.8.1") {
30+
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_572.61_windows.exe"
2931
} else {
3032
Write-Output "Unsupported CUDA version specified"
3133
exit 1

src/turbomind/kernels/gemm/scaled_gmma_fp8_sm90.h

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,44 @@
1111

1212
namespace turbomind::gemm {
1313

14+
namespace {
15+
16+
template<int tile>
17+
struct select_gmma_operation;
18+
template<>
19+
struct select_gmma_operation<256> {
20+
using type = cute::SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN<>;
21+
};
22+
template<>
23+
struct select_gmma_operation<224> {
24+
using type = cute::SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN<>;
25+
};
26+
template<>
27+
struct select_gmma_operation<192> {
28+
using type = cute::SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN<>;
29+
};
30+
template<>
31+
struct select_gmma_operation<160> {
32+
using type = cute::SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN<>;
33+
};
34+
template<>
35+
struct select_gmma_operation<128> {
36+
using type = cute::SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN<>;
37+
};
38+
template<>
39+
struct select_gmma_operation<96> {
40+
using type = cute::SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN<>;
41+
};
42+
template<>
43+
struct select_gmma_operation<64> {
44+
using type = cute::SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN<>;
45+
};
46+
47+
} // namespace
48+
1449
template<int TILE_M, int TILE_N, int TILE_K, int BATCH_M, int BATCH_N, int PIPE_M, int PIPE_N>
1550
struct ScaledGmmaFP8_TN {
16-
17-
static constexpr auto select_gmma_operation()
51+
static constexpr auto select_gmma_size()
1852
{
1953
static_assert(TILE_M % (BATCH_M * PIPE_M) == 0);
2054
static_assert(TILE_N % (BATCH_N * PIPE_N) == 0);
@@ -24,35 +58,33 @@ struct ScaledGmmaFP8_TN {
2458

2559
static_assert(M % 64 == 0);
2660

27-
using namespace cute::SM90::GMMA;
28-
2961
if constexpr (N % 256 == 0) {
30-
return MMA_64x256x32_F32E4M3E4M3_SS_TN<>{};
62+
return 256;
3163
}
3264
else if constexpr (N % 224 == 0) {
33-
return MMA_64x224x32_F32E4M3E4M3_SS_TN<>{};
65+
return 224;
3466
}
3567
else if constexpr (N % 192 == 0) {
36-
return MMA_64x192x32_F32E4M3E4M3_SS_TN<>{};
68+
return 192;
3769
}
3870
else if constexpr (N % 160 == 0) {
39-
return MMA_64x160x32_F32E4M3E4M3_SS_TN<>{};
71+
return 160;
4072
}
4173
else if constexpr (N % 128 == 0) {
42-
return MMA_64x128x32_F32E4M3E4M3_SS_TN<>{};
74+
return 128;
4375
}
4476
else if constexpr (N % 96 == 0) {
45-
return MMA_64x96x32_F32E4M3E4M3_SS_TN<>{};
77+
return 96;
4678
}
4779
else if constexpr (N % 64 == 0) {
48-
return MMA_64x64x32_F32E4M3E4M3_SS_TN<>{};
80+
return 64;
4981
}
5082
else {
5183
static_assert(N == 0, "unsupported configuration");
5284
}
5385
}
5486

55-
using Operation = decltype(select_gmma_operation());
87+
using Operation = typename select_gmma_operation<select_gmma_size()>::type;
5688

5789
static constexpr typename cute::MMA_Traits<Operation>::Shape_MNK OP_Shape{};
5890

0 commit comments

Comments
 (0)