Skip to content

Commit

Permalink
put nil preserve logic behind a field tag
Browse files Browse the repository at this point in the history
  • Loading branch information
whyrusleeping committed Jan 4, 2024
1 parent 0831c9a commit 075d157
Show file tree
Hide file tree
Showing 7 changed files with 1,121 additions and 936 deletions.
60 changes: 49 additions & 11 deletions gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ type Field struct {
Pkg string
Const *string

OmitEmpty bool
IterLabel string
OmitEmpty bool
PreserveNil bool
IterLabel string

MaxLen int
}
Expand Down Expand Up @@ -226,16 +227,18 @@ func ParseTypeInfo(i interface{}) (*GenTypeInfo, error) {
}

_, omitempty := tags["omitempty"]
_, preservenil := tags["preservenil"]

out.Fields = append(out.Fields, Field{
Name: f.Name,
MapKey: mapk,
Pointer: pointer,
Type: ft,
Pkg: pkg,
OmitEmpty: omitempty,
MaxLen: usrMaxLen,
Const: constval,
Name: f.Name,
MapKey: mapk,
Pointer: pointer,
Type: ft,
Pkg: pkg,
OmitEmpty: omitempty,
PreserveNil: preservenil,
MaxLen: usrMaxLen,
Const: constval,
})
}

Expand All @@ -259,6 +262,8 @@ func tagparse(v string) (map[string]string, error) {
out[strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
} else if elem == "omitempty" {
out["omitempty"] = "true"
} else if elem == "preservenil" {
out["preservenil"] = "true"
} else if elem == "ignore" || elem == "-" {
out["ignore"] = "true"
} else {
Expand Down Expand Up @@ -528,18 +533,22 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error {
return xerrors.Errorf("Byte array in field {{ .Name }} was too long")
}
{{ if .PreserveNil }}
if {{ .Name }} == nil {
_, err := w.Write(cbg.CborNull)
if err != nil {
return err
}
} else {
{{ end }}
{{ MajorType "cw" "cbg.MajByteString" (print "len(" .Name ")" ) }}
if _, err := cw.Write({{ .Name }}[:]); err != nil {
return err
}
{{ if .PreserveNil }}
}
{{ end }}
`)
}

Expand All @@ -554,12 +563,14 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error {
return xerrors.Errorf("Slice value in field {{ .Name }} was too long")
}
{{ if .PreserveNil }}
if {{ .Name }} == nil {
_, err := w.Write(cbg.CborNull)
if err != nil {
return err
}
} else {
{{ end }}
{{ MajorType "cw" "cbg.MajArray" ( print "len(" .Name ")" ) }}
for _, v := range {{ .Name }} {`)
if err != nil {
Expand Down Expand Up @@ -588,7 +599,15 @@ func emitCborMarshalSliceField(w io.Writer, f Field) error {
}

// array end
fmt.Fprintf(w, "\t\t}\n\t}\n")
if err := doTemplate(w, f, `
{{ if .PreserveNil }}
}
{{ end }}
}
`); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -1099,6 +1118,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error {
}

err := doTemplate(w, f, `
{{ if .PreserveNil }}
{
b, err := cr.ReadByte()
if err != nil {
Expand All @@ -1109,6 +1129,7 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error {
return err
}
{{ end }}
maj, extra, err = {{ ReadHeader "cr" }}
if err != nil {
return err
Expand All @@ -1126,12 +1147,21 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error {
if maj != cbg.MajByteString {
return fmt.Errorf("expected byte array")
}
{{ if .PreserveNil }}
{{ .Name }} = make({{ .TypeName }}, extra)
{{ else }}
if extra > 0 {
{{ .Name }} = make({{ .TypeName }}, extra)
}
{{ end }}
if _, err := io.ReadFull(cr, {{ .Name }}[:]); err != nil {
return err
}
{{ if .PreserveNil }}
}
}
{{ end }}
`)
}

Expand All @@ -1147,7 +1177,13 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error {
if maj != cbg.MajArray {
return fmt.Errorf("expected cbor array")
}
{{ if .PreserveNil }}
{{ .Name }} = make({{ .TypeName }}, extra)
{{ else }}
if extra > 0 {
{{ .Name }} = make({{ .TypeName }}, extra)
}
{{ end }}
for {{ .IterLabel }} := 0; {{ .IterLabel }} < int(extra); {{ .IterLabel }}++ {
`)
if err != nil {
Expand Down Expand Up @@ -1228,8 +1264,10 @@ func emitCborUnmarshalSliceField(w io.Writer, f Field) error {
}

if err := doTemplate(w, f, `
{{ if .PreserveNil }}
}
}
{{ end }}
}
}
`); err != nil {
Expand Down
1 change: 1 addition & 0 deletions testgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ func main() {
types.TestConstField{},
types.TestCanonicalFieldOrder{},
types.MapStringString{},
types.TestSliceNilPreserve{},
); err != nil {
panic(err)
}
Expand Down
Loading

0 comments on commit 075d157

Please sign in to comment.