From 9e297e3a28e2e0dfca53e1bd9fb2a9cd64b2d8d3 Mon Sep 17 00:00:00 2001 From: Jesse Peterson Date: Wed, 11 Sep 2024 10:20:11 -0700 Subject: [PATCH] optionize txn kv tests --- storage/kv/test/txn.go | 48 ++++++++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/storage/kv/test/txn.go b/storage/kv/test/txn.go index 1444b30..f9ea1df 100644 --- a/storage/kv/test/txn.go +++ b/storage/kv/test/txn.go @@ -9,7 +9,23 @@ import ( "github.com/micromdm/nanolib/storage/kv" ) -func TestTxnSimple(t *testing.T, ctx context.Context, b kv.TxnCRUDBucket) { +type txnConfig struct { + noReadAfterRollback bool +} + +type TxnOption func(*txnConfig) + +func WithNoReadAfterRollback() TxnOption { + return func(c *txnConfig) { + c.noReadAfterRollback = true + } +} + +func TestTxnSimple(t *testing.T, ctx context.Context, b kv.TxnCRUDBucket, opts ...TxnOption) { + config := new(txnConfig) + for _, opt := range opts { + opt(config) + } // first, set a value in the "parent" bucket err := b.Set(ctx, "test-txn-key-1", []byte("test-txn-val-1")) if err != nil { @@ -61,14 +77,22 @@ func TestTxnSimple(t *testing.T, ctx context.Context, b kv.TxnCRUDBucket) { t.Errorf("have: %v, want: %v", string(have), string(want)) } - // read the value we just reset in the txn and make sure it was rolled back - val, err = bt.Get(ctx, "test-txn-key-1") + if !config.noReadAfterRollback { + // read the value we just reset in the txn and make sure it was rolled back + val, err = bt.Get(ctx, "test-txn-key-1") + if err != nil { + t.Fatal(err) + } + if have, want := val, []byte("test-txn-val-1"); !bytes.Equal(have, want) { + t.Errorf("have: %v, want: %v", string(have), string(want)) + } + } + + // create a txn again + bt, err = b.BeginCRUDBucketTxn(ctx) if err != nil { t.Fatal(err) } - if have, want := val, []byte("test-txn-val-1"); !bytes.Equal(have, want) { - t.Errorf("have: %v, want: %v", string(have), string(want)) - } // okay, let's re-reset the value again err = bt.Set(ctx, "test-txn-key-1", []byte("test-txn-val-2")) @@ -119,11 +143,13 @@ func TestTxnSimple(t *testing.T, ctx context.Context, b kv.TxnCRUDBucket) { t.Fatal(err) } - // and try and read the values we just set (but discarded) - // should error with a key not found - _, err = bt.Get(ctx, "test-txn-key-2") - if !errors.Is(err, kv.ErrKeyNotFound) { - t.Fatal(err) + if !config.noReadAfterRollback { + // and try and read the values we just set (but discarded) + // should error with a key not found + _, err = bt.Get(ctx, "test-txn-key-2") + if !errors.Is(err, kv.ErrKeyNotFound) { + t.Fatal(err) + } } // .. same for the parent bucket