From b7e8554ee3924148e0c6a1991f1c70d00f3dd336 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Tue, 30 May 2023 00:42:50 +0100 Subject: [PATCH] http request decide protocol based on ALPN --- .../request/roundtripper/httprt/httprt.go | 22 +- .../internet/request/stereotype/meek/meek.go | 3 +- transport/internet/security/connprop.go | 5 + transport/internet/tls/tls.go | 7 + transport/internet/tls/utls/utls.go | 13 +- .../internet/transportcommon/httpDialer.go | 215 ++++++++++++++++++ 6 files changed, 254 insertions(+), 11 deletions(-) create mode 100644 transport/internet/security/connprop.go create mode 100644 transport/internet/transportcommon/httpDialer.go diff --git a/transport/internet/request/roundtripper/httprt/httprt.go b/transport/internet/request/roundtripper/httprt/httprt.go index a4968befc9b..83dc76c8657 100644 --- a/transport/internet/request/roundtripper/httprt/httprt.go +++ b/transport/internet/request/roundtripper/httprt/httprt.go @@ -7,10 +7,12 @@ import ( "context" "encoding/base64" "io" + gonet "net" "net/http" - "github.com/v2fly/v2ray-core/v5/common" + "github.com/v2fly/v2ray-core/v5/transport/internet/transportcommon" + "github.com/v2fly/v2ray-core/v5/common" "github.com/v2fly/v2ray-core/v5/common/net" "github.com/v2fly/v2ray-core/v5/transport/internet/request" ) @@ -25,20 +27,22 @@ type httpTripperClient struct { assembly request.TransportClientAssembly } +type unimplementedBackDrop struct { +} + +func (u unimplementedBackDrop) RoundTrip(r *http.Request) (*http.Response, error) { + return nil, newError("unimplemented") +} + func (h *httpTripperClient) OnTransportClientAssemblyReady(assembly request.TransportClientAssembly) { h.assembly = assembly } func (h *httpTripperClient) RoundTrip(ctx context.Context, req request.Request, opts ...request.RoundTripperOption) (resp request.Response, err error) { if h.httpRTT == nil { - h.httpRTT = &http.Transport{ - DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) { - return h.assembly.AutoImplDialer().Dial(ctx) - }, - DialTLSContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) { - return h.assembly.AutoImplDialer().Dial(ctx) - }, - } + h.httpRTT = transportcommon.NewALPNAwareHTTPRoundTripper(ctx, func(ctx context.Context, addr string) (gonet.Conn, error) { + return h.assembly.AutoImplDialer().Dial(ctx) + }, unimplementedBackDrop{}) } connectionTagStr := base64.RawURLEncoding.EncodeToString(req.ConnectionTag) diff --git a/transport/internet/request/stereotype/meek/meek.go b/transport/internet/request/stereotype/meek/meek.go index db232e5908c..69b6fe4cf13 100644 --- a/transport/internet/request/stereotype/meek/meek.go +++ b/transport/internet/request/stereotype/meek/meek.go @@ -38,7 +38,8 @@ func meekDial(ctx context.Context, dest net.Destination, streamSettings *interne } httprtSetting := &httprt.ClientConfig{Http: &httprt.HTTPConfig{ UrlPrefix: meekSetting.Url, - }} + }, + } request := &assembly.Config{ Assembler: serial.ToTypedMessage(simpleAssembler), Roundtripper: serial.ToTypedMessage(httprtSetting), diff --git a/transport/internet/security/connprop.go b/transport/internet/security/connprop.go new file mode 100644 index 00000000000..13c962c9826 --- /dev/null +++ b/transport/internet/security/connprop.go @@ -0,0 +1,5 @@ +package security + +type ConnectionApplicationProtocol interface { + GetConnectionApplicationProtocol() (string, error) +} diff --git a/transport/internet/tls/tls.go b/transport/internet/tls/tls.go index 671cabbf948..389315316da 100644 --- a/transport/internet/tls/tls.go +++ b/transport/internet/tls/tls.go @@ -17,6 +17,13 @@ type Conn struct { *tls.Conn } +func (c *Conn) GetConnectionApplicationProtocol() (string, error) { + if err := c.Handshake(); err != nil { + return "", err + } + return c.ConnectionState().NegotiatedProtocol, nil +} + func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error { mb = buf.Compact(mb) mb, err := buf.WriteMultiBuffer(c, mb) diff --git a/transport/internet/tls/utls/utls.go b/transport/internet/tls/utls/utls.go index 29349fbdac0..be4b42e0607 100644 --- a/transport/internet/tls/utls/utls.go +++ b/transport/internet/tls/utls/utls.go @@ -90,7 +90,18 @@ func (e Engine) Client(conn net.Conn, opts ...security.Option) (security.Conn, e if err != nil { return nil, newError("unable to finish utls handshake").Base(err) } - return utlsClientConn, nil + return uTLSClientConnection{utlsClientConn}, nil +} + +type uTLSClientConnection struct { + *utls.UConn +} + +func (u uTLSClientConnection) GetConnectionApplicationProtocol() (string, error) { + if err := u.Handshake(); err != nil { + return "", err + } + return u.ConnectionState().NegotiatedProtocol, nil } func uTLSConfigFromTLSConfig(config *systls.Config) (*utls.Config, error) { // nolint: unparam diff --git a/transport/internet/transportcommon/httpDialer.go b/transport/internet/transportcommon/httpDialer.go new file mode 100644 index 00000000000..17dd161e7b1 --- /dev/null +++ b/transport/internet/transportcommon/httpDialer.go @@ -0,0 +1,215 @@ +package transportcommon + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "sync" + "time" + + "github.com/v2fly/v2ray-core/v5/transport/internet/security" + "golang.org/x/net/http2" +) + +type DialerFunc func(ctx context.Context, addr string) (net.Conn, error) + +// NewALPNAwareHTTPRoundTripper creates an instance of RoundTripper that dial to remote HTTPS endpoint with +// an alternative version of TLS implementation. +func NewALPNAwareHTTPRoundTripper(ctx context.Context, dialer DialerFunc, + backdropTransport http.RoundTripper) http.RoundTripper { + rtImpl := &alpnAwareHTTPRoundTripperImpl{ + connectWithH1: map[string]bool{}, + backdropTransport: backdropTransport, + pendingConn: map[pendingConnKey]*unclaimedConnection{}, + dialer: dialer, + ctx: ctx, + } + rtImpl.init() + return rtImpl +} + +type alpnAwareHTTPRoundTripperImpl struct { + accessConnectWithH1 sync.Mutex + connectWithH1 map[string]bool + + httpsH1Transport http.RoundTripper + httpsH2Transport http.RoundTripper + backdropTransport http.RoundTripper + + accessDialingConnection sync.Mutex + pendingConn map[pendingConnKey]*unclaimedConnection + + ctx context.Context + dialer DialerFunc +} + +type pendingConnKey struct { + isH2 bool + dest string +} + +var errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN") +var errEAGAINTooMany = errors.New("incorrect ALPN negotiated") +var errExpired = errors.New("connection have expired") + +func (r *alpnAwareHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) { + if req.URL.Scheme != "https" { + return r.backdropTransport.RoundTrip(req) + } + for retryCount := 0; retryCount < 5; retryCount++ { + effectivePort := req.URL.Port() + if effectivePort == "" { + effectivePort = "443" + } + if r.getShouldConnectWithH1(fmt.Sprintf("%v:%v", req.URL.Hostname(), effectivePort)) { + resp, err := r.httpsH1Transport.RoundTrip(req) + if errors.Is(err, errEAGAIN) { + continue + } + return resp, err + } + resp, err := r.httpsH2Transport.RoundTrip(req) + if errors.Is(err, errEAGAIN) { + continue + } + return resp, err + } + return nil, errEAGAINTooMany +} + +func (r *alpnAwareHTTPRoundTripperImpl) getShouldConnectWithH1(domainName string) bool { + r.accessConnectWithH1.Lock() + defer r.accessConnectWithH1.Unlock() + if value, set := r.connectWithH1[domainName]; set { + return value + } + return false +} + +func (r *alpnAwareHTTPRoundTripperImpl) setShouldConnectWithH1(domainName string) { + r.accessConnectWithH1.Lock() + defer r.accessConnectWithH1.Unlock() + r.connectWithH1[domainName] = true +} + +func (r *alpnAwareHTTPRoundTripperImpl) clearShouldConnectWithH1(domainName string) { + r.accessConnectWithH1.Lock() + defer r.accessConnectWithH1.Unlock() + r.connectWithH1[domainName] = false +} + +func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey { + return pendingConnKey{isH2: alpnIsH2, dest: dest} +} + +func (r *alpnAwareHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) { + connId := getPendingConnectionID(addr, alpnIsH2) + r.pendingConn[connId] = NewUnclaimedConnection(conn, time.Minute) +} +func (r *alpnAwareHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn { + connId := getPendingConnectionID(addr, alpnIsH2) + if conn, ok := r.pendingConn[connId]; ok { + delete(r.pendingConn, connId) + if claimedConnection, err := conn.claimConnection(); err == nil { + return claimedConnection + } + } + return nil +} +func (r *alpnAwareHTTPRoundTripperImpl) dialOrGetTLSWithExpectedALPN(ctx context.Context, addr string, expectedH2 bool) (net.Conn, error) { + r.accessDialingConnection.Lock() + defer r.accessDialingConnection.Unlock() + + if r.getShouldConnectWithH1(addr) == expectedH2 { + return nil, errEAGAIN + } + + //Get a cached connection if possible to reduce preflight connection closed without sending data + if gconn := r.getConn(addr, expectedH2); gconn != nil { + return gconn, nil + } + + conn, err := r.dialTLS(ctx, addr) + if err != nil { + return nil, err + } + + protocol := "" + if connAPLNGetter, ok := conn.(security.ConnectionApplicationProtocol); ok { + connectionALPN, err := connAPLNGetter.GetConnectionApplicationProtocol() + if err != nil { + return nil, newError("failed to get connection ALPN").Base(err).AtWarning() + } + protocol = connectionALPN + } + + protocolIsH2 := protocol == http2.NextProtoTLS + + if protocolIsH2 == expectedH2 { + return conn, err + } + + r.putConn(addr, protocolIsH2, conn) + + if protocolIsH2 { + r.clearShouldConnectWithH1(addr) + } else { + r.setShouldConnectWithH1(addr) + } + + return nil, errEAGAIN +} + +func (r *alpnAwareHTTPRoundTripperImpl) dialTLS(ctx context.Context, addr string) (net.Conn, error) { + return r.dialer(r.ctx, addr) +} + +func (r *alpnAwareHTTPRoundTripperImpl) init() { + r.httpsH2Transport = &http2.Transport{ + DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { + return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true) + }, + } + r.httpsH1Transport = &http.Transport{ + DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { + return r.dialOrGetTLSWithExpectedALPN(ctx, addr, false) + }, + } +} + +func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection { + c := &unclaimedConnection{ + Conn: conn, + } + time.AfterFunc(expireTime, c.tick) + return c +} + +type unclaimedConnection struct { + net.Conn + claimed bool + access sync.Mutex +} + +func (c *unclaimedConnection) claimConnection() (net.Conn, error) { + c.access.Lock() + defer c.access.Unlock() + if !c.claimed { + c.claimed = true + return c.Conn, nil + } + return nil, errExpired +} + +func (c *unclaimedConnection) tick() { + c.access.Lock() + defer c.access.Unlock() + if !c.claimed { + c.claimed = true + c.Conn.Close() + c.Conn = nil + } +}