diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d974d7..00ed39f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added - Add support for HJSON. #131 - Add new parse.Config to adjust parsing of varibles returned by a Resolve. #139 +- Add call to InitDefaults when map, primitives, or structs implement Initializer interface during Unpack. #104 ### Changed - Moved internal/parse to parse module. #139 diff --git a/initializer.go b/initializer.go new file mode 100644 index 0000000..3614f3f --- /dev/null +++ b/initializer.go @@ -0,0 +1,59 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package ucfg + +import ( + "reflect" +) + +// Initializer interface provides initialization of default values support to Unpack. +// The InitDefaults method will be executed for any type passed directly or indirectly to +// Unpack. +type Initializer interface { + InitDefaults() +} + +func tryInitDefaults(val reflect.Value) reflect.Value { + t := val.Type() + + var initializer Initializer + if t.Implements(iInitializer) { + initializer = val.Interface().(Initializer) + initializer.InitDefaults() + return val + } else if reflect.PtrTo(t).Implements(iInitializer) { + tmp := pointerize(reflect.PtrTo(t), t, val) + initializer = tmp.Interface().(Initializer) + initializer.InitDefaults() + + // Return the element in the pointer so the value is set into the + // field and not a pointer to the value. + return tmp.Elem() + } + return val +} + +func hasInitDefaults(t reflect.Type) bool { + if t.Implements(iInitializer) { + return true + } + if reflect.PtrTo(t).Implements(iInitializer) { + return true + } + return false +} diff --git a/initializer_test.go b/initializer_test.go new file mode 100644 index 0000000..5f25558 --- /dev/null +++ b/initializer_test.go @@ -0,0 +1,235 @@ +// Licensed to Elasticsearch B.V. under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Elasticsearch B.V. licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package ucfg + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type myIntInitializer int +type myMapInitializer map[string]string + +func (i *myIntInitializer) InitDefaults() { + *i = myIntInitializer(3) +} + +func (m *myMapInitializer) InitDefaults() { + (*m)["init"] = "defaults" +} + +type structInitializer struct { + I int + J int +} + +func (s *structInitializer) InitDefaults() { + s.J = 10 +} + +type structNoInitalizer struct { + I myIntInitializer +} + +type nestedStructInitializer struct { + M myMapInitializer + N structInitializer + O int + P structNoInitalizer +} + +func (n *nestedStructInitializer) InitDefaults() { + n.O = 20 + + // overridden by InitDefaults from structInitializer + n.N.J = 15 +} + +type ptrNestedStructInitializer struct { + M *myMapInitializer + N *structInitializer + O int + P *structNoInitalizer +} + +func (n *ptrNestedStructInitializer) InitDefaults() { + n.O = 20 +} + +func TestInitDefaultsPrimitive(t *testing.T) { + c, _ := NewFrom(map[string]interface{}{}) + + // unpack S + r := &struct { + I myIntInitializer + }{} + + err := c.Unpack(r) + assert.NoError(t, err) + assert.Equal(t, myIntInitializer(3), r.I) +} + +func TestInitDefaultsPrimitiveSet(t *testing.T) { + c, _ := NewFrom(map[string]interface{}{ + "i": 25, + }) + + // unpack S + r := &struct { + I myIntInitializer + }{} + + err := c.Unpack(r) + assert.NoError(t, err) + assert.Equal(t, myIntInitializer(25), r.I) +} + +func TestInitDefaultsMap(t *testing.T) { + c, _ := NewFrom(map[string]interface{}{}) + + // unpack S + r := &struct { + M myMapInitializer + }{} + + err := c.Unpack(r) + assert.NoError(t, err) + assert.Equal(t, myMapInitializer{ + "init": "defaults", + }, r.M) +} + +func TestInitDefaultsMapUpdate(t *testing.T) { + c, _ := NewFrom(map[string]interface{}{ + "m": map[string]interface{}{ + "other": "config", + }, + }) + + // unpack S + r := &struct { + M myMapInitializer + }{} + + err := c.Unpack(r) + assert.NoError(t, err) + assert.Equal(t, myMapInitializer{ + "init": "defaults", + "other": "config", + }, r.M) +} + +func TestInitDefaultsMapReplace(t *testing.T) { + c, _ := NewFrom(map[string]interface{}{ + "m": map[string]interface{}{ + "init": "replace", + "other": "config", + }, + }) + + // unpack S + r := &struct { + M myMapInitializer + }{} + + err := c.Unpack(r) + assert.NoError(t, err) + assert.Equal(t, myMapInitializer{ + "init": "replace", + "other": "config", + }, r.M) +} + +func TestInitDefaultsSingle(t *testing.T) { + c, _ := NewFrom(map[string]interface{}{ + "s": map[string]interface{}{ + "i": 5, + }, + }) + + // unpack S + r := &struct { + S structInitializer + }{} + + err := c.Unpack(r) + assert.NoError(t, err) + assert.Equal(t, 5, r.S.I) + assert.Equal(t, 10, r.S.J) +} + +func TestInitDefaultsNested(t *testing.T) { + c, _ := NewFrom(map[string]interface{}{ + "s": map[string]interface{}{ + "n": map[string]interface{}{ + "i": 5, + }, + }, + }) + + // unpack S + r := &struct { + S nestedStructInitializer + }{} + + err := c.Unpack(r) + assert.NoError(t, err) + assert.Equal(t, myMapInitializer{ + "init": "defaults", + }, r.S.M) + assert.Equal(t, 5, r.S.N.I) + assert.Equal(t, 10, r.S.N.J) + assert.Equal(t, 20, r.S.O) + assert.Equal(t, myIntInitializer(3), r.S.P.I) +} + +func TestInitDefaultsNestedEmpty(t *testing.T) { + c, _ := NewFrom(map[string]interface{}{}) + + // unpack S + r := &struct { + S nestedStructInitializer + }{} + + err := c.Unpack(r) + assert.NoError(t, err) + assert.Equal(t, myMapInitializer{ + "init": "defaults", + }, r.S.M) + assert.Equal(t, 0, r.S.N.I) + assert.Equal(t, 10, r.S.N.J) + assert.Equal(t, 20, r.S.O) + assert.Equal(t, myIntInitializer(3), r.S.P.I) +} + +func TestInitDefaultsPtrNestedEmpty(t *testing.T) { + c, _ := NewFrom(map[string]interface{}{}) + + // unpack S + r := &struct { + S ptrNestedStructInitializer + }{} + + err := c.Unpack(r) + assert.NoError(t, err) + assert.Nil(t, r.S.M) + assert.Nil(t, r.S.N) + assert.Nil(t, r.S.P) + assert.Equal(t, 20, r.S.O) +} diff --git a/reify.go b/reify.go index 663cca7..994ba83 100644 --- a/reify.go +++ b/reify.go @@ -85,8 +85,19 @@ import ( // The struct tag options `replace`, `append`, and `prepend` overwrites the // global value merging strategy (e.g. ReplaceValues, AppendValues, ...) for all sub-fields. // +// When unpacking into a map, primitive, or struct Unpack will call InitDefaults if +// the type implements the Initializer interface. The Initializer interface is not supported +// on arrays or slices. InitDefaults is initialized top-down, meaning that if struct contains +// a map, struct, or primitive that also implements the Initializer interface the contained +// type will be initialized after the struct that contains it. (e.g. if we have +// type A struct { B B }, with both A, and B implementing InitDefaults, then A.InitDefaults +// is called before B.InitDefaults). In the case that a struct contains a pointer to +// a type that implements the Initializer interface and the configuration doesn't contain a +// value for that field then the pointer will not be initialized and InitDefaults will not +// be called. +// // Fields available in a struct or a map, but not in the Config object, will not -// be touched. Default values should be set in the target value before calling Unpack. +// be touched by Unpack unless they are initialized from InitDefaults. // // Type aliases like "type myTypeAlias T" are unpacked using Unpack if the alias // implements the Unpacker interface. Otherwise unpacking rules for type T will be used. @@ -180,6 +191,11 @@ func reifyMap(opts *options, to reflect.Value, from *Config) Error { return raiseKeyInvalidTypeUnpack(to.Type(), from) } + if to.IsNil() { + to.Set(reflect.MakeMap(to.Type())) + } + tryInitDefaults(to) + fields := from.fields.dict() if len(fields) == 0 { if err := tryValidate(to); err != nil { @@ -188,9 +204,6 @@ func reifyMap(opts *options, to reflect.Value, from *Config) Error { return nil } - if to.IsNil() { - to.Set(reflect.MakeMap(to.Type())) - } for k, value := range fields { opts.activeFields = newFieldSet(parentFields) key := reflect.ValueOf(k) @@ -235,6 +248,7 @@ func reifyStruct(opts *options, orig reflect.Value, cfg *Config) Error { return err } } else { + tryInitDefaults(to) numField := to.NumField() for i := 0; i < numField; i++ { stField := to.Type().Field(i) @@ -285,7 +299,7 @@ func reifyStruct(opts *options, orig reflect.Value, cfg *Config) Error { } else { name = fieldName(name, stField.Name) fopts := fieldOptions{opts: opts, tag: tagOpts, validators: validators} - if err := reifyGetField(cfg, fopts, name, vField); err != nil { + if err := reifyGetField(cfg, fopts, name, vField, stField.Type); err != nil { return err } } @@ -305,6 +319,7 @@ func reifyGetField( opts fieldOptions, name string, to reflect.Value, + fieldType reflect.Type, ) Error { p := parsePath(name, opts.opts.pathSep) value, err := p.GetValue(cfg, opts.opts) @@ -316,10 +331,28 @@ func reifyGetField( } if isNil(value) { - if err := runValidators(nil, opts.validators); err != nil { - return raiseValidation(cfg.ctx, cfg.metadata, name, err) + // When fieldType is a pointer and the value is nil, return nil as the + // underlying type should not be allocated. + if fieldType.Kind() == reflect.Ptr { + if err := runValidators(nil, opts.validators); err != nil { + return raiseValidation(cfg.ctx, cfg.metadata, name, err) + } + return nil + } + + // Primitive types return early when it doesn't implement the Initializer interface. + if fieldType.Kind() != reflect.Map && fieldType.Kind() != reflect.Struct && !hasInitDefaults(fieldType) { + if err := runValidators(nil, opts.validators); err != nil { + return raiseValidation(cfg.ctx, cfg.metadata, name, err) + } + return nil + } + + // None primitive types always get initialized even if it doesn't implement the + // Initializer interface, because nested types might implement the Initializer interface. + if value == nil { + value = &cfgNil{cfgPrimitive{cfg.ctx, cfg.metadata}} } - return nil } v, err := reifyMergeValue(opts, to, value) @@ -327,7 +360,7 @@ func reifyGetField( return err } - to.Set(v) + to.Set(pointerize(to.Type(), v.Type(), v)) return nil } @@ -627,7 +660,8 @@ func reifyPrimitive( ) (reflect.Value, Error) { // zero initialize value if val==nil if isNil(val) { - return pointerize(t, baseType, reflect.Zero(baseType)), nil + v := pointerize(t, baseType, reflect.Zero(baseType)) + return tryInitDefaults(v), nil } var v reflect.Value diff --git a/ucfg.go b/ucfg.go index 69f6eaa..bddd423 100644 --- a/ucfg.go +++ b/ucfg.go @@ -63,8 +63,9 @@ var ( tInterfaceArray = reflect.TypeOf([]interface{}(nil)) // interface types - tError = reflect.TypeOf((*error)(nil)).Elem() - tValidator = reflect.TypeOf((*Validator)(nil)).Elem() + tError = reflect.TypeOf((*error)(nil)).Elem() + iInitializer = reflect.TypeOf((*Initializer)(nil)).Elem() + tValidator = reflect.TypeOf((*Validator)(nil)).Elem() // primitives tBool = reflect.TypeOf(true) diff --git a/ucfg_test.go b/ucfg_test.go index 91bf52d..7518388 100644 --- a/ucfg_test.go +++ b/ucfg_test.go @@ -400,7 +400,7 @@ func TestRemove(t *testing.T) { }{ "exist": { cfg: map[string]interface{}{"field": "test"}, - wants: nil, + wants: map[string]interface{}{}, spec: spec{has: true, path: "field", idx: -1}, }, "unknown field": { diff --git a/unpack.go b/unpack.go index 85ba181..00fb92c 100644 --- a/unpack.go +++ b/unpack.go @@ -159,6 +159,11 @@ func implementsUnpacker(t reflect.Type) bool { } func unpackWith(opts *options, v reflect.Value, with value) Error { + // short circuit nil values + if isNil(with) { + return nil + } + ctx := with.Context() meta := with.meta() diff --git a/unpack_test.go b/unpack_test.go index 0ef66fa..6acf107 100644 --- a/unpack_test.go +++ b/unpack_test.go @@ -146,7 +146,9 @@ func TestReifyUnpackers(t *testing.T) { } // apply configurations - for _, c := range configs { + for i, c := range configs { + t.Logf("Unpacking config (%v): %#v", i, c) + cfg, err := NewFrom(c) if err != nil { t.Fatal(err) @@ -204,7 +206,9 @@ func TestReifyUnpackersPtr(t *testing.T) { } // apply configurations - for _, c := range configs { + for i, c := range configs { + t.Logf("Unpacking config (%v): %#v", i, c) + cfg, err := NewFrom(c) if err != nil { t.Fatal(err)