Skip to content

Commit cae2680

Browse files
[Relay] [Virtual Device] Store function result virtual device in virtual_device_ field (#9848)
* VStore function result virtual devices in virtual_device_ field * Address Mark's 'mega nit' * Promote function result virtual device to first class * Add kVirtualDevice * move kVirtualDevice * Fix annotation test * Progress on parsing & printing * Fix printing of virtual device attribute * flake
1 parent 6591cba commit cae2680

File tree

8 files changed

+49
-17
lines changed

8 files changed

+49
-17
lines changed

include/tvm/ir/expr.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,11 @@ class RelayExprNode : public BaseExprNode {
178178
*
179179
* For expressions that have the function type, the virtual device describes where the result of
180180
* the call to the function or closure is stored (instead of where the function itself is stored).
181+
* For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where
182+
* the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual
183+
* device of body. For more details, see the documentation in
184+
* src/relay/transforms/device_planner.cc.
185+
*
181186
* The VirtualDevice's Target field describes how the body of the function should be compiled.
182187
*
183188
* Set to VirtualDevice::FullyUnconstrained by default.
@@ -190,6 +195,13 @@ class RelayExprNode : public BaseExprNode {
190195
/*!
191196
* \return The virtual device (VirtualDevice).
192197
* If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained().
198+
* Note that for function types, the virtual device is the device where the result of a
199+
* call to the function is stored, not where the function itself lives.
200+
* For example, the virtual device of f = fn(x) { body } is the virtual device of f(y), not where
201+
* the function itself is stored. Note that f(y)'s virtual device will be the same as the virtual
202+
* device of body.
203+
*
204+
* See the documentation of the virtual_device_ field (above) for more details.
193205
*/
194206
VirtualDevice virtual_device() const;
195207

include/tvm/ir/function.h

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -200,16 +200,6 @@ constexpr const char* kGlobalSymbol = "global_symbol";
200200
*/
201201
constexpr const char* kParamVirtualDevice = "param_virtual_devices";
202202

203-
/*!
204-
* \brief The \p VirtualDevice which will hold the function result.
205-
*
206-
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
207-
* may be included as an annotation on user programs.
208-
*
209-
* Type: VirtualDevice
210-
*/
211-
constexpr const char* kResultVirtualDevice = "result_virtual_device";
212-
213203
} // namespace attr
214204
} // namespace tvm
215205
#endif // TVM_IR_FUNCTION_H_

include/tvm/target/virtual_device.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ using MemoryScope = String;
162162
*
163163
* These operations are needed during device planning.
164164
*/
165+
165166
class VirtualDeviceNode : public AttrsNode<VirtualDeviceNode> {
166167
private:
167168
/*!
@@ -361,6 +362,13 @@ class VirtualDeviceCache {
361362
std::unordered_set<VirtualDevice, StructuralHash, StructuralEqual> cache_;
362363
};
363364

365+
/*! brief The attribute key for the virtual device. This key will be promoted to first class on
366+
* functions. For use in the parser and printer only.
367+
*
368+
* Type: VirtualDevice
369+
*/
370+
constexpr const char* kVirtualDevice = "result_virtual_device";
371+
364372
} // namespace tvm
365373

366374
#endif // TVM_TARGET_VIRTUAL_DEVICE_H_

src/parser/parser.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <tvm/runtime/logging.h>
3232
#include <tvm/runtime/object.h>
3333
#include <tvm/runtime/registry.h>
34+
#include <tvm/target/virtual_device.h>
3435

3536
#include <fstream>
3637

@@ -1137,7 +1138,19 @@ class Parser {
11371138

11381139
// TODO(@jroesch): attributes should never be null, they should always be empty.
11391140
if (raw_attrs.size()) {
1140-
return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
1141+
// Promote kVirtualDevice to first-class
1142+
if (raw_attrs.count(kVirtualDevice)) {
1143+
ObjectRef vid = raw_attrs.at(kVirtualDevice);
1144+
ICHECK(vid.as<VirtualDeviceNode>())
1145+
<< "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got "
1146+
<< vid->GetTypeKey();
1147+
raw_attrs.erase(kVirtualDevice);
1148+
Function func = relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
1149+
func->virtual_device_ = vid;
1150+
return func;
1151+
} else {
1152+
return relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));
1153+
}
11411154
} else {
11421155
return relay::Function(params, body, ret_type, generics, tvm::DictAttrs());
11431156
}

src/printer/relay_text_printer.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,10 +445,16 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
445445
for (const Doc& d : PrintDictAttrs(fn->attrs)) {
446446
params.push_back(d);
447447
}
448+
if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) {
449+
Doc vid_doc;
450+
vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device());
451+
params.push_back(vid_doc);
452+
}
448453
doc << Doc::Concat(params) << ") ";
449454
if (fn->ret_type.defined()) {
450455
doc << "-> " << Print(fn->ret_type) << " ";
451456
}
457+
452458
doc << PrintBody(fn->body);
453459
return doc;
454460
}

src/relay/backend/graph_plan_memory.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
279279
smap.Set(GetRef<Expr>(kv.first), storage_info);
280280
}
281281
// Either all or none of the nodes should be annotated.
282+
VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes
283+
<< std::endl;
282284
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
283285
LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes
284286
<< "expressions are assigned with virtual device types. Either all "

src/relay/op/memory/on_device.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,11 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {
144144

145145
Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
146146
VirtualDevice result_virtual_device) {
147-
return WithAttrs(std::move(function),
148-
{{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)},
149-
{tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}});
147+
auto func = WithAttr(
148+
WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)),
149+
tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices));
150+
VLOG(1) << "Annotated func: " << PrettyPrint(func);
151+
return func;
150152
}
151153

152154
TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice);
@@ -166,8 +168,7 @@ Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_vir
166168
}
167169

168170
VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) {
169-
auto opt_virtual_device = function_node->GetAttr<VirtualDevice>(tvm::attr::kResultVirtualDevice);
170-
return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained());
171+
return function_node->virtual_device();
171172
}
172173

173174
VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) {

tests/python/relay/op/annotation/test_annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_function_on_device():
7070
assert len(func.attrs["param_virtual_devices"]) == 2
7171
assert func.attrs["param_virtual_devices"][0].device_type_int == 1 # ie kDLCPU
7272
assert func.attrs["param_virtual_devices"][1].device_type_int == 2 # ie kDLCUDA
73-
assert func.attrs["result_virtual_device"].device_type_int == 2 # ie KDLCUDA
73+
assert func.virtual_device_.device_type_int == 2 # ie KDLCUDA
7474

7575

7676
if __name__ == "__main__":

0 commit comments

Comments
 (0)