Skip to content

Commit 26f8e13

Browse files
committed
float: return error when marshaling NaN or Inf
1 parent 5d2d1e5 commit 26f8e13

File tree

4 files changed

+44
-0
lines changed

4 files changed

+44
-0
lines changed

float.go

+7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"database/sql"
55
"encoding/json"
66
"fmt"
7+
"math"
78
"reflect"
89
"strconv"
910
)
@@ -91,6 +92,12 @@ func (f Float) MarshalJSON() ([]byte, error) {
9192
if !f.Valid {
9293
return []byte("null"), nil
9394
}
95+
if math.IsInf(f.Float64, 0) || math.IsNaN(f.Float64) {
96+
return nil, &json.UnsupportedValueError{
97+
Value: reflect.ValueOf(f.Float64),
98+
Str: strconv.FormatFloat(f.Float64, 'g', -1, 64),
99+
}
100+
}
94101
return []byte(strconv.FormatFloat(f.Float64, 'f', -1, 64)), nil
95102
}
96103

float_test.go

+15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package null
22

33
import (
44
"encoding/json"
5+
"math"
56
"testing"
67
)
78

@@ -170,6 +171,20 @@ func TestFloatScan(t *testing.T) {
170171
assertNullFloat(t, null, "scanned null")
171172
}
172173

174+
func TestFloatInfNaN(t *testing.T) {
175+
nan := NewFloat(math.NaN(), true)
176+
_, err := nan.MarshalJSON()
177+
if err == nil {
178+
t.Error("expected error for NaN, got nil")
179+
}
180+
181+
inf := NewFloat(math.Inf(1), true)
182+
_, err = inf.MarshalJSON()
183+
if err == nil {
184+
t.Error("expected error for Inf, got nil")
185+
}
186+
}
187+
173188
func assertFloat(t *testing.T, f Float, from string) {
174189
if f.Float64 != 1.2345 {
175190
t.Errorf("bad %s float: %f ≠ %f\n", from, f.Float64, 1.2345)

zero/float.go

+7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"database/sql"
55
"encoding/json"
66
"fmt"
7+
"math"
78
"reflect"
89
"strconv"
910
)
@@ -92,6 +93,12 @@ func (f Float) MarshalJSON() ([]byte, error) {
9293
if !f.Valid {
9394
n = 0
9495
}
96+
if math.IsInf(f.Float64, 0) || math.IsNaN(f.Float64) {
97+
return nil, &json.UnsupportedValueError{
98+
Value: reflect.ValueOf(f.Float64),
99+
Str: strconv.FormatFloat(f.Float64, 'g', -1, 64),
100+
}
101+
}
95102
return []byte(strconv.FormatFloat(n, 'f', -1, 64)), nil
96103
}
97104

zero/float_test.go

+15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package zero
22

33
import (
44
"encoding/json"
5+
"math"
56
"testing"
67
)
78

@@ -176,6 +177,20 @@ func TestFloatScan(t *testing.T) {
176177
assertNullFloat(t, null, "scanned null")
177178
}
178179

180+
func TestFloatInfNaN(t *testing.T) {
181+
nan := NewFloat(math.NaN(), true)
182+
_, err := nan.MarshalJSON()
183+
if err == nil {
184+
t.Error("expected error for NaN, got nil")
185+
}
186+
187+
inf := NewFloat(math.Inf(1), true)
188+
_, err = inf.MarshalJSON()
189+
if err == nil {
190+
t.Error("expected error for Inf, got nil")
191+
}
192+
}
193+
179194
func assertFloat(t *testing.T, f Float, from string) {
180195
if f.Float64 != 1.2345 {
181196
t.Errorf("bad %s float: %f ≠ %f\n", from, f.Float64, 1.2345)

0 commit comments

Comments
 (0)