diff --git a/installer/BUILD.bazel b/installer/BUILD.bazel index ab425037f6..673c5f30e8 100644 --- a/installer/BUILD.bazel +++ b/installer/BUILD.bazel @@ -33,5 +33,6 @@ test_suite( tests = [ "//installer/pkg/workflow:go_default_test", "//installer/pkg/config-generator:go_default_test", + "//installer/pkg/validate:go_default_test", ], ) diff --git a/installer/pkg/validate/BUILD.bazel b/installer/pkg/validate/BUILD.bazel new file mode 100644 index 0000000000..1724d95b64 --- /dev/null +++ b/installer/pkg/validate/BUILD.bazel @@ -0,0 +1,15 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_test( + name = "go_default_test", + srcs = ["validate_test.go"], + embed = [":go_default_library"], +) + +go_library( + name = "go_default_library", + srcs = ["validate.go"], + importpath = "github.com/coreos/tectonic-installer/installer/pkg/validate", + visibility = ["//visibility:public"], + deps = ["//installer/pkg/config:go_default_library"], +) diff --git a/installer/pkg/validate/validate.go b/installer/pkg/validate/validate.go new file mode 100644 index 0000000000..b917c8e167 --- /dev/null +++ b/installer/pkg/validate/validate.go @@ -0,0 +1,364 @@ +package validate + +import ( + "errors" + "fmt" + "net" + "regexp" + "strconv" + "strings" + "unicode/utf8" +) + +func isMatch(re string, v string) bool { + return regexp.MustCompile(re).MatchString(v) +} + +// NonEmpty checks if the given string contains at least one non-whitespace character and returns an error if not. +func NonEmpty(v string) error { + if utf8.RuneCountInString(strings.TrimSpace(v)) == 0 { + return errors.New("cannot be empty") + } + return nil +} + +// Int checks if the given string is a valid integer and returns an error if not. +func Int(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + if _, err := strconv.Atoi(v); err != nil { + return errors.New("invalid integer") + } + return nil +} + +// IntRange checks if the given string is a valid integer between `min` and `max` and returns an error if not. +func IntRange(v string, min int, max int) error { + i, err := strconv.Atoi(v) + if err != nil { + return Int(v) + } + if i < min { + return fmt.Errorf("cannot be less than %v", min) + } + if i > max { + return fmt.Errorf("cannot be greater than %v", max) + } + return nil +} + +// IntOdd checks if the given string is a valid integer and that it is odd and returns an error if not. +func IntOdd(v string) error { + i, err := strconv.Atoi(v) + if err != nil { + return Int(v) + } + if i%2 != 1 { + return errors.New("must be an odd integer") + } + return nil +} + +// ClusterName checks if the given string is a valid name for a cluster and returns an error if not. +func ClusterName(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + if length := utf8.RuneCountInString(v); length < 1 || length > 253 { + return errors.New("must be between 1 and 253 characters") + } + + if strings.ToLower(v) != v { + return errors.New("must be lower case") + } + + if !isMatch("^[a-z0-9-.]*$", v) { + return errors.New("only lower case alphanumeric [a-z0-9], dashes and dots are allowed") + } + + isAlphaNum := regexp.MustCompile("^[a-z0-9]$").MatchString + + // If we got this far, we know the string is ASCII and has at least one character + if !isAlphaNum(v[:1]) || !isAlphaNum(v[len(v)-1:]) { + return errors.New("must start and end with a lower case alphanumeric character [a-z0-9]") + } + + for _, segment := range strings.Split(v, ".") { + // Each segment can have up to 63 characters + if utf8.RuneCountInString(segment) > 63 { + return errors.New("no segment between dots can be more than 63 characters") + } + if !isAlphaNum(segment[:1]) || !isAlphaNum(segment[len(segment)-1:]) { + return errors.New("segments between dots must start and end with a lower case alphanumeric character [a-z0-9]") + } + } + + return nil +} + +// AWSClusterName checks if the given string is a valid name for a cluster on AWS and returns an error if not. +// See AWS docs: +// http://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/cfn-using-console-create-stack-parameters.html +// http://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-elasticloadbalancingv2-loadbalancer.html#cfn-elasticloadbalancingv2-loadbalancer-name +func AWSClusterName(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + if length := utf8.RuneCountInString(v); length < 1 || length > 28 { + return errors.New("must be between 1 and 28 characters") + } + + if strings.ToLower(v) != v { + return errors.New("must be lower case") + } + + if strings.HasPrefix(v, "-") || strings.HasSuffix(v, "-") { + return errors.New("must not start or end with '-'") + } + + if !isMatch("^[a-z][-a-z0-9]*$", v) { + return errors.New("must be a lower case AWS Stack Name: [a-z][-a-z0-9]*") + } + + return nil +} + +// MAC checks if the given string is a valid MAC address and returns an error if not. +// Based on net.ParseMAC. +func MAC(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + if _, err := net.ParseMAC(v); err != nil { + return errors.New("invalid MAC Address") + } + return nil +} + +// IPv4 checks if the given string is a valid IP v4 address and returns an error if not. +// Based on net.ParseIP. +func IPv4(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + if ip := net.ParseIP(v); ip == nil || !strings.Contains(v, ".") { + return errors.New("invalid IPv4 address") + } + return nil +} + +// SubnetCIDR checks if the given string is a valid CIDR for a master nodes or worker nodes subnet and returns an error if not. +func SubnetCIDR(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + split := strings.Split(v, "/") + + if len(split) == 1 { + return errors.New("must provide a CIDR netmask (eg, /24)") + } + + if len(split) != 2 { + return errors.New("invalid IPv4 address") + } + + ip := split[0] + + if err := IPv4(ip); err != nil { + return errors.New("invalid IPv4 address") + } + + if mask, err := strconv.Atoi(split[1]); err != nil || mask < 0 || mask > 32 { + return errors.New("invalid netmask size (must be between 0 and 32)") + } + + // Catch any invalid CIDRs not caught by the checks above + if _, _, err := net.ParseCIDR(v); err != nil { + return errors.New("invalid CIDR") + } + + if strings.HasPrefix(ip, "172.17.") { + return errors.New("overlaps with default Docker Bridge subnet (172.17.0.0/16)") + } + + return nil +} + +// AWSSubnetCIDR checks if the given string is a valid CIDR for a master nodes or worker nodes subnet in an AWS VPC and returns an error if not. +func AWSSubnetCIDR(v string) error { + if err := SubnetCIDR(v); err != nil { + return err + } + + _, network, err := net.ParseCIDR(v) + if err != nil { + return errors.New("invalid CIDR") + } + if mask, _ := network.Mask.Size(); mask < 16 || mask > 28 { + return errors.New("AWS subnets must be between /16 and /28") + } + + return nil +} + +// DomainName checks if the given string is a valid domain name and returns an error if not. +func DomainName(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + split := strings.Split(v, ".") + for i, segment := range split { + // Trailing dot is OK + if len(segment) == 0 && i == len(split)-1 { + continue + } + if !isMatch("^[a-zA-Z0-9-]{1,63}$", segment) { + return errors.New("invalid domain name") + } + } + return nil +} + +// Host checks if the given string is either a valid IPv4 address or a valid domain name and returns an error if not. +func Host(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + // Either a valid IP address or domain name + if IPv4(v) != nil && DomainName(v) != nil { + return errors.New("invalid host (must be a domain name or IP address)") + } + return nil +} + +// Port checks if the given string is a valid port number and returns an error if not. +func Port(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + if IntRange(v, 1, 65535) != nil { + return errors.New("invalid port number") + } + return nil +} + +// HostPort checks if the given string is valid : format and returns an error if not. +func HostPort(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + split := strings.Split(v, ":") + if len(split) != 2 { + return errors.New("must use : format") + } + if err := Host(split[0]); err != nil { + return err + } + return Port(split[1]) +} + +// Email checks if the given string is a valid email address and returns an error if not. +func Email(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + invalidError := errors.New("invalid email address") + + split := strings.Split(v, "@") + if len(split) != 2 { + return invalidError + } + localPart := split[0] + domain := split[1] + + if NonEmpty(localPart) != nil { + return invalidError + } + + // No whitespace allowed in local-part + if isMatch(`\s`, localPart) { + return invalidError + } + + return DomainName(domain) +} + +const base64RegExp = `[A-Za-z0-9+\/]+={0,2}` + +// Certificate checks if the given string is a valid certificate in PEM format and returns an error if not. +// Ignores leading and trailing whitespace. +func Certificate(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + trimmed := strings.TrimSpace(v) + + // Don't let users hang themselves + if isMatch(`-BEGIN [\w-]+ PRIVATE KEY-`, trimmed) { + return errors.New("invalid certificate (appears to be a private key)") + } + + if !isMatch("(?s:^-----BEGIN CERTIFICATE-----\n"+base64RegExp+"\n-----END CERTIFICATE-----$)", trimmed) { + return errors.New("invalid certificate") + } + return nil +} + +// PrivateKey checks if the given string is a valid private key in PEM format and returns an error if not. +// Ignores leading and trailing whitespace. +func PrivateKey(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + trimmed := strings.TrimSpace(v) + + if !isMatch("(?s:^-----BEGIN [A-Z]{2,10} PRIVATE KEY-----\n"+base64RegExp+"\n-----END [A-Z]{2,10} PRIVATE KEY-----$)", trimmed) { + return errors.New("invalid private key") + } + return nil +} + +// OpenSSHPublicKey checks if the given string is a valid OpenSSH public key and returns an error if not. +// Ignores leading and trailing whitespace. +func OpenSSHPublicKey(v string) error { + if err := NonEmpty(v); err != nil { + return err + } + + trimmed := strings.TrimSpace(v) + + // Don't let users hang themselves + if isMatch(`-BEGIN [\w-]+ PRIVATE KEY-`, trimmed) { + return errors.New("invalid SSH public key (appears to be a private key)") + } + + if strings.Contains(trimmed, "\n") { + return errors.New("invalid SSH public key (should not contain any newline characters)") + } + + invalidError := errors.New("invalid SSH public key") + + keyParts := regexp.MustCompile(`\s+`).Split(trimmed, -1) + if len(keyParts) < 2 { + return invalidError + } + + keyType := keyParts[0] + keyBase64 := keyParts[1] + if !isMatch(`^[\w-]+$`, keyType) || !isMatch("^"+base64RegExp+"$", keyBase64) { + return invalidError + } + + return nil +} diff --git a/installer/pkg/validate/validate_test.go b/installer/pkg/validate/validate_test.go new file mode 100644 index 0000000000..e25922fb6e --- /dev/null +++ b/installer/pkg/validate/validate_test.go @@ -0,0 +1,437 @@ +package validate + +import ( + "strings" + "testing" +) + +const caseMsg = "must be lower case" +const emptyMsg = "cannot be empty" +const invalidDomainMsg = "invalid domain name" +const invalidHostMsg = "invalid host (must be a domain name or IP address)" +const invalidIPMsg = "invalid IPv4 address" +const invalidIntMsg = "invalid integer" +const invalidPortMsg = "invalid port number" +const noCIDRNetmaskMsg = "must provide a CIDR netmask (eg, /24)" + +type test struct { + in string + expected string +} + +type validator func(string) error + +func runTests(t *testing.T, funcName string, fn validator, tests []test) { + for _, test := range tests { + err := fn(test.in) + if (err == nil && test.expected != "") || (err != nil && err.Error() != test.expected) { + t.Errorf("For %s(%q), expected %q, got %q", funcName, test.in, test.expected, err) + } + } +} + +func TestNonEmpty(t *testing.T) { + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", ""}, + {".", ""}, + {"日本語", ""}, + } + runTests(t, "NonEmpty", NonEmpty, tests) +} + +func TestInt(t *testing.T) { + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"2 3", invalidIntMsg}, + {"1.1", invalidIntMsg}, + {"abc", invalidIntMsg}, + {"日本語", invalidIntMsg}, + {"1 abc", invalidIntMsg}, + {"日本語2", invalidIntMsg}, + {"0", ""}, + {"1", ""}, + {"999999", ""}, + {"-1", ""}, + } + runTests(t, "Int", Int, tests) +} + +func TestIntRange(t *testing.T) { + tests := []struct { + in string + min int + max int + expected string + }{ + {"", 4, 6, emptyMsg}, + {" ", 4, 6, emptyMsg}, + {"2 3", 1, 2, invalidIntMsg}, + {"1.1", 0, 0, invalidIntMsg}, + {"abc", -2, -1, invalidIntMsg}, + {"日本語", 99, 100, invalidIntMsg}, + {"5", 4, 6, ""}, + {"5", 5, 5, ""}, + {"5", 6, 8, "cannot be less than 6"}, + {"5", 6, 4, "cannot be less than 6"}, + {"5", 2, 4, "cannot be greater than 4"}, + } + + for _, test := range tests { + err := IntRange(test.in, test.min, test.max) + if (err == nil && test.expected != "") || (err != nil && err.Error() != test.expected) { + t.Errorf("For IntRange(%q, %v, %v), expected %q, got %q", test.in, test.min, test.max, test.expected, err) + } + } +} + +func TestIntOdd(t *testing.T) { + notOddMsg := "must be an odd integer" + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"0", notOddMsg}, + {"1", ""}, + {"2", notOddMsg}, + {"99", ""}, + {"100", notOddMsg}, + {"abc", invalidIntMsg}, + {"1 abc", invalidIntMsg}, + {"日本語", invalidIntMsg}, + } + runTests(t, "IntOdd", IntOdd, tests) +} + +func TestClusterName(t *testing.T) { + const charsMsg = "only lower case alphanumeric [a-z0-9], dashes and dots are allowed" + const lengthMsg = "must be between 1 and 253 characters" + const segmentLengthMsg = "no segment between dots can be more than 63 characters" + const startEndCharMsg = "must start and end with a lower case alphanumeric character [a-z0-9]" + const segmentStartEndCharMsg = "segments between dots must start and end with a lower case alphanumeric character [a-z0-9]" + + maxSizeName := strings.Repeat("123456789.", 25) + "123" + maxSizeSegment := strings.Repeat("1234567890", 6) + "123" + + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", ""}, + {"A", caseMsg}, + {"abc D", caseMsg}, + {"1", ""}, + {".", startEndCharMsg}, + {"a.", startEndCharMsg}, + {".a", startEndCharMsg}, + {"a.a", ""}, + {"-a", startEndCharMsg}, + {"a-", startEndCharMsg}, + {"a.-a", segmentStartEndCharMsg}, + {"a-.a", segmentStartEndCharMsg}, + {"a%a", charsMsg}, + {"日本語", charsMsg}, + {"a日本語a", charsMsg}, + {maxSizeName, ""}, + {maxSizeName + "a", lengthMsg}, + {maxSizeSegment + ".abc", ""}, + {maxSizeSegment + "a.abc", segmentLengthMsg}, + } + runTests(t, "ClusterName", ClusterName, tests) +} + +func TestAWSClusterName(t *testing.T) { + const charsMsg = "must be a lower case AWS Stack Name: [a-z][-a-z0-9]*" + const lengthMsg = "must be between 1 and 28 characters" + const hyphenMsg = "must not start or end with '-'" + + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", ""}, + {"A", caseMsg}, + {"abc D", caseMsg}, + {"1", charsMsg}, + {".", charsMsg}, + {"a.", charsMsg}, + {".a", charsMsg}, + {"a.a", charsMsg}, + {"a%a", charsMsg}, + {"a-a", ""}, + {"-abc", hyphenMsg}, + {"abc-", hyphenMsg}, + {"日本語", charsMsg}, + {"a日本語a", charsMsg}, + {"a234567890123456789012345678", ""}, + {"12345678901234567890123456789", lengthMsg}, + {"A2345678901234567890123456789", lengthMsg}, + } + runTests(t, "AWSClusterName", AWSClusterName, tests) +} + +func TestMAC(t *testing.T) { + const invalidMsg = "invalid MAC Address" + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"abc", invalidMsg}, + {"12:34:45:78:9A:BC", ""}, + {"12-34-45-78-9A-BC", ""}, + {"12:34:45:78:9a:bc", ""}, + {"12:34:45:78:9X:YZ", invalidMsg}, + {"12.34.45.78.9A.BC", invalidMsg}, + } + runTests(t, "MAC", MAC, tests) +} + +func TestIPv4(t *testing.T) { + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"0.0.0.0", ""}, + {"1.2.3.4", ""}, + {"1.2.3.", invalidIPMsg}, + {"1.2.3.4.", invalidIPMsg}, + {"1.2.3.a", invalidIPMsg}, + {"255.255.255.255", ""}, + } + runTests(t, "IPv4", IPv4, tests) +} + +func TestSubnetCIDR(t *testing.T) { + const netmaskSizeMsg = "invalid netmask size (must be between 0 and 32)" + + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"/16", invalidIPMsg}, + {"0.0.0.0/0", ""}, + {"0.0.0.0/32", ""}, + {"1.2.3.4", noCIDRNetmaskMsg}, + {"1.2.3.", noCIDRNetmaskMsg}, + {"1.2.3.4.", noCIDRNetmaskMsg}, + {"1.2.3.4/0", ""}, + {"1.2.3.4/1", ""}, + {"1.2.3.4/31", ""}, + {"1.2.3.4/32", ""}, + {"1.2.3./16", invalidIPMsg}, + {"1.2.3.4./16", invalidIPMsg}, + {"1.2.3.4/33", netmaskSizeMsg}, + {"1.2.3.4/-1", netmaskSizeMsg}, + {"1.2.3.4/abc", netmaskSizeMsg}, + {"172.17.1.2", noCIDRNetmaskMsg}, + {"172.17.1.2/", netmaskSizeMsg}, + {"172.17.1.2/33", netmaskSizeMsg}, + {"172.17.1.2/20", "overlaps with default Docker Bridge subnet (172.17.0.0/16)"}, + {"255.255.255.255/1", ""}, + {"255.255.255.255/32", ""}, + } + runTests(t, "SubnetCIDR", SubnetCIDR, tests) +} + +func TestAWSsubnetCIDR(t *testing.T) { + const awsNetmaskSizeMsg = "AWS subnets must be between /16 and /28" + + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"/20", invalidIPMsg}, + {"1.2.3.4", noCIDRNetmaskMsg}, + {"1.2.3.4/15", awsNetmaskSizeMsg}, + {"1.2.3.4/16", ""}, + {"1.2.3.4/28", ""}, + {"1.2.3.4/29", awsNetmaskSizeMsg}, + } + runTests(t, "AWSSubnetCIDR", AWSSubnetCIDR, tests) +} + +func TestDomainName(t *testing.T) { + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", ""}, + {".", invalidDomainMsg}, + {"日本語", invalidDomainMsg}, + {"日本語.com", invalidDomainMsg}, + {"abc.日本語.com", invalidDomainMsg}, + {"a日本語a.com", invalidDomainMsg}, + {"abc", ""}, + {"ABC", ""}, + {"ABC123", ""}, + {"ABC123.COM123", ""}, + {"1", ""}, + {"0.0", ""}, + {"1.2.3.4", ""}, + {"1.2.3.4.", ""}, + {"abc.", ""}, + {"abc.com", ""}, + {"abc.com.", ""}, + {"a.b.c.d.e.f", ""}, + {".abc", invalidDomainMsg}, + {".abc.com", invalidDomainMsg}, + {".abc.com", invalidDomainMsg}, + } + runTests(t, "DomainName", DomainName, tests) +} + +func TestHost(t *testing.T) { + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", ""}, + {".", invalidHostMsg}, + {"日本語", invalidHostMsg}, + {"日本語.com", invalidHostMsg}, + {"abc.日本語.com", invalidHostMsg}, + {"a日本語a.com", invalidHostMsg}, + {"abc", ""}, + {"ABC", ""}, + {"ABC123", ""}, + {"ABC123.COM123", ""}, + {"1", ""}, + {"0.0", ""}, + {"1.2.3.4", ""}, + {"1.2.3.4.", ""}, + {"abc.", ""}, + {"abc.com", ""}, + {"abc.com.", ""}, + {"a.b.c.d.e.f", ""}, + {".abc", invalidHostMsg}, + {".abc.com", invalidHostMsg}, + {".abc.com", invalidHostMsg}, + } + runTests(t, "Host", Host, tests) +} + +func TestPort(t *testing.T) { + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", invalidPortMsg}, + {".", invalidPortMsg}, + {"日本語", invalidPortMsg}, + {"0", invalidPortMsg}, + {"1", ""}, + {"123", ""}, + {"12345", ""}, + {"65535", ""}, + {"65536", invalidPortMsg}, + } + runTests(t, "Port", Port, tests) +} + +func TestHostPort(t *testing.T) { + const invalidHostPortMsg = "must use : format" + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {".", invalidHostPortMsg}, + {"日本語", invalidHostPortMsg}, + {"abc.com", invalidHostPortMsg}, + {"abc.com:0", invalidPortMsg}, + {"abc.com:1", ""}, + {"abc.com:65535", ""}, + {"abc.com:65536", invalidPortMsg}, + {"abc.com:abc", invalidPortMsg}, + {"1.2.3.4:1234", ""}, + {"1.2.3.4:abc", invalidPortMsg}, + {"日本語:1234", invalidHostMsg}, + } + runTests(t, "HostPort", HostPort, tests) +} + +func TestEmail(t *testing.T) { + const invalidMsg = "invalid email address" + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", invalidMsg}, + {".", invalidMsg}, + {"日本語", invalidMsg}, + {"a@abc.com", ""}, + {"A@abc.com", ""}, + {"1@abc.com", ""}, + {"a.B.1.あ@abc.com", ""}, + {"ア@abc.com", ""}, + {"中文@abc.com", ""}, + {"a@abc.com", ""}, + {"a@ABC.com", ""}, + {"a@123.com", ""}, + {"a@日本語.com", invalidDomainMsg}, + {"a@.com", invalidDomainMsg}, + {"@abc.com", invalidMsg}, + } + runTests(t, "Email", Email, tests) +} + +func TestCertificate(t *testing.T) { + const invalidMsg = "invalid certificate" + const privateKeyMsg = "invalid certificate (appears to be a private key)" + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", invalidMsg}, + {".", invalidMsg}, + {"日本語", invalidMsg}, + {"-----BEGIN CERTIFICATE-----\na\n-----END CERTIFICATE-----", ""}, + {"-----BEGIN CERTIFICATE-----\nabc\n-----END CERTIFICATE-----", ""}, + {"-----BEGIN CERTIFICATE-----\nabc=\n-----END CERTIFICATE-----", ""}, + {"-----BEGIN CERTIFICATE-----\nabc==\n-----END CERTIFICATE-----", ""}, + {"-----BEGIN CERTIFICATE-----\nabc===\n-----END CERTIFICATE-----", invalidMsg}, + {"-----BEGIN CERTIFICATE-----\na%a\n-----END CERTIFICATE-----", invalidMsg}, + {"-----BEGIN CERTIFICATE-----\n\nab\n-----END CERTIFICATE-----", invalidMsg}, + {"-----BEGIN CERTIFICATE-----\nab\n\n-----END CERTIFICATE-----", invalidMsg}, + {"-----BEGIN CERTIFICATE-----\na\n-----END CERTIFICATE-----\n-----BEGIN CERTIFICATE-----\na\n-----END CERTIFICATE-----", invalidMsg}, + {"-----BEGIN RSA PRIVATE KEY-----\nabc\n-----END RSA PRIVATE KEY-----", privateKeyMsg}, + } + runTests(t, "Certificate", Certificate, tests) +} + +func TestPrivateKey(t *testing.T) { + const invalidMsg = "invalid private key" + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", invalidMsg}, + {".", invalidMsg}, + {"日本語", invalidMsg}, + {"-----BEGIN RSA PRIVATE KEY-----\na\n-----END RSA PRIVATE KEY-----", ""}, + {"-----BEGIN RSA PRIVATE KEY-----\nabc\n-----END RSA PRIVATE KEY-----", ""}, + {"-----BEGIN RSA PRIVATE KEY-----\nabc=\n-----END RSA PRIVATE KEY-----", ""}, + {"-----BEGIN RSA PRIVATE KEY-----\nabc==\n-----END RSA PRIVATE KEY-----", ""}, + {"-----BEGIN RSA PRIVATE KEY-----\nabc===\n-----END RSA PRIVATE KEY-----", invalidMsg}, + {"-----BEGIN EC PRIVATE KEY-----\nabc\n-----END EC PRIVATE KEY-----", ""}, + {"-----BEGIN RSA PRIVATE KEY-----\na%a\n-----END RSA PRIVATE KEY-----", invalidMsg}, + {"-----BEGIN RSA PRIVATE KEY-----\n\nab\n-----END RSA PRIVATE KEY-----", invalidMsg}, + {"-----BEGIN RSA PRIVATE KEY-----\nab\n\n-----END RSA PRIVATE KEY-----", invalidMsg}, + {"-----BEGIN RSA PRIVATE KEY-----\na\n-----END RSA PRIVATE KEY-----\n-----BEGIN CERTIFICATE-----\na\n-----END CERTIFICATE-----", invalidMsg}, + {"-----BEGIN CERTIFICATE-----\na\n-----END CERTIFICATE-----", invalidMsg}, + } + runTests(t, "PrivateKey", PrivateKey, tests) +} + +func TestOpenSSHPublicKey(t *testing.T) { + const invalidMsg = "invalid SSH public key" + const multiLineMsg = "invalid SSH public key (should not contain any newline characters)" + const privateKeyMsg = "invalid SSH public key (appears to be a private key)" + tests := []test{ + {"", emptyMsg}, + {" ", emptyMsg}, + {"a", invalidMsg}, + {".", invalidMsg}, + {"日本語", invalidMsg}, + {"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDxL", ""}, + {"ssh-rsa \t AAAAB3NzaC1yc2EAAAADAQABAAACAQDxL", ""}, + {"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDxL you@example.com", ""}, + {"\nssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDxL you@example.com", ""}, + {"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDxL you@example.com\n", ""}, + {"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDxL\nssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDxL", multiLineMsg}, + {"ssh-rsa\nAAAAB3NzaC1yc2EAAAADAQABAAACAQDxL you@example.com", multiLineMsg}, + {"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDxL\nyou@example.com", multiLineMsg}, + {"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAACAQDxL", ""}, + {"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCt3BebCHqnSsgpLjo4kVvyfY/z2BS8t27r/7du+O2pb4xYkr7n+KFpbOz523vMTpQ+o1jY4u4TgexglyT9nqasWgLOvo1qjD1agHme8LlTPQSk07rXqOB85Uq5p7ig2zoOejF6qXhcc3n1c7+HkxHrgpBENjLVHOBpzPBIAHkAGaZcl07OCqbsG5yxqEmSGiAlh/IiUVOZgdDMaGjCRFy0wk0mQaGD66DmnFc1H5CzcPjsxr0qO65e7lTGsE930KkO1Vc+RHCVwvhdXs+c2NhJ2/3740Kpes9n1/YullaWZUzlCPDXtRuy6JRbFbvy39JUgHWGWzB3d+3f8oJ/N4qZ cardno:000603633110", ""}, + {"-----BEGIN CERTIFICATE-----abcd-----END CERTIFICATE-----", invalidMsg}, + {"-----BEGIN RSA PRIVATE KEY-----\nabc\n-----END RSA PRIVATE KEY-----", privateKeyMsg}, + } + runTests(t, "OpenSSHPublicKey", OpenSSHPublicKey, tests) +}