Skip to content

Commit

Permalink
http request decide protocol based on ALPN
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaokangwang committed May 30, 2023
1 parent 0e519b9 commit b7e8554
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 11 deletions.
22 changes: 13 additions & 9 deletions transport/internet/request/roundtripper/httprt/httprt.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ import (
"context"
"encoding/base64"
"io"
gonet "net"

This comment has been minimized.

"net/http"

"github.com/v2fly/v2ray-core/v5/common"
"github.com/v2fly/v2ray-core/v5/transport/internet/transportcommon"

"github.com/v2fly/v2ray-core/v5/common"
"github.com/v2fly/v2ray-core/v5/common/net"
"github.com/v2fly/v2ray-core/v5/transport/internet/request"
)
Expand All @@ -25,20 +27,22 @@ type httpTripperClient struct {
assembly request.TransportClientAssembly
}

type unimplementedBackDrop struct {
}

func (u unimplementedBackDrop) RoundTrip(r *http.Request) (*http.Response, error) {
return nil, newError("unimplemented")
}

func (h *httpTripperClient) OnTransportClientAssemblyReady(assembly request.TransportClientAssembly) {
h.assembly = assembly
}

func (h *httpTripperClient) RoundTrip(ctx context.Context, req request.Request, opts ...request.RoundTripperOption) (resp request.Response, err error) {
if h.httpRTT == nil {
h.httpRTT = &http.Transport{
DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
return h.assembly.AutoImplDialer().Dial(ctx)
},
DialTLSContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
return h.assembly.AutoImplDialer().Dial(ctx)
},
}
h.httpRTT = transportcommon.NewALPNAwareHTTPRoundTripper(ctx, func(ctx context.Context, addr string) (gonet.Conn, error) {
return h.assembly.AutoImplDialer().Dial(ctx)
}, unimplementedBackDrop{})
}

