Skip to content
This repository was archived by the owner on Jul 22, 2024. It is now read-only.

Commit 2da917e

Browse files
Ability to use encoding.TextUnmarshaler
1 parent 95075d6 commit 2da917e

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

decode_hooks.go

+24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package mapstructure
22

33
import (
4+
"encoding"
45
"errors"
56
"fmt"
67
"net"
@@ -230,3 +231,26 @@ func RecursiveStructToMapHookFunc() DecodeHookFunc {
230231
return f.Interface(), nil
231232
}
232233
}
234+
235+
// TextUnmarshallerHookFunc returns a DecodeHookFunc that applies
236+
// strings to the UnmarshalText function, when the target type
237+
// implements the encoding.TextUnmarshaler interface
238+
func TextUnmarshallerHookFunc() DecodeHookFuncType {
239+
return func(
240+
f reflect.Type,
241+
t reflect.Type,
242+
data interface{}) (interface{}, error) {
243+
if f.Kind() != reflect.String {
244+
return data, nil
245+
}
246+
result := reflect.New(t).Interface()
247+
unmarshaller, ok := result.(encoding.TextUnmarshaler)
248+
if !ok {
249+
return data, nil
250+
}
251+
if err := unmarshaller.UnmarshalText([]byte(data.(string))); err != nil {
252+
return nil, err
253+
}
254+
return result, nil
255+
}
256+
}

decode_hooks_test.go

+26
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package mapstructure
22

33
import (
44
"errors"
5+
"math/big"
56
"net"
67
"reflect"
78
"testing"
@@ -419,3 +420,28 @@ func TestStructToMapHookFuncTabled(t *testing.T) {
419420

420421
}
421422
}
423+
424+
func TestTextUnmarshallerHookFunc(t *testing.T) {
425+
cases := []struct {
426+
f, t reflect.Value
427+
result interface{}
428+
err bool
429+
}{
430+
{reflect.ValueOf("42"), reflect.ValueOf(big.Int{}), big.NewInt(42), false},
431+
{reflect.ValueOf("invalid"), reflect.ValueOf(big.Int{}), nil, true},
432+
{reflect.ValueOf("5"), reflect.ValueOf("5"), "5", false},
433+
}
434+
435+
for i, tc := range cases {
436+
f := TextUnmarshallerHookFunc()
437+
actual, err := DecodeHookExec(f, tc.f, tc.t)
438+
if tc.err != (err != nil) {
439+
t.Fatalf("case %d: expected err %#v", i, tc.err)
440+
}
441+
if !reflect.DeepEqual(actual, tc.result) {
442+
t.Fatalf(
443+
"case %d: expected %#v, got %#v",
444+
i, tc.result, actual)
445+
}
446+
}
447+
}

0 commit comments

Comments
 (0)