diff --git a/lib/runtime/imports_old.go b/lib/runtime/imports_old.go index 2acd67ef1b..7a275e91b8 100644 --- a/lib/runtime/imports_old.go +++ b/lib/runtime/imports_old.go @@ -60,6 +60,7 @@ import ( "encoding/binary" "fmt" "math/big" + "reflect" "unsafe" "github.com/ChainSafe/gossamer/lib/common" @@ -735,7 +736,7 @@ func ext_local_storage_get(context unsafe.Pointer, kind, key, keyLen, valueLen i runtimeCtx := instanceContext.Data().(*Ctx) var res []byte var err error - switch kind { + switch NodeStorageType(kind) { case NodeStorageTypePersistent: res, err = runtimeCtx.nodeStorage.PersistentStorage.Get(keyM) case NodeStorageTypeLocal: @@ -757,10 +758,43 @@ func ext_local_storage_get(context unsafe.Pointer, kind, key, keyLen, valueLen i } //export ext_local_storage_compare_and_set -func ext_local_storage_compare_and_set(context unsafe.Pointer, kind, key, keyLen, oldValue, oldValueLen, newValue, newValueLen int32) int32 { +func ext_local_storage_compare_and_set(context unsafe.Pointer, kind, keyPtr, keyLen, oldValuePtr, oldValueLen, newValuePtr, newValueLen int32) int32 { logger.Trace("[ext_local_storage_compare_and_set] executing...") - logger.Warn("[ext_local_storage_compare_and_set] Not yet implemented.") - return 0 + instanceContext := wasm.IntoInstanceContext(context) + memory := instanceContext.Memory().Data() + + key := memory[keyPtr : keyPtr+keyLen] + runtimeCtx := instanceContext.Data().(*Ctx) + var storedValue []byte + var err error + var nodeStorage BasicStorage + + switch NodeStorageType(kind) { + case NodeStorageTypePersistent: + nodeStorage = runtimeCtx.nodeStorage.PersistentStorage + storedValue, err = nodeStorage.Get(key) + case NodeStorageTypeLocal: + nodeStorage = runtimeCtx.nodeStorage.LocalStorage + storedValue, err = nodeStorage.Get(key) + } + + if err != nil { + logger.Error("[ext_local_storage_compare_and_set]", "error", err) + return 1 + } + + oldValue := memory[oldValuePtr : oldValuePtr+oldValueLen] + + if reflect.DeepEqual(storedValue, oldValue) { + newValue := memory[newValuePtr : newValuePtr+newValueLen] + err := nodeStorage.Put(key, newValue) + if err != nil { + logger.Error("[ext_local_storage_compare_and_set]", "error", err) + return 1 + } + return 0 + } + return 1 } //export ext_network_state @@ -814,7 +848,7 @@ func ext_local_storage_set(context unsafe.Pointer, kind, key, keyLen, value, val runtimeCtx := instanceContext.Data().(*Ctx) var err error - switch kind { + switch NodeStorageType(kind) { case NodeStorageTypePersistent: err = runtimeCtx.nodeStorage.PersistentStorage.Put(keyM, valueM) case NodeStorageTypeLocal: diff --git a/lib/runtime/imports_old_test.go b/lib/runtime/imports_old_test.go index a765f18d98..d431c3dee4 100644 --- a/lib/runtime/imports_old_test.go +++ b/lib/runtime/imports_old_test.go @@ -1154,7 +1154,7 @@ func TestExt_local_storage_set_local(t *testing.T) { t.Fatal("could not find exported function") } - _, err := testFunc(NodeStorageTypeLocal, keyPtr, keyLen, valuePtr, valueLen) + _, err := testFunc(int32(NodeStorageTypeLocal), keyPtr, keyLen, valuePtr, valueLen) require.NoError(t, err) resValue, err := runtime.ctx.nodeStorage.LocalStorage.Get(key) @@ -1184,7 +1184,7 @@ func TestExt_local_storage_set_persistent(t *testing.T) { t.Fatal("could not find exported function") } - _, err := testFunc(NodeStorageTypePersistent, keyPtr, keyLen, valuePtr, valueLen) + _, err := testFunc(int32(NodeStorageTypePersistent), keyPtr, keyLen, valuePtr, valueLen) require.NoError(t, err) resValue, err := runtime.ctx.nodeStorage.PersistentStorage.Get(key) @@ -1212,7 +1212,7 @@ func TestExt_local_storage_get_local(t *testing.T) { t.Fatal("could not find exported function") } - res, err := testFunc(NodeStorageTypeLocal, keyPtr, keyLen, valueLen) + res, err := testFunc(int32(NodeStorageTypeLocal), keyPtr, keyLen, valueLen) require.Nil(t, err) require.Equal(t, value, mem[res.ToI32():res.ToI32()+int32(valueLen)]) @@ -1238,12 +1238,111 @@ func TestExt_local_storage_get_persistent(t *testing.T) { t.Fatal("could not find exported function") } - res, err := testFunc(NodeStorageTypePersistent, keyPtr, keyLen, valueLen) + res, err := testFunc(int32(NodeStorageTypePersistent), keyPtr, keyLen, valueLen) require.Nil(t, err) require.Equal(t, value, mem[res.ToI32():res.ToI32()+int32(valueLen)]) } +type CompareSetTest struct { + storageType NodeStorageType + key []byte + value []byte + oldValue []byte + newValue []byte + result int32 + storageValue []byte +} + +var CompareSetTests = []CompareSetTest{ + { // persistent, condition match + storageType: NodeStorageTypePersistent, + key: []byte("mykey"), + value: []byte("value"), + oldValue: []byte("value"), + newValue: []byte("newValue"), + result: 0, + storageValue: []byte("newValue"), + }, + { // persistent, condition don't match + storageType: NodeStorageTypePersistent, + key: []byte("mykey"), + value: []byte("value"), + oldValue: []byte("oldValue"), + newValue: []byte("newValue"), + result: 1, + storageValue: []byte("value"), + }, + { // local, condition match + storageType: NodeStorageTypeLocal, + key: []byte("mykey"), + value: []byte("value"), + oldValue: []byte("value"), + newValue: []byte("newValue"), + result: 0, + storageValue: []byte("newValue"), + }, + { // local, condition don't match + storageType: NodeStorageTypeLocal, + key: []byte("mykey"), + value: []byte("value"), + oldValue: []byte("oldValue"), + newValue: []byte("newValue"), + result: 1, + storageValue: []byte("value"), + }, +} + +func TestExt_local_storage_compare_and_set(t *testing.T) { + for _, v := range CompareSetTests { + runtime := NewTestRuntime(t, TEST_RUNTIME) + mem := runtime.vm.Memory.Data() + // setup and init storage + var nodeStorage BasicStorage + switch v.storageType { + case NodeStorageTypePersistent: + nodeStorage = runtime.ctx.nodeStorage.PersistentStorage + case NodeStorageTypeLocal: + nodeStorage = runtime.ctx.nodeStorage.LocalStorage + } + nodeStorage.Put(v.key, v.value) + keyLen := uint32(len(v.key)) + keyPtr, err := runtime.malloc(keyLen) + require.NoError(t, err) + copy(mem[keyPtr:keyPtr+keyLen], v.key) + + oldValueLen := uint32(len(v.oldValue)) + oldValuePtr, err := runtime.malloc(oldValueLen) + require.NoError(t, err) + copy(mem[oldValuePtr:oldValuePtr+oldValueLen], v.oldValue) + + newValueLen := uint32(len(v.newValue)) + newValuePtr, err := runtime.malloc(newValueLen) + require.NoError(t, err) + copy(mem[newValuePtr:newValuePtr+newValueLen], v.newValue) + + // call wasm function + testFunc, ok := runtime.vm.Exports["test_ext_local_storage_compare_and_set"] + if !ok { + t.Fatal("could not find exported function") + } + + res, err := testFunc(int32(v.storageType), int32(keyPtr), int32(keyLen), int32(oldValuePtr), int32(oldValueLen), + int32(newValuePtr), int32(newValueLen)) + require.NoError(t, err) + + // confirm results + require.Equal(t, v.result, res.ToI32()) + checkFunc, ok := runtime.vm.Exports["test_ext_local_storage_get"] + if !ok { + t.Fatal("could not find exported function") + } + checkRes, err := checkFunc(int32(v.storageType), int32(keyPtr), int32(keyLen), int32(newValueLen)) + require.NoError(t, err) + require.Equal(t, v.storageValue, mem[checkRes.ToI32():checkRes.ToI32()+int32(len(v.storageValue))]) + } +} + func TestExt_is_validator(t *testing.T) { // test with validator runtime := NewTestRuntimeWithRole(t, TEST_RUNTIME, byte(4)) diff --git a/lib/runtime/runtime.go b/lib/runtime/runtime.go index 4d0dbea27a..711a04e11f 100644 --- a/lib/runtime/runtime.go +++ b/lib/runtime/runtime.go @@ -29,11 +29,14 @@ import ( var memory, memErr = wasm.NewMemory(17, 0) var logger = log.New("pkg", "runtime") +// NodeStorageType type to identify offchain storage type +type NodeStorageType byte + // NodeStorageTypePersistent flag to identify offchain storage as persistent (db) -const NodeStorageTypePersistent int32 = 1 +const NodeStorageTypePersistent NodeStorageType = 1 // NodeStorageTypeLocal flog to identify offchain storage as local (memory) -const NodeStorageTypeLocal int32 = 2 +const NodeStorageTypeLocal NodeStorageType = 2 // NodeStorage struct for storage of runtime offchain worker data type NodeStorage struct {