Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions copy/copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ type digestingReader struct {
source io.Reader
digest hash.Hash
expectedDigest []byte
failureIndicator *bool
validationFailed bool
}

// newDigestingReader returns an io.Reader with contents of source, which will eventually return a non-EOF error
// and set *failureIndicator to true if the source stream does not match expectedDigestString.
func newDigestingReader(source io.Reader, expectedDigestString string, failureIndicator *bool) (io.Reader, error) {
// newDigestingReader returns an io.Reader implementation with contents of source, which will eventually return a non-EOF error
// and set validationFailed to true if the source stream does not match expectedDigestString.
func newDigestingReader(source io.Reader, expectedDigestString string) (*digestingReader, error) {
fields := strings.SplitN(expectedDigestString, ":", 2)
if len(fields) != 2 {
return nil, fmt.Errorf("Invalid digest specification %s", expectedDigestString)
Expand All @@ -50,7 +50,7 @@ func newDigestingReader(source io.Reader, expectedDigestString string, failureIn
source: source,
digest: digest,
expectedDigest: expectedDigest,
failureIndicator: failureIndicator,
validationFailed: false,
}, nil
}

Expand All @@ -67,7 +67,7 @@ func (d *digestingReader) Read(p []byte) (int, error) {
if err == io.EOF {
actualDigest := d.digest.Sum(nil)
if subtle.ConstantTimeCompare(actualDigest, d.expectedDigest) != 1 {
*d.failureIndicator = true
d.validationFailed = true
return 0, fmt.Errorf("Digest did not match, expected %s, got %s", hex.EncodeToString(d.expectedDigest), hex.EncodeToString(actualDigest))
}
}
Expand Down Expand Up @@ -123,15 +123,14 @@ func Image(ctx *types.SystemContext, policyContext *signature.PolicyContext, des
// Note that we don't use a stronger "validationSucceeded" indicator, because
// dest.PutBlob may detect that the layer already exists, in which case we don't
// read stream to the end, and validation does not happen.
validationFailed := false // This is a new instance on each loop iteration.
digestingReader, err := newDigestingReader(stream, digest, &validationFailed)
digestingReader, err := newDigestingReader(stream, digest)
if err != nil {
return fmt.Errorf("Error preparing to verify blob %s: %v", digest, err)
}
if err := dest.PutBlob(digest, digestingReader); err != nil {
return fmt.Errorf("Error writing blob: %v", err)
}
if validationFailed { // Coverage: This should never happen.
if digestingReader.validationFailed { // Coverage: This should never happen.
return fmt.Errorf("Internal error uploading blob %s, digest verification failed but was ignored", digest)
}
}
Expand Down
13 changes: 5 additions & 8 deletions copy/copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ func TestNewDigestingReader(t *testing.T) {
"sha256:0", // Invalid hex value
"sha256:01", // Invalid length of hex value
} {
validationFailed := false
_, err := newDigestingReader(source, input, &validationFailed)
_, err := newDigestingReader(source, input)
assert.Error(t, err, input)
}
}
Expand All @@ -38,25 +37,23 @@ func TestDigestingReaderRead(t *testing.T) {
// Valid input
for _, c := range cases {
source := bytes.NewReader(c.input)
validationFailed := false
reader, err := newDigestingReader(source, c.digest, &validationFailed)
reader, err := newDigestingReader(source, c.digest)
require.NoError(t, err, c.digest)
dest := bytes.Buffer{}
n, err := io.Copy(&dest, reader)
assert.NoError(t, err, c.digest)
assert.Equal(t, int64(len(c.input)), n, c.digest)
assert.Equal(t, c.input, dest.Bytes(), c.digest)
assert.False(t, validationFailed, c.digest)
assert.False(t, reader.validationFailed, c.digest)
}
// Modified input
for _, c := range cases {
source := bytes.NewReader(bytes.Join([][]byte{c.input, []byte("x")}, nil))
validationFailed := false
reader, err := newDigestingReader(source, c.digest, &validationFailed)
reader, err := newDigestingReader(source, c.digest)
require.NoError(t, err, c.digest)
dest := bytes.Buffer{}
_, err = io.Copy(&dest, reader)
assert.Error(t, err, c.digest)
assert.True(t, validationFailed)
assert.True(t, reader.validationFailed)
}
}