Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #63: invalid reads after write #64

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions basic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package diskv
import (
"bytes"
"errors"
"io"
"math/rand"
"regexp"
"strings"
Expand Down Expand Up @@ -428,3 +429,60 @@ func TestHybridStore(t *testing.T) {
}

}

// Make sure that temporary files used for atomic writes never
// show up in the key listing
func TestIgnoreAtomicTempFiles(t *testing.T) {
var (
basePath = "test-data"
)
// Simplest transform function: put all the data files into the base dir.
flatTransform := func(s string) []string { return []string{} }

// Initialize a new diskv store, rooted at "my-data-dir",
// with no cache.
d := New(Options{
BasePath: basePath,
Transform: flatTransform,
CacheSizeMax: 0,
})

// Write something in so everything is set up
d.Write("foo", []byte("bar"))

// Start to write an entry using a stream, but do not
// put anything into it yet!
key := "key1"
data := make([]byte, 1024*1024)
rand.Read(data)

// Get a pipe
rdr, wtr := io.Pipe()

// Start the write
go d.WriteStream(key, rdr, true)

// Now list keys: there should be 1 key.
keys := d.Keys(nil)
var count int
for _ = range keys {
count++
}
if count != 1 {
t.Fatalf("Expected 1 key, got %d", count)
}

// Now complete the write
wtr.Write(data)
wtr.Close()

// And make sure we see exactly two keys
keys = d.Keys(nil)
for _ = range keys {
count++
}
if count != 2 {
t.Fatalf("Expected 2 keys, got %d", count)
}
d.EraseAll()
}
80 changes: 39 additions & 41 deletions diskv.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@ import (
"fmt"
"io"
"io/ioutil"
"math/rand"
"os"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
)

const (
defaultBasePath = "diskv"
defaultFilePerm os.FileMode = 0666
defaultPathPerm os.FileMode = 0777

DefaultAtomicPrefix = ".diskv_atomic_temp"
)

