diff --git a/go.mod b/go.mod index bb0f456ff7..573393404c 100644 --- a/go.mod +++ b/go.mod @@ -48,14 +48,14 @@ require ( github.com/multiformats/go-varint v0.0.7 github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 github.com/pion/datachannel v1.5.10 - github.com/pion/ice/v4 v4.0.6 + github.com/pion/ice/v4 v4.0.8 github.com/pion/logging v0.2.3 - github.com/pion/sctp v1.8.36 + github.com/pion/sctp v1.8.37 github.com/pion/stun v0.6.1 github.com/pion/webrtc/v4 v4.0.10 - github.com/prometheus/client_golang v1.21.0 + github.com/prometheus/client_golang v1.21.1 github.com/prometheus/client_model v0.6.1 - github.com/quic-go/quic-go v0.50.0 + github.com/quic-go/quic-go v0.50.1 github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 github.com/raulk/go-watchdog v1.3.0 github.com/stretchr/testify v1.10.0 diff --git a/go.sum b/go.sum index 0df1c6267f..867ec2c0a4 100644 --- a/go.sum +++ b/go.sum @@ -276,8 +276,8 @@ github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk= github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= github.com/pion/dtls/v3 v3.0.4 h1:44CZekewMzfrn9pmGrj5BNnTMDCFwr+6sLH+cCuLM7U= github.com/pion/dtls/v3 v3.0.4/go.mod h1:R373CsjxWqNPf6MEkfdy3aSe9niZvL/JaKlGeFphtMg= -github.com/pion/ice/v4 v4.0.6 h1:jmM9HwI9lfetQV/39uD0nY4y++XZNPhvzIPCb8EwxUM= -github.com/pion/ice/v4 v4.0.6/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= +github.com/pion/ice/v4 v4.0.8 h1:ajNx0idNG+S+v9Phu4LSn2cs8JEfTsA1/tEjkkAVpFY= +github.com/pion/ice/v4 v4.0.8/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI= github.com/pion/interceptor v0.1.37/go.mod h1:JzxbJ4umVTlZAf+/utHzNesY8tmRkM2lVmkS82TTj8Y= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= @@ -291,8 +291,8 @@ github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo= github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0= github.com/pion/rtp v1.8.11 h1:17xjnY5WO5hgO6SD3/NTIUPvSFw/PbLsIJyz1r1yNIk= github.com/pion/rtp v1.8.11/go.mod h1:8uMBJj32Pa1wwx8Fuv/AsFhn8jsgw+3rUC2PfoBZ8p4= -github.com/pion/sctp v1.8.36 h1:owNudmnz1xmhfYje5L/FCav3V9wpPRePHle3Zi+P+M0= -github.com/pion/sctp v1.8.36/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE= +github.com/pion/sctp v1.8.37 h1:ZDmGPtRPX9mKCiVXtMbTWybFw3z/hVKAZgU81wcOrqs= +github.com/pion/sctp v1.8.37/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE= github.com/pion/sdp/v3 v3.0.10 h1:6MChLE/1xYB+CjumMw+gZ9ufp2DPApuVSnDT8t5MIgA= github.com/pion/sdp/v3 v3.0.10/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E= github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M= @@ -317,8 +317,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA= -github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= +github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= +github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= @@ -330,8 +330,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.50.0 h1:3H/ld1pa3CYhkcc20TPIyG1bNsdhn9qZBGN3b9/UyUo= -github.com/quic-go/quic-go v0.50.0/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E= +github.com/quic-go/quic-go v0.50.1 h1:unsgjFIUqW8a2oopkY7YNONpV1gYND6Nt9hnt1PN94Q= +github.com/quic-go/quic-go v0.50.1/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 h1:4WFk6u3sOT6pLa1kQ50ZVdm8BQFgJNA117cepZxtLIg= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66/go.mod h1:Vp72IJajgeOL6ddqrAhmp7IM9zbTcgkQxD/YdxrVwMw= github.com/raulk/go-watchdog v1.3.0 h1:oUmdlHxdkXRJlwfG0O9omj8ukerm8MEQavSiDTEtBsk= diff --git a/p2p/http/libp2phttp.go b/p2p/http/libp2phttp.go index 26f4825d10..8691ed1073 100644 --- a/p2p/http/libp2phttp.go +++ b/p2p/http/libp2phttp.go @@ -45,6 +45,10 @@ const LegacyWellKnownProtocols = "/.well-known/libp2p" const peerMetadataLimit = 8 << 10 // 8KB const peerMetadataLRUSize = 256 // How many different peer's metadata to keep in our LRU cache +// defaultNewStreamTimeout is the default value for new stream establishing timeout. +// It is the same value as basic_host.DefaultNegotiationTimeout +var defaultNewStreamTimeout = 10 * time.Second + type clientPeerIDContextKey struct{} type serverPeerIDContextKey struct{} @@ -496,7 +500,16 @@ func (rt *streamRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) }) } - s, err := rt.h.NewStream(r.Context(), rt.server, ProtocolIDForMultistreamSelect) + // If r.Context() timeout is greater than DefaultNewStreamTimeout + // use DefaultNewStreamTimeout for new stream negotiation. + newStreamCtx := r.Context() + if deadline, ok := newStreamCtx.Deadline(); !ok || deadline.After(time.Now().Add(defaultNewStreamTimeout)) { + var cancel context.CancelFunc + newStreamCtx, cancel = context.WithTimeout(context.Background(), defaultNewStreamTimeout) + defer cancel() + } + + s, err := rt.h.NewStream(newStreamCtx, rt.server, ProtocolIDForMultistreamSelect) if err != nil { return nil, err } diff --git a/p2p/net/conngater/conngater.go b/p2p/net/conngater/conngater.go index bcc927f184..8c9084c5a0 100644 --- a/p2p/net/conngater/conngater.go +++ b/p2p/net/conngater/conngater.go @@ -72,7 +72,7 @@ func (cg *BasicConnectionGater) loadRules(ctx context.Context) error { for r := range res.Next() { if r.Error != nil { log.Errorf("query result error: %s", r.Error) - return err + return r.Error } p := peer.ID(r.Entry.Value) @@ -89,7 +89,7 @@ func (cg *BasicConnectionGater) loadRules(ctx context.Context) error { for r := range res.Next() { if r.Error != nil { log.Errorf("query result error: %s", r.Error) - return err + return r.Error } ip := net.IP(r.Entry.Value) @@ -106,7 +106,7 @@ func (cg *BasicConnectionGater) loadRules(ctx context.Context) error { for r := range res.Next() { if r.Error != nil { log.Errorf("query result error: %s", r.Error) - return err + return r.Error } ipnetStr := string(r.Entry.Value) diff --git a/p2p/protocol/autonatv2/server.go b/p2p/protocol/autonatv2/server.go index 05e0bdd9fd..dafac4dfe2 100644 --- a/p2p/protocol/autonatv2/server.go +++ b/p2p/protocol/autonatv2/server.go @@ -30,7 +30,7 @@ var ( errDialDataRefused = errors.New("dial data refused") ) -type dataRequestPolicyFunc = func(s network.Stream, dialAddr ma.Multiaddr) bool +type dataRequestPolicyFunc = func(observedAddr, dialAddr ma.Multiaddr) bool type EventDialRequestCompleted struct { Error error @@ -212,7 +212,7 @@ func (as *server) serveDialRequest(s network.Stream) EventDialRequestCompleted { nonce := msg.GetDialRequest().Nonce - isDialDataRequired := as.dialDataRequestPolicy(s, dialAddr) + isDialDataRequired := as.dialDataRequestPolicy(s.Conn().RemoteMultiaddr(), dialAddr) if isDialDataRequired && !as.limiter.AcceptDialDataRequest(p) { msg = pb.Message{ Msg: &pb.Message_DialResponse{ @@ -517,11 +517,11 @@ func (r *rateLimiter) Close() { // amplificationAttackPrevention is a dialDataRequestPolicy which requests data when the peer's observed // IP address is different from the dial back IP address -func amplificationAttackPrevention(s network.Stream, dialAddr ma.Multiaddr) bool { - connIP, err := manet.ToIP(s.Conn().RemoteMultiaddr()) +func amplificationAttackPrevention(observedAddr, dialAddr ma.Multiaddr) bool { + observedIP, err := manet.ToIP(observedAddr) if err != nil { return true } - dialIP, _ := manet.ToIP(s.Conn().LocalMultiaddr()) // must be an IP multiaddr - return !connIP.Equal(dialIP) + dialIP, _ := manet.ToIP(dialAddr) // must be an IP multiaddr + return !observedIP.Equal(dialIP) } diff --git a/p2p/protocol/autonatv2/server_test.go b/p2p/protocol/autonatv2/server_test.go index 6664f1c397..f92396fb16 100644 --- a/p2p/protocol/autonatv2/server_test.go +++ b/p2p/protocol/autonatv2/server_test.go @@ -143,7 +143,7 @@ func TestServerDataRequest(t *testing.T) { dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) // ask for dial data for quic address an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( - func(s network.Stream, dialAddr ma.Multiaddr) bool { + func(_, dialAddr ma.Multiaddr) bool { if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { return true } @@ -197,7 +197,7 @@ func TestServerMaxConcurrentRequestsPerPeer(t *testing.T) { doneChan := make(chan struct{}) an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( // stall all allowed requests - func(s network.Stream, dialAddr ma.Multiaddr) bool { + func(_, dialAddr ma.Multiaddr) bool { <-doneChan return true }), @@ -255,7 +255,7 @@ func TestServerDataRequestJitter(t *testing.T) { dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) // ask for dial data for quic address an := newAutoNAT(t, dialer, allowPrivateAddrs, withDataRequestPolicy( - func(s network.Stream, dialAddr ma.Multiaddr) bool { + func(_, dialAddr ma.Multiaddr) bool { if _, err := dialAddr.ValueForProtocol(ma.P_QUIC_V1); err == nil { return true } @@ -520,6 +520,76 @@ func TestReadDialData(t *testing.T) { } } +func TestServerDataRequestWithAmplificationAttackPrevention(t *testing.T) { + // server will skip all tcp addresses + dialer := bhost.NewBlankHost(swarmt.GenSwarm(t, swarmt.OptDisableTCP)) + // ask for dial data for quic address + an := newAutoNAT(t, dialer, allowPrivateAddrs, + WithServerRateLimit(10, 10, 10, 2), + withAmplificationAttackPreventionDialWait(0), + ) + defer an.Close() + defer an.host.Close() + + c := newAutoNAT(t, nil, allowPrivateAddrs) + defer c.Close() + defer c.host.Close() + + idAndWait(t, c, an) + + err := c.host.Network().Listen(ma.StringCast("/ip6/::1/udp/0/quic-v1")) + if err != nil { + // machine doesn't have ipv6 + t.Skip("skipping test because machine doesn't have ipv6") + } + + var quicv4Addr ma.Multiaddr + var quicv6Addr ma.Multiaddr + for _, a := range c.host.Addrs() { + if _, err := a.ValueForProtocol(ma.P_QUIC_V1); err == nil { + if _, err := a.ValueForProtocol(ma.P_IP4); err == nil { + quicv4Addr = a + } else { + quicv6Addr = a + } + } + } + res, err := c.GetReachability(context.Background(), []Request{{Addr: quicv4Addr, SendDialData: false}}) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: quicv4Addr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) + + // ipv6 address should require dial data + _, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: false}}) + require.Error(t, err) + require.ErrorContains(t, err, "invalid dial data request: low priority addr") + + // ipv6 address should work fine with dial data + res, err = c.GetReachability(context.Background(), []Request{{Addr: quicv6Addr, SendDialData: true}}) + require.NoError(t, err) + require.Equal(t, Result{ + Addr: quicv6Addr, + Reachability: network.ReachabilityPublic, + Status: pb.DialStatus_OK, + }, res) +} + +func TestDefaultAmplificationAttackPrevention(t *testing.T) { + q1 := ma.StringCast("/ip4/1.2.3.4/udp/1234/quic-v1") + q2 := ma.StringCast("/ip4/1.2.3.4/udp/1235/quic-v1") + t1 := ma.StringCast("/ip4/1.2.3.4/tcp/1234") + + require.False(t, amplificationAttackPrevention(q1, q1)) + require.False(t, amplificationAttackPrevention(q1, q2)) + require.False(t, amplificationAttackPrevention(q1, t1)) + + t2 := ma.StringCast("/ip4/1.1.1.1/tcp/1235") // different IP + require.True(t, amplificationAttackPrevention(q2, t2)) +} + func FuzzServerDialRequest(f *testing.F) { a := newAutoNAT(f, nil, allowPrivateAddrs, WithServerRateLimit(math.MaxInt32, math.MaxInt32, math.MaxInt32, 2)) c := newAutoNAT(f, nil) diff --git a/p2p/protocol/holepunch/holepunch_test.go b/p2p/protocol/holepunch/holepunch_test.go index 28e6122914..a763880711 100644 --- a/p2p/protocol/holepunch/holepunch_test.go +++ b/p2p/protocol/holepunch/holepunch_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "slices" "sync" "sync/atomic" "testing" @@ -682,3 +683,98 @@ func SetLegacyBehavior(legacyBehavior bool) holepunch.Option { return nil } } + +// TestEndToEndSimConnectQUICReuse tests that hole punching works if we are +// reusing the same port for QUIC and WebTransport, and when we have multiple +// QUIC listeners on different ports. +// +// If this tests fails or is flaky it may be because: +// - The quicreuse logic (and association logic) is not returning the appropriate transport for holepunching. +// - The ordering of listeners is unexpected (remember the swarm will sort the listeners with `.ListenOrder()`). +func TestEndToEndSimConnectQUICReuse(t *testing.T) { + h1tr := &mockEventTracer{} + h2tr := &mockEventTracer{} + + router := &simconn.SimpleFirewallRouter{} + relay := MustNewHost(t, + quicSimConn(true, router), + libp2p.ListenAddrs(ma.StringCast("/ip4/1.2.0.1/udp/8000/quic-v1")), + libp2p.DisableRelay(), + libp2p.ResourceManager(&network.NullResourceManager{}), + libp2p.WithFxOption(fx.Invoke(func(h host.Host) { + // Setup relay service + _, err := relayv2.New(h) + require.NoError(t, err) + })), + ) + + // We return addrs of quic on port 8001 and circuit. + // This lets us listen on other ports for QUIC in order to confuse the quicreuse logic during hole punching. + onlyQuicOnPort8001AndCircuit := func(addrs []ma.Multiaddr) []ma.Multiaddr { + return slices.DeleteFunc(addrs, func(a ma.Multiaddr) bool { + _, err := a.ValueForProtocol(ma.P_CIRCUIT) + isCircuit := err == nil + if isCircuit { + return false + } + _, err = a.ValueForProtocol(ma.P_QUIC_V1) + isQuic := err == nil + if !isQuic { + return true + } + port, err := a.ValueForProtocol(ma.P_UDP) + if err != nil { + return true + } + isPort8001 := port == "8001" + return !isPort8001 + }) + } + + h1 := MustNewHost(t, + quicSimConn(false, router), + libp2p.EnableHolePunching(holepunch.WithTracer(h1tr), holepunch.DirectDialTimeout(100*time.Millisecond)), + libp2p.ListenAddrs(ma.StringCast("/ip4/2.2.0.1/udp/8001/quic-v1/webtransport")), + libp2p.ResourceManager(&network.NullResourceManager{}), + libp2p.AddrsFactory(onlyQuicOnPort8001AndCircuit), + libp2p.ForceReachabilityPrivate(), + ) + // Listen on quic *after* listening on webtransport. + // This is to test that the quicreuse logic is not returning the wrong transport. + // See: https://github.com/libp2p/go-libp2p/issues/3165#issuecomment-2700126706 for details. + h1.Network().Listen( + ma.StringCast("/ip4/2.2.0.1/udp/8001/quic-v1"), + ma.StringCast("/ip4/2.2.0.1/udp/9001/quic-v1"), + ) + + h2 := MustNewHost(t, + quicSimConn(false, router), + libp2p.ListenAddrs( + ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1/webtransport"), + ), + libp2p.ResourceManager(&network.NullResourceManager{}), + connectToRelay(&relay), + libp2p.EnableHolePunching(holepunch.WithTracer(h2tr), holepunch.DirectDialTimeout(100*time.Millisecond)), + libp2p.AddrsFactory(onlyQuicOnPort8001AndCircuit), + libp2p.ForceReachabilityPrivate(), + ) + // Listen on quic after listening on webtransport. + h2.Network().Listen( + ma.StringCast("/ip4/2.2.0.2/udp/8001/quic-v1"), + ma.StringCast("/ip4/2.2.0.2/udp/9001/quic-v1"), + ) + + defer h1.Close() + defer h2.Close() + defer relay.Close() + + // Wait for holepunch service to start + waitForHolePunchingSvcActive(t, h1) + waitForHolePunchingSvcActive(t, h2) + + learnAddrs(h1, h2) + pingAtoB(t, h1, h2) + + // wait till a direct connection is complete + ensureDirectConn(t, h1, h2) +} diff --git a/p2p/transport/quicreuse/connmgr.go b/p2p/transport/quicreuse/connmgr.go index bf9ba22bf2..c9e3088b5e 100644 --- a/p2p/transport/quicreuse/connmgr.go +++ b/p2p/transport/quicreuse/connmgr.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "fmt" "io" "net" "sync" @@ -190,6 +191,18 @@ func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr, } key = tr.LocalAddr().String() entry = quicListenerEntry{ln: ln} + } else if c.enableReuseport && association != nil { + reuse, err := c.getReuse(netw) + if err != nil { + return nil, fmt.Errorf("reuse error: %w", err) + } + err = reuse.AssertTransportExists(entry.ln.transport) + if err != nil { + return nil, fmt.Errorf("reuse assert transport failed: %w", err) + } + if tr, ok := entry.ln.transport.(*refcountedTransport); ok { + tr.associate(association) + } } l, err := entry.ln.Add(tlsConf, allowWindowIncrease, func() { c.onListenerClosed(key) }) if err != nil { diff --git a/p2p/transport/quicreuse/connmgr_test.go b/p2p/transport/quicreuse/connmgr_test.go index 51646bac98..d128119dab 100644 --- a/p2p/transport/quicreuse/connmgr_test.go +++ b/p2p/transport/quicreuse/connmgr_test.go @@ -315,3 +315,59 @@ func TestExternalTransport(t *testing.T) { t.Fatal("doneWithTr not closed") } } + +func TestAssociate(t *testing.T) { + testAssociate := func(lnAddr1, lnAddr2 ma.Multiaddr, dialAddr *net.UDPAddr) { + cm, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{}) + require.NoError(t, err) + defer cm.Close() + + lp2pTLS := &tls.Config{NextProtos: []string{"libp2p"}} + assoc1 := "test-1" + ln1, err := cm.ListenQUICAndAssociate(assoc1, lnAddr1, lp2pTLS, nil) + require.NoError(t, err) + defer ln1.Close() + addrs := ln1.Multiaddrs() + require.Len(t, addrs, 1) + + addr := addrs[0] + assoc2 := "test-2" + h3TLS := &tls.Config{NextProtos: []string{"h3"}} + ln2, err := cm.ListenQUICAndAssociate(assoc2, addr, h3TLS, nil) + require.NoError(t, err) + defer ln2.Close() + + tr1, err := cm.TransportWithAssociationForDial(assoc1, "udp4", dialAddr) + require.NoError(t, err) + defer tr1.Close() + require.Equal(t, tr1.LocalAddr().String(), ln1.Addr().String()) + + tr2, err := cm.TransportWithAssociationForDial(assoc2, "udp4", dialAddr) + require.NoError(t, err) + defer tr2.Close() + require.Equal(t, tr2.LocalAddr().String(), ln2.Addr().String()) + + ln3, err := cm.ListenQUICAndAssociate(assoc1, lnAddr2, lp2pTLS, nil) + require.NoError(t, err) + defer ln3.Close() + + // an unused association should also return the same transport + // association is only a preference for a specific transport, not an exclusion criteria + tr3, err := cm.TransportWithAssociationForDial("unused", "udp4", dialAddr) + require.NoError(t, err) + defer tr3.Close() + require.Contains(t, []string{ln2.Addr().String(), ln3.Addr().String()}, tr3.LocalAddr().String()) + } + + t.Run("MultipleUnspecifiedListeners", func(t *testing.T) { + testAssociate(ma.StringCast("/ip4/0.0.0.0/udp/0/quic-v1"), + ma.StringCast("/ip4/0.0.0.0/udp/0/quic-v1"), + &net.UDPAddr{IP: net.IPv4(1, 1, 1, 1), Port: 1}) + }) + t.Run("MultipleSpecificListeners", func(t *testing.T) { + testAssociate(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), + ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), + &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1}, + ) + }) +} diff --git a/p2p/transport/quicreuse/reuse.go b/p2p/transport/quicreuse/reuse.go index 20a6260b82..6d0098e33d 100644 --- a/p2p/transport/quicreuse/reuse.go +++ b/p2p/transport/quicreuse/reuse.go @@ -303,6 +303,10 @@ func (r *reuse) transportForDialLocked(association any, network string, source * return tr, nil } } + // We don't have a transport with the association, use any one + for _, tr := range trs { + return tr, nil + } } } @@ -313,6 +317,10 @@ func (r *reuse) transportForDialLocked(association any, network string, source * return tr, nil } } + // We don't have a transport with the association, use any one + for _, tr := range r.globalListeners { + return tr, nil + } // Use a transport we've previously dialed from for _, tr := range r.globalDialers { @@ -360,6 +368,33 @@ func (r *reuse) AddTransport(tr *refcountedTransport, laddr *net.UDPAddr) error return nil } +func (r *reuse) AssertTransportExists(tr refCountedQuicTransport) error { + t, ok := tr.(*refcountedTransport) + if !ok { + return fmt.Errorf("invalid transport type: expected: *refcountedTransport, got: %T", tr) + } + laddr := t.LocalAddr().(*net.UDPAddr) + if laddr.IP.IsUnspecified() { + if lt, ok := r.globalListeners[laddr.Port]; ok { + if lt == t { + return nil + } + return errors.New("two global listeners on the same port") + } + return errors.New("transport not found") + } + if m, ok := r.unicast[laddr.IP.String()]; ok { + if lt, ok := m[laddr.Port]; ok { + if lt == t { + return nil + } + return errors.New("two unicast listeners on same ip:port") + } + return errors.New("transport not found") + } + return errors.New("transport not found") +} + func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) { r.mutex.Lock() defer r.mutex.Unlock() diff --git a/p2p/transport/webrtc/udpmux/mux.go b/p2p/transport/webrtc/udpmux/mux.go index f01facae32..76d68d8e89 100644 --- a/p2p/transport/webrtc/udpmux/mux.go +++ b/p2p/transport/webrtc/udpmux/mux.go @@ -271,12 +271,13 @@ func (mux *UDPMux) RemoveConnByUfrag(ufrag string) { for _, isIPv6 := range [...]bool{true, false} { key := ufragConnKey{ufrag: ufrag, isIPv6: isIPv6} - if _, ok := mux.ufragMap[key]; ok { + if conn, ok := mux.ufragMap[key]; ok { delete(mux.ufragMap, key) for _, addr := range mux.ufragAddrMap[key] { delete(mux.addrMap, addr.String()) } delete(mux.ufragAddrMap, key) + conn.close() } } } @@ -293,7 +294,7 @@ func (mux *UDPMux) getOrCreateConn(ufrag string, isIPv6 bool, _ *UDPMux, addr ne return false, conn } - conn := newMuxedConnection(mux, func() { mux.RemoveConnByUfrag(ufrag) }) + conn := newMuxedConnection(mux, ufrag) mux.ufragMap[key] = conn mux.addrMap[addr.String()] = conn mux.ufragAddrMap[key] = append(mux.ufragAddrMap[key], addr) diff --git a/p2p/transport/webrtc/udpmux/mux_test.go b/p2p/transport/webrtc/udpmux/mux_test.go index 298e5c920f..b75f3e8302 100644 --- a/p2p/transport/webrtc/udpmux/mux_test.go +++ b/p2p/transport/webrtc/udpmux/mux_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/pion/stun" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -246,3 +247,28 @@ func TestMuxedConnection(t *testing.T) { } require.Empty(t, addrUfragMap) } + +func TestRemovingUfragClosesConn(t *testing.T) { + c := newPacketConn(t) + m := NewUDPMux(c) + m.Start() + defer m.Close() + remoteAddr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} + conn, err := m.GetConn("a", remoteAddr) + require.NoError(t, err) + defer conn.Close() + + connClosed := make(chan bool) + go func() { + _, _, err := conn.ReadFrom(make([]byte, 100)) + assert.ErrorIs(t, err, context.Canceled) + close(connClosed) + }() + require.NoError(t, err) + m.RemoveConnByUfrag("a") + select { + case <-connClosed: + case <-time.After(1 * time.Second): + t.Fatalf("expected the connection to be closed") + } +} diff --git a/p2p/transport/webrtc/udpmux/muxed_connection.go b/p2p/transport/webrtc/udpmux/muxed_connection.go index 2af5d33253..84d30c84b1 100644 --- a/p2p/transport/webrtc/udpmux/muxed_connection.go +++ b/p2p/transport/webrtc/udpmux/muxed_connection.go @@ -23,31 +23,29 @@ const queueLen = 128 // from which this connection (indexed by ufrag) received // data. type muxedConnection struct { - ctx context.Context - cancel context.CancelFunc - onClose func() - queue chan packet - mux *UDPMux + ctx context.Context + cancel context.CancelFunc + queue chan packet + mux *UDPMux + ufrag string } var _ net.PacketConn = &muxedConnection{} -func newMuxedConnection(mux *UDPMux, onClose func()) *muxedConnection { +func newMuxedConnection(mux *UDPMux, ufrag string) *muxedConnection { ctx, cancel := context.WithCancel(mux.ctx) return &muxedConnection{ - ctx: ctx, - cancel: cancel, - queue: make(chan packet, queueLen), - onClose: onClose, - mux: mux, + ctx: ctx, + cancel: cancel, + queue: make(chan packet, queueLen), + mux: mux, + ufrag: ufrag, } } func (c *muxedConnection) Push(buf []byte, addr net.Addr) error { - select { - case <-c.ctx.Done(): + if c.ctx.Err() != nil { return errors.New("closed") - default: } select { case c.queue <- packet{buf: buf, addr: addr}: @@ -76,12 +74,21 @@ func (c *muxedConnection) WriteTo(p []byte, addr net.Addr) (n int, err error) { } func (c *muxedConnection) Close() error { - select { - case <-c.ctx.Done(): + if c.ctx.Err() != nil { return nil - default: } - c.onClose() + // mux calls close to actually close the connection + // + // Removing the connection from the mux or closing the connection + // must trigger the other. + // Doing this here ensures we don't need to call both RemoveConnByUfrag + // and close on all code paths. + c.mux.RemoveConnByUfrag(c.ufrag) + return nil +} + +// closes the connection. Must only be called by the mux. +func (c *muxedConnection) close() { c.cancel() // drain the packet queue for { @@ -89,7 +96,7 @@ func (c *muxedConnection) Close() error { case p := <-c.queue: pool.Put(p.buf) default: - return nil + return } } } diff --git a/test-plans/go.mod b/test-plans/go.mod index 75153bfefe..ef6c40a72b 100644 --- a/test-plans/go.mod +++ b/test-plans/go.mod @@ -66,14 +66,14 @@ require ( github.com/pion/datachannel v1.5.10 // indirect github.com/pion/dtls/v2 v2.2.12 // indirect github.com/pion/dtls/v3 v3.0.4 // indirect - github.com/pion/ice/v4 v4.0.6 // indirect + github.com/pion/ice/v4 v4.0.8 // indirect github.com/pion/interceptor v0.1.37 // indirect github.com/pion/logging v0.2.3 // indirect github.com/pion/mdns/v2 v2.0.7 // indirect github.com/pion/randutil v0.1.0 // indirect github.com/pion/rtcp v1.2.15 // indirect github.com/pion/rtp v1.8.11 // indirect - github.com/pion/sctp v1.8.36 // indirect + github.com/pion/sctp v1.8.37 // indirect github.com/pion/sdp/v3 v3.0.10 // indirect github.com/pion/srtp/v3 v3.0.4 // indirect github.com/pion/stun v0.6.1 // indirect @@ -83,12 +83,12 @@ require ( github.com/pion/turn/v4 v4.0.0 // indirect github.com/pion/webrtc/v4 v4.0.10 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/prometheus/client_golang v1.21.0 // indirect + github.com/prometheus/client_golang v1.21.1 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.62.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/quic-go/qpack v0.5.1 // indirect - github.com/quic-go/quic-go v0.50.0 // indirect + github.com/quic-go/quic-go v0.50.1 // indirect github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 // indirect github.com/raulk/go-watchdog v1.3.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect diff --git a/test-plans/go.sum b/test-plans/go.sum index 43ce4641ea..4654e55106 100644 --- a/test-plans/go.sum +++ b/test-plans/go.sum @@ -222,8 +222,8 @@ github.com/pion/dtls/v2 v2.2.12 h1:KP7H5/c1EiVAAKUmXyCzPiQe5+bCJrpOeKg/L05dunk= github.com/pion/dtls/v2 v2.2.12/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= github.com/pion/dtls/v3 v3.0.4 h1:44CZekewMzfrn9pmGrj5BNnTMDCFwr+6sLH+cCuLM7U= github.com/pion/dtls/v3 v3.0.4/go.mod h1:R373CsjxWqNPf6MEkfdy3aSe9niZvL/JaKlGeFphtMg= -github.com/pion/ice/v4 v4.0.6 h1:jmM9HwI9lfetQV/39uD0nY4y++XZNPhvzIPCb8EwxUM= -github.com/pion/ice/v4 v4.0.6/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= +github.com/pion/ice/v4 v4.0.8 h1:ajNx0idNG+S+v9Phu4LSn2cs8JEfTsA1/tEjkkAVpFY= +github.com/pion/ice/v4 v4.0.8/go.mod h1:y3M18aPhIxLlcO/4dn9X8LzLLSma84cx6emMSu14FGw= github.com/pion/interceptor v0.1.37 h1:aRA8Zpab/wE7/c0O3fh1PqY0AJI3fCSEM5lRWJVorwI= github.com/pion/interceptor v0.1.37/go.mod h1:JzxbJ4umVTlZAf+/utHzNesY8tmRkM2lVmkS82TTj8Y= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= @@ -237,8 +237,8 @@ github.com/pion/rtcp v1.2.15 h1:LZQi2JbdipLOj4eBjK4wlVoQWfrZbh3Q6eHtWtJBZBo= github.com/pion/rtcp v1.2.15/go.mod h1:jlGuAjHMEXwMUHK78RgX0UmEJFV4zUKOFHR7OP+D3D0= github.com/pion/rtp v1.8.11 h1:17xjnY5WO5hgO6SD3/NTIUPvSFw/PbLsIJyz1r1yNIk= github.com/pion/rtp v1.8.11/go.mod h1:8uMBJj32Pa1wwx8Fuv/AsFhn8jsgw+3rUC2PfoBZ8p4= -github.com/pion/sctp v1.8.36 h1:owNudmnz1xmhfYje5L/FCav3V9wpPRePHle3Zi+P+M0= -github.com/pion/sctp v1.8.36/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE= +github.com/pion/sctp v1.8.37 h1:ZDmGPtRPX9mKCiVXtMbTWybFw3z/hVKAZgU81wcOrqs= +github.com/pion/sctp v1.8.37/go.mod h1:cNiLdchXra8fHQwmIoqw0MbLLMs+f7uQ+dGMG2gWebE= github.com/pion/sdp/v3 v3.0.10 h1:6MChLE/1xYB+CjumMw+gZ9ufp2DPApuVSnDT8t5MIgA= github.com/pion/sdp/v3 v3.0.10/go.mod h1:88GMahN5xnScv1hIMTqLdu/cOcUkj6a9ytbncwMCq2E= github.com/pion/srtp/v3 v3.0.4 h1:2Z6vDVxzrX3UHEgrUyIGM4rRouoC7v+NiF1IHtp9B5M= @@ -263,8 +263,8 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= -github.com/prometheus/client_golang v1.21.0 h1:DIsaGmiaBkSangBgMtWdNfxbMNdku5IK6iNhrEqWvdA= -github.com/prometheus/client_golang v1.21.0/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= +github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= +github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= @@ -276,8 +276,8 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= -github.com/quic-go/quic-go v0.50.0 h1:3H/ld1pa3CYhkcc20TPIyG1bNsdhn9qZBGN3b9/UyUo= -github.com/quic-go/quic-go v0.50.0/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E= +github.com/quic-go/quic-go v0.50.1 h1:unsgjFIUqW8a2oopkY7YNONpV1gYND6Nt9hnt1PN94Q= +github.com/quic-go/quic-go v0.50.1/go.mod h1:Vim6OmUvlYdwBhXP9ZVrtGmCMWa3wEqhq3NgYrI8b4E= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66 h1:4WFk6u3sOT6pLa1kQ50ZVdm8BQFgJNA117cepZxtLIg= github.com/quic-go/webtransport-go v0.8.1-0.20241018022711-4ac2c9250e66/go.mod h1:Vp72IJajgeOL6ddqrAhmp7IM9zbTcgkQxD/YdxrVwMw= github.com/raulk/go-watchdog v1.3.0 h1:oUmdlHxdkXRJlwfG0O9omj8ukerm8MEQavSiDTEtBsk= diff --git a/version.json b/version.json index d17f08210f..d452b313fe 100644 --- a/version.json +++ b/version.json @@ -1,3 +1,3 @@ { - "version": "v0.41.0" + "version": "v0.41.1" }