Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions field.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ type Field struct {
IsNullable bool
// IsSensitive is field sensitive? (password, token)
IsSensitive bool
// IsEmbed is field embedded?
IsEmbed bool
// Validators represents the array of field validators for a field
Validators []string
// PlanModifiers represents the array of plan modifiers for a field
Expand Down Expand Up @@ -189,6 +191,7 @@ func BuildField(c *FieldBuildContext) (*Field, error) {
IsRepeated: c.IsRepeated(),
IsMap: c.IsMap(),
IsNullable: c.GetNullable(),
IsEmbed: c.IsEmbed(),
Validators: c.GetValidators(),
PlanModifiers: c.GetPlanModifiers(),
Path: c.GetPath(),
Expand Down
17 changes: 17 additions & 0 deletions field_build_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ type FieldBuildContext struct {
imports *Imports
path string
goType string
isEmbed bool
}

// NewFieldBuildContext creates FieldBuildContext
Expand Down Expand Up @@ -127,6 +128,7 @@ func NewFieldBuildContext(m MessageBuildContext, field *FieldDescriptorProtoExt,
imports: m.imports,
path: path,
goType: m.imports.PrependPackageNameIfMissing(t, m.config.DefaultPackageName),
isEmbed: field.IsEmbed(),
}

return c, nil
Expand Down Expand Up @@ -166,6 +168,16 @@ func (c *FieldBuildContext) GetNameWithTypeName() string {

// GetName returns field name
func (c *FieldBuildContext) GetName() string {
if c.IsEmbed() {
// Return the name of the struct with no prepended package or pointer.
goType := strings.TrimPrefix(c.GetGoType(), "*")
goTypeSplit := strings.SplitN(goType, ".", 2)
if len(goTypeSplit) == 2 {
return goTypeSplit[1]
}
return goType
}

name := c.field.GetName()
if name[0:1] == strings.ToLower(name[0:1]) {
return strcase.UpperCamelCase(name)
Expand Down Expand Up @@ -316,6 +328,11 @@ func (c *FieldBuildContext) IsCastType() bool {
return c.field.IsCastType()
}

// IsEmbed returns true if the field has gogo.embed flag set to true
func (c *FieldBuildContext) IsEmbed() bool {
return c.isEmbed
}

// GetComment returns field comment as a single line
func (c *FieldBuildContext) GetComment() string {
// ",2," marks that we are extracting comment for a message field. See descriptor.SourceCodeInfo source for details.
Expand Down
78 changes: 78 additions & 0 deletions field_build_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
Copyright 2023 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"testing"

"github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
"github.com/stretchr/testify/require"
)

func TestFieldBuildContextGetName(t *testing.T) {
tests := []struct {
name string
embedded bool
fieldName string
goType string
expected string
}{
{
name: "regular name",
fieldName: "Name",
expected: "Name",
},
{
name: "embedded name",
embedded: true,
fieldName: "Name",
goType: "EmbeddedName",
expected: "EmbeddedName",
},
{
name: "embedded name with pointer",
embedded: true,
fieldName: "Name",
goType: "*EmbeddedName",
expected: "EmbeddedName",
},
{
name: "embedded name in another package",
embedded: true,
fieldName: "Name",
goType: "*someotherpackage.EmbeddedName",
expected: "EmbeddedName",
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
msg := &descriptor.FieldDescriptorProto{
Name: &test.fieldName,
}

fbc := FieldBuildContext{
field: &FieldDescriptorProtoExt{msg},
goType: test.goType,
isEmbed: test.embedded,
}
require.Equal(t, test.expected, fbc.GetName())
})
}
}
5 changes: 5 additions & 0 deletions field_descriptor_proto_ext.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ func (f *FieldDescriptorProtoExt) IsCustomType() bool {
return gogoproto.IsCustomType(f.FieldDescriptorProto)
}

// IsEmbed returns true if the field has the gogoproto.embed flag set to true
func (f *FieldDescriptorProtoExt) IsEmbed() bool {
return gogoproto.IsEmbed(f.FieldDescriptorProto)
}

// GetCastType returns field cast type name
func (f *FieldDescriptorProtoExt) GetCastType() string {
return gogoproto.GetCastType(f.FieldDescriptorProto)
Expand Down
20 changes: 15 additions & 5 deletions gen_copy_from.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ func (f *FieldCopyFromGenerator) errAttrConversionFailure(path string, typ strin

// nextField reads current field value from Terraform object and asserts it's type against expected
func (f *FieldCopyFromGenerator) nextField(g func(g *j.Group)) *j.Statement {
if f.IsEmbed {
return j.BlockFunc(g)
}
return j.Block(
// a, ok := ft.Attrs["key"]
j.List(j.Id("a"), j.Id("ok")).Op(":=").Id("tf.Attrs").Index(j.Lit(f.NameSnake)),
Expand Down Expand Up @@ -182,11 +185,12 @@ func (f *FieldCopyFromGenerator) genObject() *j.Statement {
// obj.Nested = Nested{}
g.Id(objFieldName).Op("=").Id(f.i.WithType(f.GoElemType)).Values()
}
// if !v.Null
g.If(j.Id("!v.Null && !v.Unknown")).BlockFunc(func(g *j.Group) {
fn := func(g *j.Group) {
if !m.IsEmpty {
// tf := v
g.Id("tf").Op(":=").Id("v")
if !f.IsEmbed {
// tf := v
g.Id("tf").Op(":=").Id("v")
}

if f.IsNullable {
// obj.Nested = &Nested{}
Expand All @@ -200,7 +204,13 @@ func (f *FieldCopyFromGenerator) genObject() *j.Statement {

m.GenerateFields(g)
}
})
}
if f.IsEmbed {
g.BlockFunc(fn)
} else {
// if !v.Null
g.If(j.Id("!v.Null && !v.Unknown")).BlockFunc(fn)
}
} else {
// We do not need nullable checks because all oneOf branches are nullable by design
// We do not need to assign OneOf explicitly to not overrite other OneOf branch values
Expand Down
27 changes: 21 additions & 6 deletions gen_copy_to.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ func (f *FieldCopyToGenerator) Generate() *j.Statement {

// nextField reads current field value from Terraform object and asserts it's type against expected
func (f *FieldCopyToGenerator) nextField(v string, g func(g *j.Group)) *j.Statement {
if f.IsEmbed {
return j.BlockFunc(g)
}
return j.Block(
// _, ok := ft.AttrsTypes["key"]
j.List(j.Id(v), j.Id("ok")).Op(":=").Id("tf.AttrTypes").Index(j.Lit(f.NameSnake)),
Expand Down Expand Up @@ -179,11 +182,18 @@ func (f *FieldCopyToGenerator) genObjectBody(m *MessageCopyToGenerator, fieldNam
if !m.IsEmpty {
g.Id("obj").Op(":=").Id(fieldName)
}
g.Id("tf").Op(":=").Id("&v")
if !f.IsEmbed {
g.Id("tf").Op(":=").Id("&v")
}
m.GenerateFields(g)
}
}

if f.IsEmbed {
g.BlockFunc(copyObj)
return
}

f.getAttr("v", f.Field.ElemValueType, g)
g.If(j.Id("!ok")).Block(
// v := types.Object{Attrs: make(map[string]attr.Value, len(o.AttrTypes)), AttrTypes: o.AttrTypes}
Expand All @@ -206,7 +216,6 @@ func (f *FieldCopyToGenerator) genObjectBody(m *MessageCopyToGenerator, fieldNam
} else {
g.BlockFunc(copyObj)
}
g.Id("v.Unknown").Op("=").False()
}

// assertTo asserts a to typ
Expand Down Expand Up @@ -251,10 +260,16 @@ func (f *FieldCopyToGenerator) genObject() *j.Statement {
f.genOneOfStub(g)
}

f.assertTo(f.Field.ElemType, g, func(g *j.Group) {
f.genObjectBody(m, fieldName, f.Field.ValueType, g)
g.Id("tf.Attrs").Index(j.Lit(f.NameSnake)).Op("=").Id("v")
})
if f.IsEmbed {
g.BlockFunc(func(g *j.Group) {
f.genObjectBody(m, fieldName, f.Field.ValueType, g)
})
} else {
f.assertTo(f.Field.ElemType, g, func(g *j.Group) {
f.genObjectBody(m, fieldName, f.Field.ValueType, g)
g.Id("tf.Attrs").Index(j.Lit(f.NameSnake)).Op("=").Id("v")
})
}
})
}

Expand Down
16 changes: 11 additions & 5 deletions gen_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (

"io"

"github.com/dave/jennifer/jen"
j "github.com/dave/jennifer/jen"
)

Expand Down Expand Up @@ -73,8 +72,15 @@ func (m *MessageSchemaGenerator) fieldsDictSchema() j.Dict {
d := j.Dict{}

for _, f := range m.Fields {
f := NewFieldSchemaGenerator(f, m.i)
d[j.Lit(f.NameSnake)] = f.Generate()
if f.IsEmbed {
for _, f := range f.Message.Fields {
f := NewFieldSchemaGenerator(f, m.i)
d[j.Lit(f.NameSnake)] = f.Generate()
}
} else {
f := NewFieldSchemaGenerator(f, m.i)
d[j.Lit(f.NameSnake)] = f.Generate()
}
}

if len(m.Message.InjectedFields) > 0 {
Expand Down Expand Up @@ -222,7 +228,7 @@ func (f *FieldSchemaGenerator) xNestedAttributes(typ string, m *MessageSchemaGen
}

func generatePlanModifiers(imports *Imports, pm []string) j.Code {
v := make([]jen.Code, len(pm))
v := make([]j.Code, len(pm))
for i, n := range pm {
v[i] = j.Id(imports.WithType(n))
}
Expand All @@ -231,7 +237,7 @@ func generatePlanModifiers(imports *Imports, pm []string) j.Code {
}

func generateValidators(imports *Imports, vals []string) j.Code {
v := make([]jen.Code, len(vals))
v := make([]j.Code, len(vals))
for i, n := range vals {
v[i] = j.Id(imports.WithType(n))
}
Expand Down
10 changes: 10 additions & 0 deletions test/copy_from_terraform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,13 @@ func TestCopyFromOneOfObjectNoBranch(t *testing.T) {

require.Equal(t, nil, target.OneOf)
}

func TestCopyFromEmbedded(t *testing.T) {
obj := copyFromTerraformObject(t)

target := Test{}
require.False(t, CopyTestFromTerraform(context.Background(), obj, &target).HasError())

require.Equal(t, int32(1), target.Embedded.EmbeddedOne)
require.Equal(t, int32(2), target.Embedded.EmbeddedTwo)
}
11 changes: 11 additions & 0 deletions test/copy_to_terraform_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,14 @@ func TestCopyToOneOfNoBranch(t *testing.T) {
require.True(t, o.Attrs["branch2"].(types.Object).Null)
require.True(t, o.Attrs["branch3"].(types.String).Null)
}

func TestCopyToEmbedded(t *testing.T) {
o := copyToTerraformObject(t)
testObj := createTestObj()

diags := CopyTestToTerraform(context.Background(), testObj, &o)
require.False(t, diags.HasError())

require.Equal(t, types.Int64{Value: 1}, o.Attrs["embedded_one"])
require.Equal(t, types.Int64{Value: 2}, o.Attrs["embedded_two"])
}
8 changes: 8 additions & 0 deletions test/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ func createTestObj() Test {
},

Map: map[string]string{"key1": "value1", "key2": "value2"},

Embedded: &Embedded{
EmbeddedOne: 1,
EmbeddedTwo: 2,
},
}
}

Expand Down Expand Up @@ -409,6 +414,9 @@ func copyFromTerraformObject(t *testing.T) types.Object {
"branch3": types.String{Null: true},
"empty_message_branch": types.Object{Null: true},
"string_branch": types.String{Null: true},

"embedded_one": types.Int64{Value: 1},
"embedded_two": types.Int64{Value: 2},
},
AttrTypes: obj.AttrTypes,
}
Expand Down
Loading