connectionTagStr := base64.RawURLEncoding.EncodeToString(req.ConnectionTag)
Expand Down
3 changes: 2 additions & 1 deletion transport/internet/request/stereotype/meek/meek.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ func meekDial(ctx context.Context, dest net.Destination, streamSettings *interne
}
httprtSetting := &httprt.ClientConfig{Http: &httprt.HTTPConfig{
UrlPrefix: meekSetting.Url,
}}
},
}
request := &assembly.Config{
Assembler: serial.ToTypedMessage(simpleAssembler),
Roundtripper: serial.ToTypedMessage(httprtSetting),
Expand Down
5 changes: 5 additions & 0 deletions transport/internet/security/connprop.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package security

type ConnectionApplicationProtocol interface {
GetConnectionApplicationProtocol() (string, error)
}
7 changes: 7 additions & 0 deletions transport/internet/tls/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ type Conn struct {
*tls.Conn
}

func (c *Conn) GetConnectionApplicationProtocol() (string, error) {
if err := c.Handshake(); err != nil {
return "", err
}
return c.ConnectionState().NegotiatedProtocol, nil
}

func (c *Conn) WriteMultiBuffer(mb buf.MultiBuffer) error {
mb = buf.Compact(mb)
mb, err := buf.WriteMultiBuffer(c, mb)
Expand Down
13 changes: 12 additions & 1 deletion transport/internet/tls/utls/utls.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,18 @@ func (e Engine) Client(conn net.Conn, opts ...security.Option) (security.Conn, e
if err != nil {
return nil, newError("unable to finish utls handshake").Base(err)
}
return utlsClientConn, nil
return uTLSClientConnection{utlsClientConn}, nil
}

type uTLSClientConnection struct {
*utls.UConn
}

func (u uTLSClientConnection) GetConnectionApplicationProtocol() (string, error) {
if err := u.Handshake(); err != nil {
return "", err
}
return u.ConnectionState().NegotiatedProtocol, nil
}

func uTLSConfigFromTLSConfig(config *systls.Config) (*utls.Config, error) { // nolint: unparam
Expand Down
215 changes: 215 additions & 0 deletions transport/internet/transportcommon/httpDialer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
package transportcommon

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"sync"
"time"

"github.com/v2fly/v2ray-core/v5/transport/internet/security"
"golang.org/x/net/http2"
)

type DialerFunc func(ctx context.Context, addr string) (net.Conn, error)

// NewALPNAwareHTTPRoundTripper creates an instance of RoundTripper that dial to remote HTTPS endpoint with
// an alternative version of TLS implementation.
func NewALPNAwareHTTPRoundTripper(ctx context.Context, dialer DialerFunc,
backdropTransport http.RoundTripper) http.RoundTripper {
rtImpl := &alpnAwareHTTPRoundTripperImpl{
connectWithH1: map[string]bool{},
backdropTransport: backdropTransport,
pendingConn: map[pendingConnKey]*unclaimedConnection{},
dialer: dialer,
ctx: ctx,
}
rtImpl.init()
return rtImpl
}

type alpnAwareHTTPRoundTripperImpl struct {
accessConnectWithH1 sync.Mutex
connectWithH1 map[string]bool

httpsH1Transport http.RoundTripper
httpsH2Transport http.RoundTripper
backdropTransport http.RoundTripper

accessDialingConnection sync.Mutex
pendingConn map[pendingConnKey]*unclaimedConnection

ctx context.Context
dialer DialerFunc
}

type pendingConnKey struct {
isH2 bool
dest string
}

var errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN")
var errEAGAINTooMany = errors.New("incorrect ALPN negotiated")
var errExpired = errors.New("connection have expired")

func (r *alpnAwareHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "https" {
return r.backdropTransport.RoundTrip(req)
}
for retryCount := 0; retryCount < 5; retryCount++ {
effectivePort := req.URL.Port()
if effectivePort == "" {
effectivePort = "443"
}
if r.getShouldConnectWithH1(fmt.Sprintf("%v:%v", req.URL.Hostname(), effectivePort)) {
resp, err := r.httpsH1Transport.RoundTrip(req)
if errors.Is(err, errEAGAIN) {
continue
}
return resp, err
}
resp, err := r.httpsH2Transport.RoundTrip(req)
if errors.Is(err, errEAGAIN) {
continue
}
return resp, err
}
return nil, errEAGAINTooMany
}

func (r *alpnAwareHTTPRoundTripperImpl) getShouldConnectWithH1(domainName string) bool {
r.accessConnectWithH1.Lock()
defer r.accessConnectWithH1.Unlock()
if value, set := r.connectWithH1[domainName]; set {
return value
}
return false
}

func (r *alpnAwareHTTPRoundTripperImpl) setShouldConnectWithH1(domainName string) {
r.accessConnectWithH1.Lock()
defer r.accessConnectWithH1.Unlock()
r.connectWithH1[domainName] = true
}

func (r *alpnAwareHTTPRoundTripperImpl) clearShouldConnectWithH1(domainName string) {
r.accessConnectWithH1.Lock()
defer r.accessConnectWithH1.Unlock()
r.connectWithH1[domainName] = false
}

func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey {
return pendingConnKey{isH2: alpnIsH2, dest: dest}
}

func (r *alpnAwareHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) {
connId := getPendingConnectionID(addr, alpnIsH2)
r.pendingConn[connId] = NewUnclaimedConnection(conn, time.Minute)
}
func (r *alpnAwareHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn {
connId := getPendingConnectionID(addr, alpnIsH2)
if conn, ok := r.pendingConn[connId]; ok {
delete(r.pendingConn, connId)
if claimedConnection, err := conn.claimConnection(); err == nil {
return claimedConnection
}
}
return nil
}
func (r *alpnAwareHTTPRoundTripperImpl) dialOrGetTLSWithExpectedALPN(ctx context.Context, addr string, expectedH2 bool) (net.Conn, error) {
r.accessDialingConnection.Lock()
defer r.accessDialingConnection.Unlock()

if r.getShouldConnectWithH1(addr) == expectedH2 {
return nil, errEAGAIN
}

//Get a cached connection if possible to reduce preflight connection closed without sending data
if gconn := r.getConn(addr, expectedH2); gconn != nil {
return gconn, nil
}

conn, err := r.dialTLS(ctx, addr)
if err != nil {
return nil, err
}

protocol := ""
if connAPLNGetter, ok := conn.(security.ConnectionApplicationProtocol); ok {
connectionALPN, err := connAPLNGetter.GetConnectionApplicationProtocol()
if err != nil {
return nil, newError("failed to get connection ALPN").Base(err).AtWarning()
}
protocol = connectionALPN
}

protocolIsH2 := protocol == http2.NextProtoTLS

if protocolIsH2 == expectedH2 {
return conn, err
}

r.putConn(addr, protocolIsH2, conn)

if protocolIsH2 {
r.clearShouldConnectWithH1(addr)
} else {
r.setShouldConnectWithH1(addr)
}

return nil, errEAGAIN
}

func (r *alpnAwareHTTPRoundTripperImpl) dialTLS(ctx context.Context, addr string) (net.Conn, error) {
return r.dialer(r.ctx, addr)
}

func (r *alpnAwareHTTPRoundTripperImpl) init() {
r.httpsH2Transport = &http2.Transport{
DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true)
},
}
r.httpsH1Transport = &http.Transport{
DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
return r.dialOrGetTLSWithExpectedALPN(ctx, addr, false)
},
}
}

func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection {
c := &unclaimedConnection{
Conn: conn,
}
time.AfterFunc(expireTime, c.tick)
return c
}

type unclaimedConnection struct {
net.Conn
claimed bool
access sync.Mutex
}

func (c *unclaimedConnection) claimConnection() (net.Conn, error) {
c.access.Lock()
defer c.access.Unlock()
if !c.claimed {
c.claimed = true
return c.Conn, nil
}
return nil, errExpired
}

func (c *unclaimedConnection) tick() {
c.access.Lock()
defer c.access.Unlock()
if !c.claimed {
c.claimed = true
c.Conn.Close()
c.Conn = nil
}
}

0 comments on commit b7e8554

Please sign in to comment.