Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ torch::Tensor wvSplitK_int8(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount);

torch::Tensor wvSplitK_int4(const at::Tensor& in_a, const at::Tensor& in_b,
const at::Tensor& in_scale,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount);

torch::Tensor wvSplitK_int4_g(const at::Tensor& in_a, const at::Tensor& in_b,
const at::Tensor& in_scale,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount, const int64_t group_size);

#ifdef VLLM_SKINNY_GEMM_SWEEP
torch::Tensor wvSplitK_sweep(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount, const int64_t ytile,
const int64_t unrl);

torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a,
const at::Tensor& in_b,
const at::Tensor& in_scale,
Expand All @@ -22,11 +38,6 @@ torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a,
const int64_t unrl, const int64_t achunk,
const int64_t wvprgrp);

torch::Tensor wvSplitK_int4(const at::Tensor& in_a, const at::Tensor& in_b,
const at::Tensor& in_scale,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount);

torch::Tensor wvSplitK_int4_sweep(const at::Tensor& in_a,
const at::Tensor& in_b,
const at::Tensor& in_scale,
Expand All @@ -35,11 +46,6 @@ torch::Tensor wvSplitK_int4_sweep(const at::Tensor& in_a,
const int64_t unrl, const int64_t achunk,
const int64_t wvprgrp);

torch::Tensor wvSplitK_int4_g(const at::Tensor& in_a, const at::Tensor& in_b,
const at::Tensor& in_scale,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount, const int64_t group_size);

torch::Tensor wvSplitK_int4g_sweep(
const at::Tensor& in_a, const at::Tensor& in_b, const at::Tensor& in_scale,
const int64_t CuCount, const int64_t group_size, const int64_t ytile,
Expand All @@ -49,12 +55,6 @@ torch::Tensor wvSplitK_int4g_hf_sweep(
const at::Tensor& in_a, const at::Tensor& in_b, const at::Tensor& in_scale,
const int64_t CuCount, const int64_t group_size, const int64_t ytile,
const int64_t unrl, const int64_t achunk, const int64_t wvprgrp);

#ifdef VLLM_SKINNY_GEMM_SWEEP
torch::Tensor wvSplitK_sweep(const at::Tensor& in_a, const at::Tensor& in_b,
const std::optional<at::Tensor>& in_bias,
const int64_t CuCount, const int64_t ytile,
const int64_t unrl);
#endif

torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b,
Expand Down
405 changes: 208 additions & 197 deletions csrc/rocm/skinny_gemms_int4.cu

Large diffs are not rendered by default.

120 changes: 61 additions & 59 deletions csrc/rocm/skinny_gemms_int8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,9 @@ torch::Tensor wvSplitK_int8(const at::Tensor& in_a, const at::Tensor& in_b,
return out_c;
}

// Sweep function: all 4 tuning params dispatched at runtime (fp16 only).
// Used for benchmarking only — not for production.
// Sweep function disabled by default to reduce compile time.
// Build with -DVLLM_SKINNY_GEMM_SWEEP to enable.
#ifdef VLLM_SKINNY_GEMM_SWEEP
torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a,
const at::Tensor& in_b,
const at::Tensor& in_scale,
Expand Down Expand Up @@ -459,62 +460,62 @@ torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a,

const int THRDS = is_gfx11_int8() ? 32 : 64;

#define SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N) \
{ \
dim3 block(_THRDS, _WVPRGRP); \
int __wvPrGrp = mindiv_int8(M_in, CuCount * _YTILE, _WVPRGRP); \
wvSplitK_int8_hf_sml_<fptype, _THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, \
_N> \
<<<grid, block, 0, stream>>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \
biasptr, cptr, __wvPrGrp, CuCount); \
}
#define SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, _N) \
{ \
dim3 block(_THRDS, _WVPRGRP); \
int __wvPrGrp = mindiv_int8(M_in, CuCount * _YTILE, _WVPRGRP); \
wvSplitK_int8_hf_sml_<fptype, _THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, \
_N> \
<<<grid, block, 0, stream>>>(K_in, M_in, 1, 1, wptr, aptr, sptr, \
biasptr, cptr, __wvPrGrp, CuCount); \
}

#define SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL) \
switch (N_in) { \
case 1: \
SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1) break; \
case 2: \
SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2) break; \
case 3: \
SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3) break; \
case 4: \
SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4) break; \
default: \
TORCH_CHECK(false, "Unsupported N=", N_in); \
}
#define SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL) \
switch (N_in) { \
case 1: \
SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 1) break; \
case 2: \
SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 2) break; \
case 3: \
SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 3) break; \
case 4: \
SWEEP_LAUNCH(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, _UNRL, 4) break; \
default: \
TORCH_CHECK(false, "Unsupported N=", N_in); \
}

#define SWEEP_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK) \
if (unrl == 1) { \
SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1) \
} else if (unrl == 2) { \
SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2) \
} else if (unrl == 4) { \
SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4) \
} else { \
TORCH_CHECK(false, "Unsupported unrl=", unrl); \
}
#define SWEEP_UNRL(_THRDS, _YTILE, _WVPRGRP, _ACHUNK) \
if (unrl == 1) { \
SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 1) \
} else if (unrl == 2) { \
SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 2) \
} else if (unrl == 4) { \
SWEEP_N(_THRDS, _YTILE, _WVPRGRP, _ACHUNK, 4) \
} else { \
TORCH_CHECK(false, "Unsupported unrl=", unrl); \
}

