Skip to content

Commit 1c35c39

Browse files
authored
[Unity] Add Relax multi-device e2e cases (#15823)
* [Unity] filter out non-GPU primfuncs in default_gpu_schedule * Add relex heterogeneous e2e case * Remove get_prim_func_device * Update test cases * Fix flake8 * fix lint * Add test case for change of default_gpu_schedule * fix comment
1 parent f328e9b commit 1c35c39

File tree

15 files changed

+471
-51
lines changed

15 files changed

+471
-51
lines changed

python/tvm/driver/build_module.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,20 +243,33 @@ def build(
243243

244244
if not isinstance(inputs, (dict, container.Map)):
245245
target = Target.current() if target is None else target
246-
target = target if target else "llvm"
247-
target_input_mod = {target: input_mod}
246+
if target is None and isinstance(input_mod, tvm.IRModule):
247+
target_mod = {}
248+
for gvar, func in input_mod.functions.items():
249+
tgt = func.attrs["target"] if func.attrs and "target" in func.attrs else "llvm"
250+
if tgt not in target_mod:
251+
target_mod[tgt] = {}
252+
target_mod[tgt][gvar] = func
253+
254+
target_input_mod = {}
255+
for tgt in target_mod.keys():
256+
tir_mod = tvm.IRModule(target_mod[tgt])
257+
tir_mod.with_attrs(input_mod.attrs)
258+
target_input_mod[tgt] = tir_mod
259+
else:
260+
target_input_mod = {target: input_mod}
248261
else:
249-
target_input_mod = inputs
262+
target_input_mod = {tgt: lower(mod) for tgt, mod in inputs.items()}
250263

251264
# Because modules can be created from a variety of sources, we annotate them
252265
# with the relevant attributes here to ensure they propagate
253266
annotated_mods = {}
254-
for tar, mod in target_input_mod.items():
255-
if not isinstance(tar, (str, Target)):
267+
for tgt, mod in target_input_mod.items():
268+
if not isinstance(tgt, (str, Target)):
256269
raise ValueError("The key of inputs must be str or " "Target when inputs is dict.")
257270
if not isinstance(mod, tvm.IRModule):
258-
raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.")
259-
annotated_mods[tar] = mod.with_attr("runtime", runtime)
271+
raise ValueError("inputs must be Schedule, IRModule, " "or dict of str to IRModule.")
272+
annotated_mods[tgt] = mod.with_attr("runtime", runtime)
260273

261274
# TODO(mbs): Both CompilationConfig and TIRToRuntime implement the same host target
262275
# defaulting logic, but there's currently no way to get back the decided host.

python/tvm/relax/utils.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .expr import Tuple as rx_Tuple
2929
from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
3030
from ..te import Tensor as te_Tensor, create_prim_func
31-
from ..ir import Array, Attrs, Type, Map
31+
from ..ir import Array, Attrs, Type, Map, VDevice
3232
from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo
3333

3434

@@ -418,6 +418,24 @@ def _populate_used_vars(expr):
418418
diff = used_vars - bound_vars
419419
return list(diff)
420420

421+
def _get_vdevice(arg: Any) -> Optional[VDevice]:
422+
"""get the virtual device from arguments."""
423+
vdevice = None
424+
if isinstance(arg, Expr): # type: ignore
425+
if isinstance(arg.struct_info, TensorStructInfo):
426+
vdevice = arg.struct_info.vdevice
427+
elif isinstance(arg, (list, Array, tuple)):
428+
for x in arg:
429+
vdevice = _get_vdevice(x)
430+
if vdevice is not None:
431+
return vdevice
432+
elif isinstance(arg, (dict, Map)):
433+
for k in arg:
434+
vdevice = _get_vdevice(arg[k])
435+
if vdevice is not None:
436+
return vdevice
437+
return vdevice
438+
421439
def _shape_with_old_tir_var(
422440
shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var, tir.PrimExpr]
423441
):
@@ -456,7 +474,11 @@ def _shape_with_old_tir_var(
456474
tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}
457475

458476
output_sinfo = [
459-
TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype)
477+
TensorStructInfo(
478+
_shape_with_old_tir_var(out.shape, tir_var_inverse_map),
479+
out.dtype,
480+
_get_vdevice(args),
481+
)
460482
for out in outs
461483
]
462484

