diff --git a/api/types/saml.go b/api/types/saml.go index 0295aeb3e66c1..5f042f2f4047c 100644 --- a/api/types/saml.go +++ b/api/types/saml.go @@ -17,6 +17,7 @@ limitations under the License. package types import ( + "encoding/json" "slices" "strings" "time" @@ -524,3 +525,98 @@ func (r *SAMLAuthRequest) Check() error { } return nil } + +// MarshalJSON marshals SAMLForceAuthn to string. +func (s SAMLForceAuthn) MarshalYAML() (interface{}, error) { + val, err := s.encode() + if err != nil { + return nil, trace.Wrap(err) + } + return val, nil +} + +// UnmarshalYAML supports parsing SAMLForceAuthn from string. +func (s *SAMLForceAuthn) UnmarshalYAML(unmarshal func(interface{}) error) error { + var val any + if err := unmarshal(&val); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(s.decode(val)) +} + +// MarshalJSON marshals SAMLForceAuthn to string. +func (s SAMLForceAuthn) MarshalJSON() ([]byte, error) { + val, err := s.encode() + if err != nil { + return nil, trace.Wrap(err) + } + out, err := json.Marshal(val) + return out, trace.Wrap(err) +} + +// UnmarshalJSON supports parsing SAMLForceAuthn from string. +func (s *SAMLForceAuthn) UnmarshalJSON(data []byte) error { + var val any + if err := json.Unmarshal(data, &val); err != nil { + return trace.Wrap(err) + } + return trace.Wrap(s.decode(val)) +} + +func (s *SAMLForceAuthn) encode() (string, error) { + switch *s { + case SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED: + return "", nil + case SAMLForceAuthn_FORCE_AUTHN_NO: + return "no", nil + case SAMLForceAuthn_FORCE_AUTHN_YES: + return "yes", nil + default: + return "", trace.BadParameter("SAMLForceAuthn invalid value %v", *s) + } +} + +func (s *SAMLForceAuthn) decode(val any) error { + switch v := val.(type) { + case string: + // try parsing as a boolean + switch strings.ToLower(v) { + case "": + *s = SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED + case "yes", "yeah", "y", "true", "1", "on": + *s = SAMLForceAuthn_FORCE_AUTHN_YES + case "no", "nope", "n", "false", "0", "off": + *s = SAMLForceAuthn_FORCE_AUTHN_NO + default: + return trace.BadParameter("SAMLForceAuthn invalid value %v", val) + } + case bool: + if v { + *s = SAMLForceAuthn_FORCE_AUTHN_YES + } else { + *s = SAMLForceAuthn_FORCE_AUTHN_NO + } + case int32: + return trace.Wrap(s.setFromEnum(v)) + case int64: + return trace.Wrap(s.setFromEnum(int32(v))) + case int: + return trace.Wrap(s.setFromEnum(int32(v))) + case float64: + return trace.Wrap(s.setFromEnum(int32(v))) + case float32: + return trace.Wrap(s.setFromEnum(int32(v))) + default: + return trace.BadParameter("SAMLForceAuthn invalid type %T", val) + } + return nil +} + +// setFromEnum sets the value from enum value as int32. +func (s *SAMLForceAuthn) setFromEnum(val int32) error { + if _, ok := SAMLForceAuthn_name[val]; !ok { + return trace.BadParameter("invalid SAMLForceAuthn enum %v", val) + } + *s = SAMLForceAuthn(val) + return nil +} diff --git a/api/types/saml_test.go b/api/types/saml_test.go index 228b0afd9f35b..933df64e150f6 100644 --- a/api/types/saml_test.go +++ b/api/types/saml_test.go @@ -17,9 +17,13 @@ limitations under the License. package types_test import ( + "encoding/json" + "fmt" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" "github.com/gravitational/teleport/api/types" ) @@ -107,3 +111,62 @@ func TestSAMLForceAuthn(t *testing.T) { }) } } + +func TestSAMLForceAuthn_Encoding(t *testing.T) { + for _, tt := range []struct { + forceAuthn types.SAMLForceAuthn + expectEncoded string + }{ + { + forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_UNSPECIFIED, + expectEncoded: "", + }, { + forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_YES, + expectEncoded: "yes", + }, { + forceAuthn: types.SAMLForceAuthn_FORCE_AUTHN_NO, + expectEncoded: "no", + }, + } { + t.Run(tt.forceAuthn.String(), func(t *testing.T) { + type object struct { + ForceAuthn types.SAMLForceAuthn `json:"force_authn" yaml:"force_authn"` + } + o := object{ + ForceAuthn: tt.forceAuthn, + } + objectJSON := fmt.Sprintf(`{"force_authn":%q}`, tt.expectEncoded) + objectYAML := fmt.Sprintf("force_authn: %q\n", tt.expectEncoded) + + t.Run("JSON", func(t *testing.T) { + t.Run("Marshal", func(t *testing.T) { + gotJSON, err := json.Marshal(o) + assert.NoError(t, err, "unexpected error from json.Marshal") + assert.Equal(t, objectJSON, string(gotJSON), "unexpected json.Marshal value") + }) + + t.Run("Unmarshal", func(t *testing.T) { + var gotObject object + err := json.Unmarshal([]byte(objectJSON), &gotObject) + assert.NoError(t, err, "unexpected error from json.Unmarshal") + assert.Equal(t, tt.forceAuthn, gotObject.ForceAuthn, "unexpected json.Unmarshal value") + }) + }) + + t.Run("YAML", func(t *testing.T) { + t.Run("Marshal", func(t *testing.T) { + gotYAML, err := yaml.Marshal(o) + assert.NoError(t, err, "unexpected error from yaml.Marshal") + assert.Equal(t, objectYAML, string(gotYAML), "unexpected yaml.Marshal value") + }) + + t.Run("Unmarshal", func(t *testing.T) { + var gotObject object + err := yaml.Unmarshal([]byte(objectYAML), &gotObject) + assert.NoError(t, err, "unexpected error from yaml.Unmarshal") + assert.Equal(t, tt.forceAuthn, gotObject.ForceAuthn, "unexpected yaml.Unmarshal value") + }) + }) + }) + } +}