Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
src/transform/*.cc
src/op/*.cc
src/target/utils.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
)

# Include CUDA source files if CUDA is enabled
Expand Down
71 changes: 55 additions & 16 deletions src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,33 @@
*
*/

#include <tvm/arith/analyzer.h>
#include <tvm/script/ir_builder/tir/ir.h>

namespace tvm {
namespace tl {

constexpr const char *tilelang_is_cpu_kernel_frame =
"tilelang.is_cpu_kernel_frame";

using namespace script::ir_builder::tir;

static ForFrame MakeIterVarFrame(std::string name, PrimExpr dom) {
using namespace tvm::tir;
Var var = Var(name);
// Create a frame that represents a loop over the given domain.
ObjectPtr<ForFrameNode> n = make_object<ForFrameNode>();
n->vars.push_back(var);
n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](Array<Var> vars, Array<Range> doms,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body);
};
return ForFrame(n);
}

ForFrame ParallelFor(Array<PrimExpr> extents,
Map<String, ObjectRef> annotations) {
using namespace tvm::tir;
Expand Down Expand Up @@ -121,24 +141,43 @@ KernelLaunchFrame KernelLaunch(Array<PrimExpr> grid_size,
Array<PrimExpr> block_size,
Map<String, ObjectRef> attrs) {
ObjectPtr<KernelLaunchFrameNode> n = make_object<KernelLaunchFrameNode>();
ICHECK(grid_size.size() <= 3);
if (grid_size.size() > 0)
n->frames.push_back(LaunchThread("blockIdx.x", grid_size[0]));
if (grid_size.size() > 1)
n->frames.push_back(LaunchThread("blockIdx.y", grid_size[1]));
if (grid_size.size() > 2)
n->frames.push_back(LaunchThread("blockIdx.z", grid_size[2]));
if (block_size.defined()) {
ICHECK(block_size.size() <= 3);
if (block_size.size() > 0)
n->frames.push_back(LaunchThread("threadIdx.x", block_size[0]));
if (block_size.size() > 1)
n->frames.push_back(LaunchThread("threadIdx.y", block_size[1]));
if (block_size.size() > 2)
n->frames.push_back(LaunchThread("threadIdx.z", block_size[2]));

// If the kernel is a CPU kernel, we don't need to launch any threads.
bool is_cpu_kernel_frame =
attrs.defined() && attrs.count(tilelang_is_cpu_kernel_frame);

if (is_cpu_kernel_frame) {
ICHECK(grid_size.size() >= 0);
ICHECK(block_size.size() == 0) << "CPU kernel cannot have block size";
ICHECK(attrs.defined());
// create grid loop var
for (int i = 0; i < grid_size.size(); i++) {
n->frames.push_back(
MakeIterVarFrame("block_var_" + std::to_string(i), grid_size[i]));
}
// Launch CPU Kernel
} else {
n->frames.push_back(Block(""));
// Launch GPU Kernel
ICHECK(grid_size.size() <= 3);
if (grid_size.size() > 0)
n->frames.push_back(LaunchThread("blockIdx.x", grid_size[0]));
if (grid_size.size() > 1)
n->frames.push_back(LaunchThread("blockIdx.y", grid_size[1]));
if (grid_size.size() > 2)
n->frames.push_back(LaunchThread("blockIdx.z", grid_size[2]));
if (block_size.defined()) {
ICHECK(block_size.size() <= 3);
if (block_size.size() > 0)
n->frames.push_back(LaunchThread("threadIdx.x", block_size[0]));
if (block_size.size() > 1)
n->frames.push_back(LaunchThread("threadIdx.y", block_size[1]));
if (block_size.size() > 2)
n->frames.push_back(LaunchThread("threadIdx.z", block_size[2]));
} else {
n->frames.push_back(Block(""));
}
}

if (attrs.defined()) {
auto empty_block = Block("");
empty_block->annotations = attrs;
Expand Down
19 changes: 14 additions & 5 deletions src/op/elem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ For Copy::MakeSIMTLoop(arith::Analyzer *analyzer) const {
}

Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
Target target = T.target;
bool is_cpu_target = target->GetTargetDeviceType() == kDLCPU;
Stmt ldsm_stmt = LowerLDSMCopy(T, analyzer);
if (ldsm_stmt.defined())
return ldsm_stmt;
Expand All @@ -148,12 +150,19 @@ Stmt Copy::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
auto simt_loop = MakeSIMTLoop(analyzer);
auto fused_loop = Downcast<For>(ParallelLoopFuser::Fuse(simt_loop));

For vectorized_thread_loop;
auto par_op = std::make_unique<ParallelOp>(fused_loop);
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);

if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(fused_loop);
} else {
par_op->InferLayout({T.target, T.block_size, T.layout_map, T.buffer_remap},
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
vectorized_thread_loop = VectorizeLoop(thread_loop);
}

if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
Expand Down
Loading
Loading