Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 59 additions & 18 deletions tool/tsh/recording_export.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ package main
import (
"bytes"
"context"
"errors"
"fmt"
"image"
"image/draw"
"image/jpeg"
"image/png"
"strings"

"github.com/gravitational/trace"
"github.com/icza/mjpeg"
Expand Down Expand Up @@ -54,43 +56,56 @@ func onExportRecording(cf *CLIConf) error {
authClient := proxyClient.CurrentCluster()
defer authClient.Close()

fname := cf.OutFile
if fname == "" {
fname = cf.SessionID + ".avi"
filenamePrefix := cf.SessionID
if cf.OutFile != "" {
// trim the extension if it was provided (we'll add it later on)
filenamePrefix = strings.TrimSuffix(
strings.TrimSuffix(cf.OutFile, ".avi"), ".AVI")
}

frames, err := writeMovie(cf.Context, authClient, session.ID(cf.SessionID), fname)
// there may be a partial file, even if we encountered an error,
// so we indicate to the user when we wrote something
if frames > 0 {
fmt.Printf("wrote recording to %v\n", fname)
_, err = writeMovie(cf.Context, authClient, session.ID(cf.SessionID), filenamePrefix, fmt.Printf)
return trace.Wrap(err)
}

func makeAVIFileName(prefix string, currentFile int) string {
if currentFile == 0 {
return prefix + ".avi"
}

return trace.Wrap(err)
return fmt.Sprintf("%v-%d.avi", prefix, currentFile)
}

// writeMovie writes streams the events for the specified session into a movie file
// identified by fname. It returns the number of frames that were written and an error.
func writeMovie(ctx context.Context, ss events.SessionStreamer, sid session.ID, fname string) (int, error) {
// writeMovie writes the events for the specified session into one or more movie files
// beginning with the specified prefix. It returns the number of frames that were written and an error.
func writeMovie(ctx context.Context, ss events.SessionStreamer, sid session.ID, prefix string,
write func(format string, args ...any) (int, error)) (frames int, err error) {

var screen *image.NRGBA
var movie mjpeg.AviWriter

lastEmitted := int64(-1)
buf := new(bytes.Buffer)
frameCount := 0

var frameCount, fileCount int
var width, height int32
currentFilename := makeAVIFileName(prefix, fileCount)

evts, errs := ss.StreamSessionEvents(ctx, sid, 0)
loop:
for {
select {
case err := <-errs:
if movie != nil {
movie.Close()
if err := movie.Close(); err == nil && write != nil && frames > 0 {
write("wrote %v\n", currentFilename)
}
}
return frameCount, trace.Wrap(err)
case <-ctx.Done():
if movie != nil {
movie.Close()
if err := movie.Close(); err == nil && write != nil && frames > 0 {
write("wrote %v\n", currentFilename)
}
}
return frameCount, ctx.Err()
case evt, more := <-evts:
Expand Down Expand Up @@ -123,12 +138,13 @@ loop:
// the window during a session. If this changes, we'd have to
// find the maximum window size first.
log.Debugf("allocating %dx%d screen", msg.Width, msg.Height)
width, height = int32(msg.Width), int32(msg.Height)
screen = image.NewNRGBA(image.Rectangle{
Min: image.Pt(0, 0),
Max: image.Pt(int(msg.Width), int(msg.Height)),
})

movie, err = mjpeg.New(fname, int32(msg.Width), int32(msg.Height), framesPerSecond)
movie, err = mjpeg.New(currentFilename, width, height, framesPerSecond)
if err != nil {
return frameCount, trace.Wrap(err)
}
Expand Down Expand Up @@ -170,7 +186,27 @@ loop:
return frameCount, trace.Wrap(err)
}
for i := 0; i < int(framesToEmit); i++ {
if err := movie.AddFrame(buf.Bytes()); err != nil {
err := movie.AddFrame(buf.Bytes())
if errors.Is(err, mjpeg.ErrTooLarge) {
// this file can't get any larger - time to open a new file
if err := movie.Close(); err != nil {
return frameCount, trace.WrapWithMessage(err, "failed to write partial recording")
}
if write != nil {
write("wrote %v\n", currentFilename)
}
fileCount++
currentFilename = makeAVIFileName(prefix, fileCount)
movie, err = mjpeg.New(currentFilename, width, height, framesPerSecond)
if err != nil {
return frameCount, trace.Wrap(err)
}

// write the frame to the new file
if err := movie.AddFrame(buf.Bytes()); err != nil {
return frameCount, trace.Wrap(err)
}
} else if err != nil {
return frameCount, trace.Wrap(err)
}
frameCount++
Expand All @@ -190,7 +226,12 @@ loop:
return 0, trace.BadParameter("operation canceled")
}

return frameCount, trace.Wrap(movie.Close())
err = movie.Close()
if err == nil && write != nil {
write("wrote %v\n", currentFilename)
}

return frameCount, trace.Wrap(err)
}

func imgFromPNGMessage(msg tdp.Message) (image.Image, error) {
Expand Down
10 changes: 5 additions & 5 deletions tool/tsh/recording_export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestWriteMovieCanBeCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel()

frames, err := writeMovie(ctx, fs, "test", "test.avi")
frames, err := writeMovie(ctx, fs, "test", "test.avi", nil)
require.Equal(t, context.Canceled, err)
require.Equal(t, 0, frames)
}
Expand All @@ -68,7 +68,7 @@ func TestWriteMovieDoesNotSupportSSH(t *testing.T) {
}
fs := eventstest.NewFakeStreamer(events, 0)

frames, err := writeMovie(context.Background(), fs, "test", "test.avi")
frames, err := writeMovie(context.Background(), fs, "test", "test.avi", nil)
require.True(t, trace.IsBadParameter(err), "expected bad paramater error, got %v", err)
require.Equal(t, 0, frames)
}
Expand All @@ -88,7 +88,7 @@ func TestWriteMovieMultipleScreenSpecs(t *testing.T) {

fs := eventstest.NewFakeStreamer(events, 0)
t.Cleanup(func() { os.RemoveAll("test.avi") })
frames, err := writeMovie(context.Background(), fs, session.ID("test"), "test.avi")
frames, err := writeMovie(context.Background(), fs, session.ID("test"), "test.avi", nil)
require.True(t, trace.IsBadParameter(err), "expected bad paramater error, got %v", err)
require.Equal(t, 0, frames)
}
Expand All @@ -105,7 +105,7 @@ func TestWriteMovieWritesOneFrame(t *testing.T) {
}
fs := eventstest.NewFakeStreamer(events, 0)
t.Cleanup(func() { os.RemoveAll("test.avi") })
frames, err := writeMovie(context.Background(), fs, session.ID("test"), "test.avi")
frames, err := writeMovie(context.Background(), fs, session.ID("test"), "test.avi", nil)
require.NoError(t, err)
require.Equal(t, 1, frames)
}
Expand All @@ -122,7 +122,7 @@ func TestWriteMovieWritesManyFrames(t *testing.T) {
}
fs := eventstest.NewFakeStreamer(events, 0)
t.Cleanup(func() { os.RemoveAll("test.avi") })
frames, err := writeMovie(context.Background(), fs, session.ID("test"), "test.avi")
frames, err := writeMovie(context.Background(), fs, session.ID("test"), "test.avi", nil)
require.NoError(t, err)
require.Equal(t, framesPerSecond, frames)
}
Expand Down