diff --git a/installer/pkg/validate/BUILD.bazel b/installer/pkg/validate/BUILD.bazel index 9df2dd22512..d03c33ca7cb 100644 --- a/installer/pkg/validate/BUILD.bazel +++ b/installer/pkg/validate/BUILD.bazel @@ -3,7 +3,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_test( name = "go_default_test", size = "small", - srcs = ["validate_test.go"], + srcs = [ + "last_ip_test.go", + "validate_test.go", + ], embed = [":go_default_library"], deps = ["//pkg/asset/tls:go_default_library"], ) diff --git a/installer/pkg/validate/last_ip_test.go b/installer/pkg/validate/last_ip_test.go new file mode 100644 index 00000000000..d4c8610c19e --- /dev/null +++ b/installer/pkg/validate/last_ip_test.go @@ -0,0 +1,49 @@ +package validate + +import ( + "net" + "testing" +) + +func TestLastIP(t *testing.T) { + cases := []struct { + in net.IPNet + out net.IP + }{ + { + in: net.IPNet{ + IP: net.ParseIP("192.168.0.0").To4(), + Mask: net.CIDRMask(24, 32), + }, + out: net.ParseIP("192.168.0.255"), + }, + { + in: net.IPNet{ + IP: net.ParseIP("192.168.0.0").To4(), + Mask: net.CIDRMask(22, 32), + }, + out: net.ParseIP("192.168.3.255"), + }, + { + in: net.IPNet{ + IP: net.ParseIP("192.168.0.0").To4(), + Mask: net.CIDRMask(32, 32), + }, + out: net.ParseIP("192.168.0.0"), + }, + { + in: net.IPNet{ + IP: net.ParseIP("0.0.0.0").To4(), + Mask: net.CIDRMask(0, 32), + }, + out: net.ParseIP("255.255.255.255"), + }, + } + + var out net.IP + for i, c := range cases { + if out = lastIP(&c.in); out.String() != c.out.String() { + t.Errorf("test case %d: expected %s but got %s", i, c.out, out) + } + } +} diff --git a/installer/pkg/validate/validate_test.go b/installer/pkg/validate/validate_test.go index 8227be6bdb7..0c5e1d98147 100644 --- a/installer/pkg/validate/validate_test.go +++ b/installer/pkg/validate/validate_test.go @@ -1,15 +1,15 @@ -package validate +package validate_test import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" "io/ioutil" - "net" "os" "strings" "testing" + "github.com/openshift/installer/installer/pkg/validate" "github.com/openshift/installer/pkg/asset/tls" ) @@ -46,7 +46,7 @@ func TestNonEmpty(t *testing.T) { {".", ""}, {"日本語", ""}, } - runTests(t, "NonEmpty", NonEmpty, tests) + runTests(t, "NonEmpty", validate.NonEmpty, tests) } func TestInt(t *testing.T) { @@ -64,7 +64,7 @@ func TestInt(t *testing.T) { {"999999", ""}, {"-1", ""}, } - runTests(t, "Int", Int, tests) + runTests(t, "Int", validate.Int, tests) } func TestIntRange(t *testing.T) { @@ -88,7 +88,7 @@ func TestIntRange(t *testing.T) { } for _, test := range tests { - err := IntRange(test.in, test.min, test.max) + err := validate.IntRange(test.in, test.min, test.max) if (err == nil && test.expected != "") || (err != nil && err.Error() != test.expected) { t.Errorf("For IntRange(%q, %v, %v), expected %q, got %q", test.in, test.min, test.max, test.expected, err) } @@ -109,7 +109,7 @@ func TestIntOdd(t *testing.T) { {"1 abc", invalidIntMsg}, {"日本語", invalidIntMsg}, } - runTests(t, "IntOdd", IntOdd, tests) + runTests(t, "IntOdd", validate.IntOdd, tests) } func TestClusterName(t *testing.T) { @@ -145,7 +145,7 @@ func TestClusterName(t *testing.T) { {maxSizeSegment + ".abc", ""}, {maxSizeSegment + "a.abc", segmentLengthMsg}, } - runTests(t, "ClusterName", ClusterName, tests) + runTests(t, "ClusterName", validate.ClusterName, tests) } func TestAWSClusterName(t *testing.T) { @@ -174,7 +174,7 @@ func TestAWSClusterName(t *testing.T) { {"12345678901234567890123456789", lengthMsg}, {"A2345678901234567890123456789", lengthMsg}, } - runTests(t, "AWSClusterName", AWSClusterName, tests) + runTests(t, "AWSClusterName", validate.AWSClusterName, tests) } func TestMAC(t *testing.T) { @@ -189,7 +189,7 @@ func TestMAC(t *testing.T) { {"12:34:45:78:9X:YZ", invalidMsg}, {"12.34.45.78.9A.BC", invalidMsg}, } - runTests(t, "MAC", MAC, tests) + runTests(t, "MAC", validate.MAC, tests) } func TestIPv4(t *testing.T) { @@ -203,7 +203,7 @@ func TestIPv4(t *testing.T) { {"1.2.3.a", invalidIPMsg}, {"255.255.255.255", ""}, } - runTests(t, "IPv4", IPv4, tests) + runTests(t, "IPv4", validate.IPv4, tests) } func TestSubnetCIDR(t *testing.T) { @@ -234,7 +234,7 @@ func TestSubnetCIDR(t *testing.T) { {"255.255.255.255/1", ""}, {"255.255.255.255/32", ""}, } - runTests(t, "SubnetCIDR", SubnetCIDR, tests) + runTests(t, "SubnetCIDR", validate.SubnetCIDR, tests) } func TestAWSsubnetCIDR(t *testing.T) { @@ -250,7 +250,7 @@ func TestAWSsubnetCIDR(t *testing.T) { {"1.2.3.4/28", ""}, {"1.2.3.4/29", awsNetmaskSizeMsg}, } - runTests(t, "AWSSubnetCIDR", AWSSubnetCIDR, tests) + runTests(t, "AWSSubnetCIDR", validate.AWSSubnetCIDR, tests) } func TestDomainName(t *testing.T) { @@ -279,7 +279,7 @@ func TestDomainName(t *testing.T) { {".abc.com", invalidDomainMsg}, {".abc.com", invalidDomainMsg}, } - runTests(t, "DomainName", DomainName, tests) + runTests(t, "DomainName", validate.DomainName, tests) } func TestHost(t *testing.T) { @@ -308,7 +308,7 @@ func TestHost(t *testing.T) { {".abc.com", invalidHostMsg}, {".abc.com", invalidHostMsg}, } - runTests(t, "Host", Host, tests) + runTests(t, "Host", validate.Host, tests) } func TestPort(t *testing.T) { @@ -325,7 +325,7 @@ func TestPort(t *testing.T) { {"65535", ""}, {"65536", invalidPortMsg}, } - runTests(t, "Port", Port, tests) + runTests(t, "Port", validate.Port, tests) } func TestHostPort(t *testing.T) { @@ -345,7 +345,7 @@ func TestHostPort(t *testing.T) { {"1.2.3.4:abc", invalidPortMsg}, {"日本語:1234", invalidHostMsg}, } - runTests(t, "HostPort", HostPort, tests) + runTests(t, "HostPort", validate.HostPort, tests) } func TestEmail(t *testing.T) { @@ -369,7 +369,7 @@ func TestEmail(t *testing.T) { {"a@.com", invalidDomainMsg}, {"@abc.com", invalidMsg}, } - runTests(t, "Email", Email, tests) + runTests(t, "Email", validate.Email, tests) } func TestCertificate(t *testing.T) { @@ -426,7 +426,7 @@ func TestCertificate(t *testing.T) { {"-----BEGIN CERTIFICATE-----\na\n-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----\na\n-----END CERTIFICATE-----", badPem}, {string(keyinPem), privateKeyMsg}, } - runTests(t, "Certificate", Certificate, tests) + runTests(t, "Certificate", validate.Certificate, tests) } func TestPrivateKey(t *testing.T) { @@ -463,7 +463,7 @@ func TestPrivateKey(t *testing.T) { {"-----BEGIN RSA PRIVATE KEY-----\na\n-----END RSA PRIVATE KEY-----\n-----BEGIN CERTIFICATE-----\na\n-----END CERTIFICATE-----", invalidMsg}, {"-----BEGIN CERTIFICATE-----\na\n-----END CERTIFICATE-----", invalidMsg}, } - runTests(t, "PrivateKey", PrivateKey, tests) + runTests(t, "PrivateKey", validate.PrivateKey, tests) } func TestOpenSSHPublicKey(t *testing.T) { @@ -489,7 +489,7 @@ func TestOpenSSHPublicKey(t *testing.T) { {"-----BEGIN CERTIFICATE-----abcd-----END CERTIFICATE-----", invalidMsg}, {"-----BEGIN RSA PRIVATE KEY-----\nabc\n-----END RSA PRIVATE KEY-----", privateKeyMsg}, } - runTests(t, "OpenSSHPublicKey", OpenSSHPublicKey, tests) + runTests(t, "OpenSSHPublicKey", validate.OpenSSHPublicKey, tests) } func TestCIDRsDontOverlap(t *testing.T) { @@ -526,7 +526,7 @@ func TestCIDRsDontOverlap(t *testing.T) { } for i, c := range cases { - if err := CIDRsDontOverlap(c.a, c.b); (err != nil) != c.err { + if err := validate.CIDRsDontOverlap(c.a, c.b); (err != nil) != c.err { no := "no" if c.err { no = "an" @@ -536,49 +536,6 @@ func TestCIDRsDontOverlap(t *testing.T) { } } -func TestLastIP(t *testing.T) { - cases := []struct { - in net.IPNet - out net.IP - }{ - { - in: net.IPNet{ - IP: net.ParseIP("192.168.0.0").To4(), - Mask: net.CIDRMask(24, 32), - }, - out: net.ParseIP("192.168.0.255"), - }, - { - in: net.IPNet{ - IP: net.ParseIP("192.168.0.0").To4(), - Mask: net.CIDRMask(22, 32), - }, - out: net.ParseIP("192.168.3.255"), - }, - { - in: net.IPNet{ - IP: net.ParseIP("192.168.0.0").To4(), - Mask: net.CIDRMask(32, 32), - }, - out: net.ParseIP("192.168.0.0"), - }, - { - in: net.IPNet{ - IP: net.ParseIP("0.0.0.0").To4(), - Mask: net.CIDRMask(0, 32), - }, - out: net.ParseIP("255.255.255.255"), - }, - } - - var out net.IP - for i, c := range cases { - if out = lastIP(&c.in); out.String() != c.out.String() { - t.Errorf("test case %d: expected %s but got %s", i, c.out, out) - } - } -} - func TestJSONFile(t *testing.T) { cases := []struct { buf []byte @@ -617,7 +574,7 @@ func TestJSONFile(t *testing.T) { if _, err := f.Write(c.buf); err != nil { t.Errorf("test case %d: failed to write to temporary file: %v", i, err) } - if err := JSONFile(f.Name()); (err != nil) != c.err { + if err := validate.JSONFile(f.Name()); (err != nil) != c.err { no := "no" if c.err { no = "an" @@ -644,7 +601,7 @@ func TestFileExists(t *testing.T) { }, } for i, c := range cases { - if err := FileExists(c.path); (err != nil) != c.err { + if err := validate.FileExists(c.path); (err != nil) != c.err { no := "no" if c.err { no = "an" @@ -692,7 +649,7 @@ func TestFileHeader(t *testing.T) { t.Errorf("test case %d: failed to write to temporary file: %v", i, err) } f.Close() - if err := FileHeader(f.Name(), c.expected); (err != nil) != c.err { + if err := validate.FileHeader(f.Name(), c.expected); (err != nil) != c.err { no := "no" if c.err { no = "an" diff --git a/pkg/asset/installconfig/BUILD.bazel b/pkg/asset/installconfig/BUILD.bazel index c0edbbf8e1e..d603db4b0ad 100644 --- a/pkg/asset/installconfig/BUILD.bazel +++ b/pkg/asset/installconfig/BUILD.bazel @@ -7,11 +7,13 @@ go_library( "doc.go", "installconfig.go", "platform.go", + "ssh.go", "stock.go", ], importpath = "github.com/openshift/installer/pkg/asset/installconfig", visibility = ["//visibility:public"], deps = [ + "//installer/pkg/validate:go_default_library", "//pkg/asset:go_default_library", "//pkg/types:go_default_library", "//vendor/github.com/ghodss/yaml:go_default_library", diff --git a/pkg/asset/installconfig/ssh.go b/pkg/asset/installconfig/ssh.go new file mode 100644 index 00000000000..1861c637e1a --- /dev/null +++ b/pkg/asset/installconfig/ssh.go @@ -0,0 +1,114 @@ +package installconfig + +import ( + "bufio" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/openshift/installer/installer/pkg/validate" + "github.com/openshift/installer/pkg/asset" +) + +type sshPublicKey struct { + inputReader *bufio.Reader +} + +// Dependencies returns no dependencies. +func (a *sshPublicKey) Dependencies() []asset.Asset { + return nil +} + +// Generate generates the SSH public key asset. +func (a *sshPublicKey) Generate(map[asset.Asset]*asset.State) (state *asset.State, err error) { + var paths []string + var pubKeys map[string]string + home := os.Getenv("HOME") + if home != "" { + paths, err = filepath.Glob(filepath.Join(home, ".ssh", "*.pub")) + if err != nil { + return nil, err + } + + if len(paths) > 0 { + pubKeys = map[string]string{} + } + + for _, path := range paths { + pubKeyBytes, err := ioutil.ReadFile(path) + if err != nil { + continue + } + pubKey := string(pubKeyBytes) + + err = validate.OpenSSHPublicKey(pubKey) + if err != nil { + continue + } + + pubKeys[path] = pubKey + } + + paths = []string{} + for path := range pubKeys { + paths = append(paths, path) + } + sort.Strings(paths) + } + + promptLines := []string{"SSH Public Key:"} + if len(paths) == 0 { + promptLines = append( + promptLines, + "Enter an empty string or your public key (e.g. 'ssh-rsa AAAA...')", + ) + } else { + promptLines = append( + promptLines, + "Enter an empty string, your public key (e.g. 'ssh-rsa AAAA...'), or one of the following numbers:", + ) + for i, path := range paths { + promptLines = append( + promptLines, + fmt.Sprintf("%d: %s", i+1, path), + ) + } + } + prompt := strings.Join(promptLines, "\n") + + var input string + for { + fmt.Println(prompt) + input, err = a.inputReader.ReadString('\n') + if err != nil && err != io.EOF { + fmt.Println("Could not understand response. Please retry.") + continue + } + if input != "" && input[len(input)-1] == '\n' { + input = input[:len(input)-1] + } + var i int + n, err := fmt.Sscanf(input, "%d", &i) + if n == len(input) && err == nil && i > 0 && i <= len(paths) { + path := paths[i-1] + input = pubKeys[path] + } else { + err = validate.OpenSSHPublicKey(input) + if err != nil { + fmt.Println(err) + continue + } + } + break + } + + return &asset.State{ + Contents: []asset.Content{ + {Data: []byte(input)}, + }, + }, nil +} diff --git a/pkg/asset/installconfig/stock.go b/pkg/asset/installconfig/stock.go index 4dc35c252ab..0e33a9125b3 100644 --- a/pkg/asset/installconfig/stock.go +++ b/pkg/asset/installconfig/stock.go @@ -58,9 +58,8 @@ func (s *StockImpl) EstablishStock(directory string, inputReader *bufio.Reader) Prompt: "Password:", InputReader: inputReader, } - s.sshKey = &asset.UserProvided{ - Prompt: "SSH Key:", - InputReader: inputReader, + s.sshKey = &sshPublicKey{ + inputReader: inputReader, } s.baseDomain = &asset.UserProvided{ Prompt: "Base Domain:", diff --git a/pkg/asset/tls/BUILD.bazel b/pkg/asset/tls/BUILD.bazel index 48e99b3c0b5..6c1c64a3005 100644 --- a/pkg/asset/tls/BUILD.bazel +++ b/pkg/asset/tls/BUILD.bazel @@ -11,6 +11,7 @@ go_test( deps = [ "//pkg/asset:go_default_library", "//pkg/types:go_default_library", + "//vendor/github.com/stretchr/testify/assert:go_default_library", ], )