Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/android/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
Expand Down
2 changes: 1 addition & 1 deletion client/internal/auth/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopCl

// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *internal.Config) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
Expand Down
13 changes: 13 additions & 0 deletions client/internal/auth/pkce_flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"crypto/subtle"
"crypto/tls"
"encoding/base64"
"errors"
"fmt"
Expand Down Expand Up @@ -143,6 +144,18 @@ func (p *PKCEAuthorizationFlow) WaitToken(ctx context.Context, _ AuthFlowInfo) (
func (p *PKCEAuthorizationFlow) startServer(server *http.Server, tokenChan chan<- *oauth2.Token, errChan chan<- error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
cert := p.providerConfig.ClientCertPair
if cert != nil {
tr := &http.Transport{
TLSClientConfig: &tls.Config{
Certificates: []tls.Certificate{*cert},
},
}
sslClient := &http.Client{Transport: tr}
ctx := context.WithValue(req.Context(), oauth2.HTTPClient, sslClient)
req = req.WithContext(ctx)
}

token, err := p.handleRequest(req)
if err != nil {
renderPKCEFlowTmpl(w, err)
Expand Down
30 changes: 30 additions & 0 deletions client/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package internal

import (
"context"
"crypto/tls"
"fmt"
"net/url"
"os"
Expand Down Expand Up @@ -57,6 +58,8 @@ type ConfigInput struct {
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
}

// Config Configuration type
Expand Down Expand Up @@ -102,6 +105,13 @@ type Config struct {

// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
//Path to a certificate used for mTLS authentication
ClientCertPath string

//Path to corresponding private key of ClientCertPath
ClientCertKeyPath string

ClientCertKeyPair *tls.Certificate `json:"-"`
}

// ReadConfig read config file and return with Config. If it is not exists create a new with default values
Expand Down Expand Up @@ -385,6 +395,26 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {

}

if input.ClientCertKeyPath != "" {
config.ClientCertKeyPath = input.ClientCertKeyPath
updated = true
}

if input.ClientCertPath != "" {
config.ClientCertPath = input.ClientCertPath
updated = true
}

if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
if err != nil {
log.Error("Failed to load mTLS cert/key pair: ", err)
} else {
config.ClientCertKeyPair = &cert
log.Info("Loaded client mTLS cert/key pair")
}
}

return updated, nil
}

Expand Down
6 changes: 5 additions & 1 deletion client/internal/pkce_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package internal

import (
"context"
"crypto/tls"
"fmt"
"net/url"

Expand Down Expand Up @@ -36,10 +37,12 @@ type PKCEAuthProviderConfig struct {
RedirectURLs []string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
//ClientCertPair is used for mTLS authentication to the IDP
ClientCertPair *tls.Certificate
}

// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL) (PKCEAuthorizationFlow, error) {
func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL *url.URL, clientCert *tls.Certificate) (PKCEAuthorizationFlow, error) {
// validate our peer's Wireguard PRIVATE key
myPrivateKey, err := wgtypes.ParseKey(privateKey)
if err != nil {
Expand Down Expand Up @@ -93,6 +96,7 @@ func GetPKCEAuthorizationFlowInfo(ctx context.Context, privateKey string, mgmURL
Scope: protoPKCEAuthorizationFlow.GetProviderConfig().GetScope(),
RedirectURLs: protoPKCEAuthorizationFlow.GetProviderConfig().GetRedirectURLs(),
UseIDToken: protoPKCEAuthorizationFlow.GetProviderConfig().GetUseIDToken(),
ClientCertPair: clientCert,
},
}

Expand Down
2 changes: 1 addition & 1 deletion client/ios/NetBirdSDK/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (a *Auth) SaveConfigIfSSOSupported() (bool, error) {
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
supportsSSO = false
err = nil
Expand Down