diff --git a/policy/files/table.go b/policy/files/table.go index a9a05148..536d30e5 100644 --- a/policy/files/table.go +++ b/policy/files/table.go @@ -16,7 +16,75 @@ package files -import "strings" +import ( + "strings" + "unicode" +) + +// fieldsEscaped splits a string on whitespace boundaries, but preserves +// whitespace that is escaped with a backslash. This allows for values +// containing spaces to be represented in the policy file. +func fieldsEscaped(s string) []string { + var currentField strings.Builder + escaped := false + fields := []string{} + + for _, r := range s { + if escaped { // This will write the next character (including if it's an escape character or space) + // If we're in escaped mode, add the character regardless of what it is + currentField.WriteRune(r) + escaped = false + continue + } + + if r == '\\' { + // Enter escaped mode for the next character + escaped = true + continue + } + + if unicode.IsSpace(r) { + // We found a space and we're not in escaped mode, so this is a field boundary + if currentField.Len() > 0 { + fields = append(fields, currentField.String()) + currentField.Reset() + } + } else { + // Not a space, add to current field + currentField.WriteRune(r) + } + } + + // Add the last field if there is one + if currentField.Len() > 0 { + fields = append(fields, currentField.String()) + } + + return fields +} + +// writeEscaped takes an array of strings and returns a single string with each +// element separated by a space. Any spaces or backslashes within the input strings +// are escaped with a backslash to preserve them when parsing with fieldsEscaped. +func writeEscaped(fields []string) string { + var result strings.Builder + + for i, field := range fields { + if i > 0 { + result.WriteRune(' ') + } + + for _, r := range field { + // Escape backslashes and spaces + if r == '\\' || unicode.IsSpace(r) { + result.WriteRune('\\') + } + result.WriteRune(r) + } + } + + return result.String() +} type Table struct { rows [][]string @@ -30,7 +98,8 @@ func NewTable(content []byte) *Table { if row == "" { continue } - columns := strings.Fields(row) + // Parse the row using fieldsEscaped to handle escaped spaces and backslashes + columns := fieldsEscaped(row) table = append(table, columns) } return &Table{rows: table} @@ -51,7 +120,7 @@ func (t *Table) AddRow(row ...string) { func (t Table) ToString() string { var sb strings.Builder for _, row := range t.rows { - sb.WriteString(strings.Join(row, " ") + "\n") + sb.WriteString(writeEscaped(row) + "\n") } return sb.String() } diff --git a/policy/files/table_test.go b/policy/files/table_test.go index 963cb2bb..bae7145a 100644 --- a/policy/files/table_test.go +++ b/policy/files/table_test.go @@ -67,3 +67,165 @@ https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0 096c }) } } + +func TestFieldsEscaped(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "empty string", + input: "", + expected: []string{}, + }, + { + name: "simple space-separated words", + input: "hello world test", + expected: []string{"hello", "world", "test"}, + }, + { + name: "escaped spaces", + input: `hello\ world test\ case`, + expected: []string{"hello world", "test case"}, + }, + { + name: "escaped backslashes", + input: `hello\\world test\\case`, + expected: []string{"hello\\world", "test\\case"}, + }, + { + name: "mixed escapes", + input: `hello\ world\\test case\\\ final`, + expected: []string{"hello world\\test", "case\\ final"}, + }, + { + name: "multiple consecutive spaces", + input: "hello world test", + expected: []string{"hello", "world", "test"}, + }, + { + name: "trailing escape", + input: `hello world\`, + expected: []string{"hello", "world"}, + }, + { + name: "disappearing escapes", + input: `\a\b\c\d\e\f\ghi\\\\\\jkl \mno\pqr\`, + expected: []string{"abcdefghi\\\\\\jkl", "mnopqr"}, + }, + { + name: "escaped special characters", + input: `hello\#world test\$case`, + expected: []string{"hello#world", "test$case"}, + }, + { + name: "multiple escaped spaces", + input: `hello\ \ \ world`, + expected: []string{"hello world"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fieldsEscaped(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestWriteEscaped(t *testing.T) { + tests := []struct { + name string + input []string + expected string + }{ + { + name: "empty slice", + input: []string{}, + expected: "", + }, + { + name: "single word", + input: []string{"hello"}, + expected: "hello", + }, + { + name: "simple words", + input: []string{"hello", "world", "test"}, + expected: "hello world test", + }, + { + name: "words with spaces", + input: []string{"hello world", "test case"}, + expected: `hello\ world test\ case`, + }, + { + name: "words with backslashes", + input: []string{"hello\\world", "test\\case"}, + expected: `hello\\world test\\case`, + }, + { + name: "mixed special characters", + input: []string{"hello world\\test", "case\\ final"}, + expected: `hello\ world\\test case\\\ final`, + }, + { + name: "multiple spaces", + input: []string{"hello world", "test"}, + expected: `hello\ \ \ world test`, + }, + { + name: "special characters", + input: []string{"hello#world", "test$case"}, + expected: "hello#world test$case", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := writeEscaped(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestRoundTrip tests that writeEscaped and fieldsEscaped work correctly together +func TestRoundTrip(t *testing.T) { + tests := []struct { + name string + input []string + }{ + { + name: "empty slice", + input: []string{}, + }, + { + name: "simple words", + input: []string{"hello", "world", "test"}, + }, + { + name: "words with spaces", + input: []string{"hello world", "test case", "final test"}, + }, + { + name: "words with backslashes", + input: []string{"hello\\world", "test\\case", "final\\test"}, + }, + { + name: "mixed content", + input: []string{"hello world\\test", "case\\ final", "test\\case space"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Convert to string and back to fields + escaped := writeEscaped(tt.input) + result := fieldsEscaped(escaped) + + // Check if the round trip preserves the original input + assert.Equal(t, tt.input, result) + }) + } +} diff --git a/policy/policyloader_test.go b/policy/policyloader_test.go index d5af822d..28ca31e3 100644 --- a/policy/policyloader_test.go +++ b/policy/policyloader_test.go @@ -345,6 +345,32 @@ func TestLoadSystemDefaultPolicy_Success(t *testing.T) { require.Equal(t, testPolicy, gotPolicy) } +func TestLoadSystemDefaultPolicyWithSpaces_Success(t *testing.T) { + t.Parallel() + + mockUserLookup := &MockUserLookup{User: ValidUser} + policyLoader := NewTestSystemPolicyLoader(afero.NewMemMapFs(), mockUserLookup) + mockFs := policyLoader.FileLoader.Fs + // Create policy file at default path with valid file + testPolicy := &policy.Policy{ + Users: []policy.User{ + { + IdentityAttribute: "oidc:groups:group with space", + Principals: []string{"test"}, + Issuer: "https://example.com", + }, + }, + } + testPolicyFile, err := testPolicy.ToTable() + require.NoError(t, err) + err = afero.WriteFile(mockFs, policy.SystemDefaultPolicyPath, testPolicyFile, 0640) + require.NoError(t, err) + gotPolicy, _, err := policyLoader.LoadSystemPolicy() + + require.NoError(t, err) + require.Equal(t, testPolicy, gotPolicy) +} + func TestDump_Success(t *testing.T) { // Test that Dump writes the policy to the mock filesystem when there are no // errors