Skip to content

Commit 1a271f0

Browse files
committed
[FFI] Make JSON Parser/Write fastmath safe (apache#18212)
This PR adds fallbacks for nan and inf detection/creation under fastmath mode.
1 parent 03e8a6b commit 1a271f0

File tree

3 files changed

+97
-8
lines changed

3 files changed

+97
-8
lines changed

src/ffi/extra/json_parser.cc

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ class JSONParserContext {
167167
++cur_;
168168
if (cur_ != end_ && *cur_ == 'I') {
169169
if (this->MatchLiteral("Infinity", 8)) {
170-
*out = -std::numeric_limits<double>::infinity();
170+
*out = FastMathSafeNegInf();
171171
return true;
172172
} else {
173173
this->SetCurrentPosForBetterErrorMsg(start_pos);
@@ -177,7 +177,7 @@ class JSONParserContext {
177177
}
178178
} else if (*cur_ == 'I') {
179179
if (this->MatchLiteral("Infinity", 8)) {
180-
*out = std::numeric_limits<double>::infinity();
180+
*out = FastMathSafePosInf();
181181
return true;
182182
} else {
183183
this->SetCurrentPosForBetterErrorMsg(start_pos);
@@ -186,7 +186,7 @@ class JSONParserContext {
186186
}
187187
} else if (*cur_ == 'N') {
188188
if (this->MatchLiteral("NaN", 3)) {
189-
*out = std::numeric_limits<double>::quiet_NaN();
189+
*out = FastMathSafeNaN();
190190
return true;
191191
} else {
192192
this->SetCurrentPosForBetterErrorMsg(start_pos);
@@ -296,6 +296,33 @@ class JSONParserContext {
296296
void SetErrorExpectingComma() { error_msg_ = GetSyntaxErrorContext("Expecting \',\' delimiter"); }
297297

298298
private:
299+
static double FastMathSafePosInf() {
300+
#ifdef __FAST_MATH__
301+
const uint64_t inf_bits = 0x7FF0000000000000ULL;
302+
return *reinterpret_cast<const double*>(&inf_bits);
303+
#else
304+
return std::numeric_limits<double>::infinity();
305+
#endif
306+
}
307+
308+
static double FastMathSafeNegInf() {
309+
#ifdef __FAST_MATH__
310+
const uint64_t inf_bits = 0xFFF0000000000000ULL;
311+
return *reinterpret_cast<const double*>(&inf_bits);
312+
#else
313+
return -std::numeric_limits<double>::infinity();
314+
#endif
315+
}
316+
317+
static double FastMathSafeNaN() {
318+
#ifdef __FAST_MATH__
319+
const uint64_t nan_bits = 0x7FF8000000000000ULL;
320+
return *reinterpret_cast<const double*>(&nan_bits);
321+
#else
322+
return std::numeric_limits<double>::quiet_NaN();
323+
#endif
324+
}
325+
299326
// Full string parsing with escape and unicode handling
300327
bool NextStringWithFullHandling(Any* out, const char* start_pos) {
301328
// copy over the prefix that was already parsed

src/ffi/extra/json_writer.cc

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,37 @@ class JSONWriter {
6060
private:
6161
explicit JSONWriter(int indent) : indent_(indent), out_iter_(result_) {}
6262

63+
static bool FastMathSafeIsNaN(double x) {
64+
#ifdef __FAST_MATH__
65+
// Bit-level NaN detection (IEEE 754 double)
66+
// IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754
67+
// NaN is encoded as all 1s in the exponent and non-zero in the mantissa
68+
static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size");
69+
uint64_t bits = *reinterpret_cast<const uint64_t*>(&x);
70+
uint64_t exponent = (bits >> 52) & 0x7FF;
71+
uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull;
72+
return (exponent == 0x7FF) && (mantissa != 0);
73+
#else
74+
// Safe to use std::isnan when fast-math is off
75+
return std::isnan(x);
76+
#endif
77+
}
78+
79+
static bool FastMathSafeIsInf(double x) {
80+
#ifdef __FAST_MATH__
81+
// IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754
82+
// Inf is encoded as all 1s in the exponent and zero in the mantissa
83+
static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size");
84+
uint64_t bits = *reinterpret_cast<const uint64_t*>(&x);
85+
uint64_t exponent = (bits >> 52) & 0x7FF;
86+
uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull;
87+
// inf is encoded as all 1s in the exponent and zero in the mantissa
88+
return (exponent == 0x7FF) && (mantissa == 0);
89+
#else
90+
return std::isinf(x);
91+
#endif
92+
}
93+
6394
void WriteValue(const json::Value& value) {
6495
switch (value.type_index()) {
6596
case TypeIndex::kTVMFFINone: {
@@ -120,9 +151,9 @@ class JSONWriter {
120151
// largest possible string representation of a double is around 24 chars plus
121152
// one null terminator keep 32 to be safe
122153
char buffer[32];
123-
if (std::isnan(value)) {
154+
if (FastMathSafeIsNaN(value)) {
124155
WriteLiteral("NaN", 3);
125-
} else if (std::isinf(value)) {
156+
} else if (FastMathSafeIsInf(value)) {
126157
if (value < 0) {
127158
WriteLiteral("-Infinity", 9);
128159
} else {

tests/cpp/extra/test_json_parser.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,37 @@ namespace {
2828

2929
using namespace tvm::ffi;
3030

31+
inline bool FastMathSafeIsNaN(double x) {
32+
#ifdef __FAST_MATH__
33+
// Bit-level NaN detection (IEEE 754 double)
34+
// IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754
35+
// NaN is encoded as all 1s in the exponent and non-zero in the mantissa
36+
static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size");
37+
uint64_t bits = *reinterpret_cast<const uint64_t*>(&x);
38+
uint64_t exponent = (bits >> 52) & 0x7FF;
39+
uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull;
40+
return (exponent == 0x7FF) && (mantissa != 0);
41+
#else
42+
// Safe to use std::isnan when fast-math is off
43+
return std::isnan(x);
44+
#endif
45+
}
46+
47+
inline bool FastMathSafeIsInf(double x) {
48+
#ifdef __FAST_MATH__
49+
// IEEE 754 standard: https://en.wikipedia.org/wiki/IEEE_754
50+
// Inf is encoded as all 1s in the exponent and zero in the mantissa
51+
static_assert(sizeof(double) == sizeof(uint64_t), "Unexpected double size");
52+
uint64_t bits = *reinterpret_cast<const uint64_t*>(&x);
53+
uint64_t exponent = (bits >> 52) & 0x7FF;
54+
uint64_t mantissa = bits & 0xFFFFFFFFFFFFFull;
55+
// inf is encoded as all 1s in the exponent and zero in the mantissa
56+
return (exponent == 0x7FF) && (mantissa == 0);
57+
#else
58+
return std::isinf(x);
59+
#endif
60+
}
61+
3162
TEST(JSONParser, BoolNull) {
3263
// boolean value
3364
EXPECT_EQ(json::Parse("true").cast<bool>(), true);
@@ -61,11 +92,11 @@ TEST(JSONParser, Number) {
6192
// parsing scientific notation
6293
EXPECT_EQ(json::Parse("1.456e12").cast<double>(), 1.456e12);
6394
// NaN
64-
EXPECT_EQ(std::isnan(json::Parse("NaN").cast<double>()), true);
95+
EXPECT_EQ(FastMathSafeIsNaN(json::Parse("NaN").cast<double>()), true);
6596
// Infinity
66-
EXPECT_EQ(std::isinf(json::Parse("Infinity").cast<double>()), true);
97+
EXPECT_EQ(FastMathSafeIsInf(json::Parse("Infinity").cast<double>()), true);
6798
// -Infinity
68-
EXPECT_EQ(std::isinf(-json::Parse("-Infinity").cast<double>()), true);
99+
EXPECT_EQ(FastMathSafeIsInf(-json::Parse("-Infinity").cast<double>()), true);
69100

70101
// Test zero variants
71102
EXPECT_EQ(json::Parse("0").cast<int64_t>(), 0);

0 commit comments

Comments
 (0)