Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion lib/cloud/imds/azure/imds.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ func (client *InstanceMetadataClient) getVersions(ctx context.Context) ([]string
return versions.APIVersions, nil
}

// IsAvailable checks if instance metadata is available.
// IsAvailable reports whether the Azure Instance Metadata Service is reachable.
// On first use it discovers and caches a supported IMDS API version (via /versions).
// Other methods call this internally to ensure the client is initialized.
func (client *InstanceMetadataClient) IsAvailable(ctx context.Context) bool {
if client.GetAPIVersion() != "" {
return true
Expand Down Expand Up @@ -225,6 +227,10 @@ type InstanceInfo struct {

// GetInstanceInfo gets the Azure Instance information.
func (client *InstanceMetadataClient) GetInstanceInfo(ctx context.Context) (*InstanceInfo, error) {
if !client.IsAvailable(ctx) {
return nil, trace.NotFound("Instance metadata is not available")
}

body, err := client.getRawMetadata(ctx, "/instance/compute", url.Values{"format": []string{"json"}})
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -253,12 +259,20 @@ func (client *InstanceMetadataClient) GetID(ctx context.Context) (string, error)

// GetAttestedData gets attested data from the instance.
func (client *InstanceMetadataClient) GetAttestedData(ctx context.Context, nonce string) ([]byte, error) {
if !client.IsAvailable(ctx) {
return nil, trace.NotFound("Instance metadata is not available")
}

body, err := client.getRawMetadata(ctx, "/attested/document", url.Values{"nonce": []string{nonce}, "format": []string{"json"}})
return body, trace.Wrap(err)
}

// GetAccessToken gets an oauth2 access token from the instance.
func (client *InstanceMetadataClient) GetAccessToken(ctx context.Context, clientID string) (string, error) {
if !client.IsAvailable(ctx) {
return "", trace.NotFound("Instance metadata is not available")
}

params := url.Values{"resource": []string{"https://management.azure.com/"}}
if clientID != "" {
params["client_id"] = []string{clientID}
Expand Down
162 changes: 139 additions & 23 deletions lib/cloud/imds/azure/imds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -80,10 +81,9 @@ func TestAzureIsInstanceMetadataAvailable(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

ctx := context.Background()
server := httptest.NewServer(tc.handler)
clt := tc.client(t, server)
tc.assertion(t, clt.IsAvailable(ctx))
tc.assertion(t, clt.IsAvailable(t.Context()))
})
}
}
Expand Down Expand Up @@ -172,14 +172,77 @@ func TestParseMetadataClientError(t *testing.T) {
}
}

type mockIMDS struct {
t *testing.T
versionsCalled bool
lastAPIVersion string

mu sync.Mutex
}

func (m *mockIMDS) status() (versionsCalled bool, lastAPIVersion string) {
m.mu.Lock()
defer m.mu.Unlock()
return m.versionsCalled, m.lastAPIVersion
}

func newMockIMDS(t *testing.T, overrides map[string]http.Handler) (*mockIMDS, *httptest.Server) {
t.Helper()

m := &mockIMDS{t: t}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
m.mu.Lock()
defer m.mu.Unlock()

m.lastAPIVersion = r.URL.Query().Get("api-version")

if r.URL.Path == "/versions" {
m.versionsCalled = true
}

// /versions doesn't require api-version; all other endpoints do
if r.URL.Path != "/versions" && m.lastAPIVersion == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(`{"error":"Bad request. api-version is invalid or was not specified in the request.","newest-versions":["2023-07-01","2021-02-01"]}`))
return
}

if overrides != nil {
handler, ok := overrides[r.URL.Path]
if ok {
handler.ServeHTTP(w, r)
return
}
}

responses := map[string]string{
"/versions": `{"apiVersions":["2021-02-01","2023-07-01"]}`,
"/instance/compute": `{"resourceId":"/subscriptions/test","location":"eastus"}`,
"/instance/compute/tagsList": `[{"name":"foo","value":"bar"}]`,
"/attested/document": `{"signature":"test"}`,
"/identity/oauth2/token": `{"access_token":"test-token"}`,
}

if body, ok := responses[r.URL.Path]; ok {
_, _ = w.Write([]byte(body))
return
}
http.NotFound(w, r)
}))

