diff --git a/cbor_gen.go b/cbor_gen.go new file mode 100644 index 00000000..adfad6f5 --- /dev/null +++ b/cbor_gen.go @@ -0,0 +1,122 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package f3 + +import ( + "fmt" + "io" + "math" + "sort" + + gpbft "github.com/filecoin-project/go-f3/gpbft" + cid "github.com/ipfs/go-cid" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf +var _ = cid.Undef +var _ = math.E +var _ = sort.Sort + +var lengthBufPartialGMessage = []byte{130} + +func (t *PartialGMessage) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + cw := cbg.NewCborWriter(w) + + if _, err := cw.Write(lengthBufPartialGMessage); err != nil { + return err + } + + // t.GMessage (gpbft.GMessage) (struct) + if err := t.GMessage.MarshalCBOR(cw); err != nil { + return err + } + + // t.VoteValueKey (chainexchange.Key) (slice) + if len(t.VoteValueKey) > 32 { + return xerrors.Errorf("Byte array in field t.VoteValueKey was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajByteString, uint64(len(t.VoteValueKey))); err != nil { + return err + } + + if _, err := cw.Write(t.VoteValueKey); err != nil { + return err + } + + return nil +} + +func (t *PartialGMessage) UnmarshalCBOR(r io.Reader) (err error) { + *t = PartialGMessage{} + + cr := cbg.NewCborReader(r) + + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + if maj != cbg.MajArray { + return fmt.Errorf("cbor input should be of type array") + } + + if extra != 2 { + return fmt.Errorf("cbor input had wrong number of fields") + } + + // t.GMessage (gpbft.GMessage) (struct) + + { + + b, err := cr.ReadByte() + if err != nil { + return err + } + if b != cbg.CborNull[0] { + if err := cr.UnreadByte(); err != nil { + return err + } + t.GMessage = new(gpbft.GMessage) + if err := t.GMessage.UnmarshalCBOR(cr); err != nil { + return xerrors.Errorf("unmarshaling t.GMessage pointer: %w", err) + } + } + + } + // t.VoteValueKey (chainexchange.Key) (slice) + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + + if extra > 32 { + return fmt.Errorf("t.VoteValueKey: byte array too large (%d)", extra) + } + if maj != cbg.MajByteString { + return fmt.Errorf("expected byte array") + } + + if extra > 0 { + t.VoteValueKey = make([]uint8, extra) + } + + if _, err := io.ReadFull(cr, t.VoteValueKey); err != nil { + return err + } + + return nil +} diff --git a/chainexchange/pubsub.go b/chainexchange/pubsub.go index a16feb15..83f15ca5 100644 --- a/chainexchange/pubsub.go +++ b/chainexchange/pubsub.go @@ -32,11 +32,12 @@ type PubSubChainExchange struct { *options // mu guards access to chains and API calls. - mu sync.Mutex - chainsWanted map[uint64]*lru.Cache[string, *chainPortion] - chainsDiscovered map[uint64]*lru.Cache[string, *chainPortion] - topic *pubsub.Topic - stop func() error + mu sync.Mutex + chainsWanted map[uint64]*lru.Cache[string, *chainPortion] + chainsDiscovered map[uint64]*lru.Cache[string, *chainPortion] + pendingCacheAsWanted chan Message + topic *pubsub.Topic + stop func() error } func NewPubSubChainExchange(o ...Option) (*PubSubChainExchange, error) { @@ -45,9 +46,10 @@ func NewPubSubChainExchange(o ...Option) (*PubSubChainExchange, error) { return nil, err } return &PubSubChainExchange{ - options: opts, - chainsWanted: map[uint64]*lru.Cache[string, *chainPortion]{}, - chainsDiscovered: map[uint64]*lru.Cache[string, *chainPortion]{}, + options: opts, + chainsWanted: map[uint64]*lru.Cache[string, *chainPortion]{}, + chainsDiscovered: map[uint64]*lru.Cache[string, *chainPortion]{}, + pendingCacheAsWanted: make(chan Message, 100), // TODO: parameterise. }, nil } @@ -64,7 +66,9 @@ func (p *PubSubChainExchange) Start(ctx context.Context) error { } if p.topicScoreParams != nil { if err := p.topic.SetScoreParams(p.topicScoreParams); err != nil { - return fmt.Errorf("failed to set score params: %w", err) + // This can happen most likely due to router not supporting peer scoring. It's + // non-critical. Hence, the warning log. + log.Warnw("failed to set topic score params", "err", err) } } subscription, err := p.topic.Subscribe(pubsub.WithBufferSize(p.subscriptionBufferSize)) @@ -79,17 +83,31 @@ func (p *PubSubChainExchange) Start(ctx context.Context) error { for ctx.Err() == nil { msg, err := subscription.Next(ctx) if err != nil { - log.Debugw("failed to read nex message from subscription", "err", err) + log.Debugw("failed to read next message from subscription", "err", err) continue } cmsg := msg.ValidatorData.(Message) p.cacheAsDiscoveredChain(ctx, cmsg) } + log.Debug("Stopped reading messages from chainexchange subscription.") + }() + go func() { + for ctx.Err() == nil { + select { + case <-ctx.Done(): + return + case cmsg := <-p.pendingCacheAsWanted: + p.cacheAsWantedChain(ctx, cmsg) + } + } + log.Debug("Stopped caching chains as wanted.") }() p.stop = func() error { cancel() subscription.Cancel() - return p.topic.Close() + _ = p.pubsub.UnregisterTopicValidator(p.topicName) + _ = p.topic.Close() + return nil } return nil } @@ -124,21 +142,18 @@ func (p *PubSubChainExchange) GetChainByInstance(ctx context.Context, instance u cacheKey := string(key) // Check wanted keys first. - p.mu.Lock() + wanted := p.getChainsWantedAt(instance) - p.mu.Unlock() if portion, found := wanted.Get(cacheKey); found && !portion.IsPlaceholder() { return portion.chain, true } // Check if the chain for the key is discovered. - p.mu.Lock() discovered := p.getChainsDiscoveredAt(instance) if portion, found := discovered.Get(cacheKey); found { // Add it to the wanted cache and remove it from the discovered cache. wanted.Add(cacheKey, portion) discovered.Remove(cacheKey) - p.mu.Unlock() chain := portion.chain if p.listener != nil { @@ -147,7 +162,6 @@ func (p *PubSubChainExchange) GetChainByInstance(ctx context.Context, instance u // TODO: Do we want to pull all the suffixes of the chain into wanted cache? return chain, true } - p.mu.Unlock() // Otherwise, add a placeholder for the wanted key as a way to prioritise its // retention via LRU recent-ness. @@ -156,6 +170,8 @@ func (p *PubSubChainExchange) GetChainByInstance(ctx context.Context, instance u } func (p *PubSubChainExchange) getChainsWantedAt(instance uint64) *lru.Cache[string, *chainPortion] { + p.mu.Lock() + defer p.mu.Unlock() wanted, exists := p.chainsWanted[instance] if !exists { wanted = p.newChainPortionCache(p.maxWantedChainsPerInstance) @@ -165,6 +181,8 @@ func (p *PubSubChainExchange) getChainsWantedAt(instance uint64) *lru.Cache[stri } func (p *PubSubChainExchange) getChainsDiscoveredAt(instance uint64) *lru.Cache[string, *chainPortion] { + p.mu.Lock() + defer p.mu.Unlock() discovered, exists := p.chainsDiscovered[instance] if !exists { discovered = p.newChainPortionCache(p.maxDiscoveredChainsPerInstance) @@ -208,8 +226,6 @@ func (p *PubSubChainExchange) validatePubSubMessage(_ context.Context, _ peer.ID } func (p *PubSubChainExchange) cacheAsDiscoveredChain(ctx context.Context, cmsg Message) { - p.mu.Lock() - defer p.mu.Unlock() wanted := p.getChainsDiscoveredAt(cmsg.Instance) discovered := p.getChainsDiscoveredAt(cmsg.Instance) @@ -245,7 +261,13 @@ func (p *PubSubChainExchange) cacheAsDiscoveredChain(ctx context.Context, cmsg M func (p *PubSubChainExchange) Broadcast(ctx context.Context, msg Message) error { // Optimistically cache the broadcast chain and all of its prefixes as wanted. - p.cacheAsWantedChain(ctx, msg) + select { + case p.pendingCacheAsWanted <- msg: + case <-ctx.Done(): + return ctx.Err() + default: + log.Warnw("Dropping wanted cache entry. Chain exchange is too slow to process chains as wanted", "msg", msg) + } // TODO: integrate zstd compression. var buf bytes.Buffer @@ -266,7 +288,6 @@ type discovery struct { func (p *PubSubChainExchange) cacheAsWantedChain(ctx context.Context, cmsg Message) { var notifications []discovery - p.mu.Lock() wanted := p.getChainsWantedAt(cmsg.Instance) for offset := len(cmsg.Chain); offset >= 0 && ctx.Err() == nil; offset-- { // TODO: Expose internals of merkle.go so that keys can be generated @@ -290,7 +311,6 @@ func (p *PubSubChainExchange) cacheAsWantedChain(ctx context.Context, cmsg Messa // been evicted from the cache or not. This should be cheap enough considering the // added complexity of tracking evictions relative to chain prefixes. } - p.mu.Unlock() // Notify the listener outside the lock. if p.listener != nil { diff --git a/chainexchange/pubsub_test.go b/chainexchange/pubsub_test.go index 424384f5..b6ce2ba3 100644 --- a/chainexchange/pubsub_test.go +++ b/chainexchange/pubsub_test.go @@ -2,6 +2,8 @@ package chainexchange_test import ( "context" + "slices" + "sync" "testing" "time" @@ -52,32 +54,38 @@ func TestPubSubChainExchange_Broadcast(t *testing.T) { chain, found := subject.GetChainByInstance(ctx, instance, key) require.False(t, found) require.Nil(t, chain) - require.Empty(t, testListener.notifications) + require.Empty(t, testListener.getNotifications()) require.NoError(t, subject.Broadcast(ctx, chainexchange.Message{ Instance: instance, Chain: ecChain, })) - chain, found = subject.GetChainByInstance(ctx, instance, key) - require.True(t, found) + require.Eventually(t, func() bool { + chain, found = subject.GetChainByInstance(ctx, instance, key) + return found + }, time.Second, 100*time.Millisecond) require.Equal(t, ecChain, chain) baseChain := ecChain.BaseChain() baseKey := subject.Key(baseChain) - chain, found = subject.GetChainByInstance(ctx, instance, baseKey) - require.True(t, found) + require.Eventually(t, func() bool { + chain, found = subject.GetChainByInstance(ctx, instance, baseKey) + return found + }, time.Second, 100*time.Millisecond) require.Equal(t, baseChain, chain) // Assert that we have received 2 notifications, because ecChain has 2 tipsets. // First should be the ecChain, second should be the baseChain. - require.Len(t, testListener.notifications, 2) - require.Equal(t, instance, testListener.notifications[1].instance) - require.Equal(t, baseKey, testListener.notifications[1].key) - require.Equal(t, baseChain, testListener.notifications[1].chain) - require.Equal(t, instance, testListener.notifications[0].instance) - require.Equal(t, key, testListener.notifications[0].key) - require.Equal(t, ecChain, testListener.notifications[0].chain) + + notifications := testListener.getNotifications() + require.Len(t, notifications, 2) + require.Equal(t, instance, notifications[1].instance) + require.Equal(t, baseKey, notifications[1].key) + require.Equal(t, baseChain, notifications[1].chain) + require.Equal(t, instance, notifications[0].instance) + require.Equal(t, key, notifications[0].key) + require.Equal(t, ecChain, notifications[0].chain) require.NoError(t, subject.Shutdown(ctx)) } @@ -88,13 +96,22 @@ type notification struct { chain gpbft.ECChain } type listener struct { + mu sync.Mutex notifications []notification } func (l *listener) NotifyChainDiscovered(_ context.Context, key chainexchange.Key, instance uint64, chain gpbft.ECChain) { + l.mu.Lock() + defer l.mu.Unlock() l.notifications = append(l.notifications, notification{key: key, instance: instance, chain: chain}) } +func (l *listener) getNotifications() []notification { + l.mu.Lock() + defer l.mu.Unlock() + return slices.Clone(l.notifications) +} + // TODO: Add more tests, specifically: // - validation // - discovery through other chainexchange instance diff --git a/f3_test.go b/f3_test.go index 69cd8b4a..c8b745ee 100644 --- a/f3_test.go +++ b/f3_test.go @@ -548,7 +548,7 @@ func (e *testEnv) waitForEpochFinalized(epoch int64) { } } return false - }, 30*time.Second) + }, 60*time.Second) } if head < epoch-100 { diff --git a/gen/main.go b/gen/main.go index 55a72f1a..c5212684 100644 --- a/gen/main.go +++ b/gen/main.go @@ -4,6 +4,7 @@ import ( "fmt" "os" + "github.com/filecoin-project/go-f3" "github.com/filecoin-project/go-f3/certexchange" "github.com/filecoin-project/go-f3/certs" "github.com/filecoin-project/go-f3/chainexchange" @@ -45,6 +46,11 @@ func main() { chainexchange.Message{}, ) }) + eg.Go(func() error { + return gen.WriteTupleEncodersToFile("../cbor_gen.go", "f3", + f3.PartialGMessage{}, + ) + }) if err := eg.Wait(); err != nil { fmt.Printf("Failed to complete cborg_gen: %v\n", err) os.Exit(1) diff --git a/host.go b/host.go index c7f8ed60..a5dd1efc 100644 --- a/host.go +++ b/host.go @@ -52,6 +52,7 @@ type gpbftRunner struct { inputs gpbftInputs msgEncoding gMessageEncoding + pmm *partialMessageManager } type roundPhase struct { @@ -141,6 +142,12 @@ func newRunner( } else { runner.msgEncoding = &cborGMessageEncoding{} } + + runner.pmm, err = newPartialMessageManager(runner.Progress, ps, m) + if err != nil { + return nil, fmt.Errorf("creating partial message manager: %w", err) + } + return runner, nil } @@ -156,6 +163,11 @@ func (h *gpbftRunner) Start(ctx context.Context) (_err error) { return err } + completedMessageQueue, err := h.pmm.Start(ctx) + if err != nil { + return err + } + finalityCertificates, unsubCerts := h.certStore.Subscribe() select { case c := <-finalityCertificates: @@ -193,7 +205,7 @@ func (h *gpbftRunner) Start(ctx context.Context) (_err error) { default: } - // Handle messages, finality certificates, and alarms + // Handle messages, completed messages, finality certificates, and alarms select { case c := <-finalityCertificates: if err := h.receiveCertificate(c); err != nil { @@ -219,6 +231,29 @@ func (h *gpbftRunner) Start(ctx context.Context) (_err error) { // errors. log.Errorf("error when processing message: %+v", err) } + case gmsg, ok := <-completedMessageQueue: + if !ok { + return fmt.Errorf("incoming completed message queue closed") + } + switch validatedMessage, err := h.participant.ValidateMessage(gmsg); { + case errors.Is(err, gpbft.ErrValidationInvalid): + log.Debugw("validation error while validating completed message", "err", err) + // TODO: Signal partial message manager to penalise sender, + // e.g. reduce the total number of messages stroed from sender? + case errors.Is(err, gpbft.ErrValidationTooOld): + // TODO: Signal partial message manager to drop the instance? + case errors.Is(err, gpbft.ErrValidationNotRelevant): + // TODO: Signal partial message manager to drop irrelevant messages? + case errors.Is(err, gpbft.ErrValidationNoCommittee): + log.Debugw("committee error while validating completed message", "err", err) + case err != nil: + log.Errorw("unknown error while validating completed message", "err", err) + default: + recordValidatedMessage(ctx, validatedMessage) + if err := h.participant.ReceiveMessage(validatedMessage); err != nil { + log.Errorw("error while processing completed message", "err", err) + } + } case <-h.runningCtx.Done(): return nil } @@ -452,7 +487,17 @@ func (h *gpbftRunner) BroadcastMessage(ctx context.Context, msg *gpbft.GMessage) if h.topic == nil { return pubsub.ErrTopicClosed } - encoded, err := h.msgEncoding.Encode(msg) + + if err := h.pmm.BroadcastChain(ctx, msg.Vote.Instance, msg.Vote.Value); err != nil { + // Silently log the error and continue. Partial message manager should take care of re-broadcast. + log.Warnw("failed to broadcast chain", "instance", msg.Vote.Instance, "error", err) + } + + pmsg, err := h.pmm.toPartialGMessage(msg) + if err != nil { + return err + } + encoded, err := h.msgEncoding.Encode(pmsg) if err != nil { return fmt.Errorf("encoding GMessage for broadcast: %w", err) } @@ -472,7 +517,17 @@ func (h *gpbftRunner) rebroadcastMessage(msg *gpbft.GMessage) error { if h.topic == nil { return pubsub.ErrTopicClosed } - encoded, err := h.msgEncoding.Encode(msg) + + if err := h.pmm.BroadcastChain(h.runningCtx, msg.Vote.Instance, msg.Vote.Value); err != nil { + // Silently log the error and continue. Partial message manager should take care of re-broadcast. + log.Warnw("failed to rebroadcast chain", "instance", msg.Vote.Instance, "error", err) + } + + pmsg, err := h.pmm.toPartialGMessage(msg) + if err != nil { + return err + } + encoded, err := h.msgEncoding.Encode(pmsg) if err != nil { return fmt.Errorf("encoding GMessage for broadcast: %w", err) } @@ -489,12 +544,28 @@ func (h *gpbftRunner) validatePubsubMessage(ctx context.Context, _ peer.ID, msg recordValidationTime(ctx, start, _result) }(time.Now()) - gmsg, err := h.msgEncoding.Decode(msg.Data) + pgmsg, err := h.msgEncoding.Decode(msg.Data) if err != nil { log.Debugw("failed to decode message", "from", msg.GetFrom(), "err", err) return pubsub.ValidationReject } + gmsg, completed := h.pmm.CompleteMessage(ctx, pgmsg) + if !completed { + // TODO: Partially validate the message because we can. To do this, however, + // message validator needs to be refactored to tolerate partial data. + // Hence, for now validation is postponed entirely until that refactor + // is done to accommodate partial messages. + // See: https://github.com/filecoin-project/go-f3/issues/813 + + // FIXME: must verify signature before buffering otherwise nodes can spoof the + // buffer with invalid messages on behalf of other peers as censorship + // attack. + + msg.ValidatorData = pgmsg + return pubsub.ValidationAccept + } + switch validatedMessage, err := h.participant.ValidateMessage(gmsg); { case errors.Is(err, gpbft.ErrValidationInvalid): log.Debugf("validation error during validation: %+v", err) @@ -588,15 +659,18 @@ func (h *gpbftRunner) startPubsub() (<-chan gpbft.ValidatedMessage, error) { } return fmt.Errorf("pubsub message subscription returned an error: %w", err) } - gmsg, ok := msg.ValidatorData.(gpbft.ValidatedMessage) - if !ok { + + switch gmsg := msg.ValidatorData.(type) { + case gpbft.ValidatedMessage: + select { + case messageQueue <- gmsg: + case <-h.runningCtx.Done(): + return nil + } + case *PartialGMessage: + h.pmm.bufferPartialMessage(h.runningCtx, gmsg) + default: log.Errorf("invalid msgValidatorData: %+v", msg.ValidatorData) - continue - } - select { - case messageQueue <- gmsg: - case <-h.runningCtx.Done(): - return nil } } return nil @@ -632,18 +706,25 @@ func (h *gpbftHost) RequestRebroadcast(instant gpbft.Instant) error { } func (h *gpbftHost) GetProposal(instance uint64) (*gpbft.SupplementalData, gpbft.ECChain, error) { - return h.inputs.GetProposal(h.runningCtx, instance) + proposal, chain, err := h.inputs.GetProposal(h.runningCtx, instance) + if err == nil { + if err := h.pmm.BroadcastChain(h.runningCtx, instance, chain); err != nil { + log.Warnw("failed to broadcast chain", "instance", instance, "error", err) + } + } + return proposal, chain, err } func (h *gpbftHost) GetCommittee(instance uint64) (*gpbft.Committee, error) { return h.inputs.GetCommittee(h.runningCtx, instance) } -func (h *gpbftRunner) Stop(context.Context) error { +func (h *gpbftRunner) Stop(ctx context.Context) error { h.ctxCancel() return multierr.Combine( h.wal.Close(), h.errgrp.Wait(), + h.pmm.Shutdown(ctx), h.teardownPubsub(), ) } diff --git a/msg_encoding.go b/msg_encoding.go index 65b903b0..250ff05c 100644 --- a/msg_encoding.go +++ b/msg_encoding.go @@ -3,7 +3,6 @@ package f3 import ( "bytes" - "github.com/filecoin-project/go-f3/gpbft" "github.com/klauspost/compress/zstd" ) @@ -13,13 +12,13 @@ var ( ) type gMessageEncoding interface { - Encode(*gpbft.GMessage) ([]byte, error) - Decode([]byte) (*gpbft.GMessage, error) + Encode(message *PartialGMessage) ([]byte, error) + Decode([]byte) (*PartialGMessage, error) } type cborGMessageEncoding struct{} -func (c *cborGMessageEncoding) Encode(m *gpbft.GMessage) ([]byte, error) { +func (c *cborGMessageEncoding) Encode(m *PartialGMessage) ([]byte, error) { var buf bytes.Buffer if err := m.MarshalCBOR(&buf); err != nil { return nil, err @@ -27,9 +26,9 @@ func (c *cborGMessageEncoding) Encode(m *gpbft.GMessage) ([]byte, error) { return buf.Bytes(), nil } -func (c *cborGMessageEncoding) Decode(v []byte) (*gpbft.GMessage, error) { +func (c *cborGMessageEncoding) Decode(v []byte) (*PartialGMessage, error) { r := bytes.NewReader(v) - var msg gpbft.GMessage + var msg PartialGMessage if err := msg.UnmarshalCBOR(r); err != nil { return nil, err } @@ -57,7 +56,7 @@ func newZstdGMessageEncoding() (*zstdGMessageEncoding, error) { }, nil } -func (c *zstdGMessageEncoding) Encode(m *gpbft.GMessage) ([]byte, error) { +func (c *zstdGMessageEncoding) Encode(m *PartialGMessage) ([]byte, error) { cborEncoded, err := c.cborEncoding.Encode(m) if err != nil { return nil, err @@ -66,7 +65,7 @@ func (c *zstdGMessageEncoding) Encode(m *gpbft.GMessage) ([]byte, error) { return compressed, err } -func (c *zstdGMessageEncoding) Decode(v []byte) (*gpbft.GMessage, error) { +func (c *zstdGMessageEncoding) Decode(v []byte) (*PartialGMessage, error) { cborEncoded, err := c.decompressor.DecodeAll(v, make([]byte, 0, len(v))) if err != nil { return nil, err diff --git a/msg_encoding_test.go b/msg_encoding_test.go index fb83e2e2..f2163934 100644 --- a/msg_encoding_test.go +++ b/msg_encoding_test.go @@ -16,7 +16,7 @@ const seed = 1413 func BenchmarkCborEncoding(b *testing.B) { rng := rand.New(rand.NewSource(seed)) encoder := &cborGMessageEncoding{} - msg := generateRandomGMessage(b, rng) + msg := generateRandomPartialGMessage(b, rng) b.ResetTimer() b.ReportAllocs() @@ -32,7 +32,7 @@ func BenchmarkCborEncoding(b *testing.B) { func BenchmarkCborDecoding(b *testing.B) { rng := rand.New(rand.NewSource(seed)) encoder := &cborGMessageEncoding{} - msg := generateRandomGMessage(b, rng) + msg := generateRandomPartialGMessage(b, rng) data, err := encoder.Encode(msg) require.NoError(b, err) @@ -52,7 +52,7 @@ func BenchmarkZstdEncoding(b *testing.B) { rng := rand.New(rand.NewSource(seed)) encoder, err := newZstdGMessageEncoding() require.NoError(b, err) - msg := generateRandomGMessage(b, rng) + msg := generateRandomPartialGMessage(b, rng) b.ResetTimer() b.ReportAllocs() @@ -69,7 +69,7 @@ func BenchmarkZstdDecoding(b *testing.B) { rng := rand.New(rand.NewSource(seed)) encoder, err := newZstdGMessageEncoding() require.NoError(b, err) - msg := generateRandomGMessage(b, rng) + msg := generateRandomPartialGMessage(b, rng) data, err := encoder.Encode(msg) require.NoError(b, err) @@ -85,6 +85,17 @@ func BenchmarkZstdDecoding(b *testing.B) { }) } +func generateRandomPartialGMessage(b *testing.B, rng *rand.Rand) *PartialGMessage { + var pgmsg PartialGMessage + pgmsg.GMessage = generateRandomGMessage(b, rng) + pgmsg.GMessage.Vote.Value = nil + if pgmsg.Justification != nil { + pgmsg.GMessage.Justification.Vote.Value = nil + } + pgmsg.VoteValueKey = generateRandomBytes(b, rng, 32) + return &pgmsg +} + func generateRandomGMessage(b *testing.B, rng *rand.Rand) *gpbft.GMessage { var maybeTicket []byte if rng.Float64() < 0.5 { diff --git a/partial_msg.go b/partial_msg.go new file mode 100644 index 00000000..24bcdf4f --- /dev/null +++ b/partial_msg.go @@ -0,0 +1,308 @@ +package f3 + +import ( + "context" + "fmt" + + "github.com/filecoin-project/go-f3/chainexchange" + "github.com/filecoin-project/go-f3/gpbft" + "github.com/filecoin-project/go-f3/manifest" + lru "github.com/hashicorp/golang-lru/v2" + pubsub "github.com/libp2p/go-libp2p-pubsub" +) + +var _ chainexchange.Listener = (*partialMessageManager)(nil) + +type PartialGMessage struct { + *gpbft.GMessage + VoteValueKey chainexchange.Key `cborgen:"maxlen=32"` +} + +type partialMessageKey struct { + sender gpbft.ActorID + instant gpbft.Instant +} + +type discoveredChain struct { + key chainexchange.Key + instance uint64 + chain gpbft.ECChain +} + +type partialMessageManager struct { + chainex *chainexchange.PubSubChainExchange + + // pmByInstance is a map of instance to a buffer of partial messages that are + // keyed by sender+instance+round+phase. + pmByInstance map[uint64]*lru.Cache[partialMessageKey, *PartialGMessage] + // pmkByInstanceByChainKey is used for an auxiliary lookup of all partial + // messages for a given vote value at an instance. + pmkByInstanceByChainKey map[uint64]map[string][]partialMessageKey + // pendingPartialMessages is a channel of partial messages that are pending to be buffered. + pendingPartialMessages chan *PartialGMessage + // pendingDiscoveredChains is a channel of chains discovered by chainexchange + // that are pending to be processed. + pendingDiscoveredChains chan *discoveredChain + + stop func() +} + +func newPartialMessageManager(progress gpbft.Progress, ps *pubsub.PubSub, m *manifest.Manifest) (*partialMessageManager, error) { + pmm := &partialMessageManager{ + pmByInstance: make(map[uint64]*lru.Cache[partialMessageKey, *PartialGMessage]), + pmkByInstanceByChainKey: make(map[uint64]map[string][]partialMessageKey), + pendingDiscoveredChains: make(chan *discoveredChain, 100), // TODO: parameterize buffer size. + pendingPartialMessages: make(chan *PartialGMessage, 100), // TODO: parameterize buffer size. + } + var err error + pmm.chainex, err = chainexchange.NewPubSubChainExchange( + chainexchange.WithListener(pmm), + chainexchange.WithProgress(progress), + chainexchange.WithPubSub(ps), + chainexchange.WithMaxChainLength(m.ChainExchange.MaxChainLength), + chainexchange.WithMaxDiscoveredChainsPerInstance(m.ChainExchange.MaxDiscoveredChainsPerInstance), + chainexchange.WithMaxInstanceLookahead(m.ChainExchange.MaxInstanceLookahead), + chainexchange.WithMaxWantedChainsPerInstance(m.ChainExchange.MaxWantedChainsPerInstance), + chainexchange.WithSubscriptionBufferSize(m.ChainExchange.SubscriptionBufferSize), + chainexchange.WithTopicName(manifest.ChainExchangeTopicFromNetworkName(m.NetworkName)), + ) + if err != nil { + return nil, err + } + return pmm, nil +} + +func (pmm *partialMessageManager) Start(ctx context.Context) (<-chan *gpbft.GMessage, error) { + if err := pmm.chainex.Start(ctx); err != nil { + return nil, fmt.Errorf("starting chain exchange: %w", err) + } + + completedMessages := make(chan *gpbft.GMessage, 100) // TODO: parameterize buffer size. + ctx, pmm.stop = context.WithCancel(context.Background()) + go func() { + defer func() { + close(completedMessages) + log.Debugw("Partial message manager stopped.") + }() + + for ctx.Err() == nil { + select { + case <-ctx.Done(): + return + case discovered, ok := <-pmm.pendingDiscoveredChains: + if !ok { + return + } + partialMessageKeysAtInstance, found := pmm.pmkByInstanceByChainKey[discovered.instance] + if !found { + // There's no known instance with a partial message. Ignore the discovered chain. + // There's also no need to optimistically store them here. Because, chainexchange + // does this with safe caps on max future instances. + continue + } + chainkey := string(discovered.key) + partialMessageKeys, found := partialMessageKeysAtInstance[chainkey] + if !found { + // There's no known partial message at the instance for the discovered chain. + // Ignore the discovery for the same reason as above. + continue + } + buffer := pmm.getOrInitPartialMessageBuffer(discovered.instance) + for _, messageKey := range partialMessageKeys { + if pgmsg, found := buffer.Get(messageKey); found { + pgmsg.Vote.Value = discovered.chain + inferJustificationVoteValue(pgmsg) + select { + case <-ctx.Done(): + return + case completedMessages <- pgmsg.GMessage: + default: + log.Warnw("Dropped completed message as the gpbft runner is too slow to consume them.", "msg", pgmsg.GMessage) + } + buffer.Remove(messageKey) + } + } + delete(partialMessageKeysAtInstance, chainkey) + case pgmsg, ok := <-pmm.pendingPartialMessages: + if !ok { + return + } + key := partialMessageKey{ + sender: pgmsg.Sender, + instant: gpbft.Instant{ + ID: pgmsg.Vote.Instance, + Round: pgmsg.Vote.Round, + Phase: pgmsg.Vote.Phase, + }, + } + buffer := pmm.getOrInitPartialMessageBuffer(pgmsg.Vote.Instance) + if found, _ := buffer.ContainsOrAdd(key, pgmsg); !found { + pmkByChainKey := pmm.pmkByInstanceByChainKey[pgmsg.Vote.Instance] + chainKey := string(pgmsg.VoteValueKey) + pmkByChainKey[chainKey] = append(pmkByChainKey[chainKey], key) + } + // TODO: Add equivocation metrics: check if the message is different and if so + // increment the equivocations counter tagged by phase. + // See: https://github.com/filecoin-project/go-f3/issues/812 + } + } + }() + return completedMessages, nil +} + +func (pmm *partialMessageManager) BroadcastChain(ctx context.Context, instance uint64, chain gpbft.ECChain) error { + if chain.IsZero() { + return nil + } + + // TODO: Implement an independent chain broadcast and rebroadcast heuristic. + // See: https://github.com/filecoin-project/go-f3/issues/814 + + cmsg := chainexchange.Message{Instance: instance, Chain: chain} + if err := pmm.chainex.Broadcast(ctx, cmsg); err != nil { + return fmt.Errorf("broadcasting chain: %w", err) + } + log.Debugw("broadcasted chain", "instance", instance, "chain", chain) + return nil +} + +func (pmm *partialMessageManager) toPartialGMessage(msg *gpbft.GMessage) (*PartialGMessage, error) { + msgCopy := *(msg) + pmsg := &PartialGMessage{ + GMessage: &msgCopy, + } + if !pmsg.Vote.Value.IsZero() { + pmsg.VoteValueKey = pmm.chainex.Key(pmsg.Vote.Value) + pmsg.Vote.Value = gpbft.ECChain{} + } + if msg.Justification != nil && !pmsg.Justification.Vote.Value.IsZero() { + justificationCopy := *(msg.Justification) + pmsg.Justification = &justificationCopy + // The justification vote value is either zero or the same as the vote value. + // Anything else is considered to be an in invalid justification. In fact, for + // any given phase and round we can always infer: + // 1. whether a message should have a justification, and + // 2. if so, for what chain at which round. + // + // Therefore, we can entirely ignore the justification vote value as far as + // chainexchange is concerned and override it with a zero value. Upon arrival of + // a partial message, the receiver can always infer the justification chain from + // the message phase, round. In a case where justification is invalid, the + // signature won't match anyway, so it seems harmless to always infer the + // justification chain. + // + // In fact, it probably should have been omitted altogether at the time of + // protocol design. + pmsg.Justification.Vote.Value = gpbft.ECChain{} + } + return pmsg, nil +} + +func (pmm *partialMessageManager) NotifyChainDiscovered(ctx context.Context, key chainexchange.Key, instance uint64, chain gpbft.ECChain) { + discovery := &discoveredChain{key: key, instance: instance, chain: chain} + select { + case <-ctx.Done(): + return + case pmm.pendingDiscoveredChains <- discovery: + // TODO: add metrics + default: + // The message completion looks up the key on chain exchange anyway. The net + // effect of this is delay in delivering messages assuming they're re-boradcasted + // by GPBFT. This is probably the best we can do if the partial message manager + // is too slow. + log.Warnw("Dropped chain discovery notification as partial messge manager is too slow.", "instance", instance, "chain", chain) + } +} + +func (pmm *partialMessageManager) bufferPartialMessage(ctx context.Context, msg *PartialGMessage) { + select { + case <-ctx.Done(): + return + case pmm.pendingPartialMessages <- msg: + // TODO: add metrics + default: + // Choosing to rely on GPBFT re-boradcast to compensate for a partial message + // being dropped. The key thing here is that partial message manager should never + // be too slow. If it is, then there are bigger problems to solve. Hence, the + // failsafe is to drop the message instead of halting further message processing. + log.Warnw("Dropped partial message as partial message manager is too slow.", "msg", msg) + } +} + +func (pmm *partialMessageManager) getOrInitPartialMessageBuffer(instance uint64) *lru.Cache[partialMessageKey, *PartialGMessage] { + buffer, found := pmm.pmByInstance[instance] + if !found { + // TODO: parameterize this in the manifest? + // Based on 5 phases, 2K network size at a couple of rounds plus some headroom. + const maxBufferedMessagesPerInstance = 25_000 + var err error + buffer, err = lru.New[partialMessageKey, *PartialGMessage](maxBufferedMessagesPerInstance) + if err != nil { + log.Fatalf("Failed to create buffer for instance %d: %s", instance, err) + panic(err) + } + pmm.pmByInstance[instance] = buffer + } + if _, ok := pmm.pmkByInstanceByChainKey[instance]; !ok { + pmm.pmkByInstanceByChainKey[instance] = make(map[string][]partialMessageKey) + } + return buffer +} + +func (pmm *partialMessageManager) CompleteMessage(ctx context.Context, pgmsg *PartialGMessage) (*gpbft.GMessage, bool) { + if pgmsg == nil { + // For sanity assert that the message isn't nil. + return nil, false + } + if pgmsg.VoteValueKey.IsZero() { + // A zero VoteValueKey indicates that there's no partial chain value, for + // example, COMMIT for bottom. Return the message as is. + return pgmsg.GMessage, true + } + + chain, found := pmm.chainex.GetChainByInstance(ctx, pgmsg.Vote.Instance, pgmsg.VoteValueKey) + if !found { + return nil, false + } + pgmsg.Vote.Value = chain + inferJustificationVoteValue(pgmsg) + return pgmsg.GMessage, true +} + +func inferJustificationVoteValue(pgmsg *PartialGMessage) { + // Infer what the value of justification should be based on the vote phase. A + // valid message with non-nil justification must justify the vote value chain + // at: + // * CONVERGE_PHASE, with justification of PREPARE. + // * PREPARE_PHASE, with justification of PREPARE. + // * COMMIT_PHASE, with justification of PREPARE. + // * DECIDE_PHASE, with justification of COMMIT. + // + // Future work should get rid of chains in justification entirely. See: + // * https://github.com/filecoin-project/go-f3/issues/806 + if pgmsg.Justification != nil { + switch pgmsg.Vote.Phase { + case + gpbft.CONVERGE_PHASE, + gpbft.PREPARE_PHASE, + gpbft.COMMIT_PHASE: + if pgmsg.Justification.Vote.Phase == gpbft.PREPARE_PHASE { + pgmsg.Justification.Vote.Value = pgmsg.Vote.Value + } + case gpbft.DECIDE_PHASE: + if pgmsg.Justification.Vote.Phase == gpbft.COMMIT_PHASE { + pgmsg.Justification.Vote.Value = pgmsg.Vote.Value + } + default: + // The message must be invalid. But let the flow proceed and have the validator + // reject it. + } + } +} + +func (pmm *partialMessageManager) Shutdown(ctx context.Context) error { + if pmm.stop != nil { + pmm.stop() + } + return pmm.chainex.Shutdown(ctx) +}