diff --git a/snow/engine/snowman/block/blocktest/set_preference_vm.go b/snow/engine/snowman/block/blocktest/set_preference_vm.go new file mode 100644 index 000000000000..0c2944a76428 --- /dev/null +++ b/snow/engine/snowman/block/blocktest/set_preference_vm.go @@ -0,0 +1,42 @@ +// Copyright (C) 2019-2025, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package blocktest + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/snow/engine/snowman/block" +) + +var ( + errSetPreferenceWithContext = errors.New("unexpectedly called SetPreferenceWithContext") + + _ block.SetPreferenceWithContextChainVM = (*SetPreferenceVM)(nil) +) + +type SetPreferenceVM struct { + T *testing.T + + CantSetPreferenceWithContext bool + SetPreferenceWithContextF func(context.Context, ids.ID, *block.Context) error +} + +func (vm *SetPreferenceVM) Default(cant bool) { + vm.CantSetPreferenceWithContext = cant +} + +func (vm *SetPreferenceVM) SetPreferenceWithContext(ctx context.Context, id ids.ID, blockCtx *block.Context) error { + if vm.SetPreferenceWithContextF != nil { + return vm.SetPreferenceWithContextF(ctx, id, blockCtx) + } + if vm.CantSetPreferenceWithContext && vm.T != nil { + require.FailNow(vm.T, errSetPreferenceWithContext.Error()) + } + return errSetPreferenceWithContext +} diff --git a/vms/platformvm/block/executor/manager_test.go b/vms/platformvm/block/executor/manager_test.go index a0ff0882a911..eff32808a0c3 100644 --- a/vms/platformvm/block/executor/manager_test.go +++ b/vms/platformvm/block/executor/manager_test.go @@ -13,6 +13,8 @@ import ( "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/vms/platformvm/block" "github.com/ava-labs/avalanchego/vms/platformvm/state" + + snowmanblock "github.com/ava-labs/avalanchego/snow/engine/snowman/block" ) func TestGetBlock(t *testing.T) { @@ -85,3 +87,22 @@ func TestManagerSetPreference(t *testing.T) { manager.SetPreference(newPreference, nil) require.Equal(newPreference, manager.Preferred()) } + +func TestManagerSetPreferenceWithContext(t *testing.T) { + require := require.New(t) + + initialPreference := ids.GenerateTestID() + manager := &manager{ + preferred: initialPreference, + } + require.Equal(initialPreference, manager.Preferred()) + require.Nil(manager.preferredCtx) + + newPreference := ids.GenerateTestID() + newContext := &snowmanblock.Context{ + PChainHeight: 100, + } + manager.SetPreference(newPreference, newContext) + require.Equal(newPreference, manager.Preferred()) + require.Equal(newContext, manager.preferredCtx) +} diff --git a/vms/proposervm/vm_test.go b/vms/proposervm/vm_test.go index 99c3e5e2199d..2dbcfb88ecfd 100644 --- a/vms/proposervm/vm_test.go +++ b/vms/proposervm/vm_test.go @@ -45,13 +45,15 @@ import ( ) var ( - _ block.ChainVM = (*fullVM)(nil) - _ block.StateSyncableVM = (*fullVM)(nil) + _ block.ChainVM = (*fullVM)(nil) + _ block.StateSyncableVM = (*fullVM)(nil) + _ block.SetPreferenceWithContextChainVM = (*fullVM)(nil) ) type fullVM struct { *blocktest.VM *blocktest.StateSyncableVM + *blocktest.SetPreferenceVM } var ( @@ -100,37 +102,40 @@ func initTestProposerVM( VM: &blocktest.VM{ VM: enginetest.VM{ T: t, + InitializeF: func(context.Context, *snow.Context, database.Database, []byte, []byte, []byte, []*common.Fx, common.AppSender) error { + return nil + }, + }, + LastAcceptedF: snowmantest.MakeLastAcceptedBlockF( + []*snowmantest.Block{snowmantest.Genesis}, + ), + GetBlockF: func(_ context.Context, blkID ids.ID) (snowman.Block, error) { + switch blkID { + case snowmantest.GenesisID: + return snowmantest.Genesis, nil + default: + return nil, errUnknownBlock + } + }, + ParseBlockF: func(_ context.Context, b []byte) (snowman.Block, error) { + switch { + case bytes.Equal(b, snowmantest.GenesisBytes): + return snowmantest.Genesis, nil + default: + return nil, errUnknownBlock + } }, }, StateSyncableVM: &blocktest.StateSyncableVM{ T: t, }, + SetPreferenceVM: &blocktest.SetPreferenceVM{ + T: t, + }, } - - coreVM.InitializeF = func(context.Context, *snow.Context, database.Database, - []byte, []byte, []byte, - []*common.Fx, common.AppSender, - ) error { - return nil - } - coreVM.LastAcceptedF = snowmantest.MakeLastAcceptedBlockF( - []*snowmantest.Block{snowmantest.Genesis}, - ) - coreVM.GetBlockF = func(_ context.Context, blkID ids.ID) (snowman.Block, error) { - switch blkID { - case snowmantest.GenesisID: - return snowmantest.Genesis, nil - default: - return nil, errUnknownBlock - } - } - coreVM.ParseBlockF = func(_ context.Context, b []byte) (snowman.Block, error) { - switch { - case bytes.Equal(b, snowmantest.GenesisBytes): - return snowmantest.Genesis, nil - default: - return nil, errUnknownBlock - } + // Default to routing SetPreferenceWithContext to SetPreference + coreVM.SetPreferenceWithContextF = func(ctx context.Context, blkID ids.ID, _ *block.Context) error { + return coreVM.SetPreference(ctx, blkID) } var upgrades upgrade.Config @@ -155,38 +160,38 @@ func initTestProposerVM( valState := &validatorstest.State{ T: t, - } - valState.GetMinimumHeightF = func(context.Context) (uint64, error) { - return snowmantest.GenesisHeight, nil - } - valState.GetCurrentHeightF = func(context.Context) (uint64, error) { - return defaultPChainHeight, nil - } - valState.GetValidatorSetF = func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { - var ( - thisNode = proVM.ctx.NodeID - nodeID1 = ids.BuildTestNodeID([]byte{1}) - nodeID2 = ids.BuildTestNodeID([]byte{2}) - nodeID3 = ids.BuildTestNodeID([]byte{3}) - ) - return map[ids.NodeID]*validators.GetValidatorOutput{ - thisNode: { - NodeID: thisNode, - Weight: 10, - }, - nodeID1: { - NodeID: nodeID1, - Weight: 5, - }, - nodeID2: { - NodeID: nodeID2, - Weight: 6, - }, - nodeID3: { - NodeID: nodeID3, - Weight: 7, - }, - }, nil + GetMinimumHeightF: func(context.Context) (uint64, error) { + return snowmantest.GenesisHeight, nil + }, + GetCurrentHeightF: func(context.Context) (uint64, error) { + return defaultPChainHeight, nil + }, + GetValidatorSetF: func(context.Context, uint64, ids.ID) (map[ids.NodeID]*validators.GetValidatorOutput, error) { + var ( + thisNode = proVM.ctx.NodeID + nodeID1 = ids.BuildTestNodeID([]byte{1}) + nodeID2 = ids.BuildTestNodeID([]byte{2}) + nodeID3 = ids.BuildTestNodeID([]byte{3}) + ) + return map[ids.NodeID]*validators.GetValidatorOutput{ + thisNode: { + NodeID: thisNode, + Weight: 10, + }, + nodeID1: { + NodeID: nodeID1, + Weight: 5, + }, + nodeID2: { + NodeID: nodeID2, + Weight: 6, + }, + nodeID3: { + NodeID: nodeID3, + Weight: 7, + }, + }, nil + }, } ctx := snowtest.Context(t, ids.ID{1}) @@ -757,6 +762,138 @@ func TestPreFork_SetPreference(t *testing.T) { require.Equal(builtBlk.ID(), nextBlk.Parent()) } +// TestPostFork_SetPreference tests the SetPreference functionality after the fork +// when SetPreferenceWithContext may be called based on various conditions. +func TestPostFork_SetPreference(t *testing.T) { + // Helper to create a block with the given epoch, P-Chain height, and optional custom timestamp + createBlockWithEpoch := func(proVM *VM, epoch statelessblock.Epoch, blockPChainHeight uint64, customTimestamp ...time.Time) PostForkBlock { + coreBlk := snowmantest.BuildChild(snowmantest.Genesis) + + timestamp := coreBlk.Timestamp() + if len(customTimestamp) > 0 { + timestamp = customTimestamp[0] + } + + statelessBlk, err := statelessblock.BuildUnsigned( + snowmantest.GenesisID, + timestamp, + blockPChainHeight, + epoch, + coreBlk.Bytes(), + ) + require.NoError(t, err) + + return &postForkBlock{ + SignedBlock: statelessBlk, + postForkCommonComponents: postForkCommonComponents{ + vm: proVM, + innerBlk: coreBlk, + }, + } + } + + testErr := errors.New("test err") + + tests := []struct { + name string + hasSetPreferenceWithContext bool + epochPChainHeight *uint64 // nil = no epoch, otherwise epoch with this P-Chain height + sealEpoch bool // whether block timestamp should seal the epoch + expectSetPreferenceWithContext bool + expectedError error + }{ + { + name: "setPreferenceVM is nil - should call regular SetPreference", + hasSetPreferenceWithContext: false, + epochPChainHeight: &defaultPChainHeight, + expectSetPreferenceWithContext: false, + }, + { + name: "preferredEpoch is empty - should call regular SetPreference", + hasSetPreferenceWithContext: true, + epochPChainHeight: nil, // no epoch + expectSetPreferenceWithContext: false, + }, + { + name: "both conditions met - should call SetPreferenceWithContext", + hasSetPreferenceWithContext: true, + epochPChainHeight: &defaultPChainHeight, + expectSetPreferenceWithContext: true, + }, + { + name: "SetPreferenceWithContext returns error", + hasSetPreferenceWithContext: true, + epochPChainHeight: &defaultPChainHeight, + expectSetPreferenceWithContext: true, + expectedError: testErr, + }, + { + name: "epoch sealed - next epoch has different PChainHeight", + hasSetPreferenceWithContext: true, + epochPChainHeight: func() *uint64 { + h := defaultPChainHeight - 100 // older epoch height + return &h + }(), + sealEpoch: true, + expectSetPreferenceWithContext: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + require := require.New(t) + + coreVM, _, proVM, _ := initTestProposerVM(t, upgradetest.Latest, defaultPChainHeight) + defer func() { + require.NoError(proVM.Shutdown(context.Background())) + }() + + if test.hasSetPreferenceWithContext { + coreVM.SetPreferenceWithContextF = func(_ context.Context, _ ids.ID, blockContext *block.Context) error { + require.Equal(defaultPChainHeight, blockContext.PChainHeight) + return test.expectedError + } + } else { + proVM.setPreferenceVM = nil + coreVM.SetPreferenceF = func(context.Context, ids.ID) error { + return test.expectedError + } + } + + if test.expectSetPreferenceWithContext { + coreVM.CantSetPreference = true + coreVM.SetPreferenceF = nil + } else { + coreVM.CantSetPreferenceWithContext = true + coreVM.SetPreferenceWithContextF = nil + } + + // Create block based on test requirements + var epoch statelessblock.Epoch + if test.epochPChainHeight != nil { + epoch = statelessblock.Epoch{ + PChainHeight: *test.epochPChainHeight, + Number: 1, + StartTime: snowmantest.GenesisTimestamp.Unix(), + } + } + + var postForkBlk PostForkBlock + if test.sealEpoch { + // Create a block timestamp that seals the epoch + epochSealingTimestamp := snowmantest.GenesisTimestamp.Add(upgrade.Default.GraniteEpochDuration) + postForkBlk = createBlockWithEpoch(proVM, epoch, defaultPChainHeight, epochSealingTimestamp) + } else { + postForkBlk = createBlockWithEpoch(proVM, epoch, defaultPChainHeight) + } + + proVM.verifiedBlocks[postForkBlk.ID()] = postForkBlk + err := proVM.SetPreference(context.Background(), postForkBlk.ID()) + require.ErrorIs(err, test.expectedError) + }) + } +} + func TestExpiredBuildBlock(t *testing.T) { require := require.New(t)