// PathKey represents a string key that has been transformed to
Expand Down Expand Up @@ -76,12 +80,12 @@ type Options struct {
CacheSizeMax uint64 // bytes
PathPerm os.FileMode
FilePerm os.FileMode
// If TempDir is set, it will enable filesystem atomic writes by
// writing temporary files to that location before being moved
// to BasePath.
// Note that TempDir MUST be on the same device/partition as
// BasePath.
// Note: TempDir is deprecated, all writes are now atomic.
TempDir string
// AtomicPrefix sets the name of a directory which will be created
// within BasePath to store temporary files for atomic writes.
// It defaults to DefaultAtomicPrefix; you probably don't need to change it.
AtomicPrefix string

Index Index
IndexLess LessFunction
Expand All @@ -96,6 +100,7 @@ type Diskv struct {
mu sync.RWMutex
cache map[string][]byte
cacheSize uint64
rnd *rand.Rand
}

// New returns an initialized Diskv structure, ready to use.
Expand All @@ -105,6 +110,9 @@ func New(o Options) *Diskv {
if o.BasePath == "" {
o.BasePath = defaultBasePath
}
if o.AtomicPrefix == "" {
o.AtomicPrefix = DefaultAtomicPrefix
}

if o.AdvancedTransform == nil {
if o.Transform == nil {
Expand Down Expand Up @@ -132,12 +140,18 @@ func New(o Options) *Diskv {
Options: o,
cache: map[string][]byte{},
cacheSize: 0,
rnd: rand.New(rand.NewSource(time.Now().UnixNano())),
}

if d.Index != nil && d.IndexLess != nil {
d.Index.Initialize(d.IndexLess, d.Keys(nil))
}

// Just in case there were any failures during writes previously, we
// remove the atomic write directory (and any temp files within it).
// The directory will be created the first time we do a Write.
os.RemoveAll(d.atomicTempPath())

return d
}

Expand Down Expand Up @@ -196,41 +210,19 @@ func (d *Diskv) WriteStream(key string, r io.Reader, sync bool) error {
return d.writeStreamWithLock(pathKey, r, sync)
}

// createKeyFileWithLock either creates the key file directly, or
// creates a temporary file in TempDir if it is set.
func (d *Diskv) createKeyFileWithLock(pathKey *PathKey) (*os.File, error) {
if d.TempDir != "" {
if err := os.MkdirAll(d.TempDir, d.PathPerm); err != nil {
return nil, fmt.Errorf("temp mkdir: %s", err)
}
f, err := ioutil.TempFile(d.TempDir, "")
if err != nil {
return nil, fmt.Errorf("temp file: %s", err)
}

if err := os.Chmod(f.Name(), d.FilePerm); err != nil {
f.Close() // error deliberately ignored
os.Remove(f.Name()) // error deliberately ignored
return nil, fmt.Errorf("chmod: %s", err)
}
return f, nil
}

mode := os.O_WRONLY | os.O_CREATE | os.O_TRUNC // overwrite if exists
f, err := os.OpenFile(d.completeFilename(pathKey), mode, d.FilePerm)
if err != nil {
return nil, fmt.Errorf("open file: %s", err)
}
return f, nil
}

// writeStream does no input validation checking.
func (d *Diskv) writeStreamWithLock(pathKey *PathKey, r io.Reader, sync bool) error {
// fullPath is the on-disk location of the key
fullPath := d.completeFilename(pathKey)

if err := d.ensurePathWithLock(pathKey); err != nil {
return fmt.Errorf("ensure path: %s", err)
}

f, err := d.createKeyFileWithLock(pathKey)
// Get a temporary file we can write to.
// We'll move it when we're all done.
d.ensureAtomicTempDir()
f, err := ioutil.TempFile(d.atomicTempPath(), pathKey.FileName)
if err != nil {
return fmt.Errorf("create key file: %s", err)
}
Expand Down Expand Up @@ -269,12 +261,10 @@ func (d *Diskv) writeStreamWithLock(pathKey *PathKey, r io.Reader, sync bool) er
return fmt.Errorf("file close: %s", err)
}

fullPath := d.completeFilename(pathKey)
if f.Name() != fullPath {
if err := os.Rename(f.Name(), fullPath); err != nil {
os.Remove(f.Name()) // error deliberately ignored
return fmt.Errorf("rename: %s", err)
}
// Move the temporary file to the final location.
if err := os.Rename(f.Name(), fullPath); err != nil {
os.Remove(f.Name()) // error deliberately ignored
return fmt.Errorf("rename: %s", err)
}

if d.Index != nil {
Expand Down Expand Up @@ -596,7 +586,7 @@ func (d *Diskv) walker(c chan<- string, prefix string, cancel <-chan struct{}) f

key := d.InverseTransform(pathKey)

if info.IsDir() || !strings.HasPrefix(key, prefix) {
if info.IsDir() || !strings.HasPrefix(key, prefix) || strings.HasPrefix(dir, d.AtomicPrefix) {
return nil // "pass"
}

Expand Down Expand Up @@ -627,6 +617,14 @@ func (d *Diskv) completeFilename(pathKey *PathKey) string {
return filepath.Join(d.pathFor(pathKey), pathKey.FileName)
}

func (d *Diskv) ensureAtomicTempDir() error {
return os.MkdirAll(d.atomicTempPath(), d.PathPerm)
}

func (d *Diskv) atomicTempPath() string {
return filepath.Join(d.BasePath, d.AtomicPrefix)
}

// cacheWithLock attempts to cache the given key-value pair in the store's
// cache. It can fail if the value is larger than the cache's maximum size.
func (d *Diskv) cacheWithLock(key string, val []byte) error {
Expand Down
52 changes: 52 additions & 0 deletions issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,55 @@ func TestIssue40(t *testing.T) {
// is no room in the cache for this entry and it panics.
d.Read(k2)
}

// Test issue #63, where a reader obtained from ReadStream will start
// to return invalid data if WriteStream is called before you finish
// reading.
func TestIssue63(t *testing.T) {
var (
basePath = "test-data"
)
// Simplest transform function: put all the data files into the base dir.
flatTransform := func(s string) []string { return []string{} }

// Initialize a new diskv store, rooted at "my-data-dir",
// with no cache.
d := New(Options{
BasePath: basePath,
Transform: flatTransform,
CacheSizeMax: 0,
})

defer d.EraseAll()

// Write a big entry
k1 := "key1"
d1 := make([]byte, 1024*1024)
rand.Read(d1)
d.Write(k1, d1)

// Open a reader. We set the direct flag to be sure we're going straight to disk.
s1, err := d.ReadStream(k1, true)
if err != nil {
t.Fatal(err)
}

// Now generate a second big entry and put it in the *same* key
d2 := make([]byte, 1024*1024)
rand.Read(d2)
d.Write(k1, d2)

// Now read from that stream we opened
out, err := ioutil.ReadAll(s1)
if err != nil {
t.Fatal(err)
}
if len(out) != len(d1) {
t.Fatalf("Invalid read: got %v bytes expected %v\n", len(out), len(d1))
}
for i := range out {
if out[i] != d1[i] {
t.Fatalf("Output differs from expected at byte %v", i)
}
}
}