Skip to content

Commit 8ddae58

Browse files
committed
proxy: support basic authentication
1 parent 1b89e78 commit 8ddae58

File tree

2 files changed

+85
-17
lines changed

2 files changed

+85
-17
lines changed

Diff for: proxy.go

+26-11
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package grpc
2121
import (
2222
"bufio"
2323
"context"
24+
"encoding/base64"
2425
"errors"
2526
"fmt"
2627
"io"
@@ -30,14 +31,16 @@ import (
3031
"net/url"
3132
)
3233

34+
const proxyAuthHeaderKey = "Proxy-Authorization"
35+
3336
var (
3437
// errDisabled indicates that proxy is disabled for the address.
3538
errDisabled = errors.New("proxy is disabled for the address")
3639
// The following variable will be overwritten in the tests.
3740
httpProxyFromEnvironment = http.ProxyFromEnvironment
3841
)
3942

40-
func mapAddress(ctx context.Context, address string) (string, error) {
43+
func mapAddress(ctx context.Context, address string) (*url.URL, error) {
4144
req := &http.Request{
4245
URL: &url.URL{
4346
Scheme: "https",
@@ -46,12 +49,12 @@ func mapAddress(ctx context.Context, address string) (string, error) {
4649
}
4750
url, err := httpProxyFromEnvironment(req)
4851
if err != nil {
49-
return "", err
52+
return nil, err
5053
}
5154
if url == nil {
52-
return "", errDisabled
55+
return nil, errDisabled
5356
}
54-
return url.Host, nil
57+
return url, nil
5558
}
5659

5760
// To read a response from a net.Conn, http.ReadResponse() takes a bufio.Reader.
@@ -68,7 +71,12 @@ func (c *bufConn) Read(b []byte) (int, error) {
6871
return c.r.Read(b)
6972
}
7073

71-
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_ net.Conn, err error) {
74+
func basicAuth(username, password string) string {
75+
auth := username + ":" + password
76+
return base64.StdEncoding.EncodeToString([]byte(auth))
77+
}
78+
79+
func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr string, proxyURL *url.URL) (_ net.Conn, err error) {
7280
defer func() {
7381
if err != nil {
7482
conn.Close()
@@ -77,9 +85,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_
7785

7886
req := (&http.Request{
7987
Method: http.MethodConnect,
80-
URL: &url.URL{Host: addr},
88+
URL: &url.URL{Host: backendAddr},
8189
Header: map[string][]string{"User-Agent": {grpcUA}},
8290
})
91+
if t := proxyURL.User; t != nil {
92+
u := t.Username()
93+
p, _ := t.Password()
94+
req.Header.Add(proxyAuthHeaderKey, basicAuth(u, p))
95+
}
8396

8497
if err := sendHTTPRequest(ctx, req, conn); err != nil {
8598
return nil, fmt.Errorf("failed to write the HTTP request: %v", err)
@@ -107,22 +120,24 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, addr string) (_
107120
// provided dialer, does HTTP CONNECT handshake and returns the connection.
108121
func newProxyDialer(dialer func(context.Context, string) (net.Conn, error)) func(context.Context, string) (net.Conn, error) {
109122
return func(ctx context.Context, addr string) (conn net.Conn, err error) {
110-
var skipHandshake bool
111-
newAddr, err := mapAddress(ctx, addr)
123+
var newAddr string
124+
proxyURL, err := mapAddress(ctx, addr)
112125
if err != nil {
113126
if err != errDisabled {
114127
return nil, err
115128
}
116-
skipHandshake = true
117129
newAddr = addr
130+
} else {
131+
newAddr = proxyURL.Host
118132
}
119133

120134
conn, err = dialer(ctx, newAddr)
121135
if err != nil {
122136
return
123137
}
124-
if !skipHandshake {
125-
conn, err = doHTTPConnectHandshake(ctx, conn, addr)
138+
if proxyURL != nil {
139+
// proxy is disabled if proxyURL is nil.
140+
conn, err = doHTTPConnectHandshake(ctx, conn, addr, proxyURL)
126141
}
127142
return
128143
}

Diff for: proxy_test.go

+59-6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ package grpc
2222

2323
import (
2424
"bufio"
25+
"encoding/base64"
26+
"fmt"
2527
"io"
2628
"net"
2729
"net/http"
@@ -53,6 +55,8 @@ type proxyServer struct {
5355
lis net.Listener
5456
in net.Conn
5557
out net.Conn
58+
59+
requestCheck func(*http.Request) error
5660
}
5761

5862
func (p *proxyServer) run() {
@@ -67,11 +71,11 @@ func (p *proxyServer) run() {
6771
p.t.Errorf("failed to read CONNECT req: %v", err)
6872
return
6973
}
70-
if req.Method != http.MethodConnect || req.UserAgent() != grpcUA {
74+
if err := p.requestCheck(req); err != nil {
7175
resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
7276
resp.Write(p.in)
7377
p.in.Close()
74-
p.t.Errorf("get wrong CONNECT req: %+v", req)
78+
p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
7579
return
7680
}
7781

@@ -97,13 +101,17 @@ func (p *proxyServer) stop() {
97101
}
98102
}
99103

100-
func TestHTTPConnect(t *testing.T) {
104+
func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
101105
defer leakcheck.Check(t)
102106
plis, err := net.Listen("tcp", "localhost:0")
103107
if err != nil {
104108
t.Fatalf("failed to listen: %v", err)
105109
}
106-
p := &proxyServer{t: t, lis: plis}
110+
p := &proxyServer{
111+
t: t,
112+
lis: plis,
113+
requestCheck: proxyReqCheck,
114+
}
107115
go p.run()
108116
defer p.stop()
109117

@@ -128,7 +136,7 @@ func TestHTTPConnect(t *testing.T) {
128136

129137
// Overwrite the function in the test and restore them in defer.
130138
hpfe := func(req *http.Request) (*url.URL, error) {
131-
return &url.URL{Host: plis.Addr().String()}, nil
139+
return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
132140
}
133141
defer overwrite(hpfe)()
134142

@@ -157,6 +165,51 @@ func TestHTTPConnect(t *testing.T) {
157165
}
158166
}
159167

168+
func TestHTTPConnect(t *testing.T) {
169+
testHTTPConnect(t,
170+
func(in *url.URL) *url.URL {
171+
return in
172+
},
173+
func(req *http.Request) error {
174+
if req.Method != http.MethodConnect {
175+
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
176+
}
177+
if req.UserAgent() != grpcUA {
178+
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
179+
}
180+
return nil
181+
},
182+
)
183+
}
184+
185+
func TestHTTPConnectBasicAuth(t *testing.T) {
186+
const (
187+
user = "notAUser"
188+
password = "notAPassword"
189+
)
190+
testHTTPConnect(t,
191+
func(in *url.URL) *url.URL {
192+
in.User = url.UserPassword(user, password)
193+
return in
194+
},
195+
func(req *http.Request) error {
196+
if req.Method != http.MethodConnect {
197+
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
198+
}
199+
if req.UserAgent() != grpcUA {
200+
return fmt.Errorf("unexpect user agent %q, want %q", req.UserAgent(), grpcUA)
201+
}
202+
wantProxyAuthStr := base64.StdEncoding.EncodeToString([]byte(user + ":" + password))
203+
if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
204+
gotDecoded, _ := base64.StdEncoding.DecodeString(got)
205+
wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr)
206+
return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded)
207+
}
208+
return nil
209+
},
210+
)
211+
}
212+
160213
func TestMapAddressEnv(t *testing.T) {
161214
defer leakcheck.Check(t)
162215
// Overwrite the function in the test and restore them in defer.
@@ -176,7 +229,7 @@ func TestMapAddressEnv(t *testing.T) {
176229
if err != nil {
177230
t.Error(err)
178231
}
179-
if got != envProxyAddr {
232+
if got.Host != envProxyAddr {
180233
t.Errorf("want %v, got %v", envProxyAddr, got)
181234
}
182235
}

0 commit comments

Comments
 (0)