Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 90581f to afc079
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
# Extra constraint to tvm-ffi for abi issue,
# should be removed after our tvm's update.
# See discussion in tilelang#1373 and apache/tvm-ffi#307
"apache-tvm-ffi<=0.1.1",
"apache-tvm-ffi>=0.1.2",
"cloudpickle",
"ml-dtypes",
"numpy>=1.23.5",
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Requirements to run local build with `--no-build-isolation` or other developments

apache-tvm-ffi~=0.1.0
apache-tvm-ffi>=0.1.2
build
cmake>=3.26
cython>=3.0.0
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Runtime requirements
apache-tvm-ffi~=0.1.0
apache-tvm-ffi>=0.1.2
cloudpickle
ml-dtypes
numpy>=1.23.5
Expand Down
43 changes: 30 additions & 13 deletions src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,22 @@ static ForFrame MakeIterVarFrame(const std::string &name, const PrimExpr &dom) {
n->vars.push_back(var);
n->doms.push_back(Range(0, dom));
n->f_make_for_loop = [](const Array<Var> &vars, const Array<Range> &doms,
const Stmt &body) -> Stmt {
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), 1);
ICHECK_EQ(doms.size(), 1);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
/*thread_binding=*/std::nullopt,
/*annotations=*/tvm::ffi::Map<tvm::ffi::String, tvm::ffi::Any>{},
/*step=*/step);
};
return ForFrame(n);
}

ForFrame ParallelFor(const Array<PrimExpr> &extents,
const Map<String, ObjectRef> &annotations) {
const Map<String, tvm::ffi::Any> &annotations) {
using namespace tvm::tir;
ObjectPtr<ForFrameNode> n = tvm::ffi::make_object<ForFrameNode>();
n->vars.reserve(extents.size());
Expand All @@ -63,16 +69,19 @@ ForFrame ParallelFor(const Array<PrimExpr> &extents,
n->vars.push_back(Var("v", extent.dtype()));
n->doms.push_back(Range(make_const(dtype, 0), extent));
}
n->f_make_for_loop = [annotations](const Array<Var> &vars,
const Array<Range> &doms,
Stmt body) -> Stmt {
n->f_make_for_loop =
[annotations](const Array<Var> &vars, const Array<Range> &doms,
const Array<Optional<PrimExpr>> &steps, Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
int n = vars.size();
for (int i = n - 1; i >= 0; --i) {
Range dom = doms[i];
Var var = vars[i];
Optional<PrimExpr> step =
i < steps.size() ? steps[i] : Optional<PrimExpr>(std::nullopt);
body = For(var, dom->min, dom->extent, ForKind::kParallel, body,
/*thread_binding=*/std::nullopt, /*annotations=*/annotations);
/*thread_binding=*/std::nullopt, /*annotations=*/annotations,
/*step=*/step);
}
return body;
};
Expand All @@ -90,11 +99,12 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
n->vars.push_back(Var("v", dtype));
n->doms.push_back(Range(std::move(start), stop));
n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
int n = vars.size();
ICHECK(n == 1);
Map<String, ObjectRef> anno;
Map<String, tvm::ffi::Any> anno;
if (num_stages > 0)
anno.Set("num_stages", PrimExpr(num_stages));
if (!order.empty())
Expand All @@ -105,8 +115,11 @@ ForFrame PipelinedFor(PrimExpr start, const PrimExpr &stop, int num_stages,
anno.Set("tl_pipeline_sync", sync);
if (!groups.empty())
anno.Set("tl_pipeline_group", groups);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
body = For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kSerial, body,
/*thread_binding=*/std::nullopt, /*annotations=*/anno);
/*thread_binding=*/std::nullopt, /*annotations=*/anno,
/*step=*/step);
return body;
};
return ForFrame(n);
Expand Down Expand Up @@ -145,9 +158,10 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
grouped_domain.push_back(group_size);

n->f_make_for_loop = [=](const Array<Var> &vars, const Array<Range> &doms,
const Stmt &body) -> Stmt {
const Array<Optional<PrimExpr>> &steps,
Stmt body) -> Stmt {
ICHECK_EQ(vars.size(), doms.size());
Map<String, ObjectRef> anno;
Map<String, tvm::ffi::Any> anno;
Array<PrimExpr> idxs(grouped_domain.size(), PrimExpr());
PrimExpr rem = loop_var * wave_size + index;

Expand All @@ -168,8 +182,11 @@ ForFrame PersistentFor(const Array<PrimExpr> &domain, const PrimExpr &wave_size,
if (analyzer.CanProveGreaterEqual(waves, 2)) {
new_body = SeqStmt({out_if, body});
}
Stmt outer =
For(loop_var, 0, waves, ForKind::kSerial, new_body, std::nullopt, anno);
Optional<PrimExpr> step =
!steps.empty() ? steps[0] : Optional<PrimExpr>(std::nullopt);
Stmt outer = For(loop_var, 0, waves, ForKind::kSerial, new_body,
/*thread_binding=*/std::nullopt, /*annotations=*/anno,
/*step=*/step);
for (int i = 0; i < vars.size() - 1; ++i) {
outer = tvm::tir::LetStmt(vars[i], idxs[i + 1], outer);
}
Expand Down
3 changes: 2 additions & 1 deletion src/transform/atomicadd_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ class AtomicAddVectorizeRewriter : public StmtExprMutator {
vmap.Set(old_var, new_var * vector_size_);
Stmt body = Substitute(fnode->body, vmap);
return For(new_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
fnode->thread_binding, fnode->annotations, fnode->step,
fnode->span);
}
}
return ret;
Expand Down
3 changes: 2 additions & 1 deletion src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ class VectorizeRewriter : public StmtExprMutator {
Stmt body = Substitute(fnode->body, vmap);
body = For(inner_var, 0, vector_size_, ForKind::kVectorized, body);
body = For(outer_var, 0, extent / vector_size_, fnode->kind, body,
fnode->thread_binding, fnode->annotations, fnode->span);
fnode->thread_binding, fnode->annotations, fnode->step,
fnode->span);
return body;
}
} else {
Expand Down
Loading