diff --git a/copy/copy.go b/copy/copy.go index 0a2241fb83..79fd3b1886 100644 --- a/copy/copy.go +++ b/copy/copy.go @@ -24,12 +24,12 @@ type digestingReader struct { source io.Reader digest hash.Hash expectedDigest []byte - failureIndicator *bool + validationFailed bool } -// newDigestingReader returns an io.Reader with contents of source, which will eventually return a non-EOF error -// and set *failureIndicator to true if the source stream does not match expectedDigestString. -func newDigestingReader(source io.Reader, expectedDigestString string, failureIndicator *bool) (io.Reader, error) { +// newDigestingReader returns an io.Reader implementation with contents of source, which will eventually return a non-EOF error +// and set validationFailed to true if the source stream does not match expectedDigestString. +func newDigestingReader(source io.Reader, expectedDigestString string) (*digestingReader, error) { fields := strings.SplitN(expectedDigestString, ":", 2) if len(fields) != 2 { return nil, fmt.Errorf("Invalid digest specification %s", expectedDigestString) @@ -50,7 +50,7 @@ func newDigestingReader(source io.Reader, expectedDigestString string, failureIn source: source, digest: digest, expectedDigest: expectedDigest, - failureIndicator: failureIndicator, + validationFailed: false, }, nil } @@ -67,7 +67,7 @@ func (d *digestingReader) Read(p []byte) (int, error) { if err == io.EOF { actualDigest := d.digest.Sum(nil) if subtle.ConstantTimeCompare(actualDigest, d.expectedDigest) != 1 { - *d.failureIndicator = true + d.validationFailed = true return 0, fmt.Errorf("Digest did not match, expected %s, got %s", hex.EncodeToString(d.expectedDigest), hex.EncodeToString(actualDigest)) } } @@ -123,15 +123,14 @@ func Image(ctx *types.SystemContext, policyContext *signature.PolicyContext, des // Note that we don't use a stronger "validationSucceeded" indicator, because // dest.PutBlob may detect that the layer already exists, in which case we don't // read stream to the end, and validation does not happen. - validationFailed := false // This is a new instance on each loop iteration. - digestingReader, err := newDigestingReader(stream, digest, &validationFailed) + digestingReader, err := newDigestingReader(stream, digest) if err != nil { return fmt.Errorf("Error preparing to verify blob %s: %v", digest, err) } if err := dest.PutBlob(digest, digestingReader); err != nil { return fmt.Errorf("Error writing blob: %v", err) } - if validationFailed { // Coverage: This should never happen. + if digestingReader.validationFailed { // Coverage: This should never happen. return fmt.Errorf("Internal error uploading blob %s, digest verification failed but was ignored", digest) } } diff --git a/copy/copy_test.go b/copy/copy_test.go index 0a0fd6793d..c5aaf5c4a0 100644 --- a/copy/copy_test.go +++ b/copy/copy_test.go @@ -20,8 +20,7 @@ func TestNewDigestingReader(t *testing.T) { "sha256:0", // Invalid hex value "sha256:01", // Invalid length of hex value } { - validationFailed := false - _, err := newDigestingReader(source, input, &validationFailed) + _, err := newDigestingReader(source, input) assert.Error(t, err, input) } } @@ -38,25 +37,23 @@ func TestDigestingReaderRead(t *testing.T) { // Valid input for _, c := range cases { source := bytes.NewReader(c.input) - validationFailed := false - reader, err := newDigestingReader(source, c.digest, &validationFailed) + reader, err := newDigestingReader(source, c.digest) require.NoError(t, err, c.digest) dest := bytes.Buffer{} n, err := io.Copy(&dest, reader) assert.NoError(t, err, c.digest) assert.Equal(t, int64(len(c.input)), n, c.digest) assert.Equal(t, c.input, dest.Bytes(), c.digest) - assert.False(t, validationFailed, c.digest) + assert.False(t, reader.validationFailed, c.digest) } // Modified input for _, c := range cases { source := bytes.NewReader(bytes.Join([][]byte{c.input, []byte("x")}, nil)) - validationFailed := false - reader, err := newDigestingReader(source, c.digest, &validationFailed) + reader, err := newDigestingReader(source, c.digest) require.NoError(t, err, c.digest) dest := bytes.Buffer{} _, err = io.Copy(&dest, reader) assert.Error(t, err, c.digest) - assert.True(t, validationFailed) + assert.True(t, reader.validationFailed) } }