diff --git a/api/types/constants.go b/api/types/constants.go index a2daa0b5ff5e0..e99f779f77d93 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -911,6 +911,8 @@ const ( // TeleportAzureMSIEndpoint is a special URL intercepted by TSH local proxy, serving Azure credentials. TeleportAzureMSIEndpoint = "azure-msi." + TeleportNamespace + // TeleportAzureIdentityEndpoint is a special URL intercepted by TSH local proxy, serving Azure credentials. + TeleportAzureIdentityEndpoint = "azure-identity." + TeleportNamespace // ConnectMyComputerNodeOwnerLabel is a label used to control access to the node managed by // Teleport Connect as part of Connect My Computer. See [teleterm.connectmycomputer.RoleSetup]. diff --git a/lib/srv/alpnproxy/azure_msi_middleware.go b/lib/srv/alpnproxy/azure_token_middleware.go similarity index 57% rename from lib/srv/alpnproxy/azure_msi_middleware.go rename to lib/srv/alpnproxy/azure_token_middleware.go index 6e462f2de8906..9a1704a2aea94 100644 --- a/lib/srv/alpnproxy/azure_msi_middleware.go +++ b/lib/srv/alpnproxy/azure_token_middleware.go @@ -24,6 +24,7 @@ import ( "fmt" "log/slog" "net/http" + "strings" "sync" "time" @@ -35,8 +36,11 @@ import ( "github.com/gravitational/teleport/lib/jwt" ) -// AzureMSIMiddleware implements a simplified version of MSI server serving auth tokens. -type AzureMSIMiddleware struct { +// AzureTokenMiddleware implements a simplified version of MSI and Identity +// servers serving auth tokens. +// +// https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference +type AzureTokenMiddleware struct { DefaultLocalProxyHTTPMiddleware // Identity is the Azure identity to be served by the server. Only single identity will be provided. @@ -58,14 +62,14 @@ type AzureMSIMiddleware struct { privateKeyMu sync.RWMutex } -var _ LocalProxyHTTPMiddleware = &AzureMSIMiddleware{} +var _ LocalProxyHTTPMiddleware = &AzureTokenMiddleware{} -func (m *AzureMSIMiddleware) CheckAndSetDefaults() error { +func (m *AzureTokenMiddleware) CheckAndSetDefaults() error { if m.Clock == nil { m.Clock = clockwork.NewRealClock() } if m.Log == nil { - m.Log = slog.With(teleport.ComponentKey, "azure_msi") + m.Log = slog.With(teleport.ComponentKey, "azure_token") } if m.Secret == "" { @@ -83,37 +87,45 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error { return nil } -func (m *AzureMSIMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool { - if req.Host == types.TeleportAzureMSIEndpoint { - if err := m.msiEndpoint(rw, req); err != nil { - m.Log.WarnContext(req.Context(), "Bad MSI request", "error", err) - trace.WriteError(rw, trace.Wrap(err)) - } - return true +func (m *AzureTokenMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool { + var err error + switch req.Host { + case types.TeleportAzureMSIEndpoint: + err = m.handleEndpoint(rw, req, MSIResourceFieldName, strings.TrimPrefix(req.URL.Path, "/")) + case types.TeleportAzureIdentityEndpoint: + err = m.handleEndpoint(rw, req, IdentityResourceFieldName, req.Header.Get(IdentitySecretHeader)) + default: + m.Log.DebugContext(req.Context(), "Unsupported token host", "host", req.Host) + return false + } + if err != nil { + m.Log.WarnContext(req.Context(), "Bad token request", "error", err) + trace.WriteError(rw, trace.Wrap(err)) } - return false + return true } // SetPrivateKey updates the private key. -func (m *AzureMSIMiddleware) SetPrivateKey(privateKey crypto.Signer) { +func (m *AzureTokenMiddleware) SetPrivateKey(privateKey crypto.Signer) { m.privateKeyMu.Lock() defer m.privateKeyMu.Unlock() m.privateKey = privateKey } -func (m *AzureMSIMiddleware) getPrivateKey() (crypto.Signer, error) { +func (m *AzureTokenMiddleware) getPrivateKey() (crypto.Signer, error) { m.privateKeyMu.RLock() defer m.privateKeyMu.RUnlock() if m.privateKey == nil { // Use a plain error to return status code 500. - return nil, trace.Errorf("missing private key set in AzureMSIMiddleware") + return nil, trace.Errorf("missing private key set in AzureTokenMiddleware") } return m.privateKey, nil } -func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Request) error { +// handleEndpoint handles the Azure identity token generation. +func (m *AzureTokenMiddleware) handleEndpoint(rw http.ResponseWriter, req *http.Request, resourceFieldName string, secret string) error { // request validation - if req.URL.Path != ("/" + m.Secret) { + if secret != m.Secret { return trace.BadParameter("invalid secret") } @@ -122,8 +134,7 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque return trace.BadParameter("expected Metadata header with value 'true'") } - err := req.ParseForm() - if err != nil { + if err := req.ParseForm(); err != nil { return trace.Wrap(err) } @@ -132,19 +143,19 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque return trace.BadParameter("missing value for parameter 'resource'") } - // check that msi_res_id matches expected Azure Identity - requestedAzureIdentity := req.Form.Get("msi_res_id") + // check that resource field matches expected Azure Identity + requestedAzureIdentity := req.Form.Get(resourceFieldName) if requestedAzureIdentity != m.Identity { m.Log.WarnContext(req.Context(), "Requested unexpected identity", "requested_identity", requestedAzureIdentity, "expected_identity", m.Identity) - return trace.BadParameter("unexpected value for parameter 'msi_res_id': %v", requestedAzureIdentity) + return trace.BadParameter("unexpected value for parameter '%s': %v", resourceFieldName, requestedAzureIdentity) } - respBody, err := m.fetchMSILoginResp(resource) + respBody, err := m.fetchLoginResp(resource) if err != nil { return trace.Wrap(err) } - m.Log.InfoContext(req.Context(), "MSI: returning token for identity", "identity", m.Identity) + m.Log.InfoContext(req.Context(), "Returning token for identity", "identity", m.Identity) rw.Header().Add("Content-Type", "application/json; charset=utf-8") rw.Header().Add("Content-Length", fmt.Sprintf("%v", len(respBody))) @@ -153,7 +164,7 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque return nil } -func (m *AzureMSIMiddleware) fetchMSILoginResp(resource string) ([]byte, error) { +func (m *AzureTokenMiddleware) fetchLoginResp(resource string) ([]byte, error) { now := m.Clock.Now() notBefore := now.Add(-10 * time.Second) @@ -187,16 +198,18 @@ func (m *AzureMSIMiddleware) fetchMSILoginResp(resource string) ([]byte, error) return out, nil } -func (m *AzureMSIMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error) { +func (m *AzureTokenMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error) { privateKey, err := m.getPrivateKey() if err != nil { return "", trace.Wrap(err) } // Create a new key that can sign and verify tokens. key, err := jwt.New(&jwt.Config{ - Clock: m.Clock, - PrivateKey: privateKey, - ClusterName: types.TeleportAzureMSIEndpoint, // todo get cluster name + Clock: m.Clock, + PrivateKey: privateKey, + // TODO(gabrielcorado): use the cluster name. This value must match the + // one used by the proxy. + ClusterName: types.TeleportAzureMSIEndpoint, }) if err != nil { return "", trace.Wrap(err) @@ -209,3 +222,19 @@ func (m *AzureMSIMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error) return token, nil } + +const ( + // IdentitySecretHeader is the HTTP header that contains the identity + // secret on App Service identity requests. + // + // https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference + IdentitySecretHeader = "X-IDENTITY-HEADER" + // IdentityResourceFieldName is the request field name that contains the + // Azure identity on App Service identity requests. + // + // https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference + IdentityResourceFieldName = "mi_res_id" + // MSIResourceFieldName is the request field name that contains the Azure + // Identity on MSI identity requests. + MSIResourceFieldName = "msi_res_id" +) diff --git a/lib/srv/alpnproxy/azure_msi_middleware_test.go b/lib/srv/alpnproxy/azure_token_middleware_test.go similarity index 63% rename from lib/srv/alpnproxy/azure_msi_middleware_test.go rename to lib/srv/alpnproxy/azure_token_middleware_test.go index c1fd2a42d5bea..f00eac9c15dac 100644 --- a/lib/srv/alpnproxy/azure_msi_middleware_test.go +++ b/lib/srv/alpnproxy/azure_token_middleware_test.go @@ -21,9 +21,11 @@ package alpnproxy import ( "crypto" "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" + "net/url" "strings" "testing" "time" @@ -37,23 +39,33 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -func TestAzureMSIMiddlewareHandleRequest(t *testing.T) { +func TestAzureTokenMiddlewareHandleRequest(t *testing.T) { t.Parallel() for _, alg := range []cryptosuites.Algorithm{cryptosuites.RSA2048, cryptosuites.ECDSAP256} { - t.Run(alg.String(), func(t *testing.T) { - testAzureMSIMiddlewareHandleRequest(t, alg) - }) + for _, endpoint := range []struct { + name string + endpoint string + resourceFieldName string + secret func(string) azureRequestModifier + }{ + {name: "msi", endpoint: types.TeleportAzureMSIEndpoint, resourceFieldName: MSIResourceFieldName, secret: msiSecretModifier}, + {name: "identity", endpoint: types.TeleportAzureIdentityEndpoint, resourceFieldName: IdentityResourceFieldName, secret: identitySecretModifier}, + } { + t.Run(alg.String()+"_"+endpoint.name, func(t *testing.T) { + testAzureTokenMiddlewareHandleRequest(t, alg, endpoint.endpoint, endpoint.secret, endpoint.resourceFieldName) + }) + } } } -func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorithm) { +func testAzureTokenMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorithm, endpoint string, endpointSecret func(string) azureRequestModifier, resourceFieldName string) { newPrivateKey := func() crypto.Signer { privateKey, err := cryptosuites.GenerateKeyWithAlgorithm(alg) require.NoError(t, err) return privateKey } privateKey := newPrivateKey() - m := &AzureMSIMiddleware{ + m := &AzureTokenMiddleware{ Identity: "azureTestIdentity", TenantID: "cafecafe-cafe-4aaa-cafe-cafecafecafe", ClientID: "decaffff-cafe-4aaa-cafe-cafecafecafe", @@ -66,8 +78,10 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith tests := []struct { name string url string + params map[string]string headers map[string]string privateKey crypto.Signer + secretFunc azureRequestModifier expectedHandle bool expectedCode int expectedBody string @@ -81,8 +95,9 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith }, { name: "invalid request, wrong secret", - url: "https://azure-msi.teleport.dev/bad-secret", - headers: nil, + url: endpoint, + headers: map[string]string{}, + secretFunc: endpointSecret("bad-secret"), privateKey: privateKey, expectedHandle: true, expectedCode: 400, @@ -90,8 +105,9 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith }, { name: "invalid request, missing secret", - url: "https://azure-msi.teleport.dev", - headers: nil, + url: endpoint, + headers: map[string]string{}, + secretFunc: emptySecretMethod, privateKey: privateKey, expectedHandle: true, expectedCode: 400, @@ -99,8 +115,9 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith }, { name: "invalid request, missing metadata", - url: "https://azure-msi.teleport.dev/my-secret", - headers: nil, + url: endpoint, + headers: map[string]string{}, + secretFunc: endpointSecret("my-secret"), privateKey: privateKey, expectedHandle: true, expectedCode: 400, @@ -108,8 +125,9 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith }, { name: "invalid request, bad metadata value", - url: "https://azure-msi.teleport.dev/my-secret", + url: endpoint, headers: map[string]string{"Metadata": "false"}, + secretFunc: endpointSecret("my-secret"), privateKey: privateKey, expectedHandle: true, expectedCode: 400, @@ -117,44 +135,63 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith }, { name: "invalid request, missing arguments", - url: "https://azure-msi.teleport.dev/my-secret", + url: endpoint, headers: map[string]string{"Metadata": "true"}, + secretFunc: endpointSecret("my-secret"), privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"missing value for parameter 'resource'\"\n }\n}", }, { - name: "invalid request, missing resource", - url: "https://azure-msi.teleport.dev/my-secret?msi_res_id=azureTestIdentity", - headers: map[string]string{"Metadata": "true"}, + name: "invalid request, missing resource", + url: endpoint, + headers: map[string]string{"Metadata": "true"}, + params: map[string]string{ + resourceFieldName: "azureTestIdentity", + }, + secretFunc: endpointSecret("my-secret"), privateKey: privateKey, expectedHandle: true, expectedCode: 400, expectedBody: "{\n \"error\": {\n \"message\": \"missing value for parameter 'resource'\"\n }\n}", }, { - name: "invalid request, missing identity", - url: "https://azure-msi.teleport.dev/my-secret?resource=myresource", - headers: map[string]string{"Metadata": "true"}, + name: "invalid request, missing identity", + url: endpoint, + headers: map[string]string{"Metadata": "true"}, + params: map[string]string{ + "resource": "myresource", + }, + secretFunc: endpointSecret("my-secret"), privateKey: privateKey, expectedHandle: true, expectedCode: 400, - expectedBody: "{\n \"error\": {\n \"message\": \"unexpected value for parameter 'msi_res_id': \"\n }\n}", + expectedBody: fmt.Sprintf("{\n \"error\": {\n \"message\": \"unexpected value for parameter '%s': \"\n }\n}", resourceFieldName), }, { - name: "invalid request, wrong identity", - url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestWrongIdentity", - headers: map[string]string{"Metadata": "true"}, + name: "invalid request, wrong identity", + url: endpoint, + headers: map[string]string{"Metadata": "true"}, + params: map[string]string{ + resourceFieldName: "azureTestWrongIdentity", + "resource": "myresource", + }, + secretFunc: endpointSecret("my-secret"), privateKey: privateKey, expectedHandle: true, expectedCode: 400, - expectedBody: "{\n \"error\": {\n \"message\": \"unexpected value for parameter 'msi_res_id': azureTestWrongIdentity\"\n }\n}", + expectedBody: fmt.Sprintf("{\n \"error\": {\n \"message\": \"unexpected value for parameter '%s': azureTestWrongIdentity\"\n }\n}", resourceFieldName), }, { - name: "well-formatted request", - url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestIdentity", - headers: map[string]string{"Metadata": "true"}, + name: "well-formatted request", + url: endpoint, + headers: map[string]string{"Metadata": "true"}, + params: map[string]string{ + resourceFieldName: "azureTestIdentity", + "resource": "myresource", + }, + secretFunc: endpointSecret("my-secret"), privateKey: privateKey, expectedHandle: true, expectedCode: 200, @@ -213,13 +250,18 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith }, }, { - name: "no private key set", - url: "https://azure-msi.teleport.dev/my-secret?resource=myresource&msi_res_id=azureTestIdentity", - headers: map[string]string{"Metadata": "true"}, + name: "no private key set", + url: endpoint, + headers: map[string]string{"Metadata": "true"}, + params: map[string]string{ + resourceFieldName: "azureTestIdentity", + "resource": "myresource", + }, + secretFunc: endpointSecret("my-secret"), privateKey: nil, expectedHandle: true, expectedCode: 500, - expectedBody: "{\n \"error\": {\n \"message\": \"missing private key set in AzureMSIMiddleware\"\n }\n}", + expectedBody: "{\n \"error\": {\n \"message\": \"missing private key set in AzureTokenMiddleware\"\n }\n}", }, } @@ -227,14 +269,22 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith t.Run(tt.name, func(t *testing.T) { m.SetPrivateKey(tt.privateKey) + params := url.Values{} + for name, value := range tt.params { + params.Set(name, value) + } + // prepare request - req, err := http.NewRequest("GET", tt.url, strings.NewReader("")) + req, err := http.NewRequest("GET", "https://"+tt.url+"?"+params.Encode(), strings.NewReader("")) require.NoError(t, err) for k, v := range tt.headers { req.Header.Set(k, v) } + if tt.secretFunc != nil { + tt.secretFunc(req) + } recorder := httptest.NewRecorder() // run handler @@ -261,3 +311,20 @@ func testAzureMSIMiddlewareHandleRequest(t *testing.T, alg cryptosuites.Algorith }) } } + +// azureRequestModifier modifies an Azure request. +type azureRequestModifier func(req *http.Request) + +func msiSecretModifier(secret string) azureRequestModifier { + return func(req *http.Request) { + req.URL = req.URL.JoinPath(secret) + } +} + +func identitySecretModifier(secret string) azureRequestModifier { + return func(req *http.Request) { + req.Header.Add(IdentitySecretHeader, secret) + } +} + +func emptySecretMethod(_ *http.Request) {} diff --git a/lib/srv/alpnproxy/forward_proxy.go b/lib/srv/alpnproxy/forward_proxy.go index f1fec62a8fab1..64ee8242a745c 100644 --- a/lib/srv/alpnproxy/forward_proxy.go +++ b/lib/srv/alpnproxy/forward_proxy.go @@ -202,7 +202,7 @@ func isAWSSSMWebsocketRequest(req *http.Request) bool { // request. func MatchAzureRequests(req *http.Request) bool { h := req.URL.Hostname() - return azure.IsAzureEndpoint(h) || types.TeleportAzureMSIEndpoint == h + return azure.IsAzureEndpoint(h) || types.TeleportAzureMSIEndpoint == h || types.TeleportAzureIdentityEndpoint == h } // MatchGCPRequests is a MatchFunc that returns true if request is an GCP API request. diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index 04cf06ff5e950..3192fcb95471f 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -257,8 +257,10 @@ func (s *handler) parseAuthHeader(token string, pubKey crypto.PublicKey) (*jwt.A // Create a new key that can sign and verify tokens. key, err := jwt.New(&jwt.Config{ - Clock: s.Clock, - PublicKey: pubKey, + Clock: s.Clock, + PublicKey: pubKey, + // TODO(gabrielcorado): use the cluster name. This value must match the + // one used by the local proxy middleware. ClusterName: types.TeleportAzureMSIEndpoint, }) if err != nil { diff --git a/lib/web/app/transport.go b/lib/web/app/transport.go index 918dc0fd47283..b3af7e27b3122 100644 --- a/lib/web/app/transport.go +++ b/lib/web/app/transport.go @@ -287,8 +287,10 @@ func (t *transport) resignAzureJWTCookie(r *http.Request) error { // Create a new jwt key using the client public key to verify and parse the token. clientJWTKey, err := jwt.New(&jwt.Config{ - Clock: t.c.clock, - PublicKey: r.TLS.PeerCertificates[0].PublicKey, + Clock: t.c.clock, + PublicKey: r.TLS.PeerCertificates[0].PublicKey, + // TODO(gabrielcorado): use the cluster name. This value must match the + // one used by the proxy. ClusterName: types.TeleportAzureMSIEndpoint, }) if err != nil { @@ -301,8 +303,10 @@ func (t *transport) resignAzureJWTCookie(r *http.Request) error { return trace.Wrap(err) } wsJWTKey, err := jwt.New(&jwt.Config{ - Clock: t.c.clock, - PrivateKey: wsPrivateKey, + Clock: t.c.clock, + PrivateKey: wsPrivateKey, + // TODO(gabrielcorado): use the cluster name. This value must match the + // one used by the proxy. ClusterName: types.TeleportAzureMSIEndpoint, }) if err != nil { diff --git a/tool/tsh/common/app.go b/tool/tsh/common/app.go index 43d20a558898f..8f5c38cd38743 100644 --- a/tool/tsh/common/app.go +++ b/tool/tsh/common/app.go @@ -108,7 +108,7 @@ func onAppLogin(cf *CLIConf) error { return trace.Wrap(err) } - if err := printAppCommand(cf, tc, app, routeToApp); err != nil { + if err := printAppCommand(cf, tc, app, appInfo); err != nil { return trace.Wrap(err) } @@ -132,7 +132,8 @@ func localProxyRequiredForApp(tc *client.TeleportClient) bool { return tc.TLSRoutingConnUpgradeRequired } -func printAppCommand(cf *CLIConf, tc *client.TeleportClient, app types.Application, routeToApp proto.RouteToApp) error { +func printAppCommand(cf *CLIConf, tc *client.TeleportClient, app types.Application, appInfo *appInfo) error { + routeToApp := appInfo.RouteToApp output := cf.Stdout() if cf.Quiet { output = io.Discard @@ -151,11 +152,24 @@ func printAppCommand(cf *CLIConf, tc *client.TeleportClient, app types.Applicati return trace.BadParameter("app is Azure Cloud but Azure identity is missing") } - var args []string + azureApp, err := newAzureApp(tc, cf, appInfo) + if err != nil { + return trace.Wrap(err) + } + + resourceArgumentName := "--username" + // After the CLI started relying in MSAL by default, the param for the + // managed identity changed. + // + // https://learn.microsoft.com/en-us/cli/azure/release-notes-azure-cli?view=azure-cli-latest#profile + if azureApp.usingMSAL() { + resourceArgumentName = "--resource-id" + } + + args := []string{"az", "login", "--identity", resourceArgumentName, routeToApp.AzureIdentity} if cf.Debug { args = append(args, "--debug") } - args = append(args, "az", "login", "--identity", "-u", routeToApp.AzureIdentity) // automatically login with right identity. cmd := exec.Command(cf.executablePath, args...) @@ -164,8 +178,7 @@ func printAppCommand(cf *CLIConf, tc *client.TeleportClient, app types.Applicati cmd.Stdout = output logger.DebugContext(cf.Context, "Running automatic az login", "command", logutils.StringerAttr(cmd)) - err := cf.RunCommand(cmd) - if err != nil { + if err := cf.RunCommand(cmd); err != nil { return trace.Wrap(err, "failed to automatically login with `az login` using identity %q; run with --debug for details", routeToApp.AzureIdentity) } diff --git a/tool/tsh/common/app_azure.go b/tool/tsh/common/app_azure.go index 74a70e770115d..d582d7bc5d6bb 100644 --- a/tool/tsh/common/app_azure.go +++ b/tool/tsh/common/app_azure.go @@ -19,16 +19,20 @@ package common import ( + "bytes" "context" "crypto" "crypto/tls" + "encoding/json" "fmt" "os" "os/exec" "path/filepath" "sort" "strings" + "sync" + "github.com/coreos/go-semver/semver" "github.com/google/uuid" "github.com/gravitational/trace" @@ -44,6 +48,24 @@ import ( const ( azureCLIBinaryName = "az" + + // msiEndpointEnvVarName defines the name of environment variable that + // contains the MSI endpoint value. + msiEndpointEnvVarName = "MSI_ENDPOINT" + // identityEndpointEnvVarName defines the name of environment variable that + // contains the App Service Identity endpoint value. + identityEndpointEnvVarName = "IDENTITY_ENDPOINT" + // identityHeaderEnvVarName defines the name of environment variable that + // contains the App Service Identity secret value. + identityHeaderEnvVarName = "IDENTITY_HEADER" +) + +var ( + // azureCLIVersionMSALRequirement represents the version the login with + // managed identities started using MSAL by default. + // + // https://learn.microsoft.com/en-us/cli/azure/release-notes-azure-cli?view=azure-cli-latest#profile + azureCLIVersionMSALRequirement = semver.New("2.73.0") ) func onAzure(cf *CLIConf) error { @@ -73,13 +95,15 @@ func onAzure(cf *CLIConf) error { type azureApp struct { *localProxyApp - cf *CLIConf - msiSecret string + cf *CLIConf + tokenSecret string + // fetchAzureCLIVersion retrieves the Azure CLI version. + fetchCLIVersion func() (*semver.Version, error) } // newAzureApp creates a new Azure app. func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azureApp, error) { - msiSecret, err := getMSISecret() + msiSecret, err := getAzureTokenSecret() if err != nil { return nil, err } @@ -91,13 +115,42 @@ func newAzureApp(tc *client.TeleportClient, cf *CLIConf, appInfo *appInfo) (*azu return &azureApp{ localProxyApp: localProxyApp, cf: cf, - msiSecret: msiSecret, + tokenSecret: msiSecret, + fetchCLIVersion: sync.OnceValues(func() (*semver.Version, error) { + // Retrieve the core version as it contains the login-related changes. + versionInfo := struct { + CLICoreVersion string `json:"azure-cli-core"` + }{} + + var buf bytes.Buffer + cmd := exec.Command(azureCLIBinaryName, "version") + cmd.Stdout = &buf + if err := cf.RunCommand(cmd); err != nil { + return nil, trace.Wrap(err) + } + + if err := json.Unmarshal(buf.Bytes(), &versionInfo); err != nil { + return nil, trace.Wrap(err) + } + + ver, err := semver.NewVersion(versionInfo.CLICoreVersion) + return ver, trace.Wrap(err) + }), }, nil } -// getMSISecret will try to find the secret by parsing MSI_ENDPOINT env variable if present; it will return random hex string otherwise. -func getMSISecret() (string, error) { - endpoint := os.Getenv("MSI_ENDPOINT") +// getAzureTokenSecret will try to find the secret from the environment. +// If not found it will return random hex string. +func getAzureTokenSecret() (string, error) { + if secret, err := getAzureIdentitySecretToken(); !trace.IsNotFound(err) { + if err != nil { + return "", trace.Wrap(err) + } + + return secret, nil + } + + endpoint := os.Getenv(msiEndpointEnvVarName) if endpoint == "" { randomHex, err := utils.CryptoRandomHex(10) if err != nil { @@ -108,7 +161,7 @@ func getMSISecret() (string, error) { expectedPrefix := "https://" + types.TeleportAzureMSIEndpoint + "/" if !strings.HasPrefix(endpoint, expectedPrefix) { - return "", trace.BadParameter("MSI_ENDPOINT not empty, but doesn't start with %q as expected", expectedPrefix) + return "", trace.BadParameter("%q environment variable not empty, but doesn't start with %q as expected", msiEndpointEnvVarName, expectedPrefix) } secret := strings.TrimPrefix(endpoint, expectedPrefix) @@ -118,21 +171,43 @@ func getMSISecret() (string, error) { return secret, nil } +// getAzureIdentitySecretToken returns the secret token for App Service Identity. +func getAzureIdentitySecretToken() (string, error) { + endpoint := os.Getenv(identityEndpointEnvVarName) + secret := os.Getenv(identityHeaderEnvVarName) + if endpoint == "" && secret == "" { + return "", trace.NotFound("App Service Identity environment variables not provided") + } + + if endpoint == "" || secret == "" { + return "", trace.BadParameter("%q and %q environment variables should be provided when using App Service Identity", identityEndpointEnvVarName, identityHeaderEnvVarName) + } + + expectedPrefix := "https://" + types.TeleportAzureIdentityEndpoint + if !strings.HasPrefix(endpoint, expectedPrefix) { + return "", trace.BadParameter("%s not empty, but doesn't start with %q as expected", identityEndpointEnvVarName, expectedPrefix) + } + + return secret, nil +} + // StartLocalProxies sets up local proxies for serving Azure clients. // // At minimum clients should work with these variables set: // - HTTPS_PROXY, for routing the traffic through the proxy -// - MSI_ENDPOINT, for informing the client about credential provider endpoint +// - MSI_ENDPOINT or IDENTITY_ENDPOINT, for informing the client about credential provider endpoint // // The request flow to remote server (i.e. Azure APIs) looks like this: // clients -> local forward proxy -> local ALPN proxy -> remote server // -// However, with MSI_ENDPOINT variable set, clients will reach out to this address for tokens. -// We intercept calls to https://azure-msi.teleport.dev using alpnproxy.AzureMSIMiddleware. -// These calls are served entirely locally, which helps the overall performance experienced by the user. +// However, with MSI_ENDPOINT or IDENTITY_ENDPOINT variable set, clients will +// reach out to this address for tokens. +// We intercept calls to those token endpoints using alpnproxy.AzureTokensMiddleware. +// These calls are served entirely locally, which helps the overall performance +// experienced by the user. func (a *azureApp) StartLocalProxies(ctx context.Context) error { - azureMiddleware := &alpnproxy.AzureMSIMiddleware{ - Secret: a.msiSecret, + azureMiddleware := &alpnproxy.AzureTokenMiddleware{ + Secret: a.tokenSecret, // we could, in principle, get the actual TenantID either from live data or from static configuration, // but at this moment there is no clear advantage over simply issuing a new random identifier. TenantID: uuid.New().String(), @@ -166,10 +241,6 @@ func (a *azureApp) GetEnvVars() (map[string]string, error) { // 2. `az ...` in another console // without custom config dir the second invocation will hang, attempting to connect to (inaccessible without configuration) MSI. "AZURE_CONFIG_DIR": filepath.Join(profile.FullProfilePath(a.cf.HomePath), "azure", a.routeToApp.ClusterName, a.routeToApp.Name), - // setting MSI_ENDPOINT instructs Azure CLI to make managed identity calls on this address. - // the requests will be handled by tsh proxy. - "MSI_ENDPOINT": "https://" + types.TeleportAzureMSIEndpoint + "/" + a.msiSecret, - // Needed for az CLI to accept our certs. // This isn't portable and applications other than az CLI may have to set different env variables, // add the application cert to system root store (not recommended, ultimate fallback) @@ -177,6 +248,22 @@ func (a *azureApp) GetEnvVars() (map[string]string, error) { "REQUESTS_CA_BUNDLE": a.profile.AppLocalCAPath(a.cf.SiteName, a.routeToApp.Name), } + if a.usingMSAL() { + // Setting App service Identity environment variables instructs Azure + // CLI to make managed identity calls on this address. The requests will + // be handled by tsh proxy. This is only required when Azure CLI + // defaults to using MSAL. + // + // https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference + envVars[identityEndpointEnvVarName] = "https://" + types.TeleportAzureIdentityEndpoint + envVars[identityHeaderEnvVarName] = a.tokenSecret + } else { + // Setting MSI environment variable instructs Azure CLI to make managed + // identity calls on this address. The requests will be handled by tsh + // proxy. + envVars[msiEndpointEnvVarName] = "https://" + types.TeleportAzureMSIEndpoint + "/" + a.tokenSecret + } + // Set proxy settings. if a.localForwardProxy != nil { envVars["HTTPS_PROXY"] = "http://" + a.localForwardProxy.GetAddr() @@ -207,6 +294,19 @@ func (a *azureApp) RunCommand(cmd *exec.Cmd) error { return nil } +// usingMSAL returns true if the CLI is using Microsoft Authentication +// Library (MSAL). +func (a *azureApp) usingMSAL() bool { + ver, err := a.fetchCLIVersion() + if err != nil { + logger.WarnContext(a.cf.Context, "Unable to determine Azure CLI version. Assuming MSAL will be used.", "error", err) + return true + } + + logger.DebugContext(a.cf.Context, "Azure CLI version", "version", ver) + return ver.Compare(*azureCLIVersionMSALRequirement) >= 0 +} + func printAzureIdentities(identities []string) { fmt.Println(formatAzureIdentities(identities)) } diff --git a/tool/tsh/common/app_azure_test.go b/tool/tsh/common/app_azure_test.go index 2a0db09e4490c..8254105499399 100644 --- a/tool/tsh/common/app_azure_test.go +++ b/tool/tsh/common/app_azure_test.go @@ -22,8 +22,7 @@ import ( "context" "crypto/tls" "encoding/hex" - "encoding/json" - "io" + "fmt" "net/http" "net/url" "os/exec" @@ -31,6 +30,7 @@ import ( "strings" "testing" + "github.com/coreos/go-semver/semver" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -79,136 +79,136 @@ func TestAzure(t *testing.T) { require.NoError(t, err) } - // set MSI_ENDPOINT along with secret - t.Setenv("MSI_ENDPOINT", "https://azure-msi.teleport.dev/very-secret") - - // Log into Teleport cluster. - run([]string{"login", "--insecure", "--debug", "--proxy", proxyAddr.String()}) - - // Log into the "azure-api" app. - // Verify `tsh az login ...` gets called. - run([]string{"app", "login", "--insecure", "--azure-identity", "dummy_azure_identity", "azure-api"}, - setCmdRunner(func(cmd *exec.Cmd) error { - require.Equal(t, []string{"az", "login", "--identity", "-u", "dummy_azure_identity"}, cmd.Args[1:]) - return nil - })) - - // Log into the "azure-api" app -- now with --debug flag. - run([]string{"app", "login", "--insecure", "azure-api", "--debug"}, - setCmdRunner(func(cmd *exec.Cmd) error { - require.Equal(t, []string{"--debug", "az", "login", "--identity", "-u", "dummy_azure_identity"}, cmd.Args[1:]) - return nil - })) - - // basic requests to verify we can dial proxy as expected - // more comprehensive tests cover AzureMSIMiddleware directly - requests := []struct { - name string - url string - headers map[string]string - expectedCode int - expectedBody []byte - verifyBody func(t *testing.T, body []byte) + getEnvValue := func(cmdEnv []string, key string) string { + for _, env := range cmdEnv { + if strings.HasPrefix(env, key+"=") { + return strings.TrimPrefix(env, key+"=") + } + } + return "" + } + + versionWithoutMSAL := semver.New(azureCLIVersionMSALRequirement.String()) + versionWithoutMSAL.Minor -= 1 + + for name, tc := range map[string]struct { + setEnvironment func(t *testing.T) + cliVersion *semver.Version + tokenEndpointURL string + expectedLoginCommand []string + assertCommandEnv require.ValueAssertionFunc }{ - { - name: "incomplete request", - url: "https://azure-msi.teleport.dev/very-secret", - headers: nil, - expectedCode: 400, - expectedBody: []byte("{\n \"error\": {\n \"message\": \"expected Metadata header with value 'true'\"\n }\n}"), + "MSI": { + setEnvironment: func(t *testing.T) { + // This is required to avoid having a random generated secret. + t.Setenv(msiEndpointEnvVarName, "https://azure-msi.teleport.dev/very-secret") + }, + cliVersion: versionWithoutMSAL, + tokenEndpointURL: "https://azure-msi.teleport.dev/very-secret", + expectedLoginCommand: []string{"az", "login", "--identity", "--username", "dummy_azure_identity"}, + assertCommandEnv: func(t require.TestingT, val any, msgAndArgs ...any) { + env := val.([]string) + require.Equal(t, "https://azure-msi.teleport.dev/very-secret", getEnvValue(env, msiEndpointEnvVarName)) + }, }, - { - name: "well-formatted request", - url: "https://azure-msi.teleport.dev/very-secret?resource=myresource&msi_res_id=dummy_azure_identity", - headers: map[string]string{"Metadata": "true"}, - expectedCode: 200, - verifyBody: func(t *testing.T, body []byte) { - var req struct { - AccessToken string `json:"access_token"` - ClientID string `json:"client_id"` - Resource string `json:"resource"` - TokenType string `json:"token_type"` - ExpiresIn int `json:"expires_in"` - ExpiresOn int `json:"expires_on"` - ExtExpiresIn int `json:"ext_expires_in"` - NotBefore int `json:"not_before"` + "Identity": { + setEnvironment: func(t *testing.T) { + // This is required to avoid having a random generated secret. + t.Setenv(identityEndpointEnvVarName, "https://azure-identity.teleport.dev") + t.Setenv(identityHeaderEnvVarName, "very-secret") + }, + cliVersion: azureCLIVersionMSALRequirement, + tokenEndpointURL: "https://azure-identity.teleport.dev", + expectedLoginCommand: []string{"az", "login", "--identity", "--resource-id", "dummy_azure_identity"}, + assertCommandEnv: func(t require.TestingT, val any, msgAndArgs ...any) { + env := val.([]string) + require.Equal(t, "https://azure-identity.teleport.dev", getEnvValue(env, identityEndpointEnvVarName)) + require.Equal(t, "very-secret", getEnvValue(env, identityHeaderEnvVarName)) + }, + }, + } { + t.Run("With"+name, func(t *testing.T) { + handleAzVersion := func(cmd *exec.Cmd) bool { + if len(cmd.Args) > 0 && cmd.Args[1] == "version" { + fmt.Fprintf(cmd.Stdout, `{ "azure-cli": "%s", "azure-cli-core": "%s", "azure-cli-telemetry": "1.1.0", "extensions": {} }`, tc.cliVersion.String(), tc.cliVersion.String()) + return true } + return false + } - require.NoError(t, json.Unmarshal(body, &req)) + tc.setEnvironment(t) - require.NotEmpty(t, req.AccessToken) - require.NotEmpty(t, req.ClientID) - require.Equal(t, "myresource", req.Resource) - require.NotZero(t, req.ExpiresIn) - require.NotZero(t, req.ExpiresOn) - require.NotZero(t, req.ExtExpiresIn) - require.NotZero(t, req.NotBefore) - }, - }, - } + // Log into Teleport cluster. + run([]string{"login", "--insecure", "--debug", "--proxy", proxyAddr.String()}) + + // Log into the "azure-api" app. + // Verify `tsh az login ...` gets called. + run([]string{"app", "login", "--insecure", "--azure-identity", "dummy_azure_identity", "azure-api"}, + setCmdRunner(func(cmd *exec.Cmd) error { + if handleAzVersion(cmd) { + return nil + } - // Run `tsh az vm ls`. Verify executed command and environment. - run([]string{"az", "vm", "ls", "-g", "my-group"}, - setCmdRunner(func(cmd *exec.Cmd) error { - require.Equal(t, []string{"az", "vm", "ls", "-g", "my-group"}, cmd.Args) + require.Equal(t, tc.expectedLoginCommand, cmd.Args[1:]) + return nil + })) - getEnvValue := func(key string) string { - for _, env := range cmd.Env { - if strings.HasPrefix(env, key+"=") { - return strings.TrimPrefix(env, key+"=") + // Log into the "azure-api" app -- now with --debug flag. + run([]string{"app", "login", "--insecure", "azure-api", "--debug"}, + setCmdRunner(func(cmd *exec.Cmd) error { + if handleAzVersion(cmd) { + return nil } - } - return "" - } - require.Equal(t, filepath.Join(tmpHomePath, "azure/localhost/azure-api"), getEnvValue("AZURE_CONFIG_DIR")) - require.Equal(t, "https://azure-msi.teleport.dev/very-secret", getEnvValue("MSI_ENDPOINT")) - require.Equal(t, filepath.Join(tmpHomePath, "keys/127.0.0.1/alice@example.com-app/localhost/azure-api-localca.pem"), getEnvValue("REQUESTS_CA_BUNDLE")) - require.True(t, strings.HasPrefix(getEnvValue("HTTPS_PROXY"), "http://127.0.0.1:")) + require.Equal(t, append(tc.expectedLoginCommand, "--debug"), cmd.Args[1:]) + return nil + })) - // Validate MSI endpoint can be reached - caPool, err := utils.NewCertPoolFromPath(getEnvValue("REQUESTS_CA_BUNDLE")) - require.NoError(t, err) + // Run `tsh az vm ls`. Verify executed command and environment. + run([]string{"az", "vm", "ls", "-g", "my-group"}, + setCmdRunner(func(cmd *exec.Cmd) error { + if handleAzVersion(cmd) { + return nil + } - httpsProxy, err := url.Parse(getEnvValue("HTTPS_PROXY")) - require.NoError(t, err) + require.Equal(t, []string{"az", "vm", "ls", "-g", "my-group"}, cmd.Args) - client := &http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(httpsProxy), - TLSClientConfig: &tls.Config{ - RootCAs: caPool, - }, - }, - } + require.Equal(t, filepath.Join(tmpHomePath, "azure/localhost/azure-api"), getEnvValue(cmd.Env, "AZURE_CONFIG_DIR")) + require.Equal(t, filepath.Join(tmpHomePath, "keys/127.0.0.1/alice@example.com-app/localhost/azure-api-localca.pem"), getEnvValue(cmd.Env, "REQUESTS_CA_BUNDLE")) + require.True(t, strings.HasPrefix(getEnvValue(cmd.Env, "HTTPS_PROXY"), "http://127.0.0.1:")) + + tc.assertCommandEnv(t, cmd.Env) - for _, tc := range requests { - t.Run(tc.name, func(t *testing.T) { - req, err := http.NewRequest("GET", tc.url, nil) + // Validate MSI endpoint can be reached + caPool, err := utils.NewCertPoolFromPath(getEnvValue(cmd.Env, "REQUESTS_CA_BUNDLE")) require.NoError(t, err) - for k, v := range tc.headers { - req.Header.Set(k, v) + httpsProxy, err := url.Parse(getEnvValue(cmd.Env, "HTTPS_PROXY")) + require.NoError(t, err) + + // Dial using the Azure token service to ensure it will be + // reachable and handled when requested by the Azure CLI. + client := &http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(httpsProxy), + TLSClientConfig: &tls.Config{RootCAs: caPool}, + }, } - resp, err := client.Do(req) + req, err := http.NewRequest("GET", tc.tokenEndpointURL, nil) require.NoError(t, err) - require.Equal(t, tc.expectedCode, resp.StatusCode) - body, err := io.ReadAll(resp.Body) + // Given the missing params, the request should return error. + resp, err := client.Do(req) require.NoError(t, err) - require.NoError(t, resp.Body.Close()) + defer resp.Body.Close() + require.NotNil(t, resp) + require.Equal(t, http.StatusBadRequest, resp.StatusCode) - if tc.verifyBody != nil { - tc.verifyBody(t, body) - } else { - require.Equal(t, tc.expectedBody, body) - } - }) - } - - return nil - })) + return nil + })) + }) + } } func makeUserWithAzureRole(t *testing.T) (types.User, types.Role) { @@ -401,17 +401,19 @@ func Test_getAzureIdentityFromFlags(t *testing.T) { } } -func Test_getMSISecret(t *testing.T) { +func Test_getAzureTokenSecret(t *testing.T) { tests := []struct { - name string - env string - want string - wantFunc func(t require.TestingT, result string) - wantErr require.ErrorAssertionFunc + name string + msiEndpoint string + identityHeader string + identityEndpoint string + want string + wantFunc func(t require.TestingT, result string) + wantErr require.ErrorAssertionFunc }{ { - name: "no env", - env: "", + name: "no env", + msiEndpoint: "", wantFunc: func(t require.TestingT, result string) { bytes, err := hex.DecodeString(result) require.NoError(t, err) @@ -421,31 +423,55 @@ func Test_getMSISecret(t *testing.T) { wantErr: require.NoError, }, { - name: "MSI_ENDPOINT with secret", - env: "https://azure-msi.teleport.dev/mysecret", - want: "mysecret", - wantErr: require.NoError, + name: "MSI_ENDPOINT with secret", + msiEndpoint: "https://" + types.TeleportAzureMSIEndpoint + "/mysecret", + want: "mysecret", + wantErr: require.NoError, }, { - name: "MSI_ENDPOINT with invalid prefix", - env: "dummy", + name: "MSI_ENDPOINT with invalid prefix", + msiEndpoint: "dummy", wantErr: func(t require.TestingT, err error, i ...interface{}) { - require.ErrorContains(t, err, `MSI_ENDPOINT not empty, but doesn't start with "https://azure-msi.teleport.dev/" as expected`) + require.ErrorContains(t, err, `"MSI_ENDPOINT" environment variable not empty, but doesn't start with "https://azure-msi.teleport.dev/" as expected`) }, }, { - name: "MSI_ENDPOINT without secret", - env: "https://azure-msi.teleport.dev/", + name: "MSI_ENDPOINT without secret", + msiEndpoint: "https://" + types.TeleportAzureMSIEndpoint + "/", wantErr: func(t require.TestingT, err error, i ...interface{}) { require.ErrorContains(t, err, "MSI secret cannot be empty") }, }, + { + name: "IDENTITY_HEADER and IDENTITY_ENDPOINT present", + identityHeader: "secret", + identityEndpoint: "https://" + types.TeleportAzureIdentityEndpoint, + want: "secret", + wantErr: require.NoError, + }, + { + name: "IDENTITY_HEADER present without endpoint", + identityHeader: "secret", + wantErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, `IDENTITY_HEADER`) + }, + }, + { + name: "Identity and MSI present, identity takes precedence", + identityHeader: "secret", + identityEndpoint: "https://" + types.TeleportAzureIdentityEndpoint, + msiEndpoint: "https://azure-msi.teleport.dev/different-secret", + want: "secret", + wantErr: require.NoError, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - t.Setenv("MSI_ENDPOINT", tt.env) - result, err := getMSISecret() + t.Setenv(msiEndpointEnvVarName, tt.msiEndpoint) + t.Setenv(identityHeaderEnvVarName, tt.identityHeader) + t.Setenv(identityEndpointEnvVarName, tt.identityEndpoint) + result, err := getAzureTokenSecret() tt.wantErr(t, err) if tt.wantFunc != nil { tt.wantFunc(t, result)