@@ -27,7 +27,9 @@ import (
2727 "code.cloudfoundry.org/durationjson"
2828 "code.cloudfoundry.org/go-loggregator/v9/rpc/loggregator_v2"
2929 "code.cloudfoundry.org/inigo/helpers/certauthority"
30+ "code.cloudfoundry.org/lager/v3"
3031 "code.cloudfoundry.org/lager/v3/lagerflags"
32+ "code.cloudfoundry.org/lager/v3/lagertest"
3133 "code.cloudfoundry.org/tlsconfig"
3234 "github.com/gogo/protobuf/proto"
3335 . "github.com/onsi/ginkgo/v2"
@@ -482,6 +484,7 @@ var _ = Describe("SSH proxy", func() {
482484 intermediaryTLSConfig * tls.Config
483485 intermediaryListener net.Listener
484486 connectedToTLS chan struct {}
487+ forwardServer * forwardTLSServer
485488 )
486489
487490 BeforeEach (func () {
@@ -505,14 +508,16 @@ var _ = Describe("SSH proxy", func() {
505508 Expect (err ).NotTo (HaveOccurred ())
506509
507510 connectedToTLS = make (chan struct {}, 1 )
511+ logger := lagertest .NewTestLogger ("ssh-proxy-test" )
512+ forwardServer = NewForwardTLSServer (logger , intermediaryListener , sshdAddress )
508513 })
509514
510515 JustBeforeEach (func () {
511- go forwardTLSConn ( sshdAddress , intermediaryListener , connectedToTLS )
516+ go forwardServer . Start ( connectedToTLS )
512517 })
513518
514519 AfterEach (func () {
515- intermediaryListener . Close ()
520+ forwardServer . Stop ()
516521 close (connectedToTLS )
517522 })
518523
@@ -1117,41 +1122,85 @@ func RespondWithProto(message proto.Message) http.HandlerFunc {
11171122 return ghttp .RespondWith (200 , string (data ), headers )
11181123}
11191124
1120- func forwardTLSConn (serverAddress string , proxy net.Listener , onConnectionReceived chan struct {}) {
1125+ type forwardTLSServer struct {
1126+ logger lager.Logger
1127+ proxy net.Listener
1128+ stopCh chan struct {}
1129+ address string
1130+ }
1131+
1132+ func NewForwardTLSServer (logger lager.Logger , proxy net.Listener , address string ) * forwardTLSServer {
1133+ return & forwardTLSServer {
1134+ logger : logger .Session ("forward-tls-server" ),
1135+ proxy : proxy ,
1136+ address : address ,
1137+ stopCh : make (chan struct {}),
1138+ }
1139+ }
1140+
1141+ func (s * forwardTLSServer ) Start (onConnectionReceived chan struct {}) error {
11211142 for {
1122- conn , err := proxy .Accept ()
1123- if err != nil {
1124- return
1125- }
1143+ select {
1144+ case <- s .stopCh :
1145+ return nil
1146+ default :
1147+ conn , err := s .proxy .Accept ()
1148+ if err != nil {
1149+ select {
1150+ case <- s .stopCh :
1151+ return nil
1152+ default :
1153+ s .logger .Error ("failed-to-receive-connection" , err )
1154+ return err
1155+ }
1156+ }
11261157
1127- tlsConn := conn .(* tls.Conn )
1128- err = tlsConn .Handshake ()
1129- if err != nil {
1130- return
1131- }
1158+ tlsConn := conn .(* tls.Conn )
1159+ err = tlsConn .Handshake ()
1160+ if err != nil {
1161+ select {
1162+ case <- s .stopCh :
1163+ return nil
1164+ default :
1165+ s .logger .Error ("failed-to-tls-handshake" , err )
1166+ return err
1167+ }
1168+ }
11321169
1133- if onConnectionReceived != nil {
1134- onConnectionReceived <- struct {}{}
1135- }
1170+ if onConnectionReceived != nil {
1171+ onConnectionReceived <- struct {}{}
1172+ }
11361173
1137- proxyConn , err := net .Dial ("tcp" , serverAddress )
1138- if err != nil {
1139- return
1140- }
1174+ proxyConn , err := net .Dial ("tcp" , s .address )
1175+ if err != nil {
1176+ select {
1177+ case <- s .stopCh :
1178+ return nil
1179+ default :
1180+ s .logger .Error ("failed-to-dial" , err )
1181+ return err
1182+ }
1183+ }
11411184
1142- wg := sync.WaitGroup {}
1143- wg .Add (2 )
1185+ wg := sync.WaitGroup {}
1186+ wg .Add (2 )
11441187
1145- go func () {
1146- _ , _ = io .Copy (conn , proxyConn )
1147- wg .Done ()
1148- }()
1188+ go func () {
1189+ _ , _ = io .Copy (conn , proxyConn )
1190+ wg .Done ()
1191+ }()
11491192
1150- go func () {
1151- _ , _ = io .Copy (proxyConn , conn )
1152- wg .Done ()
1153- }()
1193+ go func () {
1194+ _ , _ = io .Copy (proxyConn , conn )
1195+ wg .Done ()
1196+ }()
11541197
1155- wg .Wait ()
1198+ wg .Wait ()
1199+ }
11561200 }
11571201}
1202+
1203+ func (s * forwardTLSServer ) Stop () {
1204+ close (s .stopCh )
1205+ s .proxy .Close ()
1206+ }
0 commit comments