diff --git a/comm.go b/comm.go index af0ba207..234572af 100644 --- a/comm.go +++ b/comm.go @@ -32,9 +32,16 @@ func (p *PubSub) getHelloPacket() *RPC { } for t := range subscriptions { + var requestPartial, supportsPartialMessages bool + if ts, ok := p.myTopics[t]; ok { + requestPartial = ts.requestPartialMessages + supportsPartialMessages = ts.supportsPartialMessages + } as := &pb.RPC_SubOpts{ - Topicid: proto.String(t), - Subscribe: proto.Bool(true), + Topicid: proto.String(t), + Subscribe: proto.Bool(true), + RequestsPartial: &requestPartial, + SupportsSendingPartial: &supportsPartialMessages, } rpc.Subscriptions = append(rpc.Subscriptions, as) } @@ -123,7 +130,7 @@ func (p *PubSub) notifyPeerDead(pid peer.ID) { } func (p *PubSub) handleNewPeer(ctx context.Context, pid peer.ID, outgoing *rpcQueue) { - s, err := p.host.NewStream(p.ctx, pid, p.rt.Protocols()...) + s, err := p.host.NewStream(ctx, pid, p.rt.Protocols()...) if err != nil { p.logger.Debug("error opening new stream to peer", "err", err, "peer", pid) @@ -135,11 +142,14 @@ func (p *PubSub) handleNewPeer(ctx context.Context, pid peer.ID, outgoing *rpcQu return } - go p.handleSendingMessages(ctx, s, outgoing) + firstMessage := make(chan *RPC, 1) + sCtx, cancel := context.WithCancel(ctx) + go p.handleSendingMessages(sCtx, s, outgoing, firstMessage) go p.handlePeerDead(s) select { - case p.newPeerStream <- s: + case p.newPeerStream <- peerOutgoingStream{Stream: s, FirstMessage: firstMessage, Cancel: cancel}: case <-ctx.Done(): + cancel() } } @@ -164,7 +174,7 @@ func (p *PubSub) handlePeerDead(s network.Stream) { p.notifyPeerDead(pid) } -func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, outgoing *rpcQueue) { +func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, outgoing *rpcQueue, firstMessage chan *RPC) { writeRpc := func(rpc *RPC) error { size := uint64(rpc.Size()) @@ -177,6 +187,11 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou return err } + if err := s.SetWriteDeadline(time.Now().Add(time.Second * 30)); err != nil { + p.rpcLogger.Debug("failed to set write deadline", "peer", s.Conn().RemotePeer(), "err", err) + return err + } + _, err = s.Write(buf) if err != nil { p.rpcLogger.Debug("failed to send message", "peer", s.Conn().RemotePeer(), "rpc", rpc, "err", err) @@ -186,6 +201,21 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou return nil } + select { + case rpc := <-firstMessage: + if rpc.Size() > 0 { + err := writeRpc(rpc) + if err != nil { + s.Reset() + p.logger.Debug("error writing message to peer", "peer", s.Conn().RemotePeer(), "err", err) + return + } + } + case <-ctx.Done(): + s.Reset() + return + } + defer s.Close() for ctx.Err() == nil { rpc, err := outgoing.Pop(ctx) diff --git a/extensions.go b/extensions.go index 3a2c3f11..fc40b34c 100644 --- a/extensions.go +++ b/extensions.go @@ -1,12 +1,17 @@ package pubsub import ( + "errors" + "iter" + + "github.com/libp2p/go-libp2p-pubsub/partialmessages" pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb" "github.com/libp2p/go-libp2p/core/peer" ) type PeerExtensions struct { - TestExtension bool + TestExtension bool + PartialMessages bool } type TestExtensionConfig struct { @@ -37,6 +42,7 @@ func peerExtensionsFromRPC(rpc *RPC) PeerExtensions { out := PeerExtensions{} if hasPeerExtensions(rpc) { out.TestExtension = rpc.Control.Extensions.GetTestExtension() + out.PartialMessages = rpc.Control.Extensions.GetPartialMessages() } return out } @@ -46,9 +52,19 @@ func (pe *PeerExtensions) ExtendRPC(rpc *RPC) *RPC { if rpc.Control == nil { rpc.Control = &pubsub_pb.ControlMessage{} } - rpc.Control.Extensions = &pubsub_pb.ControlExtensions{ - TestExtension: &pe.TestExtension, + if rpc.Control.Extensions == nil { + rpc.Control.Extensions = &pubsub_pb.ControlExtensions{} + } + rpc.Control.Extensions.TestExtension = &pe.TestExtension + } + if pe.PartialMessages { + if rpc.Control == nil { + rpc.Control = &pubsub_pb.ControlMessage{} + } + if rpc.Control.Extensions == nil { + rpc.Control.Extensions = &pubsub_pb.ControlExtensions{} } + rpc.Control.Extensions.PartialMessages = &pe.PartialMessages } return rpc } @@ -59,8 +75,9 @@ type extensionsState struct { sentExtensions map[peer.ID]struct{} reportMisbehavior func(peer.ID) sendRPC func(p peer.ID, r *RPC, urgent bool) + testExtension *testExtension - testExtension *testExtension + partialMessagesExtension *partialmessages.PartialMessageExtension } func newExtensionsState(myExtensions PeerExtensions, reportMisbehavior func(peer.ID), sendRPC func(peer.ID, *RPC, bool)) *extensionsState { @@ -132,14 +149,97 @@ func (es *extensionsState) extensionsAddPeer(id peer.ID) { if es.myExtensions.TestExtension && es.peerExtensions[id].TestExtension { es.testExtension.AddPeer(id) } + + if es.myExtensions.PartialMessages && es.peerExtensions[id].PartialMessages { + es.partialMessagesExtension.AddPeer(id) + } } // extensionsRemovePeer is always called after extensionsAddPeer. func (es *extensionsState) extensionsRemovePeer(id peer.ID) { + if es.myExtensions.PartialMessages && es.peerExtensions[id].PartialMessages { + es.partialMessagesExtension.RemovePeer(id) + } } func (es *extensionsState) extensionsHandleRPC(rpc *RPC) { if es.myExtensions.TestExtension && es.peerExtensions[rpc.from].TestExtension { es.testExtension.HandleRPC(rpc.from, rpc.TestExtension) } + + if es.myExtensions.PartialMessages && es.peerExtensions[rpc.from].PartialMessages && rpc.Partial != nil { + es.partialMessagesExtension.HandleRPC(rpc.from, rpc.Partial) + } } + +func (es *extensionsState) Heartbeat() { + if es.myExtensions.PartialMessages { + es.partialMessagesExtension.Heartbeat() + } +} + +func WithPartialMessagesExtension(pm *partialmessages.PartialMessageExtension) Option { + return func(ps *PubSub) error { + gs, ok := ps.rt.(*GossipSubRouter) + if !ok { + return errors.New("pubsub router is not gossipsub") + } + err := pm.Init(partialMessageRouter{gs}) + if err != nil { + return err + } + + gs.extensions.myExtensions.PartialMessages = true + gs.extensions.partialMessagesExtension = pm + return nil + } +} + +type partialMessageRouter struct { + gs *GossipSubRouter +} + +// PeerRequestsPartial implements partialmessages.Router. +func (r partialMessageRouter) PeerRequestsPartial(peer peer.ID, topic string) bool { + return r.gs.peerRequestsPartial(peer, topic) +} + +// MeshPeers implements partialmessages.Router. +func (r partialMessageRouter) MeshPeers(topic string) iter.Seq[peer.ID] { + return func(yield func(peer.ID) bool) { + peerSet, ok := r.gs.mesh[topic] + if !ok { + // Possibly a fanout topic + peerSet, ok = r.gs.fanout[topic] + if !ok { + return + } + } + + myTopicState := r.gs.p.myTopics[topic] + iRequestPartial := myTopicState != nil && myTopicState.requestPartialMessages + for peer := range peerSet { + if r.gs.extensions.peerExtensions[peer].PartialMessages { + peerSupportsPartial := r.gs.peerSupportsPartial(peer, topic) + peerRequestsPartial := r.gs.peerRequestsPartial(peer, topic) + if (iRequestPartial && peerSupportsPartial) || peerRequestsPartial { + // Peer supports partial messages + if !yield(peer) { + return + } + } + } + } + } +} + +// SendRPC implements partialmessages.Router. +func (r partialMessageRouter) SendRPC(p peer.ID, rpc *pubsub_pb.PartialMessagesExtension, urgent bool) { + r.gs.sendRPC(p, &RPC{ + RPC: pubsub_pb.RPC{ + Partial: rpc, + }, + }, urgent) +} + +var _ partialmessages.Router = partialMessageRouter{} diff --git a/floodsub_test.go b/floodsub_test.go index 451639ce..fcdb12c4 100644 --- a/floodsub_test.go +++ b/floodsub_test.go @@ -57,6 +57,12 @@ func denseConnect(t *testing.T, hosts []host.Host) { connectSome(t, hosts, 10) } +func ringConnect(t *testing.T, hosts []host.Host) { + for i := range hosts { + connect(t, hosts[i], hosts[(i+1)%len(hosts)]) + } +} + func connectSome(t *testing.T, hosts []host.Host, d int) { for i, a := range hosts { for j := 0; j < d; j++ { diff --git a/gossipsub.go b/gossipsub.go index c492ded9..8d1358d2 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -864,6 +864,12 @@ func (gs *GossipSubRouter) Preprocess(from peer.ID, msgs []*Message) { // We don't send IDONTWANT to the peer that sent us the messages continue } + if gs.peerRequestsPartial(p, topic) { + // Don't send IDONTWANT to peers that are using partial messages + // for this topic + continue + } + // send to only peers that support IDONTWANT if gs.feature(GossipSubFeatureIdontwant, gs.peers[p]) { idontwant := []*pb.ControlIDontWant{{MessageIDs: mids}} @@ -1375,6 +1381,10 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] { if pid == from || pid == peer.ID(msg.GetFrom()) { continue } + if gs.peerRequestsPartial(pid, topic) { + // The peer requested partial messages. We'll skip sending them full messages + continue + } if !yield(pid, out) { return @@ -1383,6 +1393,16 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] { } } +func (gs *GossipSubRouter) peerSupportsPartial(p peer.ID, topic string) bool { + peerStates, ok := gs.p.topics[topic] + return ok && gs.extensions.myExtensions.PartialMessages && peerStates[p].supportsPartial +} + +func (gs *GossipSubRouter) peerRequestsPartial(p peer.ID, topic string) bool { + peerStates, ok := gs.p.topics[topic] + return ok && gs.extensions.myExtensions.PartialMessages && peerStates[p].requestsPartial +} + func (gs *GossipSubRouter) Join(topic string) { gmap, ok := gs.mesh[topic] if ok { @@ -1833,6 +1853,8 @@ func (gs *GossipSubRouter) heartbeat() { // advance the message history window gs.mcache.Shift() + + gs.extensions.Heartbeat() } func (gs *GossipSubRouter) clearIHaveCounters() { diff --git a/gossipsub_test.go b/gossipsub_test.go index 0fca9d34..26d22792 100644 --- a/gossipsub_test.go +++ b/gossipsub_test.go @@ -5,6 +5,8 @@ import ( "context" crand "crypto/rand" "encoding/base64" + "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -22,6 +24,8 @@ import ( "time" "github.com/libp2p/go-libp2p-pubsub/internal/gologshim" + "github.com/libp2p/go-libp2p-pubsub/partialmessages" + "github.com/libp2p/go-libp2p-pubsub/partialmessages/bitmap" pb "github.com/libp2p/go-libp2p-pubsub/pb" "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-msgio" @@ -60,15 +64,6 @@ func getGossipsubs(ctx context.Context, hs []host.Host, opts ...Option) []*PubSu return psubs } -func getGossipsubsOptFn(ctx context.Context, hs []host.Host, optFn func(int, host.Host) []Option) []*PubSub { - var psubs []*PubSub - for i, h := range hs { - opts := optFn(i, h) - psubs = append(psubs, getGossipsub(ctx, h, opts...)) - } - return psubs -} - func TestGossipSubParamsValidate(t *testing.T) { params := DefaultGossipSubParams() params.Dhi = 1 @@ -90,6 +85,15 @@ func TestGossipSubBootstrapParamsValidate(t *testing.T) { } } +func getGossipsubsOptFn(ctx context.Context, hs []host.Host, optFn func(int, host.Host) []Option) []*PubSub { + var psubs []*PubSub + for i, h := range hs { + opts := optFn(i, h) + psubs = append(psubs, getGossipsub(ctx, h, opts...)) + } + return psubs +} + func TestSparseGossipsub(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -2213,6 +2217,58 @@ func TestGossipsubPeerScoreInspect(t *testing.T) { } } +func TestGossipsubPeerFeedback(t *testing.T) { + // this test exercises the code path sof peer score inspection + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + hosts := getDefaultHosts(t, 2) + + inspector := &mockPeerScoreInspector{} + psub1 := getGossipsub(ctx, hosts[0], + WithPeerScore( + &PeerScoreParams{ + Topics: map[string]*TopicScoreParams{ + "test": { + TopicWeight: 1, + TimeInMeshQuantum: time.Second, + FirstMessageDeliveriesWeight: 10, + FirstMessageDeliveriesDecay: 0.999, + FirstMessageDeliveriesCap: 100, + InvalidMessageDeliveriesWeight: -1, + InvalidMessageDeliveriesDecay: 0.9999, + }, + }, + AppSpecificScore: func(peer.ID) float64 { return 0 }, + DecayInterval: time.Second, + DecayToZero: 0.01, + }, + &PeerScoreThresholds{ + GossipThreshold: -1, + PublishThreshold: -10, + GraylistThreshold: -1000, + }), + WithPeerScoreInspect(inspector.inspect, time.Second)) + _ = getGossipsub(ctx, hosts[1]) + + connect(t, hosts[0], hosts[1]) + time.Sleep(500 * time.Millisecond) // Wait for nodes to connect + + var err error + err = errors.Join(err, psub1.PeerFeedback("test", hosts[1].ID(), PeerFeedbackUsefulMessage)) + err = errors.Join(err, psub1.PeerFeedback("test", hosts[1].ID(), PeerFeedbackUsefulMessage)) + err = errors.Join(err, psub1.PeerFeedback("test", hosts[1].ID(), PeerFeedbackUsefulMessage)) + if err != nil { + t.Fatal(err) + } + time.Sleep(500 * time.Millisecond) // Wait for feedback to be incorporated + + score2 := inspector.score(hosts[1].ID()) + if score2 < 9 { + t.Fatalf("expected score to be at least 9, instead got %f", score2) + } +} + func TestGossipsubPeerScoreResetTopicParams(t *testing.T) { // this test exercises the code path sof peer score inspection ctx, cancel := context.WithCancel(context.Background()) @@ -4394,3 +4450,786 @@ func TestTestExtension(t *testing.T) { t.Fatal("TestExtension not received") } } + +type minimalTestPartialMessage struct { + Group []byte + Parts [2][]byte +} + +func (m *minimalTestPartialMessage) complete() bool { + return len(m.Parts[0]) > 0 && len(m.Parts[1]) > 0 +} + +// PartsMetadata implements partialmessages.PartialMessage. +func (m *minimalTestPartialMessage) PartsMetadata() partialmessages.PartsMetadata { + out := make(bitmap.Bitmap, 1) + for i := range m.Parts { + if len(m.Parts[i]) > 0 { + out.Set(i) + } + } + return partialmessages.PartsMetadata(out) +} + +func (m *minimalTestPartialMessage) extendFromEncodedPartialMessage(_ peer.ID, data []byte) (extended bool) { + var temp minimalTestPartialMessage + json.Unmarshal(data, &temp) + for i := range m.Parts { + if len(temp.Parts[i]) > 0 && m.Parts[i] == nil { + extended = true + m.Parts[i] = temp.Parts[i] + } + } + return +} + +// onIncomingRPC handle an incoming rpc and will return a non-nil publish +// options if the caller should republish this partial message. +func (m *minimalTestPartialMessage) onIncomingRPC(from peer.ID, rpc *pb.PartialMessagesExtension) *partialmessages.PublishOptions { + var extended bool + if len(rpc.PartialMessage) > 0 { + extended = m.extendFromEncodedPartialMessage(from, rpc.PartialMessage) + } + + var publishOpts partialmessages.PublishOptions + + if !extended { + var peerHasUsefulData, iHaveUsefulData bool + // Only do these checks if we didn't extend our partial message. + // Since, otherwise, we simply publish again to all peers. + if len(rpc.PartsMetadata) > 0 { + iHave := m.PartsMetadata()[0] + iWant := ^iHave + + peerHas := rpc.PartsMetadata[0] + peerWants := ^peerHas + + iHaveUsefulData = iHave&peerWants != 0 + peerHasUsefulData = iWant&peerHas != 0 + } + if peerHasUsefulData || iHaveUsefulData { + publishOpts.PublishToPeers = []peer.ID{from} + } + } + + if extended || len(publishOpts.PublishToPeers) > 0 { + return &publishOpts + } + return nil +} + +// GroupID implements partialmessages.PartialMessage. +func (m *minimalTestPartialMessage) GroupID() []byte { + return m.Group +} + +func (m *minimalTestPartialMessage) PartialMessageBytes(peerPartsMetadata partialmessages.PartsMetadata) ([]byte, error) { + if len(peerPartsMetadata) == 0 { + return nil, errors.New("invalid metadata") + } + peerHas := bitmap.Bitmap(peerPartsMetadata) + + var temp minimalTestPartialMessage + temp.Group = m.Group + if !peerHas.Get(0) && m.Parts[0] != nil { + temp.Parts[0] = m.Parts[0] + } + if !peerHas.Get(1) && m.Parts[1] != nil { + temp.Parts[1] = m.Parts[1] + } + + if temp.Parts[0] == nil && temp.Parts[1] == nil { + return nil, nil + } + + b, err := json.Marshal(temp) + if err != nil { + return nil, err + } + return b, nil +} + +func (m *minimalTestPartialMessage) shouldRequest(_ peer.ID, peerHasMetadata []byte) bool { + if len(peerHasMetadata) == 0 { + return false + } + peerHas := peerHasMetadata[0] + iWant := ^m.PartsMetadata()[0] + return iWant&peerHas != 0 +} + +var _ partialmessages.Message = (*minimalTestPartialMessage)(nil) + +type minimalTestPartialMessageChecker struct { + fullMessage *minimalTestPartialMessage +} + +func (m *minimalTestPartialMessageChecker) MergePartsMetadata(left, right partialmessages.PartsMetadata) partialmessages.PartsMetadata { + return partialmessages.MergeBitmap(left, right) +} + +// EmptyMessage implements partialmessages.InvariantChecker. +func (m *minimalTestPartialMessageChecker) EmptyMessage() *minimalTestPartialMessage { + return &minimalTestPartialMessage{ + Group: m.fullMessage.Group, + } +} + +// Equal implements partialmessages.InvariantChecker. +func (m *minimalTestPartialMessageChecker) Equal(a *minimalTestPartialMessage, b *minimalTestPartialMessage) bool { + return bytes.Equal(a.Group, b.Group) && (slices.CompareFunc(a.Parts[:], b.Parts[:], func(e1 []byte, e2 []byte) int { + return bytes.Compare(e1, e2) + }) == 0) +} + +// ExtendFromBytes implements partialmessages.InvariantChecker. +func (m *minimalTestPartialMessageChecker) ExtendFromBytes(a *minimalTestPartialMessage, data []byte) (*minimalTestPartialMessage, error) { + a.extendFromEncodedPartialMessage("", data) + return a, nil +} + +// FullMessage implements partialmessages.InvariantChecker. +func (m *minimalTestPartialMessageChecker) FullMessage() (*minimalTestPartialMessage, error) { + return m.fullMessage, nil +} + +// SplitIntoParts implements partialmessages.InvariantChecker. +func (m *minimalTestPartialMessageChecker) SplitIntoParts(in *minimalTestPartialMessage) ([]*minimalTestPartialMessage, error) { + var parts [2]*minimalTestPartialMessage + for i := range parts { + parts[i] = &minimalTestPartialMessage{ + Group: in.Group, + } + parts[i].Parts[i] = in.Parts[i] + } + return parts[:], nil +} + +func (m *minimalTestPartialMessageChecker) ShouldRequest(a *minimalTestPartialMessage, from peer.ID, partsMetadata []byte) bool { + return a.shouldRequest(from, partsMetadata) +} + +var _ partialmessages.InvariantChecker[*minimalTestPartialMessage] = (*minimalTestPartialMessageChecker)(nil) + +func TestMinimalPartialMessageImpl(t *testing.T) { + group := []byte("test-group") + full := &minimalTestPartialMessage{ + Group: group, + Parts: [2][]byte{ + []byte("Hello"), + []byte("World"), + }, + } + checker := &minimalTestPartialMessageChecker{ + fullMessage: full, + } + partialmessages.TestPartialMessageInvariants(t, checker) +} + +func TestPartialMessages(t *testing.T) { + topic := "test-topic" + const hostCount = 5 + hosts := getDefaultHosts(t, hostCount) + psubs := make([]*PubSub, 0, len(hosts)) + + gossipsubCtx, closeGossipsub := context.WithCancel(context.Background()) + go func() { + <-gossipsubCtx.Done() + for _, h := range hosts { + h.Close() + } + }() + + partialExt := make([]*partialmessages.PartialMessageExtension, hostCount) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) + + // A list of maps from topic+groupID to partialMessage. One map per peer + // var partialMessageStoreMu sync.Mutex + partialMessageStore := make([]map[string]*minimalTestPartialMessage, hostCount) + for i := range hostCount { + partialMessageStore[i] = make(map[string]*minimalTestPartialMessage) + } + + for i := range partialExt { + partialExt[i] = &partialmessages.PartialMessageExtension{ + Logger: logger.With("id", i), + ValidateRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error { + // No validation. Only for this test. In production you should + // have some basic fast rules here. + return nil + }, + MergePartsMetadata: func(_ string, left, right partialmessages.PartsMetadata) partialmessages.PartsMetadata { + return partialmessages.MergeBitmap(left, right) + }, + OnIncomingRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error { + groupID := rpc.GroupID + pm, ok := partialMessageStore[i][topic+string(groupID)] + if !ok { + pm = &minimalTestPartialMessage{ + Group: groupID, + } + partialMessageStore[i][topic+string(groupID)] = pm + } + if publishOpts := pm.onIncomingRPC(from, rpc); publishOpts != nil { + go psubs[i].PublishPartialMessage(topic, pm, *publishOpts) + } + return nil + }, + } + } + + for i, h := range hosts { + psub := getGossipsub(gossipsubCtx, h, WithPartialMessagesExtension(partialExt[i])) + topic, err := psub.Join(topic, RequestPartialMessages()) + if err != nil { + t.Fatal(err) + } + _, err = topic.Subscribe() + if err != nil { + t.Fatal(err) + } + psubs = append(psubs, psub) + } + + denseConnect(t, hosts) + time.Sleep(2 * time.Second) + + group := []byte("test-group") + msg1 := &minimalTestPartialMessage{ + Group: group, + Parts: [2][]byte{ + []byte("Hello"), + []byte("World"), + }, + } + partialMessageStore[0][topic+string(group)] = msg1 + err := psubs[0].PublishPartialMessage(topic, msg1, partialmessages.PublishOptions{}) + if err != nil { + t.Fatal(err) + } + + time.Sleep(2 * time.Second) + + // Close gossipsub before we inspect the state to avoid race conditions + closeGossipsub() + time.Sleep(1 * time.Second) + + if len(partialMessageStore) != hostCount { + t.Errorf("One host is missing the partial message") + } + + for i, msgStore := range partialMessageStore { + if len(msgStore) == 0 { + t.Errorf("Host %d is missing the partial message", i) + } + for _, partialMessage := range msgStore { + if !partialMessage.complete() { + t.Errorf("expected complete message, but %v is incomplete", partialMessage) + } + } + } +} + +func TestPeerSupportsPartialMessages(t *testing.T) { + // N peers connected in a ring: + // peer 0 requests partial messages + // peer 1 does not support partial messages + // peer 2..N-1 support partial messages + // + // Peer 0 first requests a partial message by doing a partial publish with an + // empty (no parts) message. + // + // The rest of the peers then publish the full message. The peer that + // supports partial messages should have received the request from peer 0, + // and can sent the missing parts right away. + + topic := "test-topic" + const hostCount = 5 + hosts := getDefaultHosts(t, hostCount) + psubs := make([]*PubSub, 0, len(hosts)) + topics := make([]*Topic, 0, len(hosts)) + subs := make([]*Subscription, 0, len(hosts)) + + gossipsubCtx, closeGossipsub := context.WithCancel(context.Background()) + go func() { + <-gossipsubCtx.Done() + for _, h := range hosts { + h.Close() + } + }() + + partialExt := make([]*partialmessages.PartialMessageExtension, hostCount) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) + + // A list of maps from topic+groupID to partialMessage. One map per peer + // var partialMessageStoreMu sync.Mutex + partialMessageStore := make([]map[string]*minimalTestPartialMessage, hostCount) + for i := range hostCount { + partialMessageStore[i] = make(map[string]*minimalTestPartialMessage) + } + + for i := range partialExt { + partialExt[i] = &partialmessages.PartialMessageExtension{ + Logger: logger.With("id", i), + ValidateRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error { + // No validation. Only for this test. In production you should + // have some basic fast rules here. + return nil + }, + MergePartsMetadata: func(_ string, left, right partialmessages.PartsMetadata) partialmessages.PartsMetadata { + return partialmessages.MergeBitmap(left, right) + }, + OnIncomingRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error { + if from == hosts[1].ID() { + panic("peer 1 does not support partial messages, so should not send a partial message RPC") + } + + if i == 0 && rpc.PartialMessage == nil { + // The first incoming rpc to the peer requesting a partial + // message should contain data since we made sure to send + // the request first. + panic("expected to receive a partial message from a supporting peer") + } + + groupID := rpc.GroupID + pm, ok := partialMessageStore[i][topic+string(groupID)] + if !ok { + pm = &minimalTestPartialMessage{ + Group: groupID, + } + partialMessageStore[i][topic+string(groupID)] = pm + } + if publishOpts := pm.onIncomingRPC(from, rpc); publishOpts != nil { + go psubs[i].PublishPartialMessage(topic, pm, *publishOpts) + if pm.complete() { + encoded, _ := pm.PartialMessageBytes(partialmessages.PartsMetadata([]byte{0})) + go func() { + err := psubs[i].Publish(topic, encoded) + if err != nil { + panic(err) + } + }() + } + + } + return nil + }, + } + } + + for i, h := range hosts { + var topicOpts []TopicOpt + if i == 0 { + topicOpts = append(topicOpts, RequestPartialMessages()) + } else if i == 1 { + // The right neighbor doesn't support partial messages + } else { + topicOpts = append(topicOpts, SupportsPartialMessages()) + } + + psub := getGossipsub(gossipsubCtx, h, WithPartialMessagesExtension(partialExt[i])) + topic, err := psub.Join(topic, topicOpts...) + if err != nil { + t.Fatal(err) + } + sub, err := topic.Subscribe() + if err != nil { + t.Fatal(err) + } + psubs = append(psubs, psub) + topics = append(topics, topic) + subs = append(subs, sub) + } + + ringConnect(t, hosts) + time.Sleep(2 * time.Second) + + group := []byte("test-group") + emptyMsg := &minimalTestPartialMessage{ + Group: group, + } + fullMsg := &minimalTestPartialMessage{ + Group: group, + Parts: [2][]byte{ + []byte("Hello"), + []byte("World"), + }, + } + + for i := range hosts { + if i <= 1 { + continue + } + copiedMsg := *fullMsg + partialMessageStore[i][topic+string(group)] = &copiedMsg + } + + // Have the first host publish the empty partial message to send a partial + // message request to peers that support partial messages. + partialMessageStore[0][topic+string(group)] = emptyMsg + // first host has no data + err := psubs[0].PublishPartialMessage(topic, emptyMsg, partialmessages.PublishOptions{}) + if err != nil { + t.Fatal(err) + } + + time.Sleep(time.Second) + + for i := range hosts { + if i == 0 { + continue + } else { + if i != 1 { + err := psubs[i].PublishPartialMessage(topic, partialMessageStore[i][topic+string(group)], partialmessages.PublishOptions{}) + if err != nil { + t.Fatal(err) + } + } + encoded, err := fullMsg.PartialMessageBytes(partialmessages.PartsMetadata([]byte{0})) + if err != nil { + t.Fatal(err) + } + err = topics[i].Publish(context.Background(), encoded) + if err != nil { + t.Fatal(err) + } + } + } + + time.Sleep(2 * time.Second) + + // Close gossipsub before we inspect the state to avoid race conditions + closeGossipsub() + time.Sleep(1 * time.Second) + + if len(partialMessageStore) != hostCount { + t.Errorf("One host is missing the partial message") + } + + for i, msgStore := range partialMessageStore { + if i == 1 { + // Host 1 doesn't support partial messages + continue + } + if len(msgStore) == 0 { + t.Errorf("Host %d is missing the partial message", i) + } + for _, partialMessage := range msgStore { + if !partialMessage.complete() { + t.Errorf("expected complete message, but %v is incomplete at host %d", partialMessage, i) + } + } + } + + for i, sub := range subs { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := sub.Next(ctx) + if err != nil { + t.Errorf("failed to receive message: %v at host %d", err, i) + } + } +} + +func TestSkipPublishingToPeersRequestingPartialMessages(t *testing.T) { + topicName := "test-topic" + + // 3 hosts. + // hosts[0]: Publisher. Requests partial messages + // hosts[1]: Subscriber. Requests partial messages + // hosts[2]: Alternate publisher. Does not support partial messages. Only + // connected to hosts[0] + hosts := getDefaultHosts(t, 3) + + const hostsWithPartialMessageSupport = 2 + partialExt := make([]*partialmessages.PartialMessageExtension, hostsWithPartialMessageSupport) + // A list of maps from topic+groupID to partialMessage. One map per peer + partialMessageStore := make([]map[string]*minimalTestPartialMessage, hostsWithPartialMessageSupport) + for i := range hostsWithPartialMessageSupport { + partialMessageStore[i] = make(map[string]*minimalTestPartialMessage) + } + + // Only hosts with partial message support + psubs := make([]*PubSub, 0, len(hosts)-1) + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: slog.LevelDebug, + })) + + for i := range partialExt { + partialExt[i] = &partialmessages.PartialMessageExtension{ + Logger: logger, + ValidateRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error { + return nil + }, + MergePartsMetadata: func(_ string, left, right partialmessages.PartsMetadata) partialmessages.PartsMetadata { + return partialmessages.MergeBitmap(left, right) + }, + OnIncomingRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error { + topicID := rpc.GetTopicID() + groupID := rpc.GetGroupID() + pm, ok := partialMessageStore[i][topicID+string(groupID)] + if !ok { + pm = &minimalTestPartialMessage{ + Group: groupID, + } + partialMessageStore[i][topicID+string(groupID)] = pm + } + if publishOpts := pm.onIncomingRPC(from, rpc); publishOpts != nil { + go psubs[i].PublishPartialMessage(topicID, pm, *publishOpts) + } + return nil + }, + } + } + + for i, h := range hosts[:2] { + psub := getGossipsub(context.Background(), h, WithPartialMessagesExtension(partialExt[i])) + psubs = append(psubs, psub) + } + + nonPartialPubsub := getGossipsub(context.Background(), hosts[2]) + + denseConnect(t, hosts[:2]) + time.Sleep(2 * time.Second) + + // Connect nonPartialPubsub to the publisher + connect(t, hosts[0], hosts[2]) + + var topics []*Topic + var subs []*Subscription + for _, psub := range psubs { + topic, err := psub.Join(topicName, RequestPartialMessages()) + if err != nil { + t.Fatal(err) + } + topics = append(topics, topic) + s, err := topic.Subscribe() + if err != nil { + t.Fatal(err) + } + subs = append(subs, s) + } + + topicForNonPartial, err := nonPartialPubsub.Join(topicName) + if err != nil { + t.Fatal(err) + } + + // Wait for subscriptions to propagate + time.Sleep(time.Second) + + topics[0].Publish(context.Background(), []byte("Hello")) + + // Publish from another peer, the publisher (psub[0]) should not forward this to psub[1]. + // The application has to handle the interaction of getting a standard + // gossipsub message and republishing it with partial messages. + topicForNonPartial.Publish(context.Background(), []byte("from non-partial")) + + recvdMessage := make(chan struct{}, 1) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + msg, err := subs[1].Next(ctx) + if err == context.Canceled { + return + } + if err != nil { + t.Log(err) + t.Fail() + return + } + t.Log("Received msg", string(msg.Data)) + recvdMessage <- struct{}{} + }() + + select { + case <-recvdMessage: + t.Fatal("Received message") + case <-time.After(2 * time.Second): + } +} + +func TestPairwiseInteractionWithPartialMessages(t *testing.T) { + type PartialMessageStatus int + const ( + NoPartialMessages PartialMessageStatus = iota + PeerSupportsPartialMessages + PeerRequestsPartialMessages + ) + + type TestCase struct { + hostSupport []PartialMessageStatus + publisherIdx int + } + + var tcs []TestCase + for _, a := range []PartialMessageStatus{NoPartialMessages, PeerSupportsPartialMessages, PeerRequestsPartialMessages} { + for _, b := range []PartialMessageStatus{NoPartialMessages, PeerSupportsPartialMessages, PeerRequestsPartialMessages} { + for i := range 2 { + tcs = append(tcs, TestCase{hostSupport: []PartialMessageStatus{a, b}, publisherIdx: i}) + } + } + } + + for _, tc := range tcs { + t.Run(fmt.Sprintf("Host Support: %v. Publisher: %d", tc.hostSupport, tc.publisherIdx), func(t *testing.T) { + topic := "test-topic" + hostCount := len(tc.hostSupport) + hosts := getDefaultHosts(t, hostCount) + topics := make([]*Topic, 0, len(hosts)) + psubs := make([]*PubSub, 0, len(hosts)) + + gossipsubCtx, closeGossipsub := context.WithCancel(context.Background()) + defer closeGossipsub() + go func() { + <-gossipsubCtx.Done() + for _, h := range hosts { + h.Close() + } + }() + + partialExt := make([]*partialmessages.PartialMessageExtension, hostCount) + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelDebug})) + + // A list of maps from topic+groupID to partialMessage. One map per peer + // var partialMessageStoreMu sync.Mutex + partialMessageStore := make([]map[string]*minimalTestPartialMessage, hostCount) + for i := range hostCount { + partialMessageStore[i] = make(map[string]*minimalTestPartialMessage) + } + + receivedMessage := make(chan struct{}, hostCount) + + for i := range partialExt { + if tc.hostSupport[i] == NoPartialMessages { + continue + } + partialExt[i] = &partialmessages.PartialMessageExtension{ + Logger: logger.With("id", i), + ValidateRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error { + // No validation. Only for this test. In production you should + // have some basic fast rules here. + return nil + }, + MergePartsMetadata: func(_ string, left, right partialmessages.PartsMetadata) partialmessages.PartsMetadata { + return partialmessages.MergeBitmap(left, right) + }, + OnIncomingRPC: func(from peer.ID, rpc *pb.PartialMessagesExtension) error { + if tc.hostSupport[i] == PeerSupportsPartialMessages && len(rpc.PartialMessage) > 0 { + panic("This host should not have received partial message data") + } + + groupID := rpc.GroupID + pm, ok := partialMessageStore[i][topic+string(groupID)] + if !ok { + pm = &minimalTestPartialMessage{ + Group: groupID, + } + partialMessageStore[i][topic+string(groupID)] = pm + } + prevComplete := pm.complete() + if publishOpts := pm.onIncomingRPC(from, rpc); publishOpts != nil { + if !prevComplete && pm.complete() { + t.Log("host", i, "received partial message") + + receivedMessage <- struct{}{} + } + + go psubs[i].PublishPartialMessage(topic, pm, *publishOpts) + } + return nil + }, + } + } + + for i, h := range hosts { + var opts []Option + var topicOpts []TopicOpt + switch tc.hostSupport[i] { + case NoPartialMessages: + case PeerSupportsPartialMessages: + opts = append(opts, WithPartialMessagesExtension(partialExt[i])) + topicOpts = append(topicOpts, SupportsPartialMessages()) + case PeerRequestsPartialMessages: + opts = append(opts, WithPartialMessagesExtension(partialExt[i])) + topicOpts = append(topicOpts, RequestPartialMessages()) + } + + psub := getGossipsub(gossipsubCtx, h, opts...) + topic, err := psub.Join(topic, topicOpts...) + if err != nil { + t.Fatal(err) + } + topics = append(topics, topic) + sub, err := topic.Subscribe() + if err != nil { + t.Fatal(err) + } + psubs = append(psubs, psub) + go func() { + _, err := sub.Next(gossipsubCtx) + if err == context.Canceled { + return + } + if err != nil { + panic(err) + } + + t.Log("host", i, "received message") + receivedMessage <- struct{}{} + }() + } + + denseConnect(t, hosts) + time.Sleep(time.Second) + + group := []byte("test-group") + msg1 := &minimalTestPartialMessage{ + Group: group, + Parts: [2][]byte{ + []byte("Hello"), + []byte("World"), + }, + } + + for i := range hostCount { + if i != tc.publisherIdx { + continue + } + + partialMessageStore[i][topic+string(group)] = msg1 + + encoded, err := msg1.PartialMessageBytes(partialmessages.PartsMetadata([]byte{0})) + if err != nil { + t.Fatal(err) + } + err = topics[i].Publish(context.Background(), encoded) + if err != nil { + t.Fatal(err) + } + + if tc.hostSupport[i] != NoPartialMessages { + err = psubs[i].PublishPartialMessage(topic, msg1, partialmessages.PublishOptions{}) + if err != nil { + t.Fatal(err) + } + } + } + + for range hostCount { + select { + case <-receivedMessage: + case <-time.After(time.Second): + t.Fatalf("At least one message was not received") + } + } + + select { + case <-receivedMessage: + t.Fatalf("An extra message was received") + case <-time.After(100 * time.Millisecond): + } + }) + } +} diff --git a/internal/merkle/example.go b/internal/merkle/example.go new file mode 100644 index 00000000..034f33c9 --- /dev/null +++ b/internal/merkle/example.go @@ -0,0 +1,79 @@ +// merkle is a minimal merkle tree commitment scheme to serve as an example and +// test for partial messages. +package merkle + +import ( + "bytes" + "crypto/sha256" +) + +func hash(data []byte) []byte { + h := sha256.Sum256(data) + return h[:] +} + +// buildMerkleTree computes the Merkle root and build layers for proof generation +func buildMerkleTree(leaves [][]byte) [][][]byte { + tree := [][][]byte{leaves} + for len(tree[len(tree)-1]) > 1 { + level := tree[len(tree)-1] + var nextLevel [][]byte + for i := 0; i < len(level); i += 2 { + if i+1 == len(level) { + nextLevel = append(nextLevel, level[i]) + } else { + combined := append(level[i], level[i+1]...) + nextLevel = append(nextLevel, hash(combined)) + } + } + tree = append(tree, nextLevel) + } + return tree +} + +func MerkleRoot(leaves [][]byte) []byte { + tree := buildMerkleTree(leaves) + return tree[len(tree)-1][0] +} + +// Merkle proof element with sibling hash and direction flag (true if sibling is on left) +type ProofStep struct { + Hash []byte + IsLeft bool +} + +// Generate Merkle proof for leaf at index +func MerkleProof(leaves [][]byte, index int) []ProofStep { + tree := buildMerkleTree(leaves) + proof := []ProofStep{} + for level := 0; level < len(tree)-1; level++ { + siblingIndex := 0 + if index%2 == 0 { + siblingIndex = index + 1 + if siblingIndex >= len(tree[level]) { + // No sibling, skip + index /= 2 + continue + } + proof = append(proof, ProofStep{Hash: tree[level][siblingIndex], IsLeft: false}) + } else { + siblingIndex = index - 1 + proof = append(proof, ProofStep{Hash: tree[level][siblingIndex], IsLeft: true}) + } + index /= 2 + } + return proof +} + +// Verify Merkle proof for leaf, root, and proof steps +func VerifyProof(leaf, root []byte, proof []ProofStep) bool { + computedHash := leaf + for _, step := range proof { + if step.IsLeft { + computedHash = hash(append(step.Hash, computedHash...)) + } else { + computedHash = hash(append(computedHash, step.Hash...)) + } + } + return bytes.Equal(computedHash, root) +} diff --git a/partialmessages/bitmap/bitmap.go b/partialmessages/bitmap/bitmap.go new file mode 100644 index 00000000..f014d177 --- /dev/null +++ b/partialmessages/bitmap/bitmap.go @@ -0,0 +1,84 @@ +package bitmap + +import ( + "math/bits" +) + +type Bitmap []byte + +func NewBitmapWithOnesCount(n int) Bitmap { + b := make(Bitmap, (n+7)/8) + for i := range n { + b.Set(i) + } + return b +} + +func Merge(left, right Bitmap) Bitmap { + out := make(Bitmap, max(len(left), len(right))) + out.Or(left) + out.Or(right) + return out +} + +func (b Bitmap) IsZero() bool { + for i := range b { + if b[i] != 0 { + return false + } + } + return true +} + +func (b Bitmap) OnesCount() int { + var count int + for i := range b { + count += bits.OnesCount8(b[i]) + } + return count +} + +func (b Bitmap) Set(index int) { + for len(b)*8 <= index { + b = append(b, 0) + } + b[index/8] |= 1 << (uint(index) % 8) +} + +func (b Bitmap) Get(index int) bool { + if index >= len(b)*8 { + return false + } + return b[index/8]&(1<<(uint(index%8))) != 0 +} + +func (b Bitmap) Clear(index int) { + if index >= len(b)*8 { + return + } + b[index/8] &^= 1 << (uint(index) % 8) +} + +func (b Bitmap) And(other Bitmap) { + for i := range min(len(b), len(other)) { + b[i] &= other[i] + } +} + +func (b Bitmap) Or(other Bitmap) { + for i := range min(len(b), len(other)) { + b[i] |= other[i] + } +} + +func (b Bitmap) Xor(other Bitmap) { + for i := range min(len(b), len(other)) { + b[i] ^= other[i] + } +} + +func (b Bitmap) Flip() { + for i := range b { + b[i] ^= 0xff + } +} diff --git a/partialmessages/invariants.go b/partialmessages/invariants.go new file mode 100644 index 00000000..633f7d34 --- /dev/null +++ b/partialmessages/invariants.go @@ -0,0 +1,290 @@ +package partialmessages + +import ( + "bytes" + "testing" + + "github.com/libp2p/go-libp2p/core/peer" +) + +// InvariantChecker is a test tool to test an implementation of PartialMessage +// upholds its invariants. Use this in your application's tests to validate your +// PartialMessage implementation. +type InvariantChecker[P Message] interface { + // SplitIntoParts returns a list of partial messages where there are no + // overlaps between messages, and the sum of all messages is the original + // partial message. + SplitIntoParts(in P) ([]P, error) + + // FullMessage should return a complete partial message. + FullMessage() (P, error) + + // EmptyMessage should return a empty partial message. + EmptyMessage() P + + // ExtendFromBytes extends a from data and returns the extended partial + // message or an error. An implementation may mutate a and return a. + ExtendFromBytes(a P, data []byte) (P, error) + + // ShouldRequest should return true if peer has something we're interested + // in (as determined from the parts metadata) + ShouldRequest(a P, from peer.ID, partsMetadata []byte) bool + + MergePartsMetadata(left, right PartsMetadata) PartsMetadata + + Equal(a, b P) bool +} + +func TestPartialMessageInvariants[P Message](t *testing.T, checker InvariantChecker[P]) { + extend := func(a, b P) (P, error) { + emptyParts := checker.EmptyMessage().PartsMetadata() + encodedB, err := b.PartialMessageBytes(emptyParts) + if err != nil { + var out P + return out, err + } + return checker.ExtendFromBytes(a, encodedB) + } + + t.Run("A empty message should not return a nil slice for parts metadata (should encode a request)", func(t *testing.T) { + empty := checker.EmptyMessage() + b := empty.PartsMetadata() + if b == nil { + t.Errorf("did not expect empty slice") + } + }) + t.Run("Splitting a full message, and then recombining it yields the original message", func(t *testing.T) { + fullMessage, err := checker.FullMessage() + if err != nil { + t.Fatal(err) + } + + parts, err := checker.SplitIntoParts(fullMessage) + if err != nil { + t.Fatal(err) + } + + recombined := checker.EmptyMessage() + for _, part := range parts { + b, err := part.PartialMessageBytes(recombined.PartsMetadata()) + if err != nil { + t.Fatal(err) + } + recombined, err = checker.ExtendFromBytes(recombined, b) + if err != nil { + t.Fatal(err) + } + } + + if !checker.Equal(fullMessage, recombined) { + t.Errorf("Expected %v, got %v", fullMessage, recombined) + } + }) + + t.Run("Empty message requesting parts it doesn't have returns nil response", func(t *testing.T) { + emptyMessage := checker.EmptyMessage() + + // Get metadata representing all parts from the empty message + emptyMsgPartsMeta := emptyMessage.PartsMetadata() + + // Empty message should not be able to fulfill any request + response, err := emptyMessage.PartialMessageBytes(emptyMsgPartsMeta) + if err != nil { + t.Fatal(err) + } + rest := checker.MergePartsMetadata(emptyMsgPartsMeta, emptyMessage.PartsMetadata()) + + if len(response) != 0 { + t.Error("Empty message should return nil response when requesting parts it doesn't have") + } + + // The rest should be the same as the original request since nothing was fulfilled + if len(rest) == 0 && len(emptyMsgPartsMeta) > 0 { + t.Error("Empty message should return the full request as 'rest' when it cannot fulfill anything") + } + }) + + t.Run("Partial fulfillment returns correct rest and can be completed by another message", func(t *testing.T) { + fullMessage, err := checker.FullMessage() + if err != nil { + t.Fatal(err) + } + + parts, err := checker.SplitIntoParts(fullMessage) + if err != nil { + t.Fatal(err) + } + + // Skip this test if we can't split into at least 2 parts + if len(parts) < 2 { + t.Skip("Cannot test partial fulfillment with less than 2 parts") + } + + // Get metadata representing all parts needed + emptyMessage := checker.EmptyMessage() + emptyMsgPartsMeta := emptyMessage.PartsMetadata() + + // Request all parts from the partial message + response1, err := parts[0].PartialMessageBytes(emptyMsgPartsMeta) + if err != nil { + t.Fatal(err) + } + rest1 := checker.MergePartsMetadata(emptyMsgPartsMeta, parts[0].PartsMetadata()) + + // Should get some response since partial message has at least one part + if len(response1) == 0 { + t.Error("Partial message should return some data when it has parts to fulfill") + } + + // Rest should be non-zero and different from original request since something was fulfilled + if len(rest1) == 0 { + t.Fatal("Rest should be non-zero when partial fulfillment occurred") + } + if bytes.Equal(rest1, emptyMsgPartsMeta) { + t.Fatalf("Rest should be different from original request since partial fulfillment occurred") + } + + // Create another partial message with the remaining parts + remainingPartial := checker.EmptyMessage() + for i := 1; i < len(parts); i++ { + remainingPartial, err = extend(remainingPartial, parts[i]) + if err != nil { + t.Fatal(err) + } + } + + // The remaining partial message should be able to fulfill the "rest" request + response2, err := remainingPartial.PartialMessageBytes(rest1) + if err != nil { + t.Fatal(err) + } + rest2 := checker.MergePartsMetadata(rest1, remainingPartial.PartsMetadata()) + + // response2 should be non-empty since we have remaining parts to fulfill + if len(response2) == 0 { + t.Error("Response2 should be non-empty when fulfilling remaining parts") + } + + // After fulfilling the rest, the metadata should be the same as full + if !bytes.Equal(rest2, fullMessage.PartsMetadata()) { + t.Errorf("After fulfilling all parts, the parts metadata should be the same as the full message, saw %v", rest2) + } + + // Combine both responses and verify we can reconstruct the full message + reconstructed := checker.EmptyMessage() + reconstructed, err = checker.ExtendFromBytes(reconstructed, response1) + if err != nil { + t.Fatal(err) + } + if len(response2) > 0 { + reconstructed, err = checker.ExtendFromBytes(reconstructed, response2) + if err != nil { + t.Fatal(err) + } + } + + // The reconstructed message should be equivalent to the full message + if !checker.Equal(fullMessage, reconstructed) { + t.Errorf("Reconstructed message from partial responses should equal full message") + } + }) + + t.Run("PartialMessageBytesFromMetadata with empty metadata requests all parts", func(t *testing.T) { + fullMessage, err := checker.FullMessage() + if err != nil { + t.Fatal(err) + } + + // Request with empty metadata should return all available parts + emptyMeta := checker.EmptyMessage().PartsMetadata() + response, err := fullMessage.PartialMessageBytes(emptyMeta) + rest := checker.MergePartsMetadata(emptyMeta, fullMessage.PartsMetadata()) + if err != nil { + t.Fatal(err) + } + + // Should get some response from a full message + if len(response) == 0 { + t.Error("Full message should return data when requested with empty metadata") + } + + // Should have no remaining parts since full message can fulfill everything + if !bytes.Equal(rest, fullMessage.PartsMetadata()) { + t.Error("Full message should have no remaining parts when fulfilling empty metadata request") + } + }) + + t.Run("Available parts, missing parts, and partial message bytes consistency", func(t *testing.T) { + fullMessage, err := checker.FullMessage() + if err != nil { + t.Fatal(err) + } + + // Get the available parts + fullMsgPartsMeta := fullMessage.PartsMetadata() + + // Assert available parts is non-zero length + if len(fullMsgPartsMeta) == 0 { + t.Error("Full message should have non-zero available parts") + } + + // Split the full message into parts + parts, err := checker.SplitIntoParts(fullMessage) + if err != nil { + t.Fatal(err) + } + + var partialMessageResponses [][]byte + + // Test each part and empty message + testMessages := make([]P, len(parts)+1) + copy(testMessages, parts) + testMessages[len(parts)] = checker.EmptyMessage() + + for i, testMsg := range testMessages { + // Assert that ShouldRequest returns true for the available parts + if !checker.ShouldRequest(testMsg, "", fullMsgPartsMeta) { + t.Errorf("Message %d should request the available parts", i) + } + + // Get the MissingParts() and have the full message fulfill the request + msgPartsMeta := testMsg.PartsMetadata() + + response, err := fullMessage.PartialMessageBytes(msgPartsMeta) + rest := checker.MergePartsMetadata(msgPartsMeta, fullMessage.PartsMetadata()) + if err != nil { + t.Fatal(err) + } + + // Assert that the rest is nil + if !bytes.Equal(rest, fullMessage.PartsMetadata()) { + t.Errorf("rest should be equal to fullMessage.PartsMetadata() for message %d", i) + } + + // Store each partial message bytes + if len(response) > 0 { + partialMessageResponses = append(partialMessageResponses, response) + } + + // Call ExtendFromEncodedPartialMessage + testMsg, err = checker.ExtendFromBytes(testMsg, response) + if err != nil { + t.Fatal(err) + } + + // Assert the extended form is now equal to the full message + if !checker.Equal(fullMessage, testMsg) { + t.Errorf("Extended message %d should equal full message", i) + } + } + + // Assert that none of the partial message bytes are equal to each other. + for i := range partialMessageResponses { + for j := i + 1; j < len(partialMessageResponses); j++ { + if bytes.Equal(partialMessageResponses[i], partialMessageResponses[j]) { + t.Errorf("Partial message bytes %d and %d should not be equal", i, j) + } + } + } + }) +} diff --git a/partialmessages/partialmsgs.go b/partialmessages/partialmsgs.go new file mode 100644 index 00000000..37cdd285 --- /dev/null +++ b/partialmessages/partialmsgs.go @@ -0,0 +1,398 @@ +package partialmessages + +import ( + "bytes" + "errors" + "iter" + "log/slog" + "slices" + + "github.com/libp2p/go-libp2p-pubsub/partialmessages/bitmap" + pb "github.com/libp2p/go-libp2p-pubsub/pb" + "github.com/libp2p/go-libp2p/core/peer" +) + +// TODO: Add gossip fallback (pick random connected peers and send ihave/iwants) + +const minGroupTTL = 3 + +// defaultPeerInitiatedGroupLimitPerTopic limits the total number (per topic) of +// *partialMessageStatePerTopicGroup we create in response to a incoming RPC. +// This only applies to groups that we haven't published for yet. +const defaultPeerInitiatedGroupLimitPerTopic = 255 + +const defaultPeerInitiatedGroupLimitPerTopicPerPeer = 8 + +// PartsMetadata returns metadata about the parts this partial message +// contains and, possibly implicitly, the parts it wants. +type PartsMetadata []byte + +// Message is a message that can be broken up into parts. It can be +// complete, partially complete, or empty. It is up to the application to define +// how a message is split into parts and recombined, as well as how missing and +// available parts are represented. +// +// It is passed to Gossipsub with a PublishPartialMessage method call. Gossipsub +// keeps a reference to this object, so the implementation should either not +// mutate this object in a separate goroutine after handing it to Gossipsub, or +// take care to make the object thread safe. +type Message interface { + GroupID() []byte + + // PartialMessageBytes takes in the opaque request metadata and + // returns a encoded partial message that fulfills as much of the request as + // possible. It also returns a opaque request metadata representing the + // parts it could not fulfill. This MUST be empty if the implementation could + // fulfill the whole request. + // + // An empty metadata should be treated the same as a request for all parts. + // + // If the Partial Message is empty, the implementation MUST return: + // nil, metadata, nil. + PartialMessageBytes(partsMetadata PartsMetadata) (msg []byte, _ error) + + PartsMetadata() PartsMetadata +} + +// peerState is the state we keep per peer. Used to make Publish +// Idempotent. +type peerState struct { + // The parts metadata the peer has sent us + partsMetadata PartsMetadata + // The parts metadata this node has sent to the peer + sentPartsMetadata PartsMetadata +} + +func (ps *peerState) IsZero() bool { + return ps.partsMetadata == nil && ps.sentPartsMetadata == nil +} + +type partialMessageStatePerGroupPerTopic struct { + peerState map[peer.ID]*peerState + groupTTL int + initiatedBy peer.ID // zero value if we initiated the group +} + +func newPartialMessageStatePerTopicGroup(groupTTL int) *partialMessageStatePerGroupPerTopic { + return &partialMessageStatePerGroupPerTopic{ + peerState: make(map[peer.ID]*peerState), + groupTTL: max(groupTTL, minGroupTTL), + } +} + +func (s *partialMessageStatePerGroupPerTopic) remotePeerInitiated() bool { + return s.initiatedBy != "" +} + +func (s *partialMessageStatePerGroupPerTopic) clearPeerMetadata(peerID peer.ID) { + if peerState, ok := s.peerState[peerID]; ok { + peerState.partsMetadata = nil + if peerState.IsZero() { + delete(s.peerState, peerID) + } + } +} + +// MergeBitmap is a helper function for merging parts metadata if they are a +// bitmap. +func MergeBitmap(left, right PartsMetadata) PartsMetadata { + return PartsMetadata(bitmap.Merge(bitmap.Bitmap(left), bitmap.Bitmap(right))) +} + +type PartialMessageExtension struct { + Logger *slog.Logger + + MergePartsMetadata func(topic string, left, right PartsMetadata) PartsMetadata + + // OnIncomingRPC is called whenever we receive an encoded + // partial message from a peer. This func MUST be fast and non-blocking. + // If you need to do slow work, dispatch the work to your own goroutine. + OnIncomingRPC func(from peer.ID, rpc *pb.PartialMessagesExtension) error + + // ValidateRPC should be a fast function that performs some + // basic sanity checks on incoming RPC. For example: + // - Is this a known topic? + // - Is the groupID well formed per application semantics? + // - If this is a PartialIHAVE/PartialIWant, is the request metadata within + // expected bounds? + ValidateRPC func(from peer.ID, rpc *pb.PartialMessagesExtension) error + + // PeerInitiatedGroupLimitPerTopic limits the number of Group states all + // peers can initialize per topic. A group state is initialized by a peer if + // the peer's message marks the first time we've seen a group id. + PeerInitiatedGroupLimitPerTopic int + + // PeerInitiatedGroupLimitPerTopicPerPeer limits the number of Group states + // a single peer can initialize per topic. A group state is initialized by a + // peer if the peer's message marks the first time we've seen a group id. + PeerInitiatedGroupLimitPerTopicPerPeer int + + // GroupTTLByHeatbeat is how many heartbeats we store Group state for after + // publishing a partial message for the group. + GroupTTLByHeatbeat int + + // map topic -> map[group]partialMessageStatePerGroupPerTopic + statePerTopicPerGroup map[string]map[string]*partialMessageStatePerGroupPerTopic + + // map[topic]counter + peerInitiatedGroupCounter map[string]*peerInitiatedGroupCounterState + + router Router +} + +type PublishOptions struct { + // PublishToPeers limits the publishing to only the specified peers. + // If nil, will use the topic's mesh peers. + PublishToPeers []peer.ID + // EagerPush is data that will be eagerly pushed to peers in a PartialMessage + EagerPush []byte +} + +type Router interface { + SendRPC(p peer.ID, r *pb.PartialMessagesExtension, urgent bool) + MeshPeers(topic string) iter.Seq[peer.ID] + PeerRequestsPartial(peer peer.ID, topic string) bool +} + +func (e *PartialMessageExtension) groupState(topic string, groupID []byte, peerInitiated bool, from peer.ID) (*partialMessageStatePerGroupPerTopic, error) { + statePerTopic, ok := e.statePerTopicPerGroup[topic] + if !ok { + statePerTopic = make(map[string]*partialMessageStatePerGroupPerTopic) + e.statePerTopicPerGroup[topic] = statePerTopic + } + if _, ok := e.peerInitiatedGroupCounter[topic]; !ok { + e.peerInitiatedGroupCounter[topic] = &peerInitiatedGroupCounterState{} + } + state, ok := statePerTopic[string(groupID)] + if !ok { + if peerInitiated { + err := e.peerInitiatedGroupCounter[topic].Inc(e.PeerInitiatedGroupLimitPerTopic, e.PeerInitiatedGroupLimitPerTopicPerPeer, from) + if err != nil { + return nil, err + } + } + + state = newPartialMessageStatePerTopicGroup(e.GroupTTLByHeatbeat) + statePerTopic[string(groupID)] = state + state.initiatedBy = from + } + if !peerInitiated && state.remotePeerInitiated() { + // We've tried to initiate this state as well, so it's no longer peer initiated. + e.peerInitiatedGroupCounter[topic].Dec(state.initiatedBy) + state.initiatedBy = "" + } + return state, nil +} + +func (e *PartialMessageExtension) Init(router Router) error { + e.router = router + if e.Logger == nil { + return errors.New("field Logger must be set") + } + if e.ValidateRPC == nil { + return errors.New("field ValidateRPC must be set") + } + if e.OnIncomingRPC == nil { + return errors.New("field OnIncomingRPC must be set") + } + if e.MergePartsMetadata == nil { + return errors.New("field MergePartsMetadata must be set") + } + + if e.PeerInitiatedGroupLimitPerTopic == 0 { + e.PeerInitiatedGroupLimitPerTopic = defaultPeerInitiatedGroupLimitPerTopic + } + if e.PeerInitiatedGroupLimitPerTopicPerPeer == 0 { + e.PeerInitiatedGroupLimitPerTopicPerPeer = defaultPeerInitiatedGroupLimitPerTopicPerPeer + } + + e.statePerTopicPerGroup = make(map[string]map[string]*partialMessageStatePerGroupPerTopic) + e.peerInitiatedGroupCounter = make(map[string]*peerInitiatedGroupCounterState) + + return nil +} + +func (e *PartialMessageExtension) PublishPartial(topic string, partial Message, opts PublishOptions) error { + groupID := partial.GroupID() + myPartsMeta := partial.PartsMetadata() + + state, err := e.groupState(topic, groupID, false, "") + if err != nil { + return err + } + + state.groupTTL = max(e.GroupTTLByHeatbeat, minGroupTTL) + + var peers iter.Seq[peer.ID] + if len(opts.PublishToPeers) > 0 { + peers = slices.Values(opts.PublishToPeers) + } else { + peers = e.router.MeshPeers(topic) + } + for p := range peers { + log := e.Logger.With("peer", p) + requestedPartial := e.router.PeerRequestsPartial(p, topic) + + var rpc pb.PartialMessagesExtension + var sendRPC bool + var inResponseToIWant bool + + pState, peerStateOk := state.peerState[p] + if !peerStateOk { + pState = &peerState{} + state.peerState[p] = pState + } + + // Try to fulfill any wants from the peer + if requestedPartial && pState.partsMetadata != nil { + // This peer has previously asked for a certain part. We'll give + // them what we can. + pm, err := partial.PartialMessageBytes(pState.partsMetadata) + if err != nil { + log.Warn("partial message extension failed to get partial message bytes", "error", err) + // Possibly a bad request, we'll delete the request as we will likely error next time we try to handle it + state.clearPeerMetadata(p) + continue + } + pState.partsMetadata = e.MergePartsMetadata(topic, pState.partsMetadata, myPartsMeta) + if len(pm) > 0 { + log.Debug("Respond to peer's IWant") + sendRPC = true + rpc.PartialMessage = pm + inResponseToIWant = true + } + } + + // Only send the eager push to the peer if: + // - we didn't reply to an explicit request + // - we have something to eager push + if requestedPartial && !inResponseToIWant && len(opts.EagerPush) > 0 { + log.Debug("Eager pushing") + sendRPC = true + rpc.PartialMessage = opts.EagerPush + } + + // Only send parts metadata if it was different then before + if pState.sentPartsMetadata == nil || !bytes.Equal(myPartsMeta, pState.sentPartsMetadata) { + log.Debug("Including parts metadata") + sendRPC = true + pState.sentPartsMetadata = myPartsMeta + rpc.PartsMetadata = myPartsMeta + } + + if sendRPC { + rpc.TopicID = &topic + rpc.GroupID = groupID + e.sendRPC(p, &rpc) + } + } + + return nil +} + +func (e *PartialMessageExtension) AddPeer(id peer.ID) { +} + +func (e *PartialMessageExtension) RemovePeer(id peer.ID) { + for topic, statePerTopic := range e.statePerTopicPerGroup { + for _, state := range statePerTopic { + delete(state.peerState, id) + } + if ctr, ok := e.peerInitiatedGroupCounter[topic]; ok { + ctr.RemovePeer(id) + } + } +} + +func (e *PartialMessageExtension) Heartbeat() { + for topic, statePerTopic := range e.statePerTopicPerGroup { + for group, s := range statePerTopic { + if s.groupTTL == 0 { + delete(statePerTopic, group) + if len(statePerTopic) == 0 { + delete(e.statePerTopicPerGroup, topic) + } + if s.remotePeerInitiated() { + e.peerInitiatedGroupCounter[topic].Dec(s.initiatedBy) + } + } else { + s.groupTTL-- + } + } + } +} + +func (e *PartialMessageExtension) sendRPC(to peer.ID, rpc *pb.PartialMessagesExtension) { + e.Logger.Debug("Sending RPC", "to", to, "rpc", rpc) + e.router.SendRPC(to, rpc, false) +} + +func (e *PartialMessageExtension) HandleRPC(from peer.ID, rpc *pb.PartialMessagesExtension) error { + if rpc == nil { + return nil + } + if err := e.ValidateRPC(from, rpc); err != nil { + return err + } + e.Logger.Debug("Received RPC", "from", from, "rpc", rpc) + topic := rpc.GetTopicID() + groupID := rpc.GroupID + + state, err := e.groupState(topic, groupID, true, from) + if err != nil { + return err + } + + if rpc.PartsMetadata != nil { + pState, ok := state.peerState[from] + if !ok { + pState = &peerState{} + state.peerState[from] = pState + } + pState.partsMetadata = e.MergePartsMetadata(rpc.GetTopicID(), pState.partsMetadata, rpc.PartsMetadata) + } + + return e.OnIncomingRPC(from, rpc) +} + +type peerInitiatedGroupCounterState struct { + // total number of peer initiated groups + total int + // number of groups initiated per peer + perPeer map[peer.ID]int +} + +var errPeerInitiatedGroupTotalLimitReached = errors.New("too many peer initiated group states") +var errPeerInitiatedGroupLimitReached = errors.New("too many peer initiated group states for this peer") + +func (ctr *peerInitiatedGroupCounterState) Inc(totalLimit int, peerLimit int, id peer.ID) error { + if ctr.total >= totalLimit { + return errPeerInitiatedGroupTotalLimitReached + } + if ctr.perPeer == nil { + ctr.perPeer = make(map[peer.ID]int) + } + if ctr.perPeer[id] >= peerLimit { + return errPeerInitiatedGroupLimitReached + } + ctr.total++ + ctr.perPeer[id]++ + return nil +} + +func (ctr *peerInitiatedGroupCounterState) Dec(id peer.ID) { + if _, ok := ctr.perPeer[id]; ok { + ctr.total-- + ctr.perPeer[id]-- + if ctr.perPeer[id] == 0 { + delete(ctr.perPeer, id) + } + } +} + +func (ctr *peerInitiatedGroupCounterState) RemovePeer(id peer.ID) { + if n, ok := ctr.perPeer[id]; ok { + ctr.total -= n + delete(ctr.perPeer, id) + } +} diff --git a/partialmessages/partialmsgs_test.go b/partialmessages/partialmsgs_test.go new file mode 100644 index 00000000..2b8f47c4 --- /dev/null +++ b/partialmessages/partialmsgs_test.go @@ -0,0 +1,1223 @@ +package partialmessages + +import ( + "bytes" + cryptorand "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "iter" + "log/slog" + "math/big" + "math/rand" + "reflect" + "testing" + + "github.com/libp2p/go-libp2p-pubsub/internal/merkle" + "github.com/libp2p/go-libp2p-pubsub/partialmessages/bitmap" + pubsub_pb "github.com/libp2p/go-libp2p-pubsub/pb" + "github.com/libp2p/go-libp2p/core/peer" +) + +// testRouter implements the Router interface for testing +type testRouter struct { + sendRPC func(p peer.ID, r *pubsub_pb.PartialMessagesExtension, urgent bool) + meshPeers func(topic string) iter.Seq[peer.ID] +} + +// PeerRequestsPartial implements Router. +func (r *testRouter) PeerRequestsPartial(peer peer.ID, topic string) bool { + return true +} + +func (r *testRouter) SendRPC(p peer.ID, rpc *pubsub_pb.PartialMessagesExtension, urgent bool) { + r.sendRPC(p, rpc, urgent) +} + +func (r *testRouter) MeshPeers(topic string) iter.Seq[peer.ID] { + return r.meshPeers(topic) +} + +type rpcWithFrom struct { + from peer.ID + rpc *pubsub_pb.PartialMessagesExtension +} + +type mockNetworkPartialMessages struct { + t *testing.T + pendingMsgs map[peer.ID][]rpcWithFrom + + allSentMsgs map[peer.ID][]rpcWithFrom + + handlers map[peer.ID]*PartialMessageExtension +} + +func (m *mockNetworkPartialMessages) addPeers() { + for a := range m.handlers { + for b := range m.handlers { + if a == b { + continue + } + m.handlers[a].AddPeer(b) + m.handlers[b].AddPeer(a) + } + } +} + +func (m *mockNetworkPartialMessages) removePeers() { + for a := range m.handlers { + for b := range m.handlers { + if a == b { + continue + } + m.handlers[a].RemovePeer(b) + m.handlers[b].RemovePeer(a) + } + } + + // assert that there are no leaked peerInitiatedGroupCountPerTopics + for _, h := range m.handlers { + for topic, ctr := range h.peerInitiatedGroupCounter { + if ctr.total != 0 { + m.t.Errorf("unexpected peerInitiatedGroupCountPerTopic for topic %s: %d", topic, ctr.total) + } + for _, v := range ctr.perPeer { + if v != 0 { + m.t.Errorf("unexpected peerInitiatedGroupCountPerTopic for topic %s: %d", topic, v) + } + } + } + } +} + +func (m *mockNetworkPartialMessages) handleRPCs() bool { + for id, h := range m.handlers { + if len(m.pendingMsgs[id]) > 0 { + var rpc rpcWithFrom + rpc, m.pendingMsgs[id] = m.pendingMsgs[id][0], m.pendingMsgs[id][1:] + h.HandleRPC(rpc.from, rpc.rpc) + } + } + moreLeft := false + for id := range m.handlers { + if len(m.pendingMsgs[id]) > 0 { + moreLeft = true + break + } + } + return moreLeft +} + +func (m *mockNetworkPartialMessages) sendRPC(from, to peer.ID, rpc *pubsub_pb.PartialMessagesExtension, _ bool) { + if to == "" { + panic("empty peer ID") + } + // fmt.Printf("Sending RPC from %s to %s: %+v\n", from, to, rpc) + m.pendingMsgs[to] = append(m.pendingMsgs[to], rpcWithFrom{from, rpc}) + m.allSentMsgs[to] = append(m.allSentMsgs[to], rpcWithFrom{from, rpc}) +} + +const testPartialMessageLeaves = 8 + +// testPartialMessage represents a partial message where parts can be verified +// by a merkle tree commitment. By convention, there are +// `testPartialMessageLeaves` parts. +type testPartialMessage struct { + Commitment []byte + Parts [testPartialMessageLeaves][]byte + Proofs [testPartialMessageLeaves][]merkle.ProofStep + + republish func(*testPartialMessage, []byte) + onErr func(error) +} + +func (pm *testPartialMessage) complete() bool { + for _, part := range pm.Parts { + if len(part) == 0 { + return false + } + } + return true +} + +// AvailableParts returns a bitmap of available parts +func (pm *testPartialMessage) PartsMetadata() PartsMetadata { + out := bitmap.NewBitmapWithOnesCount(testPartialMessageLeaves) + for i, part := range pm.Parts { + if len(part) == 0 { + out.Clear(i) + } + } + return PartsMetadata(out) +} + +func (pm *testPartialMessage) extendFromEncodedPartialMessage(_ peer.ID, data []byte) (extended bool) { + if len(data) == 0 { + return + } + var decoded testPartialMessage + if err := json.Unmarshal(data, &decoded); err != nil { + pm.onErr(err) + return + } + + // Verify + if !bytes.Equal(pm.Commitment, decoded.Commitment) { + pm.onErr(errors.New("commitment mismatch")) + return + } + + for i, part := range decoded.Parts { + if len(pm.Parts[i]) > 0 { + continue + } + if len(part) == 0 { + continue + } + proof := decoded.Proofs[i] + if len(proof) == 0 { + continue + } + if !merkle.VerifyProof(part, pm.Commitment, proof) { + pm.onErr(errors.New("proof verification failed")) + return + } + + pm.Parts[i] = part + pm.Proofs[i] = proof + extended = true + } + + nonEmptyParts := 0 + for i := range pm.Parts { + if len(pm.Parts[i]) > 0 { + nonEmptyParts++ + } + } + + return +} + +// GroupID implements PartialMessage. +func (pm *testPartialMessage) GroupID() []byte { + return pm.Commitment +} + +func (pm *testPartialMessage) shouldRequest(partsMetadata []byte) bool { + var myParts big.Int + myParts.SetBytes(pm.PartsMetadata()) + var zero big.Int + + var peerHas big.Int + peerHas.SetBytes(partsMetadata) + + var iWant big.Int + iWant.Xor(&myParts, &peerHas) + iWant.And(&iWant, &peerHas) + + return iWant.Cmp(&zero) != 0 +} + +// PartialMessageBytes implements PartialMessage. +func (pm *testPartialMessage) PartialMessageBytes(metadata PartsMetadata) ([]byte, error) { + peerHas := bitmap.Bitmap(metadata) + + var added bool + var tempMessage testPartialMessage + tempMessage.Commitment = pm.Commitment + for i := range pm.Parts { + if peerHas.Get(i) { + continue + } + + if len(pm.Parts[i]) == 0 { + // We can't fulfill this part + continue + } + + tempMessage.Parts[i] = pm.Parts[i] + tempMessage.Proofs[i] = pm.Proofs[i] + added = true + } + + if !added { + return nil, nil + } + + b, err := json.Marshal(tempMessage) + if err != nil { + return nil, err + } + + return b, nil +} + +type testPartialMessageChecker struct { + fullMessage *testPartialMessage +} + +func (t *testPartialMessageChecker) MergePartsMetadata(left, right PartsMetadata) PartsMetadata { + return MergeBitmap(left, right) +} + +// EmptyMessage implements InvariantChecker. +func (t *testPartialMessageChecker) EmptyMessage() *testPartialMessage { + return &testPartialMessage{Commitment: t.fullMessage.Commitment, republish: func(pm *testPartialMessage, _ []byte) {}} +} + +// Equal implements InvariantChecker. +func (t *testPartialMessageChecker) Equal(a, b *testPartialMessage) bool { + if !bytes.Equal(a.Commitment, b.Commitment) { + return false + } + if len(a.Parts) != len(b.Parts) { + return false + } + for i := range a.Parts { + if !bytes.Equal(a.Parts[i], b.Parts[i]) { + return false + } + aProof := a.Proofs[i] + bProof := b.Proofs[i] + for j := range aProof { + if !bytes.Equal(aProof[j].Hash, bProof[j].Hash) { + return false + } + if aProof[j].IsLeft != bProof[j].IsLeft { + return false + } + } + } + return true +} + +// ExtendFromBytes implements InvariantChecker. +func (t *testPartialMessageChecker) ExtendFromBytes(a *testPartialMessage, data []byte) (*testPartialMessage, error) { + var err error + a.onErr = func(e error) { + err = e + } + a.extendFromEncodedPartialMessage("", data) + a.onErr = nil + if err != nil { + return nil, err + } + return a, nil +} +func (t *testPartialMessageChecker) ShouldRequest(a *testPartialMessage, from peer.ID, partsMetadata []byte) bool { + return a.shouldRequest(partsMetadata) +} + +// FullMessage implements InvariantChecker. +func (t *testPartialMessageChecker) FullMessage() (*testPartialMessage, error) { + return t.fullMessage, nil +} + +// SplitIntoParts implements InvariantChecker. +func (t *testPartialMessageChecker) SplitIntoParts(in *testPartialMessage) ([]*testPartialMessage, error) { + parts := make([]*testPartialMessage, len(in.Parts)) + for i := range in.Parts { + p := &testPartialMessage{ + Commitment: in.Commitment, + republish: in.republish, + } + p.Parts[i] = in.Parts[i] + p.Proofs[i] = in.Proofs[i] + parts[i] = p + } + return parts, nil +} + +func TestExamplePartialMessageImpl(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + // Create a dummy extension for the test + dummyExt := &PartialMessageExtension{} + full, err := newFullTestMessage(rand, dummyExt, "test-topic") + if err != nil { + t.Fatal(err) + } + full.republish = func(pm *testPartialMessage, _ []byte) {} + + invariant := &testPartialMessageChecker{ + fullMessage: full, + } + TestPartialMessageInvariants(t, invariant) +} + +var _ InvariantChecker[*testPartialMessage] = (*testPartialMessageChecker)(nil) +var _ Message = (*testPartialMessage)(nil) + +type testPeers struct { + peers []peer.ID + handlers []*PartialMessageExtension + network *mockNetworkPartialMessages + // Track partial messages per peer per topic per group + partialMessages map[peer.ID]map[string]map[string]*testPartialMessage +} + +func createPeers(t *testing.T, topic string, n int) *testPeers { + nw := &mockNetworkPartialMessages{ + t: t, + pendingMsgs: make(map[peer.ID][]rpcWithFrom), + allSentMsgs: make(map[peer.ID][]rpcWithFrom), + handlers: make(map[peer.ID]*PartialMessageExtension), + } + + peers := make([]peer.ID, n) + handlers := make([]*PartialMessageExtension, n) + + // Create peer IDs + for i := range n { + peers[i] = peer.ID(fmt.Sprintf("%d", i+1)) + } + + // Create testPeers structure first + testPeers := &testPeers{ + peers: peers, + handlers: handlers, + network: nw, + partialMessages: make(map[peer.ID]map[string]map[string]*testPartialMessage), + } + + // Initialize partial message tracking for each peer + for _, peerID := range peers { + testPeers.partialMessages[peerID] = make(map[string]map[string]*testPartialMessage) + } + + // Create handlers for each peer + for i := range n { + currentPeer := peers[i] + + // Create router for this peer + router := &testRouter{ + sendRPC: func(p peer.ID, r *pubsub_pb.PartialMessagesExtension, urgent bool) { + nw.sendRPC(currentPeer, p, r, urgent) + }, + meshPeers: func(topic string) iter.Seq[peer.ID] { + return func(yield func(peer.ID) bool) { + // Yield all other peers + for j, otherPeer := range peers { + if j != i { + if !yield(otherPeer) { + return + } + } + } + } + }, + } + + var handler *PartialMessageExtension + // Create handler + handler = &PartialMessageExtension{ + Logger: slog.Default().With("id", i), + MergePartsMetadata: func(_ string, left, right PartsMetadata) PartsMetadata { + return MergeBitmap(left, right) + }, + OnIncomingRPC: func(from peer.ID, rpc *pubsub_pb.PartialMessagesExtension) error { + // Handle incoming partial message data - use testPeers to track state + // Get or create the partial message for this topic/group + if testPeers.partialMessages[currentPeer][topic] == nil { + testPeers.partialMessages[currentPeer][topic] = make(map[string]*testPartialMessage) + } + + groupID := rpc.GroupID + groupKey := string(groupID) + pm := testPeers.partialMessages[currentPeer][topic][groupKey] + if pm == nil { + pm = &testPartialMessage{ + Commitment: groupID, + republish: func(pm *testPartialMessage, _ []byte) { + handlers[i].PublishPartial(topic, pm, PublishOptions{}) + }, + } + testPeers.partialMessages[currentPeer][topic][groupKey] = pm + } + + // Extend the partial message with the incoming data + recvdNewData := pm.extendFromEncodedPartialMessage(from, rpc.PartialMessage) + + if recvdNewData { + // Publish to all peers our new data. + // We'll request and fulfill any + handler.PublishPartial(topic, pm, PublishOptions{}) + return nil + } + + var zeroInt big.Int + + peerHasUsefulData := pm.shouldRequest(rpc.PartsMetadata) + + var iHave big.Int + iHave.SetBytes(pm.PartsMetadata()) + + var peerHas big.Int + peerHas.SetBytes(rpc.PartsMetadata) + + var peerWants big.Int + peerWants.Xor(&iHave, &peerHas) + peerWants.And(&peerWants, &iHave) + + weHaveUsefulData := peerWants.Cmp(&zeroInt) != 0 + + if weHaveUsefulData || peerHasUsefulData { + // This peer has something we want or we can provide + // something to them. Call publish partial just for them. + handler.PublishPartial(topic, pm, PublishOptions{ + PublishToPeers: []peer.ID{from}, + }) + } + return nil + }, + ValidateRPC: func(_ peer.ID, rpc *pubsub_pb.PartialMessagesExtension) error { + if len(rpc.PartsMetadata) > 1024 { + return errors.New("metadata too large") + } + return nil + }, + GroupTTLByHeatbeat: 5, + } + handler.Init(router) + + handlers[i] = handler + nw.handlers[currentPeer] = handler + } + + t.Cleanup(func() { + testPeers.cleanup(t) + }) + + return testPeers +} + +// Helper method to get or create a partial message for a peer +func (tp *testPeers) getOrCreatePartialMessage(peerIndex int, topic string, groupID []byte) *testPartialMessage { + peerID := tp.peers[peerIndex] + if tp.partialMessages[peerID][topic] == nil { + tp.partialMessages[peerID][topic] = make(map[string]*testPartialMessage) + } + + groupKey := string(groupID) + pm := tp.partialMessages[peerID][topic][groupKey] + if pm == nil { + handler := tp.handlers[peerIndex] + pm = &testPartialMessage{ + Commitment: groupID, + republish: func(pm *testPartialMessage, _ []byte) { + handler.PublishPartial(topic, pm, PublishOptions{}) + }, + } + tp.partialMessages[peerID][topic][groupKey] = pm + } + return pm +} + +func (tp *testPeers) cleanup(t *testing.T) { + // Assert no more state is left + for range 10 { + for _, h := range tp.handlers { + h.Heartbeat() + } + } + for _, h := range tp.handlers { + if len(h.statePerTopicPerGroup) != 0 { + t.Fatal("handlers should have cleaned up all their state") + } + } + + // Assert no empty RPCs + for _, msgs := range tp.network.allSentMsgs { + for _, msg := range msgs { + if msg.rpc.Size() == 0 { + t.Fatal("empty message") + } + } + } +} + +// Helper function to register a message with the test framework +func (tp *testPeers) registerMessage(peerIndex int, topic string, msg *testPartialMessage) { + peerID := tp.peers[peerIndex] + if tp.partialMessages[peerID][topic] == nil { + tp.partialMessages[peerID][topic] = make(map[string]*testPartialMessage) + } + tp.partialMessages[peerID][topic][string(msg.GroupID())] = msg +} + +func newFullTestMessage(r io.Reader, ext *PartialMessageExtension, topic string) (*testPartialMessage, error) { + out := &testPartialMessage{} + for i := range out.Parts { + out.Parts[i] = make([]byte, 8) + if _, err := io.ReadFull(r, out.Parts[i]); err != nil { + return nil, err + } + } + out.Commitment = merkle.MerkleRoot(out.Parts[:]) + for i := range out.Parts { + out.Proofs[i] = merkle.MerkleProof(out.Parts[:], i) + } + out.republish = func(pm *testPartialMessage, _ []byte) { + ext.PublishPartial(topic, pm, PublishOptions{}) + } + return out, nil +} + +func newEmptyTestMessage(commitment []byte, ext *PartialMessageExtension, topic string) *testPartialMessage { + return &testPartialMessage{ + Commitment: commitment, + republish: func(pm *testPartialMessage, _ []byte) { + ext.PublishPartial(topic, pm, PublishOptions{}) + }, + } +} + +func TestPartialMessages(t *testing.T) { + topic := "test-topic" + rand := rand.New(rand.NewSource(0)) + // For debugging: + // slog.SetLogLoggerLevel(slog.LevelDebug) + + t.Run("h1 has all the data. h2 requests it", func(t *testing.T) { + peers := createPeers(t, topic, 2) + peers.network.addPeers() + defer peers.network.removePeers() + + h1Msg, err := newFullTestMessage(rand, peers.handlers[0], topic) + if err != nil { + t.Fatal(err) + } + peers.registerMessage(0, topic, h1Msg) + + // h1 knows the full message + peers.handlers[0].PublishPartial(topic, h1Msg, PublishOptions{}) + peers.registerMessage(0, topic, h1Msg) + + // h2 only knows the group id + h2Msg := newEmptyTestMessage(h1Msg.Commitment, peers.handlers[1], topic) + peers.registerMessage(1, topic, h2Msg) + peers.handlers[1].PublishPartial(topic, h2Msg, PublishOptions{}) + + // Handle all RPCs + for peers.network.handleRPCs() { + } + + // Assert h2 has the full message + if !h2Msg.complete() { + t.Fatal("h2 should have the full message", h2Msg.PartsMetadata()) + } + }) + + t.Run("h1 has all the data and eager pushes. h2 has the next message also eager pushes", func(t *testing.T) { + peers := createPeers(t, topic, 2) + peers.network.addPeers() + defer peers.network.removePeers() + + h1Msg, err := newFullTestMessage(rand, peers.handlers[0], topic) + if err != nil { + t.Fatal(err) + } + + msgBytes, err := h1Msg.PartialMessageBytes(nil) + if err != nil { + t.Fatal(err) + } + + // h1 knows the full message and eager pushes + peers.handlers[0].PublishPartial(topic, h1Msg, PublishOptions{ + EagerPush: msgBytes, + }) + + // h2 will receive partial message data through OnIncomingRPC + // We can access it through our tracking system + lastPartialMessageh2 := peers.getOrCreatePartialMessage(1, topic, h1Msg.Commitment) + + // Handle all RPCs + for peers.network.handleRPCs() { + } + + // Assert h2 has the full message + if !lastPartialMessageh2.complete() { + t.Fatal("h2 should have the full message") + } + + h2Msg, err := newFullTestMessage(rand, peers.handlers[1], topic) + if err != nil { + t.Fatal(err) + } + msgBytes, err = h2Msg.PartialMessageBytes(nil) + if err != nil { + t.Fatal(err) + } + + // h2 knows the full message and eager pushes + peers.handlers[1].PublishPartial(topic, h2Msg, PublishOptions{ + EagerPush: msgBytes, + }) + + // h1 will receive partial message data through OnIncomingRPC + // We can access it through our tracking system + lastPartialMessageh1 := peers.getOrCreatePartialMessage(0, topic, h2Msg.Commitment) + + // Handle all RPCs + for peers.network.handleRPCs() { + } + + // Assert h1 has the full message + if !lastPartialMessageh1.complete() { + t.Fatal("h1 should have the full message") + } + }) + + t.Run("h1 has all the data. h2 doesn't know anything", func(t *testing.T) { + peers := createPeers(t, topic, 2) + peers.network.addPeers() + defer peers.network.removePeers() + + h1Msg, err := newFullTestMessage(rand, peers.handlers[0], topic) + if err != nil { + t.Fatal(err) + } + peers.registerMessage(0, topic, h1Msg) + + // h1 knows the full message + peers.handlers[0].PublishPartial(topic, h1Msg, PublishOptions{}) + + // Handle all RPCs + for peers.network.handleRPCs() { + } + + // h2 should now have the partial message after receiving data + h2Msg := peers.getOrCreatePartialMessage(1, topic, h1Msg.Commitment) + + // Assert h2 has the full message + if !h2Msg.complete() { + t.Fatal("h2 should have the full message") + } + }) + + t.Run("h1 has all the data. h2 has some of it", func(t *testing.T) { + peers := createPeers(t, topic, 2) + peers.network.addPeers() + defer peers.network.removePeers() + + h1Msg, err := newFullTestMessage(rand, peers.handlers[0], topic) + if err != nil { + t.Fatal(err) + } + + // h1 knows the full message + peers.handlers[0].PublishPartial(topic, h1Msg, PublishOptions{}) + peers.registerMessage(0, topic, h1Msg) + + // h2 only knows part of it + h2Msg := newEmptyTestMessage(h1Msg.Commitment, peers.handlers[1], topic) + for i := range h2Msg.Parts { + if i%2 == 0 { + h2Msg.Parts[i] = h1Msg.Parts[i] + h2Msg.Proofs[i] = h1Msg.Proofs[i] + } + } + peers.handlers[1].PublishPartial(topic, h2Msg, PublishOptions{}) + peers.registerMessage(1, topic, h2Msg) + + emptyMsg := &testPartialMessage{} + emptyMetadata := emptyMsg.PartsMetadata() + if bytes.Equal(peers.network.pendingMsgs[peers.peers[0]][0].rpc.PartsMetadata, emptyMetadata) { + t.Fatal("h2 request should not be the same as an empty message") + } + + // Handle all RPCs + for peers.network.handleRPCs() { + } + + // Assert that h2 only sent a single Partial IWANT + count := 0 + for _, rpc := range peers.network.allSentMsgs[peers.peers[0]] { + if rpc.rpc.PartsMetadata != nil { + count++ + } + } + if count != 2 { + t.Fatal("h2 should only have sent two parts updates") + } + + // Assert h2 has the full message + if !h2Msg.complete() { + t.Fatal("h2 should have the full message") + } + }) + + t.Run("h1 has half the data. h2 has the other half of it", func(t *testing.T) { + peers := createPeers(t, topic, 2) + peers.network.addPeers() + defer peers.network.removePeers() + + fullMsg, err := newFullTestMessage(rand, peers.handlers[0], topic) + if err != nil { + t.Fatal(err) + } + h1Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[0], topic) + for i := range fullMsg.Parts { + if i%2 == 0 { + h1Msg.Parts[i] = fullMsg.Parts[i] + h1Msg.Proofs[i] = fullMsg.Proofs[i] + } + } + + // h2 only knows part of it + h2Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[1], topic) + for i := range h2Msg.Parts { + if i%2 == 1 { + h2Msg.Parts[i] = fullMsg.Parts[i] + h2Msg.Proofs[i] = fullMsg.Proofs[i] + } + } + + // h1 knows half + peers.handlers[0].PublishPartial(topic, h1Msg, PublishOptions{}) + peers.registerMessage(0, topic, h1Msg) + // h2 knows the other half + peers.handlers[1].PublishPartial(topic, h2Msg, PublishOptions{}) + peers.registerMessage(1, topic, h2Msg) + + emptyMsg := &testPartialMessage{} + emptyMetadata := emptyMsg.PartsMetadata() + if bytes.Equal(peers.network.pendingMsgs[peers.peers[0]][0].rpc.PartsMetadata, emptyMetadata) { + t.Fatal("h2 metadata should not be the same as an empty message's metadata") + } + + // Handle all RPCs + for peers.network.handleRPCs() { + } + + // Assert that h2 only sent a single Partial IWANT + count := 0 + for _, rpc := range peers.network.allSentMsgs[peers.peers[0]] { + if len(rpc.rpc.PartsMetadata) > 0 { + count++ + } + } + if count != 2 { + t.Fatal("h2 should only have sent two parts updates") + } + + // Assert h2 has the full message + if !h2Msg.complete() { + t.Fatal("h2 should have the full message") + } + }) + + t.Run("h1 and h2 have the the same half of data. No partial messages should be sent", func(t *testing.T) { + peers := createPeers(t, topic, 2) + peers.network.addPeers() + defer peers.network.removePeers() + + fullMsg, err := newFullTestMessage(rand, peers.handlers[0], topic) + if err != nil { + t.Fatal(err) + } + h1Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[0], topic) + for i := range fullMsg.Parts { + if i%2 == 0 { + h1Msg.Parts[i] = fullMsg.Parts[i] + h1Msg.Proofs[i] = fullMsg.Proofs[i] + } + } + + // h2 only knows part of it (same as h1) + h2Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[1], topic) + for i := range h2Msg.Parts { + if i%2 == 0 { + h2Msg.Parts[i] = fullMsg.Parts[i] + h2Msg.Proofs[i] = fullMsg.Proofs[i] + } + } + + // h1 knows half + peers.handlers[0].PublishPartial(topic, h1Msg, PublishOptions{}) + // h2 knows the same half + peers.handlers[1].PublishPartial(topic, h2Msg, PublishOptions{}) + + // Handle all RPCs + for peers.network.handleRPCs() { + } + + // Assert that no peer sent a partial message + count := 0 + for _, rpcs := range peers.network.allSentMsgs { + for _, rpc := range rpcs { + if len(rpc.rpc.PartialMessage) > 0 { + count++ + } + } + } + if count > 0 { + t.Fatal("No partial messages should have been sent") + } + }) + + t.Run("three peers with distributed partial data", func(t *testing.T) { + peers := createPeers(t, topic, 3) + peers.network.addPeers() + defer peers.network.removePeers() + + fullMsg, err := newFullTestMessage(rand, peers.handlers[0], topic) + if err != nil { + t.Fatal(err) + } + + // Peer 1 has parts 0, 3, 6 + h1Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[0], topic) + for i := range fullMsg.Parts { + if i%3 == 0 { + h1Msg.Parts[i] = fullMsg.Parts[i] + h1Msg.Proofs[i] = fullMsg.Proofs[i] + } + } + + // Peer 2 has parts 1, 4, 7 + h2Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[1], topic) + for i := range fullMsg.Parts { + if i%3 == 1 { + h2Msg.Parts[i] = fullMsg.Parts[i] + h2Msg.Proofs[i] = fullMsg.Proofs[i] + } + } + + // Peer 3 has parts 2, 5 + h3Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[2], topic) + for i := range fullMsg.Parts { + if i%3 == 2 { + h3Msg.Parts[i] = fullMsg.Parts[i] + h3Msg.Proofs[i] = fullMsg.Proofs[i] + } + } + + // All peers publish their partial messages + peers.handlers[0].PublishPartial(topic, h1Msg, PublishOptions{}) + peers.registerMessage(0, topic, h1Msg) + peers.handlers[1].PublishPartial(topic, h2Msg, PublishOptions{}) + peers.registerMessage(1, topic, h2Msg) + peers.handlers[2].PublishPartial(topic, h3Msg, PublishOptions{}) + peers.registerMessage(2, topic, h3Msg) + + // Handle all RPCs until convergence + for peers.network.handleRPCs() { + } + + // Assert all peers have the full message + for i := range peers.handlers { + var msg *testPartialMessage + switch i { + case 0: + msg = h1Msg + case 1: + msg = h2Msg + case 2: + msg = h3Msg + } + + if !msg.complete() { + t.Fatalf("peer %d should have the full message", i+1) + } + } + }) + t.Run("three peers. peer 1 has all data and eager pushes. Receivers eager push new content", func(t *testing.T) { + peers := createPeers(t, topic, 3) + peers.network.addPeers() + defer peers.network.removePeers() + + fullMsg, err := newFullTestMessage(rand, peers.handlers[0], topic) + if err != nil { + t.Fatal(err) + } + + // Peer 1 has all the data + h1Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[0], topic) + for i := range fullMsg.Parts { + h1Msg.Parts[i] = fullMsg.Parts[i] + h1Msg.Proofs[i] = fullMsg.Proofs[i] + } + h1MsgEncoded, err := h1Msg.PartialMessageBytes(nil) + if err != nil { + t.Fatal(err) + } + + // Peer 2 has no parts + h2Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[1], topic) + + // Eagerly push new data to peers + h2Msg.republish = func(pm *testPartialMessage, newData []byte) { + peers.handlers[1].PublishPartial(topic, pm, PublishOptions{EagerPush: newData}) + } + // Peer 3 has no parts + h3Msg := newEmptyTestMessage(fullMsg.Commitment, peers.handlers[2], topic) + // Eagerly push new data to peers + h3Msg.republish = func(pm *testPartialMessage, newData []byte) { + peers.handlers[2].PublishPartial(topic, pm, PublishOptions{EagerPush: newData}) + } + // All peers publish their partial messages + peers.handlers[0].PublishPartial(topic, h1Msg, PublishOptions{EagerPush: h1MsgEncoded}) + peers.registerMessage(0, topic, h1Msg) + peers.handlers[1].PublishPartial(topic, h2Msg, PublishOptions{}) + peers.registerMessage(1, topic, h2Msg) + peers.handlers[2].PublishPartial(topic, h3Msg, PublishOptions{}) + peers.registerMessage(2, topic, h3Msg) + + // Handle all RPCs until convergence + for peers.network.handleRPCs() { + } + + // Assert all peers have the full message + for i := range peers.handlers { + var msg *testPartialMessage + switch i { + case 0: + msg = h1Msg + case 1: + msg = h2Msg + case 2: + msg = h3Msg + } + + if !msg.complete() { + t.Fatalf("peer %d should have the full message", i+1) + } + } + }) +} + +func TestPeerInitiatedCounter(t *testing.T) { + // slog.SetLogLoggerLevel(slog.LevelDebug) + topic := "test-topic" + randParts := func() []byte { + buf := make([]byte, 8) + _, _ = cryptorand.Read(buf) + return buf + } + handler := PartialMessageExtension{ + Logger: slog.Default(), + MergePartsMetadata: func(topic string, left, right PartsMetadata) PartsMetadata { + return left + }, + OnIncomingRPC: func(from peer.ID, rpc *pubsub_pb.PartialMessagesExtension) error { + // Ignore for this test + return nil + }, + ValidateRPC: func(from peer.ID, rpc *pubsub_pb.PartialMessagesExtension) error { + return nil + }, + PeerInitiatedGroupLimitPerTopic: 4, + PeerInitiatedGroupLimitPerTopicPerPeer: 2, + } + router := testRouter{ + sendRPC: func(p peer.ID, r *pubsub_pb.PartialMessagesExtension, urgent bool) {}, + meshPeers: func(topic string) iter.Seq[peer.ID] { return func(yield func(peer.ID) bool) {} }, + } + handler.Init(&router) + + err := handler.HandleRPC("1", &pubsub_pb.PartialMessagesExtension{ + TopicID: &topic, + GroupID: []byte("1"), + PartsMetadata: randParts(), + }) + if err != nil { + t.Fatal(err) + } + + assertCounts := func(expectedTotal int, expectedMap map[peer.ID]int) { + t.Helper() + if handler.peerInitiatedGroupCounter[topic].total != expectedTotal { + t.Fatal() + } + if !reflect.DeepEqual(handler.peerInitiatedGroupCounter[topic].perPeer, expectedMap) { + t.Fatal() + } + } + + assertCounts(1, map[peer.ID]int{"1": 1}) + + err = handler.HandleRPC("1", &pubsub_pb.PartialMessagesExtension{ + TopicID: &topic, + GroupID: []byte("2"), + PartsMetadata: randParts(), + }) + if err != nil { + t.Fatal(err) + } + + assertCounts(2, map[peer.ID]int{"1": 2}) + + err = handler.HandleRPC("1", &pubsub_pb.PartialMessagesExtension{ + TopicID: &topic, + GroupID: []byte("3"), + PartsMetadata: randParts(), + }) + if err != errPeerInitiatedGroupLimitReached { + t.Fatal(err) + } + + assertCounts(2, map[peer.ID]int{"1": 2}) + + // Two peers publish new groups + for id := range 2 { + otherPeer := fmt.Sprintf("peer%d", id) + err = handler.HandleRPC(peer.ID(otherPeer), &pubsub_pb.PartialMessagesExtension{ + TopicID: &topic, + GroupID: []byte(otherPeer), + PartsMetadata: randParts(), + }) + if err != nil { + t.Fatal(err) + } + } + + // The third one goves over the limit + otherPeer := fmt.Sprintf("peer%d", 3) + err = handler.HandleRPC(peer.ID(otherPeer), &pubsub_pb.PartialMessagesExtension{ + TopicID: &topic, + GroupID: []byte(otherPeer), + PartsMetadata: randParts(), + }) + if err != errPeerInitiatedGroupTotalLimitReached { + t.Fatal(err) + } + + // All peers go away, and the counts should be back to 0 + handler.RemovePeer("1") + for id := range 2 { + otherPeer := fmt.Sprintf("peer%d", id) + handler.RemovePeer(peer.ID(otherPeer)) + } + + assertCounts(0, map[peer.ID]int{}) + + // Test heartbeat cleanup + err = handler.HandleRPC("1", &pubsub_pb.PartialMessagesExtension{ + TopicID: &topic, + GroupID: []byte("4"), + PartsMetadata: randParts(), + }) + if err != nil { + t.Fatal(err) + } + assertCounts(1, map[peer.ID]int{"1": 1}) + + for range minGroupTTL + 1 { + handler.Heartbeat() + } + assertCounts(0, map[peer.ID]int{}) +} + +func FuzzPeerInitiatedCounter(f *testing.F) { + topic := "test-topic" + f.Fuzz(func(t *testing.T, script []byte, totalLimit uint8, peerLimit uint8) { + // This fuzzer works like a simple interpreter. It interprets the script as bytecode. + if len(script) == 0 { + return + } + if totalLimit == 0 { + totalLimit = uint8(defaultPeerInitiatedGroupLimitPerTopic) + } + if peerLimit == 0 { + peerLimit = uint8(defaultPeerInitiatedGroupLimitPerTopicPerPeer) + } + + handler := PartialMessageExtension{ + Logger: slog.Default(), + MergePartsMetadata: func(topic string, left, right PartsMetadata) PartsMetadata { + return left + }, + OnIncomingRPC: func(from peer.ID, rpc *pubsub_pb.PartialMessagesExtension) error { + // Ignore for this test + return nil + }, + ValidateRPC: func(from peer.ID, rpc *pubsub_pb.PartialMessagesExtension) error { + return nil + }, + GroupTTLByHeatbeat: minGroupTTL, + PeerInitiatedGroupLimitPerTopic: int(totalLimit), + PeerInitiatedGroupLimitPerTopicPerPeer: int(peerLimit), + } + router := testRouter{ + sendRPC: func(p peer.ID, r *pubsub_pb.PartialMessagesExtension, urgent bool) {}, + meshPeers: func(topic string) iter.Seq[peer.ID] { return func(yield func(peer.ID) bool) {} }, + } + handler.Init(&router) + partsMetadata := []byte{0, 0, 0, 0} + + expectedTotal := 0 + expectedPeercounts := map[peer.ID]int{} + + for i := 0; len(script) > 0; i++ { + switch script[0] % 3 { + case 0: // handle new rpc + script = script[1:] + if len(script) < 1 { + return + } + otherPeer := peer.ID(fmt.Sprintf("%d", script[0])) + script = script[1:] + + var expectNoError bool + if expectedTotal < int(totalLimit) && expectedPeercounts[peer.ID(otherPeer)] < int(peerLimit) { + expectedPeercounts[peer.ID(otherPeer)]++ + expectedTotal++ + expectNoError = true + } + + err := handler.HandleRPC(peer.ID(otherPeer), &pubsub_pb.PartialMessagesExtension{ + TopicID: &topic, + GroupID: fmt.Appendf(nil, "%d", i), + PartsMetadata: partsMetadata, + }) + if expectNoError && err != nil { + t.Fatalf("unexpected error: %v", err) + } + case 1: // remove peer + script = script[1:] + if len(script) < 1 { + return + } + otherPeer := peer.ID(fmt.Sprintf("%d", script[0])) + script = script[1:] + + if expectedPeercounts[peer.ID(otherPeer)] > 0 { + expectedTotal -= expectedPeercounts[peer.ID(otherPeer)] + delete(expectedPeercounts, peer.ID(otherPeer)) + } + + handler.RemovePeer(peer.ID(otherPeer)) + case 2: // heartbeat until everything is cleared + script = script[1:] + for range handler.GroupTTLByHeatbeat + 1 { + handler.Heartbeat() + } + expectedTotal = 0 + expectedPeercounts = map[peer.ID]int{} + default: + // no-op + script = script[1:] + } + + ctr, ok := handler.peerInitiatedGroupCounter[topic] + if ok { + + if expectedTotal != ctr.total { + t.Fatalf("expected total %d, got %d", expectedTotal, ctr.total) + } + if !reflect.DeepEqual(expectedPeercounts, ctr.perPeer) { + t.Fatalf("expected peer counts %v, got %v", expectedPeercounts, ctr.perPeer) + } + } else { + if expectedTotal != 0 { + t.Fatalf("expected total %d, got %d", expectedTotal, 0) + } + if !reflect.DeepEqual(expectedPeercounts, map[peer.ID]int{}) { + t.Fatalf("expected peer counts %v, got %v", expectedPeercounts, map[peer.ID]int{}) + } + } + } + }) +} diff --git a/pb/rpc.pb.go b/pb/rpc.pb.go index cd05d8eb..110b1e9e 100644 --- a/pb/rpc.pb.go +++ b/pb/rpc.pb.go @@ -23,9 +23,10 @@ var _ = math.Inf const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package type RPC struct { - Subscriptions []*RPC_SubOpts `protobuf:"bytes,1,rep,name=subscriptions" json:"subscriptions,omitempty"` - Publish []*Message `protobuf:"bytes,2,rep,name=publish" json:"publish,omitempty"` - Control *ControlMessage `protobuf:"bytes,3,opt,name=control" json:"control,omitempty"` + Subscriptions []*RPC_SubOpts `protobuf:"bytes,1,rep,name=subscriptions" json:"subscriptions,omitempty"` + Publish []*Message `protobuf:"bytes,2,rep,name=publish" json:"publish,omitempty"` + Control *ControlMessage `protobuf:"bytes,3,opt,name=control" json:"control,omitempty"` + Partial *PartialMessagesExtension `protobuf:"bytes,10,opt,name=partial" json:"partial,omitempty"` // Experimental Extensions should register their messages here. They // must use field numbers larger than 0x200000 to be encoded with at least 4 // bytes @@ -89,6 +90,13 @@ func (m *RPC) GetControl() *ControlMessage { return nil } +func (m *RPC) GetPartial() *PartialMessagesExtension { + if m != nil { + return m.Partial + } + return nil +} + func (m *RPC) GetTestExtension() *TestExtension { if m != nil { return m.TestExtension @@ -97,11 +105,18 @@ func (m *RPC) GetTestExtension() *TestExtension { } type RPC_SubOpts struct { - Subscribe *bool `protobuf:"varint,1,opt,name=subscribe" json:"subscribe,omitempty"` - Topicid *string `protobuf:"bytes,2,opt,name=topicid" json:"topicid,omitempty"` - XXX_NoUnkeyedLiteral struct{} `json:"-"` - XXX_unrecognized []byte `json:"-"` - XXX_sizecache int32 `json:"-"` + Subscribe *bool `protobuf:"varint,1,opt,name=subscribe" json:"subscribe,omitempty"` + Topicid *string `protobuf:"bytes,2,opt,name=topicid" json:"topicid,omitempty"` + // Used with Partial Messages extension. + // If set, signals to the receiver that the sender prefers partial messages. + RequestsPartial *bool `protobuf:"varint,3,opt,name=requestsPartial" json:"requestsPartial,omitempty"` + // If set, signals to the receiver that the sender supports sending partial + // messages on this topic. If requestsPartial is true, this is assumed to be + // true. + SupportsSendingPartial *bool `protobuf:"varint,4,opt,name=supportsSendingPartial" json:"supportsSendingPartial,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } func (m *RPC_SubOpts) Reset() { *m = RPC_SubOpts{} } @@ -151,6 +166,20 @@ func (m *RPC_SubOpts) GetTopicid() string { return "" } +func (m *RPC_SubOpts) GetRequestsPartial() bool { + if m != nil && m.RequestsPartial != nil { + return *m.RequestsPartial + } + return false +} + +func (m *RPC_SubOpts) GetSupportsSendingPartial() bool { + if m != nil && m.SupportsSendingPartial != nil { + return *m.SupportsSendingPartial + } + return false +} + type Message struct { From []byte `protobuf:"bytes,1,opt,name=from" json:"from,omitempty"` Data []byte `protobuf:"bytes,2,opt,name=data" json:"data,omitempty"` @@ -327,7 +356,8 @@ func (m *ControlMessage) GetExtensions() *ControlExtensions { type ControlIHave struct { TopicID *string `protobuf:"bytes,1,opt,name=topicID" json:"topicID,omitempty"` - // implementors from other languages should use bytes here - go protobuf emits invalid utf8 strings + // implementors from other languages should use bytes here - go protobuf emits + // invalid utf8 strings MessageIDs []string `protobuf:"bytes,2,rep,name=messageIDs" json:"messageIDs,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` @@ -382,7 +412,8 @@ func (m *ControlIHave) GetMessageIDs() []string { } type ControlIWant struct { - // implementors from other languages should use bytes here - go protobuf emits invalid utf8 strings + // implementors from other languages should use bytes here - go protobuf emits + // invalid utf8 strings MessageIDs []string `protobuf:"bytes,1,rep,name=messageIDs" json:"messageIDs,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` @@ -540,7 +571,8 @@ func (m *ControlPrune) GetBackoff() uint64 { } type ControlIDontWant struct { - // implementors from other languages should use bytes here - go protobuf emits invalid utf8 strings + // implementors from other languages should use bytes here - go protobuf emits + // invalid utf8 strings MessageIDs []string `protobuf:"bytes,1,rep,name=messageIDs" json:"messageIDs,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` @@ -588,6 +620,7 @@ func (m *ControlIDontWant) GetMessageIDs() []string { } type ControlExtensions struct { + PartialMessages *bool `protobuf:"varint,10,opt,name=partialMessages" json:"partialMessages,omitempty"` // Experimental extensions must use field numbers larger than 0x200000 to be // encoded with 4 bytes TestExtension *bool `protobuf:"varint,6492434,opt,name=testExtension" json:"testExtension,omitempty"` @@ -629,6 +662,13 @@ func (m *ControlExtensions) XXX_DiscardUnknown() { var xxx_messageInfo_ControlExtensions proto.InternalMessageInfo +func (m *ControlExtensions) GetPartialMessages() bool { + if m != nil && m.PartialMessages != nil { + return *m.PartialMessages + } + return false +} + func (m *ControlExtensions) GetTestExtension() bool { if m != nil && m.TestExtension != nil { return *m.TestExtension @@ -730,6 +770,79 @@ func (m *TestExtension) XXX_DiscardUnknown() { var xxx_messageInfo_TestExtension proto.InternalMessageInfo +type PartialMessagesExtension struct { + TopicID *string `protobuf:"bytes,1,opt,name=topicID" json:"topicID,omitempty"` + GroupID []byte `protobuf:"bytes,2,opt,name=groupID" json:"groupID,omitempty"` + // An encoded partial message + PartialMessage []byte `protobuf:"bytes,3,opt,name=partialMessage" json:"partialMessage,omitempty"` + // An encoded representation of the parts a peer has and wants. + PartsMetadata []byte `protobuf:"bytes,4,opt,name=partsMetadata" json:"partsMetadata,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *PartialMessagesExtension) Reset() { *m = PartialMessagesExtension{} } +func (m *PartialMessagesExtension) String() string { return proto.CompactTextString(m) } +func (*PartialMessagesExtension) ProtoMessage() {} +func (*PartialMessagesExtension) Descriptor() ([]byte, []int) { + return fileDescriptor_77a6da22d6a3feb1, []int{11} +} +func (m *PartialMessagesExtension) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *PartialMessagesExtension) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_PartialMessagesExtension.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *PartialMessagesExtension) XXX_Merge(src proto.Message) { + xxx_messageInfo_PartialMessagesExtension.Merge(m, src) +} +func (m *PartialMessagesExtension) XXX_Size() int { + return m.Size() +} +func (m *PartialMessagesExtension) XXX_DiscardUnknown() { + xxx_messageInfo_PartialMessagesExtension.DiscardUnknown(m) +} + +var xxx_messageInfo_PartialMessagesExtension proto.InternalMessageInfo + +func (m *PartialMessagesExtension) GetTopicID() string { + if m != nil && m.TopicID != nil { + return *m.TopicID + } + return "" +} + +func (m *PartialMessagesExtension) GetGroupID() []byte { + if m != nil { + return m.GroupID + } + return nil +} + +func (m *PartialMessagesExtension) GetPartialMessage() []byte { + if m != nil { + return m.PartialMessage + } + return nil +} + +func (m *PartialMessagesExtension) GetPartsMetadata() []byte { + if m != nil { + return m.PartsMetadata + } + return nil +} + func init() { proto.RegisterType((*RPC)(nil), "pubsub.pb.RPC") proto.RegisterType((*RPC_SubOpts)(nil), "pubsub.pb.RPC.SubOpts") @@ -743,49 +856,58 @@ func init() { proto.RegisterType((*ControlExtensions)(nil), "pubsub.pb.ControlExtensions") proto.RegisterType((*PeerInfo)(nil), "pubsub.pb.PeerInfo") proto.RegisterType((*TestExtension)(nil), "pubsub.pb.TestExtension") + proto.RegisterType((*PartialMessagesExtension)(nil), "pubsub.pb.PartialMessagesExtension") } func init() { proto.RegisterFile("rpc.proto", fileDescriptor_77a6da22d6a3feb1) } var fileDescriptor_77a6da22d6a3feb1 = []byte{ - // 583 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x94, 0xcd, 0x8e, 0xd3, 0x3e, - 0x14, 0xc5, 0x95, 0x7e, 0x4c, 0x9b, 0xdb, 0xf4, 0xff, 0x2f, 0x06, 0x0d, 0x06, 0x46, 0x55, 0x95, - 0x0d, 0x05, 0x41, 0x16, 0x65, 0x85, 0xd4, 0xcd, 0xd0, 0x22, 0xa6, 0x0b, 0xa0, 0x32, 0x48, 0xac, - 0x93, 0xd4, 0xed, 0x44, 0x33, 0xb5, 0x83, 0xed, 0x0c, 0xf0, 0x0e, 0xb0, 0xe1, 0x11, 0x58, 0xf3, - 0x1a, 0x48, 0x2c, 0x79, 0x04, 0xd4, 0x27, 0x41, 0x76, 0x3e, 0x9a, 0x36, 0x53, 0xd8, 0xd9, 0xd7, - 0xbf, 0xe3, 0x1c, 0x9f, 0x6b, 0x07, 0x6c, 0x11, 0x87, 0x5e, 0x2c, 0xb8, 0xe2, 0xc8, 0x8e, 0x93, - 0x40, 0x26, 0x81, 0x17, 0x07, 0xee, 0xf7, 0x1a, 0xd4, 0xc9, 0x7c, 0x82, 0xc6, 0xd0, 0x95, 0x49, - 0x20, 0x43, 0x11, 0xc5, 0x2a, 0xe2, 0x4c, 0x62, 0x6b, 0x50, 0x1f, 0x76, 0x46, 0xc7, 0x5e, 0x81, - 0x7a, 0x64, 0x3e, 0xf1, 0xde, 0x24, 0xc1, 0xeb, 0x58, 0x49, 0xb2, 0x0b, 0xa3, 0x47, 0xd0, 0x8a, - 0x93, 0xe0, 0x32, 0x92, 0xe7, 0xb8, 0x66, 0x74, 0xa8, 0xa4, 0x7b, 0x49, 0xa5, 0xf4, 0x57, 0x94, - 0xe4, 0x08, 0x7a, 0x02, 0xad, 0x90, 0x33, 0x25, 0xf8, 0x25, 0xae, 0x0f, 0xac, 0x61, 0x67, 0x74, - 0xa7, 0x44, 0x4f, 0xd2, 0x95, 0x42, 0x94, 0x91, 0xe8, 0x14, 0xba, 0x8a, 0x4a, 0xf5, 0xfc, 0xa3, - 0xa2, 0x4c, 0x46, 0x9c, 0xe1, 0xaf, 0xdf, 0x3e, 0xa7, 0x6a, 0x5c, 0x52, 0xbf, 0x2d, 0x23, 0x64, - 0x57, 0x71, 0xf7, 0x14, 0x5a, 0x99, 0x7f, 0x74, 0x02, 0x76, 0x76, 0x82, 0x80, 0x62, 0x6b, 0x60, - 0x0d, 0xdb, 0x64, 0x5b, 0x40, 0x18, 0x5a, 0x8a, 0xc7, 0x51, 0x18, 0x2d, 0x70, 0x6d, 0x60, 0x0d, - 0x6d, 0x92, 0x4f, 0xdd, 0x2f, 0x16, 0xb4, 0x32, 0x6b, 0x08, 0x41, 0x63, 0x29, 0xf8, 0xda, 0xc8, - 0x1d, 0x62, 0xc6, 0xba, 0xb6, 0xf0, 0x95, 0x6f, 0x64, 0x0e, 0x31, 0x63, 0x74, 0x0b, 0x9a, 0x92, - 0xbe, 0x67, 0xdc, 0x1c, 0xd6, 0x21, 0xe9, 0x44, 0x57, 0xcd, 0xa6, 0xb8, 0x61, 0xbe, 0x90, 0x4e, - 0x8c, 0xaf, 0x68, 0xc5, 0x7c, 0x95, 0x08, 0x8a, 0x9b, 0x86, 0xdf, 0x16, 0x50, 0x0f, 0xea, 0x17, - 0xf4, 0x13, 0x3e, 0x32, 0x75, 0x3d, 0x74, 0x7f, 0xd4, 0xe0, 0xbf, 0xdd, 0xc4, 0xd0, 0x63, 0x68, - 0x46, 0xe7, 0xfe, 0x15, 0xcd, 0x3a, 0x78, 0xbb, 0x9a, 0xed, 0xec, 0xcc, 0xbf, 0xa2, 0x24, 0xa5, - 0x0c, 0xfe, 0xc1, 0x67, 0x2a, 0x6b, 0xdc, 0x75, 0xf8, 0x3b, 0x9f, 0x29, 0x92, 0x52, 0x1a, 0x5f, - 0x09, 0x7f, 0xa9, 0x70, 0xfd, 0x10, 0xfe, 0x42, 0x2f, 0x93, 0x94, 0xd2, 0x78, 0x2c, 0x12, 0x46, - 0x71, 0xe3, 0x10, 0x3e, 0xd7, 0xcb, 0x24, 0xa5, 0xd0, 0x53, 0xb0, 0xa3, 0x05, 0x67, 0xca, 0x18, - 0x6a, 0x1a, 0xc9, 0xbd, 0x6b, 0x0c, 0x4d, 0x39, 0x53, 0xc6, 0xd4, 0x96, 0x46, 0x63, 0x00, 0x9a, - 0x77, 0x5a, 0x9a, 0x88, 0x3a, 0xa3, 0x93, 0xaa, 0xb6, 0xb8, 0x0d, 0x92, 0x94, 0x78, 0xf7, 0x0c, - 0x9c, 0x72, 0x38, 0xc5, 0x0d, 0x98, 0x4d, 0x4d, 0x7b, 0xf3, 0x1b, 0x30, 0x9b, 0xa2, 0x3e, 0xc0, - 0x3a, 0x4d, 0x7a, 0x36, 0x95, 0x26, 0x34, 0x9b, 0x94, 0x2a, 0xae, 0xb7, 0xdd, 0x49, 0x5b, 0xdc, - 0xe3, 0xad, 0x0a, 0x3f, 0x2c, 0x78, 0x13, 0xdc, 0xe1, 0x2f, 0xbb, 0xeb, 0x82, 0x34, 0x99, 0xfd, - 0xc5, 0xe3, 0x03, 0x68, 0xc6, 0x94, 0x0a, 0x99, 0xf5, 0xf4, 0x66, 0x29, 0x86, 0x39, 0xa5, 0x62, - 0xc6, 0x96, 0x9c, 0xa4, 0x84, 0xde, 0x24, 0xf0, 0xc3, 0x0b, 0xbe, 0x5c, 0x9a, 0xeb, 0xd9, 0x20, - 0xf9, 0xd4, 0x1d, 0x41, 0x6f, 0x3f, 0xef, 0x7f, 0x1e, 0x66, 0x0c, 0x37, 0x2a, 0x39, 0xa3, 0xfb, - 0x07, 0x5e, 0x6e, 0x7b, 0xef, 0x7d, 0xba, 0xaf, 0xa0, 0x9d, 0xdb, 0x43, 0xc7, 0x70, 0xa4, 0x0d, - 0x66, 0x67, 0x73, 0x48, 0x36, 0x43, 0x0f, 0xa1, 0xa7, 0xdf, 0x03, 0x5d, 0x68, 0x92, 0xd0, 0x90, - 0x8b, 0x45, 0xf6, 0xd8, 0x2a, 0x75, 0xf7, 0x7f, 0xe8, 0xee, 0xfc, 0x0f, 0x9e, 0x39, 0x3f, 0x37, - 0x7d, 0xeb, 0xd7, 0xa6, 0x6f, 0xfd, 0xde, 0xf4, 0xad, 0x3f, 0x01, 0x00, 0x00, 0xff, 0xff, 0xa2, - 0x64, 0xfc, 0x1b, 0x11, 0x05, 0x00, 0x00, + // 706 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x84, 0x94, 0xc1, 0x6e, 0x13, 0x3b, + 0x14, 0x86, 0x35, 0x4d, 0xd2, 0x64, 0x4e, 0x93, 0xb6, 0xd7, 0xf7, 0xaa, 0xd7, 0xf7, 0x52, 0x45, + 0xd1, 0x80, 0x20, 0x20, 0xc8, 0x22, 0x48, 0x48, 0x48, 0x65, 0x01, 0x0d, 0xa2, 0x59, 0x14, 0x22, + 0x17, 0x89, 0xf5, 0x4c, 0xe2, 0xa4, 0xa3, 0xb6, 0xb6, 0x6b, 0x7b, 0x0a, 0xbc, 0x03, 0x6c, 0x58, + 0xb3, 0x82, 0x67, 0x41, 0x62, 0x85, 0x78, 0x04, 0xd4, 0x27, 0x41, 0xf6, 0x78, 0x26, 0x33, 0x49, + 0x53, 0x76, 0xe3, 0xe3, 0xef, 0x1f, 0x9f, 0xf3, 0xfb, 0xf8, 0x80, 0x2f, 0xc5, 0xb8, 0x27, 0x24, + 0xd7, 0x1c, 0xf9, 0x22, 0x89, 0x54, 0x12, 0xf5, 0x44, 0x14, 0xfc, 0xa8, 0x40, 0x85, 0x8c, 0xf6, + 0xd1, 0x1e, 0xb4, 0x54, 0x12, 0xa9, 0xb1, 0x8c, 0x85, 0x8e, 0x39, 0x53, 0xd8, 0xeb, 0x54, 0xba, + 0x1b, 0xfd, 0x9d, 0x5e, 0x8e, 0xf6, 0xc8, 0x68, 0xbf, 0x77, 0x94, 0x44, 0xaf, 0x84, 0x56, 0xa4, + 0x0c, 0xa3, 0xfb, 0x50, 0x17, 0x49, 0x74, 0x1a, 0xab, 0x63, 0xbc, 0x66, 0x75, 0xa8, 0xa0, 0x3b, + 0xa4, 0x4a, 0x85, 0x33, 0x4a, 0x32, 0x04, 0x3d, 0x84, 0xfa, 0x98, 0x33, 0x2d, 0xf9, 0x29, 0xae, + 0x74, 0xbc, 0xee, 0x46, 0xff, 0xbf, 0x02, 0xbd, 0x9f, 0xee, 0xe4, 0x22, 0x47, 0xa2, 0x27, 0x50, + 0x17, 0xa1, 0xd4, 0x71, 0x78, 0x8a, 0xc1, 0x8a, 0x6e, 0x16, 0x44, 0xa3, 0x74, 0xc7, 0x89, 0xd4, + 0xf3, 0x77, 0x9a, 0x32, 0x15, 0x73, 0x46, 0x32, 0x0d, 0x7a, 0x0a, 0x2d, 0x4d, 0x95, 0xce, 0x77, + 0xf0, 0xa7, 0x2f, 0x1f, 0xd2, 0xc3, 0x71, 0xe1, 0x3f, 0xaf, 0x8b, 0x08, 0x29, 0x2b, 0xfe, 0xff, + 0xea, 0x41, 0xdd, 0xd5, 0x8f, 0x76, 0xc1, 0x77, 0x0e, 0x44, 0x14, 0x7b, 0x1d, 0xaf, 0xdb, 0x20, + 0xf3, 0x00, 0xc2, 0x50, 0xd7, 0x5c, 0xc4, 0xe3, 0x78, 0x82, 0xd7, 0x3a, 0x5e, 0xd7, 0x27, 0xd9, + 0x12, 0x75, 0x61, 0x4b, 0xd2, 0xf3, 0x84, 0x2a, 0xad, 0x5c, 0xce, 0xd6, 0x82, 0x06, 0x59, 0x0c, + 0xa3, 0x47, 0xb0, 0xa3, 0x12, 0x21, 0xb8, 0xd4, 0xea, 0x88, 0xb2, 0x49, 0xcc, 0x66, 0x99, 0xa0, + 0x6a, 0x05, 0x2b, 0x76, 0x83, 0x8f, 0x1e, 0xd4, 0x9d, 0x0f, 0x08, 0x41, 0x75, 0x2a, 0xf9, 0x99, + 0x4d, 0xb0, 0x49, 0xec, 0xb7, 0x89, 0x4d, 0x42, 0x1d, 0xda, 0xc4, 0x9a, 0xc4, 0x7e, 0xa3, 0x7f, + 0xa0, 0xa6, 0xe8, 0x39, 0xe3, 0x36, 0x97, 0x26, 0x49, 0x17, 0x26, 0x6a, 0xd3, 0xb6, 0x07, 0xfa, + 0x24, 0x5d, 0xd8, 0xca, 0xe3, 0x19, 0x0b, 0x75, 0x22, 0x29, 0xae, 0x59, 0x7e, 0x1e, 0x40, 0xdb, + 0x50, 0x39, 0xa1, 0xef, 0xf1, 0xba, 0x8d, 0x9b, 0xcf, 0xe0, 0xdb, 0x1a, 0x6c, 0x96, 0xef, 0x14, + 0x3d, 0x80, 0x5a, 0x7c, 0x1c, 0x5e, 0x50, 0xd7, 0x63, 0xff, 0x2e, 0xdf, 0xfe, 0xf0, 0x20, 0xbc, + 0xa0, 0x24, 0xa5, 0x2c, 0xfe, 0x36, 0x64, 0xda, 0xb5, 0xd6, 0x55, 0xf8, 0x9b, 0x90, 0x69, 0x92, + 0x52, 0x06, 0x9f, 0xc9, 0x70, 0xaa, 0x71, 0x65, 0x15, 0xfe, 0xc2, 0x6c, 0x93, 0x94, 0x32, 0xb8, + 0x90, 0x09, 0xa3, 0xb8, 0xba, 0x0a, 0x1f, 0x99, 0x6d, 0x92, 0x52, 0xe8, 0x31, 0xf8, 0xf1, 0x84, + 0x33, 0x6d, 0x13, 0xaa, 0x59, 0xc9, 0x8d, 0x2b, 0x12, 0x1a, 0x70, 0xa6, 0x6d, 0x52, 0x73, 0x1a, + 0xed, 0x01, 0xd0, 0xac, 0x99, 0x94, 0xb5, 0x68, 0xa3, 0xbf, 0xbb, 0xac, 0xcd, 0x1b, 0x4e, 0x91, + 0x02, 0x1f, 0x1c, 0x40, 0xb3, 0x68, 0x4e, 0xde, 0x63, 0xc3, 0x81, 0xbd, 0xde, 0xac, 0xc7, 0x86, + 0x03, 0xd4, 0x06, 0x38, 0x4b, 0x9d, 0x1e, 0x0e, 0x94, 0x35, 0xcd, 0x27, 0x85, 0x48, 0xd0, 0x9b, + 0xff, 0xc9, 0xa4, 0xb8, 0xc0, 0x7b, 0x4b, 0x7c, 0x37, 0xe7, 0xad, 0x71, 0xab, 0x4f, 0x0e, 0xce, + 0x72, 0xd2, 0x7a, 0x76, 0x4d, 0x8e, 0x77, 0xa1, 0x26, 0x28, 0x95, 0xca, 0xdd, 0xe9, 0xdf, 0xc5, + 0xb7, 0x4c, 0xa9, 0x1c, 0xb2, 0x29, 0x27, 0x29, 0x61, 0x7e, 0x12, 0x85, 0xe3, 0x13, 0x3e, 0x9d, + 0xda, 0xf6, 0xac, 0x92, 0x6c, 0x19, 0xf4, 0x61, 0x7b, 0xd1, 0xef, 0x3f, 0x16, 0x33, 0x85, 0xbf, + 0x96, 0x7c, 0x36, 0xaf, 0x52, 0x94, 0x27, 0x88, 0x9d, 0x31, 0x0d, 0xb2, 0x18, 0x46, 0x77, 0x56, + 0x8c, 0x91, 0xc6, 0xc2, 0xb0, 0x08, 0x5e, 0x42, 0x23, 0x2b, 0x04, 0xed, 0xc0, 0xba, 0x29, 0xc5, + 0xb9, 0xd0, 0x24, 0x6e, 0x85, 0xee, 0xc1, 0xb6, 0x79, 0x39, 0x74, 0x62, 0x48, 0x42, 0xc7, 0x5c, + 0x4e, 0xdc, 0xb3, 0x5c, 0x8a, 0x07, 0x5b, 0xd0, 0x2a, 0x0d, 0xa7, 0xe0, 0xb3, 0x07, 0x78, 0xd5, + 0xd8, 0xbb, 0xc6, 0x78, 0x0c, 0xf5, 0x99, 0xe4, 0x89, 0x18, 0x0e, 0xdc, 0x51, 0xd9, 0x12, 0xdd, + 0x86, 0xcd, 0x72, 0xb5, 0x6e, 0x1a, 0x2c, 0x44, 0xd1, 0x2d, 0x68, 0x99, 0x88, 0x3a, 0xa4, 0x3a, + 0xb4, 0x93, 0xa4, 0x6a, 0xb1, 0x72, 0xf0, 0x59, 0xf3, 0xfb, 0x65, 0xdb, 0xfb, 0x79, 0xd9, 0xf6, + 0x7e, 0x5d, 0xb6, 0xbd, 0xdf, 0x01, 0x00, 0x00, 0xff, 0xff, 0xa0, 0x3f, 0x9d, 0x94, 0x7c, 0x06, + 0x00, 0x00, } func (m *RPC) Marshal() (dAtA []byte, err error) { @@ -830,6 +952,18 @@ func (m *RPC) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x92 } + if m.Partial != nil { + { + size, err := m.Partial.MarshalToSizedBuffer(dAtA[:i]) + if err != nil { + return 0, err + } + i -= size + i = encodeVarintRpc(dAtA, i, uint64(size)) + } + i-- + dAtA[i] = 0x52 + } if m.Control != nil { { size, err := m.Control.MarshalToSizedBuffer(dAtA[:i]) @@ -897,6 +1031,26 @@ func (m *RPC_SubOpts) MarshalToSizedBuffer(dAtA []byte) (int, error) { i -= len(m.XXX_unrecognized) copy(dAtA[i:], m.XXX_unrecognized) } + if m.SupportsSendingPartial != nil { + i-- + if *m.SupportsSendingPartial { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x20 + } + if m.RequestsPartial != nil { + i-- + if *m.RequestsPartial { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x18 + } if m.Topicid != nil { i -= len(*m.Topicid) copy(dAtA[i:], *m.Topicid) @@ -1337,6 +1491,16 @@ func (m *ControlExtensions) MarshalToSizedBuffer(dAtA []byte) (int, error) { i-- dAtA[i] = 0x90 } + if m.PartialMessages != nil { + i-- + if *m.PartialMessages { + dAtA[i] = 1 + } else { + dAtA[i] = 0 + } + i-- + dAtA[i] = 0x50 + } return len(dAtA) - i, nil } @@ -1408,6 +1572,61 @@ func (m *TestExtension) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *PartialMessagesExtension) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *PartialMessagesExtension) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *PartialMessagesExtension) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.XXX_unrecognized != nil { + i -= len(m.XXX_unrecognized) + copy(dAtA[i:], m.XXX_unrecognized) + } + if m.PartsMetadata != nil { + i -= len(m.PartsMetadata) + copy(dAtA[i:], m.PartsMetadata) + i = encodeVarintRpc(dAtA, i, uint64(len(m.PartsMetadata))) + i-- + dAtA[i] = 0x22 + } + if m.PartialMessage != nil { + i -= len(m.PartialMessage) + copy(dAtA[i:], m.PartialMessage) + i = encodeVarintRpc(dAtA, i, uint64(len(m.PartialMessage))) + i-- + dAtA[i] = 0x1a + } + if m.GroupID != nil { + i -= len(m.GroupID) + copy(dAtA[i:], m.GroupID) + i = encodeVarintRpc(dAtA, i, uint64(len(m.GroupID))) + i-- + dAtA[i] = 0x12 + } + if m.TopicID != nil { + i -= len(*m.TopicID) + copy(dAtA[i:], *m.TopicID) + i = encodeVarintRpc(dAtA, i, uint64(len(*m.TopicID))) + i-- + dAtA[i] = 0xa + } + return len(dAtA) - i, nil +} + func encodeVarintRpc(dAtA []byte, offset int, v uint64) int { offset -= sovRpc(v) base := offset @@ -1441,6 +1660,10 @@ func (m *RPC) Size() (n int) { l = m.Control.Size() n += 1 + l + sovRpc(uint64(l)) } + if m.Partial != nil { + l = m.Partial.Size() + n += 1 + l + sovRpc(uint64(l)) + } if m.TestExtension != nil { l = m.TestExtension.Size() n += 4 + l + sovRpc(uint64(l)) @@ -1464,6 +1687,12 @@ func (m *RPC_SubOpts) Size() (n int) { l = len(*m.Topicid) n += 1 + l + sovRpc(uint64(l)) } + if m.RequestsPartial != nil { + n += 2 + } + if m.SupportsSendingPartial != nil { + n += 2 + } if m.XXX_unrecognized != nil { n += len(m.XXX_unrecognized) } @@ -1657,6 +1886,9 @@ func (m *ControlExtensions) Size() (n int) { } var l int _ = l + if m.PartialMessages != nil { + n += 2 + } if m.TestExtension != nil { n += 5 } @@ -1698,6 +1930,34 @@ func (m *TestExtension) Size() (n int) { return n } +func (m *PartialMessagesExtension) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.TopicID != nil { + l = len(*m.TopicID) + n += 1 + l + sovRpc(uint64(l)) + } + if m.GroupID != nil { + l = len(m.GroupID) + n += 1 + l + sovRpc(uint64(l)) + } + if m.PartialMessage != nil { + l = len(m.PartialMessage) + n += 1 + l + sovRpc(uint64(l)) + } + if m.PartsMetadata != nil { + l = len(m.PartsMetadata) + n += 1 + l + sovRpc(uint64(l)) + } + if m.XXX_unrecognized != nil { + n += len(m.XXX_unrecognized) + } + return n +} + func sovRpc(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 } @@ -1837,6 +2097,42 @@ func (m *RPC) Unmarshal(dAtA []byte) error { return err } iNdEx = postIndex + case 10: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field Partial", wireType) + } + var msglen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + msglen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if msglen < 0 { + return ErrInvalidLengthRpc + } + postIndex := iNdEx + msglen + if postIndex < 0 { + return ErrInvalidLengthRpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + if m.Partial == nil { + m.Partial = &PartialMessagesExtension{} + } + if err := m.Partial.Unmarshal(dAtA[iNdEx:postIndex]); err != nil { + return err + } + iNdEx = postIndex case 6492434: if wireType != 2 { return fmt.Errorf("proto: wrong wireType = %d for field TestExtension", wireType) @@ -1978,6 +2274,48 @@ func (m *RPC_SubOpts) Unmarshal(dAtA []byte) error { s := string(dAtA[iNdEx:postIndex]) m.Topicid = &s iNdEx = postIndex + case 3: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field RequestsPartial", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + b := bool(v != 0) + m.RequestsPartial = &b + case 4: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field SupportsSendingPartial", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + b := bool(v != 0) + m.SupportsSendingPartial = &b default: iNdEx = preIndex skippy, err := skipRpc(dAtA[iNdEx:]) @@ -3044,6 +3382,27 @@ func (m *ControlExtensions) Unmarshal(dAtA []byte) error { return fmt.Errorf("proto: ControlExtensions: illegal tag %d (wire type %d)", fieldNum, wire) } switch fieldNum { + case 10: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field PartialMessages", wireType) + } + var v int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + v |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + b := bool(v != 0) + m.PartialMessages = &b case 6492434: if wireType != 0 { return fmt.Errorf("proto: wrong wireType = %d for field TestExtension", wireType) @@ -3257,6 +3616,192 @@ func (m *TestExtension) Unmarshal(dAtA []byte) error { } return nil } +func (m *PartialMessagesExtension) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: PartialMessagesExtension: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: PartialMessagesExtension: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field TopicID", wireType) + } + var stringLen uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + stringLen |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + intStringLen := int(stringLen) + if intStringLen < 0 { + return ErrInvalidLengthRpc + } + postIndex := iNdEx + intStringLen + if postIndex < 0 { + return ErrInvalidLengthRpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + s := string(dAtA[iNdEx:postIndex]) + m.TopicID = &s + iNdEx = postIndex + case 2: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field GroupID", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthRpc + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthRpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.GroupID = append(m.GroupID[:0], dAtA[iNdEx:postIndex]...) + if m.GroupID == nil { + m.GroupID = []byte{} + } + iNdEx = postIndex + case 3: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field PartialMessage", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthRpc + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthRpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.PartialMessage = append(m.PartialMessage[:0], dAtA[iNdEx:postIndex]...) + if m.PartialMessage == nil { + m.PartialMessage = []byte{} + } + iNdEx = postIndex + case 4: + if wireType != 2 { + return fmt.Errorf("proto: wrong wireType = %d for field PartsMetadata", wireType) + } + var byteLen int + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowRpc + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + byteLen |= int(b&0x7F) << shift + if b < 0x80 { + break + } + } + if byteLen < 0 { + return ErrInvalidLengthRpc + } + postIndex := iNdEx + byteLen + if postIndex < 0 { + return ErrInvalidLengthRpc + } + if postIndex > l { + return io.ErrUnexpectedEOF + } + m.PartsMetadata = append(m.PartsMetadata[:0], dAtA[iNdEx:postIndex]...) + if m.PartsMetadata == nil { + m.PartsMetadata = []byte{} + } + iNdEx = postIndex + default: + iNdEx = preIndex + skippy, err := skipRpc(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthRpc + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + m.XXX_unrecognized = append(m.XXX_unrecognized, dAtA[iNdEx:iNdEx+skippy]...) + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func skipRpc(dAtA []byte) (n int, err error) { l := len(dAtA) iNdEx := 0 diff --git a/pb/rpc.proto b/pb/rpc.proto index 4f3833c1..d2c95b05 100644 --- a/pb/rpc.proto +++ b/pb/rpc.proto @@ -3,81 +3,102 @@ syntax = "proto2"; package pubsub.pb; message RPC { - repeated SubOpts subscriptions = 1; - repeated Message publish = 2; + repeated SubOpts subscriptions = 1; + repeated Message publish = 2; - message SubOpts { - optional bool subscribe = 1; // subscribe or unsubcribe - optional string topicid = 2; - } + message SubOpts { + optional bool subscribe = 1; // subscribe or unsubcribe + optional string topicid = 2; - optional ControlMessage control = 3; + // Used with Partial Messages extension. + // If set, signals to the receiver that the sender prefers partial messages. + optional bool requestsPartial = 3; + // If set, signals to the receiver that the sender supports sending partial + // messages on this topic. If requestsPartial is true, this is assumed to be + // true. + optional bool supportsSendingPartial = 4; + } - // Canonical Extensions should register their messages here. + optional ControlMessage control = 3; + + optional PartialMessagesExtension partial = 10; + + // Canonical Extensions should register their messages here. + + // Experimental Extensions should register their messages here. They + // must use field numbers larger than 0x200000 to be encoded with at least 4 + // bytes + optional TestExtension testExtension = 6492434; - // Experimental Extensions should register their messages here. They - // must use field numbers larger than 0x200000 to be encoded with at least 4 - // bytes - optional TestExtension testExtension = 6492434; } message Message { - optional bytes from = 1; - optional bytes data = 2; - optional bytes seqno = 3; - optional string topic = 4; - optional bytes signature = 5; - optional bytes key = 6; + optional bytes from = 1; + optional bytes data = 2; + optional bytes seqno = 3; + optional string topic = 4; + optional bytes signature = 5; + optional bytes key = 6; } message ControlMessage { - repeated ControlIHave ihave = 1; - repeated ControlIWant iwant = 2; - repeated ControlGraft graft = 3; - repeated ControlPrune prune = 4; - repeated ControlIDontWant idontwant = 5; - optional ControlExtensions extensions = 6; + repeated ControlIHave ihave = 1; + repeated ControlIWant iwant = 2; + repeated ControlGraft graft = 3; + repeated ControlPrune prune = 4; + repeated ControlIDontWant idontwant = 5; + optional ControlExtensions extensions = 6; } message ControlIHave { - optional string topicID = 1; - // implementors from other languages should use bytes here - go protobuf emits invalid utf8 strings - repeated string messageIDs = 2; + optional string topicID = 1; + // implementors from other languages should use bytes here - go protobuf emits + // invalid utf8 strings + repeated string messageIDs = 2; } message ControlIWant { - // implementors from other languages should use bytes here - go protobuf emits invalid utf8 strings - repeated string messageIDs = 1; + // implementors from other languages should use bytes here - go protobuf emits + // invalid utf8 strings + repeated string messageIDs = 1; } -message ControlGraft { - optional string topicID = 1; -} +message ControlGraft { optional string topicID = 1; } message ControlPrune { - optional string topicID = 1; - repeated PeerInfo peers = 2; - optional uint64 backoff = 3; + optional string topicID = 1; + repeated PeerInfo peers = 2; + optional uint64 backoff = 3; } message ControlIDontWant { - // implementors from other languages should use bytes here - go protobuf emits invalid utf8 strings - repeated string messageIDs = 1; + // implementors from other languages should use bytes here - go protobuf emits + // invalid utf8 strings + repeated string messageIDs = 1; } message ControlExtensions { - // Initially empty. Future extensions will be added here along with a - // reference to their specification. + optional bool partialMessages = 10; - // Experimental extensions must use field numbers larger than 0x200000 to be - // encoded with 4 bytes - optional bool testExtension = 6492434; + // Experimental extensions must use field numbers larger than 0x200000 to be + // encoded with 4 bytes + optional bool testExtension = 6492434; } - message PeerInfo { - optional bytes peerID = 1; - optional bytes signedPeerRecord = 2; + optional bytes peerID = 1; + optional bytes signedPeerRecord = 2; } -message TestExtension {} \ No newline at end of file +message TestExtension {} + +message PartialMessagesExtension { + optional string topicID = 1; + optional bytes groupID = 2; + + // An encoded partial message + optional bytes partialMessage = 3; + + // An encoded representation of the parts a peer has and wants. + optional bytes partsMetadata = 4; +} diff --git a/pb/slog.go b/pb/slog.go index 1ce92c52..ac9fd78f 100644 --- a/pb/slog.go +++ b/pb/slog.go @@ -29,3 +29,24 @@ func (m *RPC) LogValue() slog.Value { } return slog.GroupValue(fields...) } + +var _ slog.LogValuer = (*PartialMessagesExtension)(nil) + +func (e *PartialMessagesExtension) LogValue() slog.Value { + fields := make([]slog.Attr, 0, 4) + fields = append(fields, slog.String("topic", e.GetTopicID())) + fields = append(fields, slog.Any("groupID", e.GetGroupID())) + + // Message + if e.PartialMessage != nil { + fields = append(fields, slog.Group( + "message", + slog.Any("dataLen", len(e.PartialMessage)), + )) + } + + if e.PartsMetadata != nil { + fields = append(fields, slog.Any("partsMetadata", e.PartsMetadata)) + } + return slog.GroupValue(fields...) +} diff --git a/pubsub.go b/pubsub.go index 03ada8ef..01bd73cd 100644 --- a/pubsub.go +++ b/pubsub.go @@ -14,6 +14,7 @@ import ( "time" "github.com/libp2p/go-libp2p-pubsub/internal/gologshim" + "github.com/libp2p/go-libp2p-pubsub/partialmessages" pb "github.com/libp2p/go-libp2p-pubsub/pb" "github.com/libp2p/go-libp2p-pubsub/timecache" @@ -44,6 +45,17 @@ var ( type ProtocolMatchFn = func(protocol.ID) func(protocol.ID) bool +type peerTopicState struct { + requestsPartial bool + supportsPartial bool +} + +type peerOutgoingStream struct { + network.Stream + FirstMessage chan *RPC + Cancel context.CancelFunc +} + // PubSub is the implementation of the pubsub system. type PubSub struct { // atomic counter for seqnos @@ -110,7 +122,7 @@ type PubSub struct { newPeersPend map[peer.ID]struct{} // a notification channel for new outoging peer streams - newPeerStream chan network.Stream + newPeerStream chan peerOutgoingStream // a notification channel for errors opening new peer streams newPeerError chan peer.ID @@ -133,7 +145,7 @@ type PubSub struct { myTopics map[string]*Topic // topics tracks which topics each of our peers are subscribed to - topics map[string]map[peer.ID]struct{} + topics map[string]map[peer.ID]peerTopicState // sendMsg handles messages that have been validated sendMsg chan *Message @@ -141,6 +153,8 @@ type PubSub struct { // sendMessageBatch publishes a batch of messages sendMessageBatch chan messageBatchAndPublishOptions + sendPartialMsg chan publishPartialMessageReq + // addVal handles validator registration requests addVal chan *addValReq @@ -262,6 +276,35 @@ type RPC struct { from peer.ID } +// LogValue implements slog.LogValuer. +func (rpc *RPC) LogValue() slog.Value { + // Messages + msgs := make([]any, 0, len(rpc.Publish)) + for _, msg := range rpc.Publish { + msgs = append(msgs, slog.Group( + "message", + slog.Any("topic", msg.Topic), + slog.Any("dataPrefix", msg.Data[0:min(len(msg.Data), 32)]), + slog.Any("dataLen", len(msg.Data)), + )) + } + + fields := make([]slog.Attr, 0, len(msgs)+3) + if len(msgs) > 0 { + fields = append(fields, slog.Group("publish", msgs...)) + } + if rpc.Control != nil { + fields = append(fields, slog.Any("control", rpc.Control)) + } + if rpc.Subscriptions != nil { + fields = append(fields, slog.Any("subscriptions", rpc.Subscriptions)) + } + if rpc.Partial != nil { + fields = append(fields, slog.Any("Partial", rpc.Partial)) + } + return slog.GroupValue(fields...) +} + // split splits the given RPC If a sub RPC is too large and can't be split // further (e.g. Message data is bigger than the RPC limit), then it will be // returned as an oversized RPC. The caller should filter out oversized RPCs. @@ -480,7 +523,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option incoming: make(chan *RPC, 32), newPeers: make(chan struct{}, 1), newPeersPend: make(map[peer.ID]struct{}), - newPeerStream: make(chan network.Stream), + newPeerStream: make(chan peerOutgoingStream), newPeerError: make(chan peer.ID), peerDead: make(chan struct{}, 1), peerDeadPend: make(map[peer.ID]struct{}), @@ -495,13 +538,14 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option getTopics: make(chan *topicReq), sendMsg: make(chan *Message, 32), sendMessageBatch: make(chan messageBatchAndPublishOptions, 1), + sendPartialMsg: make(chan publishPartialMessageReq, 1), addVal: make(chan *addValReq), rmVal: make(chan *rmValReq), eval: make(chan func()), myTopics: make(map[string]*Topic), mySubs: make(map[string]map[*Subscription]struct{}), myRelays: make(map[string]int), - topics: make(map[string]map[peer.ID]struct{}), + topics: make(map[string]map[peer.ID]peerTopicState), peers: make(map[peer.ID]*rpcQueue), inboundStreams: make(map[peer.ID]network.Stream), blacklist: NewMapBlacklist(), @@ -815,6 +859,7 @@ func (p *PubSub) processLoop(ctx context.Context) { q, ok := p.peers[pid] if !ok { p.logger.Warn("new stream for unknown peer", "peer", pid) + s.Cancel() s.Reset() continue } @@ -823,13 +868,14 @@ func (p *PubSub) processLoop(ctx context.Context) { p.logger.Warn("closing stream for blacklisted peer", "peer", pid) q.Close() delete(p.peers, pid) + s.Cancel() s.Reset() continue } helloPacket := p.getHelloPacket() helloPacket = p.rt.AddPeer(pid, s.Protocol(), helloPacket) - q.Push(helloPacket, true) + s.FirstMessage <- helloPacket case pid := <-p.newPeerError: delete(p.peers, pid) @@ -881,6 +927,9 @@ func (p *PubSub) processLoop(ctx context.Context) { case batchAndOpts := <-p.sendMessageBatch: p.publishMessageBatch(batchAndOpts) + case req := <-p.sendPartialMsg: + p.publishPartialMessage(req) + case req := <-p.addVal: p.val.AddValidator(req) @@ -1148,9 +1197,19 @@ func (p *PubSub) handleRemoveRelay(topic string) { // announce announces whether or not this node is interested in a given topic // Only called from processLoop. func (p *PubSub) announce(topic string, sub bool) { + var requestPartialMessages bool + var supportsPartialMessages bool + if sub { + if t, ok := p.myTopics[topic]; ok { + requestPartialMessages = t.requestPartialMessages + supportsPartialMessages = t.supportsPartialMessages + } + } subopt := &pb.RPC_SubOpts{ - Topicid: &topic, - Subscribe: &sub, + Topicid: &topic, + Subscribe: &sub, + RequestsPartial: &requestPartialMessages, + SupportsSendingPartial: &supportsPartialMessages, } out := rpcWithSubs(subopt) @@ -1192,9 +1251,19 @@ func (p *PubSub) doAnnounceRetry(pid peer.ID, topic string, sub bool) { return } + var requestPartialMessages bool + var supportsPartialMessages bool + if sub { + if t, ok := p.myTopics[topic]; ok { + requestPartialMessages = t.requestPartialMessages + supportsPartialMessages = t.supportsPartialMessages + } + } subopt := &pb.RPC_SubOpts{ - Topicid: &topic, - Subscribe: &sub, + Topicid: &topic, + Subscribe: &sub, + RequestsPartial: &requestPartialMessages, + SupportsSendingPartial: &supportsPartialMessages, } out := rpcWithSubs(subopt) @@ -1294,12 +1363,19 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { if subopt.GetSubscribe() { tmap, ok := p.topics[t] if !ok { - tmap = make(map[peer.ID]struct{}) + tmap = make(map[peer.ID]peerTopicState) p.topics[t] = tmap } - if _, ok = tmap[rpc.from]; !ok { - tmap[rpc.from] = struct{}{} + pts := peerTopicState{ + requestsPartial: subopt.GetRequestsPartial(), + // If the peer requested partial, they support it by default + supportsPartial: subopt.GetRequestsPartial() || subopt.GetSupportsSendingPartial(), + } + _, seenBefore := tmap[rpc.from] + tmap[rpc.from] = pts + if !seenBefore { + tmap[rpc.from] = pts if topic, ok := p.myTopics[t]; ok { peer := rpc.from topic.sendNotification(PeerEvent{PeerJoin, peer}) @@ -1338,7 +1414,7 @@ func (p *PubSub) handleIncomingRPC(rpc *RPC) { continue } - msg := &Message{pmsg, "", rpc.from, nil, false} + msg := &Message{Message: pmsg, ID: "", ReceivedFrom: rpc.from, ValidatorData: nil, Local: false} if p.shouldPush(msg) { toPush = append(toPush, msg) } @@ -1477,7 +1553,51 @@ type rmTopicReq struct { resp chan error } -type TopicOptions struct{} +type TopicOptions struct { + SkipPublishingToPartialMessageCapablePeers bool +} + +// RequestPartialMessages requests that peers, if they support it, send us +// partial messages on this topic. +// +// If a peer does not support partial messages, this has no effect, and the peer +// will continue sending us full messages. +// +// It is an error to use this option if partial messages are not enabled. +// This option implies `SupportsPartialMessages`. +func RequestPartialMessages() TopicOpt { + return func(t *Topic) error { + gs, ok := t.p.rt.(*GossipSubRouter) + if !ok { + return errors.New("partial messages only supported by gossipsub") + } + + if !gs.extensions.myExtensions.PartialMessages { + return errors.New("partial messages are not enabled") + } + t.requestPartialMessages = true + t.supportsPartialMessages = true + return nil + } +} + +// SupportsPartialMessages signals to other peers that you will send partial +// message metadata and fulfill their partial message request, but you will not +// request partial messages. +func SupportsPartialMessages() TopicOpt { + return func(t *Topic) error { + gs, ok := t.p.rt.(*GossipSubRouter) + if !ok { + return errors.New("partial messages only supported by gossipsub") + } + + if !gs.extensions.myExtensions.PartialMessages { + return errors.New("partial messages are not enabled") + } + t.supportsPartialMessages = true + return nil + } +} type TopicOpt func(t *Topic) error @@ -1604,6 +1724,37 @@ func (p *PubSub) Publish(topic string, data []byte, opts ...PubOpt) error { return t.Publish(context.TODO(), data, opts...) } +type PeerFeedbackKind int + +const ( + PeerFeedbackUsefulMessage PeerFeedbackKind = iota + PeerFeedbackInvalidMessage +) + +// PeerFeedback lets applications inform GossipSub's peer scorer about the +// performance of a peer's message. This is useful if the application is using +// partial messages, because the application handles merging parts. +func (p *PubSub) PeerFeedback(topic string, peer peer.ID, kind PeerFeedbackKind) error { + gs, ok := p.rt.(*GossipSubRouter) + if !ok { + return errors.New("peer feedback is only supported by GossipSub") + } + p.eval <- func() { + if gs.score == nil { + return + } + gs.score.Lock() + defer gs.score.Unlock() + switch kind { + case PeerFeedbackUsefulMessage: + gs.score.markFirstMessageDelivery(peer, topic) + case PeerFeedbackInvalidMessage: + gs.score.markInvalidMessageDelivery(peer, topic) + } + } + return nil +} + // PublishBatch publishes a batch of messages. This only works for routers that // implement the BatchPublisher interface. // @@ -1635,6 +1786,43 @@ func (p *PubSub) PublishBatch(batch *MessageBatch, opts ...BatchPubOpt) error { return nil } +type publishPartialMessageReq struct { + topic string + partialMessage partialmessages.Message + opts partialmessages.PublishOptions + errCh chan error +} + +func (p *PubSub) PublishPartialMessage(topic string, partialMessage partialmessages.Message, opts partialmessages.PublishOptions) error { + errCh := make(chan error, 1) + select { + case p.sendPartialMsg <- publishPartialMessageReq{ + topic: topic, + partialMessage: partialMessage, + opts: opts, + errCh: errCh, + }: + case <-p.ctx.Done(): + return p.ctx.Err() + } + return <-errCh +} + +func (p *PubSub) publishPartialMessage(req publishPartialMessageReq) { + rt, ok := p.rt.(*GossipSubRouter) + if !ok { + req.errCh <- errors.New("partial publishing is only supported by the GossipSub router") + return + } + + if rt.extensions.partialMessagesExtension == nil { + req.errCh <- errors.New("partial publishing is not enabled") + return + } + + req.errCh <- rt.extensions.partialMessagesExtension.PublishPartial(req.topic, req.partialMessage, req.opts) +} + func (p *PubSub) nextSeqno() []byte { seqno := make([]byte, 8) counter := atomic.AddUint64(&p.counter, 1) diff --git a/score.go b/score.go index feb05465..7f5e384b 100644 --- a/score.go +++ b/score.go @@ -709,7 +709,7 @@ func (ps *peerScore) DeliverMessage(msg *Message) { ps.Lock() defer ps.Unlock() - ps.markFirstMessageDelivery(msg.ReceivedFrom, msg) + ps.markFirstMessageDelivery(msg.ReceivedFrom, msg.GetTopic()) drec := ps.deliveries.getRecord(ps.idGen.ID(msg)) @@ -746,7 +746,7 @@ func (ps *peerScore) RejectMessage(msg *Message, reason string) { case RejectUnexpectedAuthInfo: fallthrough case RejectSelfOrigin: - ps.markInvalidMessageDelivery(msg.ReceivedFrom, msg) + ps.markInvalidMessageDelivery(msg.ReceivedFrom, msg.GetTopic()) return // we ignore those messages, so do nothing. @@ -789,9 +789,9 @@ func (ps *peerScore) RejectMessage(msg *Message, reason string) { // mark the message as invalid and penalize peers that have already forwarded it. drec.status = deliveryInvalid - ps.markInvalidMessageDelivery(msg.ReceivedFrom, msg) + ps.markInvalidMessageDelivery(msg.ReceivedFrom, msg.GetTopic()) for p := range drec.peers { - ps.markInvalidMessageDelivery(p, msg) + ps.markInvalidMessageDelivery(p, msg.GetTopic()) } // release the delivery time tracking map to free some memory early @@ -823,7 +823,7 @@ func (ps *peerScore) DuplicateMessage(msg *Message) { case deliveryInvalid: // we no longer track delivery time - ps.markInvalidMessageDelivery(msg.ReceivedFrom, msg) + ps.markInvalidMessageDelivery(msg.ReceivedFrom, msg.GetTopic()) case deliveryThrottled: // the message was throttled; do nothing (we don't know if it was valid) @@ -904,13 +904,12 @@ func (pstats *peerStats) getTopicStats(topic string, params *PeerScoreParams) (* // markInvalidMessageDelivery increments the "invalid message deliveries" // counter for all scored topics the message is published in. -func (ps *peerScore) markInvalidMessageDelivery(p peer.ID, msg *Message) { +func (ps *peerScore) markInvalidMessageDelivery(p peer.ID, topic string) { pstats, ok := ps.peerStats[p] if !ok { return } - topic := msg.GetTopic() tstats, ok := pstats.getTopicStats(topic, ps.params) if !ok { return @@ -922,13 +921,12 @@ func (ps *peerScore) markInvalidMessageDelivery(p peer.ID, msg *Message) { // markFirstMessageDelivery increments the "first message deliveries" counter // for all scored topics the message is published in, as well as the "mesh // message deliveries" counter, if the peer is in the mesh for the topic. -func (ps *peerScore) markFirstMessageDelivery(p peer.ID, msg *Message) { +func (ps *peerScore) markFirstMessageDelivery(p peer.ID, topic string) { pstats, ok := ps.peerStats[p] if !ok { return } - topic := msg.GetTopic() tstats, ok := pstats.getTopicStats(topic, ps.params) if !ok { return diff --git a/topic.go b/topic.go index dd094eae..7966c651 100644 --- a/topic.go +++ b/topic.go @@ -32,6 +32,9 @@ type Topic struct { mux sync.RWMutex closed bool + + requestPartialMessages bool + supportsPartialMessages bool } // String returns the topic associated with t @@ -348,7 +351,14 @@ func (t *Topic) validate(ctx context.Context, data []byte, opts ...PubOpt) (*Mes } } - msg := &Message{m, "", t.p.host.ID(), pub.validatorData, pub.local} + msg := &Message{ + Message: m, + ID: "", + ReceivedFrom: t.p.host.ID(), + ValidatorData: pub.validatorData, + Local: pub.local, + } + select { case t.p.eval <- func() { t.p.rt.Preprocess(t.p.host.ID(), []*Message{msg})