From 281e2f00463d7c7947cc37344ddab9d2dff3409d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Efe=20=C3=87etin?= Date: Thu, 8 Sep 2022 08:57:05 +0300 Subject: [PATCH] :sparkles: v3 (feature): merge Listen methods & ListenConfig (#1930) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * :sparkles: v3: new Start method for app * :sparkles: v3: new Start method for app * :sparkles: v3: new Start method for app * :sparkles: v3: new Start method for app * :sparkles: v3: new Start method for app * :sparkles: v3: new Start method for app * fix tests * improve graceful shutdown * update * Start -> Listen * rename test funcs. * Add Test_Listen_Graceful_Shutdown test. * add OnShutdownSuccess * fix tests * fix tests * split listen & listener * typo * Add retry logic to tests * Add retry logic to tests * Add retry logic to tests * Add retry logic to tests Co-authored-by: René Werner --- .github/workflows/test.yml | 6 +- app.go | 26 +-- app_test.go | 31 +-- client_test.go | 169 ++++++++++---- error.go | 13 +- hooks_test.go | 8 +- listen.go | 331 +++++++++++++++++----------- listen_test.go | 387 +++++++++++++++++++++++++-------- middleware/pprof/pprof_test.go | 8 +- middleware/proxy/proxy_test.go | 37 +++- prefork.go | 20 +- prefork_test.go | 14 +- 12 files changed, 714 insertions(+), 336 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ffd8b92709..1ffb54108b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -35,4 +35,8 @@ jobs: - name: Fetch Repository uses: actions/checkout@v3 - name: Run Test - run: go test ./... -v -race + uses: nick-fields/retry@v2 + with: + max_attempts: 3 + timeout_minutes: 15 + command: go test ./... -v -race diff --git a/app.go b/app.go index 7740d3aba6..861c321b68 100644 --- a/app.go +++ b/app.go @@ -124,11 +124,6 @@ type App struct { // Config is a struct holding the server settings. type Config struct { - // When set to true, this will spawn multiple Go processes listening on the same port. - // - // Default: false - Prefork bool `json:"prefork"` - // Enables the "Server: value" HTTP header. // // Default: "" @@ -270,11 +265,6 @@ type Config struct { // Default: false DisableHeaderNormalizing bool `json:"disable_header_normalizing"` - // When set to true, it will not print out the «Fiber» ASCII art and listening address. - // - // Default: false - DisableStartupMessage bool `json:"disable_startup_message"` - // This function allows to setup app name for the app // // Default: nil @@ -332,12 +322,6 @@ type Config struct { // Default: xml.Marshal XMLEncoder utils.XMLMarshal `json:"-"` - // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only) - // WARNING: When prefork is set to true, only "tcp4" and "tcp6" can be chose. - // - // Default: NetworkTCP4 - Network string - // If you find yourself behind some sort of proxy, like a load balancer, // then certain header information may be sent to you using special X-Forwarded-* headers or the Forwarded header. // For example, the Host HTTP header is usually used to return the requested host. @@ -374,10 +358,6 @@ type Config struct { // Default: false EnableIPValidation bool `json:"enable_ip_validation"` - // If set to true, will print all routes with their method, path and handler. - // Default: false - EnablePrintRoutes bool `json:"enable_print_routes"` - // You can define custom color scheme. They'll be used for startup message, route list and some middlewares. // // Optional. Default: DefaultColors @@ -533,9 +513,6 @@ func New(config ...Config) *App { if app.config.XMLEncoder == nil { app.config.XMLEncoder = xml.Marshal } - if app.config.Network == "" { - app.config.Network = NetworkTCP4 - } app.config.trustedProxiesMap = make(map[string]struct{}, len(app.config.TrustedProxies)) for _, ipAddress := range app.config.TrustedProxies { @@ -831,7 +808,8 @@ func (app *App) HandlersCount() uint32 { // Shutdown does not close keepalive connections so its recommended to set ReadTimeout to something else than 0. func (app *App) Shutdown() error { if app.hooks != nil { - defer app.hooks.executeOnShutdownHooks() + // TODO: check should be defered? + app.hooks.executeOnShutdownHooks() } app.mutex.Lock() diff --git a/app_test.go b/app_test.go index 31aa9c2789..1e14597346 100644 --- a/app_test.go +++ b/app_test.go @@ -599,16 +599,14 @@ func Test_App_New(t *testing.T) { func Test_App_Config(t *testing.T) { app := New(Config{ - DisableStartupMessage: true, + StrictRouting: true, }) - require.True(t, app.Config().DisableStartupMessage) + require.True(t, app.Config().StrictRouting) } func Test_App_Shutdown(t *testing.T) { t.Run("success", func(t *testing.T) { - app := New(Config{ - DisableStartupMessage: true, - }) + app := New() require.True(t, app.Shutdown() == nil) }) @@ -1098,7 +1096,6 @@ func Test_App_Deep_Group(t *testing.T) { // go test -run Test_App_Next_Method func Test_App_Next_Method(t *testing.T) { app := New() - app.config.DisableStartupMessage = true app.Use(func(c Ctx) error { require.Equal(t, MethodGet, c.Method()) @@ -1140,7 +1137,6 @@ func Test_NewError(t *testing.T) { // go test -run Test_Test_Timeout func Test_Test_Timeout(t *testing.T) { app := New() - app.config.DisableStartupMessage = true app.Get("/", testEmptyHandler) @@ -1166,7 +1162,6 @@ func (errorReader) Read([]byte) (int, error) { // go test -run Test_Test_DumpError func Test_Test_DumpError(t *testing.T) { app := New() - app.config.DisableStartupMessage = true app.Get("/", testEmptyHandler) @@ -1236,10 +1231,9 @@ func Test_App_HandlersCount(t *testing.T) { // go test -run Test_App_ReadTimeout func Test_App_ReadTimeout(t *testing.T) { app := New(Config{ - ReadTimeout: time.Nanosecond, - IdleTimeout: time.Minute, - DisableStartupMessage: true, - DisableKeepalive: true, + ReadTimeout: time.Nanosecond, + IdleTimeout: time.Minute, + DisableKeepalive: true, }) app.Get("/read-timeout", func(c Ctx) error { @@ -1266,14 +1260,12 @@ func Test_App_ReadTimeout(t *testing.T) { require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.Listen(":4004")) + require.Nil(t, app.Listen(":4004", ListenConfig{DisableStartupMessage: true})) } // go test -run Test_App_BadRequest func Test_App_BadRequest(t *testing.T) { - app := New(Config{ - DisableStartupMessage: true, - }) + app := New() app.Get("/bad-request", func(c Ctx) error { return c.SendString("I should not be sent") @@ -1298,14 +1290,13 @@ func Test_App_BadRequest(t *testing.T) { require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.Listen(":4005")) + require.Nil(t, app.Listen(":4005", ListenConfig{DisableStartupMessage: true})) } // go test -run Test_App_SmallReadBuffer func Test_App_SmallReadBuffer(t *testing.T) { app := New(Config{ - ReadBufferSize: 1, - DisableStartupMessage: true, + ReadBufferSize: 1, }) app.Get("/small-read-buffer", func(c Ctx) error { @@ -1322,7 +1313,7 @@ func Test_App_SmallReadBuffer(t *testing.T) { require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.Listen(":4006")) + require.Nil(t, app.Listen(":4006", ListenConfig{DisableStartupMessage: true})) } func Test_App_Server(t *testing.T) { diff --git a/client_test.go b/client_test.go index 9b866dc027..e3d78dca6d 100644 --- a/client_test.go +++ b/client_test.go @@ -28,13 +28,17 @@ func Test_Client_Invalid_URL(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.SendString(c.Host()) }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() a := Get("http://example.com\r\n\r\nGET /\r\n\r\n") @@ -66,13 +70,17 @@ func Test_Client_Get(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.SendString(c.Host()) }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() for i := 0; i < 5; i++ { a := Get("http://example.com") @@ -92,14 +100,17 @@ func Test_Client_Head(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Head("/", func(c Ctx) error { return c.SendStatus(StatusAccepted) }) - go func() { require.Nil(t, app.Listener(ln)) }() - + go func() { + require.Nil(t, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() for i := 0; i < 5; i++ { a := Head("http://example.com") @@ -118,14 +129,18 @@ func Test_Client_Post(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Post("/", func(c Ctx) error { return c.Status(StatusCreated). SendString(c.FormValue("foo")) }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() for i := 0; i < 5; i++ { args := AcquireArgs() @@ -152,13 +167,17 @@ func Test_Client_Put(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Put("/", func(c Ctx) error { return c.SendString(c.FormValue("foo")) }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() for i := 0; i < 5; i++ { args := AcquireArgs() @@ -185,13 +204,17 @@ func Test_Client_Patch(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Patch("/", func(c Ctx) error { return c.SendString(c.FormValue("foo")) }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() for i := 0; i < 5; i++ { args := AcquireArgs() @@ -218,14 +241,18 @@ func Test_Client_Delete(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Delete("/", func(c Ctx) error { return c.Status(StatusNoContent). SendString("deleted") }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() for i := 0; i < 5; i++ { args := AcquireArgs() @@ -249,13 +276,17 @@ func Test_Client_UserAgent(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.Send(c.Request().Header.UserAgent()) }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() t.Run("default", func(t *testing.T) { for i := 0; i < 5; i++ { @@ -391,13 +422,17 @@ func Test_Client_Agent_Host(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.SendString(c.Host()) }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() a := Get("http://1.1.1.1:8080"). Host("example.com"). @@ -487,13 +522,17 @@ func Test_Client_Agent_Custom_Response(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.SendString("custom") }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() for i := 0; i < 5; i++ { a := AcquireAgent() @@ -524,13 +563,17 @@ func Test_Client_Agent_Dest(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.SendString("dest") }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() t.Run("small dest", func(t *testing.T) { dest := []byte("de") @@ -592,9 +635,13 @@ func Test_Client_Agent_RetryIf(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() a := Post("http://example.com"). RetryIf(func(req *Request) bool { @@ -699,7 +746,7 @@ func Test_Client_Agent_MultipartForm(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Post("/", func(c Ctx) error { require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) @@ -711,7 +758,11 @@ func Test_Client_Agent_MultipartForm(t *testing.T) { return c.Send(c.Request().Body()) }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() args := AcquireArgs() @@ -754,7 +805,7 @@ func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Post("/", func(c Ctx) error { require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) @@ -781,7 +832,11 @@ func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { return c.SendString("multipart form files") }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() for i := 0; i < 5; i++ { ff := AcquireFormFile() @@ -885,14 +940,18 @@ func Test_Client_Agent_Timeout(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { time.Sleep(time.Millisecond * 200) return c.SendString("timeout") }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() a := Get("http://example.com"). Timeout(time.Millisecond * 50) @@ -911,13 +970,17 @@ func Test_Client_Agent_Reuse(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.SendString("reuse") }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() a := Get("http://example.com"). Reuse() @@ -952,13 +1015,17 @@ func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { ln = tls.NewListener(ln, serverTLSConf) - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.SendString("ignore tls") }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() code, body, errs := Get("https://" + ln.Addr().String()). InsecureSkipVerify(). @@ -981,13 +1048,17 @@ func Test_Client_Agent_TLS(t *testing.T) { ln = tls.NewListener(ln, serverTLSConf) - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.SendString("tls") }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() code, body, errs := Get("https://" + ln.Addr().String()). TLSConfig(clientTLSConf). @@ -1003,7 +1074,7 @@ func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { if c.Request().URI().QueryArgs().Has("foo") { @@ -1015,7 +1086,11 @@ func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { return c.SendString("redirect") }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() t.Run("success", func(t *testing.T) { a := Get("http://example.com?foo"). @@ -1049,7 +1124,7 @@ func Test_Client_Agent_Struct(t *testing.T) { ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", func(c Ctx) error { return c.JSON(data{true}) @@ -1059,7 +1134,11 @@ func Test_Client_Agent_Struct(t *testing.T) { return c.SendString(`{"success"`) }) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() t.Run("success", func(t *testing.T) { t.Parallel() @@ -1128,11 +1207,15 @@ func testAgent(t *testing.T, handler Handler, wrapAgent func(agent *Agent), exce ln := fasthttputil.NewInmemoryListener() - app := New(Config{DisableStartupMessage: true}) + app := New() app.Get("/", handler) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, nil, app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + })) + }() c := 1 if len(count) > 0 { diff --git a/error.go b/error.go index 87a6af38c9..6b08acef15 100644 --- a/error.go +++ b/error.go @@ -2,19 +2,24 @@ package fiber import ( errors "encoding/json" - goErrors "errors" + stdErrors "errors" "github.com/gofiber/fiber/v3/internal/schema" ) +// Graceful shutdown errors +var ( + ErrGracefulTimeout = stdErrors.New("shutdown: graceful timeout has been reached, exiting") +) + // Range errors var ( - ErrRangeMalformed = goErrors.New("range: malformed range header string") - ErrRangeUnsatisfiable = goErrors.New("range: unsatisfiable range") + ErrRangeMalformed = stdErrors.New("range: malformed range header string") + ErrRangeUnsatisfiable = stdErrors.New("range: unsatisfiable range") ) // Binder errors -var ErrCustomBinderNotFound = goErrors.New("binder: custom binder not found, please be sure to enter the right name") +var ErrCustomBinderNotFound = stdErrors.New("binder: custom binder not found, please be sure to enter the right name") // gorilla/schema errors type ( diff --git a/hooks_test.go b/hooks_test.go index ca46993738..611dd93eb7 100644 --- a/hooks_test.go +++ b/hooks_test.go @@ -158,9 +158,7 @@ func Test_Hook_OnShutdown(t *testing.T) { func Test_Hook_OnListen(t *testing.T) { t.Parallel() - app := New(Config{ - DisableStartupMessage: true, - }) + app := New() buf := bytebufferpool.Get() defer bytebufferpool.Put(buf) @@ -176,8 +174,8 @@ func Test_Hook_OnListen(t *testing.T) { time.Sleep(1000 * time.Millisecond) require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.Listen(":9000")) + require.Nil(t, app.Listen(":9000", ListenConfig{DisableStartupMessage: true})) require.Equal(t, "ready", buf.String()) } @@ -198,5 +196,5 @@ func Test_Hook_OnHook(t *testing.T) { return nil }) - require.Nil(t, app.prefork(NetworkTCP4, ":3000", nil)) + require.Nil(t, app.prefork(":3000", nil, ListenConfig{DisableStartupMessage: true, EnablePrefork: true})) } diff --git a/listen.go b/listen.go index 05f210cc60..d3ab216aae 100644 --- a/listen.go +++ b/listen.go @@ -5,10 +5,11 @@ package fiber import ( + "context" "crypto/tls" "crypto/x509" - "errors" "fmt" + "log" "net" "os" "path/filepath" @@ -23,96 +24,167 @@ import ( "github.com/mattn/go-isatty" ) -// Listener can be used to pass a custom listener. -func (app *App) Listener(ln net.Listener) error { - // prepare the server for the start - app.startupProcess() +// ListenConfig is a struct to customize startup of Fiber. +// +// TODO: Add timeout for graceful shutdown. +type ListenConfig struct { + // Known networks are "tcp", "tcp4" (IPv4-only), "tcp6" (IPv6-only) + // WARNING: When prefork is set to true, only "tcp4" and "tcp6" can be chose. + // + // Default: NetworkTCP4 + ListenerNetwork string `json:"listener_network"` + + // CertFile is a path of certficate file. + // If you want to use TLS, you have to enter this field. + // + // Default : "" + CertFile string `json:"cert_file"` + + // KeyFile is a path of certficate's private key. + // If you want to use TLS, you have to enter this field. + // + // Default : "" + CertKeyFile string `json:"cert_key_file"` + + // CertClientFile is a path of client certficate. + // If you want to use mTLS, you have to enter this field. + // + // Default : "" + CertClientFile string `json:"cert_client_file"` + + // GracefulContext is a field to shutdown Fiber by given context gracefully. + // + // Default: nil + GracefulContext context.Context `json:"graceful_context"` + + // TLSConfigFunc allows customizing tls.Config as you want. + // + // Default: nil + TLSConfigFunc func(tlsConfig *tls.Config) `json:"tls_config_func"` + + // ListenerFunc allows accessing and customizing net.Listener. + // + // Default: nil + ListenerAddrFunc func(addr net.Addr) `json:"listener_addr_func"` + + // BeforeServeFunc allows customizing and accessing fiber app before serving the app. + // + // Default: nil + BeforeServeFunc func(app *App) error `json:"before_serve_func"` + + // When set to true, it will not print out the «Fiber» ASCII art and listening address. + // + // Default: false + DisableStartupMessage bool `json:"disable_startup_message"` + + // When set to true, this will spawn multiple Go processes listening on the same port. + // + // Default: false + EnablePrefork bool `json:"enable_prefork"` + + // If set to true, will print all routes with their method, path and handler. + // + // Default: false + EnablePrintRoutes bool `json:"enable_print_routes"` + + // OnShutdownError allows to customize error behavior when to graceful shutdown server by given signal. + // + // Default: Print error with log.Fatalf() + OnShutdownError func(err error) + + // OnShutdownSuccess allows to customize success behavior when to graceful shutdown server by given signal. + // + // Default: nil + OnShutdownSuccess func() +} - // Print startup message - if !app.config.DisableStartupMessage { - app.startupMessage(ln.Addr().String(), getTlsConfig(ln) != nil, "") +// listenConfigDefault is a function to set default values of ListenConfig. +func listenConfigDefault(config ...ListenConfig) ListenConfig { + if len(config) < 1 { + return ListenConfig{ + ListenerNetwork: NetworkTCP4, + OnShutdownError: func(err error) { + log.Fatalf("shutdown: %v", err) + }, + } } - // Print routes - if app.config.EnablePrintRoutes { - app.printRoutesMessage() + cfg := config[0] + if cfg.ListenerNetwork == "" { + cfg.ListenerNetwork = NetworkTCP4 } - // Prefork is not supported for custom listeners - if app.config.Prefork { - fmt.Println("[Warning] Prefork isn't supported for custom listeners.") + if cfg.OnShutdownError == nil { + cfg.OnShutdownError = func(err error) { + log.Fatalf("shutdown: %v", err) + } } - // Start listening - return app.server.Serve(ln) + return cfg } // Listen serves HTTP requests from the given addr. +// You should enter custom ListenConfig to customize startup. (TLS, mTLS, prefork...) // // app.Listen(":8080") // app.Listen("127.0.0.1:8080") -func (app *App) Listen(addr string) error { - // Start prefork - if app.config.Prefork { - return app.prefork(app.config.Network, addr, nil) - } - - // Setup listener - ln, err := net.Listen(app.config.Network, addr) - if err != nil { - return err - } +// app.Listen(":8080", ListenConfig{EnablePrefork: true}) +func (app *App) Listen(addr string, config ...ListenConfig) error { + cfg := listenConfigDefault(config...) + + // Configure TLS + var tlsConfig *tls.Config = nil + if cfg.CertFile != "" && cfg.CertKeyFile != "" { + cert, err := tls.LoadX509KeyPair(cfg.CertFile, cfg.CertKeyFile) + if err != nil { + return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", cfg.CertFile, cfg.CertKeyFile, err) + } - // prepare the server for the start - app.startupProcess() + tlsHandler := &TLSHandler{} + tlsConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + Certificates: []tls.Certificate{ + cert, + }, + GetCertificate: tlsHandler.GetClientInfo, + } - // Print startup message - if !app.config.DisableStartupMessage { - app.startupMessage(ln.Addr().String(), false, "") - } + if cfg.CertClientFile != "" { + clientCACert, err := os.ReadFile(filepath.Clean(cfg.CertClientFile)) + if err != nil { + return err + } - // Print routes - if app.config.EnablePrintRoutes { - app.printRoutesMessage() - } + clientCertPool := x509.NewCertPool() + clientCertPool.AppendCertsFromPEM(clientCACert) - // Start listening - return app.server.Serve(ln) -} + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.ClientCAs = clientCertPool + } -// ListenTLS serves HTTPS requests from the given addr. -// certFile and keyFile are the paths to TLS certificate and key file: -// -// app.ListenTLS(":8080", "./cert.pem", "./cert.key") -func (app *App) ListenTLS(addr, certFile, keyFile string) error { - // Check for valid cert/key path - if len(certFile) == 0 || len(keyFile) == 0 { - return errors.New("tls: provide a valid cert or key path") + // Attach the tlsHandler to the config + app.SetTLSHandler(tlsHandler) } - // Set TLS config with handler - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) + if cfg.TLSConfigFunc != nil { + cfg.TLSConfigFunc(tlsConfig) } - tlsHandler := &TLSHandler{} - config := &tls.Config{ - MinVersion: tls.VersionTLS12, - Certificates: []tls.Certificate{ - cert, - }, - GetCertificate: tlsHandler.GetClientInfo, + // Graceful shutdown + if cfg.GracefulContext != nil { + ctx, cancel := context.WithCancel(cfg.GracefulContext) + defer cancel() + + go app.gracefulShutdown(ctx, cfg) } - // Prefork is supported - if app.config.Prefork { - return app.prefork(app.config.Network, addr, config) + // Start prefork + if cfg.EnablePrefork { + return app.prefork(addr, tlsConfig, cfg) } - // Setup listener - ln, err := net.Listen(app.config.Network, addr) - ln = tls.NewListener(ln, config) + // Configure Listener + ln, err := app.createListener(addr, tlsConfig, cfg) if err != nil { return err } @@ -120,89 +192,85 @@ func (app *App) ListenTLS(addr, certFile, keyFile string) error { // prepare the server for the start app.startupProcess() - // Print startup message - if !app.config.DisableStartupMessage { - app.startupMessage(ln.Addr().String(), true, "") - } + // Print startup message & routes + app.printMessages(cfg, ln) - // Print routes - if app.config.EnablePrintRoutes { - app.printRoutesMessage() + // Serve + if cfg.BeforeServeFunc != nil { + if err := cfg.BeforeServeFunc(app); err != nil { + return err + } } - // Attach the tlsHandler to the config - app.SetTLSHandler(tlsHandler) - - // Start listening return app.server.Serve(ln) } -// ListenMutualTLS serves HTTPS requests from the given addr. -// certFile, keyFile and clientCertFile are the paths to TLS certificate and key file: -// -// app.ListenMutualTLS(":8080", "./cert.pem", "./cert.key", "./client.pem") -func (app *App) ListenMutualTLS(addr, certFile, keyFile, clientCertFile string) error { - // Check for valid cert/key path - if len(certFile) == 0 || len(keyFile) == 0 { - return errors.New("tls: provide a valid cert or key path") - } +// Listener serves HTTP requests from the given listener. +// You should enter custom ListenConfig to customize startup. (prefork, startup message, graceful shutdown...) +func (app *App) Listener(ln net.Listener, config ...ListenConfig) error { + cfg := listenConfigDefault(config...) - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return fmt.Errorf("tls: cannot load TLS key pair from certFile=%q and keyFile=%q: %s", certFile, keyFile, err) + // Graceful shutdown + if cfg.GracefulContext != nil { + ctx, cancel := context.WithCancel(cfg.GracefulContext) + defer cancel() + + go app.gracefulShutdown(ctx, cfg) } - clientCACert, err := os.ReadFile(filepath.Clean(clientCertFile)) - if err != nil { - return err + // prepare the server for the start + app.startupProcess() + + // Print startup message & routes + app.printMessages(cfg, ln) + + // Serve + if cfg.BeforeServeFunc != nil { + if err := cfg.BeforeServeFunc(app); err != nil { + return err + } } - clientCertPool := x509.NewCertPool() - clientCertPool.AppendCertsFromPEM(clientCACert) - tlsHandler := &TLSHandler{} - config := &tls.Config{ - MinVersion: tls.VersionTLS12, - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: clientCertPool, - Certificates: []tls.Certificate{ - cert, - }, - GetCertificate: tlsHandler.GetClientInfo, + // Prefork is not supported for custom listeners + if cfg.EnablePrefork { + fmt.Println("[Warning] Prefork isn't supported for custom listeners.") } - // Prefork is supported - if app.config.Prefork { - return app.prefork(app.config.Network, addr, config) + return app.server.Serve(ln) +} + +// Create listener function. +func (app *App) createListener(addr string, tlsConfig *tls.Config, cfg ListenConfig) (net.Listener, error) { + var listener net.Listener + var err error + + if tlsConfig != nil { + listener, err = tls.Listen(cfg.ListenerNetwork, addr, tlsConfig) + } else { + listener, err = net.Listen(cfg.ListenerNetwork, addr) } - // Setup listener - ln, err := tls.Listen(app.config.Network, addr, config) - if err != nil { - return err + if cfg.ListenerAddrFunc != nil { + cfg.ListenerAddrFunc(listener.Addr()) } - // prepare the server for the start - app.startupProcess() + return listener, err +} +func (app *App) printMessages(cfg ListenConfig, ln net.Listener) { // Print startup message - if !app.config.DisableStartupMessage { - app.startupMessage(ln.Addr().String(), true, "") + if !cfg.DisableStartupMessage { + app.startupMessage(ln.Addr().String(), getTlsConfig(ln) != nil, "", cfg) } // Print routes - if app.config.EnablePrintRoutes { + if cfg.EnablePrintRoutes { app.printRoutesMessage() } - - // Attach the tlsHandler to the config - app.SetTLSHandler(tlsHandler) - - // Start listening - return app.server.Serve(ln) } // startupMessage prepares the startup message with the handler number, port, address and other information -func (app *App) startupMessage(addr string, tls bool, pids string) { +func (app *App) startupMessage(addr string, tls bool, pids string, cfg ListenConfig) { // ignore child processes if IsChild() { return @@ -259,7 +327,7 @@ func (app *App) startupMessage(addr string, tls bool, pids string) { host, port := parseAddr(addr) if host == "" { - if app.config.Network == NetworkTCP6 { + if cfg.ListenerNetwork == NetworkTCP6 { host = "[::1]" } else { host = "0.0.0.0" @@ -272,12 +340,12 @@ func (app *App) startupMessage(addr string, tls bool, pids string) { } isPrefork := "Disabled" - if app.config.Prefork { + if cfg.EnablePrefork { isPrefork = "Enabled" } procs := strconv.Itoa(runtime.GOMAXPROCS(0)) - if !app.config.Prefork { + if !cfg.EnablePrefork { procs = "1" } @@ -307,7 +375,7 @@ func (app *App) startupMessage(addr string, tls bool, pids string) { ) var childPidsLogo string - if app.config.Prefork { + if cfg.EnablePrefork { var childPidsTemplate string childPidsTemplate += "%s" childPidsTemplate += " ┌───────────────────────────────────────────────────┐\n%s" @@ -444,3 +512,16 @@ func (app *App) printRoutesMessage() { _ = w.Flush() } + +// shutdown goroutine +func (app *App) gracefulShutdown(ctx context.Context, cfg ListenConfig) { + <-ctx.Done() + + if err := app.Shutdown(); err != nil { + cfg.OnShutdownError(err) + } + + if success := cfg.OnShutdownSuccess; success != nil { + success() + } +} diff --git a/listen_test.go b/listen_test.go index a096d6a839..09ddc09630 100644 --- a/listen_test.go +++ b/listen_test.go @@ -1,15 +1,14 @@ -// ⚡️ Fiber is an Express inspired web framework written in Go with ☕️ -// 🤖 Github Repository: https://github.com/gofiber/fiber -// 📌 API Documentation: https://docs.gofiber.io - package fiber import ( "bytes" + "context" "crypto/tls" + "errors" "fmt" "io" "log" + "net" "os" "strings" "sync" @@ -20,9 +19,9 @@ import ( "github.com/valyala/fasthttp/fasthttputil" ) -// go test -run Test_App_Listen -func Test_App_Listen(t *testing.T) { - app := New(Config{DisableStartupMessage: true}) +// go test -run Test_Listen +func Test_Listen(t *testing.T) { + app := New() require.False(t, app.Listen(":99999") == nil) @@ -31,78 +30,181 @@ func Test_App_Listen(t *testing.T) { require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.Listen(":4003")) + require.Nil(t, app.Listen(":4003", ListenConfig{DisableStartupMessage: true})) +} + +// go test -run Test_Listen_Graceful_Shutdown +func Test_Listen_Graceful_Shutdown(t *testing.T) { + var mu sync.Mutex + var shutdown bool + + app := New() + + app.Get("/", func(c Ctx) error { + return c.SendString(c.Hostname()) + }) + + ln := fasthttputil.NewInmemoryListener() + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + + err := app.Listener(ln, ListenConfig{ + DisableStartupMessage: true, + GracefulContext: ctx, + OnShutdownSuccess: func() { + mu.Lock() + shutdown = true + mu.Unlock() + }, + }) + + require.NoError(t, err) + }() + + testCases := []struct { + Time time.Duration + ExpectedBody string + ExpectedStatusCode int + ExceptedErrsLen int + }{ + {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExceptedErrsLen: 0}, + {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: 0, ExceptedErrsLen: 1}, + } + + for _, tc := range testCases { + time.Sleep(tc.Time) + + a := Get("http://example.com") + a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() } + code, body, errs := a.String() + + require.Equal(t, tc.ExpectedStatusCode, code) + require.Equal(t, tc.ExpectedBody, body) + require.Equal(t, tc.ExceptedErrsLen, len(errs)) + } + + mu.Lock() + require.True(t, shutdown) + mu.Unlock() } -// go test -run Test_App_Listen_Prefork -func Test_App_Listen_Prefork(t *testing.T) { +// go test -run Test_Listen_Prefork +func Test_Listen_Prefork(t *testing.T) { testPreforkMaster = true - app := New(Config{DisableStartupMessage: true, Prefork: true}) + app := New() - require.Nil(t, app.Listen(":99999")) + require.Nil(t, app.Listen(":99999", ListenConfig{DisableStartupMessage: true, EnablePrefork: true})) } -// go test -run Test_App_ListenTLS -func Test_App_ListenTLS(t *testing.T) { +// go test -run Test_Listen_TLS +func Test_Listen_TLS(t *testing.T) { app := New() // invalid port - require.False(t, app.ListenTLS(":99999", "./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") == nil) - // missing perm/cert file - require.False(t, app.ListenTLS(":0", "", "./.github/testdata/ssl.key") == nil) + require.False(t, app.Listen(":99999", ListenConfig{ + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + }) == nil) go func() { time.Sleep(1000 * time.Millisecond) require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.ListenTLS(":0", "./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")) + require.Nil(t, app.Listen(":0", ListenConfig{ + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + })) + } -// go test -run Test_App_ListenTLS_Prefork -func Test_App_ListenTLS_Prefork(t *testing.T) { +// go test -run Test_Listen_TLS_Prefork +func Test_Listen_TLS_Prefork(t *testing.T) { testPreforkMaster = true - app := New(Config{DisableStartupMessage: true, Prefork: true}) + app := New() // invalid key file content - require.False(t, app.ListenTLS(":0", "./.github/testdata/ssl.pem", "./.github/testdata/template.tmpl") == nil) + require.False(t, app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/template.tmpl", + }) == nil) + + go func() { + time.Sleep(1000 * time.Millisecond) + require.Nil(t, app.Shutdown()) + }() + + require.Nil(t, app.Listen(":99999", ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + })) - require.Nil(t, app.ListenTLS(":99999", "./.github/testdata/ssl.pem", "./.github/testdata/ssl.key")) } -// go test -run Test_App_ListenMutualTLS -func Test_App_ListenMutualTLS(t *testing.T) { +// go test -run Test_Listen_MutualTLS +func Test_Listen_MutualTLS(t *testing.T) { app := New() // invalid port - require.False(t, app.ListenMutualTLS(":99999", "./.github/testdata/ssl.pem", "./.github/testdata/ssl.key", "./.github/testdata/ca-chain.cert.pem") == nil) - // missing perm/cert file - require.False(t, app.ListenMutualTLS(":0", "", "./.github/testdata/ssl.key", "") == nil) + require.False(t, app.Listen(":99999", ListenConfig{ + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + CertClientFile: "./.github/testdata/ca-chain.cert.pem", + }) == nil) go func() { time.Sleep(1000 * time.Millisecond) require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.ListenMutualTLS(":0", "./.github/testdata/ssl.pem", "./.github/testdata/ssl.key", "./.github/testdata/ca-chain.cert.pem")) + require.Nil(t, app.Listen(":0", ListenConfig{ + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + CertClientFile: "./.github/testdata/ca-chain.cert.pem", + })) + } -// go test -run Test_App_ListenMutualTLS_Prefork -func Test_App_ListenMutualTLS_Prefork(t *testing.T) { +// go test -run Test_Listen_MutualTLS_Prefork +func Test_Listen_MutualTLS_Prefork(t *testing.T) { testPreforkMaster = true - app := New(Config{DisableStartupMessage: true, Prefork: true}) + app := New() // invalid key file content - require.False(t, app.ListenMutualTLS(":0", "./.github/testdata/ssl.pem", "./.github/testdata/template.html", "") == nil) + require.False(t, app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/template.html", + CertClientFile: "./.github/testdata/ca-chain.cert.pem", + }) == nil) + + go func() { + time.Sleep(1000 * time.Millisecond) + require.Nil(t, app.Shutdown()) + }() + + require.Nil(t, app.Listen(":99999", ListenConfig{ + DisableStartupMessage: true, + EnablePrefork: true, + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + CertClientFile: "./.github/testdata/ca-chain.cert.pem", + })) - require.Nil(t, app.ListenMutualTLS(":99999", "./.github/testdata/ssl.pem", "./.github/testdata/ssl.key", "./.github/testdata/ca-chain.cert.pem")) } -// go test -run Test_App_Listener -func Test_App_Listener(t *testing.T) { +// go test -run Test_Listener +func Test_Listener(t *testing.T) { app := New() go func() { @@ -135,46 +237,117 @@ func Test_App_Listener_TLS_Listener(t *testing.T) { require.Nil(t, app.Listener(ln)) } -func captureOutput(f func()) string { - reader, writer, err := os.Pipe() - if err != nil { - panic(err) - } - stdout := os.Stdout - stderr := os.Stderr - defer func() { - os.Stdout = stdout - os.Stderr = stderr - log.SetOutput(os.Stderr) +// go test -run Test_Listen_TLSConfigFunc +func Test_Listen_TLSConfigFunc(t *testing.T) { + var callTLSConfig bool + app := New() + + go func() { + time.Sleep(1000 * time.Millisecond) + require.Nil(t, app.Shutdown()) }() - os.Stdout = writer - os.Stderr = writer - log.SetOutput(writer) - out := make(chan string) - wg := new(sync.WaitGroup) - wg.Add(1) + + require.Nil(t, app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + TLSConfigFunc: func(tlsConfig *tls.Config) { + callTLSConfig = true + }, + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + })) + + require.True(t, callTLSConfig) +} + +// go test -run Test_Listen_ListenerAddrFunc +func Test_Listen_ListenerAddrFunc(t *testing.T) { + var network string + app := New() + go func() { - var buf bytes.Buffer - wg.Done() - _, err := io.Copy(&buf, reader) - if err != nil { - panic(err) - } - out <- buf.String() + time.Sleep(1000 * time.Millisecond) + require.Nil(t, app.Shutdown()) }() - wg.Wait() - f() - err = writer.Close() - if err != nil { - panic(err) - } - return <-out + + require.Nil(t, app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + network = addr.Network() + }, + CertFile: "./.github/testdata/ssl.pem", + CertKeyFile: "./.github/testdata/ssl.key", + })) + + require.Equal(t, "tcp", network) } -func Test_App_Master_Process_Show_Startup_Message(t *testing.T) { +// go test -run Test_Listen_BeforeServeFunc +func Test_Listen_BeforeServeFunc(t *testing.T) { + var handlers uint32 + app := New() + + go func() { + time.Sleep(1000 * time.Millisecond) + require.Nil(t, app.Shutdown()) + }() + + require.Equal(t, errors.New("test"), app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + BeforeServeFunc: func(fiber *App) error { + handlers = fiber.HandlersCount() + + return errors.New("test") + }, + })) + + require.Equal(t, uint32(0), handlers) +} + +// go test -run Test_Listen_ListenerNetwork +func Test_Listen_ListenerNetwork(t *testing.T) { + var network string + app := New() + + go func() { + time.Sleep(1000 * time.Millisecond) + require.Nil(t, app.Shutdown()) + }() + + require.Nil(t, app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + ListenerNetwork: NetworkTCP6, + ListenerAddrFunc: func(addr net.Addr) { + network = addr.String() + }, + })) + + require.True(t, strings.Contains(network, "[::]:")) + + go func() { + time.Sleep(1000 * time.Millisecond) + require.Nil(t, app.Shutdown()) + }() + + require.Nil(t, app.Listen(":0", ListenConfig{ + DisableStartupMessage: true, + ListenerNetwork: NetworkTCP4, + ListenerAddrFunc: func(addr net.Addr) { + network = addr.String() + }, + })) + + require.True(t, strings.Contains(network, "0.0.0.0:")) +} + +// go test -run Test_Listen_Master_Process_Show_Startup_Message +func Test_Listen_Master_Process_Show_Startup_Message(t *testing.T) { + cfg := ListenConfig{ + EnablePrefork: true, + } + startupMessage := captureOutput(func() { - New(Config{Prefork: true}). - startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10)) + New(). + startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10), cfg) }) fmt.Println(startupMessage) require.True(t, strings.Contains(startupMessage, "https://127.0.0.1:3000")) @@ -184,41 +357,41 @@ func Test_App_Master_Process_Show_Startup_Message(t *testing.T) { require.True(t, strings.Contains(startupMessage, "Prefork ........ Enabled")) } -func Test_App_Master_Process_Show_Startup_MessageWithAppName(t *testing.T) { - app := New(Config{Prefork: true, AppName: "Test App v1.0.1"}) - startupMessage := captureOutput(func() { - app.startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10)) - }) - fmt.Println(startupMessage) - require.Equal(t, "Test App v1.0.1", app.Config().AppName) - require.True(t, strings.Contains(startupMessage, app.Config().AppName)) -} +// go test -run Test_Listen_Master_Process_Show_Startup_MessageWithAppName +func Test_Listen_Master_Process_Show_Startup_MessageWithAppName(t *testing.T) { + cfg := ListenConfig{ + EnablePrefork: true, + } -func Test_App_Master_Process_Show_Startup_MessageWithAppNameNonAscii(t *testing.T) { - appName := "Serveur de vérification des données" - app := New(Config{Prefork: true, AppName: appName}) + app := New(Config{AppName: "Test App v3.0.0"}) startupMessage := captureOutput(func() { - app.startupMessage(":3000", false, "") + app.startupMessage(":3000", true, strings.Repeat(",11111,22222,33333,44444,55555,60000", 10), cfg) }) fmt.Println(startupMessage) - require.True(t, strings.Contains(startupMessage, "│ Serveur de vérification des données │")) + require.Equal(t, "Test App v3.0.0", app.Config().AppName) + require.True(t, strings.Contains(startupMessage, app.Config().AppName)) } -func Test_App_print_Route(t *testing.T) { - app := New(Config{EnablePrintRoutes: true}) +// go test -run Test_Listen_Print_Route +func Test_Listen_Print_Route(t *testing.T) { + app := New() app.Get("/", emptyHandler).Name("routeName") + printRoutesMessage := captureOutput(func() { app.printRoutesMessage() }) + fmt.Println(printRoutesMessage) + require.True(t, strings.Contains(printRoutesMessage, "GET")) require.True(t, strings.Contains(printRoutesMessage, "/")) require.True(t, strings.Contains(printRoutesMessage, "emptyHandler")) require.True(t, strings.Contains(printRoutesMessage, "routeName")) } -func Test_App_print_Route_with_group(t *testing.T) { - app := New(Config{EnablePrintRoutes: true}) +// go test -run Test_Listen_Print_Route_With_Group +func Test_Listen_Print_Route_With_Group(t *testing.T) { + app := New() app.Get("/", emptyHandler) v1 := app.Group("v1") @@ -230,6 +403,8 @@ func Test_App_print_Route_with_group(t *testing.T) { app.printRoutesMessage() }) + fmt.Println(printRoutesMessage) + require.True(t, strings.Contains(printRoutesMessage, "GET")) require.True(t, strings.Contains(printRoutesMessage, "/")) require.True(t, strings.Contains(printRoutesMessage, "emptyHandler")) @@ -240,6 +415,42 @@ func Test_App_print_Route_with_group(t *testing.T) { require.True(t, strings.Contains(printRoutesMessage, "/v1/test/fiber/*")) } +func captureOutput(f func()) string { + reader, writer, err := os.Pipe() + if err != nil { + panic(err) + } + stdout := os.Stdout + stderr := os.Stderr + defer func() { + os.Stdout = stdout + os.Stderr = stderr + log.SetOutput(os.Stderr) + }() + os.Stdout = writer + os.Stderr = writer + log.SetOutput(writer) + out := make(chan string) + wg := new(sync.WaitGroup) + wg.Add(1) + go func() { + var buf bytes.Buffer + wg.Done() + _, err := io.Copy(&buf, reader) + if err != nil { + panic(err) + } + out <- buf.String() + }() + wg.Wait() + f() + err = writer.Close() + if err != nil { + panic(err) + } + return <-out +} + func emptyHandler(_ Ctx) error { return nil } diff --git a/middleware/pprof/pprof_test.go b/middleware/pprof/pprof_test.go index b67663189d..eac8b3caec 100644 --- a/middleware/pprof/pprof_test.go +++ b/middleware/pprof/pprof_test.go @@ -11,7 +11,7 @@ import ( ) func Test_Non_Pprof_Path(t *testing.T) { - app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app := fiber.New() app.Use(New()) @@ -29,7 +29,7 @@ func Test_Non_Pprof_Path(t *testing.T) { } func Test_Pprof_Index(t *testing.T) { - app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app := fiber.New() app.Use(New()) @@ -48,7 +48,7 @@ func Test_Pprof_Index(t *testing.T) { } func Test_Pprof_Subs(t *testing.T) { - app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app := fiber.New() app.Use(New()) @@ -75,7 +75,7 @@ func Test_Pprof_Subs(t *testing.T) { } func Test_Pprof_Other(t *testing.T) { - app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app := fiber.New() app.Use(New()) diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 05005a928b..b1debeee08 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -19,18 +19,21 @@ import ( func createProxyTestServer(handler fiber.Handler, t *testing.T) (*fiber.App, string) { t.Helper() - target := fiber.New(fiber.Config{DisableStartupMessage: true}) + target := fiber.New() target.Get("/", handler) ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) + addr := ln.Addr().String() + go func() { - require.Nil(t, target.Listener(ln)) + require.Nil(t, target.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) }() time.Sleep(2 * time.Second) - addr := ln.Addr().String() return target, addr } @@ -77,7 +80,7 @@ func Test_Proxy(t *testing.T) { require.NoError(t, err) require.Equal(t, fiber.StatusTeapot, resp.StatusCode) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app := fiber.New() app.Use(Balancer(Config{Servers: []string{addr}})) @@ -100,7 +103,7 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { ln = tls.NewListener(ln, serverTLSConf) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app := fiber.New() app.Get("/tlsbalaner", func(c fiber.Ctx) error { return c.SendString("tls balancer") @@ -115,7 +118,11 @@ func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { TlsConfig: clientTLSConf, })) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() code, body, errs := fiber.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String() @@ -140,13 +147,17 @@ func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { proxyServerLn = tls.NewListener(proxyServerLn, proxyServerTLSConf) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app := fiber.New() proxyAddr := proxyServerLn.Addr().String() app.Use(Forward("http://" + targetAddr)) - go func() { require.Nil(t, app.Listener(proxyServerLn)) }() + go func() { + require.Nil(t, app.Listener(proxyServerLn, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() code, body, errs := fiber.Get("https://" + proxyAddr). InsecureSkipVerify(). @@ -191,7 +202,7 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) { ln = tls.NewListener(ln, serverTLSConf) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app := fiber.New() app.Get("/tlsfwd", func(c fiber.Ctx) error { return c.SendString("tls forward") @@ -204,7 +215,11 @@ func Test_Proxy_Forward_WithTlsConfig(t *testing.T) { WithTlsConfig(clientTLSConf) app.Use(Forward("https://" + addr + "/tlsfwd")) - go func() { require.Nil(t, app.Listener(ln)) }() + go func() { + require.Nil(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() code, body, errs := fiber.Get("https://" + addr).TLSConfig(clientTLSConf).String() @@ -372,7 +387,7 @@ func Test_Proxy_Do_HTTP_Prefix_URL(t *testing.T) { return c.SendString("hello world") }, t) - app := fiber.New(fiber.Config{DisableStartupMessage: true}) + app := fiber.New() app.Get("/*", func(c fiber.Ctx) error { path := c.OriginalURL() url := strings.TrimPrefix(path, "/") diff --git a/prefork.go b/prefork.go index b3049abc60..71db86b8d7 100644 --- a/prefork.go +++ b/prefork.go @@ -28,7 +28,7 @@ func IsChild() bool { } // prefork manages child processes to make use of the OS REUSEPORT or REUSEADDR feature -func (app *App) prefork(network, addr string, tlsConfig *tls.Config) (err error) { +func (app *App) prefork(addr string, tlsConfig *tls.Config, cfg ListenConfig) (err error) { // 👶 child process 👶 if IsChild() { // use 1 cpu core per child process @@ -36,8 +36,8 @@ func (app *App) prefork(network, addr string, tlsConfig *tls.Config) (err error) var ln net.Listener // Linux will use SO_REUSEPORT and Windows falls back to SO_REUSEADDR // Only tcp4 or tcp6 is supported when preforking, both are not supported - if ln, err = reuseport.Listen(network, addr); err != nil { - if !app.config.DisableStartupMessage { + if ln, err = reuseport.Listen(cfg.ListenerNetwork, addr); err != nil { + if !cfg.DisableStartupMessage { time.Sleep(100 * time.Millisecond) // avoid colliding with startup message } return fmt.Errorf("prefork: %v", err) @@ -53,6 +53,10 @@ func (app *App) prefork(network, addr string, tlsConfig *tls.Config) (err error) // prepare the server for the start app.startupProcess() + if cfg.ListenerAddrFunc != nil { + cfg.ListenerAddrFunc(ln.Addr()) + } + // listen for incoming connections return app.server.Serve(ln) } @@ -94,6 +98,7 @@ func (app *App) prefork(network, addr string, tlsConfig *tls.Config) (err error) cmd.Env = append(os.Environ(), fmt.Sprintf("%s=%s", envPreforkChildKey, envPreforkChildVal), ) + if err = cmd.Start(); err != nil { return fmt.Errorf("failed to start a child prefork process, error: %v", err) } @@ -119,8 +124,13 @@ func (app *App) prefork(network, addr string, tlsConfig *tls.Config) (err error) } // Print startup message - if !app.config.DisableStartupMessage { - app.startupMessage(addr, tlsConfig != nil, ","+strings.Join(pids, ",")) + if !cfg.DisableStartupMessage { + app.startupMessage(addr, tlsConfig != nil, ","+strings.Join(pids, ","), cfg) + } + + // Print routes + if cfg.EnablePrintRoutes { + app.printRoutesMessage() } // return error if child crashes diff --git a/prefork_test.go b/prefork_test.go index f6c04b9b2a..a1a680c1bd 100644 --- a/prefork_test.go +++ b/prefork_test.go @@ -23,7 +23,7 @@ func Test_App_Prefork_Child_Process(t *testing.T) { app := New() - err := app.prefork(NetworkTCP4, "invalid", nil) + err := app.prefork("invalid", nil, listenConfigDefault()) require.False(t, err == nil) go func() { @@ -31,7 +31,7 @@ func Test_App_Prefork_Child_Process(t *testing.T) { require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.prefork(NetworkTCP6, "[::1]:", nil)) + require.Nil(t, app.prefork("[::1]:", nil, ListenConfig{ListenerNetwork: NetworkTCP6})) // Create tls certificate cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") @@ -45,7 +45,7 @@ func Test_App_Prefork_Child_Process(t *testing.T) { require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.prefork(NetworkTCP4, "127.0.0.1:", config)) + require.Nil(t, app.prefork("127.0.0.1:", config, listenConfigDefault())) } func Test_App_Prefork_Master_Process(t *testing.T) { @@ -59,12 +59,14 @@ func Test_App_Prefork_Master_Process(t *testing.T) { require.Nil(t, app.Shutdown()) }() - require.Nil(t, app.prefork(NetworkTCP4, ":3000", nil)) + require.Nil(t, app.prefork(":3000", nil, listenConfigDefault())) dummyChildCmd = "invalid" - err := app.prefork(NetworkTCP4, "127.0.0.1:", nil) + err := app.prefork("127.0.0.1:", nil, listenConfigDefault()) require.False(t, err == nil) + + dummyChildCmd = "go" } func Test_App_Prefork_Child_Process_Never_Show_Startup_Message(t *testing.T) { @@ -79,7 +81,7 @@ func Test_App_Prefork_Child_Process_Never_Show_Startup_Message(t *testing.T) { os.Stdout = w - New().startupProcess().startupMessage(":3000", false, "") + New().startupProcess().startupMessage(":3000", false, "", listenConfigDefault()) require.Nil(t, w.Close())