Skip to content

Commit df7708f

Browse files
committed
unified test
1 parent 754c83e commit df7708f

File tree

3 files changed

+403
-2
lines changed

3 files changed

+403
-2
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include <tvm/tir/stmt_functor.h>
2929
#include <tvm/tir/index_map.h>
3030

31-
#include <algorithm>
3231
#include <cmath>
3332
#include <string>
3433
#include <utility>

src/tir/transforms/lower_warp_memory.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ Pass LowerWarpMemory() {
472472
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
473473
auto* n = f.CopyOnWrite();
474474
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
475-
int warp_size = 32;
475+
int warp_size = target.value()->GetAttr<Integer>("thread_warp_size", 1).value();
476476
WarpMemoryRewriter warp_memory_rewriter(warp_size);
477477
auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body));
478478
n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt);

0 commit comments

Comments
 (0)