python/tvm/relax/vm_build.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def jit(self, fcompile=None, addons=None, **kwargs) -> tvm.runtime.Module:
7979
vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda())
8080
"""
8181

82-
# TODO(tvm-team): Update runtime.Module interfac
82+
# TODO(tvm-team): Update runtime.Module interface
8383
# to query these properties as bitmask.
8484
def _not_runnable(x):
8585
return x.type_key in ("c", "static_library")
@@ -179,13 +179,17 @@ def _vmcodegen(
179179
raise ValueError(f"Unknown exec_mode {exec_mode}")
180180

181181

182-
def _autodetect_system_lib_req(target: tvm.target.Target, system_lib):
182+
def _autodetect_system_lib_req(
183+
target: Optional[tvm.target.Target] = None, system_lib: Optional[bool] = None
184+
):
183185
"""Automatically detect system lib requirement"""
184-
host = target if target.host is None else target.host
185-
if system_lib is None:
186-
system_lib = False
187-
if "wasm" in host.attrs.get("mtriple", ""):
188-
system_lib = True
186+
if target is not None:
187+
host = target if target.host is None else target.host
188+
if system_lib is None:
189+
system_lib = False
190+
if "wasm" in host.attrs.get("mtriple", ""):
191+
system_lib = True
192+
189193
if system_lib:
190194
# use packed-func to avoid relay dep.
191195
return tvm.get_global_func("relay.backend.CreateRuntime")("cpp", {"system-lib": system_lib})
@@ -194,7 +198,7 @@ def _autodetect_system_lib_req(target: tvm.target.Target, system_lib):
194198

195199
def _vmlink(
196200
builder: "relax.ExecBuilder",
197-
target: Union[str, tvm.target.Target],
201+
target: Optional[Union[str, tvm.target.Target]],
198202
tir_mod: Optional[tvm.IRModule] = None,
199203
ext_libs: List[tvm.runtime.Module] = None,
200204
params: Optional[Dict[str, list]] = None,
@@ -213,8 +217,10 @@ def _vmlink(
213217
builder: relax.ExecBuilder
214218
Builder used to collect executables.
215219
216-
target : Union[str, tvm.target.Target]
220+
target : Optional[Union[str, tvm.target.Target]]
217221
A build target which can have optional host side compilation target.
222+
If the target is not specified, the target in the vdevice list will be used.
223+
For multi-target compilation, the vdevice should be annotated.
218224
219225
tir_mod: IRModule
220226
The input TIR IRModule to be linked together.
@@ -239,14 +245,16 @@ def _vmlink(
239245
lib = None
240246
if tir_mod is not None:
241247
lib = tvm.build(
242-
tir_mod, target=target, runtime=_autodetect_system_lib_req(target, system_lib)
248+
tir_mod,
249+
target=target,
250+
runtime=_autodetect_system_lib_req(target, system_lib),
243251
)
244252
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore
245253

246254

247255
def build(
248256
mod: tvm.IRModule,
249-
target: Union[str, tvm.target.Target],
257+
target: Optional[Union[str, tvm.target.Target]] = None,
250258
params: Optional[Dict[str, list]] = None,
251259
pipeline: Union[None, str, tvm.transform.Pass] = "default_build",
252260
exec_mode: str = "bytecode",
@@ -261,7 +269,7 @@ def build(
261269
mod: IRModule
262270
The input IRModule to be built.
263271
264-
target : Union[str, tvm.target.Target]
272+
target : Optional[Union[str, tvm.target.Target]]
265273
A build target which can have optional host side compilation target.
266274
267275
When TVM compiles device specific program such as CUDA,

python/tvm/runtime/relax_vm.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(
5454
5555
Parameters
5656
----------
57-
mod: Union[tvm.runtime.Module, tvm.relax.Executable]
57+
rt_mod: Union[tvm.runtime.Module, tvm.relax.Executable]
5858
Runtime module exported by the result of build.
5959
6060
device : Union[Device, List[Device]]
@@ -107,11 +107,6 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]])
107107
)
108108
devs = [dev]
109109

110-
if any(dev.device_type % RPC_SESS_MASK == tvm.cpu().device_type for dev in devs[:-1]):
111-
raise RuntimeError(
112-
"CPU host is required to be the last element of the device list if provided."
113-
)
114-
115110
# CPU is required for executing shape functions
116111
if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type:
117112
devs.append(tvm.cpu())

python/tvm/testing/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -832,6 +832,16 @@ def _any_gpu_exists():
832832
)
833833

834834

835+
def _multi_gpu_exists():
836+
return (
837+
(tvm.cuda(0).exist and tvm.cuda(1).exist)
838+
or (tvm.rocm(0).exist and tvm.rocm(1).exist)
839+
or (tvm.opencl(0).exist and tvm.opencl(1).exist)
840+
or (tvm.metal(0).exist and tvm.metal(1).exist)
841+
or (tvm.vulkan(0).exist and tvm.vulkan(1).exist)
842+
)
843+
844+
835845
# Mark a test as requiring llvm to run
836846
requires_llvm = Feature(
837847
"llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm"
@@ -847,6 +857,16 @@ def _any_gpu_exists():
847857
# :py:func:`tvm.testing.requires_gpu`.
848858
uses_gpu = requires_gpu(support_required="optional")
849859

860+
# Mark a test as requiring multiple GPUs to run.
861+
requires_multi_gpu = Feature("multi_gpu", run_time_check=_multi_gpu_exists)
862+
863+
# Mark to differentiate tests that use multiple GPUs in some capacity.
864+
#
865+
# These tests will be run on test nodes with multiple GPUs.
866+
# To mark a test that must have multiple GPUs present to run, use
867+
# :py:func:`tvm.testing.requires_multi_gpu`.
868+
uses_multi_gpu = requires_multi_gpu(support_required="optional")
869+
850870
# Mark a test as requiring the x86 Architecture to run.
851871
requires_x86 = Feature(
852872
"x86", "x86 Architecture", run_time_check=lambda: platform.machine() == "x86_64"

src/ir/module.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ void IRModuleNode::Update(const IRModule& mod) {
324324

325325
IRModule IRModuleNode::ShallowCopy() {
326326
return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map,
327-
this->attrs);
327+
this->attrs, this->global_infos);
328328
}
329329

330330
std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(

src/relax/transform/call_tir_rewrite.cc

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
*/
1919
/*!
2020
* \file src/relax/transform/call_tir_rewrite.cc
21-
* \brief Perform explicit tensor allocation for call_tir.
21+
* \brief Perform explicit tensor allocation for call_tir,
22+
* call_tir_inplace, and call_dps_packed.
2223
*/
2324
#include <tvm/relax/attrs/op.h>
2425
#include <tvm/relax/expr_functor.h>
@@ -28,6 +29,7 @@
2829
#include <tvm/tir/op.h>
2930

3031
#include "../../relay/transforms/pattern_utils.h"
32+
#include "utils.h"
3133

3234
namespace tvm {
3335
namespace relax {
@@ -43,6 +45,19 @@ namespace relax {
4345

4446
class CallTIRMutator : public ExprMutator {
4547
public:
48+
explicit CallTIRMutator(const IRModule& mod) : ExprMutator(mod), mod_(std::move(mod)) {}
49+
50+
IRModule Run() {
51+
for (const auto& [gv, func] : mod_->functions) {
52+
if (func->IsInstance<FunctionNode>()) {
53+
auto updated_func = Downcast<Function>(this->VisitExpr(func));
54+
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
55+
}
56+
}
57+
return builder_->GetContextIRModule();
58+
}
59+
60+
private:
4661
using ExprMutator::VisitExpr_;
4762
Expr VisitExpr_(const CallNode* call) override {
4863
// post-order mutation
@@ -65,11 +80,15 @@ class CallTIRMutator : public ExprMutator {
6580
const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value();
6681
ICHECK(tensor_sinfo->shape.defined())
6782
<< "the TensorStructInfo shape of call_tir has not populated";
83+
int dev_index = 0;
84+
if (tensor_sinfo->vdevice.defined()) {
85+
dev_index = GetDeviceIndex(mod_, tensor_sinfo->vdevice.value());
86+
}
6887
if (!is_inplace) {
6988
outs.push_back(
70-
builder_->Emit(Call(alloc_tensor_op, //
89+
builder_->Emit(Call(alloc_tensor_op,
7190
{Downcast<ShapeExpr>(tensor_sinfo->shape.value()),
72-
DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(0)}, //
91+
DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(dev_index)},
7392
Attrs()),
7493
"alloc"));
7594
} else {
@@ -150,16 +169,20 @@ class CallTIRMutator : public ExprMutator {
150169

151170
return GetRef<Expr>(call);
152171
}
153-
};
154172

155-
Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }
173+
/*! \brief The context IRModule. */
174+
IRModule mod_;
175+
};
156176

157177
namespace transform {
158178

159179
Pass CallTIRRewrite() {
160-
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
161-
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(CallTIRRewrite(f)); };
162-
return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
180+
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
181+
[=](IRModule mod, PassContext pc) { return CallTIRMutator(mod).Run(); };
182+
return CreateModulePass(/*pass_function=*/pass_func,
183+
/*opt_level=*/0,
184+
/*pass_name=*/"CallTIRRewrite",
185+
/*required=*/{});
163186
}
164187

165188
TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);

src/relax/transform/legalize_ops.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <tvm/relax/analysis.h>
2727
#include <tvm/relax/expr_functor.h>
2828
#include <tvm/relax/op_attr_types.h>
29+
#include <tvm/relax/struct_info.h>
2930
#include <tvm/relax/transform.h>
3031

3132
namespace tvm {
@@ -72,6 +73,14 @@ class LegalizeMutator : public ExprMutator {
7273
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
7374
}
7475
}
76+
// Fill the "kTarget" attribute of PrimFunc
77+
for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
78+
const tir::PrimFuncNode* prim_func;
79+
if (tmap_.count(gv) && (prim_func = func.as<tir::PrimFuncNode>())) {
80+
auto f = WithAttr(GetRef<tir::PrimFunc>(prim_func), tvm::attr::kTarget, tmap_[gv]);
81+
builder_->UpdateFunction(gv, f);
82+
}
83+
}
7584
return builder_->GetContextIRModule();
7685
}
7786

@@ -109,6 +118,33 @@ class LegalizeMutator : public ExprMutator {
109118
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
110119
}
111120

121+
Target GetTarget(const Array<StructInfo>& sinfos) {
122+
for (auto sinfo : sinfos) {
123+
if (const auto* tinfo = sinfo.as<TensorStructInfoNode>()) {
124+
if (tinfo->vdevice.defined()) {
125+
auto vdevice = tinfo->vdevice.value();
126+
if (vdevice->target.defined()) {
127+
return vdevice->target;
128+
}
129+
}
130+
} else if (const auto* tup_sinfo = sinfo.as<TupleStructInfoNode>()) {
131+
return GetTarget(tup_sinfo->fields);
132+
}
133+
}
134+
return Target();
135+
}
136+
137+
void SaveTarget(const Expr& expr) {
138+
if (expr->IsInstance<CallNode>()) {
139+
auto call = Downcast<Call>(expr);
140+
auto target = GetTarget(call->sinfo_args);
141+
const GlobalVarNode* gvar_node;
142+
if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>())) {
143+
this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);
144+
}
145+
}
146+
}
147+
112148
Expr VisitExpr_(const CallNode* call) final {
113149
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
114150
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
@@ -164,6 +200,10 @@ class LegalizeMutator : public ExprMutator {
164200
builder_->BeginBindingBlock();
165201
}
166202
Expr legalized = legalization_func(builder_, visited_call);
203+
204+
// Save the expected target info. into tmap_
205+
SaveTarget(legalized);
206+
167207
legalized = builder_->Normalize(legalized);
168208

169209
BindingBlock prologue = builder_->EndBlock();
@@ -196,6 +236,8 @@ class LegalizeMutator : public ExprMutator {
196236
IRModule mod_;
197237
/*! \brief The customized legalization function map. */
198238
Map<String, PackedFunc> cmap_;
239+
/*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
240+
Map<GlobalVar, Target> tmap_;
199241
/*!
200242
* \brief A boolean value indicating if to print warnings for CallNode whose op's
201243
* legalization function is not registered.

0 commit comments

Comments
 (0)