Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/types/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,8 @@ const (

// TeleportAzureMSIEndpoint is a special URL intercepted by TSH local proxy, serving Azure credentials.
TeleportAzureMSIEndpoint = "azure-msi." + TeleportNamespace
// TeleportAzureIdentityEndpoint is a special URL intercepted by TSH local proxy, serving Azure credentials.
TeleportAzureIdentityEndpoint = "azure-identity." + TeleportNamespace

// ConnectMyComputerNodeOwnerLabel is a label used to control access to the node managed by
// Teleport Connect as part of Connect My Computer. See [teleterm.connectmycomputer.RoleSetup].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"fmt"
"log/slog"
"net/http"
"strings"
"sync"
"time"

Expand All @@ -35,8 +36,11 @@ import (
"github.com/gravitational/teleport/lib/jwt"
)

// AzureMSIMiddleware implements a simplified version of MSI server serving auth tokens.
type AzureMSIMiddleware struct {
// AzureTokenMiddleware implements a simplified version of MSI and Identity
// servers serving auth tokens.
//
// https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference
type AzureTokenMiddleware struct {
DefaultLocalProxyHTTPMiddleware

// Identity is the Azure identity to be served by the server. Only single identity will be provided.
Expand All @@ -58,14 +62,14 @@ type AzureMSIMiddleware struct {
privateKeyMu sync.RWMutex
}

var _ LocalProxyHTTPMiddleware = &AzureMSIMiddleware{}
var _ LocalProxyHTTPMiddleware = &AzureTokenMiddleware{}

func (m *AzureMSIMiddleware) CheckAndSetDefaults() error {
func (m *AzureTokenMiddleware) CheckAndSetDefaults() error {
if m.Clock == nil {
m.Clock = clockwork.NewRealClock()
}
if m.Log == nil {
m.Log = slog.With(teleport.ComponentKey, "azure_msi")
m.Log = slog.With(teleport.ComponentKey, "azure_token")
}

if m.Secret == "" {
Expand All @@ -83,37 +87,45 @@ func (m *AzureMSIMiddleware) CheckAndSetDefaults() error {
return nil
}

func (m *AzureMSIMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool {
if req.Host == types.TeleportAzureMSIEndpoint {
if err := m.msiEndpoint(rw, req); err != nil {
m.Log.WarnContext(req.Context(), "Bad MSI request", "error", err)
trace.WriteError(rw, trace.Wrap(err))
}
return true
func (m *AzureTokenMiddleware) HandleRequest(rw http.ResponseWriter, req *http.Request) bool {
var err error
switch req.Host {
case types.TeleportAzureMSIEndpoint:
err = m.handleEndpoint(rw, req, MSIResourceFieldName, strings.TrimPrefix(req.URL.Path, "/"))
case types.TeleportAzureIdentityEndpoint:
err = m.handleEndpoint(rw, req, IdentityResourceFieldName, req.Header.Get(IdentitySecretHeader))
default:
m.Log.DebugContext(req.Context(), "Unsupported token host", "host", req.Host)
return false
}
if err != nil {
m.Log.WarnContext(req.Context(), "Bad token request", "error", err)
trace.WriteError(rw, trace.Wrap(err))
}

return false
return true
}

// SetPrivateKey updates the private key.
func (m *AzureMSIMiddleware) SetPrivateKey(privateKey crypto.Signer) {
func (m *AzureTokenMiddleware) SetPrivateKey(privateKey crypto.Signer) {
m.privateKeyMu.Lock()
defer m.privateKeyMu.Unlock()
m.privateKey = privateKey
}
func (m *AzureMSIMiddleware) getPrivateKey() (crypto.Signer, error) {
func (m *AzureTokenMiddleware) getPrivateKey() (crypto.Signer, error) {
m.privateKeyMu.RLock()
defer m.privateKeyMu.RUnlock()
if m.privateKey == nil {
// Use a plain error to return status code 500.
return nil, trace.Errorf("missing private key set in AzureMSIMiddleware")
return nil, trace.Errorf("missing private key set in AzureTokenMiddleware")
}
return m.privateKey, nil
}

func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Request) error {
// handleEndpoint handles the Azure identity token generation.
func (m *AzureTokenMiddleware) handleEndpoint(rw http.ResponseWriter, req *http.Request, resourceFieldName string, secret string) error {
// request validation
if req.URL.Path != ("/" + m.Secret) {
if secret != m.Secret {
return trace.BadParameter("invalid secret")
}

Expand All @@ -122,8 +134,7 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque
return trace.BadParameter("expected Metadata header with value 'true'")
}

err := req.ParseForm()
if err != nil {
if err := req.ParseForm(); err != nil {
return trace.Wrap(err)
}

Expand All @@ -132,19 +143,19 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque
return trace.BadParameter("missing value for parameter 'resource'")
}

// check that msi_res_id matches expected Azure Identity
requestedAzureIdentity := req.Form.Get("msi_res_id")
// check that resource field matches expected Azure Identity
requestedAzureIdentity := req.Form.Get(resourceFieldName)
if requestedAzureIdentity != m.Identity {
m.Log.WarnContext(req.Context(), "Requested unexpected identity", "requested_identity", requestedAzureIdentity, "expected_identity", m.Identity)
return trace.BadParameter("unexpected value for parameter 'msi_res_id': %v", requestedAzureIdentity)
return trace.BadParameter("unexpected value for parameter '%s': %v", resourceFieldName, requestedAzureIdentity)
}

respBody, err := m.fetchMSILoginResp(resource)
respBody, err := m.fetchLoginResp(resource)
if err != nil {
return trace.Wrap(err)
}

m.Log.InfoContext(req.Context(), "MSI: returning token for identity", "identity", m.Identity)
m.Log.InfoContext(req.Context(), "Returning token for identity", "identity", m.Identity)

rw.Header().Add("Content-Type", "application/json; charset=utf-8")
rw.Header().Add("Content-Length", fmt.Sprintf("%v", len(respBody)))
Expand All @@ -153,7 +164,7 @@ func (m *AzureMSIMiddleware) msiEndpoint(rw http.ResponseWriter, req *http.Reque
return nil
}

func (m *AzureMSIMiddleware) fetchMSILoginResp(resource string) ([]byte, error) {
func (m *AzureTokenMiddleware) fetchLoginResp(resource string) ([]byte, error) {
now := m.Clock.Now()

notBefore := now.Add(-10 * time.Second)
Expand Down Expand Up @@ -187,16 +198,18 @@ func (m *AzureMSIMiddleware) fetchMSILoginResp(resource string) ([]byte, error)
return out, nil
}

func (m *AzureMSIMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error) {
func (m *AzureTokenMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error) {
privateKey, err := m.getPrivateKey()
if err != nil {
return "", trace.Wrap(err)
}
// Create a new key that can sign and verify tokens.
key, err := jwt.New(&jwt.Config{
Clock: m.Clock,
PrivateKey: privateKey,
ClusterName: types.TeleportAzureMSIEndpoint, // todo get cluster name
Clock: m.Clock,
PrivateKey: privateKey,
// TODO(gabrielcorado): use the cluster name. This value must match the
// one used by the proxy.
ClusterName: types.TeleportAzureMSIEndpoint,
})
if err != nil {
return "", trace.Wrap(err)
Expand All @@ -209,3 +222,19 @@ func (m *AzureMSIMiddleware) toJWT(claims jwt.AzureTokenClaims) (string, error)

return token, nil
}

const (
// IdentitySecretHeader is the HTTP header that contains the identity
// secret on App Service identity requests.
//
// https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference
IdentitySecretHeader = "X-IDENTITY-HEADER"
// IdentityResourceFieldName is the request field name that contains the
// Azure identity on App Service identity requests.
//
// https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#rest-endpoint-reference
IdentityResourceFieldName = "mi_res_id"
// MSIResourceFieldName is the request field name that contains the Azure
// Identity on MSI identity requests.
MSIResourceFieldName = "msi_res_id"
)
Loading
Loading