diff --git a/go.sum b/go.sum index 808e807a7e0..eecbee18a70 100644 --- a/go.sum +++ b/go.sum @@ -6,7 +6,6 @@ contrib.go.opencensus.io/exporter/prometheus v0.1.0 h1:SByaIoWwNgMdPSgl5sMqM2KDE contrib.go.opencensus.io/exporter/prometheus v0.1.0/go.mod h1:cGFniUXGZlKRjzOyuZJ6mgB+PgBcCIa79kEKR8YCW+A= github.com/99designs/gqlgen v0.13.1-0.20200928230741-819e751c2416 h1:8qbuDq7x3pPeEUmfa2wPKuN2G5Q/+znZWAJWZJXTjDA= github.com/99designs/gqlgen v0.13.1-0.20200928230741-819e751c2416/go.mod h1:NV130r6f4tpRWuAI+zsrSdooO/eWUv+Gyyoi3rEfXIk= -github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9 h1:HD8gA2tkByhMAwYaFAX9w2l7vxvBQ5NMoxDrkhqhtn4= github.com/AndreasBriese/bbloom v0.0.0-20190306092124-e2d15f34fcf9/go.mod h1:bOvUY6CB00SOBii9/FifXqc0awNKxLFCL/+pkDPuyl8= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= @@ -73,9 +72,7 @@ github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8Nz github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20180511133405-39ca1b05acc7/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA= -github.com/cpuguy83/go-md2man v1.0.10 h1:BSKMNlYxDvnunlTymqtgONjNnaRV1sTpcovwwjF22jk= github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= -github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d h1:U+s90UTSYgptZMwQh2aRr3LuazLJIa+Pg3Kc1ylSYVY= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/d4l3k/messagediff v1.2.1 h1:ZcAIMYsUg0EAp9X+tt8/enBE/Q8Yd5kzPynLyKptt9U= github.com/d4l3k/messagediff v1.2.1/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo= @@ -90,7 +87,6 @@ github.com/dgraph-io/dgo/v200 v200.0.0-20200805103119-a3544c464dd6 h1:toHzMCdCUg github.com/dgraph-io/dgo/v200 v200.0.0-20200805103119-a3544c464dd6/go.mod h1:rHa+h3kI4M8ASOirxyIyNeXBfHFgeskVUum2OrDMN3U= github.com/dgraph-io/graphql-transport-ws v0.0.0-20200916064635-48589439591b h1:PDEhlwHpkEQ5WBfOOKZCNZTXFDGyCEWTYDhxGQbyIpk= github.com/dgraph-io/graphql-transport-ws v0.0.0-20200916064635-48589439591b/go.mod h1:7z3c/5w0sMYYZF5bHsrh8IH4fKwG5O5Y70cPH1ZLLRQ= -github.com/dgraph-io/ristretto v0.0.4-0.20201013194302-6d6fac64beae h1:yh5085twGpsgfuu56DXKOM3SKyZKQPskJIoMNb3jzos= github.com/dgraph-io/ristretto v0.0.4-0.20201013194302-6d6fac64beae/go.mod h1:bDI4cDaalvYSji3vBVDKrn9ouDZrwN974u8ZO/AhYXs= github.com/dgraph-io/ristretto v0.0.4-0.20201013234705-28aba7a42dfa h1:gAJJ+Ln7gBkeLEoKIuFL1p+YjbihmgQYWpOIu8XE6JM= github.com/dgraph-io/ristretto v0.0.4-0.20201013234705-28aba7a42dfa/go.mod h1:bDI4cDaalvYSji3vBVDKrn9ouDZrwN974u8ZO/AhYXs= diff --git a/raftwal/encryption_test.go b/raftwal/encryption_test.go new file mode 100644 index 00000000000..980468ba981 --- /dev/null +++ b/raftwal/encryption_test.go @@ -0,0 +1,67 @@ +/* + * Copyright 2020 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package raftwal + +import ( + "io/ioutil" + "math/rand" + "os" + "testing" + + "github.com/dgraph-io/dgraph/x" + "github.com/stretchr/testify/require" + "go.etcd.io/etcd/raft/raftpb" +) + +func TestEntryReadWrite(t *testing.T) { + x.WorkerConfig.EncryptionKey = []byte("badger16byteskey") + dir, err := ioutil.TempDir("", "raftwal") + require.NoError(t, err) + el, err := openWal(dir) + require.NoError(t, err) + defer os.RemoveAll(dir) + + // generate some random data + data := make([]byte, rand.Intn(1000)) + rand.Read(data) + + require.NoError(t, el.AddEntries([]raftpb.Entry{{Index: 1, Term: 1, Data: data}})) + entries := el.allEntries(0, 100, 10000) + require.Equal(t, 1, len(entries)) + require.Equal(t, uint64(1), entries[0].Index) + require.Equal(t, uint64(1), entries[0].Term) + require.Equal(t, data, entries[0].Data) + + // Open the wal file again. + el2, err := openWal(dir) + require.NoError(t, err) + entries = el2.allEntries(0, 100, 10000) + require.Equal(t, 1, len(entries)) + require.Equal(t, uint64(1), entries[0].Index) + require.Equal(t, uint64(1), entries[0].Term) + require.Equal(t, data, entries[0].Data) + + // Opening it with a wrong key fails. + x.WorkerConfig.EncryptionKey = []byte("other16byteskeys") + _, err = openWal(dir) + require.EqualError(t, err, "Encryption key mismatch") + + // Opening it without encryption key fails. + x.WorkerConfig.EncryptionKey = nil + _, err = openWal(dir) + require.EqualError(t, err, "Logfile is encrypted but encryption key is nil") +} diff --git a/raftwal/log.go b/raftwal/log.go index d43e69892fd..3189a3c8962 100644 --- a/raftwal/log.go +++ b/raftwal/log.go @@ -17,6 +17,8 @@ package raftwal import ( + "crypto/aes" + cryptorand "crypto/rand" "encoding/binary" "fmt" "os" @@ -24,7 +26,11 @@ import ( "sort" "strconv" "strings" + "time" + "github.com/dgraph-io/badger/v2" + "github.com/dgraph-io/badger/v2/pb" + "github.com/dgraph-io/badger/v2/y" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" "github.com/golang/glog" @@ -43,6 +49,9 @@ const ( maxNumEntries = 30000 // logFileOffset is offset in the log file where data is stored. logFileOffset = 1 << 20 // 1MB + // encOffset is offset in the log file where keyID (first 8 bytes) + // and baseIV (remaining 8 bytes) are stored. + encOffset = logFileOffset - 16 // 1MB - 16B // logFileSize is the initial size of the log file. logFileSize = 16 << 30 // entrySize is the size in bytes of a single entry. @@ -75,6 +84,10 @@ func marshalEntry(b []byte, term, index, do, typ uint64) { type logFile struct { *z.MmapFile fid int64 + + registry *badger.KeyRegistry + dataKey *pb.DataKey + baseIV []byte } func logFname(dir string, id int64) string { @@ -86,19 +99,55 @@ func logFname(dir string, id int64) string { func openLogFile(dir string, fid int64) (*logFile, error) { glog.V(3).Infof("opening log file: %d\n", fid) fpath := logFname(dir, fid) + lf := &logFile{ + fid: fid, + } + var err error + encKey := x.WorkerConfig.EncryptionKey + // Initialize the registry for logFile if encryption in enabled. + // NOTE: If encryption is enabled then there is no going back because if we disable it + // later then the older log files which were previously encrypted can't be opened. + if len(encKey) > 0 { + krOpt := badger.KeyRegistryOptions{ + ReadOnly: false, + Dir: dir, + EncryptionKey: encKey, + EncryptionKeyRotationDuration: 10 * 24 * time.Hour, + InMemory: false, + } + // This won't open Badger. It would only use its key registry. + if lf.registry, err = badger.OpenKeyRegistry(krOpt); err != nil { + return nil, err + } + } // Open the file in read-write mode and create it if it doesn't exist yet. - mf, err := z.OpenMmapFile(fpath, os.O_RDWR|os.O_CREATE, logFileSize) + lf.MmapFile, err = z.OpenMmapFile(fpath, os.O_RDWR|os.O_CREATE, logFileSize) if err == z.NewFile { glog.V(3).Infof("New file: %d\n", fid) - z.ZeroOut(mf.Data, 0, logFileOffset) - } else { + z.ZeroOut(lf.Data, 0, logFileOffset) + if err = lf.bootstrap(); err != nil { + return nil, err + } + } else if err != nil { x.Check(err) - } + } else { + buf := lf.Data[encOffset : encOffset+16] + keyID := binary.BigEndian.Uint64(buf[:8]) - lf := &logFile{ - MmapFile: mf, - fid: fid, + // If keyID is non-zero, then the opened file is encrypted. + if keyID != 0 { + // Logfile is encrypted but encryption key is not provided. + if encKey == nil { + return nil, errors.New("Logfile is encrypted but encryption key is nil") + } + // retrieve datakey from the keyID of the logfile. + if lf.dataKey, err = lf.registry.DataKey(keyID); err != nil { + return nil, err + } + lf.baseIV = buf[8:] + y.AssertTrue(len(lf.baseIV) == 8) + } } return lf, nil } @@ -129,6 +178,16 @@ func (lf *logFile) GetRaftEntry(idx int) raftpb.Entry { re.Data = append(re.Data, data...) } } + // Decrypt the data if encryption is enabled. + if lf.dataKey != nil && len(re.Data) > 0 { + // No need to worry about mmap. Because, XORBlock allocates a byte array to do the + // XOR. So, the given slice is not being mutated. + // NOTE: We can potentially use allocator for this allocation. + decoded, err := y.XORBlockAllocate( + re.Data, lf.dataKey.Data, lf.generateIV(entry.DataOffset())) + x.Check(err) + re.Data = decoded + } return re } @@ -242,3 +301,50 @@ func getLogFiles(dir string) ([]*logFile, error) { }) return files, nil } + +// KeyID returns datakey's ID. +func (lf *logFile) keyID() uint64 { + if lf.dataKey == nil { + // If there is no datakey, then we'll return 0. Which means no encryption. + return 0 + } + return lf.dataKey.KeyId +} + +// generateIV will generate IV by appending given offset with the base IV. +func (lf *logFile) generateIV(offset uint64) []byte { + iv := make([]byte, aes.BlockSize) + // IV is of 16 bytes, in which first 8 bytes are obtained from baseIV + // and the remaining 8 bytes is obtained from the offset. + y.AssertTrue(8 == copy(iv[:8], lf.baseIV)) + binary.BigEndian.PutUint64(iv[8:], offset) + return iv +} + +// bootstrap will initialize the log file with key id and baseIV. +// The below figure shows the layout of log file. +// +----------------+------------------+------------------+ +// | keyID(8 bytes) | baseIV(8 bytes) | entry... | +// +----------------+------------------+------------------+ +func (lf *logFile) bootstrap() error { + // registry is nil if we don't have encryption enabled. + if lf.registry == nil { + return nil + } + var err error + // generate data key for the log file. + if lf.dataKey, err = lf.registry.LatestDataKey(); err != nil { + return y.Wrapf(err, "Error while retrieving datakey in logFile.bootstrap") + } + buf := lf.Data[encOffset : encOffset+16] + // Put keyID in the first 8 bytes. + binary.BigEndian.PutUint64(buf[:8], lf.keyID()) + + // fill in random bytes in the last 8 bytes of buf. + if _, err := cryptorand.Read(buf[8:]); err != nil { + return y.Wrapf(err, "Error while creating base IV, while creating logfile") + } + // Initialize base IV. + lf.baseIV = buf[8:] + return nil +} diff --git a/raftwal/storage_test.go b/raftwal/storage_test.go index d19bff3f842..abd22c054e8 100644 --- a/raftwal/storage_test.go +++ b/raftwal/storage_test.go @@ -40,6 +40,7 @@ import ( "reflect" "testing" + "github.com/dgraph-io/dgraph/x" "github.com/stretchr/testify/require" "go.etcd.io/etcd/raft" "go.etcd.io/etcd/raft/raftpb" @@ -347,155 +348,163 @@ func TestEntryFile(t *testing.T) { require.Equal(t, "abc", string(entries[0].Data)) } -func TestStorageBig(t *testing.T) { - dir, err := ioutil.TempDir("", "raftwal") - require.NoError(t, err) - ds := Init(dir) - t.Logf("Creating dir: %s\n", dir) - // defer os.RemoveAll(dir) +func TestStorageOnlySnap(t *testing.T) { + test := func(t *testing.T, key []byte) { + x.WorkerConfig.EncryptionKey = key + dir, err := ioutil.TempDir("", "raftwal") + require.NoError(t, err) + ds := Init(dir) + t.Logf("Creating dir: %s\n", dir) - ent := raftpb.Entry{ - Term: 1, - Type: raftpb.EntryNormal, - } + buf := make([]byte, 128) + rand.Read(buf) + N := uint64(1000) - addEntries := func(start, end uint64) { - t.Logf("adding entries: %d -> %d\n", start, end) - for idx := start; idx <= end; idx++ { - ent.Index = idx - require.NoError(t, ds.wal.AddEntries([]raftpb.Entry{ent})) - li, err := ds.LastIndex() - require.NoError(t, err) - require.Equal(t, idx, li) - } - } + snap := &raftpb.Snapshot{} + snap.Metadata.Index = N + snap.Metadata.ConfState = raftpb.ConfState{} + snap.Data = buf - N := uint64(100000) - addEntries(1, N) - num := ds.NumEntries() - require.Equal(t, int(N), num) + require.NoError(t, ds.meta.StoreSnapshot(snap)) - check := func(start, end uint64) { - ents, err := ds.Entries(start, end, math.MaxInt64) + out, err := ds.Snapshot() require.NoError(t, err) - require.Equal(t, int(end-start), len(ents)) - for i, e := range ents { - require.Equal(t, start+uint64(i), e.Index) - } - } - _, err = ds.Entries(0, 1, math.MaxInt64) - require.Equal(t, raft.ErrCompacted, err) - - check(3, N) - check(10000, 20000) - check(20000, 33000) - check(33000, 45000) - check(45000, N) - - // Around file boundaries. - check(1, N) - check(30000, N) - check(30001, N) - check(60000, N) - check(60001, N) - check(60000, 90000) - check(N, N+1) - - _, err = ds.Entries(N+1, N+10, math.MaxInt64) - require.Error(t, raft.ErrUnavailable, err) - - // Jump back a few files. - addEntries(N/3, N) - check(3, N) - check(10000, 20000) - check(20000, 33000) - check(33000, 45000) - check(45000, N) - check(N, N+1) - - buf := make([]byte, 128) - rand.Read(buf) - - cs := &raftpb.ConfState{} - require.NoError(t, ds.CreateSnapshot(N-100, cs, buf)) - fi, err := ds.FirstIndex() - require.NoError(t, err) - require.Equal(t, N-100+1, fi) + require.Equal(t, N, out.Metadata.Index) - snap, err := ds.Snapshot() - require.NoError(t, err) - require.Equal(t, N-100, snap.Metadata.Index) - require.Equal(t, buf, snap.Data) - - require.Equal(t, 0, len(ds.wal.files)) - - files, err := getLogFiles(dir) - require.NoError(t, err) - require.Equal(t, 1, len(files)) - - // Jumping back. - ent.Index = N - 50 - require.NoError(t, ds.wal.AddEntries([]raftpb.Entry{ent})) + fi, err := ds.FirstIndex() + require.NoError(t, err) + require.Equal(t, N+1, fi) - start := N - 100 + 1 - ents := ds.wal.allEntries(start, math.MaxInt64, math.MaxInt64) - require.Equal(t, 50, len(ents)) - for idx, ent := range ents { - require.Equal(t, int(start)+idx, int(ent.Index)) + li, err := ds.LastIndex() + require.NoError(t, err) + require.Equal(t, N, li) } + t.Run("without encryption", func(t *testing.T) { test(t, nil) }) + t.Run("with encryption", func(t *testing.T) { test(t, []byte("badger16byteskey")) }) +} - ent.Index = N - require.NoError(t, ds.wal.AddEntries([]raftpb.Entry{ent})) - ents = ds.wal.allEntries(start, math.MaxInt64, math.MaxInt64) - require.Equal(t, 51, len(ents)) - for idx, ent := range ents { - if idx == 50 { - require.Equal(t, N, ent.Index) - } else { - require.Equal(t, int(start)+idx, int(ent.Index)) +func TestStorageBig(t *testing.T) { + test := func(t *testing.T, key []byte) { + x.WorkerConfig.EncryptionKey = key + dir, err := ioutil.TempDir("", "raftwal") + require.NoError(t, err) + ds := Init(dir) + defer os.RemoveAll(dir) + + ent := raftpb.Entry{ + Term: 1, + Type: raftpb.EntryNormal, } - } - require.NoError(t, ds.Sync()) - - ks := Init(dir) - ents = ks.wal.allEntries(start, math.MaxInt64, math.MaxInt64) - require.Equal(t, 51, len(ents)) - for idx, ent := range ents { - if idx == 50 { - require.Equal(t, N, ent.Index) - } else { - require.Equal(t, int(start)+idx, int(ent.Index)) + + addEntries := func(start, end uint64) { + t.Logf("adding entries: %d -> %d\n", start, end) + for idx := start; idx <= end; idx++ { + ent.Index = idx + require.NoError(t, ds.wal.AddEntries([]raftpb.Entry{ent})) + li, err := ds.LastIndex() + require.NoError(t, err) + require.Equal(t, idx, li) + } } - } -} + N := uint64(100000) + addEntries(1, N) + num := ds.NumEntries() + require.Equal(t, int(N), num) -func TestStorageOnlySnap(t *testing.T) { - dir, err := ioutil.TempDir("", "raftwal") - require.NoError(t, err) - ds := Init(dir) - t.Logf("Creating dir: %s\n", dir) + check := func(start, end uint64) { + ents, err := ds.Entries(start, end, math.MaxInt64) + require.NoError(t, err) + require.Equal(t, int(end-start), len(ents)) + for i, e := range ents { + require.Equal(t, start+uint64(i), e.Index) + } + } + _, err = ds.Entries(0, 1, math.MaxInt64) + require.Equal(t, raft.ErrCompacted, err) + + check(3, N) + check(10000, 20000) + check(20000, 33000) + check(33000, 45000) + check(45000, N) + + // Around file boundaries. + check(1, N) + check(30000, N) + check(30001, N) + check(60000, N) + check(60001, N) + check(60000, 90000) + check(N, N+1) + + _, err = ds.Entries(N+1, N+10, math.MaxInt64) + require.Error(t, raft.ErrUnavailable, err) + + // Jump back a few files. + addEntries(N/3, N) + check(3, N) + check(10000, 20000) + check(20000, 33000) + check(33000, 45000) + check(45000, N) + check(N, N+1) + + buf := make([]byte, 128) + rand.Read(buf) + + cs := &raftpb.ConfState{} + require.NoError(t, ds.CreateSnapshot(N-100, cs, buf)) + fi, err := ds.FirstIndex() + require.NoError(t, err) + require.Equal(t, N-100+1, fi) - buf := make([]byte, 128) - rand.Read(buf) - N := uint64(1000) + snap, err := ds.Snapshot() + require.NoError(t, err) + require.Equal(t, N-100, snap.Metadata.Index) + require.Equal(t, buf, snap.Data) - snap := &raftpb.Snapshot{} - snap.Metadata.Index = N - snap.Metadata.ConfState = raftpb.ConfState{} - snap.Data = buf + require.Equal(t, 0, len(ds.wal.files)) - require.NoError(t, ds.meta.StoreSnapshot(snap)) + files, err := getLogFiles(dir) + require.NoError(t, err) + require.Equal(t, 1, len(files)) - out, err := ds.Snapshot() - require.NoError(t, err) - require.Equal(t, N, out.Metadata.Index) + // Jumping back. + ent.Index = N - 50 + require.NoError(t, ds.wal.AddEntries([]raftpb.Entry{ent})) - fi, err := ds.FirstIndex() - require.NoError(t, err) - require.Equal(t, N+1, fi) + start := N - 100 + 1 + ents := ds.wal.allEntries(start, math.MaxInt64, math.MaxInt64) + require.Equal(t, 50, len(ents)) + for idx, ent := range ents { + require.Equal(t, int(start)+idx, int(ent.Index)) + } - li, err := ds.LastIndex() - require.NoError(t, err) - require.Equal(t, N, li) + ent.Index = N + require.NoError(t, ds.wal.AddEntries([]raftpb.Entry{ent})) + ents = ds.wal.allEntries(start, math.MaxInt64, math.MaxInt64) + require.Equal(t, 51, len(ents)) + for idx, ent := range ents { + if idx == 50 { + require.Equal(t, N, ent.Index) + } else { + require.Equal(t, int(start)+idx, int(ent.Index)) + } + } + require.NoError(t, ds.Sync()) + + ks := Init(dir) + ents = ks.wal.allEntries(start, math.MaxInt64, math.MaxInt64) + require.Equal(t, 51, len(ents)) + for idx, ent := range ents { + if idx == 50 { + require.Equal(t, N, ent.Index) + } else { + require.Equal(t, int(start)+idx, int(ent.Index)) + } + } + } + t.Run("without encryption", func(t *testing.T) { test(t, nil) }) + t.Run("with encryption", func(t *testing.T) { test(t, []byte("badger16byteskey")) }) } diff --git a/raftwal/wal.go b/raftwal/wal.go index 8ccf50681e6..135aef96a94 100644 --- a/raftwal/wal.go +++ b/raftwal/wal.go @@ -17,8 +17,10 @@ package raftwal import ( + "bytes" "sort" + "github.com/dgraph-io/badger/v2/y" "github.com/dgraph-io/dgraph/x" "github.com/dgraph-io/ristretto/z" "github.com/golang/glog" @@ -166,8 +168,18 @@ func (l *wal) AddEntries(entries []raftpb.Entry) error { } l.nextEntryIdx, offset = 0, logFileOffset } + // If encryption is enabled then encrypt the data. + if l.current.dataKey != nil { + var ebuf bytes.Buffer + curr := l.current + if err := y.XORBlockStream( + &ebuf, re.Data, curr.dataKey.Data, curr.generateIV(uint64(offset))); err != nil { + return err + } + re.Data = ebuf.Bytes() + } - // Write re.Data to a new slice at the end of the file. + // Allocate slice for the data and copy bytes. destBuf, next := l.current.AllocateSlice(len(re.Data), offset) x.AssertTrue(copy(destBuf, re.Data) == len(re.Data))