Skip to content

Commit

Permalink
proxy: support basic authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
menghanl committed Nov 5, 2018
1 parent 1b89e78 commit 51810ac
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 17 deletions.
37 changes: 26 additions & 11 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package grpc
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
Expand All @@ -30,14 +31,16 @@ import (
"net/url"
)

const proxyAuthHeaderKey = "Proxy-Authorization"

var (
// errDisabled indicates that proxy is disabled for the address.
errDisabled = errors.New("proxy is disabled for the address")
// The following variable will be overwritten in the tests.
httpProxyFromEnvironment = http.ProxyFromEnvironment
)

func mapAddress(ctx context.Context, address string) (string, error) {
func mapAddress(ctx context.Context, address string) (*url.URL, error) {
req := &http.Request{
URL: &url.URL{
Scheme: "https",
Expand All @@ -46,12 +49,12 @@ func mapAddress(ctx context.Context, address string) (string, error) {
}
url, err := httpProxyFromEnvironment(req)
if err != nil {
return "", err
return nil, err
}
if url == nil {
return "", errDisabled
return nil, errDisabled
}
return url.Host, nil
return url, nil
}

// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
Expand All @@ -68,7 +71,12 @@ func (c *bufConn) Read(b []byte) (int, error) {
return c.r.Read(b)
}

func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ net.Conn, err error) {
func basicAuth(username, password string) string {
auth := username + ":" + password
return base64.StdEncoding.EncodeToString([]byte(auth))
}

func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL) (_ net.Conn, err error) {
defer func() {
if err != nil {
conn.Close()
Expand All @@ -77,9 +85,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_

req := (&http.Request{
Method: http.MethodConnect,
URL: &url.URL{Host: addr},
URL: &url.URL{Host: backendAddr},
Header: map[string][]string{"User-Agent": {grpcUA}},
})
if t := proxyURL.User; t != nil {
u := t.Username()
p, _ := t.Password()
req.Header.Add(proxyAuthHeaderKey, basicAuth(u, p))
}

if err := sendHTTPRequest(ctx, req, conn); err != nil {
return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
Expand Down Expand Up @@ -107,22 +120,24 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_
// provided dialer, does HTTP CONNECT handshake and returns the connection.
func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (conn net.Conn, err error) {
var skipHandshake bool
newAddr, err := mapAddress(ctx, addr)
var newAddr string
proxyURL, err := mapAddress(ctx, addr)
if err != nil {
if err != errDisabled {
return nil, err
}
skipHandshake = true
newAddr = addr
} else {
newAddr = proxyURL.Host
}

conn, err = dialer(ctx, newAddr)
if err != nil {
return
}
if !skipHandshake {
conn, err = doHTTPConnectHandshake(ctx, conn, addr)
if proxyURL != nil {
// proxy is disabled if proxyURL is nil.
conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL)
}
return
}
Expand Down
65 changes: 59 additions & 6 deletions proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ package grpc

import (
"bufio"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -53,6 +55,8 @@ type proxyServer struct {
lis net.Listener
in net.Conn
out net.Conn

requestCheck func(*http.Request) error
}

func (p *proxyServer) run() {
Expand All @@ -67,11 +71,11 @@ func (p *proxyServer) run() {
p.t.Errorf("failed to read CONNECT req: %v", err)
return
}
if req.Method != http.MethodConnect || req.UserAgent() != grpcUA {
if err := p.requestCheck(req); err != nil {
resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
resp.Write(p.in)
p.in.Close()
p.t.Errorf("get wrong CONNECT req: %+v", req)
p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
return
}

Expand All @@ -97,13 +101,17 @@ func (p *proxyServer) stop() {
}
}

func TestHTTPConnect(t *testing.T) {
func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
defer leakcheck.Check(t)
plis, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
p := &proxyServer{t: t, lis: plis}
p := &proxyServer{
t: t,
lis: plis,
requestCheck: proxyReqCheck,
}
go p.run()
defer p.stop()

Expand All @@ -128,7 +136,7 @@ func TestHTTPConnect(t *testing.T) {

// Overwrite the function in the test and restore them in defer.
hpfe := func(req *http.Request) (*url.URL, error) {
return &url.URL{Host: plis.Addr().String()}, nil
return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
}
defer overwrite(hpfe)()

Expand Down Expand Up @@ -157,6 +165,51 @@ func TestHTTPConnect(t *testing.T) {
}
}

func TestHTTPConnect(t *testing.T) {
testHTTPConnect(t,
func(in *url.URL) *url.URL {
return in
},
func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
if req.UserAgent() != grpcUA {
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
}
return nil
},
)
}

func TestHTTPConnectBasicAuth(t *testing.T) {
const (
user = "notAUser"
password = "notAPassword"
)
testHTTPConnect(t,
func(in *url.URL) *url.URL {
in.User = url.UserPassword(user, password)
return in
},
func(req *http.Request) error {
if req.Method != http.MethodConnect {
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
}
if req.UserAgent() != grpcUA {
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
}
wantProxyAuthStr := base64.StdEncoding.EncodeToString([]byte(user + ":" + password))
if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
gotDecoded, _ := base64.StdEncoding.DecodeString(got)
wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr)
return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded)
}
return nil
},
)
}

func TestMapAddressEnv(t *testing.T) {
defer leakcheck.Check(t)
// Overwrite the function in the test and restore them in defer.
Expand All @@ -176,7 +229,7 @@ func TestMapAddressEnv(t *testing.T) {
if err != nil {
t.Error(err)
}
if got != envProxyAddr {
if got.Host != envProxyAddr {
t.Errorf("want %v, got %v", envProxyAddr, got)
}
}

0 comments on commit 51810ac

Please sign in to comment.