From eaabfea6b713a5590399caea25ab9b09a25a5f15 Mon Sep 17 00:00:00 2001 From: Zac Bergquist Date: Sun, 27 Mar 2022 16:13:36 -0600 Subject: [PATCH 1/2] Fix tsh player issues This commit fixes race conditions in the tsh session player by using a condition variable to detect state changes rather than unsafely polling a variable that is written by a separate goroutine. In addition, fix an off by one error when resuming playback after pausing. The player's position variable has always stored the index of the last succesfully played event, so when we resume playback we should start at position+1 to not re-play the previous event twice. Fixes #11479 --- lib/client/api.go | 2 +- lib/client/player.go | 152 +++++++++++++++++++++++------------- lib/client/player_test.go | 158 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 258 insertions(+), 54 deletions(-) create mode 100644 lib/client/player_test.go diff --git a/lib/client/api.go b/lib/client/api.go index 91b63661cf331..906746601d5f3 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -3534,7 +3534,7 @@ func playSession(sessionEvents []events.EventFields, stream []byte) error { ) // playback control goroutine go func() { - defer player.Stop() + defer player.RequestStop() var key [1]byte for { _, err := term.Stdin().Read(key[:]) diff --git a/lib/client/player.go b/lib/client/player.go index 872667aedec6d..93d9cb49797a1 100644 --- a/lib/client/player.go +++ b/lib/client/player.go @@ -25,10 +25,13 @@ import ( "github.com/gravitational/teleport/lib/client/terminal" "github.com/gravitational/teleport/lib/events" + "github.com/jonboulle/clockwork" ) +type tshPlayerState int + const ( - stateStopped = iota + stateStopped tshPlayerState = iota stateStopping statePlaying ) @@ -37,40 +40,41 @@ const ( // and allows to control it type sessionPlayer struct { sync.Mutex + cond *sync.Cond + + state tshPlayerState + position int // position is the index of the last event successfully played back + + clock clockwork.Clock stream []byte sessionEvents []events.EventFields term *terminal.Terminal - state int - position int - // stopC is used to tell the caller that player has finished playing - stopC chan int + stopC chan int + stopOnce sync.Once } func newSessionPlayer(sessionEvents []events.EventFields, stream []byte, term *terminal.Terminal) *sessionPlayer { - return &sessionPlayer{ + p := &sessionPlayer{ + clock: clockwork.NewRealClock(), + position: -1, // position is the last successfully written event stream: stream, sessionEvents: sessionEvents, - stopC: make(chan int), term: term, + stopC: make(chan int), } + p.cond = sync.NewCond(p) + return p } func (p *sessionPlayer) Play() { p.playRange(0, 0) } -func (p *sessionPlayer) Stop() { +func (p *sessionPlayer) Stopped() bool { p.Lock() defer p.Unlock() - if p.stopC != nil { - close(p.stopC) - p.stopC = nil - } -} - -func (p *sessionPlayer) Stopped() bool { return p.state == stateStopped } @@ -78,7 +82,7 @@ func (p *sessionPlayer) Rewind() { p.Lock() defer p.Unlock() if p.state != stateStopped { - p.state = stateStopping + p.setState(stateStopping) p.waitUntil(stateStopped) } if p.position > 0 { @@ -86,11 +90,17 @@ func (p *sessionPlayer) Rewind() { } } +func (p *sessionPlayer) stopRequested() bool { + p.Lock() + defer p.Unlock() + return p.state == stateStopping +} + func (p *sessionPlayer) Forward() { p.Lock() defer p.Unlock() if p.state != stateStopped { - p.state = stateStopping + p.setState(stateStopping) p.waitUntil(stateStopped) } if p.position < len(p.sessionEvents) { @@ -102,20 +112,44 @@ func (p *sessionPlayer) TogglePause() { p.Lock() defer p.Unlock() if p.state == statePlaying { - p.state = stateStopping + p.setState(stateStopping) p.waitUntil(stateStopped) } else { - p.playRange(p.position, 0) + p.playRange(p.position+1, 0) p.waitUntil(statePlaying) } } -func (p *sessionPlayer) waitUntil(state int) { +// RequestStop makes an asynchronous request for the player to stop playing. +// Playback may not stop before this method returns. +func (p *sessionPlayer) RequestStop() { + p.Lock() + defer p.Unlock() + + switch p.state { + case stateStopped, stateStopping: + // do nothing if stop already in progress + default: + p.setState(stateStopping) + } +} + +// waitUntil waits for the specified state to be reached. +// Callers must hold the lock on p.Mutex before calling. +func (p *sessionPlayer) waitUntil(state tshPlayerState) { for state != p.state { - time.Sleep(time.Millisecond) + p.cond.Wait() } } +// setState sets the current player state and notifies any +// goroutines waiting in waitUntil(). Callers must hold the +// lock on p.Mutex before calling. +func (p *sessionPlayer) setState(state tshPlayerState) { + p.state = state + p.cond.Broadcast() +} + // timestampFrame prints 'event timestamp' in the top right corner of the // terminal after playing every 'print' event func timestampFrame(term *terminal.Terminal, message string) { @@ -146,56 +180,45 @@ func timestampFrame(term *terminal.Terminal, message string) { // that playback starts from there. func (p *sessionPlayer) playRange(from, to int) { if to > len(p.sessionEvents) || from < 0 { - p.state = stateStopped + p.Lock() + p.setState(stateStopped) + p.Unlock() return } if to == 0 { to = len(p.sessionEvents) } // clear screen between runs: - os.Stdout.Write([]byte("\x1bc")) - // wait: waits between events during playback - prev := time.Duration(0) - wait := func(i int, e events.EventFields) { - ms := time.Duration(e.GetInt("ms")) - // before "from"? play that instantly: - if i >= from { - delay := ms - prev - // make playback smoother: - if delay < 10 { - delay = 0 - } - if delay > 250 && delay < 500 { - delay = 250 - } - if delay > 500 && delay < 1000 { - delay = 500 - } - if delay > 1000 { - delay = 1000 - } - timestampFrame(p.term, e.GetString("time")) - time.Sleep(time.Millisecond * delay) - } - prev = ms - } + // os.Stdout.Write([]byte("\x1bc")) + // playback goroutine: go func() { defer func() { - p.state = stateStopped + p.Lock() + p.setState(stateStopped) + p.Unlock() }() - p.state = statePlaying + + p.Lock() + p.setState(statePlaying) + p.Unlock() + + prev := time.Duration(0) i, offset, bytes := 0, 0, 0 for i = 0; i < to; i++ { - if p.state == stateStopping { + if p.stopRequested() { return } + e := p.sessionEvents[i] switch e.GetString(events.EventType) { // 'print' event (output) case events.SessionPrintEvent: - wait(i, e) + // delay is only necessary once we've caught up to the "from" event + if i >= from { + p.applyDelay(&prev, e) + } offset = e.GetInt("offset") bytes = e.GetInt("bytes") os.Stdout.Write(p.stream[offset : offset+bytes]) @@ -215,7 +238,30 @@ func (p *sessionPlayer) playRange(from, to int) { } // played last event? if i == len(p.sessionEvents) { - p.Stop() + // we defer here so the stop notification happens after the deferred final state update + defer p.stopOnce.Do(func() { close(p.stopC) }) } }() } + +// applyDelay waits until it is time to play back the current event (e). +func (p *sessionPlayer) applyDelay(previousTimestamp *time.Duration, e events.EventFields) { + eventTime := time.Duration(e.GetInt("ms") * int(time.Millisecond)) + delay := eventTime - *previousTimestamp + + // make playback smoother: + switch { + case delay < 10*time.Millisecond: + delay = 0 + case delay > 250*time.Millisecond && delay < 500*time.Millisecond: + delay = 250 * time.Millisecond + case delay > 500*time.Millisecond && delay < 1*time.Second: + delay = 500 * time.Millisecond + case delay > time.Second: + delay = time.Second + } + + timestampFrame(p.term, e.GetString("time")) + p.clock.Sleep(delay) + *previousTimestamp = eventTime +} diff --git a/lib/client/player_test.go b/lib/client/player_test.go new file mode 100644 index 0000000000000..8abfb99afa184 --- /dev/null +++ b/lib/client/player_test.go @@ -0,0 +1,158 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "bytes" + "testing" + "time" + + "github.com/gravitational/teleport/lib/client/terminal" + "github.com/gravitational/teleport/lib/events" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +// TestEmptyPlay verifies that a playback of 0 events +// immediately transitions to a stopped state. +func TestEmptyPlay(t *testing.T) { + c := clockwork.NewFakeClock() + p := newSessionPlayer(nil, nil, testTerm(t)) + p.clock = c + + p.Play() + + select { + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for player to complete") + case <-p.stopC: + } + + require.True(t, p.Stopped()) +} + +// TestStop verifies that we can stop playback. +func TestStop(t *testing.T) { + c := clockwork.NewFakeClock() + events := printEvents(100, 200) + p := newSessionPlayer(events, nil, testTerm(t)) + p.clock = c + + p.Play() + + // wait for player to see the first event and apply the delay + c.BlockUntil(1) + + p.RequestStop() + + // advance the clock: + // at this point, the player will write the first event and then + // see that we requested a stop + c.Advance(100 * time.Millisecond) + + require.Eventually(t, p.Stopped, 2*time.Second, 200*time.Millisecond) +} + +// TestPlayPause verifies the play/pause functionality. +func TestPlayPause(t *testing.T) { + c := clockwork.NewFakeClock() + + // in this test, we let the player play 2 of the 3 events, + // then pause it and verify the pause state before resuming + // playback for the final event. + events := printEvents(100, 200, 300) + var stream []byte // intentionally empty, we dont care about stream contents here + p := newSessionPlayer(events, stream, testTerm(t)) + p.clock = c + + p.Play() + + // wait for player to see the first event and apply the delay + c.BlockUntil(1) + + // advance the clock: + // at this point, the player will write the first event + c.Advance(100 * time.Millisecond) + + // wait for the player to sleep on the 2nd event + c.BlockUntil(1) + + // pause playback + // note: we don't use p.TogglePause here, as it waits for the state transition, + // and the state won't transition proceed until we advance the clock + p.Lock() + p.setState(stateStopping) + p.Unlock() + + // advance the clock again: + // the player will write the second event and + // then realize that it's been asked to pause + c.Advance(100 * time.Millisecond) + + p.Lock() + p.waitUntil(stateStopped) + p.Unlock() + + ch := make(chan struct{}) + go func() { + // resume playback + p.TogglePause() + ch <- struct{}{} + }() + + // playback should resume for the 3rd and final event: + // in this case, the first two events are written immediately without delay, + // and we block here until the player is sleeping prior to the 3rd event + c.BlockUntil(1) + + // make sure that we've resumed + <-ch + require.False(t, p.Stopped()) + + // advance the clock a final time, forcing the player to write the last event + // note: on the resume, we play the successful events immediately, and then sleep + // up to the resume point, which is why we advance by 300ms here + c.Advance(300 * time.Millisecond) + + select { + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for player to complete") + case <-p.stopC: + } + require.True(t, p.Stopped()) +} + +func testTerm(t *testing.T) *terminal.Terminal { + t.Helper() + term, err := terminal.New(bytes.NewReader(nil), &bytes.Buffer{}, &bytes.Buffer{}) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, term.Close()) + }) + return term +} + +func printEvents(delays ...int) []events.EventFields { + result := make([]events.EventFields, len(delays)) + for i := range result { + result[i] = events.EventFields{ + events.EventType: events.SessionPrintEvent, + "ms": delays[i], + } + } + return result +} From 79161fdfbeef222afe5776d496adc7c8ad6bead0 Mon Sep 17 00:00:00 2001 From: Zac Bergquist Date: Mon, 28 Mar 2022 13:37:06 -0600 Subject: [PATCH 2/2] Address review feedback --- lib/client/player.go | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/lib/client/player.go b/lib/client/player.go index 93d9cb49797a1..3a9b042b72014 100644 --- a/lib/client/player.go +++ b/lib/client/player.go @@ -189,14 +189,21 @@ func (p *sessionPlayer) playRange(from, to int) { to = len(p.sessionEvents) } // clear screen between runs: - // os.Stdout.Write([]byte("\x1bc")) + os.Stdout.Write([]byte("\x1bc")) // playback goroutine: go func() { + var i int + defer func() { p.Lock() p.setState(stateStopped) p.Unlock() + + // played last event? + if i == len(p.sessionEvents) { + p.stopOnce.Do(func() { close(p.stopC) }) + } }() p.Lock() @@ -204,7 +211,7 @@ func (p *sessionPlayer) playRange(from, to int) { p.Unlock() prev := time.Duration(0) - i, offset, bytes := 0, 0, 0 + offset, bytes := 0, 0 for i = 0; i < to; i++ { if p.stopRequested() { return @@ -217,7 +224,7 @@ func (p *sessionPlayer) playRange(from, to int) { case events.SessionPrintEvent: // delay is only necessary once we've caught up to the "from" event if i >= from { - p.applyDelay(&prev, e) + prev = p.applyDelay(prev, e) } offset = e.GetInt("offset") bytes = e.GetInt("bytes") @@ -234,20 +241,18 @@ func (p *sessionPlayer) playRange(from, to int) { default: continue } + p.Lock() p.position = i - } - // played last event? - if i == len(p.sessionEvents) { - // we defer here so the stop notification happens after the deferred final state update - defer p.stopOnce.Do(func() { close(p.stopC) }) + p.Unlock() } }() } -// applyDelay waits until it is time to play back the current event (e). -func (p *sessionPlayer) applyDelay(previousTimestamp *time.Duration, e events.EventFields) { +// applyDelay waits until it is time to play back the current event. +// It returns the duration from the start of the session up until the current event. +func (p *sessionPlayer) applyDelay(previousTimestamp time.Duration, e events.EventFields) time.Duration { eventTime := time.Duration(e.GetInt("ms") * int(time.Millisecond)) - delay := eventTime - *previousTimestamp + delay := eventTime - previousTimestamp // make playback smoother: switch { @@ -263,5 +268,5 @@ func (p *sessionPlayer) applyDelay(previousTimestamp *time.Duration, e events.Ev timestampFrame(p.term, e.GetString("time")) p.clock.Sleep(delay) - *previousTimestamp = eventTime + return eventTime }