Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions router/core/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -49,6 +51,7 @@ type server struct {
state atomic.Pointer[serverState]
healthcheck health.Checker
baseURL string
listener net.Listener // Pre-bound listener for synchronous port check
}

type httpServerOptions struct {
Expand All @@ -64,7 +67,13 @@ type httpServerOptions struct {
healthCheckPath string
}

func newServer(opts *httpServerOptions) *server {
func newServer(opts *httpServerOptions) (*server, error) {
// Bind the port synchronously to detect port conflicts immediately
listener, err := net.Listen("tcp", opts.addr)
if err != nil {
return nil, fmt.Errorf("failed to bind to address %s: %w", opts.addr, err)
}

httpServer := &http.Server{
Addr: opts.addr,
ReadTimeout: 60 * time.Second,
Expand All @@ -89,6 +98,7 @@ func newServer(opts *httpServerOptions) *server {
mu: sync.RWMutex{},
healthcheck: opts.healthcheck,
baseURL: opts.baseURL,
listener: listener, // Store the pre-bound listener
}

// Store the initial state with health check mux (graphServer nil until first config)
Expand All @@ -104,7 +114,7 @@ func newServer(opts *httpServerOptions) *server {
n.state.Load().mux.ServeHTTP(w, r)
})

return n
return n, nil
}

func (s *server) HealthChecks() health.Checker {
Expand Down Expand Up @@ -140,15 +150,17 @@ func (s *server) SwapGraphServer(ctx context.Context, svr *graphServer) {
}
}

// listenAndServe starts the server and blocks until the server is shutdown.
// listenAndServe starts the server using the pre-bound listener and blocks until shutdown.
// This method is called in a goroutine; the port was already bound in newServer().
func (s *server) listenAndServe() error {
if s.tlsConfig != nil && s.tlsConfig.Enabled {
// Leave the cert and key empty to use the default ones
if err := s.httpServer.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
// Use TLS with the pre-bound listener
if err := s.httpServer.ServeTLS(s.listener, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
} else {
if err := s.httpServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
// Use plain HTTP with the pre-bound listener
if err := s.httpServer.Serve(s.listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
}
Expand Down
97 changes: 97 additions & 0 deletions router/core/http_server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package core

import (
"net"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1"
"github.com/wundergraph/cosmo/router/pkg/health"
"go.uber.org/zap"
)

func TestNewServer_PortBindingError(t *testing.T) {
// Bind a port first
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

// Get the address that was bound
addr := listener.Addr().String()

// Try to create a server on the same port - this should fail immediately
logger := zap.NewNop()
hc := health.New(&health.Options{Logger: logger})

_, err = newServer(&httpServerOptions{
addr: addr,
logger: logger,
healthcheck: hc,
baseURL: "http://" + addr,
maxHeaderBytes: 1024,
healthCheckPath: "/health",
livenessCheckPath: "/health/live",
readinessCheckPath: "/health/ready",
})

// Should return an error immediately, not succeed
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to bind to address")
}

func TestNewServer_PortBindingSuccess(t *testing.T) {
// Find an available port
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := listener.Addr().String()
listener.Close() // Close it so we can use it

// Try to create a server on the available port - this should succeed
logger := zap.NewNop()
hc := health.New(&health.Options{Logger: logger})

server, err := newServer(&httpServerOptions{
addr: addr,
logger: logger,
healthcheck: hc,
baseURL: "http://" + addr,
maxHeaderBytes: 1024,
healthCheckPath: "/health",
livenessCheckPath: "/health/live",
readinessCheckPath: "/health/ready",
})

// Should succeed
assert.NoError(t, err)
assert.NotNil(t, server)

// Clean up
if server != nil {
server.Shutdown(t.Context())
}
}

func TestRouter_Start_PortBindingError(t *testing.T) {
// Bind a port first
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer listener.Close()

// Get the address that was bound
addr := listener.Addr().String()

// Create a router with static config that uses the already-bound port
router, err := NewRouter(
WithStaticExecutionConfig(&nodev1.RouterConfig{
Version: "1.0.0",
}),
WithListenerAddr(addr),
)
require.NoError(t, err)

// Try to start the router - should fail immediately with port binding error
err = router.Start(t.Context())
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to create server")
}
12 changes: 10 additions & 2 deletions router/core/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,8 @@ func (r *Router) NewServer(ctx context.Context) (Server, error) {
return nil, fmt.Errorf("failed to bootstrap application: %w", err)
}

r.httpServer = newServer(&httpServerOptions{
var err error
r.httpServer, err = newServer(&httpServerOptions{
addr: r.listenAddr,
logger: r.logger,
tlsConfig: r.tlsConfig,
Expand All @@ -774,6 +775,9 @@ func (r *Router) NewServer(ctx context.Context) (Server, error) {
readinessCheckPath: r.readinessCheckPath,
healthCheckPath: r.healthCheckPath,
})
if err != nil {
return nil, fmt.Errorf("failed to create server: %w", err)
}

r.configureUsageTracking(ctx)

Expand Down Expand Up @@ -1376,7 +1380,8 @@ func (r *Router) Start(ctx context.Context) error {

r.trackRouterConfigUsage()

r.httpServer = newServer(&httpServerOptions{
var err error
r.httpServer, err = newServer(&httpServerOptions{
addr: r.listenAddr,
logger: r.logger,
tlsConfig: r.tlsConfig,
Expand All @@ -1388,6 +1393,9 @@ func (r *Router) Start(ctx context.Context) error {
readinessCheckPath: r.readinessCheckPath,
healthCheckPath: r.healthCheckPath,
})
if err != nil {
return fmt.Errorf("failed to create server: %w", err)
}

if r.reloadPersistentState == nil {
// This is only applicable for tests since we do not call here via the supervisor
Expand Down
Loading