diff --git a/lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go b/lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go index 456fd5648fcfb..85f25813c2d37 100644 --- a/lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go +++ b/lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go @@ -34,12 +34,15 @@ import ( "os" "regexp" "strconv" + "strings" "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" v1 "k8s.io/api/core/v1" workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest/container" ) @@ -68,6 +71,7 @@ type KubernetesAttestor struct { log *slog.Logger // rootPath specifies the location of `/`. This allows overriding for tests. rootPath string + clock clockwork.Clock } // NewKubernetesAttestor creates a new KubernetesAttestor. @@ -76,6 +80,7 @@ func NewKubernetesAttestor(cfg KubernetesAttestorConfig, log *slog.Logger) *Kube return &KubernetesAttestor{ kubeletClient: kubeletClient, log: log, + clock: clockwork.NewRealClock(), } } @@ -95,21 +100,18 @@ func (a *KubernetesAttestor) Attest(ctx context.Context, pid int) (*workloadiden "container_id", container.ID, ) - pod, err := a.getPodForID(ctx, container.PodID) + pod, containerStatus, err := a.getPodAndContainerStatus(ctx, container.PodID, container.ID) if err != nil { return nil, trace.Wrap(err, "finding pod by ID") } a.log.DebugContext(ctx, "Found pod", "pod_name", pod.Name) var ctr *workloadidentityv1pb.WorkloadAttrsKubernetesContainer - for _, status := range pod.Status.ContainerStatuses { - if status.ContainerID != container.ID { - continue - } + if containerStatus != nil { ctr = &workloadidentityv1pb.WorkloadAttrsKubernetesContainer{ - Name: status.Name, - Image: status.Image, - ImageDigest: imageDigestRegex.FindString(status.ImageID), + Name: containerStatus.Name, + Image: containerStatus.Image, + ImageDigest: imageDigestRegex.FindString(containerStatus.ImageID), } } @@ -126,19 +128,81 @@ func (a *KubernetesAttestor) Attest(ctx context.Context, pid int) (*workloadiden return att, nil } -// getPodForID retrieves the pod information for the provided pod ID. -// https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/server/server.go#L371 -func (a *KubernetesAttestor) getPodForID(ctx context.Context, podID string) (*v1.Pod, error) { +func (a *KubernetesAttestor) getPodAndContainerStatus(ctx context.Context, podID, containerID string) (*v1.Pod, *v1.ContainerStatus, error) { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + log := a.log.With("pod_id", podID, "container_id", containerID) + + retry, err := retryutils.NewRetryV2(retryutils.RetryV2Config{ + Driver: retryutils.NewExponentialDriver(100 * time.Millisecond), + Max: 2 * time.Second, + Clock: a.clock, + }) + if err != nil { + return nil, nil, trace.Wrap(err, "creating retrier") + } + + var ( + pod *v1.Pod + containerStatus *v1.ContainerStatus + ) +LOOP: + for { + pod, containerStatus, err = a.tryGetPodAndContainerStatus(ctx, podID, containerID) + switch { + case err != nil: + return nil, nil, err + case containerStatus == nil: + // It's possible for a workload container to start and request a SVID + // before the kubelet has updated its state, in which case we might + // get back no container status at all, or in the case of a restart, + // the previous run's status. + log.DebugContext(ctx, "Kubelet did not return expected container status; its state might be stale") + default: + break LOOP + } + + retry.Inc() + select { + case <-ctx.Done(): + break LOOP + case <-retry.After(): + } + } + + if pod != nil { + return pod, containerStatus, nil + } + return nil, nil, err +} + +func (a *KubernetesAttestor) tryGetPodAndContainerStatus(ctx context.Context, podID, containerID string) (*v1.Pod, *v1.ContainerStatus, error) { pods, err := a.kubeletClient.ListAllPods(ctx) if err != nil { - return nil, trace.Wrap(err, "listing all pods") + return nil, nil, trace.Wrap(err, "listing all pods") } - for _, pod := range pods.Items { - if string(pod.UID) == podID { - return &pod, nil + + var pod *v1.Pod + for _, p := range pods.Items { + if string(p.UID) == podID { + pod = &p + break + } + } + if pod == nil { + return nil, nil, trace.NotFound("pod %q not found", podID) + } + + var containerStatus *v1.ContainerStatus + for _, status := range pod.Status.ContainerStatuses { + // Kubelet returns the container ID prefixed by `://`. + if _, id, _ := strings.Cut(status.ContainerID, "://"); id == containerID { + containerStatus = &status + break } } - return nil, trace.NotFound("pod %q not found", podID) + return pod, containerStatus, nil } // kubeletClient is a HTTP client for the Kubelet API diff --git a/lib/tbot/workloadidentity/workloadattest/kubernetes_unix_test.go b/lib/tbot/workloadidentity/workloadattest/kubernetes_unix_test.go index 22141157f50f3..19eeaa2a27250 100644 --- a/lib/tbot/workloadidentity/workloadattest/kubernetes_unix_test.go +++ b/lib/tbot/workloadidentity/workloadattest/kubernetes_unix_test.go @@ -32,6 +32,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/testing/protocmp" @@ -55,11 +56,31 @@ func TestKubernetesAttestor_Attest(t *testing.T) { mockContainerID := "9da25af0b548c8c60aa60f77f299ba727bf72d58248bd7528eb5390ffcce555a" // Setup mock Kubelet Secure API + var requests int mockKubeletAPI := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { if req.URL.Path != "/pods" { http.NotFound(w, req) return } + + // Don't return the container status in the first response, to simulate + // the kubelet API's eventual consistency. + var containerStatuses []v1.ContainerStatus + switch { + case requests == 1: + containerStatuses = append(containerStatuses, v1.ContainerStatus{ + ContainerID: "docker://totally-wrong-container-id", + }) + case requests > 1: + containerStatuses = append(containerStatuses, v1.ContainerStatus{ + ContainerID: "docker://" + mockContainerID, + Name: "container-1", + Image: "my.registry.io/my-app:v1", + ImageID: "docker-pullable://my.registry.io/my-app@sha256:84c998f7610b356a5eed24f801c01b273cf3e83f081f25c9b16aa8136c2cafb1", + }) + } + requests++ + out := v1.PodList{ Items: []v1.Pod{ { @@ -75,14 +96,7 @@ func TestKubernetesAttestor_Attest(t *testing.T) { ServiceAccountName: "my-service-account", }, Status: v1.PodStatus{ - ContainerStatuses: []v1.ContainerStatus{ - { - ContainerID: mockContainerID, - Name: "container-1", - Image: "my.registry.io/my-app:v1", - ImageID: "docker-pullable://my.registry.io/my-app@sha256:84c998f7610b356a5eed24f801c01b273cf3e83f081f25c9b16aa8136c2cafb1", - }, - }, + ContainerStatuses: containerStatuses, }, }, }, @@ -121,6 +135,7 @@ func TestKubernetesAttestor_Attest(t *testing.T) { }, }, log) attestor.rootPath = tmpDir + attestor.clock = clockwork.NewRealClock() attestor.kubeletClient.getEnv = func(s string) string { env := map[string]string{ "TELEPORT_NODE_NAME": host,