Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions hopper/tile_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ class DynamicPersistentTileScheduler {

static Params
to_underlying_arguments(TileSchedulerArguments const& args) {
int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size;
long long const size_one_kv_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size);
int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V
// Swizzle is the size of each "section". Round swizzle to a power of 2
// If not PackGQA already, the size of each section can increase by qhead_per_khead
Expand Down Expand Up @@ -382,9 +382,9 @@ class SingleTileBwdLPTScheduler {
static Params
to_underlying_arguments(TileSchedulerArguments const& args) {
// Since it's the bwd pass, seqlen_k get passed to args.seqlen and seqlen_q is passed to args.seqlen_k
int const size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size;
int const size_one_dqaccum_head = args.seqlen_k * args.headdim * sizeof(float);
int const size_one_head = size_one_qdo_head + size_one_dqaccum_head;
long long const size_one_qdo_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size);
long long const size_one_dqaccum_head = long(args.seqlen_k) * long(args.headdim) * sizeof(float);
long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head;
int const size_l2 = 40 * 1024 * 1024; // 40 MB for Q, dO, and dQaccum
// Swizzle is the size of each "section". Round swizzle to a power of 2
// Need to be careful about the case where only one head will fit
Expand Down