Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
344 changes: 250 additions & 94 deletions cmd/kas-keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,10 +236,7 @@ func getTableRows(kasKey *policy.KasKey) [][]string {

// TODO: Handle wrapping the generated key with provider config.
func policyCreateKasKey(cmd *cobra.Command, args []string) {
var (
wrappingKeyID string
providerConfigID string
)
var wrappingKeyID string

c := cli.New(cmd, args)
h := NewHandler(c)
Expand All @@ -249,99 +246,16 @@ func policyCreateKasKey(cmd *cobra.Command, args []string) {
kasIdentifier := c.Flags.GetRequiredString("kas")
metadataLabels = c.Flags.GetStringSlice("label", metadataLabels, cli.FlagsStringSliceOptions{Min: 0})

alg, err := algToEnum(c.Flags.GetRequiredString("algorithm"))
// Use the helper function to get and validate key parameters
alg, mode, wrappingKeyID, err := prepareKeyParams(c)
if err != nil {
cli.ExitWithError("Invalid algorithm", err)
cli.ExitWithError("Invalid key parameters", err)
}

mode, err := modeToEnum(c.Flags.GetRequiredString("mode"))
// Use the helper function to prepare key contexts
publicKeyCtx, privateKeyCtx, providerConfigID, err := prepareKeyContexts(c, mode, alg, wrappingKeyID)
if err != nil {
cli.ExitWithError("Invalid mode", err)
}

wrappingKeyID = c.Flags.GetOptionalString("wrapping-key-id")
if mode != policy.KeyMode_KEY_MODE_PUBLIC_KEY_ONLY && wrappingKeyID == "" {
formattedMode, _ := enumToMode(mode)
cli.ExitWithError(fmt.Sprintf("wrapping-key-id is required for mode %s", formattedMode), nil)
}

providerConfigID = c.Flags.GetOptionalString("provider-config-id")
if (mode == policy.KeyMode_KEY_MODE_PROVIDER_ROOT_KEY || mode == policy.KeyMode_KEY_MODE_REMOTE) && providerConfigID == "" {
formattedMode, _ := enumToMode(mode)
cli.ExitWithError(fmt.Sprintf("provider-config-id is required for mode %s", formattedMode), nil)
}

var publicKeyCtx *policy.PublicKeyCtx
var privateKeyCtx *policy.PrivateKeyCtx
switch mode {
case policy.KeyMode_KEY_MODE_CONFIG_ROOT_KEY:
wrappingKey := c.Flags.GetRequiredString("wrapping-key")
privateKeyPem, publicKeyPem, err := generateKeys(alg)
if err != nil {
cli.ExitWithError("Failed to generate keys", err)
}

privateKey, err := wrapKey(privateKeyPem, wrappingKey)
if err != nil {
cli.ExitWithError("Failed to wrap key", err)
}

pubPemBase64 := base64.StdEncoding.EncodeToString([]byte(publicKeyPem))
privPemBase64 := base64.StdEncoding.EncodeToString(privateKey)
publicKeyCtx = &policy.PublicKeyCtx{
Pem: pubPemBase64,
}
privateKeyCtx = &policy.PrivateKeyCtx{
KeyId: wrappingKeyID,
WrappedKey: privPemBase64,
}
case policy.KeyMode_KEY_MODE_PROVIDER_ROOT_KEY:
providerConfigID = c.Flags.GetRequiredString("provider-config-id")
publicPem := c.Flags.GetRequiredString("public-key-pem")
privatePem := c.Flags.GetRequiredString("private-key-pem")
_, err = base64.StdEncoding.DecodeString(publicPem)
if err != nil {
cli.ExitWithError("pem must be base64 encoded", err)
}
_, err = base64.StdEncoding.DecodeString(privatePem)
if err != nil {
cli.ExitWithError("pem must be base64 encoded", err)
}
publicKeyCtx = &policy.PublicKeyCtx{
Pem: publicPem,
}
privateKeyCtx = &policy.PrivateKeyCtx{
KeyId: wrappingKeyID,
WrappedKey: privatePem,
}
case policy.KeyMode_KEY_MODE_REMOTE:
pem := c.Flags.GetRequiredString("public-key-pem")
providerConfigID = c.Flags.GetRequiredString("provider-config-id")

_, err = base64.StdEncoding.DecodeString(pem)
if err != nil {
cli.ExitWithError("pem must be base64 encoded", err)
}

publicKeyCtx = &policy.PublicKeyCtx{
Pem: pem,
}
privateKeyCtx = &policy.PrivateKeyCtx{
KeyId: wrappingKeyID,
}
case policy.KeyMode_KEY_MODE_PUBLIC_KEY_ONLY:
pem := c.Flags.GetRequiredString("public-key-pem")
_, err = base64.StdEncoding.DecodeString(pem)
if err != nil {
cli.ExitWithError("pem must be base64 encoded", err)
}
publicKeyCtx = &policy.PublicKeyCtx{
Pem: pem,
}
case policy.KeyMode_KEY_MODE_UNSPECIFIED:
fallthrough
default:
cli.ExitWithError("Invalid mode", nil)
cli.ExitWithError("Failed to prepare key contexts", err)
}

kasLookup, err := resolveKasIdentifier(kasIdentifier)
Expand Down Expand Up @@ -539,6 +453,70 @@ func policyListKasKeys(cmd *cobra.Command, args []string) {
HandleSuccess(cmd, "", t, keys)
}

func policyRotateKasKey(cmd *cobra.Command, args []string) {
var wrappingKeyID string

c := cli.New(cmd, args)
h := NewHandler(c)
defer h.Close()

// Get parameters for the old key
oldKey := c.Flags.GetRequiredString("key")

// Get parameters for creating the new key
newKeyID := c.Flags.GetRequiredString("key-id")
metadataLabels = c.Flags.GetStringSlice("label", metadataLabels, cli.FlagsStringSliceOptions{Min: 0})

// Use the helper function to get and validate key parameters
alg, mode, wrappingKeyID, err := prepareKeyParams(c)
if err != nil {
cli.ExitWithError("Invalid key parameters", err)
}

// Use the helper function to prepare key contexts
publicKeyCtx, privateKeyCtx, providerConfigID, err := prepareKeyContexts(c, mode, alg, wrappingKeyID)
if err != nil {
cli.ExitWithError("Failed to prepare key contexts", err)
}

// Create the new key request with the contexts created by the helper
newKey := &kasregistry.RotateKeyRequest_NewKey{
KeyId: newKeyID,
Algorithm: alg,
KeyMode: mode,
PublicKeyCtx: publicKeyCtx,
PrivateKeyCtx: privateKeyCtx,
ProviderConfigId: providerConfigID,
Metadata: getMetadataMutable(metadataLabels),
}

var identifier *kasregistry.KasKeyIdentifier
if utils.ClassifyString(oldKey) != utils.StringTypeUUID {
identifier, err = getKasKeyIdentifier(c)
if err != nil {
cli.ExitWithError("Invalid key identifier", err)
}
}

// Call the rotate key function
rotateKeyResult, err := h.RotateKasKey(
c.Context(),
oldKey,
identifier,
newKey,
)
if err != nil {
cli.ExitWithError("Failed to rotate key", err)
}

rows := getTableRows(rotateKeyResult.KasKey)
if mdRows := getMetadataRows(rotateKeyResult.KasKey.GetKey().GetMetadata()); mdRows != nil {
rows = append(rows, mdRows...)
}
t := cli.NewTabular(rows...)
HandleSuccess(cmd, rotateKeyResult.KasKey.GetKey().GetId(), t, rotateKeyResult)
}

func resolveKasIdentifier(ident string) (handlers.KasIdentifier, error) {
// If the identifier is empty, it means no KAS filter is applied.
// Return an empty KasIdentifier and no error.
Expand All @@ -564,6 +542,118 @@ func resolveKasIdentifier(ident string) (handlers.KasIdentifier, error) {
return kasLookup, nil
}

// prepareKeyParams parses and validates the common key parameters used by both create and rotate operations.
// It returns the algorithm, mode, wrapping key ID, and any error that occurred.
func prepareKeyParams(c *cli.Cli) (policy.Algorithm, policy.KeyMode, string, error) {
// Parse algorithm
alg, err := algToEnum(c.Flags.GetRequiredString("algorithm"))
if err != nil {
return alg, 0, "", err
}

// Parse mode
mode, err := modeToEnum(c.Flags.GetRequiredString("mode"))
if err != nil {
return alg, mode, "", err
}

// Get wrapping key ID and validate based on mode
wrappingKeyID := c.Flags.GetOptionalString("wrapping-key-id")
if mode != policy.KeyMode_KEY_MODE_PUBLIC_KEY_ONLY && wrappingKeyID == "" {
formattedMode, _ := enumToMode(mode)
return alg, mode, "", fmt.Errorf("wrapping-key-id is required for mode %s", formattedMode)
}

return alg, mode, wrappingKeyID, nil
}

// prepareKeyContexts prepares the key contexts based on the specified mode and parameters.
// This function encapsulates the common logic between key creation and key rotation.
func prepareKeyContexts(
c *cli.Cli,
mode policy.KeyMode,
alg policy.Algorithm,
wrappingKeyID string,
) (*policy.PublicKeyCtx, *policy.PrivateKeyCtx, string, error) {
var publicKeyCtx *policy.PublicKeyCtx
var privateKeyCtx *policy.PrivateKeyCtx
var providerConfigID string

switch mode {
case policy.KeyMode_KEY_MODE_CONFIG_ROOT_KEY:
// Local mode: generate keys locally and wrap with provided wrapping key
wrappingKey := c.Flags.GetRequiredString("wrapping-key")
privateKeyPem, publicKeyPem, err := generateKeys(alg)
if err != nil {
return nil, nil, "", errors.Join(errors.New("failed to generate keys"), err)
}

privateKey, err := wrapKey(privateKeyPem, wrappingKey)
if err != nil {
return nil, nil, "", errors.Join(errors.New("failed to wrap key"), err)
}

pubPemBase64 := base64.StdEncoding.EncodeToString([]byte(publicKeyPem))
privPemBase64 := base64.StdEncoding.EncodeToString(privateKey)
publicKeyCtx = &policy.PublicKeyCtx{
Pem: pubPemBase64,
}
privateKeyCtx = &policy.PrivateKeyCtx{
KeyId: wrappingKeyID,
WrappedKey: privPemBase64,
}
case policy.KeyMode_KEY_MODE_PROVIDER_ROOT_KEY:
providerConfigID = c.Flags.GetRequiredString("provider-config-id")
publicPem := c.Flags.GetRequiredString("public-key-pem")
privatePem := c.Flags.GetRequiredString("private-key-pem")
_, err := base64.StdEncoding.DecodeString(publicPem)
if err != nil {
return nil, nil, "", errors.Join(errors.New("public key pem must be base64 encoded"), err)
}
_, err = base64.StdEncoding.DecodeString(privatePem)
if err != nil {
return nil, nil, "", errors.Join(errors.New("private key pem must be base64 encoded"), err)
}
publicKeyCtx = &policy.PublicKeyCtx{
Pem: publicPem,
}
privateKeyCtx = &policy.PrivateKeyCtx{
KeyId: wrappingKeyID,
WrappedKey: privatePem,
}
case policy.KeyMode_KEY_MODE_REMOTE:
pem := c.Flags.GetRequiredString("public-key-pem")
providerConfigID = c.Flags.GetRequiredString("provider-config-id")

_, err := base64.StdEncoding.DecodeString(pem)
if err != nil {
return nil, nil, "", errors.Join(errors.New("pem must be base64 encoded"), err)
}

publicKeyCtx = &policy.PublicKeyCtx{
Pem: pem,
}
privateKeyCtx = &policy.PrivateKeyCtx{
KeyId: wrappingKeyID,
}
case policy.KeyMode_KEY_MODE_PUBLIC_KEY_ONLY:
pem := c.Flags.GetRequiredString("public-key-pem")
_, err := base64.StdEncoding.DecodeString(pem)
if err != nil {
return nil, nil, "", errors.Join(errors.New("pem must be base64 encoded"), err)
}
publicKeyCtx = &policy.PublicKeyCtx{
Pem: pem,
}
case policy.KeyMode_KEY_MODE_UNSPECIFIED:
fallthrough
default:
return nil, nil, "", errors.New("invalid mode")
}

return publicKeyCtx, privateKeyCtx, providerConfigID, nil
}

func init() {
// Create Kas Key
createDoc := man.Docs.GetCommand("policy/kas-registry/key/create",
Expand Down Expand Up @@ -671,6 +761,72 @@ func init() {
)
injectListPaginationFlags(listDoc)

policyKasRegistryKeysCmd.AddSubcommands(createDoc, getDoc, updateDoc, listDoc)
// Rotate Kas Key
rotateDoc := man.Docs.GetCommand("policy/kas-registry/key/rotate",
man.WithRun(policyRotateKasKey),
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("key").Name,
rotateDoc.GetDocFlag("key").Shorthand,
rotateDoc.GetDocFlag("key").Default,
rotateDoc.GetDocFlag("key").Description,
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("kas").Name,
rotateDoc.GetDocFlag("kas").Shorthand,
rotateDoc.GetDocFlag("kas").Default,
rotateDoc.GetDocFlag("kas").Description,
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("key-id").Name,
rotateDoc.GetDocFlag("key-id").Shorthand,
rotateDoc.GetDocFlag("key-id").Default,
rotateDoc.GetDocFlag("key-id").Description,
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("algorithm").Name,
rotateDoc.GetDocFlag("algorithm").Shorthand,
rotateDoc.GetDocFlag("algorithm").Default,
rotateDoc.GetDocFlag("algorithm").Description,
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("mode").Name,
rotateDoc.GetDocFlag("mode").Shorthand,
rotateDoc.GetDocFlag("mode").Default,
rotateDoc.GetDocFlag("mode").Description,
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("wrapping-key-id").Name,
rotateDoc.GetDocFlag("wrapping-key-id").Shorthand,
rotateDoc.GetDocFlag("wrapping-key-id").Default,
rotateDoc.GetDocFlag("wrapping-key-id").Description,
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("wrapping-key").Name,
rotateDoc.GetDocFlag("wrapping-key").Shorthand,
rotateDoc.GetDocFlag("wrapping-key").Default,
rotateDoc.GetDocFlag("wrapping-key").Description,
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("provider-config-id").Name,
rotateDoc.GetDocFlag("provider-config-id").Shorthand,
rotateDoc.GetDocFlag("provider-config-id").Default,
rotateDoc.GetDocFlag("provider-config-id").Description,
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("public-key-pem").Name,
rotateDoc.GetDocFlag("public-key-pem").Shorthand,
rotateDoc.GetDocFlag("public-key-pem").Default,
rotateDoc.GetDocFlag("public-key-pem").Description,
)
rotateDoc.Flags().StringP(
rotateDoc.GetDocFlag("private-key-pem").Name,
rotateDoc.GetDocFlag("private-key-pem").Shorthand,
rotateDoc.GetDocFlag("private-key-pem").Default,
rotateDoc.GetDocFlag("private-key-pem").Description,
)
injectLabelFlags(&rotateDoc.Command, true)

policyKasRegistryKeysCmd.AddSubcommands(createDoc, getDoc, updateDoc, listDoc, rotateDoc)
policyKasRegCmd.AddCommand(&policyKasRegistryKeysCmd.Command)
}
Loading
Loading