Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions include/tvm/runtime/container/variant.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/runtime/container/variant.h
* \brief Runtime Variant container types.
*/
#ifndef TVM_RUNTIME_CONTAINER_VARIANT_H_
#define TVM_RUNTIME_CONTAINER_VARIANT_H_

#include <tvm/runtime/object.h>

#include <tuple>
#include <type_traits>
#include <utility>

namespace tvm {
namespace runtime {

namespace detail {
template <typename Parent, typename ChildTuple>
constexpr bool parent_is_base_of_any = false;

template <typename Parent, typename... Child>
constexpr bool parent_is_base_of_any<Parent, std::tuple<Child...>> =
((std::is_base_of_v<Parent, Child> && !std::is_same_v<Parent, Child>) || ...);

/* \brief Utility to check if any parent is a base class of any child
*
* The type-checking in Variant relies on all types being from
* independent types, such that `Object::IsInstance` is sufficient to
* determine which variant is populated.
*
* For example, suppose the illegal `Variant<tir::Var, tir::PrimExpr>`
* were allowed (e.g. to represent either the defintion of a variable
* or the usage of a variable). If a function returned
* `tir::PrimExpr`, it could result in either variant being filled, as
* the underlying type at runtime could be a `tir::Var`. This
* behavior is different from `std::variant`, which determines the
* active variant based solely on the compile-time type, and could
* produce very unexpected results if the variants have different
* semantic interpretations.
*/
template <typename ParentTuple, typename ChildTuple>
static constexpr bool any_parent_is_base_of_any_child = false;

template <typename ChildTuple, typename... Parent>
static constexpr bool any_parent_is_base_of_any_child<std::tuple<Parent...>, ChildTuple> =
(parent_is_base_of_any<Parent, ChildTuple> || ...);
} // namespace detail

template <typename... V>
class Variant : public ObjectRef {
static constexpr bool all_inherit_from_objectref = (std::is_base_of_v<ObjectRef, V> && ...);
static_assert(all_inherit_from_objectref,
"All types used in Variant<...> must inherit from ObjectRef");

static constexpr bool a_variant_inherits_from_another_variant =
detail::any_parent_is_base_of_any_child<std::tuple<V...>, std::tuple<V...>>;
static_assert(!a_variant_inherits_from_another_variant,
"Due to implementation limitations, "
"no type stored in a tvm::runtime::Variant "
"may be a subclass of any other type "
"stored in the same variant.");

public:
/* \brief Helper utility to check if the type is part of the variant */
template <typename T>
static constexpr bool is_variant = (std::is_same_v<T, V> || ...);

/* \brief Helper utility for SFINAE if the type is part of the variant */
template <typename T>
using enable_if_variant = std::enable_if_t<is_variant<T>>;

template <typename T, typename = enable_if_variant<T>>
Variant(T value) : ObjectRef(std::move(value)) {} // NOLINT(*)

template <typename T, typename = enable_if_variant<T>>
Variant& operator=(T value) {
ObjectRef::operator=(std::move(value));
return *this;
}

// These functions would normally be declared with the
// TVM_DEFINE_OBJECT_REF_METHODS macro. However, we need additional
// type-checking inside the ObjectPtr<Object> constructor.
using ContainerType = Object;
Variant() : ObjectRef() {}
explicit Variant(ObjectPtr<Object> node) : ObjectRef(node) {
CHECK(node == nullptr || (node->IsInstance<typename V::ContainerType>() || ...))
<< "Variant<"
<< static_cast<const std::stringstream&>(
(std::stringstream() << ... << V::ContainerType::_type_key))
.str()
<< "> cannot hold an object of type " << node->GetTypeKey();
}
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Variant);
};

} // namespace runtime

// expose the functions to the root namespace.
using runtime::Variant;

} // namespace tvm

#endif // TVM_RUNTIME_CONTAINER_VARIANT_H_
54 changes: 51 additions & 3 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/variant.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/module.h>
Expand Down Expand Up @@ -680,9 +681,6 @@ class TVMArgValue : public TVMPODValue_ {
} else if (type_code_ == kTVMStr) {
return std::string(value_.v_str);
} else {
ICHECK(IsObjectRef<tvm::runtime::String>())
<< "Could not convert TVM object of type " << runtime::Object::TypeIndex2Key(type_code_)
<< " to a string.";
return AsObjectRef<tvm::runtime::String>().operator std::string();
}
}
Expand Down Expand Up @@ -2063,6 +2061,56 @@ struct PackedFuncValueConverter<Optional<T>> {
}
};

