Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #63: invalid reads after write #64

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions diskv.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import (
"fmt"
"io"
"io/ioutil"
"math/rand"
"os"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
)

const (
Expand All @@ -40,6 +42,11 @@ var (
errImportDirectory = errors.New("can't import a directory")
)

func init() {
// Make sure we have good random numbers
rand.Seed(time.Now().UnixNano())
}
floren marked this conversation as resolved.
Show resolved Hide resolved

// TransformFunction transforms a key into a slice of strings, with each
// element in the slice representing a directory in the file path where the
// key's entry will eventually be stored.
Expand Down Expand Up @@ -76,6 +83,7 @@ type Options struct {
CacheSizeMax uint64 // bytes
PathPerm os.FileMode
FilePerm os.FileMode
// Note: TempDir is deprecated, all writes are now atomic.
// If TempDir is set, it will enable filesystem atomic writes by
// writing temporary files to that location before being moved
// to BasePath.
Expand Down Expand Up @@ -196,28 +204,18 @@ func (d *Diskv) WriteStream(key string, r io.Reader, sync bool) error {
return d.writeStreamWithLock(pathKey, r, sync)
}

// createKeyFileWithLock either creates the key file directly, or
// creates a temporary file in TempDir if it is set.
// createKeyFileWithLock creates the key file with a random extension. This
// will be automatically renamed by writeStreamWithLock once the write has been
// completed. This solves issue #63, where calling ReadStream, then updating the
// key before reading completes, leads to the reader getting invalid data.
func (d *Diskv) createKeyFileWithLock(pathKey *PathKey) (*os.File, error) {
if d.TempDir != "" {
if err := os.MkdirAll(d.TempDir, d.PathPerm); err != nil {
return nil, fmt.Errorf("temp mkdir: %s", err)
}
f, err := ioutil.TempFile(d.TempDir, "")
if err != nil {
return nil, fmt.Errorf("temp file: %s", err)
}

if err := os.Chmod(f.Name(), d.FilePerm); err != nil {
f.Close() // error deliberately ignored
os.Remove(f.Name()) // error deliberately ignored
return nil, fmt.Errorf("chmod: %s", err)
}
return f, nil
}

mode := os.O_WRONLY | os.O_CREATE | os.O_TRUNC // overwrite if exists
f, err := os.OpenFile(d.completeFilename(pathKey), mode, d.FilePerm)
// Figure out the path and append a random number
path := fmt.Sprintf("%s.%d", d.completeFilename(pathKey), rand.Int())
// It's incredibly unlikely that the destination file will exist, but
// we want to be absolutely sure: O_EXCL means we'll get an error if the
// file already exists.
mode := os.O_WRONLY | os.O_CREATE | os.O_EXCL
f, err := os.OpenFile(path, mode, d.FilePerm)
if err != nil {
return nil, fmt.Errorf("open file: %s", err)
}
floren marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -226,14 +224,27 @@ func (d *Diskv) createKeyFileWithLock(pathKey *PathKey) (*os.File, error) {

// writeStream does no input validation checking.
func (d *Diskv) writeStreamWithLock(pathKey *PathKey, r io.Reader, sync bool) error {
// fullPath is the on-disk location of the key
fullPath := d.completeFilename(pathKey)

if err := d.ensurePathWithLock(pathKey); err != nil {
return fmt.Errorf("ensure path: %s", err)
}

// createKeyFileWithLock gives us a temporary file we can write to.
// We'll move it when we're all done.
f, err := d.createKeyFileWithLock(pathKey)
if err != nil {
return fmt.Errorf("create key file: %s", err)
}
// In case something bad happens, we want to delete the temporary file,
// lest we leave junk in the store.
defer func() {
if r := recover(); r != nil {
os.Remove(f.Name())
panic(r)
}
}()
floren marked this conversation as resolved.
Show resolved Hide resolved

wc := io.WriteCloser(&nopWriteCloser{f})
if d.Compression != nil {
Expand Down Expand Up @@ -269,7 +280,7 @@ func (d *Diskv) writeStreamWithLock(pathKey *PathKey, r io.Reader, sync bool) er
return fmt.Errorf("file close: %s", err)
}

fullPath := d.completeFilename(pathKey)
// Move the temporary file to the final location.
if f.Name() != fullPath {
floren marked this conversation as resolved.
Show resolved Hide resolved
if err := os.Rename(f.Name(), fullPath); err != nil {
os.Remove(f.Name()) // error deliberately ignored
Expand Down
52 changes: 52 additions & 0 deletions issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,55 @@ func TestIssue40(t *testing.T) {
// is no room in the cache for this entry and it panics.
d.Read(k2)
}

// Test issue #63, where a reader obtained from ReadStream will start
// to return invalid data if WriteStream is called before you finish
// reading.
func TestIssue63(t *testing.T) {
var (
basePath = "test-data"
)
// Simplest transform function: put all the data files into the base dir.
flatTransform := func(s string) []string { return []string{} }

// Initialize a new diskv store, rooted at "my-data-dir",
// with no cache.
d := New(Options{
BasePath: basePath,
Transform: flatTransform,
CacheSizeMax: 0,
})

defer d.EraseAll()

// Write a big entry
k1 := "key1"
d1 := make([]byte, 1024*1024)
rand.Read(d1)
d.Write(k1, d1)

// Open a reader. We set the direct flag to be sure we're going straight to disk.
s1, err := d.ReadStream(k1, true)
if err != nil {
t.Fatal(err)
}

// Now generate a second big entry and put it in the *same* key
d2 := make([]byte, 1024*1024)
rand.Read(d2)
d.Write(k1, d2)

// Now read from that stream we opened
out, err := ioutil.ReadAll(s1)
if err != nil {
t.Fatal(err)
}
if len(out) != len(d1) {
t.Fatalf("Invalid read: got %v bytes expected %v\n", len(out), len(d1))
}
for i := range out {
if out[i] != d1[i] {
t.Fatalf("Output differs from expected at byte %v", i)
}
}
}
floren marked this conversation as resolved.
Show resolved Hide resolved