diff --git a/color.go b/color.go index b2795fed..39f4a7e7 100644 --- a/color.go +++ b/color.go @@ -220,22 +220,30 @@ func (c *Color) unset() { // a low-level function, and users should use the higher-level functions, such // as color.Fprint, color.Print, etc. func (c *Color) SetWriter(w io.Writer) *Color { + _, _ = c.setWriter(w) + return c +} + +func (c *Color) setWriter(w io.Writer) (int, error) { if c.isNoColorSet() { - return c + return 0, nil } - fmt.Fprint(w, c.format()) - return c + return fmt.Fprint(w, c.format()) } // UnsetWriter resets all escape attributes and clears the output with the give // io.Writer. Usually should be called after SetWriter(). func (c *Color) UnsetWriter(w io.Writer) { + _, _ = c.unsetWriter(w) +} + +func (c *Color) unsetWriter(w io.Writer) (int, error) { if c.isNoColorSet() { - return + return 0, nil } - fmt.Fprintf(w, "%s[%dm", escape, Reset) + return fmt.Fprintf(w, "%s[%dm", escape, Reset) } // Add is used to chain SGR parameters. Use as many as parameters to combine @@ -251,10 +259,20 @@ func (c *Color) Add(value ...Attribute) *Color { // On Windows, users should wrap w with colorable.NewColorable() if w is of // type *os.File. func (c *Color) Fprint(w io.Writer, a ...interface{}) (n int, err error) { - c.SetWriter(w) - defer c.UnsetWriter(w) + n, err = c.setWriter(w) + if err != nil { + return n, err + } + + nn, err := fmt.Fprint(w, a...) + n += nn + if err != nil { + return + } - return fmt.Fprint(w, a...) + nn, err = c.unsetWriter(w) + n += nn + return n, err } // Print formats using the default formats for its operands and writes to @@ -274,10 +292,20 @@ func (c *Color) Print(a ...interface{}) (n int, err error) { // On Windows, users should wrap w with colorable.NewColorable() if w is of // type *os.File. func (c *Color) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) { - c.SetWriter(w) - defer c.UnsetWriter(w) + n, err = c.setWriter(w) + if err != nil { + return n, err + } + + nn, err := fmt.Fprintf(w, format, a...) + n += nn + if err != nil { + return + } - return fmt.Fprintf(w, format, a...) + nn, err = c.unsetWriter(w) + n += nn + return n, err } // Printf formats according to a format specifier and writes to standard output. diff --git a/color_test.go b/color_test.go index 586039b4..880dcfb3 100644 --- a/color_test.go +++ b/color_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "strings" "testing" "github.com/mattn/go-colorable" @@ -446,6 +447,64 @@ func TestColor_Sprintln_Newline(t *testing.T) { } } +func TestColor_Fprint(t *testing.T) { + rb := new(strings.Builder) + c := New(FgRed) + + n, err := c.Fprint(rb, "foo", "bar") + if err != nil { + t.Errorf("Fprint error: %v", err) + } + got := rb.String() + want := "\x1b[31mfoobar\x1b[0m" + + if want != got { + t.Errorf("Fprint error\n\nwant: %q\n got: %q", want, got) + } + if n != len(got) { + t.Errorf("Fprint byte count does not match actual bytes written\n\nwant: %d\n got: %d", len(got), n) + } +} + +func TestColor_Fprintln(t *testing.T) { + rb := new(strings.Builder) + c := New(FgRed) + + n, err := c.Fprintln(rb, "foo", "bar") + if err != nil { + t.Errorf("Fprint error: %v", err) + } + got := rb.String() + want := "\x1b[31mfoo bar\x1b[0m\n" + + if want != got { + t.Errorf("Fprintln error\n\nwant: %q\n got: %q", want, got) + } + if n != len(got) { + t.Errorf("Fprintln byte count does not match actual bytes written\n\nwant: %d\n got: %d", len(got), n) + } +} + +func TestColor_Fprintf(t *testing.T) { + rb := new(strings.Builder) + c := New(FgRed) + + n, err := c.Fprintf(rb, "%-7s %-7s %5d\n", "hello", "world", 123) + if err != nil { + t.Errorf("Fprint error: %v", err) + } + + want := "\x1b[31mhello world 123\n\x1b[0m" + + got := rb.String() + if want != got { + t.Errorf("Fprintf error\n\nwant: %q\n got: %q", want, got) + } + if n != len(got) { + t.Errorf("Fprintf byte count does not match actual bytes written\n\nwant: %d\n got: %d", len(got), n) + } +} + func TestColor_Fprintln_Newline(t *testing.T) { rb := new(bytes.Buffer) c := New(FgRed)