Skip to content

Commit e520b9b

Browse files
authored
[Utility][Container] Support non-nullable types in Array::Map (#17094)
[Container] Support non-nullable types in Array::Map Prior to this commit, the `Array::Map` member function could only be applied to nullable object types. This was due to the internal use of `U()` as the default value for initializing the output `ArrayNode`, where `U` is the return type of the mapping function. This default constructor is only available for nullable types, and would result in a compile-time failure for non-nullable types. This commit replaces `U()` with `ObjectRef()` in `Array::Map`, removing this limitation. Since all items in the output array are overwritten before returning to the calling scope, initializing the output array with `ObjectRef()` does not violate type safety.
1 parent a4f20f0 commit e520b9b

File tree

1 file changed

+12
-2
lines changed
  • include/tvm/runtime/container

1 file changed

+12
-2
lines changed

include/tvm/runtime/container/array.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -827,8 +827,13 @@ class Array : public ObjectRef {
827827
// consisting of any previous elements that had mapped to
828828
// themselves (if any), and the element that didn't map to
829829
// itself.
830+
//
831+
// We cannot use `U()` as the default object, as `U` may be
832+
// a non-nullable type. Since the default `ObjectRef()`
833+
// will be overwritten before returning, all objects will be
834+
// of type `U` for the calling scope.
830835
all_identical = false;
831-
output = ArrayNode::CreateRepeated(arr->size(), U());
836+
output = ArrayNode::CreateRepeated(arr->size(), ObjectRef());
832837
output->InitRange(0, arr->begin(), it);
833838
output->SetItem(it - arr->begin(), std::move(mapped));
834839
it++;
@@ -843,7 +848,12 @@ class Array : public ObjectRef {
843848
// compatible types isn't strictly necessary, as the first
844849
// mapped.same_as(*it) would return false, but we might as well
845850
// avoid it altogether.
846-
output = ArrayNode::CreateRepeated(arr->size(), U());
851+
//
852+
// We cannot use `U()` as the default object, as `U` may be a
853+
// non-nullable type. Since the default `ObjectRef()` will be
854+
// overwritten before returning, all objects will be of type `U`
855+
// for the calling scope.
856+
output = ArrayNode::CreateRepeated(arr->size(), ObjectRef());
847857
}
848858

849859
// Normal path for incompatible types, or post-copy path for

0 commit comments

Comments
 (0)