Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 70 additions & 25 deletions lib/backend/firestore/firestorebk.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,13 @@ type Backend struct {
}

type record struct {
Key backend.Key `firestore:"key,omitempty"`
Timestamp int64 `firestore:"timestamp,omitempty"`
Expires int64 `firestore:"expires,omitempty"`
ID int64 `firestore:"id,omitempty"`
Value []byte `firestore:"value,omitempty"`
RevisionV2 string `firestore:"revision,omitempty"`
RevisionV1 string `firestore:"-"`
Key []byte `firestore:"key,omitempty"`
Timestamp int64 `firestore:"timestamp,omitempty"`
Expires int64 `firestore:"expires,omitempty"`
ID int64 `firestore:"id,omitempty"`
Value []byte `firestore:"value,omitempty"`
RevisionV2 string `firestore:"revision,omitempty"`
RevisionV1 string `firestore:"-"`
}

func (r *record) updates() []firestore.Update {
Expand Down Expand Up @@ -163,9 +163,21 @@ type legacyRecord struct {
Value string `firestore:"value,omitempty"`
}

// brokenRecord is an incorrect version of record used to marshal backend.Items.
// The Key type was inadvertently changed from a []byte to backend.Key which
// causes problems reading existing data prior to the conversion.
type brokenRecord struct {
Key backend.Key `firestore:"key,omitempty"`
Timestamp int64 `firestore:"timestamp,omitempty"`
Expires int64 `firestore:"expires,omitempty"`
Value []byte `firestore:"value,omitempty"`
ID int64 `firestore:"id,omitempty"`
RevisionV2 string `firestore:"revision,omitempty"`
}

func newRecord(from backend.Item, clock clockwork.Clock) record {
r := record{
Key: from.Key,
Key: []byte(from.Key.String()),
Value: from.Value,
Timestamp: clock.Now().UTC().Unix(),
ID: id(clock.Now()),
Expand All @@ -184,23 +196,48 @@ func newRecord(from backend.Item, clock clockwork.Clock) record {
}

func newRecordFromDoc(doc *firestore.DocumentSnapshot) (*record, error) {
k, err := doc.DataAt(keyDocProperty)
if err != nil {
return nil, trace.Wrap(err)
}

var r record
if err := doc.DataTo(&r); err != nil {
// If unmarshal failed, try using the old format of records, where
// Value was a string. This document could've been written by an older
// version of our code.
var rl legacyRecord
if doc.DataTo(&rl) != nil {
switch k.(type) {
case []any:
// If the key is a slice of any, then the key was mistakenly persisted
// as a backend.Key directly.
var br brokenRecord
if doc.DataTo(&br) != nil {
return nil, ConvertGRPCError(err)
}

r = record{
Key: backend.Key(rl.Key),
Value: []byte(rl.Value),
Timestamp: rl.Timestamp,
Expires: rl.Expires,
ID: rl.ID,
Key: br.Key,
Value: br.Value,
Timestamp: br.Timestamp,
Expires: br.Expires,
RevisionV2: br.RevisionV2,
ID: br.ID,
}
default:
if err := doc.DataTo(&r); err != nil {
// If unmarshal failed, try using the old format of records, where
// Value was a string. This document could've been written by an older
// version of our code.
var rl legacyRecord
if doc.DataTo(&rl) != nil {
return nil, ConvertGRPCError(err)
}
r = record{
Key: backend.Key(rl.Key),
Value: []byte(rl.Value),
Timestamp: rl.Timestamp,
Expires: rl.Expires,
ID: rl.ID,
}
}
}

if r.RevisionV2 == "" {
r.RevisionV1 = toRevisionV1(doc.UpdateTime)
}
Expand All @@ -218,7 +255,7 @@ func (r *record) isExpired(now time.Time) bool {

func (r *record) backendItem() backend.Item {
bi := backend.Item{
Key: r.Key,
Key: backend.Key(r.Key),
Value: r.Value,
ID: r.ID,
}
Expand Down Expand Up @@ -444,23 +481,31 @@ func (b *Backend) getRangeDocs(ctx context.Context, startKey, endKey backend.Key
limit = backend.DefaultRangeLimit
}
docs, err := b.svc.Collection(b.CollectionName).
Where(keyDocProperty, ">=", startKey).
Where(keyDocProperty, "<=", endKey).
Where(keyDocProperty, ">=", []byte(startKey.String())).
Where(keyDocProperty, "<=", []byte(endKey.String())).
Limit(limit).
Documents(ctx).GetAll()
if err != nil {
return nil, trace.Wrap(err)
}
legacyDocs, err := b.svc.Collection(b.CollectionName).
Where(keyDocProperty, ">=", string(startKey)).
Where(keyDocProperty, "<=", string(endKey)).
Where(keyDocProperty, ">=", startKey.String()).
Where(keyDocProperty, "<=", endKey.String()).
Limit(limit).
Documents(ctx).GetAll()
if err != nil {
return nil, trace.Wrap(err)
}
brokenDocs, err := b.svc.Collection(b.CollectionName).
Where(keyDocProperty, ">=", startKey).
Where(keyDocProperty, "<=", endKey).
Limit(limit).
Documents(ctx).GetAll()
if err != nil {
return nil, trace.Wrap(err)
}

allDocs := append(docs, legacyDocs...)
allDocs := append(append(docs, legacyDocs...), brokenDocs...)
if len(allDocs) >= backend.DefaultRangeLimit {
b.Warnf("Range query hit backend limit. (this is a bug!) startKey=%q,limit=%d", startKey, backend.DefaultRangeLimit)
}
Expand Down
96 changes: 96 additions & 0 deletions lib/backend/firestore/firestorebk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/api/option"
"google.golang.org/genproto/googleapis/rpc/code"
Expand Down Expand Up @@ -215,6 +216,101 @@ func TestReadLegacyRecord(t *testing.T) {
require.Equal(t, item.Expires, got.Expires)
}

func TestReadBrokenRecord(t *testing.T) {
cfg := firestoreParams()
ensureTestsEnabled(t)
ensureEmulatorRunning(t, cfg)

uut := newBackend(t, cfg)

ctx := context.Background()

prefix := test.MakePrefix()

// Create a valid record with the correct key type.
item := backend.Item{
Key: prefix("valid-record"),
Value: []byte("llamas"),
}
_, err := uut.Put(ctx, item)
require.NoError(t, err)

// Create a legacy record with a string key type.
lr := legacyRecord{
Key: prefix("legacy-record").String(),
Value: "sheep",
}
_, err = uut.svc.Collection(uut.CollectionName).Doc(uut.keyToDocumentID(backend.Key(lr.Key))).Set(ctx, lr)
require.NoError(t, err)

// Create a broken record with a backend.Key key type.
brokenItem := backend.Item{
Key: prefix("broken-record"),
Value: []byte("foo"),
Expires: uut.clock.Now().Add(time.Minute).Round(time.Second).UTC(),
}

// Write using broken record format, emulating data written by an older
// version of this backend.
br := brokenRecord{
Key: brokenItem.Key,
Value: brokenItem.Value,
Expires: brokenItem.Expires.UTC().Unix(),
Timestamp: uut.clock.Now().UTC().Unix(),
}
_, err = uut.svc.Collection(uut.CollectionName).Doc(uut.keyToDocumentID(brokenItem.Key)).Set(ctx, br)
require.NoError(t, err)

// Read the data back and make sure it matches the original item.
got, err := uut.Get(ctx, brokenItem.Key)
require.NoError(t, err)
require.Equal(t, brokenItem.Key, got.Key)
require.Equal(t, brokenItem.Value, got.Value)
require.Equal(t, brokenItem.Expires, got.Expires)

// Read the data back using a range query too.
gotRange, err := uut.GetRange(ctx, brokenItem.Key, brokenItem.Key, 1)
require.NoError(t, err)
require.Len(t, gotRange.Items, 1)

got = &gotRange.Items[0]
require.Equal(t, brokenItem.Key, got.Key)
require.Equal(t, brokenItem.Value, got.Value)
require.Equal(t, brokenItem.Expires, got.Expires)

// Retrieve the entire key range to validate that there are no duplicate records
results, err := uut.GetRange(ctx, prefix(""), backend.RangeEnd(prefix("")), 5)
require.NoError(t, err)

require.Len(t, results.Items, 3)

for _, result := range results.Items {
switch r := result.Key.String(); r {
case item.Key.String():
assert.Equal(t, item.Value, result.Value)
case br.Key.String():
assert.Equal(t, br.Value, result.Value)
case lr.Key:
assert.Equal(t, lr.Value, string(result.Value))
default:
t.Errorf("GetRange returned unexpected item key %s", r)
}
}

// Update the value and ensure that it's set to the correct key value
item.Value = []byte("llama")
_, err = uut.Update(ctx, item)
require.NoError(t, err)

doc, err := uut.svc.Collection(uut.CollectionName).Doc(uut.keyToDocumentID(item.Key)).Get(ctx)
require.NoError(t, err)

var r record
require.NoError(t, doc.DataTo(&r))
require.Equal(t, []byte(item.Key.String()), r.Key)
require.Equal(t, item.Value, r.Value)
}

type mockFirestoreServer struct {
// Embed for forward compatibility.
// Tests will keep working if more methods are added
Expand Down