Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tile_engine/ops/gemm/gemm_instance_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,11 +450,11 @@ def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};

// Traits
static constexpr bool kPadM = {"true" if pad_m == "true" else "false"};
static constexpr bool kPadN = {"true" if pad_n == "true" else "false"};
static constexpr bool kPadK = {"true" if pad_k == "true" else "false"};
static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"};
static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"};
static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"};
static constexpr bool TransposeC = false;
static constexpr bool UsePersistentKernel = {"true" if persistent == "true" else "false"};
static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"};
static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"};
static constexpr bool UseStructuredSparsity = false;
static constexpr bool Preshuffle = false;
Expand Down Expand Up @@ -576,7 +576,7 @@ def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
}}

// Get grid and block sizes
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent == "true" else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
const dim3 blocks = GemmKernel::BlockSize();

if(stream.log_level_ > 0) {{
Expand Down
Loading