|
27 | 27 |
|
28 | 28 | namespace tvm { |
29 | 29 |
|
| 30 | +GlobalVar WithFields(GlobalVar global_var, Optional<String> opt_name_hint, Optional<Type> opt_type, |
| 31 | + Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
| 32 | + String name_hint = opt_name_hint.value_or(global_var->name_hint); |
| 33 | + Type type = opt_type.value_or(global_var->checked_type()); |
| 34 | + VirtualDevice virtual_device = opt_virtual_device.value_or(global_var->virtual_device()); |
| 35 | + Span span = opt_span.value_or(global_var->span); |
| 36 | + bool all_fields_unchanged = |
| 37 | + name_hint.same_as(global_var->name_hint) && type.same_as(global_var->checked_type()) && |
| 38 | + virtual_device.same_as(global_var->virtual_device()) && span.same_as(global_var->span); |
| 39 | + if (!all_fields_unchanged) { |
| 40 | + GlobalVarNode* cow_global_var_node = global_var.CopyOnWrite(); |
| 41 | + cow_global_var_node->name_hint = name_hint; |
| 42 | + cow_global_var_node->checked_type_ = type; |
| 43 | + cow_global_var_node->virtual_device_ = virtual_device; |
| 44 | + cow_global_var_node->span = span; |
| 45 | + } |
| 46 | + |
| 47 | + return global_var; |
| 48 | +} |
| 49 | + |
30 | 50 | VirtualDevice RelayExprNode::virtual_device() const { |
31 | 51 | if (!this->virtual_device_.defined()) { |
32 | 52 | // virtual_device_ should always be defined, unless we imported this node from JSON using an old |
@@ -77,6 +97,25 @@ TensorType ConstantNode::tensor_type() const { |
77 | 97 | return TensorType(shape, dtype); |
78 | 98 | } |
79 | 99 |
|
| 100 | +Constant WithFields(Constant constant, Optional<runtime::NDArray> opt_data, |
| 101 | + Optional<VirtualDevice> opt_virtual_device, Optional<Span> opt_span) { |
| 102 | + runtime::NDArray data = opt_data.value_or(constant->data); |
| 103 | + VirtualDevice virtual_device = opt_virtual_device.value_or(constant->virtual_device()); |
| 104 | + Span span = opt_span.value_or(constant->span); |
| 105 | + |
| 106 | + bool all_fields_unchanged = data.same_as(constant->data) && |
| 107 | + virtual_device.same_as(constant->virtual_device()) && |
| 108 | + span.same_as(constant->span); |
| 109 | + |
| 110 | + if (!all_fields_unchanged) { |
| 111 | + ConstantNode* cow_constant_node = constant.CopyOnWrite(); |
| 112 | + cow_constant_node->data = data; |
| 113 | + cow_constant_node->virtual_device_ = virtual_device; |
| 114 | + cow_constant_node->span = span; |
| 115 | + } |
| 116 | + return constant; |
| 117 | +} |
| 118 | + |
80 | 119 | Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) { |
81 | 120 | ObjectPtr<TupleNode> n = make_object<TupleNode>(); |
82 | 121 | n->fields = std::move(fields); |
|
0 commit comments