-
Notifications
You must be signed in to change notification settings - Fork 59
[CHUNK_PREFILL] fp8kv cache #128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,6 +36,8 @@ struct chunk_prefill_args_t { | |
| int max_keys; | ||
| int total_seqlen_q; | ||
| int total_seqlen_k; | ||
| float k_scale = 1.0; | ||
| float v_scale = 1.0; | ||
| float sm_scale; | ||
| void* sm_sink; | ||
| int batch_size; | ||
|
|
@@ -145,6 +147,8 @@ struct KernelLauncher { | |
| stride_O, | ||
| reinterpret_cast<ElementQ*>(args.sm_sink)}, | ||
| {args.sm_scale, | ||
| args.k_scale, | ||
| args.v_scale, | ||
| static_cast<int*>(args.block_table), | ||
| args.block_size, | ||
| args.max_blocks_per_seq, | ||
|
|
@@ -315,35 +319,109 @@ struct FMHAConfig { | |
|
|
||
| template <typename chunk_policy, bool Paged, bool Causal, bool Local, bool Sink> | ||
| void policy_dispatch_impl( | ||
| sycl::queue& queue, CutlassType cuType, const chunk_prefill_args_t& args) { | ||
| sycl::queue& queue, | ||
| CutlassQKType& cuQKType, | ||
| const chunk_prefill_args_t& args) { | ||
| const int PipelineStages = 2; | ||
| if (cuType == CutlassType::half) { | ||
| return FMHAConfig< | ||
| typename chunk_policy::ShapeQK, | ||
| typename chunk_policy::ShapePV, | ||
| typename chunk_policy::ShapeOut, | ||
| typename chunk_policy::SubgroupLayoutQK, | ||
| void, | ||
| PipelineStages, | ||
| Paged, | ||
| Causal, | ||
| Local, | ||
| Sink, | ||
| half_t, | ||
| half_t, | ||
| half_t, | ||
| half_t>::kernel_dispatch(queue, args); | ||
| if (cuQKType.q_type == CutlassDType::half) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I prefer to make Q/KV type also template here in the future. I feel current change will increase compile time a little bit. |
||
| if (cuQKType.k_type == CutlassDType::half) { | ||
| return FMHAConfig< | ||
| typename chunk_policy::ShapeQK, | ||
| typename chunk_policy::ShapePV, | ||
| typename chunk_policy::ShapeOut, | ||
| typename chunk_policy::SubgroupLayoutQK, | ||
| void, | ||
| PipelineStages, | ||
| Paged, | ||
| Causal, | ||
| Local, | ||
| Sink, | ||
| half_t, | ||
| half_t, | ||
| half_t, | ||
| half_t>::kernel_dispatch(queue, args); | ||
| } else if (cuQKType.k_type == CutlassDType::float8_e4m3) { | ||
| return FMHAConfig< | ||
| typename chunk_policy::ShapeQK, | ||
| typename chunk_policy::ShapePV, | ||
| typename chunk_policy::ShapeOut, | ||
| typename chunk_policy::SubgroupLayoutQK, | ||
| void, | ||
| PipelineStages, | ||
| Paged, | ||
| Causal, | ||
| Local, | ||
| Sink, | ||
| half_t, | ||
| float_e4m3_t, | ||
| float_e4m3_t, | ||
| half_t>::kernel_dispatch(queue, args); | ||
| } else if (cuQKType.k_type == CutlassDType::float8_e5m2) { | ||
| return FMHAConfig< | ||
| typename chunk_policy::ShapeQK, | ||
| typename chunk_policy::ShapePV, | ||
| typename chunk_policy::ShapeOut, | ||
| typename chunk_policy::SubgroupLayoutQK, | ||
| void, | ||
| PipelineStages, | ||
| Paged, | ||
| Causal, | ||
| Local, | ||
| Sink, | ||
| half_t, | ||
| float_e5m2_t, | ||
| float_e5m2_t, | ||
| half_t>::kernel_dispatch(queue, args); | ||
| } | ||
| } else { | ||
| return FMHAConfig< | ||
| typename chunk_policy::ShapeQK, | ||
| typename chunk_policy::ShapePV, | ||
| typename chunk_policy::ShapeOut, | ||
| typename chunk_policy::SubgroupLayoutQK, | ||
| void, | ||
| PipelineStages, | ||
| Paged, | ||
| Causal, | ||
| Local, | ||
| Sink>::kernel_dispatch(queue, args); | ||
| if (cuQKType.k_type == CutlassDType::bfloat16) { | ||
| return FMHAConfig< | ||
| typename chunk_policy::ShapeQK, | ||
| typename chunk_policy::ShapePV, | ||
| typename chunk_policy::ShapeOut, | ||
| typename chunk_policy::SubgroupLayoutQK, | ||
| void, | ||
| PipelineStages, | ||
| Paged, | ||
| Causal, | ||
| Local, | ||
| Sink, | ||
| bfloat16_t, | ||
| bfloat16_t, | ||
| bfloat16_t, | ||
| bfloat16_t>::kernel_dispatch(queue, args); | ||
| } else if (cuQKType.k_type == CutlassDType::float8_e4m3) { | ||
| return FMHAConfig< | ||
| typename chunk_policy::ShapeQK, | ||
| typename chunk_policy::ShapePV, | ||
| typename chunk_policy::ShapeOut, | ||
| typename chunk_policy::SubgroupLayoutQK, | ||
| void, | ||
| PipelineStages, | ||
| Paged, | ||
| Causal, | ||
| Local, | ||
| Sink, | ||
| bfloat16_t, | ||
| float_e4m3_t, | ||
| float_e4m3_t, | ||
| bfloat16_t>::kernel_dispatch(queue, args); | ||
| } else if (cuQKType.k_type == CutlassDType::float8_e5m2) { | ||
| return FMHAConfig< | ||
| typename chunk_policy::ShapeQK, | ||
| typename chunk_policy::ShapePV, | ||
| typename chunk_policy::ShapeOut, | ||
| typename chunk_policy::SubgroupLayoutQK, | ||
| void, | ||
| PipelineStages, | ||
| Paged, | ||
| Causal, | ||
| Local, | ||
| Sink, | ||
| bfloat16_t, | ||
| float_e5m2_t, | ||
| float_e5m2_t, | ||
| bfloat16_t>::kernel_dispatch(queue, args); | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -131,6 +131,9 @@ struct FMHAFwdMainloop< | |
| using TensorK = TensorK_; | ||
| using TensorV = TensorV_; | ||
|
|
||
| using ElementQ = typename TensorQ::engine_type::value_type; | ||
| using ElementK = typename TensorK::engine_type::value_type; | ||
|
|
||
| using TensorQ2D = | ||
| decltype(TensorQ_{}(append<rank_v<TensorQ_>>(make_coord(_, _), 0))); | ||
| using TensorK2D = | ||
|
|
@@ -178,13 +181,17 @@ struct FMHAFwdMainloop< | |
| using FragARow = decltype(reduce<1>(FragA{}, sycl::plus<void>{})); | ||
| using ElementA = typename TiledMMAPV::ValTypeD; | ||
|
|
||
| static constexpr bool Fp8KV = | ||
| is_any_of_v<ElementK, float_e5m2_t, float_e4m3_t>; | ||
| static constexpr bool CausalMask = CausalMask_; | ||
| static constexpr bool LocalMask = LocalMask_; | ||
| static constexpr bool PagedKV = PagedKV_; | ||
|
|
||
| // User-facing arguments | ||
| struct Arguments { | ||
| ElementS const scale; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's scale here for?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. softmax scale |
||
| ElementS const scale_k; | ||
| ElementS const scale_v; | ||
|
|
||
| // Paged KV Cache | ||
| int* ptr_page_table; | ||
|
|
@@ -215,6 +222,8 @@ struct FMHAFwdMainloop< | |
| ElementS val = args.scale * static_cast<ElementS>(kLog2e); | ||
| return Params{ | ||
| val, | ||
| args.scale_k, | ||
| args.scale_v, | ||
| args.ptr_page_table, | ||
| args.page_size, | ||
| args.max_pages_per_seq, | ||
|
|
@@ -375,6 +384,12 @@ struct FMHAFwdMainloop< | |
|
|
||
| reorder(tQrQ, tSrQ); | ||
| reorder(tKrK, tSrK); | ||
| if constexpr (Fp8KV) { | ||
| for (int i = 0; i < tSrK.size(); ++i) { | ||
| tSrK(i) = static_cast<ElementQ>( | ||
| params.scale_k * static_cast<float>(tSrK(i))); | ||
| } | ||
| } | ||
| cute::gemm(mma_qk, tSrQ, tSrK, tSrS); | ||
| } | ||
|
|
||
|
|
@@ -440,6 +455,13 @@ struct FMHAFwdMainloop< | |
| for (int VV = 0; VV < VTiles; VV++) { | ||
| copy(copy_v, tVgV_cache(_, _, _, VV), tVrV); | ||
| reorder(tVrV, tArV); | ||
| if constexpr (Fp8KV) { | ||
| CUTLASS_PRAGMA_UNROLL | ||
| for (int i = 0; i < tArV.size(); ++i) { | ||
| tArV(i) = static_cast<ElementQ>( | ||
| params.scale_v * static_cast<float>(tArV(i))); | ||
| } | ||
| } | ||
| cute::gemm(mma_pv, tArP, tArV, tArA(_, _, _, VV)); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can reformat here.