diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index af797b51..9a9c9ee1 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -21,6 +21,9 @@ import ( "testing" "time" + "github.com/golang-jwt/jwt/v5" + "github.com/kylelemons/godebug/pretty" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" @@ -28,8 +31,6 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" - "github.com/golang-jwt/jwt/v5" - "github.com/kylelemons/godebug/pretty" ) // errorClient is an HTTP client for tests that should fail when confidential.Client sends a request @@ -65,10 +66,10 @@ func TestCertFromPEM(t *testing.T) { const ( authorityFmt = "https://%s/%s" - fakeAuthority = "https://fake_authority/fake" + fakeAuthority = "https://fake_authority/fake_tenant" fakeClientID = "fake_client_id" fakeSecret = "fake_secret" - fakeTokenEndpoint = "https://fake_authority/fake/token" + fakeTokenEndpoint = "https://fake_authority/fake_tenant/token" localhost = "http://localhost" refresh = "fake_refresh" token = "fake_token" @@ -76,7 +77,7 @@ const ( var tokenScope = []string{"the_scope"} -func fakeClient(tk accesstokens.TokenResponse, credential Credential, options ...Option) (Client, error) { +func fakeClient(tk accesstokens.TokenResponse, credential Credential, fakeAuthority string, options ...Option) (Client, error) { client, err := New(fakeAuthority, fakeClientID, credential, options...) if err != nil { return Client{}, err @@ -86,7 +87,7 @@ func fakeClient(tk accesstokens.TokenResponse, credential Credential, options .. } client.base.Token.Authority = &fake.Authority{ InstanceResp: authority.InstanceDiscoveryResponse{ - TenantDiscoveryEndpoint: "https://fake_authority/fake/discovery/endpoint", + TenantDiscoveryEndpoint: fakeAuthority + "/discovery/endpoint", Metadata: []authority.InstanceDiscoveryMetadata{ { PreferredNetwork: "fake_authority", @@ -104,8 +105,12 @@ func fakeClient(tk accesstokens.TokenResponse, credential Credential, options .. }, } client.base.Token.Resolver = &fake.ResolveEndpoints{ - Endpoints: authority.NewEndpoints("https://fake_authority/fake/auth", - fakeTokenEndpoint, "https://fake_authority/fake/jwt", "fake_authority"), + Endpoints: authority.NewEndpoints( + fakeAuthority+"/auth", + fakeAuthority+"/token", + fakeAuthority+"/jwt", + fakeAuthority, + ), } client.base.Token.WSTrust = &fake.WSTrust{} return client, nil @@ -137,7 +142,7 @@ func TestAcquireTokenByCredential(t *testing.T) { ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, TokenType: "Bearer", - }, cred) + }, cred, fakeAuthority) if err != nil { t.Fatal(err) } @@ -304,7 +309,7 @@ func TestAcquireTokenByAssertionCallback(t *testing.T) { return "", errors.New("expected error") } cred := NewCredFromAssertionCallback(getAssertion) - client, err := fakeClient(accesstokens.TokenResponse{}, cred) + client, err := fakeClient(accesstokens.TokenResponse{}, cred, fakeAuthority) if err != nil { t.Fatal(err) } @@ -348,7 +353,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) { Oid: "123-456", TenantID: "fake", Subject: "nothing", - Issuer: "https://fake_authority/fake", + Issuer: fakeAuthority, Audience: "abc-123", ExpirationTime: time.Now().Add(time.Hour).Unix(), IssuedAt: time.Now().Add(-5 * time.Minute).Unix(), @@ -363,7 +368,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) { }, } - client, err := fakeClient(tr, cred) + client, err := fakeClient(tr, cred, fakeAuthority) if err != nil { t.Fatal(err) } @@ -590,7 +595,7 @@ func TestNewCredFromCert(t *testing.T) { AccessToken: token, ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, - }, cred, opts...) + }, cred, fakeAuthority, opts...) if err != nil { t.Fatal(err) } @@ -1382,7 +1387,7 @@ func TestWithAuthenticationScheme(t *testing.T) { ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, TokenType: "TokenType", - }, cred) + }, cred, fakeAuthority) if err != nil { t.Fatal(err) } @@ -1401,3 +1406,59 @@ func TestWithAuthenticationScheme(t *testing.T) { t.Fatalf(`unexpected access token "%s"`, result.AccessToken) } } + +func TestAcquireTokenByCredentialFromDSTS(t *testing.T) { + tests := map[string]struct { + cred string + }{ + "secret": {cred: "fake_secret"}, + "signed assertion": {cred: "fake_assertion"}, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + cred, err := NewCredFromSecret(test.cred) + if err != nil { + t.Fatal(err) + } + client, err := fakeClient(accesstokens.TokenResponse{ + AccessToken: token, + ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)}, + GrantedScopes: accesstokens.Scopes{Slice: tokenScope}, + TokenType: "Bearer", + }, cred, "https://fake_authority/dstsv2/"+authority.DSTSTenant) + if err != nil { + t.Fatal(err) + } + + // expect first attempt to fail + _, err = client.AcquireTokenSilent(context.Background(), tokenScope) + if err == nil { + t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err) + } + + tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope) + if err != nil { + t.Errorf("got err == %s, want err == nil", err) + } + if tk.AccessToken != token { + t.Errorf("unexpected access token %s", tk.AccessToken) + } + + tk, err = client.AcquireTokenSilent(context.Background(), tokenScope) + if err != nil { + t.Errorf("got err == %s, want err == nil", err) + } + if tk.AccessToken != token { + t.Errorf("unexpected access token %s", tk.AccessToken) + } + + // fail for another tenant + tk, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("other")) + if err == nil { + t.Errorf("unexpected nil error from AcquireTokenSilent: %s", err) + } + }) + } +} diff --git a/apps/internal/oauth/oauth.go b/apps/internal/oauth/oauth.go index 5dd9fe08..e0653134 100644 --- a/apps/internal/oauth/oauth.go +++ b/apps/internal/oauth/oauth.go @@ -10,6 +10,8 @@ import ( "io" "time" + "github.com/google/uuid" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" @@ -18,7 +20,6 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs" - "github.com/google/uuid" ) // ResolveEndpointer contains the methods for resolving authority endpoints. diff --git a/apps/internal/oauth/ops/authority/authority.go b/apps/internal/oauth/ops/authority/authority.go index 360a9f07..a49e0357 100644 --- a/apps/internal/oauth/ops/authority/authority.go +++ b/apps/internal/oauth/ops/authority/authority.go @@ -23,7 +23,7 @@ import ( const ( authorizationEndpoint = "https://%v/%v/oauth2/v2.0/authorize" - instanceDiscoveryEndpoint = "https://%v/common/discovery/instance" + aadInstanceDiscoveryEndpoint = "https://%v/common/discovery/instance" tenantDiscoveryEndpointWithRegion = "https://%s.%s/%s/v2.0/.well-known/openid-configuration" regionName = "REGION_NAME" defaultAPIVersion = "2021-10-01" @@ -136,8 +136,12 @@ const ( const ( AAD = "MSSTS" ADFS = "ADFS" + DSTS = "DSTS" ) +// DSTSTenant is referenced throughout multiple files, let us use a const in case we ever need to change it. +const DSTSTenant = "7a433bfc-2514-4697-b467-e0933190487f" + // AuthenticationScheme is an extensibility mechanism designed to be used only by Azure Arc for proof of possession access tokens. type AuthenticationScheme interface { // Extra parameters that are added to the request to the /token endpoint. @@ -235,23 +239,26 @@ func NewAuthParams(clientID string, authorityInfo Info) AuthParams { // - the client is configured to authenticate only Microsoft accounts via the "consumers" endpoint // - the resulting authority URL is invalid func (p AuthParams) WithTenant(ID string) (AuthParams, error) { - switch ID { - case "", p.AuthorityInfo.Tenant: - // keep the default tenant because the caller didn't override it + if ID == "" || ID == p.AuthorityInfo.Tenant { return p, nil - case "common", "consumers", "organizations": - if p.AuthorityInfo.AuthorityType == AAD { + } + + var authority string + switch p.AuthorityInfo.AuthorityType { + case AAD: + if ID == "common" || ID == "consumers" || ID == "organizations" { return p, fmt.Errorf(`tenant ID must be a specific tenant, not "%s"`, ID) } - // else we'll return a better error below - } - if p.AuthorityInfo.AuthorityType != AAD { - return p, errors.New("the authority doesn't support tenants") - } - if p.AuthorityInfo.Tenant == "consumers" { - return p, errors.New(`client is configured to authenticate only personal Microsoft accounts, via the "consumers" endpoint`) + if p.AuthorityInfo.Tenant == "consumers" { + return p, errors.New(`client is configured to authenticate only personal Microsoft accounts, via the "consumers" endpoint`) + } + authority = "https://" + path.Join(p.AuthorityInfo.Host, ID) + case ADFS: + return p, errors.New("ADFS authority doesn't support tenants") + case DSTS: + return p, errors.New("dSTS authority doesn't support tenants") } - authority := "https://" + path.Join(p.AuthorityInfo.Host, ID) + info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled) if err == nil { info.Region = p.AuthorityInfo.Region @@ -343,44 +350,50 @@ type Info struct { Host string CanonicalAuthorityURI string AuthorityType string - UserRealmURIPrefix string ValidateAuthority bool Tenant string Region string InstanceDiscoveryDisabled bool } -func firstPathSegment(u *url.URL) (string, error) { - pathParts := strings.Split(u.EscapedPath(), "/") - if len(pathParts) >= 2 { - return pathParts[1], nil - } - - return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) -} - // NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided. func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) { u, err := url.Parse(strings.ToLower(authority)) - if err != nil || u.Scheme != "https" { - return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`) + if err != nil { + return Info{}, fmt.Errorf("couldn't parse authority url: %w", err) + } + if u.Scheme != "https" { + return Info{}, errors.New("authority url scheme must be https") } - tenant, err := firstPathSegment(u) - if err != nil { - return Info{}, err + pathParts := strings.Split(u.EscapedPath(), "/") + if len(pathParts) < 2 { + return Info{}, errors.New(`authority must be an URL such as "https://login.microsoftonline.com/"`) } - authorityType := AAD - if tenant == "adfs" { + + var authorityType, tenant string + switch pathParts[1] { + case "adfs": authorityType = ADFS + case "dstsv2": + if len(pathParts) != 3 { + return Info{}, fmt.Errorf("dSTS authority must be an https URL such as https:///dstsv2/%s", DSTSTenant) + } + if pathParts[2] != DSTSTenant { + return Info{}, fmt.Errorf("dSTS authority only accepts a single tenant %q", DSTSTenant) + } + authorityType = DSTS + tenant = DSTSTenant + default: + authorityType = AAD + tenant = pathParts[1] } // u.Host includes the port, if any, which is required for private cloud deployments return Info{ Host: u.Host, - CanonicalAuthorityURI: fmt.Sprintf("https://%v/%v/", u.Host, tenant), + CanonicalAuthorityURI: authority, AuthorityType: authorityType, - UserRealmURIPrefix: fmt.Sprintf("https://%v/common/userrealm/", u.Hostname()), ValidateAuthority: validateAuthority, Tenant: tenant, InstanceDiscoveryDisabled: instanceDiscoveryDisabled, @@ -524,7 +537,7 @@ func (c Client) AADInstanceDiscovery(ctx context.Context, authorityInfo Info) (I discoveryHost = authorityInfo.Host } - endpoint := fmt.Sprintf(instanceDiscoveryEndpoint, discoveryHost) + endpoint := fmt.Sprintf(aadInstanceDiscoveryEndpoint, discoveryHost) err = c.Comm.JSONCall(ctx, endpoint, http.Header{}, qv, nil, &resp) } return resp, err diff --git a/apps/internal/oauth/ops/authority/authority_test.go b/apps/internal/oauth/ops/authority/authority_test.go index 0ce103fc..6795a8f1 100644 --- a/apps/internal/oauth/ops/authority/authority_test.go +++ b/apps/internal/oauth/ops/authority/authority_test.go @@ -14,6 +14,8 @@ import ( "strings" "testing" + "github.com/google/uuid" + "github.com/kylelemons/godebug/pretty" ) @@ -212,7 +214,7 @@ func TestAADInstanceDiscovery(t *testing.T) { }, { desc: "Success with authorityInfo.Host not in trusted list", - endpoint: fmt.Sprintf(instanceDiscoveryEndpoint, defaultHost), + endpoint: fmt.Sprintf(aadInstanceDiscoveryEndpoint, defaultHost), authInfo: Info{ Host: "host", Tenant: "tenant", @@ -225,7 +227,7 @@ func TestAADInstanceDiscovery(t *testing.T) { }, { desc: "Success with authorityInfo.Host in trusted list", - endpoint: fmt.Sprintf(instanceDiscoveryEndpoint, "login.microsoftonline.de"), + endpoint: fmt.Sprintf(aadInstanceDiscoveryEndpoint, "login.microsoftonline.de"), authInfo: Info{ Host: "login.microsoftonline.de", Tenant: "tenant", @@ -305,7 +307,6 @@ func TestCreateAuthorityInfoFromAuthorityUri(t *testing.T) { Host: "login.microsoftonline.com", CanonicalAuthorityURI: authorityURI, AuthorityType: "MSSTS", - UserRealmURIPrefix: "https://login.microsoftonline.com/common/userrealm/", Tenant: "common", ValidateAuthority: true, } @@ -320,22 +321,32 @@ func TestCreateAuthorityInfoFromAuthorityUri(t *testing.T) { } func TestAuthParamsWithTenant(t *testing.T) { - uuid1 := "00000000-0000-0000-0000-000000000000" - uuid2 := strings.ReplaceAll(uuid1, "0", "1") + uuid1 := uuid.New().String() + uuid2 := uuid.New().String() host := "https://localhost/" - for _, test := range []struct { + + tests := map[string]struct { authority, expectedAuthority, tenant string expectError bool }{ - {authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1}, - {authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1}, - {authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2}, - {authority: host + uuid1, tenant: "common", expectError: true}, - {authority: host + uuid1, tenant: "organizations", expectError: true}, - {authority: host + "adfs", tenant: uuid1, expectError: true}, - {authority: host + "consumers", tenant: uuid1, expectError: true}, - } { - t.Run("", func(t *testing.T) { + "do nothing if tenant override is empty": {authority: host + uuid1, tenant: "", expectedAuthority: host + uuid1}, + "do nothing if tenant override is empty for ADFS": {authority: host + "adfs", tenant: "", expectedAuthority: host + "adfs"}, + "do nothing if tenant override equals tenant": {authority: host + uuid1, tenant: uuid1, expectedAuthority: host + uuid1}, + + "override common to tenant": {authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1}, + "override organizations to tenant": {authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1}, + "override tenant to tenant2": {authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2}, + + "tenant can't be common for AAD": {authority: host + uuid1, tenant: "common", expectError: true}, + "tenant can't be consumers for AAD": {authority: host + uuid1, tenant: "consumers", expectError: true}, + "tenant can't be organizations for AAD": {authority: host + uuid1, tenant: "organizations", expectError: true}, + "can't override tenant for ADFS ever": {authority: host + "adfs", tenant: uuid1, expectError: true}, + "can't override tenant for dSTS ever": {authority: host + "dstsv2/" + DSTSTenant, tenant: uuid1, expectError: true}, + "can't override AAD tenant consumers": {authority: host + "consumers", tenant: uuid1, expectError: true}, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { info, err := NewInfoFromAuthorityURI(test.authority, false, false) if err != nil { t.Fatal(err) diff --git a/apps/internal/oauth/ops/internal/comm/comm.go b/apps/internal/oauth/ops/internal/comm/comm.go index 7d9ec7cd..d62aac74 100644 --- a/apps/internal/oauth/ops/internal/comm/comm.go +++ b/apps/internal/oauth/ops/internal/comm/comm.go @@ -18,10 +18,11 @@ import ( "strings" "time" + "github.com/google/uuid" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" customJSON "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/version" - "github.com/google/uuid" ) // HTTPClient represents an HTTP client. @@ -70,15 +71,13 @@ func (c *Client) JSONCall(ctx context.Context, endpoint string, headers http.Hea unmarshal = customJSON.Unmarshal } - u, err := url.Parse(endpoint) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("%s?%s", endpoint, qv.Encode()), nil) if err != nil { - return fmt.Errorf("could not parse path URL(%s): %w", endpoint, err) + return fmt.Errorf("could not create request: %w", err) } - u.RawQuery = qv.Encode() addStdHeaders(headers) - - req := &http.Request{Method: http.MethodGet, URL: u, Header: headers} + req.Header = headers if body != nil { // Note: In case your wondering why we are not gzip encoding.... diff --git a/apps/internal/oauth/resolvers.go b/apps/internal/oauth/resolvers.go index 0ade4117..4030ec8d 100644 --- a/apps/internal/oauth/resolvers.go +++ b/apps/internal/oauth/resolvers.go @@ -18,9 +18,6 @@ import ( "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority" ) -// ADFS is an active directory federation service authority type. -const ADFS = "ADFS" - type cacheEntry struct { Endpoints authority.Endpoints ValidForDomainsInList map[string]bool @@ -51,7 +48,7 @@ func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo return endpoints, nil } - endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName) + endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo) if err != nil { return authority.Endpoints{}, err } @@ -83,7 +80,7 @@ func (m *authorityEndpoint) cachedEndpoints(authorityInfo authority.Info, userPr defer m.mu.Unlock() if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok { - if authorityInfo.AuthorityType == ADFS { + if authorityInfo.AuthorityType == authority.ADFS { domain, err := adfsDomainFromUpn(userPrincipalName) if err == nil { if _, ok := cacheEntry.ValidForDomainsInList[domain]; ok { @@ -102,7 +99,7 @@ func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, use updatedCacheEntry := createcacheEntry(endpoints) - if authorityInfo.AuthorityType == ADFS { + if authorityInfo.AuthorityType == authority.ADFS { // Since we're here, we've made a call to the backend. We want to ensure we're caching // the latest values from the server. if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok { @@ -119,9 +116,12 @@ func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, use m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry } -func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) { - if authorityInfo.Tenant == "adfs" { +func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info) (string, error) { + if authorityInfo.AuthorityType == authority.ADFS { return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil + } else if authorityInfo.AuthorityType == authority.DSTS { + return fmt.Sprintf("https://%s/dstsv2/%s/v2.0/.well-known/openid-configuration", authorityInfo.Host, authority.DSTSTenant), nil + } else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) { resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo) if err != nil { @@ -134,7 +134,6 @@ func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, aut return "", err } return resp.TenantDiscoveryEndpoint, nil - } return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil