Skip to content

Commit 54eb835

Browse files
Cunxiao2002nicunxiao
andauthored
[BugFix] Correct direct copy from bf16 to fp8 (tile-ai#1090)
* [BugFix] Correct direct copy from bf16 to fp8 * fix lint * implement overloaded cast codegen for type conversion * fix lint * remove test * fix lint * trigger CI * Overload fp8 for implicit conversion * format * new format * fix: Reinterpret types to cute types in GEMM * new format * fix lint * new format * fix lint * format * trigger ci --------- Co-authored-by: nicunxiao <[email protected]>
1 parent 264a782 commit 54eb835

File tree

6 files changed

+61
-18
lines changed

6 files changed

+61
-18
lines changed

src/tl_templates/cuda/common.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
#include <cutlass/numeric_types.h>
1111
#include <math_constants.h>
1212

13+
#include <cutlass/bfloat16.h>
14+
#include <cutlass/float8.h>
15+
1316
using cutlass::bfloat16_t;
1417
using cutlass::half_t;
1518
using cutlass::tfloat32_t;
@@ -339,6 +342,37 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
339342
descriptor.reg32_[0] += (offset >> 4);
340343
}
341344

345+
// and add the desired implicit conversion from bfloat16_t.
346+
struct float_e4m3_t : public cute::float_e4m3_t {
347+
using cute::float_e4m3_t::float_e4m3_t;
348+
CUTLASS_HOST_DEVICE
349+
float_e4m3_t() = default;
350+
351+
CUTLASS_HOST_DEVICE
352+
explicit float_e4m3_t(__nv_bfloat16 x)
353+
: float_e4m3_t(static_cast<float>(x)) {}
354+
};
355+
356+
struct float_e5m2_t : public cute::float_e5m2_t {
357+
using cute::float_e5m2_t::float_e5m2_t;
358+
CUTLASS_HOST_DEVICE
359+
float_e5m2_t() = default;
360+
361+
CUTLASS_HOST_DEVICE
362+
explicit float_e5m2_t(__nv_bfloat16 x)
363+
: float_e5m2_t(static_cast<float>(x)) {}
364+
};
365+
366+
template <typename T> struct to_cute_type {
367+
using type = T;
368+
};
369+
template <> struct to_cute_type<tl::float_e4m3_t> {
370+
using type = cute::float_e4m3_t;
371+
};
372+
template <> struct to_cute_type<tl::float_e5m2_t> {
373+
using type = cute::float_e5m2_t;
374+
};
375+
342376
} // namespace tl
343377

344378
namespace cutlass {

src/tl_templates/cuda/cuda_fp8.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
#pragma once
22

3+
#include "common.h"
34
#include <cuda_fp8.h>
45
#include <cute/numeric/numeric_types.hpp>
56

6-
using fp8_e4_t = cute::float_e4m3_t;
7-
using fp8_e5_t = cute::float_e5m2_t;
7+
using fp8_e4_t = tl::float_e4m3_t;
8+
using fp8_e5_t = tl::float_e5m2_t;
89

910
struct __CUDA_ALIGN__(2) fp8_e4_2_t {
1011
fp8_e4_t x;

src/tl_templates/cuda/gemm_mma.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,14 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
263263
typename C_type_raw>
264264
class GemmTensorOp {
265265
public:
266+
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
267+
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
266268
using A_type =
267-
typename std::conditional<std::is_same<A_type_raw, float>::value,
268-
tfloat32_t, A_type_raw>::type;
269+
typename std::conditional<std::is_same<A_type_cute, float>::value,
270+
tfloat32_t, A_type_cute>::type;
269271
using B_type =
270-
typename std::conditional<std::is_same<B_type_raw, float>::value,
271-
tfloat32_t, A_type_raw>::type;
272+
typename std::conditional<std::is_same<B_type_cute, float>::value,
273+
tfloat32_t, B_type_cute>::type;
272274
using C_type = C_type_raw;
273275

274276
using Instruction =

src/tl_templates/cuda/gemm_sm100.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -289,12 +289,14 @@ template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
289289
typename C_type_raw>
290290
class GemmTensorOp {
291291
public:
292+
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
293+
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
292294
using A_type =
293-
typename std::conditional<std::is_same<A_type_raw, float>::value,
294-
tfloat32_t, A_type_raw>::type;
295+
typename std::conditional<std::is_same<A_type_cute, float>::value,
296+
tfloat32_t, A_type_cute>::type;
295297
using B_type =
296-
typename std::conditional<std::is_same<B_type_raw, float>::value,
297-
tfloat32_t, B_type_raw>::type;
298+
typename std::conditional<std::is_same<B_type_cute, float>::value,
299+
tfloat32_t, B_type_cute>::type;
298300
using C_type = C_type_raw;
299301

300302
static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32);

src/tl_templates/cuda/gemm_sm90.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
2121
typename B_type_raw, typename C_type_raw>
2222
class GemmTensorOp {
2323
public:
24-
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
25-
tfloat32_t, A_type_raw>;
26-
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
27-
tfloat32_t, B_type_raw>;
24+
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
25+
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
26+
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
27+
tfloat32_t, A_type_cute>;
28+
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
29+
tfloat32_t, A_type_cute>;
2830
using C_type = C_type_raw;
2931

3032
static constexpr GMMA::Major GmmaMajorA =

src/tl_templates/cuda/gemm_sp_sm90.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ class GemmTensorOp {
1313
public:
1414
static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");
1515

16-
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
17-
tfloat32_t, A_type_raw>;
18-
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
19-
tfloat32_t, B_type_raw>;
16+
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
17+
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
18+
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
19+
tfloat32_t, A_type_cute>;
20+
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
21+
tfloat32_t, B_type_cute>;
2022
using C_type = C_type_raw;
2123

2224
static constexpr bool need_tfloat32_cast =

0 commit comments

Comments
 (0)