diff --git a/protofsm/daemon_events.go b/protofsm/daemon_events.go new file mode 100644 index 00000000000..e5de0b69517 --- /dev/null +++ b/protofsm/daemon_events.go @@ -0,0 +1,122 @@ +package protofsm + +import ( + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lnwire" +) + +// DaemonEvent is a special event that can be emitted by a state transition +// function. A state machine can use this to perform side effects, such as +// sending a message to a peer, or broadcasting a transaction. +type DaemonEvent interface { + daemonSealed() +} + +// DaemonEventSet is a set of daemon events that can be emitted by a state +// transition. +type DaemonEventSet []DaemonEvent + +// DaemonEvents is a special type constraint that enumerates all the possible +// types of daemon events. +type DaemonEvents interface { + SendMsgEvent[any] | BroadcastTxn | RegisterSpend[any] | + RegisterConf[any] +} + +// SendPredicate is a function that returns true if the target message should +// sent. +type SendPredicate = func() bool + +// SendMsgEvent is a special event that can be emitted by a state transition +// that instructs the daemon to send the contained message to the target peer. +type SendMsgEvent[Event any] struct { + // TargetPeer is the peer to send the message to. + TargetPeer btcec.PublicKey + + // Msgs is the set of messages to send to the target peer. + Msgs []lnwire.Message + + // SendWhen implements a system for a conditional send once a special + // send predicate has been met. + // + // TODO(roasbeef): contrast with usage of OnCommitFlush, etc + SendWhen fn.Option[SendPredicate] + + // PostSendEvent is an optional event that is to be emitted after the + // message has been sent. If a SendWhen is specified, then this will + // only be executed after that returns true to unblock the send. + PostSendEvent fn.Option[Event] +} + +// daemonSealed indicates that this struct is a DaemonEvent instance. +func (s *SendMsgEvent[E]) daemonSealed() {} + +// BroadcastTxn indicates the target transaction should be broadcast to the +// network. +type BroadcastTxn struct { + // Tx is the transaction to broadcast. + Tx *wire.MsgTx + + // Label is an optional label to attach to the transaction. + Label string +} + +// daemonSealed indicates that this struct is a DaemonEvent instance. +func (b *BroadcastTxn) daemonSealed() {} + +// SpendMapper is a function that's used to map a spend notification to a +// custom state machine event. +type SpendMapper[Event any] func(*chainntnfs.SpendDetail) Event + +// RegisterSpend is used to request that a certain event is sent into the state +// machine once the specified outpoint has been spent. +type RegisterSpend[Event any] struct { + // OutPoint is the outpoint on chain to watch. + OutPoint wire.OutPoint + + // PkScript is the script that we expect to be spent along with the + // outpoint. + PkScript []byte + + // HeightHint is a value used to give the chain scanner a hint on how + // far back it needs to start its search. + HeightHint uint32 + + // PostSpendEvent is a special spend mapper, that if present, will be + // used to map the protofsm spend event to a custom event. + PostSpendEvent fn.Option[SpendMapper[Event]] +} + +// daemonSealed indicates that this struct is a DaemonEvent instance. +func (r *RegisterSpend[E]) daemonSealed() {} + +// RegisterConf is used to request that a certain event is sent into the state +// machien once the specified outpoint has been spent. +type RegisterConf[Event any] struct { + // Txid is the txid of the txn we want to watch the chain for. + Txid chainhash.Hash + + // PkScript is the script that we expect to be created along with the + // outpoint. + PkScript []byte + + // HeightHint is a value used to give the chain scanner a hint on how + // far back it needs to start its search. + HeightHint uint32 + + // NumConfs is the number of confirmations that the spending + // transaction needs to dispatch an event. + NumConfs fn.Option[uint32] + + // PostConfEvent is an event that's sent back to the requester once the + // transaction specified above has confirmed in the chain with + // sufficient depth. + PostConfEvent fn.Option[Event] +} + +// daemonSealed indicates that this struct is a DaemonEvent instance. +func (r *RegisterConf[E]) daemonSealed() {} diff --git a/protofsm/log.go b/protofsm/log.go new file mode 100644 index 00000000000..8ff9c1b62f2 --- /dev/null +++ b/protofsm/log.go @@ -0,0 +1,29 @@ +package protofsm + +import ( + "github.com/btcsuite/btclog" + "github.com/lightningnetwork/lnd/build" +) + +// log is a logger that is initialized with no output filters. This +// means the package will not perform any logging by default until the caller +// requests it. +var log btclog.Logger + +// The default amount of logging is none. +func init() { + UseLogger(build.NewSubLogger("PFSM", nil)) +} + +// DisableLog disables all library log output. Logging output is disabled +// by default until UseLogger is called. +func DisableLog() { + UseLogger(btclog.Disabled) +} + +// UseLogger uses a specified Logger to output package logging info. +// This should be used in preference to SetLogWriter if the caller is also +// using btclog. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/protofsm/msg_mapper.go b/protofsm/msg_mapper.go new file mode 100644 index 00000000000..b96d677e6bb --- /dev/null +++ b/protofsm/msg_mapper.go @@ -0,0 +1,15 @@ +package protofsm + +import ( + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lnwire" +) + +// MsgMapper is used to map incoming wire messages into a FSM event. This is +// useful to decouple the translation of an outside or wire message into an +// event type that can be understood by the FSM. +type MsgMapper[Event any] interface { + // MapMsg maps a wire message into a FSM event. If the message is not + // mappable, then an None is returned. + MapMsg(msg lnwire.Message) fn.Option[Event] +} diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go new file mode 100644 index 00000000000..ecbd7483478 --- /dev/null +++ b/protofsm/state_machine.go @@ -0,0 +1,670 @@ +package protofsm + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lnutils" + "github.com/lightningnetwork/lnd/lnwire" +) + +const ( + // pollInterval is the interval at which we'll poll the SendWhen + // predicate if specified. + pollInterval = time.Millisecond * 100 +) + +// EmittedEvent is a special type that can be emitted by a state transition. +// This can container internal events which are to be routed back to the state, +// or external events which are to be sent to the daemon. +type EmittedEvent[Event any] struct { + // InternalEvent is an optional internal event that is to be routed + // back to the target state. This enables state to trigger one or many + // state transitions without a new external event. + InternalEvent fn.Option[[]Event] + + // ExternalEvent is an optional external event that is to be sent to + // the daemon for dispatch. Usually, this is some form of I/O. + ExternalEvents fn.Option[DaemonEventSet] +} + +// StateTransition is a state transition type. It denotes the next state to go +// to, and also the set of events to emit. +type StateTransition[Event any, Env Environment] struct { + // NextState is the next state to transition to. + NextState State[Event, Env] + + // NewEvents is the set of events to emit. + NewEvents fn.Option[EmittedEvent[Event]] +} + +// Environment is an abstract interface that represents the environment that +// the state machine will execute using. From the PoV of the main state machine +// executor, we just care about being able to clean up any resources that were +// allocated by the environment. +type Environment interface { + // Name returns the name of the environment. This is used to uniquely + // identify the environment of related state machines. + Name() string +} + +// State defines an abstract state along, namely its state transition function +// that takes as input an event and an environment, and returns a state +// transition (next state, and set of events to emit). As state can also either +// be terminal, or not, a terminal event causes state execution to halt. +type State[Event any, Env Environment] interface { + // ProcessEvent takes an event and an environment, and returns a new + // state transition. This will be iteratively called until either a + // terminal state is reached, or no further internal events are + // emitted. + ProcessEvent(event Event, env Env) (*StateTransition[Event, Env], error) + + // IsTerminal returns true if this state is terminal, and false + // otherwise. + IsTerminal() bool + + // TODO(roasbeef): also add state serialization? +} + +// DaemonAdapters is a set of methods that server as adapters to bridge the +// pure world of the FSM to the real world of the daemon. These will be used to +// do things like broadcast transactions, or send messages to peers. +type DaemonAdapters interface { + // SendMessages sends the target set of messages to the target peer. + SendMessages(btcec.PublicKey, []lnwire.Message) error + + // BroadcastTransaction broadcasts a transaction with the target label. + BroadcastTransaction(*wire.MsgTx, string) error + + // RegisterConfirmationsNtfn registers an intent to be notified once + // txid reaches numConfs confirmations. We also pass in the pkScript as + // the default light client instead needs to match on scripts created + // in the block. If a nil txid is passed in, then not only should we + // match on the script, but we should also dispatch once the + // transaction containing the script reaches numConfs confirmations. + // This can be useful in instances where we only know the script in + // advance, but not the transaction containing it. + // + // TODO(roasbeef): could abstract further? + RegisterConfirmationsNtfn(txid *chainhash.Hash, pkScript []byte, + numConfs, heightHint uint32, + opts ...chainntnfs.NotifierOption, + ) (*chainntnfs.ConfirmationEvent, error) + + // RegisterSpendNtfn registers an intent to be notified once the target + // outpoint is successfully spent within a transaction. The script that + // the outpoint creates must also be specified. This allows this + // interface to be implemented by BIP 158-like filtering. + RegisterSpendNtfn(outpoint *wire.OutPoint, pkScript []byte, + heightHint uint32) (*chainntnfs.SpendEvent, error) +} + +// stateQuery is used by outside callers to query the internal state of the +// state machine. +type stateQuery[Event any, Env Environment] struct { + // CurrentState is a channel that will be sent the current state of the + // state machine. + CurrentState chan State[Event, Env] +} + +// StateMachine represents an abstract FSM that is able to process new incoming +// events and drive a state machine to termination. This implementation uses +// type params to abstract over the types of events and environment. Events +// trigger new state transitions, that use the environment to perform some +// action. +// +// TODO(roasbeef): terminal check, daemon event execution, init? +type StateMachine[Event any, Env Environment] struct { + cfg StateMachineCfg[Event, Env] + + // events is the channel that will be used to send new events to the + // FSM. + events chan Event + + // newStateEvents is an EventDistributor that will be used to notify + // any relevant callers of new state transitions that occur. + newStateEvents *fn.EventDistributor[State[Event, Env]] + + // stateQuery is a channel that will be used by outside callers to + // query the internal state machine state. + stateQuery chan stateQuery[Event, Env] + + wg fn.GoroutineManager + quit chan struct{} + + startOnce sync.Once + stopOnce sync.Once +} + +// ErrorReporter is an interface that's used to report errors that occur during +// state machine execution. +type ErrorReporter interface { + // ReportError is a method that's used to report an error that occurred + // during state machine execution. + ReportError(err error) +} + +// StateMachineCfg is a configuration struct that's used to create a new state +// machine. +type StateMachineCfg[Event any, Env Environment] struct { + // ErrorReporter is used to report errors that occur during state + // transitions. + ErrorReporter ErrorReporter + + // Daemon is a set of adapters that will be used to bridge the FSM to + // the daemon. + Daemon DaemonAdapters + + // InitialState is the initial state of the state machine. + InitialState State[Event, Env] + + // Env is the environment that the state machine will use to execute. + Env Env + + // InitEvent is an optional event that will be sent to the state + // machine as if it was emitted at the onset of the state machine. This + // can be used to set up tracking state such as a txid confirmation + // event. + InitEvent fn.Option[DaemonEvent] + + // MsgMapper is an optional message mapper that can be used to map + // normal wire messages into FSM events. + MsgMapper fn.Option[MsgMapper[Event]] + + // CustomPollInterval is an optional custom poll interval that can be + // used to set a quicker interval for tests. + CustomPollInterval fn.Option[time.Duration] +} + +// NewStateMachine creates a new state machine given a set of daemon adapters, +// an initial state, an environment, and an event to process as if emitted at +// the onset of the state machine. Such an event can be used to set up tracking +// state such as a txid confirmation event. +func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env], //nolint:lll +) StateMachine[Event, Env] { + + return StateMachine[Event, Env]{ + cfg: cfg, + events: make(chan Event, 1), + stateQuery: make(chan stateQuery[Event, Env]), + wg: *fn.NewGoroutineManager(context.Background()), + newStateEvents: fn.NewEventDistributor[State[Event, Env]](), + quit: make(chan struct{}), + } +} + +// Start starts the state machine. This will spawn a goroutine that will drive +// the state machine to completion. +func (s *StateMachine[Event, Env]) Start() { + s.startOnce.Do(func() { + _ = s.wg.Go(func(ctx context.Context) { + s.driveMachine() + }) + }) +} + +// Stop stops the state machine. This will block until the state machine has +// reached a stopping point. +func (s *StateMachine[Event, Env]) Stop() { + s.stopOnce.Do(func() { + close(s.quit) + s.wg.Stop() + }) +} + +// SendEvent sends a new event to the state machine. +// +// TODO(roasbeef): bool if processed? +func (s *StateMachine[Event, Env]) SendEvent(event Event) { + log.Debugf("FSM(%v): sending event: %v", s.cfg.Env.Name(), + lnutils.SpewLogClosure(event), + ) + + select { + case s.events <- event: + case <-s.quit: + return + } +} + +// CanHandle returns true if the target message can be routed to the state +// machine. +func (s *StateMachine[Event, Env]) CanHandle(msg lnwire.Message) bool { + cfgMapper := s.cfg.MsgMapper + return fn.MapOptionZ(cfgMapper, func(mapper MsgMapper[Event]) bool { + return mapper.MapMsg(msg).IsSome() + }) +} + +// Name returns the name of the state machine's environment. +func (s *StateMachine[Event, Env]) Name() string { + return s.cfg.Env.Name() +} + +// SendMessage attempts to send a wire message to the state machine. If the +// message can be mapped using the default message mapper, then true is +// returned indicating that the message was processed. Otherwise, false is +// returned. +func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool { + // If we have no message mapper, then return false as we can't process + // this message. + if !s.cfg.MsgMapper.IsSome() { + return false + } + + log.Debugf("FSM(%v): sending msg: %v", s.cfg.Env.Name(), + lnutils.SpewLogClosure(msg), + ) + + // Otherwise, try to map the message using the default message mapper. + // If we can't extract an event, then we'll return false to indicate + // that the message wasn't processed. + var processed bool + s.cfg.MsgMapper.WhenSome(func(mapper MsgMapper[Event]) { + event := mapper.MapMsg(msg) + + event.WhenSome(func(event Event) { + s.SendEvent(event) + + processed = true + }) + }) + + return processed +} + +// CurrentState returns the current state of the state machine. +func (s *StateMachine[Event, Env]) CurrentState() (State[Event, Env], error) { + query := stateQuery[Event, Env]{ + CurrentState: make(chan State[Event, Env], 1), + } + + if !fn.SendOrQuit(s.stateQuery, query, s.quit) { + return nil, fmt.Errorf("state machine is shutting down") + } + + return fn.RecvOrTimeout(query.CurrentState, time.Second) +} + +// StateSubscriber represents an active subscription to be notified of new +// state transitions. +type StateSubscriber[E any, F Environment] *fn.EventReceiver[State[E, F]] + +// RegisterStateEvents registers a new event listener that will be notified of +// new state transitions. +func (s *StateMachine[Event, Env]) RegisterStateEvents() StateSubscriber[ + Event, Env] { + + subscriber := fn.NewEventReceiver[State[Event, Env]](10) + + // TODO(roasbeef): instead give the state and the input event? + + s.newStateEvents.RegisterSubscriber(subscriber) + + return subscriber +} + +// RemoveStateSub removes the target state subscriber from the set of active +// subscribers. +func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[ + Event, Env]) { + + _ = s.newStateEvents.RemoveSubscriber(sub) +} + +// executeDaemonEvent executes a daemon event, which is a special type of event +// that can be emitted as part of the state transition function of the state +// machine. An error is returned if the type of event is unknown. +func (s *StateMachine[Event, Env]) executeDaemonEvent( + event DaemonEvent) error { + + switch daemonEvent := event.(type) { + // This is a send message event, so we'll send the event, and also mind + // any preconditions as well as post-send events. + case *SendMsgEvent[Event]: + sendAndCleanUp := func() error { + log.Debugf("FSM(%v): sending message to target(%x): "+ + "%v", s.cfg.Env.Name(), + daemonEvent.TargetPeer.SerializeCompressed(), + lnutils.SpewLogClosure(daemonEvent.Msgs), + ) + + err := s.cfg.Daemon.SendMessages( + daemonEvent.TargetPeer, daemonEvent.Msgs, + ) + if err != nil { + return fmt.Errorf("unable to send msgs: %w", + err) + } + + // If a post-send event was specified, then we'll funnel + // that back into the main state machine now as well. + return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:lll + return s.wg.Go(func(ctx context.Context) { + log.Debugf("FSM(%v): sending "+ + "post-send event: %v", + s.cfg.Env.Name(), + lnutils.SpewLogClosure(event), + ) + + s.SendEvent(event) + }) + }) + } + + // If this doesn't have a SendWhen predicate, then we can just + // send it off right away. + if !daemonEvent.SendWhen.IsSome() { + return sendAndCleanUp() + } + + // Otherwise, this has a SendWhen predicate, so we'll need + // launch a goroutine to poll the SendWhen, then send only once + // the predicate is true. + return s.wg.Go(func(ctx context.Context) { + predicateTicker := time.NewTicker( + s.cfg.CustomPollInterval.UnwrapOr(pollInterval), + ) + defer predicateTicker.Stop() + + log.Infof("FSM(%v): waiting for send predicate to "+ + "be true", s.cfg.Env.Name()) + + for { + select { + case <-predicateTicker.C: + canSend := fn.MapOptionZ( + daemonEvent.SendWhen, + func(pred SendPredicate) bool { + return pred() + }, + ) + + if canSend { + log.Infof("FSM(%v): send "+ + "active predicate", + s.cfg.Env.Name()) + + err := sendAndCleanUp() + if err != nil { + //nolint:lll + log.Errorf("FSM(%v): unable to send message: %v", err) + } + + return + } + + case <-ctx.Done(): + return + } + } + }) + + // If this is a broadcast transaction event, then we'll broadcast with + // the label attached. + case *BroadcastTxn: + log.Debugf("FSM(%v): broadcasting txn, txid=%v", + s.cfg.Env.Name(), daemonEvent.Tx.TxHash()) + + err := s.cfg.Daemon.BroadcastTransaction( + daemonEvent.Tx, daemonEvent.Label, + ) + if err != nil { + return fmt.Errorf("unable to broadcast txn: %w", err) + } + + return nil + + // The state machine has requested a new event to be sent once a + // transaction spending a specified outpoint has confirmed. + case *RegisterSpend[Event]: + log.Debugf("FSM(%v): registering spend: %v", s.cfg.Env.Name(), + daemonEvent.OutPoint) + + spendEvent, err := s.cfg.Daemon.RegisterSpendNtfn( + &daemonEvent.OutPoint, daemonEvent.PkScript, + daemonEvent.HeightHint, + ) + if err != nil { + return fmt.Errorf("unable to register spend: %w", err) + } + + return s.wg.Go(func(ctx context.Context) { + for { + select { + case spend, ok := <-spendEvent.Spend: + if !ok { + return + } + + // If there's a post-send event, then + // we'll send that into the current + // state now. + postSpend := daemonEvent.PostSpendEvent + postSpend.WhenSome(func(f SpendMapper[Event]) { //nolint:lll + customEvent := f(spend) + s.SendEvent(customEvent) + }) + + return + + case <-ctx.Done(): + return + } + } + }) + + // The state machine has requested a new event to be sent once a + // specified txid+pkScript pair has confirmed. + case *RegisterConf[Event]: + log.Debugf("FSM(%v): registering conf: %v", s.cfg.Env.Name(), + daemonEvent.Txid) + + numConfs := daemonEvent.NumConfs.UnwrapOr(1) + confEvent, err := s.cfg.Daemon.RegisterConfirmationsNtfn( + &daemonEvent.Txid, daemonEvent.PkScript, + numConfs, daemonEvent.HeightHint, + ) + if err != nil { + return fmt.Errorf("unable to register conf: %w", err) + } + + return s.wg.Go(func(ctx context.Context) { + for { + select { + case <-confEvent.Confirmed: + // If there's a post-conf event, then + // we'll send that into the current + // state now. + // + // TODO(roasbeef): refactor to + // dispatchAfterRecv w/ above + postConf := daemonEvent.PostConfEvent + postConf.WhenSome(func(e Event) { + s.SendEvent(e) + }) + + return + + case <-ctx.Done(): + return + } + } + }) + } + + return fmt.Errorf("unknown daemon event: %T", event) +} + +// applyEvents applies a new event to the state machine. This will continue +// until no further events are emitted by the state machine. Along the way, +// we'll also ensure to execute any daemon events that are emitted. +func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env], + newEvent Event) (State[Event, Env], error) { + + log.Debugf("FSM(%v): applying new event", s.cfg.Env.Name(), + lnutils.SpewLogClosure(newEvent), + ) + eventQueue := fn.NewQueue(newEvent) + + // Given the next event to handle, we'll process the event, then add + // any new emitted internal events to our event queue. This continues + // until we reach a terminal state, or we run out of internal events to + // process. + // + //nolint:lll + for nextEvent := eventQueue.Dequeue(); nextEvent.IsSome(); nextEvent = eventQueue.Dequeue() { + err := fn.MapOptionZ(nextEvent, func(event Event) error { + log.Debugf("FSM(%v): processing event: %v", + s.cfg.Env.Name(), + lnutils.SpewLogClosure(event), + ) + + // Apply the state transition function of the current + // state given this new event and our existing env. + transition, err := currentState.ProcessEvent( + event, s.cfg.Env, + ) + if err != nil { + return err + } + + newEvents := transition.NewEvents + err = fn.MapOptionZ(newEvents, func(events EmittedEvent[Event]) error { //nolint:lll + // With the event processed, we'll process any + // new daemon events that were emitted as part + // of this new state transition. + // + //nolint:lll + err := fn.MapOptionZ(events.ExternalEvents, func(dEvents DaemonEventSet) error { + log.Debugf("FSM(%v): processing "+ + "daemon %v daemon events", + s.cfg.Env.Name(), len(dEvents)) + + for _, dEvent := range dEvents { + err := s.executeDaemonEvent( + dEvent, + ) + if err != nil { + return err + } + } + + return nil + }) + if err != nil { + return err + } + + // Next, we'll add any new emitted events to + // our event queue. + // + //nolint:lll + events.InternalEvent.WhenSome(func(es []Event) { + for _, inEvent := range es { + log.Debugf("FSM(%v): adding "+ + "new internal event "+ + "to queue: %v", + s.cfg.Env.Name(), + lnutils.SpewLogClosure( + inEvent, + ), + ) + + eventQueue.Enqueue(inEvent) + } + }) + + return nil + }) + if err != nil { + return err + } + + log.Infof("FSM(%v): state transition: from_state=%T, "+ + "to_state=%T", + s.cfg.Env.Name(), currentState, + transition.NextState) + + // With our events processed, we'll now update our + // internal state. + currentState = transition.NextState + + // Notify our subscribers of the new state transition. + // + // TODO(roasbeef): will only give us the outer state? + // * let FSMs choose which state to emit? + s.newStateEvents.NotifySubscribers(currentState) + + return nil + }) + if err != nil { + return currentState, err + } + } + + return currentState, nil +} + +// driveMachine is the main event loop of the state machine. It accepts any new +// incoming events, and then drives the state machine forward until it reaches +// a terminal state. +func (s *StateMachine[Event, Env]) driveMachine() { + log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name()) + + currentState := s.cfg.InitialState + + // Before we start, if we have an init daemon event specified, then + // we'll handle that now. + err := fn.MapOptionZ(s.cfg.InitEvent, func(event DaemonEvent) error { + return s.executeDaemonEvent(event) + }) + if err != nil { + log.Errorf("unable to execute init event: %w", err) + return + } + + // We just started driving the state machine, so we'll notify our + // subscribers of this starting state. + s.newStateEvents.NotifySubscribers(currentState) + + for { + select { + // We have a new external event, so we'll drive the state + // machine forward until we either run out of internal events, + // or we reach a terminal state. + case newEvent := <-s.events: + newState, err := s.applyEvents(currentState, newEvent) + if err != nil { + s.cfg.ErrorReporter.ReportError(err) + + log.Errorf("unable to apply event: %v", err) + + // An error occurred, so we'll tear down the + // entire state machine as we can't proceed. + go s.Stop() + + return + } + + currentState = newState + + // An outside caller is querying our state, so we'll return the + // latest state. + case stateQuery := <-s.stateQuery: + if !fn.SendOrQuit(stateQuery.CurrentState, currentState, s.quit) { //nolint:lll + return + } + + case <-s.wg.Done(): + return + } + } +} diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go new file mode 100644 index 00000000000..0432f386b70 --- /dev/null +++ b/protofsm/state_machine_test.go @@ -0,0 +1,456 @@ +package protofsm + +import ( + "encoding/hex" + "fmt" + "sync/atomic" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +type dummyEvents interface { + dummy() +} + +type goToFin struct { +} + +func (g *goToFin) dummy() { +} + +type emitInternal struct { +} + +func (e *emitInternal) dummy() { +} + +type daemonEvents struct { +} + +func (s *daemonEvents) dummy() { +} + +type dummyEnv struct { + mock.Mock +} + +func (d *dummyEnv) Name() string { + return "test" +} + +type dummyStateStart struct { + canSend *atomic.Bool +} + +var ( + hexDecode = func(keyStr string) []byte { + keyBytes, _ := hex.DecodeString(keyStr) + return keyBytes + } + pub1, _ = btcec.ParsePubKey(hexDecode( + "02ec95e4e8ad994861b95fc5986eedaac24739e5ea3d0634db4c8ccd44cd" + + "a126ea", + )) + pub2, _ = btcec.ParsePubKey(hexDecode( + "0356167ba3e54ac542e86e906d4186aba9ca0b9df45001c62b753d33fe06" + + "f5b4e8", + )) +) + +func (d *dummyStateStart) ProcessEvent(event dummyEvents, env *dummyEnv, +) (*StateTransition[dummyEvents, *dummyEnv], error) { + + switch event.(type) { + case *goToFin: + return &StateTransition[dummyEvents, *dummyEnv]{ + NextState: &dummyStateFin{}, + }, nil + + // This state will loop back upon itself, but will also emit an event + // to head to the terminal state. + case *emitInternal: + return &StateTransition[dummyEvents, *dummyEnv]{ + NextState: &dummyStateStart{}, + NewEvents: fn.Some(EmittedEvent[dummyEvents]{ + InternalEvent: fn.Some( + []dummyEvents{&goToFin{}}, + ), + }), + }, nil + + // This state will proceed to the terminal state, but will emit all the + // possible daemon events. + case *daemonEvents: + // This send event can only succeed once the bool turns to + // true. After that, then we'll expect another event to take us + // to the final state. + sendEvent := &SendMsgEvent[dummyEvents]{ + TargetPeer: *pub1, + SendWhen: fn.Some(func() bool { + return d.canSend.Load() + }), + PostSendEvent: fn.Some(dummyEvents(&goToFin{})), + } + + // We'll also send out a normal send event that doesn't have + // any preconditions. + sendEvent2 := &SendMsgEvent[dummyEvents]{ + TargetPeer: *pub2, + } + + return &StateTransition[dummyEvents, *dummyEnv]{ + // We'll state in this state until the send succeeds + // based on our predicate. Then it'll transition to the + // final state. + NextState: &dummyStateStart{ + canSend: d.canSend, + }, + NewEvents: fn.Some(EmittedEvent[dummyEvents]{ + ExternalEvents: fn.Some(DaemonEventSet{ + sendEvent, sendEvent2, + &BroadcastTxn{ + Tx: &wire.MsgTx{}, + Label: "test", + }, + }), + }), + }, nil + } + + return nil, fmt.Errorf("unknown event: %T", event) +} + +func (d *dummyStateStart) IsTerminal() bool { + return false +} + +type dummyStateFin struct { +} + +func (d *dummyStateFin) ProcessEvent(event dummyEvents, env *dummyEnv, +) (*StateTransition[dummyEvents, *dummyEnv], error) { + + return &StateTransition[dummyEvents, *dummyEnv]{ + NextState: &dummyStateFin{}, + }, nil +} + +func (d *dummyStateFin) IsTerminal() bool { + return true +} + +func assertState[Event any, Env Environment](t *testing.T, + m *StateMachine[Event, Env], expectedState State[Event, Env]) { + + state, err := m.CurrentState() + require.NoError(t, err) + require.IsType(t, expectedState, state) +} + +func assertStateTransitions[Event any, Env Environment]( + t *testing.T, stateSub StateSubscriber[Event, Env], + expectedStates []State[Event, Env]) { + + for _, expectedState := range expectedStates { + newState := <-stateSub.NewItemCreated.ChanOut() + + require.IsType(t, expectedState, newState) + } +} + +type dummyAdapters struct { + mock.Mock + + confChan chan *chainntnfs.TxConfirmation + spendChan chan *chainntnfs.SpendDetail +} + +func newDaemonAdapters() *dummyAdapters { + return &dummyAdapters{ + confChan: make(chan *chainntnfs.TxConfirmation, 1), + spendChan: make(chan *chainntnfs.SpendDetail, 1), + } +} + +func (d *dummyAdapters) SendMessages(pub btcec.PublicKey, + msgs []lnwire.Message) error { + + args := d.Called(pub, msgs) + + return args.Error(0) +} + +func (d *dummyAdapters) BroadcastTransaction(tx *wire.MsgTx, + label string) error { + + args := d.Called(tx, label) + + return args.Error(0) +} + +func (d *dummyAdapters) RegisterConfirmationsNtfn(txid *chainhash.Hash, + pkScript []byte, numConfs, heightHint uint32, + opts ...chainntnfs.NotifierOption, +) (*chainntnfs.ConfirmationEvent, error) { + + args := d.Called(txid, pkScript, numConfs) + + err := args.Error(0) + + return &chainntnfs.ConfirmationEvent{ + Confirmed: d.confChan, + }, err +} + +func (d *dummyAdapters) RegisterSpendNtfn(outpoint *wire.OutPoint, + pkScript []byte, heightHint uint32) (*chainntnfs.SpendEvent, error) { + + args := d.Called(outpoint, pkScript, heightHint) + + err := args.Error(0) + + return &chainntnfs.SpendEvent{ + Spend: d.spendChan, + }, err +} + +// TestStateMachineOnInitDaemonEvent tests that the state machine will properly +// execute any init-level daemon events passed into it. +func TestStateMachineOnInitDaemonEvent(t *testing.T) { + // First, we'll create our state machine given the env, and our + // starting state. + env := &dummyEnv{} + startingState := &dummyStateStart{} + + adapters := newDaemonAdapters() + + // We'll make an init event that'll send to a peer, then transition us + // to our terminal state. + initEvent := &SendMsgEvent[dummyEvents]{ + TargetPeer: *pub1, + PostSendEvent: fn.Some(dummyEvents(&goToFin{})), + } + + cfg := StateMachineCfg[dummyEvents, *dummyEnv]{ + Daemon: adapters, + InitialState: startingState, + Env: env, + InitEvent: fn.Some[DaemonEvent](initEvent), + } + stateMachine := NewStateMachine(cfg) + + // Before we start up the state machine, we'll assert that the send + // message adapter is called on start up. + adapters.On("SendMessages", *pub1, mock.Anything).Return(nil) + + stateMachine.Start() + defer stateMachine.Stop() + + // As we're triggering internal events, we'll also subscribe to the set + // of new states so we can assert as we go. + stateSub := stateMachine.RegisterStateEvents() + defer stateMachine.RemoveStateSub(stateSub) + + // Assert that we go from the starting state to the final state. The + // state machine should now also be on the final terminal state. + expectedStates := []State[dummyEvents, *dummyEnv]{ + &dummyStateStart{}, &dummyStateFin{}, + } + assertStateTransitions(t, stateSub, expectedStates) + + // We'll now assert that after the daemon was started, the send message + // adapter was called above as specified in the init event. + adapters.AssertExpectations(t) + env.AssertExpectations(t) +} + +// TestStateMachineInternalEvents tests that the state machine is able to add +// new internal events to the event queue for further processing during a state +// transition. +func TestStateMachineInternalEvents(t *testing.T) { + t.Parallel() + + // First, we'll create our state machine given the env, and our + // starting state. + env := &dummyEnv{} + startingState := &dummyStateStart{} + + adapters := newDaemonAdapters() + + cfg := StateMachineCfg[dummyEvents, *dummyEnv]{ + Daemon: adapters, + InitialState: startingState, + Env: env, + InitEvent: fn.None[DaemonEvent](), + } + stateMachine := NewStateMachine(cfg) + stateMachine.Start() + defer stateMachine.Stop() + + // As we're triggering internal events, we'll also subscribe to the set + // of new states so we can assert as we go. + stateSub := stateMachine.RegisterStateEvents() + defer stateMachine.RemoveStateSub(stateSub) + + // For this transition, we'll send in the emitInternal event, which'll + // send us back to the starting event, but emit an internal event. + stateMachine.SendEvent(&emitInternal{}) + + // We'll now also assert the path we took to get here to ensure the + // internal events were processed. + expectedStates := []State[dummyEvents, *dummyEnv]{ + &dummyStateStart{}, &dummyStateStart{}, &dummyStateFin{}, + } + assertStateTransitions( + t, stateSub, expectedStates, + ) + + // We should ultimately end up in the terminal state. + assertState[dummyEvents, *dummyEnv](t, &stateMachine, &dummyStateFin{}) + + // Make sure all the env expectations were met. + env.AssertExpectations(t) +} + +// TestStateMachineDaemonEvents tests that the state machine is able to process +// daemon emitted as part of the state transition process. +func TestStateMachineDaemonEvents(t *testing.T) { + t.Parallel() + + // First, we'll create our state machine given the env, and our + // starting state. + env := &dummyEnv{} + + var boolTrigger atomic.Bool + startingState := &dummyStateStart{ + canSend: &boolTrigger, + } + + adapters := newDaemonAdapters() + + cfg := StateMachineCfg[dummyEvents, *dummyEnv]{ + Daemon: adapters, + InitialState: startingState, + Env: env, + InitEvent: fn.None[DaemonEvent](), + } + stateMachine := NewStateMachine(cfg) + stateMachine.Start() + defer stateMachine.Stop() + + // As we're triggering internal events, we'll also subscribe to the set + // of new states so we can assert as we go. + stateSub := stateMachine.RegisterStateEvents() + defer stateMachine.RemoveStateSub(stateSub) + + // As soon as we send in the daemon event, we expect the + // disable+broadcast events to be processed, as they are unconditional. + adapters.On( + "BroadcastTransaction", mock.Anything, mock.Anything, + ).Return(nil) + adapters.On("SendMessages", *pub2, mock.Anything).Return(nil) + + // We'll start off by sending in the daemon event, which'll trigger the + // state machine to execute the series of daemon events. + stateMachine.SendEvent(&daemonEvents{}) + + // We should transition back to the starting state now, after we + // started from the very same state. + expectedStates := []State[dummyEvents, *dummyEnv]{ + &dummyStateStart{}, &dummyStateStart{}, + } + assertStateTransitions(t, stateSub, expectedStates) + + // At this point, we expect that the two methods above were called. + adapters.AssertExpectations(t) + + // However, we don't expect the SendMessages for the first peer target + // to be called yet, as the condition hasn't yet been met. + adapters.AssertNotCalled(t, "SendMessages", *pub1) + + // We'll now flip the bool to true, which should cause the SendMessages + // method to be called, and for us to transition to the final state. + boolTrigger.Store(true) + adapters.On("SendMessages", *pub1, mock.Anything).Return(nil) + + expectedStates = []State[dummyEvents, *dummyEnv]{&dummyStateFin{}} + assertStateTransitions(t, stateSub, expectedStates) + + adapters.AssertExpectations(t) + env.AssertExpectations(t) +} + +type dummyMsgMapper struct { + mock.Mock +} + +func (d *dummyMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[dummyEvents] { + args := d.Called(wireMsg) + + //nolint:forcetypeassert + return args.Get(0).(fn.Option[dummyEvents]) +} + +// TestStateMachineMsgMapper tests that given a message mapper, we can properly +// send in wire messages get mapped to FSM events. +func TestStateMachineMsgMapper(t *testing.T) { + // First, we'll create our state machine given the env, and our + // starting state. + env := &dummyEnv{} + startingState := &dummyStateStart{} + adapters := newDaemonAdapters() + + // We'll also provide a message mapper that only knows how to map a + // single wire message (error). + dummyMapper := &dummyMsgMapper{} + + // The only thing we know how to map is the error message, which'll + // terminate the state machine. + wireError := &lnwire.Error{} + initMsg := &lnwire.Init{} + dummyMapper.On("MapMsg", wireError).Return( + fn.Some(dummyEvents(&goToFin{})), + ) + dummyMapper.On("MapMsg", initMsg).Return(fn.None[dummyEvents]()) + + cfg := StateMachineCfg[dummyEvents, *dummyEnv]{ + Daemon: adapters, + InitialState: startingState, + Env: env, + MsgMapper: fn.Some[MsgMapper[dummyEvents]](dummyMapper), + } + stateMachine := NewStateMachine(cfg) + stateMachine.Start() + defer stateMachine.Stop() + + // As we're triggering internal events, we'll also subscribe to the set + // of new states so we can assert as we go. + stateSub := stateMachine.RegisterStateEvents() + defer stateMachine.RemoveStateSub(stateSub) + + // First, we'll verify that the CanHandle method works as expected. + require.True(t, stateMachine.CanHandle(wireError)) + require.False(t, stateMachine.CanHandle(&lnwire.Init{})) + + // Next, we'll attempt to send the wire message into the state machine. + // We should transition to the final state. + require.True(t, stateMachine.SendMessage(wireError)) + + // We should transition to the final state. + expectedStates := []State[dummyEvents, *dummyEnv]{ + &dummyStateStart{}, &dummyStateFin{}, + } + assertStateTransitions(t, stateSub, expectedStates) + + dummyMapper.AssertExpectations(t) + adapters.AssertExpectations(t) + env.AssertExpectations(t) +}