Skip to content

Commit 76459c9

Browse files
committed
build(pypi): add cu128 windows build
2 parents c8474a0 + 85a0fea commit 76459c9

File tree

5 files changed

+80
-42
lines changed

5 files changed

+80
-42
lines changed

.github/workflows/cuda12.8-whl-release.yml

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,48 @@ jobs:
5353
retention-days: 1
5454
name: linux-${{ matrix.pyver }}
5555

56+
windows-build:
57+
strategy:
58+
matrix:
59+
pyver: ['3.9', '3.10', '3.11', '3.12', '3.13']
60+
runs-on: windows-latest
61+
steps:
62+
- name: Set git for windows
63+
run: |
64+
git config --global core.longpaths true
65+
- name: Checkout repository
66+
uses: actions/checkout@v3
67+
- name: Set up python
68+
uses: actions/setup-python@v4
69+
with:
70+
python-version: ${{ matrix.pyver }}
71+
- name: Install python packages
72+
run: |
73+
pip install build change-wheel-version
74+
- name: Setup CUDA Toolkit
75+
id: cuda-toolkit
76+
shell: pwsh
77+
run: ./builder/windows/setup_cuda.ps1
78+
env:
79+
INPUT_CUDA_VERSION: '12.8.1'
80+
- name: Build wheel
81+
run: |
82+
python -m build --wheel -o build/wheel
83+
Get-ChildItem -Path "build" -Filter "*.whl" | ForEach-Object { change_wheel_version $_.FullName --local-version cu128 --delete-old-wheel }
84+
- name: Upload Artifacts
85+
uses: actions/upload-artifact@v4
86+
with:
87+
if-no-files-found: error
88+
path: build/wheel/*
89+
retention-days: 1
90+
name: windows-${{ matrix.pyver }}
91+
5692
publish:
5793
runs-on: ubuntu-latest
5894
environment: 'prod'
5995
needs:
6096
- linux-build
97+
- windows-build
6198
steps:
6299
- name: Checkout repository
63100
uses: actions/checkout@v3

.github/workflows/linux-x64-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
strategy:
3030
fail-fast: false
3131
matrix:
32-
cudaver: [11.8, 12.4]
32+
cudaver: [11.8, 12.4, 12.8]
3333
name: cuda-${{ matrix.cudaver }}
3434
runs-on: ubuntu-latest
3535
steps:

.github/workflows/windows-x64-gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
strategy:
3030
fail-fast: false
3131
matrix:
32-
cudaver: [11.8.0, 12.1.0]
32+
cudaver: [11.8.0, 12.5.0, 12.8.1]
3333
name: cuda-${{ matrix.cudaver }}
3434
runs-on: windows-latest
3535
steps:

builder/windows/setup_cuda.ps1

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ if ($CUDA_VERSION_FULL -eq "12.1.0") {
2626
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_522.06_windows.exe"
2727
} elseif ($CUDA_VERSION_FULL -eq "12.5.0") {
2828
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.5.0/local_installers/cuda_12.5.0_555.85_windows.exe"
29+
} elseif ($CUDA_VERSION_FULL -eq "12.8.1") {
30+
$downloadUrl = "https://developer.download.nvidia.com/compute/cuda/12.8.1/local_installers/cuda_12.8.1_572.61_windows.exe"
2931
} else {
3032
Write-Output "Unsupported CUDA version specified"
3133
exit 1

src/turbomind/kernels/gemm/scaled_gmma_fp8_sm90.h

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,45 +14,44 @@ namespace turbomind::gemm {
1414
template<int TILE_M, int TILE_N, int TILE_K, int BATCH_M, int BATCH_N, int PIPE_M, int PIPE_N>
1515
struct ScaledGmmaFP8_TN {
1616

17-
static constexpr auto select_gmma_operation()
18-
{
19-
static_assert(TILE_M % (BATCH_M * PIPE_M) == 0);
20-
static_assert(TILE_N % (BATCH_N * PIPE_N) == 0);
21-
22-
constexpr int M = TILE_M / (BATCH_M * PIPE_M);
23-
constexpr int N = TILE_N / (BATCH_N * PIPE_N);
17+
template<int tile_m = TILE_M,
18+
int tile_n = TILE_N,
19+
int batch_m = BATCH_M,
20+
int batch_n = BATCH_N,
21+
int pipe_m = PIPE_M,
22+
int pipe_n = PIPE_N>
23+
struct select_gmma_operation {
24+
static constexpr int M = tile_m / (batch_m * pipe_m);
25+
static constexpr int N = tile_n / (batch_n * pipe_n);
2426

27+
static_assert(tile_m % (batch_m * pipe_m) == 0);
28+
static_assert(tile_n % (batch_n * pipe_n) == 0);
2529
static_assert(M % 64 == 0);
2630

27-
using namespace cute::SM90::GMMA;
28-
29-
if constexpr (N % 256 == 0) {
30-
return MMA_64x256x32_F32E4M3E4M3_SS_TN<>{};
31-
}
32-
else if constexpr (N % 224 == 0) {
33-
return MMA_64x224x32_F32E4M3E4M3_SS_TN<>{};
34-
}
35-
else if constexpr (N % 192 == 0) {
36-
return MMA_64x192x32_F32E4M3E4M3_SS_TN<>{};
37-
}
38-
else if constexpr (N % 160 == 0) {
39-
return MMA_64x160x32_F32E4M3E4M3_SS_TN<>{};
40-
}
41-
else if constexpr (N % 128 == 0) {
42-
return MMA_64x128x32_F32E4M3E4M3_SS_TN<>{};
43-
}
44-
else if constexpr (N % 96 == 0) {
45-
return MMA_64x96x32_F32E4M3E4M3_SS_TN<>{};
46-
}
47-
else if constexpr (N % 64 == 0) {
48-
return MMA_64x64x32_F32E4M3E4M3_SS_TN<>{};
49-
}
50-
else {
51-
static_assert(N == 0, "unsupported configuration");
52-
}
53-
}
31+
using type = std::conditional_t<
32+
N % 256 == 0,
33+
cute::SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN<>,
34+
std::conditional_t<
35+
N % 224 == 0,
36+
cute::SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN<>,
37+
std::conditional_t<
38+
N % 192 == 0,
39+
cute::SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN<>,
40+
std::conditional_t<
41+
N % 160 == 0,
42+
cute::SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN<>,
43+
std::conditional_t<
44+
N % 128 == 0,
45+
cute::SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN<>,
46+
std::conditional_t<N % 96 == 0,
47+
cute::SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN<>,
48+
std::conditional_t<N % 64 == 0,
49+
cute::SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN<>,
50+
void>>>>>>>;
51+
static_assert(!std::is_same_v<type, void>, "unsupported configuration");
52+
};
5453

55-
using Operation = decltype(select_gmma_operation());
54+
using Operation = select_gmma_operation<>::type;
5655

5756
static constexpr typename cute::MMA_Traits<Operation>::Shape_MNK OP_Shape{};
5857

@@ -242,11 +241,11 @@ struct ScaledGmmaFP8_TN {
242241
int n = ((i_n * PIPE_N) + p_n * BATCH_N) + b_n;
243242
func(frag[i_m][i_n][p_m][p_n][b_m][b_n], m, n);
244243
} // BATCH_N
245-
} // BATCH_M
246-
} // PIPE_N
247-
} // PIPE_M
248-
} // ITER_N
249-
} // ITER_M
244+
} // BATCH_M
245+
} // PIPE_N
246+
} // PIPE_M
247+
} // ITER_N
248+
} // ITER_M
250249
}
251250

252251
template<class Frag, class Func>

0 commit comments

Comments
 (0)