diff --git a/plugin/evm/test_validator_state.go b/plugin/evm/test_validator_state.go index 0c4c383457..fdf33bc2a2 100644 --- a/plugin/evm/test_validator_state.go +++ b/plugin/evm/test_validator_state.go @@ -2,6 +2,7 @@ package evm import ( "context" + "fmt" "time" "github.com/ava-labs/avalanchego/ids" @@ -49,27 +50,27 @@ func NewTestValidatorState(pState validators.State) TestValidatorState { } } -func (s *testValidatorState) RecordValidator(nodeID ids.NodeID, startTime, setWeightNonce uint64) { - s.recordedValidators[nodeID] = recordedValidator{ +func (t *testValidatorState) RecordValidator(nodeID ids.NodeID, startTime, setWeightNonce uint64) { + t.recordedValidators[nodeID] = recordedValidator{ StartTime: startTime, SetWeightNonce: setWeightNonce, IsActive: true, } } -func (s *testValidatorState) GetCurrentValidatorSet(ctx context.Context, subnetID ids.ID) (map[ids.ID]*ValidatorOutput, error) { - currentPHeight, err := s.GetCurrentHeight(ctx) +func (t *testValidatorState) GetCurrentValidatorSet(ctx context.Context, subnetID ids.ID) (map[ids.ID]*ValidatorOutput, error) { + currentPHeight, err := t.GetCurrentHeight(ctx) if err != nil { return nil, err } - validatorSet, err := s.GetValidatorSet(ctx, currentPHeight, subnetID) + validatorSet, err := t.GetValidatorSet(ctx, currentPHeight, subnetID) if err != nil { return nil, err } output := make(map[ids.ID]*ValidatorOutput, len(validatorSet)) for key, value := range validatorSet { startTime, isActive, setWeightNonce := DefaultStartTime, DefaultIsActive, DefaultSetWeightNonce - if recordedValidator, ok := s.recordedValidators[key]; ok { + if recordedValidator, ok := t.recordedValidators[key]; ok { startTime = recordedValidator.StartTime isActive = recordedValidator.IsActive setWeightNonce = recordedValidator.SetWeightNonce @@ -91,3 +92,15 @@ func (s *testValidatorState) GetCurrentValidatorSet(ctx context.Context, subnetI } return output, nil } + +func GetValidatorIDs(ctx context.Context, t TestValidatorState, subnetID ids.ID) ([]ids.NodeID, error) { + currentValidatorSet, err := t.GetCurrentValidatorSet(ctx, subnetID) + if err != nil { + return nil, fmt.Errorf("failed to get current validator set: %w", err) + } + validatorIDs := make([]ids.NodeID, 0, len(currentValidatorSet)) + for _, validator := range currentValidatorSet { + validatorIDs = append(validatorIDs, validator.NodeID) + } + return validatorIDs, nil +} diff --git a/plugin/evm/vm.go b/plugin/evm/vm.go index 89a1339ea3..6041c79cb9 100644 --- a/plugin/evm/vm.go +++ b/plugin/evm/vm.go @@ -76,6 +76,7 @@ import ( "github.com/ava-labs/avalanchego/utils/profiler" "github.com/ava-labs/avalanchego/utils/timer/mockable" "github.com/ava-labs/avalanchego/utils/units" + "github.com/ava-labs/avalanchego/version" "github.com/ava-labs/avalanchego/vms/components/chain" commonEng "github.com/ava-labs/avalanchego/snow/engine/common" @@ -706,13 +707,9 @@ func (vm *VM) onNormalOperationsStarted() error { vm.cancel = cancel vm.validatorState = NewTestValidatorState(vm.ctx.ValidatorState) - validatorsOutput, err := vm.validatorState.GetCurrentValidatorSet(ctx, vm.ctx.SubnetID) + vdrIDs, err := GetValidatorIDs(ctx, vm.validatorState, vm.ctx.SubnetID) if err != nil { - return fmt.Errorf("failed to get current validator set: %w", err) - } - var vdrIDs []ids.NodeID - for _, output := range validatorsOutput { - vdrIDs = append(vdrIDs, output.NodeID) + return fmt.Errorf("failed to get validator IDs: %w", err) } if err := vm.uptimeManager.StartTracking(vdrIDs); err != nil { return err @@ -844,7 +841,7 @@ func (vm *VM) setCrossChainAppRequestHandler() { } // Shutdown implements the snowman.ChainVM interface -func (vm *VM) Shutdown(context.Context) error { +func (vm *VM) Shutdown(ctx context.Context) error { if vm.ctx == nil { return nil } @@ -859,6 +856,20 @@ func (vm *VM) Shutdown(context.Context) error { vm.eth.Stop() log.Info("Ethereum backend stop completed") vm.shutdownWg.Wait() + if vm.bootstrapped.Get() { + // TODO: is this the correct ctx to use? + vdrIDs, err := GetValidatorIDs(ctx, vm.validatorState, vm.ctx.SubnetID) + if err != nil { + return fmt.Errorf("failed to get validator IDs: %w", err) + } + if err := vm.uptimeManager.StopTracking(vdrIDs); err != nil { + return err + } + // TODO: persist validator state + // if err := vm.state.Commit(); err != nil { + // return err + // } + } log.Info("Subnet-EVM Shutdown completed") return nil } @@ -1187,3 +1198,21 @@ func attachEthService(handler *rpc.Server, apis []rpc.API, names []string) error return nil } + +func (vm *VM) Connected(ctx context.Context, nodeID ids.NodeID, version *version.Application) error { + if err := vm.uptimeManager.Connect(nodeID); err != nil { + return err + } + return vm.Network.Connected(ctx, nodeID, version) +} + +func (vm *VM) Disconnected(ctx context.Context, nodeID ids.NodeID) error { + if err := vm.uptimeManager.Disconnect(nodeID); err != nil { + return err + } + // TODO: persist state here after disconnect + // if err := vm.state.Commit(); err != nil { + // return err + // } + return vm.Network.Disconnected(ctx, nodeID) +}