diff --git a/commands/login.go b/commands/login.go index c6d33357..196e077c 100644 --- a/commands/login.go +++ b/commands/login.go @@ -85,6 +85,7 @@ type LoginCmd struct { ProviderArg string // OpenID Provider specification in the format: , or ,, or ,,, ProviderAliasArg string KeyTypeArg KeyType + PrintKeyArg bool // Print private key and SSH cert instead of writing them to the filesystem SSHConfigured bool Verbosity int // Default verbosity is 0, 1 is verbose, 2 is debug overrideProvider *providers.OpenIdProvider // Used in tests to override the provider to inject a mock provider @@ -98,12 +99,15 @@ type LoginCmd struct { alg jwa.SignatureAlgorithm client *client.OpkClient principals []string + + // For testing + OutWriter io.Writer // Captures non-logged output that would normally be written to stdout } // NewLogin creates a new LoginCmd instance with the provided arguments. func NewLogin(autoRefreshArg bool, configPathArg string, createConfigArg bool, configureArg bool, logDirArg string, sendAccessTokenArg bool, disableBrowserOpenArg bool, printIdTokenArg bool, - providerArg string, keyPathArg string, providerAliasArg string, keyTypeArg KeyType, + providerArg string, printKeyArg bool, keyPathArg string, providerAliasArg string, keyTypeArg KeyType, ) *LoginCmd { return &LoginCmd{ Fs: afero.NewOsFs(), @@ -117,6 +121,7 @@ func NewLogin(autoRefreshArg bool, configPathArg string, createConfigArg bool, c PrintIdTokenArg: printIdTokenArg, KeyPathArg: keyPathArg, ProviderArg: providerArg, + PrintKeyArg: printKeyArg, ProviderAliasArg: providerAliasArg, KeyTypeArg: keyTypeArg, } @@ -455,7 +460,11 @@ func (l *LoginCmd) login(ctx context.Context, provider providers.OpenIdProvider, } // Write ssh secret key and public key to filesystem - if seckeyPath != "" { + if l.PrintKeyArg { + w := l.out() + fmt.Fprintln(w, string(certBytes)) // Base64 encoded SSH cert + fmt.Fprintln(w, string(seckeySshPem)) // SSH private key in OpenSSH native format + } else if seckeyPath != "" { // If we have set seckeyPath then write it there if err := l.writeKeys(seckeyPath, seckeyPath+"-cert.pub", seckeySshPem, certBytes); err != nil { return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err) @@ -580,6 +589,13 @@ func (l *LoginCmd) LoginWithRefresh(ctx context.Context, provider providers.Refr } } +func (l *LoginCmd) out() io.Writer { + if l.OutWriter != nil { + return l.OutWriter + } + return os.Stdout +} + func createSSHCert(pkt *pktoken.PKToken, signer crypto.Signer, principals []string) ([]byte, []byte, error) { return createSSHCertWithAccessToken(pkt, nil, signer, principals) } diff --git a/commands/login_test.go b/commands/login_test.go index 420c3d6e..4db6fb5c 100644 --- a/commands/login_test.go +++ b/commands/login_test.go @@ -17,6 +17,7 @@ package commands import ( + "bytes" "context" "crypto" "crypto/rand" @@ -24,6 +25,7 @@ import ( "fmt" "os" "path/filepath" + "strings" "testing" "golang.org/x/crypto/ed25519" @@ -143,6 +145,16 @@ func TestLoginCmd(t *testing.T) { }, wantError: false, }, + { + name: "Good path PrintKey", + envVars: map[string]string{}, + loginCmd: LoginCmd{ + Verbosity: 0, + PrintKeyArg: true, + LogDirArg: logDir, + }, + wantError: false, + }, { name: "Good path with SendAccessToken set in arg and config", envVars: map[string]string{}, @@ -199,6 +211,10 @@ func TestLoginCmd(t *testing.T) { tt.loginCmd.overrideProvider = &mockOp tt.loginCmd.Fs = mockFs + // Allows us to capture non-logged CLI output + cliOutputBuffer := &bytes.Buffer{} + tt.loginCmd.OutWriter = cliOutputBuffer + err = tt.loginCmd.Run(context.Background()) if tt.wantError { require.Error(t, err, "Expected error but got none") @@ -208,29 +224,38 @@ func TestLoginCmd(t *testing.T) { } else { require.NoError(t, err, "Unexpected error") - homePath, err := os.UserHomeDir() - require.NoError(t, err) - - sshPath := filepath.Join(homePath, ".ssh", "id_ecdsa") - secKeyBytes, err := afero.ReadFile(mockFs, sshPath) - require.NoError(t, err) - require.NotNil(t, secKeyBytes) - require.Contains(t, string(secKeyBytes), "-----BEGIN OPENSSH PRIVATE KEY-----") - - logBytes, err := afero.ReadFile(mockFs, logPath) - require.NoError(t, err) - require.NotNil(t, logBytes) - require.Contains(t, string(logBytes), "running login command with args:") - - sshPubPath := filepath.Join(homePath, ".ssh", "id_ecdsa-cert.pub") - pubKeyBytes, err := afero.ReadFile(mockFs, sshPubPath) - require.NoError(t, err) + var pubKeyBytes []byte + if tt.loginCmd.PrintKeyArg { + got := cliOutputBuffer.String() + gotLines := strings.Split(strings.TrimSpace(got), "\n") + require.GreaterOrEqual(t, len(gotLines), 2, "expected at least 2 lines in output") + require.Contains(t, gotLines[0], "cert-v01@openssh.com AAAA") + require.Contains(t, gotLines[1], "-----BEGIN OPENSSH PRIVATE KEY-----") + pubKeyBytes = []byte(gotLines[0]) + } else { + homePath, err := os.UserHomeDir() + require.NoError(t, err) + + sshPath := filepath.Join(homePath, ".ssh", "id_ecdsa") + secKeyBytes, err := afero.ReadFile(mockFs, sshPath) + require.NoError(t, err) + require.NotNil(t, secKeyBytes) + require.Contains(t, string(secKeyBytes), "-----BEGIN OPENSSH PRIVATE KEY-----") + + logBytes, err := afero.ReadFile(mockFs, logPath) + require.NoError(t, err) + require.NotNil(t, logBytes) + require.Contains(t, string(logBytes), "running login command with args:") + + sshPubPath := filepath.Join(homePath, ".ssh", "id_ecdsa-cert.pub") + pubKeyBytes, err = afero.ReadFile(mockFs, sshPubPath) + require.NoError(t, err) + } certSmug, err := sshcert.NewFromAuthorizedKey("fake-cert-type", string(pubKeyBytes)) require.NoError(t, err) accToken := certSmug.GetAccessToken() - if tt.wantAccessToken { require.NotEmpty(t, accToken, "expected access token to be set in SSH cert") } else { @@ -379,10 +404,11 @@ func TestNewLogin(t *testing.T) { providerArg := "" keyPathArg := "" providerAlias := "" + keyAsOutputArg := false keyTypeArg := ECDSA loginCmd := NewLogin(autoRefresh, configPathArg, createConfig, configureArg, logDir, - sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyPathArg, providerAlias, keyTypeArg) + sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyAsOutputArg, keyPathArg, providerAlias, keyTypeArg) require.NotNil(t, loginCmd) } diff --git a/main.go b/main.go index 501b9ddd..5e15c989 100644 --- a/main.go +++ b/main.go @@ -151,6 +151,7 @@ Arguments: var sendAccessTokenArg bool var disableBrowserOpenArg bool var printIdTokenArg bool + var printKeyArg bool var keyPathArg string var keyTypeArg commands.KeyType loginCmd := &cobra.Command{ @@ -183,7 +184,7 @@ Arguments: providerAliasArg = args[0] } - login := commands.NewLogin(autoRefreshArg, configPathArg, createConfigArg, configureArg, logDirArg, sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyPathArg, providerAliasArg, keyTypeArg) + login := commands.NewLogin(autoRefreshArg, configPathArg, createConfigArg, configureArg, logDirArg, sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, printKeyArg, keyPathArg, providerAliasArg, keyTypeArg) if err := login.Run(ctx); err != nil { log.Println("Error executing login command:", err) return err @@ -203,6 +204,7 @@ Arguments: loginCmd.Flags().BoolVar(&printIdTokenArg, "print-id-token", false, "Set this flag to print out the contents of the id_token. Useful for inspecting claims") loginCmd.Flags().BoolVar(&sendAccessTokenArg, "send-access-token", false, "Set this flag to send the Access Token as well as the PK Token in the SSH cert. The Access Token is used to call the userinfo endpoint to get claims not included in the ID Token") loginCmd.Flags().StringVar(&providerArg, "provider", "", "OpenID Provider specification in the format: , or ,, or ,,,") + loginCmd.Flags().BoolVarP(&printKeyArg, "print-key", "p", false, "Print private key and SSH cert instead of writing them to the filesystem") loginCmd.Flags().StringVarP(&keyPathArg, "private-key-file", "i", "", "Path where private keys is written") loginCmd.Flags().VarP(enumflag.New(&keyTypeArg, "Key Type", map[commands.KeyType][]string{commands.ECDSA: {commands.ECDSA.String()}, commands.ED25519: {commands.ED25519.String()}}, enumflag.EnumCaseInsensitive), "key-type", "t", "Type of key to generate") rootCmd.AddCommand(loginCmd)