Skip to content

Commit 0e1b7b1

Browse files
authored
Wait for the server to stop in AfterEach (#57)
Log all server errors
1 parent ddb0135 commit 0e1b7b1

File tree

1 file changed

+79
-30
lines changed

1 file changed

+79
-30
lines changed

cmd/ssh-proxy/main_test.go

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ import (
2727
"code.cloudfoundry.org/durationjson"
2828
"code.cloudfoundry.org/go-loggregator/v9/rpc/loggregator_v2"
2929
"code.cloudfoundry.org/inigo/helpers/certauthority"
30+
"code.cloudfoundry.org/lager/v3"
3031
"code.cloudfoundry.org/lager/v3/lagerflags"
32+
"code.cloudfoundry.org/lager/v3/lagertest"
3133
"code.cloudfoundry.org/tlsconfig"
3234
"github.com/gogo/protobuf/proto"
3335
. "github.com/onsi/ginkgo/v2"
@@ -482,6 +484,7 @@ var _ = Describe("SSH proxy", func() {
482484
intermediaryTLSConfig *tls.Config
483485
intermediaryListener net.Listener
484486
connectedToTLS chan struct{}
487+
forwardServer *forwardTLSServer
485488
)
486489

487490
BeforeEach(func() {
@@ -505,14 +508,16 @@ var _ = Describe("SSH proxy", func() {
505508
Expect(err).NotTo(HaveOccurred())
506509

507510
connectedToTLS = make(chan struct{}, 1)
511+
logger := lagertest.NewTestLogger("ssh-proxy-test")
512+
forwardServer = NewForwardTLSServer(logger, intermediaryListener, sshdAddress)
508513
})
509514

510515
JustBeforeEach(func() {
511-
go forwardTLSConn(sshdAddress, intermediaryListener, connectedToTLS)
516+
go forwardServer.Start(connectedToTLS)
512517
})
513518

514519
AfterEach(func() {
515-
intermediaryListener.Close()
520+
forwardServer.Stop()
516521
close(connectedToTLS)
517522
})
518523

@@ -1117,41 +1122,85 @@ func RespondWithProto(message proto.Message) http.HandlerFunc {
11171122
return ghttp.RespondWith(200, string(data), headers)
11181123
}
11191124

1120-
func forwardTLSConn(serverAddress string, proxy net.Listener, onConnectionReceived chan struct{}) {
1125+
type forwardTLSServer struct {
1126+
logger lager.Logger
1127+
proxy net.Listener
1128+
stopCh chan struct{}
1129+
address string
1130+
}
1131+
1132+
func NewForwardTLSServer(logger lager.Logger, proxy net.Listener, address string) *forwardTLSServer {
1133+
return &forwardTLSServer{
1134+
logger: logger.Session("forward-tls-server"),
1135+
proxy: proxy,
1136+
address: address,
1137+
stopCh: make(chan struct{}),
1138+
}
1139+
}
1140+
1141+
func (s *forwardTLSServer) Start(onConnectionReceived chan struct{}) error {
11211142
for {
1122-
conn, err := proxy.Accept()
1123-
if err != nil {
1124-
return
1125-
}
1143+
select {
1144+
case <-s.stopCh:
1145+
return nil
1146+
default:
1147+
conn, err := s.proxy.Accept()
1148+
if err != nil {
1149+
select {
1150+
case <-s.stopCh:
1151+
return nil
1152+
default:
1153+
s.logger.Error("failed-to-receive-connection", err)
1154+
return err
1155+
}
1156+
}
11261157

1127-
tlsConn := conn.(*tls.Conn)
1128-
err = tlsConn.Handshake()
1129-
if err != nil {
1130-
return
1131-
}
1158+
tlsConn := conn.(*tls.Conn)
1159+
err = tlsConn.Handshake()
1160+
if err != nil {
1161+
select {
1162+
case <-s.stopCh:
1163+
return nil
1164+
default:
1165+
s.logger.Error("failed-to-tls-handshake", err)
1166+
return err
1167+
}
1168+
}
11321169

1133-
if onConnectionReceived != nil {
1134-
onConnectionReceived <- struct{}{}
1135-
}
1170+
if onConnectionReceived != nil {
1171+
onConnectionReceived <- struct{}{}
1172+
}
11361173

1137-
proxyConn, err := net.Dial("tcp", serverAddress)
1138-
if err != nil {
1139-
return
1140-
}
1174+
proxyConn, err := net.Dial("tcp", s.address)
1175+
if err != nil {
1176+
select {
1177+
case <-s.stopCh:
1178+
return nil
1179+
default:
1180+
s.logger.Error("failed-to-dial", err)
1181+
return err
1182+
}
1183+
}
11411184

1142-
wg := sync.WaitGroup{}
1143-
wg.Add(2)
1185+
wg := sync.WaitGroup{}
1186+
wg.Add(2)
11441187

1145-
go func() {
1146-
_, _ = io.Copy(conn, proxyConn)
1147-
wg.Done()
1148-
}()
1188+
go func() {
1189+
_, _ = io.Copy(conn, proxyConn)
1190+
wg.Done()
1191+
}()
11491192

1150-
go func() {
1151-
_, _ = io.Copy(proxyConn, conn)
1152-
wg.Done()
1153-
}()
1193+
go func() {
1194+
_, _ = io.Copy(proxyConn, conn)
1195+
wg.Done()
1196+
}()
11541197

1155-
wg.Wait()
1198+
wg.Wait()
1199+
}
11561200
}
11571201
}
1202+
1203+
func (s *forwardTLSServer) Stop() {
1204+
close(s.stopCh)
1205+
s.proxy.Close()
1206+
}

0 commit comments

Comments
 (0)