Skip to content

Commit f9a9b84

Browse files
committed
[Cleanup] Accept Variant<...> instead of ObjectRef when possible
Prior to the implementation of `Variant<...>` in #15672, functions that were polymorphic over an argument type would typically accept an `ObjectRef` argument, then downcast to an allowed type. This delays the catching of an error, and can accidentally omit automatic conversions applied by the FFI. This commit updates several locations using this pattern to instead accept a `Variant`, templated over the allowed types. This enables C++ type checking for C++ callers, standardizes the type-checking in the FFI for non-C++ callers, and ensures that FFI type conversions are uniformly applied.
1 parent 4ecae58 commit f9a9b84

File tree

9 files changed

+15
-13
lines changed

9 files changed

+15
-13
lines changed

include/tvm/tir/function.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ class TensorIntrin : public ObjectRef {
264264
* B[vi, vj] = A[vi, vj]
265265
* \endcode
266266
*/
267-
PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map);
267+
PrimFunc Specialize(PrimFunc func, const Map<Var, Variant<Buffer, PrimExpr>>& param_map);
268268

269269
/*!
270270
* \brief PrimFunc specific attribute names.

src/relax/op/tensor/create.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace relax {
3636
TVM_REGISTER_NODE_TYPE(InitAttrs);
3737

3838
/* relax.full */
39-
Expr full(ObjectRef shape, Expr fill_value, DataType dtype) {
39+
Expr full(Variant<Expr, Array<PrimExpr>> shape, Expr fill_value, DataType dtype) {
4040
Expr shape_in_expr{nullptr};
4141
if (const auto* expr = shape.as<ExprNode>()) {
4242
shape_in_expr = GetRef<Expr>(expr);

src/relax/op/tensor/create.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace relax {
3939
* If dtype is not given, it will by default use the dtype of fill_value.
4040
* \return The result tensor.
4141
*/
42-
Expr full(ObjectRef shape, Expr fill_value, DataType dtype);
42+
Expr full(Variant<Expr, Array<PrimExpr>> shape, Expr fill_value, DataType dtype);
4343

4444
/*!
4545
* \brief Construct a tensor such that

src/relax/op/tensor/manipulate.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ TVM_REGISTER_OP("relax.permute_dims")
652652
.set_attr<Bool>("FPurity", Bool(true));
653653

654654
/* relax.reshape */
655-
Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
655+
Expr ConvertNewShapeToExpr(const Expr& data, const Variant<Expr, Array<PrimExpr>>& shape) {
656656
const ArrayNode* array;
657657
// Treat shape expressions as constant arrays to handle special values.
658658
if (const auto* e = shape.as<ShapeExprNode>()) {
@@ -745,7 +745,7 @@ Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) {
745745
return ShapeExpr(array_ref);
746746
}
747747

748-
Expr reshape(Expr x, ObjectRef shape) {
748+
Expr reshape(Expr x, Variant<Expr, Array<PrimExpr>> shape) {
749749
Expr shape_in_expr = ConvertNewShapeToExpr(x, shape);
750750
static const Op& op = Op::Get("relax.reshape");
751751
return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {});
@@ -810,7 +810,7 @@ TVM_REGISTER_OP("relax.reshape")
810810
/* relax.split */
811811
TVM_REGISTER_NODE_TYPE(SplitAttrs);
812812

813-
Expr split(Expr x, ObjectRef indices_or_sections, int axis) {
813+
Expr split(Expr x, Variant<IntImm, Array<IntImm>> indices_or_sections, int axis) {
814814
ObjectPtr<SplitAttrs> attrs = make_object<SplitAttrs>();
815815
if (const auto* indices = indices_or_sections.as<ArrayNode>()) {
816816
for (int i = 0; i < static_cast<int>(indices->size()); ++i) {

src/relax/op/tensor/manipulate.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ Expr permute_dims(Expr x, Optional<Array<Integer>> axes);
8888
* It is required to be either an Array of PrimExpr, or a Shape in Relax
8989
* \return The reshaped result.
9090
*/
91-
Expr reshape(Expr x, ObjectRef shape);
91+
Expr reshape(Expr x, Variant<Expr, Array<PrimExpr>> shape);
9292

9393
/*!
9494
* \brief Split input tensor along axis by sections or indices.
@@ -103,7 +103,7 @@ Expr reshape(Expr x, ObjectRef shape);
103103
* \param axis The axis over which to split.
104104
* \return The computed result.
105105
*/
106-
Expr split(Expr x, ObjectRef indices_or_sections, int axis);
106+
Expr split(Expr x, Variant<IntImm, Array<IntImm>> indices_or_sections, int axis);
107107

108108
/*!
109109
* \brief Squeeze axes in the array.

src/relay/transforms/to_mixed_precision.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ using CachedCastNodes = std::unordered_map<std::pair<const ExprNode*, DataType>,
6666
// Return array is of type : [MixedTypeConversionCategory (int), String, String]
6767
// The fields are : [ConversionCategory, accumulation_datatype, output_datatype]
6868
// Call is a call node, DataType is the mixed precision type
69-
using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<ObjectRef>(
69+
using FTVMMixedPrecisionConversionType = runtime::TypedPackedFunc<Array<Variant<Integer, String>>(
7070
const Call& call_node, const std::string& target_dtype_str)>;
7171

7272
/*! \brief This class transforms the given relay module into a version where
@@ -372,7 +372,7 @@ class MixedPrecisionPass : public MixedModeMutator {
372372
if (attr_map.count(op)) {
373373
// Calculate the conversion category and dtypes from registered attribute.
374374
FTVMMixedPrecisionConversionType func = attr_map[op];
375-
Array<ObjectRef> op_descriptor =
375+
Array<Variant<Integer, String>> op_descriptor =
376376
func(GetRef<Call>(pre_call_node), DLDataType2String(mixed_precision_type_));
377377
ICHECK(op_descriptor.size() == 3)
378378
<< "got the wrong number of returned arguments (expected 3 got " << op_descriptor.size()

src/tir/ir/expr.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,9 @@ Call::Call(DataType dtype, RelayExpr op, Array<PrimExpr> args, Span span) {
546546
}
547547

548548
TVM_REGISTER_GLOBAL("tir.Call")
549-
.set_body_typed([](DataType type, RelayExpr op, Array<ObjectRef> args, Span span) {
549+
.set_body_typed([](DataType type, RelayExpr op,
550+
Array<Variant<runtime::String, IterVar, BufferRegion, PrimExpr>> args,
551+
Span span) {
550552
Array<PrimExpr> prim_expr_args;
551553
for (const auto& it : args) {
552554
ICHECK(it->IsInstance<runtime::StringObj>() || it->IsInstance<PrimExprNode>() ||

src/tir/ir/specialize.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ void UpdateSpecializeVarMap(const PrimFunc& func, const Var& param, const PrimEx
414414

415415
/**************** Implementation ****************/
416416

417-
PrimFunc Specialize(PrimFunc func, const Map<Var, ObjectRef>& param_map) {
417+
PrimFunc Specialize(PrimFunc func, const Map<Var, Variant<Buffer, PrimExpr>>& param_map) {
418418
VarMap var_map;
419419
for (const auto& kv : param_map) {
420420
const Var& param = kv.first;

src/tir/transforms/inline_private_functions.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ class PrimFuncInliner : StmtExprMutator {
231231
<< "Inlining of PrimFuncs with buffer arguments is not yet supported, "
232232
<< "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map;
233233

234-
Map<Var, ObjectRef> param_map;
234+
Map<Var, Variant<tir::Buffer, tvm::PrimExpr>> param_map;
235235
for (size_t i = 0; i < callee->params.size(); i++) {
236236
param_map.Set(callee->params[i], args[i]);
237237
}

0 commit comments

Comments
 (0)