From 43e00239a14e1068e1bd43a5e757e5d8bf53c5ed Mon Sep 17 00:00:00 2001 From: Ethan Heilman Date: Fri, 9 Jan 2026 13:38:18 -0500 Subject: [PATCH 1/2] Fix openssh version detection bug --- go.mod | 1 + go.sum | 2 ++ main.go | 15 ++++++----- main_test.go | 72 +++++++++++++++++++++++++++++++++++----------------- 4 files changed, 60 insertions(+), 30 deletions(-) diff --git a/go.mod b/go.mod index 6cbfc34a..a532f976 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/thediveo/enumflag/v2 v2.0.7 github.com/zitadel/oidc/v3 v3.41.0 golang.org/x/crypto v0.40.0 + golang.org/x/mod v0.26.0 golang.org/x/term v0.33.0 ) diff --git a/go.sum b/go.sum index d30db961..7ca356e2 100644 --- a/go.sum +++ b/go.sum @@ -237,6 +237,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= diff --git a/main.go b/main.go index 7e0d0240..9dd67a21 100644 --- a/main.go +++ b/main.go @@ -40,6 +40,7 @@ import ( "github.com/spf13/cobra" "github.com/spf13/cobra/doc" "github.com/thediveo/enumflag/v2" + "golang.org/x/mod/semver" "golang.org/x/term" ) @@ -450,7 +451,6 @@ func checkOpenSSHVersion() { func getOpenSSHVersion() string { // OS-specific package manager queries osType := detectOS() - log.Printf("Attempting OS-specific version detection for: %s", osType) switch osType { case OSTypeRHEL: @@ -503,7 +503,7 @@ func getOpenSSHVersion() string { return "" } -func isOpenSSHVersion8Dot1OrGreater(opensshVersion string) (bool, error) { +func isOpenSSHVersion8Dot1OrGreater(opensshVersionStr string) (bool, error) { // To handle versions like 9.9p1; we only need the initial numeric part for the comparison re, err := regexp.Compile(`^(\d+(?:\.\d+)*).*`) if err != nil { @@ -511,8 +511,8 @@ func isOpenSSHVersion8Dot1OrGreater(opensshVersion string) (bool, error) { return false, err } - opensshVersion = strings.TrimPrefix( - strings.Split(opensshVersion, ", ")[0], + opensshVersion := strings.TrimPrefix( + strings.Split(opensshVersionStr, ", ")[0], "OpenSSH_", ) @@ -523,9 +523,10 @@ func isOpenSSHVersion8Dot1OrGreater(opensshVersion string) (bool, error) { return false, errors.New("invalid OpenSSH version") } - version := matches[1] - - if version >= "8.1" { + version := "v" + matches[1] // semver requires that version strings start with 'v' + // OpenSSH doesn't use semantic versioning, but does use major.minor which after striping the patch version can be compared using semver + if semver.Compare(version, "v8.1.0") >= 0 { + // if version is greater than or equal to v8.1.0 return true, nil } diff --git a/main_test.go b/main_test.go index 4df123bc..bc2671dc 100644 --- a/main_test.go +++ b/main_test.go @@ -20,6 +20,7 @@ import ( "errors" "io" "os" + "strconv" "strings" "testing" @@ -27,6 +28,8 @@ import ( ) func TestIsOpenSSHVersion8Dot1OrGreater(t *testing.T) { + t.Parallel() + tests := []struct { name string input string @@ -45,6 +48,12 @@ func TestIsOpenSSHVersion8Dot1OrGreater(t *testing.T) { wantIsGreater: true, wantErr: nil, }, + { + name: "Regression test for 10.0 bug", + input: "OpenSSH_10.0", + wantIsGreater: true, + wantErr: nil, + }, { name: "Above 8.1 with patch (9.9p1)", input: "OpenSSH_9.9p1", @@ -91,31 +100,47 @@ func TestIsOpenSSHVersion8Dot1OrGreater(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotIsGreater, gotErr := isOpenSSHVersion8Dot1OrGreater(tt.input) + RunOpenSSHVersionTest(t, tt.name, tt.input, tt.wantIsGreater, tt.wantErr) + }) + } +} - if gotIsGreater != tt.wantIsGreater { - t.Errorf( - "isOpenSSHVersion8Dot1OrGreater(%q) got %v; want %v", - tt.input, - gotIsGreater, - tt.wantIsGreater, - ) - } +func TestOpenSSHVersion8Dot1OrGreaterViaBruteForce(t *testing.T) { + t.Parallel() + for major := 9; major <= 15; major++ { + for minor := 0; minor < 100; minor++ { + versionStr := "OpenSSH_" + strconv.Itoa(major) + "." + strconv.Itoa(minor) + expectedIsGreater := true + testName := "Testing openssh version " + versionStr + RunOpenSSHVersionTest(t, testName, versionStr, expectedIsGreater, nil) + } + } +} - if (gotErr != nil) != (tt.wantErr != nil) { - t.Errorf( - "isOpenSSHVersion8Dot1OrGreater(%q) error = %v; want %v", - tt.input, - gotErr, - tt.wantErr, - ) - } else if gotErr != nil && tt.wantErr != nil { - if gotErr.Error() != tt.wantErr.Error() { - t.Errorf("Unexpected error message. got %q; want %q", - gotErr.Error(), tt.wantErr.Error()) - } - } - }) +func RunOpenSSHVersionTest(t *testing.T, testName string, versionOutput string, expectedIsGreater bool, expectedErr error) { + gotIsGreater, gotErr := isOpenSSHVersion8Dot1OrGreater(versionOutput) + + if gotIsGreater != expectedIsGreater { + t.Errorf( + "isOpenSSHVersion8Dot1OrGreater(%q) got %v; want %v", + versionOutput, + gotIsGreater, + expectedIsGreater, + ) + } + + if (gotErr != nil) != (expectedErr != nil) { + t.Errorf( + "isOpenSSHVersion8Dot1OrGreater(%q) error = %v; want %v", + versionOutput, + gotErr, + expectedErr, + ) + } else if gotErr != nil && expectedErr != nil { + if gotErr.Error() != expectedErr.Error() { + t.Errorf("Unexpected error message. got %q; want %q", + gotErr.Error(), expectedErr.Error()) + } } } @@ -149,6 +174,7 @@ func RunCliAndCaptureResult(t *testing.T, args []string) (string, int) { } func TestRun(t *testing.T) { + t.Parallel() tests := []struct { name string args []string From 397dc8eae6efa0ba7be82d74d65d3ad671084d6e Mon Sep 17 00:00:00 2001 From: Ethan Heilman Date: Fri, 9 Jan 2026 15:40:34 -0500 Subject: [PATCH 2/2] Fixes issues found in machine review --- main.go | 2 +- main_test.go | 71 +++++++++++++++++++++++++--------------------------- 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/main.go b/main.go index 9dd67a21..e4c443c5 100644 --- a/main.go +++ b/main.go @@ -524,7 +524,7 @@ func isOpenSSHVersion8Dot1OrGreater(opensshVersionStr string) (bool, error) { } version := "v" + matches[1] // semver requires that version strings start with 'v' - // OpenSSH doesn't use semantic versioning, but does use major.minor which after striping the patch version can be compared using semver + // OpenSSH doesn't use semantic versioning, but does use major.minor which after stripping the patch version can be compared using semver if semver.Compare(version, "v8.1.0") >= 0 { // if version is greater than or equal to v8.1.0 return true, nil diff --git a/main_test.go b/main_test.go index bc2671dc..6757a325 100644 --- a/main_test.go +++ b/main_test.go @@ -17,7 +17,6 @@ package main import ( - "errors" "io" "os" "strconv" @@ -34,67 +33,67 @@ func TestIsOpenSSHVersion8Dot1OrGreater(t *testing.T) { name string input string wantIsGreater bool - wantErr error + wantErr string }{ { name: "Exact 8.1", input: "OpenSSH_8.1", wantIsGreater: true, - wantErr: nil, + wantErr: "", }, { name: "Above 8.1 (8.4)", input: "OpenSSH_8.4", wantIsGreater: true, - wantErr: nil, + wantErr: "", }, { name: "Regression test for 10.0 bug", input: "OpenSSH_10.0", wantIsGreater: true, - wantErr: nil, + wantErr: "", }, { name: "Above 8.1 with patch (9.9p1)", input: "OpenSSH_9.9p1", wantIsGreater: true, - wantErr: nil, + wantErr: "", }, { name: "Below 8.1 (7.9)", input: "OpenSSH_7.9", wantIsGreater: false, - wantErr: nil, + wantErr: "", }, { name: "Multiple dotted version above 8.1 (8.1.2)", input: "OpenSSH_8.1.2", wantIsGreater: true, - wantErr: nil, + wantErr: "", }, { name: "Multiple dotted version below 8.1 (7.10.3)", input: "OpenSSH_7.10.3", wantIsGreater: false, - wantErr: nil, + wantErr: "", }, { name: "Malformed version string", input: "OpenSSH_, something not right", wantIsGreater: false, - wantErr: errors.New("invalid OpenSSH version"), + wantErr: "invalid OpenSSH version", }, { name: "No OpenSSH prefix at all", input: "Completely invalid input", wantIsGreater: false, - wantErr: errors.New("invalid OpenSSH version"), + wantErr: "invalid OpenSSH version", }, { name: "Includes trailing info (8.2, Raspbian-1)", input: "OpenSSH_8.2, Raspbian-1", wantIsGreater: true, - wantErr: nil, + wantErr: "", }, } @@ -107,40 +106,38 @@ func TestIsOpenSSHVersion8Dot1OrGreater(t *testing.T) { func TestOpenSSHVersion8Dot1OrGreaterViaBruteForce(t *testing.T) { t.Parallel() - for major := 9; major <= 15; major++ { - for minor := 0; minor < 100; minor++ { + for major := 7; major <= 15; major++ { + for minor := 0; minor <= 101; minor++ { versionStr := "OpenSSH_" + strconv.Itoa(major) + "." + strconv.Itoa(minor) - expectedIsGreater := true testName := "Testing openssh version " + versionStr - RunOpenSSHVersionTest(t, testName, versionStr, expectedIsGreater, nil) + + expectedIsGreater := true + if major < 8 || (major == 8 && minor < 1) { + expectedIsGreater = false + } + RunOpenSSHVersionTest(t, testName, versionStr, expectedIsGreater, "") } } } -func RunOpenSSHVersionTest(t *testing.T, testName string, versionOutput string, expectedIsGreater bool, expectedErr error) { +func RunOpenSSHVersionTest(t *testing.T, testName string, versionOutput string, expectedIsGreater bool, expectedErr string) { gotIsGreater, gotErr := isOpenSSHVersion8Dot1OrGreater(versionOutput) - if gotIsGreater != expectedIsGreater { - t.Errorf( - "isOpenSSHVersion8Dot1OrGreater(%q) got %v; want %v", - versionOutput, - gotIsGreater, - expectedIsGreater, - ) - } + require.Equal(t, expectedIsGreater, gotIsGreater, + "Test %q failed: isOpenSSHVersion8Dot1OrGreater(%q) got %v; want %v", + testName, versionOutput, gotIsGreater, expectedIsGreater) - if (gotErr != nil) != (expectedErr != nil) { - t.Errorf( - "isOpenSSHVersion8Dot1OrGreater(%q) error = %v; want %v", - versionOutput, - gotErr, - expectedErr, - ) - } else if gotErr != nil && expectedErr != nil { - if gotErr.Error() != expectedErr.Error() { - t.Errorf("Unexpected error message. got %q; want %q", - gotErr.Error(), expectedErr.Error()) - } + if expectedErr != "" { + require.Error(t, gotErr, + "Test %q failed: isOpenSSHVersion8Dot1OrGreater(%q) expected error = %s", + testName, versionOutput, expectedErr, gotErr) + require.ErrorContains(t, gotErr, expectedErr, + "Test %q failed: isOpenSSHVersion8Dot1OrGreater(%q) expected error = %s, got %v", + testName, versionOutput, expectedErr, gotErr) + } else { + require.NoError(t, gotErr, + "Test %q failed: isOpenSSHVersion8Dot1OrGreater(%q) unexpected error = %v", + testName, versionOutput, gotErr) } }