template <typename... VariantTypes>
struct PackedFuncValueConverter<Variant<VariantTypes...>> {
using VType = Variant<VariantTypes...>;

// Can't just take `const TVMPODValue&` as an argument, because
// `TVMArgValue` and `TVMRetValue` have different implementations
// for `operator std::string()`.
template <typename PODSubclass>
static VType From(const PODSubclass& val) {
if (auto opt = TryAsObjectRef<VariantTypes...>(val)) {
return opt.value();
}

if (auto opt = TryValueConverter<PODSubclass, VariantTypes...>(val)) {
return opt.value();
}

LOG(FATAL) << "Expected one of "
<< static_cast<const std::stringstream&>(
(std::stringstream() << ... << VariantTypes::ContainerType::_type_key))
.str()
<< " but got " << ArgTypeCode2Str(val.type_code());
}

template <typename VarFirst, typename... VarRest>
static Optional<VType> TryAsObjectRef(const TVMPODValue_& val) {
if (val.IsObjectRef<VarFirst>()) {
return VType(val.AsObjectRef<VarFirst>());
} else if constexpr (sizeof...(VarRest)) {
return TryAsObjectRef<VarRest...>(val);
} else {
return NullOpt;
}
}

template <typename PODSubclass, typename VarFirst, typename... VarRest>
static Optional<VType> TryValueConverter(const PODSubclass& val) {
try {
return VType(PackedFuncValueConverter<VarFirst>::From(val));
} catch (const InternalError&) {
}

if constexpr (sizeof...(VarRest)) {
return TryValueConverter<PODSubclass, VarRest...>(val);
} else {
return NullOpt;
}
}
};

inline bool String::CanConvertFrom(const TVMArgValue& val) {
return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
}
Expand Down
12 changes: 12 additions & 0 deletions src/support/ffi_testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <tvm/ir/attrs.h>
#include <tvm/ir/env_func.h>
#include <tvm/runtime/container/variant.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/tensor.h>
Expand Down Expand Up @@ -165,4 +166,15 @@ TVM_REGISTER_GLOBAL("testing.sleep_in_ffi").set_body_typed([](double timeout) {
std::this_thread::sleep_for(duration);
});

TVM_REGISTER_GLOBAL("testing.ReturnsVariant").set_body_typed([](int x) -> Variant<String, IntImm> {
if (x % 2 == 0) {
return IntImm(DataType::Int(64), x / 2);
} else {
return String("argument was odd");
}
});

TVM_REGISTER_GLOBAL("testing.AcceptsVariant")
.set_body_typed([](Variant<String, Integer> arg) -> String { return arg->GetTypeKey(); });

} // namespace tvm
47 changes: 47 additions & 0 deletions tests/cpp/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/container/variant.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>

Expand Down Expand Up @@ -853,3 +854,49 @@ TEST(Optional, PackedCall) {
test_ffi(s, static_cast<int>(kTVMObjectHandle));
test_ffi(String(s), static_cast<int>(kTVMObjectRValueRefArg));
}

TEST(Variant, Construct) {
Variant<PrimExpr, String> variant;
variant = PrimExpr(1);
ICHECK(variant.as<PrimExpr>());
ICHECK(!variant.as<String>());

variant = String("hello");
ICHECK(variant.as<String>());
ICHECK(!variant.as<PrimExpr>());
}

TEST(Variant, InvalidTypeThrowsError) {
auto expected_to_throw = []() {
ObjectPtr<Object> node = make_object<Object>();
Variant<PrimExpr, String> variant(node);
};

EXPECT_THROW(expected_to_throw(), InternalError);
}
Comment on lines +858 to +876
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A rather small set of tests, albeit for a fairly small API surface as compared to Array and Map. Are there other tests we could add? Maybe check assignment?

TEST(Variant, Assignment) {
  Variant<PrimExpr, String> variant;
  Variant<PrimExpr, String> variant2 = String("foo");
  variant = PrimExpr(1);
  variant2 = variant;

  ICHECK(variant2.as<PrimExpr>());
  # check the value of variant2
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I made the API surface as small as possible, but there were additional tests that should be included. I've added tests to validate that reference equality is preserved across Variant assignments, and that the values are correctly preserved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Lunderberg!


TEST(Variant, ReferenceIdentifyPreservedThroughAssignment) {
Variant<PrimExpr, String> variant;
ICHECK(!variant.defined());

String string_obj = "dummy_test";
variant = string_obj;
ICHECK(variant.defined());
ICHECK(variant.same_as(string_obj));
ICHECK(string_obj.same_as(variant));

String out_string_obj = Downcast<String>(variant);
ICHECK(string_obj.same_as(out_string_obj));
}

TEST(Variant, ExtractValueFromAssignment) {
Variant<PrimExpr, String> variant = String("hello");
ICHECK_EQ(variant.as<String>().value(), "hello");
}

TEST(Variant, AssignmentFromVariant) {
Variant<PrimExpr, String> variant = String("hello");
auto variant2 = variant;
ICHECK(variant2.as<String>());
ICHECK_EQ(variant2.as<String>().value(), "hello");
}
26 changes: 26 additions & 0 deletions tests/python/unittest/test_ir_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,5 +112,31 @@ def test_ndarray_container():
assert isinstance(arr[0], tvm.nd.NDArray)


def test_return_variant_type():
func = tvm.get_global_func("testing.ReturnsVariant")
res_even = func(42)
assert isinstance(res_even, tvm.tir.IntImm)
assert res_even == 21

res_odd = func(17)
assert isinstance(res_odd, tvm.runtime.String)
assert res_odd == "argument was odd"


def test_pass_variant_type():
func = tvm.get_global_func("testing.AcceptsVariant")

assert func("string arg") == "runtime.String"
assert func(17) == "IntImm"


def test_pass_incorrect_variant_type():
func = tvm.get_global_func("testing.AcceptsVariant")
float_arg = tvm.tir.FloatImm("float32", 0.5)

with pytest.raises(Exception):
func(float_arg)


if __name__ == "__main__":
tvm.testing.main()