diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index a15e5f96b7..c72fa5fd88 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -1,6 +1,7 @@ package proxy import ( + "bytes" "crypto/tls" "fmt" "net/url" @@ -121,6 +122,13 @@ func Do(c *fiber.Ctx, addr string) error { req := c.Request() res := c.Response() req.SetRequestURI(addr) + // NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https. + // issue reference: + // https://github.com/gofiber/fiber/issues/1762 + if scheme := getScheme(utils.UnsafeBytes(addr)); len(scheme) > 0 { + req.URI().SetSchemeBytes(scheme) + } + req.Header.Del(fiber.HeaderConnection) if err := client.Do(req, res); err != nil { return err @@ -128,3 +136,11 @@ func Do(c *fiber.Ctx, addr string) error { res.Header.Del(fiber.HeaderConnection) return nil } + +func getScheme(uri []byte) []byte { + i := bytes.IndexByte(uri, '/') + if i < 1 || uri[i-1] != ':' || i == len(uri)-1 || uri[i+1] != '/' { + return nil + } + return uri[:i-1] +} diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index d7665c2dbd..f03f54d3f8 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -122,6 +122,40 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { utils.AssertEqual(t, "tls balancer", body) } +// go test -run Test_Proxy_Forward_WithTlsConfig_To_Http +func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { + t.Parallel() + + _, targetAddr := createProxyTestServer(func(c *fiber.Ctx) error { + return c.SendString("hello from target") + }, t) + + proxyServerTLSConf, _, err := tlstest.GetTLSConfigs() + utils.AssertEqual(t, nil, err) + + proxyServerLn, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + utils.AssertEqual(t, nil, err) + + proxyServerLn = tls.NewListener(proxyServerLn, proxyServerTLSConf) + + app := fiber.New(fiber.Config{DisableStartupMessage: true}) + + proxyAddr := proxyServerLn.Addr().String() + + app.Use(Forward("http://" + targetAddr)) + + go func() { utils.AssertEqual(t, nil, app.Listener(proxyServerLn)) }() + + code, body, errs := fiber.Get("https://" + proxyAddr). + InsecureSkipVerify(). + Timeout(5 * time.Second). + String() + + utils.AssertEqual(t, 0, len(errs)) + utils.AssertEqual(t, fiber.StatusOK, code) + utils.AssertEqual(t, "hello from target", body) +} + // go test -run Test_Proxy_Forward func Test_Proxy_Forward(t *testing.T) { t.Parallel()