From 69d6c4a4f24b6ad877ffc130808622717148af80 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Sun, 30 Mar 2025 11:40:55 -0700 Subject: [PATCH] fix bmm fp8 --- sgl-kernel/csrc/torch_extension.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sgl-kernel/csrc/torch_extension.cc b/sgl-kernel/csrc/torch_extension.cc index 3633c9f407e..263a9d15ca6 100644 --- a/sgl-kernel/csrc/torch_extension.cc +++ b/sgl-kernel/csrc/torch_extension.cc @@ -82,7 +82,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { /* * From FlashInfer */ - m.def("bmm_fp8", bmm_fp8); + m.def( + "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " + "cublas_handle, int cuda_stream) -> ()"); + m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); m.def("min_p_sampling_from_probs", min_p_sampling_from_probs); m.def("top_k_renorm_probs", top_k_renorm_probs); m.def("top_p_renorm_probs", top_p_renorm_probs);