@@ -770,121 +770,53 @@ inline const TTypeNode* RelayExprNode::type_as() const {
770770
771771namespace tvm {
772772namespace runtime {
773-
774- // Automatic conversion into IntImm, Integer, and Bool, when called
775- // through the FFI. Automatic conversions into PrimExpr are
776- // registered in "tvm/tir/expr.h", as it includes conversions to the
777- // TIR-only StringImm.
778- //
779- // While the FFI only requires the From() method, these
780- // implementations also define a TryFrom() method to avoid duplicate
781- // logic in the PrimExpr conversion.
782-
773+ // common rule for RetValue and ArgValue
783774template <>
784- struct PackedFuncValueConverter <tvm::IntImm> {
785- template <typename PODSubclass>
786- static Optional<tvm::IntImm> TryFrom (const PODSubclass& val) {
787- if (auto opt = val.TryAsInt ()) {
788- int64_t value = opt.value ();
789- auto dtype =
790- (value > std::numeric_limits<int >::max () || value < std::numeric_limits<int >::min ())
791- ? DataType::Int (64 )
792- : DataType::Int (32 );
793- return IntImm (dtype, value);
794- } else if (auto opt = val.TryAsBool ()) {
795- return IntImm (DataType::Int (32 ), opt.value ());
796- } else {
797- return NullOpt;
775+ struct PackedFuncValueConverter <PrimExpr> {
776+ static PrimExpr From (const TVMPODValue_& val) {
777+ if (val.type_code () == kTVMNullptr ) {
778+ return PrimExpr (ObjectPtr<Object>(nullptr ));
798779 }
799- }
800-
801- template <typename PODSubclass>
802- static tvm::IntImm From (const PODSubclass& val) {
803- if (auto opt = TryFrom (val)) {
804- return opt.value ();
805- } else {
806- return val.template AsObjectRef <tvm::IntImm>();
780+ if (val.type_code () == kDLInt ) {
781+ int64_t value = val.operator int64_t ();
782+ if (value > std::numeric_limits<int >::max () || value < std::numeric_limits<int >::min ()) {
783+ return IntImm (runtime::DataType::Int (64 ), value);
784+ }
785+ return IntImm (runtime::DataType::Int (32 ), val.operator int ());
807786 }
808- }
809- };
810-
811- template <>
812- struct PackedFuncValueConverter <tvm::Integer> {
813- template <typename PODSubclass>
814- static tvm::Integer From (const PODSubclass& val) {
815- if (auto opt = PackedFuncValueConverter<tvm::IntImm>::TryFrom (val)) {
816- return Integer (opt.value ());
817- } else {
818- return val.template AsObjectRef <tvm::Integer>();
787+ if (val.type_code () == kDLFloat ) {
788+ return FloatImm (runtime::DataType::Float (32 ), val.operator double ());
819789 }
820- }
821- };
822790
823- template <>
824- struct PackedFuncValueConverter <tvm::Bool> {
825- template <typename PODSubclass>
826- static Optional<tvm::Bool> TryFrom (const PODSubclass& val) {
827- if (auto opt = val.TryAsBool ()) {
828- return tvm::Bool (opt.value ());
829- } else if (auto opt = val.TryAsInt ()) {
830- int value = opt.value ();
831- ICHECK (value == 0 || value == 1 )
832- << " ValueError: boolean value can only be 0 or 1, but get " << value;
833- return tvm::Bool (static_cast <bool >(value));
834- } else {
835- return NullOpt;
836- }
837- }
838-
839- template <typename PODSubclass>
840- static tvm::Bool From (const PODSubclass& val) {
841- if (auto opt = TryFrom (val)) {
842- return opt.value ();
843- } else {
844- return val.template AsObjectRef <tvm::Bool>();
845- }
791+ return PrimExpr::FromObject_ (val.AsObjectRef <ObjectRef>());
846792 }
847793};
848794
849795template <>
850- struct PackedFuncValueConverter <tvm::FloatImm> {
851- static Optional<tvm::FloatImm> TryFrom (const TVMPODValue_& val) {
852- if (auto opt = val.TryAsFloat ()) {
853- return FloatImm (runtime::DataType::Float (32 ), opt.value ());
854- } else {
855- return NullOpt;
796+ struct PackedFuncValueConverter <tvm::Integer> {
797+ static tvm::Integer From (const TVMPODValue_& val) {
798+ if (val.type_code () == kTVMNullptr ) {
799+ return Integer (ObjectPtr<Object>(nullptr ));
856800 }
857- }
858-
859- template <typename PODSubclass>
860- static tvm::FloatImm From (const PODSubclass& val) {
861- if (auto opt = TryFrom (val)) {
862- return opt.value ();
863- } else {
864- return val.template AsObjectRef <tvm::FloatImm>();
801+ if (val.type_code () == kTVMArgInt ) {
802+ return Integer (val.operator int ());
865803 }
804+ return val.AsObjectRef <tvm::Integer>();
866805 }
867806};
868807
869- /* \brief Backwards compatibility wrapper for IntImm arguments
870- *
871- * In previous versions of TVM, IntImm was the default FFI type for
872- * integer arguments, instead of runtime::Int. For backwards
873- * compatibility where the callee has been updated to expected a
874- * runtime::Int, the caller has not been updated to provide a
875- * runtime::Int (e.g. relay script parsing), and the auto-unboxing of
876- * runtime::Int does not apply (e.g. making an `Array<runtime::Int>`),
877- * allow the IntImm to be generated.
878- */
879808template <>
880- struct PackedFuncValueConverter <runtime::Int> {
881- template <typename PODSubclass>
882- static runtime::Int From (const PODSubclass& val) {
883- if (val.template IsObjectRef <tvm::IntImm>()) {
884- return runtime::Int (val.template AsObjectRef <tvm::IntImm>()->value );
885- } else {
886- return val.template AsObjectRef <runtime::Int>();
809+ struct PackedFuncValueConverter <tvm::Bool> {
810+ static tvm::Bool From (const TVMPODValue_& val) {
811+ if (val.type_code () == kTVMNullptr ) {
812+ return Bool (ObjectPtr<Object>(nullptr ));
813+ }
814+ if (val.type_code () == kTVMArgInt ) {
815+ int v = val.operator int ();
816+ ICHECK (v == 0 || v == 1 ) << " ValueError: boolean value can only be 0 or 1, but get " << v;
817+ return Bool (static_cast <bool >(v));
887818 }
819+ return val.AsObjectRef <tvm::Bool>();
888820 }
889821};
890822
0 commit comments