Skip to content

Commit 8c1361e

Browse files
tqchencbalint13
andauthored
[BugFix][UMA] Protect target registration (apache#13624) (#4)
This PR address fixes for UMA target registration. * Fix the doc issue apache#13304 * Continues stalled PR apache#12731 Changes: * Incorporates all proposed fixes from mentioned [PR apache#12731](apache#12731) * Address test case concerns and discussions from [PR apache#12731](apache#12731) * **NEW:** Already exiting target cannot be created, explicit error on this. * **NEW:** Attributes having special/reserved scope cannot be created explicitly. It also address proper test cases for all the above. Signed-off-by: tqchen <[email protected]> Co-authored-by: Balint Cristian <[email protected]>
1 parent 1d98634 commit 8c1361e

File tree

4 files changed

+42
-16
lines changed

4 files changed

+42
-16
lines changed

gallery/tutorial/uma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
#
5858

5959
######################################################################
60-
# .. image:: https://raw.githubusercontent.com/apache/tvm-site/main/images/tutorial/uma_vanilla_block_diagram.png
60+
# .. image:: https://raw.githubusercontent.com/tlc-pack/web-data/main/images/tutorial/uma_vanilla_block_diagram.png
6161
# :width: 100%
6262
# :alt: A block diagram of Vanilla
6363
#

python/tvm/relay/backend/contrib/uma/backend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,12 @@ def register(self) -> None:
278278
"""
279279
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
280280

281-
for name, attr in self._target_attrs:
281+
for name, attr in self._target_attrs.items():
282282
if attr is None:
283283
raise ValueError("Target attribute None is not supported.")
284-
285-
if registration_func(self.target_name, self._target_attrs):
284+
# skip if target is already registered
285+
if self.target_name not in tvm.target.Target.list_kinds():
286+
registration_func(self.target_name, self._target_attrs)
286287
self._relay_to_relay.register()
287288
self._relay_to_tir.register()
288289
self._tir_to_runtime.register()

src/relay/backend/contrib/uma/targets.cc

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,23 @@ namespace tvm {
3131
namespace relay {
3232
namespace contrib {
3333
namespace uma {
34-
tvm::transform::Pass RelayToTIR(String target_name);
34+
transform::Pass RelayToTIR(String target_name);
3535
runtime::Module TIRToRuntime(IRModule mod, Target target);
3636
} // namespace uma
3737
} // namespace contrib
3838
} // namespace relay
3939

4040
TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
4141
.set_body_typed([](String target_name, Map<String, ObjectRef> attr_options) -> bool {
42-
// @todo(cgerum): We probably should get rid of target.register rather sooner than later
43-
// And use a proper registry for uma backends
44-
for (const String registered_target_name : ::tvm::TargetKindRegEntry::ListTargetKinds()) {
42+
// create only new target and init only once
43+
for (const String registered_target_name : TargetKindRegEntry::ListTargetKinds()) {
4544
if (registered_target_name == target_name) {
46-
return false;
45+
LOG(FATAL) << "TVM UMA Error: Target is already registered: " << target_name;
4746
}
4847
}
4948

5049
auto target_kind =
51-
::tvm::TargetKindRegEntry::RegisterOrGet(target_name)
50+
TargetKindRegEntry::RegisterOrGet(target_name)
5251
.set_name()
5352
.set_default_device_type(kDLCPU)
5453
.add_attr_option<Array<String>>("keys")
@@ -58,20 +57,27 @@ TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
5857
.add_attr_option<Array<String>>("libs")
5958
.add_attr_option<Target>("host")
6059
.add_attr_option<Integer>("from_device")
61-
.set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR,
60+
.set_attr<FTVMRelayToTIR>(attr::kRelayToTIR,
6261
relay::contrib::uma::RelayToTIR(target_name))
6362
.set_attr<FTVMTIRToRuntime>("TIRToRuntime", relay::contrib::uma::TIRToRuntime);
6463

64+
// target kind attrs inventory
65+
auto kind = TargetKind::Get(target_name).value();
66+
auto list_attrs = TargetKindRegEntry::ListTargetKindOptions(kind);
67+
6568
for (auto& attr_option : attr_options) {
6669
auto option_name = attr_option.first;
6770
auto default_value = attr_option.second;
71+
if (list_attrs.find(option_name) != list_attrs.end()) {
72+
LOG(FATAL) << "TVM UMA Error: Attribute is already registered: " << option_name;
73+
}
6874
if (default_value->IsInstance<StringObj>()) {
6975
target_kind.add_attr_option<String>(option_name, Downcast<String>(default_value));
7076
} else if (default_value->IsInstance<IntImmNode>()) {
7177
target_kind.add_attr_option<Integer>(option_name, Downcast<Integer>(default_value));
7278
} else {
73-
LOG(FATAL) << "Only String, Integer, or Bool are supported. Given attribute option type: "
74-
<< attr_option.second->GetTypeKey();
79+
LOG(FATAL) << "TypeError: Only String, Integer, or Bool are supported. "
80+
<< "Given attribute option type: " << attr_option.second->GetTypeKey();
7581
}
7682
}
7783
return true;

tests/python/contrib/test_uma/test_target.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,23 +63,42 @@ def test_uma_target(target_name, target_attrs, target_args):
6363
[
6464
("float_attr", 3.14),
6565
("none_attr", None),
66+
("model", "my_model"),
6667
],
6768
)
6869
def test_invalid_attr_option(attr_name: str, target_attr: Union[str, int, bool, float, None]):
70+
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
6971
if target_attr is None:
7072
# None cannot be caught as TVMError, as it causes a SIGKILL, therefore it must be prevented to be
7173
# entered into relay.backend.contrib.uma.RegisterTarget at Python level.
72-
with pytest.raises(ValueError):
74+
with pytest.raises(ValueError, match=r"Target attribute None is not supported."):
7375
uma_backend = VanillaAcceleratorBackend()
7476
uma_backend._target_attrs = {attr_name: target_attr}
7577
uma_backend.register()
78+
elif "model" in attr_name:
79+
target_name = f"{attr_name}_{target_attr}"
80+
target_attr = {attr_name: target_attr}
81+
with pytest.raises(tvm.TVMError, match=r"Attribute is already registered: .*"):
82+
registration_func(target_name, target_attr)
7683
else:
77-
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
7884
target_name = f"{attr_name}_{target_attr}"
7985
target_attr = {attr_name: target_attr}
80-
with pytest.raises(tvm.TVMError, match=r"Only String, Integer, or Bool are supported. .*"):
86+
with pytest.raises(TypeError, match=r"Only String, Integer, or Bool are supported. .*"):
8187
registration_func(target_name, target_attr)
8288

8389

90+
@pytest.mark.parametrize(
91+
"target_name",
92+
[
93+
"llvm",
94+
"c",
95+
],
96+
)
97+
def test_target_duplication(target_name: str):
98+
with pytest.raises(tvm.TVMError, match=r"TVM UMA Error: Target is already registered: .*"):
99+
registration_func = tvm.get_global_func("relay.backend.contrib.uma.RegisterTarget")
100+
registration_func(target_name, {})
101+
102+
84103
if __name__ == "__main__":
85104
tvm.testing.main()

0 commit comments

Comments
 (0)