From 78963ae54ad554013e09ba9b995613bdc3f4a181 Mon Sep 17 00:00:00 2001 From: Yonatan Striem-Amit Date: Tue, 8 Apr 2025 18:49:01 -0400 Subject: [PATCH 1/6] support escaping spaces in tables --- policy/files/table.go | 76 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 3 deletions(-) diff --git a/policy/files/table.go b/policy/files/table.go index a9a05148..2c35201d 100644 --- a/policy/files/table.go +++ b/policy/files/table.go @@ -16,7 +16,75 @@ package files -import "strings" +import ( + "strings" + "unicode" +) + +// fieldsEscaped splits a string on whitespace boundaries, but preserves +// whitespace that is escaped with a backslash. This allows for values +// containing spaces to be represented in the policy file. +func fieldsEscaped(s string) []string { + var fields []string + var currentField strings.Builder + escaped := false + + for _, r := range s { + if escaped { // This will write the next character (including if it's an escape character or space) + // If we're in escaped mode, add the character regardless of what it is + currentField.WriteRune(r) + escaped = false + continue + } + + if r == '\\' { + // Enter escaped mode for the next character + escaped = true + continue + } + + if unicode.IsSpace(r) { + // We found a space and we're not in escaped mode, so this is a field boundary + if currentField.Len() > 0 { + fields = append(fields, currentField.String()) + currentField.Reset() + } + } else { + // Not a space, add to current field + currentField.WriteRune(r) + } + } + + // Add the last field if there is one + if currentField.Len() > 0 { + fields = append(fields, currentField.String()) + } + + return fields +} + +// writeEscaped takes an array of strings and returns a single string with each +// element separated by a space. Any spaces or backslashes within the input strings +// are escaped with a backslash to preserve them when parsing with fieldsEscaped. +func writeEscaped(fields []string) string { + var result strings.Builder + + for i, field := range fields { + if i > 0 { + result.WriteRune(' ') + } + + for _, r := range field { + // Escape backslashes and spaces + if r == '\\' || unicode.IsSpace(r) { + result.WriteRune('\\') + } + result.WriteRune(r) + } + } + + return result.String() +} type Table struct { rows [][]string @@ -30,7 +98,8 @@ func NewTable(content []byte) *Table { if row == "" { continue } - columns := strings.Fields(row) + // Parse the row using fieldsEscaped to handle escaped spaces and backslashes + columns := fieldsEscaped(row) table = append(table, columns) } return &Table{rows: table} @@ -51,7 +120,8 @@ func (t *Table) AddRow(row ...string) { func (t Table) ToString() string { var sb strings.Builder for _, row := range t.rows { - sb.WriteString(strings.Join(row, " ") + "\n") + // URL-encode each column before writing + sb.WriteString(writeEscaped(row) + "\n") } return sb.String() } From c11838895af2853ae2c579ced43d1fa001d4fc7a Mon Sep 17 00:00:00 2001 From: Yonatan Striem-Amit Date: Tue, 8 Apr 2025 19:49:58 -0400 Subject: [PATCH 2/6] Support additional scopes in custom providers --- commands/login.go | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/commands/login.go b/commands/login.go index 75813dd1..3b6e12c5 100644 --- a/commands/login.go +++ b/commands/login.go @@ -91,8 +91,8 @@ func (l *LoginCmd) Run(ctx context.Context) error { var provider providers.OpenIdProvider if l.providerArg != "" { parts := strings.Split(l.providerArg, ",") - if len(parts) != 2 && len(parts) != 3 { - return fmt.Errorf("invalid provider argument format. Expected format , or ,, got (%s)", l.providerArg) + if len(parts) < 2 { + return fmt.Errorf("invalid provider argument format. Expected format , or ,, or ,,, got (%s)\n", l.providerArg) } issuerArg := parts[0] clientIDArg := parts[1] @@ -146,10 +146,18 @@ func (l *LoginCmd) Run(ctx context.Context) error { opts.GQSign = false opts.OpenBrowser = openBrowser - if len(parts) == 3 { + if len(parts) >= 3 { opts.ClientSecret = parts[2] } + if len(parts) >= 4 { + // Add all additional scopes from parts[3:] to opts.Scopes + additionalScopes := parts[3:] + if len(additionalScopes) > 0 { + opts.Scopes = append(opts.Scopes, additionalScopes...) + } + } + provider = providers.NewGoogleOpWithOptions(opts) } } else if l.providerFromLdFlags != nil { From 852e9cb3398f8606ae23854cf17ff8f313b39568 Mon Sep 17 00:00:00 2001 From: Yonatan Striem-Amit Date: Tue, 8 Apr 2025 20:02:24 -0400 Subject: [PATCH 3/6] Fixed test --- main_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main_test.go b/main_test.go index ec19f406..8e9311e6 100644 --- a/main_test.go +++ b/main_test.go @@ -218,7 +218,7 @@ func TestRun(t *testing.T) { { name: "Login command with provider bad provider value", args: []string{"opkssh", "login", "--provider=badvalue"}, - wantOutput: "Error: invalid provider argument format. Expected format , or ,, got (badvalue)", + wantOutput: "invalid provider argument format. Expected format , or ,, or ,,, got (badvalue)", wantExit: 1, }, { From 1a8dc1da043cd2f1a79fec0f184c370a7c5aa9e6 Mon Sep 17 00:00:00 2001 From: Yonatan Striem-Amit Date: Tue, 15 Apr 2025 18:27:11 -0400 Subject: [PATCH 4/6] Add unit tests --- policy/files/table.go | 2 +- policy/files/table_test.go | 157 +++++++++++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+), 1 deletion(-) diff --git a/policy/files/table.go b/policy/files/table.go index 2c35201d..827b432e 100644 --- a/policy/files/table.go +++ b/policy/files/table.go @@ -25,9 +25,9 @@ import ( // whitespace that is escaped with a backslash. This allows for values // containing spaces to be represented in the policy file. func fieldsEscaped(s string) []string { - var fields []string var currentField strings.Builder escaped := false + fields := []string{} for _, r := range s { if escaped { // This will write the next character (including if it's an escape character or space) diff --git a/policy/files/table_test.go b/policy/files/table_test.go index 963cb2bb..b73b4804 100644 --- a/policy/files/table_test.go +++ b/policy/files/table_test.go @@ -67,3 +67,160 @@ https://login.microsoftonline.com/9188040d-6c67-4c5b-b112-36a304b66dad/v2.0 096c }) } } + +func TestFieldsEscaped(t *testing.T) { + tests := []struct { + name string + input string + expected []string + }{ + { + name: "empty string", + input: "", + expected: []string{}, + }, + { + name: "simple space-separated words", + input: "hello world test", + expected: []string{"hello", "world", "test"}, + }, + { + name: "escaped spaces", + input: `hello\ world test\ case`, + expected: []string{"hello world", "test case"}, + }, + { + name: "escaped backslashes", + input: `hello\\world test\\case`, + expected: []string{"hello\\world", "test\\case"}, + }, + { + name: "mixed escapes", + input: `hello\ world\\test case\\\ final`, + expected: []string{"hello world\\test", "case\\ final"}, + }, + { + name: "multiple consecutive spaces", + input: "hello world test", + expected: []string{"hello", "world", "test"}, + }, + { + name: "trailing escape", + input: `hello world\`, + expected: []string{"hello", "world"}, + }, + { + name: "escaped special characters", + input: `hello\#world test\$case`, + expected: []string{"hello#world", "test$case"}, + }, + { + name: "multiple escaped spaces", + input: `hello\ \ \ world`, + expected: []string{"hello world"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := fieldsEscaped(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestWriteEscaped(t *testing.T) { + tests := []struct { + name string + input []string + expected string + }{ + { + name: "empty slice", + input: []string{}, + expected: "", + }, + { + name: "single word", + input: []string{"hello"}, + expected: "hello", + }, + { + name: "simple words", + input: []string{"hello", "world", "test"}, + expected: "hello world test", + }, + { + name: "words with spaces", + input: []string{"hello world", "test case"}, + expected: `hello\ world test\ case`, + }, + { + name: "words with backslashes", + input: []string{"hello\\world", "test\\case"}, + expected: `hello\\world test\\case`, + }, + { + name: "mixed special characters", + input: []string{"hello world\\test", "case\\ final"}, + expected: `hello\ world\\test case\\\ final`, + }, + { + name: "multiple spaces", + input: []string{"hello world", "test"}, + expected: `hello\ \ \ world test`, + }, + { + name: "special characters", + input: []string{"hello#world", "test$case"}, + expected: "hello#world test$case", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := writeEscaped(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestRoundTrip tests that writeEscaped and fieldsEscaped work correctly together +func TestRoundTrip(t *testing.T) { + tests := []struct { + name string + input []string + }{ + { + name: "empty slice", + input: []string{}, + }, + { + name: "simple words", + input: []string{"hello", "world", "test"}, + }, + { + name: "words with spaces", + input: []string{"hello world", "test case", "final test"}, + }, + { + name: "words with backslashes", + input: []string{"hello\\world", "test\\case", "final\\test"}, + }, + { + name: "mixed content", + input: []string{"hello world\\test", "case\\ final", "test\\case space"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Convert to string and back to fields + escaped := writeEscaped(tt.input) + result := fieldsEscaped(escaped) + + // Check if the round trip preserves the original input + assert.Equal(t, tt.input, result) + }) + } +} From 11e834b044a7633bb71f5a867764e535adb03d45 Mon Sep 17 00:00:00 2001 From: Ethan Heilman Date: Thu, 17 Apr 2025 15:44:10 -0400 Subject: [PATCH 5/6] Adds unit tests --- policy/files/table_test.go | 5 +++++ policy/policyloader_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/policy/files/table_test.go b/policy/files/table_test.go index b73b4804..bae7145a 100644 --- a/policy/files/table_test.go +++ b/policy/files/table_test.go @@ -109,6 +109,11 @@ func TestFieldsEscaped(t *testing.T) { input: `hello world\`, expected: []string{"hello", "world"}, }, + { + name: "disappearing escapes", + input: `\a\b\c\d\e\f\ghi\\\\\\jkl \mno\pqr\`, + expected: []string{"abcdefghi\\\\\\jkl", "mnopqr"}, + }, { name: "escaped special characters", input: `hello\#world test\$case`, diff --git a/policy/policyloader_test.go b/policy/policyloader_test.go index d5af822d..28ca31e3 100644 --- a/policy/policyloader_test.go +++ b/policy/policyloader_test.go @@ -345,6 +345,32 @@ func TestLoadSystemDefaultPolicy_Success(t *testing.T) { require.Equal(t, testPolicy, gotPolicy) } +func TestLoadSystemDefaultPolicyWithSpaces_Success(t *testing.T) { + t.Parallel() + + mockUserLookup := &MockUserLookup{User: ValidUser} + policyLoader := NewTestSystemPolicyLoader(afero.NewMemMapFs(), mockUserLookup) + mockFs := policyLoader.FileLoader.Fs + // Create policy file at default path with valid file + testPolicy := &policy.Policy{ + Users: []policy.User{ + { + IdentityAttribute: "oidc:groups:group with space", + Principals: []string{"test"}, + Issuer: "https://example.com", + }, + }, + } + testPolicyFile, err := testPolicy.ToTable() + require.NoError(t, err) + err = afero.WriteFile(mockFs, policy.SystemDefaultPolicyPath, testPolicyFile, 0640) + require.NoError(t, err) + gotPolicy, _, err := policyLoader.LoadSystemPolicy() + + require.NoError(t, err) + require.Equal(t, testPolicy, gotPolicy) +} + func TestDump_Success(t *testing.T) { // Test that Dump writes the policy to the mock filesystem when there are no // errors From 6a58188ce6cbd87c1c0644202abd3a078d4d730e Mon Sep 17 00:00:00 2001 From: Yonatan Striem-Amit Date: Thu, 17 Apr 2025 16:23:31 -0400 Subject: [PATCH 6/6] Removed old irrelevant comment --- policy/files/table.go | 1 - 1 file changed, 1 deletion(-) diff --git a/policy/files/table.go b/policy/files/table.go index 827b432e..536d30e5 100644 --- a/policy/files/table.go +++ b/policy/files/table.go @@ -120,7 +120,6 @@ func (t *Table) AddRow(row ...string) { func (t Table) ToString() string { var sb strings.Builder for _, row := range t.rows { - // URL-encode each column before writing sb.WriteString(writeEscaped(row) + "\n") } return sb.String()