t.Cleanup(srv.Close)

return m, srv
}

func TestGetInstanceInfo(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
statusCode int
body []byte
expectedInstanceInfo *InstanceInfo
errAssertion require.ErrorAssertionFunc
wantErr string
}{
{
name: "with resource ID",
Expand All @@ -188,7 +251,7 @@ func TestGetInstanceInfo(t *testing.T) {
expectedInstanceInfo: &InstanceInfo{
ResourceID: "test-id",
},
errAssertion: require.NoError,
wantErr: "",
},
{
name: "all fields",
Expand All @@ -202,36 +265,38 @@ func TestGetInstanceInfo(t *testing.T) {
SubscriptionID: "5187AF11-3581-4AB6-A654-59405CD40C44",
VMID: "ED7DAC09-6E73-447F-BD18-AF4D1196C1E4",
},
errAssertion: require.NoError,
wantErr: "",
},
{
name: "request error",
statusCode: http.StatusNotFound,
errAssertion: func(tt require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "not found")
},
wantErr: "not found",
},
{
name: "empty body returns an error",
statusCode: http.StatusOK,
errAssertion: func(tt require.TestingT, err error, i ...interface{}) {
require.ErrorContains(t, err, "error found in #0 byte")
},
wantErr: "error found in #0 byte",
},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(tc.statusCode)
w.Write(tc.body)
}))

_, server := newMockIMDS(t, map[string]http.Handler{
"/instance/compute": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.statusCode)
w.Write(tc.body)
}),
})

client := NewInstanceMetadataClient(WithBaseURL(server.URL))
instanceInfo, err := client.GetInstanceInfo(context.Background())
tc.errAssertion(t, err)
if tc.expectedInstanceInfo != nil {
instanceInfo, err := client.GetInstanceInfo(t.Context())
if tc.wantErr == "" {
require.NoError(t, err)
require.Equal(t, tc.expectedInstanceInfo, instanceInfo)
} else {
require.Nil(t, instanceInfo)
require.ErrorContains(t, err, tc.wantErr)
}
})
}
Expand Down Expand Up @@ -270,15 +335,66 @@ func TestGetInstanceID(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(tc.statusCode)
w.Write(tc.body)
}))

_, server := newMockIMDS(t, map[string]http.Handler{
"/instance/compute": http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(tc.statusCode)
w.Write(tc.body)
}),
})

client := NewInstanceMetadataClient(WithBaseURL(server.URL))
resourceID, err := client.GetID(context.Background())
resourceID, err := client.GetID(t.Context())
tc.errAssertion(t, err)
require.Equal(t, tc.expectedResourceID, resourceID)
})
}
}

func TestMethodsEnsureInitialization(t *testing.T) {
t.Parallel()

tests := []struct {
name string
call func(ctx context.Context, c *InstanceMetadataClient) error
}{
{"GetInstanceInfo", func(ctx context.Context, c *InstanceMetadataClient) error {
_, err := c.GetInstanceInfo(ctx)
return err
}},
{"GetID", func(ctx context.Context, c *InstanceMetadataClient) error {
_, err := c.GetID(ctx)
return err
}},
{"GetTags", func(ctx context.Context, c *InstanceMetadataClient) error {
_, err := c.GetTags(ctx)
return err
}},
{"GetAttestedData", func(ctx context.Context, c *InstanceMetadataClient) error {
_, err := c.GetAttestedData(ctx, "")
return err
}},
{"GetAccessToken", func(ctx context.Context, c *InstanceMetadataClient) error {
_, err := c.GetAccessToken(ctx, "")
return err
}},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
mock, srv := newMockIMDS(t, nil)
defer srv.Close()

client := NewInstanceMetadataClient(WithBaseURL(srv.URL))
require.Empty(t, client.GetAPIVersion(), "client should start uninitialized")

err := tc.call(t.Context(), client)
require.NoError(t, err)
versionsCalled, lastAPIVersion := mock.status()
require.True(t, versionsCalled, "should call /versions to initialize")
require.Equal(t, "2023-07-01", lastAPIVersion, "should use negotiated api-version")
require.Equal(t, "2023-07-01", client.GetAPIVersion(), "client should be initialized")
})
}
}
Loading