Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/auth-clientCredentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ var clientCredentialsCmd = man.Docs.GetCommand("auth/client-credentials",

func auth_clientCredentials(cmd *cobra.Command, args []string) {
c := cli.New(cmd, args)
cp := InitProfile(c, false)
_, cp := InitProfile(c, false)

var clientId string
var clientSecret string
Expand Down
2 changes: 1 addition & 1 deletion cmd/auth-login.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func auth_codeLogin(cmd *cobra.Command, args []string) {
c := cli.New(cmd, args)
cp := InitProfile(c, false)
_, cp := InitProfile(c, false)

c.Print("Initiating login...")
tok, publicClientID, err := auth.LoginWithPKCE(
Expand Down
2 changes: 1 addition & 1 deletion cmd/auth-logout.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func auth_logout(cmd *cobra.Command, args []string) {
c := cli.New(cmd, args)
cp := InitProfile(c, false)
_, cp := InitProfile(c, false)
c.Println("Initiating logout...")

// we can only revoke access tokens stored for the code login flow, not client credentials
Expand Down
2 changes: 1 addition & 1 deletion cmd/auth-printAccessToken.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var auth_printAccessTokenCmd = man.Docs.GetCommand("auth/print-access-token",

func auth_printAccessToken(cmd *cobra.Command, args []string) {
c := cli.New(cmd, args)
cp := InitProfile(c, false)
_, cp := InitProfile(c, false)

ac := cp.GetAuthCredentials()
switch ac.AuthType {
Expand Down
128 changes: 82 additions & 46 deletions cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
/*
Copyright © 2023 NAME HERE <EMAIL ADDRESS>
*/
package cmd

import (
Expand Down Expand Up @@ -35,7 +32,10 @@ type version struct {
BuildTime string `json:"build_time"`
}

func InitProfile(c *cli.Cli, onlyNew bool) *profiles.ProfileStore {
// InitProfile initializes the profile store and loads the profile specified in the flags
// if onlyNew is set to true, a new profile will be created and returned
// returns the profile and the current profile store
func InitProfile(c *cli.Cli, onlyNew bool) (*profiles.Profile, *profiles.ProfileStore) {
var err error
profileName := c.FlagHelper.GetOptionalString("profile")

Expand All @@ -45,126 +45,157 @@ func InitProfile(c *cli.Cli, onlyNew bool) *profiles.ProfileStore {
}

// short circuit if onlyNew is set to enable creating a new profile
if onlyNew {
return nil
if onlyNew && profileName == "" {
return profile, nil
}

// check if there exists a default profile and warn if not with steps to create one
if profile.GetGlobalConfig().GetDefaultProfile() == "" {
c.ExitWithWarning("No default profile set. Use `" + config.AppName + " profile create <profile> <endpoint>` to create a default profile.")
c.ExitWithWarning(fmt.Sprintf("No default profile set. Use `%s profile create <profile> <endpoint>` to create a default profile.", config.AppName))
}
c.Printf("Using profile [%s]\n", profile.GetGlobalConfig().GetDefaultProfile())

if profileName == "" {
profileName = profile.GetGlobalConfig().GetDefaultProfile()
}

c.Printf("Using profile [%s]\n", profileName)

// load profile
cp, err := profile.UseProfile(profileName)
if err != nil {
c.ExitWithError("Failed to load profile "+profileName, err)
c.ExitWithError(fmt.Sprintf("Failed to load profile: %s", profileName), err)
}

return cp
return profile, cp
}

// instantiates a new handler with authentication via client credentials
// TODO make this a preRun hook
//
//nolint:nestif // separate refactor [https://github.com/opentdf/otdfctl/issues/383]
func NewHandler(c *cli.Cli) handlers.Handler {
// if global flags are set then validate and create a temporary profile in memory
var cp *profiles.ProfileStore

// Non-profile flags
host := c.FlagHelper.GetOptionalString("host")
tlsNoVerify := c.FlagHelper.GetOptionalBool("tls-no-verify")
withClientCreds := c.FlagHelper.GetOptionalString("with-client-creds")
withClientCredsFile := c.FlagHelper.GetOptionalString("with-client-creds-file")
withAccessToken := c.FlagHelper.GetOptionalString("with-access-token")
var inMemoryProfile bool

// if global flags are set then validate and create a temporary profile in memory
var cp *profiles.ProfileStore
authFlags := []string{"--with-access-token", "--with-client-creds", "--with-client-creds-file"}
nonProfileFlags := append([]string{"--host", "--tls-no-verify"}, authFlags...)
hasNonProfileFlags := host != "" || tlsNoVerify || withClientCreds != "" || withClientCredsFile != "" || withAccessToken != ""

//nolint:nestif // nested if statements are necessary for validation
if host != "" || tlsNoVerify || withClientCreds != "" || withClientCredsFile != "" {
err := errors.New(
"when using global flags --host, --tls-no-verify, --with-client-creds, or --with-client-creds-file, " +
"profiles will not be used and all required flags must be set",
)
if hasNonProfileFlags {
err := fmt.Errorf("when using global flags %s, profiles will not be used and all required flags must be set", cli.PrettyList(nonProfileFlags))

// host must be set
if host == "" {
cli.ExitWithError("Host must be set", err)
}

// either with-client-creds or with-client-creds-file must be set
if withClientCreds == "" && withClientCredsFile == "" {
cli.ExitWithError("Either --with-client-creds or --with-client-creds-file must be set", err)
} else if withClientCreds != "" && withClientCredsFile != "" {
cli.ExitWithError("Only one of --with-client-creds or --with-client-creds-file can be set", err)
authFlagsCounter := 0
if withAccessToken != "" {
authFlagsCounter++
}

var cc auth.ClientCredentials
if withClientCreds != "" {
cc, err = auth.GetClientCredsFromJSON([]byte(withClientCreds))
} else {
cc, err = auth.GetClientCredsFromFile(withClientCredsFile)
authFlagsCounter++
}
if err != nil {
cli.ExitWithError("Failed to get client credentials", err)
if withClientCredsFile != "" {
authFlagsCounter++
}
if authFlagsCounter == 0 {
cli.ExitWithError(fmt.Sprintf("One of %s must be set", cli.PrettyList(authFlags)), err)
} else if authFlagsCounter > 1 {
cli.ExitWithError(fmt.Sprintf("Only one of %s must be set", cli.PrettyList(authFlags)), err)
}

inMemoryProfile = true
profile, err = profiles.New(profiles.WithInMemoryStore())
if err != nil || profile == nil {
cli.ExitWithError("Failed to initialize a temporary profile", err)
cli.ExitWithError("Failed to initialize in-memory profile", err)
}

if err := profile.AddProfile("temp", host, tlsNoVerify, true); err != nil {
cli.ExitWithError("Failed to create temporary profile", err)
cli.ExitWithError("Failed to create in-memory profile", err)
}

// add credentials to the temporary profile
cp, err = profile.UseProfile("temp")
if err != nil {
cli.ExitWithError("Failed to load temporary profile", err)
cli.ExitWithError("Failed to load in-memory profile", err)
}

// add credentials to the temporary profile
if err := cp.SetAuthCredentials(profiles.AuthCredentials{
AuthType: profiles.PROFILE_AUTH_TYPE_CLIENT_CREDENTIALS,
ClientId: cc.ClientId,
ClientSecret: cc.ClientSecret,
}); err != nil {
cli.ExitWithError("Failed to set client credentials", err)
// get credentials from flags
if withAccessToken != "" {
claims, err := auth.ParseClaimsJWT(withAccessToken)
if err != nil {
cli.ExitWithError("Failed to get access token", err)
}

if err := cp.SetAuthCredentials(profiles.AuthCredentials{
AuthType: profiles.PROFILE_AUTH_TYPE_ACCESS_TOKEN,
AccessToken: profiles.AuthCredentialsAccessToken{
AccessToken: withAccessToken,
Expiration: claims.Expiration,
},
}); err != nil {
cli.ExitWithError("Failed to set access token", err)
}
} else {
var cc auth.ClientCredentials
if withClientCreds != "" {
cc, err = auth.GetClientCredsFromJSON([]byte(withClientCreds))
} else if withClientCredsFile != "" {
cc, err = auth.GetClientCredsFromFile(withClientCredsFile)
}
if err != nil {
cli.ExitWithError("Failed to get client credentials", err)
}

// add credentials to the temporary profile
if err := cp.SetAuthCredentials(profiles.AuthCredentials{
AuthType: profiles.PROFILE_AUTH_TYPE_CLIENT_CREDENTIALS,
ClientId: cc.ClientId,
ClientSecret: cc.ClientSecret,
}); err != nil {
cli.ExitWithError("Failed to set client credentials", err)
}
}
if err := cp.Save(); err != nil {
cli.ExitWithError("Failed to save profile", err)
}
} else {
cp = InitProfile(c, false)
profile, cp = InitProfile(c, false)
}

if err := auth.ValidateProfileAuthCredentials(c.Context(), cp); err != nil {
if errors.Is(err, auth.ErrPlatformConfigNotFound) {
cli.ExitWithError(fmt.Sprintf("Failed to get platform configuration. Is the platform accepting connections at '%s'?", cp.GetEndpoint()), nil)
}
if inMemoryProfile {
cli.ExitWithError("Failed to authenticate with flag-provided client credentials", err)
cli.ExitWithError("Failed to authenticate with flag-provided client credentials.", err)
}
if errors.Is(err, auth.ErrProfileCredentialsNotFound) {
cli.ExitWithWarning("Profile missing credentials. Please login or add client credentials.")
}

if errors.Is(err, auth.ErrAccessTokenExpired) {
cli.ExitWithWarning("Access token expired. Please login again.")
cli.ExitWithWarning("Access token expired. Please login or add flag-provided credentials.")
}
if errors.Is(err, auth.ErrAccessTokenNotFound) {
cli.ExitWithWarning("No access token found. Please login or add client credentials.")
cli.ExitWithWarning("No access token found. Please login or add flag-provided credentials.")
}
cli.ExitWithError("Failed to get access token", err)
cli.ExitWithError("Failed to get access token.", err)
}

h, err := handlers.New(handlers.WithProfile(cp))
if err != nil {
cli.ExitWithError("Failed to create handler", err)
cli.ExitWithError("Unexpected error", err)
}
return h
}
Expand All @@ -181,7 +212,7 @@ func init() {
BuildTime: config.BuildTime,
}

c.Println(config.AppName + " version " + config.Version + " (" + config.BuildTime + ") " + config.CommitSha)
c.Println(fmt.Sprintf("%s version %s (%s) %s", config.AppName, config.Version, config.BuildTime, config.CommitSha))
c.ExitWithJSON(v)
return
}
Expand Down Expand Up @@ -243,5 +274,10 @@ func init() {
rootCmd.GetDocFlag("with-client-creds").Default,
rootCmd.GetDocFlag("with-client-creds").Description,
)
RootCmd.PersistentFlags().String(
rootCmd.GetDocFlag("with-access-token").Name,
rootCmd.GetDocFlag("with-access-token").Default,
rootCmd.GetDocFlag("with-access-token").Description,
)
RootCmd.AddGroup(&cobra.Group{ID: TDF})
}
2 changes: 2 additions & 0 deletions docs/man/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ command:
- fatal
- panic
default: info
- name: with-access-token
description: access token for authentication via bearer token
- name: with-client-creds-file
description: path to a JSON file containing a 'clientId' and 'clientSecret' for auth via client-credentials flow
- name: with-client-creds
Expand Down
9 changes: 5 additions & 4 deletions e2e/auth.bats
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ teardown_file() {
BAD_HOST='--host http://localhost:9000'
run_otdfctl $BAD_HOST $WITH_CREDS policy attributes list
assert_failure
assert_output --partial "Failed to get platform configuration. Is the platform accepting connections at 'http://localhost:9000'?"
assert_output --partial "Failed to get platform configuration. Is the platform accepting connections at"
}

@test "helpful error if bad credentials" {
Expand All @@ -43,17 +43,18 @@ teardown_file() {
BAD_CREDS="--with-client-creds '{clientId:"badClient",clientSecret:"badSecret"}'"
run_otdfctl $HOST $BAD_CREDS policy attributes list
assert_failure
assert_output --partial "Failed to get client credentials: failed to decode creds JSON"
assert_output --partial "Failed to get client credentials"
}

@test "helpful error if missing client credentials" {
run_otdfctl $HOST policy attributes list
assert_failure
assert_output --partial "Either --with-client-creds or --with-client-creds-file must be set: when using global flags --host, --tls-no-verify, --with-client-creds, or --with-client-creds-file, profiles will not be used and all required flags must be set"
assert_output --partial "One of"
assert_output --partial "must be set: when using global flags"
}

@test "helpful error if missing host" {
run_otdfctl $WITH_CREDS policy attributes list
assert_failure
assert_output --partial "Host must be set: when using global flags --host, --tls-no-verify, --with-client-creds, or --with-client-creds-file, profiles will not be used and all required flags must be set"
assert_output --partial "Host must be set: when using global flags"
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-jose/go-jose/v3 v3.0.3 // indirect
github.com/go-jose/go-jose/v4 v4.0.4 // indirect
github.com/go-logr/logr v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
Expand Down
Loading
Loading