diff --git a/3rdparty/tvm b/3rdparty/tvm index 90581fe9e..afc079350 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 90581fe9e5287bbcf1844ad14255a1e1e8cdf7f0 +Subproject commit afc079350def46a78931c6edeb7bad3fb248b4e1 diff --git a/pyproject.toml b/pyproject.toml index 66424c02c..088737d47 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ # Extra constraint to tvm-ffi for abi issue, # should be removed after our tvm's update. # See discussion in tilelang#1373 and apache/tvm-ffi#307 - "apache-tvm-ffi<=0.1.1", + "apache-tvm-ffi>=0.1.2", "cloudpickle", "ml-dtypes", "numpy>=1.23.5", diff --git a/requirements-dev.txt b/requirements-dev.txt index 6cd968731..ef8e98b63 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,6 @@ # Requirements to run local build with `--no-build-isolation` or other developments -apache-tvm-ffi~=0.1.0 +apache-tvm-ffi>=0.1.2 build cmake>=3.26 cython>=3.0.0 diff --git a/requirements.txt b/requirements.txt index 3ad186ed4..58e851d75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # Runtime requirements -apache-tvm-ffi~=0.1.0 +apache-tvm-ffi>=0.1.2 cloudpickle ml-dtypes numpy>=1.23.5 diff --git a/src/ir.cc b/src/ir.cc index 3d2b3ecdc..82a94cb8e 100644 --- a/src/ir.cc +++ b/src/ir.cc @@ -44,16 +44,22 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) { n->vars.push_back(var); n->doms.push_back(Range(0, dom)); n->f_make_for_loop = [](const Array &vars, const Array &doms, - const Stmt &body) -> Stmt { + const Array> &steps, + Stmt body) -> Stmt { ICHECK_EQ(vars.size(), 1); ICHECK_EQ(doms.size(), 1); - return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body); + Optional step = + !steps.empty() ? steps[0] : Optional(std::nullopt); + return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body, + /*thread_binding=*/std::nullopt, + /*annotations=*/tvm::ffi::Map{}, + /*step=*/step); }; return ForFrame(n); } ForFrame ParallelFor(const Array &extents, - const Map &annotations) { + const Map &annotations) { using namespace tvm::tir; ObjectPtr n = tvm::ffi::make_object(); n->vars.reserve(extents.size()); @@ -63,16 +69,19 @@ ForFrame ParallelFor(const Array &extents, n->vars.push_back(Var("v", extent.dtype())); n->doms.push_back(Range(make_const(dtype, 0), extent)); } - n->f_make_for_loop = [annotations](const Array &vars, - const Array &doms, - Stmt body) -> Stmt { + n->f_make_for_loop = + [annotations](const Array &vars, const Array &doms, + const Array> &steps, Stmt body) -> Stmt { ICHECK_EQ(vars.size(), doms.size()); int n = vars.size(); for (int i = n - 1; i >= 0; --i) { Range dom = doms[i]; Var var = vars[i]; + Optional step = + i < steps.size() ? steps[i] : Optional(std::nullopt); body = For(var, dom->min, dom->extent, ForKind::kParallel, body, - /*thread_binding=*/std::nullopt, /*annotations=*/annotations); + /*thread_binding=*/std::nullopt, /*annotations=*/annotations, + /*step=*/step); } return body; }; @@ -90,11 +99,12 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, n->vars.push_back(Var("v", dtype)); n->doms.push_back(Range(std::move(start), stop)); n->f_make_for_loop = [=](const Array &vars, const Array &doms, + const Array> &steps, Stmt body) -> Stmt { ICHECK_EQ(vars.size(), doms.size()); int n = vars.size(); ICHECK(n == 1); - Map anno; + Map anno; if (num_stages > 0) anno.Set("num_stages", PrimExpr(num_stages)); if (!order.empty()) @@ -105,8 +115,11 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages, anno.Set("tl_pipeline_sync", sync); if (!groups.empty()) anno.Set("tl_pipeline_group", groups); + Optional step = + !steps.empty() ? steps[0] : Optional(std::nullopt); body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body, - /*thread_binding=*/std::nullopt, /*annotations=*/anno); + /*thread_binding=*/std::nullopt, /*annotations=*/anno, + /*step=*/step); return body; }; return ForFrame(n); @@ -145,9 +158,10 @@ ForFrame PersistentFor(const Array &domain, const PrimExpr &wave_size, grouped_domain.push_back(group_size); n->f_make_for_loop = [=](const Array &vars, const Array &doms, - const Stmt &body) -> Stmt { + const Array> &steps, + Stmt body) -> Stmt { ICHECK_EQ(vars.size(), doms.size()); - Map anno; + Map anno; Array idxs(grouped_domain.size(), PrimExpr()); PrimExpr rem = loop_var * wave_size + index; @@ -168,8 +182,11 @@ ForFrame PersistentFor(const Array &domain, const PrimExpr &wave_size, if (analyzer.CanProveGreaterEqual(waves, 2)) { new_body = SeqStmt({out_if, body}); } - Stmt outer = - For(loop_var, 0, waves, ForKind::kSerial, new_body, std::nullopt, anno); + Optional step = + !steps.empty() ? steps[0] : Optional(std::nullopt); + Stmt outer = For(loop_var, 0, waves, ForKind::kSerial, new_body, + /*thread_binding=*/std::nullopt, /*annotations=*/anno, + /*step=*/step); for (int i = 0; i < vars.size() - 1; ++i) { outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer); } diff --git a/src/transform/atomicadd_vectorize.cc b/src/transform/atomicadd_vectorize.cc index 40cb81402..37a6d5896 100644 --- a/src/transform/atomicadd_vectorize.cc +++ b/src/transform/atomicadd_vectorize.cc @@ -203,7 +203,8 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator { vmap.Set(old_var, new_var * vector_size_); Stmt body = Substitute(fnode->body, vmap); return For(new_var, 0, extent / vector_size_, fnode->kind, body, - fnode->thread_binding, fnode->annotations, fnode->span); + fnode->thread_binding, fnode->annotations, fnode->step, + fnode->span); } } return ret; diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index e8a18b004..836a52b4e 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -232,7 +232,8 @@ class VectorizeRewriter : public StmtExprMutator { Stmt body = Substitute(fnode->body, vmap); body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body); body = For(outer_var, 0, extent / vector_size_, fnode->kind, body, - fnode->thread_binding, fnode->annotations, fnode->span); + fnode->thread_binding, fnode->annotations, fnode->step, + fnode->span); return body; } } else {