#define SWEEP_YTILE(_THRDS, _WVPRGRP, _ACHUNK) \
if (ytile == 1) { \
SWEEP_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK) \
} else if (ytile == 2) { \
SWEEP_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK) \
} else if (ytile == 4) { \
SWEEP_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK) \
} else { \
TORCH_CHECK(false, "Unsupported ytile=", ytile); \
}
#define SWEEP_YTILE(_THRDS, _WVPRGRP, _ACHUNK) \
if (ytile == 1) { \
SWEEP_UNRL(_THRDS, 1, _WVPRGRP, _ACHUNK) \
} else if (ytile == 2) { \
SWEEP_UNRL(_THRDS, 2, _WVPRGRP, _ACHUNK) \
} else if (ytile == 4) { \
SWEEP_UNRL(_THRDS, 4, _WVPRGRP, _ACHUNK) \
} else { \
TORCH_CHECK(false, "Unsupported ytile=", ytile); \
}

#define SWEEP_WVPRGRP(_THRDS, _ACHUNK) \
if (wvprgrp == 8) { \
SWEEP_YTILE(_THRDS, 8, _ACHUNK) \
} else if (wvprgrp == 12) { \
SWEEP_YTILE(_THRDS, 12, _ACHUNK) \
} else if (wvprgrp == 16) { \
SWEEP_YTILE(_THRDS, 16, _ACHUNK) \
} else { \
TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \
}
#define SWEEP_WVPRGRP(_THRDS, _ACHUNK) \
if (wvprgrp == 8) { \
SWEEP_YTILE(_THRDS, 8, _ACHUNK) \
} else if (wvprgrp == 12) { \
SWEEP_YTILE(_THRDS, 12, _ACHUNK) \
} else if (wvprgrp == 16) { \
SWEEP_YTILE(_THRDS, 16, _ACHUNK) \
} else { \
TORCH_CHECK(false, "Unsupported wvprgrp=", wvprgrp); \
}

if (THRDS == 32) {
if (achunk == 8) {
Expand All @@ -538,11 +539,12 @@ torch::Tensor wvSplitK_int8_sweep(const at::Tensor& in_a,
}
}

#undef SWEEP_LAUNCH
#undef SWEEP_N
#undef SWEEP_UNRL
#undef SWEEP_YTILE
#undef SWEEP_WVPRGRP
#undef SWEEP_LAUNCH
#undef SWEEP_N
#undef SWEEP_UNRL
#undef SWEEP_YTILE
#undef SWEEP_WVPRGRP

return out_c;
}
#endif // VLLM_SKINNY_GEMM_SWEEP
28 changes: 14 additions & 14 deletions csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,32 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
"Tensor? in_bias, int CuCount) -> Tensor");
rocm_ops.impl("wvSplitK_int8", torch::kCUDA, &wvSplitK_int8);

// W8A16 skinny GEMM sweep: all tuning params as runtime args (benchmark only)
rocm_ops.def(
"wvSplitK_int8_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, "
"Tensor? in_bias, int CuCount, int ytile, int unrl, int achunk, "
"int wvprgrp) -> Tensor");
rocm_ops.impl("wvSplitK_int8_sweep", torch::kCUDA, &wvSplitK_int8_sweep);

// W4A16 skinny GEMM: packed int4 weights, fp16/bf16 activations, per-channel
// scale
rocm_ops.def(
"wvSplitK_int4(Tensor in_a, Tensor in_b, Tensor in_scale, "
"Tensor? in_bias, int CuCount) -> Tensor");
rocm_ops.impl("wvSplitK_int4", torch::kCUDA, &wvSplitK_int4);

// W4A16 skinny GEMM sweep: all tuning params as runtime args (benchmark only)
rocm_ops.def(
"wvSplitK_int4_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, "
"Tensor? in_bias, int CuCount, int ytile, int unrl, int achunk, "
"int wvprgrp) -> Tensor");
rocm_ops.impl("wvSplitK_int4_sweep", torch::kCUDA, &wvSplitK_int4_sweep);

// W4A16 grouped skinny GEMM: packed int4 weights, per-group scales
rocm_ops.def(
"wvSplitK_int4_g(Tensor in_a, Tensor in_b, Tensor in_scale, "
"Tensor? in_bias, int CuCount, int group_size) -> Tensor");
rocm_ops.impl("wvSplitK_int4_g", torch::kCUDA, &wvSplitK_int4_g);

#ifdef VLLM_SKINNY_GEMM_SWEEP
rocm_ops.def(
"wvSplitK_int8_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, "
"Tensor? in_bias, int CuCount, int ytile, int unrl, int achunk, "
"int wvprgrp) -> Tensor");
rocm_ops.impl("wvSplitK_int8_sweep", torch::kCUDA, &wvSplitK_int8_sweep);

rocm_ops.def(
"wvSplitK_int4_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, "
"Tensor? in_bias, int CuCount, int ytile, int unrl, int achunk, "
"int wvprgrp) -> Tensor");
rocm_ops.impl("wvSplitK_int4_sweep", torch::kCUDA, &wvSplitK_int4_sweep);

rocm_ops.def(
"wvSplitK_int4g_sweep(Tensor in_a, Tensor in_b, Tensor in_scale, "
"int CuCount, int group_size, int ytile, int unrl, int achunk, "
Expand All @@ -79,6 +78,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
"int wvprgrp) -> Tensor");
rocm_ops.impl("wvSplitK_int4g_hf_sweep", torch::kCUDA,
&wvSplitK_int4g_hf_sweep);
#endif // VLLM_SKINNY_GEMM_SWEEP

// Custom gemm op for skinny matrix-matrix multiplication
rocm_ops.def(
Expand Down
Loading
Loading