Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 33 additions & 76 deletions client/android/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,7 @@ package android
import (
"context"
"fmt"
"time"

"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"

"github.com/netbirdio/netbird/client/cmd"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
Expand Down Expand Up @@ -84,34 +76,21 @@ func (a *Auth) SaveConfigIfSSOSupported(listener SSOListener) {
}

func (a *Auth) saveConfigIfSSOSupported() (bool, error) {
supportsSSO := true
err := a.withBackOff(a.ctx, func() (err error) {
_, err = internal.GetPKCEAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL, nil)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.NotFound || s.Code() == codes.Unimplemented) {
_, err = internal.GetDeviceAuthorizationFlowInfo(a.ctx, a.config.PrivateKey, a.config.ManagementURL)
s, ok := gstatus.FromError(err)
if !ok {
return err
}
if s.Code() == codes.NotFound || s.Code() == codes.Unimplemented {
supportsSSO = false
err = nil
}

return err
}
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return false, fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()

return err
})
supportsSSO, err := authClient.IsSSOSupported(a.ctx)
if err != nil {
return false, fmt.Errorf("failed to check SSO support: %v", err)
}

if !supportsSSO {
return false, nil
}

if err != nil {
return false, fmt.Errorf("backoff cycle failed: %v", err)
}

err = profilemanager.WriteOutConfig(a.cfgPath, a.config)
return true, err
}
Expand All @@ -129,19 +108,17 @@ func (a *Auth) LoginWithSetupKeyAndSaveConfig(resultListener ErrListener, setupK
}

func (a *Auth) loginWithSetupKeyAndSaveConfig(setupKey string, deviceName string) error {
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()

//nolint
ctxWithValues := context.WithValue(a.ctx, system.DeviceNameCtxKey, deviceName)

err := a.withBackOff(a.ctx, func() error {
backoffErr := internal.Login(ctxWithValues, a.config, setupKey, "")
if s, ok := gstatus.FromError(backoffErr); ok && (s.Code() == codes.PermissionDenied) {
// we got an answer from management, exit backoff earlier
return backoff.Permanent(backoffErr)
}
return backoffErr
})
err, _ = authClient.Login(ctxWithValues, setupKey, "")
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}

return profilemanager.WriteOutConfig(a.cfgPath, a.config)
Expand All @@ -160,49 +137,41 @@ func (a *Auth) Login(resultListener ErrListener, urlOpener URLOpener, isAndroidT
}

func (a *Auth) login(urlOpener URLOpener, isAndroidTV bool) error {
var needsLogin bool
authClient, err := auth.NewAuth(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()

// check if we need to generate JWT token
err := a.withBackOff(a.ctx, func() (err error) {
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
return
})
needsLogin, err := authClient.IsLoginRequired(a.ctx)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("failed to check login requirement: %v", err)
}

jwtToken := ""
if needsLogin {
tokenInfo, err := a.foregroundGetTokenInfo(urlOpener, isAndroidTV)
tokenInfo, err := a.foregroundGetTokenInfo(authClient, urlOpener, isAndroidTV)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
jwtToken = tokenInfo.GetTokenToUse()
}

err = a.withBackOff(a.ctx, func() error {
err := internal.Login(a.ctx, a.config, "", jwtToken)

if err == nil {
go urlOpener.OnLoginSuccess()
}

if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
return nil
}
return err
})
err, _ = authClient.Login(a.ctx, "", jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}

go urlOpener.OnLoginSuccess()

return nil
}

func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, isAndroidTV, "")
func (a *Auth) foregroundGetTokenInfo(authClient *auth.Auth, urlOpener URLOpener, isAndroidTV bool) (*auth.TokenInfo, error) {
oAuthFlow, err := authClient.GetOAuthFlow(a.ctx, isAndroidTV)
if err != nil {
return nil, err
return nil, fmt.Errorf("failed to get OAuth flow: %v", err)
}

flowInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
Expand All @@ -212,22 +181,10 @@ func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener, isAndroidTV bool) (*a

go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)

waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, cancel := context.WithTimeout(a.ctx, waitTimeout)
defer cancel()
tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(a.ctx, flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}

return &tokenInfo, nil
}

func (a *Auth) withBackOff(ctx context.Context, bf func() error) error {
return backoff.RetryNotify(
bf,
backoff.WithContext(cmd.CLIBackOffSettings, ctx),
func(err error, duration time.Duration) {
log.Warnf("retrying Login to the Management service in %v due to error %v", duration, err)
})
}
46 changes: 14 additions & 32 deletions client/cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"os/user"
"runtime"
"strings"
"time"

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -277,18 +276,19 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
}

func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
}
defer authClient.Close()

needsLogin := false

err := WithBackOff(func() error {
err := internal.Login(ctx, config, "", "")
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
needsLogin = true
return nil
}
return err
})
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
err, isAuthError := authClient.Login(ctx, "", "")
if isAuthError {
needsLogin = true
} else if err != nil {
return fmt.Errorf("login check failed: %v", err)
}

jwtToken := ""
Expand All @@ -300,23 +300,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken = tokenInfo.GetTokenToUse()
}

var lastError error

err = WithBackOff(func() error {
err := internal.Login(ctx, config, setupKey, jwtToken)
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.InvalidArgument || s.Code() == codes.PermissionDenied) {
lastError = err
return nil
}
return err
})

if lastError != nil {
return fmt.Errorf("login failed: %v", lastError)
}

err, _ = authClient.Login(ctx, setupKey, jwtToken)
if err != nil {
return fmt.Errorf("backoff cycle failed: %v", err)
return fmt.Errorf("login failed: %v", err)
}

return nil
Expand Down Expand Up @@ -344,11 +330,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro

openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)

waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
defer c()

tokenInfo, err := oAuthFlow.WaitToken(waitCTX, flowInfo)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil {
return nil, fmt.Errorf("waiting for browser login failed: %v", err)
}
Expand Down
9 changes: 8 additions & 1 deletion client/embed/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
Expand Down Expand Up @@ -168,7 +169,13 @@ func (c *Client) Start(startCtx context.Context) error {
ctx := internal.CtxInitState(context.Background())
// nolint:staticcheck
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
if err := internal.Login(ctx, c.config, c.setupKey, c.jwtToken); err != nil {
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
defer authClient.Close()

if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
return fmt.Errorf("login: %w", err)
}

Expand Down
Loading
Loading