-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Implement context-based cancellation in `/lib/utils/prompt`, for MFA prompts. This fixes the following scenario: ```sh User has both OTP and U2F devices registered. $ tsh mfa ls Name Type Added at Last used ----- ---- ----------------------------- ----------------------------- otp TOTP Wed, 21 Apr 2021 19:41:44 UTC Wed, 21 Apr 2021 19:44:32 UTC usb-a U2F Wed, 21 Apr 2021 19:44:34 UTC Wed, 21 Apr 2021 19:44:34 UTC Add a new OTP device, using existing U2F device: $ tsh mfa add Choose device type [TOTP, U2F]: totp Enter device name: otp2 Tap any *registered* security key or enter a code from a *registered* OTP device: <tap> # <- First OTP prompt here Open your TOTP app and create a new manual entry with these fields: Name: awly@localhost:3080 Issuer: Teleport Algorithm: SHA1 Number of digits: 6 Period: 30s Secret: 3UD42X2NN7EEZ6LUPG6NFMNOLDY6AJTS Once created, enter an OTP code generated by the app: 607738 # <- Second OTP prompt here MFA device "otp2" added. ``` Before this PR, the first OTP prompt (for `*registered* device`) would hang in the background. The OTP code from the newly-registered device is prompted later, but any text written ends up going to the first prompt. After this PR, the first prompt is canceled and the code from a new device goes to the second prompt as intended. Note: this is implemented using pure Go code (background goroutine consuming `os.Stdin`) rather than syscalls (e.g. `poll` or `select`) for portability.
- Loading branch information
Andrew Lytvynov
authored
May 3, 2021
1 parent
44d7ab5
commit f90cfc5
Showing
7 changed files
with
254 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
/* | ||
Copyright 2021 Gravitational, Inc. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package prompt | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"io" | ||
"os" | ||
"sync" | ||
) | ||
|
||
var ( | ||
stdinOnce = &sync.Once{} | ||
stdin *ContextReader | ||
) | ||
|
||
// Stdin returns a singleton ContextReader wrapped around os.Stdin. | ||
// | ||
// os.Stdin should not be used directly after the first call to this function | ||
// to avoid losing data. Closing this ContextReader will prevent all future | ||
// reads for all callers. | ||
func Stdin() *ContextReader { | ||
stdinOnce.Do(func() { | ||
stdin = NewContextReader(os.Stdin) | ||
}) | ||
return stdin | ||
} | ||
|
||
// ErrReaderClosed is returned from ContextReader.Read after it was closed. | ||
var ErrReaderClosed = errors.New("ContextReader has been closed") | ||
|
||
// ContextReader is a wrapper around io.Reader where each individual | ||
// ReadContext call can be canceled using a context. | ||
type ContextReader struct { | ||
r io.Reader | ||
data chan []byte | ||
close chan struct{} | ||
|
||
mu sync.RWMutex | ||
err error | ||
} | ||
|
||
// NewContextReader creates a new ContextReader wrapping r. Callers should not | ||
// use r after creating this ContextReader to avoid loss of data (the last read | ||
// will be lost). | ||
// | ||
// Callers are responsible for closing the ContextReader to release associated | ||
// resources. | ||
func NewContextReader(r io.Reader) *ContextReader { | ||
cr := &ContextReader{ | ||
r: r, | ||
data: make(chan []byte), | ||
close: make(chan struct{}), | ||
} | ||
go cr.read() | ||
return cr | ||
} | ||
|
||
func (r *ContextReader) setErr(err error) { | ||
r.mu.Lock() | ||
defer r.mu.Unlock() | ||
if r.err != nil { | ||
// Keep only the first encountered error. | ||
return | ||
} | ||
r.err = err | ||
} | ||
|
||
func (r *ContextReader) getErr() error { | ||
r.mu.RLock() | ||
defer r.mu.RUnlock() | ||
return r.err | ||
} | ||
|
||
func (r *ContextReader) read() { | ||
defer close(r.data) | ||
|
||
for { | ||
// Allocate a new buffer for every read because we need to send it to | ||
// another goroutine. | ||
buf := make([]byte, 4*1024) // 4kB, matches Linux page size. | ||
n, err := r.r.Read(buf) | ||
r.setErr(err) | ||
buf = buf[:n] | ||
if n == 0 { | ||
return | ||
} | ||
select { | ||
case <-r.close: | ||
return | ||
case r.data <- buf: | ||
} | ||
} | ||
} | ||
|
||
// ReadContext returns the next chunk of output from the reader. If ctx is | ||
// canceled before any data is available, ReadContext will return too. If r | ||
// was closed, ReadContext will return immediately with ErrReaderClosed. | ||
func (r *ContextReader) ReadContext(ctx context.Context) ([]byte, error) { | ||
select { | ||
case <-ctx.Done(): | ||
return nil, ctx.Err() | ||
case <-r.close: | ||
// Close was called, unblock immediately. | ||
// r.data might still be blocked if it's blocked on the Read call. | ||
return nil, r.getErr() | ||
case buf, ok := <-r.data: | ||
if !ok { | ||
// r.data was closed, so the read goroutine has finished. | ||
// No more data will be available, return the latest error. | ||
return nil, r.getErr() | ||
} | ||
return buf, nil | ||
} | ||
} | ||
|
||
// Close releases the background resources of r. All ReadContext calls will | ||
// unblock immediately. | ||
func (r *ContextReader) Close() { | ||
select { | ||
case <-r.close: | ||
// Already closed, do nothing. | ||
return | ||
default: | ||
close(r.close) | ||
r.setErr(ErrReaderClosed) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* | ||
Copyright 2021 Gravitational, Inc. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package prompt | ||
|
||
import ( | ||
"context" | ||
"io" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestContextReader(t *testing.T) { | ||
pr, pw := io.Pipe() | ||
t.Cleanup(func() { pr.Close() }) | ||
t.Cleanup(func() { pw.Close() }) | ||
|
||
write := func(t *testing.T, s string) { | ||
_, err := pw.Write([]byte(s)) | ||
require.NoError(t, err) | ||
} | ||
ctx := context.Background() | ||
|
||
r := NewContextReader(pr) | ||
|
||
t.Run("simple read", func(t *testing.T) { | ||
go write(t, "hello") | ||
buf, err := r.ReadContext(ctx) | ||
require.NoError(t, err) | ||
require.Equal(t, string(buf), "hello") | ||
}) | ||
|
||
t.Run("cancelled read", func(t *testing.T) { | ||
cancelCtx, cancel := context.WithCancel(ctx) | ||
go cancel() | ||
buf, err := r.ReadContext(cancelCtx) | ||
require.ErrorIs(t, err, context.Canceled) | ||
require.Empty(t, buf) | ||
|
||
go write(t, "after cancel") | ||
buf, err = r.ReadContext(ctx) | ||
require.NoError(t, err) | ||
require.Equal(t, string(buf), "after cancel") | ||
}) | ||
|
||
t.Run("close underlying reader", func(t *testing.T) { | ||
go func() { | ||
write(t, "before close") | ||
pw.CloseWithError(io.EOF) | ||
}() | ||
|
||
// Read the last chunk of data successfully. | ||
buf, err := r.ReadContext(ctx) | ||
require.NoError(t, err) | ||
require.Equal(t, string(buf), "before close") | ||
|
||
// Next read fails because underlying reader is closed. | ||
buf, err = r.ReadContext(ctx) | ||
require.ErrorIs(t, err, io.EOF) | ||
require.Empty(t, buf) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters