diff --git a/libp2p_test.go b/libp2p_test.go index 7ca7850854..a5803add4d 100644 --- a/libp2p_test.go +++ b/libp2p_test.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "errors" "fmt" + "io" "net" "net/netip" "regexp" @@ -587,3 +588,70 @@ func TestWebRTCReuseAddrWithQUIC(t *testing.T) { require.Contains(t, h1.Addrs()[0].String(), "quic-v1") }) } + +func TestUseCorrectTransportForDialOut(t *testing.T) { + listAddrOrder := [][]string{ + {"/ip4/127.0.0.1/udp/0/quic-v1", "/ip4/127.0.0.1/udp/0/quic-v1/webtransport"}, + {"/ip4/127.0.0.1/udp/0/quic-v1/webtransport", "/ip4/127.0.0.1/udp/0/quic-v1"}, + {"/ip4/0.0.0.0/udp/0/quic-v1", "/ip4/0.0.0.0/udp/0/quic-v1/webtransport"}, + {"/ip4/0.0.0.0/udp/0/quic-v1/webtransport", "/ip4/0.0.0.0/udp/0/quic-v1"}, + } + for _, order := range listAddrOrder { + h1, err := New(ListenAddrStrings(order...), Transport(quic.NewTransport), Transport(webtransport.New)) + require.NoError(t, err) + t.Cleanup(func() { + h1.Close() + }) + + go func() { + h1.SetStreamHandler("/echo-port", func(s network.Stream) { + m := s.Conn().RemoteMultiaddr() + v, err := m.ValueForProtocol(ma.P_UDP) + if err != nil { + s.Reset() + return + } + s.Write([]byte(v)) + s.Close() + }) + }() + + for _, addr := range h1.Addrs() { + t.Run("order "+strings.Join(order, ",")+" Dial to "+addr.String(), func(t *testing.T) { + h2, err := New(ListenAddrStrings( + "/ip4/0.0.0.0/udp/0/quic-v1", + "/ip4/0.0.0.0/udp/0/quic-v1/webtransport", + ), Transport(quic.NewTransport), Transport(webtransport.New)) + require.NoError(t, err) + defer h2.Close() + t.Log("H2 Addrs", h2.Addrs()) + var myExpectedDialOutAddr ma.Multiaddr + addrIsWT, _ := webtransport.IsWebtransportMultiaddr(addr) + isLocal := func(a ma.Multiaddr) bool { + return strings.Contains(a.String(), "127.0.0.1") + } + addrIsLocal := isLocal(addr) + for _, a := range h2.Addrs() { + aIsWT, _ := webtransport.IsWebtransportMultiaddr(a) + if addrIsWT == aIsWT && isLocal(a) == addrIsLocal { + myExpectedDialOutAddr = a + break + } + } + + err = h2.Connect(context.Background(), peer.AddrInfo{ID: h1.ID(), Addrs: []ma.Multiaddr{addr}}) + require.NoError(t, err) + + s, err := h2.NewStream(context.Background(), h1.ID(), "/echo-port") + require.NoError(t, err) + + port, err := io.ReadAll(s) + require.NoError(t, err) + + myExpectedPort, err := myExpectedDialOutAddr.ValueForProtocol(ma.P_UDP) + require.NoError(t, err) + require.Equal(t, myExpectedPort, string(port)) + }) + } + } +} diff --git a/p2p/transport/quic/transport.go b/p2p/transport/quic/transport.go index 04b5e4d6fe..f0862d22e9 100644 --- a/p2p/transport/quic/transport.go +++ b/p2p/transport/quic/transport.go @@ -136,6 +136,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee } tlsConf, keyCh := t.identity.ConfigForPeer(p) + ctx = quicreuse.WithAssociation(ctx, t) pconn, err := t.connManager.DialQUIC(ctx, raddr, tlsConf, t.allowWindowIncrease) if err != nil { return nil, err @@ -196,7 +197,7 @@ func (t *transport) holePunch(ctx context.Context, raddr ma.Multiaddr, p peer.ID if err != nil { return nil, err } - tr, err := t.connManager.TransportForDial(network, addr) + tr, err := t.connManager.TransportWithAssociationForDial(t, network, addr) if err != nil { return nil, err } @@ -313,7 +314,7 @@ func (t *transport) Listen(addr ma.Multiaddr) (tpt.Listener, error) { return nil, fmt.Errorf("can't listen on quic version %v, underlying listener doesn't support it", version) } } else { - ln, err := t.connManager.ListenQUIC(addr, &tlsConf, t.allowWindowIncrease) + ln, err := t.connManager.ListenQUICAndAssociate(t, addr, &tlsConf, t.allowWindowIncrease) if err != nil { return nil, err } diff --git a/p2p/transport/quicreuse/connmgr.go b/p2p/transport/quicreuse/connmgr.go index 01e1bcda5e..8e4c61b0bc 100644 --- a/p2p/transport/quicreuse/connmgr.go +++ b/p2p/transport/quicreuse/connmgr.go @@ -102,6 +102,11 @@ func (c *ConnManager) getReuse(network string) (*reuse, error) { } func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) { + return c.ListenQUICAndAssociate(nil, addr, tlsConf, allowWindowIncrease) +} + +// ListenQUICAndAssociate returns a QUIC listener and associates the underlying transport with the given association. +func (c *ConnManager) ListenQUICAndAssociate(association any, addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) { netw, host, err := manet.DialArgs(addr) if err != nil { return nil, err @@ -117,7 +122,7 @@ func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWi key := laddr.String() entry, ok := c.quicListeners[key] if !ok { - tr, err := c.transportForListen(netw, laddr) + tr, err := c.transportForListen(association, netw, laddr) if err != nil { return nil, err } @@ -176,13 +181,18 @@ func (c *ConnManager) SharedNonQUICPacketConn(network string, laddr *net.UDPAddr return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set") } -func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) { +func (c *ConnManager) transportForListen(association any, network string, laddr *net.UDPAddr) (refCountedQuicTransport, error) { if c.enableReuseport { reuse, err := c.getReuse(network) if err != nil { return nil, err } - return reuse.TransportForListen(network, laddr) + tr, err := reuse.TransportForListen(network, laddr) + if err != nil { + return nil, err + } + tr.associate(association) + return tr, nil } conn, err := net.ListenUDP(network, laddr) @@ -199,6 +209,14 @@ func (c *ConnManager) transportForListen(network string, laddr *net.UDPAddr) (re }, nil } +type associationKey struct{} + +// WithAssociation returns a new context with the given association. Used in +// DialQUIC to prefer a transport that has the given association. +func WithAssociation(ctx context.Context, association any) context.Context { + return context.WithValue(ctx, associationKey{}, association) +} + func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (quic.Connection, error) { naddr, v, err := FromQuicMultiaddr(raddr) if err != nil { @@ -219,7 +237,12 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf return nil, errors.New("unknown QUIC version") } - tr, err := c.TransportForDial(netw, naddr) + var tr refCountedQuicTransport + if association := ctx.Value(associationKey{}); association != nil { + tr, err = c.TransportWithAssociationForDial(association, netw, naddr) + } else { + tr, err = c.TransportForDial(netw, naddr) + } if err != nil { return nil, err } @@ -232,12 +255,17 @@ func (c *ConnManager) DialQUIC(ctx context.Context, raddr ma.Multiaddr, tlsConf } func (c *ConnManager) TransportForDial(network string, raddr *net.UDPAddr) (refCountedQuicTransport, error) { + return c.TransportWithAssociationForDial(nil, network, raddr) +} + +// TransportWithAssociationForDial returns a QUIC transport for dialing, preferring a transport with the given association. +func (c *ConnManager) TransportWithAssociationForDial(association any, network string, raddr *net.UDPAddr) (refCountedQuicTransport, error) { if c.enableReuseport { reuse, err := c.getReuse(network) if err != nil { return nil, err } - return reuse.TransportForDial(network, raddr) + return reuse.transportWithAssociationForDial(association, network, raddr) } var laddr *net.UDPAddr diff --git a/p2p/transport/quicreuse/connmgr_test.go b/p2p/transport/quicreuse/connmgr_test.go index f3576a3905..8e7da2cd7f 100644 --- a/p2p/transport/quicreuse/connmgr_test.go +++ b/p2p/transport/quicreuse/connmgr_test.go @@ -61,8 +61,6 @@ func testListenOnSameProto(t *testing.T, enableReuseport bool) { const alpn = "proto" - var tlsConf tls.Config - tlsConf.NextProtos = []string{alpn} ln1, err := cm.ListenQUIC(ma.StringCast("/ip4/127.0.0.1/udp/0/quic-v1"), &tls.Config{NextProtos: []string{alpn}}, nil) require.NoError(t, err) defer ln1.Close() @@ -96,7 +94,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) { _, err = cm.ListenQUIC(raddr, &tls.Config{NextProtos: []string{"proto"}}, nil) require.NoError(t, err) - quicTr, err := cm.transportForListen(netw, naddr) + quicTr, err := cm.transportForListen(nil, netw, naddr) require.NoError(t, err) defer quicTr.Close() if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok { diff --git a/p2p/transport/quicreuse/reuse.go b/p2p/transport/quicreuse/reuse.go index dc2b33b853..c6fc611331 100644 --- a/p2p/transport/quicreuse/reuse.go +++ b/p2p/transport/quicreuse/reuse.go @@ -69,6 +69,36 @@ type refcountedTransport struct { mutex sync.Mutex refCount int unusedSince time.Time + + assocations map[any]struct{} +} + +// associate an arbitrary value with this transport. +// This lets us "tag" the refcountedTransport when listening so we can use it +// later for dialing. Necessary for holepunching and learning about our own +// observed listening address. +func (c *refcountedTransport) associate(a any) { + if a == nil { + return + } + c.mutex.Lock() + defer c.mutex.Unlock() + if c.assocations == nil { + c.assocations = make(map[any]struct{}) + } + c.assocations[a] = struct{}{} +} + +// hasAssociation returns true if the transport has the given association. +// If it is a nil association, it will always return true. +func (c *refcountedTransport) hasAssociation(a any) bool { + if a == nil { + return true + } + c.mutex.Lock() + defer c.mutex.Unlock() + _, ok := c.assocations[a] + return ok } func (c *refcountedTransport) IncreaseCount() { @@ -204,7 +234,7 @@ func (r *reuse) gc() { } } -func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcountedTransport, error) { +func (r *reuse) transportWithAssociationForDial(association any, network string, raddr *net.UDPAddr) (*refcountedTransport, error) { var ip *net.IP // Only bother looking up the source address if we actually _have_ non 0.0.0.0 listeners. @@ -224,7 +254,7 @@ func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcounte r.mutex.Lock() defer r.mutex.Unlock() - tr, err := r.transportForDialLocked(network, ip) + tr, err := r.transportForDialLocked(association, network, ip) if err != nil { return nil, err } @@ -232,21 +262,26 @@ func (r *reuse) TransportForDial(network string, raddr *net.UDPAddr) (*refcounte return tr, nil } -func (r *reuse) transportForDialLocked(network string, source *net.IP) (*refcountedTransport, error) { +func (r *reuse) transportForDialLocked(association any, network string, source *net.IP) (*refcountedTransport, error) { if source != nil { // We already have at least one suitable transport... if trs, ok := r.unicast[source.String()]; ok { - // ... we don't care which port we're dialing from. Just use the first. + // Prefer a transport that has the given association. We want to + // reuse the transport the association used for listening. for _, tr := range trs { - return tr, nil + if tr.hasAssociation(association) { + return tr, nil + } } } } // Use a transport listening on 0.0.0.0 (or ::). - // Again, we don't care about the port number. + // Again, prefer a transport that has the given association. for _, tr := range r.globalListeners { - return tr, nil + if tr.hasAssociation(association) { + return tr, nil + } } // Use a transport we've previously dialed from diff --git a/p2p/transport/quicreuse/reuse_test.go b/p2p/transport/quicreuse/reuse_test.go index b463094720..1da32b5e24 100644 --- a/p2p/transport/quicreuse/reuse_test.go +++ b/p2p/transport/quicreuse/reuse_test.go @@ -91,7 +91,7 @@ func TestReuseCreateNewGlobalConnOnDial(t *testing.T) { addr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") require.NoError(t, err) - conn, err := reuse.TransportForDial("udp4", addr) + conn, err := reuse.transportWithAssociationForDial(nil, "udp4", addr) require.NoError(t, err) require.Equal(t, 1, conn.GetCount()) laddr := conn.LocalAddr().(*net.UDPAddr) @@ -111,7 +111,7 @@ func TestReuseConnectionWhenDialing(t *testing.T) { // dial raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") require.NoError(t, err) - conn, err := reuse.TransportForDial("udp4", raddr) + conn, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr) require.NoError(t, err) require.Equal(t, 2, conn.GetCount()) } @@ -122,7 +122,7 @@ func TestReuseConnectionWhenListening(t *testing.T) { raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") require.NoError(t, err) - tr, err := reuse.TransportForDial("udp4", raddr) + tr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr) require.NoError(t, err) laddr := &net.UDPAddr{IP: net.IPv4zero, Port: tr.LocalAddr().(*net.UDPAddr).Port} lconn, err := reuse.TransportForListen("udp4", laddr) @@ -138,7 +138,7 @@ func TestReuseConnectionWhenDialBeforeListen(t *testing.T) { // dial any address raddr, err := net.ResolveUDPAddr("udp4", "1.1.1.1:1234") require.NoError(t, err) - rTr, err := reuse.TransportForDial("udp4", raddr) + rTr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr) require.NoError(t, err) // open a listener @@ -149,7 +149,7 @@ func TestReuseConnectionWhenDialBeforeListen(t *testing.T) { // new dials should go via the listener connection raddr, err = net.ResolveUDPAddr("udp4", "1.1.1.1:1235") require.NoError(t, err) - tr, err := reuse.TransportForDial("udp4", raddr) + tr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr) require.NoError(t, err) require.Equal(t, lTr, tr) require.Equal(t, 2, tr.GetCount()) @@ -183,7 +183,7 @@ func TestReuseListenOnSpecificInterface(t *testing.T) { require.NoError(t, err) require.Equal(t, 1, lconn.GetCount()) // dial - conn, err := reuse.TransportForDial("udp4", raddr) + conn, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr) require.NoError(t, err) require.Equal(t, 1, conn.GetCount()) } @@ -214,7 +214,7 @@ func TestReuseGarbageCollect(t *testing.T) { raddr, err := net.ResolveUDPAddr("udp4", "1.2.3.4:1234") require.NoError(t, err) - dTr, err := reuse.TransportForDial("udp4", raddr) + dTr, err := reuse.transportWithAssociationForDial(nil, "udp4", raddr) require.NoError(t, err) require.Equal(t, 1, dTr.GetCount()) diff --git a/p2p/transport/webtransport/transport.go b/p2p/transport/webtransport/transport.go index ef8551d60f..acb40f0b89 100644 --- a/p2p/transport/webtransport/transport.go +++ b/p2p/transport/webtransport/transport.go @@ -207,6 +207,7 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string return verifyRawCerts(rawCerts, certHashes) } } + ctx = quicreuse.WithAssociation(ctx, t) conn, err := t.connManager.DialQUIC(ctx, addr, tlsConf, t.allowWindowIncrease) if err != nil { return nil, nil, err @@ -331,7 +332,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { } tlsConf.NextProtos = append(tlsConf.NextProtos, http3.NextProtoH3) - ln, err := t.connManager.ListenQUIC(laddr, tlsConf, t.allowWindowIncrease) + ln, err := t.connManager.ListenQUICAndAssociate(t, laddr, tlsConf, t.allowWindowIncrease) if err != nil { return nil, err }