diff --git a/templates/go/register.go b/templates/go/register.go index 0fa9b81d0..b6681b324 100644 --- a/templates/go/register.go +++ b/templates/go/register.go @@ -4,11 +4,12 @@ import ( "text/template" "github.com/lyft/protoc-gen-star" + pgsgo "github.com/lyft/protoc-gen-star/lang/go" "github.com/lyft/protoc-gen-validate/templates/goshared" ) func Register(tpl *template.Template, params pgs.Parameters) { - goshared.Register(tpl, params) + goshared.Register(tpl, params, pgsgo.InitContext) template.Must(tpl.Parse(fileTpl)) template.Must(tpl.New("required").Parse(requiredTpl)) template.Must(tpl.New("timestamp").Parse(timestampTpl)) diff --git a/templates/gogo/register.go b/templates/gogo/register.go index 59efec29c..1a61ac8bf 100644 --- a/templates/gogo/register.go +++ b/templates/gogo/register.go @@ -4,12 +4,12 @@ import ( "text/template" "github.com/lyft/protoc-gen-star" - + pgsgo "github.com/lyft/protoc-gen-star/lang/go" shared "github.com/lyft/protoc-gen-validate/templates/goshared" ) func Register(tpl *template.Template, params pgs.Parameters) { - shared.Register(tpl, params) + shared.Register(tpl, params, pgsgo.InitGoGoContext) template.Must(tpl.Parse(fileTpl)) template.Must(tpl.New("required").Parse(requiredTpl)) template.Must(tpl.New("timestamp").Parse(timestampTpl)) diff --git a/templates/goshared/register.go b/templates/goshared/register.go index 55a05fa4e..c91374708 100644 --- a/templates/goshared/register.go +++ b/templates/goshared/register.go @@ -14,8 +14,8 @@ import ( "github.com/lyft/protoc-gen-validate/templates/shared" ) -func Register(tpl *template.Template, params pgs.Parameters) { - fns := goSharedFuncs{pgsgo.InitContext(params)} +func Register(tpl *template.Template, params pgs.Parameters, initContext func(pgs.Parameters) pgsgo.Context) { + fns := goSharedFuncs{initContext(params)} tpl.Funcs(map[string]interface{}{ "accessor": fns.accessor, @@ -91,8 +91,11 @@ func (fns goSharedFuncs) accessor(ctx shared.RuleContext) string { if ctx.AccessorOverride != "" { return ctx.AccessorOverride } - - return fmt.Sprintf("m.Get%s()", fns.Name(ctx.Field)) + name := fns.Name(ctx.Field) + if name == "" { + return fmt.Sprintf("m.%s", fns.Type(ctx.Field).Value()) + } + return fmt.Sprintf("m.Get%s()", name) } func (fns goSharedFuncs) errName(m pgs.Message) pgs.Name {