Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 11 additions & 55 deletions api/types/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -1118,59 +1118,15 @@ func (r *RequireMFAType) encode() (interface{}, error) {
// decode RequireMFAType from a string or boolean. This is necessary for
// backwards compatibility with the json/yaml tag "require_session_mfa",
// which used to be a boolean.
func (r *RequireMFAType) decode(val interface{}) error {
switch v := val.(type) {
case string:
switch v {
case RequireMFATypeHardwareKeyString:
*r = RequireMFAType_SESSION_AND_HARDWARE_KEY
case RequireMFATypeHardwareKeyTouchString:
*r = RequireMFAType_HARDWARE_KEY_TOUCH
case RequireMFATypeHardwareKeyPINString:
*r = RequireMFAType_HARDWARE_KEY_PIN
case RequireMFATypeHardwareKeyTouchAndPINString:
*r = RequireMFAType_HARDWARE_KEY_TOUCH_AND_PIN
case "":
// default to off
*r = RequireMFAType_OFF
default:
// try parsing as a boolean
switch strings.ToLower(v) {
case "yes", "yeah", "y", "true", "1", "on":
*r = RequireMFAType_SESSION
case "no", "nope", "n", "false", "0", "off":
*r = RequireMFAType_OFF
default:
return trace.BadParameter("RequireMFAType invalid value %v", val)
}
}
case bool:
if v {
*r = RequireMFAType_SESSION
} else {
*r = RequireMFAType_OFF
}
case int32:
return trace.Wrap(r.setFromEnum(v))
case int64:
return trace.Wrap(r.setFromEnum(int32(v)))
case int:
return trace.Wrap(r.setFromEnum(int32(v)))
case float64:
return trace.Wrap(r.setFromEnum(int32(v)))
case float32:
return trace.Wrap(r.setFromEnum(int32(v)))
default:
return trace.BadParameter("RequireMFAType invalid type %T", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (r *RequireMFAType) setFromEnum(val int32) error {
if _, ok := RequireMFAType_name[val]; !ok {
return trace.BadParameter("invalid required mfa mode %v", val)
}
*r = RequireMFAType(val)
return nil
func (r *RequireMFAType) decode(val any) error {
err := decodeEnum(r, val, map[any]RequireMFAType{
"": RequireMFAType_OFF, // default to off
false: RequireMFAType_OFF,
true: RequireMFAType_SESSION,
RequireMFATypeHardwareKeyString: RequireMFAType_SESSION_AND_HARDWARE_KEY,
RequireMFATypeHardwareKeyTouchString: RequireMFAType_HARDWARE_KEY_TOUCH,
RequireMFATypeHardwareKeyPINString: RequireMFAType_HARDWARE_KEY_PIN,
RequireMFATypeHardwareKeyTouchAndPINString: RequireMFAType_HARDWARE_KEY_TOUCH_AND_PIN,
}, RequireMFAType_name)
return trace.Wrap(err, "failed to decode require mfa type")
}
Comment on lines +1121 to 1132
Copy link
Copy Markdown
Contributor

@rosstimothy rosstimothy Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I prefer the more explicit implementation than farming things out to this helper. The call site here is quite dense and a bit hard to grok what is going on without jumping to the decodeEnum function.

Copy link
Copy Markdown
Contributor Author

@Joerger Joerger Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. My thinking here is that I like the readability of the helper's usage - you can see at a glance what values map to what enum without having to parse through the type switches, bool option handling, etc. But I'm not married to the change at all, I can just close this one.

Edit: it also did turn out more complex than I envisioned, so I understand the confusion concern.

82 changes: 82 additions & 0 deletions api/types/enum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
Copyright 2024 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package types

import (
"strings"

"github.com/gravitational/trace"
)

// decodeEnum decodes a protobuf enum from a representational value, usually a bool,
// string, or from the actual enum (int32) value. If the value is valid, it is saved
// in the given enum pointer.
func decodeEnum[T ~int32](p *T, val any, representationMap map[any]T, enumMap map[int32]string) error {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General use helper functions like this should be covered by tests.

if v, ok := representationMap[val]; ok {
*p = v
return nil
}

// try parsing as a bool value
if v, ok := val.(string); ok {
switch strings.ToLower(v) {
case "yes", "yeah", "y", "true", "1", "on":
if v, ok := representationMap[true]; ok {
*p = v
return nil
}
case "no", "nope", "n", "false", "0", "off":
if v, ok := representationMap[false]; ok {
*p = v
return nil
}
}
return trace.BadParameter("unknown enum value %v", val)
}

// parse as enum
var enumVal T
switch v := val.(type) {
case int:
enumVal = T(v)
case int32:
enumVal = T(v)
case int64:
enumVal = T(v)
case float64:
enumVal = T(v)
case float32:
enumVal = T(v)
default:
return trace.BadParameter("unknown enum value %v", val)
}

if err := checkEnum(enumMap, int32(enumVal)); err != nil {
return trace.BadParameter("unknown enum value %v", val)
}

*p = enumVal
return nil
}

// checkEnum checks if the given enum is valid.
func checkEnum(enumMap map[int32]string, val int32) error {
if _, ok := enumMap[val]; ok {
return nil
}
return trace.NotFound("enum %v not found in enum map", val)
}
87 changes: 16 additions & 71 deletions api/types/role.go
Original file line number Diff line number Diff line change
Expand Up @@ -1998,55 +1998,15 @@ func (h CreateHostUserMode) encode() (string, error) {
}

func (h *CreateHostUserMode) decode(val any) error {
var valS string
switch val := val.(type) {
case int32:
return trace.Wrap(h.setFromEnum(val))
case int64:
return trace.Wrap(h.setFromEnum(int32(val)))
case int:
return trace.Wrap(h.setFromEnum(int32(val)))
case float64:
return trace.Wrap(h.setFromEnum(int32(val)))
case float32:
return trace.Wrap(h.setFromEnum(int32(val)))
case string:
valS = val
case bool:
if val {
return trace.BadParameter("create_host_user_mode cannot be true, got %v", val)
}
valS = createHostUserModeOffString
default:
return trace.BadParameter("bad value type %T, expected string or int", val)
}

switch valS {
case "":
*h = CreateHostUserMode_HOST_USER_MODE_UNSPECIFIED
case createHostUserModeOffString:
*h = CreateHostUserMode_HOST_USER_MODE_OFF
case createHostUserModeKeepString:
*h = CreateHostUserMode_HOST_USER_MODE_KEEP
case createHostUserModeInsecureDropString, createHostUserModeDropString:
*h = CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP
default:
return trace.BadParameter("invalid host user mode %v", val)
}
return nil
}

// setFromEnum sets the value from enum value as int32.
func (h *CreateHostUserMode) setFromEnum(val int32) error {
// Map drop to insecure-drop
if val == int32(CreateHostUserMode_HOST_USER_MODE_DROP) {
val = int32(CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP)
}
if _, ok := CreateHostUserMode_name[val]; !ok {
return trace.BadParameter("invalid host user mode %v", val)
}
*h = CreateHostUserMode(val)
return nil
err := decodeEnum(h, val, map[interface{}]CreateHostUserMode{
"": CreateHostUserMode_HOST_USER_MODE_UNSPECIFIED,
false: CreateHostUserMode_HOST_USER_MODE_OFF,
createHostUserModeOffString: CreateHostUserMode_HOST_USER_MODE_OFF,
createHostUserModeKeepString: CreateHostUserMode_HOST_USER_MODE_KEEP,
createHostUserModeInsecureDropString: CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP,
createHostUserModeDropString: CreateHostUserMode_HOST_USER_MODE_INSECURE_DROP,
}, CreateHostUserMode_name)
return trace.Wrap(err, "failed to decode host user mode")
}

// UnmarshalYAML supports parsing CreateHostUserMode from string.
Expand Down Expand Up @@ -2114,28 +2074,13 @@ func (h CreateDatabaseUserMode) encode() (string, error) {
}

func (h *CreateDatabaseUserMode) decode(val any) error {
var str string
switch val := val.(type) {
case string:
str = val
default:
return trace.BadParameter("bad value type %T, expected string", val)
}

switch str {
case "":
*h = CreateDatabaseUserMode_DB_USER_MODE_UNSPECIFIED
case createDatabaseUserModeOffString:
*h = CreateDatabaseUserMode_DB_USER_MODE_OFF
case createDatabaseUserModeKeepString:
*h = CreateDatabaseUserMode_DB_USER_MODE_KEEP
case createDatabaseUserModeBestEffortDropString:
*h = CreateDatabaseUserMode_DB_USER_MODE_BEST_EFFORT_DROP
default:
return trace.BadParameter("invalid database user mode %v", val)
}

return nil
err := decodeEnum(h, val, map[interface{}]CreateDatabaseUserMode{
"": CreateDatabaseUserMode_DB_USER_MODE_UNSPECIFIED,
createDatabaseUserModeOffString: CreateDatabaseUserMode_DB_USER_MODE_OFF,
createDatabaseUserModeKeepString: CreateDatabaseUserMode_DB_USER_MODE_KEEP,
createDatabaseUserModeBestEffortDropString: CreateDatabaseUserMode_DB_USER_MODE_BEST_EFFORT_DROP,
}, CreateDatabaseUserMode_name)
return trace.Wrap(err, "failed to decode require mfa type")
}

// UnmarshalYAML supports parsing CreateDatabaseUserMode from string.
Expand Down