Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ otdfctl.yaml

# Ignore the tructl binary
otdfctl
otdfctl_testbuild

# Misc
creds.json

# Hugo
public/
Expand Down
16 changes: 13 additions & 3 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,20 @@ func InitProfile(cmd *cobra.Command, onlyNew bool) *profiles.ProfileStore {
// TODO make this a preRun hook
func NewHandler(cmd *cobra.Command) handlers.Handler {
fh := cli.NewFlagHelper(cmd)

// Non-profile flags
host := fh.GetOptionalString("host")
tlsNoVerify := fh.GetOptionalBool("tls-no-verify")
withClientCreds := fh.GetOptionalString("with-client-creds")
withClientCredsFile := fh.GetOptionalString("with-client-creds-file")

// if global flags are set then validate and create a temporary profile in memory
var cp *profiles.ProfileStore
if host != "" || withClientCreds != "" || withClientCredsFile != "" {
err := errors.New("when using global flags --host, --with-client-creds, or --with-client-creds-file, " +
"profiles will not be used and all required flags must be set")
if host != "" || tlsNoVerify || withClientCreds != "" || withClientCredsFile != "" {
err := errors.New(
"when using global flags --host, --tls-no-verify, --with-client-creds, or --with-client-creds-file, " +
"profiles will not be used and all required flags must be set",
)

// host must be set
if host == "" {
Expand Down Expand Up @@ -170,6 +174,12 @@ func init() {
rootCmd.GetDocFlag("version").Description,
)

RootCmd.PersistentFlags().String(
rootCmd.GetDocFlag("profile").Name,
rootCmd.GetDocFlag("profile").Default,
rootCmd.GetDocFlag("profile").Description,
)

RootCmd.PersistentFlags().String(
rootCmd.GetDocFlag("host").Name,
rootCmd.GetDocFlag("host").Default,
Expand Down
3 changes: 3 additions & 0 deletions docs/man/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ command:
- name: version
description: show version
default: false
- name: profile
description: profile to use for interacting with the platform
default:
- name: host
description: Hostname of the platform (i.e. https://localhost)
default:
Expand Down
56 changes: 44 additions & 12 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ type platformConfiguration struct {
publicClientID string
}

type oidcClientCredentials struct {
clientID string
clientSecret string
isPublic bool
}

// Retrieves credentials by reading specified file
func GetClientCredsFromFile(filepath string) (ClientCredentials, error) {
creds := ClientCredentials{}
Expand Down Expand Up @@ -185,16 +191,13 @@ func GetTokenWithProfile(ctx context.Context, profile *profiles.ProfileStore) (*

// Uses the OAuth2 client credentials flow to obtain a token.
func GetTokenWithClientCreds(ctx context.Context, endpoint string, clientId string, clientSecret string, tlsNoVerify bool) (*oauth2.Token, error) {
pc, err := getPlatformConfiguration(endpoint, "", tlsNoVerify)
if err != nil && !errors.Is(err, sdk.ErrPlatformPublicClientIDNotFound) {
return nil, err
}

rp, err := oidcrp.NewRelyingPartyOIDC(ctx, pc.issuer, clientId, clientSecret, "", []string{"email"})
rp, err := newOidcRelyingParty(ctx, endpoint, tlsNoVerify, oidcClientCredentials{
clientID: clientId,
clientSecret: clientSecret,
})
if err != nil {
return nil, err
}

return oidcrp.ClientCredentials(ctx, rp, url.Values{})
}

Expand Down Expand Up @@ -268,15 +271,44 @@ func LoginWithPKCE(host, publicClientID string, tlsNoVerify bool) (*oauth2.Token

// Revokes the access token
func RevokeAccessToken(endpoint, publicClientID, refreshToken string, tlsNoVerify bool) error {
pCfg, err := getPlatformConfiguration(endpoint, publicClientID, tlsNoVerify)
rp, err := newOidcRelyingParty(context.Background(), endpoint, tlsNoVerify, oidcClientCredentials{
clientID: publicClientID,
isPublic: true,
})
if err != nil {
return fmt.Errorf("failed to get platform configuration: %w", err)
return err
}
return oidcrp.RevokeToken(context.Background(), rp, refreshToken, "refresh_token")
}

rp, err := oidcrp.NewRelyingPartyOIDC(context.Background(), pCfg.issuer, pCfg.publicClientID, "", "", nil)
func newOidcRelyingParty(ctx context.Context, endpoint string, tlsNoVerify bool, clientCreds oidcClientCredentials) (oidcrp.RelyingParty, error) {
if clientCreds.clientID == "" {
return nil, errors.New("client ID is required")
}
if clientCreds.clientSecret == "" && !clientCreds.isPublic {
return nil, errors.New("client secret is required")
}
if clientCreds.clientSecret != "" && clientCreds.isPublic {
return nil, errors.New("client secret must be empty for public clients")
}

var pcClient string
if clientCreds.isPublic {
pcClient = clientCreds.clientID
}

pc, err := getPlatformConfiguration(endpoint, pcClient, tlsNoVerify)
if err != nil {
return err
return nil, err
}

return oidcrp.RevokeToken(context.Background(), rp, refreshToken, "refresh_token")
return oidcrp.NewRelyingPartyOIDC(
ctx,
pc.issuer,
clientCreds.clientID,
clientCreds.clientSecret,
"",
nil,
oidcrp.WithHTTPClient(utils.NewHttpClient(tlsNoVerify)),
)
}
16 changes: 16 additions & 0 deletions pkg/utils/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package utils

import (
"crypto/tls"
"net/http"
)

func NewHttpClient(tlsNoVerify bool) *http.Client {
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: tlsNoVerify,
},
},
}
}