diff --git a/proxy.go b/proxy.go index 5547ea8d8f8c..5b72dfbe6221 100644 --- a/proxy.go +++ b/proxy.go @@ -21,6 +21,7 @@ package grpc import ( "bufio" "context" + "encoding/base64" "errors" "fmt" "io" @@ -30,6 +31,8 @@ import ( "net/url" ) +const proxyAuthHeaderKey = "Proxy-Authorization" + var ( // errDisabled indicates that proxy is disabled for the address. errDisabled = errors.New("proxy is disabled for the address") @@ -37,7 +40,7 @@ var ( httpProxyFromEnvironment = http.ProxyFromEnvironment ) -func mapAddress(ctx context.Context, address string) (string, error) { +func mapAddress(ctx context.Context, address string) (*url.URL, error) { req := &http.Request{ URL: &url.URL{ Scheme: "https", @@ -46,12 +49,12 @@ func mapAddress(ctx context.Context, address string) (string, error) { } url, err := httpProxyFromEnvironment(req) if err != nil { - return "", err + return nil, err } if url == nil { - return "", errDisabled + return nil, errDisabled } - return url.Host, nil + return url, nil } // To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader. @@ -68,7 +71,12 @@ func (c *bufConn) Read(b []byte) (int, error) { return c.r.Read(b) } -func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ net.Conn, err error) { +func basicAuth(username, password string) string { + auth := username + ":" + password + return base64.StdEncoding.EncodeToString([]byte(auth)) +} + +func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL) (_ net.Conn, err error) { defer func() { if err != nil { conn.Close() @@ -77,9 +85,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ req := (&http.Request{ Method: http.MethodConnect, - URL: &url.URL{Host: addr}, + URL: &url.URL{Host: backendAddr}, Header: map[string][]string{"User-Agent": {grpcUA}}, }) + if t := proxyURL.User; t != nil { + u := t.Username() + p, _ := t.Password() + req.Header.Add(proxyAuthHeaderKey, basicAuth(u, p)) + } if err := sendHTTPRequest(ctx, req, conn); err != nil { return nil, fmt.Errorf("failed to write the HTTP request: %v", err) @@ -107,22 +120,24 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ // provided dialer, does HTTP CONNECT handshake and returns the connection. func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) { return func(ctx context.Context, addr string) (conn net.Conn, err error) { - var skipHandshake bool - newAddr, err := mapAddress(ctx, addr) + var newAddr string + proxyURL, err := mapAddress(ctx, addr) if err != nil { if err != errDisabled { return nil, err } - skipHandshake = true newAddr = addr + } else { + newAddr = proxyURL.Host } conn, err = dialer(ctx, newAddr) if err != nil { return } - if !skipHandshake { - conn, err = doHTTPConnectHandshake(ctx, conn, addr) + if proxyURL != nil { + // proxy is disabled if proxyURL is nil. + conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL) } return } diff --git a/proxy_test.go b/proxy_test.go index 7183ba342554..9efba48580aa 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -22,6 +22,8 @@ package grpc import ( "bufio" + "encoding/base64" + "fmt" "io" "net" "net/http" @@ -53,6 +55,8 @@ type proxyServer struct { lis net.Listener in net.Conn out net.Conn + + requestCheck func(*http.Request) error } func (p *proxyServer) run() { @@ -67,11 +71,11 @@ func (p *proxyServer) run() { p.t.Errorf("failed to read CONNECT req: %v", err) return } - if req.Method != http.MethodConnect || req.UserAgent() != grpcUA { + if err := p.requestCheck(req); err != nil { resp := http.Response{StatusCode: http.StatusMethodNotAllowed} resp.Write(p.in) p.in.Close() - p.t.Errorf("get wrong CONNECT req: %+v", req) + p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err) return } @@ -97,13 +101,17 @@ func (p *proxyServer) stop() { } } -func TestHTTPConnect(t *testing.T) { +func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) { defer leakcheck.Check(t) plis, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("failed to listen: %v", err) } - p := &proxyServer{t: t, lis: plis} + p := &proxyServer{ + t: t, + lis: plis, + requestCheck: proxyReqCheck, + } go p.run() defer p.stop() @@ -128,7 +136,7 @@ func TestHTTPConnect(t *testing.T) { // Overwrite the function in the test and restore them in defer. hpfe := func(req *http.Request) (*url.URL, error) { - return &url.URL{Host: plis.Addr().String()}, nil + return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil } defer overwrite(hpfe)() @@ -157,6 +165,51 @@ func TestHTTPConnect(t *testing.T) { } } +func TestHTTPConnect(t *testing.T) { + testHTTPConnect(t, + func(in *url.URL) *url.URL { + return in + }, + func(req *http.Request) error { + if req.Method != http.MethodConnect { + return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) + } + if req.UserAgent() != grpcUA { + return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA) + } + return nil + }, + ) +} + +func TestHTTPConnectBasicAuth(t *testing.T) { + const ( + user = "notAUser" + password = "notAPassword" + ) + testHTTPConnect(t, + func(in *url.URL) *url.URL { + in.User = url.UserPassword(user, password) + return in + }, + func(req *http.Request) error { + if req.Method != http.MethodConnect { + return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect) + } + if req.UserAgent() != grpcUA { + return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA) + } + wantProxyAuthStr := base64.StdEncoding.EncodeToString([]byte(user + ":" + password)) + if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr { + gotDecoded, _ := base64.StdEncoding.DecodeString(got) + wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr) + return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded) + } + return nil + }, + ) +} + func TestMapAddressEnv(t *testing.T) { defer leakcheck.Check(t) // Overwrite the function in the test and restore them in defer. @@ -176,7 +229,7 @@ func TestMapAddressEnv(t *testing.T) { if err != nil { t.Error(err) } - if got != envProxyAddr { + if got.Host != envProxyAddr { t.Errorf("want %v, got %v", envProxyAddr, got) } }