From d80d7e988207a0c850842eee59691ca05ab5f90d Mon Sep 17 00:00:00 2001 From: Roman Volosatovs Date: Thu, 21 Feb 2019 15:34:55 +0100 Subject: [PATCH] lang/go: Add Context.FieldType{PackageName,ImportPath} --- lang/go/context.go | 10 +++++ lang/go/package.go | 98 +++++++++++++++++++++++++++++++--------------- 2 files changed, 76 insertions(+), 32 deletions(-) diff --git a/lang/go/context.go b/lang/go/context.go index a652c50..a1a5df9 100644 --- a/lang/go/context.go +++ b/lang/go/context.go @@ -60,6 +60,16 @@ type Context interface { // OutputPath returns the output path relative to the plugin's output destination OutputPath(entity pgs.Entity) pgs.FilePath + + // FieldTypeImportPath returns name of the Field type's package as it would appear in + // Go source generated by the official protoc-gen-go plugin. + // For builtin types empty FieldPath will be returned. + FieldTypePackageName(field pgs.Field) pgs.Name + + // FieldTypeImportPath returns the Go import path of the type of the Field + // as it would be included in an import block in a Go file. + // For builtin types empty FieldPath will be returned. + FieldTypeImportPath(field pgs.Field) pgs.FilePath } type context struct{ p pgs.Parameters } diff --git a/lang/go/package.go b/lang/go/package.go index a4fab9e..181c099 100644 --- a/lang/go/package.go +++ b/lang/go/package.go @@ -12,29 +12,6 @@ import ( var nonAlphaNumPattern = regexp.MustCompile("[^a-zA-Z0-9]") -func (c context) PackageName(node pgs.Node) pgs.Name { - e, ok := node.(pgs.Entity) - if !ok { - e = node.(pgs.Package).Files()[0] - } - - _, pkg := c.optionPackage(e) - - // use import_path parameter ONLY if there is no go_package option in the file. - if ip := c.p.Str("import_path"); ip != "" && - e.File().Descriptor().GetOptions().GetGoPackage() == "" { - pkg = ip - } - - // if the package name is a Go keyword, prefix with '_' - if token.Lookup(pkg).IsKeyword() { - pkg = "_" + pkg - } - - // package name is kosher - return pgs.Name(pkg) -} - func gogoType(f pgs.Field) (pgs.FilePath, TypeName, bool) { ft := f.Type() switch { @@ -99,14 +76,54 @@ func gogoType(f pgs.Field) (pgs.FilePath, TypeName, bool) { return "", TypeName(typeName), true } -func (c gogoContext) PackageName(node pgs.Node) pgs.Name { - f, ok := node.(pgs.Field) +func (c context) PackageName(node pgs.Node) pgs.Name { + e, ok := node.(pgs.Entity) if !ok { - return c.context.PackageName(node) + e = node.(pgs.Package).Files()[0] + } + + _, pkg := c.optionPackage(e) + + // use import_path parameter ONLY if there is no go_package option in the file. + if ip := c.p.Str("import_path"); ip != "" && + e.File().Descriptor().GetOptions().GetGoPackage() == "" { + pkg = ip + } + + // if the package name is a Go keyword, prefix with '_' + if token.Lookup(pkg).IsKeyword() { + pkg = "_" + pkg + } + + // package name is kosher + return pgs.Name(pkg) +} + +func (c context) FieldTypePackageName(f pgs.Field) pgs.Name { + var en pgs.Entity + switch ft := f.Type(); { + case ft.IsEmbed(): + en = ft.Embed() + case ft.IsEnum(): + en = ft.Enum() + case ft.IsRepeated(), ft.IsMap(): + el := ft.Element() + switch { + case el.IsEmbed(): + en = el.Embed() + case el.IsEnum(): + en = el.Enum() + } + default: + return pgs.Name("") } + return c.PackageName(en) +} + +func (c gogoContext) FieldTypePackageName(f pgs.Field) pgs.Name { pkg, _, ok := gogoType(f) if !ok { - return c.context.PackageName(node) + return c.context.FieldTypePackageName(f) } return pgs.Name(nonAlphaNumPattern.ReplaceAllString(string(pkg), "_")) } @@ -117,14 +134,31 @@ func (c context) ImportPath(e pgs.Entity) pgs.FilePath { return pgs.FilePath(path) } -func (c gogoContext) ImportPath(e pgs.Entity) pgs.FilePath { - f, ok := e.(pgs.Field) - if !ok { - return c.context.ImportPath(e) +func (c context) FieldTypeImportPath(f pgs.Field) pgs.FilePath { + var en pgs.Entity + switch ft := f.Type(); { + case ft.IsEmbed(): + en = ft.Embed() + case ft.IsEnum(): + en = ft.Enum() + case ft.IsRepeated(), ft.IsMap(): + el := ft.Element() + switch { + case el.IsEmbed(): + en = el.Embed() + case el.IsEnum(): + en = el.Enum() + } + default: + return pgs.FilePath("") } + return c.ImportPath(en) +} + +func (c gogoContext) FieldTypeImportPath(f pgs.Field) pgs.FilePath { pkg, _, ok := gogoType(f) if !ok { - return c.context.ImportPath(e) + return c.context.FieldTypeImportPath(f) } return pkg }