diff --git a/src/encoding/pem/pem.go b/src/encoding/pem/pem.go index 90fe3dc50cf16..506196b1db9bd 100644 --- a/src/encoding/pem/pem.go +++ b/src/encoding/pem/pem.go @@ -10,8 +10,10 @@ package pem import ( "bytes" "encoding/base64" + "errors" "io" "sort" + "strings" ) // A Block represents a PEM encoded structure. @@ -110,27 +112,37 @@ func Decode(data []byte) (p *Block, rest []byte) { } // TODO(agl): need to cope with values that spread across lines. - key, val := line[0:i], line[i+1:] + key, val := line[:i], line[i+1:] key = bytes.TrimSpace(key) val = bytes.TrimSpace(val) p.Headers[string(key)] = string(val) rest = next } - i := bytes.Index(rest, pemEnd) - if i < 0 { + var endIndex int + // If there were no headers, the END line might occur + // immediately, without a leading newline. + if len(p.Headers) == 0 && bytes.HasPrefix(rest, pemEnd[1:]) { + endIndex = 0 + } else { + endIndex = bytes.Index(rest, pemEnd) + } + + if endIndex < 0 { return decodeError(data, rest) } - base64Data := removeWhitespace(rest[0:i]) + base64Data := removeWhitespace(rest[:endIndex]) p.Bytes = make([]byte, base64.StdEncoding.DecodedLen(len(base64Data))) n, err := base64.StdEncoding.Decode(p.Bytes, base64Data) if err != nil { return decodeError(data, rest) } - p.Bytes = p.Bytes[0:n] + p.Bytes = p.Bytes[:n] - _, rest = getLine(rest[i+len(pemEnd):]) + // the -1 is because we might have only matched pemEnd without the + // leading newline if the PEM block was empty. + _, rest = getLine(rest[endIndex+len(pemEnd)-1:]) return } @@ -246,6 +258,9 @@ func Encode(out io.Writer, b *Block) error { // For consistency of output, write other headers sorted by key. sort.Strings(h) for _, k := range h { + if strings.Contains(k, ":") { + return errors.New("pem: cannot encode a header key that contains a colon") + } if err := writeHeader(out, k, b.Headers[k]); err != nil { return err } diff --git a/src/encoding/pem/pem_test.go b/src/encoding/pem/pem_test.go index 92451feff89ed..1913f44c1fb74 100644 --- a/src/encoding/pem/pem_test.go +++ b/src/encoding/pem/pem_test.go @@ -8,7 +8,9 @@ import ( "bytes" "io/ioutil" "reflect" + "strings" "testing" + "testing/quick" ) type GetLineTest struct { @@ -44,6 +46,32 @@ func TestDecode(t *testing.T) { if !reflect.DeepEqual(result, privateKey) { t.Errorf("#1 got:%#v want:%#v", result, privateKey) } + + isEmpty := func(block *Block) bool { + return block != nil && block.Type == "EMPTY" && len(block.Headers) == 0 && len(block.Bytes) == 0 + } + result, remainder = Decode(remainder) + if !isEmpty(result) { + t.Errorf("#2 should be empty but got:%#v", result) + } + result, remainder = Decode(remainder) + if !isEmpty(result) { + t.Errorf("#3 should be empty but got:%#v", result) + } + result, remainder = Decode(remainder) + if !isEmpty(result) { + t.Errorf("#4 should be empty but got:%#v", result) + } + + result, remainder = Decode(remainder) + if result == nil || result.Type != "HEADERS" || len(result.Headers) != 1 { + t.Errorf("#5 expected single header block but got :%v", result) + } + + if len(remainder) != 0 { + t.Errorf("expected nothing remaining of pemData, but found %s", string(remainder)) + } + result, _ = Decode([]byte(pemPrivateKey2)) if !reflect.DeepEqual(result, privateKey2) { t.Errorf("#2 got:%#v want:%#v", result, privateKey2) @@ -117,6 +145,44 @@ func TestLineBreaker(t *testing.T) { } } +func TestFuzz(t *testing.T) { + testRoundtrip := func(block *Block) bool { + for key := range block.Headers { + if strings.Contains(key, ":") { + // Keys with colons cannot be encoded. + return true + } + } + + var buf bytes.Buffer + err := Encode(&buf, block) + decoded, rest := Decode(buf.Bytes()) + + switch { + case err != nil: + t.Errorf("Encode of %#v resulted in error: %s", block, err) + case !reflect.DeepEqual(block, decoded): + t.Errorf("Encode of %#v decoded as %#v", block, decoded) + case len(rest) != 0: + t.Errorf("Encode of %#v decoded correctly, but with %x left over", block, rest) + default: + return true + } + return false + } + + // Explicitly test the empty block. + if !testRoundtrip(&Block{ + Type: "EMPTY", + Headers: make(map[string]string), + Bytes: []byte{}, + }) { + return + } + + quick.Check(testRoundtrip, nil) +} + func BenchmarkEncode(b *testing.B) { data := &Block{Bytes: make([]byte, 65536)} b.SetBytes(int64(len(data.Bytes))) @@ -188,7 +254,32 @@ BTiHcL3s3KrJu1vDVrshvxfnz71KTeNnZH8UbOqT5i7fPGyXtY1XJddcbI/Q6tXf wHFsZc20TzSdsVLBtwksUacpbDogcEVMctnNrB8FIrB3vZEv9Q0Z1VeY7nmTpF+6 a+z2P7acL7j6A6Pr3+q8P9CPiPC7zFonVzuVPyB8GchGR2hytyiOVpuD9+k8hcuw ZWAaUoVtWIQ52aKS0p19G99hhb+IVANC4akkdHV4SP8i7MVNZhfUmg== ------END RSA PRIVATE KEY-----` +-----END RSA PRIVATE KEY----- + + +-----BEGIN EMPTY----- +-----END EMPTY----- + +-----BEGIN EMPTY----- + +-----END EMPTY----- + +-----BEGIN EMPTY----- + + +-----END EMPTY----- + +# This shouldn't be recognised because of the missing newline after the +headers. +-----BEGIN HEADERS----- +Header: 1 +-----END HEADERS----- + +# This should be valid, however. +-----BEGIN HEADERS----- +Header: 1 + +-----END HEADERS-----` var certificate = &Block{Type: "CERTIFICATE", Headers: map[string]string{},