diff --git a/cmd/util/cmd/execution-state-extract/cmd.go b/cmd/util/cmd/execution-state-extract/cmd.go index becf31c744a..1215c8bb133 100644 --- a/cmd/util/cmd/execution-state-extract/cmd.go +++ b/cmd/util/cmd/execution-state-extract/cmd.go @@ -7,11 +7,11 @@ import ( "path" "strings" - runtimeCommon "github.com/onflow/cadence/runtime/common" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "github.com/onflow/flow-go/cmd/util/cmd/common" + common2 "github.com/onflow/flow-go/cmd/util/common" "github.com/onflow/flow-go/cmd/util/ledger/migrations" "github.com/onflow/flow-go/cmd/util/ledger/util" "github.com/onflow/flow-go/model/bootstrap" @@ -259,24 +259,13 @@ func run(*cobra.Command, []string) { } } - var exportedAddresses []runtimeCommon.Address + var exportPayloadsForOwners map[string]struct{} if len(flagOutputPayloadByAddresses) > 0 { - - addresses := strings.Split(flagOutputPayloadByAddresses, ",") - - for _, hexAddr := range addresses { - b, err := hex.DecodeString(strings.TrimSpace(hexAddr)) - if err != nil { - log.Fatal().Err(err).Msgf("cannot hex decode address %s for payload export", strings.TrimSpace(hexAddr)) - } - - addr, err := runtimeCommon.BytesToAddress(b) - if err != nil { - log.Fatal().Err(err).Msgf("cannot decode address %x for payload export", b) - } - - exportedAddresses = append(exportedAddresses, addr) + var err error + exportPayloadsForOwners, err = common2.ParseOwners(strings.Split(flagOutputPayloadByAddresses, ",")) + if err != nil { + log.Fatal().Err(err).Msgf("failed to parse addresses") } } @@ -334,12 +323,12 @@ func run(*cobra.Command, []string) { var outputMsg string if len(flagOutputPayloadFileName) > 0 { // Output is payload file - if len(exportedAddresses) == 0 { + if len(exportPayloadsForOwners) == 0 { outputMsg = fmt.Sprintf("exporting all payloads to %s", flagOutputPayloadFileName) } else { outputMsg = fmt.Sprintf( - "exporting payloads by addresses %v to %s", - flagOutputPayloadByAddresses, + "exporting payloads for owners %v to %s", + common2.OwnersToString(exportPayloadsForOwners), flagOutputPayloadFileName, ) } @@ -397,7 +386,7 @@ func run(*cobra.Command, []string) { !flagNoMigration, flagInputPayloadFileName, flagOutputPayloadFileName, - exportedAddresses, + exportPayloadsForOwners, flagSortPayloads, opts, ) @@ -410,7 +399,7 @@ func run(*cobra.Command, []string) { flagNWorker, !flagNoMigration, flagOutputPayloadFileName, - exportedAddresses, + exportPayloadsForOwners, flagSortPayloads, opts, ) diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract.go b/cmd/util/cmd/execution-state-extract/execution_state_extract.go index 388c1134a68..319f5fc057e 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract.go @@ -10,7 +10,6 @@ import ( syncAtomic "sync/atomic" "time" - "github.com/onflow/cadence/runtime/common" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "go.uber.org/atomic" @@ -39,7 +38,7 @@ func extractExecutionState( nWorker int, // number of concurrent worker to migration payloads runMigrations bool, outputPayloadFile string, - exportPayloadsByAddresses []common.Address, + exportPayloadsForOwners map[string]struct{}, sortPayloads bool, opts migrators.Options, ) error { @@ -145,7 +144,7 @@ func extractExecutionState( payloads, nWorker, outputPayloadFile, - exportPayloadsByAddresses, + exportPayloadsForOwners, false, // payloads represents entire state. sortPayloads, ) @@ -217,7 +216,7 @@ func extractExecutionStateFromPayloads( runMigrations bool, inputPayloadFile string, outputPayloadFile string, - exportPayloadsByAddresses []common.Address, + exportPayloadsForOwners map[string]struct{}, sortPayloads bool, opts migrators.Options, ) error { @@ -253,7 +252,7 @@ func extractExecutionStateFromPayloads( payloads, nWorker, outputPayloadFile, - exportPayloadsByAddresses, + exportPayloadsForOwners, inputPayloadsFromPartialState, sortPayloads, ) @@ -288,7 +287,7 @@ func exportPayloads( payloads []*ledger.Payload, nWorker int, outputPayloadFile string, - exportPayloadsByAddresses []common.Address, + exportPayloadsForOwners map[string]struct{}, inputPayloadsFromPartialState bool, sortPayloads bool, ) error { @@ -308,7 +307,7 @@ func exportPayloads( log, outputPayloadFile, payloads, - exportPayloadsByAddresses, + exportPayloadsForOwners, inputPayloadsFromPartialState, ) if err != nil { diff --git a/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go b/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go index 26ab069fedf..a57cd966f05 100644 --- a/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go +++ b/cmd/util/cmd/execution-state-extract/execution_state_extract_test.go @@ -406,10 +406,22 @@ func TestExtractPayloadsFromExecutionState(t *testing.T) { // Verify exported payloads. partialState, payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), outputPayloadFileName) require.NoError(t, err) - require.Equal(t, len(selectedKeysValues), len(payloadsFromFile)) require.True(t, partialState) + nonGlobalPayloads := make([]*ledger.Payload, 0, len(selectedKeysValues)) for _, payloadFromFile := range payloadsFromFile { + key, err := payloadFromFile.Key() + require.NoError(t, err) + + owner := key.KeyParts[0].Value + if len(owner) > 0 { + nonGlobalPayloads = append(nonGlobalPayloads, payloadFromFile) + } + } + + require.Equal(t, len(selectedKeysValues), len(nonGlobalPayloads)) + + for _, payloadFromFile := range nonGlobalPayloads { k, err := payloadFromFile.Key() require.NoError(t, err) diff --git a/cmd/util/cmd/extract-payloads-by-address/cmd.go b/cmd/util/cmd/extract-payloads-by-address/cmd.go index 3d66ea65cf1..5759b4443bd 100644 --- a/cmd/util/cmd/extract-payloads-by-address/cmd.go +++ b/cmd/util/cmd/extract-payloads-by-address/cmd.go @@ -1,16 +1,13 @@ package extractpayloads import ( - "encoding/hex" - "fmt" "os" "strings" "github.com/rs/zerolog/log" "github.com/spf13/cobra" - "github.com/onflow/cadence/runtime/common" - + "github.com/onflow/flow-go/cmd/util/common" "github.com/onflow/flow-go/cmd/util/ledger/util" ) @@ -60,14 +57,14 @@ func run(*cobra.Command, []string) { log.Fatal().Msgf("Output file %s exists", flagOutputPayloadFileName) } - addresses, err := parseAddresses(strings.Split(flagAddresses, ",")) + owners, err := common.ParseOwners(strings.Split(flagAddresses, ",")) if err != nil { log.Fatal().Err(err) } log.Info().Msgf( - "extracting payloads with address %v from %s to %s", - addresses, + "extracting payloads with owners %s from %s to %s", + common.OwnersToString(owners), flagInputPayloadFileName, flagOutputPayloadFileName, ) @@ -77,39 +74,22 @@ func run(*cobra.Command, []string) { log.Fatal().Err(err) } - numOfPayloadWritten, err := util.CreatePayloadFile(log.Logger, flagOutputPayloadFileName, payloads, addresses, inputPayloadsFromPartialState) + numOfPayloadWritten, err := util.CreatePayloadFile( + log.Logger, + flagOutputPayloadFileName, + payloads, + owners, + inputPayloadsFromPartialState, + ) if err != nil { log.Fatal().Err(err) } log.Info().Msgf( - "extracted %d payloads with address %v from %s to %s", + "extracted %d payloads with owners %s from %s to %s", numOfPayloadWritten, - addresses, + common.OwnersToString(owners), flagInputPayloadFileName, flagOutputPayloadFileName, ) } - -func parseAddresses(hexAddresses []string) ([]common.Address, error) { - if len(hexAddresses) == 0 { - return nil, fmt.Errorf("at least one address must be provided") - } - - addresses := make([]common.Address, len(hexAddresses)) - for i, hexAddr := range hexAddresses { - b, err := hex.DecodeString(strings.TrimSpace(hexAddr)) - if err != nil { - return nil, fmt.Errorf("address is not hex encoded %s: %w", strings.TrimSpace(hexAddr), err) - } - - addr, err := common.BytesToAddress(b) - if err != nil { - return nil, fmt.Errorf("cannot decode address %x", b) - } - - addresses[i] = addr - } - - return addresses, nil -} diff --git a/cmd/util/cmd/extract-payloads-by-address/extract_payloads_test.go b/cmd/util/cmd/extract-payloads-by-address/extract_payloads_test.go index a30574b926a..3c1bb267cc1 100644 --- a/cmd/util/cmd/extract-payloads-by-address/extract_payloads_test.go +++ b/cmd/util/cmd/extract-payloads-by-address/extract_payloads_test.go @@ -1,9 +1,7 @@ package extractpayloads import ( - "bytes" "crypto/rand" - "encoding/hex" "path/filepath" "strings" "testing" @@ -94,10 +92,22 @@ func TestExtractPayloads(t *testing.T) { // Verify exported payloads. partialState, payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), outputFile) require.NoError(t, err) - require.Equal(t, len(selectedKeysValues), len(payloadsFromFile)) require.True(t, partialState) + nonGlobalPayloads := make([]*ledger.Payload, 0, len(selectedKeysValues)) for _, payloadFromFile := range payloadsFromFile { + key, err := payloadFromFile.Key() + require.NoError(t, err) + + owner := key.KeyParts[0].Value + if len(owner) > 0 { + nonGlobalPayloads = append(nonGlobalPayloads, payloadFromFile) + } + } + + require.Equal(t, len(selectedKeysValues), len(nonGlobalPayloads)) + + for _, payloadFromFile := range nonGlobalPayloads { k, err := payloadFromFile.Key() require.NoError(t, err) @@ -108,9 +118,7 @@ func TestExtractPayloads(t *testing.T) { }) }) - t.Run("no payloads", func(t *testing.T) { - - emptyAddress := common.Address{} + t.Run("empty address", func(t *testing.T) { unittest.RunWithTempDir(t, func(datadir string) { @@ -127,9 +135,6 @@ func TestExtractPayloads(t *testing.T) { keys, values := getSampleKeyValues(i) for j, key := range keys { - if bytes.Equal(key.KeyParts[0].Value, emptyAddress[:]) { - continue - } keysValues[key.String()] = keyPair{ key: key, value: values[j], @@ -147,7 +152,7 @@ func TestExtractPayloads(t *testing.T) { Cmd.SetArgs([]string{ "--input-filename", inputFile, "--output-filename", outputFile, - "--addresses", hex.EncodeToString(emptyAddress[:]), + "--addresses", ",", }) err = Cmd.Execute() @@ -156,8 +161,21 @@ func TestExtractPayloads(t *testing.T) { // Verify exported payloads. partialState, payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), outputFile) require.NoError(t, err) - require.Equal(t, 0, len(payloadsFromFile)) require.True(t, partialState) + + var nonGlobalPayloads []*ledger.Payload + for _, payloadFromFile := range payloadsFromFile { + key, err := payloadFromFile.Key() + require.NoError(t, err) + + owner := key.KeyParts[0].Value + if len(owner) > 0 { + nonGlobalPayloads = append(nonGlobalPayloads, payloadFromFile) + } + } + + require.Equal(t, 0, len(nonGlobalPayloads)) + }) }) } diff --git a/cmd/util/common/address.go b/cmd/util/common/address.go new file mode 100644 index 00000000000..60060b1072d --- /dev/null +++ b/cmd/util/common/address.go @@ -0,0 +1,60 @@ +package common + +import ( + "encoding/hex" + "fmt" + "strings" + + "github.com/onflow/flow-go/model/flow" +) + +func ParseOwners(hexAddresses []string) (map[string]struct{}, error) { + if len(hexAddresses) == 0 { + return nil, fmt.Errorf("at least one address must be provided") + } + + addresses := make(map[string]struct{}, len(hexAddresses)) + for _, hexAddr := range hexAddresses { + hexAddr = strings.TrimSpace(hexAddr) + + if len(hexAddr) > 0 { + addr, err := ParseAddress(hexAddr) + if err != nil { + return nil, err + } + + addresses[string(addr[:])] = struct{}{} + } else { + // global registers has empty address + addresses[""] = struct{}{} + } + } + + return addresses, nil +} + +func ParseAddress(hexAddr string) (flow.Address, error) { + b, err := hex.DecodeString(hexAddr) + if err != nil { + return flow.Address{}, fmt.Errorf( + "address is not hex encoded %s: %w", + strings.TrimSpace(hexAddr), + err, + ) + } + + return flow.BytesToAddress(b), nil +} + +func OwnersToString(owners map[string]struct{}) string { + var sb strings.Builder + index := 0 + for owner := range owners { + if index > 0 { + sb.WriteRune(',') + } + _, _ = fmt.Fprintf(&sb, "%x", owner) + index++ + } + return sb.String() +} diff --git a/cmd/util/ledger/util/payload_file.go b/cmd/util/ledger/util/payload_file.go index 1aad4a1bc10..4b419d19736 100644 --- a/cmd/util/ledger/util/payload_file.go +++ b/cmd/util/ledger/util/payload_file.go @@ -2,7 +2,6 @@ package util import ( "bufio" - "bytes" "encoding/binary" "fmt" "io" @@ -12,8 +11,6 @@ import ( "github.com/fxamacker/cbor/v2" "github.com/rs/zerolog" - "github.com/onflow/cadence/runtime/common" - "github.com/onflow/flow-go/ledger" "github.com/onflow/flow-go/ledger/complete/wal" ) @@ -119,11 +116,11 @@ func CreatePayloadFile( logger zerolog.Logger, payloadFile string, payloads []*ledger.Payload, - addresses []common.Address, + owners map[string]struct{}, inputPayloadsFromPartialState bool, ) (int, error) { - partialState := inputPayloadsFromPartialState || len(addresses) > 0 + partialState := inputPayloadsFromPartialState || len(owners) > 0 f, err := os.Create(payloadFile) if err != nil { @@ -132,9 +129,6 @@ func CreatePayloadFile( defer f.Close() writer := bufio.NewWriterSize(f, defaultBufioWriteSize) - if err != nil { - return 0, fmt.Errorf("can't create bufio writer for %s: %w", payloadFile, err) - } defer writer.Flush() // TODO: replace CRC-32 checksum. @@ -155,14 +149,14 @@ func CreatePayloadFile( return 0, fmt.Errorf("can't write payload file head for %s: %w", payloadFile, err) } - includeAllPayloads := len(addresses) == 0 + includeAllPayloads := len(owners) == 0 // Write payloads. var writtenPayloadCount int if includeAllPayloads { writtenPayloadCount, err = writePayloads(logger, crc32Writer, payloads) } else { - writtenPayloadCount, err = writeSelectedPayloads(logger, crc32Writer, payloads, addresses) + writtenPayloadCount, err = writeSelectedPayloads(logger, crc32Writer, payloads, owners) } if err != nil { @@ -209,7 +203,12 @@ func writePayloads(logger zerolog.Logger, w io.Writer, payloads []*ledger.Payloa return len(payloads), nil } -func writeSelectedPayloads(logger zerolog.Logger, w io.Writer, payloads []*ledger.Payload, addresses []common.Address) (int, error) { +func writeSelectedPayloads( + logger zerolog.Logger, + w io.Writer, + payloads []*ledger.Payload, + owners map[string]struct{}, +) (int, error) { logger.Info().Msgf("filtering %d payloads and writing selected payloads to file", len(payloads)) enc := cbor.NewEncoder(w) @@ -217,7 +216,7 @@ func writeSelectedPayloads(logger zerolog.Logger, w io.Writer, payloads []*ledge var includedPayloadCount int var payloadScratchBuffer [1024 * 2]byte for _, p := range payloads { - include, err := includePayloadByAddresses(p, addresses) + include, err := includePayloadByOwners(p, owners) if err != nil { return 0, err } @@ -239,8 +238,8 @@ func writeSelectedPayloads(logger zerolog.Logger, w io.Writer, payloads []*ledge return includedPayloadCount, nil } -func includePayloadByAddresses(payload *ledger.Payload, addresses []common.Address) (bool, error) { - if len(addresses) == 0 { +func includePayloadByOwners(payload *ledger.Payload, owners map[string]struct{}) (bool, error) { + if len(owners) == 0 { // Include all payloads return true, nil } @@ -250,15 +249,16 @@ func includePayloadByAddresses(payload *ledger.Payload, addresses []common.Addre return false, fmt.Errorf("can't get key from payload: %w", err) } - owner := k.KeyParts[0].Value + owner := string(k.KeyParts[0].Value) - for _, address := range addresses { - if bytes.Equal(owner, address[:]) { - return true, nil - } + // Always include payloads for global registers, + // i.e. with empty owner + if owner == "" { + return true, nil } - return false, nil + _, ok := owners[owner] + return ok, nil } func ReadPayloadFile(logger zerolog.Logger, payloadFile string) (bool, []*ledger.Payload, error) { diff --git a/cmd/util/ledger/util/payload_file_test.go b/cmd/util/ledger/util/payload_file_test.go index 2ce69dc5876..26b2092a623 100644 --- a/cmd/util/ledger/util/payload_file_test.go +++ b/cmd/util/ledger/util/payload_file_test.go @@ -142,10 +142,16 @@ func TestPayloadFile(t *testing.T) { keysValues := make(map[string]keyPair) var payloads []*ledger.Payload + var globalRegisterCount int for i := 0; i < size; i++ { keys, values := getSampleKeyValues(i) for j, key := range keys { + + if len(key.KeyParts[0].Value) == 0 { + globalRegisterCount++ + } + keysValues[key.String()] = keyPair{ key: key, value: values[j], @@ -176,9 +182,9 @@ func TestPayloadFile(t *testing.T) { } } - addresses := make([]common.Address, 0, len(selectedAddresses)) + addresses := make(map[string]struct{}, len(selectedAddresses)) for address := range selectedAddresses { - addresses = append(addresses, address) + addresses[string(address[:])] = struct{}{} } numOfPayloadWritten, err := util.CreatePayloadFile( @@ -189,7 +195,11 @@ func TestPayloadFile(t *testing.T) { false, // input payloads represent entire state ) require.NoError(t, err) - require.Equal(t, len(selectedKeysValues), numOfPayloadWritten) + require.Equal( + t, + len(selectedKeysValues)+globalRegisterCount, + numOfPayloadWritten, + ) partialState, err := util.IsPayloadFilePartialState(payloadFileName) require.NoError(t, err) @@ -197,10 +207,22 @@ func TestPayloadFile(t *testing.T) { partialState, payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), payloadFileName) require.NoError(t, err) - require.Equal(t, len(selectedKeysValues), len(payloadsFromFile)) require.True(t, partialState) + nonGlobalPayloads := make([]*ledger.Payload, 0, len(selectedKeysValues)) for _, payloadFromFile := range payloadsFromFile { + key, err := payloadFromFile.Key() + require.NoError(t, err) + + owner := key.KeyParts[0].Value + if len(owner) > 0 { + nonGlobalPayloads = append(nonGlobalPayloads, payloadFromFile) + } + } + + require.Equal(t, len(selectedKeysValues), len(nonGlobalPayloads)) + + for _, payloadFromFile := range nonGlobalPayloads { k, err := payloadFromFile.Key() require.NoError(t, err) @@ -213,6 +235,7 @@ func TestPayloadFile(t *testing.T) { }) t.Run("no payloads found with filter", func(t *testing.T) { + emptyAddress := common.Address{} unittest.RunWithTempDir(t, func(datadir string) { @@ -224,13 +247,20 @@ func TestPayloadFile(t *testing.T) { keysValues := make(map[string]keyPair) var payloads []*ledger.Payload + var globalRegisterCount int + for i := 0; i < size; i++ { keys, values := getSampleKeyValues(i) for j, key := range keys { + if len(key.KeyParts[0].Value) == 0 { + globalRegisterCount++ + } + if bytes.Equal(key.KeyParts[0].Value, emptyAddress[:]) { continue } + keysValues[key.String()] = keyPair{ key: key, value: values[j], @@ -244,11 +274,13 @@ func TestPayloadFile(t *testing.T) { zerolog.Nop(), payloadFileName, payloads, - []common.Address{emptyAddress}, + map[string]struct{}{ + string(emptyAddress[:]): {}, + }, false, ) require.NoError(t, err) - require.Equal(t, 0, numOfPayloadWritten) + require.Equal(t, globalRegisterCount, numOfPayloadWritten) partialState, err := util.IsPayloadFilePartialState(payloadFileName) require.NoError(t, err) @@ -256,7 +288,7 @@ func TestPayloadFile(t *testing.T) { partialState, payloadsFromFile, err := util.ReadPayloadFile(zerolog.Nop(), payloadFileName) require.NoError(t, err) - require.Equal(t, 0, len(payloadsFromFile)) + require.Equal(t, globalRegisterCount, len(payloadsFromFile)) require.True(t, partialState) }) })