diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp index 41e0bab1624..3c9e42996b0 100644 --- a/hopper/tile_scheduler.hpp +++ b/hopper/tile_scheduler.hpp @@ -251,7 +251,7 @@ class DynamicPersistentTileScheduler { static Params to_underlying_arguments(TileSchedulerArguments const& args) { - int const size_one_kv_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; + long long const size_one_kv_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); int const size_l2 = 32 * 1024 * 1024; // 32 MB for K & V // Swizzle is the size of each "section". Round swizzle to a power of 2 // If not PackGQA already, the size of each section can increase by qhead_per_khead @@ -382,9 +382,9 @@ class SingleTileBwdLPTScheduler { static Params to_underlying_arguments(TileSchedulerArguments const& args) { // Since it's the bwd pass, seqlen_k get passed to args.seqlen and seqlen_q is passed to args.seqlen_k - int const size_one_qdo_head = args.seqlen_k * (args.headdim + args.headdim_v) * args.element_size; - int const size_one_dqaccum_head = args.seqlen_k * args.headdim * sizeof(float); - int const size_one_head = size_one_qdo_head + size_one_dqaccum_head; + long long const size_one_qdo_head = long(args.seqlen_k) * long(args.headdim + args.headdim_v) * long(args.element_size); + long long const size_one_dqaccum_head = long(args.seqlen_k) * long(args.headdim) * sizeof(float); + long long const size_one_head = size_one_qdo_head + size_one_dqaccum_head; int const size_l2 = 40 * 1024 * 1024; // 40 MB for Q, dO, and dQaccum // Swizzle is the size of each "section". Round swizzle to a power of 2 // Need to be careful about the case where only one head will fit