diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 165a662f..38ce7572 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -432,6 +432,77 @@ func TestAcquireTokenByAuthCode(t *testing.T) { } } +func TestAcquireTokenByAuthCodeTokenExpiry(t *testing.T) { + accessToken := "initial-access-token" + newAccessToken := "new-access-token" + homeTenant := "home-tenant" + clientInfo := base64.RawStdEncoding.EncodeToString([]byte( + fmt.Sprintf(`{"uid":"uid","utid":"%s"}`, homeTenant), + )) + lmo := "login.microsoftonline.com" + + originalTime := base.Now + defer func() { + base.Now = originalTime + }() + + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + + mockClient := mock.NewClient() + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "common"))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(newAccessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000))) + + client, err := New(fmt.Sprintf(authorityFmt, lmo, "common"), fakeClientID, cred, WithHTTPClient(mockClient), WithInstanceDiscovery(false)) + if err != nil { + t.Fatal(err) + } + + // Acquire token using auth code + ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken) + } + + account := ar.Account + + // First silent call should return cached token + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken) + } + if ar.Metadata.TokenSource != base.TokenSourceCache { + t.Fatalf("expected token source %v, got %v", base.TokenSourceCache, ar.Metadata.TokenSource) + } + + // Move time forward past RefreshOn (1001 seconds) + fixedTime := time.Now().Add(time.Duration(1001) * time.Second) + base.Now = func() time.Time { + return fixedTime + } + + // Second silent call should automatically refresh and return new token + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != newAccessToken { + t.Fatalf("expected %q, got %q", newAccessToken, ar.AccessToken) + } + // Verify the token came from the identity provider (refresh), not cache + if ar.Metadata.TokenSource != base.TokenSourceIdentityProvider { + t.Fatalf("expected token source %v, got %v", base.TokenSourceIdentityProvider, ar.Metadata.TokenSource) + } +} func TestInvalidJsonErrFromResponse(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { diff --git a/apps/internal/base/base.go b/apps/internal/base/base.go index 61c1c4ce..06b6ad2e 100644 --- a/apps/internal/base/base.go +++ b/apps/internal/base/base.go @@ -367,8 +367,19 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen // If the token is not same, we don't need to refresh it. // Which means it refreshed. if str, err := m.Read(ctx, authParams); err == nil && str.AccessToken.Secret == ar.AccessToken { - if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { - return b.AuthResultFromToken(ctx, authParams, tr) + switch silent.RequestType { + case accesstokens.ATConfidential: + if tr, er := b.Token.Credential(ctx, authParams, silent.Credential); er == nil { + return b.AuthResultFromToken(ctx, authParams, tr) + } + case accesstokens.ATPublic: + token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, silent.Credential, storageTokenResponse.RefreshToken) + if err != nil { + return ar, err + } + return b.AuthResultFromToken(ctx, authParams, token) + case accesstokens.ATUnknown: + return ar, errors.New("silent request type cannot be ATUnknown") } } } diff --git a/apps/public/public_test.go b/apps/public/public_test.go index fa019ca5..e70a85fb 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -16,6 +16,7 @@ import ( "time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake" @@ -1046,3 +1047,58 @@ func getNewClientWithMockedResponses( return client, nil } + +func TestAcquireTokenSilentWithRefreshOnIsExpired(t *testing.T) { + accessToken := "*" + homeTenant := "home-tenant" + clientInfo := base64.RawStdEncoding.EncodeToString([]byte( + fmt.Sprintf(`{"uid":"uid","utid":"%s"}`, homeTenant), + )) + lmo := "login.microsoftonline.com" + originalTime := base.Now + defer func() { + base.Now = originalTime + }() + mockClient := mock.NewClient() + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "common"))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000))) + mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody("new-"+accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 36000, 1000))) + + client, err := New("common", + WithAuthority(fmt.Sprintf(authorityFmt, lmo, "common")), + WithHTTPClient(mockClient), + WithInstanceDiscovery(false)) + if err != nil { + t.Fatal(err) + } + // the auth flow isn't important, we just need to populate the cache + ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken) + } + account := ar.Account + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != accessToken { + t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken) + } + // moving time forward to expire the current token + fixedTime := time.Now().Add(time.Duration(36001) * time.Second) + base.Now = func() time.Time { + return fixedTime + } + // calling the acquire token again + ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account)) + if err != nil { + t.Fatal(err) + } + if ar.AccessToken != "new-"+accessToken { + t.Fatalf("expected %q, got %q", "new-"+accessToken, ar.AccessToken) + } + +}