Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type LoginCmd struct {
ProviderArg string // OpenID Provider specification in the format: <issuer>,<client_id> or <issuer>,<client_id>,<client_secret> or <issuer>,<client_id>,<client_secret>,<scopes>
ProviderAliasArg string
KeyTypeArg KeyType
PrintKeyArg bool // Print private key and SSH cert instead of writing them to the filesystem
SSHConfigured bool
Verbosity int // Default verbosity is 0, 1 is verbose, 2 is debug
overrideProvider *providers.OpenIdProvider // Used in tests to override the provider to inject a mock provider
Expand All @@ -98,12 +99,15 @@ type LoginCmd struct {
alg jwa.SignatureAlgorithm
client *client.OpkClient
principals []string

// For testing
OutWriter io.Writer // Captures non-logged output that would normally be written to stdout
}

// NewLogin creates a new LoginCmd instance with the provided arguments.
func NewLogin(autoRefreshArg bool, configPathArg string, createConfigArg bool, configureArg bool, logDirArg string,
sendAccessTokenArg bool, disableBrowserOpenArg bool, printIdTokenArg bool,
providerArg string, keyPathArg string, providerAliasArg string, keyTypeArg KeyType,
providerArg string, printKeyArg bool, keyPathArg string, providerAliasArg string, keyTypeArg KeyType,
) *LoginCmd {
return &LoginCmd{
Fs: afero.NewOsFs(),
Expand All @@ -117,6 +121,7 @@ func NewLogin(autoRefreshArg bool, configPathArg string, createConfigArg bool, c
PrintIdTokenArg: printIdTokenArg,
KeyPathArg: keyPathArg,
ProviderArg: providerArg,
PrintKeyArg: printKeyArg,
ProviderAliasArg: providerAliasArg,
KeyTypeArg: keyTypeArg,
}
Expand Down Expand Up @@ -455,7 +460,11 @@ func (l *LoginCmd) login(ctx context.Context, provider providers.OpenIdProvider,
}

// Write ssh secret key and public key to filesystem
if seckeyPath != "" {
if l.PrintKeyArg {
w := l.out()
fmt.Fprintln(w, string(certBytes)) // Base64 encoded SSH cert
fmt.Fprintln(w, string(seckeySshPem)) // SSH private key in OpenSSH native format
} else if seckeyPath != "" {
// If we have set seckeyPath then write it there
if err := l.writeKeys(seckeyPath, seckeyPath+"-cert.pub", seckeySshPem, certBytes); err != nil {
return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
Expand Down Expand Up @@ -580,6 +589,13 @@ func (l *LoginCmd) LoginWithRefresh(ctx context.Context, provider providers.Refr
}
}

func (l *LoginCmd) out() io.Writer {
if l.OutWriter != nil {
return l.OutWriter
}
return os.Stdout
}

func createSSHCert(pkt *pktoken.PKToken, signer crypto.Signer, principals []string) ([]byte, []byte, error) {
return createSSHCertWithAccessToken(pkt, nil, signer, principals)
}
Expand Down
64 changes: 45 additions & 19 deletions commands/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
package commands

import (
"bytes"
"context"
"crypto"
"crypto/rand"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"testing"

"golang.org/x/crypto/ed25519"
Expand Down Expand Up @@ -143,6 +145,16 @@ func TestLoginCmd(t *testing.T) {
},
wantError: false,
},
{
name: "Good path PrintKey",
envVars: map[string]string{},
loginCmd: LoginCmd{
Verbosity: 0,
PrintKeyArg: true,
LogDirArg: logDir,
},
wantError: false,
},
{
name: "Good path with SendAccessToken set in arg and config",
envVars: map[string]string{},
Expand Down Expand Up @@ -199,6 +211,10 @@ func TestLoginCmd(t *testing.T) {
tt.loginCmd.overrideProvider = &mockOp
tt.loginCmd.Fs = mockFs

// Allows us to capture non-logged CLI output
cliOutputBuffer := &bytes.Buffer{}
tt.loginCmd.OutWriter = cliOutputBuffer

err = tt.loginCmd.Run(context.Background())
if tt.wantError {
require.Error(t, err, "Expected error but got none")
Expand All @@ -208,29 +224,38 @@ func TestLoginCmd(t *testing.T) {
} else {
require.NoError(t, err, "Unexpected error")

homePath, err := os.UserHomeDir()
require.NoError(t, err)

sshPath := filepath.Join(homePath, ".ssh", "id_ecdsa")
secKeyBytes, err := afero.ReadFile(mockFs, sshPath)
require.NoError(t, err)
require.NotNil(t, secKeyBytes)
require.Contains(t, string(secKeyBytes), "-----BEGIN OPENSSH PRIVATE KEY-----")

logBytes, err := afero.ReadFile(mockFs, logPath)
require.NoError(t, err)
require.NotNil(t, logBytes)
require.Contains(t, string(logBytes), "running login command with args:")

sshPubPath := filepath.Join(homePath, ".ssh", "id_ecdsa-cert.pub")
pubKeyBytes, err := afero.ReadFile(mockFs, sshPubPath)
require.NoError(t, err)
var pubKeyBytes []byte

if tt.loginCmd.PrintKeyArg {
got := cliOutputBuffer.String()
gotLines := strings.Split(strings.TrimSpace(got), "\n")
require.GreaterOrEqual(t, len(gotLines), 2, "expected at least 2 lines in output")
require.Contains(t, gotLines[0], "cert-v01@openssh.com AAAA")
require.Contains(t, gotLines[1], "-----BEGIN OPENSSH PRIVATE KEY-----")
pubKeyBytes = []byte(gotLines[0])
} else {
homePath, err := os.UserHomeDir()
require.NoError(t, err)

sshPath := filepath.Join(homePath, ".ssh", "id_ecdsa")
secKeyBytes, err := afero.ReadFile(mockFs, sshPath)
require.NoError(t, err)
require.NotNil(t, secKeyBytes)
require.Contains(t, string(secKeyBytes), "-----BEGIN OPENSSH PRIVATE KEY-----")

logBytes, err := afero.ReadFile(mockFs, logPath)
require.NoError(t, err)
require.NotNil(t, logBytes)
require.Contains(t, string(logBytes), "running login command with args:")

sshPubPath := filepath.Join(homePath, ".ssh", "id_ecdsa-cert.pub")
pubKeyBytes, err = afero.ReadFile(mockFs, sshPubPath)
require.NoError(t, err)
}
certSmug, err := sshcert.NewFromAuthorizedKey("fake-cert-type", string(pubKeyBytes))
require.NoError(t, err)

accToken := certSmug.GetAccessToken()

if tt.wantAccessToken {
require.NotEmpty(t, accToken, "expected access token to be set in SSH cert")
} else {
Expand Down Expand Up @@ -379,10 +404,11 @@ func TestNewLogin(t *testing.T) {
providerArg := ""
keyPathArg := ""
providerAlias := ""
keyAsOutputArg := false
keyTypeArg := ECDSA

loginCmd := NewLogin(autoRefresh, configPathArg, createConfig, configureArg, logDir,
sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyPathArg, providerAlias, keyTypeArg)
sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyAsOutputArg, keyPathArg, providerAlias, keyTypeArg)
require.NotNil(t, loginCmd)
}

Expand Down
4 changes: 3 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ Arguments:
var sendAccessTokenArg bool
var disableBrowserOpenArg bool
var printIdTokenArg bool
var printKeyArg bool
var keyPathArg string
var keyTypeArg commands.KeyType
loginCmd := &cobra.Command{
Expand Down Expand Up @@ -183,7 +184,7 @@ Arguments:
providerAliasArg = args[0]
}

login := commands.NewLogin(autoRefreshArg, configPathArg, createConfigArg, configureArg, logDirArg, sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyPathArg, providerAliasArg, keyTypeArg)
login := commands.NewLogin(autoRefreshArg, configPathArg, createConfigArg, configureArg, logDirArg, sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, printKeyArg, keyPathArg, providerAliasArg, keyTypeArg)
if err := login.Run(ctx); err != nil {
log.Println("Error executing login command:", err)
return err
Expand All @@ -203,6 +204,7 @@ Arguments:
loginCmd.Flags().BoolVar(&printIdTokenArg, "print-id-token", false, "Set this flag to print out the contents of the id_token. Useful for inspecting claims")
loginCmd.Flags().BoolVar(&sendAccessTokenArg, "send-access-token", false, "Set this flag to send the Access Token as well as the PK Token in the SSH cert. The Access Token is used to call the userinfo endpoint to get claims not included in the ID Token")
loginCmd.Flags().StringVar(&providerArg, "provider", "", "OpenID Provider specification in the format: <issuer>,<client_id> or <issuer>,<client_id>,<client_secret> or <issuer>,<client_id>,<client_secret>,<scopes>")
loginCmd.Flags().BoolVarP(&printKeyArg, "print-key", "p", false, "Print private key and SSH cert instead of writing them to the filesystem")
loginCmd.Flags().StringVarP(&keyPathArg, "private-key-file", "i", "", "Path where private keys is written")
loginCmd.Flags().VarP(enumflag.New(&keyTypeArg, "Key Type", map[commands.KeyType][]string{commands.ECDSA: {commands.ECDSA.String()}, commands.ED25519: {commands.ED25519.String()}}, enumflag.EnumCaseInsensitive), "key-type", "t", "Type of key to generate")
rootCmd.AddCommand(loginCmd)
Expand Down
Loading