From 9027c97f16853d48fa100ffa2b17cea502949a3f Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Thu, 16 Dec 2021 12:39:16 -0800 Subject: [PATCH 1/2] [Relay] s/SEScope/VirtualDevice/g Nobody liked 'SEScope', and 'DeviceMcDeviceFace' is too verbose, so it seems 'VirtualDevice' has the popular vote. --- include/tvm/ir/expr.h | 14 +- include/tvm/ir/function.h | 12 +- include/tvm/relay/attrs/device_copy.h | 10 +- include/tvm/relay/attrs/memory.h | 6 +- include/tvm/relay/attrs/on_device.h | 29 +- include/tvm/relay/expr.h | 20 +- include/tvm/relay/function.h | 2 +- include/tvm/relay/transform.h | 14 +- include/tvm/target/compilation_config.h | 37 +- .../target/{se_scope.h => virtual_device.h} | 162 +++---- python/tvm/relay/op/annotation/annotation.py | 12 +- python/tvm/relay/op/tensor.py | 10 +- python/tvm/relay/transform/transform.py | 10 +- python/tvm/target/__init__.py | 2 +- .../target/{se_scope.py => virtual_device.py} | 9 +- src/printer/relay_text_printer.cc | 10 +- src/relay/backend/aot_executor_codegen.cc | 30 +- src/relay/backend/build_module.cc | 2 +- src/relay/backend/graph_executor_codegen.cc | 8 +- src/relay/backend/graph_plan_memory.cc | 36 +- src/relay/backend/interpreter.cc | 8 +- src/relay/backend/te_compiler.cc | 58 +-- src/relay/backend/te_compiler.h | 5 +- src/relay/backend/utils.cc | 26 +- src/relay/backend/utils.h | 8 +- src/relay/backend/vm/compiler.cc | 129 ++--- src/relay/backend/vm/compiler.h | 8 +- src/relay/backend/vm/lambda_lift.cc | 9 +- src/relay/ir/expr.cc | 46 +- src/relay/ir/expr_functor.cc | 14 +- src/relay/ir/function.cc | 4 +- src/relay/op/memory/device_copy.cc | 26 +- src/relay/op/memory/device_copy.h | 29 +- src/relay/op/memory/memory.cc | 4 +- src/relay/op/memory/memory.h | 4 +- src/relay/op/memory/on_device.cc | 74 +-- src/relay/op/memory/on_device.h | 67 +-- src/relay/transforms/device_aware_visitors.cc | 82 ++-- src/relay/transforms/device_aware_visitors.h | 104 +++-- src/relay/transforms/device_domains.cc | 130 +++--- src/relay/transforms/device_domains.h | 86 ++-- src/relay/transforms/device_planner.cc | 343 +++++++------- src/relay/transforms/fold_constant.cc | 22 +- src/relay/transforms/memory_alloc.cc | 89 ++-- src/relay/transforms/to_a_normal_form.cc | 16 +- src/target/compilation_config.cc | 49 +- src/target/{se_scope.cc => virtual_device.cc} | 54 ++- src/tir/analysis/device_constraint_utils.cc | 107 ++--- src/tir/analysis/device_constraint_utils.h | 28 +- tests/cpp/relay/op/memory/on_device_test.cc | 28 +- .../relay/transforms/device_domains_test.cc | 12 +- tests/cpp/target/compilation_config_test.cc | 66 +-- tests/cpp/target/se_scope_test.cc | 119 ----- tests/cpp/target/virtual_device_test.cc | 121 +++++ .../relay/op/annotation/test_annotation.py | 22 +- tests/python/relay/op/test_tensor.py | 20 +- .../relay/test_pass_dead_code_elimination.py | 12 +- tests/python/relay/test_pass_plan_devices.py | 442 +++++++++--------- ...est_se_scope.py => test_virtual_device.py} | 32 +- 59 files changed, 1514 insertions(+), 1424 deletions(-) rename include/tvm/target/{se_scope.h => virtual_device.h} (65%) rename python/tvm/target/{se_scope.py => virtual_device.py} (72%) rename src/target/{se_scope.cc => virtual_device.cc} (71%) delete mode 100644 tests/cpp/target/se_scope_test.cc create mode 100644 tests/cpp/target/virtual_device_test.cc rename tests/python/target/{test_se_scope.py => test_virtual_device.py} (54%) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index a6e5c8de73a7..8937bb7b1016 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -39,8 +39,8 @@ namespace tvm { using tvm::runtime::String; -// Forward-declare SEScope to avoid circular imports. -class SEScope; +// Forward-declare VirtualDevice to avoid circular imports. +class VirtualDevice; /*! * \brief Base type of all the expressions. @@ -169,7 +169,7 @@ class RelayExprNode : public BaseExprNode { inline const TTypeNode* type_as() const; /*! - * \brief The virtual device (SEScope) for this node (the result of device planning). + * \brief The virtual device (VirtualDevice) for this node (the result of device planning). * For first-order expressions (non functions), this describes where the result of evaluating the * expression should be stored. Note that currently, all composite first-order values (tuples, * references, ADTs) must be stored on the same virtual device. This means that it is not possible @@ -178,7 +178,7 @@ class RelayExprNode : public BaseExprNode { * * For expressions that have the function type, the virtual device describes where the result of * the call to the function or closure is stored (instead of where the function itself is stored). - * The SEScope's Target field describes how the body of the function should be compiled. + * The VirtualDevice's Target field describes how the body of the function should be compiled. * * \note Unfortunately, the type of virtual_device_ needs to be ObjectRef to avoid a circular * import. @@ -186,10 +186,10 @@ class RelayExprNode : public BaseExprNode { mutable ObjectRef virtual_device_; /*! - * \return The virtual device (SEScope). - * If the virtual device is not defined, returns SEScope::FullyUnconstrained(). + * \return The virtual device (VirtualDevice). + * If the virtual device is not defined, returns VirtualDevice::FullyUnconstrained(). */ - SEScope virtual_device() const; + VirtualDevice virtual_device() const; static constexpr const char* _type_key = "RelayExpr"; static constexpr const uint32_t _type_child_slots = 22; diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index e466cde097ac..051c05dd3d01 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -191,24 +191,24 @@ constexpr const char* kTarget = "target"; constexpr const char* kGlobalSymbol = "global_symbol"; /*! - * \brief The SEScope which will hold each of the functions parameters. + * \brief The \p VirtualDevice which will hold each of the functions parameters. * * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but * may be included as an annotation on user programs. * - * Type: Array + * Type: Array */ -constexpr const char* kParamSEScopes = "param_se_scopes"; +constexpr const char* kParamVirtualDevice = "param_virtual_devices"; /*! - * \brief The SEScope which will hold the function result. + * \brief The \p VirtualDevice which will hold the function result. * * Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but * may be included as an annotation on user programs. * - * Type: SEScope + * Type: VirtualDevice */ -constexpr const char* kResultSEScope = "result_se_scope"; +constexpr const char* kResultVirtualDevice = "result_virtual_device"; } // namespace attr } // namespace tvm diff --git a/include/tvm/relay/attrs/device_copy.h b/include/tvm/relay/attrs/device_copy.h index 6d97ab79be4a..fe0534a8a2b4 100644 --- a/include/tvm/relay/attrs/device_copy.h +++ b/include/tvm/relay/attrs/device_copy.h @@ -25,7 +25,7 @@ #define TVM_RELAY_ATTRS_DEVICE_COPY_H_ #include -#include +#include #include @@ -36,13 +36,13 @@ namespace relay { * \brief Options for the device copy operators. */ struct DeviceCopyAttrs : public tvm::AttrsNode { - SEScope src_se_scope = SEScope::FullyUnconstrained(); - SEScope dst_se_scope = SEScope::FullyUnconstrained(); + VirtualDevice src_virtual_device = VirtualDevice::FullyUnconstrained(); + VirtualDevice dst_virtual_device = VirtualDevice::FullyUnconstrained(); TVM_DECLARE_ATTRS(DeviceCopyAttrs, "relay.attrs.DeviceCopyAttrs") { - TVM_ATTR_FIELD(src_se_scope) + TVM_ATTR_FIELD(src_virtual_device) .describe("The (virtual) device and scope where the op copies data from."); - TVM_ATTR_FIELD(dst_se_scope) + TVM_ATTR_FIELD(dst_virtual_device) .describe("The (virtual) device and scope where the op copies data to."); } }; diff --git a/include/tvm/relay/attrs/memory.h b/include/tvm/relay/attrs/memory.h index 952d4affc584..07d6cc7e271e 100644 --- a/include/tvm/relay/attrs/memory.h +++ b/include/tvm/relay/attrs/memory.h @@ -26,7 +26,7 @@ #include #include -#include +#include #include #include @@ -43,13 +43,13 @@ Expr ToTupleType(const Type& t, const std::vector& exprs); */ struct AllocStorageAttrs : public tvm::AttrsNode { DataType dtype; - SEScope se_scope = SEScope::FullyUnconstrained(); + VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained(); TVM_DECLARE_ATTRS(AllocStorageAttrs, "relay.attrs.AllocStorageAttrs") { TVM_ATTR_FIELD(dtype) .describe("The dtype of the tensor to allocate.") .set_default(DataType::Float(32, 1)); - TVM_ATTR_FIELD(se_scope).describe("The SEScope on which to allocate memory."); + TVM_ATTR_FIELD(virtual_device).describe("The virtual device on which to allocate memory."); } }; diff --git a/include/tvm/relay/attrs/on_device.h b/include/tvm/relay/attrs/on_device.h index 0931865fa88e..3facc3a597f1 100644 --- a/include/tvm/relay/attrs/on_device.h +++ b/include/tvm/relay/attrs/on_device.h @@ -25,7 +25,7 @@ #define TVM_RELAY_ATTRS_ON_DEVICE_H_ #include -#include +#include #include @@ -37,42 +37,43 @@ namespace relay { * * The Relay call: * \code - * on_device(sub_expr, se_scope=S) + * on_device(sub_expr, virtual_device=S) * \endcode - * constrains \p sub_expr to execute and store its result on the \p SEScope \p S. + * constrains \p sub_expr to execute and store its result on the \p VirtualDevice \p S. * However the annotation itself may appear in an expression to be executed and stored on a - * different \p SEScope. If so the compiler will automatically insert a "device_copy" call to - * mediate the transition between \p SEScopes. + * different \p VirtualDevice. If so the compiler will automatically insert a "device_copy" call to + * mediate the transition between \p VirtualDevices. * * E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then: * \code - * multiply(on_device(add(%x, %y), se_scope=GPU), %z) + * multiply(on_device(add(%x, %y), virtual_device=GPU), %z) * \endcode * indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU. * The compiler will rewrite this to: * \code - * multiply(device_copy(add(%x, %y), src_se_scope=GPU, dst_se_scope=CPU), %z) + * multiply(device_copy(add(%x, %y), src_virtual_device=GPU, dst_virtual_device=CPU), %z) * \endcode * * The \p constraint_body (default true) and \p constraint_result (default false) fields can be - * used by passes for finer-grained control over how the \p SEScope constraint should be applied. + * used by passes for finer-grained control over how the \p VirtualDevice constraint should be + * applied. */ struct OnDeviceAttrs : public tvm::AttrsNode { /*! - * \brief The \p SEScope to constraint to apply to the body, result, or both body and result + * \brief The \p VirtualDevice to constraint to apply to the body, result, or both body and result * of the "on_device" call. */ - SEScope se_scope = SEScope::FullyUnconstrained(); + VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained(); /*! * \brief If false (the default), the result of the "on_device" call is not constrained to be - * \p se_scope. + * \p virtual_device. */ bool constrain_result = false; /*! * \brief If true (the default), the body of the "on_device" call is constrained to be \p - * se_scope. + * virtual_device. */ bool constrain_body = true; @@ -87,9 +88,9 @@ struct OnDeviceAttrs : public tvm::AttrsNode { bool is_normal() const { return !constrain_result && constrain_body; } TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") { - TVM_ATTR_FIELD(se_scope) + TVM_ATTR_FIELD(virtual_device) .describe("The (virtual) device to constrain to.") - .set_default(SEScope::FullyUnconstrained()); + .set_default(VirtualDevice::FullyUnconstrained()); TVM_ATTR_FIELD(constrain_result) .describe("Whether the constraint applies to the overall expression") .set_default(false); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 8bec72490ab1..04dd9223719e 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -158,7 +158,7 @@ class Tuple : public Expr { * ret_tuple->span = tuple->span. */ Tuple WithFields(Tuple tuple, Optional> opt_fields = Optional>(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -264,7 +264,7 @@ class Var : public Expr { */ Var WithFields(Var var, Optional opt_vid = Optional(), Optional opt_type_annotation = Optional(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -391,7 +391,7 @@ Call WithFields(Call call, Optional opt_op = Optional(), Optional> opt_args = Optional>(), Optional opt_attrs = Optional(), Optional> opt_type_args = Optional>(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -487,7 +487,7 @@ class Let : public Expr { Let WithFields(Let let, Optional opt_var = Optional(), Optional opt_value = Optional(), Optional opt_body = Optional(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! @@ -574,7 +574,7 @@ class If : public Expr { If WithFields(If if_expr, Optional opt_cond = Optional(), Optional opt_true_branch = Optional(), Optional opt_false_branch = Optional(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Get index-th field out of a tuple. */ @@ -640,7 +640,7 @@ class TupleGetItem : public Expr { */ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple = Optional(), Optional opt_index = Optional(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Create a new Reference out of initial value. */ @@ -701,7 +701,7 @@ class RefCreate : public Expr { * ret_ref_create->value = opt_value.value()). */ RefCreate WithFields(RefCreate ref_create, Optional opt_value = Optional(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Get value out of Reference. */ @@ -761,7 +761,7 @@ class RefRead : public Expr { * if opt_ref.value() != ref_read->ref, then ret_ref_read->ref = opt_ref.value()). */ RefRead WithFields(RefRead ref_read, Optional opt_ref = Optional(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ @@ -829,7 +829,7 @@ class RefWrite : public Expr { */ RefWrite WithFields(RefWrite ref_write, Optional opt_ref = Optional(), Optional opt_value = Optional(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /*! diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index 1b8ed4443456..d9bf7acaa037 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -148,7 +148,7 @@ Function WithFields(Function function, Optional> opt_params = Optiona Optional opt_ret_type = Optional(), Optional> opt_ty_params = Optional>(), Optional opt_attrs = Optional(), - Optional opt_virtual_device = Optional(), + Optional opt_virtual_device = Optional(), Optional opt_span = Optional()); /* diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 2d6cdeaa8ca1..dfc49cb5e466 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -31,8 +31,8 @@ #include #include #include -#include #include +#include #include @@ -449,22 +449,22 @@ TVM_DLL Pass RelayToTIRTargetHook(); * \brief A pass for manifesting explicit memory allocations and rewriting * specific dialects. * - * \param cpu_se_scope SEScope for computations and data which must reside on a CPU, such as - * shapes and shape functions. + * \param cpu_virtual_device VirtualDevice for computations and data which must reside on a CPU, + * such as shapes and shape functions. * * \return The pass. */ -TVM_DLL Pass ManifestAlloc(SEScope cpu_se_scope); +TVM_DLL Pass ManifestAlloc(VirtualDevice cpu_virtual_device); /*! - * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p SEScope on which - * every Relay sub-expression should run and the result stored. Captures the result of that + * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the \p VirtualDevice on + * which every Relay sub-expression should run and the result stored. Captures the result of that * analysis using new "on_device" and "device_copy" CallNodes. * * See tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} * for help recovering the device for an arbitrary sub-expression in downstream transformations. * - * \param config Describes the targets and default \p SEScope for all primitive operators and + * \param config Describes the targets and default \p VirtualDevice for all primitive operators and * host sub-expressions. * * \return The pass. diff --git a/include/tvm/target/compilation_config.h b/include/tvm/target/compilation_config.h index 45ff774f1742..1c47a0f806a3 100644 --- a/include/tvm/target/compilation_config.h +++ b/include/tvm/target/compilation_config.h @@ -26,12 +26,12 @@ #ifndef TVM_TARGET_COMPILATION_CONFIG_H_ #define TVM_TARGET_COMPILATION_CONFIG_H_ -#include +#include namespace tvm { /*! - * \brief Gathers the \p Targets and distinguished \p SEScopes in canonical form needed to + * \brief Gathers the \p Targets and distinguished \p VirtualDevices in canonical form needed to * compile a Relay module. Centralizes any setup and validation logic needed to transition * from configuration options conveyed implicitly (eg in \p PassContexts) or explicitly * (eg a a list of \p Targets) to the configuration. @@ -82,13 +82,13 @@ class CompilationConfigNode : public Object { Array primitive_targets; /*! - * \brief \p SEScope for primitive operators which are not otherwise constrained to a particular - * device. + * \brief \p VirtualDevice for primitive operators which are not otherwise constrained to a + * particular device. */ - SEScope default_primitive_se_scope = SEScope::FullyUnconstrained(); + VirtualDevice default_primitive_virtual_device = VirtualDevice::FullyUnconstrained(); - /*! \brief SEScope for the host. */ - SEScope host_se_scope = SEScope::FullyUnconstrained(); + /*! \brief VirtualDevice for the host. */ + VirtualDevice host_virtual_device = VirtualDevice::FullyUnconstrained(); /*! * \brief If defined then compile and/or run in 'homogenous execution mode'. In this mode all @@ -104,24 +104,25 @@ class CompilationConfigNode : public Object { void VisitAttrs(AttrVisitor* v); /*! - * \brief Returns a \p SEScope agreeing with \p se_scope on all its constrained fields, however: + * \brief Returns a \p VirtualDevice agreeing with \p virtual_device on all its constrained + * fields, however: * - If the target is null then it is filled in from the known available primitive targets by * matching on device type. Fails if no such target is known. - * - The returned object is unique for the field values w.r.t. all other \p SEScopes returned - * by this method. + * - The returned object is unique for the field values w.r.t. all other \p VirtualDevices + * returned by this method. * - * We call the result the 'canonical' \p SEScope. Two canonical \p SEScopes are structurally - * equal if and only if they are pointer equal. + * We call the result the 'canonical' \p VirtualDevice. Two canonical \p VirtualDevices are + * structurally equal if and only if they are pointer equal. */ - SEScope CanonicalSEScope(const SEScope& se_scope) const; + VirtualDevice CanonicalVirtualDevice(const VirtualDevice& virtual_device) const; static constexpr const char* _type_key = "CompilationConfig"; TVM_DECLARE_FINAL_OBJECT_INFO(CompilationConfigNode, Object) private: /*! - * \brief Establishes the default \p SEScope for primitives and the \p SEScope for the host - * given: + * \brief Establishes the default \p VirtualDevice for primitives and the \p VirtualDevice for the + * host given: * - the vector of available primitive \p Targets. * - any host \p Target. * - any "relay.fallback_device_type" attribute on \p pass_ctx. @@ -134,7 +135,7 @@ class CompilationConfigNode : public Object { * CAUTION: Recreated the primitive_targets so that they all have the given/constructed * host_target as their host (cf CheckAndUpdateHostConsistency). */ - void EstablishDefaultSEScopes(const transform::PassContext& pass_ctx); + void EstablishDefaultVirtualDevices(const transform::PassContext& pass_ctx); /*! * \brief Returns a freshly constructed \p Target to represent \p device_type. @@ -147,9 +148,9 @@ class CompilationConfigNode : public Object { Target FindPrimitiveTargetOrFail(DLDeviceType device_type) const; /*! - * \brief A cache of constructed SEScopes. + * \brief A cache of constructed virtual devices. */ - mutable SEScopeCache se_scope_cache_; + mutable VirtualDeviceCache virtual_device_cache_; friend class CompilationConfig; }; diff --git a/include/tvm/target/se_scope.h b/include/tvm/target/virtual_device.h similarity index 65% rename from include/tvm/target/se_scope.h rename to include/tvm/target/virtual_device.h index 314bf054d7ea..07011ea412a3 100644 --- a/include/tvm/target/se_scope.h +++ b/include/tvm/target/virtual_device.h @@ -18,12 +18,13 @@ */ /*! - * \file tvm/target/se_scope.h - * \brief A compile time representation for a Storage or Execution Scope. + * \file tvm/target/virtual_device.h + * \brief A compile time representation for where data is to be stored at runtime, and how to + * compile code to compute it. */ -#ifndef TVM_TARGET_SE_SCOPE_H_ -#define TVM_TARGET_SE_SCOPE_H_ +#ifndef TVM_TARGET_VIRTUAL_DEVICE_H_ +#define TVM_TARGET_VIRTUAL_DEVICE_H_ #include #include @@ -44,9 +45,13 @@ namespace tvm { using MemoryScope = String; /*! - * \brief Describes at compile time where data is to be stored down to the device and memory - * scope level, or where execution is to take place, down to the device level. It is a quadruple of: - * - A \p device_type (\p DLDeviceType). May be kInvalidDeviceType if unconstrained. + * \brief Describes at compile time the constraints on where data is to be stored at runtime + * down to the (virtual) device and memory scope level, and how to compile code to compute that + * data. Used by the \p PlanDevices pass to collect and solve (virtual) device constraints for + * the whole Relay program. + * + * Is a quadruple of: + * - A \p device_type (\p DLDeviceType). May be \p kInvalidDeviceType if unconstrained. * - A \p virtual_device_id (\p int). This allows us to distinguish distinct devices * with the same \p Target, for example in a multi-GPU system. May be -1 if unconstrained. * See "Virtual Devices" below. @@ -60,19 +65,19 @@ using MemoryScope = String; * choose a value consistent with the whole program. However if a \p target is given then the \p * device_type must equal \p target->kind->device_type. * - * Note that currently we assume if a function returns its result on a particular device + * Note that currently we assume if a function returns its result on a particular (virtual) device * then the function body is also executed on that device. See the overview comment in * src/relay/transforms/device_planner.cc for more details. * * By 'data' we include both tensors and additional supporting datastructures such as shapes, - * Relay AST items, Relay tuples, and Relay references. Typically non-tensor data must reside - * on a 'CPU'-like device with good support for scalars. + * Relay ADT items (including tuples), Relay references, and Relay closures. Typically non-tensor + * data must reside on a 'CPU'-like host device with good support for scalars. * * By 'execution' we include both (fused) primitive operators, and all the Relay expressions * surrounding them which coordinates data and control flow. Again, typically non-primitive * operators must be executed on a 'CPU'-like device with good support for control flow. * - * Since TVM targets such a wide range of systems it is not possible for \p SEScope to impose + * Since TVM targets such a wide range of systems it is not possible for \p VirtualDevice to impose * much semantics on these fields, particularly for \p virtual_device_id and \p memory_scope. * Instead we assume downstream passes and codegen will interpret an validate these fields * appropriately. @@ -84,7 +89,7 @@ using MemoryScope = String; * compile time) describe a physical device on the target system. Obviously the target must agree * with the device's microarchitecture, but we otherwise don't impose any constraints between them: * - It's ok to use different \p Targets for the same \p Device, eg to squeeze some extra perf - * out of a particular primitive. + * out of a particular primitive using particular compiler flags. * - It's ok to use the same \p Target for multiple \p Devices, eg if we have multiple CPUs. * * Traditionally TVM assumes at most one \p Target per \p DLDeviceType. We are moving away from that @@ -133,14 +138,14 @@ using MemoryScope = String; * a memory scope to only be accessible to a device when code is compiled with particular * \p Target options. * - * \p SEScopes themselves have no system-level understanding. Currently device planning will - * simply insert "device_copy" operators wherever \p SEScopes are not exactly pointwise equal. - * We may revisit this in the future as the work on memory pools matures. + * \p VirtualDevices themselves have no system-level understanding. Currently the \p PlanDevices + * pass will simply insert "device_copy" operators wherever \p VirtualDevices are not exactly + * pointwise equal. We may revisit this in the future as the work on memory pools matures. * * Joining and Defaulting * ---------------------- - * It is possible to 'join' two \p SEScopes to yield the most constrained \p SEScope which agrees - * with both join arguments. Eg: + * It is possible to 'join' two \p VirtualDevices to yield the most constrained \p VirtualDevice + * which agrees with both join arguments. Eg: * \code * Join((kDLCPU, -1, "llvm", ""), (kInvalidDeviceType, 3, null, "global)) * => (kDLCPU, 3, "llvm", "global") @@ -156,9 +161,8 @@ using MemoryScope = String; * \endcode * * These operations are needed during device planning. - * */ -class SEScopeNode : public AttrsNode { +class VirtualDeviceNode : public AttrsNode { private: /*! * \brief The \p DLDeviceType (represented as an int) of the virtual device. If \p target is @@ -187,7 +191,7 @@ class SEScopeNode : public AttrsNode { /*! * \brief The \p Target describing how to compile for the virtual device. * - * Null denotes unconstrained. Note that if a target later becomes known for this \p SEScope + * Null denotes unconstrained. Note that if a target later becomes known for this \p VirtualDevice * then it must be consistent with the \p device_type if already known. This is enforced by the * Join and Default methods. */ @@ -201,8 +205,8 @@ class SEScopeNode : public AttrsNode { MemoryScope memory_scope; /*! - * \brief Returns true if scope is fully unconstrained, ie no target/device type, device id - * or memory scope is specified. + * \brief Returns true if virtual device is 'fully unconstrained', ie no target/device type, + * device id or memory scope is specified. */ bool IsFullyUnconstrained() const { return !target.defined() && device_type() == kInvalidDeviceType && virtual_device_id == -1 && @@ -210,18 +214,18 @@ class SEScopeNode : public AttrsNode { } /*! - * \brief Returns true if scope is fully constrained, ie target, device id and memory scope are - * all specified. + * \brief Returns true if virtual device is 'fully constrained', ie target, device id and memory + * scope are all specified. */ bool IsFullyConstrained() const { return target.defined() && virtual_device_id != -1 && !memory_scope.empty(); } /*! - * \brief Returns the (virtual) \p Device implied by this \p SEScope. Both the \p device_type and - * \p virtual_device_must be constrained. The returned \p Device may not correspond to any - * physical device available at compile time or even runtime: see "Virtual vs Physical Devices" - * above. + * \brief Returns the (virtual) \p Device implied by this \p VirtualDevice. Both the \p + * device_type and \p virtual_device_must be constrained. The returned \p Device may not + * correspond to any physical device available at compile time or even runtime: see "Virtual vs + * Physical Devices" above. */ Device ToDevice() const { ICHECK(device_type() != kInvalidDeviceType); @@ -232,7 +236,7 @@ class SEScopeNode : public AttrsNode { return device; } - TVM_DECLARE_ATTRS(SEScopeNode, "SEScope") { + TVM_DECLARE_ATTRS(VirtualDeviceNode, "VirtualDevice") { TVM_ATTR_FIELD(device_type_int) .describe("The type of the virtual device.") .set_default(kInvalidDeviceType); @@ -247,74 +251,72 @@ class SEScopeNode : public AttrsNode { .set_default(""); } - friend class SEScope; + friend class VirtualDevice; }; /*! - * \brief Managed reference class to \p SEScopeNode. - * - * \sa SEScopeNode. + * \brief Managed reference class to \p VirtualDeviceNode. */ -class SEScope : public ObjectRef { +class VirtualDevice : public ObjectRef { public: /*! - * \brief Construct an SEScope. - * \param device_type The device type for the virtual device, or kInvalidDeviceType if + * \brief Construct a virtual device. + * \param device_type The device type for the virtual device, or \p kInvalidDeviceType if * unconstrained. If \p target is defined then must match its \p target->kind->device_type. * \param virtual_device_id The device id for the virtual device, or -1 if unconstrained. * \param target The target describing how to compile for the virtual device, or null if * unconstrained. * \param memory_scope The memory scope w.r.t. the virtual device which holds data, or "" if * unconstrained. - * \return The SEScope + * \return The virtual device. */ - explicit SEScope(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, - Target target = {}, MemoryScope memory_scope = {}); + explicit VirtualDevice(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + Target target = {}, MemoryScope memory_scope = {}); - /*! \brief Returns the unique fully unconstrained \p SEScope. */ - static SEScope FullyUnconstrained(); + /*! \brief Returns the unique fully unconstrained \p VirtualDevice. */ + static VirtualDevice FullyUnconstrained(); /*! - * \brief Returns the \p SEScope for \p device_type and (if not -1) \p virtual_device_id. + * \brief Returns the \p VirtualDevice for \p device_type and (if not -1) \p virtual_device_id. * The target and memory scope will be unconstrained. */ - static SEScope ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) { + static VirtualDevice ForDeviceType(DLDeviceType device_type, int virtual_device_id = -1) { ICHECK_GT(device_type, 0); - return SEScope(device_type, virtual_device_id); + return VirtualDevice(device_type, virtual_device_id); } - static SEScope ForDeviceType(int device_type, int virtual_device_id = -1) { + static VirtualDevice ForDeviceType(int device_type, int virtual_device_id = -1) { return ForDeviceType(static_cast(device_type), virtual_device_id); } - static SEScope ForDeviceType(const Integer& device_type, int virtual_device_id = -1) { + static VirtualDevice ForDeviceType(const Integer& device_type, int virtual_device_id = -1) { return ForDeviceType(static_cast(device_type->value), virtual_device_id); } - /*! \brief Returns the \p SEScope for \p device. */ - static SEScope ForDevice(const Device& device) { + /*! \brief Returns the \p VirtualDevice for \p device. */ + static VirtualDevice ForDevice(const Device& device) { return ForDeviceType(device.device_type, device.device_id); } - /*! \brief Returns the \p SEScope for \p device and \p target. */ - static SEScope ForDeviceAndTarget(const Device& device, Target target) { - return SEScope(device.device_type, device.device_id, std::move(target)); + /*! \brief Returns the \p VirtualDevice for \p device and \p target. */ + static VirtualDevice ForDeviceAndTarget(const Device& device, Target target) { + return VirtualDevice(device.device_type, device.device_id, std::move(target)); } - /*! \brief Returns the \p SEScope for \p target. */ - static SEScope ForTarget(Target target) { + /*! \brief Returns the \p VirtualDevice for \p target. */ + static VirtualDevice ForTarget(Target target) { DLDeviceType device_type = static_cast(target->kind->device_type); - return SEScope(device_type, /*virtual_device_id=*/0, std::move(target)); + return VirtualDevice(device_type, /*virtual_device_id=*/0, std::move(target)); } - /*! \brief Returns the \p SEScope for \p memory_scope alone. */ - static SEScope ForMemoryScope(MemoryScope memory_scope) { - return SEScope(kInvalidDeviceType, -1, {}, std::move(memory_scope)); + /*! \brief Returns the \p VirtualDevice for \p memory_scope alone. */ + static VirtualDevice ForMemoryScope(MemoryScope memory_scope) { + return VirtualDevice(kInvalidDeviceType, -1, {}, std::move(memory_scope)); } - /*! \brief Returns the \p SEScope for \p device, \p target and \p memory_scope. */ - TVM_DLL static SEScope ForDeviceTargetAndMemoryScope(const Device& device, Target target, - MemoryScope memory_scope) { - return SEScope(device.device_type, device.device_id, std::move(target), - std::move(memory_scope)); + /*! \brief Returns the \p VirtualDevice for \p device, \p target and \p memory_scope. */ + TVM_DLL static VirtualDevice ForDeviceTargetAndMemoryScope(const Device& device, Target target, + MemoryScope memory_scope) { + return VirtualDevice(device.device_type, device.device_id, std::move(target), + std::move(memory_scope)); } /*! @@ -322,41 +324,43 @@ class SEScope : public ObjectRef { * \p lhs and \p rhs on all their constrained fields. Returns the null optional if no such * join exists, ie there's disagreement on at least one constrained field. */ - static Optional Join(const SEScope& lhs, const SEScope& rhs); + static Optional Join(const VirtualDevice& lhs, const VirtualDevice& rhs); /*! * \brief Returns the 'default' of \p lhs and \p rhs. The result will be \p lhs, except any * unconstrained fields in \p lhs will take their value from \p rhs. Always well-defined. */ - static SEScope Default(const SEScope& lhs, const SEScope& rhs); + static VirtualDevice Default(const VirtualDevice& lhs, const VirtualDevice& rhs); - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SEScope, ObjectRef, SEScopeNode); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(VirtualDevice, ObjectRef, VirtualDeviceNode); - friend class SEScopeCache; // Private implementation helper. + friend class VirtualDeviceCache; // Private implementation helper. }; /*! - * \brief A cache of \p SEScopes. This can be used: - * - To avoid ending up with lots of identical instances, since the space of SEScopes for any + * \brief A cache of \p VirtualDevices. This can be used: + * - To avoid ending up with lots of identical instances, since the space of VirtualDevices for any * one compilation is very small but the number of points they need to be constructed can * be very large (eg during device planning). - * - So we can assume \p SEScopes are pointer equal if and only if they are structurally equal. - * This simplifies the unification of 'device domains' which are built on \p SEScopes. + * - So we can assume \p VirtualDevices are pointer equal if and only if they are structurally + * equal. This simplifies the unification of 'device domains' which are built on \p VirtualDevices. */ -class SEScopeCache { +class VirtualDeviceCache { public: - /*! \brief Returns the unique \p SEScope representing given fields. */ - SEScope Make(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, - Target target = {}, MemoryScope memory_scope = {}); + /*! \brief Returns the unique \p VirtualDevice representing given fields. */ + VirtualDevice Make(DLDeviceType device_type = kInvalidDeviceType, int virtual_device_id = -1, + Target target = {}, MemoryScope memory_scope = {}); - /*! \brief Returns the unique \p SEScope structurally equal to the given \p se_scope. */ - SEScope Unique(const SEScope& scope); + /*! + * \brief Returns the unique \p VirtualDevice structurally equal to the given \p virtual_device. + */ + VirtualDevice Unique(const VirtualDevice& virtual_device); private: - /*! \brief Already constructed SEScopes. */ - std::unordered_set cache_; + /*! \brief Already constructed VirtualDevices. */ + std::unordered_set cache_; }; } // namespace tvm -#endif // TVM_TARGET_SE_SCOPE_H_ +#endif // TVM_TARGET_VIRTUAL_DEVICE_H_ diff --git a/python/tvm/relay/op/annotation/annotation.py b/python/tvm/relay/op/annotation/annotation.py index cb4e628ebc92..f2ce6c563eab 100644 --- a/python/tvm/relay/op/annotation/annotation.py +++ b/python/tvm/relay/op/annotation/annotation.py @@ -23,11 +23,11 @@ from .. import op as reg -def _make_se_scope(device): +def _make_virtual_device(device): if isinstance(device, _Device): - return target.make_se_scope(device) + return target.make_virtual_device(device) if isinstance(device, str): - return target.make_se_scope(_nd.device(device)) + return target.make_virtual_device(_nd.device(device)) raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) @@ -59,7 +59,7 @@ def on_device(body, device, constrain_result=False, constrain_body=True): result : tvm.relay.Expr The annotated expression. """ - return _make.OnDevice(body, _make_se_scope(device), constrain_result, constrain_body) + return _make.OnDevice(body, _make_virtual_device(device), constrain_result, constrain_body) def function_on_device(function, param_devices, result_device): @@ -83,7 +83,9 @@ def function_on_device(function, param_devices, result_device): The annotated function. """ return _make.FunctionOnDevice( - function, [_make_se_scope(d) for d in param_devices], _make_se_scope(result_device) + function, + [_make_virtual_device(d) for d in param_devices], + _make_virtual_device(result_device), ) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index d9847a453569..20b883ba2616 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -27,11 +27,11 @@ from . import op as reg -def _make_se_scope(device): +def _make_virtual_device(device): if isinstance(device, _Device): - return target.make_se_scope(device) + return target.make_virtual_device(device) if isinstance(device, str): - return target.make_se_scope(_nd.device(device)) + return target.make_virtual_device(_nd.device(device)) raise ValueError("expecting a Device or device name, but received a %s" % (type(device))) @@ -1211,7 +1211,9 @@ def device_copy(data, src_device, dst_device): result : tvm.relay.Expr The copied result. """ - return _make.DeviceCopy(data, _make_se_scope(src_device), _make_se_scope(dst_device)) + return _make.DeviceCopy( + data, _make_virtual_device(src_device), _make_virtual_device(dst_device) + ) def shape_of(data, dtype="int32"): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 4369009559ba..696bd6258ee6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1164,12 +1164,12 @@ def SimplifyExpr(): def PlanDevices(config): """ - Uses existing "on_device" and "device_copy" CallNodes to infer the SEScope on which + Uses existing "on_device" and "device_copy" calls to infer the virtual device on which every Relay sub-expression should run and the result stored. Captures the result of that - analysis using new "on_device" and "device_copy" CallNodes. Sub-expressions which are - not otherwise constrained are assigned to the default_primitive_se_scope. However data and - computations which must be hosted on a CPU (such as shapes and shape functions) use the - cpu_se_scope. + analysis using new "on_device" and "device_copy" calls. Sub-expressions which are + not otherwise constrained are assigned to the default primitive virtual device describe by + config. However data and computations which must be hosted on a CPU (such as shapes and shape functions) + use the host virtual device of the config. Parameters ---------- diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 1b90a506624e..6c13ceddc21e 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -71,7 +71,7 @@ riscv_cpu, hexagon, ) -from .se_scope import make_se_scope +from .virtual_device import make_virtual_device from .compilation_config import make_compilation_config from .tag import list_tags from .generic_func import GenericFunc diff --git a/python/tvm/target/se_scope.py b/python/tvm/target/virtual_device.py similarity index 72% rename from python/tvm/target/se_scope.py rename to python/tvm/target/virtual_device.py index 83df5ae3448a..a88d405ac9ea 100644 --- a/python/tvm/target/se_scope.py +++ b/python/tvm/target/virtual_device.py @@ -14,9 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Python bindings for creating SEScopes.""" +"""Python bindings for creating VirtualDevices.""" from . import _ffi_api -def make_se_scope(device, target=None, memory_scope=""): - return _ffi_api.SEScope_ForDeviceTargetAndMemoryScope(device, target, memory_scope) +# TODO(mbs): We need an official Python class representation given the importance of this structure. + + +def make_virtual_device(device, target=None, memory_scope=""): + return _ffi_api.VirtualDevice_ForDeviceTargetAndMemoryScope(device, target, memory_scope) diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index d0c2cfebbbd8..fdc6c37e527a 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -37,7 +37,7 @@ #include #include #include -#include +#include #include #include "../ir/attr_functor.h" @@ -906,14 +906,14 @@ Doc RelayTextPrinter::PrintAttributeValue(const ObjectRef& value, bool force_met printed_attr << Doc::StrLiteral(GetRef(str_obj)); } else if (force_meta) { printed_attr = meta_->GetMetaNode(Downcast(value)); - } else if (const auto* se_scope_node = value.as()) { + } else if (const auto* virtual_device_node = value.as()) { if (show_meta_data_) { - printed_attr = meta_->GetMetaNode(GetRef(se_scope_node)); + printed_attr = meta_->GetMetaNode(GetRef(virtual_device_node)); } else { - // Special case: The ReprPrinter for SEScopeNodes is much easier to work with while + // Special case: The ReprPrinter for VirtualDeviceNodes is much easier to work with while // debugging. std::ostringstream os; - os << GetRef(se_scope_node); + os << GetRef(virtual_device_node); return Doc::Text(os.str()); } } else if (const auto* base_attr_node = value.as()) { diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 9ea1e423c119..d901f8a26c4f 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -135,18 +135,19 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { void VisitExpr_(const TupleNode* op) final { std::vector storage_ids; - std::vector se_scopes; + std::vector virtual_devices; std::vector storage_sizes_in_bytes; Expr expr = GetRef(op); for (Expr field : op->fields) { auto sid = GetStorage(field); storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end()); - se_scopes.insert(se_scopes.end(), sid->se_scopes.begin(), sid->se_scopes.end()); + virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(), + sid->virtual_devices.end()); storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(), sid->storage_sizes_in_bytes.begin(), sid->storage_sizes_in_bytes.end()); } - storage_device_map_[expr] = StorageInfo(storage_ids, se_scopes, storage_sizes_in_bytes); + storage_device_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes); AssignReturnSid(expr); } @@ -155,7 +156,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { auto sids = GetStorage(op->tuple); ICHECK_LT(static_cast(op->index), sids->storage_ids.size()); storage_device_map_[expr] = - StorageInfo({sids->storage_ids[op->index]}, {sids->se_scopes[op->index]}, + StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]}, {sids->storage_sizes_in_bytes[op->index]}); AssignReturnSid(expr); } @@ -221,24 +222,25 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { */ void CreateStorage(const ExprNode* op) { Expr expr = GetRef(op); - return CreateStorage(expr, GetSEScope(expr)); + return CreateStorage(expr, GetVirtualDevice(expr)); } /*! - * \brief Create storage to hold the result of evaluating \p expr in \p se_scope. + * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device. */ - void CreateStorage(const Expr& expr, const SEScope& se_scope) { - ICHECK(!se_scope->IsFullyUnconstrained()) << "invalid SEScope for expr:" << std::endl - << PrettyPrint(expr); + void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) { + ICHECK(!virtual_device->IsFullyUnconstrained()) + << "invalid virtual device for expr:" << std::endl + << PrettyPrint(expr); std::vector storage_ids; - std::vector se_scopes; + std::vector virtual_devices; std::vector storage_sizes_in_bytes; for (const auto& ttype : FlattenTupleType(expr->checked_type())) { storage_ids.push_back(next_available_sid_++); - se_scopes.push_back(se_scope); + virtual_devices.push_back(virtual_device); storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); } - storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(se_scopes), + storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices), std::move(storage_sizes_in_bytes)); } @@ -736,7 +738,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { use_unpacked_api_ = executor_config->GetAttr("unpacked-api").value_or(Bool(false)); // TODO(mbs): Plumb from compiler config - SEScope host_se_scope = SEScope::ForTarget(target_host_); + VirtualDevice host_virtual_device = VirtualDevice::ForTarget(target_host_); IRModule lowered_mod = tec::LowerTEPass( mod_name, @@ -753,7 +755,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { // lowering process directly. tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); }, - host_se_scope)(mod); + host_virtual_device)(mod); auto lowered_main = lowered_mod->Lookup("main"); auto lowered_main_func = GetRef(lowered_main.as()); diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ab86dbf41b3f..ccfd30476f67 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -424,7 +424,7 @@ class RelayBuildModule : public runtime::ModuleNode { lowered_funcs.Set(ext_dev, IRModule()); } - const Target& host_target = config_->host_se_scope->target; + const Target& host_target = config_->host_virtual_device->target; const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.LLVMModuleCreate"); // Generate a placeholder function that attaches linked params as its arguments. diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index 16b1ddb3c82f..f61fe9b402b3 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -245,7 +245,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorfunction_metadata_); }, - config->host_se_scope)(mod); + config->host_virtual_device)(mod); Optional main_func_info = lowered_mod->GetAttr("main_func_info"); @@ -328,10 +328,10 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslatorattrs_["storage_id"] = std::move(storage_ids); // type std::vector device_types; - for (const auto& se_scope : storage_info->se_scopes) { + for (const auto& virtual_device : storage_info->virtual_devices) { // TODO(mbs): Keeping only the device type. - ICHECK_GT(se_scope->device_type(), 0); - device_types.push_back(se_scope->device_type()); + ICHECK_GT(virtual_device->device_type(), 0); + device_types.push_back(virtual_device->device_type()); } size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) { diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 3ee318740ba8..2ad27a0d20b0 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -53,19 +53,21 @@ struct StorageToken { size_t max_bytes{0}; /*! \brief The corresponding tensor type. */ TensorType ttype{nullptr}; - /*! \brief SEScope on which the memory will reside. */ - SEScope se_scope = SEScope::FullyUnconstrained(); + /*! \brief VirtualDevice on which the memory will reside. */ + VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained(); /*! \brief The storage id */ int64_t storage_id{-1}; - bool is_valid() const { return !se_scope->IsFullyUnconstrained(); } + bool is_valid() const { return !virtual_device->IsFullyUnconstrained(); } - bool is_compatible(const StorageToken& that) const { return se_scope == that.se_scope; } + bool is_compatible(const StorageToken& that) const { + return virtual_device == that.virtual_device; + } std::string ToString() const { std::ostringstream os; os << "{storage_id: " << storage_id << ", max_bytes: " << max_bytes - << ", ttype: " << PrettyPrint(ttype) << ", se_scope: " << se_scope << "}"; + << ", ttype: " << PrettyPrint(ttype) << ", virtual_device: " << virtual_device << "}"; return os.str(); } }; @@ -167,14 +169,14 @@ class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { * the result of evaluating \p op. */ void CreateToken(const ExprNode* expr_node, bool can_realloc) { - return CreateTokenOnDevice(expr_node, GetSEScope(GetRef(expr_node)), can_realloc); + return CreateTokenOnDevice(expr_node, GetVirtualDevice(GetRef(expr_node)), can_realloc); } /*! * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding * the result of evaluating \p op on \p device_type. */ - virtual void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, + virtual void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device, bool can_realloc) = 0; }; @@ -193,13 +195,14 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { protected: using StorageAllocaBaseVisitor::VisitExpr_; - void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) override { + void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device, + bool can_realloc) override { ICHECK(!token_map_.count(op)); std::vector tokens; for (const auto& ttype : FlattenTupleType(op->checked_type())) { auto* token = arena_->make(); token->ttype = ttype; - token->se_scope = se_scope; + token->virtual_device = virtual_device; tokens.push_back(token); } token_map_[op] = tokens; @@ -256,8 +259,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { for (const auto& kv : token_map_) { std::vector storage_ids; storage_ids.reserve(kv.second.size()); - std::vector se_scopes; - se_scopes.reserve(kv.second.size()); + std::vector virtual_devices; + virtual_devices.reserve(kv.second.size()); std::vector sid_sizes_byte; sid_sizes_byte.reserve(kv.second.size()); @@ -268,10 +271,10 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } num_nodes++; storage_ids.push_back(tok->storage_id); - se_scopes.push_back(tok->se_scope); + virtual_devices.push_back(tok->virtual_device); sid_sizes_byte.push_back(GetMemorySize(tok)); } - auto storage_info = backend::StorageInfo(std::move(storage_ids), std::move(se_scopes), + auto storage_info = backend::StorageInfo(std::move(storage_ids), std::move(virtual_devices), std::move(sid_sizes_byte)); smap.Set(GetRef(kv.first), storage_info); } @@ -286,20 +289,21 @@ class StorageAllocator : public StorageAllocaBaseVisitor { protected: // override create token by getting token as prototype requirements. - void CreateTokenOnDevice(const ExprNode* op, const SEScope& se_scope, bool can_realloc) final { + void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device, + bool can_realloc) final { ICHECK(!token_map_.count(op)); auto it = prototype_.find(op); ICHECK(it != prototype_.end()); std::vector tokens; for (StorageToken* tok : it->second) { - ICHECK(tok->se_scope == se_scope); + ICHECK(tok->virtual_device == virtual_device); if (can_realloc) { tokens.push_back(Request(tok)); } else { // Allocate a new token, StorageToken* allocated_tok = Alloc(tok, GetMemorySize(tok)); - allocated_tok->se_scope = tok->se_scope; + allocated_tok->virtual_device = tok->virtual_device; // ensure it never get de-allocated. allocated_tok->ref_counter += 1; tokens.push_back(allocated_tok); diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index 82a04551c145..2bea8101d645 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -474,7 +474,7 @@ class Interpreter : public ExprFunctor, // whether the shape and or data needs to be passed, and flattening of tuples. // Similarly, num_shape_outputs will account for flattening of tuples. - // TODO(mbs): Take this from the host_se_scope. + // TODO(mbs): Take this from the host_virtual_device. Device shape_device; shape_device.device_type = static_cast(prim_shape_target->kind->device_type); shape_device.device_id = 0; @@ -754,7 +754,7 @@ class Interpreter : public ExprFunctor, return InvokePrimitiveOp(call_lowered_props.lowered_func, all_prim_fn_vars, config_->optional_homogeneous_target, prim_shape_fn_var, all_prim_shape_fn_vars, prim_shape_fn_states, num_shape_inputs, - num_shape_outputs, config_->host_se_scope->target, args); + num_shape_outputs, config_->host_virtual_device->target, args); } else { // All other calls // Evaluate all arguments std::vector args; @@ -945,7 +945,7 @@ class Interpreter : public ExprFunctor, * functions needed by the rewritten module. */ IRModule Prepare(IRModule mod, CompilationConfig config) { - SEScope host_se_scope = config->host_se_scope; + VirtualDevice host_virtual_device = config->host_virtual_device; // Run minimal transforms on module to establish invariants needed by interpreter. transform::Sequential seq( {transform::SimplifyInference(), @@ -962,7 +962,7 @@ IRModule Prepare(IRModule mod, CompilationConfig config) { /*expand_constructor=*/true, /*expand_global_var=*/false), transform::InferType(), tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, - std::move(host_se_scope))}); + std::move(host_virtual_device))}); transform::PassContext pass_ctx = transform::PassContext::Current(); With ctx(pass_ctx); diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 528df647fe4a..901661dd87a3 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -503,13 +503,13 @@ using AnalysisRemapping = std::unordered_maptarget); + CCacheKey shape_key(func, host_virtual_device_->target); CachedFunc lowered_shape_func = compiler_->LowerShapeFunc(shape_key); // Capture the shape function's global var and parameters 'states' in call @@ -707,8 +707,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { DeviceCopyProps device_copy_props = GetDeviceCopyProps(function_node->body); if (device_copy_props.body.defined()) { ICHECK_EQ(new_args.size(), 1); - return DeviceCopy(new_args[0], device_copy_props.src_se_scope, - device_copy_props.dst_se_scope); + return DeviceCopy(new_args[0], device_copy_props.src_virtual_device, + device_copy_props.dst_virtual_device); } } @@ -746,9 +746,9 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { target = Target("ext_dev"); } else { // The target corresponding to the call_node expression's annotation. - SEScope se_scope = GetSEScope(GetRef(call_node)); - ICHECK(!se_scope->IsFullyUnconstrained()); - target = se_scope->target; + VirtualDevice virtual_device = GetVirtualDevice(GetRef(call_node)); + ICHECK(!virtual_device->IsFullyUnconstrained()); + target = virtual_device->target; ICHECK(target.defined()); } @@ -769,10 +769,10 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { String module_name_; TECompiler compiler_; /*! - * \brief The \p SEScope for the host, which is where all shape-related data and computation + * \brief The \p VirtualDevice for the host, which is where all shape-related data and computation * must live. */ - SEScope host_se_scope_; + VirtualDevice host_virtual_device_; // Cache ops that need to be frequently used later to reduce lookup overhead. const Op& debug_op_; }; @@ -808,10 +808,11 @@ Target GetTargetFromInteger(DLDeviceType dev_type, tec::TargetMap targets) { } Pass LowerTensorExpr(const String& module_name, TECompiler compiler, ProcessFn process_fn, - SEScope host_se_scope) { + VirtualDevice host_virtual_device) { runtime::TypedPackedFunc pass_func = [=](Function func, IRModule module, PassContext ctx) { - LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler, host_se_scope); + LowerTensorExprMutator lower_te(module, process_fn, module_name, compiler, + host_virtual_device); return Downcast(lower_te.Mutate(func)); }; return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); @@ -828,7 +829,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa } // This is a Map> - // TODO(mbs): Collapsing SEScopes to just device type. + // TODO(mbs): Collapsing VirtualDevices to just device type. std::unordered_map, backend::EnumClassHash> sid_workspace; // This is a Map @@ -841,10 +842,10 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa for (const auto& kv : storage_info_map) { const backend::StorageInfo& storage_info = kv.second; const std::vector& storage_ids = storage_info->storage_ids; - const std::vector& se_scopes = storage_info->se_scopes; - CHECK_EQ(storage_ids.size(), se_scopes.size()); - for (uint32_t i = 0; i < se_scopes.size(); i++) { - DLDeviceType device_type = se_scopes[i]->device_type(); + const std::vector& virtual_devices = storage_info->virtual_devices; + CHECK_EQ(storage_ids.size(), virtual_devices.size()); + for (uint32_t i = 0; i < virtual_devices.size(); i++) { + DLDeviceType device_type = virtual_devices[i]->device_type(); sid_workspace[device_type][storage_ids[i]] = 0; device_io[device_type] = 0; device_consts[device_type] = 0; @@ -877,18 +878,18 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa << "has size " << size_bytes << " and storage info:" << std::endl << storage_info; const std::vector& storage_ids = storage_info->storage_ids; - const std::vector& se_scopes = storage_info->se_scopes; + const std::vector& virtual_devices = storage_info->virtual_devices; if (expr->IsInstance()) { - for (const auto& se_scope : se_scopes) { - DLDeviceType device_type = se_scope->device_type(); + for (const auto& virtual_device : virtual_devices) { + DLDeviceType device_type = virtual_device->device_type(); ICHECK_EQ(device_consts.count(device_type), 1); device_consts[device_type] += size_bytes; } } else if (expr->IsInstance() || expr.same_as(func->body)) { - CHECK_GE(se_scopes.size(), 1) << "must be at least one device"; - for (const auto& se_scope : se_scopes) { - DLDeviceType device_type = se_scope->device_type(); + CHECK_GE(virtual_devices.size(), 1) << "must be at least one device"; + for (const auto& virtual_device : virtual_devices) { + DLDeviceType device_type = virtual_device->device_type(); device_io[device_type] += size_bytes; } } else { @@ -899,7 +900,7 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, tec::TargetMa // Here we record the largest size of the tensor // that share the same storage id, because storage_id will // be shared between multiple tensors that are not live simultaneously. - DLDeviceType device_type = se_scopes[i]->device_type(); + DLDeviceType device_type = virtual_devices[i]->device_type(); if (size_bytes > sid_workspace[device_type][storage_ids[i]]) { sid_workspace[device_type][storage_ids[i]] = size_bytes; } @@ -1045,7 +1046,7 @@ void UpdateFunctionMetadata(BaseFunc func, } IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn process_fn, - SEScope host_se_scope) { + VirtualDevice host_virtual_device) { TECompiler compiler(module); // TODO(mbs): This is all unnecessarily convoluted. Better would be to accumulate the rewritten @@ -1061,7 +1062,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // - Calls to functions tagged with "Primitive" are compiled to PrimFuncs, and calls updated // (using call_lowered convention). IRModule updated_module = LowerTensorExpr(module_name, compiler, std::move(process_fn), - std::move(host_se_scope))(module); + std::move(host_virtual_device))(module); // The Functions tagged with "Compiler" are now residing in the cache ready to be // compiled by LowerExternalFunctions. However we still need a record of them in the @@ -1161,10 +1162,11 @@ Map GetPerTargetModules(IRModule mod) { return per_target_modules; } -Pass LowerTEPass(const String& module_name, ProcessFn process_fn, SEScope host_se_scope) { +Pass LowerTEPass(const String& module_name, ProcessFn process_fn, + VirtualDevice host_virtual_device) { runtime::TypedPackedFunc pass_func = [=](IRModule module, PassContext ctx) { - return LowerTE(module, module_name, process_fn, host_se_scope); + return LowerTE(module, module_name, process_fn, host_virtual_device); }; return tvm::transform::Sequential( diff --git a/src/relay/backend/te_compiler.h b/src/relay/backend/te_compiler.h index 60dd5fe2c6b3..b6f2218e2319 100644 --- a/src/relay/backend/te_compiler.h +++ b/src/relay/backend/te_compiler.h @@ -214,10 +214,11 @@ IRModule LowerTE( * \param module_name The name of this module * \param process_fn Callback allowing one-level up code generators to process * each function that we lower - * \param host_se_scope \p SEScope for host data and computations + * \param host_virtual_device \p VirtualDevice for host data and computations * \returns The pass which lowers primative functions to TIR */ -transform::Pass LowerTEPass(const String& module_name, ProcessFn process_fn, SEScope host_se_scope); +transform::Pass LowerTEPass(const String& module_name, ProcessFn process_fn, + VirtualDevice host_virtual_device); } // namespace tec } // namespace relay } // namespace tvm diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 252c43f9bc03..608d4cdb9f85 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -43,9 +43,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) for (auto id : node->storage_ids) { p->stream << id << ","; } - p->stream << "], se_scopes=["; - for (const auto& se_scope : node->se_scopes) { - p->stream << se_scope << ","; + p->stream << "], virtual_devices=["; + for (const auto& virtual_device : node->virtual_devices) { + p->stream << virtual_device << ","; } p->stream << "], storage_size_in_bytes=["; for (auto bytes : node->storage_sizes_in_bytes) { @@ -54,13 +54,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "])"; }); -StorageInfo::StorageInfo(std::vector storage_ids, std::vector se_scopes, +StorageInfo::StorageInfo(std::vector storage_ids, + std::vector virtual_devices, std::vector storage_sizes_in_bytes) { - ICHECK_EQ(storage_ids.size(), se_scopes.size()); + ICHECK_EQ(storage_ids.size(), virtual_devices.size()); ICHECK_EQ(storage_ids.size(), storage_sizes_in_bytes.size()); auto node = make_object(); node->storage_ids = std::move(storage_ids); - node->se_scopes = std::move(se_scopes); + node->virtual_devices = std::move(virtual_devices); node->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes); data_ = std::move(node); } @@ -74,17 +75,18 @@ TVM_REGISTER_GLOBAL("relay.ir.StorageInfo") for (auto s : sids) { sids_v.push_back(s); } - std::vector se_scopes_v; - se_scopes_v.reserve(device_types.size()); + std::vector virtual_devices_v; + virtual_devices_v.reserve(device_types.size()); for (const auto& device_type : device_types) { - se_scopes_v.emplace_back(SEScope::ForDeviceType(device_type)); + virtual_devices_v.emplace_back(VirtualDevice::ForDeviceType(device_type)); } std::vector size_in_bytes_v; size_in_bytes_v.reserve(sizes_in_bytes.size()); for (auto s : sizes_in_bytes) { size_in_bytes_v.push_back(s); } - return StorageInfo(std::move(sids_v), std::move(se_scopes_v), std::move(size_in_bytes_v)); + return StorageInfo(std::move(sids_v), std::move(virtual_devices_v), + std::move(size_in_bytes_v)); }); TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageInfo si) { @@ -98,8 +100,8 @@ TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageI // This is the legacy interface for devices as DLDeviceTypes (represented by integers) TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes").set_body_typed([](StorageInfo si) { Array device_types; - for (const auto& se_scope : si->se_scopes) { - device_types.push_back(se_scope->device_type()); + for (const auto& virtual_device : si->virtual_devices) { + device_types.push_back(virtual_device->device_type()); } return device_types; }); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 64f7c653ea5a..df25a8641792 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include @@ -62,8 +62,8 @@ class StorageInfoNode : public Object { // TODO(mbs): Switch from struct-of-array to array-of-struct repr throughout. /*! \brief The set of storage ids where the expression is stored. */ std::vector storage_ids; - /* \brief The SEScopes these expressions are stored within. */ - std::vector se_scopes; + /* \brief The virtual devices these expressions are stored within. */ + std::vector virtual_devices; /* \brief The sizes of each storage element, in bytes. */ std::vector storage_sizes_in_bytes; @@ -77,7 +77,7 @@ class StorageInfoNode : public Object { /*! \brief The storage information for a single expression. */ class StorageInfo : public ObjectRef { public: - StorageInfo(std::vector storage_ids, std::vector se_scopes, + StorageInfo(std::vector storage_ids, std::vector virtual_devices, std::vector storage_sizes_in_bytes); TVM_DEFINE_OBJECT_REF_METHODS(StorageInfo, ObjectRef, StorageInfoNode); }; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 23aee452ba09..73f4b672a81c 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -235,12 +235,12 @@ std::vector ToAllocTensorShape(NDArray shape) { class VMFunctionCompiler : DeviceAwareExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, SEScope host_se_scope) + VMFunctionCompiler(VMCompilerContext* context, VirtualDevice host_virtual_device) : DeviceAwareExprFunctor(context->module), last_register_(0), registers_num_(0), context_(context), - host_se_scope_(std::move(host_se_scope)) {} + host_virtual_device_(std::move(host_virtual_device)) {} VMFunction Compile(const GlobalVar& var, const Function& func) { std::vector param_device_indexes; @@ -252,21 +252,21 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { // Do that flattening on-the-fly here. Function inner_func = Downcast(func->body); std::vector params; - std::vector param_se_scopes; + std::vector param_virtual_devices; params.reserve(func->params.size() + inner_func->params.size()); - param_se_scopes.reserve(func->params.size() + inner_func->params.size()); + param_virtual_devices.reserve(func->params.size() + inner_func->params.size()); param_device_indexes.reserve(func->params.size() + inner_func->params.size()); for (size_t i = 0; i < func->params.size(); ++i) { params.emplace_back(func->params[i]); - SEScope param_se_scope = GetFunctionParamSEScope(func.get(), i); - param_se_scopes.push_back(param_se_scope); - param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); + VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(func.get(), i); + param_virtual_devices.push_back(param_virtual_device); + param_device_indexes.push_back(GetDeviceIndex(param_virtual_device)); } for (size_t i = 0; i < inner_func->params.size(); ++i) { params.emplace_back(inner_func->params[i]); - SEScope param_se_scope = GetFunctionParamSEScope(inner_func.get(), i); - param_se_scopes.push_back(param_se_scope); - param_device_indexes.push_back(GetDeviceIndex(param_se_scope)); + VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(inner_func.get(), i); + param_virtual_devices.push_back(param_virtual_device); + param_device_indexes.push_back(GetDeviceIndex(param_virtual_device)); } std::vector type_params; type_params.reserve(func->type_params.size() + inner_func->type_params.size()); @@ -278,12 +278,13 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } Function flattened_func = Function(params, inner_func->body, inner_func->ret_type, type_params, func->attrs, func->span); - VisitExpr(MaybeFunctionOnDevice(flattened_func, param_se_scopes, - GetFunctionResultSEScope(inner_func.get()))); + VisitExpr(MaybeFunctionOnDevice(flattened_func, param_virtual_devices, + GetFunctionResultVirtualDevice(inner_func.get()))); } else { param_device_indexes.reserve(func->params.size()); for (size_t i = 0; i < func->params.size(); ++i) { - param_device_indexes.push_back(GetDeviceIndex(GetFunctionParamSEScope(func.get(), i))); + param_device_indexes.push_back( + GetDeviceIndex(GetFunctionParamVirtualDevice(func.get(), i))); } VisitExpr(func); } @@ -333,42 +334,44 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { } /*! - * \brief Returns the "device index" to represent \p se_scope for primitives + * \brief Returns the "device index" to represent \p virtual_device for primitives * in emitted code. Note that the host device is always at index 0. */ - Index GetDeviceIndex(const SEScope& se_scope) { - ICHECK(!se_scope->IsFullyUnconstrained()); - auto itr = std::find(context_->se_scopes_.begin(), context_->se_scopes_.end(), se_scope); - if (itr != context_->se_scopes_.end()) { - return std::distance(context_->se_scopes_.begin(), itr); + Index GetDeviceIndex(const VirtualDevice& virtual_device) { + ICHECK(!virtual_device->IsFullyUnconstrained()); + auto itr = std::find(context_->virtual_devices_.begin(), context_->virtual_devices_.end(), + virtual_device); + if (itr != context_->virtual_devices_.end()) { + return std::distance(context_->virtual_devices_.begin(), itr); } - ICHECK_GT(context_->se_scopes_.size(), 0); - ICHECK_NE(se_scope, host_se_scope_); // the host scope is always at index 0 + ICHECK_GT(context_->virtual_devices_.size(), 0); + ICHECK_NE(virtual_device, host_virtual_device_); // the host scope is always at index 0 - if (se_scope->device_type() == context_->se_scopes_.front()->device_type()) { + if (virtual_device->device_type() == context_->virtual_devices_.front()->device_type()) { // It's ok if we see distinct scopes which share the host device type. This is because - // we allow the SEScope for the host to be different from the SEScope for primitive - // operations which both happen to be on the same device (typically CPU). + // we allow the VirtualDevice for the host to be different from the VirtualDevice for + // primitive operations which both happen to be on the same device (typically CPU). return 0; } - // However, otherwise we allow at most one SEScope per device type. + // However, otherwise we allow at most one VirtualDevice per device type. // TODO(mbs): This will eventually need to account for memory scopes somehow so device_copy // instructions can do the right thing. - itr = std::find_if(context_->se_scopes_.begin() + 1, context_->se_scopes_.end(), - [&se_scope](const SEScope& existing_se_scope) { - return existing_se_scope->device_type() == se_scope->device_type(); + itr = std::find_if(context_->virtual_devices_.begin() + 1, context_->virtual_devices_.end(), + [&virtual_device](const VirtualDevice& existing_virtual_device) { + return existing_virtual_device->device_type() == + virtual_device->device_type(); }); - CHECK(itr == context_->se_scopes_.end()) + CHECK(itr == context_->virtual_devices_.end()) << "The VM does not currently support using more than one device with the same device type " "for primitives, however the program is using the distinct scopes " - << se_scope << " and " << *itr << " of device type " << se_scope->device_type(); + << virtual_device << " and " << *itr << " of device type " << virtual_device->device_type(); - ICHECK(se_scope != host_se_scope_); - Index index = context_->se_scopes_.size(); - VLOG(2) << "se_scope[" << index << "] = " << se_scope; - context_->se_scopes_.push_back(se_scope); + ICHECK(virtual_device != host_virtual_device_); + Index index = context_->virtual_devices_.size(); + VLOG(2) << "virtual_device[" << index << "] = " << virtual_device; + context_->virtual_devices_.push_back(virtual_device); return index; } @@ -380,7 +383,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { NDArray data = const_node->data; size_t const_index = context_->constants.size(); auto con = GetRef(const_node); - Index device_index = GetDeviceIndex(GetSEScope(con)); + Index device_index = GetDeviceIndex(GetVirtualDevice(con)); VLOG(2) << "constant[" << const_index << "] on device[" << device_index << "]"; context_->const_device_indexes.push_back(device_index); context_->constants.push_back(const_node->data); @@ -542,8 +545,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { // TODO(mbs): device_copy cleanup. VisitExpr(device_copy_props.body); RegName src_reg = last_register_; - Index src_index = GetDeviceIndex(device_copy_props.src_se_scope); - Index dst_index = GetDeviceIndex(device_copy_props.dst_se_scope); + Index src_index = GetDeviceIndex(device_copy_props.src_virtual_device); + Index dst_index = GetDeviceIndex(device_copy_props.dst_virtual_device); // Since scopes distinguish by targets (including any target hosts) but at runtime we // deal only with devices, the copy may be unnecessary. if (src_index != dst_index) { @@ -619,7 +622,7 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { auto dtype = alloc_attrs->dtype; Emit(Instruction::AllocStorage(size_register, alignment, dtype, - GetDeviceIndex(alloc_attrs->se_scope), + GetDeviceIndex(alloc_attrs->virtual_device), NewRegister())); }) .Match("vm.shape_of", @@ -819,8 +822,8 @@ class VMFunctionCompiler : DeviceAwareExprFunctor { size_t registers_num_; /*! \brief Global shared meta data */ VMCompilerContext* context_; - /*! \brief SEScope for data and computation which must reside on a CPU. */ - SEScope host_se_scope_; + /*! \brief VirtualDevice for data and computation which must reside on a CPU. */ + VirtualDevice host_virtual_device_; }; PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { @@ -873,9 +876,9 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) config_ = CompilationConfig(PassContext::Current(), std::move(targets), std::move(target_host)); // The first device is always for the host. - CHECK(context_.se_scopes_.empty()); - VLOG(2) << "se_scope[0] = " << config_->host_se_scope << " (host)"; - context_.se_scopes_.push_back(config_->host_se_scope); + CHECK(context_.virtual_devices_.empty()); + VLOG(2) << "virtual_device[0] = " << config_->host_virtual_device << " (host)"; + context_.virtual_devices_.push_back(config_->host_virtual_device); // Run the optimizations necessary to target the VM. context_.module = OptimizeModuleImpl(std::move(mod)); @@ -896,7 +899,7 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) continue; } auto func = GetRef(n); - VMFunctionCompiler func_compiler(&context_, config_->host_se_scope); + VMFunctionCompiler func_compiler(&context_, config_->host_virtual_device); auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); @@ -911,12 +914,12 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) } // Populate virtual devices and the host device index. - for (const auto& se_scope : context_.se_scopes_) { - ICHECK(!se_scope->IsFullyUnconstrained()); - ICHECK_GT(se_scope->device_type(), 0); + for (const auto& virtual_device : context_.virtual_devices_) { + ICHECK(!virtual_device->IsFullyUnconstrained()); + ICHECK_GT(virtual_device->device_type(), 0); // TODO(mbs): We forget the memory scope. - exec_->virtual_devices.push_back( - Device{/*device_type=*/se_scope->device_type(), /*device_id=*/se_scope->virtual_device_id}); + exec_->virtual_devices.push_back(Device{/*device_type=*/virtual_device->device_type(), + /*device_id=*/virtual_device->virtual_device_id}); } exec_->host_device_index = kHostDeviceIndex; @@ -952,25 +955,25 @@ void VMCompiler::Lower(IRModule mod, TargetMap targets, tvm::Target target_host) } } -transform::Sequential VMCompiler::MemoryOpt(const SEScope& host_se_scope) { +transform::Sequential VMCompiler::MemoryOpt(const VirtualDevice& host_virtual_device) { Array pass_seqs; // Remove unused functions Array entry_functions{"main"}; pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(host_se_scope)); + pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); // Fuse & lower any new shape functions and device_copies. - pass_seqs.push_back(FuseAndLowerOperators(host_se_scope)); + pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(host_se_scope)); + pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); // Fuse & lower any new allocations. - pass_seqs.push_back(FuseAndLowerOperators(host_se_scope)); + pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); // TODO(mbrookhart, jroesch, masahi): this pass is very slow, and is // incomplete to provide memory resuse optimizations. Disable it until we can @@ -982,10 +985,10 @@ transform::Sequential VMCompiler::MemoryOpt(const SEScope& host_se_scope) { pass_seqs.push_back(transform::FoldConstant()); // Fuse & lower yet again - pass_seqs.push_back(FuseAndLowerOperators(host_se_scope)); + pass_seqs.push_back(FuseAndLowerOperators(host_virtual_device)); // Create allocations for math introduced by dynamic region math. - pass_seqs.push_back(transform::ManifestAlloc(host_se_scope)); + pass_seqs.push_back(transform::ManifestAlloc(host_virtual_device)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -998,7 +1001,7 @@ transform::Sequential VMCompiler::MemoryOpt(const SEScope& host_se_scope) { return transform::Sequential(std::move(pass_seqs)); } -transform::Sequential VMCompiler::FuseAndLowerOperators(const SEScope& host_se_scope) { +transform::Sequential VMCompiler::FuseAndLowerOperators(const VirtualDevice& host_virtual_device) { Array pass_seqs; // Hoist operators to "primitive" Functions. pass_seqs.push_back(FuseOps()); @@ -1011,7 +1014,7 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const SEScope& host_se_s backend::UpdateConstants(func, ¶ms_); } }, - host_se_scope)); + host_virtual_device)); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); @@ -1022,8 +1025,8 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets, const Target& target_host) { config_ = CompilationConfig(PassContext::Current(), targets, target_host); // The first device always corresponds to the host. - CHECK(context_.se_scopes_.empty()); - context_.se_scopes_.push_back(config_->host_se_scope); + CHECK(context_.virtual_devices_.empty()); + context_.virtual_devices_.push_back(config_->host_virtual_device); // TODO(mbs): exec_ is not allocated. What is the API here? CHECK(exec_ == nullptr); return OptimizeModuleImpl(std::move(mod)); @@ -1082,13 +1085,13 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { backend::UpdateConstants(func, ¶ms_); } }, - config_->host_se_scope)); + config_->host_virtual_device)); // Since lowered functions are bound in the IRModule, we can now eliminate any unused // let-bound functions. pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); - // Now that we have PrimFuncs, flow and solve SEScope constraints again to account for + // Now that we have PrimFuncs, flow and solve VirtualDevice constraints again to account for // any memory scopes which lowering has settled on. pass_seqs.push_back(transform::PlanDevices(config_)); @@ -1099,7 +1102,7 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(config_->host_se_scope)); + pass_seqs.push_back(MemoryOpt(config_->host_virtual_device)); pass_seqs.push_back(transform::InferType()); transform::Sequential seq(pass_seqs); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index b8dd9d637b45..906e5148b593 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -80,8 +80,8 @@ struct VMCompilerContext { std::vector const_device_indexes; // Map from names of primitive functions already allocated to their primitive function index. std::unordered_map primitive_map; - // The SEScopes corresponding to each device index. - std::vector se_scopes_; + // The virtual devices corresponding to each device index. + std::vector virtual_devices_; }; class VMCompiler : public runtime::ModuleNode { @@ -136,8 +136,8 @@ class VMCompiler : public runtime::ModuleNode { IRModule OptimizeModuleImpl(IRModule mod); - transform::Sequential MemoryOpt(const SEScope& host_se_scope); - transform::Sequential FuseAndLowerOperators(const SEScope& host_se_scope); + transform::Sequential MemoryOpt(const VirtualDevice& host_virtual_device); + transform::Sequential FuseAndLowerOperators(const VirtualDevice& host_virtual_device); /*! * \brief Populate the global function names in a map where the value is used diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index ffd0e466eb24..0457459b3847 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -112,7 +112,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { auto free_type_vars = FreeTypeVars(func, module_); Array captured_vars; - std::vector captured_var_se_scopes; + std::vector captured_var_virtual_devices; bool recursive = false; for (const auto& var : free_vars) { if (!letrec_.empty() && var == letrec_.back()) { @@ -120,7 +120,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { continue; } captured_vars.push_back(var); - captured_var_se_scopes.push_back(GetSEScope(var)); + captured_var_virtual_devices.push_back(GetVirtualDevice(var)); } // Freshen all the captured vars. @@ -132,7 +132,7 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { rebinding_map.Set(free_var, var); } - SEScope result_se_scope = GetSEScope(func_node->body); + VirtualDevice result_virtual_device = GetVirtualDevice(func_node->body); if (recursive) { if (!captured_vars.empty()) { @@ -195,7 +195,8 @@ class LambdaLifter : public transform::DeviceAwareExprMutator { lifted_func = Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(), free_type_vars, /*attrs=*/{}, func->span); - lifted_func = MaybeFunctionOnDevice(lifted_func, captured_var_se_scopes, result_se_scope); + lifted_func = + MaybeFunctionOnDevice(lifted_func, captured_var_virtual_devices, result_virtual_device); lifted_func = MarkClosure(lifted_func); } diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 18e83f998e24..b680a49af887 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -23,15 +23,15 @@ */ #include #include -#include +#include namespace tvm { -SEScope RelayExprNode::virtual_device() const { +VirtualDevice RelayExprNode::virtual_device() const { if (virtual_device_.defined()) { - return Downcast(this->virtual_device_); + return Downcast(this->virtual_device_); } - return SEScope::FullyUnconstrained(); + return VirtualDevice::FullyUnconstrained(); } namespace relay { @@ -86,9 +86,9 @@ TVM_REGISTER_GLOBAL("relay.ir.Tuple").set_body_typed([](tvm::Array return Tuple(fields, span); }); Tuple WithFields(Tuple tuple, Optional> opt_fields, - Optional opt_virtual_device, Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Array fields = opt_fields.value_or(tuple->fields); - SEScope virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); Span span = opt_span.value_or(tuple->span); bool all_fields_unchanged = true; @@ -132,10 +132,10 @@ Var::Var(Id vid, Type type_annotation, Span span) { } Var WithFields(Var var, Optional opt_vid, Optional opt_type_annotation, - Optional opt_virtual_device, Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Id vid = opt_vid.value_or(var->vid); Type type_annotation = opt_type_annotation.value_or(var->type_annotation); - SEScope virtual_device = opt_virtual_device.value_or(var->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(var->virtual_device()); Span span = opt_span.value_or(var->span); bool unchanged = vid.same_as(var->vid) && type_annotation.same_as(var->type_annotation) && @@ -180,12 +180,12 @@ Call::Call(Expr op, Array args, Attrs attrs, Array type_args, Span s Call WithFields(Call call, Optional opt_op, Optional> opt_args, Optional opt_attrs, Optional> opt_type_args, - Optional opt_virtual_device, Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Expr op = opt_op.value_or(call->op); Array args = opt_args.value_or(call->args); Attrs attrs = opt_attrs.value_or(call->attrs); Array type_args = opt_type_args.value_or(call->type_args); - SEScope virtual_device = opt_virtual_device.value_or(call->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(call->virtual_device()); Span span = opt_span.value_or(call->span); bool unchanged = op.same_as(call->op) && attrs.same_as(call->attrs) && span.same_as(call->span); @@ -253,11 +253,11 @@ Let::Let(Var var, Expr value, Expr body, Span span) { } Let WithFields(Let let, Optional opt_var, Optional opt_value, Optional opt_body, - Optional opt_virtual_device, Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Var var = opt_var.value_or(let->var); Expr value = opt_value.value_or(let->value); Expr body = opt_body.value_or(let->body); - SEScope virtual_device = opt_virtual_device.value_or(let->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(let->virtual_device()); Span span = opt_span.value_or(let->span); bool unchanged = var.same_as(let->var) && value.same_as(let->value) && body.same_as(let->body) && @@ -296,12 +296,12 @@ If::If(Expr cond, Expr true_branch, Expr false_branch, Span span) { } If WithFields(If if_expr, Optional opt_cond, Optional opt_true_branch, - Optional opt_false_branch, Optional opt_virtual_device, + Optional opt_false_branch, Optional opt_virtual_device, Optional opt_span) { Expr cond = opt_cond.value_or(if_expr->cond); Expr true_branch = opt_true_branch.value_or(if_expr->true_branch); Expr false_branch = opt_false_branch.value_or(if_expr->false_branch); - SEScope virtual_device = opt_virtual_device.value_or(if_expr->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(if_expr->virtual_device()); Span span = opt_span.value_or(if_expr->span); bool unchanged = cond.same_as(if_expr->cond) && true_branch.same_as(if_expr->true_branch) && @@ -341,11 +341,11 @@ TupleGetItem::TupleGetItem(Expr tuple, int index, Span span) { } TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional opt_tuple, - Optional opt_index, Optional opt_virtual_device, + Optional opt_index, Optional opt_virtual_device, Optional opt_span) { Expr tuple = opt_tuple.value_or(tuple_get_item->tuple); Integer index = opt_index.value_or(tuple_get_item->index); - SEScope virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(tuple->virtual_device()); Span span = opt_span.value_or(tuple_get_item->span); bool unchanged = tuple.same_as(tuple_get_item->tuple) && (index == tuple_get_item->index) && @@ -380,9 +380,9 @@ RefCreate::RefCreate(Expr value, Span span) { } RefCreate WithFields(RefCreate ref_create, Optional opt_value, - Optional opt_virtual_device, Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Expr value = opt_value.value_or(ref_create->value); - SEScope virtual_device = opt_virtual_device.value_or(ref_create->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(ref_create->virtual_device()); Span span = opt_span.value_or(ref_create->span); bool unchanged = value.same_as(ref_create->value) && span.same_as(ref_create->span); @@ -414,10 +414,10 @@ RefRead::RefRead(Expr ref, Span span) { data_ = std::move(n); } -RefRead WithFields(RefRead ref_read, Optional opt_ref, Optional opt_virtual_device, - Optional opt_span) { +RefRead WithFields(RefRead ref_read, Optional opt_ref, + Optional opt_virtual_device, Optional opt_span) { Expr ref = opt_ref.value_or(ref_read->ref); - SEScope virtual_device = opt_virtual_device.value_or(ref_read->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(ref_read->virtual_device()); Span span = opt_span.value_or(ref_read->span); bool unchanged = ref.same_as(ref_read->ref) && span.same_as(ref_read->span); @@ -449,10 +449,10 @@ RefWrite::RefWrite(Expr ref, Expr value, Span span) { } RefWrite WithFields(RefWrite ref_write, Optional opt_ref, Optional opt_value, - Optional opt_virtual_device, Optional opt_span) { + Optional opt_virtual_device, Optional opt_span) { Expr ref = opt_ref.value_or(ref_write->ref); Expr value = opt_value.value_or(ref_write->value); - SEScope virtual_device = opt_virtual_device.value_or(ref_write->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(ref_write->virtual_device()); Span span = opt_span.value_or(ref_write->span); bool unchanged = ref.same_as(ref_write->ref) && value.same_as(ref_write->value) && diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index a08de39d0abb..2d6f75ae3948 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -478,11 +478,11 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { if (const FunctionNode* func = expr.as()) { Expr new_body = ExprBinder(args_map).VisitExpr(func->body); Array new_params; - std::vector new_param_se_scopes; + std::vector new_param_virtual_devices; for (size_t i = 0; i < func->params.size(); ++i) { if (!args_map.count(func->params[i])) { new_params.push_back(func->params[i]); - new_param_se_scopes.push_back(GetFunctionParamSEScope(func, i)); + new_param_virtual_devices.push_back(GetFunctionParamVirtualDevice(func, i)); } } if (new_body.same_as(func->body) && new_params.size() == func->params.size()) { @@ -490,7 +490,8 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { } auto ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func)); + ret = + MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func)); std::unordered_set set; for (const auto& v : FreeVars(expr)) { set.insert(v); @@ -498,19 +499,20 @@ Expr Bind(const Expr& expr, const tvm::Map& args_map) { for (const auto& v : FreeVars(ret)) { if (set.count(v) == 0) { new_params.push_back(v); - if (!GetFunctionResultSEScope(func)->IsFullyUnconstrained()) { + if (!GetFunctionResultVirtualDevice(func)->IsFullyUnconstrained()) { // TODO(mbs): The function has been annotated with a device, which means we are supposed // to be preserving device annotations on every transformation. However there's no // such context for the free vars in args_map. LOG(WARNING) << "introduced free var '" << PrettyPrint(v) << "' into function body but no device is known for it"; } - new_param_se_scopes.push_back(SEScope::FullyUnconstrained()); + new_param_virtual_devices.push_back(VirtualDevice::FullyUnconstrained()); } } ret = Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span); - ret = MaybeFunctionOnDevice(ret, new_param_se_scopes, GetFunctionResultSEScope(func)); + ret = + MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func)); ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size()); return std::move(ret); } else { diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 4c5b867e49da..43305402557a 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -42,14 +42,14 @@ Function::Function(tvm::Array params, Expr body, Type ret_type, Function WithFields(Function function, Optional> opt_params, Optional opt_body, Optional opt_ret_type, Optional> opt_ty_params, - Optional opt_attrs, Optional opt_virtual_device, + Optional opt_attrs, Optional opt_virtual_device, Optional opt_span) { Array params = opt_params.value_or(function->params); Expr body = opt_body.value_or(function->body); Type ret_type = opt_ret_type.value_or(function->ret_type); Array ty_params = opt_ty_params.value_or(function->type_params); DictAttrs attrs = opt_attrs.value_or(function->attrs); - SEScope virtual_device = opt_virtual_device.value_or(function->virtual_device()); + VirtualDevice virtual_device = opt_virtual_device.value_or(function->virtual_device()); Span span = opt_span.value_or(function->span); bool unchanged = body.same_as(function->body) && ret_type.same_as(function->ret_type) && diff --git a/src/relay/op/memory/device_copy.cc b/src/relay/op/memory/device_copy.cc index 690ad4a99317..a59e25ce1e13 100644 --- a/src/relay/op/memory/device_copy.cc +++ b/src/relay/op/memory/device_copy.cc @@ -50,12 +50,12 @@ const Op& DeviceCopyOp() { return op; } -Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) { - ICHECK(!src_se_scope->IsFullyUnconstrained()); - ICHECK(!dst_se_scope->IsFullyUnconstrained()); +Expr DeviceCopy(Expr expr, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device) { + ICHECK(!src_virtual_device->IsFullyUnconstrained()); + ICHECK(!dst_virtual_device->IsFullyUnconstrained()); auto attrs = make_object(); - attrs->src_se_scope = std::move(src_se_scope); - attrs->dst_se_scope = std::move(dst_se_scope); + attrs->src_virtual_device = std::move(src_virtual_device); + attrs->dst_virtual_device = std::move(dst_virtual_device); Span span = expr->span; return Call(DeviceCopyOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, std::move(span)); @@ -63,12 +63,13 @@ Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) { TVM_REGISTER_GLOBAL("relay.op._make.DeviceCopy").set_body_typed(DeviceCopy); -Expr MaybeDeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope) { - if (src_se_scope == dst_se_scope) { +Expr MaybeDeviceCopy(Expr expr, VirtualDevice src_virtual_device, + VirtualDevice dst_virtual_device) { + if (src_virtual_device == dst_virtual_device) { // No copy needed. return expr; } - return DeviceCopy(std::move(expr), std::move(src_se_scope), std::move(dst_se_scope)); + return DeviceCopy(std::move(expr), std::move(src_virtual_device), std::move(dst_virtual_device)); } RELAY_REGISTER_OP("device_copy") @@ -98,13 +99,14 @@ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node) { const auto* device_copy_attrs = call_node->attrs.as(); ICHECK(device_copy_attrs != nullptr) << "device_copy requires DeviceCopyAttrs"; // Follow nesting: - // device_copy(device_copy(expr, src_se_scope=S, dst_se_scope=T), - // src_se_scope=T, dst_se_scope=U) ==> {expr, S, U} + // device_copy(device_copy(expr, src_virtual_device=S, dst_virtual_device=T), + // src_virtual_device=T, dst_virtual_device=U) ==> {expr, S, U} auto inner = GetDeviceCopyProps(call_node->args[0]); if (inner.body.defined()) { - return {inner.body, inner.src_se_scope, device_copy_attrs->dst_se_scope}; + return {inner.body, inner.src_virtual_device, device_copy_attrs->dst_virtual_device}; } else { - return {call_node->args[0], device_copy_attrs->src_se_scope, device_copy_attrs->dst_se_scope}; + return {call_node->args[0], device_copy_attrs->src_virtual_device, + device_copy_attrs->dst_virtual_device}; } } return {}; diff --git a/src/relay/op/memory/device_copy.h b/src/relay/op/memory/device_copy.h index 728deb79b351..bb74324d5444 100644 --- a/src/relay/op/memory/device_copy.h +++ b/src/relay/op/memory/device_copy.h @@ -40,42 +40,41 @@ const Op& DeviceCopyOp(); /*! * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated and - * stored at \p src_se_scope but then copied to \p dst_se_scope. + * stored at \p src_virtual_device but then copied to \p dst_virtual_device. */ -Expr DeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope); +Expr DeviceCopy(Expr expr, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device); /*! * \brief Wraps \p expr in a "device_copy" CallNode indicating it should be evaluated and - * stored at \p src_se_scope but then copied to \p dst_se_scope.However, return \p expr - * directly if \p src_se_scope and \p dst_se_scope are (structurally) the same. + * stored at \p src_virtual_device but then copied to \p dst_virtual_device.However, return \p expr + * directly if \p src_virtual_device and \p dst_virtual_device are (structurally) the same. */ -Expr MaybeDeviceCopy(Expr expr, SEScope src_se_scope, SEScope dst_se_scope); +Expr MaybeDeviceCopy(Expr expr, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device); /*! \brief Result of \p GetDeviceCopyProps. */ struct DeviceCopyProps { Expr body; // = null - SEScope src_se_scope = SEScope::FullyUnconstrained(); - SEScope dst_se_scope = SEScope::FullyUnconstrained(); + VirtualDevice src_virtual_device = VirtualDevice::FullyUnconstrained(); + VirtualDevice dst_virtual_device = VirtualDevice::FullyUnconstrained(); DeviceCopyProps() = default; - DeviceCopyProps(Expr body, SEScope src_se_scope, SEScope dst_se_scope) + DeviceCopyProps(Expr body, VirtualDevice src_virtual_device, VirtualDevice dst_virtual_device) : body(std::move(body)), - src_se_scope(std::move(src_se_scope)), - dst_se_scope(std::move(dst_se_scope)) {} + src_virtual_device(std::move(src_virtual_device)), + dst_virtual_device(std::move(dst_virtual_device)) {} }; /*! - * \brief Returns the body expression, source, and destination \p SEScopes for \p call_node + * \brief Returns the body expression, source, and destination \p VirtualDevices for \p call_node * if it is a "device_copy" CallNode. Otherwise returns the null expression and unconstrained - * device and scopes. + * virtual device. */ DeviceCopyProps GetDeviceCopyProps(const CallNode* call_node); /*! - * \brief Returns the body expression, source, and destination \p SEScopes for \p expr if it - * is a "device_copy" Call. Otherwise returns the null expression and unconstrained device and - * scopes. + * \brief Returns the body expression, source, and destination \p VirtualDevices for \p expr if it + * is a "device_copy" Call. Otherwise returns the null expression and unconstrained virtual device. */ DeviceCopyProps GetDeviceCopyProps(const Expr& expr); diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index 315ad9c3b6a5..b546bd5384e1 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -50,10 +50,10 @@ TVM_REGISTER_NODE_TYPE(AllocTensorAttrs); // The passing value in attrs and args doesn't seem super great. // We should consider a better solution, i.e the type relation // being able to see the arguments as well? -Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint) { +Expr AllocStorage(Expr size, Expr alignment, VirtualDevice virtual_device, DataType dtype_hint) { auto attrs = make_object(); attrs->dtype = dtype_hint; - attrs->se_scope = std::move(se_scope); + attrs->virtual_device = std::move(virtual_device); static const Op& op = Op::Get("memory.alloc_storage"); return Call(op, {std::move(size), std::move(alignment)}, Attrs(std::move(attrs)), {}); } diff --git a/src/relay/op/memory/memory.h b/src/relay/op/memory/memory.h index 9e93afdcfa37..690854c38293 100644 --- a/src/relay/op/memory/memory.h +++ b/src/relay/op/memory/memory.h @@ -25,7 +25,7 @@ #ifndef TVM_RELAY_OP_MEMORY_MEMORY_H_ #define TVM_RELAY_OP_MEMORY_MEMORY_H_ -#include +#include #include @@ -34,7 +34,7 @@ namespace tvm { namespace relay { -Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint); +Expr AllocStorage(Expr size, Expr alignment, VirtualDevice virtual_device, DataType dtype_hint); /*! \brief Returns the "memory.alloc_tensor" operator. */ const Op& MemoryAllocTensorOp(); Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype, diff --git a/src/relay/op/memory/on_device.cc b/src/relay/op/memory/on_device.cc index 0fd86d3de67c..48e93ccf654d 100644 --- a/src/relay/op/memory/on_device.cc +++ b/src/relay/op/memory/on_device.cc @@ -43,11 +43,12 @@ const Op& OnDeviceOp() { return op; } -Call OnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) { - ICHECK((!constrain_result && !constrain_body) || !se_scope->IsFullyUnconstrained()); +Call OnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result, bool constrain_body) { + ICHECK((!constrain_result && !constrain_body) || !virtual_device->IsFullyUnconstrained()); auto attrs = make_object(); - attrs->se_scope = - (constrain_result || constrain_body) ? std::move(se_scope) : SEScope::FullyUnconstrained(); + attrs->virtual_device = (constrain_result || constrain_body) + ? std::move(virtual_device) + : VirtualDevice::FullyUnconstrained(); attrs->constrain_result = constrain_result; attrs->constrain_body = constrain_body; Span span = body->span; // about to be moved @@ -57,8 +58,9 @@ Call OnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain TVM_REGISTER_GLOBAL("relay.op.annotation._make.OnDevice").set_body_typed(OnDevice); -Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) { - if (se_scope->IsFullyUnconstrained()) { +Expr MaybeOnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result, + bool constrain_body) { + if (virtual_device->IsFullyUnconstrained()) { // Nothing to annotate with. return body; } @@ -72,40 +74,40 @@ Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result, bool cons } if (body->IsInstance()) { // If a primitive function then it is device polymorphic. Otherwise the device is captured - // by the function's "result_se_scope" attribute. + // by the function's "result_virtual_device" attribute. return body; } OnDeviceProps props = GetOnDeviceProps(body); if (props.body.defined()) { // The user is asking for - // on_device(on_device(body, se_scope=inner), se_scope=outer) + // on_device(on_device(body, virtual_device=inner), virtual_device=outer) // ^ ^ ^ // outer middle inner // First recover the implied constraints (if any) for outer and inner, and check they don't // contradict. - const SEScope& inner = props.se_scope; - const SEScope& outer = se_scope; + const VirtualDevice& inner = props.virtual_device; + const VirtualDevice& outer = virtual_device; bool constrain_outer = constrain_result; bool constrain_inner = props.constrain_body; if (constrain_outer && constrain_inner) { - ICHECK(inner == outer) - << "Cannot constrain result and body of nested on_device calls to different SEScopes"; + ICHECK(inner == outer) << "Cannot constrain result and body of nested on_device calls to " + "different virtual devices"; } // There are two possible ways the middle sub-expression may be constrained, check they don't // contradict. bool constrain_middle_via_outer = constrain_body; bool constrain_middle_via_inner = props.constrain_result; if (constrain_middle_via_outer && constrain_middle_via_inner) { - ICHECK(inner == outer) - << "Cannot constrain intermediate result of nested on_device calls to different SEScopes"; + ICHECK(inner == outer) << "Cannot constrain intermediate result of nested on_device calls to " + "different virtual devices"; } // We can now ignore the middle constraint. - // If the outer on_device has any constraint then use se_scope given for it. - // Otherwise we can use the existing inner se_scope. + // If the outer on_device has any constraint then use virtual_device given for it. + // Otherwise we can use the existing inner virtual_device. return OnDevice(props.body, (constrain_inner || constrain_outer) ? outer : inner, constrain_outer, constrain_inner); } else { - return OnDevice(body, std::move(se_scope), constrain_result, constrain_body); + return OnDevice(body, std::move(virtual_device), constrain_result, constrain_body); } } @@ -127,7 +129,7 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node) { ICHECK(call_node->attrs.defined()) << "on_device requires attributes"; const auto* on_device_attrs = call_node->attrs.as(); ICHECK(on_device_attrs != nullptr) << "on_device requires OnDeviceAttrs"; - return {call_node->args[0], on_device_attrs->se_scope, on_device_attrs->constrain_result, + return {call_node->args[0], on_device_attrs->virtual_device, on_device_attrs->constrain_result, on_device_attrs->constrain_body}; } return {}; @@ -140,38 +142,42 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) { return {}; } -Function FunctionOnDevice(Function function, Array param_se_scopes, - SEScope result_se_scope) { - return WithAttrs(std::move(function), {{tvm::attr::kParamSEScopes, std::move(param_se_scopes)}, - {tvm::attr::kResultSEScope, std::move(result_se_scope)}}); +Function FunctionOnDevice(Function function, Array param_virtual_devices, + VirtualDevice result_virtual_device) { + return WithAttrs(std::move(function), + {{tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices)}, + {tvm::attr::kResultVirtualDevice, std::move(result_virtual_device)}}); } TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice); -Function MaybeFunctionOnDevice(Function function, Array param_se_scopes, - SEScope result_se_scope) { - if (std::all_of(param_se_scopes.begin(), param_se_scopes.end(), - [](const SEScope& se_scope) { return se_scope->IsFullyUnconstrained(); }) && - result_se_scope->IsFullyUnconstrained()) { +Function MaybeFunctionOnDevice(Function function, Array param_virtual_devices, + VirtualDevice result_virtual_device) { + if (std::all_of(param_virtual_devices.begin(), param_virtual_devices.end(), + [](const VirtualDevice& virtual_device) { + return virtual_device->IsFullyUnconstrained(); + }) && + result_virtual_device->IsFullyUnconstrained()) { // Nothing to annotate. return function; } - return FunctionOnDevice(function, std::move(param_se_scopes), std::move(result_se_scope)); + return FunctionOnDevice(function, std::move(param_virtual_devices), + std::move(result_virtual_device)); } -SEScope GetFunctionResultSEScope(const FunctionNode* function_node) { - auto opt_se_scope = function_node->GetAttr(tvm::attr::kResultSEScope); - return opt_se_scope.value_or(SEScope::FullyUnconstrained()); +VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) { + auto opt_virtual_device = function_node->GetAttr(tvm::attr::kResultVirtualDevice); + return opt_virtual_device.value_or(VirtualDevice::FullyUnconstrained()); } -SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i) { +VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) { ICHECK_LT(i, function_node->params.size()) << "param index " << i << " out of range for function of arity " << function_node->params.size(); - auto opt_array = function_node->GetAttr>(tvm::attr::kParamSEScopes); + auto opt_array = function_node->GetAttr>(tvm::attr::kParamVirtualDevice); if (!opt_array) { // No annotation. - return SEScope::FullyUnconstrained(); + return VirtualDevice::FullyUnconstrained(); } ICHECK_EQ(opt_array.value().size(), function_node->params.size()) << "annotation parameters do not match function arity"; diff --git a/src/relay/op/memory/on_device.h b/src/relay/op/memory/on_device.h index 2ebaf034c760..7489e3b62b0c 100644 --- a/src/relay/op/memory/on_device.h +++ b/src/relay/op/memory/on_device.h @@ -39,25 +39,25 @@ namespace relay { const Op& OnDeviceOp(); /*! - * \brief Wraps \p body in an "on_device" CallNode for \p se_scope. + * \brief Wraps \p body in an "on_device" CallNode for \p virtual_device. * * See \p OnDeviceAttrs for an overview. */ -Call OnDevice(Expr body, SEScope se_scope, bool constrain_result = false, +Call OnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result = false, bool constrain_body = true); /*! \brief Result of \p GetOnDeviceProps. */ struct OnDeviceProps { Expr body; // = null - SEScope se_scope = SEScope::FullyUnconstrained(); + VirtualDevice virtual_device = VirtualDevice::FullyUnconstrained(); bool constrain_result = false; bool constrain_body = false; OnDeviceProps() = default; - OnDeviceProps(Expr body, SEScope se_scope, bool constrain_result, bool constrain_body) + OnDeviceProps(Expr body, VirtualDevice virtual_device, bool constrain_result, bool constrain_body) : body(std::move(body)), - se_scope(std::move(se_scope)), + virtual_device(std::move(virtual_device)), constrain_result(constrain_result), constrain_body(constrain_body) {} @@ -70,7 +70,8 @@ struct OnDeviceProps { * props. */ inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) { - return OnDevice(std::move(body), props.se_scope, props.constrain_result, props.constrain_body); + return OnDevice(std::move(body), props.virtual_device, props.constrain_result, + props.constrain_body); } /*! @@ -80,50 +81,50 @@ inline Call OnDeviceWithProps(Expr body, const OnDeviceProps& props) { * choices. */ inline Call OnDeviceCopyOk(Expr body) { - return OnDevice(std::move(body), SEScope::FullyUnconstrained(), + return OnDevice(std::move(body), VirtualDevice::FullyUnconstrained(), /*constrain_result=*/false, /*constrain_body=*/false); } /*! - * \brief Wraps \p expr in an "on_device" CallNode for \p se_scope and \p constraint if the - * \p SEScope for \p expr cannot otherwise be recovered by the lexical scoping convention. + * \brief Wraps \p expr in an "on_device" CallNode for \p virtual_device and \p constraint if the + * \p VirtualDevice for \p expr cannot otherwise be recovered by the lexical scoping convention. * This means we will NOT wrap if: - * - \p se_scope is full unconstrained, which signals there are no device annotations + * - \p virtual_device is full unconstrained, which signals there are no device annotations * already in play. * - \p expr is an operator or primitive function literal. These are device polymorphic. * - \p expr is a non-primitive function literal. The device is captured by the - * "result_se_scope" attribute on the function itself. + * "result_virtual_device" attribute on the function itself. * - \p expr is a global var. The device is on the function attributes the global is bound to. * - \p expr is a local var. The device is tracked by the device aware visitors for us. * - \p expr is a constructor. These are device polymorphic. * Nested on_device calls will never be constructed, they are instead merged on-the-fly. */ -Expr MaybeOnDevice(Expr body, SEScope se_scope, bool constrain_result = false, +Expr MaybeOnDevice(Expr body, VirtualDevice virtual_device, bool constrain_result = false, bool constrain_body = true); /*! \brief As for MaybeOnDevice, but with both body and result constrained. */ -inline Expr MaybeOnDeviceFixed(Expr body, SEScope se_scope) { - return MaybeOnDevice(std::move(body), std::move(se_scope), /*constrain_result=*/true, +inline Expr MaybeOnDeviceFixed(Expr body, VirtualDevice virtual_device) { + return MaybeOnDevice(std::move(body), std::move(virtual_device), /*constrain_result=*/true, /*constrain_body=*/true); } /*! \brief As for MaybeOnDevice, but with fields other than body taken from \p props. */ inline Expr MaybeOnDeviceWithProps(Expr body, const OnDeviceProps& props) { - return MaybeOnDevice(std::move(body), props.se_scope, props.constrain_result, + return MaybeOnDevice(std::move(body), props.virtual_device, props.constrain_result, props.constrain_body); } /*! - * \brief Returns the body expression, \p SEScope, and constraint field for \p call_node if it + * \brief Returns the body expression, \p VirtualDevice, and constraint field for \p call_node if it * is an "on_device" CallNode. Otherwise returns the null expression, the unconstrained - * \p SEScope, and \p kBody. + * \p VirtualDevice, and \p kBody. */ OnDeviceProps GetOnDeviceProps(const CallNode* call_node); /*! - * \brief Returns the body expression, \p SEScope, and constraint field for \p expr if it is an - * "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p SEScope, - * and \p kBody. + * \brief Returns the body expression, \p VirtualDevice, and constraint field for \p expr if it is + * an "on_device" CallNode. Otherwise returns the null expression, the unconstrained \p + * VirtualDevice, and \p kBody. */ OnDeviceProps GetOnDeviceProps(const Expr& expr); @@ -154,29 +155,31 @@ const NodeType* AsIgnoringOnDevice(const Expr& expr) { } /*! - * \brief Returns \p function annotated with "param_se_scopes" and "result_se_scope" - * attributes capturing parameter and result \p SEScopes respectively. + * \brief Returns \p function annotated with "param_virtual_devices" and "result_virtual_device" + * attributes capturing parameter and result \p VirtualDevices respectively. */ -Function FunctionOnDevice(Function function, Array param_se_scopes, SEScope body_se_scope); +Function FunctionOnDevice(Function function, Array param_virtual_devices, + VirtualDevice body_virtual_device); /*! * \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and - * result \p SEScopes are unconstrained. + * result \p VirtualDevices are unconstrained. */ -Function MaybeFunctionOnDevice(Function function, Array param_se_scopes, - SEScope result_se_scope); +Function MaybeFunctionOnDevice(Function function, Array param_virtual_devices, + VirtualDevice result_virtual_device); /*! - * \brief Returns the \p SEScope for the resut of \p function_node, or the unconstrained - * \p SEScope if function does not have the "result_se_scope" annotation. + * \brief Returns the \p VirtualDevice for the resut of \p function_node, or the unconstrained + * \p VirtualDevice if function does not have the "result_virtual_device" annotation. */ -SEScope GetFunctionResultSEScope(const FunctionNode* function_node); +VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node); /*! - * \brief Returns the \p SEScope for the \p i'th parameter of \p function_node, or - * the unconstrained \p SEScope if function does not have the "param_se_scopes" annotation. + * \brief Returns the \p VirtualDevice for the \p i'th parameter of \p function_node, or + * the unconstrained \p VirtualDevice if function does not have the "param_virtual_devices" + * annotation. */ -SEScope GetFunctionParamSEScope(const FunctionNode* function_node, size_t i); +VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc index 29965d2dac97..10584da51976 100644 --- a/src/relay/transforms/device_aware_visitors.cc +++ b/src/relay/transforms/device_aware_visitors.cc @@ -38,52 +38,52 @@ LexicalOnDeviceMixin::LexicalOnDeviceMixin(const Optional& maybe_mod) if (maybe_mod) { for (const auto& kv : maybe_mod.value()->functions) { if (const auto* function_node = kv.second.as()) { - SEScope se_scope = GetFunctionResultSEScope(function_node); - if (!se_scope->IsFullyUnconstrained()) { - VLOG(2) << "global '" << kv.first->name_hint << "' has scope " << se_scope; - global_var_se_scopes_.emplace(kv.first, se_scope); + VirtualDevice virtual_device = GetFunctionResultVirtualDevice(function_node); + if (!virtual_device->IsFullyUnconstrained()) { + VLOG(2) << "global '" << kv.first->name_hint << "' has virtual device " << virtual_device; + global_var_virtual_devices_.emplace(kv.first, virtual_device); } } } } } -SEScope LexicalOnDeviceMixin::GetSEScope(const Expr& expr) const { +VirtualDevice LexicalOnDeviceMixin::GetVirtualDevice(const Expr& expr) const { OnDeviceProps props = GetOnDeviceProps(expr); if (props.body.defined() && props.is_fixed()) { - return props.se_scope; + return props.virtual_device; } else if (const auto* var_node = expr.as()) { // Lookup variable binding. - auto itr = var_se_scopes_.find(GetRef(var_node)); - if (itr != var_se_scopes_.end()) { + auto itr = var_virtual_devices_.find(GetRef(var_node)); + if (itr != var_virtual_devices_.end()) { return itr->second; } // else: fallthrough to unconstrained } else if (const auto* global_var_node = expr.as()) { // Lookup global variable. - auto itr = global_var_se_scopes_.find(GetRef(global_var_node)); - if (itr != global_var_se_scopes_.end()) { + auto itr = global_var_virtual_devices_.find(GetRef(global_var_node)); + if (itr != global_var_virtual_devices_.end()) { return itr->second; } // else: fallthrough to unconstrained } else if (const auto* function_node = expr.as()) { if (function_node->HasNonzeroAttr(attr::kPrimitive)) { - if (!expr_se_scopes_.empty()) { + if (!expr_virtual_devices_.empty()) { // Use the currently in-scope device type. - return expr_se_scopes_.back(); + return expr_virtual_devices_.back(); } // else: fallthrough to unconstrained } else { - return GetFunctionResultSEScope(function_node); + return GetFunctionResultVirtualDevice(function_node); } } else { - if (!expr_se_scopes_.empty()) { + if (!expr_virtual_devices_.empty()) { // Use the currently in-scope device type. - return expr_se_scopes_.back(); + return expr_virtual_devices_.back(); } // else: fallthrough to unconstrained } - return SEScope::FullyUnconstrained(); + return VirtualDevice::FullyUnconstrained(); } void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } @@ -93,34 +93,34 @@ void LexicalOnDeviceMixin::ExitFunctionBody() { --function_nesting_; } -void LexicalOnDeviceMixin::PushSEScope(const SEScope& se_scope) { - if (se_scope->IsFullyUnconstrained()) { +void LexicalOnDeviceMixin::PushVirtualDevice(const VirtualDevice& virtual_device) { + if (virtual_device->IsFullyUnconstrained()) { return; } - expr_se_scopes_.emplace_back(se_scope); + expr_virtual_devices_.emplace_back(virtual_device); } -void LexicalOnDeviceMixin::PopSEScope() { - if (expr_se_scopes_.empty()) { +void LexicalOnDeviceMixin::PopVirtualDevice() { + if (expr_virtual_devices_.empty()) { return; } - expr_se_scopes_.pop_back(); + expr_virtual_devices_.pop_back(); } -void LexicalOnDeviceMixin::PushBoundVar(Var var, const SEScope& se_scope) { - if (se_scope->IsFullyUnconstrained()) { +void LexicalOnDeviceMixin::PushBoundVar(Var var, const VirtualDevice& virtual_device) { + if (virtual_device->IsFullyUnconstrained()) { return; } - ICHECK(var_se_scopes_.find(var) == var_se_scopes_.end()); - var_se_scopes_.emplace(std::move(var), se_scope); + ICHECK(var_virtual_devices_.find(var) == var_virtual_devices_.end()); + var_virtual_devices_.emplace(std::move(var), virtual_device); } void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { - auto itr = var_se_scopes_.find(var); - if (itr == var_se_scopes_.end()) { + auto itr = var_virtual_devices_.find(var); + if (itr == var_virtual_devices_.end()) { return; } - var_se_scopes_.erase(itr); + var_virtual_devices_.erase(itr); } // TODO(mbs): We'd probably have less tedious code duplication if we redefined the memoizing @@ -133,17 +133,17 @@ void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { } else { // Function parameters come into scope. for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); + PushBoundVar(function_node->params[i], GetFunctionParamVirtualDevice(function_node, i)); } // Entering scope of function body. - PushSEScope(GetFunctionResultSEScope(function_node)); + PushVirtualDevice(GetFunctionResultVirtualDevice(function_node)); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); // Leaving scope of function body. ExitFunctionBody(); - PopSEScope(); + PopVirtualDevice(); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -158,7 +158,7 @@ void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { while (const auto* inner_let_node = expr.as()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec). - PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); + PushBoundVar(inner_let_node->var, GetVirtualDevice(inner_let_node->value)); PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(inner_let_node); expr = inner_let_node->body; @@ -178,10 +178,10 @@ void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { OnDeviceProps props = GetOnDeviceProps(call_node); if (props.body.defined() && props.is_fixed()) { // Entering lexical scope of fixed "on_device" call. - PushSEScope(props.se_scope); + PushVirtualDevice(props.virtual_device); VisitExpr(props.body); // Leaving lexical scope of "on_device" call. - PopSEScope(); + PopVirtualDevice(); } else { DeviceAwareVisitExpr_(call_node); } @@ -219,17 +219,17 @@ Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { } else { // Function parameters come into scope. for (size_t i = 0; i < function_node->params.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); + PushBoundVar(function_node->params[i], GetFunctionParamVirtualDevice(function_node, i)); } // Entering scope of function body. - PushSEScope(GetFunctionResultSEScope(function_node)); + PushVirtualDevice(GetFunctionResultVirtualDevice(function_node)); EnterFunctionBody(); Expr result = DeviceAwareVisitExpr_(function_node); // Leaving scope of function body. ExitFunctionBody(); - PopSEScope(); + PopVirtualDevice(); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { PopBoundVar(function_node->params[i]); @@ -246,7 +246,7 @@ Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { while (const auto* inner_let_node = expr.as()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec.) - PushBoundVar(inner_let_node->var, GetSEScope(inner_let_node->value)); + PushBoundVar(inner_let_node->var, GetVirtualDevice(inner_let_node->value)); std::pair pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); expr = inner_let_node->body; @@ -269,10 +269,10 @@ Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { OnDeviceProps props = GetOnDeviceProps(call_node); if (props.body.defined() && props.is_fixed()) { // Entering lexical scope of fixed "on_device" call. - PushSEScope(props.se_scope); + PushVirtualDevice(props.virtual_device); Expr expr = VisitExpr(props.body); // Leaving lexical scope of "on_device" call. - PopSEScope(); + PopVirtualDevice(); return MaybeOnDeviceWithProps(expr, props); } else { return DeviceAwareVisitExpr_(call_node); diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h index 044cda85c579..9340c03fc2d5 100644 --- a/src/relay/transforms/device_aware_visitors.h +++ b/src/relay/transforms/device_aware_visitors.h @@ -42,7 +42,7 @@ namespace relay { namespace transform { /*! - * \brief Helper class for expression transformers which need to keep track of the \p SEScope + * \brief Helper class for expression transformers which need to keep track of the \p VirtualDevice * holding the results of expressions. This is recovered from function attributes and "on_device" * CallNodes added by the PlanDevices pass. * @@ -53,11 +53,11 @@ class LexicalOnDeviceMixin { explicit LexicalOnDeviceMixin(const Optional& maybe_mod); /*! - * \brief Returns the \p SEScope on which the result of \p expr should/will be stored, assuming - * {Push,Pop}{SEScope,BoundVar} have been correctly called. May return the unconstrained - * \p SEScope if the device planning pass has not been run. + * \brief Returns the \p VirtualDevice on which the result of \p expr should/will be stored, + * assuming {Push,Pop}{VirtualDevice,BoundVar} have been correctly called. May return the + * unconstrained \p VirtualDevice if the device planning pass has not been run. */ - SEScope GetSEScope(const Expr& expr) const; + VirtualDevice GetVirtualDevice(const Expr& expr) const; /*! \brief Indicate a function body is being entered. */ void EnterFunctionBody(); @@ -65,19 +65,21 @@ class LexicalOnDeviceMixin { /*! \brief Indicate a function body has been processed. */ void ExitFunctionBody(); - /*! \brief Push an \p SEScope onto the lexical SEScope stack. Ignore if unconstrained. */ - void PushSEScope(const SEScope& se_scope); + /*! \brief Push an \p VirtualDevice onto the lexical VirtualDevice stack. Ignore if unconstrained. + */ + void PushVirtualDevice(const VirtualDevice& virtual_device); - /*! \brief Pop an \p SEScope from the lexical SEScope stack. Ignore if stack is empty. */ - void PopSEScope(); + /*! \brief Pop an \p VirtualDevice from the lexical VirtualDevice stack. Ignore if stack is empty. + */ + void PopVirtualDevice(); - /*! \brief Remember that \p var will be stored at \p se_scope. Ignore if unconstrained. + /*! \brief Remember that \p var will be stored at \p virtual_device. Ignore if unconstrained. * * CAUTION: Despite the name we don't support re-entering the same function body. */ - void PushBoundVar(Var var, const SEScope& se_scope); + void PushBoundVar(Var var, const VirtualDevice& virtual_device); - /*! \brief Remove the binding for \p var to its \p SEScope. Ignore if var is not bound. */ + /*! \brief Remove the binding for \p var to its \p VirtualDevice. Ignore if var is not bound. */ void PopBoundVar(const Var& var); /*! @@ -93,36 +95,37 @@ class LexicalOnDeviceMixin { int function_nesting_ = 0; /*! - * \brief The stack of lexically enclosing "on_device" \p SEScopes, from outermost to + * \brief The stack of lexically enclosing "on_device" \p VirtualDevices, from outermost to * innermost. When visiting an expression other than a variable we can assume the expression's - * result is to be stored on \p expr_se_scopes.back(). + * result is to be stored on \p expr_virtual_devices.back(). */ - std::vector expr_se_scopes_; + std::vector expr_virtual_devices_; /*! - * \brief A map from in-scope local variables to their \p SEScopes. We may assume the variable is - * only ever bound to a value stored on this \p SEScope at runtime. + * \brief A map from in-scope local variables to their \p VirtualDevices. We may assume the + * variable is only ever bound to a value stored on this \p VirtualDevice at runtime. * * Note: We're playing it safe and keying by object refs here just in case the Relay expression * being rewritten has no module or other global to keep it alive. */ - std::unordered_map var_se_scopes_; + std::unordered_map + var_virtual_devices_; /*! - * \brief A map from global variables to their \p SEScopes, ie the "result_se_scope" of the - * function they are bound to in the module we are working on. We calculate and store this + * \brief A map from global variables to their \p VirtualDevices, ie the "result_virtual_device" + * of the function they are bound to in the module we are working on. We calculate and store this * explicitly so that we don't need to hold on to any module, which is often in the process of * being rewritten. */ - std::unordered_map - global_var_se_scopes_; + std::unordered_map + global_var_virtual_devices_; }; template class DeviceAwareExprFunctor; /*! - * \brief ExprFunctor which tracks \p SEScopes. We only support 'visitor' style implementation + * \brief ExprFunctor which tracks \p VirtualDevices. We only support 'visitor' style implementation * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without * any memoization. */ @@ -143,21 +146,21 @@ class DeviceAwareExprFunctor : public ExprFunctorparams.size(); ++i) { - PushBoundVar(function_node->params[i], GetFunctionParamSEScope(function_node, i)); + PushBoundVar(function_node->params[i], GetFunctionParamVirtualDevice(function_node, i)); } // Entering scope of function body. - SEScope se_scope = GetFunctionResultSEScope(function_node); - VLOG(2) << "entering " << se_scope << " for function:" << std::endl + VirtualDevice virtual_device = GetFunctionResultVirtualDevice(function_node); + VLOG(2) << "entering " << virtual_device << " for function:" << std::endl << PrettyPrint(GetRef(function_node)); - PushSEScope(se_scope); + PushVirtualDevice(virtual_device); EnterFunctionBody(); DeviceAwareVisitExpr_(function_node); // Leaving scope of function body. ExitFunctionBody(); - PopSEScope(); - VLOG(2) << "leaving " << se_scope << " for function:" << std::endl + PopVirtualDevice(); + VLOG(2) << "leaving " << virtual_device << " for function:" << std::endl << PrettyPrint(GetRef(function_node)); // Function parameters go out of scope. for (size_t i = 0; i < function_node->params.size(); ++i) { @@ -173,9 +176,10 @@ class DeviceAwareExprFunctor : public ExprFunctor()) { // Let-bound var (in pre visited version) goes into scope. // (We'll just assume this is a letrec.) - SEScope se_scope = GetSEScope(inner_let_node->value); - VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has scope " << se_scope; - PushBoundVar(inner_let_node->var, se_scope); + VirtualDevice virtual_device = GetVirtualDevice(inner_let_node->value); + VLOG(2) << "var '" << inner_let_node->var->name_hint() << "' has virtual device " + << virtual_device; + PushBoundVar(inner_let_node->var, virtual_device); PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); bindings.emplace_back(inner_let_node); expr = inner_let_node->body; @@ -196,13 +200,13 @@ class DeviceAwareExprFunctor : public ExprFunctor(call_node)); - PushSEScope(props.se_scope); + PushVirtualDevice(props.virtual_device); VisitExpr(props.body); // Leaving lexical scope of "on_device" call. - PopSEScope(); - VLOG(2) << "leaving " << props.se_scope << " for on_device:" << std::endl + PopVirtualDevice(); + VLOG(2) << "leaving " << props.virtual_device << " for on_device:" << std::endl << PrettyPrint(GetRef(call_node)); } else { DeviceAwareVisitExpr_(call_node); @@ -210,8 +214,8 @@ class DeviceAwareExprFunctor : public ExprFunctor : public ExprFunctor& maybe_mod) @@ -267,8 +271,8 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { void VisitExpr_(const CallNode* call_node) final; /*! - * \brief These are as for VisitExpr_. \p SEScopes for expressions and function parameters will be - * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters + * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For * functions the function_nesting count will already include that of \p function_node. */ virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node); @@ -281,9 +285,9 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { virtual void PreVisitLetBlock_(const LetNode* let_node); /*! - * \brief Visit a let-bound expression before the let body has been visited. \p SEScopes for the - * let-bound variable will be tracked automatically. Default implementation just visits var and - * value. + * \brief Visit a let-bound expression before the let body has been visited. \p VirtualDevices for + * the let-bound variable will be tracked automatically. Default implementation just visits var + * and value. */ virtual void PreVisitLetBinding_(const Var& var, const Expr& value); @@ -300,7 +304,7 @@ class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { virtual void PostVisitLetBlock_(const LetNode* let_node); }; -/*! \brief ExprMutator which tracks \p SEScopes. */ +/*! \brief ExprMutator which tracks \p VirtualDevices. */ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { public: explicit DeviceAwareExprMutator(const Optional& maybe_mod) @@ -311,8 +315,8 @@ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { Expr VisitExpr_(const CallNode* call_node) final; /*! - * \brief These are as for VisitExpr_. \p SEScopes for expressions and function parameters will be - * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * \brief These are as for VisitExpr_. \p VirtualDevices for expressions and function parameters + * will be tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For * functions the function_nesting count will already include that of \p function_node. */ virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node); @@ -325,9 +329,9 @@ class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { virtual void PreVisitLetBlock_(const LetNode* let_node); /*! - * \brief Visit a let-bound expression before the let body has been visited. \p SEScopes for the - * let-bound variable will be tracked automatically. Default implementation just visits var and - * value. + * \brief Visit a let-bound expression before the let body has been visited. \p VirtualDevices for + * the let-bound variable will be tracked automatically. Default implementation just visits var + * and value. */ virtual std::pair PreVisitLetBinding_(const Var& var, const Expr& value); diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc index fd46a6dc0563..95249f902b48 100644 --- a/src/relay/transforms/device_domains.cc +++ b/src/relay/transforms/device_domains.cc @@ -37,43 +37,44 @@ namespace relay { namespace transform { DeviceDomains::DeviceDomains(CompilationConfig config) : config_(std::move(config)) { - host_domain_ = MakeFirstOrderDomain(config_->host_se_scope); + host_domain_ = MakeFirstOrderDomain(config_->host_virtual_device); } -DeviceDomainPtr DeviceDomains::MakeFirstOrderDomain(const SEScope& se_scope) { - if (se_scope->IsFullyConstrained()) { - auto itr = fully_constrained_se_scope_to_domain_.find(se_scope); - if (itr != fully_constrained_se_scope_to_domain_.end()) { +DeviceDomainPtr DeviceDomains::MakeFirstOrderDomain(const VirtualDevice& virtual_device) { + if (virtual_device->IsFullyConstrained()) { + auto itr = fully_constrained_virtual_device_to_domain_.find(virtual_device); + if (itr != fully_constrained_virtual_device_to_domain_.end()) { return itr->second; } - DeviceDomainPtr domain = std::make_shared(se_scope); - fully_constrained_se_scope_to_domain_.emplace(se_scope, domain); + DeviceDomainPtr domain = std::make_shared(virtual_device); + fully_constrained_virtual_device_to_domain_.emplace(virtual_device, domain); return domain; } else { - return std::make_shared(se_scope); + return std::make_shared(virtual_device); } } -DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, const SEScope& se_scope) { +DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, const VirtualDevice& virtual_device) { if (const auto* func_type_node = type.as()) { std::vector args_and_result; args_and_result.reserve(func_type_node->arg_types.size() + 1); for (const auto& arg_type : func_type_node->arg_types) { - args_and_result.emplace_back(MakeDomain(arg_type, SEScope::FullyUnconstrained())); + args_and_result.emplace_back(MakeDomain(arg_type, VirtualDevice::FullyUnconstrained())); } - args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, se_scope)); + args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, virtual_device)); return std::make_shared(std::move(args_and_result)); } else { - return MakeFirstOrderDomain(se_scope); + return MakeFirstOrderDomain(virtual_device); } } -DeviceDomainPtr DeviceDomains::ForSEScope(const Type& type, const SEScope& non_canonical_se_scope) { - // Generally se_scope will have come from an annotation so resolve it to ensure we have +DeviceDomainPtr DeviceDomains::ForVirtualDevice(const Type& type, + const VirtualDevice& non_canonical_virtual_device) { + // Generally the virtual device will have come from an annotation so resolve it to ensure we have // its canonical representation. - SEScope se_scope = config_->CanonicalSEScope(non_canonical_se_scope); - ICHECK(!se_scope->IsFullyUnconstrained()); - return MakeDomain(type, se_scope); + VirtualDevice virtual_device = config_->CanonicalVirtualDevice(non_canonical_virtual_device); + ICHECK(!virtual_device->IsFullyUnconstrained()); + return MakeDomain(type, virtual_device); } DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { @@ -110,17 +111,18 @@ DeviceDomainPtr DeviceDomains::JoinOrNull(const DeviceDomainPtr& lhs, const Devi << "do not have the same kind and can't be unified."; if (lhs->args_and_result_.empty()) { // Directly compare first-order. - if (rhs->se_scope_->IsFullyUnconstrained()) { + if (rhs->virtual_device_->IsFullyUnconstrained()) { return lhs; } - if (lhs->se_scope_->IsFullyUnconstrained()) { + if (lhs->virtual_device_->IsFullyUnconstrained()) { return rhs; } - Optional joined_se_scope = SEScope::Join(lhs->se_scope_, rhs->se_scope_); - if (!joined_se_scope) { + Optional joined_virtual_device = + VirtualDevice::Join(lhs->virtual_device_, rhs->virtual_device_); + if (!joined_virtual_device) { return nullptr; } - return MakeFirstOrderDomain(config_->CanonicalSEScope(joined_se_scope.value())); + return MakeFirstOrderDomain(config_->CanonicalVirtualDevice(joined_virtual_device.value())); } else { // Recurse for higher-order. std::vector args_and_result; @@ -205,41 +207,42 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { // all the argument and result devices domains must be equal, ignoring memory scopes. // So at this point we'll let all the arguments and result be free so that memory scopes can // differ. - // TODO(mbs): As per header comments, need to revisit when can setup sub-SEScope constraints. + // TODO(mbs): As per header comments, need to revisit when can setup sub-virtual device + // constraints. return DomainFor(call_lowered_props.lowered_func); } else if (on_device_props.body.defined()) { // By default: - // on_device(expr, se_scope=) + // on_device(expr, virtual_device=) // on_device : fn():?x? // However we'll interpret the constrain_body and constrain_result fields to decide // on free vs constrained domains for the argument and result respectively. if (on_device_props.constrain_body) { args_and_result.emplace_back( - ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); + ForVirtualDevice(on_device_props.body->checked_type(), on_device_props.virtual_device)); } else { args_and_result.emplace_back(Free(on_device_props.body->checked_type())); } if (on_device_props.constrain_result) { args_and_result.emplace_back( - ForSEScope(on_device_props.body->checked_type(), on_device_props.se_scope)); + ForVirtualDevice(on_device_props.body->checked_type(), on_device_props.virtual_device)); } else { args_and_result.emplace_back(Free(on_device_props.body->checked_type())); } } else if (device_copy_props.body.defined()) { - // device_copy(expr, src_se_scope=, dst_se_scope=) + // device_copy(expr, src_virtual_device=, dst_virtual_device=) // device_copy: fn(): - args_and_result.emplace_back( - ForSEScope(device_copy_props.body->checked_type(), device_copy_props.src_se_scope)); - args_and_result.emplace_back( - ForSEScope(device_copy_props.body->checked_type(), device_copy_props.dst_se_scope)); + args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(), + device_copy_props.src_virtual_device)); + args_and_result.emplace_back(ForVirtualDevice(device_copy_props.body->checked_type(), + device_copy_props.dst_virtual_device)); } else if (call->op == alloc_storage_op) { ICHECK_EQ(call->args.size(), 2U); - // alloc_storage(size, alignment, se_scope=) + // alloc_storage(size, alignment, virtual_device=) // alloc_storage: fn(, ): const auto* attrs = call->attrs.as(); args_and_result.emplace_back(host_domain_); args_and_result.emplace_back(host_domain_); - args_and_result.emplace_back(ForSEScope(call->checked_type(), attrs->se_scope)); + args_and_result.emplace_back(ForVirtualDevice(call->checked_type(), attrs->virtual_device)); } else if (call->op == alloc_tensor_op) { ICHECK_EQ(call->args.size(), 3U); // alloc_tensor(storage, offset, shape) @@ -277,7 +280,7 @@ DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { // (arg1, ..., argn) // : fn(?x?, ..., ?x?):?x? // (all args and result must be first-order). - auto free_domain = MakeFirstOrderDomain(SEScope::FullyUnconstrained()); + auto free_domain = MakeFirstOrderDomain(VirtualDevice::FullyUnconstrained()); for (size_t i = 0; i < call->args.size(); ++i) { args_and_result.emplace_back(free_domain); } @@ -314,12 +317,12 @@ void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) { auto rhs_domain = DomainFor(rhs); if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) { // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Incompatible SEScopes for expressions:" << std::endl + LOG(FATAL) << "Incompatible virtual devices for expressions:" << std::endl << PrettyPrint(lhs) << std::endl - << "with scope:" << std::endl + << "with virtual device:" << std::endl << ToString(lhs_domain) << "and:" << std::endl << PrettyPrint(rhs) << std::endl - << "with scope:" << std::endl + << "with virtual device:" << std::endl << ToString(rhs_domain); } } @@ -332,21 +335,21 @@ void DeviceDomains::OptionalUnifyExprExact(const Expr& lhs, const Expr& rhs) { if (UnifyOrNull(lhs_domain, rhs_domain) == nullptr) { // Rollback domain_to_equiv_ = domain_to_equiv_snapshot; - VLOG(2) << "Unable to unify SEScopes for expression:" << std::endl + VLOG(2) << "Unable to unify virtual devices for expression:" << std::endl << PrettyPrint(lhs) << std::endl - << "with scope:" << std::endl + << "with virtual device:" << std::endl << ToString(lhs_domain) << std::endl << "and expression:" << std::endl << PrettyPrint(rhs) << std::endl - << "with scope:" << std::endl + << "with virtual device:" << std::endl << ToString(rhs_domain) << std::endl - << ". Leaving scopes non-unified."; + << ". Leaving virtual devices non-unified."; } else { - VLOG(2) << "Unified SEScopes for expression:" << std::endl + VLOG(2) << "Unified virtual devices for expression:" << std::endl << PrettyPrint(lhs) << std::endl << "and expression:" << std::endl << PrettyPrint(rhs) << std::endl - << "to scope:" << std::endl + << "to virtual devices:" << std::endl << ToString(lhs_domain); } } @@ -355,11 +358,11 @@ void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expe auto actual_domain = DomainFor(expr); if (UnifyOrNull(actual_domain, expected_domain) == nullptr) { // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Incompatible SEScopes for expression:" << std::endl + LOG(FATAL) << "Incompatible virtual devices for expression:" << std::endl << PrettyPrint(expr) << std::endl - << "with actual scope:" << std::endl + << "with actual virtual device:" << std::endl << ToString(actual_domain) << std::endl - << "and expected scope:" << std::endl + << "and expected virtual device:" << std::endl << ToString(expected_domain); } } @@ -369,11 +372,11 @@ void DeviceDomains::UnifyExprCollapsed(const Expr& expr_first_order, auto actual_domain_first_order = DomainFor(expr_first_order); if (!UnifyCollapsedOrFalse(actual_domain_first_order, expected_domain_maybe_higher_order)) { // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Incompatible SEScopes for expression:" << std::endl + LOG(FATAL) << "Incompatible virtual devices for expression:" << std::endl << PrettyPrint(expr_first_order) << std::endl - << "with actual scope:" << std::endl + << "with actual virtual devices:" << std::endl << ToString(actual_domain_first_order) << std::endl - << "and expected scope:" << std::endl + << "and expected virtual device:" << std::endl << ToString(expected_domain_maybe_higher_order); } } @@ -382,7 +385,7 @@ bool DeviceDomains::IsFullyConstrained(DeviceDomainPtr domain) { domain = Lookup(domain); if (domain->args_and_result_.empty()) { // First-order. - return domain->se_scope_->IsFullyConstrained(); + return domain->virtual_device_->IsFullyConstrained(); } else { // Higher-order. return std::all_of( @@ -391,30 +394,31 @@ bool DeviceDomains::IsFullyConstrained(DeviceDomainPtr domain) { } } -void DeviceDomains::SetDefault(DeviceDomainPtr domain, const SEScope& default_se_scope) { - ICHECK(!default_se_scope->IsFullyUnconstrained()); +void DeviceDomains::SetDefault(DeviceDomainPtr domain, + const VirtualDevice& default_virtual_device) { + ICHECK(!default_virtual_device->IsFullyUnconstrained()); domain = Lookup(domain); if (domain->args_and_result_.empty()) { - DeviceDomainPtr defaulted_domain_ptr = - UnifyOrNull(domain, MakeFirstOrderDomain(config_->CanonicalSEScope( - SEScope::Default(domain->se_scope_, default_se_scope)))); + DeviceDomainPtr defaulted_domain_ptr = UnifyOrNull( + domain, MakeFirstOrderDomain(config_->CanonicalVirtualDevice( + VirtualDevice::Default(domain->virtual_device_, default_virtual_device)))); ICHECK_NOTNULL(defaulted_domain_ptr); } else { for (const auto& sub_domain : domain->args_and_result_) { - SetDefault(sub_domain, default_se_scope); + SetDefault(sub_domain, default_virtual_device); } } } void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order, - const SEScope& default_se_scope) { + const VirtualDevice& default_virtual_device) { if (domain_maybe_higher_order->args_and_result_.empty()) { - SetDefault(domain_maybe_higher_order, default_se_scope); + SetDefault(domain_maybe_higher_order, default_virtual_device); } else { // First set default for result domain. - SetDefault(ResultDomain(domain_maybe_higher_order), default_se_scope); + SetDefault(ResultDomain(domain_maybe_higher_order), default_virtual_device); // Then use current result domain as default for everything else. - SetDefault(domain_maybe_higher_order, ResultSEScope(domain_maybe_higher_order)); + SetDefault(domain_maybe_higher_order, ResultVirtualDevice(domain_maybe_higher_order)); } } @@ -431,11 +435,11 @@ std::string DeviceDomains::ToString(DeviceDomainPtr domain) { std::ostringstream os; if (domain->args_and_result_.empty()) { // First-order. - if (!domain->se_scope_->IsFullyConstrained()) { + if (!domain->virtual_device_->IsFullyConstrained()) { os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; } - if (!domain->se_scope_->IsFullyUnconstrained()) { - os << domain->se_scope_; + if (!domain->virtual_device_->IsFullyUnconstrained()) { + os << domain->virtual_device_; } } else { // higher-order diff --git a/src/relay/transforms/device_domains.h b/src/relay/transforms/device_domains.h index 223c7d42bfa1..983ecb4b6d5d 100644 --- a/src/relay/transforms/device_domains.h +++ b/src/relay/transforms/device_domains.h @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include @@ -51,7 +51,7 @@ class DeviceDomains; * * \code * D ::= ?x? -- first order, free - * | -- first order, bound to specific device and memory scope + * | -- first order, bound to specific virtual device * | fn(D1, ..., Dn):Dr -- higher order * \endcode * @@ -59,31 +59,32 @@ class DeviceDomains; * a notion of the 'result domain' of a domain: * \code * result_domain(?x?) = ?x? - * result_domain() = + * result_domain() = * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) * \endcode * - * TODO(mbs): We currently don't allow sub-SEScope constraints. Eg for a function we can - * express that the argument and result SEScopes must be exactly equal, but we cannot express + * TODO(mbs): We currently don't allow sub-VirtualDevice constraints. Eg for a function we can + * express that the argument and result VirtualDevices must be exactly equal, but we cannot express * that though the devices and targets for arguments and results must be equal, it is ok for * memory scopes to differ. At the moment we can get away with this since we run PlanDevices * twice: once with all memory scopes unconstrained, then again with just memory scopes as * the new property to flow. However we're on thin ice here and better would be to allow - * constraints on SEScopes to be exploded into their device/target component and their - * memory scope component. Should we fold layout constraints into SEScopes then they would + * constraints on VirtualDevices to be exploded into their device/target component and their + * memory scope component. Should we fold layout constraints into VirtualDevices then they would * probably be grouped with memory scopes. */ class DeviceDomain { public: /*! - * \brief Constructs a first-order domain for \p se_scope, which may be - * fully free (ie se_scope is unconstrained), partially free (ie se_scope has at least on - * of its target, device id or memory scopes known), or fully fixed (ie se_scope has its target, - * device id and memory scopes set). + * \brief Constructs a first-order domain for \p virtual_device, which may be + * fully free (ie virtual_device is unconstrained), partially free (ie virtual_device has at + * least on of its target, device id or memory scopes known), or fully fixed (ie virtual_device + * has its target, device id and memory scopes set). * * CAUTION: Use DeviceDomains::MakeFirstOrderDomain instead of this ctor. */ - explicit DeviceDomain(SEScope se_scope) : se_scope_(std::move(se_scope)) {} + explicit DeviceDomain(VirtualDevice virtual_device) + : virtual_device_(std::move(virtual_device)) {} /*! * \brief Constructs a higher-order domain, where \p args_and_result contain the @@ -92,13 +93,14 @@ class DeviceDomain { * CAUTION: Use DeviceDomains::MakeHigherOrderDomain instead of this ctor. */ explicit DeviceDomain(std::vector args_and_result) - : se_scope_(SEScope::FullyUnconstrained()), args_and_result_(std::move(args_and_result)) {} + : virtual_device_(VirtualDevice::FullyUnconstrained()), + args_and_result_(std::move(args_and_result)) {} bool is_higher_order() const { return !args_and_result_.empty(); } - SEScope first_order_se_scope() const { + VirtualDevice first_order_virtual_device() const { ICHECK(args_and_result_.empty()) << "expecting domain to be first-order"; - return se_scope_; + return virtual_device_; } size_t function_arity() const { @@ -124,7 +126,7 @@ class DeviceDomain { * (for example, the \p target and \p device_type are constrained but the \p virtual_device_id and * \p memory_scope are still unconstrained), or fully constrained (everything is known). */ - const SEScope se_scope_; + const VirtualDevice virtual_device_; /*! * \brief If this is a function domain then the sub-domains for each of the function's @@ -146,10 +148,10 @@ class DeviceDomains { const CompilationConfig& config() const { return config_; } /*! - * \brief Returns the domain representing \p se_scope. If \p se_scope is fully constrained - * then the domain will be unique that \p se_scope. + * \brief Returns the domain representing \p virtual_device. If \p virtual_device is fully + * constrained then the domain will be unique that \p virtual_device. */ - DeviceDomainPtr MakeFirstOrderDomain(const SEScope& se_scope); + DeviceDomainPtr MakeFirstOrderDomain(const VirtualDevice& virtual_device); /*! * \brief Returns a higher-order domain with \p args_and_results. @@ -159,21 +161,24 @@ class DeviceDomains { } /*! - * \brief Returns a domain appropriate for \p type who's result domain is bound to \p se_scope. - * If \p type is a function then all parameter domains will be completely free. It is valid for - * \p se_scope to be fully unconstrained. + * \brief Returns a domain appropriate for \p type who's result domain is bound to \p + * virtual_device. If \p type is a function then all parameter domains will be completely free. It + * is valid for \p virtual_device to be fully unconstrained. */ - DeviceDomainPtr MakeDomain(const Type& type, const SEScope& se_scope); + DeviceDomainPtr MakeDomain(const Type& type, const VirtualDevice& virtual_device); /*! - * \brief Returns a domain with the given result appropriate \p non_canonical_se_scope, - * which cannot be fully unconstrained. We first canonicalize the scope to unsure it has + * \brief Returns a domain with the given result appropriate \p non_canonical_virtual_device, + * which cannot be fully unconstrained. We first canonicalize the virtual device to unsure it has * a target and is unique. */ - DeviceDomainPtr ForSEScope(const Type& type, const SEScope& non_canonical_se_scope); + DeviceDomainPtr ForVirtualDevice(const Type& type, + const VirtualDevice& non_canonical_virtual_device); /*! \brief Returns a free domain appropriate for \p type. */ - DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, SEScope::FullyUnconstrained()); } + DeviceDomainPtr Free(const Type& type) { + return MakeDomain(type, VirtualDevice::FullyUnconstrained()); + } /*! \brief Returns the domain representing the equivalence class containing \p domain. */ DeviceDomainPtr Lookup(DeviceDomainPtr domain); @@ -274,16 +279,16 @@ class DeviceDomains { /*! \brief Returns true if \p domain is fully constrainted. */ bool IsFullyConstrained(DeviceDomainPtr domain); - /*! \brief Force all \p SEScopes in \p domain to default to \p default_se_scope. */ - void SetDefault(DeviceDomainPtr domain, const SEScope& default_se_scope); + /*! \brief Force all \p VirtualDevices in \p domain to default to \p default_virtual_device. */ + void SetDefault(DeviceDomainPtr domain, const VirtualDevice& default_virtual_device); /*! - * \brief If \p domain is higher-order default it's result domain to \p default_se_scope. - * Then force all remaining \p SEScopes to the result domain (freshly defaulted or original). - * If \p domain is first-order same as \p SetDefault. + * \brief If \p domain is higher-order default it's result domain to \p default_virtual_device. + * Then force all remaining \p VirtualDevices to the result domain (freshly defaulted or + * original). If \p domain is first-order same as \p SetDefault. */ void SetResultDefaultThenParams(const DeviceDomainPtr& domain_maybe_higher_order, - const SEScope& default_se_scope); + const VirtualDevice& default_virtual_device); /*! * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). @@ -291,11 +296,11 @@ class DeviceDomains { DeviceDomainPtr ResultDomain(DeviceDomainPtr domain); /*! - * \brief Returns the result \p SEScope (possibly unconstrained) for \p domain + * \brief Returns the result \p VirtualDevice (possibly unconstrained) for \p domain * (see defn in DeviceDomain comment). */ - SEScope ResultSEScope(const DeviceDomainPtr& domain) { - return ResultDomain(domain)->first_order_se_scope(); + VirtualDevice ResultVirtualDevice(const DeviceDomainPtr& domain) { + return ResultDomain(domain)->first_order_virtual_device(); } /*! \brief Returns one-line description of \p domain for debugging. */ @@ -332,16 +337,17 @@ class DeviceDomains { std::unordered_map domain_to_equiv_; /*! - * \brief Maps fully constrained \p SEScopes to their corresponding domains. By sharing those - * domains we can ensure: + * \brief Maps fully constrained \p VirtualDevices to their corresponding domains. By sharing + * those domains we can ensure: * * \code * domain0 != domain1 && domain0 fully constrained && domain1 fully constrained * ==> domain0 and domain1 are incompatible * \endcode */ - std::unordered_map - fully_constrained_se_scope_to_domain_; + std::unordered_map + fully_constrained_virtual_device_to_domain_; }; } // namespace transform diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc index bad8363f4783..d40dd6c95089 100644 --- a/src/relay/transforms/device_planner.cc +++ b/src/relay/transforms/device_planner.cc @@ -19,14 +19,14 @@ /*! * \file src/relay/transforms/device_planner.cc - * \brief Determines a unique \p SEScope to hold the result of every Relay sub-expression. + * \brief Determines a unique \p VirtualDevice to hold the result of every Relay sub-expression. * This pass can be run multiple times, and can be run both before and after lowering. * - * TODO(mbs): Rename SEScope |-> VirtualDevice, and use 'virtual device' (or just 'device') + * TODO(mbs): Rename VirtualDevice |-> VirtualDevice, and use 'virtual device' (or just 'device') * throughout. * * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. - * We represent D by an \p SEScope, which means we can track anywhere from an arbitrary device + * We represent D by an \p VirtualDevice, which means we can track anywhere from an arbitrary device * of some \p DLDeviceType to a specific memory scope on a specific (virtual) \p Device who's * code is compiled with a specific \p Target. * @@ -37,17 +37,17 @@ * resolve any remaining undetermined devices, and encoding the results on the output in a form * that's reasonably friendly to downstream passes. * - * Specific \p SEScopes flow into the constraints from five places: + * Specific \p VirtualDevices flow into the constraints from five places: * - Existing "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a - * 'src_se_scope' and 'dst_se_scope' \p SEScope. Those constrain the argument and context of - * the call respectively. It is ok if source and destination devices are the same, such no-op - * copies will be removed after accounting for the device preference. - * - Existing "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an 'se_scope', - * which constrains the argument of the call, but (usually, see below) leaves the context - * unconstrained. These are called 'annotations' in the rest of the code, have no operational - * significance by themselves, but may trigger the insertion of a new "device_copy" call by - * this pass. In two situations the result of an "on_device" CallNode may also be constrained - * to the given 'se_scope': + * 'src_virtual_device' and 'dst_virtual_device' \p VirtualDevice. Those constrain the argument + * and context of the call respectively. It is ok if source and destination devices are the same, + * such no-op copies will be removed after accounting for the device preference. + * - Existing "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify an + * 'virtual_device', which constrains the argument of the call, but (usually, see below) leaves the + * context unconstrained. These are called 'annotations' in the rest of the code, have no + * operational significance by themselves, but may trigger the insertion of a new "device_copy" call + * by this pass. In two situations the result of an "on_device" CallNode may also be constrained to + * the given 'virtual_device': * - The "on_device" call occurs at the top-level of a function body, or occurs as an * immediately let-bound expression. In this situation the extra degree of freedom in * the function result and let-binding leads to surprising device copies, so we simply @@ -56,12 +56,12 @@ * it ourselves during an earlier invocation of this pass. This helps make this pass * idempotent. * - Some special operators require their arguments or results to be on the 'host' (typcially - * a CPU) \p SEScope, see below. + * a CPU) \p VirtualDevice, see below. * - Any \p PrimFuncs in the \p IRModule (if \p LowerTEPass has already run) may constrain their - * argument buffers to have a specific memory scope, which is part of \p SEScope. - * - Annotations left over from a previous run of this pass, such as 'param_se_scopes' and - * 'result_se_scope' function attributes we introduce below. This is so the pass is idempotent - * and can be re-run to flow additional memory scope constraints. + * argument buffers to have a specific memory scope, which is part of \p VirtualDevice. + * - Annotations left over from a previous run of this pass, such as 'param_virtual_devices' and + * 'result_virtual_device' function attributes we introduce below. This is so the pass is + * idempotent and can be re-run to flow additional memory scope constraints. * * We proceed in four phases: * @@ -114,8 +114,8 @@ * * Phase 2 * ------- - * After flowing constraints we apply some defaulting heuristics (using a global default \p SEScope) - * to fix the device for any as-yet unconstrained sub-expressions. + * After flowing constraints we apply some defaulting heuristics (using a global default \p + * VirtualDevice) to fix the device for any as-yet unconstrained sub-expressions. * - Unconstrained function result devices default to the global default device. * - Unconstrained function parameters devices default to the device for the function result. * - Unconstrained let-bound expression devices default to the device for the overall let. @@ -127,9 +127,9 @@ * Phase 3 * ------- * Finally, the result of this analysis is reified into the result as: - * - Additional "param_se_scopes" (an \p Array) and "result_se_scope" (an \p SEScope) - * attributes for every function (both top-level and local). These describe the devices for - * the function's parameters and the result. + * - Additional "param_virtual_devices" (an \p Array) and "result_virtual_device" + * (an \p VirtualDevice) attributes for every function (both top-level and local). These describe + * the devices for the function's parameters and the result. * - Additional "device_copy" CallNodes where a copy is required in order to respect the * intent of the original "on_device" CallNodes. * - Additional "on_device" CallNodes where the device type of an expression is not trivially @@ -155,11 +155,11 @@ * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion * to ANF must respect the lexical scoping convention: * \code - * f(on_device(g(h(a, b), c), se_scope=CPU)) + * f(on_device(g(h(a, b), c), virtual_device=CPU)) * ==> - * let %x0 = on_device(h(a, b), se_scope=CPU) - * let %x1 = on_device(g(%x0), se_scope=CPU) - * f(on_device(%x1, se_scope=CPU)) + * let %x0 = on_device(h(a, b), virtual_device=CPU) + * let %x1 = on_device(g(%x0), virtual_device=CPU) + * f(on_device(%x1, virtual_device=CPU)) * \endcode * * This pass can be run before FuseOps so that it can use device-specific fusion rules. @@ -188,7 +188,7 @@ * minimize cross-device calls by moving device copies out of functions. E.g.: * \code * def @f() { // execute on CPU - * let x = on_device(...GPU computation..., se_scope=GPU); + * let x = on_device(...GPU computation..., virtual_device=GPU); * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) * } * def @main() { @@ -220,7 +220,7 @@ * \code * let f = fn(x, y) { ... } * let g = fn(f, z) { f(z, z) } - * g(f, on_device(..., se_scope=CPU)) + * g(f, on_device(..., virtual_device=CPU)) * \endcode * the parameters \p x and \p y will be on the CPU. * @@ -298,14 +298,14 @@ namespace { * * - Don't let the device for %x remain unconstrained: * \code - * let %x = on_device(e, se_scope=d) - * ==> let %x = on_device(e, se_scope=d, constraint=kBoth) + * let %x = on_device(e, virtual_device=d) + * ==> let %x = on_device(e, virtual_device=d, constraint=kBoth) * \endcode * * - Don't let the function result remain unconstrained: * \code - * fn(%x) { on_device(e, se_scope=d) } - * ==> fn(%x) { on_device(e, se_scope=d, constraint=kBoth) + * fn(%x) { on_device(e, virtual_device=d) } + * ==> fn(%x) { on_device(e, virtual_device=d, constraint=kBoth) * \endcode * * - Project-then-copy rather than copy-then-project: @@ -321,7 +321,7 @@ namespace { * call_lowered(@prim, (a, b)) * ==> copy_ok(call_lowered(@prim, (copy_ok(a), copy_ok(b)))) * where - * copy_ok(x) = on_device(x, se_scope=SEScope::FullyUnconstrained, + * copy_ok(x) = on_device(x, virtual_device=VirtualDevice::FullyUnconstrained, * constrain_body=False, constrain_result=False) * \endcode */ @@ -338,7 +338,7 @@ class RewriteOnDevices : public ExprMutator { if (props.body.defined() && props.is_normal()) { VLOG(2) << "wrapping tuple get item:" << std::endl << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl - << "with \"on_device\" for SEScope " << props.se_scope; + << "with \"on_device\" for VirtualDevice " << props.virtual_device; return OnDeviceWithProps(tuple_get_item, props); } else { return tuple_get_item; @@ -355,8 +355,8 @@ class RewriteOnDevices : public ExprMutator { if (props.body.defined() && props.is_normal()) { VLOG(2) << "revising let-bound expression of let:" << std::endl << PrettyPrint(expr) << std::endl - << "to be fixed to SEScope " << props.se_scope; - value = MaybeOnDeviceFixed(props.body, props.se_scope); + << "to be fixed to VirtualDevice " << props.virtual_device; + value = MaybeOnDeviceFixed(props.body, props.virtual_device); } bindings.emplace_back(inner_let, value); expr = inner_let_node->body; @@ -375,8 +375,8 @@ class RewriteOnDevices : public ExprMutator { if (props.body.defined() && props.is_normal()) { VLOG(2) << "revising body of function:" << std::endl << PrettyPrint(GetRef(function_node)) << std::endl - << "to be fixed to SEScope " << props.se_scope; - body = MaybeOnDeviceFixed(props.body, props.se_scope); + << "to be fixed to VirtualDevice " << props.virtual_device; + body = MaybeOnDeviceFixed(props.body, props.virtual_device); } return WithFields(GetRef(function_node), function_node->params, std::move(body)); } @@ -412,12 +412,12 @@ class RewriteOnDevices : public ExprMutator { * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. * * Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later, - * from \code on_device(%x, se_scope=d) \endcode we know \p %x must be on device \p d, and thus - * so must \p %y. + * from \code on_device(%x, virtual_device=d) \endcode we know \p %x must be on device \p d, and + * thus so must \p %y. * * Constraints can flow in interesting ways. E.g. in: * \code - * let %f = fn(%x, %y) { add(%x, on_device(%y, se_scope=d)) } + * let %f = fn(%x, %y) { add(%x, on_device(%y, virtual_device=d)) } * let %g = fn(%f, %x, %y) { %f(%x, %y) } * %g(%f, %a, %b) * \endcode @@ -468,21 +468,21 @@ class DeviceAnalyzer : public ExprVisitor { ICHECK(func_type_node); ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size()); - Array se_scopes = + Array virtual_devices = tir::GetPrimFuncArgAndResultConstraints(prim_func, GetRef(func_type_node)); // Build the implied domain (in terms of the function's Relay type) implied by any memory scope // constrains in the function's buffers, for both arguments and results. std::vector args_and_result_domains; - args_and_result_domains.reserve(se_scopes.size()); + args_and_result_domains.reserve(virtual_devices.size()); for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) { - const SEScope& param_se_scope = se_scopes[i]; - VLOG(2) << "param_se_scope[" << i << "] = " << param_se_scope; - args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(param_se_scope)); + const VirtualDevice& param_virtual_device = virtual_devices[i]; + VLOG(2) << "param_virtual_device[" << i << "] = " << param_virtual_device; + args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(param_virtual_device)); } - const SEScope& ret_se_scope = se_scopes.back(); - VLOG(2) << "ret_se_scope = " << ret_se_scope; - args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(ret_se_scope)); + const VirtualDevice& ret_virtual_device = virtual_devices.back(); + VLOG(2) << "ret_virtual_device = " << ret_virtual_device; + args_and_result_domains.push_back(domains_->MakeFirstOrderDomain(ret_virtual_device)); return domains_->MakeHigherOrderDomain(std::move(args_and_result_domains)); } @@ -520,13 +520,14 @@ class DeviceAnalyzer : public ExprVisitor { // The above must match. if (domains_->UnifyOrNull(func_domain, implied_domain) == nullptr) { // higher-order // TODO(mbs): Proper diagnostics. - LOG(FATAL) << "Function parameters and result SEScopes do not match those of call. Call:" - << std::endl - << PrettyPrint(call) << std::endl - << "with function virtual devices:" << std::endl - << domains_->ToString(func_domain) << std::endl - << "and implied call virtual devices:" << std::endl - << domains_->ToString(implied_domain); + LOG(FATAL) + << "Function parameters and result VirtualDevices do not match those of call. Call:" + << std::endl + << PrettyPrint(call) << std::endl + << "with function virtual devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied call virtual devices:" << std::endl + << domains_->ToString(implied_domain); } VLOG(2) << "final call function domain:" << std::endl @@ -584,27 +585,28 @@ class DeviceAnalyzer : public ExprVisitor { VisitExpr(function_node->params[i]); } - // If the function already has SEScope attributes then we can further constrain the + // If the function already has VirtualDevice attributes then we can further constrain the // function's domain to match them. - if (!GetFunctionResultSEScope(function_node)->IsFullyUnconstrained()) { + if (!GetFunctionResultVirtualDevice(function_node)->IsFullyUnconstrained()) { std::vector args_and_result; for (size_t i = 0; i < function_node->params.size(); ++i) { - args_and_result.emplace_back(domains_->ForSEScope( - function_node->params[i]->checked_type(), GetFunctionParamSEScope(function_node, i))); + args_and_result.emplace_back( + domains_->ForVirtualDevice(function_node->params[i]->checked_type(), + GetFunctionParamVirtualDevice(function_node, i))); } - args_and_result.emplace_back(domains_->ForSEScope(function_node->body->checked_type(), - GetFunctionResultSEScope(function_node))); + args_and_result.emplace_back(domains_->ForVirtualDevice( + function_node->body->checked_type(), GetFunctionResultVirtualDevice(function_node))); auto annotation_domain = domains_->MakeHigherOrderDomain(std::move(args_and_result)); if (domains_->UnifyOrNull(func_domain, annotation_domain) == nullptr) { // higher-order // TODO(mbs): Proper diagnostics. - LOG(FATAL) - << "Function SEScopes are incompatible with its \"on_device\" annotation. Function:" - << std::endl - << PrettyPrint(function) << std::endl - << "with function virtual devices:" << std::endl - << domains_->ToString(func_domain) << std::endl - << "and annotation virtual devices:" << std::endl - << domains_->ToString(annotation_domain); + LOG(FATAL) << "Function VirtualDevices are incompatible with its \"on_device\" annotation. " + "Function:" + << std::endl + << PrettyPrint(function) << std::endl + << "with function virtual devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and annotation virtual devices:" << std::endl + << domains_->ToString(annotation_domain); } } @@ -783,7 +785,7 @@ class FreeOnDeviceDefaulter : public ExprVisitor { * \code * def @main(%x, %y, %z) { * let %a = add(%x, %y); - * multiply(%a, on_device(%z, se_scope=d)) + * multiply(%a, on_device(%z, virtual_device=d)) * } * \endcode * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, @@ -801,7 +803,8 @@ class DeviceDefaulter : public ExprVisitor { std::unique_ptr Default() { VLOG_CONTEXT << "DeviceDefaulter"; - VLOG(0) << "defaulting to SEScope " << domains_->config()->default_primitive_se_scope; + VLOG(0) << "defaulting to VirtualDevice " + << domains_->config()->default_primitive_virtual_device; for (const auto& kv : mod_->functions) { if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { VLOG(2) << "defaulting devices for '" << kv.first->name_hint << "'"; @@ -825,7 +828,7 @@ class DeviceDefaulter : public ExprVisitor { if (!domains_->IsFullyConstrained(func_domain)) { VLOG(2) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); domains_->SetResultDefaultThenParams(func_domain, - domains_->config()->default_primitive_se_scope); + domains_->config()->default_primitive_virtual_device); VLOG(2) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); } VisitExpr(function_node->body); @@ -845,7 +848,7 @@ class DeviceDefaulter : public ExprVisitor { // defaulted. VLOG(2) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); domains_->SetResultDefaultThenParams(func_domain, - domains_->config()->default_primitive_se_scope); + domains_->config()->default_primitive_virtual_device); VLOG(2) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); } return ExprVisitor::VisitExpr_(call_node); @@ -858,12 +861,12 @@ class DeviceDefaulter : public ExprVisitor { Let let = Downcast(expr); // If the let-var device is still free force it to match the overall let. auto let_domain = domains_->DomainFor(let); // may be higher-order - SEScope let_se_scope = domains_->ResultSEScope(let_domain); - ICHECK(!let_se_scope->IsFullyUnconstrained()); + VirtualDevice let_virtual_device = domains_->ResultVirtualDevice(let_domain); + ICHECK(!let_virtual_device->IsFullyUnconstrained()); auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order if (!domains_->IsFullyConstrained(let_var_domain)) { VLOG(2) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); - domains_->SetDefault(let_var_domain, let_se_scope); + domains_->SetDefault(let_var_domain, let_virtual_device); VLOG(2) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); } VisitExpr(let->var); @@ -889,7 +892,7 @@ class DeviceDefaulter : public ExprVisitor { * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard * any existing "device_copy" CallNodes which are no-ops. * - * - Functions are given "param_se_scopes" and "result_se_scope" attributes to capture + * - Functions are given "param_virtual_devices" and "result_virtual_device" attributes to capture * the device type for its parameters and result. * * - Additional "device_copy" CallNodes are inserted wherever there's a transition between @@ -910,10 +913,10 @@ class DeviceDefaulter : public ExprVisitor { * * For example, we'll end up with programs that look like: * \code - * def @main(%x, %y, param_se_scopes=[...], result_se_scope=...) { - * let %a = on_device(..., se_scope=..., is_fixed=True) - * @f(%a, device_copy(on_device(..., se_scope=..., is_fixed=True), - * src_se_scope=..., dst_se_scope=...)) + * def @main(%x, %y, param_virtual_devices=[...], result_virtual_device=...) { + * let %a = on_device(..., virtual_device=..., is_fixed=True) + * @f(%a, device_copy(on_device(..., virtual_device=..., is_fixed=True), + * src_virtual_device=..., dst_virtual_device=...)) * } * \endcode */ @@ -961,19 +964,21 @@ class DeviceCapturer : public ExprMutator { ICHECK(func_type_node); ICHECK_EQ(func_domain->function_arity(), func_type_node->arg_types.size()); - std::vector arg_and_result_se_scopes; - arg_and_result_se_scopes.reserve(func_type_node->arg_types.size() + 1); + std::vector arg_and_result_virtual_devices; + arg_and_result_virtual_devices.reserve(func_type_node->arg_types.size() + 1); for (size_t i = 0; i < func_type_node->arg_types.size(); ++i) { - SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); - VLOG(2) << "param_se_scope[" << i << "] = " << param_se_scope; - arg_and_result_se_scopes.push_back(param_se_scope); + VirtualDevice param_virtual_device = + domains_->ResultVirtualDevice(func_domain->function_param(i)); + VLOG(2) << "param_virtual_device[" << i << "] = " << param_virtual_device; + arg_and_result_virtual_devices.push_back(param_virtual_device); } - SEScope ret_se_scope = domains_->ResultSEScope(func_domain->function_result()); - VLOG(2) << "ret_se_scope = " << ret_se_scope; - arg_and_result_se_scopes.push_back(ret_se_scope); + VirtualDevice ret_virtual_device = + domains_->ResultVirtualDevice(func_domain->function_result()); + VLOG(2) << "ret_virtual_device = " << ret_virtual_device; + arg_and_result_virtual_devices.push_back(ret_virtual_device); return tir::ApplyPrimFuncArgAndResultConstraints(prim_func, GetRef(func_type_node), - arg_and_result_se_scopes); + arg_and_result_virtual_devices); } // Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode @@ -1002,26 +1007,28 @@ class DeviceCapturer : public ExprMutator { // Gather the parameter and result device types for the function attributes. ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); - SEScope result_se_scope = domains_->ResultSEScope(func_domain); - ICHECK(!result_se_scope->IsFullyUnconstrained()); - Array param_se_scopes; - param_se_scopes.reserve(function_node->params.size()); + VirtualDevice result_virtual_device = domains_->ResultVirtualDevice(func_domain); + ICHECK(!result_virtual_device->IsFullyUnconstrained()); + Array param_virtual_devices; + param_virtual_devices.reserve(function_node->params.size()); for (size_t i = 0; i < function_node->params.size(); ++i) { - SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); - ICHECK(!param_se_scope->IsFullyUnconstrained()); - param_se_scopes.push_back(param_se_scope); + VirtualDevice param_virtual_device = + domains_->ResultVirtualDevice(func_domain->function_param(i)); + ICHECK(!param_virtual_device->IsFullyUnconstrained()); + param_virtual_devices.push_back(param_virtual_device); } // Rewrite the body. Note that the body may have begun with an "on_device" so // be prepared to insert a "device_copy". Expr body = VisitChild( - /*lexical_se_scope=*/result_se_scope, - /*expected_se_scope=*/result_se_scope, - /*child_se_scope=*/GetSEScope(function_node->body), function_node->body); + /*lexical_virtual_device=*/result_virtual_device, + /*expected_virtual_device=*/result_virtual_device, + /*child_virtual_device=*/GetVirtualDevice(function_node->body), function_node->body); Function func = WithFields(GetRef(function_node), std::move(function_node->params), std::move(body)); - return FunctionOnDevice(func, std::move(param_se_scopes), std::move(result_se_scope)); + return FunctionOnDevice(func, std::move(param_virtual_devices), + std::move(result_virtual_device)); } Expr VisitExpr_(const CallNode* call_node) final { @@ -1031,7 +1038,7 @@ class DeviceCapturer : public ExprMutator { // (However we'll preserve the form in the result below.) auto vanilla_call = GetAnyCall(call_node); - SEScope call_se_scope = GetSEScope(call); + VirtualDevice call_virtual_device = GetVirtualDevice(call); auto on_device_props = GetOnDeviceProps(call_node); if (on_device_props.body.defined()) { @@ -1042,17 +1049,19 @@ class DeviceCapturer : public ExprMutator { DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); if (device_copy_props.body.defined()) { - SEScope src_se_scope = domains_->config()->CanonicalSEScope(device_copy_props.src_se_scope); - SEScope dst_se_scope = domains_->config()->CanonicalSEScope(device_copy_props.dst_se_scope); - ICHECK_EQ(call_se_scope, dst_se_scope); - if (src_se_scope == dst_se_scope) { + VirtualDevice src_virtual_device = + domains_->config()->CanonicalVirtualDevice(device_copy_props.src_virtual_device); + VirtualDevice dst_virtual_device = + domains_->config()->CanonicalVirtualDevice(device_copy_props.dst_virtual_device); + ICHECK_EQ(call_virtual_device, dst_virtual_device); + if (src_virtual_device == dst_virtual_device) { // We can pinch out existing "device_copy" CallNodes if their source and destinations // match. return VisitExpr(device_copy_props.body); } else { - return VisitChild(/*lexical_se_scope=*/dst_se_scope, - /*expected_se_scope=*/dst_se_scope, - /*child_se_scope=*/src_se_scope, device_copy_props.body); + return VisitChild(/*lexical_virtual_device=*/dst_virtual_device, + /*expected_virtual_device=*/dst_virtual_device, + /*child_virtual_device=*/src_virtual_device, device_copy_props.body); } } @@ -1060,16 +1069,17 @@ class DeviceCapturer : public ExprMutator { auto func_domain = domains_->DomainForCallee(call); // higher-order VLOG(2) << "considering call:" << std::endl << PrettyPrint(call) << std::endl - << "in scope " << call_se_scope << " with function virtual devices:" << std::endl + << "in virtual device " << call_virtual_device + << " with function virtual devices:" << std::endl << domains_->ToString(func_domain); - SEScope result_se_scope = domains_->ResultSEScope(func_domain); - ICHECK(!result_se_scope->IsFullyUnconstrained()); + VirtualDevice result_virtual_device = domains_->ResultVirtualDevice(func_domain); + ICHECK(!result_virtual_device->IsFullyUnconstrained()); // The callee is on the current device. Expr op = VisitChild( - /*lexical_se_scope=*/call_se_scope, - /*expected_se_scope=*/call_se_scope, - /*child_se_scope=*/result_se_scope, vanilla_call->op); + /*lexical_virtual_device=*/call_virtual_device, + /*expected_virtual_device=*/call_virtual_device, + /*child_virtual_device=*/result_virtual_device, vanilla_call->op); // Each argument can be on the device for the corresponding function parameter. However if // any of those differ from the overall call device then wrap them in an "on_device" to @@ -1078,13 +1088,14 @@ class DeviceCapturer : public ExprMutator { args.reserve(vanilla_call->args.size()); ICHECK_EQ(func_domain->function_arity(), vanilla_call->args.size()); for (size_t i = 0; i < vanilla_call->args.size(); ++i) { - SEScope param_se_scope = domains_->ResultSEScope(func_domain->function_param(i)); - ICHECK(!param_se_scope->IsFullyUnconstrained()) + VirtualDevice param_virtual_device = + domains_->ResultVirtualDevice(func_domain->function_param(i)); + ICHECK(!param_virtual_device->IsFullyUnconstrained()) << "for parameter " << i << " for call:" << std::endl << PrettyPrint(call); - args.push_back(VisitChild(/*lexical_se_scope=*/call_se_scope, - /*expected_se_scope=*/param_se_scope, - /*child_se_scope=*/GetSEScope(vanilla_call->args[i]), + args.push_back(VisitChild(/*lexical_virtual_device=*/call_virtual_device, + /*expected_virtual_device=*/param_virtual_device, + /*child_virtual_device=*/GetVirtualDevice(vanilla_call->args[i]), vanilla_call->args[i])); } @@ -1100,27 +1111,28 @@ class DeviceCapturer : public ExprMutator { Expr VisitExpr_(const LetNode* let_node) final { Expr expr = GetRef(let_node); // Iterate through chained lets, provided they all agree on their device type. - SEScope let_se_scope = GetSEScope(expr); + VirtualDevice let_virtual_device = GetVirtualDevice(expr); std::vector> bindings; while (const auto* inner_let_node = expr.as()) { Expr inner_let = GetRef(inner_let_node); - if (GetSEScope(inner_let) != let_se_scope) { + if (GetVirtualDevice(inner_let) != let_virtual_device) { // We have a device transition which needs to be handled. break; } // The let-bound value can be on a different device than the overall let. - // By using the fully-unconstrained SEScope for the 'lexical' scope we'll force the let-bound - // value to *always* be wrapped by an "on_device" (see introductory comment for motivation.) - Expr value = - VisitChild(/*lexical_se_scope=*/SEScope::FullyUnconstrained(), - /*expected_se_scope=*/GetSEScope(inner_let_node->var), - /*child_se_scope=*/GetSEScope(inner_let_node->value), inner_let_node->value); + // By using the fully-unconstrained virtual device for the 'lexical' scope we'll force the + // let-bound value to *always* be wrapped by an "on_device" (see introductory comment for + // motivation.) + Expr value = VisitChild(/*lexical_virtual_device=*/VirtualDevice::FullyUnconstrained(), + /*expected_virtual_device=*/GetVirtualDevice(inner_let_node->var), + /*child_virtual_device=*/GetVirtualDevice(inner_let_node->value), + inner_let_node->value); bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); expr = inner_let_node->body; } - Expr body = VisitChild(/*lexical_se_scope=*/let_se_scope, - /*expected_se_scope=*/let_se_scope, - /*child_se_scope=*/GetSEScope(expr), expr); + Expr body = VisitChild(/*lexical_virtual_device=*/let_virtual_device, + /*expected_virtual_device=*/let_virtual_device, + /*child_virtual_device=*/GetVirtualDevice(expr), expr); for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, /*span=*/std::get<2>(*itr)); @@ -1175,68 +1187,70 @@ class DeviceCapturer : public ExprMutator { return WithFields(std::move(match), std::move(data), std::move(clauses)); } - SEScope GetSEScope(const Expr& expr) { + VirtualDevice GetVirtualDevice(const Expr& expr) { // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. OnDeviceProps props = GetOnDeviceProps(expr); Expr true_expr = props.body.defined() ? props.body : expr; ICHECK(domains_->contains(true_expr)); // If expr is higher order we'll return only the result domain's device. - SEScope se_scope = domains_->ResultSEScope(domains_->DomainFor(true_expr)); - ICHECK(!se_scope->IsFullyUnconstrained()) - << "no SEScope was determined for expression:" << std::endl + VirtualDevice virtual_device = domains_->ResultVirtualDevice(domains_->DomainFor(true_expr)); + ICHECK(!virtual_device->IsFullyUnconstrained()) + << "no VirtualDevice was determined for expression:" << std::endl << PrettyPrint(true_expr); - return std::move(se_scope); + return std::move(virtual_device); } /*! - * \brief Reconcile the \p child_se_scope for \p child with both the \p expected_se_scope - * (as required by the expression context the \p child is in) and the \p lexical_se_scope - * (as a downstream transform would infer based only on lexically enclosing "on_device" - * CallNodes and function attributes.) Generally \p lexical_se_scope and \p - * expected_se_scope are the same by definition, but may differ in arguments to functions + * \brief Reconcile the \p child_virtual_device for \p child with both the \p + * expected_virtual_device (as required by the expression context the \p child is in) and the \p + * lexical_virtual_device (as a downstream transform would infer based only on lexically enclosing + * "on_device" CallNodes and function attributes.) Generally \p lexical_virtual_device and \p + * expected_virtual_device are the same by definition, but may differ in arguments to functions * and let-bound expressions. * - * If \p child_se_scope differs from \p expected_se_scope, wrap it as: + * If \p child_virtual_device differs from \p expected_virtual_device, wrap it as: * \code - * device_copy(on_device(child', se_scope=child_se_scope), - * src_dev_type=child_se_scope, dst_dev_type=expected_se_scope) + * device_copy(on_device(child', virtual_device=child_virtual_device), + * src_dev_type=child_virtual_device, dst_dev_type=expected_virtual_device) * \endcode * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the * child. * - * If \p expected_se_scope differs from \p lexical_se_scope, then (also) wrap + * If \p expected_virtual_device differs from \p lexical_virtual_device, then (also) wrap * the expression as: * \code - * on_device(..., se_scope=expected_se_scope) + * on_device(..., virtual_device=expected_virtual_device) * \endcode * * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped * by a "device_copy", even though those copies will generally all be to the same destination * device. */ - Expr VisitChild(const SEScope& lexical_se_scope, const SEScope& expected_se_scope, - const SEScope& child_se_scope, const Expr& child) { - ICHECK(!expected_se_scope->IsFullyUnconstrained()); + Expr VisitChild(const VirtualDevice& lexical_virtual_device, + const VirtualDevice& expected_virtual_device, + const VirtualDevice& child_virtual_device, const Expr& child) { + ICHECK(!expected_virtual_device->IsFullyUnconstrained()); if (child->IsInstance() || child->IsInstance()) { // Primitive operators and contructors don't need to be rewritten and can have a // different domain at each call site. return child; } Expr result = VisitExpr(child); - if (child_se_scope != expected_se_scope) { - VLOG(2) << "creating " << DeviceCopyOp()->name << " from virtual device " << child_se_scope - << " to virtual device " << expected_se_scope << " for:" << std::endl + if (child_virtual_device != expected_virtual_device) { + VLOG(2) << "creating " << DeviceCopyOp()->name << " from virtual device " + << child_virtual_device << " to virtual device " << expected_virtual_device + << " for:" << std::endl << PrettyPrint(result); // Also wrap the child in an "on_device" so downstream transforms can track devices // lexically. - result = MaybeOnDeviceFixed(result, child_se_scope); - result = DeviceCopy(result, child_se_scope, expected_se_scope); + result = MaybeOnDeviceFixed(result, child_virtual_device); + result = DeviceCopy(result, child_virtual_device, expected_virtual_device); } - if (expected_se_scope != lexical_se_scope) { - VLOG(2) << "creating " << OnDeviceOp()->name << " for virtual device " << expected_se_scope - << " for:" << std::endl + if (expected_virtual_device != lexical_virtual_device) { + VLOG(2) << "creating " << OnDeviceOp()->name << " for virtual device " + << expected_virtual_device << " for:" << std::endl << PrettyPrint(result); - result = MaybeOnDeviceFixed(result, expected_se_scope); + result = MaybeOnDeviceFixed(result, expected_virtual_device); } return result; } @@ -1246,9 +1260,10 @@ class DeviceCapturer : public ExprMutator { * is expected to be on the same device as the \p parent. */ Expr VisitChild(const Expr& parent, const Expr& child) { - SEScope expected_se_scope = GetSEScope(parent); - SEScope child_se_scope = GetSEScope(child); - return VisitChild(expected_se_scope, expected_se_scope, child_se_scope, child); + VirtualDevice expected_virtual_device = GetVirtualDevice(parent); + VirtualDevice child_virtual_device = GetVirtualDevice(child); + return VisitChild(expected_virtual_device, expected_virtual_device, child_virtual_device, + child); } /*! \brief Module we are rewriting, so we can lookup global variables. */ @@ -1282,7 +1297,7 @@ tvm::transform::Pass PlanDevicesCore(CompilationConfig config) { VLOG(3) << "Domains after defaulting: " << std::endl << domains->ToString(); // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture - // the above map, and attach additional "param_se_scopes" and "result_se_scope" + // the above map, and attach additional "param_virtual_devices" and "result_virtual_device" // attributes to all function definitions. return DeviceCapturer(mod, std::move(domains)).Capture(); }, diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 831d28b48540..dd8195797e8d 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -41,8 +41,8 @@ namespace transform { namespace { /*! * \brief Returns whether \p expr is a literal \p Constant, optionally wrapped by an "on_device" - * annotation CallNode (which serves only to associate an \p SEScope to the constant and has no - * operational effect). + * annotation CallNode (which serves only to associate an \p VirtualDevice to the constant and has + * no operational effect). */ bool IsSimpleConstant(const Expr& expr) { return AsIgnoringOnDevice(expr) != nullptr; @@ -86,19 +86,19 @@ class ConstantFolder : public MixedModeMutator { // the variable. // // We need to retain any "on_device" annotation so that downstream 'device aware' - // passes can still retrieve the \p SEScope for the constant in its new position(s). Eg: - // def @f(..., result_se_scope=D) { - // let %x = on_device(... something we eval to a constant..., se_scope=E) + // passes can still retrieve the virtual device for the constant in its new position(s). Eg: + // def @f(..., result_virtual_device=D) { + // let %x = on_device(... something we eval to a constant..., virtual_device=E) // @f(..., %x, ...) // } - // Here the default scope is D, whereas the argument %x to @f is on E (and @f expects - // that). No on_device annotation is required in the call according to the convention used - // by the device-aware visitors. + // Here the default virtual device is D, whereas the argument %x to @f is on E (and @f + // expects that). No on_device annotation is required in the call according to the + // convention used by the device-aware visitors. // // However once we've inlined the constant we need to insert an on_device, again to // respect the convention used by the device-aware visitors. - // def @f(..., result_se_scope=D) { - // @f(..., on_device(...the constant..., se_scope=E), ...) + // def @f(..., result_virtual_device=D) { + // @f(..., on_device(...the constant..., virtual_device=E), ...) // } VLOG(1) << "Replacing let-binding for " << op->var->name_hint() << " with constant:" << std::endl @@ -214,7 +214,7 @@ class ConstantFolder : public MixedModeMutator { Expr result = tuple_node->fields[tuple_get_item_node->index]; OnDeviceProps props = GetOnDeviceProps(post_tuple_get_item_node->tuple); if (props.body.defined()) { - // (on_device((x, y, z), se_scope=D).1 ==> on_device(y, se_scope=D) + // (on_device((x, y, z), virtual_device=D).1 ==> on_device(y, virtual_device=D) return MaybeOnDeviceWithProps(result, props); } else { return result; diff --git a/src/relay/transforms/memory_alloc.cc b/src/relay/transforms/memory_alloc.cc index 25827d5e918d..74da99bc3b1b 100644 --- a/src/relay/transforms/memory_alloc.cc +++ b/src/relay/transforms/memory_alloc.cc @@ -61,10 +61,10 @@ namespace relay { class DialectRewriter : public transform::DeviceAwareExprMutator { public: - DialectRewriter(IRModule mod, SEScope host_se_scope) + DialectRewriter(IRModule mod, VirtualDevice host_virtual_device) : transform::DeviceAwareExprMutator(mod), mod_(std::move(mod)), - host_se_scope_(std::move(host_se_scope)) {} + host_virtual_device_(std::move(host_virtual_device)) {} Function Rewrite(const Function& expr) { return Downcast(Mutate(expr)); } @@ -79,10 +79,10 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { for (auto field : tuple_node->fields) { auto new_field = Mutate(field); if (new_field->IsInstance()) { - SEScope se_scope = GetSEScope(field); - ICHECK(!se_scope->IsFullyUnconstrained()); + VirtualDevice virtual_device = GetVirtualDevice(field); + ICHECK(!virtual_device->IsFullyUnconstrained()); Var const_var("const", Type(nullptr)); - new_field = scope.Push(const_var, MaybeOnDeviceFixed(new_field, se_scope)); + new_field = scope.Push(const_var, MaybeOnDeviceFixed(new_field, virtual_device)); } new_fields.push_back(new_field); } @@ -93,9 +93,9 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { std::pair PreVisitLetBinding_(const Var& var, const Expr& value) final { Expr new_value = Mutate(value); - SEScope se_scope = GetSEScope(value); - ICHECK(!se_scope->IsFullyUnconstrained()); - scopes_.back().Push(var, MaybeOnDeviceFixed(new_value, se_scope)); + VirtualDevice virtual_device = GetVirtualDevice(value); + ICHECK(!virtual_device->IsFullyUnconstrained()); + scopes_.back().Push(var, MaybeOnDeviceFixed(new_value, virtual_device)); // Since we always need a let block on which to bind sub-expressions the rewritten bindings // are tracked in the current scopes. But return the rewritten binding anyway. return {var, new_value}; @@ -132,8 +132,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Call call = GetRef(call_node); VLOG(1) << "converting lowered call to DPS:" << std::endl << PrettyPrint(call); - SEScope se_scope = GetSEScope(call); - ICHECK(!se_scope->IsFullyUnconstrained()); + VirtualDevice virtual_device = GetVirtualDevice(call); + ICHECK(!virtual_device->IsFullyUnconstrained()); LetList& scope = scopes_.back(); std::vector new_args; @@ -171,19 +171,20 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { // by a companion shape function. if (IsDynamic(ret_type)) { return DynamicInvoke(&scope, call_lowered_props.lowered_func, ins, call_lowered_props.attrs, - out_types, ret_type, se_scope); + out_types, ret_type, virtual_device); } // Handle ordinary primitive calls. Array outputs; for (size_t i = 0; i < out_types.size(); ++i) { - outputs.push_back(MakeStaticAllocation(&scope, out_types[i], se_scope, std::to_string(i))); + outputs.push_back( + MakeStaticAllocation(&scope, out_types[i], virtual_device, std::to_string(i))); } Tuple outs(outputs); Expr invoke = InvokeTVMOp(call_lowered_props.lowered_func, ins, outs, Downcast(call_lowered_props.attrs.metadata.at("relay_attrs"))); - scope.Push(MaybeOnDeviceFixed(invoke, se_scope)); + scope.Push(MaybeOnDeviceFixed(invoke, virtual_device)); return ToTupleType(ret_type, std::vector(outputs.begin(), outputs.end())); } @@ -199,7 +200,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { /*! Returns an \p alloc_tensor call for a tensor of \p shape and \p dtype over \p storage. */ inline Expr AllocTensor(const Expr& storage, tvm::relay::Expr shape, DataType dtype, Array assert_shape) { - Expr offset = MaybeOnDeviceFixed(MakeConstantScalar(DataType::Int(64), 0), host_se_scope_); + Expr offset = + MaybeOnDeviceFixed(MakeConstantScalar(DataType::Int(64), 0), host_virtual_device_); return tvm::relay::AllocTensor(storage, std::move(offset), std::move(shape), dtype, assert_shape); } @@ -234,28 +236,28 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { } // Allocate a tensor with a statically known shape. - Var MakeStaticAllocation(LetList* scope, const TensorType& type, const SEScope& se_scope, - String name_hint) { + Var MakeStaticAllocation(LetList* scope, const TensorType& type, + const VirtualDevice& virtual_device, String name_hint) { std::vector int_shape; for (auto it : type->shape) { const auto* imm = it.as(); CHECK(imm) << "expect static int shape"; int_shape.push_back(imm->value); } - Expr shape = MaybeOnDeviceFixed(MakeConstant(int_shape), host_se_scope_); - Expr size = MaybeOnDeviceFixed(ComputeStorage(type), host_se_scope_); + Expr shape = MaybeOnDeviceFixed(MakeConstant(int_shape), host_virtual_device_); + Expr size = MaybeOnDeviceFixed(ComputeStorage(type), host_virtual_device_); // Alignment is directly captured in the instruction rather than calculated, so we // don't want to wrap it with an "on_device". Expr alignment = ComputeAlignment(type->dtype); // Run type inference later to get the correct type. Var var("storage_" + name_hint, Type(nullptr)); - Expr value = AllocStorage(size, alignment, se_scope, type->dtype); - auto sto = scope->Push(var, MaybeOnDeviceFixed(value, se_scope)); + Expr value = AllocStorage(size, alignment, virtual_device, type->dtype); + auto sto = scope->Push(var, MaybeOnDeviceFixed(value, virtual_device)); // TODO(@jroesch): There is a bug with typing based on the constant shape. auto tensor = AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape); Var tensor_var("tensor_" + name_hint, Type(nullptr)); - return scope->Push(tensor_var, MaybeOnDeviceFixed(tensor, se_scope)); + return scope->Push(tensor_var, MaybeOnDeviceFixed(tensor, virtual_device)); } /*! @@ -294,21 +296,21 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { Expr sh_of = Mutate(ShapeOf(exprs[j])); Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr)); shape_func_ins.push_back( - scope->Push(in_shape_var, MaybeOnDeviceFixed(sh_of, host_se_scope_))); + scope->Push(in_shape_var, MaybeOnDeviceFixed(sh_of, host_virtual_device_))); input_pos++; } } else if (state == tec::kNeedInputData) { auto new_arg = Mutate(arg); // already accounts for device - SEScope arg_se_scope = GetSEScope(arg); - ICHECK(!arg_se_scope->IsFullyUnconstrained()); + VirtualDevice arg_virtual_device = GetVirtualDevice(arg); + ICHECK(!arg_virtual_device->IsFullyUnconstrained()); // The dynamic shape function is expecting its data on the host/CPU, so insert a // device_copy otherwise. (We'll need to fuse & lower these copies in the same way // we fuse & lower other operators we insert for, eg, dynamic tensor size calculation.) - new_arg = MaybeDeviceCopy(MaybeOnDeviceFixed(new_arg, arg_se_scope), arg_se_scope, - host_se_scope_); + new_arg = MaybeDeviceCopy(MaybeOnDeviceFixed(new_arg, arg_virtual_device), + arg_virtual_device, host_virtual_device_); Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr)); shape_func_ins.push_back( - scope->Push(in_shape_var, MaybeOnDeviceFixed(new_arg, host_se_scope_))); + scope->Push(in_shape_var, MaybeOnDeviceFixed(new_arg, host_virtual_device_))); input_pos++; } else { // TODO(@jroesch): handle kNeedBoth @@ -327,8 +329,8 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { ICHECK(tensor_type_node); // Put the shape func on the host. This also ensures that everything between // shape_of and shape_func is similarly on the host. - Var alloc = MakeStaticAllocation(scope, GetRef(tensor_type_node), host_se_scope_, - "out_shape_" + std::to_string(i)); + Var alloc = MakeStaticAllocation(scope, GetRef(tensor_type_node), + host_virtual_device_, "out_shape_" + std::to_string(i)); out_shapes.push_back(alloc); } @@ -336,26 +338,27 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto shape_call = InvokeTVMOp(prim_fn_var, Tuple(shape_func_ins), Tuple(out_shapes), Downcast(attrs.metadata.at("relay_attrs"))); Var shape_func_var("shape_func", Type(nullptr)); - scope->Push(shape_func_var, MaybeOnDeviceFixed(shape_call, host_se_scope_)); + scope->Push(shape_func_var, MaybeOnDeviceFixed(shape_call, host_virtual_device_)); return out_shapes; } // Generate the code for invoking the TVM primitive \p func who's results have dynamic shapes. Expr DynamicInvoke(LetList* scope, const Expr& func, const Tuple& ins, const CallLoweredAttrs& attrs, const std::vector& out_types, - const Type& ret_type, const SEScope& se_scope) { + const Type& ret_type, const VirtualDevice& virtual_device) { Array out_shapes = EmitShapeFunc(scope, ins, attrs); std::vector storages; CHECK_EQ(out_shapes.size(), out_types.size()); for (size_t i = 0; i < out_shapes.size(); ++i) { auto out_shape = out_shapes[i]; auto out_type = out_types[i]; - auto size = MaybeOnDeviceFixed(ComputeStorageInRelay(out_shape, out_type), host_se_scope_); + auto size = + MaybeOnDeviceFixed(ComputeStorageInRelay(out_shape, out_type), host_virtual_device_); // Alignment is directly captured in the instruction so don't wrap in "on_device". auto alignment = ComputeAlignment(out_type->dtype); Var sto_var("storage_" + std::to_string(i), Type(nullptr)); - auto val = AllocStorage(size, alignment, se_scope, out_type->dtype); - storages.push_back(scope->Push(sto_var, MaybeOnDeviceFixed(val, se_scope))); + auto val = AllocStorage(size, alignment, virtual_device, out_type->dtype); + storages.push_back(scope->Push(sto_var, MaybeOnDeviceFixed(val, virtual_device))); } Array outs; @@ -365,13 +368,13 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { auto storage = storages[i]; auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape); Var out_var("out_" + std::to_string(i), Type(nullptr)); - outs.push_back(scope->Push(out_var, MaybeOnDeviceFixed(alloc, se_scope))); + outs.push_back(scope->Push(out_var, MaybeOnDeviceFixed(alloc, virtual_device))); } Tuple tuple_outs(outs); auto call = InvokeTVMOp(func, ins, tuple_outs, Downcast(attrs.metadata.at("relay_attrs"))); - scope->Push(MaybeOnDeviceFixed(call, se_scope)); + scope->Push(MaybeOnDeviceFixed(call, virtual_device)); return ToTupleType(ret_type, std::vector(tuple_outs->fields.begin(), tuple_outs->fields.end())); } @@ -395,7 +398,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { CHECK(imm) << "expect static int shape"; shape.push_back(imm->value); } - shape_expr = MaybeOnDeviceFixed(MakeConstant(shape), host_se_scope_); + shape_expr = MaybeOnDeviceFixed(MakeConstant(shape), host_virtual_device_); } return ReshapeTensor(ins->fields[0], shape_expr, ret_ty->shape); } @@ -404,7 +407,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator { const Op& device_copy_op_ = Op::Get("device_copy"); runtime::DataType compute_dtype_ = runtime::DataType::Int(64); IRModule mod_; - SEScope host_se_scope_; + VirtualDevice host_virtual_device_; std::vector scopes_; }; @@ -421,16 +424,16 @@ Pass ManifestAllocImportStorage() { /*required=*/{}); } -Pass ManifestAllocImpl(SEScope host_se_scope) { - auto pass_func = [host_se_scope](Function func, IRModule mod, PassContext ctxt) { - return DialectRewriter(mod, host_se_scope).Rewrite(func); +Pass ManifestAllocImpl(VirtualDevice host_virtual_device) { + auto pass_func = [host_virtual_device](Function func, IRModule mod, PassContext ctxt) { + return DialectRewriter(mod, host_virtual_device).Rewrite(func); }; return CreateFunctionPass(pass_func, 0, "ManifestAllocImpl", {}); } -Pass ManifestAlloc(SEScope host_se_scope) { +Pass ManifestAlloc(VirtualDevice cpu_virtual_device) { std::vector passes = {ManifestAllocImportStorage(), InferType(), - ManifestAllocImpl(std::move(host_se_scope)), InferType()}; + ManifestAllocImpl(std::move(cpu_virtual_device)), InferType()}; return Sequential(passes, "ManifestAlloc"); } diff --git a/src/relay/transforms/to_a_normal_form.cc b/src/relay/transforms/to_a_normal_form.cc index 741de6d7ea9b..321839d81e3e 100644 --- a/src/relay/transforms/to_a_normal_form.cc +++ b/src/relay/transforms/to_a_normal_form.cc @@ -211,14 +211,14 @@ class Fill : ExprFunctor, private transform::Lexi } Expr Atomic(const Expr& e, const Var& v) { - Expr annotated_expr = MaybeOnDeviceFixed(e, GetSEScope(e)); + Expr annotated_expr = MaybeOnDeviceFixed(e, GetVirtualDevice(e)); return v.defined() ? GetScope(e)->let_list->Push(v, annotated_expr) : annotated_expr; } // Bind expression `now` to var `v` if the original expression is in the include set, or if // v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly Expr Compound(const Expr& orig, const Expr& now, const Var& v) { - Expr annotated_expr = MaybeOnDeviceFixed(now, GetSEScope(orig)); + Expr annotated_expr = MaybeOnDeviceFixed(now, GetVirtualDevice(orig)); Var var = v.defined() ? v : Var::GenSym(); bool not_included = include_set_ && include_set_->find(orig) == include_set_->end(); if (!v.defined() && not_included) { @@ -232,10 +232,10 @@ class Fill : ExprFunctor, private transform::Lexi OnDeviceProps props = GetOnDeviceProps(c); if (props.body.defined() && props.is_fixed()) { // Keep track of expression device type for lexically enclosing sub-expressions. - PushSEScope(props.se_scope); + PushVirtualDevice(props.virtual_device); Expr body = VisitExpr(props.body, v); // We are done with this sub-expression. - PopSEScope(); + PopVirtualDevice(); // Preserve the "on_device" annotations. return OnDeviceWithProps(body, props); } @@ -293,9 +293,9 @@ class Fill : ExprFunctor, private transform::Lexi } else { // Keep track of expression and bound variable device types for lexically enclosing // sub-expressions. - PushSEScope(GetFunctionResultSEScope(f)); + PushVirtualDevice(GetFunctionResultVirtualDevice(f)); for (size_t i = 0; i < f->params.size(); ++i) { - PushBoundVar(f->params[i], GetFunctionParamSEScope(f, i)); + PushBoundVar(f->params[i], GetFunctionParamVirtualDevice(f, i)); } EnterFunctionBody(); ret = Function(f->params, GetSubScope(e, 0)->let_list->Get(VisitExpr(f->body)), f->ret_type, @@ -305,7 +305,7 @@ class Fill : ExprFunctor, private transform::Lexi for (size_t i = 0; i < f->params.size(); ++i) { PopBoundVar(f->params[i]); } - PopSEScope(); + PopVirtualDevice(); } if (function_nesting() == 0) { ICHECK(!v.defined()); @@ -320,7 +320,7 @@ class Fill : ExprFunctor, private transform::Lexi Expr VisitExpr_(const LetNode* l, const Var& v) final { Expr e = GetRef(l); // Keep track of bound variable device types for lexically enclosing sub-expressions. - PushBoundVar(l->var, GetSEScope(l->value)); + PushBoundVar(l->var, GetVirtualDevice(l->value)); VisitExpr(l->value, l->var); Expr ret = GetSubScope(e, 0)->let_list->Get(VisitExpr(l->body)); // We are done with these sub-expressions. diff --git a/src/target/compilation_config.cc b/src/target/compilation_config.cc index e9c4daf7bf05..0401eebe51ef 100644 --- a/src/target/compilation_config.cc +++ b/src/target/compilation_config.cc @@ -33,26 +33,27 @@ void CompilationConfigNode::VisitAttrs(AttrVisitor* v) { v->Visit("legacy_target_map", &legacy_target_map); v->Visit("host_target", &host_target); v->Visit("primitive_targets", &primitive_targets); - v->Visit("default_primitive_se_scope", &default_primitive_se_scope); - v->Visit("host_se_scope", &host_se_scope); + v->Visit("default_primitive_virtual_device", &default_primitive_virtual_device); + v->Visit("host_virtual_device", &host_virtual_device); v->Visit("optional_homogenous_target", &optional_homogeneous_target); - // NOTE: The se_scope_cache_ is not accessible via FFI. + // NOTE: The virtual_device_cache_ is not accessible via FFI. } -SEScope CompilationConfigNode::CanonicalSEScope(const SEScope& se_scope) const { - if (se_scope->target.defined()) { - return se_scope_cache_.Unique(se_scope); +VirtualDevice CompilationConfigNode::CanonicalVirtualDevice( + const VirtualDevice& virtual_device) const { + if (virtual_device->target.defined()) { + return virtual_device_cache_.Unique(virtual_device); } - DLDeviceType device_type = se_scope->device_type(); + DLDeviceType device_type = virtual_device->device_type(); // TODO(mbs): Proper diagnostics. CHECK(device_type != kInvalidDeviceType) - << "SEScope annotations must include at least a device_type"; - Target target = FindPrimitiveTargetOrFail(se_scope->device_type()); - return se_scope_cache_.Unique( - SEScope(device_type, se_scope->virtual_device_id, target, se_scope->memory_scope)); + << "VirtualDevice annotations must include at least a device_type"; + Target target = FindPrimitiveTargetOrFail(virtual_device->device_type()); + return virtual_device_cache_.Unique(VirtualDevice(device_type, virtual_device->virtual_device_id, + target, virtual_device->memory_scope)); } -void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContext& pass_ctx) { +void CompilationConfigNode::EstablishDefaultVirtualDevices(const transform::PassContext& pass_ctx) { // // Gather the hints as to what our default device type for the 'host' should be, and // create an appropriate target if we don't already have one. @@ -105,9 +106,10 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex } // - // Establish the host SEScope. + // Establish the host VirtualDevice. // - host_se_scope = se_scope_cache_.Unique(SEScope(host_device_type, + host_virtual_device = + virtual_device_cache_.Unique(VirtualDevice(host_device_type, /*virtual_device_id=*/0, host_target)); // @@ -149,11 +151,12 @@ void CompilationConfigNode::EstablishDefaultSEScopes(const transform::PassContex } // - // Establish the default primitive SEScope, choosing a known Target to match the device type. + // Establish the default primitive VirtualDevice, choosing a known Target to match the device + // type. // - default_primitive_se_scope = se_scope_cache_.Unique( - SEScope(default_primitive_device_type, - /*virtual_device_id=*/0, FindPrimitiveTargetOrFail(default_primitive_device_type))); + default_primitive_virtual_device = virtual_device_cache_.Unique(VirtualDevice( + default_primitive_device_type, + /*virtual_device_id=*/0, FindPrimitiveTargetOrFail(default_primitive_device_type))); } /* static */ Target CompilationConfigNode::MakeDefaultTarget(DLDeviceType device_type) { @@ -205,7 +208,7 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, // Complete the targets vector and establish default scopes. After this primitive_targets will // contain the definitive list of all required targets, target_host will be defined, and // all primitive targets will have host target_host. - node->EstablishDefaultSEScopes(pass_ctx); + node->EstablishDefaultVirtualDevices(pass_ctx); // LEGACY: Reconstruct the target map from all the primitive targets. // Note that we require pointer equality between targets in legacy_target_map and @@ -214,8 +217,8 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, node->legacy_target_map.Set(Integer(primitive_target->kind->device_type), primitive_target); } - ICHECK(node->default_primitive_se_scope->target.defined()); - ICHECK(node->host_se_scope->target.defined()); + ICHECK(node->default_primitive_virtual_device->target.defined()); + ICHECK(node->host_virtual_device->target.defined()); ICHECK_GT(node->primitive_targets.size(), 0U); // Legacy: Some passes only support homogenous compilation and expect the target to be @@ -227,8 +230,8 @@ CompilationConfig::CompilationConfig(const transform::PassContext& pass_ctx, DLOG(INFO) << "Target " << target->ToDebugString() << " of device type " << target->kind->device_type << " is available for primitives"; } - DLOG(INFO) << "Using default primitive scope " << node->default_primitive_se_scope; - DLOG(INFO) << "Using host scope " << node->host_se_scope; + DLOG(INFO) << "Using default primitive virtual device " << node->default_primitive_virtual_device; + DLOG(INFO) << "Using host virtual device " << node->host_virtual_device; data_ = std::move(node); } diff --git a/src/target/se_scope.cc b/src/target/virtual_device.cc similarity index 71% rename from src/target/se_scope.cc rename to src/target/virtual_device.cc index 8e6c6fe7f2a2..cde58d3cc22c 100644 --- a/src/target/se_scope.cc +++ b/src/target/virtual_device.cc @@ -18,21 +18,22 @@ */ /*! - * \file tvm/target/se_scope.cc - * \brief Implementation of \p SEScope for representing a Storage or Execution scope. + * \file tvm/target/virtual_device.cc + * \brief A compile time representation for where data is to be stored at runtime, and how to + * compile code to compute it. */ #include #include -#include +#include namespace tvm { -TVM_REGISTER_NODE_TYPE(SEScopeNode); +TVM_REGISTER_NODE_TYPE(VirtualDeviceNode); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = ref.as(); - p->stream << "SEScope("; + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = ref.as(); + p->stream << "VirtualDevice("; if (node->IsFullyUnconstrained()) { p->stream << "?"; } else { @@ -65,12 +66,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << ")"; }); -SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target, - MemoryScope memory_scope) { +VirtualDevice::VirtualDevice(DLDeviceType device_type, int virtual_device_id, Target target, + MemoryScope memory_scope) { ICHECK(!target.defined() || device_type == target->kind->device_type) << "target " << target->ToDebugString() << " has device type " << target->kind->device_type - << " but scope has device type " << device_type; - auto node = make_object(); + << " but virtual device has device type " << device_type; + auto node = make_object(); node->device_type_int = device_type; node->virtual_device_id = virtual_device_id; node->target = std::move(target); @@ -78,13 +79,13 @@ SEScope::SEScope(DLDeviceType device_type, int virtual_device_id, Target target, data_ = std::move(node); } -/* static */ SEScope SEScope::FullyUnconstrained() { - static const SEScope unconstrained{}; +/* static */ VirtualDevice VirtualDevice::FullyUnconstrained() { + static const VirtualDevice unconstrained{}; return unconstrained; } /* static */ -Optional SEScope::Join(const SEScope& lhs, const SEScope& rhs) { +Optional VirtualDevice::Join(const VirtualDevice& lhs, const VirtualDevice& rhs) { if (lhs == rhs) { return lhs; } @@ -124,11 +125,12 @@ Optional SEScope::Join(const SEScope& lhs, const SEScope& rhs) { } else { joined_memory_scope = rhs->memory_scope; } - return SEScope(joined_device_type, joined_virtual_device_id, joined_target, joined_memory_scope); + return VirtualDevice(joined_device_type, joined_virtual_device_id, joined_target, + joined_memory_scope); } /* static */ -SEScope SEScope::Default(const SEScope& lhs, const SEScope& rhs) { +VirtualDevice VirtualDevice::Default(const VirtualDevice& lhs, const VirtualDevice& rhs) { if (lhs == rhs) { return lhs; } @@ -160,13 +162,14 @@ SEScope SEScope::Default(const SEScope& lhs, const SEScope& rhs) { } else { defaulted_memory_scope = rhs->memory_scope; } - return SEScope(defaulted_device_type, defaulted_virtual_device_id, defaulted_target, - defaulted_memory_scope); + return VirtualDevice(defaulted_device_type, defaulted_virtual_device_id, defaulted_target, + defaulted_memory_scope); } -SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Target target, - MemoryScope memory_scope) { - SEScope prototype(device_type, virtual_device_id, std::move(target), std::move(memory_scope)); +VirtualDevice VirtualDeviceCache::Make(DLDeviceType device_type, int virtual_device_id, + Target target, MemoryScope memory_scope) { + VirtualDevice prototype(device_type, virtual_device_id, std::move(target), + std::move(memory_scope)); auto itr = cache_.find(prototype); if (itr == cache_.end()) { cache_.emplace(prototype); @@ -180,11 +183,12 @@ SEScope SEScopeCache::Make(DLDeviceType device_type, int virtual_device_id, Targ } } -SEScope SEScopeCache::Unique(const SEScope& scope) { - return Make(scope->device_type(), scope->virtual_device_id, scope->target, scope->memory_scope); +VirtualDevice VirtualDeviceCache::Unique(const VirtualDevice& virtual_device) { + return Make(virtual_device->device_type(), virtual_device->virtual_device_id, + virtual_device->target, virtual_device->memory_scope); } -TVM_REGISTER_GLOBAL("target.SEScope_ForDeviceTargetAndMemoryScope") - .set_body_typed(SEScope::ForDeviceTargetAndMemoryScope); +TVM_REGISTER_GLOBAL("target.VirtualDevice_ForDeviceTargetAndMemoryScope") + .set_body_typed(VirtualDevice::ForDeviceTargetAndMemoryScope); } // namespace tvm diff --git a/src/tir/analysis/device_constraint_utils.cc b/src/tir/analysis/device_constraint_utils.cc index 8412cb8b8923..ae2b9500649e 100644 --- a/src/tir/analysis/device_constraint_utils.cc +++ b/src/tir/analysis/device_constraint_utils.cc @@ -32,7 +32,7 @@ #include "./device_constraint_utils.h" #include -#include +#include #include #include @@ -104,11 +104,11 @@ void CheckNoRemainingPointerParams(const tir::PrimFunc& prim_func, * using \p prim_func parameters at or after \p *current_primfunc_param_index. Currently * only memory scope is extracted. Fails if constraints are not consistent, ie \p type is a tuple * type and the \p prim_func is attempting to map different fields of that tuple to different memory - * scopes. Returns the fully unconstrained \p SEScope if no memory scopes constraints arise from - * the \p prim_func, ie all storage scope strings in pointer types are empty. + * scopes. Returns the fully unconstrained \p VirtualDevice if no memory scopes constraints arise + * from the \p prim_func, ie all storage scope strings in pointer types are empty. */ -SEScope ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& type, - size_t* current_primfunc_param_index) { +VirtualDevice ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& type, + size_t* current_primfunc_param_index) { std::string memory_scope; // default empty => no constraint for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) { std::pair kv = FindPointerParam(prim_func, current_primfunc_param_index); @@ -120,25 +120,26 @@ SEScope ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& ty } else if (buffer_memory_scope.empty()) { // No constraint. } else { - // Tuples must be homogenous on their SEScope and thus memory scope. + // Tuples must be homogenous on their VirtualDevice and thus memory scope. ICHECK_EQ(buffer_memory_scope, memory_scope); } ++*current_primfunc_param_index; } - return SEScope::ForMemoryScope(memory_scope); + return VirtualDevice::ForMemoryScope(memory_scope); } /*! * \brief Insert into param_constraints an entry for each parameter of \p prim_func starting from * \p *current_primfunc_param_index for the flattened form of a Rleay parameters of \p type. Each - * entry maps to \p se_scope. + * entry maps to \p virtual_device. */ -void InsertParamConstraints(const tir::PrimFunc& prim_func, const Type& type, - const SEScope& se_scope, size_t* current_primfunc_param_index, - std::unordered_map* param_constraints) { +void InsertParamConstraints( + const tir::PrimFunc& prim_func, const Type& type, const VirtualDevice& virtual_device, + size_t* current_primfunc_param_index, + std::unordered_map* param_constraints) { for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) { std::pair kv = FindPointerParam(prim_func, current_primfunc_param_index); - param_constraints->emplace(kv.first.get(), se_scope); + param_constraints->emplace(kv.first.get(), virtual_device); ++*current_primfunc_param_index; } } @@ -186,22 +187,22 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { * memory scopes needed to change. */ PrimFunc Rewrite(const PrimFunc& prim_func, const FuncType& relay_func_type, - const Array& arg_and_result_se_scopes) { + const Array& arg_and_result_virtual_devices) { size_t current_primfunc_param_index = 0; - std::unordered_map param_constraints; + std::unordered_map param_constraints; // For each Relay function parameter... for (size_t i = 0; i < relay_func_type->arg_types.size(); ++i) { const Type& param_type = relay_func_type->arg_types[i]; - const SEScope& param_se_scope = arg_and_result_se_scopes[i]; - InsertParamConstraints(prim_func, param_type, param_se_scope, ¤t_primfunc_param_index, - ¶m_constraints); + const VirtualDevice& param_virtual_device = arg_and_result_virtual_devices[i]; + InsertParamConstraints(prim_func, param_type, param_virtual_device, + ¤t_primfunc_param_index, ¶m_constraints); } // For the Relay function result... const Type& ret_type = relay_func_type->ret_type; - const SEScope& ret_se_scope = arg_and_result_se_scopes.back(); - InsertParamConstraints(prim_func, ret_type, ret_se_scope, ¤t_primfunc_param_index, + const VirtualDevice& ret_virtual_device = arg_and_result_virtual_devices.back(); + InsertParamConstraints(prim_func, ret_type, ret_virtual_device, ¤t_primfunc_param_index, ¶m_constraints); // Make sure we accounted for all prim_func parameters. @@ -214,10 +215,10 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { // For each constrained parameter... for (const auto& kv : param_constraints) { const tir::Var param = GetRef(kv.first); - const SEScope& se_scope = kv.second; + const VirtualDevice& virtual_device = kv.second; const tir::Buffer& buffer = prim_func->buffer_map[param]; // Rewrite the buffer to account for constraint. - const Buffer new_buffer = RewriteBuffer(buffer, se_scope); + const Buffer new_buffer = RewriteBuffer(buffer, virtual_device); if (!new_buffer.same_as(buffer)) { any_change = true; } @@ -357,10 +358,10 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { BufferRegion new_source = VisitItem(match_buffer_region_node->source.get()); // The buffer field however is a definitional occurrence, aliased on top of the source. // Transfer any memory scope from the source to the destination. - Optional opt_se_scope = GetBufferConstraint(new_source->buffer); + Optional opt_virtual_device = GetBufferConstraint(new_source->buffer); tir::Buffer new_buffer; - if (opt_se_scope.defined()) { - new_buffer = RewriteBuffer(match_buffer_region_node->buffer, opt_se_scope.value()); + if (opt_virtual_device.defined()) { + new_buffer = RewriteBuffer(match_buffer_region_node->buffer, opt_virtual_device.value()); } else { new_buffer = match_buffer_region_node->buffer; } @@ -407,21 +408,21 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { } /*! - * \brief Rewrites \p buffer so as to follow the constraints in \p se_scope + * \brief Rewrites \p buffer so as to follow the constraints in \p virtual_device * (currently just memory scope). * * Updates both the var_subst_ and buffer_subst_ to capture the rewrite, but * also returns the new buffer. */ - Buffer RewriteBuffer(const Buffer& buffer, const SEScope& se_scope) { + Buffer RewriteBuffer(const Buffer& buffer, const VirtualDevice& virtual_device) { ICHECK(buffer->data->type_annotation.defined()); const auto* pointer_type_node = buffer->data->type_annotation.as(); ICHECK(pointer_type_node); - if (pointer_type_node->storage_scope == se_scope->memory_scope) { + if (pointer_type_node->storage_scope == virtual_device->memory_scope) { // No change. return buffer; } - PointerType new_pointer_type(pointer_type_node->element_type, se_scope->memory_scope); + PointerType new_pointer_type(pointer_type_node->element_type, virtual_device->memory_scope); Var new_data(buffer->data->name_hint, new_pointer_type, buffer->data->span); var_subst_.emplace(buffer->data.get(), new_data); Buffer new_buffer(new_data, buffer->dtype, buffer->shape, buffer->strides, buffer->elem_offset, @@ -432,14 +433,15 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { } /*! - * \brief Returns the SEScope capturing any memory scope in \p buffer. Returns nullptr if + * \brief Returns the VirtualDevice capturing any memory scope in \p buffer. Returns nullptr if * buffer's data var does not have a type annotation of \p PointerType. Returns the fully - * unconstrained \p SEScope if no memory scope is given. + * unconstrained \p VirtualDevice if no memory scope is given. */ - static Optional GetBufferConstraint(const tir::Buffer& buffer) { + static Optional GetBufferConstraint(const tir::Buffer& buffer) { const auto* pointer_type_node = PointerInBuffer(buffer); - return pointer_type_node == nullptr ? Optional() - : SEScope::ForMemoryScope(pointer_type_node->storage_scope); + return pointer_type_node == nullptr + ? Optional() + : VirtualDevice::ForMemoryScope(pointer_type_node->storage_scope); } /*! @@ -455,59 +457,60 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator { } // namespace -Array GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, - const FuncType& relay_func_type) { +Array GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, + const FuncType& relay_func_type) { // Build the implied domain (in terms of the function's Relay type) implied by any memory scope // constrains in the function's buffers, for both arguments and results. - Array se_scopes; - se_scopes.reserve(relay_func_type->arg_types.size() + 1); + Array virtual_devices; + virtual_devices.reserve(relay_func_type->arg_types.size() + 1); // For each Relay function parameter... size_t current_primfunc_param_index = 0; for (const auto& param_type : relay_func_type->arg_types) { - SEScope param_se_scope = + VirtualDevice param_virtual_device = ConsistentParamConstraint(prim_func, param_type, ¤t_primfunc_param_index); - se_scopes.push_back(param_se_scope); + virtual_devices.push_back(param_virtual_device); } // For the Relay function result... const Type& ret_type = relay_func_type->ret_type; - SEScope ret_se_scope = + VirtualDevice ret_virtual_device = ConsistentParamConstraint(prim_func, ret_type, ¤t_primfunc_param_index); - se_scopes.push_back(ret_se_scope); + virtual_devices.push_back(ret_virtual_device); // Make sure all parameters of the prim_func have been accounted for. CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); - return se_scopes; + return virtual_devices; } TVM_REGISTER_GLOBAL("tir.analysis.GetPrimFuncArgAndResultMemoryConstraints") .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type) { Array memory_scopes; memory_scopes.reserve(relay_func_type->type_params.size() + 1); - for (const auto& se_scope : GetPrimFuncArgAndResultConstraints(prim_func, relay_func_type)) { - memory_scopes.push_back(se_scope->memory_scope); + for (const auto& virtual_device : + GetPrimFuncArgAndResultConstraints(prim_func, relay_func_type)) { + memory_scopes.push_back(virtual_device->memory_scope); } return memory_scopes; }); -PrimFunc ApplyPrimFuncArgAndResultConstraints(const PrimFunc& prim_func, - const FuncType& relay_func_type, - const Array& arg_and_result_se_scopes) { +PrimFunc ApplyPrimFuncArgAndResultConstraints( + const PrimFunc& prim_func, const FuncType& relay_func_type, + const Array& arg_and_result_virtual_devices) { return ApplyDeviceConstraintsMutator().Rewrite(prim_func, relay_func_type, - arg_and_result_se_scopes); + arg_and_result_virtual_devices); } TVM_REGISTER_GLOBAL("tir.analysis.ApplyPrimFuncArgAndResultMemoryConstraints") .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type, const Array& arg_and_result_memory_scopes) { - Array se_scopes; - se_scopes.reserve(arg_and_result_memory_scopes.size()); + Array virtual_devices; + virtual_devices.reserve(arg_and_result_memory_scopes.size()); for (const auto& memory_scope : arg_and_result_memory_scopes) { - se_scopes.push_back(SEScope::ForMemoryScope(memory_scope)); + virtual_devices.push_back(VirtualDevice::ForMemoryScope(memory_scope)); } - return ApplyPrimFuncArgAndResultConstraints(prim_func, relay_func_type, se_scopes); + return ApplyPrimFuncArgAndResultConstraints(prim_func, relay_func_type, virtual_devices); }); } // namespace tir diff --git a/src/tir/analysis/device_constraint_utils.h b/src/tir/analysis/device_constraint_utils.h index be0f199f5226..717bf5280c00 100644 --- a/src/tir/analysis/device_constraint_utils.h +++ b/src/tir/analysis/device_constraint_utils.h @@ -23,21 +23,21 @@ * parameters. * * These utilities are used by the \p PlanDevices pass to extract memory (aka 'storage') scope - * information from \p PrimFuncs and convert them back into \p SEScope form w.r.t. the original - * Relay type of the \p PrimFunc (ie before flattening of tuple arguments/results and conversion - * to destination-passing style aka DPS). + * information from \p PrimFuncs and convert them back into \p VirtualDevice form w.r.t. the + * original Relay type of the \p PrimFunc (ie before flattening of tuple arguments/results and + * conversion to destination-passing style aka DPS). * * A utility is also supplied to go the other way: impose memory scopes on \p PrimFunc parameters. * However that's still in EXPERIMENTAL form. * * We may extend these utilities to also gather/apply layout information should we add that to - * \p SEScope. + * \p VirtualDevice. */ #ifndef TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ #define TVM_TIR_ANALYSIS_DEVICE_CONSTRAINT_UTILS_H_ -#include +#include #include namespace tvm { @@ -71,26 +71,26 @@ namespace tir { */ /*! - * \brief Returns the \p SEScopes capturing the memory (aka storage) scope constraints for all the - * arguments and result of \p prim_func. However the result will be w.r.t. the \p prim_func's + * \brief Returns the \p VirtualDevices capturing the memory (aka storage) scope constraints for all + * the arguments and result of \p prim_func. However the result will be w.r.t. the \p prim_func's * representation as a Relay \p Function of \p relay_func_type_ before lowering and conversion to * DPS. */ -Array GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, - const FuncType& relay_func_type); +Array GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, + const FuncType& relay_func_type); /* * \brief Returns \p prim_func written to capture the memory (aka storage) scope constraints - * for each of the \p prim_func's parameters given by \p arg_and_result_se_scopes. However, - * \p arg_and_result_se_scopes should be w.r.t. the \p prim_func's representation as a Relay + * for each of the \p prim_func's parameters given by \p arg_and_result_virtual_devices. However, + * \p arg_and_result_virtual_devices should be w.r.t. the \p prim_func's representation as a Relay * \p Function of \p relay_func_type before lowering and conversion to DPS. * * CAUTION: This is experimental. The resulting \p PrimFunc may not have fully accounted for all * new memory scopes. */ -PrimFunc ApplyPrimFuncArgAndResultConstraints(const PrimFunc& prim_func, - const FuncType& relay_func_type, - const Array& arg_and_result_se_scopes); +PrimFunc ApplyPrimFuncArgAndResultConstraints( + const PrimFunc& prim_func, const FuncType& relay_func_type, + const Array& arg_and_result_virtual_devices); } // namespace tir } // namespace tvm diff --git a/tests/cpp/relay/op/memory/on_device_test.cc b/tests/cpp/relay/op/memory/on_device_test.cc index 45d4f881c454..6f0a0b0d8beb 100644 --- a/tests/cpp/relay/op/memory/on_device_test.cc +++ b/tests/cpp/relay/op/memory/on_device_test.cc @@ -30,22 +30,22 @@ TEST(OnDeviceOp, Name) { EXPECT_EQ(OnDeviceOp()->name, "on_device"); } TEST(OnDevice, Default) { Var body("x", {}); - SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); - Call call = OnDevice(body, se_scope); + VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3); + Call call = OnDevice(body, virtual_device); EXPECT_EQ(call->op, OnDeviceOp()); EXPECT_EQ(call->args.size(), 1); EXPECT_EQ(call->args[0], body); const auto* attrs = call->attrs.as(); ASSERT_TRUE(attrs != nullptr); - EXPECT_EQ(attrs->se_scope, se_scope); + EXPECT_EQ(attrs->virtual_device, virtual_device); EXPECT_FALSE(attrs->constrain_result); EXPECT_TRUE(attrs->constrain_body); } TEST(OnDevice, Fixed) { Var body("x", {}); - SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); - Call call = OnDevice(body, se_scope, /*constrain_result=*/true); + VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3); + Call call = OnDevice(body, virtual_device, /*constrain_result=*/true); const auto* attrs = call->attrs.as(); ASSERT_TRUE(attrs != nullptr); EXPECT_TRUE(attrs->constrain_result); @@ -54,8 +54,8 @@ TEST(OnDevice, Fixed) { TEST(OnDevice, Free) { Var body("x", {}); - SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); - Call call = OnDevice(body, se_scope, /*constrain_result=*/false, /*constrain_body=*/false); + VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3); + Call call = OnDevice(body, virtual_device, /*constrain_result=*/false, /*constrain_body=*/false); const auto* attrs = call->attrs.as(); ASSERT_TRUE(attrs != nullptr); EXPECT_FALSE(attrs->constrain_result); @@ -64,23 +64,23 @@ TEST(OnDevice, Free) { TEST(GetOnDeviceProps, Correct) { Var body("x", {}); - SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); - Call call = OnDevice(body, se_scope, /*constrain_result=*/true, /*constrain_body=*/false); + VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3); + Call call = OnDevice(body, virtual_device, /*constrain_result=*/true, /*constrain_body=*/false); OnDeviceProps props = GetOnDeviceProps(call); ASSERT_TRUE(props.body.defined()); - ASSERT_EQ(props.se_scope, se_scope); + ASSERT_EQ(props.virtual_device, virtual_device); ASSERT_TRUE(props.constrain_result); ASSERT_FALSE(props.constrain_body); } TEST(MaybeOnDevice, Wrapped) { - SEScope se_scope = SEScope::ForDeviceType(kDLCPU, 3); + VirtualDevice virtual_device = VirtualDevice::ForDeviceType(kDLCPU, 3); Var body("x", {}); - Call inner = OnDevice(body, se_scope); - Call outer = OnDevice(inner, se_scope); + Call inner = OnDevice(body, virtual_device); + Call outer = OnDevice(inner, virtual_device); OnDeviceProps props = GetOnDeviceProps(outer); ASSERT_TRUE(props.body.defined()); - ASSERT_EQ(props.se_scope, se_scope); + ASSERT_EQ(props.virtual_device, virtual_device); ASSERT_FALSE(props.constrain_result); ASSERT_TRUE(props.constrain_body); } diff --git a/tests/cpp/relay/transforms/device_domains_test.cc b/tests/cpp/relay/transforms/device_domains_test.cc index 5df7984d003a..7314f6425189 100644 --- a/tests/cpp/relay/transforms/device_domains_test.cc +++ b/tests/cpp/relay/transforms/device_domains_test.cc @@ -45,8 +45,8 @@ IRModule TestModule() { } TEST(DeviceDomains, SmokeTest) { - SEScope cpu = SEScope::ForDeviceType(kDLCPU); - SEScope cuda = SEScope::ForDeviceType(kDLCUDA); + VirtualDevice cpu = VirtualDevice::ForDeviceType(kDLCPU); + VirtualDevice cuda = VirtualDevice::ForDeviceType(kDLCUDA); TargetMap target_map; target_map.Set(Integer(static_cast(kDLCPU)), Target("llvm")); target_map.Set(Integer(static_cast(kDLCUDA)), Target("cuda")); @@ -66,11 +66,11 @@ TEST(DeviceDomains, SmokeTest) { arg_and_results.push_back(result_domain); DeviceDomainPtr implied_add_domain = domains.MakeHigherOrderDomain(std::move(arg_and_results)); EXPECT_FALSE(domains.UnifyOrNull(actual_add_domain, implied_add_domain) == nullptr); - EXPECT_FALSE(domains.UnifyOrNull( - x_domain, domains.ForSEScope(f->params[0]->checked_type(), cuda)) == nullptr); + EXPECT_FALSE(domains.UnifyOrNull(x_domain, domains.ForVirtualDevice(f->params[0]->checked_type(), + cuda)) == nullptr); - EXPECT_EQ(domains.ResultSEScope(y_domain), config->CanonicalSEScope(cuda)); - EXPECT_EQ(domains.ResultSEScope(result_domain), config->CanonicalSEScope(cuda)); + EXPECT_EQ(domains.ResultVirtualDevice(y_domain), config->CanonicalVirtualDevice(cuda)); + EXPECT_EQ(domains.ResultVirtualDevice(result_domain), config->CanonicalVirtualDevice(cuda)); } } // namespace diff --git a/tests/cpp/target/compilation_config_test.cc b/tests/cpp/target/compilation_config_test.cc index 31b936807edc..2b1041b47d0b 100644 --- a/tests/cpp/target/compilation_config_test.cc +++ b/tests/cpp/target/compilation_config_test.cc @@ -48,9 +48,9 @@ TEST(CompilationConfig, Constructor_Homogeneous_FallbackCPUHost) { legacy_target_map.Set(Integer(static_cast(kDLCUDA)), cuda_target); CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); - SEScope expected_default_primitive_se_scope(kDLCUDA, 0, - Target::WithHost(cuda_target, host_target)); - SEScope expected_host_se_scope(kDLCPU, 0, host_target); + VirtualDevice expected_default_primitive_virtual_device( + kDLCUDA, 0, Target::WithHost(cuda_target, host_target)); + VirtualDevice expected_host_virtual_device(kDLCPU, 0, host_target); ASSERT_EQ(config->legacy_target_map.size(), 1); EXPECT_TRUE(StructuralEqual()((*config->legacy_target_map.begin()).second, @@ -60,9 +60,9 @@ TEST(CompilationConfig, Constructor_Homogeneous_FallbackCPUHost) { ASSERT_EQ(config->primitive_targets.size(), 1); EXPECT_TRUE( StructuralEqual()(config->primitive_targets[0], Target::WithHost(cuda_target, host_target))); - EXPECT_TRUE( - StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); - EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->default_primitive_virtual_device, + expected_default_primitive_virtual_device)); + EXPECT_TRUE(StructuralEqual()(config->host_virtual_device, expected_host_virtual_device)); ASSERT_TRUE(config->optional_homogeneous_target.defined()); EXPECT_TRUE(StructuralEqual()(config->optional_homogeneous_target, Target::WithHost(cuda_target, host_target))); @@ -107,9 +107,9 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_FallbackCPUHost) { Target::WithHost(cuda_target, host_target)); CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{}); - SEScope expected_default_primitive_se_scope(kDLCUDA, 0, - Target::WithHost(cuda_target, host_target)); - SEScope expected_host_se_scope(kDLCPU, 0, host_target); + VirtualDevice expected_default_primitive_virtual_device( + kDLCUDA, 0, Target::WithHost(cuda_target, host_target)); + VirtualDevice expected_host_virtual_device(kDLCPU, 0, host_target); ASSERT_EQ(config->legacy_target_map.size(), 2); for (const auto& pair : config->legacy_target_map) { @@ -121,9 +121,9 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_FallbackCPUHost) { } EXPECT_TRUE(config->host_target.defined()); EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); - EXPECT_TRUE( - StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); - EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->default_primitive_virtual_device, + expected_default_primitive_virtual_device)); + EXPECT_TRUE(StructuralEqual()(config->host_virtual_device, expected_host_virtual_device)); EXPECT_FALSE(config->optional_homogeneous_target.defined()); } @@ -140,9 +140,9 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_ExplicitHost) { Target::WithHost(cuda_target, host_target)); CompilationConfig config(pass_ctx, legacy_target_map, host_target); - SEScope expected_default_primitive_se_scope(kDLCUDA, 0, - Target::WithHost(cuda_target, host_target)); - SEScope expected_host_se_scope(kDLCPU, 0, host_target); + VirtualDevice expected_default_primitive_virtual_device( + kDLCUDA, 0, Target::WithHost(cuda_target, host_target)); + VirtualDevice expected_host_virtual_device(kDLCPU, 0, host_target); ASSERT_EQ(config->legacy_target_map.size(), 2); for (const auto& pair : config->legacy_target_map) { @@ -155,9 +155,9 @@ TEST(CompilationConfig, Constructor_Hetrogeneous_ExplicitHost) { EXPECT_TRUE(config->host_target.defined()); EXPECT_TRUE(StructuralEqual()(config->host_target, host_target)); ASSERT_EQ(config->primitive_targets.size(), 2); - EXPECT_TRUE( - StructuralEqual()(config->default_primitive_se_scope, expected_default_primitive_se_scope)); - EXPECT_TRUE(StructuralEqual()(config->host_se_scope, expected_host_se_scope)); + EXPECT_TRUE(StructuralEqual()(config->default_primitive_virtual_device, + expected_default_primitive_virtual_device)); + EXPECT_TRUE(StructuralEqual()(config->host_virtual_device, expected_host_virtual_device)); EXPECT_FALSE(config->optional_homogeneous_target.defined()); } @@ -188,40 +188,40 @@ TEST(CompilationConfig, Constructor_DefaultNoMatchingPrimitiveTarget) { CompilationConfig config(pass_ctx, legacy_target_map, /*optional_host_target_arg=*/{})); } -TEST(CompilationConfig, CanonicalSEScope) { +TEST(CompilationConfig, CanonicalVirtualDevice) { Target host_target = TestDefaultCpuTarget(); Target cuda_target = TestCudaTarget(); Target cpu_target = TestCpuTarget(); CompilationConfig config = TestCompilationConfig(); { - SEScope in = SEScope(kDLCPU); - SEScope actual = config->CanonicalSEScope(in); + VirtualDevice in = VirtualDevice(kDLCPU); + VirtualDevice actual = config->CanonicalVirtualDevice(in); ASSERT_TRUE(actual->target.defined()); EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cpu_target, host_target))); - EXPECT_EQ(config->CanonicalSEScope(in), actual); + EXPECT_EQ(config->CanonicalVirtualDevice(in), actual); } { - SEScope in = SEScope(kDLCUDA); - SEScope actual = config->CanonicalSEScope(in); + VirtualDevice in = VirtualDevice(kDLCUDA); + VirtualDevice actual = config->CanonicalVirtualDevice(in); ASSERT_TRUE(actual->target.defined()); EXPECT_TRUE(StructuralEqual()(actual->target, Target::WithHost(cuda_target, host_target))); - EXPECT_EQ(config->CanonicalSEScope(in), actual); + EXPECT_EQ(config->CanonicalVirtualDevice(in), actual); } } -TEST(CompilationConfig, CanonicalSEScope_NoDevice) { +TEST(CompilationConfig, CanonicalVirtualDevice_NoDevice) { CompilationConfig config = TestCompilationConfig(); - SEScope fully_unconstrained; - EXPECT_ANY_THROW(config->CanonicalSEScope(fully_unconstrained)); - SEScope missing_device(kInvalidDeviceType, 3, {}, "local"); - EXPECT_ANY_THROW(config->CanonicalSEScope(missing_device)); + VirtualDevice fully_unconstrained; + EXPECT_ANY_THROW(config->CanonicalVirtualDevice(fully_unconstrained)); + VirtualDevice missing_device(kInvalidDeviceType, 3, {}, "local"); + EXPECT_ANY_THROW(config->CanonicalVirtualDevice(missing_device)); } -TEST(CompilationConfig, CanonicalSEScope_NoMatchingTarget) { +TEST(CompilationConfig, CanonicalVirtualDevice_NoMatchingTarget) { CompilationConfig config = TestCompilationConfig(); - SEScope no_such_target(kDLMetal); - EXPECT_ANY_THROW(config->CanonicalSEScope(no_such_target)); + VirtualDevice no_such_target(kDLMetal); + EXPECT_ANY_THROW(config->CanonicalVirtualDevice(no_such_target)); } } // namespace diff --git a/tests/cpp/target/se_scope_test.cc b/tests/cpp/target/se_scope_test.cc deleted file mode 100644 index 166ba46faf37..000000000000 --- a/tests/cpp/target/se_scope_test.cc +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include - -namespace tvm { -namespace { - -TEST(SEScope, Join_Defined) { - { - Target target_a = Target("cuda"); - SEScope lhs = SEScope(kDLCUDA, 3); - SEScope rhs = SEScope(kDLCUDA, -1, target_a, "global"); - Optional actual = SEScope::Join(lhs, rhs); - EXPECT_TRUE(actual.operator bool()); - SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); - EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); - } - { - Target target_a = Target("cuda"); - SEScope lhs = SEScope(kDLCUDA, -1, target_a, "global"); - SEScope rhs = SEScope(kDLCUDA, 3); - Optional actual = SEScope::Join(lhs, rhs); - EXPECT_TRUE(actual.operator bool()); - SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); - EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); - } - { - Target target_a = Target("cuda"); - SEScope lhs = SEScope(kDLCUDA); - SEScope rhs = SEScope(kDLCUDA, 2, target_a); - Optional actual = SEScope::Join(lhs, rhs); - EXPECT_TRUE(actual.operator bool()); - SEScope expected = SEScope(kDLCUDA, 2, target_a); - EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); - } - { - Target target_a = Target("cuda"); - SEScope lhs = SEScope(); - SEScope rhs = SEScope(kDLCUDA, 3, target_a, "global"); - Optional actual = SEScope::Join(lhs, rhs); - EXPECT_TRUE(actual.operator bool()); - SEScope expected = rhs; - EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); - } -} - -TEST(SEScope, Join_Undefined) { - { - SEScope lhs = SEScope(kDLCUDA); - SEScope rhs = SEScope(kDLCPU); - Optional actual = SEScope::Join(lhs, rhs); - EXPECT_FALSE(actual); - } - { - SEScope lhs = SEScope(kDLCUDA, 3); - SEScope rhs = SEScope(kDLCUDA, 4); - Optional actual = SEScope::Join(lhs, rhs); - EXPECT_FALSE(actual); - } - { - SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda")); - SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda")); - Optional actual = SEScope::Join(lhs, rhs); - EXPECT_FALSE(actual); - } - { - SEScope lhs = SEScope(kDLCUDA, 3, Target("cuda"), "local"); - SEScope rhs = SEScope(kDLCUDA, 3, Target("cuda"), "global"); - Optional actual = SEScope::Join(lhs, rhs); - EXPECT_FALSE(actual); - } -} - -TEST(SEScope, Default) { - Target target_a = Target("cuda"); - SEScope lhs = SEScope(kDLCUDA, -1, Target(), "global"); - SEScope rhs = SEScope(kDLCUDA, 3, target_a, "local"); - SEScope actual = SEScope::Default(lhs, rhs); - SEScope expected = SEScope(kDLCUDA, 3, target_a, "global"); - EXPECT_TRUE(StructuralEqual()(actual, expected)); -} - -TEST(SEScope, Constructor_Invalid) { EXPECT_ANY_THROW(SEScope(kDLCPU, -1, Target("cuda"))); } - -TEST(SEScopeCache, Memoized) { - SEScopeCache cache; - Target target_a = Target("cuda"); - Target target_b = Target("llvm"); - SEScope se_scope_a = cache.Make(kDLCUDA, 3, target_a, "local"); - SEScope se_scope_b = cache.Make(kDLCPU, 1, target_b, "global"); - - EXPECT_EQ(cache.Make(kDLCUDA, 3, target_a, "local"), se_scope_a); - EXPECT_EQ(cache.Make(kDLCPU, 1, target_b, "global"), se_scope_b); - EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), se_scope_a); - EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), se_scope_a); - EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), se_scope_a); -} - -} // namespace -} // namespace tvm diff --git a/tests/cpp/target/virtual_device_test.cc b/tests/cpp/target/virtual_device_test.cc new file mode 100644 index 000000000000..35e078713d1b --- /dev/null +++ b/tests/cpp/target/virtual_device_test.cc @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +namespace tvm { +namespace { + +TEST(VirtualDevice, Join_Defined) { + { + Target target_a = Target("cuda"); + VirtualDevice lhs = VirtualDevice(kDLCUDA, 3); + VirtualDevice rhs = VirtualDevice(kDLCUDA, -1, target_a, "global"); + Optional actual = VirtualDevice::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + VirtualDevice lhs = VirtualDevice(kDLCUDA, -1, target_a, "global"); + VirtualDevice rhs = VirtualDevice(kDLCUDA, 3); + Optional actual = VirtualDevice::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + VirtualDevice lhs = VirtualDevice(kDLCUDA); + VirtualDevice rhs = VirtualDevice(kDLCUDA, 2, target_a); + Optional actual = VirtualDevice::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + VirtualDevice expected = VirtualDevice(kDLCUDA, 2, target_a); + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } + { + Target target_a = Target("cuda"); + VirtualDevice lhs = VirtualDevice(); + VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, target_a, "global"); + Optional actual = VirtualDevice::Join(lhs, rhs); + EXPECT_TRUE(actual.operator bool()); + VirtualDevice expected = rhs; + EXPECT_TRUE(StructuralEqual()(actual.value(), expected)); + } +} + +TEST(VirtualDevice, Join_Undefined) { + { + VirtualDevice lhs = VirtualDevice(kDLCUDA); + VirtualDevice rhs = VirtualDevice(kDLCPU); + Optional actual = VirtualDevice::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + VirtualDevice lhs = VirtualDevice(kDLCUDA, 3); + VirtualDevice rhs = VirtualDevice(kDLCUDA, 4); + Optional actual = VirtualDevice::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + VirtualDevice lhs = VirtualDevice(kDLCUDA, 3, Target("cuda")); + VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, Target("cuda")); + Optional actual = VirtualDevice::Join(lhs, rhs); + EXPECT_FALSE(actual); + } + { + VirtualDevice lhs = VirtualDevice(kDLCUDA, 3, Target("cuda"), "local"); + VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, Target("cuda"), "global"); + Optional actual = VirtualDevice::Join(lhs, rhs); + EXPECT_FALSE(actual); + } +} + +TEST(VirtualDevice, Default) { + Target target_a = Target("cuda"); + VirtualDevice lhs = VirtualDevice(kDLCUDA, -1, Target(), "global"); + VirtualDevice rhs = VirtualDevice(kDLCUDA, 3, target_a, "local"); + VirtualDevice actual = VirtualDevice::Default(lhs, rhs); + VirtualDevice expected = VirtualDevice(kDLCUDA, 3, target_a, "global"); + EXPECT_TRUE(StructuralEqual()(actual, expected)); +} + +TEST(VirtualDevice, Constructor_Invalid) { + EXPECT_ANY_THROW(VirtualDevice(kDLCPU, -1, Target("cuda"))); +} + +TEST(VirtualDeviceCache, Memoized) { + VirtualDeviceCache cache; + Target target_a = Target("cuda"); + Target target_b = Target("llvm"); + VirtualDevice virtual_device_a = cache.Make(kDLCUDA, 3, target_a, "local"); + VirtualDevice virtual_device_b = cache.Make(kDLCPU, 1, target_b, "global"); + + EXPECT_EQ(cache.Make(kDLCUDA, 3, target_a, "local"), virtual_device_a); + EXPECT_EQ(cache.Make(kDLCPU, 1, target_b, "global"), virtual_device_b); + EXPECT_NE(cache.Make(kDLCUDA, 2, target_a, "local"), virtual_device_a); + EXPECT_NE(cache.Make(kDLCPU, 3, target_b, "local"), virtual_device_a); + EXPECT_NE(cache.Make(kDLCUDA, 3, target_a, "global"), virtual_device_a); +} + +} // namespace +} // namespace tvm diff --git a/tests/python/relay/op/annotation/test_annotation.py b/tests/python/relay/op/annotation/test_annotation.py index 5ad2a59e39ab..2352821f7bee 100644 --- a/tests/python/relay/op/annotation/test_annotation.py +++ b/tests/python/relay/op/annotation/test_annotation.py @@ -26,10 +26,10 @@ def test_on_device_via_string(): assert isinstance(call, relay.Call) assert len(call.args) == 1 assert call.args[0] == x - assert call.attrs.se_scope.device_type_int == 2 # ie kDLCUDA - assert call.attrs.se_scope.virtual_device_id == 0 - assert call.attrs.se_scope.target is None - assert call.attrs.se_scope.memory_scope == "" + assert call.attrs.virtual_device.device_type_int == 2 # ie kDLCUDA + assert call.attrs.virtual_device.virtual_device_id == 0 + assert call.attrs.virtual_device.target is None + assert call.attrs.virtual_device.memory_scope == "" assert call.attrs.constrain_body assert not call.attrs.constrain_result @@ -37,7 +37,7 @@ def test_on_device_via_string(): def test_on_device_via_device(): x = relay.Var("x") call = relay.annotation.on_device(x, tvm.device("cpu")) - assert call.attrs.se_scope.device_type_int == 1 # ie kDLCPU + assert call.attrs.virtual_device.device_type_int == 1 # ie kDLCPU def test_on_device_invalid_device(): @@ -48,7 +48,7 @@ def test_on_device_invalid_device(): def test_on_device_fixed(): x = relay.Var("x") call = relay.annotation.on_device(x, "cuda", constrain_result=True) - assert call.attrs.se_scope.device_type_int == 2 # ie kDLCUDA + assert call.attrs.virtual_device.device_type_int == 2 # ie kDLCUDA assert call.attrs.constrain_body assert call.attrs.constrain_result @@ -56,7 +56,7 @@ def test_on_device_fixed(): def test_on_device_free(): x = relay.Var("x") call = relay.annotation.on_device(x, "cuda", constrain_result=False, constrain_body=False) - assert call.attrs.se_scope.device_type_int == -1 # ie kInvalidDeviceType + assert call.attrs.virtual_device.device_type_int == -1 # ie kInvalidDeviceType assert not call.attrs.constrain_body assert not call.attrs.constrain_result @@ -67,10 +67,10 @@ def test_function_on_device(): f = relay.Function([x, y], relay.add(x, y)) func = relay.annotation.function_on_device(f, ["cpu", "cuda"], "cuda") assert isinstance(func, relay.Function) - assert len(func.attrs["param_se_scopes"]) == 2 - assert func.attrs["param_se_scopes"][0].device_type_int == 1 # ie kDLCPU - assert func.attrs["param_se_scopes"][1].device_type_int == 2 # ie kDLCUDA - assert func.attrs["result_se_scope"].device_type_int == 2 # ie KDLCUDA + assert len(func.attrs["param_virtual_devices"]) == 2 + assert func.attrs["param_virtual_devices"][0].device_type_int == 1 # ie kDLCPU + assert func.attrs["param_virtual_devices"][1].device_type_int == 2 # ie kDLCUDA + assert func.attrs["result_virtual_device"].device_type_int == 2 # ie KDLCUDA if __name__ == "__main__": diff --git a/tests/python/relay/op/test_tensor.py b/tests/python/relay/op/test_tensor.py index 4d2c1766972a..2d561cf79eae 100644 --- a/tests/python/relay/op/test_tensor.py +++ b/tests/python/relay/op/test_tensor.py @@ -26,14 +26,14 @@ def test_device_copy_via_string(): assert isinstance(call, relay.Call) assert len(call.args) == 1 assert call.args[0] == x - assert call.attrs.src_se_scope.device_type_int == 2 # ie kDLCUDA - assert call.attrs.src_se_scope.virtual_device_id == 0 - assert call.attrs.src_se_scope.target is None - assert call.attrs.src_se_scope.memory_scope == "" - assert call.attrs.dst_se_scope.device_type_int == 1 # ie kDLCPU - assert call.attrs.dst_se_scope.virtual_device_id == 0 - assert call.attrs.dst_se_scope.target is None - assert call.attrs.dst_se_scope.memory_scope == "" + assert call.attrs.src_virtual_device.device_type_int == 2 # ie kDLCUDA + assert call.attrs.src_virtual_device.virtual_device_id == 0 + assert call.attrs.src_virtual_device.target is None + assert call.attrs.src_virtual_device.memory_scope == "" + assert call.attrs.dst_virtual_device.device_type_int == 1 # ie kDLCPU + assert call.attrs.dst_virtual_device.virtual_device_id == 0 + assert call.attrs.dst_virtual_device.target is None + assert call.attrs.dst_virtual_device.memory_scope == "" def test_device_copy_via_device(): @@ -42,8 +42,8 @@ def test_device_copy_via_device(): assert isinstance(call, relay.Call) assert len(call.args) == 1 assert call.args[0] == x - assert call.attrs.src_se_scope.device_type_int == 2 # ie kDLCUDA - assert call.attrs.dst_se_scope.device_type_int == 1 # ie kDLCPU + assert call.attrs.src_virtual_device.device_type_int == 2 # ie kDLCUDA + assert call.attrs.dst_virtual_device.device_type_int == 1 # ie kDLCPU if __name__ == "__main__": diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index 3893da45fcaa..bc19bcdb1739 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -19,8 +19,8 @@ from tvm.relay.testing import inception_v3 import pytest -cpu_scope = tvm.target.make_se_scope(tvm.cpu(), tvm.target.Target("llvm")) -metatable = {"SEScope": [cpu_scope]} +cpu_scope = tvm.target.make_virtual_device(tvm.cpu(), tvm.target.Target("llvm")) +metatable = {"VirtualDevice": [cpu_scope]} core = tvm.IRModule() core.import_from_std("core.rly") @@ -234,7 +234,7 @@ def test_impure_op(): def @main() { let %size: int64 = cast(1024, dtype="int64"); let %alignment: int64 = cast(64, dtype="int64"); - let %x = memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][0]); + let %x = memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][0]); 0 } """, @@ -249,7 +249,7 @@ def @main() { def @main() { let %x = memory.alloc_storage(cast(1024, dtype="int64"), cast(64, dtype="int64"), - se_scope=meta[SEScope][0]); + virtual_device=meta[VirtualDevice][0]); 0 } """, @@ -271,7 +271,7 @@ def test_impure_func(): def @f() -> int { let %size: int64 = cast(1024, dtype="int64"); let %alignment: int64 = cast(64, dtype="int64"); - let %x = memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][0]); + let %x = memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][0]); 0 } def @main() -> int { @@ -290,7 +290,7 @@ def @main() -> int { def @f() -> int { let %x = memory.alloc_storage(cast(1024, dtype="int64"), cast(64, dtype="int64"), - se_scope=meta[SEScope][0]); + virtual_device=meta[VirtualDevice][0]); 0 } def @main() -> int { diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py index ee9cfc909585..82e40af389b2 100644 --- a/tests/python/relay/test_pass_plan_devices.py +++ b/tests/python/relay/test_pass_plan_devices.py @@ -41,13 +41,13 @@ tvm.tir.IntImm("int32", GPU_DEVICE.device_type): GPU_TARGET, } -HOST = tvm.target.make_se_scope(HOST_DEVICE, HOST_TARGET) # device_type=1 -CPU = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET) # device_type=1 -GPU = tvm.target.make_se_scope(GPU_DEVICE, GPU_TARGET) # device_type=2 +HOST = tvm.target.make_virtual_device(HOST_DEVICE, HOST_TARGET) # device_type=1 +CPU = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET) # device_type=1 +GPU = tvm.target.make_virtual_device(GPU_DEVICE, GPU_TARGET) # device_type=2 DEFAULT = GPU -CPU_SCOPE_A = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET, memory_scope="scopeA") -CPU_SCOPE_B = tvm.target.make_se_scope(CPU_DEVICE, CPU_TARGET, memory_scope="scopeB") +CPU_SCOPE_A = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET, memory_scope="scopeA") +CPU_SCOPE_B = tvm.target.make_virtual_device(CPU_DEVICE, CPU_TARGET, memory_scope="scopeB") CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int}) @@ -109,7 +109,7 @@ def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, a def test_plain(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # Everything defaults to GPU def input(): @@ -134,8 +134,8 @@ def expected(): #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][1], meta[SEScope][1]], - result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); %1 = add(%c, %d); subtract(%0, %1) @@ -153,7 +153,7 @@ def ref(a, b, c, d): def test_left_add_on_cpu(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # Force some args to be on CPU, rest default to GPU. def input(): @@ -163,7 +163,7 @@ def input(): def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]); %2 = add(%c, %d); subtract(%1, %2) } @@ -179,11 +179,11 @@ def expected(): #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], - result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); - %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); %3 = add(%c, %d); subtract(%2, %3) } @@ -200,7 +200,7 @@ def ref(a, b, c, d): def test_left_add_on_cpu_via_copy(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # As for test_left_add_on_cpu, but with an explicit device_copy. def input(): @@ -210,7 +210,7 @@ def input(): def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); - %1 = device_copy(%0, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %1 = device_copy(%0, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); %2 = add(%c, %d); subtract(%1, %2) } @@ -226,11 +226,11 @@ def expected(): #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], - result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); - %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); %3 = add(%c, %d); subtract(%2, %3) } @@ -247,7 +247,7 @@ def ref(a, b, c, d): def test_both_adds_on_cpu(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -257,8 +257,8 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); %1 = add(%c, %d); - %2 = on_device(%0, se_scope=meta[SEScope][0]); - %3 = on_device(%1, se_scope=meta[SEScope][0]); + %2 = on_device(%0, virtual_device=meta[VirtualDevice][0]); + %3 = on_device(%1, virtual_device=meta[VirtualDevice][0]); subtract(%2, %3) } """, @@ -273,14 +273,14 @@ def expected(): #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], - result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], + result_virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %2 = add(%c, %d); - %3 = on_device(%2, se_scope=meta[SEScope][0], constrain_result=True); - %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %3 = on_device(%2, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %4 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); + %5 = device_copy(%3, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); subtract(%4, %5) } """, @@ -296,7 +296,7 @@ def ref(a, b, c, d): def test_sharing(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # The same add sub-expression is annotated twice. def input(): @@ -305,8 +305,8 @@ def input(): #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0]); - %2 = on_device(%0, se_scope=meta[SEScope][0]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]); + %2 = on_device(%0, virtual_device=meta[VirtualDevice][0]); subtract(%1, %2) } """, @@ -320,12 +320,12 @@ def expected(): """ #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); - %2 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); - %3 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %4 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %2 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %3 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); + %4 = device_copy(%2, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); subtract(%3, %4) } """, @@ -342,7 +342,7 @@ def ref(a, b): def test_let_on_cpu(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # The device for a let-bound expression can flow from uses of the let-bound var. def input(): @@ -353,7 +353,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { let %l = add(%a, %b); let %r = add(%c, %d); - %0 = on_device(%l, se_scope=meta[SEScope][0]); + %0 = on_device(%l, virtual_device=meta[VirtualDevice][0]); subtract(%0, %r) } """, @@ -368,12 +368,12 @@ def expected(): #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], - result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][1]) { %0 = add(%a, %b); - let %l = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); - let %r = on_device(add(%c, %d), se_scope=meta[SEScope][1], constrain_result=True); - %1 = device_copy(%l, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + let %l = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + let %r = on_device(add(%c, %d), virtual_device=meta[VirtualDevice][1], constrain_result=True); + %1 = device_copy(%l, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); subtract(%1, %r) } """, @@ -389,7 +389,7 @@ def ref(a, b, c, d): def test_func_param_on_cpu(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # Devices for function parameters flow to call sites. def input(): @@ -400,7 +400,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { let %f = fn (%x, %y) { %0 = add(%x, %y); - on_device(%0, se_scope=meta[SEScope][0]) + on_device(%0, virtual_device=meta[VirtualDevice][0]) }; %1 = %f(%a, %b); %2 = add(%c, %d); @@ -418,10 +418,10 @@ def expected(): #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], - result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], + result_virtual_device=meta[VirtualDevice][0]) { let %f = fn (%x, %y, - param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { add(%x, %y) }; %0 = %f(%a, %b); @@ -441,7 +441,7 @@ def ref(a, b, c, d): def test_func_result_on_cpu(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # Devices for call sites flow to function results. def input(): @@ -454,7 +454,7 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], add(%x, %y) }; %0 = %f(%a, %b); - %1 = on_device(%0, se_scope=meta[SEScope][0]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]); %2 = add(%c, %d); subtract(%1, %2) } @@ -470,15 +470,15 @@ def expected(): #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], - result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][1]) { let %f = fn (%x, %y, - param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { add(%x, %y) }; %1 = %f(%a, %b); - %2 = on_device(%1, se_scope=meta[SEScope][0], constrain_result=True); - %3 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %2 = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %3 = device_copy(%2, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); %4 = add(%c, %d); subtract(%3, %4) } @@ -495,7 +495,7 @@ def ref(a, b, c, d): def test_higher_order(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # The constraint on %a flows back to %y via %f and %h def input(): @@ -505,7 +505,7 @@ def input(): def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { let %f = fn (%g) { fn (%a) { - %0 = on_device(%a, se_scope=meta[SEScope][0]); + %0 = on_device(%a, virtual_device=meta[VirtualDevice][0]); %1 = %g(%0); add(%1, %x) } @@ -528,15 +528,15 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { - let %f = fn (%g, param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { - fn (%a, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { - %0 = device_copy(%a, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + let %f = fn (%g, param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) { + fn (%a, param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + %0 = device_copy(%a, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); %1 = %g(%0); add(%1, %x) } }; - let %h = fn (%b, param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { + let %h = fn (%b, param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) { negative(%b) }; %2 = %f(%h); @@ -562,7 +562,7 @@ def h(b): def test_function_in_tuple(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # Since %f ends up in a tuple its argument and result is forced to be on the CPU def input(): @@ -571,7 +571,7 @@ def input(): #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { - %0 = on_device(%b, se_scope=meta[SEScope][0]); + %0 = on_device(%b, virtual_device=meta[VirtualDevice][0]); add(%a, %0) }; let %t = (%f, %x); @@ -590,12 +590,12 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { add(%a, %b) }; - let %t = on_device((%f, %x), se_scope=meta[SEScope][0], constrain_result=True); + let %t = on_device((%f, %x), virtual_device=meta[VirtualDevice][0], constrain_result=True); %0 = %t.1; %1 = %t.0; %1(%0, %y) @@ -614,14 +614,14 @@ def ref(x, y): def test_device_copy(): const = rand((5, 7)) - metatable = {"SEScope": [CPU, GPU], "relay.Constant": [relay.const(const)]} + metatable = {"VirtualDevice": [CPU, GPU], "relay.Constant": [relay.const(const)]} def input(): return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32]) { - %0 = device_copy(%x, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %0 = device_copy(%x, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); add(%0, meta[relay.Constant][0]) } """, @@ -635,8 +635,8 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { - %0 = device_copy(%x, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + %0 = device_copy(%x, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); add(%0, meta[relay.Constant][0]) } """, @@ -652,7 +652,7 @@ def ref(x): def test_shape_of(): - metatable = {"SEScope": [HOST, GPU]} + metatable = {"VirtualDevice": [HOST, GPU]} # We need to use constrain_result=True in the on_device call so that the tensor will be on the GPU. Otherwise the # result defaults to the result device for @main which is the CPU, thus forcing a copy. @@ -662,7 +662,7 @@ def input(): """ #[version = "0.0.5"] def @main(%x: Tensor[(?, ?), float32]) { - %0 = on_device(%x, se_scope=meta[SEScope][1], constrain_result=True); + %0 = on_device(%x, virtual_device=meta[VirtualDevice][1], constrain_result=True); vm.shape_of(%0, dtype="int64") } """, @@ -676,7 +676,7 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(?, ?), float32], - param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][0]) { vm.shape_of(%x, dtype="int64") } """, @@ -692,14 +692,14 @@ def ref(x): def test_alloc_storage(): - metatable = {"SEScope": [HOST, GPU]} + metatable = {"VirtualDevice": [HOST, GPU]} def input(): return tvm.parser.parse( """ #[version = "0.0.5"] def @main(%size: int64, %alignment: int64) { - memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][1]) + memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][1]) } """, "from_string", @@ -712,8 +712,8 @@ def expected(): """ #[version = "0.0.5"] def @main(%size: int64, %alignment: int64, - param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { - memory.alloc_storage(%size, %alignment, se_scope=meta[SEScope][1]) + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + memory.alloc_storage(%size, %alignment, virtual_device=meta[VirtualDevice][1]) } """, "from_string", @@ -727,7 +727,10 @@ def @main(%size: int64, %alignment: int64, def test_alloc_tensor(): shape = np.array([3, 2]) - metatable = {"SEScope": [HOST, GPU], "relay.Constant": [relay.const(shape, dtype="int64")]} + metatable = { + "VirtualDevice": [HOST, GPU], + "relay.Constant": [relay.const(shape, dtype="int64")], + } def input(): return tvm.parser.parse( @@ -747,9 +750,9 @@ def expected(): return tvm.parser.parse( """ #[version = "0.0.5"] - def @main(%sto: Storage[], param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { - %0 = on_device(0, se_scope=meta[SEScope][0], constrain_result=True); - %1 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], constrain_result=True); + def @main(%sto: Storage[], param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) { + %0 = on_device(0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %1 = on_device(meta[relay.Constant][0], virtual_device=meta[VirtualDevice][0], constrain_result=True); memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) } """, @@ -764,7 +767,10 @@ def @main(%sto: Storage[], param_se_scopes=[meta[SEScope][1]], result_se_scope=m def test_reshape_tensor(): newshape = [2, 4, 2] - metatable = {"SEScope": [HOST, GPU], "relay.Constant": [relay.const(newshape, dtype="int64")]} + metatable = { + "VirtualDevice": [HOST, GPU], + "relay.Constant": [relay.const(newshape, dtype="int64")], + } def input(): return tvm.parser.parse( @@ -784,8 +790,8 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(2, 8), float32], - param_se_scopes=[meta[SEScope][1]], result_se_scope=meta[SEScope][1]) { - %0 = on_device(meta[relay.Constant][0], se_scope=meta[SEScope][0], constrain_result=True); + param_virtual_devices=[meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][1]) { + %0 = on_device(meta[relay.Constant][0], virtual_device=meta[VirtualDevice][0], constrain_result=True); vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) } """, @@ -801,7 +807,7 @@ def ref(x): def test_dynamic_input(): - metatable = {"SEScope": [GPU]} + metatable = {"VirtualDevice": [GPU]} # There's nothing special about inferring devices for partially unknown types. def input(): @@ -822,7 +828,7 @@ def expected(): """ #[version = "0.0.5"] def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { add(%x0, %x1) } """, @@ -838,7 +844,7 @@ def ref(x0, x1): def test_redundant_annotation(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -846,9 +852,9 @@ def input(): #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][0]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]); %2 = subtract(%1, %z); - %3 = on_device(%0, se_scope=meta[SEScope][0]); + %3 = on_device(%0, virtual_device=meta[VirtualDevice][0]); add(%2, %3) } """, @@ -862,14 +868,14 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1]], - result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][1]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); - %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %3 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); + %3 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %4 = subtract(%2, %z); - %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %5 = device_copy(%3, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); add(%4, %5) } """, @@ -886,7 +892,7 @@ def ref(x, y, z): def test_annotate_expr(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -894,9 +900,9 @@ def input(): #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][1]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][1]); %2 = subtract(%1, %z); - on_device(%2, se_scope=meta[SEScope][0]) + on_device(%2, virtual_device=meta[VirtualDevice][0]) } """, "from_string", @@ -909,11 +915,11 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][1], meta[SEScope][1], meta[SEScope][0]], - result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1], meta[VirtualDevice][0]], + result_virtual_device=meta[VirtualDevice][0]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True); - %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True); + %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]); subtract(%2, %z) } """, @@ -929,7 +935,7 @@ def ref(x, y, z): def test_annotate_all(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -937,9 +943,9 @@ def input(): #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][0]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]); %2 = subtract(%1, %z); - on_device(%2, se_scope=meta[SEScope][0]) + on_device(%2, virtual_device=meta[VirtualDevice][0]) } """, "from_string", @@ -952,8 +958,8 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], - result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], + result_virtual_device=meta[VirtualDevice][0]) { %0 = add(%x, %y); subtract(%0, %z) } @@ -982,7 +988,7 @@ def test_conv_network(): | <--- CPU """ - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -992,12 +998,12 @@ def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 5 %weight: Tensor[(64, 64, 3, 3), float32]) { %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); %1 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %2 = on_device(%0, se_scope=meta[SEScope][0]); - %3 = on_device(%1, se_scope=meta[SEScope][0]); + %2 = on_device(%0, virtual_device=meta[VirtualDevice][0]); + %3 = on_device(%1, virtual_device=meta[VirtualDevice][0]); %4 = add(%2, %3); - %5 = on_device(%4, se_scope=meta[SEScope][1]); + %5 = on_device(%4, virtual_device=meta[VirtualDevice][1]); %6 = nn.conv2d(%5, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - on_device(%6, se_scope=meta[SEScope][0]) + on_device(%6, virtual_device=meta[VirtualDevice][0]) } """, "from_string", @@ -1011,17 +1017,17 @@ def expected(): #[version = "0.0.5"] def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], %weight: Tensor[(64, 64, 3, 3), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], - result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], + result_virtual_device=meta[VirtualDevice][0]) { %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); - %3 = on_device(%2, se_scope=meta[SEScope][0], constrain_result=True); - %4 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %5 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %3 = on_device(%2, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %4 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); + %5 = device_copy(%3, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); %6 = add(%4, %5); - %7 = on_device(%6, se_scope=meta[SEScope][1], constrain_result=True); - %8 = device_copy(%7, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); + %7 = on_device(%6, virtual_device=meta[VirtualDevice][1], constrain_result=True); + %8 = device_copy(%7, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]); nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) } """, @@ -1035,7 +1041,7 @@ def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 5 def test_tuple_get_item(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # Note that the device copy should be placed after projection rather than before. This is handled by # a heuristic in the pass. @@ -1045,12 +1051,12 @@ def input(): #[version = "0.0.5"] def @main(%x: Tensor[(3, 3, 4), float32]) { let %t = split(%x, indices_or_sections=3); - %0 = on_device(%t, se_scope=meta[SEScope][0]); - %1 = on_device(%t, se_scope=meta[SEScope][0]); + %0 = on_device(%t, virtual_device=meta[VirtualDevice][0]); + %1 = on_device(%t, virtual_device=meta[VirtualDevice][0]); %2 = %0.0; %3 = %1.1; %4 = subtract(%2, %3); - on_device(%4, se_scope=meta[SEScope][1]) + on_device(%4, virtual_device=meta[VirtualDevice][1]) } """, "from_string", @@ -1063,15 +1069,15 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(3, 3, 4), float32], - param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { %0 = split(%x, indices_or_sections=3); - let %t = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); + let %t = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); %1 = %t.0; - %2 = on_device(%1, se_scope=meta[SEScope][0], constrain_result=True); + %2 = on_device(%1, virtual_device=meta[VirtualDevice][0], constrain_result=True); %3 = %t.1; - %4 = on_device(%3, se_scope=meta[SEScope][0], constrain_result=True); - %5 = device_copy(%2, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %6 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %4 = on_device(%3, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %5 = device_copy(%2, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); + %6 = device_copy(%4, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); subtract(%5, %6) } """, @@ -1101,7 +1107,7 @@ def test_propogation(): | <--- CPU """ - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -1109,16 +1115,16 @@ def input(): #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32]) { %0 = negative(%x); - %1 = on_device(%0, se_scope=meta[SEScope][0]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]); %2 = negative(%1); - %3 = on_device(%0, se_scope=meta[SEScope][0]); + %3 = on_device(%0, virtual_device=meta[VirtualDevice][0]); %4 = negative(%3); - %5 = on_device(%2, se_scope=meta[SEScope][1]); - %6 = on_device(%4, se_scope=meta[SEScope][1]); + %5 = on_device(%2, virtual_device=meta[VirtualDevice][1]); + %6 = on_device(%4, virtual_device=meta[VirtualDevice][1]); %7 = add(%5, %6); - %8 = on_device(%7, se_scope=meta[SEScope][1]); + %8 = on_device(%7, virtual_device=meta[VirtualDevice][1]); %9 = negative(%8); - on_device(%9, se_scope=meta[SEScope][0]) + on_device(%9, virtual_device=meta[VirtualDevice][0]) } """, "from_string", @@ -1131,17 +1137,17 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { %0 = negative(%x); - %1 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); - %2 = device_copy(%1, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - %3 = on_device(%0, se_scope=meta[SEScope][0], constrain_result=True); - %4 = device_copy(%3, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); + %3 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %4 = device_copy(%3, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); %5 = negative(%2); %6 = negative(%4); %7 = add(%5, %6); - %8 = on_device(%7, se_scope=meta[SEScope][1], constrain_result=True); - %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); + %8 = on_device(%7, virtual_device=meta[VirtualDevice][1], constrain_result=True); + %9 = device_copy(%8, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]); negative(%9) } """, @@ -1173,7 +1179,7 @@ def test_fusible_network(): | <--- CPU """ - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -1181,14 +1187,14 @@ def input(): #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][1]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][1]); %2 = negative(%1); - %3 = on_device(%2, se_scope=meta[SEScope][0]); + %3 = on_device(%2, virtual_device=meta[VirtualDevice][0]); %4 = negative(%0); %5 = add(%3, %4); - %6 = on_device(%5, se_scope=meta[SEScope][1]); + %6 = on_device(%5, virtual_device=meta[VirtualDevice][1]); %7 = negative(%6); - on_device(%7, se_scope=meta[SEScope][0]) + on_device(%7, virtual_device=meta[VirtualDevice][0]) } """, "from_string", @@ -1201,17 +1207,17 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][1], meta[SEScope][1]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][1]], result_virtual_device=meta[VirtualDevice][0]) { %0 = add(%x, %y); - %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True); - %2 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True); + %2 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]); %3 = negative(%2); - %4 = on_device(%3, se_scope=meta[SEScope][0], constrain_result=True); - %5 = device_copy(%4, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + %4 = on_device(%3, virtual_device=meta[VirtualDevice][0], constrain_result=True); + %5 = device_copy(%4, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); %6 = negative(%0); %7 = add(%5, %6); - %8 = on_device(%7, se_scope=meta[SEScope][1], constrain_result=True); - %9 = device_copy(%8, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); + %8 = on_device(%7, virtual_device=meta[VirtualDevice][1], constrain_result=True); + %9 = device_copy(%8, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]); negative(%9) } """, @@ -1241,7 +1247,7 @@ def test_unpropagatable_graph(): | <--- CPU """ - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -1251,10 +1257,10 @@ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { %0 = add(%a, %b); %1 = multiply(%c, %d); - %2 = on_device(%0, se_scope=meta[SEScope][0]); - %3 = on_device(%1, se_scope=meta[SEScope][1]); + %2 = on_device(%0, virtual_device=meta[VirtualDevice][0]); + %3 = on_device(%1, virtual_device=meta[VirtualDevice][1]); %4 = subtract(%2, %3); - on_device(%4, se_scope=meta[SEScope][0]) + on_device(%4, virtual_device=meta[VirtualDevice][0]) } """, "from_string", @@ -1268,12 +1274,12 @@ def expected(): #[version = "0.0.5"] def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][1], meta[SEScope][1]], - result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][0]) { %0 = multiply(%c, %d); - %1 = on_device(%0, se_scope=meta[SEScope][1], constrain_result=True); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][1], constrain_result=True); %2 = add(%a, %b); - %3 = device_copy(%1, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][0]); + %3 = device_copy(%1, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][0]); subtract(%2, %3) } """, @@ -1289,7 +1295,7 @@ def ref(a, b, c, d): def test_conditional(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} # The conditional is over a function type, thus exercising the first-order/higher-order domain handling. def input(): @@ -1298,7 +1304,7 @@ def input(): #[version = "0.0.5"] def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { let %f = fn (%a) { - %0 = on_device(%y, se_scope=meta[SEScope][0], constrain_result=True); + %0 = on_device(%y, virtual_device=meta[VirtualDevice][0], constrain_result=True); add(%a, %0) }; let %g = fn (%a1) { @@ -1322,19 +1328,19 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0], meta[SEScope][0]], - result_se_scope=meta[SEScope][0]) { - let %f = fn (%a, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0], meta[VirtualDevice][0]], + result_virtual_device=meta[VirtualDevice][0]) { + let %f = fn (%a, param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { add(%a, %y) }; - let %g = fn (%a1, param_se_scopes=[meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + let %g = fn (%a1, param_virtual_devices=[meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { subtract(%a1, %y) }; let %h = on_device(if (%x) { %f } else { %g - }, se_scope=meta[SEScope][0], constrain_result=True); + }, virtual_device=meta[VirtualDevice][0], constrain_result=True); %h(%z) } """, @@ -1357,14 +1363,14 @@ def g(a): def test_global(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( """ #[version = "0.0.5"] def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { - %0 = on_device(%b, se_scope=meta[SEScope][0]); + %0 = on_device(%b, virtual_device=meta[VirtualDevice][0]); add(%a, %0) } @@ -1382,15 +1388,15 @@ def expected(): """ #[version = "0.0.5"] def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], - result_se_scope=meta[SEScope][1]) -> Tensor[(5, 7), float32] { - %0 = device_copy(%b, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); + param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][0]], + result_virtual_device=meta[VirtualDevice][1]) -> Tensor[(5, 7), float32] { + %0 = device_copy(%b, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); add(%a, %0) } def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][1]], - result_se_scope=meta[SEScope][1]) -> Tensor[(5, 7), float32] { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][1]) -> Tensor[(5, 7), float32] { @f(%y, %x) } """, @@ -1409,7 +1415,7 @@ def f(a, b): def test_ref(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -1417,7 +1423,7 @@ def input(): #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { let %r = ref(%x); - %0 = on_device(%y, se_scope=meta[SEScope][0]); + %0 = on_device(%y, virtual_device=meta[VirtualDevice][0]); ref_write(%r, %0); %1 = ref_read(%r); add(%x, %1) @@ -1433,10 +1439,10 @@ def expected(): """ #[version = "0.0.5"] def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][1], meta[SEScope][0]], result_se_scope=meta[SEScope][1]) { - let %r = on_device(ref(%x), se_scope=meta[SEScope][1], constrain_result=True); - %0 = device_copy(%y, src_se_scope=meta[SEScope][0], dst_se_scope=meta[SEScope][1]); - on_device(ref_write(%r, %0), se_scope=meta[SEScope][1], constrain_result=True); + param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][1]) { + let %r = on_device(ref(%x), virtual_device=meta[VirtualDevice][1], constrain_result=True); + %0 = device_copy(%y, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]); + on_device(ref_write(%r, %0), virtual_device=meta[VirtualDevice][1], constrain_result=True); %1 = ref_read(%r); add(%x, %1) } @@ -1456,7 +1462,7 @@ def ref(x, y): def test_adt(): - metatable = {"SEScope": [CPU, GPU]} + metatable = {"VirtualDevice": [CPU, GPU]} def input(): return tvm.parser.parse( @@ -1467,7 +1473,7 @@ def input(): Nil, } def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { - %0 = on_device(%y, se_scope=meta[SEScope][0], constrain_result=True); + %0 = on_device(%y, virtual_device=meta[VirtualDevice][0], constrain_result=True); %1 = Nil; %2 = Cons(%0, %1); let %l = Cons(%x, %2); @@ -1490,10 +1496,10 @@ def expected(): Nil, } def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][0]], result_se_scope=meta[SEScope][0]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][0]], result_virtual_device=meta[VirtualDevice][0]) { %0 = Nil; %1 = Cons(%y, %0); - let %l = on_device(Cons(%x, %1), se_scope=meta[SEScope][0], constrain_result=True); + let %l = on_device(Cons(%x, %1), virtual_device=meta[VirtualDevice][0], constrain_result=True); match? (%l) { Cons(%z, _) => %z } @@ -1516,7 +1522,7 @@ def test_free_on_device(): a device_copy to be inserted if necessary, but otherwise does not prevent the flow of device information.""" metatable = { - "SEScope": [ + "VirtualDevice": [ CPU, # no memory scope constraint CPU_SCOPE_A, # constrain to scopeA CPU_SCOPE_B, @@ -1529,22 +1535,22 @@ def input(): """ #[version = "0.0.5"] def @on_scope_b(%x: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][2]], - result_se_scope=meta[SEScope][2]) -> Tensor[(5, 7), float32] { + param_virtual_devices=[meta[VirtualDevice][2]], + result_virtual_device=meta[VirtualDevice][2]) -> Tensor[(5, 7), float32] { %x } def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][1], meta[SEScope][2]], - result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][1], meta[VirtualDevice][2]], + result_virtual_device=meta[VirtualDevice][1]) { // %a's memory scope is unconstrained, so will take on "scopeB" and on_device has no effect - %0 = @on_scope_b(on_device(%a, se_scope=meta[SEScope][0], constrain_body=False)); + %0 = @on_scope_b(on_device(%a, virtual_device=meta[VirtualDevice][0], constrain_body=False)); // %b's memory scope is "scopeA", so will require a "scopeA"->"scopeB" copy. - %1 = @on_scope_b(on_device(%b, se_scope=meta[SEScope][0], constrain_body=False)); + %1 = @on_scope_b(on_device(%b, virtual_device=meta[VirtualDevice][0], constrain_body=False)); // %c's memory scope is "scopeB", so no copy required. - %2 = @on_scope_b(on_device(%c, se_scope=meta[SEScope][0], constrain_body=False)); + %2 = @on_scope_b(on_device(%c, virtual_device=meta[VirtualDevice][0], constrain_body=False)); // result's memory scope is is on "scopeA", so will require a "scopeB"->"scopeA" copy. %3 = add(add(%0, %1), %2); - on_device(%3, se_scope=meta[SEScope][0], constrain_body=False) + on_device(%3, virtual_device=meta[VirtualDevice][0], constrain_body=False) } """, "from_string", @@ -1557,20 +1563,20 @@ def expected(): """ #[version = "0.0.5"] def @on_scope_b(%x: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][2]], - result_se_scope=meta[SEScope][2]) -> Tensor[(5, 7), float32] { + param_virtual_devices=[meta[VirtualDevice][2]], + result_virtual_device=meta[VirtualDevice][2]) -> Tensor[(5, 7), float32] { %x } def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], %c: Tensor[(5, 7), float32], - param_se_scopes=[meta[SEScope][2], meta[SEScope][1], meta[SEScope][2]], - result_se_scope=meta[SEScope][1]) { + param_virtual_devices=[meta[VirtualDevice][2], meta[VirtualDevice][1], meta[VirtualDevice][2]], + result_virtual_device=meta[VirtualDevice][1]) { %0 = @on_scope_b(%a); - %1 = device_copy(%b, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]); + %1 = device_copy(%b, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][2]); %2 = @on_scope_b(%1); %3 = @on_scope_b(%c); %4 = add(add(%0, %2), %3); - %5 = on_device(%4, se_scope=meta[SEScope][2], constrain_result=True); - device_copy(%5, src_se_scope=meta[SEScope][2], dst_se_scope=meta[SEScope][1]) + %5 = on_device(%4, virtual_device=meta[VirtualDevice][2], constrain_result=True); + device_copy(%5, src_virtual_device=meta[VirtualDevice][2], dst_virtual_device=meta[VirtualDevice][1]) } """, "from_string", @@ -1616,12 +1622,12 @@ def expected_gem(a: T.handle, b: T.handle, c: T.handle, d: T.handle) -> None: D[vi, vj] = D[vi, vj] + A[vi, vk] * B[vj, vk] metatable = { - "SEScope": [ - CPU, # meta[SEScope][0], no memory scope - CPU_SCOPE_A, # meta[SEScope][1], "scopeA" + "VirtualDevice": [ + CPU, # meta[VirtualDevice][0], no memory scope + CPU_SCOPE_A, # meta[VirtualDevice][1], "scopeA" CPU_SCOPE_B, ] - } # meta[SEScope][2], "scopeB" + } # meta[VirtualDevice][2], "scopeB" gem_ty = relay.FuncType( [ relay.TensorType((128, 128), "float32"), @@ -1645,8 +1651,8 @@ def input(): def @main(%x : Tensor[(128, 128), float32], %y : Tensor[(128, 128), float32], %z : Tensor[(128, 128), float32], - param_se_scopes=[meta[SEScope][0], meta[SEScope][2], meta[SEScope][1]], - result_se_scope=meta[SEScope][2]) { + param_virtual_devices=[meta[VirtualDevice][0], meta[VirtualDevice][2], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][2]) { call_lowered(@gem, (%x, %y, %z)) } """, @@ -1668,13 +1674,13 @@ def expected(): def @main(%x : Tensor[(128, 128), float32], %y : Tensor[(128, 128), float32], %z : Tensor[(128, 128), float32], - param_se_scopes=[meta[SEScope][1], meta[SEScope][2], meta[SEScope][1]], - result_se_scope=meta[SEScope][2]) { - %0 = device_copy(%z, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]); - %1 = on_device(%0, se_scope=meta[SEScope][2], constrain_result=True); + param_virtual_devices=[meta[VirtualDevice][1], meta[VirtualDevice][2], meta[VirtualDevice][1]], + result_virtual_device=meta[VirtualDevice][2]) { + %0 = device_copy(%z, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][2]); + %1 = on_device(%0, virtual_device=meta[VirtualDevice][2], constrain_result=True); %2 = call_lowered(@gem, (%x, %y, %1)); - %3 = on_device(%2, se_scope=meta[SEScope][1], constrain_result=True); - device_copy(%3, src_se_scope=meta[SEScope][1], dst_se_scope=meta[SEScope][2]) + %3 = on_device(%2, virtual_device=meta[VirtualDevice][1], constrain_result=True); + device_copy(%3, src_virtual_device=meta[VirtualDevice][1], dst_virtual_device=meta[VirtualDevice][2]) } """, "from_string", diff --git a/tests/python/target/test_se_scope.py b/tests/python/target/test_virtual_device.py similarity index 54% rename from tests/python/target/test_se_scope.py rename to tests/python/target/test_virtual_device.py index 0a9384fa9c04..eec77bcc1f4f 100644 --- a/tests/python/target/test_se_scope.py +++ b/tests/python/target/test_virtual_device.py @@ -19,30 +19,30 @@ import tvm -def test_make_se_scope_for_device(): - se_scope = tvm.target.make_se_scope(tvm.device("cuda")) - assert se_scope.device_type == 2 +def test_make_virtual_device_for_device(): + virtual_device = tvm.target.make_virtual_device(tvm.device("cuda")) + assert virtual_device.device_type == 2 # ie kDLCUDA - assert se_scope.virtual_device_id == 0 - assert se_scope.target is None - assert se_scope.memory_scope == "" + assert virtual_device.virtual_device_id == 0 + assert virtual_device.target is None + assert virtual_device.memory_scope == "" -def test_make_se_scope_for_device_and_target(): +def test_make_virtual_device_for_device_and_target(): target = tvm.target.Target("cuda") - se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target) - assert se_scope.device_type == 2 # ie kDLCUDA - assert se_scope.target == target - assert se_scope.memory_scope == "" + virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"), target) + assert virtual_device.device_type == 2 # ie kDLCUDA + assert virtual_device.target == target + assert virtual_device.memory_scope == "" -def test_make_se_scope_for_device_target_and_memory_scope(): +def test_make_virtual_device_for_device_target_and_memory_scope(): target = tvm.target.Target("cuda") scope = "local" - se_scope = tvm.target.make_se_scope(tvm.device("cuda"), target, scope) - assert se_scope.device_type == 2 # ie kDLCUDA - assert se_scope.target == target - assert se_scope.memory_scope == scope + virtual_device = tvm.target.make_virtual_device(tvm.device("cuda"), target, scope) + assert virtual_device.device_type == 2 # ie kDLCUDA + assert virtual_device.target == target + assert virtual_device.memory_scope == scope if __name__ == "__main__": From cb9ea2f32f57947b75a5b5beb72abc43cd59f900 Mon Sep 17 00:00:00 2001 From: Mark Shields Date: Thu, 16 Dec 2021 14:16:01 -0800 Subject: [PATCH 2/2] [checkpoint] lint --- python/tvm/relay/transform/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 696bd6258ee6..bbe4bc2de9b2 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1168,8 +1168,8 @@ def PlanDevices(config): every Relay sub-expression should run and the result stored. Captures the result of that analysis using new "on_device" and "device_copy" calls. Sub-expressions which are not otherwise constrained are assigned to the default primitive virtual device describe by - config. However data and computations which must be hosted on a CPU (such as shapes and shape functions) - use the host virtual device of the config. + config. However data and computations which must be hosted on a CPU (such as shapes and + shape functions) use the host virtual device of the config. Parameters ----------