diff --git a/api/client/client_test.go b/api/client/client_test.go index cb4a0a379151d..0c490fa7be5cf 100644 --- a/api/client/client_test.go +++ b/api/client/client_test.go @@ -18,21 +18,19 @@ package client import ( "context" - "crypto/tls" + "flag" "fmt" "net" + "os" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" "github.com/gravitational/trace/trail" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/emptypb" "github.com/gravitational/teleport/api/client/proto" @@ -40,86 +38,150 @@ import ( "github.com/gravitational/teleport/api/types" ) -// mockServer mocks an Auth Server. -type mockServer struct { - addr string - grpc *grpc.Server - *proto.UnimplementedAuthServiceServer -} - -func newMockServer(addr string) *mockServer { - m := &mockServer{ - addr: addr, - grpc: grpc.NewServer(), - UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, +func TestMain(m *testing.M) { + flag.Parse() + if testing.Verbose() { + logrus.SetLevel(logrus.DebugLevel) } - proto.RegisterAuthServiceServer(m.grpc, m) - return m + os.Exit(m.Run()) } -func (m *mockServer) Stop() { - m.grpc.Stop() +type pingService struct { + *proto.UnimplementedAuthServiceServer } -func (m *mockServer) Addr() string { - return m.addr +func (s *pingService) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) { + return &proto.PingResponse{}, nil } -type ConfigOpt func(*Config) - -func WithConfig(cfg Config) ConfigOpt { - return func(config *Config) { - *config = cfg - } -} +func TestNew(t *testing.T) { + t.Parallel() + ctx := context.Background() + srv := startMockServer(t, &pingService{}) -func (m *mockServer) NewClient(ctx context.Context, opts ...ConfigOpt) (*Client, error) { - cfg := Config{ - Addrs: []string{m.addr}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials + tests := []struct { + desc string + modifyConfig func(*Config) + assertErr require.ErrorAssertionFunc + }{{ + desc: "successfully dial tcp address.", + modifyConfig: func(c *Config) { /* noop */ }, + assertErr: require.NoError, + }, { + desc: "synchronously dial addr/cred pairs and succeed with the 1 good pair.", + modifyConfig: func(c *Config) { + c.Addrs = append(c.Addrs, "bad addr", "bad addr") + c.Credentials = append([]Credentials{&tlsConfigCreds{nil}, &tlsConfigCreds{nil}}, c.Credentials...) }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option + assertErr: require.NoError, + }, { + desc: "fail to dial with a bad address.", + modifyConfig: func(c *Config) { + c.Addrs = []string{"bad addr"} }, - } + assertErr: func(t require.TestingT, err error, _ ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, "all connection methods failed") + }, + }, { + desc: "fail to dial with no address or dialer.", + modifyConfig: func(c *Config) { + c.Addrs = nil + }, + assertErr: func(t require.TestingT, err error, _ ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, "no connection methods found, try providing Dialer or Addrs in config") + }, + }} - for _, opt := range opts { - opt(&cfg) - } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + cfg := srv.clientCfg() + tt.modifyConfig(&cfg) + + clt, err := New(ctx, cfg) + tt.assertErr(t, err) + if err != nil { + return + } - return New(ctx, cfg) + // Requests to the server should succeed. + _, err = clt.Ping(ctx) + assert.NoError(t, err, "Ping failed") + assert.NoError(t, clt.Close(), "Close failed") + }) + } } -// startMockServer starts a new mock server. Parallel tests cannot use the same addr. -func startMockServer(t *testing.T) *mockServer { - l, err := net.Listen("tcp", "") +func TestNewDialBackground(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create a server but don't serve it yet. + l, err := net.Listen("tcp", "localhost:") + require.NoError(t, err) + addr := l.Addr().String() + srv := newMockServer(t, addr, &pingService{}) + + // Create client before the server is listening. + cfg := srv.clientCfg() + cfg.DialInBackground = true + clt, err := New(ctx, cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) + + // requests to the server will result in a connection error. + cancelCtx, cancel := context.WithTimeout(ctx, time.Second*3) + defer cancel() + _, err = clt.Ping(cancelCtx) + require.Error(t, err) + + // Server the listener and wait for the client connection to be ready. + srv.serve(t, l) + require.NoError(t, clt.waitForConnectionReady(ctx)) + + // requests to the server should succeed. + _, err = clt.Ping(ctx) require.NoError(t, err) - return startMockServerWithListener(t, l) } -// startMockServerWithListener starts a new mock server with the provided listener -func startMockServerWithListener(t *testing.T, l net.Listener) *mockServer { - srv := newMockServer(l.Addr().String()) +func TestWaitForConnectionReady(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create a server but don't serve it yet. + l, err := net.Listen("tcp", "localhost:") + require.NoError(t, err) + addr := l.Addr().String() + srv := newMockServer(t, addr, &proto.UnimplementedAuthServiceServer{}) - errCh := make(chan error, 1) - go func() { - errCh <- srv.grpc.Serve(l) - }() + // Create client before the server is listening. + cfg := srv.clientCfg() + cfg.DialInBackground = true + clt, err := New(ctx, cfg) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, clt.Close()) }) - t.Cleanup(func() { - srv.grpc.Stop() - require.NoError(t, <-errCh) - }) + // WaitForConnectionReady should return an error once the + // context is canceled if the server isn't open to connections. + cancelCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + require.Error(t, clt.waitForConnectionReady(cancelCtx)) - return srv + // WaitForConnectionReady should return nil if the server is open to connections. + srv.serve(t, l) + require.NoError(t, clt.waitForConnectionReady(ctx)) + + // WaitForConnectionReady should return an error if the grpc connection is closed. + require.NoError(t, clt.Close()) + require.Error(t, clt.waitForConnectionReady(ctx)) } -func (m *mockServer) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) { - return &proto.PingResponse{}, nil +type listResourcesService struct { + *proto.UnimplementedAuthServiceServer } -func (m *mockServer) ListResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) { +func (s *listResourcesService) ListResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) { resources, err := testResources(req.ResourceType, req.Namespace) if err != nil { return nil, trail.ToGRPC(err) @@ -204,10 +266,6 @@ func (m *mockServer) ListResources(ctx context.Context, req *proto.ListResources return resp, nil } -func (m *mockServer) AddMFADeviceSync(ctx context.Context, req *proto.AddMFADeviceSyncRequest) (*proto.AddMFADeviceSyncResponse, error) { - return nil, status.Error(codes.AlreadyExists, "Already Exists") -} - const fiveMBNode = "fiveMBNode" func testResources(resourceType, namespace string) ([]types.ResourceWithLabels, error) { @@ -342,185 +400,10 @@ func testResources(resourceType, namespace string) ([]types.ResourceWithLabels, return resources, nil } -// mockInsecureCredentials mocks insecure Client credentials. -// it returns a nil tlsConfig which allows the client to run in insecure mode. -// TODO(Joerger) replace insecure credentials with proper TLS credentials. -type mockInsecureTLSCredentials struct{} - -func (mc *mockInsecureTLSCredentials) Dialer(cfg Config) (ContextDialer, error) { - return nil, trace.NotImplemented("no dialer") -} - -func (mc *mockInsecureTLSCredentials) TLSConfig() (*tls.Config, error) { - return nil, nil -} - -func (mc *mockInsecureTLSCredentials) SSHClientConfig() (*ssh.ClientConfig, error) { - return nil, trace.NotImplemented("no ssh config") -} - -func TestNew(t *testing.T) { - t.Parallel() - ctx := context.Background() - srv := startMockServer(t) - - tests := []struct { - desc string - config Config - assertErr require.ErrorAssertionFunc - }{{ - desc: "successfully dial tcp address.", - config: Config{ - Addrs: []string{srv.Addr()}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }, - assertErr: require.NoError, - }, { - desc: "synchronously dial addr/cred pairs and succeed with the 1 good pair.", - config: Config{ - Addrs: []string{"bad addr", srv.Addr(), "bad addr"}, - Credentials: []Credentials{ - &tlsConfigCreds{nil}, - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - &tlsConfigCreds{nil}, - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }, - assertErr: require.NoError, - }, { - desc: "fail to dial with a bad address.", - config: Config{ - DialTimeout: time.Second, - Addrs: []string{"bad addr"}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }, - assertErr: func(t require.TestingT, err error, _ ...interface{}) { - require.Error(t, err) - require.Contains(t, err.Error(), "all connection methods failed") - }, - }, { - desc: "fail to dial with no address or dialer.", - config: Config{ - DialTimeout: time.Second, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }, - assertErr: func(t require.TestingT, err error, _ ...interface{}) { - require.Error(t, err) - require.Contains(t, err.Error(), "no connection methods found, try providing Dialer or Addrs in config") - }, - }} - - for _, tt := range tests { - t.Run(tt.desc, func(t *testing.T) { - clt, err := srv.NewClient(ctx, WithConfig(tt.config)) - tt.assertErr(t, err) - - if err == nil { - t.Cleanup(func() { require.NoError(t, clt.Close()) }) - // requests to the server should succeed. - _, err = clt.Ping(ctx) - require.NoError(t, err) - } - }) - } -} - -func TestNewDialBackground(t *testing.T) { - t.Parallel() - ctx := context.Background() - - // get listener but don't serve it yet. - l, err := net.Listen("tcp", "") - require.NoError(t, err) - addr := l.Addr().String() - - // Create client before the server is listening. - clt, err := New(ctx, Config{ - DialInBackground: true, - Addrs: []string{addr}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, clt.Close()) }) - - // requests to the server will result in a connection error. - cancelCtx, cancel := context.WithTimeout(ctx, time.Second*3) - defer cancel() - _, err = clt.Ping(cancelCtx) - require.Error(t, err) - - // Start the server and wait for the client connection to be ready. - startMockServerWithListener(t, l) - require.NoError(t, clt.waitForConnectionReady(ctx)) - - // requests to the server should succeed. - _, err = clt.Ping(ctx) - require.NoError(t, err) -} - -func TestWaitForConnectionReady(t *testing.T) { - t.Parallel() - ctx := context.Background() - - l, err := net.Listen("tcp", "") - require.NoError(t, err) - addr := l.Addr().String() - - // Create client before the server is listening. - clt, err := New(ctx, Config{ - DialInBackground: true, - Addrs: []string{addr}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, clt.Close()) }) - - // WaitForConnectionReady should return false once the - // context is canceled if the server isn't open to connections. - cancelCtx, cancel := context.WithTimeout(ctx, time.Second*3) - defer cancel() - require.Error(t, clt.waitForConnectionReady(cancelCtx)) - - // WaitForConnectionReady should return nil if the server is open to connections. - startMockServerWithListener(t, l) - require.NoError(t, clt.waitForConnectionReady(ctx)) - - // WaitForConnectionReady should return an error if the grpc connection is closed. - require.NoError(t, clt.Close()) - require.Error(t, clt.waitForConnectionReady(ctx)) -} - func TestListResources(t *testing.T) { t.Parallel() ctx := context.Background() - srv := startMockServer(t) + srv := startMockServer(t, &listResourcesService{}) testCases := map[string]struct { resourceType string @@ -549,7 +432,7 @@ func TestListResources(t *testing.T) { } // Create client - clt, err := srv.NewClient(ctx) + clt, err := New(ctx, srv.clientCfg()) require.NoError(t, err) for name, test := range testCases { @@ -588,10 +471,10 @@ func TestListResources(t *testing.T) { func TestGetResources(t *testing.T) { t.Parallel() ctx := context.Background() - srv := startMockServer(t) + srv := startMockServer(t, &listResourcesService{}) // Create client - clt, err := srv.NewClient(ctx) + clt, err := New(ctx, srv.clientCfg()) require.NoError(t, err) testCases := map[string]struct { @@ -640,11 +523,11 @@ func TestGetResources(t *testing.T) { } } -type mockAccessRequestServer struct { - *mockServer +type accessRequestService struct { + *proto.UnimplementedAuthServiceServer } -func (g *mockAccessRequestServer) GetAccessRequests(ctx context.Context, f *types.AccessRequestFilter) (*proto.AccessRequests, error) { +func (s *accessRequestService) GetAccessRequests(ctx context.Context, f *types.AccessRequestFilter) (*proto.AccessRequests, error) { req, err := types.NewAccessRequest("foo", "bob", "admin") if err != nil { return nil, trace.Wrap(err) @@ -658,85 +541,49 @@ func (g *mockAccessRequestServer) GetAccessRequests(ctx context.Context, f *type // TestAccessRequestDowngrade tests that the client will downgrade to the non stream API for fetching access requests // if the stream API is not available. func TestAccessRequestDowngrade(t *testing.T) { + t.Parallel() ctx := context.Background() - l, err := net.Listen("tcp", "") - require.NoError(t, err) - - m := &mockAccessRequestServer{ - &mockServer{ - addr: l.Addr().String(), - grpc: grpc.NewServer(), - UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, - }, - } - proto.RegisterAuthServiceServer(m.grpc, m) - t.Cleanup(m.grpc.Stop) - remoteErr := make(chan error) - go func() { - remoteErr <- m.grpc.Serve(l) - }() + server := startMockServer(t, &accessRequestService{}) - clt, err := m.NewClient(ctx) + clt, err := New(ctx, server.clientCfg()) require.NoError(t, err) items, err := clt.GetAccessRequests(ctx, types.AccessRequestFilter{}) require.NoError(t, err) require.Len(t, items, 1) - m.grpc.Stop() - require.NoError(t, <-remoteErr) } -type mockRoleServer struct { - *mockServer +type roleService struct { + *proto.UnimplementedAuthServiceServer roles map[string]*types.RoleV6 } -func newMockRoleServer() *mockRoleServer { - m := &mockRoleServer{ - &mockServer{ - grpc: grpc.NewServer(), - UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, - }, - make(map[string]*types.RoleV6), - } - proto.RegisterAuthServiceServer(m.grpc, m) - return m -} - -func startMockRoleServer(t *testing.T) string { - l, err := net.Listen("tcp", "") - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, l.Close()) }) - go newMockRoleServer().grpc.Serve(l) - return l.Addr().String() -} - -func (m *mockRoleServer) GetRole(ctx context.Context, req *proto.GetRoleRequest) (*types.RoleV6, error) { - conn, ok := m.roles[req.Name] +func (s *roleService) GetRole(ctx context.Context, req *proto.GetRoleRequest) (*types.RoleV6, error) { + role, ok := s.roles[req.Name] if !ok { return nil, trace.NotFound("not found") } - return conn, nil + return role, nil } -func (m *mockRoleServer) GetRoles(ctx context.Context, _ *emptypb.Empty) (*proto.GetRolesResponse, error) { - var connectors []*types.RoleV6 - for _, conn := range m.roles { - connectors = append(connectors, conn) +func (s *roleService) GetRoles(ctx context.Context, _ *emptypb.Empty) (*proto.GetRolesResponse, error) { + var roles []*types.RoleV6 + for _, role := range s.roles { + roles = append(roles, role) } return &proto.GetRolesResponse{ - Roles: connectors, + Roles: roles, }, nil } -func (m *mockRoleServer) UpsertRole(ctx context.Context, role *types.RoleV6) (*emptypb.Empty, error) { - m.roles[role.Metadata.Name] = role +func (s *roleService) UpsertRole(ctx context.Context, role *types.RoleV6) (*emptypb.Empty, error) { + s.roles[role.Metadata.Name] = role return &emptypb.Empty{}, nil } -func (m *mockRoleServer) GetCurrentUserRoles(_ *emptypb.Empty, stream proto.AuthService_GetCurrentUserRolesServer) error { - for _, role := range m.roles { +func (s *roleService) GetCurrentUserRoles(_ *emptypb.Empty, stream proto.AuthService_GetCurrentUserRolesServer) error { + for _, role := range s.roles { if err := stream.Send(role); err != nil { return trace.Wrap(err) } @@ -748,19 +595,14 @@ func (m *mockRoleServer) GetCurrentUserRoles(_ *emptypb.Empty, stream proto.Auth // Test that client will perform properly with an old server // DELETE IN 13.0.0 func TestSetRoleRequireSessionMFABackwardsCompatibility(t *testing.T) { + t.Parallel() ctx := context.Background() - addr := startMockRoleServer(t) - // Create client - clt, err := New(ctx, Config{ - Addrs: []string{addr}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, + server := startMockServer(t, &roleService{ + roles: make(map[string]*types.RoleV6), }) + + clt, err := New(ctx, server.clientCfg()) require.NoError(t, err) role := &types.RoleV6{ @@ -808,58 +650,32 @@ func TestSetRoleRequireSessionMFABackwardsCompatibility(t *testing.T) { }) } -type mockAuthPreferenceServer struct { - *mockServer +type authPreferenceService struct { + *proto.UnimplementedAuthServiceServer pref *types.AuthPreferenceV2 } -func newMockAuthPreferenceServer() *mockAuthPreferenceServer { - m := &mockAuthPreferenceServer{ - mockServer: &mockServer{ - grpc: grpc.NewServer(), - UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, - }, - } - proto.RegisterAuthServiceServer(m.grpc, m) - return m -} - -func startMockAuthPreferenceServer(t *testing.T) string { - l, err := net.Listen("tcp", "") - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, l.Close()) }) - go newMockAuthPreferenceServer().grpc.Serve(l) - return l.Addr().String() -} - -func (m *mockAuthPreferenceServer) GetAuthPreference(ctx context.Context, _ *emptypb.Empty) (*types.AuthPreferenceV2, error) { - if m.pref == nil { +func (s *authPreferenceService) GetAuthPreference(ctx context.Context, _ *emptypb.Empty) (*types.AuthPreferenceV2, error) { + if s.pref == nil { return nil, trace.NotFound("not found") } - return m.pref, nil + return s.pref, nil } -func (m *mockAuthPreferenceServer) SetAuthPreference(ctx context.Context, pref *types.AuthPreferenceV2) (*emptypb.Empty, error) { - m.pref = pref +func (s *authPreferenceService) SetAuthPreference(ctx context.Context, pref *types.AuthPreferenceV2) (*emptypb.Empty, error) { + s.pref = pref return &emptypb.Empty{}, nil } // Test that client will perform properly with an old server // DELETE IN 13.0.0 func TestSetAuthPreferenceRequireSessionMFABackwardsCompatibility(t *testing.T) { + t.Parallel() ctx := context.Background() - addr := startMockAuthPreferenceServer(t) - // Create client - clt, err := New(ctx, Config{ - Addrs: []string{addr}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }) + server := startMockServer(t, &authPreferenceService{}) + + clt, err := New(ctx, server.clientCfg()) require.NoError(t, err) pref := &types.AuthPreferenceV2{ diff --git a/api/client/mock_server_test.go b/api/client/mock_server_test.go new file mode 100644 index 0000000000000..4840961a99258 --- /dev/null +++ b/api/client/mock_server_test.go @@ -0,0 +1,84 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/testhelpers/mtls" +) + +// mockServer mocks an Auth Server. +type mockServer struct { + addr string + grpc *grpc.Server + mtlsConfig *mtls.Config +} + +func newMockServer(t *testing.T, addr string, service proto.AuthServiceServer) *mockServer { + t.Helper() + m := &mockServer{ + addr: addr, + mtlsConfig: mtls.NewConfig(t), + } + + m.grpc = grpc.NewServer( + grpc.Creds(credentials.NewTLS(m.mtlsConfig.ServerTLS)), + ) + + proto.RegisterAuthServiceServer(m.grpc, service) + return m +} + +// startMockServer starts a new mock server. Parallel tests cannot use the same addr. +func startMockServer(t *testing.T, service proto.AuthServiceServer) *mockServer { + l, err := net.Listen("tcp", "localhost:") + require.NoError(t, err) + srv := newMockServer(t, l.Addr().String(), service) + srv.serve(t, l) + return srv +} + +func (m *mockServer) serve(t *testing.T, l net.Listener) { + errCh := make(chan error, 1) + go func() { + errCh <- m.grpc.Serve(l) + }() + + t.Cleanup(func() { + m.grpc.Stop() + require.NoError(t, <-errCh, "mockServer gRPC server exited with unexpected error") + }) +} + +func (m *mockServer) clientCfg() Config { + return Config{ + // Reduce dial timeout for tests. + DialTimeout: time.Second, + Addrs: []string{m.addr}, + Credentials: []Credentials{ + LoadTLS(m.mtlsConfig.ClientTLS), + }, + } +} diff --git a/api/testhelpers/mtls/mtls.go b/api/testhelpers/mtls/mtls.go new file mode 100644 index 0000000000000..07476fbf65897 --- /dev/null +++ b/api/testhelpers/mtls/mtls.go @@ -0,0 +1,139 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mtls + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/keys" +) + +type Config struct { + ServerTLS *tls.Config + ClientTLS *tls.Config +} + +// NewConfig returns an mTLS config. +func NewConfig(t *testing.T) *Config { + t.Helper() + + caKey, caCert := generateCA(t) + serverTLS := generateChildTLSConfigFromCA(t, caKey, caCert) + clientTLS := generateChildTLSConfigFromCA(t, caKey, caCert) + clientTLS.ServerName = constants.APIDomain + + return &Config{ + ServerTLS: serverTLS, + ClientTLS: clientTLS, + } +} + +func generateCA(t *testing.T) (*keys.PrivateKey, *x509.Certificate) { + t.Helper() + + caPub, caPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + caKey, err := keys.NewPrivateKey(caPriv, nil) + require.NoError(t, err) + + // Create a self signed certificate. + + notBefore := time.Now() + notAfter := notBefore.Add(time.Minute) + entity := pkix.Name{ + Organization: []string{"teleport"}, + CommonName: "localhost", + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Issuer: entity, + Subject: entity, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageCertSign, + IsCA: true, + BasicConstraintsValid: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, template, template, caPub, caKey) + require.NoError(t, err) + + x509Cert, err := x509.ParseCertificate(caCertDER) + require.NoError(t, err) + + return caKey, x509Cert +} + +func generateChildTLSConfigFromCA(t *testing.T, caKey *keys.PrivateKey, caCert *x509.Certificate) *tls.Config { + t.Helper() + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + key, err := keys.NewPrivateKey(priv, nil) + require.NoError(t, err) + + // Create a certificate signed by the CA. + + notBefore := time.Now() + notAfter := notBefore.Add(time.Minute) + entity := pkix.Name{ + Organization: []string{"teleport"}, + CommonName: "localhost", + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: entity, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + DNSNames: []string{constants.APIDomain}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, pub, caKey) + require.NoError(t, err) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := key.TLSCertificate(certPEM) + require.NoError(t, err) + + pool := x509.NewCertPool() + pool.AddCert(caCert) + + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: pool, + } +}