Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions integration/proxy/teleterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,24 @@ func newMockTSHDEventsServiceServer(t *testing.T, tc *libclient.TeleportClient,

grpcServer := grpc.NewServer()
api.RegisterTshdEventsServiceServer(grpcServer, tshdEventsService)
t.Cleanup(grpcServer.GracefulStop)

serveErr := make(chan error)
go func() {
err := grpcServer.Serve(ls)
assert.NoError(t, err)
Comment thread
Joerger marked this conversation as resolved.
Outdated
serveErr <- grpcServer.Serve(ls)
}()

t.Cleanup(func() {
grpcServer.GracefulStop()

// For test cases that did not send any grpc calls, test may finish
// before grpcServer.Serve is called and grpcServer.Serve will return
// grpc.ErrServerStopped.
err := <-serveErr
if err != grpc.ErrServerStopped {
assert.NoError(t, err)
}
})

return tshdEventsService, ls.Addr().String()
}

Expand Down
17 changes: 14 additions & 3 deletions integration/teleterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -704,13 +704,24 @@ func newMockTSHDEventsServiceServer(t *testing.T) (service *mockTSHDEventsServic

grpcServer := grpc.NewServer()
api.RegisterTshdEventsServiceServer(grpcServer, tshdEventsService)
t.Cleanup(grpcServer.GracefulStop)

serveErr := make(chan error)
go func() {
err := grpcServer.Serve(ls)
assert.NoError(t, err)
serveErr <- grpcServer.Serve(ls)
}()

t.Cleanup(func() {
grpcServer.GracefulStop()

// For test cases that did not send any grpc calls, test may finish
// before grpcServer.Serve is called and grpcServer.Serve will return
// grpc.ErrServerStopped.
err := <-serveErr
if err != grpc.ErrServerStopped {
assert.NoError(t, err)
}
})

return tshdEventsService, ls.Addr().String()
}

Expand Down
35 changes: 19 additions & 16 deletions lib/teleterm/daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ import (
"net"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -453,7 +455,7 @@ func TestRetryWithRelogin(t *testing.T) {
}
require.Equal(t, tt.wantFnCalls, fnCallCount,
"Unexpected number of calls to fn")
require.Equal(t, tt.wantReloginCalls, service.callCounts["Relogin"],
require.EqualValues(t, tt.wantReloginCalls, service.reloginCount.Load(),
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: require.Equal does not allow comparing int to uint32.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI you can cast the values too, but this is fine as well.

"Unexpected number of calls to service.Relogin")
})
}
Expand Down Expand Up @@ -505,7 +507,7 @@ func TestImportantModalSemaphore(t *testing.T) {
t.Error("relogin completed successfully without acquiring the important modal semaphore")
case <-sphaErrC:
t.Error("sendPendingHeadlessAuthentication completed successfully without acquiring the important modal semaphore")
case <-time.After(5 * customWaitDuration):
case <-time.After(100 * time.Millisecond):
}

// if the request's ctx is canceled, they will unblock and return an error instead.
Expand All @@ -532,7 +534,7 @@ func TestImportantModalSemaphore(t *testing.T) {
case err := <-sphaErrC:
require.NoError(t, err)
otherC = reloginErrC
case <-time.After(2 * customWaitDuration):
case <-time.After(time.Second):
t.Error("important modal operations failed to acquire unclaimed semaphore")
}

Expand All @@ -543,30 +545,30 @@ func TestImportantModalSemaphore(t *testing.T) {
select {
case err := <-otherC:
require.NoError(t, err)
case <-time.After(2 * customWaitDuration):
case <-time.After(time.Second):
t.Error("important modal operations failed to acquire unclaimed semaphore")
}

if time.Since(releaseTime) < 2*customWaitDuration {
t.Error("important modal semaphore should not be acquired before waiting the specified duration")
}

require.Equal(t, 1, service.callCounts["Relogin"], "Unexpected number of calls to service.Relogin")
require.Equal(t, 1, service.callCounts["SendPendingHeadlessAuthentication"], "Unexpected number of calls to service.SendPendingHeadlessAuthentication")
require.EqualValues(t, 1, service.reloginCount.Load(), "Unexpected number of calls to service.Relogin")
require.EqualValues(t, 1, service.sendPendingHeadlessAuthenticationCount.Load(), "Unexpected number of calls to service.SendPendingHeadlessAuthentication")
}

type mockTSHDEventsService struct {
*api.UnimplementedTshdEventsServiceServer
callCounts map[string]int
reloginErr error
reloginErr error
reloginCount atomic.Uint32
sendNotificationCount atomic.Uint32
sendPendingHeadlessAuthenticationCount atomic.Uint32
}

func newMockTSHDEventsServiceServer(t *testing.T) (service *mockTSHDEventsService, addr string) {
t.Helper()

tshdEventsService := &mockTSHDEventsService{
callCounts: make(map[string]int),
}
tshdEventsService := &mockTSHDEventsService{}

ls, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
Expand All @@ -578,35 +580,36 @@ func newMockTSHDEventsServiceServer(t *testing.T) (service *mockTSHDEventsServic
go func() {
serveErr <- grpcServer.Serve(ls)
Comment thread
Joerger marked this conversation as resolved.
Outdated
}()

t.Cleanup(func() {
grpcServer.GracefulStop()

// For test cases that did not send any grpc calls, test may finish
// before grpcServer.Serve is called and grpcServer.Serve will return
// grpc.ErrServerStopped.
err := <-serveErr
if len(tshdEventsService.callCounts) > 0 || err != grpc.ErrServerStopped {
require.NoError(t, err)
if err != grpc.ErrServerStopped {
assert.NoError(t, err)
}
})

return tshdEventsService, ls.Addr().String()
}

func (c *mockTSHDEventsService) Relogin(context.Context, *api.ReloginRequest) (*api.ReloginResponse, error) {
c.callCounts["Relogin"]++
c.reloginCount.Add(1)
if c.reloginErr != nil {
return nil, c.reloginErr
}
return &api.ReloginResponse{}, nil
}

func (c *mockTSHDEventsService) SendNotification(context.Context, *api.SendNotificationRequest) (*api.SendNotificationResponse, error) {
c.callCounts["SendNotification"]++
c.sendNotificationCount.Add(1)
return &api.SendNotificationResponse{}, nil
}

func (c *mockTSHDEventsService) SendPendingHeadlessAuthentication(context.Context, *api.SendPendingHeadlessAuthenticationRequest) (*api.SendPendingHeadlessAuthenticationResponse, error) {
c.callCounts["SendPendingHeadlessAuthentication"]++
c.sendPendingHeadlessAuthenticationCount.Add(1)
return &api.SendPendingHeadlessAuthenticationResponse{}, nil
}