Skip to content

Commit 54e31f3

Browse files
authored
[Relax] Capture symbolic vars in struct info of weights (#16834)
1 parent 9862c84 commit 54e31f3

File tree

2 files changed

+121
-15
lines changed

2 files changed

+121
-15
lines changed

src/relax/transform/rewrite_cuda_graph.cc

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,31 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
239239
if (pair.second->IsInstance<FunctionNode>()) {
240240
// If a function has the num_input attribute, the last func->params.size() - num_inputs
241241
// inputs are assumed to be fixed and thus they can be captured into a cuda graph.
242+
// The symbolic variables in the struct info of the fixed inputs (weights) are also allowed
243+
// to be captured.
244+
// If the hints for capturing symbolic variables via
245+
// 'relax.rewrite_cuda_graph.capture_symbolic_vars' annotation, the actual variables with
246+
// these names are extracted from the struct info for the capturing.
242247
const auto& func = Downcast<Function>(pair.second);
243-
if (auto num_input = func->attrs.GetAttr<Integer>(attr::kNumInput)) {
244-
for (size_t i = num_input.value().IntValue(); i < func->params.size(); ++i) {
248+
auto num_inputs =
249+
func->attrs.GetAttr<Integer>(attr::kNumInput).value_or(Integer(func->params.size()));
250+
auto capture_symbolic_var_name_hints = ExtractSymbolicVarHints(func);
251+
for (int i = 0; i < static_cast<int>(func->params.size()); ++i) {
252+
Array<tir::Var> symbolic_vars = DefinableTIRVarsInStructInfo(
253+
Downcast<StructInfo>(func->params[i]->struct_info_.value()));
254+
if (i < num_inputs.IntValue()) {
255+
for (const auto& symbolic_var : symbolic_vars) {
256+
if (capture_symbolic_var_name_hints.count(symbolic_var->name_hint)) {
257+
capture_symbolic_vars_.insert(symbolic_var.get());
258+
}
259+
}
260+
} else {
245261
static_vars_.insert(func->params[i].get());
262+
for (const auto& symbolic_var : symbolic_vars) {
263+
capture_symbolic_vars_.insert(symbolic_var.get());
264+
}
246265
}
247266
}
248-
CollectSymbolicVarHints(func);
249267
disabled_storage_vars_ = OutputStorageCollector::Collect(func);
250268
VisitExpr(func);
251269
}
@@ -284,17 +302,16 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
284302
}
285303

286304
/*!
287-
* \brief Collect the name hints of the symbolic variables that are allowed to be captured.
305+
* \brief Extract the name hints of the symbolic variables that are allowed to be captured
306+
* from the function attributes.
288307
*/
289-
void CollectSymbolicVarHints(const Function& func) {
290-
capture_symbolic_vars_.clear();
291-
if (auto symbolic_vars =
292-
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars")) {
293-
for (const auto& var : symbolic_vars.value()) {
294-
capture_symbolic_vars_.insert(var);
295-
}
296-
}
308+
std::unordered_set<String> ExtractSymbolicVarHints(const Function& func) {
309+
auto symbolic_var_names =
310+
func->attrs.GetAttr<Array<String>>("relax.rewrite_cuda_graph.capture_symbolic_vars")
311+
.value_or(Array<String>());
312+
return {symbolic_var_names.begin(), symbolic_var_names.end()};
297313
}
314+
298315
/*!
299316
*\brief Start a new static region. This method should be called when encountering a
300317
* CUDA kernel launch (calls to PrimFunc or ExternFunc) that only depends on static parameters.
@@ -467,7 +484,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
467484
bool is_static = true;
468485
tir::PostOrderVisit(expr, [&](const ObjectRef& e) {
469486
if (auto var = e.as<tir::VarNode>()) {
470-
if (!capture_symbolic_vars_.count(var->name_hint)) {
487+
if (!capture_symbolic_vars_.count(var)) {
471488
is_static = false;
472489
return;
473490
}
@@ -596,8 +613,9 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
596613
FunctionScope current_function_scope_;
597614
// Variables whose buffer address is fixed
598615
std::unordered_set<const VarNode*> static_vars_;
599-
// The name of the variables that are allowed to be symbolic
600-
std::unordered_set<String> capture_symbolic_vars_;
616+
// Symbolic variables that are allowed to be captured. This can come from symbolic shapes of
617+
// weights or hints in the function annotations.
618+
std::unordered_set<const tir::VarNode*> capture_symbolic_vars_;
601619
// Binding to the FuncBuilder if the binding is lifted. This is used to update the inputs/outputs
602620
// of the lifted function when its binding is used outside.
603621
std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;

tests/python/relax/test_transform_rewrite_cuda_graph.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,5 +1088,93 @@ def main(x: R.Tensor((8,), dtype="float32")) -> R.Tuple(R.Tensor((8,), dtype="fl
10881088
return gv
10891089

10901090

1091+
class TestStaticInputWithSymbolicShape(BaseCompare):
1092+
@I.ir_module
1093+
class Before:
1094+
@R.function
1095+
def main(x: R.Tensor((8,), "float16"), w: R.Tensor(("m",))):
1096+
m = T.int64()
1097+
R.func_attr({"relax.force_pure": True, "num_input": 1})
1098+
storage1 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16")
1099+
alloc1 = R.memory.alloc_tensor(storage1, 0, R.shape([8]), "float16")
1100+
_ = R.call_packed("dummy", x, w, alloc1, sinfo_args=(R.Tuple,))
1101+
storage2 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16")
1102+
alloc2 = R.memory.alloc_tensor(storage2, 0, R.shape([8]), "float16")
1103+
_1 = R.call_packed("dummy", alloc1, w, alloc2, sinfo_args=(R.Tuple,))
1104+
storage3 = R.memory.alloc_storage(R.shape([8]), 0, "global", "float16")
1105+
alloc3 = R.memory.alloc_tensor(storage3, 0, R.shape([8]), "float16")
1106+
_2 = R.call_packed("dummy", alloc2, w, alloc3, sinfo_args=(R.Tuple,))
1107+
gv = (alloc3,)
1108+
return gv
1109+
1110+
@I.ir_module
1111+
class Expected:
1112+
@R.function(private=True)
1113+
def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
1114+
R.func_attr({"relax.force_pure": True})
1115+
storage1: R.Object = R.memory.alloc_storage(
1116+
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16")
1117+
)
1118+
storage2: R.Object = R.memory.alloc_storage(
1119+
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16")
1120+
)
1121+
gv: R.Tuple(R.Object, R.Object) = storage1, storage2
1122+
return gv
1123+
1124+
@R.function(private=True)
1125+
def main_cuda_graph_capture(
1126+
alloc1: R.Tensor((8,), dtype="float16"),
1127+
w: R.Tensor(("m",)),
1128+
alloc2: R.Tensor((8,), dtype="float16"),
1129+
shape_expr: R.Shape(["m"]),
1130+
) -> R.Tuple:
1131+
m = T.int64()
1132+
R.func_attr({"relax.force_pure": True})
1133+
R.call_packed("dummy", alloc1, w, alloc2, sinfo_args=(R.Tuple,))
1134+
R.tuple()
1135+
return R.tuple()
1136+
1137+
@R.function
1138+
def main(
1139+
x: R.Tensor((8,), dtype="float16"), w: R.Tensor(("m",))
1140+
) -> R.Tuple(R.Tensor((8,), dtype="float16")):
1141+
m = T.int64()
1142+
R.func_attr({"num_input": 1, "relax.force_pure": True})
1143+
cls = Expected
1144+
gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
1145+
"vm.builtin.cuda_graph.get_cached_alloc",
1146+
(cls.cuda_graph_alloc, R.prim_value(0)),
1147+
sinfo_args=(R.Tuple(R.Object, R.Object),),
1148+
)
1149+
storage1: R.Object = gv[0]
1150+
alloc1: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor(
1151+
storage1, R.prim_value(0), R.shape([8]), R.dtype("float16")
1152+
)
1153+
R.call_packed("dummy", x, w, alloc1, sinfo_args=(R.Tuple,))
1154+
storage2: R.Object = gv[1]
1155+
alloc2: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor(
1156+
storage2, R.prim_value(0), R.shape([8]), R.dtype("float16")
1157+
)
1158+
R.call_builtin_with_ctx(
1159+
"vm.builtin.cuda_graph.run_or_capture",
1160+
(
1161+
cls.main_cuda_graph_capture,
1162+
(alloc1, w, alloc2, R.shape([m])),
1163+
R.prim_value(0),
1164+
R.shape([m]),
1165+
),
1166+
sinfo_args=(R.Tuple,),
1167+
)
1168+
storage3: R.Object = R.memory.alloc_storage(
1169+
R.shape([8]), R.prim_value(0), R.str("global"), R.dtype("float16")
1170+
)
1171+
alloc3: R.Tensor((8,), dtype="float16") = R.memory.alloc_tensor(
1172+
storage3, R.prim_value(0), R.shape([8]), R.dtype("float16")
1173+
)
1174+
R.call_packed("dummy", alloc2, w, alloc3, sinfo_args=(R.Tuple,))
1175+
gv_1: R.Tuple(R.Tensor((8,), dtype="float16")) = (alloc3,)
1176+
return gv_1
1177+
1178+
10911179
if __name__ == "__main__":
10921180
tvm.testing.main()

0 commit comments

Comments
 (0)