Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ func New(authority, clientID string, cred Credential, options ...Option) (Client

// authCodeURLOptions contains options for AuthCodeURL
type authCodeURLOptions struct {
claims, loginHint, tenantID, domainHint string
claims, loginHint, tenantID, domainHint, prompt string
}

// AuthCodeURLOption is implemented by options for AuthCodeURL
Expand All @@ -369,7 +369,7 @@ type AuthCodeURLOption interface {

// AuthCodeURL creates a URL used to acquire an authorization code. Users need to call CreateAuthorizationCodeURLParameters and pass it in.
//
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID]
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt]
func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) {
o := authCodeURLOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
Expand All @@ -382,6 +382,7 @@ func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string,
ap.Claims = o.claims
ap.LoginHint = o.loginHint
ap.DomainHint = o.domainHint
ap.Prompt = o.prompt
return cca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap)
}

Expand Down Expand Up @@ -431,6 +432,29 @@ func WithDomainHint(domain string) interface {
}
}

// WithPrompt adds prompt query parameter in the auth url.
func WithPrompt(prompt shared.Prompt) interface {
AuthCodeURLOption
options.CallOption
} {
return struct {
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *authCodeURLOptions:
t.prompt = prompt.String()
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}

// WithClaims sets additional claims to request for the token, such as those required by conditional access policies.
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
// This option is valid for any token acquisition method.
Expand Down
53 changes: 53 additions & 0 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
)

// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
Expand Down Expand Up @@ -1774,6 +1775,58 @@ func TestWithDomainHint(t *testing.T) {
}
}

func TestWithPrompt(t *testing.T) {
prompt := shared.PromptLogin
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
if err != nil {
t.Fatal(err)
}
client.base.Token.AccessTokens = &fake.AccessTokens{}
client.base.Token.Authority = &fake.Authority{}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
for _, expectPrompt := range []bool{true, false} {
t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) {
validate := func(v url.Values) error {
if !v.Has("prompt") {
if !expectPrompt {
return nil
}
return errors.New("expected a prompt")
} else if !expectPrompt {
return fmt.Errorf("expected no prompt, got %v", v["prompt"][0])
}

if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() {
err = fmt.Errorf(`unexpected prompt "%v"`, actual[0])
}
return err
}
var urlOpts []AuthCodeURLOption
if expectPrompt {
urlOpts = append(urlOpts, WithPrompt(prompt))
}
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...)
if err == nil {
var parsed *url.URL
parsed, err = url.Parse(u)
if err == nil {
err = validate(parsed.Query())
}
}
if err != nil {
t.Fatal(err)
}
})
}
}

func TestWithAuthenticationScheme(t *testing.T) {
ctx := context.Background()
authScheme := mock.NewTestAuthnScheme()
Expand Down
26 changes: 26 additions & 0 deletions apps/internal/shared/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,29 @@ func (acc Account) IsZero() bool {

// DefaultClient is our default shared HTTP client.
var DefaultClient = &http.Client{}

type Prompt int64

const (
PromptNone Prompt = iota
PromptLogin
PromptSelectAccount
PromptConsent
PromptCreate
)

func (p Prompt) String() string {
switch p {
case PromptNone:
return "none"
case PromptLogin:
return "login"
case PromptSelectAccount:
return "select_account"
case PromptConsent:
return "consent"
case PromptCreate:
return "create"
}
return ""
}
40 changes: 34 additions & 6 deletions apps/public/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ func New(clientID string, options ...Option) (Client, error) {

// authCodeURLOptions contains options for AuthCodeURL
type authCodeURLOptions struct {
claims, loginHint, tenantID, domainHint string
claims, loginHint, tenantID, domainHint, prompt string
}

// AuthCodeURLOption is implemented by options for AuthCodeURL
Expand All @@ -159,7 +159,7 @@ type AuthCodeURLOption interface {

// AuthCodeURL creates a URL used to acquire an authorization code.
//
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID]
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID], [WithPrompt]
func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) {
o := authCodeURLOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
Expand All @@ -172,6 +172,7 @@ func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string,
ap.Claims = o.claims
ap.LoginHint = o.loginHint
ap.DomainHint = o.domainHint
ap.Prompt = o.prompt
return pca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap)
}

Expand Down Expand Up @@ -526,9 +527,9 @@ func (pca Client) RemoveAccount(ctx context.Context, account Account) error {

// interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow.
type interactiveAuthOptions struct {
claims, domainHint, loginHint, redirectURI, tenantID string
openURL func(url string) error
authnScheme AuthenticationScheme
claims, domainHint, loginHint, redirectURI, tenantID, prompt string
openURL func(url string) error
authnScheme AuthenticationScheme
}

// AcquireInteractiveOption is implemented by options for AcquireTokenInteractive
Expand Down Expand Up @@ -590,6 +591,33 @@ func WithDomainHint(domain string) interface {
}
}

