Skip to content

Commit

Permalink
Use a (Go) type for every TOML key with TOML type information
Browse files Browse the repository at this point in the history
  • Loading branch information
arp242 committed Nov 24, 2021
1 parent 97903e9 commit 91e9a00
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 253 deletions.
123 changes: 71 additions & 52 deletions encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func NewEncoder(w io.Writer) *Encoder {
}
}

func (enc *Encoder) MetaData(m *MetaData) { enc.meta = m }
func (enc *Encoder) MetaData(m MetaData) { enc.meta = &m }

// Encode writes a TOML representation of the Go value to the Encoder's writer.
//
Expand Down Expand Up @@ -142,12 +142,13 @@ func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) {

func (enc *Encoder) encode(key Key, rv reflect.Value) {
if enc.meta != nil && enc.meta.comments != nil {
c, ok := enc.meta.comments[key.String()]
if ok {
enc.w.WriteByte('\n')
enc.w.WriteString("# ")
enc.w.WriteString(strings.ReplaceAll(c, "\n", "\n# "))
enc.w.WriteByte('\n')
comments := enc.meta.comments[key.String()]
for _, c := range comments {
if c.where == commentDoc {
enc.w.WriteString("# ")
enc.w.WriteString(strings.ReplaceAll(c.text, "\n", "\n# "))
enc.w.WriteByte('\n')
}
}
}

Expand All @@ -161,6 +162,7 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
enc.writeKeyValue(key, rv, false)
return
// TODO: #76 would make this superfluous after implemented.
// TODO: remove in v2
case Primitive:
enc.encode(key, reflect.ValueOf(t.undecoded))
return
Expand All @@ -175,7 +177,7 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool:
enc.writeKeyValue(key, rv, false)
case reflect.Array, reflect.Slice:
if typeEqual(tomlArrayTable, tomlTypeOfGo(rv)) {
if typeEqual(ArrayTable{}, tomlTypeOfGo(rv)) {
enc.eArrayOfTables(key, rv)
} else {
enc.writeKeyValue(key, rv, false)
Expand All @@ -200,26 +202,45 @@ func (enc *Encoder) encode(key Key, rv reflect.Value) {
default:
encPanic(fmt.Errorf("unsupported type for key '%s': %s", key, k))
}

// TODO: there's a newline printed already.
// if enc.meta != nil && enc.meta.comments != nil {
// comments := enc.meta.comments[key.String()]
// for _, c := range comments {
// if c.where == commentComment {
// enc.w.WriteString(" # ")
// enc.w.WriteString(strings.ReplaceAll(c.text, "\n", "\n# "))
// enc.w.WriteByte('\n')
// }
// }
// }
}

func (enc *Encoder) writeIntBase(as formatAs) int {
base := 10
switch {
case as&IntBinary != 0:
func (enc *Encoder) writeInt(typ tomlType, v uint64) {
var (
iTyp = asInt(typ)
base = int(iTyp.Base)
)
switch iTyp.Base {
case 0:
base = 10
case 2:
enc.wf("0b")
base = 2
case as&IntOctal != 0:
case 8:
enc.wf("0o")
base = 8
case as&IntHex != 0:
case 16:
enc.wf("0x")
base = 16
}
return base

n := strconv.FormatUint(uint64(v), base)
if base != 10 && iTyp.Width > 0 && len(n) < int(iTyp.Width) {
enc.wf(strings.Repeat("0", int(iTyp.Width)-len(n)))
}
enc.wf(n)
}

// eElement encodes any value that can be an array element.
func (enc *Encoder) eElement(rv reflect.Value, as formatAs) {
func (enc *Encoder) eElement(rv reflect.Value, typ tomlType) {
switch v := rv.Interface().(type) {
case time.Time: // Using TextMarshaler adds extra quotes, which we don't want.
format := time.RFC3339Nano
Expand All @@ -232,7 +253,9 @@ func (enc *Encoder) eElement(rv reflect.Value, as formatAs) {
format = "15:04:05.999999999"
}

// XXX: this breaks some tests; should fix those.
// XXX: this breaks some tests; should fix those. Can also remove
// internal/tz.go I think.

// format := time.RFC3339Nano
// switch {
// case as&DatetimeLocal != 0:
Expand All @@ -254,34 +277,31 @@ func (enc *Encoder) eElement(rv reflect.Value, as formatAs) {
if err != nil {
encPanic(err)
}
enc.writeQuoted(string(s), as)
enc.writeQuoted(string(s), asString(typ))
return
case encoding.TextMarshaler:
s, err := v.MarshalText()
if err != nil {
encPanic(err)
}
enc.writeQuoted(string(s), as)
enc.writeQuoted(string(s), asString(typ))
return
}

switch rv.Kind() {
case reflect.Bool:
enc.wf(strconv.FormatBool(rv.Bool()))
case reflect.String:
enc.writeQuoted(rv.String(), as)

enc.writeQuoted(rv.String(), asString(typ))
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v := rv.Int()
if v < 0 { // Make sure sign is before "0x".
enc.wf("-")
v = -v
}
base := enc.writeIntBase(as)
enc.wf(strconv.FormatUint(uint64(v), base))
enc.writeInt(typ, uint64(v))
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
base := enc.writeIntBase(as)
enc.wf(strconv.FormatUint(rv.Uint(), base))
enc.writeInt(typ, rv.Uint())

case reflect.Float32, reflect.Float64:
f := rv.Float()
Expand All @@ -294,7 +314,7 @@ func (enc *Encoder) eElement(rv reflect.Value, as formatAs) {
if rv.Kind() == reflect.Float32 {
n = 32
}
if as&FloatExponent != 0 {
if asFloat(typ).Exponent {
enc.wf(strconv.FormatFloat(f, 'e', -1, n))
} else {
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, n)))
Expand All @@ -308,7 +328,7 @@ func (enc *Encoder) eElement(rv reflect.Value, as formatAs) {
case reflect.Map:
enc.eMap(nil, rv, true)
case reflect.Interface:
enc.eElement(rv.Elem(), as)
enc.eElement(rv.Elem(), typ)
default:
encPanic(fmt.Errorf("unexpected primitive type: %T", rv.Interface()))
}
Expand All @@ -323,16 +343,15 @@ func floatAddDecimal(fstr string) string {
return fstr
}

func (enc *Encoder) writeQuoted(s string, as formatAs) {
multi := as&StringMultiline != 0
if as&StringLiteral != 0 {
if multi {
func (enc *Encoder) writeQuoted(s string, typ String) {
if typ.Literal {
if typ.Multiline {
enc.wf("'''%s'''\n", s)
} else {
enc.wf(`'%s'`, s)
}
} else {
if multi {
if typ.Multiline {
enc.wf(`"""%s"""`+"\n",
strings.ReplaceAll(dblQuotedReplacer.Replace(s), "\\n", "\n"))
} else {
Expand All @@ -346,7 +365,7 @@ func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) {
enc.wf("[")
for i := 0; i < length; i++ {
elem := rv.Index(i)
enc.eElement(elem, 0) // TODO: add formatAs
enc.eElement(elem, nil) // TODO: add type
if i != length-1 {
enc.wf(", ")
}
Expand Down Expand Up @@ -565,46 +584,46 @@ func tomlTypeOfGo(rv reflect.Value) tomlType {
}
switch rv.Kind() {
case reflect.Bool:
return tomlBool
return Bool{}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
reflect.Uint64:
return tomlInt{}
return Int{}
case reflect.Float32, reflect.Float64:
return tomlFloat{}
return Float{}
case reflect.Array, reflect.Slice:
if typeEqual(tomlTable{}, tomlArrayType(rv)) {
return tomlArrayTable
if typeEqual(Table{}, tomlArrayType(rv)) {
return ArrayTable{}
}
return tomlArray{}
return Array{}
case reflect.Ptr, reflect.Interface:
return tomlTypeOfGo(rv.Elem())
case reflect.String:
return tomlString{}
return String{}
case reflect.Map:
return tomlTable{}
return Table{}
case reflect.Struct:
switch rv.Interface().(type) {
case time.Time:
return tomlDatetime{}
return Datetime{}
case encoding.TextMarshaler:
return tomlString{}
return String{}
default:
// Someone used a pointer receiver: we can make it work for pointer
// values.
if rv.CanAddr() {
_, ok := rv.Addr().Interface().(encoding.TextMarshaler)
if ok {
return tomlString{}
return String{}
}
}
return tomlTable{}
return Table{}
}
default:
_, ok := rv.Interface().(encoding.TextMarshaler)
if ok {
return tomlString{}
return String{}
}
encPanic(errors.New("unsupported type: " + rv.Kind().String()))
panic("") // Need *some* return value
Expand Down Expand Up @@ -709,13 +728,13 @@ func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) {
}
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1))

var as formatAs
var typ tomlType
if enc.meta != nil {
if t, ok := enc.meta.types[key.String()]; ok {
as = t.formatAs()
typ = t
}
}
enc.eElement(val, as)
enc.eElement(val, typ)
if !inline {
enc.newline()
}
Expand Down
95 changes: 77 additions & 18 deletions encode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,80 @@ import (
"time"
)

// Copy from _example/example.go
type (
example struct {
Title string `toml:"title"`
Integers []int `toml:"integers"`
Times []fmtTime `toml:"times"`
Duration []duration `toml:"duration"`
Distros []distro `toml:"distros"`
Servers map[string]server `toml:"servers"`
Characters map[string][]struct {
Name string `toml:"name"`
Rank string `toml:"rank"`
} `toml:"characters"`
}

server struct {
IP string `toml:"ip"`
Hostname string `toml:"hostname"`
Enabled bool `toml:"enabled"`
}

distro struct {
Name string `toml:"name"`
Packages string `toml:"packages"`
}

duration struct{ time.Duration }
fmtTime struct{ time.Time }
)

func (d *duration) UnmarshalText(text []byte) (err error) {
d.Duration, err = time.ParseDuration(string(text))
return err
}

func (d duration) MarshalText() ([]byte, error) {
return []byte(d.Duration.String()), nil
}

func (t fmtTime) String() string {
f := "2006-01-02 15:04:05.999999999"
if t.Time.Hour() == 0 {
f = "2006-01-02"
}
if t.Time.Year() == 0 {
f = "15:04:05.999999999"
}
if t.Time.Location() == time.UTC {
f += " UTC"
} else {
f += " -0700"
}
return t.Time.Format(`"` + f + `"`)
}

func TestXXX(t *testing.T) {
var decoded example
meta, err := DecodeFile("_example/example.toml", &decoded)
if err != nil {
t.Fatal(err)
}

buf := new(bytes.Buffer)
enc := NewEncoder(buf)
enc.MetaData(meta)
err = enc.Encode(decoded)
if err != nil {
t.Fatal(err)
}

fmt.Println(buf)

}

func TestEncodeRoundTrip(t *testing.T) {
type Config struct {
Age int
Expand Down Expand Up @@ -499,27 +573,12 @@ func TestEncodeHints(t *testing.T) {
t.Fatal(err)
}

fmt.Printf("mapping: %#v\n", meta.mapping)
fmt.Printf("types: %#v\n", meta.types)
fmt.Printf("keys : %#v\n", meta.keys)

// meta.Format() // Set format.
// meta.Doc() // Doc comment above key.
// meta.Comment() // Inline comment (like this)

//meta.Format("ml", toml.StringMultiline)
//meta.Doc("ml", "Some comment.")

//meta.Format("n", toml.IntHex | toml.IntLeadingZero(2))

// FormatString("ml", StringMultiline).
// FormatString("lit", StringLiteral).
// FormatString("cmt", StringLiteral|StringMultiline).
// Comment("cmt", "Well, hello there!"))
meta.Doc("ml", "Hello").Comment("ml", "inline")
meta.SetType("n", Int{Width: 4, Base: 16})

buf := new(bytes.Buffer)
enc := NewEncoder(buf)
enc.MetaData(&meta)
enc.MetaData(meta)
err = enc.Encode(foo)
if err != nil {
t.Fatal(err)
Expand Down
Loading

0 comments on commit 91e9a00

Please sign in to comment.