diff --git a/decode_test.go b/decode_test.go index 9269f12b..b05c466e 100644 --- a/decode_test.go +++ b/decode_test.go @@ -714,6 +714,14 @@ var unmarshalTests = []struct { "---\nhello\n...\n}not yaml", "hello", }, + { + "a: 5\n", + &struct{ A jsonNumberT }{"5"}, + }, + { + "a: 5.5\n", + &struct{ A jsonNumberT }{"5.5"}, + }, } type M map[interface{}]interface{} diff --git a/encode.go b/encode.go index a14435e8..0ee738e1 100644 --- a/encode.go +++ b/encode.go @@ -13,6 +13,19 @@ import ( "unicode/utf8" ) +// jsonNumber is the interface of the encoding/json.Number datatype. +// Repeating the interface here avoids a dependency on encoding/json, and also +// supports other libraries like jsoniter, which use a similar datatype with +// the same interface. Detecting this interface is useful when dealing with +// structures containing json.Number, which is a string under the hood. The +// encoder should prefer the use of Int64(), Float64() and string(), in that +// order, when encoding this type. +type jsonNumber interface { + Float64() (float64, error) + Int64() (int64, error) + String() string +} + type encoder struct { emitter yaml_emitter_t event yaml_event_t @@ -89,6 +102,21 @@ func (e *encoder) marshal(tag string, in reflect.Value) { } iface := in.Interface() switch m := iface.(type) { + case jsonNumber: + integer, err := m.Int64() + if err == nil { + // In this case the json.Number is a valid int64 + in = reflect.ValueOf(integer) + break + } + float, err := m.Float64() + if err == nil { + // In this case the json.Number is a valid float64 + in = reflect.ValueOf(float) + break + } + // fallback case - no number could be obtained + in = reflect.ValueOf(m.String()) case time.Time, *time.Time: // Although time.Time implements TextMarshaler, // we don't want to treat it as a string for YAML diff --git a/encode_test.go b/encode_test.go index f0911a76..4a266008 100644 --- a/encode_test.go +++ b/encode_test.go @@ -15,6 +15,24 @@ import ( "gopkg.in/yaml.v2" ) +type jsonNumberT string + +func (j jsonNumberT) Int64() (int64, error) { + val, err := strconv.Atoi(string(j)) + if err != nil { + return 0, err + } + return int64(val), nil +} + +func (j jsonNumberT) Float64() (float64, error) { + return strconv.ParseFloat(string(j), 64) +} + +func (j jsonNumberT) String() string { + return string(j) +} + var marshalIntTest = 123 var marshalTests = []struct { @@ -367,6 +385,18 @@ var marshalTests = []struct { map[string]string{"a": "你好 #comment"}, "a: '你好 #comment'\n", }, + { + map[string]interface{}{"a": jsonNumberT("5")}, + "a: 5\n", + }, + { + map[string]interface{}{"a": jsonNumberT("100.5")}, + "a: 100.5\n", + }, + { + map[string]interface{}{"a": jsonNumberT("bogus")}, + "a: bogus\n", + }, } func (s *S) TestMarshal(c *C) {