diff --git a/internal/login/login.go b/internal/login/login.go index e705dfcb..44405321 100644 --- a/internal/login/login.go +++ b/internal/login/login.go @@ -232,12 +232,13 @@ func generateState() (string, error) { } // check for Windows Subsystem for Linux -func isWsl() bool { +func isWsl(sc util.Syscall) bool { // the common factor between WSL distros is the Microsoft-specific kernel version, so we check for that // SUSE, WSLv1: 4.4.0-19041-Microsoft // Ubuntu, WSLv2: 4.19.128-microsoft-standard + const wslIdentifier = "microsoft" var uname syscall.Utsname - if err := syscall.Uname(&uname); err == nil { + if err := sc.Uname(&uname); err == nil { var kernel []byte for _, b := range uname.Release { if b == 0 { @@ -245,22 +246,23 @@ func isWsl() bool { } kernel = append(kernel, byte(b)) } - return strings.Contains(strings.ToLower(string(kernel)), "microsoft") + return strings.Contains(strings.ToLower(string(kernel)), wslIdentifier) } return false } func openBrowser(url string) error { + const rundllParameters = "url.dll,FileProtocolHandler" var err error switch runtime.GOOS { case "linux": - if isWsl() { - err = exec.Command("rundll32.exe", "url.dll,FileProtocolHandler", url).Start() + if isWsl(util.DefaultSyscall) { + err = exec.Command("rundll32.exe", rundllParameters, url).Start() } else { err = exec.Command("xdg-open", url).Start() } case "windows": - err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + err = exec.Command("rundll32", rundllParameters, url).Start() case "darwin": err = exec.Command("open", url).Start() default: diff --git a/internal/login/login_test.go b/internal/login/login_test.go index 9a2dbd47..66255179 100644 --- a/internal/login/login_test.go +++ b/internal/login/login_test.go @@ -4,13 +4,16 @@ package login import ( "encoding/json" + "errors" "fmt" "net/http" "net/http/httptest" + "syscall" "testing" "time" "github.com/spf13/viper" + "github.com/stretchr/testify/assert" "github.com/twitchdev/twitch-cli/internal/util" ) @@ -157,3 +160,35 @@ func TestUserAuthServer(t *testing.T) { a.Equal(state, ur.State, "State mismatch") a.Equal(code, ur.Code, "Code mismatch") } + +func TestIsWsl(t *testing.T) { + a := assert.New(t) + + var ( + ubuntu20Wsl2 = [65]int8{52, 46, 49, 57, 46, 49, 50, 56, 45, 109, 105, 99, 114, 111, 115, 111, 102, 116, 45, 115, 116, 97, 110, 100, 97, 114, 100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + archReal = [65]int8{53, 46, 49, 49, 46, 49, 49, 45, 97, 114, 99, 104, 49, 45, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + ) + + result := isWsl(util.Syscall{ + Uname: func(buf *syscall.Utsname) (err error) { + buf.Release = ubuntu20Wsl2 + return nil + }, + }) + a.True(result) + + result = isWsl(util.Syscall{ + Uname: func(buf *syscall.Utsname) (err error) { + buf.Release = archReal + return nil + }, + }) + a.False(result) + + result = isWsl(util.Syscall{ + Uname: func(buf *syscall.Utsname) (err error) { + return errors.New("mocked error") + }, + }) + a.False(result) +} diff --git a/internal/util/syscall.go b/internal/util/syscall.go new file mode 100644 index 00000000..ac04cf56 --- /dev/null +++ b/internal/util/syscall.go @@ -0,0 +1,12 @@ +package util + +import "syscall" + +// Syscall wraps syscalls used in the application for unit testing purposes +type Syscall struct { + Uname func(buf *syscall.Utsname) (err error) +} + +var DefaultSyscall = Syscall{ + Uname: syscall.Uname, +}