diff --git a/protoc-go-inject-tags/main.go b/protoc-go-inject-tags/main.go index 39d32aa..f9a6dac 100644 --- a/protoc-go-inject-tags/main.go +++ b/protoc-go-inject-tags/main.go @@ -34,6 +34,7 @@ func init() { flag.StringVar(&ignore, "ignore", "", "ignore fields for struct eg: form:file,json:file") fset = token.NewFileSet() } + func main() { flag.Parse() @@ -131,39 +132,51 @@ func handleTags(fields *ast.FieldList) { continue } - var newName string + var jsonName string for _, name := range field.Names { - newName = strings.ToLower(strings.Join(camelcase.Split(name.Name), "_")) + jsonName = strings.ToLower(strings.Join(camelcase.Split(name.Name), "_")) } - tv := strings.ReplaceAll(field.Tag.Value, "`", "") tagList := strings.Split(tv, " ") - fieldTags := make(map[string]*fieldTag, 0) + fieldTags := make(map[string]*fieldTag) var sortTags []string + for _, t := range tagList { tf := strings.Split(t, ":") if len(tf) != 2 { continue } - fieldTags[tf[0]] = &fieldTag{ - tagName: strings.TrimSpace(tf[0]), - tagValue: strings.Replace(tf[1], "\"", "", -1), + tagName := strings.TrimSpace(tf[0]) + tagValue := strings.Replace(tf[1], "\"", "", -1) + fieldTags[tagName] = &fieldTag{ + tagName: tagName, + tagValue: tagValue, + } + sortTags = append(sortTags, tagName) + + if tagName == "protobuf" { + if jsonIndex := strings.Index(tagValue, "json="); jsonIndex != -1 { + jsonValue := tagValue[jsonIndex+5:] + jsonValue = strings.Split(jsonValue, ",")[0] + jsonName = jsonValue + } } - sortTags = append(sortTags, tf[0]) } for _, tag := range tagsValue { - if _, ok := fieldTags[tag]; ok { - continue - } - - fieldTags[tag] = &fieldTag{tagName: tag, tagValue: fmt.Sprintf("%s,omitempty", newName)} + fieldTags[tag] = &fieldTag{tagName: tag, tagValue: fmt.Sprintf("%s,omitempty", jsonName)} sortTags = append(sortTags, tag) } + fieldTags["json"] = &fieldTag{tagName: "json", tagValue: fmt.Sprintf("%s,omitempty", jsonName)} + var newTags string + seenTags := make(map[string]bool) for _, tag := range sortTags { - newTags += fmt.Sprintf("%s:\"%s\" ", fieldTags[tag].tagName, fieldTags[tag].tagValue) + if !seenTags[tag] { + newTags += fmt.Sprintf("%s:\"%s\" ", fieldTags[tag].tagName, fieldTags[tag].tagValue) + seenTags[tag] = true + } } for _, v := range ignoreFields { @@ -172,6 +185,6 @@ func handleTags(fields *ast.FieldList) { } } - field.Tag.Value = fmt.Sprintf("`%s`", newTags) + field.Tag.Value = fmt.Sprintf("`%s`", strings.TrimSpace(newTags)) } }