Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 80 additions & 16 deletions lib/tbot/workloadidentity/workloadattest/kubernetes_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ import (
"os"
"regexp"
"strconv"
"strings"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
v1 "k8s.io/api/core/v1"

workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/tbot/workloadidentity/workloadattest/container"
)

Expand Down Expand Up @@ -68,6 +71,7 @@ type KubernetesAttestor struct {
log *slog.Logger
// rootPath specifies the location of `/`. This allows overriding for tests.
rootPath string
clock clockwork.Clock
}

// NewKubernetesAttestor creates a new KubernetesAttestor.
Expand All @@ -76,6 +80,7 @@ func NewKubernetesAttestor(cfg KubernetesAttestorConfig, log *slog.Logger) *Kube
return &KubernetesAttestor{
kubeletClient: kubeletClient,
log: log,
clock: clockwork.NewRealClock(),
}
}

Expand All @@ -95,21 +100,18 @@ func (a *KubernetesAttestor) Attest(ctx context.Context, pid int) (*workloadiden
"container_id", container.ID,
)

pod, err := a.getPodForID(ctx, container.PodID)
pod, containerStatus, err := a.getPodAndContainerStatus(ctx, container.PodID, container.ID)
if err != nil {
return nil, trace.Wrap(err, "finding pod by ID")
}
a.log.DebugContext(ctx, "Found pod", "pod_name", pod.Name)

var ctr *workloadidentityv1pb.WorkloadAttrsKubernetesContainer
for _, status := range pod.Status.ContainerStatuses {
if status.ContainerID != container.ID {
continue
}
if containerStatus != nil {
ctr = &workloadidentityv1pb.WorkloadAttrsKubernetesContainer{
Name: status.Name,
Image: status.Image,
ImageDigest: imageDigestRegex.FindString(status.ImageID),
Name: containerStatus.Name,
Image: containerStatus.Image,
ImageDigest: imageDigestRegex.FindString(containerStatus.ImageID),
}
}

Expand All @@ -126,19 +128,81 @@ func (a *KubernetesAttestor) Attest(ctx context.Context, pid int) (*workloadiden
return att, nil
}

// getPodForID retrieves the pod information for the provided pod ID.
// https://github.com/kubernetes/kubernetes/blob/master/pkg/kubelet/server/server.go#L371
func (a *KubernetesAttestor) getPodForID(ctx context.Context, podID string) (*v1.Pod, error) {
func (a *KubernetesAttestor) getPodAndContainerStatus(ctx context.Context, podID, containerID string) (*v1.Pod, *v1.ContainerStatus, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()

log := a.log.With("pod_id", podID, "container_id", containerID)

retry, err := retryutils.NewRetryV2(retryutils.RetryV2Config{
Driver: retryutils.NewExponentialDriver(100 * time.Millisecond),
Max: 2 * time.Second,
Clock: a.clock,
})
if err != nil {
return nil, nil, trace.Wrap(err, "creating retrier")
}

var (
pod *v1.Pod
containerStatus *v1.ContainerStatus
)
LOOP:
for {
pod, containerStatus, err = a.tryGetPodAndContainerStatus(ctx, podID, containerID)
switch {
case err != nil:
return nil, nil, err
case containerStatus == nil:
// It's possible for a workload container to start and request a SVID
// before the kubelet has updated its state, in which case we might
// get back no container status at all, or in the case of a restart,
// the previous run's status.
log.DebugContext(ctx, "Kubelet did not return expected container status; its state might be stale")
default:
break LOOP
}

retry.Inc()
select {
case <-ctx.Done():
break LOOP
case <-retry.After():
}
}

if pod != nil {
return pod, containerStatus, nil
}
return nil, nil, err
}

func (a *KubernetesAttestor) tryGetPodAndContainerStatus(ctx context.Context, podID, containerID string) (*v1.Pod, *v1.ContainerStatus, error) {
pods, err := a.kubeletClient.ListAllPods(ctx)
if err != nil {
return nil, trace.Wrap(err, "listing all pods")
return nil, nil, trace.Wrap(err, "listing all pods")
}
for _, pod := range pods.Items {
if string(pod.UID) == podID {
return &pod, nil

var pod *v1.Pod
for _, p := range pods.Items {
if string(p.UID) == podID {
pod = &p
break
}
}
if pod == nil {
return nil, nil, trace.NotFound("pod %q not found", podID)
}

var containerStatus *v1.ContainerStatus
for _, status := range pod.Status.ContainerStatuses {
// Kubelet returns the container ID prefixed by `<type>://`.
if _, id, _ := strings.Cut(status.ContainerID, "://"); id == containerID {
containerStatus = &status
break
}
}
return nil, trace.NotFound("pod %q not found", podID)
return pod, containerStatus, nil
}

// kubeletClient is a HTTP client for the Kubelet API
Expand Down
31 changes: 23 additions & 8 deletions lib/tbot/workloadidentity/workloadattest/kubernetes_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"testing"

"github.com/google/go-cmp/cmp"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
Expand All @@ -55,11 +56,31 @@ func TestKubernetesAttestor_Attest(t *testing.T) {
mockContainerID := "9da25af0b548c8c60aa60f77f299ba727bf72d58248bd7528eb5390ffcce555a"

// Setup mock Kubelet Secure API
var requests int
mockKubeletAPI := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path != "/pods" {
http.NotFound(w, req)
return
}

// Don't return the container status in the first response, to simulate
// the kubelet API's eventual consistency.
var containerStatuses []v1.ContainerStatus
switch {
case requests == 1:
containerStatuses = append(containerStatuses, v1.ContainerStatus{
ContainerID: "docker://totally-wrong-container-id",
})
case requests > 1:
containerStatuses = append(containerStatuses, v1.ContainerStatus{
ContainerID: "docker://" + mockContainerID,
Name: "container-1",
Image: "my.registry.io/my-app:v1",
ImageID: "docker-pullable://my.registry.io/my-app@sha256:84c998f7610b356a5eed24f801c01b273cf3e83f081f25c9b16aa8136c2cafb1",
})
}
requests++

out := v1.PodList{
Items: []v1.Pod{
{
Expand All @@ -75,14 +96,7 @@ func TestKubernetesAttestor_Attest(t *testing.T) {
ServiceAccountName: "my-service-account",
},
Status: v1.PodStatus{
ContainerStatuses: []v1.ContainerStatus{
{
ContainerID: mockContainerID,
Name: "container-1",
Image: "my.registry.io/my-app:v1",
ImageID: "docker-pullable://my.registry.io/my-app@sha256:84c998f7610b356a5eed24f801c01b273cf3e83f081f25c9b16aa8136c2cafb1",
},
},
ContainerStatuses: containerStatuses,
},
},
},
Expand Down Expand Up @@ -121,6 +135,7 @@ func TestKubernetesAttestor_Attest(t *testing.T) {
},
}, log)
attestor.rootPath = tmpDir
attestor.clock = clockwork.NewRealClock()
attestor.kubeletClient.getEnv = func(s string) string {
env := map[string]string{
"TELEPORT_NODE_NAME": host,
Expand Down
Loading