Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ require (
github.com/thediveo/enumflag/v2 v2.0.7
github.com/zitadel/oidc/v3 v3.41.0
golang.org/x/crypto v0.40.0
golang.org/x/mod v0.26.0
golang.org/x/term v0.33.0
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
Expand Down
15 changes: 8 additions & 7 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import (
"github.com/spf13/cobra"
"github.com/spf13/cobra/doc"
"github.com/thediveo/enumflag/v2"
"golang.org/x/mod/semver"
"golang.org/x/term"
)

Expand Down Expand Up @@ -450,7 +451,6 @@ func checkOpenSSHVersion() {
func getOpenSSHVersion() string {
// OS-specific package manager queries
osType := detectOS()
log.Printf("Attempting OS-specific version detection for: %s", osType)

switch osType {
case OSTypeRHEL:
Expand Down Expand Up @@ -503,16 +503,16 @@ func getOpenSSHVersion() string {
return ""
}

func isOpenSSHVersion8Dot1OrGreater(opensshVersion string) (bool, error) {
func isOpenSSHVersion8Dot1OrGreater(opensshVersionStr string) (bool, error) {
// To handle versions like 9.9p1; we only need the initial numeric part for the comparison
re, err := regexp.Compile(`^(\d+(?:\.\d+)*).*`)
if err != nil {
fmt.Println("Error compiling regex:", err)
return false, err
}

opensshVersion = strings.TrimPrefix(
strings.Split(opensshVersion, ", ")[0],
opensshVersion := strings.TrimPrefix(
strings.Split(opensshVersionStr, ", ")[0],
"OpenSSH_",
)

Expand All @@ -523,9 +523,10 @@ func isOpenSSHVersion8Dot1OrGreater(opensshVersion string) (bool, error) {
return false, errors.New("invalid OpenSSH version")
}

version := matches[1]

if version >= "8.1" {
version := "v" + matches[1] // semver requires that version strings start with 'v'
// OpenSSH doesn't use semantic versioning, but does use major.minor which after stripping the patch version can be compared using semver
if semver.Compare(version, "v8.1.0") >= 0 {
// if version is greater than or equal to v8.1.0
return true, nil
}

Expand Down
89 changes: 56 additions & 33 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,105 +17,127 @@
package main

import (
"errors"
"io"
"os"
"strconv"
"strings"
"testing"

"github.com/stretchr/testify/require"
)

func TestIsOpenSSHVersion8Dot1OrGreater(t *testing.T) {
t.Parallel()

tests := []struct {
name string
input string
wantIsGreater bool
wantErr error
wantErr string
}{
{
name: "Exact 8.1",
input: "OpenSSH_8.1",
wantIsGreater: true,
wantErr: nil,
wantErr: "",
},
{
name: "Above 8.1 (8.4)",
input: "OpenSSH_8.4",
wantIsGreater: true,
wantErr: nil,
wantErr: "",
},
{
name: "Regression test for 10.0 bug",
input: "OpenSSH_10.0",
wantIsGreater: true,
wantErr: "",
},
{
name: "Above 8.1 with patch (9.9p1)",
input: "OpenSSH_9.9p1",
wantIsGreater: true,
wantErr: nil,
wantErr: "",
},
{
name: "Below 8.1 (7.9)",
input: "OpenSSH_7.9",
wantIsGreater: false,
wantErr: nil,
wantErr: "",
},
{
name: "Multiple dotted version above 8.1 (8.1.2)",
input: "OpenSSH_8.1.2",
wantIsGreater: true,
wantErr: nil,
wantErr: "",
},
{
name: "Multiple dotted version below 8.1 (7.10.3)",
input: "OpenSSH_7.10.3",
wantIsGreater: false,
wantErr: nil,
wantErr: "",
},
{
name: "Malformed version string",
input: "OpenSSH_, something not right",
wantIsGreater: false,
wantErr: errors.New("invalid OpenSSH version"),
wantErr: "invalid OpenSSH version",
},
{
name: "No OpenSSH prefix at all",
input: "Completely invalid input",
wantIsGreater: false,
wantErr: errors.New("invalid OpenSSH version"),
wantErr: "invalid OpenSSH version",
},
{
name: "Includes trailing info (8.2, Raspbian-1)",
input: "OpenSSH_8.2, Raspbian-1",
wantIsGreater: true,
wantErr: nil,
wantErr: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotIsGreater, gotErr := isOpenSSHVersion8Dot1OrGreater(tt.input)
RunOpenSSHVersionTest(t, tt.name, tt.input, tt.wantIsGreater, tt.wantErr)
})
}
}

if gotIsGreater != tt.wantIsGreater {
t.Errorf(
"isOpenSSHVersion8Dot1OrGreater(%q) got %v; want %v",
tt.input,
gotIsGreater,
tt.wantIsGreater,
)
}
func TestOpenSSHVersion8Dot1OrGreaterViaBruteForce(t *testing.T) {
t.Parallel()
for major := 7; major <= 15; major++ {
for minor := 0; minor <= 101; minor++ {
versionStr := "OpenSSH_" + strconv.Itoa(major) + "." + strconv.Itoa(minor)
testName := "Testing openssh version " + versionStr

if (gotErr != nil) != (tt.wantErr != nil) {
t.Errorf(
"isOpenSSHVersion8Dot1OrGreater(%q) error = %v; want %v",
tt.input,
gotErr,
tt.wantErr,
)
} else if gotErr != nil && tt.wantErr != nil {
if gotErr.Error() != tt.wantErr.Error() {
t.Errorf("Unexpected error message. got %q; want %q",
gotErr.Error(), tt.wantErr.Error())
}
expectedIsGreater := true
if major < 8 || (major == 8 && minor < 1) {
expectedIsGreater = false
}
})
RunOpenSSHVersionTest(t, testName, versionStr, expectedIsGreater, "")
}
}
}

func RunOpenSSHVersionTest(t *testing.T, testName string, versionOutput string, expectedIsGreater bool, expectedErr string) {
gotIsGreater, gotErr := isOpenSSHVersion8Dot1OrGreater(versionOutput)

require.Equal(t, expectedIsGreater, gotIsGreater,
"Test %q failed: isOpenSSHVersion8Dot1OrGreater(%q) got %v; want %v",
testName, versionOutput, gotIsGreater, expectedIsGreater)

if expectedErr != "" {
require.Error(t, gotErr,
"Test %q failed: isOpenSSHVersion8Dot1OrGreater(%q) expected error = %s",
testName, versionOutput, expectedErr, gotErr)
require.ErrorContains(t, gotErr, expectedErr,
"Test %q failed: isOpenSSHVersion8Dot1OrGreater(%q) expected error = %s, got %v",
testName, versionOutput, expectedErr, gotErr)
} else {
require.NoError(t, gotErr,
"Test %q failed: isOpenSSHVersion8Dot1OrGreater(%q) unexpected error = %v",
testName, versionOutput, gotErr)
}
}

Expand Down Expand Up @@ -149,6 +171,7 @@ func RunCliAndCaptureResult(t *testing.T, args []string) (string, int) {
}

func TestRun(t *testing.T) {
t.Parallel()
tests := []struct {
name string
args []string
Expand Down
Loading