diff --git a/pkg/blueprint/customizations.go b/pkg/blueprint/customizations.go index 37d840f..c578b71 100644 --- a/pkg/blueprint/customizations.go +++ b/pkg/blueprint/customizations.go @@ -4,7 +4,6 @@ import ( "fmt" "reflect" "slices" - "strings" "github.com/osbuild/images/pkg/cert" "github.com/osbuild/images/pkg/customizations/anaconda" @@ -15,7 +14,7 @@ type Customizations struct { Kernel *KernelCustomization `json:"kernel,omitempty" toml:"kernel,omitempty"` SSHKey []SSHKeyCustomization `json:"sshkey,omitempty" toml:"sshkey,omitempty"` User []UserCustomization `json:"user,omitempty" toml:"user,omitempty"` - Group []GroupCustomization `json:"group,omitempty" toml:"group,omitempty"` + Group GroupsCustomization `json:"group,omitempty" toml:"group,omitempty"` Timezone *TimezoneCustomization `json:"timezone,omitempty" toml:"timezone,omitempty"` Locale *LocaleCustomization `json:"locale,omitempty" toml:"locale,omitempty"` Firewall *FirewallCustomization `json:"firewall,omitempty" toml:"firewall,omitempty"` @@ -247,47 +246,6 @@ func (c *Customizations) GetTimezoneSettings() (*string, []string) { return c.Timezone.Timezone, c.Timezone.NTPServers } -func (c *Customizations) GetUsers() []UserCustomization { - if c == nil || (c.User == nil && c.SSHKey == nil) { - return nil - } - - var users []UserCustomization - - // prepend sshkey for backwards compat (overridden by users) - if len(c.SSHKey) > 0 { - for _, k := range c.SSHKey { - key := k.Key - users = append(users, UserCustomization{ - Name: k.User, - Key: &key, - }) - } - } - - users = append(users, c.User...) - - // sanitize user home directory in blueprint: if it has a trailing slash, - // it might lead to the directory not getting the correct selinux labels - for idx := range users { - u := users[idx] - if u.Home != nil { - homedir := strings.TrimRight(*u.Home, "/") - u.Home = &homedir - users[idx] = u - } - } - return users -} - -func (c *Customizations) GetGroups() []GroupCustomization { - if c == nil { - return nil - } - - return c.Group -} - func (c *Customizations) GetKernel() *KernelCustomization { var kernelName, kernelAppend string if c != nil && c.Kernel != nil { diff --git a/pkg/blueprint/customizations_test.go b/pkg/blueprint/customizations_test.go index 79011a0..99ed4e6 100644 --- a/pkg/blueprint/customizations_test.go +++ b/pkg/blueprint/customizations_test.go @@ -78,81 +78,6 @@ func TestGetKernel(t *testing.T) { assert.Equal(t, &expectedKernel, retKernel) } -func TestSSHKey(t *testing.T) { - expectedSSHKeys := []SSHKeyCustomization{ - { - User: "test-user", - Key: "test-key", - }, - } - TestCustomizations := Customizations{ - SSHKey: expectedSSHKeys, - } - - retUser := TestCustomizations.GetUsers()[0].Name - retKey := *TestCustomizations.GetUsers()[0].Key - - assert.Equal(t, expectedSSHKeys[0].User, retUser) - assert.Equal(t, expectedSSHKeys[0].Key, retKey) -} - -func TestGetUsers(t *testing.T) { - Desc := "Test descritpion" - Pass := "testpass" - Key := "testkey" - Home := "Home" - Shell := "Shell" - Groups := []string{ - "Group", - } - UID := 123 - GID := 321 - ExpireDate := 12345 - ForcePasswordReset := true - - expectedUsers := []UserCustomization{ - { - Name: "John", - Description: &Desc, - Password: &Pass, - Key: &Key, - Home: &Home, - Shell: &Shell, - Groups: Groups, - UID: &UID, - GID: &GID, - ExpireDate: &ExpireDate, - ForcePasswordReset: &ForcePasswordReset, - }, - } - - TestCustomizations := Customizations{ - User: expectedUsers, - } - - retUsers := TestCustomizations.GetUsers() - - assert.ElementsMatch(t, expectedUsers, retUsers) -} - -func TestGetGroups(t *testing.T) { - GID := 1234 - expectedGroups := []GroupCustomization{ - { - Name: "TestGroup", - GID: &GID, - }, - } - - TestCustomizations := Customizations{ - Group: expectedGroups, - } - - retGroups := TestCustomizations.GetGroups() - - assert.ElementsMatch(t, expectedGroups, retGroups) -} - func TestGetTimezoneSettings(t *testing.T) { expectedTimezone := "testZONE" expectedNTPServers := []string{ @@ -253,7 +178,11 @@ func TestNoCustomizationsInBlueprint(t *testing.T) { assert.Nil(t, TestBP.Customizations.GetHostname()) assert.Nil(t, TestBP.Customizations.GetUsers()) - assert.Nil(t, TestBP.Customizations.GetGroups()) + + groups, err := TestBP.Customizations.GetGroups() + assert.NoError(t, err) + assert.Nil(t, groups) + assert.Equal(t, &KernelCustomization{Name: "kernel"}, TestBP.Customizations.GetKernel()) assert.Nil(t, TestBP.Customizations.GetFirewall()) assert.Nil(t, TestBP.Customizations.GetServices()) diff --git a/pkg/blueprint/users_groups_customizations.go b/pkg/blueprint/users_groups_customizations.go new file mode 100644 index 0000000..720b31f --- /dev/null +++ b/pkg/blueprint/users_groups_customizations.go @@ -0,0 +1,81 @@ +package blueprint + +import ( + "errors" + "fmt" + "strings" +) + +func (c *Customizations) GetUsers() []UserCustomization { + if c == nil || (c.User == nil && c.SSHKey == nil) { + return nil + } + + var users []UserCustomization + + // prepend sshkey for backwards compat (overridden by users) + if len(c.SSHKey) > 0 { + for _, k := range c.SSHKey { + key := k.Key + users = append(users, UserCustomization{ + Name: k.User, + Key: &key, + }) + } + } + + users = append(users, c.User...) + + // sanitize user home directory in blueprint: if it has a trailing slash, + // it might lead to the directory not getting the correct selinux labels + for idx := range users { + u := users[idx] + if u.Home != nil { + homedir := strings.TrimRight(*u.Home, "/") + u.Home = &homedir + users[idx] = u + } + } + return users +} + +type GroupsCustomization []GroupCustomization + +func (g GroupsCustomization) Validate() error { + names := make(map[string]bool) + gids := make(map[int]bool) + + errs := make([]error, 0) + + for _, group := range g { + if names[group.Name] { + errs = append(errs, fmt.Errorf("duplicate group name: %s", group.Name)) + } + names[group.Name] = true + + if group.GID != nil { + if gids[*group.GID] { + errs = append(errs, fmt.Errorf("duplicate group ID: %d", *group.GID)) + } + gids[*group.GID] = true + } + } + + if err := errors.Join(errs...); err != nil { + return fmt.Errorf("invalid group customizations:\n%w", err) + } + + return nil +} + +func (c *Customizations) GetGroups() (GroupsCustomization, error) { + if c == nil { + return nil, nil + } + + if err := c.Group.Validate(); err != nil { + return nil, err + } + + return c.Group, nil +} diff --git a/pkg/blueprint/users_groups_customizations_test.go b/pkg/blueprint/users_groups_customizations_test.go new file mode 100644 index 0000000..79777bc --- /dev/null +++ b/pkg/blueprint/users_groups_customizations_test.go @@ -0,0 +1,202 @@ +package blueprint + +import ( + "testing" + + "github.com/osbuild/blueprint/internal/common" + "github.com/stretchr/testify/assert" +) + +func TestSSHKey(t *testing.T) { + expectedSSHKeys := []SSHKeyCustomization{ + { + User: "test-user", + Key: "test-key", + }, + } + TestCustomizations := Customizations{ + SSHKey: expectedSSHKeys, + } + + retUser := TestCustomizations.GetUsers()[0].Name + retKey := *TestCustomizations.GetUsers()[0].Key + + assert.Equal(t, expectedSSHKeys[0].User, retUser) + assert.Equal(t, expectedSSHKeys[0].Key, retKey) +} + +func TestGetUsers(t *testing.T) { + Desc := "Test descritpion" + Pass := "testpass" + Key := "testkey" + Home := "Home" + Shell := "Shell" + Groups := []string{ + "Group", + } + UID := 123 + GID := 321 + ExpireDate := 12345 + ForcePasswordReset := true + + expectedUsers := []UserCustomization{ + { + Name: "John", + Description: &Desc, + Password: &Pass, + Key: &Key, + Home: &Home, + Shell: &Shell, + Groups: Groups, + UID: &UID, + GID: &GID, + ExpireDate: &ExpireDate, + ForcePasswordReset: &ForcePasswordReset, + }, + } + + TestCustomizations := Customizations{ + User: expectedUsers, + } + + retUsers := TestCustomizations.GetUsers() + + assert.ElementsMatch(t, expectedUsers, retUsers) +} + +func TestGetGroups(t *testing.T) { + type testCase struct { + groups []GroupCustomization + expectedErrorMessage string + } + + testCases := map[string]testCase{ + "nil": { + groups: nil, + }, + "none": { + groups: []GroupCustomization{}, + }, + "single": { + groups: []GroupCustomization{ + { + Name: "TestGroup", + GID: common.ToPtr(1234), + }, + }, + }, + "multi": { + groups: []GroupCustomization{ + { + Name: "TestGroup", + GID: common.ToPtr(1234), + }, + { + Name: "sysgrp", + GID: common.ToPtr(998), + }, + { + Name: "wheel", + GID: common.ToPtr(42), + }, + }, + }, + "duplicate-names": { + groups: []GroupCustomization{ + { + Name: "TestGroup", + GID: common.ToPtr(1234), + }, + { + Name: "sysgrp", + GID: common.ToPtr(998), + }, + { + Name: "wheel", + GID: common.ToPtr(42), + }, + { + Name: "wheel", + GID: common.ToPtr(43), + }, + }, + expectedErrorMessage: "invalid group customizations:\nduplicate group name: wheel", + }, + "duplicate-gids": { + groups: []GroupCustomization{ + { + Name: "TestGroup", + GID: common.ToPtr(1234), + }, + { + Name: "sysgrp", + GID: common.ToPtr(42), + }, + { + Name: "wheel", + GID: common.ToPtr(42), + }, + }, + expectedErrorMessage: "invalid group customizations:\nduplicate group ID: 42", + }, + "duplicate-both": { + groups: []GroupCustomization{ + { + Name: "TestGroup", + GID: common.ToPtr(1234), + }, + { + Name: "wheel", + GID: common.ToPtr(42), + }, + { + Name: "wheel", + GID: common.ToPtr(42), + }, + }, + expectedErrorMessage: "invalid group customizations:\nduplicate group name: wheel\nduplicate group ID: 42", + }, + "duplicate-multi": { + groups: []GroupCustomization{ + { + Name: "test", + GID: common.ToPtr(1234), + }, + { + Name: "wheel", + GID: common.ToPtr(42), + }, + { + Name: "wheel", + GID: common.ToPtr(42), + }, + { + Name: "user", + GID: common.ToPtr(1234), + }, + { + Name: "test", + GID: common.ToPtr(4321), + }, + }, + expectedErrorMessage: "invalid group customizations:\nduplicate group name: wheel\nduplicate group ID: 42\nduplicate group ID: 1234\nduplicate group name: test", + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + assert := assert.New(t) + c := Customizations{ + Group: tc.groups, + } + + groups, err := c.GetGroups() + if tc.expectedErrorMessage != "" { + assert.EqualError(err, tc.expectedErrorMessage) + } else { + assert.NoError(err) + assert.ElementsMatch(tc.groups, groups) + } + }) + } +}