diff --git a/integration/kube_integration_test.go b/integration/kube_integration_test.go index 778f08fb971ee..a77d8e330fc67 100644 --- a/integration/kube_integration_test.go +++ b/integration/kube_integration_test.go @@ -183,6 +183,7 @@ func TestKube(t *testing.T) { t.Run("Exec", suite.bind(testKubeExec)) t.Run("Deny", suite.bind(testKubeDeny)) t.Run("PortForward", suite.bind(testKubePortForward)) + t.Run("PortForwardPodDisconnect", suite.bind(testKubePortForwardPodDisconnect)) t.Run("TransportProtocol", suite.bind(testKubeTransportProtocol)) t.Run("TrustedClustersClientCert", suite.bind(testKubeTrustedClustersClientCert)) t.Run("TrustedClustersSNI", suite.bind(testKubeTrustedClustersSNI)) @@ -530,11 +531,11 @@ func testKubePortForward(t *testing.T, suite *KubeSuite) { builder func(*rest.Config, kubePortForwardArgs) (*kubePortForwarder, error) }{ { - name: "SPDY portForwarder", + name: "SPDY", builder: newPortForwarder, }, { - name: "SPDY over Websocket portForwarder", + name: "SPDY over Websocket", builder: newPortForwarderSPDYOverWebsocket, }, } @@ -558,7 +559,9 @@ func testKubePortForward(t *testing.T, suite *KubeSuite) { }) require.NoError(t, err) + // Forward local port to container port. forwarderCh := make(chan error) + t.Cleanup(func() { forwarder.Close() }) go func() { forwarderCh <- forwarder.ForwardPorts() }() select { @@ -566,7 +569,6 @@ func testKubePortForward(t *testing.T, suite *KubeSuite) { t.Fatalf("Timeout waiting for port forwarding.") case <-forwarder.readyC: } - t.Cleanup(func() {}) resp, err := http.Get(fmt.Sprintf("http://localhost:%v", localPort)) require.NoError(t, err) @@ -593,6 +595,169 @@ func testKubePortForward(t *testing.T, suite *KubeSuite) { } +// testKubePortForwardPodDisconnect tests Kubernetes port forwarding +// with pod disconnection. +func testKubePortForwardPodDisconnect(t *testing.T, suite *KubeSuite) { + tconf := suite.teleKubeConfig(Host) + + teleport := helpers.NewInstance(t, helpers.InstanceConfig{ + ClusterName: helpers.Site, + HostID: helpers.HostID, + NodeName: Host, + Priv: suite.priv, + Pub: suite.pub, + Logger: suite.log, + }) + + username := suite.me.Username + kubeGroups := []string{kube.TestImpersonationGroup} + role, err := types.NewRole("kubemaster", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Logins: []string{username}, + KubeGroups: kubeGroups, + KubernetesLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + KubernetesResources: []types.KubernetesResource{ + { + Kind: "pods", Name: types.Wildcard, Namespace: types.Wildcard, Verbs: []string{types.Wildcard}, APIGroup: types.Wildcard, + }, + }, + }, + }) + require.NoError(t, err) + teleport.AddUserWithRole(username, role) + + err = teleport.CreateEx(t, nil, tconf) + require.NoError(t, err) + + err = teleport.Start() + require.NoError(t, err) + defer teleport.StopAll() + + // set up kube configuration using proxy + _, proxyClientConfig, err := kube.ProxyClient(kube.ProxyConfig{ + T: teleport, + Username: username, + KubeGroups: kubeGroups, + }) + require.NoError(t, err) + + tests := []struct { + name string + builder func(*rest.Config, kubePortForwardArgs) (*kubePortForwarder, error) + }{ + { + name: "SPDY", + builder: newPortForwarder, + }, + { + name: "SPDY over Websocket", + builder: newPortForwarderSPDYOverWebsocket, + }, + } + + for _, tt := range tests { + t.Run(tt.name, + func(t *testing.T) { + // TODO(rana): Improve k8s isolation per test. + // Each test can have an isolated k8s environment. + // The isolated environment may have it's own namespace, pods, etc. + // This would involve updating CI k8s RBAC (fixtures/ci-teleport-rbac/ci-teleport.yaml). + // Existing tests can be updated to use the an isolated k8s environment. + // Current k8s integration testing reuses a single k8s environment and pod across tests. + // Some tests which delete pods (this one), or require multiple pods would benefit + // from isolated k8s environments. + // In this test, with k8s isolation per test, pod creation would be moved + // from `t.Cleanup()` to test setup. + t.Cleanup(func() { + // Current CI RBAC allows only for a pod named "test-pod". + // Kube integration test suite uses a single instance of + // "test-pod" across multiple tests. + // Here we continue the use and maintenance of the single "test-pod" pod approach. + // On successful test, "test-pod" is deleted, and re-created for the next test. + pod := newPod(testNamespace, testPod) + if _, err := suite.CoreV1().Pods(testNamespace).Create(context.Background(), pod, metav1.CreateOptions{}); err != nil { + require.True(t, kubeerrors.IsAlreadyExists(err), "Failed to create test pod: %s.", err) + } + + // Wait for pod to be running. + require.Eventually(t, func() bool { + rsp, err := suite.CoreV1().Pods(testNamespace).Get(context.Background(), testPod, metav1.GetOptions{}) + if err != nil { + t.Logf("Get pod error: %s", err) + return false + } + if rsp.Status.Phase == v1.PodRunning { + return true + } + return false + }, 60*time.Second, 500*time.Millisecond) + }) + + // Setup port-forwarding configuration. + listener, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, listener.Close()) + }) + localPort := listener.Addr().(*net.TCPAddr).Port + forwarder, err := tt.builder(proxyClientConfig, kubePortForwardArgs{ + ports: []string{fmt.Sprintf("%d:80", localPort)}, + podName: testPod, + podNamespace: testNamespace, + }) + require.NoError(t, err) + + // Forward local port to container port. + forwarderCh := make(chan error, 1) + t.Cleanup(func() { forwarder.Close() }) + go func() { forwarderCh <- forwarder.ForwardPorts() }() + + // Wait for port-forwarding to be ready. + select { + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for port forward start") + case <-forwarder.readyC: + } + + // Validate that port-forwarding is working. + resp, err := http.Get(fmt.Sprintf("http://localhost:%d", localPort)) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + + // Delete the pod. + err = suite.CoreV1().Pods(testNamespace).Delete(context.Background(), testPod, metav1.DeleteOptions{}) + require.NoError(t, err) + + // Wait for pod deletion. + require.Eventually(t, func() bool { + if _, err := suite.CoreV1().Pods(testNamespace).Get(context.Background(), testPod, metav1.GetOptions{}); err != nil { + return kubeerrors.IsNotFound(err) + } + return false + }, 60*time.Second, 500*time.Millisecond) + + // Attempt an http GET after pod deletion. + // This enables error reporting from KubeAPI back to client. + //nolint:bodyclose // http response is expected to be nil and return an error + _, err = http.Get(fmt.Sprintf("http://localhost:%d", localPort)) + require.Error(t, err) + + // Wait for port-forwarding to exit. + select { + case <-time.After(5 * time.Second): + t.Fatal("Timed out waiting for port forward exit") + case err := <-forwarderCh: + require.Equal(t, err, portforward.ErrLostConnectionToPod) + } + }, + ) + } + +} + // TestKubeTrustedClustersClientCert tests scenario with trusted clusters // using metadata encoded in the certificate func testKubeTrustedClustersClientCert(t *testing.T, suite *KubeSuite) { @@ -832,8 +997,11 @@ loop: }) require.NoError(t, err) + // Forward local port to container port. forwarderCh := make(chan error) + t.Cleanup(func() { forwarder.Close() }) go func() { forwarderCh <- forwarder.ForwardPorts() }() + defer func() { require.NoError(t, <-forwarderCh, "Forward ports exited with error") }() @@ -1101,9 +1269,12 @@ loop: podNamespace: pod.Namespace, }) require.NoError(t, err) - forwarderCh := make(chan error) + // Forward local port to container port. + forwarderCh := make(chan error) + t.Cleanup(func() { forwarder.Close() }) go func() { forwarderCh <- forwarder.ForwardPorts() }() + defer func() { require.NoError(t, <-forwarderCh, "Forward ports exited with error") }() diff --git a/lib/kube/proxy/portforward_spdy.go b/lib/kube/proxy/portforward_spdy.go index ad0b5e391a6e0..4de8a45dcbe27 100644 --- a/lib/kube/proxy/portforward_spdy.go +++ b/lib/kube/proxy/portforward_spdy.go @@ -19,6 +19,7 @@ package proxy import ( "context" "fmt" + "io" "log/slog" "net" "net/http" @@ -199,7 +200,12 @@ func (h *portForwardProxy) forwardStreamPair(p *httpStreamPair, remotePort int64 wg.Add(1) go func() { defer wg.Done() - if err := utils.ProxyConn(h.context, p.errorStream, targetErrorStream); err != nil { + // Close the target error stream to indicate no more writes. + if err := targetErrorStream.Close(); err != nil { + h.logger.DebugContext(h.context, "Unable to close target error stream", "error", err) + } + // Enables error propagation from Kube API server to kubectl client. + if _, err := io.Copy(p.errorStream, targetErrorStream); err != nil { h.logger.DebugContext(h.context, "Unable to proxy portforward error-stream", "error", err) } }() @@ -297,6 +303,8 @@ func (h *portForwardProxy) requestID(stream httpstream.Stream) (string, error) { // when the httpstream.Connection is closed. func (h *portForwardProxy) run() { h.logger.DebugContext(h.context, "Waiting for port forward streams") + var wg sync.WaitGroup + defer wg.Wait() for { select { case <-h.context.Done(): @@ -305,6 +313,9 @@ func (h *portForwardProxy) run() { case <-h.sourceConn.CloseChan(): h.logger.DebugContext(h.context, "Upgraded connection closed") return + case <-h.targetConn.CloseChan(): + h.logger.DebugContext(h.context, "Target connection closed") + return case stream := <-h.streamChan: requestID, err := h.requestID(stream) if err != nil { @@ -323,7 +334,11 @@ func (h *portForwardProxy) run() { err := trace.BadParameter("error processing stream for request %s: %v", requestID, err) p.sendErr(err) } else if complete { - go h.portForward(p) + wg.Add(1) + go func() { + defer wg.Done() + h.portForward(p) + }() } } } diff --git a/lib/kube/proxy/portforward_test.go b/lib/kube/proxy/portforward_test.go index d93638515d898..19dc2948569b1 100644 --- a/lib/kube/proxy/portforward_test.go +++ b/lib/kube/proxy/portforward_test.go @@ -20,6 +20,7 @@ package proxy import ( "context" + "errors" "fmt" "io" "net" @@ -139,7 +140,13 @@ func TestPortForwardKubeService(t *testing.T) { readyCh := make(chan struct{}) // errCh receives a single error from ForwardPorts goroutine. errCh := make(chan error) - t.Cleanup(func() { require.NoError(t, <-errCh) }) + t.Cleanup(func() { + // ErrLostConnectionToPod is an expected error. + // Server allowed to communicate error to client. + if err := <-errCh; !errors.Is(err, portforward.ErrLostConnectionToPod) { + require.NoError(t, err) + } + }) // stopCh control the port forwarding lifecycle. When it gets closed the // port forward will terminate. stopCh := make(chan struct{}) @@ -524,7 +531,13 @@ func TestPortForwardUnderlyingProtocol(t *testing.T) { readyCh := make(chan struct{}) // errCh receives a single error from ForwardPorts goroutine. errCh := make(chan error) - t.Cleanup(func() { require.NoError(t, <-errCh) }) + t.Cleanup(func() { + // ErrLostConnectionToPod is an expected error. + // Server allowed to communicate error to client. + if err := <-errCh; !errors.Is(err, portforward.ErrLostConnectionToPod) { + require.NoError(t, err) + } + }) // stopCh control the port forwarding lifecycle. When it gets closed the // port forward will terminate. stopCh := make(chan struct{}) diff --git a/lib/kube/proxy/portforward_websocket.go b/lib/kube/proxy/portforward_websocket.go index 4ca01ea473062..629090cc77a8f 100644 --- a/lib/kube/proxy/portforward_websocket.go +++ b/lib/kube/proxy/portforward_websocket.go @@ -266,11 +266,16 @@ func (h *websocketPortforwardHandler) forwardStreamPair(p *websocketChannelPair) }() wg := &sync.WaitGroup{} - wg.Add(1) + wg.Add(1) go func() { defer wg.Done() - if err := utils.ProxyConn(h.context, p.errorStream, targetErrorStream); err != nil { + // Close the target error stream to indicate no more writes. + if err := targetErrorStream.Close(); err != nil { + h.logger.DebugContext(h.context, "Unable to close target error stream", "error", err) + } + // Enables error propagation from Kube API server to kubectl client. + if _, err := io.Copy(p.errorStream, targetErrorStream); err != nil { h.logger.DebugContext(h.context, "Unable to proxy portforward error-stream", "error", err) } }()