diff --git a/lib/autoupdate/agent/installer.go b/lib/autoupdate/agent/installer.go index 6de16e17cb4c6..00d2174c31061 100644 --- a/lib/autoupdate/agent/installer.go +++ b/lib/autoupdate/agent/installer.go @@ -307,6 +307,7 @@ func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64, if err != nil { return nil, trace.Wrap(err) } + startTime := time.Now() resp, err := li.HTTP.Do(req) if err != nil { return nil, trace.Wrap(err) @@ -330,13 +331,23 @@ func (li *LocalInstaller) download(ctx context.Context, w io.Writer, max int64, } // Calculate checksum concurrently with download. shaReader := sha256.New() - n, err := io.CopyN(w, io.TeeReader(resp.Body, shaReader), size) + tee := io.TeeReader(resp.Body, shaReader) + tee = io.TeeReader(tee, &progressLogger{ + ctx: ctx, + log: li.Log, + level: slog.LevelInfo, + name: path.Base(resp.Request.URL.Path), + max: int(resp.ContentLength), + lines: 5, + }) + n, err := io.CopyN(w, tee, size) if err != nil { return nil, trace.Wrap(err) } if resp.ContentLength >= 0 && n != resp.ContentLength { return nil, trace.Errorf("mismatch in Teleport download size") } + li.Log.InfoContext(ctx, "Download complete.", "duration", time.Since(startTime), "size", n) return shaReader.Sum(nil), nil } diff --git a/lib/autoupdate/agent/logger.go b/lib/autoupdate/agent/logger.go new file mode 100644 index 0000000000000..41333d96729cd --- /dev/null +++ b/lib/autoupdate/agent/logger.go @@ -0,0 +1,106 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "fmt" + "log/slog" + + "github.com/gravitational/trace" +) + +// progressLogger logs progress of any data written as it approaches max. +// max(lines, call_count(Write)) lines are written for each multiple of max. +// progressLogger uses the variability of chunk size as a proxy for speed, and avoids +// logging extraneous lines that do not improve UX for waiting humans. +type progressLogger struct { + ctx context.Context + log *slog.Logger + level slog.Level + name string + max int + lines int + + l int + n int +} + +func (w *progressLogger) Write(p []byte) (n int, err error) { + w.n += len(p) + if w.n >= w.max*(w.l+1)/w.lines { + msg := fmt.Sprintf("%s - progress: %d%%", w.name, w.n*100/w.max) + w.log.Log(w.ctx, w.level, msg) //nolint:sloglint // msg cannot be constant + w.l++ + } + return len(p), nil +} + +// lineLogger logs each line written to it. +type lineLogger struct { + ctx context.Context + log *slog.Logger + level slog.Level + prefix string + + last bytes.Buffer +} + +func (w *lineLogger) out(s string) { + w.log.Log(w.ctx, w.level, w.prefix+s) //nolint:sloglint // msg cannot be constant +} + +func (w *lineLogger) Write(p []byte) (n int, err error) { + lines := bytes.Split(p, []byte("\n")) + // Finish writing line + if len(lines) > 0 { + n, err = w.last.Write(lines[0]) + lines = lines[1:] + } + // Quit if no newline + if len(lines) == 0 || err != nil { + return n, trace.Wrap(err) + } + + // Newline found, log line + w.out(w.last.String()) + n += 1 + w.last.Reset() + + // Log lines that are already newline-terminated + for _, line := range lines[:len(lines)-1] { + w.out(string(line)) + n += len(line) + 1 + } + + // Store remaining line non-newline-terminated line. + n2, err := w.last.Write(lines[len(lines)-1]) + n += n2 + return n, trace.Wrap(err) +} + +// Flush logs any trailing bytes that were never terminated with a newline. +func (w *lineLogger) Flush() { + if w.last.Len() == 0 { + return + } + w.out(w.last.String()) + w.last.Reset() +} diff --git a/lib/autoupdate/agent/logger_test.go b/lib/autoupdate/agent/logger_test.go new file mode 100644 index 0000000000000..9eec9348c7b60 --- /dev/null +++ b/lib/autoupdate/agent/logger_test.go @@ -0,0 +1,186 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package agent + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLineLogger(t *testing.T) { + t.Parallel() + + out := &bytes.Buffer{} + ll := lineLogger{ + ctx: context.Background(), + log: slog.New(slog.NewTextHandler(out, + &slog.HandlerOptions{ReplaceAttr: msgOnly}, + )), + } + + for _, e := range []struct { + v string + n int + }{ + {v: "", n: 0}, + {v: "a", n: 1}, + {v: "b\n", n: 2}, + {v: "c\nd", n: 3}, + {v: "e\nf\ng", n: 5}, + {v: "h", n: 1}, + {v: "", n: 0}, + {v: "\n", n: 1}, + {v: "i\n", n: 2}, + {v: "j", n: 1}, + } { + n, err := ll.Write([]byte(e.v)) + require.NoError(t, err) + require.Equal(t, e.n, n) + } + require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\n", out.String()) + ll.Flush() + require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\nmsg=j\n", out.String()) +} + +func msgOnly(_ []string, a slog.Attr) slog.Attr { + switch a.Key { + case "time", "level": + return slog.Attr{} + } + return slog.Attr{Key: a.Key, Value: a.Value} +} + +func TestProgressLogger(t *testing.T) { + t.Parallel() + + type write struct { + n int + out string + } + for _, tt := range []struct { + name string + max, lines int + writes []write + }{ + { + name: "even", + max: 100, + lines: 5, + writes: []write{ + {n: 10}, + {n: 10, out: "20%"}, + {n: 10}, + {n: 10, out: "40%"}, + {n: 10}, + {n: 10, out: "60%"}, + {n: 10}, + {n: 10, out: "80%"}, + {n: 10}, + {n: 10, out: "100%"}, + {n: 10}, + {n: 10, out: "120%"}, + }, + }, + { + name: "fast", + max: 100, + lines: 5, + writes: []write{ + {n: 100, out: "100%"}, + {n: 100, out: "200%"}, + }, + }, + { + name: "over fast", + max: 100, + lines: 5, + writes: []write{ + {n: 200, out: "200%"}, + }, + }, + { + name: "slow down when uneven", + max: 100, + lines: 5, + writes: []write{ + {n: 50, out: "50%"}, + {n: 10, out: "60%"}, + {n: 10, out: "70%"}, + {n: 10, out: "80%"}, + {n: 10}, + {n: 10, out: "100%"}, + {n: 10}, + {n: 10, out: "120%"}, + }, + }, + { + name: "slow down when very uneven", + max: 100, + lines: 5, + writes: []write{ + {n: 50, out: "50%"}, + {n: 1, out: "51%"}, + {n: 1}, + {n: 20, out: "72%"}, + {n: 10, out: "82%"}, + {n: 10}, + {n: 10, out: "102%"}, + }, + }, + { + name: "close", + max: 1000, + lines: 5, + writes: []write{ + {n: 999, out: "99%"}, + {n: 1, out: "100%"}, + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + out := &bytes.Buffer{} + ll := progressLogger{ + ctx: context.Background(), + log: slog.New(slog.NewTextHandler(out, + &slog.HandlerOptions{ReplaceAttr: msgOnly}, + )), + name: "test", + max: tt.max, + lines: tt.lines, + } + for _, e := range tt.writes { + n, err := ll.Write(make([]byte, e.n)) + require.NoError(t, err) + require.Equal(t, e.n, n) + v, err := io.ReadAll(out) + require.NoError(t, err) + if len(v) > 0 { + e.out = fmt.Sprintf(`msg="test - progress: %s"`+"\n", e.out) + } + require.Equal(t, e.out, string(v)) + } + }) + } +} diff --git a/lib/autoupdate/agent/process.go b/lib/autoupdate/agent/process.go index d2b2b081ad7fe..1117b76a4fc0e 100644 --- a/lib/autoupdate/agent/process.go +++ b/lib/autoupdate/agent/process.go @@ -377,55 +377,3 @@ func (c *localExec) Run(ctx context.Context, name string, args ...string) (int, } return code, trace.Wrap(err) } - -// lineLogger logs each line written to it. -type lineLogger struct { - ctx context.Context - log *slog.Logger - level slog.Level - prefix string - - last bytes.Buffer -} - -func (w *lineLogger) out(s string) { - w.log.Log(w.ctx, w.level, w.prefix+s) //nolint:sloglint // msg cannot be constant -} - -func (w *lineLogger) Write(p []byte) (n int, err error) { - lines := bytes.Split(p, []byte("\n")) - // Finish writing line - if len(lines) > 0 { - n, err = w.last.Write(lines[0]) - lines = lines[1:] - } - // Quit if no newline - if len(lines) == 0 || err != nil { - return n, trace.Wrap(err) - } - - // Newline found, log line - w.out(w.last.String()) - n += 1 - w.last.Reset() - - // Log lines that are already newline-terminated - for _, line := range lines[:len(lines)-1] { - w.out(string(line)) - n += len(line) + 1 - } - - // Store remaining line non-newline-terminated line. - n2, err := w.last.Write(lines[len(lines)-1]) - n += n2 - return n, trace.Wrap(err) -} - -// Flush logs any trailing bytes that were never terminated with a newline. -func (w *lineLogger) Flush() { - if w.last.Len() == 0 { - return - } - w.out(w.last.String()) - w.last.Reset() -} diff --git a/lib/autoupdate/agent/process_test.go b/lib/autoupdate/agent/process_test.go index c558a7539831a..e3d3d00ff085b 100644 --- a/lib/autoupdate/agent/process_test.go +++ b/lib/autoupdate/agent/process_test.go @@ -19,7 +19,6 @@ package agent import ( - "bytes" "context" "errors" "fmt" @@ -32,49 +31,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestLineLogger(t *testing.T) { - t.Parallel() - - out := &bytes.Buffer{} - ll := lineLogger{ - ctx: context.Background(), - log: slog.New(slog.NewTextHandler(out, - &slog.HandlerOptions{ReplaceAttr: msgOnly}, - )), - } - - for _, e := range []struct { - v string - n int - }{ - {v: "", n: 0}, - {v: "a", n: 1}, - {v: "b\n", n: 2}, - {v: "c\nd", n: 3}, - {v: "e\nf\ng", n: 5}, - {v: "h", n: 1}, - {v: "", n: 0}, - {v: "\n", n: 1}, - {v: "i\n", n: 2}, - {v: "j", n: 1}, - } { - n, err := ll.Write([]byte(e.v)) - require.NoError(t, err) - require.Equal(t, e.n, n) - } - require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\n", out.String()) - ll.Flush() - require.Equal(t, "msg=ab\nmsg=c\nmsg=de\nmsg=f\nmsg=gh\nmsg=i\nmsg=j\n", out.String()) -} - -func msgOnly(_ []string, a slog.Attr) slog.Attr { - switch a.Key { - case "time", "level": - return slog.Attr{} - } - return slog.Attr{Key: a.Key, Value: a.Value} -} - func TestWaitForStablePID(t *testing.T) { t.Parallel()