Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added "connector_id" to skip straight to a connector (similar to when len(connector) is 1. #1481

Merged
merged 1 commit into from
Jul 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
return
}

// Redirect if a client chooses a specific connector_id
if authReq.ConnectorID != "" {
for _, c := range connectors {
if c.ID == authReq.ConnectorID {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ Should this return an error if connector_id was specified but doesn't match any?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I debated this myself. What could it ever do with an error other than break flow completely or just return to the index?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true, but it's still better than 🙈ignoring it, isn't it? I'd propose calling s.tokenErrHelper, maybe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not change anything with respect to showBacklink (#1123), since the number of connectors doesn't actually change. It would need an additional test such as authReq.ConnectorID != "" to suppress the link.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we'd actually want to suppress the link, would we? Maybe you'd like to use a different method...? 🤔

http.Redirect(w, r, s.absPath("/auth", c.ID)+"?req="+authReq.ID, http.StatusFound)
return
}
}
s.tokenErrHelper(w, errInvalidConnectorID, "Connector ID does not match a valid Connector", http.StatusNotFound)
return
}

if len(connectors) == 1 {
for _, c := range connectors {
// TODO(ericchiang): Make this pass on r.URL.RawQuery and let something latter
Expand Down
22 changes: 22 additions & 0 deletions server/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ const (
errUnsupportedGrantType = "unsupported_grant_type"
errInvalidGrant = "invalid_grant"
errInvalidClient = "invalid_client"
errInvalidConnectorID = "invalid_connector_id"
)

const (
Expand Down Expand Up @@ -391,6 +392,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
clientID := q.Get("client_id")
state := q.Get("state")
nonce := q.Get("nonce")
connectorID := q.Get("connector_id")
// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
scopes := strings.Fields(q.Get("scope"))
responseTypes := strings.Fields(q.Get("response_type"))
Expand All @@ -405,6 +407,16 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
return req, &authErr{"", "", errServerError, ""}
}

if connectorID != "" {
connectors, err := s.storage.ListConnectors()
if err != nil {
return req, &authErr{"", "", errServerError, "Unable to retrieve connectors"}
}
if !validateConnectorID(connectors, connectorID) {
return req, &authErr{"", "", errInvalidRequest, "Invalid ConnectorID"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this inside the for loop, it'll fail if there are more than one connectors, and the second one is the requested one. It would be nice if we had another test case for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a validate function. Added a second connector to the test server. Added test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All right.So I ended up duplicating newTestServer to have newTestServerMultipleConnectors. It seems that the code workflow tests may depend on having a single connector, so adding in a second connector in order to test broke all of that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now tests for selecting a connector other than the first.

}
}

if !validateRedirectURI(client, redirectURI) {
description := fmt.Sprintf("Unregistered redirect_uri (%q).", redirectURI)
return req, &authErr{"", "", errInvalidRequest, description}
Expand Down Expand Up @@ -509,6 +521,7 @@ func (s *Server) parseAuthorizationRequest(r *http.Request) (req storage.AuthReq
Scopes: scopes,
RedirectURI: redirectURI,
ResponseTypes: responseTypes,
ConnectorID: connectorID,
}, nil
}

Expand Down Expand Up @@ -568,6 +581,15 @@ func validateRedirectURI(client storage.Client, redirectURI string) bool {
return err == nil && host == "localhost"
}

func validateConnectorID(connectors []storage.Connector, connectorID string) bool {
for _, c := range connectors {
if c.ID == connectorID {
return true
}
}
return false
}

// storageKeySet implements the oidc.KeySet interface backed by Dex storage
type storageKeySet struct {
storage.Storage
Expand Down
57 changes: 54 additions & 3 deletions server/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"strings"
"testing"

jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2"

"github.com/dexidp/dex/storage"
"github.com/dexidp/dex/storage/memory"
Expand Down Expand Up @@ -145,14 +145,66 @@ func TestParseAuthorizationRequest(t *testing.T) {
},
wantErr: true,
},
{
name: "choose connector_id",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code", "id_token", "token"},
queryParams: map[string]string{
"connector_id": "mock",
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code id_token",
"scope": "openid email profile",
},
},
{
name: "choose second connector_id",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code", "id_token", "token"},
queryParams: map[string]string{
"connector_id": "mock2",
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code id_token",
"scope": "openid email profile",
},
},
{
name: "choose invalid connector_id",
clients: []storage.Client{
{
ID: "bar",
RedirectURIs: []string{"https://example.com/bar"},
},
},
supportedResponseTypes: []string{"code", "id_token", "token"},
queryParams: map[string]string{
"connector_id": "bogus",
"client_id": "bar",
"redirect_uri": "https://example.com/bar",
"response_type": "code id_token",
"scope": "openid email profile",
},
wantErr: true,
},
}

for _, tc := range tests {
func() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

httpServer, server := newTestServer(ctx, t, func(c *Config) {
httpServer, server := newTestServerMultipleConnectors(ctx, t, func(c *Config) {
c.SupportedResponseTypes = tc.supportedResponseTypes
c.Storage = storage.WithStaticClients(c.Storage, tc.clients)
})
Expand All @@ -162,7 +214,6 @@ func TestParseAuthorizationRequest(t *testing.T) {
for k, v := range tc.queryParams {
params.Set(k, v)
}

var req *http.Request
if tc.usePOST {
body := strings.NewReader(params.Encode())
Expand Down
47 changes: 47 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,53 @@ func newTestServer(ctx context.Context, t *testing.T, updateConfig func(c *Confi
return s, server
}

func newTestServerMultipleConnectors(ctx context.Context, t *testing.T, updateConfig func(c *Config)) (*httptest.Server, *Server) {
var server *Server
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.ServeHTTP(w, r)
}))

config := Config{
Issuer: s.URL,
Storage: memory.New(logger),
Web: WebConfig{
Dir: "../web",
},
Logger: logger,
PrometheusRegistry: prometheus.NewRegistry(),
}
if updateConfig != nil {
updateConfig(&config)
}
s.URL = config.Issuer

connector := storage.Connector{
ID: "mock",
Type: "mockCallback",
Name: "Mock",
ResourceVersion: "1",
}
connector2 := storage.Connector{
ID: "mock2",
Type: "mockCallback",
Name: "Mock",
ResourceVersion: "1",
}
if err := config.Storage.CreateConnector(connector); err != nil {
t.Fatalf("create connector: %v", err)
}
if err := config.Storage.CreateConnector(connector2); err != nil {
t.Fatalf("create connector: %v", err)
}

var err error
if server, err = newServer(ctx, config, staticRotationStrategy(testKey)); err != nil {
t.Fatal(err)
}
server.skipApproval = true // Don't prompt for approval, just immediately redirect with code.
return s, server
}

func TestNewTestServer(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
Expand Down