Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions commands/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ type LoginCmd struct {
ProviderArg string // OpenID Provider specification in the format: <issuer>,<client_id> or <issuer>,<client_id>,<client_secret> or <issuer>,<client_id>,<client_secret>,<scopes>
ProviderAliasArg string
KeyTypeArg KeyType
PrintKeyArg bool // Print private key and SSH cert instead of writing them to the filesystem
SSHConfigured bool
Verbosity int // Default verbosity is 0, 1 is verbose, 2 is debug
overrideProvider *providers.OpenIdProvider // Used in tests to override the provider to inject a mock provider
Expand All @@ -98,12 +99,13 @@ type LoginCmd struct {
alg jwa.SignatureAlgorithm
client *client.OpkClient
principals []string
OutWriter io.Writer // optional
}

// NewLogin creates a new LoginCmd instance with the provided arguments.
func NewLogin(autoRefreshArg bool, configPathArg string, createConfigArg bool, configureArg bool, logDirArg string,
sendAccessTokenArg bool, disableBrowserOpenArg bool, printIdTokenArg bool,
providerArg string, keyPathArg string, providerAliasArg string, keyTypeArg KeyType,
providerArg string, printKeyArg bool, keyPathArg string, providerAliasArg string, keyTypeArg KeyType,
) *LoginCmd {
return &LoginCmd{
Fs: afero.NewOsFs(),
Expand All @@ -117,6 +119,7 @@ func NewLogin(autoRefreshArg bool, configPathArg string, createConfigArg bool, c
PrintIdTokenArg: printIdTokenArg,
KeyPathArg: keyPathArg,
ProviderArg: providerArg,
PrintKeyArg: printKeyArg,
ProviderAliasArg: providerAliasArg,
KeyTypeArg: keyTypeArg,
}
Expand Down Expand Up @@ -455,7 +458,11 @@ func (l *LoginCmd) login(ctx context.Context, provider providers.OpenIdProvider,
}

// Write ssh secret key and public key to filesystem
if seckeyPath != "" {
if l.PrintKeyArg {
w := l.out()
fmt.Fprintln(w, string(certBytes)) // Base64 encoded SSH cert
fmt.Fprintln(w, string(seckeySshPem)) // SSH private key in OpenSSH native format
} else if seckeyPath != "" {
// If we have set seckeyPath then write it there
if err := l.writeKeys(seckeyPath, seckeyPath+"-cert.pub", seckeySshPem, certBytes); err != nil {
return nil, fmt.Errorf("failed to write SSH keys to filesystem: %w", err)
Expand Down Expand Up @@ -580,6 +587,13 @@ func (l *LoginCmd) LoginWithRefresh(ctx context.Context, provider providers.Refr
}
}

func (l *LoginCmd) out() io.Writer {
if l.OutWriter != nil {
return l.OutWriter
}
return os.Stdout
}

func createSSHCert(pkt *pktoken.PKToken, signer crypto.Signer, principals []string) ([]byte, []byte, error) {
return createSSHCertWithAccessToken(pkt, nil, signer, principals)
}
Expand Down
69 changes: 49 additions & 20 deletions commands/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
package commands

import (
"bytes"
"context"
"crypto"
"crypto/rand"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"testing"

"golang.org/x/crypto/ed25519"
Expand Down Expand Up @@ -113,11 +115,14 @@ func TestLoginCmd(t *testing.T) {
DefaultProvider: "mockOp",
}

cliOutput := bytes.Buffer{}
Comment thread
EthanHeilman marked this conversation as resolved.
Outdated

tests := []struct {
name string
envVars map[string]string
loginCmd LoginCmd
ClientConfig *config.ClientConfig
printKeyArg bool
Comment thread
EthanHeilman marked this conversation as resolved.
Outdated
wantAccessToken bool
wantError bool
errorString string
Expand All @@ -143,6 +148,17 @@ func TestLoginCmd(t *testing.T) {
},
wantError: false,
},
{
name: "Good path SSH cert and private key to filesystems",
Comment thread
EthanHeilman marked this conversation as resolved.
Outdated
envVars: map[string]string{},
loginCmd: LoginCmd{
Verbosity: 0,
PrintKeyArg: true,
LogDirArg: logDir,
OutWriter: &cliOutput,
},
wantError: false,
},
{
name: "Good path with SendAccessToken set in arg and config",
envVars: map[string]string{},
Expand Down Expand Up @@ -208,34 +224,46 @@ func TestLoginCmd(t *testing.T) {
} else {
require.NoError(t, err, "Unexpected error")

homePath, err := os.UserHomeDir()
require.NoError(t, err)

sshPath := filepath.Join(homePath, ".ssh", "id_ecdsa")
secKeyBytes, err := afero.ReadFile(mockFs, sshPath)
require.NoError(t, err)
require.NotNil(t, secKeyBytes)
require.Contains(t, string(secKeyBytes), "-----BEGIN OPENSSH PRIVATE KEY-----")

logBytes, err := afero.ReadFile(mockFs, logPath)
require.NoError(t, err)
require.NotNil(t, logBytes)
require.Contains(t, string(logBytes), "running login command with args:")

sshPubPath := filepath.Join(homePath, ".ssh", "id_ecdsa-cert.pub")
pubKeyBytes, err := afero.ReadFile(mockFs, sshPubPath)
require.NoError(t, err)

var pubKeyBytes []byte

if tt.loginCmd.PrintKeyArg {
if tt.loginCmd.PrintKeyArg {
got := cliOutput.String()
gotLines := strings.Split(strings.TrimSpace(got), "\n")
require.GreaterOrEqual(t, len(gotLines), 2, "expected at least 2 lines in output")
require.Contains(t, gotLines[0], "ecdsa-sha2-nistp256-cert-v01@openssh.com AAAA")
Comment thread
EthanHeilman marked this conversation as resolved.
Outdated
require.Contains(t, gotLines[1], "-----BEGIN OPENSSH PRIVATE KEY-----")
pubKeyBytes = []byte(gotLines[0])
}
Comment thread
EthanHeilman marked this conversation as resolved.
Outdated
} else {
homePath, err := os.UserHomeDir()
require.NoError(t, err)

sshPath := filepath.Join(homePath, ".ssh", "id_ecdsa")
secKeyBytes, err := afero.ReadFile(mockFs, sshPath)
require.NoError(t, err)
require.NotNil(t, secKeyBytes)
require.Contains(t, string(secKeyBytes), "-----BEGIN OPENSSH PRIVATE KEY-----")

logBytes, err := afero.ReadFile(mockFs, logPath)
require.NoError(t, err)
require.NotNil(t, logBytes)
require.Contains(t, string(logBytes), "running login command with args:")

sshPubPath := filepath.Join(homePath, ".ssh", "id_ecdsa-cert.pub")
pubKeyBytes, err = afero.ReadFile(mockFs, sshPubPath)
require.NoError(t, err)
}
certSmug, err := sshcert.NewFromAuthorizedKey("fake-cert-type", string(pubKeyBytes))
require.NoError(t, err)

accToken := certSmug.GetAccessToken()

if tt.wantAccessToken {
require.NotEmpty(t, accToken, "expected access token to be set in SSH cert")
} else {
require.Empty(t, accToken, "expected access token to not be set in SSH cert")
}

}
})
}
Expand Down Expand Up @@ -379,10 +407,11 @@ func TestNewLogin(t *testing.T) {
providerArg := ""
keyPathArg := ""
providerAlias := ""
keyAsOutputArg := false
keyTypeArg := ECDSA

loginCmd := NewLogin(autoRefresh, configPathArg, createConfig, configureArg, logDir,
sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyPathArg, providerAlias, keyTypeArg)
sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyAsOutputArg, keyPathArg, providerAlias, keyTypeArg)
require.NotNil(t, loginCmd)
}

Expand Down
4 changes: 3 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ Arguments:
var sendAccessTokenArg bool
var disableBrowserOpenArg bool
var printIdTokenArg bool
var printKeyArg bool
var keyPathArg string
var keyTypeArg commands.KeyType
loginCmd := &cobra.Command{
Expand Down Expand Up @@ -183,7 +184,7 @@ Arguments:
providerAliasArg = args[0]
}

login := commands.NewLogin(autoRefreshArg, configPathArg, createConfigArg, configureArg, logDirArg, sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, keyPathArg, providerAliasArg, keyTypeArg)
login := commands.NewLogin(autoRefreshArg, configPathArg, createConfigArg, configureArg, logDirArg, sendAccessTokenArg, disableBrowserOpenArg, printIdTokenArg, providerArg, printKeyArg, keyPathArg, providerAliasArg, keyTypeArg)
if err := login.Run(ctx); err != nil {
log.Println("Error executing login command:", err)
return err
Expand All @@ -203,6 +204,7 @@ Arguments:
loginCmd.Flags().BoolVar(&printIdTokenArg, "print-id-token", false, "Set this flag to print out the contents of the id_token. Useful for inspecting claims")
loginCmd.Flags().BoolVar(&sendAccessTokenArg, "send-access-token", false, "Set this flag to send the Access Token as well as the PK Token in the SSH cert. The Access Token is used to call the userinfo endpoint to get claims not included in the ID Token")
loginCmd.Flags().StringVar(&providerArg, "provider", "", "OpenID Provider specification in the format: <issuer>,<client_id> or <issuer>,<client_id>,<client_secret> or <issuer>,<client_id>,<client_secret>,<scopes>")
loginCmd.Flags().BoolVarP(&printKeyArg, "print-key", "p", false, "Print private key and SSH cert instead of writing them to the filesystem")
loginCmd.Flags().StringVarP(&keyPathArg, "private-key-file", "i", "", "Path where private keys is written")
loginCmd.Flags().VarP(enumflag.New(&keyTypeArg, "Key Type", map[commands.KeyType][]string{commands.ECDSA: {commands.ECDSA.String()}, commands.ED25519: {commands.ED25519.String()}}, enumflag.EnumCaseInsensitive), "key-type", "t", "Type of key to generate")
rootCmd.AddCommand(loginCmd)
Expand Down
Loading