Skip to content

Commit c655b0f

Browse files
committed
- preserve legacy print behavior for unrecognised dtypes.
1 parent 71a8c2f commit c655b0f

File tree

4 files changed

+38
-23
lines changed

4 files changed

+38
-23
lines changed

src/printer/relay_text_printer.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,14 +348,15 @@ Doc RelayTextPrinter::VisitExpr_(const VarNode* op) { return AllocVar(GetRef<Var
348348

349349
Doc RelayTextPrinter::VisitExpr_(const ConstantNode* op) {
350350
// Print out simple scalars directly.
351-
if (op->is_scalar()) {
351+
if (support::IsSimpleScalar(op)) {
352352
return Doc::Text(support::NDArrayScalarToString(op->data));
353353
}
354-
// default fall-back, record it as meta node.
354+
// Fallbock: record it as a meta node.
355355
Doc doc;
356356
// Don't append optional_info. Because the entry function is Print,
357357
// and it will append the optional_info afterwards.
358-
return doc << PrintExpr(GetRef<Expr>(op), true, false, false);
358+
return doc << PrintExpr(GetRef<Expr>(op), /*meta=*/true, /*try_inline=*/false,
359+
/*optional_info=*/false);
359360
}
360361

361362
Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
@@ -772,11 +773,21 @@ Doc RelayTextPrinter::VisitAttr_(const ArrayNode* op) {
772773
}
773774

774775
Doc RelayTextPrinter::VisitAttr_(const tir::IntImmNode* op) {
775-
return Doc::Text(support::IntImmToString(GetRef<IntImm>(op)));
776+
if (support::IsSimpleScalarDtype(op->dtype)) {
777+
return Doc::Text(support::IntImmToString(GetRef<IntImm>(op)));
778+
} else {
779+
// Fallback: Print int64_t without width suffix.
780+
return Doc::Text(std::to_string(op->value));
781+
}
776782
}
777783

778784
Doc RelayTextPrinter::VisitAttr_(const tir::FloatImmNode* op) {
779-
return Doc::Text(support::FloatImmToString(GetRef<FloatImm>(op)));
785+
if (support::IsSimpleScalarDtype(op->dtype)) {
786+
return Doc::Text(support::FloatImmToString(GetRef<FloatImm>(op)));
787+
} else {
788+
// Fallbock: Print double without width suffix.
789+
return Doc::Text(std::to_string(op->value));
790+
}
780791
}
781792

782793
Doc RelayTextPrinter::VisitAttr_(const tir::StringImmNode* op) {

src/support/scalars.cc

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,27 @@ static const DataType kFloat32 = DataType::Float(32);
3838
static const DataType kFloat64 = DataType::Float(64);
3939
static const DataType kBool = DataType::Bool();
4040

41-
bool IsSimpleScalar(const relay::ConstantNode* constant_node) {
42-
if (!constant_node->is_scalar()) {
43-
return false;
44-
}
45-
DataType dtype(constant_node->data->dtype);
41+
bool IsSimpleScalarDtype(DataType dtype) {
4642
return dtype == kInt16 || dtype == kInt32 || dtype == kInt64 || dtype == kFloat16 ||
4743
dtype == kFloat32 || dtype == kFloat64 || dtype == kBool;
4844
}
4945

46+
bool IsSimpleScalar(const relay::ConstantNode* constant_node) {
47+
return constant_node->is_scalar() && IsSimpleScalarDtype(DataType(constant_node->data->dtype));
48+
}
49+
5050
runtime::NDArray IntImmToNDArray(const IntImm& int_imm) {
5151
DLDevice dev = {DLDeviceType::kDLCPU, 0};
5252
auto data = runtime::NDArray::Empty({}, int_imm->dtype, dev);
53-
if (int_imm.dtype() == kInt64) {
54-
auto array = reinterpret_cast<int64_t*>(data->data);
55-
array[0] = int_imm->value;
53+
if (int_imm.dtype() == kInt16) {
54+
auto array = reinterpret_cast<int16_t*>(data->data);
55+
array[0] = static_cast<int16_t>(int_imm->value);
5656
} else if (int_imm.dtype() == kInt32) {
5757
auto array = reinterpret_cast<int32_t*>(data->data);
5858
array[0] = static_cast<int32_t>(int_imm->value);
59-
} else if (int_imm.dtype() == kInt16) {
60-
auto array = reinterpret_cast<int16_t*>(data->data);
61-
array[0] = static_cast<int16_t>(int_imm->value);
59+
} else if (int_imm.dtype() == kInt64) {
60+
auto array = reinterpret_cast<int64_t*>(data->data);
61+
array[0] = int_imm->value;
6262
} else {
6363
LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(int_imm.dtype());
6464
}
@@ -68,15 +68,15 @@ runtime::NDArray IntImmToNDArray(const IntImm& int_imm) {
6868
runtime::NDArray FloatImmToNDArray(const FloatImm& float_imm) {
6969
DLDevice dev = {DLDeviceType::kDLCPU, 0};
7070
auto data = runtime::NDArray::Empty({}, float_imm->dtype, dev);
71-
if (float_imm.dtype() == kFloat64) {
72-
auto array = reinterpret_cast<double*>(data->data);
73-
array[0] = float_imm->value;
71+
if (float_imm.dtype() == kFloat16) {
72+
auto array = reinterpret_cast<uint16_t*>(data->data);
73+
array[0] = __gnu_f2h_ieee(static_cast<float>(float_imm->value));
7474
} else if (float_imm.dtype() == kFloat32) {
7575
auto array = reinterpret_cast<float*>(data->data);
7676
array[0] = static_cast<float>(float_imm->value);
77-
} else if (float_imm.dtype() == kFloat16) {
78-
auto array = reinterpret_cast<uint16_t*>(data->data);
79-
array[0] = __gnu_f2h_ieee(static_cast<float>(float_imm->value));
77+
} else if (float_imm.dtype() == kFloat64) {
78+
auto array = reinterpret_cast<double*>(data->data);
79+
array[0] = float_imm->value;
8080
} else {
8181
LOG(FATAL) << "Unrecognized numeric literal dtype: " << DLDataType2String(float_imm.dtype());
8282
}

src/support/scalars.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
namespace tvm {
3636
namespace support {
3737

38+
/*! \brief Returns true if a tensor of empty shape and given dtype is considered a Relay scalar. */
39+
bool IsSimpleScalarDtype(DataType dtype);
40+
3841
/*! \brief Returns true if \p constant_node is a float/int/bool scalar. */
3942
bool IsSimpleScalar(const relay::ConstantNode* constant_node);
4043

tests/python/relay/test_target_hooks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,5 @@ def test_runtime_module_generation(check_result):
7373

7474

7575
if __name__ == "__main__":
76-
sys.exit(pytest.main([__file__] + sys.argv[1:]))
76+
# sys.exit(pytest.main([__file__] + sys.argv[1:]))
77+
test_runtime_module_generation(check_aot_executor_result)

0 commit comments

Comments
 (0)