diff --git a/schema/elements.go b/schema/elements.go index 01ef4ba0..055e3e40 100644 --- a/schema/elements.go +++ b/schema/elements.go @@ -83,18 +83,15 @@ const ( // Struct represents a type which is composed of a number of different fields. // Each field has a name and a type. -// -// TODO: in the future, we will add one-of groups (sometimes called unions). type Struct struct { // Each struct field appears exactly once in this list. The order in // this list defines the canonical field ordering. Fields []StructField `yaml:"fields,omitempty"` - // TODO: Implement unions, either this way or by inlining. - // Unions are groupings of fields with special rules. They may refer to + // Union is a grouping of fields with special rules. It may refer to // one or more fields in the above list. A given field from the above // list may be referenced in exactly 0 or 1 places in the below list. - // Unions []Union `yaml:"unions,omitempty"` + Union *Union `yaml:"union,omitempty"` // ElementRelationship states the relationship between the struct's items. // * `separable` (or unset) implies that each element is 100% independent. @@ -108,6 +105,45 @@ type Struct struct { ElementRelationship ElementRelationship `yaml:"elementRelationship,omitempty"` } +// UnionFields are mapping between the fields that are part of the union and +// their discriminated value. The discriminated value has to be set, and +// should not conflict with other discriminated value in the list. +type UnionField struct { + // FieldName is the name of the field that is part of the union. This + // is the serialized form of the field. + FieldName string `yaml:"fieldName"` + // DiscriminatedBy is the value of the discriminator to select that + // field. If the union doesn't have a discriminator, this field is + // ignored. + DiscriminatedBy string `yaml:"discriminatedBy"` +} + +// Union, or oneof, means that only one of multiple fields of a structure can be +// set at a time. For backward compatibility reasons, and to help "dumb clients" +// which are not aware of the union (or can't be aware of it because they +// don't know what fields are part of the union), the code tolerates multiple +// fields to be set but will try to detect which fields must be cleared (there +// should never be more than two though): +// - If there is a discriminator and its value has changed, clear all fields +// but the one specified by the discriminator +// - If there is no discriminator, or it hasn't changed, if new has two of the +// fields set, remove the one that was set in old. +// - If there is a discriminator, set it to the value we've kept (if it changed) +type Union struct { + // Discriminator, if present, is the name of the field that + // discriminates fields in the union. The mapping between the value of + // the discriminator and the field is done by using the Fields list + // below. + Discriminator *string `yaml:"discriminator,omitempty"` + + // This is the list of fields that belong to this union. All the + // fields present in here have to be part of the parent + // structure. Discriminator (if oneOf has one), is NOT included in + // this list. The value for field is how we map the name of the field + // to actual value for discriminator. + Fields []UnionField `yaml:"fields,omitempty"` +} + // StructField pairs a field name with a field type. type StructField struct { // Name is the field name. diff --git a/schema/schemaschema.go b/schema/schemaschema.go index 628c5f86..0f72fba6 100644 --- a/schema/schemaschema.go +++ b/schema/schemaschema.go @@ -84,9 +84,35 @@ var SchemaSchemaYAML = `types: namedType: structField elementRelationship: associative keys: [ "name" ] + - name: union + type: + namedType: union - name: elementRelationship type: scalar: string +- name: unionField + struct: + fields: + - name: fieldName + type: + scalar: string + - name: discriminatedBy + type: + scalar: string +- name: union + struct: + fields: + - name: discriminator + type: + scalar: string + - name: fields + type: + list: + elementRelationship: associative + elementType: + namedType: unionField + keys: + - fieldName - name: structField struct: fields: diff --git a/typed/typed.go b/typed/typed.go index 475361ad..093fd76c 100644 --- a/typed/typed.go +++ b/typed/typed.go @@ -146,6 +146,38 @@ func (tv TypedValue) RemoveItems(items *fieldpath.Set) *TypedValue { return &tv } +// NormalizeUnions takes the new object and normalizes the union: +// - If there is a discriminator and its value has changed, clean all +// fields but the one specified by the discriminator +// - If there is no discriminator, or it hasn't changed, if new has two +// of the fields set, remove the one that was set in old. +// - If there is a discriminator, set it to the value we've kept (if it changed) +// +// This can fail if: +// - Multiple new fields are set, +// - The discriminator is changed, and at least one new field is set. +func (tv TypedValue) NormalizeUnions(new *TypedValue) (*TypedValue, error) { + var errs ValidationErrors + var normalizeFn = func(w *mergingWalker) { + if err := normalizeUnion(w); err != nil { + errs = append(errs, w.error(err)...) + } + } + out, mergeErrs := merge(&tv, new, func(w *mergingWalker) { + if w.rhs != nil { + v := *w.rhs + w.out = &v + } + }, normalizeFn) + if mergeErrs != nil { + errs = append(errs, mergeErrs.(ValidationErrors)...) + } + if len(errs) > 0 { + return nil, errs + } + return out, nil +} + func merge(lhs, rhs *TypedValue, rule, postRule mergeRule) (*TypedValue, error) { if lhs.schema != rhs.schema { return nil, errorFormatter{}. diff --git a/typed/union.go b/typed/union.go new file mode 100644 index 00000000..7684c222 --- /dev/null +++ b/typed/union.go @@ -0,0 +1,215 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package typed + +import ( + "fmt" + + "sigs.k8s.io/structured-merge-diff/schema" + "sigs.k8s.io/structured-merge-diff/value" +) + +func normalizeUnion(w *mergingWalker) error { + atom, found := w.schema.Resolve(w.typeRef) + if !found { + panic(fmt.Sprintf("Unable to resolve schema in normalize union: %v/%v", w.schema, w.typeRef)) + } + // Unions can only be in structures, and the struct must not have been removed + if atom.Struct == nil || atom.Struct.Union == nil || w.out == nil { + return nil + } + + old := &value.Map{} + if w.lhs != nil { + old = w.lhs.MapValue + } + return newUnion(atom.Struct.Union).Normalize(old, w.rhs.MapValue, w.out.MapValue) +} + +type discriminated string +type field string + +type discriminatedNames struct { + f2d map[field]discriminated + d2f map[discriminated]field +} + +func newDiscriminatedName(f2d map[field]discriminated) discriminatedNames { + d2f := map[discriminated]field{} + for key, value := range f2d { + d2f[value] = key + } + return discriminatedNames{ + f2d: f2d, + d2f: d2f, + } +} + +func (dn discriminatedNames) toField(d discriminated) field { + if f, ok := dn.d2f[d]; ok { + return f + } + return field(d) +} + +func (dn discriminatedNames) toDiscriminated(f field) discriminated { + if d, ok := dn.f2d[f]; ok { + return d + } + return discriminated(f) +} + +type discriminator struct { + name string +} + +func (d *discriminator) Set(m *value.Map, v discriminated) { + if d == nil { + return + } + m.Set(d.name, value.StringValue(string(v))) +} + +func (d *discriminator) Get(m *value.Map) discriminated { + if d == nil || m == nil { + return "" + } + f, ok := m.Get(d.name) + if !ok { + return "" + } + if f.Value.StringValue == nil { + return "" + } + return discriminated(*f.Value.StringValue) +} + +type fieldsSet map[field]struct{} + +// newFieldsSet returns a map of the fields that are part of the union and are set +// in the given map. +func newFieldsSet(m *value.Map, fields []field) fieldsSet { + if m == nil { + return nil + } + set := fieldsSet{} + for _, f := range fields { + if _, ok := m.Get(string(f)); ok { + set.Add(f) + } + } + return set +} + +func (fs fieldsSet) Add(f field) { + if fs == nil { + fs = map[field]struct{}{} + } + fs[f] = struct{}{} +} + +func (fs fieldsSet) One() *field { + for f := range fs { + return &f + } + return nil +} + +func (fs fieldsSet) Has(f field) bool { + _, ok := fs[f] + return ok +} + +func (fs fieldsSet) List() []field { + fields := []field{} + for f := range fs { + fields = append(fields, f) + } + return fields +} + +func (fs fieldsSet) Difference(o fieldsSet) fieldsSet { + n := fieldsSet{} + for f := range fs { + if !o.Has(f) { + n.Add(f) + } + } + return n +} + +type union struct { + d *discriminator + dn discriminatedNames + f []field +} + +func newUnion(su *schema.Union) *union { + u := &union{} + if su.Discriminator != nil { + u.d = &discriminator{name: *su.Discriminator} + } + f2d := map[field]discriminated{} + for _, f := range su.Fields { + u.f = append(u.f, field(f.FieldName)) + f2d[field(f.FieldName)] = discriminated(f.DiscriminatedBy) + } + u.dn = newDiscriminatedName(f2d) + return u +} + +// clear removes all the fields in map that are part of the union, but +// the one we decided to keep. +func (u *union) clear(m *value.Map, f field) { + for _, fieldName := range u.f { + if field(fieldName) != f { + m.Delete(string(fieldName)) + } + } +} + +func (u *union) Normalize(old, new, out *value.Map) error { + os := newFieldsSet(old, u.f) + ns := newFieldsSet(new, u.f) + diff := ns.Difference(os) + + if len(ns) > 1 && len(diff) != 1 { + return fmt.Errorf("unable to guess new discriminator: %v", diff) + } + + discriminator := field("") + if len(ns) == 1 { + discriminator = *ns.One() + } else if len(diff) == 1 { + discriminator = *diff.One() + } + + if u.d.Get(old) != u.d.Get(new) && u.d.Get(new) != "" { + if len(diff) == 1 && u.d.Get(new) != u.dn.toDiscriminated(discriminator) { + return fmt.Errorf("discriminator and field changed: %v/%v", discriminator, u.d.Get(new)) + } + u.clear(out, u.dn.toField(u.d.Get(new))) + return nil + } + + if discriminator != "" { + u.clear(out, discriminator) + u.d.Set(out, u.dn.toDiscriminated(discriminator)) + } + + return nil +} diff --git a/typed/union_test.go b/typed/union_test.go new file mode 100644 index 00000000..25b791d2 --- /dev/null +++ b/typed/union_test.go @@ -0,0 +1,300 @@ +/* +Copyright 2019 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package typed_test + +import ( + "testing" + + "sigs.k8s.io/structured-merge-diff/typed" +) + +var unionParser = func() typed.ParseableType { + parser, err := typed.NewParser(`types: +- name: union + struct: + fields: + - name: discriminator + type: + scalar: string + - name: one + type: + scalar: numeric + - name: two + type: + scalar: numeric + - name: three + type: + scalar: numeric + - name: nodisc + type: + namedType: nondiscriminated + union: + discriminator: discriminator + fields: + - fieldName: one + discriminatedBy: One + - fieldName: two + discriminatedBy: TWO + - fieldName: three + discriminatedBy: three +- name: nondiscriminated + struct: + fields: + - name: a + type: + scalar: numeric + - name: b + type: + scalar: numeric + - name: c + type: + scalar: numeric + union: + fields: + - fieldName: a + discriminatedBy: A + - fieldName: b + discriminatedBy: B + - fieldName: c + discriminatedBy: C]`) + if err != nil { + panic(err) + } + return parser.Type("union") +}() + +func TestNormalizeUnions(t *testing.T) { + tests := []struct { + name string + old typed.YAMLObject + new typed.YAMLObject + out typed.YAMLObject + }{ + { + name: "nothing changed, add discriminator", + old: `{"one": 1}`, + new: `{"one": 1}`, + out: `{"one": 1, "discriminator": "One"}`, + }, + { + name: "proper union update, setting discriminator", + old: `{"one": 1}`, + new: `{"two": 1}`, + out: `{"two": 1, "discriminator": "TWO"}`, + }, + { + name: "proper union update, no discriminator", + old: `{"nodisc": {"a": 1}}`, + new: `{"nodisc": {"b": 1}}`, + out: `{"nodisc": {"b": 1}}`, + }, + { + name: "proper union update from not-set, setting discriminator", + old: `{}`, + new: `{"two": 1}`, + out: `{"two": 1, "discriminator": "TWO"}`, + }, + { + name: "proper union update from not-set, no discriminator", + old: `{}`, + new: `{"nodisc": {"b": 1}}`, + out: `{"nodisc": {"b": 1}}`, + }, + { + name: "remove union, with discriminator", + old: `{"one": 1}`, + new: `{}`, + out: `{}`, + }, + { + name: "remove union and discriminator", + old: `{"one": 1, "discriminator": "One"}`, + new: `{}`, + out: `{}`, + }, + { + name: "remove union, not discriminator", + old: `{"one": 1, "discriminator": "One"}`, + new: `{"discriminator": "One"}`, + out: `{"discriminator": "One"}`, + }, + { + name: "remove union, no discriminator", + old: `{"nodisc": {"b": 1}}`, + new: `{}`, + out: `{}`, + }, + { + name: "dumb client update, no discriminator", + old: `{"nodisc": {"a": 1}}`, + new: `{"nodisc": {"a": 2, "b": 1}}`, + out: `{"nodisc": {"b": 1}}`, + }, + { + name: "dumb client update, sets discriminator", + old: `{"one": 1}`, + new: `{"one": 2, "two": 1}`, + out: `{"two": 1, "discriminator": "TWO"}`, + }, + { + name: "dumb client doesn't update discriminator", + old: `{"one": 1, "discriminator": "One"}`, + new: `{"one": 2, "two": 1, "discriminator": "One"}`, + out: `{"two": 1, "discriminator": "TWO"}`, + }, + { + name: "multi-discriminator at the same time", + old: `{"one": 1, "nodisc": {"a": 1}}`, + new: `{"one": 1, "three": 1, "nodisc": {"a": 1, "b": 1}}`, + out: `{"three": 1, "discriminator": "three", "nodisc": {"b": 1}}`, + }, + { + name: "change discriminator, nothing else", + old: `{"discriminator": "One"}`, + new: `{"discriminator": "random"}`, + out: `{"discriminator": "random"}`, + }, + { + name: "change discriminator, nothing else, it drops other field", + old: `{"discriminator": "One", "one": 1}`, + new: `{"discriminator": "random", "one": 1}`, + out: `{"discriminator": "random"}`, + }, + { + name: "remove discriminator, nothing else", + old: `{"discriminator": "One", "one": 1}`, + new: `{"one": 1}`, + out: `{"one": 1, "discriminator": "One"}`, + }, + { + name: "remove discriminator, add new field", + old: `{"discriminator": "One", "one": 1}`, + new: `{"two": 1}`, + out: `{"two": 1, "discriminator": "TWO"}`, + }, + { + name: "both fields removed", + old: `{"one": 1, "two": 1}`, + new: `{}`, + out: `{}`, + }, + { + name: "one field removed", + old: `{"one": 1, "two": 1}`, + new: `{"one": 1}`, + out: `{"one": 1, "discriminator": "One"}`, + }, + // These use-cases shouldn't happen: + { + name: "one field removed, discriminator unchanged", + old: `{"one": 1, "two": 1, "discriminator": "TWO"}`, + new: `{"one": 1, "discriminator": "TWO"}`, + out: `{"one": 1, "discriminator": "One"}`, + }, + { + name: "one field removed, discriminator added", + old: `{"two": 2, "one": 1}`, + new: `{"one": 1, "discriminator": "TWO"}`, + out: `{"discriminator": "TWO"}`, + }, + { + name: "old object has two of same union, but we add third", + old: `{"discriminator": "One", "one": 1, "two": 1}`, + new: `{"discriminator": "One", "one": 1, "two": 1, "three": 1}`, + out: `{"discriminator": "three", "three": 1}`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + old, err := unionParser.FromYAML(test.old) + if err != nil { + t.Fatalf("Failed to parse old object: %v", err) + } + new, err := unionParser.FromYAML(test.new) + if err != nil { + t.Fatalf("failed to parse new object: %v", err) + } + out, err := unionParser.FromYAML(test.out) + if err != nil { + t.Fatalf("failed to parse out object: %v", err) + } + got, err := old.NormalizeUnions(new) + if err != nil { + t.Fatalf("failed to normalize unions: %v", err) + } + comparison, err := out.Compare(got) + if err != nil { + t.Fatalf("failed to compare result and expected: %v", err) + } + if !comparison.IsSame() { + t.Errorf("Result is different from expected:\n%v", comparison) + } + }) + } +} + +func TestNormalizeUnionError(t *testing.T) { + tests := []struct { + name string + old typed.YAMLObject + new typed.YAMLObject + }{ + { + name: "new object has three of same union set", + old: `{"one": 1}`, + new: `{"one": 2, "two": 1, "three": 3}`, + }, + { + name: "client sends new field that and discriminator change", + old: `{}`, + new: `{"one": 1, "discriminator": "Two"}`, + }, + { + name: "client sends new fields that don't match discriminator change", + old: `{}`, + new: `{"one": 1, "two": 1, "discriminator": "One"}`, + }, + { + name: "old object has two of same union set", + old: `{"one": 1, "two": 2}`, + new: `{"one": 2, "two": 1}`, + }, + { + name: "one field removed, 2 left, discriminator unchanged", + old: `{"one": 1, "two": 1, "three": 1, "discriminator": "TWO"}`, + new: `{"one": 1, "two": 1, "discriminator": "TWO"}`, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + old, err := unionParser.FromYAML(test.old) + if err != nil { + t.Fatalf("Failed to parse old object: %v", err) + } + new, err := unionParser.FromYAML(test.new) + if err != nil { + t.Fatalf("failed to parse new object: %v", err) + } + _, err = old.NormalizeUnions(new) + if err == nil { + t.Fatal("Normalization should have failed, but hasn't.") + } + }) + } +} diff --git a/value/value.go b/value/value.go index a5dbc5f4..d1c6d344 100644 --- a/value/value.go +++ b/value/value.go @@ -84,6 +84,18 @@ func (m *Map) Set(key string, value Value) { m.index = nil // Since the append might have reallocated } +// Delete removes the key from the set. +func (m *Map) Delete(key string) { + items := []Field{} + for i := range m.Items { + if m.Items[i].Name != key { + items = append(items, m.Items[i]) + } + } + m.Items = items + m.index = nil // Since the list has changed +} + // StringValue returns s as a scalar string Value. func StringValue(s string) Value { s2 := String(s)