@@ -38,27 +38,27 @@ static const DataType kFloat32 = DataType::Float(32);
3838static const DataType kFloat64 = DataType::Float(64 );
3939static const DataType kBool = DataType::Bool();
4040
41- bool IsSimpleScalar (const relay::ConstantNode* constant_node) {
42- if (!constant_node->is_scalar ()) {
43- return false ;
44- }
45- DataType dtype (constant_node->data ->dtype );
41+ bool IsSimpleScalarDtype (DataType dtype) {
4642 return dtype == kInt16 || dtype == kInt32 || dtype == kInt64 || dtype == kFloat16 ||
4743 dtype == kFloat32 || dtype == kFloat64 || dtype == kBool ;
4844}
4945
46+ bool IsSimpleScalar (const relay::ConstantNode* constant_node) {
47+ return constant_node->is_scalar () && IsSimpleScalarDtype (DataType (constant_node->data ->dtype ));
48+ }
49+
5050runtime::NDArray IntImmToNDArray (const IntImm& int_imm) {
5151 DLDevice dev = {DLDeviceType::kDLCPU , 0 };
5252 auto data = runtime::NDArray::Empty ({}, int_imm->dtype , dev);
53- if (int_imm.dtype () == kInt64 ) {
54- auto array = reinterpret_cast <int64_t *>(data->data );
55- array[0 ] = int_imm->value ;
53+ if (int_imm.dtype () == kInt16 ) {
54+ auto array = reinterpret_cast <int16_t *>(data->data );
55+ array[0 ] = static_cast < int16_t >( int_imm->value ) ;
5656 } else if (int_imm.dtype () == kInt32 ) {
5757 auto array = reinterpret_cast <int32_t *>(data->data );
5858 array[0 ] = static_cast <int32_t >(int_imm->value );
59- } else if (int_imm.dtype () == kInt16 ) {
60- auto array = reinterpret_cast <int16_t *>(data->data );
61- array[0 ] = static_cast < int16_t >( int_imm->value ) ;
59+ } else if (int_imm.dtype () == kInt64 ) {
60+ auto array = reinterpret_cast <int64_t *>(data->data );
61+ array[0 ] = int_imm->value ;
6262 } else {
6363 LOG (FATAL) << " Unrecognized numeric literal dtype: " << DLDataType2String (int_imm.dtype ());
6464 }
@@ -68,15 +68,15 @@ runtime::NDArray IntImmToNDArray(const IntImm& int_imm) {
6868runtime::NDArray FloatImmToNDArray (const FloatImm& float_imm) {
6969 DLDevice dev = {DLDeviceType::kDLCPU , 0 };
7070 auto data = runtime::NDArray::Empty ({}, float_imm->dtype , dev);
71- if (float_imm.dtype () == kFloat64 ) {
72- auto array = reinterpret_cast <double *>(data->data );
73- array[0 ] = float_imm->value ;
71+ if (float_imm.dtype () == kFloat16 ) {
72+ auto array = reinterpret_cast <uint16_t *>(data->data );
73+ array[0 ] = __gnu_f2h_ieee ( static_cast < float >( float_imm->value )) ;
7474 } else if (float_imm.dtype () == kFloat32 ) {
7575 auto array = reinterpret_cast <float *>(data->data );
7676 array[0 ] = static_cast <float >(float_imm->value );
77- } else if (float_imm.dtype () == kFloat16 ) {
78- auto array = reinterpret_cast <uint16_t *>(data->data );
79- array[0 ] = __gnu_f2h_ieee ( static_cast < float >( float_imm->value )) ;
77+ } else if (float_imm.dtype () == kFloat64 ) {
78+ auto array = reinterpret_cast <double *>(data->data );
79+ array[0 ] = float_imm->value ;
8080 } else {
8181 LOG (FATAL) << " Unrecognized numeric literal dtype: " << DLDataType2String (float_imm.dtype ());
8282 }
0 commit comments