// WithPrompt adds the IdP prompt query parameter in the auth url.
func WithPrompt(prompt shared.Prompt) interface {
AcquireInteractiveOption
AuthCodeURLOption
options.CallOption
} {
return struct {
AcquireInteractiveOption
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *authCodeURLOptions:
t.prompt = prompt.String()
case *interactiveAuthOptions:
t.prompt = prompt.String()
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}

// WithRedirectURI sets a port for the local server used in interactive authentication, for
// example http://localhost:port. All URI components other than the port are ignored.
func WithRedirectURI(redirectURI string) interface {
Expand Down Expand Up @@ -674,7 +702,7 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string,
authParams.LoginHint = o.loginHint
authParams.DomainHint = o.domainHint
authParams.State = uuid.New().String()
authParams.Prompt = "select_account"
authParams.Prompt = o.prompt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We’re now changing the default from "select_account" to an empty string "".
Will need to test this a bit more to ensure everything behaves as expected.

Copy link
Contributor

@rayluo rayluo Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not aware that the "select_account" was the default value. That seems to be debatable choice in the first place. The default value on the protocol level is empty or completely absent.

In any case, changing the default from "select_account" to empty "" shall not cause catastrophic result because, again that was the default value in the specs and presumably most other libraries do that all the time. The perceived behavior will change, though, as the account selector will not pop up as often; but I think that shall be a welcome change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be tested live ?

if o.authnScheme != nil {
authParams.AuthnScheme = o.authnScheme
}
Expand Down
81 changes: 77 additions & 4 deletions apps/public/public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
"github.com/kylelemons/godebug/pretty"
)

Expand All @@ -43,16 +44,16 @@ func fakeBrowserOpenURL(authURL string) error {
if m := q.Get("code_challenge_method"); m != "S256" {
return fmt.Errorf("unexpected code_challenge_method '%s'", m)
}
if q.Get("prompt") == "" {
return errors.New("missing query param 'prompt")
}
// if q.Get("prompt") == "" {
// return errors.New("missing query param 'prompt")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commenting this out will fix the failing test, but as I mentioned earlier, changing the default value from "select_account" to "" might introduce unintended behavior.
We'll need to test this further, and once confirmed, re-enable it and update the test to assert the correct value based on the expected response.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I should put "select_account" as a default for AcquireInteractiveOption?

// }
state := q.Get("state")
if state == "" {
return errors.New("missing query param 'state'")
}
redirect := q.Get("redirect_uri")
if redirect == "" {
return errors.New("missing query param 'redirect_uri'")
return errors.New(" 'redirect_uri'")
}
// now send the info to our local redirect server
resp, err := http.DefaultClient.Get(redirect + fmt.Sprintf("/?state=%s&code=fake_auth_code", state))
Expand Down Expand Up @@ -935,6 +936,78 @@ func TestWithDomainHint(t *testing.T) {
}
}

func TestWithPrompt(t *testing.T) {
prompt := shared.PromptCreate
client, err := New("client-id")
if err != nil {
t.Fatal(err)
}
client.base.Token.AccessTokens = &fake.AccessTokens{}
client.base.Token.Authority = &fake.Authority{}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
for _, expectPrompt := range []bool{true, false} {
t.Run(fmt.Sprint(expectPrompt), func(t *testing.T) {
called := false
validate := func(v url.Values) error {
if !v.Has("prompt") {
if !expectPrompt {
return nil
}
return errors.New("expected a prompt")
} else if !expectPrompt {
return fmt.Errorf("expected no prompt, got %v", v["prompt"][0])
}

if actual := v["prompt"]; len(actual) != 1 || actual[0] != prompt.String() {
err = fmt.Errorf(`unexpected prompt "%v"`, actual[0])
}
return err
}
browserOpenURL := func(authURL string) error {
called = true
parsed, err := url.Parse(authURL)
if err != nil {
return err
}
query, err := url.ParseQuery(parsed.RawQuery)
if err != nil {
return err
}
if err = validate(query); err != nil {
t.Fatal(err)
return err
}
// this helper validates the other params and completes the redirect
return fakeBrowserOpenURL(authURL)
}
acquireOpts := []AcquireInteractiveOption{WithOpenURL(browserOpenURL)}
var urlOpts []AuthCodeURLOption
if expectPrompt {
acquireOpts = append(acquireOpts, WithPrompt(prompt))
urlOpts = append(urlOpts, WithPrompt(prompt))
}
_, err = client.AcquireTokenInteractive(context.Background(), tokenScope, acquireOpts...)
if err != nil {
t.Fatal(err)
}
if !called {
t.Fatal("browserOpenURL wasn't called")
}
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...)
if err == nil {
var parsed *url.URL
parsed, err = url.Parse(u)
if err == nil {
err = validate(parsed.Query())
}
}
if err != nil {
t.Fatal(err)
}
})
}
}

func TestWithAuthenticationScheme(t *testing.T) {
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
lmo, tenant := "login.microsoftonline.com", "tenant"
Expand Down