Skip to content
Merged
29 changes: 13 additions & 16 deletions csrc/trtllm_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,20 @@ struct TrtllmGenGemmRunnerOptions {
int64_t select_kernel_fp8(int32_t M, int32_t N, int32_t K,
const gemm::gemm::GemmInterface& interface) {
static constexpr const char* KERNEL_NAME_HIGH_N_K_RATIO =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_cga1x1x1_16dp256b_rM_TN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128u2_s6_et64x8_m64x8x32_c1x1x1_16dp256b_rM_TN_"
"transOut_"
"noShflA_dsFp8_schPd2x2x1x3_sm100f";

static constexpr const char* KERNEL_NAME_LOW_N_K_RATIO =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_rM_TN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_c1x1x1_16dp256b_rM_TN_"
"transOut_noShflA_dsFp8_schedS_sm100f";

static constexpr const char* KERNEL_NAME_LARGE_N =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_cga1x1x1_16dp256b_rM_TN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128u2_s6_et64x32_m64x32x32_c1x1x1_16dp256b_rM_TN_"
"transOut_noShflA_dsFp8_schPd2x2x1x3_sm100f";

static constexpr const char* KERNEL_NAME_DEFAULT =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_cga1x1x1_16dp256b_rM_TN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128u2_s6_et64x16_m64x16x32_c1x1x1_16dp256b_rM_TN_"
"transOut_noShflA_dsFp8_schedS_sm100f";

double const n_k_ratio = static_cast<double>(N) / static_cast<double>(K);
Expand Down Expand Up @@ -124,10 +124,9 @@ class TrtllmGenGemmRunner {
gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m;
gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n;
gemmData.mProblemDimensions.mK = k;
// TODO(jimmyzho) disable until fix trtllm-gen
// gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
// gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
// gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
gemmData.mProblemDimensions.mRank = 0;
gemmData.mProblemDimensions.mWorldSize = 1;

Expand All @@ -148,10 +147,9 @@ class TrtllmGenGemmRunner {
gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m;
gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n;
gemmData.mProblemDimensions.mK = k;
// TODO(jimmyzho) disable until fix trtllm-gen
// gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
// gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
// gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
gemmData.mProblemDimensions.mRank = 0;
gemmData.mProblemDimensions.mWorldSize = 1;

Expand Down Expand Up @@ -204,10 +202,9 @@ class TrtllmGenGemmRunner {
gemmData.mProblemDimensions.mM = mOptions.transposeMmaOutput ? n : m;
gemmData.mProblemDimensions.mN = mOptions.transposeMmaOutput ? m : n;
gemmData.mProblemDimensions.mK = k;
// TODO(jimmyzho) disable until fix trtllm-gen
// gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
// gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
// gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
gemmData.mProblemDimensions.mRank = 0;
gemmData.mProblemDimensions.mWorldSize = 1;

Expand Down
17 changes: 8 additions & 9 deletions csrc/trtllm_low_latency_gemm_runner.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ gemm::gemm::GemmData createGemmData(int64_t m, int64_t n, int64_t k) {
gemmData.mProblemDimensions.mM = n;
gemmData.mProblemDimensions.mN = m;
gemmData.mProblemDimensions.mK = k;
// TODO(jimmyzho) disable until fix trtllm-gen
gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM;
gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN;
gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK;
Expand All @@ -64,28 +63,28 @@ gemm::gemm::GemmData createGemmData(int64_t m, int64_t n, int64_t k) {
*/
int64_t select_kernel(int32_t m, int32_t n, int32_t k, const gemm::gemm::GemmInterface& interface) {
static constexpr const char* KERNEL_MMAN_8_TILEK_128 =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128_s7_et128x8_m128x8x32_cga1x1x1_16dp256b_rM_BN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x128_s7_et128x8_m128x8x32_c1x1x1_16dp256b_rM_BN_"
"transOut_schedS_sm100f";
static constexpr const char* KERNEL_MMAN_8_TILEK_256 =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x256_s4_et128x8_m128x8x32_cga1x1x1_16dp256b_rM_BN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x8x256_s4_et128x8_m128x8x32_c1x1x1_16dp256b_rM_BN_"
"transOut_schedS_sm100f";
static constexpr const char* KERNEL_MMAN_16_TILEK_128 =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s7_et128x32_m128x64x32_cga1x1x1_16dp256b_rM_BN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x128_s7_et128x32_m128x64x32_c1x1x1_16dp256b_rM_BN_"
"transOut_schedS_sm100f";
static constexpr const char* KERNEL_MMAN_16_TILEK_256 =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x256_s3_et128x32_m128x64x32_cga1x1x1_16dp256b_rM_BN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x64x256_s3_et128x32_m128x64x32_c1x1x1_16dp256b_rM_BN_"
"transOut_schedS_sm100f";
static constexpr const char* KERNEL_MMAN_32_TILEK_128 =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128_s9_et128x32_m128x32x32_cga1x1x1_16dp256b_rM_BN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x128_s9_et128x32_m128x32x32_c1x1x1_16dp256b_rM_BN_"
"transOut_schedS_sm100f";
static constexpr const char* KERNEL_MMAN_32_TILEK_256 =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x256_s5_et128x32_m128x32x32_cga1x1x1_16dp256b_rM_BN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x32x256_s5_et128x32_m128x32x32_c1x1x1_16dp256b_rM_BN_"
"transOut_schedS_sm100f";
static constexpr const char* KERNEL_MMAN_64_TILEK_128 =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128_s7_et128x16_m128x16x32_cga1x1x1_16dp256b_rM_BN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x128_s7_et128x16_m128x16x32_c1x1x1_16dp256b_rM_BN_"
"transOut_schedS_sm100f";
static constexpr const char* KERNEL_MMAN_64_TILEK_256 =
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x256_s5_et128x16_m128x16x32_cga1x1x1_16dp256b_rM_BN_"
"gemm_Bfloat16_E4m3E4m3_Fp32_t128x16x256_s5_et128x16_m128x16x32_c1x1x1_16dp256b_rM_BN_"
"transOut_schedS_sm100f";

std::string kernel_name;
Expand Down
8 changes: 4 additions & 4 deletions flashinfer/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ class ArtifactPath:

TRTLLM_GEN_FMHA: str = "55bba55929d4093682e32d817bd11ffb0441c749/fmha/trtllm-gen/"
TRTLLM_GEN_BMM: str = (
"b55211623be7f5697c5262ffd8361fc06c147bc9/batched_gemm-b3c1646-c111d7c/"
"31e75d429ff3f710de1251afdd148185f53da44d/batched_gemm-4daf11e-c111d7c/"
)
TRTLLM_GEN_GEMM: str = (
"b117d5a6b2dd2228aa966a938eac398cf336d8c0/gemm-b3c1646-1fddea2/"
"31e75d429ff3f710de1251afdd148185f53da44d/gemm-4daf11e-1fddea2/"
)
CUDNN_SDPA: str = "a72d85b019dc125b9f711300cb989430f762f5a6/fmha/cudnn/"
# For DEEPGEMM, we also need to update KernelMap.KERNEL_MAP_HASH in flashinfer/deep_gemm.py
Expand All @@ -158,11 +158,11 @@ class CheckSumHash:
"f2c0aad1e74391c4267a2f9a20ec819358b59e04588385cffb452ed341500b99"
)
TRTLLM_GEN_BMM: str = (
"0af823880730c4f0b3832d2208fab035946694b83444410b9309db5613d60195"
"2c2361bdf1deb0a2ea0f130f2d57dd62864f4400a706ac19a625d492b03460cb"
)
DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf"
TRTLLM_GEN_GEMM: str = (
"18262161e624f7da9d2d04c528c645a5ff7f5efd774024a0b2eb92748ab18bb9"
"64b7114a429ea153528dd4d4b0299363d7320964789eb5efaefec66f301523c7"
)
map_checksums: dict[str, str] = {
safe_urljoin(ArtifactPath.TRTLLM_GEN_FMHA, "checksums.txt"): TRTLLM_GEN_FMHA,
Expand Down
85 changes: 54 additions & 31 deletions flashinfer/jit/gemm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,30 @@
current_compilation_context,
)
from ..cubin_loader import (
# download_trtllm_headers,
get_artifact,
get_meta_hash,
ensure_symlink,
verify_symlinked_headers,
)
from ..utils import dtype_cutlass_map, filename_safe_dtype_map, write_if_different

GEMM_EXPORT_HEADERS = [
"Enums.h",
"GemmInterface.h",
"GemmOptions.h",
"KernelParams.h",
"KernelParamsDecl.h",
"KernelTraits.h",
"TmaDescriptor.h",
"trtllm/gen/CommonUtils.h",
"trtllm/gen/CudaArchDecl.h",
"trtllm/gen/CudaKernelLauncher.h",
"trtllm/gen/DtypeDecl.h",
"trtllm/gen/MmaDecl.h",
"trtllm/gen/SfLayoutDecl.h",
"trtllm/gen/SparsityDecl.h",
]


def gen_gemm_module() -> JitSpec:
return gen_jit_spec(
Expand Down Expand Up @@ -645,18 +663,22 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"

# TODO(jimmyzho): Re-enable after fixing trtllm-gen cubin generation issues.
# header_path = f"{include_path}/trtllmGen_gemm_export"
# header_dest_dir = (
# jit_env.FLASHINFER_CUBIN_DIR
# / "flashinfer"
# / "trtllm"
# / "gemm"
# / "trtllmGen_gemm_export"
# )
# download_trtllm_headers(
# "gemm", header_dest_dir, header_path, ArtifactPath.TRTLLM_GEN_GEMM, checksum
# )
# Fetch GEMM export headers via get_artifact() and symlink for C++ includes.
gemm_export_path = f"{include_path}/trtllmGen_gemm_export"
for header in GEMM_EXPORT_HEADERS:
h = get_artifact(
f"{gemm_export_path}/{header}", get_meta_hash(checksum, header)
)
assert h, f"{header} not found"
symlink_path = (
jit_env.FLASHINFER_CUBIN_DIR
/ "flashinfer"
/ "trtllm"
/ "gemm"
/ "trtllmGen_gemm_export"
)
ensure_symlink(symlink_path, jit_env.FLASHINFER_CUBIN_DIR / gemm_export_path)
verify_symlinked_headers(symlink_path, GEMM_EXPORT_HEADERS, checksum)

return gen_jit_spec(
"trtllm_gemm",
Expand All @@ -670,11 +692,9 @@ def gen_trtllm_gen_gemm_module() -> JitSpec:
f'-DTLLM_GEN_GEMM_CUBIN_PATH=\\"{ArtifactPath.TRTLLM_GEN_GEMM}\\"',
]
+ sm100a_nvcc_flags,
# link "include" sub-directory in cache
extra_include_paths=[
jit_env.FLASHINFER_CUBIN_DIR,
jit_env.FLASHINFER_CUBIN_DIR / include_path,
# jit_env.FLASHINFER_CUBIN_DIR,
# jit_env.FLASHINFER_CUBIN_DIR / include_path,
],
)

Expand Down Expand Up @@ -817,18 +837,22 @@ def gen_trtllm_low_latency_gemm_module() -> JitSpec:
# make sure "flashinferMetaInfo.h" is downloaded or cached
assert metainfo, f"{header_name}.h not found"

# TODO(jimmyzho): Re-enable after fixing trtllm-gen cubin generation issues.
# header_path = f"{include_path}/trtllmGen_gemm_export"
# header_dest_dir = (
# jit_env.FLASHINFER_CUBIN_DIR
# / "flashinfer"
# / "trtllm"
# / "gemm"
# / "trtllmGen_gemm_export"
# )
# download_trtllm_headers(
# "gemm", header_dest_dir, header_path, ArtifactPath.TRTLLM_GEN_GEMM, checksum
# )
# Fetch GEMM export headers via get_artifact() and symlink for C++ includes.
gemm_export_path = f"{include_path}/trtllmGen_gemm_export"
for header in GEMM_EXPORT_HEADERS:
h = get_artifact(
f"{gemm_export_path}/{header}", get_meta_hash(checksum, header)
)
assert h, f"{header} not found"
symlink_path = (
jit_env.FLASHINFER_CUBIN_DIR
/ "flashinfer"
/ "trtllm"
/ "gemm"
/ "trtllmGen_gemm_export"
)
ensure_symlink(symlink_path, jit_env.FLASHINFER_CUBIN_DIR / gemm_export_path)
verify_symlinked_headers(symlink_path, GEMM_EXPORT_HEADERS, checksum)

return gen_jit_spec(
"trtllm_low_latency_gemm",
Expand All @@ -843,9 +867,8 @@ def gen_trtllm_low_latency_gemm_module() -> JitSpec:
]
+ sm100a_nvcc_flags,
extra_include_paths=[
jit_env.FLASHINFER_CUBIN_DIR / include_path
# jit_env.FLASHINFER_CUBIN_DIR,
# jit_env.FLASHINFER_CUBIN_DIR / include_path,
jit_env.FLASHINFER_CUBIN_DIR,
jit_env.FLASHINFER_CUBIN_DIR / include_path,
],
extra_ldflags=["-lcuda"],
)
Loading
Loading