Skip to content

Commit

Permalink
use fmt.Stringer during pretty print if a value is supporting it
Browse files Browse the repository at this point in the history
  • Loading branch information
adamluzsi committed Jun 19, 2023
1 parent 747f601 commit b1734ef
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 48 deletions.
92 changes: 48 additions & 44 deletions pp/Format.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package pp

import (
"bytes"
"context"
"fmt"
"github.com/adamluzsi/testcase/internal/reflects"
"io"
Expand Down Expand Up @@ -37,7 +36,10 @@ type visitor struct {
stack int
}

var typeTimeDuration = reflect.TypeOf(time.Duration(0))
var (
typeTimeDuration = reflect.TypeOf((*time.Duration)(nil)).Elem()
typeTimeTime = reflect.TypeOf((*time.Time)(nil)).Elem()
)

func (v *visitor) Visit(w io.Writer, rv reflect.Value, depth int) {
defer debugRecover()
Expand All @@ -48,82 +50,86 @@ func (v *visitor) Visit(w io.Writer, rv reflect.Value, depth int) {
defer td()

if rv.Kind() == reflect.Invalid {
fmt.Fprint(w, "nil")
_, _ = fmt.Fprint(w, "nil")
return
}

rv = reflects.Accessible(rv)

if rv.Type() == typeTimeDuration {
d := time.Duration(rv.Int())
fmt.Fprintf(w, "/* %s */ %#v", d.String(), d)
_, _ = fmt.Fprintf(w, "/* %s */ %#v", d.String(), d)
return
}

if rv.Type() == typeTimeTime {
_, _ = fmt.Fprintf(w, "%#v", rv.Interface())
return
}

if v.tryStringer(w, rv, depth) {
return
}

if rv.CanInt() {
fmt.Fprintf(w, "%#v", rv.Int())
_, _ = fmt.Fprintf(w, "%#v", rv.Int())
return
}

if rv.CanUint() {
fmt.Fprintf(w, "%d", rv.Uint())
_, _ = fmt.Fprintf(w, "%d", rv.Uint())
return
}

if rv.CanFloat() {
fmt.Fprintf(w, "%#v", rv.Float())
_, _ = fmt.Fprintf(w, "%#v", rv.Float())
return
}

switch rv.Kind() {
case reflect.Array, reflect.Slice:
if v.tryStringer(w, rv, depth) {
return
}
if v.tryByteSlice(w, rv) {
return
}
if v.tryNilSlice(w, rv) {
return
}

fmt.Fprintf(w, "%s{", v.getTypeName(rv))
_, _ = fmt.Fprintf(w, "%s{", v.getTypeName(rv))
vLen := rv.Len()
for i := 0; i < vLen; i++ {
v.newLine(w, depth+1)
v.Visit(w, rv.Index(i), depth+1)
fmt.Fprintf(w, ",")
_, _ = fmt.Fprintf(w, ",")
}
if 0 < vLen {
v.newLine(w, depth)
}
fmt.Fprint(w, "}")
_, _ = fmt.Fprint(w, "}")

case reflect.Map:
fmt.Fprintf(w, "%s{", v.getTypeName(rv))
_, _ = fmt.Fprintf(w, "%s{", v.getTypeName(rv))
keys := rv.MapKeys()
v.sortMapKeys(keys)
for _, key := range keys {
v.newLine(w, depth+1)
v.Visit(w, key, depth+1) // key
fmt.Fprintf(w, ": ")
_, _ = fmt.Fprintf(w, ": ")
v.Visit(w, rv.MapIndex(key), depth+1) // value
fmt.Fprintf(w, ",")
_, _ = fmt.Fprintf(w, ",")
}
if 0 < len(keys) {
v.newLine(w, depth)
}
fmt.Fprint(w, "}")
_, _ = fmt.Fprint(w, "}")

case reflect.Struct:
switch rv.Type() {
case reflect.TypeOf(time.Time{}):
fmt.Fprintf(w, "%#v", rv.Interface())
default:
v.visitStructure(w, rv, depth)
}
v.visitStructure(w, rv, depth)

case reflect.Interface:
fmt.Fprintf(w, "(%s)(", v.getTypeName(rv))
_, _ = fmt.Fprintf(w, "(%s)(", v.getTypeName(rv))
v.Visit(w, rv.Elem(), depth)
fmt.Fprint(w, ")")
_, _ = fmt.Fprint(w, ")")

case reflect.Pointer:
if rv.IsNil() {
Expand All @@ -133,28 +139,28 @@ func (v *visitor) Visit(w io.Writer, rv reflect.Value, depth int) {

elem := rv.Elem()
if v.isRecursion(elem) {
fmt.Fprintf(w, "(%s)(", v.getTypeName(rv))
fmt.Fprintf(w, "%#v", rv.Pointer())
fmt.Fprint(w, ")")
_, _ = fmt.Fprintf(w, "(%s)(", v.getTypeName(rv))
_, _ = fmt.Fprintf(w, "%#v", rv.Pointer())
_, _ = fmt.Fprint(w, ")")
return
}

fmt.Fprintf(w, "&")
_, _ = fmt.Fprintf(w, "&")
v.Visit(w, rv.Elem(), depth)

case reflect.Chan:
fmt.Fprintf(w, "make(%s, %d)", rv.Type().String(), rv.Cap())
_, _ = fmt.Fprintf(w, "make(%s, %d)", rv.Type().String(), rv.Cap())

case reflect.String:
fmt.Fprintf(w, "%#v", rv.String())
_, _ = fmt.Fprintf(w, "%#v", rv.String())

default:
v, ok := reflects.TryToMakeAccessible(rv)
if !ok {
fmt.Fprint(w, "/* inaccessible */")
_, _ = fmt.Fprint(w, "/* inaccessible */")
return
}
fmt.Fprintf(w, "%#v", v.Interface())
_, _ = fmt.Fprintf(w, "%#v", v.Interface())
}
}

Expand All @@ -169,11 +175,11 @@ func (v *visitor) recursionGuard(w io.Writer, rv reflect.Value) (_td func(), _ok
return func() { delete(v.visited, rv) }, true
}
if rv.CanAddr() {
fmt.Fprintf(w, "%#v", rv.UnsafeAddr())
_, _ = fmt.Fprintf(w, "%#v", rv.UnsafeAddr())
} else if rv.CanInterface() {
fmt.Fprintf(w, "%#v", rv.Interface())
_, _ = fmt.Fprintf(w, "%#v", rv.Interface())
} else {
fmt.Fprintf(w, "%v", rv)
_, _ = fmt.Fprintf(w, "%v", rv)
}
return func() {}, false
}
Expand Down Expand Up @@ -228,27 +234,27 @@ func (v *visitor) tryStringer(w io.Writer, rv reflect.Value, depth int) bool {
return false
}

fmt.Fprintf(w, "/* %s */ ", rv.Type().String())
_, _ = fmt.Fprintf(w, "/* %s */ ", rv.Type().String())
v.Visit(w, rv.MethodByName("String").Call([]reflect.Value{})[0], depth)
return true
}

func (v *visitor) visitStructure(w io.Writer, rv reflect.Value, depth int) {
fmt.Fprintf(w, "%s{", rv.Type().String())
_, _ = fmt.Fprintf(w, "%s{", rv.Type().String())
fieldNum := rv.NumField()
for i, fNum := 0, fieldNum; i < fNum; i++ {
name := rv.Type().Field(i).Name
field := rv.FieldByName(name)

v.newLine(w, depth+1)
fmt.Fprintf(w, "%s: ", name)
_, _ = fmt.Fprintf(w, "%s: ", name)
v.Visit(w, field, depth+1)
fmt.Fprintf(w, ",")
_, _ = fmt.Fprintf(w, ",")
}
if 0 < fieldNum {
v.newLine(w, depth)
}
fmt.Fprint(w, "}")
_, _ = fmt.Fprint(w, "}")
}

func (v *visitor) newLine(w io.Writer, depth int) {
Expand Down Expand Up @@ -309,7 +315,7 @@ func (v *visitor) tryByteSlice(w io.Writer, rv reflect.Value) bool {
content = strings.ReplaceAll(content, "`", "`+\"`\"+`")
}

fmt.Fprintf(w, "%s(%s%s%s)", typeName, quoteChar, content, quoteChar)
_, _ = fmt.Fprintf(w, "%s(%s%s%s)", typeName, quoteChar, content, quoteChar)
return true
}

Expand All @@ -328,5 +334,3 @@ func (v *visitor) tryNilSlice(w io.Writer, rv reflect.Value) bool {
_, _ = fmt.Fprintf(w, "(%s)(nil)", rv.Type().String())
return true
}

var ctxValType = reflect.TypeOf(context.WithValue(context.Background(), struct{}{}, 42)).Elem()
22 changes: 18 additions & 4 deletions pp/Format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,11 @@ func TestFormat(t *testing.T) {

s.And("it implements fmt.Stringer", func(s *testcase.Spec) {
v.Let(s, func(t *testcase.T) any {
return ExampleFmtStringer("foo/bar/baz")
return ExampleSliceFmtStringer("foo/bar/baz")
})

s.Then("it will use the .String() method representation", func(t *testcase.T) {
t.Must.Equal(`/* pp_test.ExampleFmtStringer */ "foo/bar/baz"`, act(t))
t.Must.Equal(`/* pp_test.ExampleSliceFmtStringer */ "foo/bar/baz"`, act(t))
})
})

Expand Down Expand Up @@ -301,6 +301,16 @@ func TestFormat(t *testing.T) {
t.Must.Equal(fmt.Sprintf("/* %s */ %d", expected.String(), expected), act(t))
})
})

s.When("it implements fmt.Stringer", func(s *testcase.Spec) {
v.Let(s, func(t *testcase.T) any {
return ExampleFmtStringer{V: "foo/bar/baz"}
})

s.Then("it will use the .String() method representation", func(t *testcase.T) {
t.Must.Equal(`/* pp_test.ExampleFmtStringer */ "foo/bar/baz"`, act(t))
})
})
}

func TestFormat_recursion(t *testing.T) {
Expand Down Expand Up @@ -434,9 +444,13 @@ func TestFormat_nil(t *testing.T) {
assert.Equal(t, "nil", pp.Format(nil))
}

type ExampleFmtStringer []byte
type ExampleSliceFmtStringer []byte

func (e ExampleSliceFmtStringer) String() string { return string(e) }

type ExampleFmtStringer struct{ V string }

func (e ExampleFmtStringer) String() string { return string(e) }
func (e ExampleFmtStringer) String() string { return e.V }

func TestFormat_timeTime(t *testing.T) {
tm := time.Date(2022, time.July, 26, 17, 36, 19, 882377000, time.UTC)
Expand Down

0 comments on commit b1734ef

Please sign in to comment.