From 0feb1d0fbd8152ddc55ab21e04dcce3474435be7 Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 16 Aug 2023 13:06:49 -0700 Subject: [PATCH 1/4] Move mockServer into a separate file. --- api/client/client_test.go | 374 ------------------------------ api/client/mock_server_test.go | 405 +++++++++++++++++++++++++++++++++ 2 files changed, 405 insertions(+), 374 deletions(-) create mode 100644 api/client/mock_server_test.go diff --git a/api/client/client_test.go b/api/client/client_test.go index 7b6623780e5a8..ac61514d892fe 100644 --- a/api/client/client_test.go +++ b/api/client/client_test.go @@ -18,9 +18,7 @@ package client import ( "context" - "crypto/tls" "flag" - "fmt" "net" "os" "testing" @@ -28,388 +26,16 @@ import ( "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" - "github.com/gravitational/trace/trail" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/status" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" ) -// mockServer mocks an Auth Server. -type mockServer struct { - addr string - grpc *grpc.Server - *proto.UnimplementedAuthServiceServer -} - -func newMockServer(addr string) *mockServer { - m := &mockServer{ - addr: addr, - grpc: grpc.NewServer(), - UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, - } - proto.RegisterAuthServiceServer(m.grpc, m) - return m -} - -func (m *mockServer) Stop() { - m.grpc.Stop() -} - -func (m *mockServer) Addr() string { - return m.addr -} - -type ConfigOpt func(*Config) - -func WithConfig(cfg Config) ConfigOpt { - return func(config *Config) { - *config = cfg - } -} - -func (m *mockServer) NewClient(ctx context.Context, opts ...ConfigOpt) (*Client, error) { - cfg := Config{ - Addrs: []string{m.addr}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - } - - for _, opt := range opts { - opt(&cfg) - } - - return New(ctx, cfg) -} - -// startMockServer starts a new mock server. Parallel tests cannot use the same addr. -func startMockServer(t *testing.T) *mockServer { - l, err := net.Listen("tcp", "localhost:") - require.NoError(t, err) - return startMockServerWithListener(t, l) -} - -// startMockServerWithListener starts a new mock server with the provided listener -func startMockServerWithListener(t *testing.T, l net.Listener) *mockServer { - srv := newMockServer(l.Addr().String()) - - errCh := make(chan error, 1) - go func() { - errCh <- srv.grpc.Serve(l) - }() - - t.Cleanup(func() { - srv.grpc.Stop() - require.NoError(t, <-errCh) - }) - - return srv -} - -func (m *mockServer) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) { - return &proto.PingResponse{}, nil -} - -func (m *mockServer) ListResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) { - resources, err := testResources[types.ResourceWithLabels](req.ResourceType, req.Namespace) - if err != nil { - return nil, trail.ToGRPC(err) - } - - resp := &proto.ListResourcesResponse{ - Resources: make([]*proto.PaginatedResource, 0, len(resources)), - TotalCount: int32(len(resources)), - } - - var ( - takeResources = req.StartKey == "" - lastResourceName string - ) - for _, resource := range resources { - if resource.GetName() == req.StartKey { - takeResources = true - continue - } - - if !takeResources { - continue - } - - var protoResource *proto.PaginatedResource - switch req.ResourceType { - case types.KindDatabaseServer: - database, ok := resource.(*types.DatabaseServerV3) - if !ok { - return nil, trace.Errorf("database server has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseServer{DatabaseServer: database}} - case types.KindAppServer: - app, ok := resource.(*types.AppServerV3) - if !ok { - return nil, trace.Errorf("application server has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_AppServer{AppServer: app}} - case types.KindNode: - srv, ok := resource.(*types.ServerV2) - if !ok { - return nil, trace.Errorf("node has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_Node{Node: srv}} - case types.KindKubeServer: - srv, ok := resource.(*types.KubernetesServerV3) - if !ok { - return nil, trace.Errorf("kubernetes server has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_KubernetesServer{KubernetesServer: srv}} - case types.KindWindowsDesktop: - desktop, ok := resource.(*types.WindowsDesktopV3) - if !ok { - return nil, trace.Errorf("windows desktop has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_WindowsDesktop{WindowsDesktop: desktop}} - case types.KindAppOrSAMLIdPServiceProvider: - appServerOrSP, ok := resource.(*types.AppServerOrSAMLIdPServiceProviderV1) - if !ok { - return nil, trace.Errorf("AppServerOrSAMLIdPServiceProvider has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_AppServerOrSAMLIdPServiceProvider{AppServerOrSAMLIdPServiceProvider: appServerOrSP}} - } - resp.Resources = append(resp.Resources, protoResource) - lastResourceName = resource.GetName() - if len(resp.Resources) == int(req.Limit) { - break - } - } - - if len(resp.Resources) != len(resources) { - resp.NextKey = lastResourceName - } - - return resp, nil -} - -func (m *mockServer) AddMFADeviceSync(ctx context.Context, req *proto.AddMFADeviceSyncRequest) (*proto.AddMFADeviceSyncResponse, error) { - return nil, status.Error(codes.AlreadyExists, "Already Exists") -} - -const fiveMBNode = "fiveMBNode" - -func testResources[T types.ResourceWithLabels](resourceType, namespace string) ([]T, error) { - size := 50 - // Artificially make each node ~ 100KB to force - // ListResources to fail with chunks of >= 40. - labelSize := 100000 - resources := make([]T, 0, size) - - switch resourceType { - case types.KindDatabaseServer: - for i := 0; i < size; i++ { - resource, err := types.NewDatabaseServerV3(types.Metadata{ - Name: fmt.Sprintf("db-%d", i), - Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }, - }, types.DatabaseServerSpecV3{ - Hostname: "localhost", - HostID: fmt.Sprintf("host-%d", i), - Database: &types.DatabaseV3{ - Metadata: types.Metadata{ - Name: fmt.Sprintf("db-%d", i), - }, - Spec: types.DatabaseSpecV3{ - Protocol: types.DatabaseProtocolPostgreSQL, - URI: "localhost", - }, - }, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindAppServer: - for i := 0; i < size; i++ { - app, err := types.NewAppV3(types.Metadata{ - Name: fmt.Sprintf("app-%d", i), - }, types.AppSpecV3{ - URI: "localhost", - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resource, err := types.NewAppServerV3(types.Metadata{ - Name: fmt.Sprintf("app-%d", i), - Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }, - }, types.AppServerSpecV3{ - HostID: fmt.Sprintf("host-%d", i), - App: app, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindNode: - for i := 0; i < size; i++ { - nodeLabelSize := labelSize - if namespace == fiveMBNode && i == 0 { - // Artificially make a node ~ 5MB to force - // ListNodes to fail regardless of chunk size. - nodeLabelSize = 5000000 - } - - var err error - resource, err := types.NewServerWithLabels(fmt.Sprintf("node-%d", i), types.KindNode, types.ServerSpecV2{}, - map[string]string{ - "label": string(make([]byte, nodeLabelSize)), - }, - ) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindKubeServer: - for i := 0; i < size; i++ { - var err error - name := fmt.Sprintf("kube-service-%d", i) - kube, err := types.NewKubernetesClusterV3(types.Metadata{ - Name: name, - Labels: map[string]string{"name": name}, - }, - types.KubernetesClusterSpecV3{}, - ) - if err != nil { - return nil, trace.Wrap(err) - } - resource, err := types.NewKubernetesServerV3( - types.Metadata{ - Name: name, - Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }, - }, - types.KubernetesServerSpecV3{ - HostID: fmt.Sprintf("host-%d", i), - Cluster: kube, - }, - ) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindWindowsDesktop: - for i := 0; i < size; i++ { - var err error - name := fmt.Sprintf("windows-desktop-%d", i) - resource, err := types.NewWindowsDesktopV3( - name, - map[string]string{"label": string(make([]byte, labelSize))}, - types.WindowsDesktopSpecV3{ - Addr: "_", - HostID: "_", - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindAppOrSAMLIdPServiceProvider: - for i := 0; i < size; i++ { - // Alternate between adding Apps and SAMLIdPServiceProviders. If `i` is even, add an app. - if i%2 == 0 { - app, err := types.NewAppV3(types.Metadata{ - Name: fmt.Sprintf("app-%d", i), - }, types.AppSpecV3{ - URI: "localhost", - }) - if err != nil { - return nil, trace.Wrap(err) - } - - appServer, err := types.NewAppServerV3(types.Metadata{ - Name: fmt.Sprintf("app-%d", i), - Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }, - }, types.AppServerSpecV3{ - HostID: fmt.Sprintf("host-%d", i), - App: app, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resource := &types.AppServerOrSAMLIdPServiceProviderV1{ - Resource: &types.AppServerOrSAMLIdPServiceProviderV1_AppServer{ - AppServer: appServer, - }, - } - - resources = append(resources, any(resource).(T)) - } else { - sp := &types.SAMLIdPServiceProviderV1{ResourceHeader: types.ResourceHeader{Metadata: types.Metadata{Name: fmt.Sprintf("saml-app-%d", i), Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }}}} - - resource := &types.AppServerOrSAMLIdPServiceProviderV1{ - Resource: &types.AppServerOrSAMLIdPServiceProviderV1_SAMLIdPServiceProvider{ - SAMLIdPServiceProvider: sp, - }, - } - resources = append(resources, any(resource).(T)) - } - } - default: - return nil, trace.Errorf("unsupported resource type %s", resourceType) - } - - return resources, nil -} - -// mockInsecureCredentials mocks insecure Client credentials. -// it returns a nil tlsConfig which allows the client to run in insecure mode. -// TODO(Joerger) replace insecure credentials with proper TLS credentials. -type mockInsecureTLSCredentials struct{} - -func (mc *mockInsecureTLSCredentials) Dialer(cfg Config) (ContextDialer, error) { - return nil, trace.NotImplemented("no dialer") -} - -func (mc *mockInsecureTLSCredentials) TLSConfig() (*tls.Config, error) { - return nil, nil -} - -func (mc *mockInsecureTLSCredentials) SSHClientConfig() (*ssh.ClientConfig, error) { - return nil, trace.NotImplemented("no ssh config") -} - func TestNew(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/api/client/mock_server_test.go b/api/client/mock_server_test.go new file mode 100644 index 0000000000000..8c02ab987c37d --- /dev/null +++ b/api/client/mock_server_test.go @@ -0,0 +1,405 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package client + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "testing" + + "github.com/gravitational/trace" + "github.com/gravitational/trace/trail" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" +) + +// mockServer mocks an Auth Server. +type mockServer struct { + addr string + grpc *grpc.Server + *proto.UnimplementedAuthServiceServer +} + +func newMockServer(addr string) *mockServer { + m := &mockServer{ + addr: addr, + grpc: grpc.NewServer(), + UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, + } + proto.RegisterAuthServiceServer(m.grpc, m) + return m +} + +func (m *mockServer) Stop() { + m.grpc.Stop() +} + +func (m *mockServer) Addr() string { + return m.addr +} + +type ConfigOpt func(*Config) + +func WithConfig(cfg Config) ConfigOpt { + return func(config *Config) { + *config = cfg + } +} + +func (m *mockServer) NewClient(ctx context.Context, opts ...ConfigOpt) (*Client, error) { + cfg := Config{ + Addrs: []string{m.addr}, + Credentials: []Credentials{ + &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials + }, + DialOpts: []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option + }, + } + + for _, opt := range opts { + opt(&cfg) + } + + return New(ctx, cfg) +} + +// startMockServer starts a new mock server. Parallel tests cannot use the same addr. +func startMockServer(t *testing.T) *mockServer { + l, err := net.Listen("tcp", "localhost:") + require.NoError(t, err) + return startMockServerWithListener(t, l) +} + +// startMockServerWithListener starts a new mock server with the provided listener +func startMockServerWithListener(t *testing.T, l net.Listener) *mockServer { + srv := newMockServer(l.Addr().String()) + + errCh := make(chan error, 1) + go func() { + errCh <- srv.grpc.Serve(l) + }() + + t.Cleanup(func() { + srv.grpc.Stop() + require.NoError(t, <-errCh) + }) + + return srv +} + +func (m *mockServer) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) { + return &proto.PingResponse{}, nil +} + +func (m *mockServer) ListResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) { + resources, err := testResources[types.ResourceWithLabels](req.ResourceType, req.Namespace) + if err != nil { + return nil, trail.ToGRPC(err) + } + + resp := &proto.ListResourcesResponse{ + Resources: make([]*proto.PaginatedResource, 0, len(resources)), + TotalCount: int32(len(resources)), + } + + var ( + takeResources = req.StartKey == "" + lastResourceName string + ) + for _, resource := range resources { + if resource.GetName() == req.StartKey { + takeResources = true + continue + } + + if !takeResources { + continue + } + + var protoResource *proto.PaginatedResource + switch req.ResourceType { + case types.KindDatabaseServer: + database, ok := resource.(*types.DatabaseServerV3) + if !ok { + return nil, trace.Errorf("database server has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseServer{DatabaseServer: database}} + case types.KindAppServer: + app, ok := resource.(*types.AppServerV3) + if !ok { + return nil, trace.Errorf("application server has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_AppServer{AppServer: app}} + case types.KindNode: + srv, ok := resource.(*types.ServerV2) + if !ok { + return nil, trace.Errorf("node has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_Node{Node: srv}} + case types.KindKubeServer: + srv, ok := resource.(*types.KubernetesServerV3) + if !ok { + return nil, trace.Errorf("kubernetes server has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_KubernetesServer{KubernetesServer: srv}} + case types.KindWindowsDesktop: + desktop, ok := resource.(*types.WindowsDesktopV3) + if !ok { + return nil, trace.Errorf("windows desktop has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_WindowsDesktop{WindowsDesktop: desktop}} + case types.KindAppOrSAMLIdPServiceProvider: + appServerOrSP, ok := resource.(*types.AppServerOrSAMLIdPServiceProviderV1) + if !ok { + return nil, trace.Errorf("AppServerOrSAMLIdPServiceProvider has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_AppServerOrSAMLIdPServiceProvider{AppServerOrSAMLIdPServiceProvider: appServerOrSP}} + } + resp.Resources = append(resp.Resources, protoResource) + lastResourceName = resource.GetName() + if len(resp.Resources) == int(req.Limit) { + break + } + } + + if len(resp.Resources) != len(resources) { + resp.NextKey = lastResourceName + } + + return resp, nil +} + +func (m *mockServer) AddMFADeviceSync(ctx context.Context, req *proto.AddMFADeviceSyncRequest) (*proto.AddMFADeviceSyncResponse, error) { + return nil, status.Error(codes.AlreadyExists, "Already Exists") +} + +const fiveMBNode = "fiveMBNode" + +func testResources[T types.ResourceWithLabels](resourceType, namespace string) ([]T, error) { + size := 50 + // Artificially make each node ~ 100KB to force + // ListResources to fail with chunks of >= 40. + labelSize := 100000 + resources := make([]T, 0, size) + + switch resourceType { + case types.KindDatabaseServer: + for i := 0; i < size; i++ { + resource, err := types.NewDatabaseServerV3(types.Metadata{ + Name: fmt.Sprintf("db-%d", i), + Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }, + }, types.DatabaseServerSpecV3{ + Hostname: "localhost", + HostID: fmt.Sprintf("host-%d", i), + Database: &types.DatabaseV3{ + Metadata: types.Metadata{ + Name: fmt.Sprintf("db-%d", i), + }, + Spec: types.DatabaseSpecV3{ + Protocol: types.DatabaseProtocolPostgreSQL, + URI: "localhost", + }, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindAppServer: + for i := 0; i < size; i++ { + app, err := types.NewAppV3(types.Metadata{ + Name: fmt.Sprintf("app-%d", i), + }, types.AppSpecV3{ + URI: "localhost", + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resource, err := types.NewAppServerV3(types.Metadata{ + Name: fmt.Sprintf("app-%d", i), + Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }, + }, types.AppServerSpecV3{ + HostID: fmt.Sprintf("host-%d", i), + App: app, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindNode: + for i := 0; i < size; i++ { + nodeLabelSize := labelSize + if namespace == fiveMBNode && i == 0 { + // Artificially make a node ~ 5MB to force + // ListNodes to fail regardless of chunk size. + nodeLabelSize = 5000000 + } + + var err error + resource, err := types.NewServerWithLabels(fmt.Sprintf("node-%d", i), types.KindNode, types.ServerSpecV2{}, + map[string]string{ + "label": string(make([]byte, nodeLabelSize)), + }, + ) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindKubeServer: + for i := 0; i < size; i++ { + var err error + name := fmt.Sprintf("kube-service-%d", i) + kube, err := types.NewKubernetesClusterV3(types.Metadata{ + Name: name, + Labels: map[string]string{"name": name}, + }, + types.KubernetesClusterSpecV3{}, + ) + if err != nil { + return nil, trace.Wrap(err) + } + resource, err := types.NewKubernetesServerV3( + types.Metadata{ + Name: name, + Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }, + }, + types.KubernetesServerSpecV3{ + HostID: fmt.Sprintf("host-%d", i), + Cluster: kube, + }, + ) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindWindowsDesktop: + for i := 0; i < size; i++ { + var err error + name := fmt.Sprintf("windows-desktop-%d", i) + resource, err := types.NewWindowsDesktopV3( + name, + map[string]string{"label": string(make([]byte, labelSize))}, + types.WindowsDesktopSpecV3{ + Addr: "_", + HostID: "_", + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindAppOrSAMLIdPServiceProvider: + for i := 0; i < size; i++ { + // Alternate between adding Apps and SAMLIdPServiceProviders. If `i` is even, add an app. + if i%2 == 0 { + app, err := types.NewAppV3(types.Metadata{ + Name: fmt.Sprintf("app-%d", i), + }, types.AppSpecV3{ + URI: "localhost", + }) + if err != nil { + return nil, trace.Wrap(err) + } + + appServer, err := types.NewAppServerV3(types.Metadata{ + Name: fmt.Sprintf("app-%d", i), + Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }, + }, types.AppServerSpecV3{ + HostID: fmt.Sprintf("host-%d", i), + App: app, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resource := &types.AppServerOrSAMLIdPServiceProviderV1{ + Resource: &types.AppServerOrSAMLIdPServiceProviderV1_AppServer{ + AppServer: appServer, + }, + } + + resources = append(resources, any(resource).(T)) + } else { + sp := &types.SAMLIdPServiceProviderV1{ResourceHeader: types.ResourceHeader{Metadata: types.Metadata{Name: fmt.Sprintf("saml-app-%d", i), Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }}}} + + resource := &types.AppServerOrSAMLIdPServiceProviderV1{ + Resource: &types.AppServerOrSAMLIdPServiceProviderV1_SAMLIdPServiceProvider{ + SAMLIdPServiceProvider: sp, + }, + } + resources = append(resources, any(resource).(T)) + } + } + default: + return nil, trace.Errorf("unsupported resource type %s", resourceType) + } + + return resources, nil +} + +// mockInsecureCredentials mocks insecure Client credentials. +// it returns a nil tlsConfig which allows the client to run in insecure mode. +// TODO(Joerger) replace insecure credentials with proper TLS credentials. +type mockInsecureTLSCredentials struct{} + +func (mc *mockInsecureTLSCredentials) Dialer(cfg Config) (ContextDialer, error) { + return nil, trace.NotImplemented("no dialer") +} + +func (mc *mockInsecureTLSCredentials) TLSConfig() (*tls.Config, error) { + return nil, nil +} + +func (mc *mockInsecureTLSCredentials) SSHClientConfig() (*ssh.ClientConfig, error) { + return nil, trace.NotImplemented("no ssh config") +} From a446a8dd4461ff97a2704a9b46130b39c872127b Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 16 Aug 2023 13:18:40 -0700 Subject: [PATCH 2/4] Refactor mock server and client test. --- api/client/client_test.go | 433 ++++++++++++++++++++++++++------- api/client/mock_server_test.go | 356 ++------------------------- 2 files changed, 366 insertions(+), 423 deletions(-) diff --git a/api/client/client_test.go b/api/client/client_test.go index ac61514d892fe..35c812731e6f1 100644 --- a/api/client/client_test.go +++ b/api/client/client_test.go @@ -19,6 +19,7 @@ package client import ( "context" "flag" + "fmt" "net" "os" "testing" @@ -26,95 +27,87 @@ import ( "github.com/google/go-cmp/cmp" "github.com/gravitational/trace" + "github.com/gravitational/trace/trail" "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" ) +func TestMain(m *testing.M) { + flag.Parse() + if testing.Verbose() { + logrus.SetLevel(logrus.DebugLevel) + } + os.Exit(m.Run()) +} + +type pingService struct { + *proto.UnimplementedAuthServiceServer +} + +func (s *pingService) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) { + return &proto.PingResponse{}, nil +} + func TestNew(t *testing.T) { t.Parallel() ctx := context.Background() - srv := startMockServer(t) + srv := startMockServer(t, &pingService{}) tests := []struct { - desc string - config Config - assertErr require.ErrorAssertionFunc + desc string + modifyConfig func(*Config) + assertErr require.ErrorAssertionFunc }{{ - desc: "successfully dial tcp address.", - config: Config{ - Addrs: []string{srv.Addr()}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }, - assertErr: require.NoError, + desc: "successfully dial tcp address.", + modifyConfig: func(c *Config) { /* noop */ }, + assertErr: require.NoError, }, { desc: "synchronously dial addr/cred pairs and succeed with the 1 good pair.", - config: Config{ - Addrs: []string{"bad addr", srv.Addr(), "bad addr"}, - Credentials: []Credentials{ - &tlsConfigCreds{nil}, - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - &tlsConfigCreds{nil}, - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, + modifyConfig: func(c *Config) { + c.Addrs = append(c.Addrs, "bad addr", "bad addr") + c.Credentials = append([]Credentials{&tlsConfigCreds{nil}, &tlsConfigCreds{nil}}, c.Credentials...) }, assertErr: require.NoError, }, { desc: "fail to dial with a bad address.", - config: Config{ - DialTimeout: time.Second, - Addrs: []string{"bad addr"}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, + modifyConfig: func(c *Config) { + c.Addrs = []string{"bad addr"} }, assertErr: func(t require.TestingT, err error, _ ...interface{}) { require.Error(t, err) - require.Contains(t, err.Error(), "all connection methods failed") + require.ErrorContains(t, err, "all connection methods failed") }, }, { desc: "fail to dial with no address or dialer.", - config: Config{ - DialTimeout: time.Second, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, + modifyConfig: func(c *Config) { + c.Addrs = nil }, assertErr: func(t require.TestingT, err error, _ ...interface{}) { require.Error(t, err) - require.Contains(t, err.Error(), "no connection methods found, try providing Dialer or Addrs in config") + require.ErrorContains(t, err, "no connection methods found, try providing Dialer or Addrs in config") }, }} for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - clt, err := srv.NewClient(ctx, WithConfig(tt.config)) - tt.assertErr(t, err) + cfg := srv.clientCfg() + tt.modifyConfig(&cfg) - if err == nil { - t.Cleanup(func() { require.NoError(t, clt.Close()) }) - // requests to the server should succeed. - _, err = clt.Ping(ctx) - require.NoError(t, err) + clt, err := New(ctx, cfg) + tt.assertErr(t, err) + if err != nil { + return } + + // Requests to the server should succeed. + _, err = clt.Ping(ctx) + assert.NoError(t, err, "Ping failed") + assert.NoError(t, clt.Close(), "Close failed") }) } } @@ -123,22 +116,16 @@ func TestNewDialBackground(t *testing.T) { t.Parallel() ctx := context.Background() - // get listener but don't serve it yet. + // Create a server but don't serve it yet. l, err := net.Listen("tcp", "localhost:") require.NoError(t, err) addr := l.Addr().String() + srv := newMockServer(t, addr, &pingService{}) // Create client before the server is listening. - clt, err := New(ctx, Config{ - DialInBackground: true, - Addrs: []string{addr}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }) + cfg := srv.clientCfg() + cfg.DialInBackground = true + clt, err := New(ctx, cfg) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, clt.Close()) }) @@ -148,8 +135,8 @@ func TestNewDialBackground(t *testing.T) { _, err = clt.Ping(cancelCtx) require.Error(t, err) - // Start the server and wait for the client connection to be ready. - startMockServerWithListener(t, l) + // Server the listener and wait for the client connection to be ready. + srv.serve(t, l) require.NoError(t, clt.waitForConnectionReady(ctx)) // requests to the server should succeed. @@ -161,32 +148,27 @@ func TestWaitForConnectionReady(t *testing.T) { t.Parallel() ctx := context.Background() + // Create a server but don't serve it yet. l, err := net.Listen("tcp", "localhost:") require.NoError(t, err) addr := l.Addr().String() + srv := newMockServer(t, addr, &proto.UnimplementedAuthServiceServer{}) // Create client before the server is listening. - clt, err := New(ctx, Config{ - DialInBackground: true, - Addrs: []string{addr}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - }) + cfg := srv.clientCfg() + cfg.DialInBackground = true + clt, err := New(ctx, cfg) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, clt.Close()) }) - // WaitForConnectionReady should return false once the + // WaitForConnectionReady should return an error once the // context is canceled if the server isn't open to connections. - cancelCtx, cancel := context.WithTimeout(ctx, time.Second*3) + cancelCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) defer cancel() require.Error(t, clt.waitForConnectionReady(cancelCtx)) // WaitForConnectionReady should return nil if the server is open to connections. - startMockServerWithListener(t, l) + srv.serve(t, l) require.NoError(t, clt.waitForConnectionReady(ctx)) // WaitForConnectionReady should return an error if the grpc connection is closed. @@ -194,10 +176,282 @@ func TestWaitForConnectionReady(t *testing.T) { require.Error(t, clt.waitForConnectionReady(ctx)) } +type listResourcesService struct { + *proto.UnimplementedAuthServiceServer +} + +func (s *listResourcesService) ListResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) { + resources, err := testResources[types.ResourceWithLabels](req.ResourceType, req.Namespace) + if err != nil { + return nil, trail.ToGRPC(err) + } + + resp := &proto.ListResourcesResponse{ + Resources: make([]*proto.PaginatedResource, 0, len(resources)), + TotalCount: int32(len(resources)), + } + + var ( + takeResources = req.StartKey == "" + lastResourceName string + ) + for _, resource := range resources { + if resource.GetName() == req.StartKey { + takeResources = true + continue + } + + if !takeResources { + continue + } + + var protoResource *proto.PaginatedResource + switch req.ResourceType { + case types.KindDatabaseServer: + database, ok := resource.(*types.DatabaseServerV3) + if !ok { + return nil, trace.Errorf("database server has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseServer{DatabaseServer: database}} + case types.KindAppServer: + app, ok := resource.(*types.AppServerV3) + if !ok { + return nil, trace.Errorf("application server has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_AppServer{AppServer: app}} + case types.KindNode: + srv, ok := resource.(*types.ServerV2) + if !ok { + return nil, trace.Errorf("node has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_Node{Node: srv}} + case types.KindKubeServer: + srv, ok := resource.(*types.KubernetesServerV3) + if !ok { + return nil, trace.Errorf("kubernetes server has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_KubernetesServer{KubernetesServer: srv}} + case types.KindWindowsDesktop: + desktop, ok := resource.(*types.WindowsDesktopV3) + if !ok { + return nil, trace.Errorf("windows desktop has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_WindowsDesktop{WindowsDesktop: desktop}} + case types.KindAppOrSAMLIdPServiceProvider: + appServerOrSP, ok := resource.(*types.AppServerOrSAMLIdPServiceProviderV1) + if !ok { + return nil, trace.Errorf("AppServerOrSAMLIdPServiceProvider has invalid type %T", resource) + } + + protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_AppServerOrSAMLIdPServiceProvider{AppServerOrSAMLIdPServiceProvider: appServerOrSP}} + } + resp.Resources = append(resp.Resources, protoResource) + lastResourceName = resource.GetName() + if len(resp.Resources) == int(req.Limit) { + break + } + } + + if len(resp.Resources) != len(resources) { + resp.NextKey = lastResourceName + } + + return resp, nil +} + +const fiveMBNode = "fiveMBNode" + +func testResources[T types.ResourceWithLabels](resourceType, namespace string) ([]T, error) { + size := 50 + // Artificially make each node ~ 100KB to force + // ListResources to fail with chunks of >= 40. + labelSize := 100000 + resources := make([]T, 0, size) + + switch resourceType { + case types.KindDatabaseServer: + for i := 0; i < size; i++ { + resource, err := types.NewDatabaseServerV3(types.Metadata{ + Name: fmt.Sprintf("db-%d", i), + Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }, + }, types.DatabaseServerSpecV3{ + Hostname: "localhost", + HostID: fmt.Sprintf("host-%d", i), + Database: &types.DatabaseV3{ + Metadata: types.Metadata{ + Name: fmt.Sprintf("db-%d", i), + }, + Spec: types.DatabaseSpecV3{ + Protocol: types.DatabaseProtocolPostgreSQL, + URI: "localhost", + }, + }, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindAppServer: + for i := 0; i < size; i++ { + app, err := types.NewAppV3(types.Metadata{ + Name: fmt.Sprintf("app-%d", i), + }, types.AppSpecV3{ + URI: "localhost", + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resource, err := types.NewAppServerV3(types.Metadata{ + Name: fmt.Sprintf("app-%d", i), + Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }, + }, types.AppServerSpecV3{ + HostID: fmt.Sprintf("host-%d", i), + App: app, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindNode: + for i := 0; i < size; i++ { + nodeLabelSize := labelSize + if namespace == fiveMBNode && i == 0 { + // Artificially make a node ~ 5MB to force + // ListNodes to fail regardless of chunk size. + nodeLabelSize = 5000000 + } + + var err error + resource, err := types.NewServerWithLabels(fmt.Sprintf("node-%d", i), types.KindNode, types.ServerSpecV2{}, + map[string]string{ + "label": string(make([]byte, nodeLabelSize)), + }, + ) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindKubeServer: + for i := 0; i < size; i++ { + var err error + name := fmt.Sprintf("kube-service-%d", i) + kube, err := types.NewKubernetesClusterV3(types.Metadata{ + Name: name, + Labels: map[string]string{"name": name}, + }, + types.KubernetesClusterSpecV3{}, + ) + if err != nil { + return nil, trace.Wrap(err) + } + resource, err := types.NewKubernetesServerV3( + types.Metadata{ + Name: name, + Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }, + }, + types.KubernetesServerSpecV3{ + HostID: fmt.Sprintf("host-%d", i), + Cluster: kube, + }, + ) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindWindowsDesktop: + for i := 0; i < size; i++ { + var err error + name := fmt.Sprintf("windows-desktop-%d", i) + resource, err := types.NewWindowsDesktopV3( + name, + map[string]string{"label": string(make([]byte, labelSize))}, + types.WindowsDesktopSpecV3{ + Addr: "_", + HostID: "_", + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resources = append(resources, any(resource).(T)) + } + case types.KindAppOrSAMLIdPServiceProvider: + for i := 0; i < size; i++ { + // Alternate between adding Apps and SAMLIdPServiceProviders. If `i` is even, add an app. + if i%2 == 0 { + app, err := types.NewAppV3(types.Metadata{ + Name: fmt.Sprintf("app-%d", i), + }, types.AppSpecV3{ + URI: "localhost", + }) + if err != nil { + return nil, trace.Wrap(err) + } + + appServer, err := types.NewAppServerV3(types.Metadata{ + Name: fmt.Sprintf("app-%d", i), + Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }, + }, types.AppServerSpecV3{ + HostID: fmt.Sprintf("host-%d", i), + App: app, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + resource := &types.AppServerOrSAMLIdPServiceProviderV1{ + Resource: &types.AppServerOrSAMLIdPServiceProviderV1_AppServer{ + AppServer: appServer, + }, + } + + resources = append(resources, any(resource).(T)) + } else { + sp := &types.SAMLIdPServiceProviderV1{ResourceHeader: types.ResourceHeader{Metadata: types.Metadata{Name: fmt.Sprintf("saml-app-%d", i), Labels: map[string]string{ + "label": string(make([]byte, labelSize)), + }}}} + + resource := &types.AppServerOrSAMLIdPServiceProviderV1{ + Resource: &types.AppServerOrSAMLIdPServiceProviderV1_SAMLIdPServiceProvider{ + SAMLIdPServiceProvider: sp, + }, + } + resources = append(resources, any(resource).(T)) + } + } + default: + return nil, trace.Errorf("unsupported resource type %s", resourceType) + } + + return resources, nil +} + func TestListResources(t *testing.T) { t.Parallel() ctx := context.Background() - srv := startMockServer(t) + srv := startMockServer(t, &listResourcesService{}) testCases := map[string]struct { resourceType string @@ -226,7 +480,7 @@ func TestListResources(t *testing.T) { } // Create client - clt, err := srv.NewClient(ctx) + clt, err := New(ctx, srv.clientCfg()) require.NoError(t, err) for name, test := range testCases { @@ -298,10 +552,11 @@ func testGetResources[T types.ResourceWithLabels](t *testing.T, clt *Client, kin func TestGetResources(t *testing.T) { t.Parallel() - srv := startMockServer(t) + ctx := context.Background() + srv := startMockServer(t, &listResourcesService{}) // Create client - clt, err := srv.NewClient(context.Background()) + clt, err := New(ctx, srv.clientCfg()) require.NoError(t, err) t.Run("DatabaseServer", func(t *testing.T) { @@ -338,10 +593,10 @@ func TestGetResources(t *testing.T) { func TestGetResourcesWithFilters(t *testing.T) { t.Parallel() ctx := context.Background() - srv := startMockServer(t) + srv := startMockServer(t, &listResourcesService{}) // Create client - clt, err := srv.NewClient(ctx) + clt, err := New(ctx, srv.clientCfg()) require.NoError(t, err) testCases := map[string]struct { @@ -394,11 +649,3 @@ func TestGetResourcesWithFilters(t *testing.T) { }) } } - -func TestMain(m *testing.M) { - flag.Parse() - if testing.Verbose() { - logrus.SetLevel(logrus.DebugLevel) - } - os.Exit(m.Run()) -} diff --git a/api/client/mock_server_test.go b/api/client/mock_server_test.go index 8c02ab987c37d..66c694f926880 100644 --- a/api/client/mock_server_test.go +++ b/api/client/mock_server_test.go @@ -17,374 +17,70 @@ limitations under the License. package client import ( - "context" "crypto/tls" - "fmt" "net" "testing" + "time" "github.com/gravitational/trace" - "github.com/gravitational/trace/trail" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/status" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/types" ) // mockServer mocks an Auth Server. type mockServer struct { addr string grpc *grpc.Server - *proto.UnimplementedAuthServiceServer } -func newMockServer(addr string) *mockServer { +func newMockServer(t *testing.T, addr string, service proto.AuthServiceServer) *mockServer { + t.Helper() m := &mockServer{ - addr: addr, - grpc: grpc.NewServer(), - UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{}, - } - proto.RegisterAuthServiceServer(m.grpc, m) - return m -} - -func (m *mockServer) Stop() { - m.grpc.Stop() -} - -func (m *mockServer) Addr() string { - return m.addr -} - -type ConfigOpt func(*Config) - -func WithConfig(cfg Config) ConfigOpt { - return func(config *Config) { - *config = cfg - } -} - -func (m *mockServer) NewClient(ctx context.Context, opts ...ConfigOpt) (*Client, error) { - cfg := Config{ - Addrs: []string{m.addr}, - Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option - }, - } - - for _, opt := range opts { - opt(&cfg) + addr: addr, + grpc: grpc.NewServer(), } - return New(ctx, cfg) + proto.RegisterAuthServiceServer(m.grpc, service) + return m } // startMockServer starts a new mock server. Parallel tests cannot use the same addr. -func startMockServer(t *testing.T) *mockServer { +func startMockServer(t *testing.T, service proto.AuthServiceServer) *mockServer { l, err := net.Listen("tcp", "localhost:") require.NoError(t, err) - return startMockServerWithListener(t, l) + srv := newMockServer(t, l.Addr().String(), service) + srv.serve(t, l) + return srv } -// startMockServerWithListener starts a new mock server with the provided listener -func startMockServerWithListener(t *testing.T, l net.Listener) *mockServer { - srv := newMockServer(l.Addr().String()) - +func (m *mockServer) serve(t *testing.T, l net.Listener) { errCh := make(chan error, 1) go func() { - errCh <- srv.grpc.Serve(l) + errCh <- m.grpc.Serve(l) }() t.Cleanup(func() { - srv.grpc.Stop() - require.NoError(t, <-errCh) + m.grpc.Stop() + require.NoError(t, <-errCh, "mockServer gRPC server exited with unexpected error") }) - - return srv } -func (m *mockServer) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) { - return &proto.PingResponse{}, nil -} - -func (m *mockServer) ListResources(ctx context.Context, req *proto.ListResourcesRequest) (*proto.ListResourcesResponse, error) { - resources, err := testResources[types.ResourceWithLabels](req.ResourceType, req.Namespace) - if err != nil { - return nil, trail.ToGRPC(err) - } - - resp := &proto.ListResourcesResponse{ - Resources: make([]*proto.PaginatedResource, 0, len(resources)), - TotalCount: int32(len(resources)), - } - - var ( - takeResources = req.StartKey == "" - lastResourceName string - ) - for _, resource := range resources { - if resource.GetName() == req.StartKey { - takeResources = true - continue - } - - if !takeResources { - continue - } - - var protoResource *proto.PaginatedResource - switch req.ResourceType { - case types.KindDatabaseServer: - database, ok := resource.(*types.DatabaseServerV3) - if !ok { - return nil, trace.Errorf("database server has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_DatabaseServer{DatabaseServer: database}} - case types.KindAppServer: - app, ok := resource.(*types.AppServerV3) - if !ok { - return nil, trace.Errorf("application server has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_AppServer{AppServer: app}} - case types.KindNode: - srv, ok := resource.(*types.ServerV2) - if !ok { - return nil, trace.Errorf("node has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_Node{Node: srv}} - case types.KindKubeServer: - srv, ok := resource.(*types.KubernetesServerV3) - if !ok { - return nil, trace.Errorf("kubernetes server has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_KubernetesServer{KubernetesServer: srv}} - case types.KindWindowsDesktop: - desktop, ok := resource.(*types.WindowsDesktopV3) - if !ok { - return nil, trace.Errorf("windows desktop has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_WindowsDesktop{WindowsDesktop: desktop}} - case types.KindAppOrSAMLIdPServiceProvider: - appServerOrSP, ok := resource.(*types.AppServerOrSAMLIdPServiceProviderV1) - if !ok { - return nil, trace.Errorf("AppServerOrSAMLIdPServiceProvider has invalid type %T", resource) - } - - protoResource = &proto.PaginatedResource{Resource: &proto.PaginatedResource_AppServerOrSAMLIdPServiceProvider{AppServerOrSAMLIdPServiceProvider: appServerOrSP}} - } - resp.Resources = append(resp.Resources, protoResource) - lastResourceName = resource.GetName() - if len(resp.Resources) == int(req.Limit) { - break - } - } - - if len(resp.Resources) != len(resources) { - resp.NextKey = lastResourceName - } - - return resp, nil -} - -func (m *mockServer) AddMFADeviceSync(ctx context.Context, req *proto.AddMFADeviceSyncRequest) (*proto.AddMFADeviceSyncResponse, error) { - return nil, status.Error(codes.AlreadyExists, "Already Exists") -} - -const fiveMBNode = "fiveMBNode" - -func testResources[T types.ResourceWithLabels](resourceType, namespace string) ([]T, error) { - size := 50 - // Artificially make each node ~ 100KB to force - // ListResources to fail with chunks of >= 40. - labelSize := 100000 - resources := make([]T, 0, size) - - switch resourceType { - case types.KindDatabaseServer: - for i := 0; i < size; i++ { - resource, err := types.NewDatabaseServerV3(types.Metadata{ - Name: fmt.Sprintf("db-%d", i), - Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }, - }, types.DatabaseServerSpecV3{ - Hostname: "localhost", - HostID: fmt.Sprintf("host-%d", i), - Database: &types.DatabaseV3{ - Metadata: types.Metadata{ - Name: fmt.Sprintf("db-%d", i), - }, - Spec: types.DatabaseSpecV3{ - Protocol: types.DatabaseProtocolPostgreSQL, - URI: "localhost", - }, - }, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindAppServer: - for i := 0; i < size; i++ { - app, err := types.NewAppV3(types.Metadata{ - Name: fmt.Sprintf("app-%d", i), - }, types.AppSpecV3{ - URI: "localhost", - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resource, err := types.NewAppServerV3(types.Metadata{ - Name: fmt.Sprintf("app-%d", i), - Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }, - }, types.AppServerSpecV3{ - HostID: fmt.Sprintf("host-%d", i), - App: app, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindNode: - for i := 0; i < size; i++ { - nodeLabelSize := labelSize - if namespace == fiveMBNode && i == 0 { - // Artificially make a node ~ 5MB to force - // ListNodes to fail regardless of chunk size. - nodeLabelSize = 5000000 - } - - var err error - resource, err := types.NewServerWithLabels(fmt.Sprintf("node-%d", i), types.KindNode, types.ServerSpecV2{}, - map[string]string{ - "label": string(make([]byte, nodeLabelSize)), - }, - ) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindKubeServer: - for i := 0; i < size; i++ { - var err error - name := fmt.Sprintf("kube-service-%d", i) - kube, err := types.NewKubernetesClusterV3(types.Metadata{ - Name: name, - Labels: map[string]string{"name": name}, - }, - types.KubernetesClusterSpecV3{}, - ) - if err != nil { - return nil, trace.Wrap(err) - } - resource, err := types.NewKubernetesServerV3( - types.Metadata{ - Name: name, - Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }, - }, - types.KubernetesServerSpecV3{ - HostID: fmt.Sprintf("host-%d", i), - Cluster: kube, - }, - ) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindWindowsDesktop: - for i := 0; i < size; i++ { - var err error - name := fmt.Sprintf("windows-desktop-%d", i) - resource, err := types.NewWindowsDesktopV3( - name, - map[string]string{"label": string(make([]byte, labelSize))}, - types.WindowsDesktopSpecV3{ - Addr: "_", - HostID: "_", - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resources = append(resources, any(resource).(T)) - } - case types.KindAppOrSAMLIdPServiceProvider: - for i := 0; i < size; i++ { - // Alternate between adding Apps and SAMLIdPServiceProviders. If `i` is even, add an app. - if i%2 == 0 { - app, err := types.NewAppV3(types.Metadata{ - Name: fmt.Sprintf("app-%d", i), - }, types.AppSpecV3{ - URI: "localhost", - }) - if err != nil { - return nil, trace.Wrap(err) - } - - appServer, err := types.NewAppServerV3(types.Metadata{ - Name: fmt.Sprintf("app-%d", i), - Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }, - }, types.AppServerSpecV3{ - HostID: fmt.Sprintf("host-%d", i), - App: app, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - resource := &types.AppServerOrSAMLIdPServiceProviderV1{ - Resource: &types.AppServerOrSAMLIdPServiceProviderV1_AppServer{ - AppServer: appServer, - }, - } - - resources = append(resources, any(resource).(T)) - } else { - sp := &types.SAMLIdPServiceProviderV1{ResourceHeader: types.ResourceHeader{Metadata: types.Metadata{Name: fmt.Sprintf("saml-app-%d", i), Labels: map[string]string{ - "label": string(make([]byte, labelSize)), - }}}} - - resource := &types.AppServerOrSAMLIdPServiceProviderV1{ - Resource: &types.AppServerOrSAMLIdPServiceProviderV1_SAMLIdPServiceProvider{ - SAMLIdPServiceProvider: sp, - }, - } - resources = append(resources, any(resource).(T)) - } - } - default: - return nil, trace.Errorf("unsupported resource type %s", resourceType) +func (m *mockServer) clientCfg() Config { + return Config{ + // Reduce dial timeout for tests. + DialTimeout: time.Second, + Addrs: []string{m.addr}, + Credentials: []Credentials{ + &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials + }, + DialOpts: []grpc.DialOption{ + grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option + }, } - - return resources, nil } // mockInsecureCredentials mocks insecure Client credentials. From 57af62d20ac4c6169945e43d72d9fe1ec04abef9 Mon Sep 17 00:00:00 2001 From: joerger Date: Wed, 16 Aug 2023 13:25:34 -0700 Subject: [PATCH 3/4] Use mTLS in client tests. --- api/client/mock_server_test.go | 128 +++++++++++++++++++++++++++------ 1 file changed, 108 insertions(+), 20 deletions(-) diff --git a/api/client/mock_server_test.go b/api/client/mock_server_test.go index 66c694f926880..c2b6dd41b2add 100644 --- a/api/client/mock_server_test.go +++ b/api/client/mock_server_test.go @@ -17,33 +17,46 @@ limitations under the License. package client import ( + "crypto/ed25519" + "crypto/rand" "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" "net" "testing" "time" - "github.com/gravitational/trace" "github.com/stretchr/testify/require" - "golang.org/x/crypto/ssh" "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/keys" ) // mockServer mocks an Auth Server. type mockServer struct { - addr string - grpc *grpc.Server + addr string + grpc *grpc.Server + clientTLS *tls.Config + serverTLS *tls.Config } func newMockServer(t *testing.T, addr string, service proto.AuthServiceServer) *mockServer { t.Helper() m := &mockServer{ addr: addr, - grpc: grpc.NewServer(), } + m.generateTestCerts(t) + + m.grpc = grpc.NewServer( + grpc.Creds(credentials.NewTLS(m.serverTLS)), + ) + proto.RegisterAuthServiceServer(m.grpc, service) return m } @@ -75,27 +88,102 @@ func (m *mockServer) clientCfg() Config { DialTimeout: time.Second, Addrs: []string{m.addr}, Credentials: []Credentials{ - &mockInsecureTLSCredentials{}, // TODO(Joerger) replace insecure credentials - }, - DialOpts: []grpc.DialOption{ - grpc.WithTransportCredentials(insecure.NewCredentials()), // TODO(Joerger) remove insecure dial option + LoadTLS(m.clientTLS), }, } } -// mockInsecureCredentials mocks insecure Client credentials. -// it returns a nil tlsConfig which allows the client to run in insecure mode. -// TODO(Joerger) replace insecure credentials with proper TLS credentials. -type mockInsecureTLSCredentials struct{} +func (m *mockServer) generateTestCerts(t *testing.T) { + t.Helper() -func (mc *mockInsecureTLSCredentials) Dialer(cfg Config) (ContextDialer, error) { - return nil, trace.NotImplemented("no dialer") + caKey, caCert := generateCA(t) + m.serverTLS = generateChildTLSConfigFromCA(t, caKey, caCert) + m.clientTLS = generateChildTLSConfigFromCA(t, caKey, caCert) } -func (mc *mockInsecureTLSCredentials) TLSConfig() (*tls.Config, error) { - return nil, nil +func generateCA(t *testing.T) (*keys.PrivateKey, *x509.Certificate) { + t.Helper() + + caPub, caPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + caKey, err := keys.NewPrivateKey(caPriv, nil) + require.NoError(t, err) + + // Create a self signed certificate. + + notBefore := time.Now() + notAfter := notBefore.Add(time.Minute) + entity := pkix.Name{ + Organization: []string{"teleport"}, + CommonName: "localhost", + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Issuer: entity, + Subject: entity, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageCertSign, + IsCA: true, + BasicConstraintsValid: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, template, template, caPub, caKey) + require.NoError(t, err) + + x509Cert, err := x509.ParseCertificate(caCertDER) + require.NoError(t, err) + + return caKey, x509Cert } -func (mc *mockInsecureTLSCredentials) SSHClientConfig() (*ssh.ClientConfig, error) { - return nil, trace.NotImplemented("no ssh config") +func generateChildTLSConfigFromCA(t *testing.T, caKey *keys.PrivateKey, caCert *x509.Certificate) *tls.Config { + t.Helper() + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + key, err := keys.NewPrivateKey(priv, nil) + require.NoError(t, err) + + // Create a certificate signed by the CA. + + notBefore := time.Now() + notAfter := notBefore.Add(time.Minute) + entity := pkix.Name{ + Organization: []string{"teleport"}, + CommonName: "localhost", + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: entity, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + DNSNames: []string{constants.APIDomain}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, pub, caKey) + require.NoError(t, err) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := key.TLSCertificate(certPEM) + require.NoError(t, err) + + pool := x509.NewCertPool() + pool.AddCert(caCert) + + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: pool, + } } From 28dddaa7066a5d27c03c13fecf73c501062da114 Mon Sep 17 00:00:00 2001 From: joerger Date: Thu, 17 Aug 2023 18:22:24 -0700 Subject: [PATCH 4/4] Move mtls test helpers into a new package for reusability. --- api/client/mock_server_test.go | 121 ++-------------------------- api/testhelpers/mtls/mtls.go | 139 +++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 113 deletions(-) create mode 100644 api/testhelpers/mtls/mtls.go diff --git a/api/client/mock_server_test.go b/api/client/mock_server_test.go index c2b6dd41b2add..4840961a99258 100644 --- a/api/client/mock_server_test.go +++ b/api/client/mock_server_test.go @@ -17,13 +17,6 @@ limitations under the License. package client import ( - "crypto/ed25519" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "math/big" "net" "testing" "time" @@ -33,28 +26,25 @@ import ( "google.golang.org/grpc/credentials" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/constants" - "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/api/testhelpers/mtls" ) // mockServer mocks an Auth Server. type mockServer struct { - addr string - grpc *grpc.Server - clientTLS *tls.Config - serverTLS *tls.Config + addr string + grpc *grpc.Server + mtlsConfig *mtls.Config } func newMockServer(t *testing.T, addr string, service proto.AuthServiceServer) *mockServer { t.Helper() m := &mockServer{ - addr: addr, + addr: addr, + mtlsConfig: mtls.NewConfig(t), } - m.generateTestCerts(t) - m.grpc = grpc.NewServer( - grpc.Creds(credentials.NewTLS(m.serverTLS)), + grpc.Creds(credentials.NewTLS(m.mtlsConfig.ServerTLS)), ) proto.RegisterAuthServiceServer(m.grpc, service) @@ -88,102 +78,7 @@ func (m *mockServer) clientCfg() Config { DialTimeout: time.Second, Addrs: []string{m.addr}, Credentials: []Credentials{ - LoadTLS(m.clientTLS), + LoadTLS(m.mtlsConfig.ClientTLS), }, } } - -func (m *mockServer) generateTestCerts(t *testing.T) { - t.Helper() - - caKey, caCert := generateCA(t) - m.serverTLS = generateChildTLSConfigFromCA(t, caKey, caCert) - m.clientTLS = generateChildTLSConfigFromCA(t, caKey, caCert) -} - -func generateCA(t *testing.T) (*keys.PrivateKey, *x509.Certificate) { - t.Helper() - - caPub, caPriv, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - caKey, err := keys.NewPrivateKey(caPriv, nil) - require.NoError(t, err) - - // Create a self signed certificate. - - notBefore := time.Now() - notAfter := notBefore.Add(time.Minute) - entity := pkix.Name{ - Organization: []string{"teleport"}, - CommonName: "localhost", - } - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - require.NoError(t, err) - - template := &x509.Certificate{ - SerialNumber: serialNumber, - Issuer: entity, - Subject: entity, - NotBefore: notBefore, - NotAfter: notAfter, - KeyUsage: x509.KeyUsageCertSign, - IsCA: true, - BasicConstraintsValid: true, - } - - caCertDER, err := x509.CreateCertificate(rand.Reader, template, template, caPub, caKey) - require.NoError(t, err) - - x509Cert, err := x509.ParseCertificate(caCertDER) - require.NoError(t, err) - - return caKey, x509Cert -} - -func generateChildTLSConfigFromCA(t *testing.T, caKey *keys.PrivateKey, caCert *x509.Certificate) *tls.Config { - t.Helper() - - pub, priv, err := ed25519.GenerateKey(rand.Reader) - require.NoError(t, err) - - key, err := keys.NewPrivateKey(priv, nil) - require.NoError(t, err) - - // Create a certificate signed by the CA. - - notBefore := time.Now() - notAfter := notBefore.Add(time.Minute) - entity := pkix.Name{ - Organization: []string{"teleport"}, - CommonName: "localhost", - } - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - require.NoError(t, err) - - template := &x509.Certificate{ - SerialNumber: serialNumber, - Subject: entity, - NotBefore: notBefore, - NotAfter: notAfter, - KeyUsage: x509.KeyUsageDigitalSignature, - BasicConstraintsValid: true, - DNSNames: []string{constants.APIDomain}, - } - - certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, pub, caKey) - require.NoError(t, err) - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - - tlsCert, err := key.TLSCertificate(certPEM) - require.NoError(t, err) - - pool := x509.NewCertPool() - pool.AddCert(caCert) - - return &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - RootCAs: pool, - } -} diff --git a/api/testhelpers/mtls/mtls.go b/api/testhelpers/mtls/mtls.go new file mode 100644 index 0000000000000..07476fbf65897 --- /dev/null +++ b/api/testhelpers/mtls/mtls.go @@ -0,0 +1,139 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mtls + +import ( + "crypto/ed25519" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/utils/keys" +) + +type Config struct { + ServerTLS *tls.Config + ClientTLS *tls.Config +} + +// NewConfig returns an mTLS config. +func NewConfig(t *testing.T) *Config { + t.Helper() + + caKey, caCert := generateCA(t) + serverTLS := generateChildTLSConfigFromCA(t, caKey, caCert) + clientTLS := generateChildTLSConfigFromCA(t, caKey, caCert) + clientTLS.ServerName = constants.APIDomain + + return &Config{ + ServerTLS: serverTLS, + ClientTLS: clientTLS, + } +} + +func generateCA(t *testing.T) (*keys.PrivateKey, *x509.Certificate) { + t.Helper() + + caPub, caPriv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + caKey, err := keys.NewPrivateKey(caPriv, nil) + require.NoError(t, err) + + // Create a self signed certificate. + + notBefore := time.Now() + notAfter := notBefore.Add(time.Minute) + entity := pkix.Name{ + Organization: []string{"teleport"}, + CommonName: "localhost", + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Issuer: entity, + Subject: entity, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageCertSign, + IsCA: true, + BasicConstraintsValid: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, template, template, caPub, caKey) + require.NoError(t, err) + + x509Cert, err := x509.ParseCertificate(caCertDER) + require.NoError(t, err) + + return caKey, x509Cert +} + +func generateChildTLSConfigFromCA(t *testing.T, caKey *keys.PrivateKey, caCert *x509.Certificate) *tls.Config { + t.Helper() + + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + key, err := keys.NewPrivateKey(priv, nil) + require.NoError(t, err) + + // Create a certificate signed by the CA. + + notBefore := time.Now() + notAfter := notBefore.Add(time.Minute) + entity := pkix.Name{ + Organization: []string{"teleport"}, + CommonName: "localhost", + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: entity, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + DNSNames: []string{constants.APIDomain}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, caCert, pub, caKey) + require.NoError(t, err) + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + tlsCert, err := key.TLSCertificate(certPEM) + require.NoError(t, err) + + pool := x509.NewCertPool() + pool.AddCert(caCert) + + return &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + RootCAs: pool, + } +}