@@ -2,8 +2,8 @@ package proxy
22
33import (
44 "bufio"
5- "context"
65 "crypto/tls"
6+ "errors"
77 "fmt"
88 "io"
99 "log/slog"
@@ -12,7 +12,7 @@ import (
1212 "net/url"
1313 "strings"
1414 "sync"
15- "time "
15+ "sync/atomic "
1616
1717 "github.com/coder/boundary/audit"
1818 "github.com/coder/boundary/rules"
@@ -25,8 +25,9 @@ type Server struct {
2525 logger * slog.Logger
2626 tlsConfig * tls.Config
2727 httpPort int
28+ started atomic.Bool
2829
29- httpServer * http. Server
30+ listener net. Listener
3031}
3132
3233// Config holds configuration for the proxy server
@@ -50,64 +51,70 @@ func NewProxyServer(config Config) *Server {
5051}
5152
5253// Start starts the HTTP proxy server with TLS termination capability
53- func (p * Server ) Start (ctx context.Context ) error {
54- // Create HTTP server with TLS termination capability
55- p .httpServer = & http.Server {
56- Addr : fmt .Sprintf (":%d" , p .httpPort ),
57- Handler : http .HandlerFunc (p .handleHTTPWithTLSTermination ),
54+ func (p * Server ) Start () error {
55+ if p .isStarted () {
56+ return nil
57+ }
58+
59+ p .logger .Info ("Starting HTTP proxy with TLS termination" , "port" , p .httpPort )
60+ var err error
61+ p .listener , err = net .Listen ("tcp" , fmt .Sprintf (":%d" , p .httpPort ))
62+ if err != nil {
63+ p .logger .Error ("Failed to create HTTP listener" , "error" , err )
64+ return err
5865 }
5966
67+ p .started .Store (true )
68+
6069 // Start HTTP server with custom listener for TLS detection
6170 go func () {
62- p .logger .Info ("Starting HTTP proxy with TLS termination" , "port" , p .httpPort )
63- listener , err := net .Listen ("tcp" , fmt .Sprintf (":%d" , p .httpPort ))
64- if err != nil {
65- p .logger .Error ("Failed to create HTTP listener" , "error" , err )
66- return
67- }
68-
6971 for {
70- conn , err := listener .Accept ()
72+ conn , err := p .listener .Accept ()
73+ if err != nil && errors .Is (err , net .ErrClosed ) && p .isStopped () {
74+ return
75+ }
7176 if err != nil {
72- select {
73- case <- ctx .Done ():
74- err = listener .Close ()
75- if err != nil {
76- p .logger .Error ("Failed to close listener" , "error" , err )
77- }
78- return
79- default :
80- p .logger .Error ("Failed to accept connection" , "error" , err )
81- continue
82- }
77+ p .logger .Error ("Failed to accept connection" , "error" , err )
78+ continue
8379 }
8480
8581 // Handle connection with TLS detection
8682 go p .handleConnectionWithTLSDetection (conn )
8783 }
8884 }()
8985
90- // Wait for context cancellation
91- <- ctx .Done ()
92- return p .Stop ()
86+ return nil
9387}
9488
9589// Stops proxy server
9690func (p * Server ) Stop () error {
97- ctx , cancel := context .WithTimeout (context .Background (), 5 * time .Second )
98- defer cancel ()
91+ if p .isStopped () {
92+ return nil
93+ }
94+ p .started .Store (false )
9995
100- var httpErr error
101- if p . httpServer != nil {
102- httpErr = p . httpServer . Shutdown ( ctx )
96+ if p . listener == nil {
97+ p . logger . Error ( "unexpected nil listener" )
98+ return errors . New ( "unexpected nil listener" )
10399 }
104100
105- if httpErr != nil {
106- return httpErr
101+ err := p .listener .Close ()
102+ if err != nil {
103+ p .logger .Error ("Failed to close listener" , "error" , err )
104+ return err
107105 }
106+
108107 return nil
109108}
110109
110+ func (p * Server ) isStarted () bool {
111+ return p .started .Load ()
112+ }
113+
114+ func (p * Server ) isStopped () bool {
115+ return ! p .started .Load ()
116+ }
117+
111118// handleHTTP handles regular HTTP requests and CONNECT tunneling
112119func (p * Server ) handleHTTP (w http.ResponseWriter , r * http.Request ) {
113120 p .logger .Debug ("handleHTTP called" , "method" , r .Method , "url" , r .URL .String (), "host" , r .Host )
@@ -479,13 +486,6 @@ func (p *Server) handleConnectionWithTLSDetection(conn net.Conn) {
479486 }
480487}
481488
482- // handleHTTPWithTLSTermination is the main handler (currently just delegates to regular HTTP)
483- func (p * Server ) handleHTTPWithTLSTermination (w http.ResponseWriter , r * http.Request ) {
484- // This handler is not used when we do custom connection handling
485- // All traffic goes through handleConnectionWithTLSDetection
486- p .handleHTTP (w , r )
487- }
488-
489489// connectionWrapper lets us "unread" the peeked byte
490490type connectionWrapper struct {
491491 net.Conn
0 commit comments