From 22b2e8c003ec19149bb8e39943079ad422096887 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Krzysztof=20Skrz=C4=99tnicki?= Date: Fri, 2 Jan 2026 10:30:55 +0100 Subject: [PATCH] fix: add availability checks for instance metadata methods (#62457) * fix: add availability checks for instance metadata methods * clarify comment * ensure no race conditions occur --- lib/cloud/imds/azure/imds.go | 16 ++- lib/cloud/imds/azure/imds_test.go | 162 +++++++++++++++++++++++++----- 2 files changed, 154 insertions(+), 24 deletions(-) diff --git a/lib/cloud/imds/azure/imds.go b/lib/cloud/imds/azure/imds.go index fc43e42656dc9..315766edc1cd1 100644 --- a/lib/cloud/imds/azure/imds.go +++ b/lib/cloud/imds/azure/imds.go @@ -163,7 +163,9 @@ func (client *InstanceMetadataClient) getVersions(ctx context.Context) ([]string return versions.APIVersions, nil } -// IsAvailable checks if instance metadata is available. +// IsAvailable reports whether the Azure Instance Metadata Service is reachable. +// On first use it discovers and caches a supported IMDS API version (via /versions). +// Other methods call this internally to ensure the client is initialized. func (client *InstanceMetadataClient) IsAvailable(ctx context.Context) bool { if client.GetAPIVersion() != "" { return true @@ -225,6 +227,10 @@ type InstanceInfo struct { // GetInstanceInfo gets the Azure Instance information. func (client *InstanceMetadataClient) GetInstanceInfo(ctx context.Context) (*InstanceInfo, error) { + if !client.IsAvailable(ctx) { + return nil, trace.NotFound("Instance metadata is not available") + } + body, err := client.getRawMetadata(ctx, "/instance/compute", url.Values{"format": []string{"json"}}) if err != nil { return nil, trace.Wrap(err) @@ -253,12 +259,20 @@ func (client *InstanceMetadataClient) GetID(ctx context.Context) (string, error) // GetAttestedData gets attested data from the instance. func (client *InstanceMetadataClient) GetAttestedData(ctx context.Context, nonce string) ([]byte, error) { + if !client.IsAvailable(ctx) { + return nil, trace.NotFound("Instance metadata is not available") + } + body, err := client.getRawMetadata(ctx, "/attested/document", url.Values{"nonce": []string{nonce}, "format": []string{"json"}}) return body, trace.Wrap(err) } // GetAccessToken gets an oauth2 access token from the instance. func (client *InstanceMetadataClient) GetAccessToken(ctx context.Context, clientID string) (string, error) { + if !client.IsAvailable(ctx) { + return "", trace.NotFound("Instance metadata is not available") + } + params := url.Values{"resource": []string{"https://management.azure.com/"}} if clientID != "" { params["client_id"] = []string{clientID} diff --git a/lib/cloud/imds/azure/imds_test.go b/lib/cloud/imds/azure/imds_test.go index 8472d544ee6f7..720081155708d 100644 --- a/lib/cloud/imds/azure/imds_test.go +++ b/lib/cloud/imds/azure/imds_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "os" + "sync" "testing" "github.com/gravitational/trace" @@ -80,10 +81,9 @@ func TestAzureIsInstanceMetadataAvailable(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - ctx := context.Background() server := httptest.NewServer(tc.handler) clt := tc.client(t, server) - tc.assertion(t, clt.IsAvailable(ctx)) + tc.assertion(t, clt.IsAvailable(t.Context())) }) } } @@ -172,6 +172,69 @@ func TestParseMetadataClientError(t *testing.T) { } } +type mockIMDS struct { + t *testing.T + versionsCalled bool + lastAPIVersion string + + mu sync.Mutex +} + +func (m *mockIMDS) status() (versionsCalled bool, lastAPIVersion string) { + m.mu.Lock() + defer m.mu.Unlock() + return m.versionsCalled, m.lastAPIVersion +} + +func newMockIMDS(t *testing.T, overrides map[string]http.Handler) (*mockIMDS, *httptest.Server) { + t.Helper() + + m := &mockIMDS{t: t} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m.mu.Lock() + defer m.mu.Unlock() + + m.lastAPIVersion = r.URL.Query().Get("api-version") + + if r.URL.Path == "/versions" { + m.versionsCalled = true + } + + // /versions doesn't require api-version; all other endpoints do + if r.URL.Path != "/versions" && m.lastAPIVersion == "" { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"Bad request. api-version is invalid or was not specified in the request.","newest-versions":["2023-07-01","2021-02-01"]}`)) + return + } + + if overrides != nil { + handler, ok := overrides[r.URL.Path] + if ok { + handler.ServeHTTP(w, r) + return + } + } + + responses := map[string]string{ + "/versions": `{"apiVersions":["2021-02-01","2023-07-01"]}`, + "/instance/compute": `{"resourceId":"/subscriptions/test","location":"eastus"}`, + "/instance/compute/tagsList": `[{"name":"foo","value":"bar"}]`, + "/attested/document": `{"signature":"test"}`, + "/identity/oauth2/token": `{"access_token":"test-token"}`, + } + + if body, ok := responses[r.URL.Path]; ok { + _, _ = w.Write([]byte(body)) + return + } + http.NotFound(w, r) + })) + + t.Cleanup(srv.Close) + + return m, srv +} + func TestGetInstanceInfo(t *testing.T) { t.Parallel() for _, tc := range []struct { @@ -179,7 +242,7 @@ func TestGetInstanceInfo(t *testing.T) { statusCode int body []byte expectedInstanceInfo *InstanceInfo - errAssertion require.ErrorAssertionFunc + wantErr string }{ { name: "with resource ID", @@ -188,7 +251,7 @@ func TestGetInstanceInfo(t *testing.T) { expectedInstanceInfo: &InstanceInfo{ ResourceID: "test-id", }, - errAssertion: require.NoError, + wantErr: "", }, { name: "all fields", @@ -202,36 +265,38 @@ func TestGetInstanceInfo(t *testing.T) { SubscriptionID: "5187AF11-3581-4AB6-A654-59405CD40C44", VMID: "ED7DAC09-6E73-447F-BD18-AF4D1196C1E4", }, - errAssertion: require.NoError, + wantErr: "", }, { name: "request error", statusCode: http.StatusNotFound, - errAssertion: func(tt require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "not found") - }, + wantErr: "not found", }, { name: "empty body returns an error", statusCode: http.StatusOK, - errAssertion: func(tt require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, "error found in #0 byte") - }, + wantErr: "error found in #0 byte", }, } { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(tc.statusCode) - w.Write(tc.body) - })) + + _, server := newMockIMDS(t, map[string]http.Handler{ + "/instance/compute": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.statusCode) + w.Write(tc.body) + }), + }) client := NewInstanceMetadataClient(WithBaseURL(server.URL)) - instanceInfo, err := client.GetInstanceInfo(context.Background()) - tc.errAssertion(t, err) - if tc.expectedInstanceInfo != nil { + instanceInfo, err := client.GetInstanceInfo(t.Context()) + if tc.wantErr == "" { + require.NoError(t, err) require.Equal(t, tc.expectedInstanceInfo, instanceInfo) + } else { + require.Nil(t, instanceInfo) + require.ErrorContains(t, err, tc.wantErr) } }) } @@ -270,15 +335,66 @@ func TestGetInstanceID(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(tc.statusCode) - w.Write(tc.body) - })) + + _, server := newMockIMDS(t, map[string]http.Handler{ + "/instance/compute": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tc.statusCode) + w.Write(tc.body) + }), + }) client := NewInstanceMetadataClient(WithBaseURL(server.URL)) - resourceID, err := client.GetID(context.Background()) + resourceID, err := client.GetID(t.Context()) tc.errAssertion(t, err) require.Equal(t, tc.expectedResourceID, resourceID) }) } } + +func TestMethodsEnsureInitialization(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + call func(ctx context.Context, c *InstanceMetadataClient) error + }{ + {"GetInstanceInfo", func(ctx context.Context, c *InstanceMetadataClient) error { + _, err := c.GetInstanceInfo(ctx) + return err + }}, + {"GetID", func(ctx context.Context, c *InstanceMetadataClient) error { + _, err := c.GetID(ctx) + return err + }}, + {"GetTags", func(ctx context.Context, c *InstanceMetadataClient) error { + _, err := c.GetTags(ctx) + return err + }}, + {"GetAttestedData", func(ctx context.Context, c *InstanceMetadataClient) error { + _, err := c.GetAttestedData(ctx, "") + return err + }}, + {"GetAccessToken", func(ctx context.Context, c *InstanceMetadataClient) error { + _, err := c.GetAccessToken(ctx, "") + return err + }}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + mock, srv := newMockIMDS(t, nil) + defer srv.Close() + + client := NewInstanceMetadataClient(WithBaseURL(srv.URL)) + require.Empty(t, client.GetAPIVersion(), "client should start uninitialized") + + err := tc.call(t.Context(), client) + require.NoError(t, err) + versionsCalled, lastAPIVersion := mock.status() + require.True(t, versionsCalled, "should call /versions to initialize") + require.Equal(t, "2023-07-01", lastAPIVersion, "should use negotiated api-version") + require.Equal(t, "2023-07-01", client.GetAPIVersion(), "client should be initialized") + }) + } +}