diff --git a/.gitmodules b/.gitmodules index 970565e8fa..4f6388aec2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -30,10 +30,6 @@ path = src/code.cloudfoundry.org/buildpackapplifecycle url = https://github.com/cloudfoundry/buildpackapplifecycle branch = main -[submodule "src/code.cloudfoundry.org/diego-ssh"] - path = src/code.cloudfoundry.org/diego-ssh - url = https://github.com/cloudfoundry/diego-ssh - branch = main [submodule "src/code.cloudfoundry.org/route-emitter"] path = src/code.cloudfoundry.org/route-emitter url = https://github.com/cloudfoundry/route-emitter diff --git a/src/code.cloudfoundry.org/diego-ssh b/src/code.cloudfoundry.org/diego-ssh deleted file mode 160000 index bd398c2f4b..0000000000 --- a/src/code.cloudfoundry.org/diego-ssh +++ /dev/null @@ -1 +0,0 @@ -Subproject commit bd398c2f4b8910db7383ee9bcf656eaec6bc50f9 diff --git a/src/code.cloudfoundry.org/diego-ssh/.gitignore b/src/code.cloudfoundry.org/diego-ssh/.gitignore new file mode 100644 index 0000000000..6278703006 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/.gitignore @@ -0,0 +1,8 @@ +*.coverprofile +*.exe +*.swp +*.test +*~ +.DS_Store +.idea +tags diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/authenticators_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/authenticators_suite_test.go new file mode 100644 index 0000000000..161ae48f51 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/authenticators_suite_test.go @@ -0,0 +1,13 @@ +package authenticators_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "testing" +) + +func TestAuthenticators(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Authenticators Suite") +} diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/cf_authenticator.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/cf_authenticator.go new file mode 100644 index 0000000000..336ed4137a --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/cf_authenticator.go @@ -0,0 +1,200 @@ +package authenticators + +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + + "code.cloudfoundry.org/lager/v3" + "github.com/golang-jwt/jwt/v4" + "golang.org/x/crypto/ssh" +) + +type CFAuthenticator struct { + logger lager.Logger + httpClient *http.Client + ccURL string + uaaTokenURL string + uaaPassword string + uaaUsername string + permissionsBuilder PermissionsBuilder +} + +type AppSSHResponse struct { + ProcessGuid string `json:"process_guid"` +} + +type UAAAuthTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` +} + +var CFUserRegex *regexp.Regexp = regexp.MustCompile(`cf:([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})/(\d+)`) + +func NewCFAuthenticator( + logger lager.Logger, + httpClient *http.Client, + ccURL string, + uaaTokenURL string, + uaaUsername string, + uaaPassword string, + permissionsBuilder PermissionsBuilder, +) *CFAuthenticator { + return &CFAuthenticator{ + logger: logger, + httpClient: httpClient, + ccURL: ccURL, + uaaTokenURL: uaaTokenURL, + uaaUsername: uaaUsername, + uaaPassword: uaaPassword, + permissionsBuilder: permissionsBuilder, + } +} + +func (cfa *CFAuthenticator) UserRegexp() *regexp.Regexp { + return CFUserRegex +} + +func (cfa *CFAuthenticator) Authenticate(metadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + logger := cfa.logger.Session("cf-authenticate") + logger.Info("authenticate-starting") + defer logger.Info("authenticate-finished") + + if !CFUserRegex.MatchString(metadata.User()) { + logger.Error("regex-match-fail", InvalidCredentialsErr) + return nil, InvalidCredentialsErr + } + + guidAndIndex := CFUserRegex.FindStringSubmatch(metadata.User()) + + appGuid := guidAndIndex[1] + + index, err := strconv.Atoi(guidAndIndex[2]) + if err != nil { + logger.Error("atoi-failed", err) + return nil, InvalidCredentialsErr + } + + cred, err := cfa.exchangeAccessCodeForToken(logger, string(password)) + if err != nil { + return nil, err + } + + parts := strings.Split(cred, " ") + if len(parts) != 2 { + return nil, AuthenticationFailedErr + } + tokenString := parts[1] + // When parsing the certificate validating the signature is not required and we don't readily have the + // certificate to validate the signature. This is just to parse the second information part of the token anyway. + token, _ := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { + return []byte("Doesntmatter"), nil + }) + + username, ok := token.Claims.(jwt.MapClaims)["user_name"].(string) + if !ok { + username = "unknown" + } + principal, ok := token.Claims.(jwt.MapClaims)["user_id"].(string) + if !ok { + principal = "unknown" + } + + logger = logger.WithData(lager.Data{ + "app": fmt.Sprintf("%s/%d", appGuid, index), + "principal": principal, + "username": username, + }) + + processGuid, err := cfa.checkAccess(logger, appGuid, index, string(cred)) + if err != nil { + return nil, err + } + + permissions, err := cfa.permissionsBuilder.Build(logger, processGuid, index, metadata) + if err != nil { + logger.Error("building-ssh-permissions-failed", err) + } + + logger.Info("app-access-success") + + return permissions, err +} + +func (cfa *CFAuthenticator) exchangeAccessCodeForToken(logger lager.Logger, code string) (string, error) { + logger = logger.Session("exchange-access-code-for-token") + + formValues := make(url.Values) + formValues.Set("grant_type", "authorization_code") + formValues.Set("code", code) + + req, err := http.NewRequest("POST", cfa.uaaTokenURL, strings.NewReader(formValues.Encode())) + if err != nil { + return "", err + } + + req.SetBasicAuth(cfa.uaaUsername, cfa.uaaPassword) + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := cfa.httpClient.Do(req) + if err != nil { + logger.Error("request-failed", err) + return "", AuthenticationFailedErr + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.Error("response-status-not-ok", AuthenticationFailedErr, lager.Data{ + "status-code": resp.StatusCode, + }) + return "", AuthenticationFailedErr + } + + var tokenResponse UAAAuthTokenResponse + err = json.NewDecoder(resp.Body).Decode(&tokenResponse) + if err != nil { + logger.Error("decode-token-response-failed", err) + return "", AuthenticationFailedErr + } + + return fmt.Sprintf("%s %s", tokenResponse.TokenType, tokenResponse.AccessToken), nil +} + +func (cfa *CFAuthenticator) checkAccess(logger lager.Logger, appGuid string, index int, token string) (string, error) { + path := fmt.Sprintf("%s/internal/apps/%s/ssh_access/%d", cfa.ccURL, appGuid, index) + + req, err := http.NewRequest("GET", path, nil) + if err != nil { + logger.Error("creating-request-failed", InvalidRequestErr) + return "", InvalidRequestErr + } + req.Header.Add("Authorization", token) + + resp, err := cfa.httpClient.Do(req) + if err != nil { + logger.Error("fetching-app-failed", err) + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + logger.Error("fetching-app-failed", FetchAppFailedErr, lager.Data{ + "StatusCode": resp.Status, + "ResponseBody": resp.Body, + }) + return "", FetchAppFailedErr + } + + var app AppSSHResponse + err = json.NewDecoder(resp.Body).Decode(&app) + if err != nil { + logger.Error("invalid-cc-response", err) + return "", InvalidCCResponse + } + + return app.ProcessGuid, nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/cf_authenticator_test.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/cf_authenticator_test.go new file mode 100644 index 0000000000..de59ee4998 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/cf_authenticator_test.go @@ -0,0 +1,309 @@ +package authenticators_test + +import ( + "math" + "net/http" + "net/url" + "regexp" + "strconv" + "time" + + "code.cloudfoundry.org/diego-ssh/authenticators" + "code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" + "github.com/onsi/gomega/ghttp" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("CFAuthenticator", func() { + var ( + authenticator *authenticators.CFAuthenticator + logger *lagertest.TestLogger + httpClient *http.Client + httpClientTimeout time.Duration + permissionsBuilder *fake_authenticators.FakePermissionsBuilder + + authenErr error + + metadata *fake_ssh.FakeConnMetadata + password []byte + + fakeCC *ghttp.Server + fakeUAA *ghttp.Server + ccURL string + uaaTokenURL string + uaaUsername string + uaaPassword string + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + + httpClientTimeout = time.Second + httpClient = &http.Client{Timeout: httpClientTimeout} + + permissionsBuilder = &fake_authenticators.FakePermissionsBuilder{} + permissionsBuilder.BuildReturns(&ssh.Permissions{}, nil) + + metadata = &fake_ssh.FakeConnMetadata{} + + fakeCC = ghttp.NewServer() + ccURL = fakeCC.URL() + + fakeUAA = ghttp.NewServer() + u, err := url.Parse(fakeUAA.URL()) + Expect(err).NotTo(HaveOccurred()) + uaaUsername = "diego-ssh" + uaaPassword = "fake-diego-ssh-secret-$\"^&'" + + u.Path = "/oauth/token" + uaaTokenURL = u.String() + }) + + JustBeforeEach(func() { + authenticator = authenticators.NewCFAuthenticator(logger, httpClient, ccURL, uaaTokenURL, uaaUsername, uaaPassword, permissionsBuilder) + _, authenErr = authenticator.Authenticate(metadata, password) + }) + + Describe("UserRegexp", func() { + var regexp *regexp.Regexp + + BeforeEach(func() { + regexp = authenticator.UserRegexp() + }) + + It("matches cf:/ patterns", func() { + Expect(regexp.MatchString("cf:986fedf8-6b74-45af-827c-a4464e6aa05c/00")).To(BeTrue()) + Expect(regexp.MatchString("cf:986FEDF8-6B74-45AF-827C-A4464E6AA05C/00")).To(BeTrue()) + }) + + It("does not match other patterns", func() { + Expect(regexp.MatchString("cf:hhhhhhhh-6b74-45af-827c-a4464e6aa05c/00")).To(BeFalse()) + Expect(regexp.MatchString("cf:986fedf81-6b74-45af-827c-a4464e6aa05c/00")).To(BeFalse()) + Expect(regexp.MatchString("cf:986fedf8-6b74-45af-827c-a4464e6aa05c/")).To(BeFalse()) + Expect(regexp.MatchString("cf:guid/1")).To(BeFalse()) + Expect(regexp.MatchString("cf:/00")).To(BeFalse()) + Expect(regexp.MatchString("diego:guid/0")).To(BeFalse()) + Expect(regexp.MatchString("diego:guid/99")).To(BeFalse()) + Expect(regexp.MatchString("user@guid/0")).To(BeFalse()) + }) + }) + + Describe("Authenticate invalid token returned", func() { + const expectedOneTimeCode = "abc123" + + var ( + uaaTokenResponse *authenticators.UAAAuthTokenResponse + uaaTokenResponseCode int + ) + + BeforeEach(func() { + metadata.UserReturns("cf:1e051b88-a210-40b7-bcca-df645b24b634/1") + password = []byte(expectedOneTimeCode) + + uaaTokenResponseCode = http.StatusOK + uaaTokenResponse = &authenticators.UAAAuthTokenResponse{ + AccessToken: "is not right", + TokenType: "bearer", + } + + fakeUAA.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/oauth/token"), + ghttp.VerifyBasicAuth("diego-ssh", "fake-diego-ssh-secret-$\"^&'"), + ghttp.VerifyFormKV("grant_type", "authorization_code"), + ghttp.VerifyFormKV("code", expectedOneTimeCode), + ghttp.RespondWithJSONEncodedPtr(&uaaTokenResponseCode, uaaTokenResponse), + ), + ) + }) + + It("logs the access to the container by the user", func() { + Expect(authenErr).To(Equal(authenticators.AuthenticationFailedErr)) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Describe("Authenticate", func() { + const expectedOneTimeCode = "abc123" + + var ( + uaaTokenResponse *authenticators.UAAAuthTokenResponse + uaaTokenResponseCode int + + sshAccessResponse *authenticators.AppSSHResponse + sshAccessResponseCode int + ) + + BeforeEach(func() { + metadata.UserReturns("cf:1e051b88-a210-40b7-bcca-df645b24b634/1") + password = []byte(expectedOneTimeCode) + + uaaTokenResponseCode = http.StatusOK + uaaTokenResponse = &authenticators.UAAAuthTokenResponse{ + AccessToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6ImxlZ2FjeS10b2tlbi1rZXkiLCJ0eXAiOiJKV1QifQ.eyJqdGkiOiJmMGMyYWRkN2E5MDI0NTQyOWExZTdiMjNjZGVlZjkyZiIsInN1YiI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsInNjb3BlIjpbInJvdXRpbmcucm91dGVyX2dyb3Vwcy5yZWFkIiwiY2xvdWRfY29udHJvbGxlci5yZWFkIiwicGFzc3dvcmQud3JpdGUiLCJjbG91ZF9jb250cm9sbGVyLndyaXRlIiwib3BlbmlkIiwicm91dGluZy5yb3V0ZXJfZ3JvdXBzLndyaXRlIiwiZG9wcGxlci5maXJlaG9zZSIsInNjaW0ud3JpdGUiLCJzY2ltLnJlYWQiLCJjbG91ZF9jb250cm9sbGVyLmFkbWluIiwidWFhLnVzZXIiXSwiY2xpZW50X2lkIjoiY2YiLCJjaWQiOiJjZiIsImF6cCI6ImNmIiwiZ3JhbnRfdHlwZSI6InBhc3N3b3JkIiwidXNlcl9pZCI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsIm9yaWdpbiI6InVhYSIsInVzZXJfbmFtZSI6ImFkbWluIiwiZW1haWwiOiJhZG1pbiIsInJldl9zaWciOiJiMzUyMDU5ZiIsImlhdCI6MTQ3ODUxMzI3NywiZXhwIjoxNDc4NTEzODc3LCJpc3MiOiJodHRwczovL3VhYS5ib3NoLWxpdGUuY29tL29hdXRoL3Rva2VuIiwiemlkIjoidWFhIiwiYXVkIjpbInNjaW0iLCJjbG91ZF9jb250cm9sbGVyIiwicGFzc3dvcmQiLCJjZiIsInVhYSIsIm9wZW5pZCIsImRvcHBsZXIiLCJyb3V0aW5nLnJvdXRlcl9ncm91cHMiXX0.d8YS9HYM2QJ7f3xXjwHjZsGHCD2a4hM3tNQdGUQCJzT45YQkFZAJJDFIn4rai0YXJyswHmNT3K9pwKBzzcVzbe2HoMyI2HhCn3vW45OA7r55ATYmA88F1KkOtGitO_qi5NPhqDlQwg55kr6PzWAE84BXgWwivMXDDcwkyQosVYA", + TokenType: "bearer", + } + + fakeUAA.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/oauth/token"), + ghttp.VerifyBasicAuth("diego-ssh", "fake-diego-ssh-secret-$\"^&'"), + ghttp.VerifyFormKV("grant_type", "authorization_code"), + ghttp.VerifyFormKV("code", expectedOneTimeCode), + ghttp.RespondWithJSONEncodedPtr(&uaaTokenResponseCode, uaaTokenResponse), + ), + ) + + sshAccessResponseCode = http.StatusOK + sshAccessResponse = &authenticators.AppSSHResponse{ + ProcessGuid: "app-guid-app-version", + } + + fakeCC.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyRequest("GET", "/internal/apps/1e051b88-a210-40b7-bcca-df645b24b634/ssh_access/1"), + ghttp.VerifyHeader(http.Header{"Authorization": []string{"bearer eyJhbGciOiJSUzI1NiIsImtpZCI6ImxlZ2FjeS10b2tlbi1rZXkiLCJ0eXAiOiJKV1QifQ.eyJqdGkiOiJmMGMyYWRkN2E5MDI0NTQyOWExZTdiMjNjZGVlZjkyZiIsInN1YiI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsInNjb3BlIjpbInJvdXRpbmcucm91dGVyX2dyb3Vwcy5yZWFkIiwiY2xvdWRfY29udHJvbGxlci5yZWFkIiwicGFzc3dvcmQud3JpdGUiLCJjbG91ZF9jb250cm9sbGVyLndyaXRlIiwib3BlbmlkIiwicm91dGluZy5yb3V0ZXJfZ3JvdXBzLndyaXRlIiwiZG9wcGxlci5maXJlaG9zZSIsInNjaW0ud3JpdGUiLCJzY2ltLnJlYWQiLCJjbG91ZF9jb250cm9sbGVyLmFkbWluIiwidWFhLnVzZXIiXSwiY2xpZW50X2lkIjoiY2YiLCJjaWQiOiJjZiIsImF6cCI6ImNmIiwiZ3JhbnRfdHlwZSI6InBhc3N3b3JkIiwidXNlcl9pZCI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsIm9yaWdpbiI6InVhYSIsInVzZXJfbmFtZSI6ImFkbWluIiwiZW1haWwiOiJhZG1pbiIsInJldl9zaWciOiJiMzUyMDU5ZiIsImlhdCI6MTQ3ODUxMzI3NywiZXhwIjoxNDc4NTEzODc3LCJpc3MiOiJodHRwczovL3VhYS5ib3NoLWxpdGUuY29tL29hdXRoL3Rva2VuIiwiemlkIjoidWFhIiwiYXVkIjpbInNjaW0iLCJjbG91ZF9jb250cm9sbGVyIiwicGFzc3dvcmQiLCJjZiIsInVhYSIsIm9wZW5pZCIsImRvcHBsZXIiLCJyb3V0aW5nLnJvdXRlcl9ncm91cHMiXX0.d8YS9HYM2QJ7f3xXjwHjZsGHCD2a4hM3tNQdGUQCJzT45YQkFZAJJDFIn4rai0YXJyswHmNT3K9pwKBzzcVzbe2HoMyI2HhCn3vW45OA7r55ATYmA88F1KkOtGitO_qi5NPhqDlQwg55kr6PzWAE84BXgWwivMXDDcwkyQosVYA"}}), + ghttp.RespondWithJSONEncodedPtr(&sshAccessResponseCode, sshAccessResponse), + ), + ) + }) + + It("uses the client password as a one time code with the UAA", func() { + Expect(fakeUAA.ReceivedRequests()).To(HaveLen(1)) + }) + + It("fetches the app from CC using the bearer token", func() { + Expect(authenErr).NotTo(HaveOccurred()) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(1)) + }) + + It("builds permissions from the process guid of the app", func() { + Expect(permissionsBuilder.BuildCallCount()).To(Equal(1)) + + _, guid, index, metadata := permissionsBuilder.BuildArgsForCall(0) + Expect(guid).To(Equal("app-guid-app-version")) + Expect(index).To(Equal(1)) + Expect(metadata).To(Equal(metadata)) + }) + + It("logs the access to the container by the user", func() { + Eventually(logger).Should(gbytes.Say("test.cf-authenticate.app-access-success.*\"app\":\"1e051b88-a210-40b7-bcca-df645b24b634/1\".*\"principal\":\"36ba11ff-0f6a-4c50-ab34-6fbd286a643e\".*\"username\":\"admin\"")) + }) + + Context("when the token exchange fails", func() { + BeforeEach(func() { + uaaTokenResponseCode = http.StatusBadRequest + }) + + It("fails to authenticate", func() { + Expect(authenErr).To(Equal(authenticators.AuthenticationFailedErr)) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Context("when the app guid is malformed", func() { + BeforeEach(func() { + metadata.UserReturns("cf:%X%FF/1") + }) + + It("fails to authenticate", func() { + Expect(authenErr).To(Equal(authenticators.InvalidCredentialsErr)) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Context("when the index is not an integer", func() { + BeforeEach(func() { + metadata.UserReturns("cf:1e051b88-a210-40b7-bcca-df645b24b634/jim") + }) + + It("fails to authenticate", func() { + Expect(authenErr).To(Equal(authenticators.InvalidCredentialsErr)) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Context("when the username is missing an index", func() { + BeforeEach(func() { + metadata.UserReturns("cf:1e051b88-a210-40b7-bcca-df645b24b634") + }) + + It("fails to authenticate", func() { + Expect(authenErr).To(Equal(authenticators.InvalidCredentialsErr)) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Context("when the index is too big", func() { + BeforeEach(func() { + metadata.UserReturns("cf:1e051b88-a210-40b7-bcca-df645b24b634/" + strconv.FormatInt(int64(math.MaxInt64), 10) + "0") + }) + + It("fails to authenticate", func() { + Expect(authenErr).To(Equal(authenticators.InvalidCredentialsErr)) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Context("when the cc ssh_access check returns a non-200 status code", func() { + BeforeEach(func() { + sshAccessResponseCode = http.StatusInternalServerError + sshAccessResponse = &authenticators.AppSSHResponse{} + }) + + It("fails to authenticate", func() { + Expect(authenErr).To(Equal(authenticators.FetchAppFailedErr)) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(1)) + Eventually(logger).Should(gbytes.Say("test.cf-authenticate.fetching-app-failed.*\"app\":\"1e051b88-a210-40b7-bcca-df645b24b634/1\".*\"principal\":\"36ba11ff-0f6a-4c50-ab34-6fbd286a643e\".*\"username\":\"admin\"")) + }) + }) + + Context("when the cc ssh_access response cannot be parsed", func() { + BeforeEach(func() { + fakeCC.RouteToHandler("GET", "/internal/apps/1e051b88-a210-40b7-bcca-df645b24b634/ssh_access/1", ghttp.CombineHandlers( + ghttp.VerifyRequest("GET", "/internal/apps/1e051b88-a210-40b7-bcca-df645b24b634/ssh_access/1"), + ghttp.VerifyHeader(http.Header{"Authorization": []string{"bearer eyJhbGciOiJSUzI1NiIsImtpZCI6ImxlZ2FjeS10b2tlbi1rZXkiLCJ0eXAiOiJKV1QifQ.eyJqdGkiOiJmMGMyYWRkN2E5MDI0NTQyOWExZTdiMjNjZGVlZjkyZiIsInN1YiI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsInNjb3BlIjpbInJvdXRpbmcucm91dGVyX2dyb3Vwcy5yZWFkIiwiY2xvdWRfY29udHJvbGxlci5yZWFkIiwicGFzc3dvcmQud3JpdGUiLCJjbG91ZF9jb250cm9sbGVyLndyaXRlIiwib3BlbmlkIiwicm91dGluZy5yb3V0ZXJfZ3JvdXBzLndyaXRlIiwiZG9wcGxlci5maXJlaG9zZSIsInNjaW0ud3JpdGUiLCJzY2ltLnJlYWQiLCJjbG91ZF9jb250cm9sbGVyLmFkbWluIiwidWFhLnVzZXIiXSwiY2xpZW50X2lkIjoiY2YiLCJjaWQiOiJjZiIsImF6cCI6ImNmIiwiZ3JhbnRfdHlwZSI6InBhc3N3b3JkIiwidXNlcl9pZCI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsIm9yaWdpbiI6InVhYSIsInVzZXJfbmFtZSI6ImFkbWluIiwiZW1haWwiOiJhZG1pbiIsInJldl9zaWciOiJiMzUyMDU5ZiIsImlhdCI6MTQ3ODUxMzI3NywiZXhwIjoxNDc4NTEzODc3LCJpc3MiOiJodHRwczovL3VhYS5ib3NoLWxpdGUuY29tL29hdXRoL3Rva2VuIiwiemlkIjoidWFhIiwiYXVkIjpbInNjaW0iLCJjbG91ZF9jb250cm9sbGVyIiwicGFzc3dvcmQiLCJjZiIsInVhYSIsIm9wZW5pZCIsImRvcHBsZXIiLCJyb3V0aW5nLnJvdXRlcl9ncm91cHMiXX0.d8YS9HYM2QJ7f3xXjwHjZsGHCD2a4hM3tNQdGUQCJzT45YQkFZAJJDFIn4rai0YXJyswHmNT3K9pwKBzzcVzbe2HoMyI2HhCn3vW45OA7r55ATYmA88F1KkOtGitO_qi5NPhqDlQwg55kr6PzWAE84BXgWwivMXDDcwkyQosVYA"}}), + ghttp.RespondWith(http.StatusOK, "{{"), + )) + }) + + It("fails to authenticate", func() { + Expect(authenErr).To(Equal(authenticators.InvalidCCResponse)) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(1)) + }) + }) + + Context("the the cc ssh_access check times out", func() { + BeforeEach(func() { + ccTempClientTimeout := httpClientTimeout + fakeCC.RouteToHandler("GET", "/internal/apps/1e051b88-a210-40b7-bcca-df645b24b634/ssh_access/1", + func(w http.ResponseWriter, req *http.Request) { + time.Sleep(ccTempClientTimeout * 2) + w.Write([]byte(`[]`)) + }, + ) + }) + + It("fails to authenticate", func() { + Expect(authenErr).To(BeAssignableToTypeOf(&url.Error{})) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(1)) + }) + }) + + Context("when the cc url is misconfigured", func() { + BeforeEach(func() { + ccURL = "http://%FF" + }) + + It("fails to authenticate", func() { + Expect(authenErr).To(HaveOccurred()) + Expect(fakeCC.ReceivedRequests()).To(HaveLen(0)) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/composite_authenticator.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/composite_authenticator.go new file mode 100644 index 0000000000..fc7ba385b8 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/composite_authenticator.go @@ -0,0 +1,29 @@ +package authenticators + +import ( + "regexp" + + "golang.org/x/crypto/ssh" +) + +type CompositeAuthenticator struct { + authenticators map[*regexp.Regexp]PasswordAuthenticator +} + +func NewCompositeAuthenticator(passwordAuthenticators ...PasswordAuthenticator) *CompositeAuthenticator { + authenticators := map[*regexp.Regexp]PasswordAuthenticator{} + for _, a := range passwordAuthenticators { + authenticators[a.UserRegexp()] = a + } + return &CompositeAuthenticator{authenticators: authenticators} +} + +func (a *CompositeAuthenticator) Authenticate(metadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + for userRegexp, authenticator := range a.authenticators { + if userRegexp.MatchString(metadata.User()) { + return authenticator.Authenticate(metadata, password) + } + } + + return nil, InvalidCredentialsErr +} diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/composite_authenticator_test.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/composite_authenticator_test.go new file mode 100644 index 0000000000..9c0bf5b982 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/composite_authenticator_test.go @@ -0,0 +1,136 @@ +package authenticators_test + +import ( + "errors" + "regexp" + + "code.cloudfoundry.org/diego-ssh/authenticators" + "code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("CompositeAuthenticator", func() { + Describe("Authenticate", func() { + var ( + authenticator *authenticators.CompositeAuthenticator + authens []authenticators.PasswordAuthenticator + metadata *fake_ssh.FakeConnMetadata + password []byte + ) + + BeforeEach(func() { + authens = []authenticators.PasswordAuthenticator{} + metadata = &fake_ssh.FakeConnMetadata{} + password = []byte{} + }) + + JustBeforeEach(func() { + authenticator = authenticators.NewCompositeAuthenticator(authens...) + }) + + Context("when no authenticators are specified", func() { + It("fails to authenticate", func() { + _, err := authenticator.Authenticate(metadata, password) + Expect(err).To(Equal(authenticators.InvalidCredentialsErr)) + }) + }) + + Context("when one or more authenticators are specified", func() { + var ( + authenticatorOne *fake_authenticators.FakePasswordAuthenticator + authenticatorTwo *fake_authenticators.FakePasswordAuthenticator + ) + + BeforeEach(func() { + authenticatorOne = &fake_authenticators.FakePasswordAuthenticator{} + authenticatorOne.UserRegexpReturns(regexp.MustCompile("one:.*")) + + authenticatorTwo = &fake_authenticators.FakePasswordAuthenticator{} + authenticatorTwo.UserRegexpReturns(regexp.MustCompile("two:.*")) + + authens = []authenticators.PasswordAuthenticator{ + authenticatorOne, + authenticatorTwo, + } + }) + + Context("and the users realm matches the first authenticator", func() { + BeforeEach(func() { + metadata.UserReturns("one:garbage") + }) + + Context("and the authenticator successfully authenticates", func() { + var permissions *ssh.Permissions + + BeforeEach(func() { + permissions = &ssh.Permissions{} + authenticatorOne.AuthenticateReturns(permissions, nil) + }) + + It("succeeds to authenticate", func() { + perms, err := authenticator.Authenticate(metadata, password) + + Expect(err).NotTo(HaveOccurred()) + Expect(perms).To(Equal(permissions)) + }) + + It("should provide the metadata to the authenticator", func() { + _, err := authenticator.Authenticate(metadata, password) + Expect(err).NotTo(HaveOccurred()) + m, p := authenticatorOne.AuthenticateArgsForCall(0) + + Expect(m).To(Equal(metadata)) + Expect(p).To(Equal(password)) + }) + }) + + Context("and the authenticator fails to authenticate", func() { + BeforeEach(func() { + authenticatorOne.AuthenticateReturns(nil, errors.New("boom")) + }) + + It("fails to authenticate", func() { + _, err := authenticator.Authenticate(metadata, password) + Expect(err).To(MatchError("boom")) + }) + }) + + It("does not attempt to authenticate with any other authenticators", func() { + authenticator.Authenticate(metadata, password) + Expect(authenticatorTwo.AuthenticateCallCount()).To(Equal(0)) + }) + }) + + Context("and the user realm is not valid", func() { + BeforeEach(func() { + metadata.UserReturns("one") + }) + + It("fails to authenticate", func() { + _, err := authenticator.Authenticate(metadata, password) + + Expect(err).To(Equal(authenticators.InvalidCredentialsErr)) + Expect(authenticatorOne.AuthenticateCallCount()).To(Equal(0)) + Expect(authenticatorTwo.AuthenticateCallCount()).To(Equal(0)) + }) + }) + + Context("and the user realm does not match any authenticators", func() { + BeforeEach(func() { + metadata.UserReturns("jim:") + }) + + It("fails to authenticate", func() { + _, err := authenticator.Authenticate(metadata, password) + + Expect(err).To(Equal(authenticators.InvalidCredentialsErr)) + Expect(authenticatorOne.AuthenticateCallCount()).To(Equal(0)) + Expect(authenticatorTwo.AuthenticateCallCount()).To(Equal(0)) + }) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/diego_proxy_authenticator.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/diego_proxy_authenticator.go new file mode 100644 index 0000000000..edcdc51b52 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/diego_proxy_authenticator.go @@ -0,0 +1,65 @@ +package authenticators + +import ( + "bytes" + "regexp" + "strconv" + + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +var DiegoUserRegex *regexp.Regexp = regexp.MustCompile(`diego:([a-zA-Z0-9_-]+)/(\d+)`) + +type DiegoProxyAuthenticator struct { + logger lager.Logger + credentials []byte + permissionsBuilder PermissionsBuilder +} + +func NewDiegoProxyAuthenticator( + logger lager.Logger, + credentials []byte, + permissionsBuilder PermissionsBuilder, +) *DiegoProxyAuthenticator { + return &DiegoProxyAuthenticator{ + logger: logger, + credentials: credentials, + permissionsBuilder: permissionsBuilder, + } +} + +func (dpa *DiegoProxyAuthenticator) UserRegexp() *regexp.Regexp { + return DiegoUserRegex +} + +func (dpa *DiegoProxyAuthenticator) Authenticate(metadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + logger := dpa.logger.Session("diego-authenticate") + logger.Info("authenticate-starting") + defer logger.Info("authenticate-finished") + + if !DiegoUserRegex.MatchString(metadata.User()) { + logger.Error("regex-match-fail", InvalidDomainErr) + return nil, InvalidDomainErr + } + + if !bytes.Equal(dpa.credentials, password) { + logger.Error("invalid-credentials", InvalidCredentialsErr) + return nil, InvalidCredentialsErr + } + + guidAndIndex := DiegoUserRegex.FindStringSubmatch(metadata.User()) + + processGuid := guidAndIndex[1] + index, err := strconv.Atoi(guidAndIndex[2]) + if err != nil { + logger.Error("atoi-failed", err) + return nil, err + } + + permissions, err := dpa.permissionsBuilder.Build(logger, processGuid, index, metadata) + if err != nil { + logger.Error("building-ssh-permissions-failed", err) + } + return permissions, err +} diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/diego_proxy_authenticator_test.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/diego_proxy_authenticator_test.go new file mode 100644 index 0000000000..e8ee0c5565 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/diego_proxy_authenticator_test.go @@ -0,0 +1,114 @@ +package authenticators_test + +import ( + "regexp" + + "code.cloudfoundry.org/diego-ssh/authenticators" + "code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" + "code.cloudfoundry.org/lager/v3/lagertest" + "golang.org/x/crypto/ssh" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("DiegoProxyAuthenticator", func() { + var ( + logger *lagertest.TestLogger + credentials []byte + permissionsBuilder *fake_authenticators.FakePermissionsBuilder + authenticator *authenticators.DiegoProxyAuthenticator + metadata *fake_ssh.FakeConnMetadata + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + credentials = []byte("some-user:some-password") + permissionsBuilder = &fake_authenticators.FakePermissionsBuilder{} + permissionsBuilder.BuildReturns(&ssh.Permissions{}, nil) + authenticator = authenticators.NewDiegoProxyAuthenticator(logger, credentials, permissionsBuilder) + + metadata = &fake_ssh.FakeConnMetadata{} + }) + + Describe("Authenticate", func() { + var ( + password []byte + authErr error + ) + + BeforeEach(func() { + password = []byte{} + }) + + JustBeforeEach(func() { + _, authErr = authenticator.Authenticate(metadata, password) + }) + + Context("when the user name matches the user regex and valid credentials are provided", func() { + BeforeEach(func() { + metadata.UserReturns("diego:some-guid/0") + password = []byte("some-user:some-password") + }) + + It("authenticates the password against the provided user:password", func() { + Expect(authErr).NotTo(HaveOccurred()) + }) + + It("builds permissions for the requested process", func() { + Expect(permissionsBuilder.BuildCallCount()).To(Equal(1)) + _, guid, index, metadata := permissionsBuilder.BuildArgsForCall(0) + Expect(guid).To(Equal("some-guid")) + Expect(index).To(Equal(0)) + Expect(metadata).To(Equal(metadata)) + }) + }) + + Context("when the user name doesn't match the user regex", func() { + BeforeEach(func() { + metadata.UserReturns("dora:some-guid") + }) + + It("fails the authentication", func() { + Expect(authErr).To(MatchError("Invalid authentication domain")) + }) + }) + + Context("when the password doesn't match the provided credentials", func() { + BeforeEach(func() { + metadata.UserReturns("diego:some-guid/0") + password = []byte("cf-user:cf-password") + }) + + It("fails the authentication", func() { + Expect(authErr).To(MatchError("Invalid credentials")) + }) + }) + }) + + Describe("UserRegexp", func() { + var regexp *regexp.Regexp + + BeforeEach(func() { + regexp = authenticator.UserRegexp() + }) + + It("matches diego patterns", func() { + Expect(regexp.MatchString("diego:guid/0")).To(BeTrue()) + Expect(regexp.MatchString("diego:123-abc-def/00")).To(BeTrue()) + Expect(regexp.MatchString("diego:guid/99")).To(BeTrue()) + }) + + It("does not match other patterns", func() { + Expect(regexp.MatchString("diego:some+guid/99")).To(BeFalse()) + Expect(regexp.MatchString("diego:..\\/something/99")).To(BeFalse()) + Expect(regexp.MatchString("diego:guid/")).To(BeFalse()) + Expect(regexp.MatchString("diego:00")).To(BeFalse()) + Expect(regexp.MatchString("diego:/00")).To(BeFalse()) + Expect(regexp.MatchString("cf:guid/0")).To(BeFalse()) + Expect(regexp.MatchString("cf:guid/99")).To(BeFalse()) + Expect(regexp.MatchString("user@guid/0")).To(BeFalse()) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/errors.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/errors.go new file mode 100644 index 0000000000..09cb054388 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/errors.go @@ -0,0 +1,14 @@ +package authenticators + +import "errors" + +var AuthenticationFailedErr = errors.New("Authentication failed") +var FetchAppFailedErr = errors.New("Fetching application data failed") +var InvalidCCResponse = errors.New("Invalid response from Cloud Controller") +var InvalidCredentialsErr error = errors.New("Invalid credentials") +var InvalidDomainErr error = errors.New("Invalid authentication domain") +var InvalidRequestErr = errors.New("CloudController URL Invalid") +var InvalidUserFormatErr = errors.New("Invalid user format") +var NotDiegoErr = errors.New("Diego Not Enabled") +var RouteNotFoundErr error = errors.New("SSH routing info not found") +var SSHDisabledErr = errors.New("SSH Disabled") diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/fake_password_authenticator.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/fake_password_authenticator.go new file mode 100644 index 0000000000..245b332e6d --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/fake_password_authenticator.go @@ -0,0 +1,190 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fake_authenticators + +import ( + "regexp" + "sync" + + "code.cloudfoundry.org/diego-ssh/authenticators" + "golang.org/x/crypto/ssh" +) + +type FakePasswordAuthenticator struct { + AuthenticateStub func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) + authenticateMutex sync.RWMutex + authenticateArgsForCall []struct { + arg1 ssh.ConnMetadata + arg2 []byte + } + authenticateReturns struct { + result1 *ssh.Permissions + result2 error + } + authenticateReturnsOnCall map[int]struct { + result1 *ssh.Permissions + result2 error + } + UserRegexpStub func() *regexp.Regexp + userRegexpMutex sync.RWMutex + userRegexpArgsForCall []struct { + } + userRegexpReturns struct { + result1 *regexp.Regexp + } + userRegexpReturnsOnCall map[int]struct { + result1 *regexp.Regexp + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakePasswordAuthenticator) Authenticate(arg1 ssh.ConnMetadata, arg2 []byte) (*ssh.Permissions, error) { + var arg2Copy []byte + if arg2 != nil { + arg2Copy = make([]byte, len(arg2)) + copy(arg2Copy, arg2) + } + fake.authenticateMutex.Lock() + ret, specificReturn := fake.authenticateReturnsOnCall[len(fake.authenticateArgsForCall)] + fake.authenticateArgsForCall = append(fake.authenticateArgsForCall, struct { + arg1 ssh.ConnMetadata + arg2 []byte + }{arg1, arg2Copy}) + fake.recordInvocation("Authenticate", []interface{}{arg1, arg2Copy}) + authenticateStubCopy := fake.AuthenticateStub + fake.authenticateMutex.Unlock() + if authenticateStubCopy != nil { + return authenticateStubCopy(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + fakeReturns := fake.authenticateReturns + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakePasswordAuthenticator) AuthenticateCallCount() int { + fake.authenticateMutex.RLock() + defer fake.authenticateMutex.RUnlock() + return len(fake.authenticateArgsForCall) +} + +func (fake *FakePasswordAuthenticator) AuthenticateCalls(stub func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error)) { + fake.authenticateMutex.Lock() + defer fake.authenticateMutex.Unlock() + fake.AuthenticateStub = stub +} + +func (fake *FakePasswordAuthenticator) AuthenticateArgsForCall(i int) (ssh.ConnMetadata, []byte) { + fake.authenticateMutex.RLock() + defer fake.authenticateMutex.RUnlock() + argsForCall := fake.authenticateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakePasswordAuthenticator) AuthenticateReturns(result1 *ssh.Permissions, result2 error) { + fake.authenticateMutex.Lock() + defer fake.authenticateMutex.Unlock() + fake.AuthenticateStub = nil + fake.authenticateReturns = struct { + result1 *ssh.Permissions + result2 error + }{result1, result2} +} + +func (fake *FakePasswordAuthenticator) AuthenticateReturnsOnCall(i int, result1 *ssh.Permissions, result2 error) { + fake.authenticateMutex.Lock() + defer fake.authenticateMutex.Unlock() + fake.AuthenticateStub = nil + if fake.authenticateReturnsOnCall == nil { + fake.authenticateReturnsOnCall = make(map[int]struct { + result1 *ssh.Permissions + result2 error + }) + } + fake.authenticateReturnsOnCall[i] = struct { + result1 *ssh.Permissions + result2 error + }{result1, result2} +} + +func (fake *FakePasswordAuthenticator) UserRegexp() *regexp.Regexp { + fake.userRegexpMutex.Lock() + ret, specificReturn := fake.userRegexpReturnsOnCall[len(fake.userRegexpArgsForCall)] + fake.userRegexpArgsForCall = append(fake.userRegexpArgsForCall, struct { + }{}) + fake.recordInvocation("UserRegexp", []interface{}{}) + userRegexpStubCopy := fake.UserRegexpStub + fake.userRegexpMutex.Unlock() + if userRegexpStubCopy != nil { + return userRegexpStubCopy() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.userRegexpReturns + return fakeReturns.result1 +} + +func (fake *FakePasswordAuthenticator) UserRegexpCallCount() int { + fake.userRegexpMutex.RLock() + defer fake.userRegexpMutex.RUnlock() + return len(fake.userRegexpArgsForCall) +} + +func (fake *FakePasswordAuthenticator) UserRegexpCalls(stub func() *regexp.Regexp) { + fake.userRegexpMutex.Lock() + defer fake.userRegexpMutex.Unlock() + fake.UserRegexpStub = stub +} + +func (fake *FakePasswordAuthenticator) UserRegexpReturns(result1 *regexp.Regexp) { + fake.userRegexpMutex.Lock() + defer fake.userRegexpMutex.Unlock() + fake.UserRegexpStub = nil + fake.userRegexpReturns = struct { + result1 *regexp.Regexp + }{result1} +} + +func (fake *FakePasswordAuthenticator) UserRegexpReturnsOnCall(i int, result1 *regexp.Regexp) { + fake.userRegexpMutex.Lock() + defer fake.userRegexpMutex.Unlock() + fake.UserRegexpStub = nil + if fake.userRegexpReturnsOnCall == nil { + fake.userRegexpReturnsOnCall = make(map[int]struct { + result1 *regexp.Regexp + }) + } + fake.userRegexpReturnsOnCall[i] = struct { + result1 *regexp.Regexp + }{result1} +} + +func (fake *FakePasswordAuthenticator) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.authenticateMutex.RLock() + defer fake.authenticateMutex.RUnlock() + fake.userRegexpMutex.RLock() + defer fake.userRegexpMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakePasswordAuthenticator) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ authenticators.PasswordAuthenticator = new(FakePasswordAuthenticator) diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/fake_permissions_builder.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/fake_permissions_builder.go new file mode 100644 index 0000000000..f227ee61e8 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/fake_permissions_builder.go @@ -0,0 +1,124 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fake_authenticators + +import ( + "sync" + + "code.cloudfoundry.org/diego-ssh/authenticators" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +type FakePermissionsBuilder struct { + BuildStub func(lager.Logger, string, int, ssh.ConnMetadata) (*ssh.Permissions, error) + buildMutex sync.RWMutex + buildArgsForCall []struct { + arg1 lager.Logger + arg2 string + arg3 int + arg4 ssh.ConnMetadata + } + buildReturns struct { + result1 *ssh.Permissions + result2 error + } + buildReturnsOnCall map[int]struct { + result1 *ssh.Permissions + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakePermissionsBuilder) Build(arg1 lager.Logger, arg2 string, arg3 int, arg4 ssh.ConnMetadata) (*ssh.Permissions, error) { + fake.buildMutex.Lock() + ret, specificReturn := fake.buildReturnsOnCall[len(fake.buildArgsForCall)] + fake.buildArgsForCall = append(fake.buildArgsForCall, struct { + arg1 lager.Logger + arg2 string + arg3 int + arg4 ssh.ConnMetadata + }{arg1, arg2, arg3, arg4}) + fake.recordInvocation("Build", []interface{}{arg1, arg2, arg3, arg4}) + buildStubCopy := fake.BuildStub + fake.buildMutex.Unlock() + if buildStubCopy != nil { + return buildStubCopy(arg1, arg2, arg3, arg4) + } + if specificReturn { + return ret.result1, ret.result2 + } + fakeReturns := fake.buildReturns + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakePermissionsBuilder) BuildCallCount() int { + fake.buildMutex.RLock() + defer fake.buildMutex.RUnlock() + return len(fake.buildArgsForCall) +} + +func (fake *FakePermissionsBuilder) BuildCalls(stub func(lager.Logger, string, int, ssh.ConnMetadata) (*ssh.Permissions, error)) { + fake.buildMutex.Lock() + defer fake.buildMutex.Unlock() + fake.BuildStub = stub +} + +func (fake *FakePermissionsBuilder) BuildArgsForCall(i int) (lager.Logger, string, int, ssh.ConnMetadata) { + fake.buildMutex.RLock() + defer fake.buildMutex.RUnlock() + argsForCall := fake.buildArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakePermissionsBuilder) BuildReturns(result1 *ssh.Permissions, result2 error) { + fake.buildMutex.Lock() + defer fake.buildMutex.Unlock() + fake.BuildStub = nil + fake.buildReturns = struct { + result1 *ssh.Permissions + result2 error + }{result1, result2} +} + +func (fake *FakePermissionsBuilder) BuildReturnsOnCall(i int, result1 *ssh.Permissions, result2 error) { + fake.buildMutex.Lock() + defer fake.buildMutex.Unlock() + fake.BuildStub = nil + if fake.buildReturnsOnCall == nil { + fake.buildReturnsOnCall = make(map[int]struct { + result1 *ssh.Permissions + result2 error + }) + } + fake.buildReturnsOnCall[i] = struct { + result1 *ssh.Permissions + result2 error + }{result1, result2} +} + +func (fake *FakePermissionsBuilder) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.buildMutex.RLock() + defer fake.buildMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakePermissionsBuilder) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ authenticators.PermissionsBuilder = new(FakePermissionsBuilder) diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/fake_public_key_authenticator.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/fake_public_key_authenticator.go new file mode 100644 index 0000000000..1a16b6b204 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/fake_public_key_authenticator.go @@ -0,0 +1,184 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fake_authenticators + +import ( + "sync" + + "code.cloudfoundry.org/diego-ssh/authenticators" + "golang.org/x/crypto/ssh" +) + +type FakePublicKeyAuthenticator struct { + AuthenticateStub func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error) + authenticateMutex sync.RWMutex + authenticateArgsForCall []struct { + arg1 ssh.ConnMetadata + arg2 ssh.PublicKey + } + authenticateReturns struct { + result1 *ssh.Permissions + result2 error + } + authenticateReturnsOnCall map[int]struct { + result1 *ssh.Permissions + result2 error + } + PublicKeyStub func() ssh.PublicKey + publicKeyMutex sync.RWMutex + publicKeyArgsForCall []struct { + } + publicKeyReturns struct { + result1 ssh.PublicKey + } + publicKeyReturnsOnCall map[int]struct { + result1 ssh.PublicKey + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakePublicKeyAuthenticator) Authenticate(arg1 ssh.ConnMetadata, arg2 ssh.PublicKey) (*ssh.Permissions, error) { + fake.authenticateMutex.Lock() + ret, specificReturn := fake.authenticateReturnsOnCall[len(fake.authenticateArgsForCall)] + fake.authenticateArgsForCall = append(fake.authenticateArgsForCall, struct { + arg1 ssh.ConnMetadata + arg2 ssh.PublicKey + }{arg1, arg2}) + fake.recordInvocation("Authenticate", []interface{}{arg1, arg2}) + authenticateStubCopy := fake.AuthenticateStub + fake.authenticateMutex.Unlock() + if authenticateStubCopy != nil { + return authenticateStubCopy(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + fakeReturns := fake.authenticateReturns + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakePublicKeyAuthenticator) AuthenticateCallCount() int { + fake.authenticateMutex.RLock() + defer fake.authenticateMutex.RUnlock() + return len(fake.authenticateArgsForCall) +} + +func (fake *FakePublicKeyAuthenticator) AuthenticateCalls(stub func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error)) { + fake.authenticateMutex.Lock() + defer fake.authenticateMutex.Unlock() + fake.AuthenticateStub = stub +} + +func (fake *FakePublicKeyAuthenticator) AuthenticateArgsForCall(i int) (ssh.ConnMetadata, ssh.PublicKey) { + fake.authenticateMutex.RLock() + defer fake.authenticateMutex.RUnlock() + argsForCall := fake.authenticateArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakePublicKeyAuthenticator) AuthenticateReturns(result1 *ssh.Permissions, result2 error) { + fake.authenticateMutex.Lock() + defer fake.authenticateMutex.Unlock() + fake.AuthenticateStub = nil + fake.authenticateReturns = struct { + result1 *ssh.Permissions + result2 error + }{result1, result2} +} + +func (fake *FakePublicKeyAuthenticator) AuthenticateReturnsOnCall(i int, result1 *ssh.Permissions, result2 error) { + fake.authenticateMutex.Lock() + defer fake.authenticateMutex.Unlock() + fake.AuthenticateStub = nil + if fake.authenticateReturnsOnCall == nil { + fake.authenticateReturnsOnCall = make(map[int]struct { + result1 *ssh.Permissions + result2 error + }) + } + fake.authenticateReturnsOnCall[i] = struct { + result1 *ssh.Permissions + result2 error + }{result1, result2} +} + +func (fake *FakePublicKeyAuthenticator) PublicKey() ssh.PublicKey { + fake.publicKeyMutex.Lock() + ret, specificReturn := fake.publicKeyReturnsOnCall[len(fake.publicKeyArgsForCall)] + fake.publicKeyArgsForCall = append(fake.publicKeyArgsForCall, struct { + }{}) + fake.recordInvocation("PublicKey", []interface{}{}) + publicKeyStubCopy := fake.PublicKeyStub + fake.publicKeyMutex.Unlock() + if publicKeyStubCopy != nil { + return publicKeyStubCopy() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.publicKeyReturns + return fakeReturns.result1 +} + +func (fake *FakePublicKeyAuthenticator) PublicKeyCallCount() int { + fake.publicKeyMutex.RLock() + defer fake.publicKeyMutex.RUnlock() + return len(fake.publicKeyArgsForCall) +} + +func (fake *FakePublicKeyAuthenticator) PublicKeyCalls(stub func() ssh.PublicKey) { + fake.publicKeyMutex.Lock() + defer fake.publicKeyMutex.Unlock() + fake.PublicKeyStub = stub +} + +func (fake *FakePublicKeyAuthenticator) PublicKeyReturns(result1 ssh.PublicKey) { + fake.publicKeyMutex.Lock() + defer fake.publicKeyMutex.Unlock() + fake.PublicKeyStub = nil + fake.publicKeyReturns = struct { + result1 ssh.PublicKey + }{result1} +} + +func (fake *FakePublicKeyAuthenticator) PublicKeyReturnsOnCall(i int, result1 ssh.PublicKey) { + fake.publicKeyMutex.Lock() + defer fake.publicKeyMutex.Unlock() + fake.PublicKeyStub = nil + if fake.publicKeyReturnsOnCall == nil { + fake.publicKeyReturnsOnCall = make(map[int]struct { + result1 ssh.PublicKey + }) + } + fake.publicKeyReturnsOnCall[i] = struct { + result1 ssh.PublicKey + }{result1} +} + +func (fake *FakePublicKeyAuthenticator) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.authenticateMutex.RLock() + defer fake.authenticateMutex.RUnlock() + fake.publicKeyMutex.RLock() + defer fake.publicKeyMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakePublicKeyAuthenticator) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ authenticators.PublicKeyAuthenticator = new(FakePublicKeyAuthenticator) diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/package.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/package.go new file mode 100644 index 0000000000..f1991780ae --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators/package.go @@ -0,0 +1 @@ +package fake_authenticators // import "code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators" diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/package.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/package.go new file mode 100644 index 0000000000..4aad9b94b1 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/package.go @@ -0,0 +1 @@ +package authenticators // import "code.cloudfoundry.org/diego-ssh/authenticators" diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/permissions_builder.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/permissions_builder.go new file mode 100644 index 0000000000..646a38c7c4 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/permissions_builder.go @@ -0,0 +1,165 @@ +package authenticators + +import ( + "encoding/json" + "fmt" + + "code.cloudfoundry.org/bbs" + "code.cloudfoundry.org/bbs/models" + "code.cloudfoundry.org/diego-ssh/proxy" + "code.cloudfoundry.org/diego-ssh/routes" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +type permissionsBuilder struct { + bbsClient bbs.InternalClient + useDirectInstanceAddr bool +} + +func NewPermissionsBuilder(bbsClient bbs.InternalClient, useDirectInstanceAddr bool) PermissionsBuilder { + return &permissionsBuilder{ + bbsClient: bbsClient, + useDirectInstanceAddr: useDirectInstanceAddr, + } +} + +func (pb *permissionsBuilder) Build(logger lager.Logger, processGuid string, index int, metadata ssh.ConnMetadata) (*ssh.Permissions, error) { + ind := int32(index) + filter := models.ActualLRPFilter{ + ProcessGuid: processGuid, + Index: &ind, + } + actualLRPs, err := pb.bbsClient.ActualLRPs(logger, "", filter) + if err != nil { + return nil, err + } else if len(actualLRPs) > 1 { + return nil, fmt.Errorf("multiple matching ActualLRP for ProcessGuid: %s, Index: %d", processGuid, ind) + } else if len(actualLRPs) == 0 { + return nil, fmt.Errorf("no matching ActualLRP for ProcessGuid: %s, Index: %d", processGuid, ind) + } + + desired, err := pb.bbsClient.DesiredLRPByProcessGuid(logger, "", processGuid) + if err != nil { + return nil, err + } + + sshRoute, err := getRoutingInfo(desired) + if err != nil { + return nil, err + } + + logMessage := fmt.Sprintf("Successful remote access by %s", metadata.RemoteAddr().String()) + + return pb.createPermissions(sshRoute, actualLRPs[0], desired, logMessage) +} + +func (pb *permissionsBuilder) createPermissions( + sshRoute *routes.SSHRoute, + actual *models.ActualLRP, + desired *models.DesiredLRP, + logMessage string, +) (*ssh.Permissions, error) { + var targetConfig *proxy.TargetConfig + + for _, mapping := range actual.Ports { + if mapping.ContainerPort == sshRoute.ContainerPort { + address := actual.Address + port := mapping.HostPort + var useInstanceAddr bool + switch actual.PreferredAddress { + case models.ActualLRPNetInfo_PreferredAddressInstance: + useInstanceAddr = true + case models.ActualLRPNetInfo_PreferredAddressHost: + useInstanceAddr = false + case models.ActualLRPNetInfo_PreferredAddressUnknown: + useInstanceAddr = pb.useDirectInstanceAddr + } + if useInstanceAddr { + address = actual.InstanceAddress + port = mapping.ContainerPort + } + + tlsAddress := "" + if mapping.HostTlsProxyPort > 0 { + tlsAddress = fmt.Sprintf("%s:%d", actual.Address, mapping.HostTlsProxyPort) + } + + if useInstanceAddr && mapping.ContainerTlsProxyPort > 0 { + tlsAddress = fmt.Sprintf("%s:%d", actual.InstanceAddress, mapping.ContainerTlsProxyPort) + } + + targetConfig = &proxy.TargetConfig{ + Address: fmt.Sprintf("%s:%d", address, port), + TLSAddress: tlsAddress, + ServerCertDomainSAN: actual.ActualLRPInstanceKey.InstanceGuid, + HostFingerprint: sshRoute.HostFingerprint, + User: sshRoute.User, + Password: sshRoute.Password, + PrivateKey: sshRoute.PrivateKey, + } + break + } + } + + if targetConfig == nil { + return &ssh.Permissions{}, nil + } + + targetConfigJson, err := json.Marshal(targetConfig) + if err != nil { + return nil, err + } + + if len(desired.MetricTags) == 0 { + desired.MetricTags = map[string]*models.MetricTagValue{} + } + if _, ok := desired.MetricTags["source_id"]; !ok { + desired.MetricTags["source_id"] = &models.MetricTagValue{Static: desired.LogGuid} + } + if _, ok := desired.MetricTags["instance_id"]; !ok { + desired.MetricTags["instance_id"] = &models.MetricTagValue{Dynamic: models.MetricTagDynamicValueIndex} + } + + tags, err := models.ConvertMetricTags(desired.MetricTags, map[models.MetricTagValue_DynamicValue]interface{}{ + models.MetricTagDynamicValueIndex: int32(actual.Index), + models.MetricTagDynamicValueInstanceGuid: actual.ActualLRPInstanceKey.InstanceGuid, + }) + if err != nil { + return nil, err + } + + logMessageJson, err := json.Marshal(proxy.LogMessage{ + Message: logMessage, + Tags: tags, + }) + if err != nil { + return nil, err + } + + return &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + "log-message": string(logMessageJson), + }, + }, nil +} + +func getRoutingInfo(desired *models.DesiredLRP) (*routes.SSHRoute, error) { + if desired.Routes == nil { + return nil, RouteNotFoundErr + } + + rawMessage := (*desired.Routes)[routes.DIEGO_SSH] + if rawMessage == nil { + return nil, RouteNotFoundErr + } + + var sshRoute routes.SSHRoute + err := json.Unmarshal(*rawMessage, &sshRoute) + if err != nil { + return nil, err + } + + return &sshRoute, nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/permissions_builder_test.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/permissions_builder_test.go new file mode 100644 index 0000000000..6740b0d896 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/permissions_builder_test.go @@ -0,0 +1,331 @@ +package authenticators_test + +import ( + "encoding/json" + "net" + + "code.cloudfoundry.org/bbs/fake_bbs" + "code.cloudfoundry.org/bbs/models" + "code.cloudfoundry.org/diego-ssh/authenticators" + "code.cloudfoundry.org/diego-ssh/routes" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("PermissionsBuilder", func() { + Describe("Build", func() { + var ( + logger *lagertest.TestLogger + expectedRoute routes.SSHRoute + desiredLRP *models.DesiredLRP + actualLRP *models.ActualLRP + bbsClient *fake_bbs.FakeInternalClient + metadata *fake_ssh.FakeConnMetadata + + permissionsBuilder authenticators.PermissionsBuilder + permissions *ssh.Permissions + buildErr error + processGuid string + index int + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + + expectedRoute = routes.SSHRoute{ + ContainerPort: 1111, + PrivateKey: "fake-pem-encoded-key", + HostFingerprint: "host-fingerprint", + User: "user", + Password: "password", + } + + diegoSSHRoutePayload, err := json.Marshal(expectedRoute) + Expect(err).NotTo(HaveOccurred()) + + diegoSSHRouteMessage := json.RawMessage(diegoSSHRoutePayload) + + desiredLRP = &models.DesiredLRP{ + ProcessGuid: "some-guid", + Instances: 2, + Routes: &models.Routes{ + routes.DIEGO_SSH: &diegoSSHRouteMessage, + }, + LogGuid: "log-guid", + MetricTags: map[string]*models.MetricTagValue{ + "some_static_key": &models.MetricTagValue{Static: "some_value"}, + "some_dynamic_key": &models.MetricTagValue{Dynamic: models.MetricTagDynamicValueIndex}, + "some_other_dynamic_key": &models.MetricTagValue{Dynamic: models.MetricTagDynamicValueInstanceGuid}, + }, + } + + actualLRP = &models.ActualLRP{ + ActualLRPKey: models.NewActualLRPKey("some-guid", 1, "some-domain"), + ActualLRPInstanceKey: models.NewActualLRPInstanceKey("some-instance-guid", "some-cell-id"), + ActualLRPNetInfo: models.NewActualLRPNetInfo("1.2.3.4", "2.2.2.2", models.ActualLRPNetInfo_PreferredAddressUnknown, models.NewPortMappingWithTLSProxy(3333, 1111, 2222, 4444)), + } + + bbsClient = new(fake_bbs.FakeInternalClient) + bbsClient.ActualLRPsReturns([]*models.ActualLRP{actualLRP}, nil) + bbsClient.DesiredLRPByProcessGuidReturns(desiredLRP, nil) + + permissionsBuilder = authenticators.NewPermissionsBuilder(bbsClient, false) + + remoteAddr, err := net.ResolveIPAddr("ip", "1.1.1.1") + Expect(err).NotTo(HaveOccurred()) + metadata = &fake_ssh.FakeConnMetadata{} + metadata.RemoteAddrReturns(remoteAddr) + + processGuid = "some-guid" + index = 1 + }) + + JustBeforeEach(func() { + permissions, buildErr = permissionsBuilder.Build(logger, processGuid, index, metadata) + }) + + It("gets information about the desired lrp referenced in the username", func() { + Expect(bbsClient.DesiredLRPByProcessGuidCallCount()).To(Equal(1)) + _, traceId, guid := bbsClient.DesiredLRPByProcessGuidArgsForCall(0) + Expect(traceId).To(BeEmpty()) + Expect(guid).To(Equal("some-guid")) + }) + + It("gets information about the the actual lrp from the username", func() { + Expect(bbsClient.ActualLRPsCallCount()).To(Equal(1)) + + _, traceId, filter := bbsClient.ActualLRPsArgsForCall(0) + Expect(traceId).To(BeEmpty()) + Expect(filter.ProcessGuid).To(Equal("some-guid")) + Expect(*filter.Index).To(BeEquivalentTo(1)) + }) + + Context("ssh-proxy's connect-to-instance-address and rep's advertise-preference-for-instance-address interaction", func() { + var preferredAddress models.ActualLRPNetInfo_PreferredAddress + var connectToInstanceAddress bool + + JustBeforeEach(func() { + actualLRP.ActualLRPNetInfo = + models.NewActualLRPNetInfo("external-ip", "instance-address", preferredAddress, models.NewPortMappingWithTLSProxy(3333, 1111, 2222, 4444)) + + permissionsBuilder = authenticators.NewPermissionsBuilder(bbsClient, connectToInstanceAddress) + permissions, buildErr = permissionsBuilder.Build(logger, processGuid, index, metadata) + }) + + Context("when ssh-proxy is configured to connect to instance address, not Diego cell (external) address", func() { + BeforeEach(func() { + connectToInstanceAddress = true + }) + + Context("when the rep advertises preference for instance address", func() { + BeforeEach(func() { + preferredAddress = models.ActualLRPNetInfo_PreferredAddressInstance + }) + + It("saves the instance address in the critical options of the permissions", func() { + Expect(permissions).NotTo(BeNil()) + Expect(permissions.CriticalOptions).NotTo(BeNil()) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"address":"instance-address:1111"`)) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"tls_address":"instance-address:4444"`)) + }) + }) + + Context("when the rep advertises preference for host address", func() { + BeforeEach(func() { + preferredAddress = models.ActualLRPNetInfo_PreferredAddressHost + }) + + It("saves the Diego cell (external) address in the critical options of the permissions", func() { + Expect(permissions).NotTo(BeNil()) + Expect(permissions.CriticalOptions).NotTo(BeNil()) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"address":"external-ip:3333"`)) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"tls_address":"external-ip:2222"`)) + }) + }) + + Context("when the rep does not have preferrence for address", func() { + BeforeEach(func() { + preferredAddress = models.ActualLRPNetInfo_PreferredAddressUnknown + }) + + It("saves the instance address in the critical options of the permissions", func() { + Expect(permissions).NotTo(BeNil()) + Expect(permissions.CriticalOptions).NotTo(BeNil()) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"address":"instance-address:1111"`)) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"tls_address":"instance-address:4444"`)) + }) + }) + }) + + Context("when ssh-proxy is NOT configured to connect to instance address", func() { + BeforeEach(func() { + connectToInstanceAddress = false + }) + + Context("when the rep advertises preference for instance address", func() { + BeforeEach(func() { + preferredAddress = models.ActualLRPNetInfo_PreferredAddressInstance + }) + + It("saves the instance address in the critical options of the permissions", func() { + Expect(permissions).NotTo(BeNil()) + Expect(permissions.CriticalOptions).NotTo(BeNil()) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"address":"instance-address:1111"`)) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"tls_address":"instance-address:4444"`)) + }) + }) + + Context("when the rep does NOT advertise preference for instance address", func() { + BeforeEach(func() { + preferredAddress = models.ActualLRPNetInfo_PreferredAddressHost + }) + + It("saves the Diego cell (external) address in the critical options of the permissions", func() { + Expect(permissions).NotTo(BeNil()) + Expect(permissions.CriticalOptions).NotTo(BeNil()) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"address":"external-ip:3333"`)) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"tls_address":"external-ip:2222"`)) + }) + }) + + Context("when the rep does not have preferrence for address", func() { + BeforeEach(func() { + preferredAddress = models.ActualLRPNetInfo_PreferredAddressUnknown + }) + + It("saves the Diego cell (external) address in the critical options of the permissions", func() { + Expect(permissions).NotTo(BeNil()) + Expect(permissions.CriticalOptions).NotTo(BeNil()) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"address":"external-ip:3333"`)) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(ContainSubstring(`"tls_address":"external-ip:2222"`)) + }) + }) + }) + }) + + Context("when the tls port isn't set", func() { + BeforeEach(func() { + actualLRP.ActualLRPNetInfo = + models.NewActualLRPNetInfo("1.2.3.4", "2.2.2.2", models.ActualLRPNetInfo_PreferredAddressUnknown, models.NewPortMapping(3333, 1111)) + }) + + It("does not include a tls address in the permissions", func() { + expectedConfig := `{ + "address": "1.2.3.4:3333", + "tls_address": "", + "server_cert_domain_san": "some-instance-guid", + "host_fingerprint": "host-fingerprint", + "private_key": "fake-pem-encoded-key", + "user": "user", + "password": "password" + }` + + Expect(permissions).NotTo(BeNil()) + Expect(permissions.CriticalOptions).NotTo(BeNil()) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(MatchJSON(expectedConfig)) + }) + }) + + It("saves container information in the critical options of the permissions", func() { + expectedConfig := `{ + "address": "1.2.3.4:3333", + "tls_address": "1.2.3.4:2222", + "server_cert_domain_san": "some-instance-guid", + "host_fingerprint": "host-fingerprint", + "private_key": "fake-pem-encoded-key", + "user": "user", + "password": "password" + }` + + Expect(permissions).NotTo(BeNil()) + Expect(permissions.CriticalOptions).NotTo(BeNil()) + Expect(permissions.CriticalOptions["proxy-target-config"]).To(MatchJSON(expectedConfig)) + }) + + It("saves log message information in the critical options of the permissions", func() { + expectedConfig := `{ + "tags": { + "some_static_key": "some_value", + "some_dynamic_key": "1", + "some_other_dynamic_key": "some-instance-guid", + "source_id": "log-guid", + "instance_id": "1" + }, + "message": "Successful remote access by 1.1.1.1" + }` + + Expect(permissions).NotTo(BeNil()) + Expect(permissions.CriticalOptions).NotTo(BeNil()) + Expect(permissions.CriticalOptions["log-message"]).To(MatchJSON(expectedConfig)) + }) + + Context("when getting the desired LRP information fails", func() { + BeforeEach(func() { + bbsClient.DesiredLRPByProcessGuidReturns(nil, &models.Error{}) + }) + + It("returns the error", func() { + Expect(buildErr).To(Equal(&models.Error{})) + }) + }) + + Context("when getting the actual LRP information fails", func() { + BeforeEach(func() { + bbsClient.ActualLRPsReturns(nil, &models.Error{}) + }) + + It("returns the error", func() { + Expect(buildErr).To(Equal(&models.Error{})) + }) + }) + + Context("when the container port cannot be found", func() { + BeforeEach(func() { + actualLRP.Ports = []*models.PortMapping{} + bbsClient.ActualLRPsReturns([]*models.ActualLRP{actualLRP}, nil) + }) + + It("returns an empty permission reference", func() { + Expect(permissions).To(Equal(&ssh.Permissions{})) + }) + }) + + Context("when the desired LRP does not include routes", func() { + BeforeEach(func() { + desiredLRP.Routes = nil + bbsClient.DesiredLRPByProcessGuidReturns(desiredLRP, nil) + }) + + It("fails the authentication", func() { + Expect(buildErr).To(Equal(authenticators.RouteNotFoundErr)) + }) + }) + + Context("when the desired LRP does not include an SSH route", func() { + BeforeEach(func() { + r := *desiredLRP.Routes + delete(r, routes.DIEGO_SSH) + bbsClient.DesiredLRPByProcessGuidReturns(desiredLRP, nil) + }) + + It("fails the authentication", func() { + Expect(buildErr).To(Equal(authenticators.RouteNotFoundErr)) + }) + }) + + Context("when the ssh route fails to unmarshal", func() { + BeforeEach(func() { + message := json.RawMessage([]byte(`{,:`)) + (*desiredLRP.Routes)[routes.DIEGO_SSH] = &message + bbsClient.DesiredLRPByProcessGuidReturns(desiredLRP, nil) + }) + + It("fails the authentication", func() { + Expect(buildErr).To(HaveOccurred()) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/public_key_authenticator.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/public_key_authenticator.go new file mode 100644 index 0000000000..d6d7698fa6 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/public_key_authenticator.go @@ -0,0 +1,32 @@ +package authenticators + +import ( + "bytes" + "errors" + + "golang.org/x/crypto/ssh" +) + +type publicKeyAuthenticator struct { + publicKey ssh.PublicKey + marshaledPublicKey []byte +} + +func NewPublicKeyAuthenticator(publicKey ssh.PublicKey) PublicKeyAuthenticator { + return &publicKeyAuthenticator{ + publicKey: publicKey, + marshaledPublicKey: publicKey.Marshal(), + } +} + +func (a *publicKeyAuthenticator) PublicKey() ssh.PublicKey { + return a.publicKey +} + +func (a *publicKeyAuthenticator) Authenticate(conn ssh.ConnMetadata, publicKey ssh.PublicKey) (*ssh.Permissions, error) { + if bytes.Equal(publicKey.Marshal(), a.marshaledPublicKey) { + return &ssh.Permissions{}, nil + } + + return nil, errors.New("authentication failed") +} diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/public_key_authenticator_test.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/public_key_authenticator_test.go new file mode 100644 index 0000000000..97b19b6d83 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/public_key_authenticator_test.go @@ -0,0 +1,72 @@ +package authenticators_test + +import ( + "code.cloudfoundry.org/diego-ssh/authenticators" + "code.cloudfoundry.org/diego-ssh/keys" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" + "golang.org/x/crypto/ssh" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("PublicKeyAuthenticator", func() { + var ( + publicKey ssh.PublicKey + + authenticator authenticators.PublicKeyAuthenticator + + metadata *fake_ssh.FakeConnMetadata + clientKey ssh.PublicKey + + permissions *ssh.Permissions + authnError error + ) + + BeforeEach(func() { + keyPair, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + Expect(err).NotTo(HaveOccurred()) + + publicKey = keyPair.PublicKey() + + authenticator = authenticators.NewPublicKeyAuthenticator(publicKey) + + metadata = &fake_ssh.FakeConnMetadata{} + clientKey = publicKey + }) + + JustBeforeEach(func() { + permissions, authnError = authenticator.Authenticate(metadata, clientKey) + }) + + It("creates an authenticator", func() { + Expect(authenticator).NotTo(BeNil()) + Expect(authenticator.PublicKey()).To(Equal(publicKey)) + }) + + Describe("Authenticate", func() { + BeforeEach(func() { + clientKey = publicKey + }) + + Context("when the public key matches", func() { + It("does not return an error", func() { + Expect(authnError).NotTo(HaveOccurred()) + Expect(permissions).NotTo(BeNil()) + }) + }) + + Context("when the public key does not match", func() { + BeforeEach(func() { + fakeKey := &fake_ssh.FakePublicKey{} + fakeKey.MarshalReturns([]byte("go-away-alice")) + clientKey = fakeKey + }) + + It("fails the authentication", func() { + Expect(authnError).To(HaveOccurred()) + Expect(permissions).To(BeNil()) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/authenticators/types.go b/src/code.cloudfoundry.org/diego-ssh/authenticators/types.go new file mode 100644 index 0000000000..3ad657d657 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/authenticators/types.go @@ -0,0 +1,26 @@ +package authenticators + +import ( + "regexp" + + "code.cloudfoundry.org/lager/v3" + + "golang.org/x/crypto/ssh" +) + +//go:generate counterfeiter -o fake_authenticators/fake_public_key_authenticator.go . PublicKeyAuthenticator +type PublicKeyAuthenticator interface { + Authenticate(metadata ssh.ConnMetadata, publicKey ssh.PublicKey) (*ssh.Permissions, error) + PublicKey() ssh.PublicKey +} + +//go:generate counterfeiter -o fake_authenticators/fake_password_authenticator.go . PasswordAuthenticator +type PasswordAuthenticator interface { + UserRegexp() *regexp.Regexp + Authenticate(metadata ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) +} + +//go:generate counterfeiter -o fake_authenticators/fake_permissions_builder.go . PermissionsBuilder +type PermissionsBuilder interface { + Build(logger lager.Logger, processGuid string, index int, metadata ssh.ConnMetadata) (*ssh.Permissions, error) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/bin/test.bash b/src/code.cloudfoundry.org/diego-ssh/bin/test.bash new file mode 100755 index 0000000000..59a4faad67 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/bin/test.bash @@ -0,0 +1,10 @@ +#!/bin/bash + +set -eu +set -o pipefail + +# shellcheck disable=SC2068 +# Double-quoting array expansion here causes ginkgo to fail +args=${@} +# run integration and store package in serial +go run github.com/onsi/ginkgo/v2/ginkgo $(echo $args | sed 's/-p //g') diff --git a/src/code.cloudfoundry.org/diego-ssh/bin/test.ps1 b/src/code.cloudfoundry.org/diego-ssh/bin/test.ps1 new file mode 100644 index 0000000000..aea3aaca4f --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/bin/test.ps1 @@ -0,0 +1,19 @@ +$ErrorActionPreference = "Stop"; +trap { $host.SetShouldExit(1) } + +Write-Host "Downloading winpty DLL" +Add-Type -AssemblyName System.IO.Compression, System.IO.Compression.FileSystem +$WINPTY_DIR = "C:\winpty" +$env:WINPTY_DLL_DIR="$WINPTY_DIR\x64\bin" +if(!(Test-Path -Path $env:WINPTY_DLL_DIR )) { + New-Item -ItemType directory -Path $WINPTY_DIR -Force + (New-Object System.Net.WebClient).DownloadFile('https://github.com/rprichard/winpty/releases/download/0.4.3/winpty-0.4.3-msvc2015.zip', "$WINPTY_DIR\winpty.zip") + [System.IO.Compression.ZipFile]::ExtractToDirectory("$WINPTY_DIR\winpty.zip", "$WINPTY_DIR") +} + +Debug "$(gci env:* | sort-object name | Out-String)" + +Invoke-Expression "go run github.com/onsi/ginkgo/v2/ginkgo $args" +if ($LastExitCode -ne 0) { + throw "tests failed" +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/.gitignore b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/.gitignore new file mode 100644 index 0000000000..e7cf9d35a9 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/.gitignore @@ -0,0 +1 @@ +ssh-proxy diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/config.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/config.go new file mode 100644 index 0000000000..2f80058bf2 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/config.go @@ -0,0 +1,106 @@ +package config + +import ( + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "os" + + "code.cloudfoundry.org/debugserver" + loggingclient "code.cloudfoundry.org/diego-logging-client" + "code.cloudfoundry.org/durationjson" + "code.cloudfoundry.org/lager/v3/lagerflags" + "code.cloudfoundry.org/tlsconfig" +) + +type SSHProxyConfig struct { + lagerflags.LagerConfig + debugserver.DebugServerConfig + Address string `json:"address,omitempty"` + HealthCheckAddress string `json:"health_check_address,omitempty"` + DisableHealthCheckServer bool `json:"disable_health_check_server,omitempty"` + HostKey string `json:"host_key"` + BBSAddress string `json:"bbs_address"` + CCAPIURL string `json:"cc_api_url"` + CCAPICACert string `json:"cc_api_ca_cert"` + UAATokenURL string `json:"uaa_token_url"` + UAAPassword string `json:"uaa_password"` + UAAUsername string `json:"uaa_username"` + UAACACert string `json:"uaa_ca_cert"` + SkipCertVerify bool `json:"skip_cert_verify"` + EnableCFAuth bool `json:"enable_cf_auth"` + EnableDiegoAuth bool `json:"enable_diego_auth"` + DiegoCredentials string `json:"diego_credentials"` + BBSCACert string `json:"bbs_ca_cert"` + BBSClientCert string `json:"bbs_client_cert"` + BBSClientKey string `json:"bbs_client_key"` + BBSClientSessionCacheSize int `json:"bbs_client_session_cache_size"` + BBSMaxIdleConnsPerHost int `json:"bbs_max_idle_conns_per_host"` + AllowedCiphers string `json:"allowed_ciphers"` + AllowedMACs string `json:"allowed_macs"` + AllowedKeyExchanges string `json:"allowed_key_exchanges"` + LoggregatorConfig loggingclient.Config `json:"loggregator"` + CommunicationTimeout durationjson.Duration `json:"communication_timeout,omitempty"` + IdleConnectionTimeout durationjson.Duration `json:"idle_connection_timeout,omitempty"` + ConnectToInstanceAddress bool `json:"connect_to_instance_address"` + + BackendsTLSEnabled bool `json:"backends_tls_enabled,omitempty"` + BackendsTLSCACerts string `json:"backends_tls_ca_certificates,omitempty"` + BackendsTLSClientCert string `json:"backends_tls_client_certificate,omitempty"` + BackendsTLSClientKey string `json:"backends_tls_client_private_key,omitempty"` +} + +func NewSSHProxyConfig(configPath string) (SSHProxyConfig, error) { + proxyConfig := SSHProxyConfig{} + + configFile, err := os.Open(configPath) + if err != nil { + return SSHProxyConfig{}, err + } + + defer configFile.Close() + + decoder := json.NewDecoder(configFile) + + err = decoder.Decode(&proxyConfig) + if err != nil { + return SSHProxyConfig{}, err + } + + return proxyConfig, nil +} + +func (c SSHProxyConfig) BackendsTLSConfig() (*tls.Config, error) { + if !c.BackendsTLSEnabled { + return nil, nil + } + + if c.BackendsTLSCACerts == "" { + return nil, errors.New("backend tls ca certificates must be specified if backend TLS is enabled") + } + + rootCAs := x509.NewCertPool() + ca, err := os.ReadFile(c.BackendsTLSCACerts) + if err != nil { + return nil, err + } + + ok := rootCAs.AppendCertsFromPEM(ca) + if !ok { + return nil, errors.New("Failed to parse backends_tls_ca_certificates") + } + + config := &tls.Config{ + RootCAs: rootCAs, + } + + if c.BackendsTLSClientCert == "" || c.BackendsTLSClientKey == "" { + return config, nil + } + + return tlsconfig.Build( + tlsconfig.WithInternalServiceDefaults(), + tlsconfig.WithIdentityFromFile(c.BackendsTLSClientCert, c.BackendsTLSClientKey), + ).Client(tlsconfig.WithAuthority(rootCAs)) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/config_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/config_suite_test.go new file mode 100644 index 0000000000..4d4d240f67 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/config_suite_test.go @@ -0,0 +1,13 @@ +package config_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "testing" +) + +func TestConfig(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Config Suite") +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/config_test.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/config_test.go new file mode 100644 index 0000000000..86eb238bf7 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/config_test.go @@ -0,0 +1,365 @@ +package config_test + +import ( + "crypto/tls" + "os" + "time" + + "code.cloudfoundry.org/debugserver" + "code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config" + "code.cloudfoundry.org/durationjson" + "code.cloudfoundry.org/inigo/helpers/certauthority" + "code.cloudfoundry.org/lager/v3/lagerflags" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("SSHProxyConfig", func() { + Describe("#NewSSHProxyConfig", func() { + var configFilePath, configData string + + BeforeEach(func() { + configData = `{ + "address": "1.1.1.1", + "health_check_address": "2.2.2.2", + "disable_health_check_server": true, + "host_key": "I am a host key.", + "bbs_address": "3.3.3.3", + "cc_api_url": "4.4.4.4", + "cc_api_ca_cert": "I am a cc ca cert.", + "uaa_token_url": "5.5.5.5", + "uaa_password": "uaa-password", + "uaa_username": "uaa-username", + "skip_cert_verify": true, + "communication_timeout": "5s", + "enable_cf_auth": true, + "enable_diego_auth": true, + "diego_credentials": "diego-password", + "bbs_ca_cert": "I am a bbs ca cert.", + "bbs_client_cert": "I am a bbs client cert.", + "bbs_client_key": "I am a bbs client key.", + "bbs_client_session_cache_size": 10, + "bbs_max_idle_conns_per_host": 20, + "allowed_ciphers": "cipher1,cipher2,cipher3", + "allowed_macs": "mac1,mac2,mac3", + "allowed_key_exchanges": "exchange1,exchange2,exchange3", + "log_level": "debug", + "debug_address": "5.5.5.5:9090", + "connect_to_instance_address": true, + "idle_connection_timeout": "5ms", + + "backends_tls_enabled": true, + "backends_tls_ca_certificates": "./some_filepath/ca.crt", + "backends_tls_client_certificate": "./some_filepath/client.crt", + "backends_tls_client_private_key": "./some_filepath/client.key" + }` + }) + + JustBeforeEach(func() { + configFile, err := os.CreateTemp("", "ssh-proxy-config") + Expect(err).NotTo(HaveOccurred()) + + n, err := configFile.WriteString(configData) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(len(configData))) + + err = configFile.Close() + Expect(err).NotTo(HaveOccurred()) + + configFilePath = configFile.Name() + }) + + AfterEach(func() { + err := os.RemoveAll(configFilePath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("correctly parses the config file", func() { + proxyConfig, err := config.NewSSHProxyConfig(configFilePath) + Expect(err).NotTo(HaveOccurred()) + + Expect(proxyConfig).To(Equal(config.SSHProxyConfig{ + Address: "1.1.1.1", + HealthCheckAddress: "2.2.2.2", + DisableHealthCheckServer: true, + HostKey: "I am a host key.", + BBSAddress: "3.3.3.3", + CCAPIURL: "4.4.4.4", + CCAPICACert: "I am a cc ca cert.", + UAATokenURL: "5.5.5.5", + UAAPassword: "uaa-password", + UAAUsername: "uaa-username", + SkipCertVerify: true, + CommunicationTimeout: durationjson.Duration(5 * time.Second), + EnableCFAuth: true, + EnableDiegoAuth: true, + DiegoCredentials: "diego-password", + BBSCACert: "I am a bbs ca cert.", + BBSClientCert: "I am a bbs client cert.", + BBSClientKey: "I am a bbs client key.", + BBSClientSessionCacheSize: 10, + BBSMaxIdleConnsPerHost: 20, + AllowedCiphers: "cipher1,cipher2,cipher3", + AllowedMACs: "mac1,mac2,mac3", + AllowedKeyExchanges: "exchange1,exchange2,exchange3", + ConnectToInstanceAddress: true, + IdleConnectionTimeout: durationjson.Duration(5 * time.Millisecond), + LagerConfig: lagerflags.LagerConfig{ + LogLevel: lagerflags.DEBUG, + }, + DebugServerConfig: debugserver.DebugServerConfig{ + DebugAddress: "5.5.5.5:9090", + }, + + BackendsTLSEnabled: true, + BackendsTLSCACerts: "./some_filepath/ca.crt", + BackendsTLSClientCert: "./some_filepath/client.crt", + BackendsTLSClientKey: "./some_filepath/client.key", + })) + }) + + Context("when the file does not exist", func() { + It("returns an error", func() { + _, err := config.NewSSHProxyConfig("foobar") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when the file does not contain valid json", func() { + BeforeEach(func() { + configData = "{{" + }) + + It("returns an error", func() { + _, err := config.NewSSHProxyConfig(configFilePath) + Expect(err).To(HaveOccurred()) + }) + + Context("because the communication_timeout is not valid", func() { + BeforeEach(func() { + configData = `{"communication_timeout": 4234342342}` + }) + + It("returns an error", func() { + _, err := config.NewSSHProxyConfig(configFilePath) + Expect(err).To(HaveOccurred()) + }) + }) + }) + }) + + Describe("#BackendsTLSConfig", func() { + var ( + sshProxyConfig config.SSHProxyConfig + tlsConfig *tls.Config + getConfigErr error + ca certauthority.CertAuthority + certDepoDir string + ) + + JustBeforeEach(func() { + tlsConfig, getConfigErr = sshProxyConfig.BackendsTLSConfig() + }) + + BeforeEach(func() { + var err error + + certDepoDir, err = os.MkdirTemp("", "") + Expect(err).NotTo(HaveOccurred()) + + ca, err = certauthority.NewCertAuthority(certDepoDir, "ssh-proxy-ca") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(os.RemoveAll(certDepoDir)).To(Succeed()) + }) + + Context("when backends tls is disabled", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSEnabled = false + }) + + It("returns an empty tls config", func() { + Expect(getConfigErr).ToNot(HaveOccurred()) + Expect(tlsConfig).To(BeNil()) + }) + }) + + Context("when backends tls is enabled", func() { + BeforeEach(func() { + _, serverCAFile := ca.CAAndKey() + + clientKeyFile, clientCertFile, err := ca.GenerateSelfSignedCertAndKey("client", []string{}, false) + Expect(err).NotTo(HaveOccurred()) + + sshProxyConfig.BackendsTLSEnabled = true + sshProxyConfig.BackendsTLSCACerts = serverCAFile + sshProxyConfig.BackendsTLSClientCert = clientCertFile + sshProxyConfig.BackendsTLSClientKey = clientKeyFile + }) + + It("returns a tls config", func() { + Expect(getConfigErr).ToNot(HaveOccurred()) + //lint:ignore SA1019 - ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool. + Expect(len(tlsConfig.RootCAs.Subjects())).To(BeNumerically(">", 0)) + Expect(len(tlsConfig.Certificates)).To(BeNumerically(">", 0)) + }) + + Context("when the CA cert file is NOT provided", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSCACerts = "" + }) + + It("returns an error", func() { + Expect(getConfigErr).To(MatchError(ContainSubstring("backend tls ca certificates must be specified"))) + }) + }) + + Context("when the CA cert file is provided but unreadable", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSCACerts = "non-existent-path/ca.crt" + }) + + It("returns an error", func() { + Expect(getConfigErr).To(HaveOccurred()) + Expect(tlsConfig).To(BeNil()) + }) + }) + + Context("when the CA cert is not valid PEM encoded", func() { + var invalidCAPath string + + BeforeEach(func() { + invalidCA, err := os.CreateTemp("", "invalid-ca.crt") + Expect(err).NotTo(HaveOccurred()) + + _, err = invalidCA.WriteString("invalid PEM") + Expect(err).NotTo(HaveOccurred()) + + err = invalidCA.Close() + Expect(err).NotTo(HaveOccurred()) + + invalidCAPath = invalidCA.Name() + sshProxyConfig.BackendsTLSCACerts = invalidCAPath + }) + + AfterEach(func() { + err := os.Remove(invalidCAPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("returns an error", func() { + Expect(getConfigErr).To(MatchError("Failed to parse backends_tls_ca_certificates")) + Expect(tlsConfig).To(BeNil()) + }) + }) + + Context("when the client cert file is NOT provided", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSClientCert = "" + }) + + It("should NOT set the client certificate in the TLS config", func() { + Expect(getConfigErr).ToNot(HaveOccurred()) + Expect(tlsConfig.Certificates).To(HaveLen(0)) + //lint:ignore SA1019 - ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool. + Expect(len(tlsConfig.RootCAs.Subjects())).To(BeNumerically(">", 0)) + }) + }) + + Context("when the client key file is NOT provided", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSClientKey = "" + }) + + It("should NOT set the client certificate in the TLS config", func() { + Expect(getConfigErr).ToNot(HaveOccurred()) + Expect(tlsConfig.Certificates).To(HaveLen(0)) + //lint:ignore SA1019 - ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool. + Expect(len(tlsConfig.RootCAs.Subjects())).To(BeNumerically(">", 0)) + }) + }) + + Context("when the client cert file and the key file are both provided", func() { + Context("when the client cert file cannot be read", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSClientCert = "non-existant-path/client.crt" + }) + + It("returns an error", func() { + Expect(getConfigErr).To(HaveOccurred()) + Expect(tlsConfig).To(BeNil()) + }) + }) + + Context("when the client key file cannot be read", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSClientKey = "non-existant-path/client.key" + }) + + It("returns an error", func() { + Expect(getConfigErr).To(HaveOccurred()) + Expect(tlsConfig).To(BeNil()) + }) + }) + + Context("when the client certificate is not valid PEM encoded", func() { + var invalidCertPath string + + BeforeEach(func() { + invalidCert, err := os.CreateTemp("", "invalid-cert.crt") + Expect(err).NotTo(HaveOccurred()) + + _, err = invalidCert.WriteString("invalid PEM") + Expect(err).NotTo(HaveOccurred()) + + err = invalidCert.Close() + Expect(err).NotTo(HaveOccurred()) + + invalidCertPath = invalidCert.Name() + sshProxyConfig.BackendsTLSClientCert = invalidCertPath + }) + + AfterEach(func() { + err := os.Remove(invalidCertPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("returns an error", func() { + Expect(getConfigErr).To(MatchError(ContainSubstring("failed to load keypair"))) + Expect(tlsConfig).To(BeNil()) + }) + }) + + Context("when the client key is not valid PEM encoded", func() { + var invalidKeyPath string + + BeforeEach(func() { + invalidKey, err := os.CreateTemp("", "invalid-key.key") + Expect(err).NotTo(HaveOccurred()) + + _, err = invalidKey.WriteString("invalid PEM") + Expect(err).NotTo(HaveOccurred()) + + err = invalidKey.Close() + Expect(err).NotTo(HaveOccurred()) + + invalidKeyPath = invalidKey.Name() + sshProxyConfig.BackendsTLSClientKey = invalidKeyPath + }) + + AfterEach(func() { + err := os.Remove(invalidKeyPath) + Expect(err).NotTo(HaveOccurred()) + }) + + It("returns an error", func() { + Expect(getConfigErr).To(MatchError(ContainSubstring("failed to load keypair"))) + Expect(tlsConfig).To(BeNil()) + }) + }) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/package.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/package.go new file mode 100644 index 0000000000..bbdab51451 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config/package.go @@ -0,0 +1 @@ +package config // import "code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config" diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/main.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/main.go new file mode 100644 index 0000000000..9f6fb38508 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/main.go @@ -0,0 +1,252 @@ +package main + +import ( + "errors" + "flag" + "net/url" + "os" + "strings" + "time" + + "code.cloudfoundry.org/bbs" + "code.cloudfoundry.org/debugserver" + loggingclient "code.cloudfoundry.org/diego-logging-client" + "code.cloudfoundry.org/diego-ssh/authenticators" + "code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config" + "code.cloudfoundry.org/diego-ssh/healthcheck" + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/proxy" + "code.cloudfoundry.org/diego-ssh/server" + "code.cloudfoundry.org/go-loggregator/v9/runtimeemitter" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagerflags" + "github.com/tedsuo/ifrit" + "github.com/tedsuo/ifrit/grouper" + "github.com/tedsuo/ifrit/http_server" + "github.com/tedsuo/ifrit/sigmon" + "golang.org/x/crypto/ssh" +) + +var configPath = flag.String( + "config", + "", + "Path to SSH Proxy config.", +) + +func main() { + debugserver.AddFlags(flag.CommandLine) + flag.Parse() + + sshProxyConfig, err := config.NewSSHProxyConfig(*configPath) + if err != nil { + logger, _ := lagerflags.New("ssh-proxy") + logger.Fatal("failed-to-parse-config", err) + } + + logger, reconfigurableSink := lagerflags.NewFromConfig("ssh-proxy", sshProxyConfig.LagerConfig) + + metronClient, err := initializeMetron(logger, sshProxyConfig) + if err != nil { + logger.Error("failed-to-initialize-metron-client", err) + os.Exit(1) + } + + proxySSHServerConfig, err := configureProxy(logger, sshProxyConfig) + if err != nil { + logger.Error("configure-failed", err) + os.Exit(1) + } + + tlsConfig, err := sshProxyConfig.BackendsTLSConfig() + if err != nil { + logger.Error("failed-to-get-tls-config", err) + os.Exit(1) + } + sshProxy := proxy.New(logger, proxySSHServerConfig, metronClient, tlsConfig) + server := server.NewServer(logger, sshProxyConfig.Address, sshProxy, time.Duration(sshProxyConfig.IdleConnectionTimeout)) + + healthCheckHandler := healthcheck.NewHandler(logger) + + members := grouper.Members{ + {Name: "ssh-proxy", Runner: server}, + } + + if !sshProxyConfig.DisableHealthCheckServer { + httpServer := http_server.New(sshProxyConfig.HealthCheckAddress, healthCheckHandler) + members = append(members, grouper.Member{Name: "healthcheck", Runner: httpServer}) + } + + if sshProxyConfig.DebugAddress != "" { + members = append(grouper.Members{{ + Name: "debug-server", Runner: debugserver.Runner(sshProxyConfig.DebugAddress, reconfigurableSink), + }}, members...) + } + + group := grouper.NewOrdered(os.Interrupt, members) + monitor := ifrit.Invoke(sigmon.New(group)) + + logger.Info("started") + + err = <-monitor.Wait() + if err != nil { + logger.Error("exited-with-failure", err) + os.Exit(1) + } + + logger.Info("exited") + os.Exit(0) +} + +func configureProxy(logger lager.Logger, sshProxyConfig config.SSHProxyConfig) (*ssh.ServerConfig, error) { + if sshProxyConfig.BBSAddress == "" { + err := errors.New("bbsAddress is required") + logger.Fatal("bbs-address-required", err) + } + + url, err := url.Parse(sshProxyConfig.BBSAddress) + if err != nil { + logger.Fatal("failed-to-parse-bbs-address", err) + } + + bbsClient := initializeBBSClient(logger, sshProxyConfig) + permissionsBuilder := authenticators.NewPermissionsBuilder(bbsClient, sshProxyConfig.ConnectToInstanceAddress) + + authens := []authenticators.PasswordAuthenticator{} + + if sshProxyConfig.EnableDiegoAuth { + diegoAuthenticator := authenticators.NewDiegoProxyAuthenticator(logger, []byte(sshProxyConfig.DiegoCredentials), permissionsBuilder) + authens = append(authens, diegoAuthenticator) + } + + if sshProxyConfig.EnableCFAuth { + if sshProxyConfig.CCAPIURL == "" { + return nil, errors.New("ccAPIURL is required for Cloud Foundry authentication") + } + + _, err = url.Parse(sshProxyConfig.CCAPIURL) + if err != nil { + return nil, err + } + + if sshProxyConfig.UAAPassword == "" { + return nil, errors.New("UAA password is required for Cloud Foundry authentication") + } + + if sshProxyConfig.UAAUsername == "" { + return nil, errors.New("UAA username is required for Cloud Foundry authentication") + } + + if sshProxyConfig.UAATokenURL == "" { + return nil, errors.New("uaaTokenURL is required for Cloud Foundry authentication") + } + + _, err = url.Parse(sshProxyConfig.UAATokenURL) + if err != nil { + return nil, err + } + + client, err := helpers.NewHTTPSClient(sshProxyConfig.SkipCertVerify, []string{sshProxyConfig.UAACACert, sshProxyConfig.CCAPICACert}, time.Duration(sshProxyConfig.CommunicationTimeout)) + if err != nil { + return nil, err + } + + cfAuthenticator := authenticators.NewCFAuthenticator( + logger, + client, + sshProxyConfig.CCAPIURL, + sshProxyConfig.UAATokenURL, + sshProxyConfig.UAAUsername, + sshProxyConfig.UAAPassword, + permissionsBuilder, + ) + authens = append(authens, cfAuthenticator) + } + + authenticator := authenticators.NewCompositeAuthenticator(authens...) + + sshConfig := &ssh.ServerConfig{ + ServerVersion: "SSH-2.0-diego-ssh-proxy", + PasswordCallback: authenticator.Authenticate, + AuthLogCallback: func(cmd ssh.ConnMetadata, method string, err error) { + if err != nil { + logger.Error("authentication-failed", err, lager.Data{"user": cmd.User()}) + } else { + logger.Info("authentication-attempted", lager.Data{"user": cmd.User()}) + } + }, + } + + sshConfig.SetDefaults() + + if sshProxyConfig.HostKey == "" { + err := errors.New("hostKey is required") + logger.Fatal("host-key-required", err) + } + + key, err := parsePrivateKey(logger, sshProxyConfig.HostKey) + if err != nil { + logger.Fatal("failed-to-parse-host-key", err) + } + + sshConfig.AddHostKey(key) + + if sshProxyConfig.AllowedCiphers != "" { + sshConfig.Config.Ciphers = strings.Split(sshProxyConfig.AllowedCiphers, ",") + } else { + sshConfig.Config.Ciphers = []string{"aes128-gcm@openssh.com", "aes256-ctr", "aes192-ctr", "aes128-ctr"} + } + + if sshProxyConfig.AllowedMACs != "" { + sshConfig.Config.MACs = strings.Split(sshProxyConfig.AllowedMACs, ",") + } else { + sshConfig.Config.MACs = []string{"hmac-sha2-256-etm@openssh.com", "hmac-sha2-256"} + } + + if sshProxyConfig.AllowedKeyExchanges != "" { + sshConfig.Config.KeyExchanges = strings.Split(sshProxyConfig.AllowedKeyExchanges, ",") + } else { + sshConfig.Config.KeyExchanges = []string{"curve25519-sha256@libssh.org"} + } + + return sshConfig, err +} + +func parsePrivateKey(logger lager.Logger, encodedKey string) (ssh.Signer, error) { + key, err := ssh.ParsePrivateKey([]byte(encodedKey)) + if err != nil { + logger.Error("failed-to-parse-private-key", err) + return nil, err + } + return key, nil +} + +func initializeBBSClient(logger lager.Logger, sshProxyConfig config.SSHProxyConfig) bbs.InternalClient { + bbsClient, err := bbs.NewClientWithConfig(bbs.ClientConfig{ + URL: sshProxyConfig.BBSAddress, + IsTLS: true, + CAFile: sshProxyConfig.BBSCACert, + CertFile: sshProxyConfig.BBSClientCert, + KeyFile: sshProxyConfig.BBSClientKey, + ClientSessionCacheSize: sshProxyConfig.BBSClientSessionCacheSize, + MaxIdleConnsPerHost: sshProxyConfig.BBSMaxIdleConnsPerHost, + RequestTimeout: time.Duration(sshProxyConfig.CommunicationTimeout), + }) + if err != nil { + logger.Fatal("Failed to configure secure BBS client", err) + } + return bbsClient +} + +func initializeMetron(logger lager.Logger, locketConfig config.SSHProxyConfig) (loggingclient.IngressClient, error) { + client, err := loggingclient.NewIngressClient(locketConfig.LoggregatorConfig) + if err != nil { + return nil, err + } + + if locketConfig.LoggregatorConfig.UseV2API { + emitter := runtimeemitter.NewV1(client) + go emitter.Run() + } + + return client, nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/main_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/main_suite_test.go new file mode 100644 index 0000000000..1b8b6cd04c --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/main_suite_test.go @@ -0,0 +1,134 @@ +package main_test + +import ( + "encoding/json" + "fmt" + "runtime" + "testing" + "time" + + "code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner" + "code.cloudfoundry.org/diego-ssh/keys" + "code.cloudfoundry.org/inigo/helpers/portauthority" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gexec" + "github.com/tedsuo/ifrit" + ginkgomon "github.com/tedsuo/ifrit/ginkgomon_v2" +) + +var ( + sshProxyPath string + sshdPath string + sshdProcess ifrit.Process + + sshdPort uint16 + sshdTLSPort uint16 + sshdContainerPort uint16 + sshdContainerTLSPort uint16 + sshProxyPort uint16 + healthCheckProxyPort uint16 + + sshdAddress string + + hostKeyPem string + privateKeyPem string + publicAuthorizedKey string + + portAllocator portauthority.PortAllocator +) + +func TestSSHProxy(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "SSH Proxy Suite") +} + +var _ = SynchronizedBeforeSuite(func() []byte { + sshProxy, err := gexec.Build("code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy", "-race") + Expect(err).NotTo(HaveOccurred()) + + sshd, err := gexec.Build("code.cloudfoundry.org/diego-ssh/cmd/sshd", "-race") + Expect(err).NotTo(HaveOccurred()) + + hostKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + Expect(err).NotTo(HaveOccurred()) + + privateKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + Expect(err).NotTo(HaveOccurred()) + + payload, err := json.Marshal(map[string]string{ + "ssh-proxy": sshProxy, + "sshd": sshd, + "host-key": hostKey.PEMEncodedPrivateKey(), + "private-key": privateKey.PEMEncodedPrivateKey(), + "authorized-key": privateKey.AuthorizedKey(), + }) + + Expect(err).NotTo(HaveOccurred()) + + return payload +}, func(payload []byte) { + context := map[string]string{} + + err := json.Unmarshal(payload, &context) + Expect(err).NotTo(HaveOccurred()) + + hostKeyPem = context["host-key"] + privateKeyPem = context["private-key"] + publicAuthorizedKey = context["authorized-key"] + + node := GinkgoParallelProcess() + startPort := 1070*node + 10 + portRange := 1000 + endPort := startPort + portRange + + portAllocator, err = portauthority.New(startPort, endPort) + Expect(err).NotTo(HaveOccurred()) + + sshdPort, err = portAllocator.ClaimPorts(1) + Expect(err).NotTo(HaveOccurred()) + + sshdContainerPort, err = portAllocator.ClaimPorts(1) + Expect(err).NotTo(HaveOccurred()) + sshdPath = context["sshd"] + + sshdTLSPort, err = portAllocator.ClaimPorts(1) + Expect(err).NotTo(HaveOccurred()) + + sshdContainerTLSPort, err = portAllocator.ClaimPorts(1) + Expect(err).NotTo(HaveOccurred()) + sshdPath = context["sshd"] + + sshProxyPort, err = portAllocator.ClaimPorts(1) + Expect(err).NotTo(HaveOccurred()) + sshProxyPath = context["ssh-proxy"] + + healthCheckProxyPort, err = portAllocator.ClaimPorts(1) + Expect(err).NotTo(HaveOccurred()) +}) + +var _ = BeforeEach(func() { + + if runtime.GOOS == "windows" { + Skip("SSH not supported on Windows, and SSH proxy never runs on Windows anyway") + } + + sshdAddress = fmt.Sprintf("127.0.0.1:%d", sshdPort) + sshdArgs := testrunner.Args{ + Address: sshdAddress, + HostKey: hostKeyPem, + AuthorizedKey: publicAuthorizedKey, + } + + runner := testrunner.New(sshdPath, sshdArgs) + sshdProcess = ifrit.Invoke(runner) +}) + +var _ = AfterEach(func() { + ginkgomon.Kill(sshdProcess, 5*time.Second) +}) + +var _ = SynchronizedAfterSuite(func() { +}, func() { + gexec.CleanupBuildArtifacts() +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/main_test.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/main_test.go new file mode 100644 index 0000000000..75bf1e7592 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/main_test.go @@ -0,0 +1,1206 @@ +package main_test + +import ( + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "code.cloudfoundry.org/bbs/models" + "code.cloudfoundry.org/diego-logging-client/testhelpers" + "code.cloudfoundry.org/diego-ssh/authenticators" + "code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/config" + "code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/testrunner" + sshdtestrunner "code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner" + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/routes" + "code.cloudfoundry.org/durationjson" + "code.cloudfoundry.org/go-loggregator/v9/rpc/loggregator_v2" + "code.cloudfoundry.org/inigo/helpers/certauthority" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagerflags" + "code.cloudfoundry.org/lager/v3/lagertest" + "code.cloudfoundry.org/tlsconfig" + "github.com/gogo/protobuf/proto" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" + "github.com/onsi/gomega/gexec" + "github.com/onsi/gomega/ghttp" + "github.com/tedsuo/ifrit" + ginkgomon "github.com/tedsuo/ifrit/ginkgomon_v2" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("SSH proxy", Serial, func() { + var ( + fakeBBS *ghttp.Server + fakeUAA *ghttp.Server + fakeCC *ghttp.Server + runner ifrit.Runner + process ifrit.Process + sshProxyConfig *config.SSHProxyConfig + sshProxyConfigPath string + certDepoDir string + ca certauthority.CertAuthority + + address string + healthCheckAddress string + diegoCredentials string + hostKeyFingerprint string + expectedGetActualLRPRequest *models.ActualLRPsRequest + actualLRPsResponse *models.ActualLRPsResponse + getDesiredLRPRequest *models.DesiredLRPByProcessGuidRequest + desiredLRPResponse *models.DesiredLRPResponse + + processGuid string + clientConfig *ssh.ClientConfig + ) + + BeforeEach(func() { + var err error + certDepoDir, err = os.MkdirTemp("", "ssh-proxy-certs-") + Expect(err).NotTo(HaveOccurred()) + + ca, err = certauthority.NewCertAuthority(certDepoDir, "ssh-proxy-ca") + Expect(err).NotTo(HaveOccurred()) + + serverKeyFile, serverCertFile, err := ca.GenerateSelfSignedCertAndKey("server", []string{}, false) + Expect(err).NotTo(HaveOccurred()) + _, serverCAFile := ca.CAAndKey() + + fakeBBS = ghttp.NewUnstartedServer() + fakeBBS.HTTPTestServer.TLS, err = tlsconfig.Build( + tlsconfig.WithInternalServiceDefaults(), + tlsconfig.WithIdentityFromFile(serverCertFile, serverKeyFile), + ).Server(tlsconfig.WithClientAuthenticationFromFile(serverCAFile)) + Expect(err).NotTo(HaveOccurred()) + fakeBBS.HTTPTestServer.StartTLS() + + fakeUAA = ghttp.NewUnstartedServer() + fakeUAA.HTTPTestServer.TLS, err = tlsconfig.Build( + tlsconfig.WithInternalServiceDefaults(), + tlsconfig.WithIdentityFromFile(serverCertFile, serverKeyFile), + ).Server(tlsconfig.WithClientAuthenticationFromFile(serverCAFile)) + Expect(err).NotTo(HaveOccurred()) + fakeUAA.HTTPTestServer.TLS.ClientAuth = tls.NoClientCert + fakeUAA.HTTPTestServer.StartTLS() + + fakeCC = ghttp.NewUnstartedServer() + fakeCC.HTTPTestServer.TLS, err = tlsconfig.Build( + tlsconfig.WithInternalServiceDefaults(), + tlsconfig.WithIdentityFromFile(serverCertFile, serverKeyFile), + ).Server(tlsconfig.WithClientAuthenticationFromFile(serverCAFile)) + Expect(err).NotTo(HaveOccurred()) + fakeCC.HTTPTestServer.TLS.ClientAuth = tls.NoClientCert + fakeCC.HTTPTestServer.StartTLS() + + privateKey, err := ssh.ParsePrivateKey([]byte(hostKeyPem)) + Expect(err).NotTo(HaveOccurred()) + hostKeyFingerprint = helpers.MD5Fingerprint(privateKey.PublicKey()) + + address = fmt.Sprintf("127.0.0.1:%d", sshProxyPort) + healthCheckAddress = fmt.Sprintf("127.0.0.1:%d", healthCheckProxyPort) + diegoCredentials = "some-creds" + processGuid = "app-guid-app-version" + + u, err := url.Parse(fakeUAA.URL()) + Expect(err).NotTo(HaveOccurred()) + + u.Path = "/oauth/token" + + sshProxyConfig = &config.SSHProxyConfig{} + sshProxyConfig.Address = address + sshProxyConfig.HealthCheckAddress = healthCheckAddress + sshProxyConfig.BBSAddress = fakeBBS.URL() + sshProxyConfig.BBSCACert = serverCAFile + sshProxyConfig.BBSClientCert = serverCertFile + sshProxyConfig.BBSClientKey = serverKeyFile + sshProxyConfig.CCAPIURL = fakeCC.URL() + sshProxyConfig.CCAPICACert = serverCAFile + sshProxyConfig.DiegoCredentials = diegoCredentials + sshProxyConfig.EnableCFAuth = true + sshProxyConfig.EnableDiegoAuth = true + sshProxyConfig.HostKey = hostKeyPem + sshProxyConfig.SkipCertVerify = false + sshProxyConfig.UAATokenURL = u.String() + sshProxyConfig.UAAPassword = "password1" + sshProxyConfig.UAAUsername = "amandaplease" + sshProxyConfig.UAACACert = serverCAFile + sshProxyConfig.IdleConnectionTimeout = durationjson.Duration(500 * time.Millisecond) + sshProxyConfig.CommunicationTimeout = durationjson.Duration(10 * time.Second) + sshProxyConfig.ConnectToInstanceAddress = false + sshProxyConfig.LagerConfig = lagerflags.DefaultLagerConfig() + + expectedGetActualLRPRequest = &models.ActualLRPsRequest{ + ProcessGuid: processGuid, + OptionalIndex: &models.ActualLRPsRequest_Index{Index: 99}, + } + + actualLRPsResponse = &models.ActualLRPsResponse{ + Error: nil, + ActualLrps: []*models.ActualLRP{ + &models.ActualLRP{ + ActualLRPKey: models.NewActualLRPKey(processGuid, 99, "some-domain"), + ActualLRPInstanceKey: models.NewActualLRPInstanceKey("some-instance-guid", "some-cell-id"), + ActualLRPNetInfo: models.NewActualLRPNetInfo("127.0.0.1", "127.0.0.1", models.ActualLRPNetInfo_PreferredAddressUnknown, models.NewPortMappingWithTLSProxy(uint32(sshdPort), uint32(sshdContainerPort), uint32(sshdTLSPort), uint32(sshdContainerTLSPort))), + }, + }, + } + + getDesiredLRPRequest = &models.DesiredLRPByProcessGuidRequest{ + ProcessGuid: processGuid, + } + + sshRoute, err := json.Marshal(routes.SSHRoute{ + ContainerPort: uint32(sshdContainerPort), + PrivateKey: privateKeyPem, + HostFingerprint: hostKeyFingerprint, + }) + Expect(err).NotTo(HaveOccurred()) + + sshRouteMessage := json.RawMessage(sshRoute) + desiredLRPResponse = &models.DesiredLRPResponse{ + Error: nil, + DesiredLrp: &models.DesiredLRP{ + ProcessGuid: processGuid, + Instances: 100, + Routes: &models.Routes{routes.DIEGO_SSH: &sshRouteMessage}, + }, + } + + clientConfig = &ssh.ClientConfig{} + }) + + JustBeforeEach(func() { + fakeBBS.RouteToHandler("POST", "/v1/actual_lrps/list", ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/v1/actual_lrps/list"), + VerifyProto(expectedGetActualLRPRequest), + RespondWithProto(actualLRPsResponse), + )) + fakeBBS.RouteToHandler("POST", "/v1/desired_lrps/get_by_process_guid.r3", ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/v1/desired_lrps/get_by_process_guid.r3"), + VerifyProto(getDesiredLRPRequest), + RespondWithProto(desiredLRPResponse), + )) + + configData, err := json.Marshal(&sshProxyConfig) + Expect(err).NotTo(HaveOccurred()) + + configFile, err := os.CreateTemp("", "ssh-proxy-config") + Expect(err).NotTo(HaveOccurred()) + + n, err := configFile.Write(configData) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(len(configData))) + + sshProxyConfigPath = configFile.Name() + + err = configFile.Close() + Expect(err).NotTo(HaveOccurred()) + + runner = testrunner.New(sshProxyPath, sshProxyConfigPath) + process = ifrit.Invoke(runner) + }) + + AfterEach(func() { + ginkgomon.Kill(process, 3*time.Second) + + err := os.RemoveAll(sshProxyConfigPath) + Expect(err).NotTo(HaveOccurred()) + + Expect(os.RemoveAll(certDepoDir)).To(Succeed()) + + fakeBBS.Close() + fakeUAA.Close() + fakeCC.Close() + }) + + Describe("argument validation", func() { + Context("when the host key is not provided", func() { + BeforeEach(func() { + sshProxyConfig.HostKey = "" + }) + + It("reports the problem and terminates", func() { + Expect(runner).To(gbytes.Say("hostKey is required")) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("when an ill-formed host key is provided", func() { + BeforeEach(func() { + sshProxyConfig.HostKey = "host-key" + }) + + It("reports the problem and terminates", func() { + Expect(runner).To(gbytes.Say("failed-to-parse-host-key")) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("when the BBS address is missing", func() { + BeforeEach(func() { + sshProxyConfig.BBSAddress = "" + }) + + It("reports the problem and terminates", func() { + Expect(runner).To(gbytes.Say("bbsAddress is required")) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("when the BBS address cannot be parsed", func() { + BeforeEach(func() { + sshProxyConfig.BBSAddress = ":://goober-swallow#yuck" + }) + + It("reports the problem and terminates", func() { + Expect(runner).To(gbytes.Say("failed-to-parse-bbs-address")) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("when CF authentication is enabled", func() { + BeforeEach(func() { + sshProxyConfig.EnableCFAuth = true + }) + + Context("when the cc URL is missing", func() { + BeforeEach(func() { + sshProxyConfig.CCAPIURL = "" + }) + + It("reports the problem and terminates", func() { + Expect(runner).To(gbytes.Say("ccAPIURL is required for Cloud Foundry authentication")) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("when the cc URL cannot be parsed", func() { + BeforeEach(func() { + sshProxyConfig.CCAPIURL = ":://goober-swallow#yuck" + }) + + It("reports the problem and terminates", func() { + Expect(runner).To(gbytes.Say("configure-failed")) + Expect(runner).To(gexec.Exit(1)) + }) + }) + + Context("when cc ca cert does not exist", func() { + BeforeEach(func() { + sshProxyConfig.CCAPICACert = "doesnotexist" + }) + + It("exits with an error", func() { + Expect(runner).To(gbytes.Say("failed to read ca cert")) + Expect(runner).To(gexec.Exit(1)) + }) + }) + + Context("when the uaa URL is missing", func() { + BeforeEach(func() { + sshProxyConfig.UAATokenURL = "" + }) + + It("reports the problem and terminates", func() { + Expect(runner).To(gbytes.Say("uaaTokenURL is required for Cloud Foundry authentication")) + Expect(runner).To(gexec.Exit(1)) + }) + }) + + Context("when the UAA password is missing", func() { + BeforeEach(func() { + sshProxyConfig.UAAPassword = "" + }) + + It("exits with an error", func() { + Expect(runner).To(gbytes.Say("UAA password is required for Cloud Foundry authentication")) + Expect(runner).To(gexec.Exit(1)) + }) + }) + + Context("when the UAA username is missing", func() { + BeforeEach(func() { + sshProxyConfig.UAAUsername = "" + }) + + It("exits with an error", func() { + Expect(runner).To(gbytes.Say("UAA username is required for Cloud Foundry authentication")) + Expect(runner).To(gexec.Exit(1)) + }) + }) + + Context("when the UAA URL cannot be parsed", func() { + BeforeEach(func() { + sshProxyConfig.UAATokenURL = ":://spitting#nickles" + }) + + It("reports the problem and terminates", func() { + Expect(runner).To(gbytes.Say("configure-failed")) + Expect(runner).To(gexec.Exit(1)) + }) + }) + + Context("when UAA ca cert does not exist", func() { + BeforeEach(func() { + sshProxyConfig.UAACACert = "doesnotexist" + }) + + It("exits with an error", func() { + Expect(runner).To(gbytes.Say("failed to read ca cert")) + Expect(runner).To(gexec.Exit(1)) + }) + }) + }) + }) + + It("presents the correct host key", func() { + var handshakeHostKey ssh.PublicKey + _, err := ssh.Dial("tcp", address, &ssh.ClientConfig{ + User: "user", + Auth: []ssh.AuthMethod{ssh.Password("")}, + HostKeyCallback: func(_ string, _ net.Addr, key ssh.PublicKey) error { + handshakeHostKey = key + return errors.New("Short-circuit the handshake") + }, + }) + Expect(err).To(HaveOccurred()) + + proxyHostKey, err := ssh.ParsePrivateKey([]byte(hostKeyPem)) + Expect(err).NotTo(HaveOccurred()) + Expect(proxyHostKey.PublicKey().Marshal()).To(Equal(handshakeHostKey.Marshal())) + }) + + Describe("Disabled http healthcheck server", func() { + BeforeEach(func() { + sshProxyConfig.DisableHealthCheckServer = true + }) + + It("is not running the healthcheck process", func() { + req, err := http.NewRequest("GET", "http://"+healthCheckAddress, nil) + Expect(err).NotTo(HaveOccurred()) + _, err = http.DefaultClient.Do(req) + e, ok := err.(net.Error) + Expect(ok).To(BeTrue()) + Expect(e.Error()).To(MatchRegexp(".*connection refused")) + }) + }) + + Describe("http healthcheck server", func() { + var ( + method, path string + resp *http.Response + ) + + JustBeforeEach(func() { + req, err := http.NewRequest(method, "http://"+healthCheckAddress+path, nil) + Expect(err).NotTo(HaveOccurred()) + resp, err = http.DefaultClient.Do(req) + Expect(err).NotTo(HaveOccurred()) + }) + + Context("valid requests", func() { + BeforeEach(func() { + method = "GET" + path = "/" + }) + + It("returns 200", func() { + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + }) + }) + + Context("invalid requests", func() { + Context("invalid method", func() { + BeforeEach(func() { + method = "POST" + path = "/" + }) + + It("returns 405", func() { + Expect(resp.StatusCode).To(Equal(http.StatusMethodNotAllowed)) + }) + }) + + Context("invalid path", func() { + BeforeEach(func() { + method = "GET" + path = "/foo/bar" + }) + + It("returns 404", func() { + Expect(resp.StatusCode).To(Equal(http.StatusNotFound)) + }) + }) + }) + }) + + Describe("attempting authentication without a realm", func() { + BeforeEach(func() { + clientConfig = &ssh.ClientConfig{ + User: processGuid + "/99", + Auth: []ssh.AuthMethod{ssh.Password(diegoCredentials)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("fails the authentication", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError(ContainSubstring("ssh: handshake failed"))) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Describe("attempting authentication with an unknown realm", func() { + BeforeEach(func() { + clientConfig = &ssh.ClientConfig{ + User: "goo:" + processGuid + "/99", + Auth: []ssh.AuthMethod{ssh.Password(diegoCredentials)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("fails the authentication", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError(ContainSubstring("ssh: handshake failed"))) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Describe("authenticating with the diego realm", func() { + var ( + intermediaryTLSConfig *tls.Config + intermediaryListener net.Listener + connectedToTLS chan struct{} + forwardServer *forwardTLSServer + ) + + BeforeEach(func() { + clientConfig = &ssh.ClientConfig{ + User: "diego:" + processGuid + "/99", + Auth: []ssh.AuthMethod{ssh.Password(diegoCredentials)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + serverKeyFile, serverCertFile, err := ca.GenerateSelfSignedCertAndKey("server", []string{"some-instance-guid"}, false) + Expect(err).NotTo(HaveOccurred()) + _, serverCAFile := ca.CAAndKey() + + intermediaryTLSConfig, err = tlsconfig.Build( + tlsconfig.WithInternalServiceDefaults(), + tlsconfig.WithIdentityFromFile(serverCertFile, serverKeyFile), + ).Server(tlsconfig.WithClientAuthenticationFromFile(serverCAFile)) + Expect(err).NotTo(HaveOccurred()) + + intermediaryListener, err = tls.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", sshdTLSPort), intermediaryTLSConfig) + Expect(err).NotTo(HaveOccurred()) + + connectedToTLS = make(chan struct{}, 1) + logger := lagertest.NewTestLogger("ssh-proxy-test") + forwardServer = NewForwardTLSServer(logger, intermediaryListener, sshdAddress) + }) + + JustBeforeEach(func() { + go forwardServer.Start(connectedToTLS) + }) + + AfterEach(func() { + forwardServer.Stop() + close(connectedToTLS) + }) + + It("acquires the desired and actual LRP info from the BBS", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(2)) + }) + + It("connects to the target daemon", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + output, err := session.Output("echo -n hello") + Expect(err).NotTo(HaveOccurred()) + Expect(string(output)).To(Equal("hello")) + }) + + Context("when a tls intermediary is configured", func() { + Context("when ssh-proxy is configured to connect to the intermediary", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSEnabled = true + }) + + Context("when the tls handshake is via non-MTLS", func() { + BeforeEach(func() { + _, serverCAFile := ca.CAAndKey() + sshProxyConfig.BackendsTLSCACerts = serverCAFile + + intermediaryTLSConfig.ClientAuth = tls.NoClientCert + }) + + It("connects to the target daemon using tls", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + Eventually(connectedToTLS).Should(Receive()) + + _, err = client.NewSession() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("when the tls handshake is via MTLS", func() { + BeforeEach(func() { + serverKeyFile, serverCertFile, err := ca.GenerateSelfSignedCertAndKey("server", []string{}, false) + Expect(err).NotTo(HaveOccurred()) + _, serverCAFile := ca.CAAndKey() + + sshProxyConfig.BackendsTLSCACerts = serverCAFile + sshProxyConfig.BackendsTLSClientCert = serverCertFile + sshProxyConfig.BackendsTLSClientKey = serverKeyFile + + intermediaryTLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + }) + + It("connects to the target daemon using MTLS", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + Eventually(connectedToTLS).Should(Receive()) + + _, err = client.NewSession() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("when connecting using TLS fails", func() { + BeforeEach(func() { + // force TLS handshake to fail + otherCA, err := certauthority.NewCertAuthority(certDepoDir, "other_server_ca") + Expect(err).NotTo(HaveOccurred()) + + _, otherCAFile := otherCA.CAAndKey() + + sshProxyConfig.BackendsTLSCACerts = otherCAFile + + intermediaryTLSConfig.ClientAuth = tls.NoClientCert + }) + + It("connects to the daemon without TLS", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + Consistently(connectedToTLS).ShouldNot(Receive()) + + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + output, err := session.Output("echo -n hello") + Expect(err).NotTo(HaveOccurred()) + Expect(string(output)).To(Equal("hello")) + }) + }) + }) + + Context("when ssh-proxy is NOT configured to connect to the intermediary", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSEnabled = false + }) + + It("connects to the target daemon without using tls", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + Consistently(connectedToTLS).ShouldNot(Receive()) + + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + output, err := session.Output("echo -n hello") + Expect(err).NotTo(HaveOccurred()) + Expect(string(output)).To(Equal("hello")) + }) + }) + }) + + Context("when there is NO tls intermediary configured", func() { + BeforeEach(func() { + intermediaryListener.Close() + }) + + Context("when ssh-proxy is configured to connect to a tls intermediary", func() { + BeforeEach(func() { + sshProxyConfig.BackendsTLSEnabled = true + _, serverCAFile := ca.CAAndKey() + sshProxyConfig.BackendsTLSCACerts = serverCAFile + }) + + It("connects to the daemon without using tls and logs appropriately", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + Consistently(connectedToTLS).ShouldNot(Receive()) + + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + output, err := session.Output("echo -n hello") + Expect(err).NotTo(HaveOccurred()) + Expect(string(output)).To(Equal("hello")) + }) + }) + }) + + It("identifies itself as a Diego SSH proxy server", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + Expect(string(client.Conn.ServerVersion())).To(Equal("SSH-2.0-diego-ssh-proxy")) + }) + + Context("when dealing with an idle connection", func() { + It("eventually times out", func() { + client, err := net.Dial("tcp", address) + Expect(err).NotTo(HaveOccurred()) + + errs := make(chan error) + go func() { + defer GinkgoRecover() + for { + bs := make([]byte, 10) + _, err := client.Read(bs) + errs <- err + } + }() + Eventually(errs).Should(Receive(MatchError("EOF"))) + }) + }) + + Context("metrics", func() { + var ( + testMetricsChan = make(chan *loggregator_v2.Envelope, 10) + signalMetricsChan = make(chan struct{}) + testIngressServer *testhelpers.TestIngressServer + ) + + BeforeEach(func() { + serverKeyFile, serverCertFile, err := ca.GenerateSelfSignedCertAndKey("metron", []string{"metron"}, false) + Expect(err).NotTo(HaveOccurred()) + _, serverCAFile := ca.CAAndKey() + + testIngressServer, err = testhelpers.NewTestIngressServer(serverCertFile, serverKeyFile, serverCAFile) + Expect(err).NotTo(HaveOccurred()) + + receiversChan := testIngressServer.Receivers() + Expect(testIngressServer.Start()).To(Succeed()) + port, err := strconv.Atoi(strings.TrimPrefix(testIngressServer.Addr(), "127.0.0.1:")) + Expect(err).NotTo(HaveOccurred()) + sshProxyConfig.LoggregatorConfig.BatchFlushInterval = 10 * time.Millisecond + sshProxyConfig.LoggregatorConfig.BatchMaxSize = 1 + sshProxyConfig.LoggregatorConfig.APIPort = port + sshProxyConfig.LoggregatorConfig.UseV2API = true + sshProxyConfig.LoggregatorConfig.CACertPath = serverCAFile + sshProxyConfig.LoggregatorConfig.KeyPath = serverKeyFile + sshProxyConfig.LoggregatorConfig.CertPath = serverCertFile + + testMetricsChan, signalMetricsChan = testhelpers.TestMetricChan(receiversChan) + }) + + AfterEach(func() { + testIngressServer.Stop() + close(signalMetricsChan) + }) + + Context("when the loggregator server isn't up", func() { + BeforeEach(func() { + testIngressServer.Stop() + }) + + It("exits with non-zero status code", func() { + Eventually(process.Wait()).Should(Receive(HaveOccurred())) + }) + }) + + Context("when the loggregator agent is up", func() { + JustBeforeEach(func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + _, err = client.NewSession() + Expect(err).NotTo(HaveOccurred()) + }) + + Context("when using loggregator v2 api", func() { + BeforeEach(func() { + sshProxyConfig.LoggregatorConfig.UseV2API = true + }) + + It("emits the number of current ssh-connections", func() { + Eventually(testMetricsChan).Should(Receive(testhelpers.MatchV2MetricAndValue(testhelpers.MetricAndValue{Name: "ssh-connections", Value: int32(1)}))) + }) + }) + + Context("when not using the loggregator v2 api", func() { + BeforeEach(func() { + sshProxyConfig.LoggregatorConfig.UseV2API = false + }) + + It("doesn't emit any metrics", func() { + Consistently(testMetricsChan).ShouldNot(Receive()) + }) + }) + }) + }) + + Context("when the proxy provides an unsupported cipher algorithm", func() { + BeforeEach(func() { + sshProxyConfig.AllowedCiphers = "unsupported" + }) + + It("rejects the cipher algorithm", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError(ContainSubstring("ssh: no common algorithm for client to server cipher"))) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Context("when the proxy provides the default cipher algorithms", func() { + BeforeEach(func() { + clientConfig.Ciphers = []string{"arcfour128"} + }) + + It("errors when the client doesn't provide any of the algorithms: 'aes128-gcm@openssh.com', 'aes128-gcm@openssh.com', 'aes256-ctr', 'aes192-ctr', 'aes128-ctr'", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError("ssh: handshake failed: ssh: no common algorithm for client to server cipher; we offered: [arcfour128], peer offered: [aes128-gcm@openssh.com aes256-ctr aes192-ctr aes128-ctr]")) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Context("when the proxy provides a supported cipher algorithm", func() { + BeforeEach(func() { + sshProxyConfig.AllowedCiphers = "aes128-ctr,aes256-ctr" + clientConfig = &ssh.ClientConfig{ + User: "diego:" + processGuid + "/99", + Auth: []ssh.AuthMethod{ssh.Password(diegoCredentials)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("allows a client to complete a handshake", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("when the proxy provides an unsupported MAC algorithm", func() { + BeforeEach(func() { + sshProxyConfig.AllowedMACs = "unsupported" + }) + + Context("and the cipher is an AEAD cipher", func() { + BeforeEach(func() { + sshProxyConfig.AllowedCiphers = "aes128-gcm@openssh.com" + }) + + It("does not error", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) + + }) + }) + + Context("and the cipher is not an AEAD cipher", func() { + BeforeEach(func() { + sshProxyConfig.AllowedCiphers = "aes256-ctr" + }) + + It("rejects the MAC algorithm", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError(ContainSubstring("ssh: no common algorithm for client to server MAC"))) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(0)) + }) + }) + }) + + Context("when the proxy provides a supported MAC algorithm", func() { + BeforeEach(func() { + sshProxyConfig.AllowedMACs = "hmac-sha2-256,hmac-sha1" + clientConfig = &ssh.ClientConfig{ + User: "diego:" + processGuid + "/99", + Auth: []ssh.AuthMethod{ssh.Password(diegoCredentials)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("allows a client to complete a handshake", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("when the proxy provides the default MAC algorithm", func() { + BeforeEach(func() { + clientConfig.MACs = []string{"hmac-sha1"} + }) + + Context("and the cipher is an AEAD cipher", func() { + BeforeEach(func() { + sshProxyConfig.AllowedCiphers = "aes128-gcm@openssh.com" + }) + + It("does not error", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) + + }) + }) + + Context("and the cipher is not an AEAD cipher", func() { + BeforeEach(func() { + sshProxyConfig.AllowedCiphers = "aes256-ctr" + }) + + It("errors when the client doesn't provide one of the algorithms: 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-256'", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError("ssh: handshake failed: ssh: no common algorithm for client to server MAC; we offered: [hmac-sha1], peer offered: [hmac-sha2-256-etm@openssh.com hmac-sha2-256]")) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(0)) + }) + }) + }) + + Context("when the proxy provides an unsupported key exchange algorithm", func() { + BeforeEach(func() { + sshProxyConfig.AllowedKeyExchanges = "unsupported" + }) + + It("rejects the key exchange algorithm", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError(ContainSubstring("ssh: no common algorithm for key exchange"))) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Context("when the proxy provides a supported key exchange algorithm", func() { + BeforeEach(func() { + sshProxyConfig.AllowedKeyExchanges = "curve25519-sha256@libssh.org,ecdh-sha2-nistp384,diffie-hellman-group14-sha1" + clientConfig = &ssh.ClientConfig{ + User: "diego:" + processGuid + "/99", + Auth: []ssh.AuthMethod{ssh.Password(diegoCredentials)}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("allows a client to complete a handshake", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("when the proxy provides the default KeyExchange algorithm", func() { + BeforeEach(func() { + clientConfig.KeyExchanges = []string{"diffie-hellman-group14-sha1"} + }) + + It("errors when the client doesn't provide the algorithm: 'curve25519-sha256@libssh.org'", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError("ssh: handshake failed: ssh: no common algorithm for key exchange; we offered: [diffie-hellman-group14-sha1 ext-info-c kex-strict-c-v00@openssh.com], peer offered: [curve25519-sha256@libssh.org kex-strict-s-v00@openssh.com]")) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(0)) + }) + }) + + Context("when a non-existent process guid is used", func() { + BeforeEach(func() { + clientConfig.User = "diego:bad-process-guid/999" + expectedGetActualLRPRequest = &models.ActualLRPsRequest{ + ProcessGuid: "bad-process-guid", + OptionalIndex: &models.ActualLRPsRequest_Index{Index: 999}, + } + actualLRPsResponse = &models.ActualLRPsResponse{ + Error: models.ErrResourceNotFound, + } + }) + + It("attempts to acquire the lrp info from the BBS", func() { + _, _ = ssh.Dial("tcp", address, clientConfig) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(1)) + }) + + It("fails the authentication", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError(ContainSubstring("ssh: handshake failed"))) + }) + }) + + Context("when invalid credentials are presented", func() { + BeforeEach(func() { + clientConfig.Auth = []ssh.AuthMethod{ + ssh.Password("bogus-password"), + } + }) + + It("fails diego authentication when the wrong credentials are used", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError(ContainSubstring("ssh: handshake failed"))) + }) + }) + + Context("and the enableDiegoAuth flag is set to false", func() { + BeforeEach(func() { + sshProxyConfig.EnableDiegoAuth = false + }) + + It("fails the authentication", func() { + _, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).To(MatchError(ContainSubstring("ssh: handshake failed"))) + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(0)) + }) + }) + }) + + Describe("authenticating with the cf realm with a one time code", Serial, func() { + BeforeEach(func() { + clientConfig = &ssh.ClientConfig{ + User: "cf:60f0f26e-86b3-4487-8f19-9e94f848f3d2/99", + Auth: []ssh.AuthMethod{ssh.Password("abc123")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + fakeUAA.RouteToHandler("POST", "/oauth/token", ghttp.CombineHandlers( + ghttp.VerifyRequest("POST", "/oauth/token"), + ghttp.VerifyBasicAuth("amandaplease", "password1"), + ghttp.VerifyContentType("application/x-www-form-urlencoded"), + ghttp.VerifyFormKV("grant_type", "authorization_code"), + ghttp.VerifyFormKV("code", "abc123"), + ghttp.RespondWithJSONEncoded(http.StatusOK, authenticators.UAAAuthTokenResponse{ + AccessToken: "eyJhbGciOiJSUzI1NiIsImtpZCI6ImxlZ2FjeS10b2tlbi1rZXkiLCJ0eXAiOiJKV1QifQ.eyJqdGkiOiJmMGMyYWRkN2E5MDI0NTQyOWExZTdiMjNjZGVlZjkyZiIsInN1YiI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsInNjb3BlIjpbInJvdXRpbmcucm91dGVyX2dyb3Vwcy5yZWFkIiwiY2xvdWRfY29udHJvbGxlci5yZWFkIiwicGFzc3dvcmQud3JpdGUiLCJjbG91ZF9jb250cm9sbGVyLndyaXRlIiwib3BlbmlkIiwicm91dGluZy5yb3V0ZXJfZ3JvdXBzLndyaXRlIiwiZG9wcGxlci5maXJlaG9zZSIsInNjaW0ud3JpdGUiLCJzY2ltLnJlYWQiLCJjbG91ZF9jb250cm9sbGVyLmFkbWluIiwidWFhLnVzZXIiXSwiY2xpZW50X2lkIjoiY2YiLCJjaWQiOiJjZiIsImF6cCI6ImNmIiwiZ3JhbnRfdHlwZSI6InBhc3N3b3JkIiwidXNlcl9pZCI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsIm9yaWdpbiI6InVhYSIsInVzZXJfbmFtZSI6ImFkbWluIiwiZW1haWwiOiJhZG1pbiIsInJldl9zaWciOiJiMzUyMDU5ZiIsImlhdCI6MTQ3ODUxMzI3NywiZXhwIjoxNDc4NTEzODc3LCJpc3MiOiJodHRwczovL3VhYS5ib3NoLWxpdGUuY29tL29hdXRoL3Rva2VuIiwiemlkIjoidWFhIiwiYXVkIjpbInNjaW0iLCJjbG91ZF9jb250cm9sbGVyIiwicGFzc3dvcmQiLCJjZiIsInVhYSIsIm9wZW5pZCIsImRvcHBsZXIiLCJyb3V0aW5nLnJvdXRlcl9ncm91cHMiXX0.d8YS9HYM2QJ7f3xXjwHjZsGHCD2a4hM3tNQdGUQCJzT45YQkFZAJJDFIn4rai0YXJyswHmNT3K9pwKBzzcVzbe2HoMyI2HhCn3vW45OA7r55ATYmA88F1KkOtGitO_qi5NPhqDlQwg55kr6PzWAE84BXgWwivMXDDcwkyQosVYA", + TokenType: "bearer", + }), + )) + + fakeCC.RouteToHandler("GET", "/internal/apps/60f0f26e-86b3-4487-8f19-9e94f848f3d2/ssh_access/99", ghttp.CombineHandlers( + ghttp.VerifyRequest("GET", "/internal/apps/60f0f26e-86b3-4487-8f19-9e94f848f3d2/ssh_access/99"), + ghttp.VerifyHeader(http.Header{"Authorization": []string{"bearer eyJhbGciOiJSUzI1NiIsImtpZCI6ImxlZ2FjeS10b2tlbi1rZXkiLCJ0eXAiOiJKV1QifQ.eyJqdGkiOiJmMGMyYWRkN2E5MDI0NTQyOWExZTdiMjNjZGVlZjkyZiIsInN1YiI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsInNjb3BlIjpbInJvdXRpbmcucm91dGVyX2dyb3Vwcy5yZWFkIiwiY2xvdWRfY29udHJvbGxlci5yZWFkIiwicGFzc3dvcmQud3JpdGUiLCJjbG91ZF9jb250cm9sbGVyLndyaXRlIiwib3BlbmlkIiwicm91dGluZy5yb3V0ZXJfZ3JvdXBzLndyaXRlIiwiZG9wcGxlci5maXJlaG9zZSIsInNjaW0ud3JpdGUiLCJzY2ltLnJlYWQiLCJjbG91ZF9jb250cm9sbGVyLmFkbWluIiwidWFhLnVzZXIiXSwiY2xpZW50X2lkIjoiY2YiLCJjaWQiOiJjZiIsImF6cCI6ImNmIiwiZ3JhbnRfdHlwZSI6InBhc3N3b3JkIiwidXNlcl9pZCI6IjM2YmExMWZmLTBmNmEtNGM1MC1hYjM0LTZmYmQyODZhNjQzZSIsIm9yaWdpbiI6InVhYSIsInVzZXJfbmFtZSI6ImFkbWluIiwiZW1haWwiOiJhZG1pbiIsInJldl9zaWciOiJiMzUyMDU5ZiIsImlhdCI6MTQ3ODUxMzI3NywiZXhwIjoxNDc4NTEzODc3LCJpc3MiOiJodHRwczovL3VhYS5ib3NoLWxpdGUuY29tL29hdXRoL3Rva2VuIiwiemlkIjoidWFhIiwiYXVkIjpbInNjaW0iLCJjbG91ZF9jb250cm9sbGVyIiwicGFzc3dvcmQiLCJjZiIsInVhYSIsIm9wZW5pZCIsImRvcHBsZXIiLCJyb3V0aW5nLnJvdXRlcl9ncm91cHMiXX0.d8YS9HYM2QJ7f3xXjwHjZsGHCD2a4hM3tNQdGUQCJzT45YQkFZAJJDFIn4rai0YXJyswHmNT3K9pwKBzzcVzbe2HoMyI2HhCn3vW45OA7r55ATYmA88F1KkOtGitO_qi5NPhqDlQwg55kr6PzWAE84BXgWwivMXDDcwkyQosVYA"}}), + ghttp.RespondWithJSONEncoded(http.StatusOK, authenticators.AppSSHResponse{ + ProcessGuid: processGuid, + }), + )) + }) + + It("provides the access code to the UAA and and gets an access token", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeUAA.ReceivedRequests()).To(HaveLen(1)) + }) + + It("provides a bearer token to the CC and gets the process guid", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeCC.ReceivedRequests()).To(HaveLen(1)) + }) + + It("acquires the lrp info from the BBS using the process guid from the CC", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + err = client.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(fakeBBS.ReceivedRequests()).To(HaveLen(2)) + }) + + It("connects to the target daemon", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + + output, err := session.Output("echo -n hello") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(output)).To(Equal("hello")) + }) + + Context("when the proxy is configured to use direct instance address", Serial, func() { + BeforeEach(func() { + sshProxyConfig.ConnectToInstanceAddress = true + + ginkgomon.Kill(sshdProcess) + sshdArgs := sshdtestrunner.Args{ + Address: fmt.Sprintf("127.0.0.1:%d", uint32(sshdContainerPort)), + HostKey: hostKeyPem, + AuthorizedKey: publicAuthorizedKey, + } + + runner := sshdtestrunner.New(sshdPath, sshdArgs) + sshdProcess = ifrit.Invoke(runner) + }) + + It("connects to the target daemon", func() { + client, err := ssh.Dial("tcp", address, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + + output, err := session.Output("echo -n hello") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(output)).To(Equal("hello")) + }) + }) + }) +}) + +func VerifyProto(expected proto.Message) http.HandlerFunc { + return ghttp.CombineHandlers( + ghttp.VerifyContentType("application/x-protobuf"), + + func(w http.ResponseWriter, req *http.Request) { + defer GinkgoRecover() + body, err := io.ReadAll(req.Body) + Expect(err).ToNot(HaveOccurred()) + + err = req.Body.Close() + Expect(err).NotTo(HaveOccurred()) + + expectedType := reflect.TypeOf(expected) + actualValuePtr := reflect.New(expectedType.Elem()) + + actual, ok := actualValuePtr.Interface().(proto.Message) + Expect(ok).To(BeTrue()) + + err = proto.Unmarshal(body, actual) + Expect(err).ToNot(HaveOccurred()) + + Expect(actual).To(Equal(expected), "ProtoBuf Mismatch") + }, + ) +} + +func RespondWithProto(message proto.Message) http.HandlerFunc { + data, err := proto.Marshal(message) + Expect(err).ToNot(HaveOccurred()) + + var headers = make(http.Header) + headers["Content-Type"] = []string{"application/x-protobuf"} + return ghttp.RespondWith(200, string(data), headers) +} + +type forwardTLSServer struct { + logger lager.Logger + proxy net.Listener + stopCh chan struct{} + address string +} + +func NewForwardTLSServer(logger lager.Logger, proxy net.Listener, address string) *forwardTLSServer { + return &forwardTLSServer{ + logger: logger.Session("forward-tls-server"), + proxy: proxy, + address: address, + stopCh: make(chan struct{}), + } +} + +func (s *forwardTLSServer) Start(onConnectionReceived chan struct{}) error { + for { + select { + case <-s.stopCh: + return nil + default: + conn, err := s.proxy.Accept() + if err != nil { + select { + case <-s.stopCh: + return nil + default: + s.logger.Error("failed-to-receive-connection", err) + return err + } + } + + tlsConn := conn.(*tls.Conn) + err = tlsConn.Handshake() + if err != nil { + select { + case <-s.stopCh: + return nil + default: + s.logger.Error("failed-to-tls-handshake", err) + return err + } + } + + if onConnectionReceived != nil { + onConnectionReceived <- struct{}{} + } + + proxyConn, err := net.Dial("tcp", s.address) + if err != nil { + select { + case <-s.stopCh: + return nil + default: + s.logger.Error("failed-to-dial", err) + return err + } + } + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + _, _ = io.Copy(conn, proxyConn) + wg.Done() + }() + + go func() { + _, _ = io.Copy(proxyConn, conn) + wg.Done() + }() + + wg.Wait() + } + } +} + +func (s *forwardTLSServer) Stop() { + close(s.stopCh) + s.proxy.Close() +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/package.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/package.go new file mode 100644 index 0000000000..64092cc4f5 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/package.go @@ -0,0 +1 @@ +package main // import "code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy" diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/testrunner/package.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/testrunner/package.go new file mode 100644 index 0000000000..7d7fa00617 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/testrunner/package.go @@ -0,0 +1 @@ +package testrunner // import "code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/testrunner" diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/testrunner/runner.go b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/testrunner/runner.go new file mode 100644 index 0000000000..8fffdab5ee --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/ssh-proxy/testrunner/runner.go @@ -0,0 +1,18 @@ +package testrunner + +import ( + "os/exec" + "time" + + ginkgomon "github.com/tedsuo/ifrit/ginkgomon_v2" +) + +func New(binPath string, configPath string) *ginkgomon.Runner { + return ginkgomon.New(ginkgomon.Config{ + Name: "ssh-proxy", + AnsiColorCode: "1;95m", + StartCheck: "ssh-proxy.started", + StartCheckTimeout: 10 * time.Second, + Command: exec.Command(binPath, "-config="+configPath), + }) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/helpers_test.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/helpers_test.go new file mode 100644 index 0000000000..b3d149c3a0 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/helpers_test.go @@ -0,0 +1,25 @@ +//go:build !windows2012R2 + +package main_test + +import ( + "fmt" + + "code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gexec" + "github.com/tedsuo/ifrit" +) + +func buildSshd() string { + sshd, err := gexec.Build("code.cloudfoundry.org/diego-ssh/cmd/sshd", "-race") + Expect(err).NotTo(HaveOccurred()) + return sshd +} + +func startSshd(sshdPath string, args testrunner.Args, address string, port int) (ifrit.Runner, ifrit.Process) { + args.Address = fmt.Sprintf("%s:%d", address, port) + runner := testrunner.New(sshdPath, args) + process := ifrit.Invoke(runner) + return runner, process +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/helpers_windows2012R2_test.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/helpers_windows2012R2_test.go new file mode 100644 index 0000000000..8f756281fb --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/helpers_windows2012R2_test.go @@ -0,0 +1,30 @@ +//go:build windows2012R2 + +package main_test + +import ( + "fmt" + "os" + + "code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gexec" + "github.com/tedsuo/ifrit" +) + +func buildSshd() string { + sshd, err := gexec.Build("code.cloudfoundry.org/diego-ssh/cmd/sshd", "-race", "-tags", "windows2012R2") + Expect(err).NotTo(HaveOccurred()) + return sshd +} + +func startSshd(sshdPath string, args testrunner.Args, address string, port int) (ifrit.Runner, ifrit.Process) { + args.Address = fmt.Sprintf("%s:2222", address) + runner := testrunner.New(sshdPath, args) + runner.Command.Env = append( + os.Environ(), + fmt.Sprintf(`CF_INSTANCE_PORTS=[{"external":%d,"internal":%d}]`, port, 2222), + ) + process := ifrit.Invoke(runner) + return runner, process +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main.go new file mode 100644 index 0000000000..e7113417de --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main.go @@ -0,0 +1,303 @@ +package main + +import ( + "errors" + "flag" + "fmt" + "net" + "os" + "runtime" + "strings" + "syscall" + "time" + + "code.cloudfoundry.org/debugserver" + "code.cloudfoundry.org/diego-ssh/authenticators" + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/globalrequest" + "code.cloudfoundry.org/diego-ssh/keys" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagerflags" + "github.com/tedsuo/ifrit" + "github.com/tedsuo/ifrit/grouper" + "github.com/tedsuo/ifrit/sigmon" + "golang.org/x/crypto/ssh" +) + +var address = flag.String( + "address", + "127.0.0.1:2222", + "listen address for ssh daemon", +) + +var hostKey = flag.String( + "hostKey", + "", + "PEM encoded RSA host key", +) + +var authorizedKey = flag.String( + "authorizedKey", + "", + "Public key in the OpenSSH authorized_keys format", +) + +var allowUnauthenticatedClients = flag.Bool( + "allowUnauthenticatedClients", + false, + "Allow access to unauthenticated clients", +) + +var inheritDaemonEnv = flag.Bool( + "inheritDaemonEnv", + false, + "Inherit daemon's environment", +) + +var allowedCiphers = flag.String( + "allowedCiphers", + "", + "Limit cipher algorithms to those provided (comma separated)", +) + +var allowedMACs = flag.String( + "allowedMACs", + "", + "Limit MAC algorithms to those provided (comma separated)", +) + +var allowedKeyExchanges = flag.String( + "allowedKeyExchanges", + "", + "Limit key exchanges algorithms to those provided (comma separated)", +) + +var hostKeyPEM string +var authorizedKeyValue string + +func runServer() error { + debugserver.AddFlags(flag.CommandLine) + lagerflags.AddFlags(flag.CommandLine) + flag.Parse() + exec := false + + logger, reconfigurableSink := lagerflags.New("sshd") + + hostKeyPEM = os.Getenv("SSHD_HOSTKEY") + if hostKeyPEM != "" { + authorizedKeyValue = os.Getenv("SSHD_AUTHKEY") + + // unset the variables so child processes don't inherit them + os.Unsetenv("SSHD_HOSTKEY") + os.Unsetenv("SSHD_AUTHKEY") + } else { + hostKeyPEM = *hostKey + if hostKeyPEM == "" { + var err error + hostKeyPEM, err = generateNewHostKey() + if err != nil { + logger.Error("failed-to-generate-host-key", err) + return err + } + } + authorizedKeyValue = *authorizedKey + exec = true + } + + if exec && runtime.GOOS != "windows" { + err := os.Setenv("SSHD_HOSTKEY", hostKeyPEM) + if err != nil { + logger.Error("failed-to-set-environment-variable", err, lager.Data{"environment-variable": "SSHD_HOSTKEY"}) + return err + } + + err = os.Setenv("SSHD_AUTHKEY", authorizedKeyValue) + if err != nil { + logger.Error("failed-to-set-environment-variable", err, lager.Data{"environment-variable": "SSHD_AUTHKEY"}) + return err + } + + logLevel := "info" + flag.CommandLine.Lookup("logLevel") + logLevelFlag := flag.CommandLine.Lookup("logLevel") + if logLevelFlag != nil { + logLevel = logLevelFlag.Value.String() + } + + runtime.GOMAXPROCS(1) + err = syscall.Exec(os.Args[0], []string{ + os.Args[0], + fmt.Sprintf("--allowedKeyExchanges=%s", *allowedKeyExchanges), + fmt.Sprintf("--address=%s", *address), + fmt.Sprintf("--allowUnauthenticatedClients=%t", *allowUnauthenticatedClients), + fmt.Sprintf("--inheritDaemonEnv=%t", *inheritDaemonEnv), + fmt.Sprintf("--allowedCiphers=%s", *allowedCiphers), + fmt.Sprintf("--allowedMACs=%s", *allowedMACs), + fmt.Sprintf("--logLevel=%s", logLevel), + fmt.Sprintf("--debugAddr=%s", debugserver.DebugAddress(flag.CommandLine)), + }, os.Environ()) + if err != nil { + logger.Error("failed-exec", err) + return err + } + } + + serverConfig, err := configure(logger) + if err != nil { + logger.Error("configure-failed", err) + return err + } + + runner := handlers.NewCommandRunner() + shellLocator := handlers.NewShellLocator() + dialer := &net.Dialer{} + + sshDaemon := daemon.New( + logger, + serverConfig, + map[string]handlers.GlobalRequestHandler{ + globalrequest.TCPIPForward: new(globalrequest.TCPIPForwardHandler), + globalrequest.CancelTCPIPForward: new(globalrequest.CancelTCPIPForwardHandler), + }, + map[string]handlers.NewChannelHandler{ + "session": handlers.NewSessionChannelHandler(runner, shellLocator, getDaemonEnvironment(), 15*time.Second), + "direct-tcpip": handlers.NewDirectTcpipChannelHandler(dialer), + }, + ) + server, err := createServer(logger, *address, sshDaemon) + if err != nil { + logger.Error("create-server-failure", err) + return err + } + + members := grouper.Members{ + {Name: "sshd", Runner: server}, + } + + if dbgAddr := debugserver.DebugAddress(flag.CommandLine); dbgAddr != "" { + members = append(grouper.Members{ + {Name: "debug-server", Runner: debugserver.Runner(dbgAddr, reconfigurableSink)}, + }, members...) + } + + group := grouper.NewOrdered(os.Interrupt, members) + monitor := ifrit.Invoke(sigmon.New(group)) + + logger.Info("started") + + if err := <-monitor.Wait(); err != nil { + logger.Error("exited-with-failure", err) + return err + } + + logger.Info("exited") + return nil +} + +func main() { + if err := runServer(); err != nil { + os.Exit(1) + } +} + +func getDaemonEnvironment() map[string]string { + daemonEnv := map[string]string{} + + if *inheritDaemonEnv { + envs := os.Environ() + for _, env := range envs { + nvp := strings.SplitN(env, "=", 2) + // account for windows "Path" environment variable! + if len(nvp) == 2 && strings.ToUpper(nvp[0]) != "PATH" { + daemonEnv[nvp[0]] = nvp[1] + } + } + } + return daemonEnv +} + +func configure(logger lager.Logger) (*ssh.ServerConfig, error) { + errorStrings := []string{} + sshConfig := &ssh.ServerConfig{ServerVersion: "SSH-2.0-diego-sshd"} + sshConfig.SetDefaults() + + key, err := acquireHostKey(logger) + if err != nil { + logger.Error("failed-to-acquire-host-key", err) + errorStrings = append(errorStrings, err.Error()) + } + + sshConfig.AddHostKey(key) + sshConfig.NoClientAuth = *allowUnauthenticatedClients + + if authorizedKeyValue == "" && !*allowUnauthenticatedClients { + logger.Error("authorized-key-required", nil) + errorStrings = append(errorStrings, "Public user key is required") + } + + if authorizedKeyValue != "" { + decodedPublicKey, err := decodeAuthorizedKey(logger) + if err == nil { + authenticator := authenticators.NewPublicKeyAuthenticator(decodedPublicKey) + sshConfig.PublicKeyCallback = authenticator.Authenticate + } else { + errorStrings = append(errorStrings, err.Error()) + } + } + + if *allowedCiphers != "" { + sshConfig.Config.Ciphers = strings.Split(*allowedCiphers, ",") + } else { + sshConfig.Config.Ciphers = []string{"aes128-gcm@openssh.com", "aes256-ctr", "aes192-ctr", "aes128-ctr"} + } + + if *allowedMACs != "" { + sshConfig.Config.MACs = strings.Split(*allowedMACs, ",") + } else { + sshConfig.Config.MACs = []string{"hmac-sha2-256-etm@openssh.com", "hmac-sha2-256"} + } + + if *allowedKeyExchanges != "" { + sshConfig.Config.KeyExchanges = strings.Split(*allowedKeyExchanges, ",") + } else { + sshConfig.Config.KeyExchanges = []string{"curve25519-sha256@libssh.org"} + } + + err = nil + if len(errorStrings) > 0 { + err = errors.New(strings.Join(errorStrings, ", ")) + } + + return sshConfig, err +} + +func decodeAuthorizedKey(logger lager.Logger) (ssh.PublicKey, error) { + publicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(authorizedKeyValue)) + return publicKey, err +} + +func acquireHostKey(logger lager.Logger) (ssh.Signer, error) { + var encoded []byte + if hostKeyPEM == "" { + return nil, errors.New("empty-host-key") + } else { + encoded = []byte(hostKeyPEM) + } + + key, err := ssh.ParsePrivateKey(encoded) + if err != nil { + logger.Error("failed-to-parse-host-key", err) + return nil, err + } + return key, nil +} + +func generateNewHostKey() (string, error) { + hostKeyPair, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + + if err != nil { + return "", err + } + return hostKeyPair.PEMEncodedPrivateKey(), nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_port.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_port.go new file mode 100644 index 0000000000..4fb57658f9 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_port.go @@ -0,0 +1,18 @@ +//go:build !windows2012R2 + +package main + +import ( + "time" + + "code.cloudfoundry.org/diego-ssh/server" + "code.cloudfoundry.org/lager/v3" +) + +func createServer( + logger lager.Logger, + address string, + sshDaemon server.ConnectionHandler, +) (*server.Server, error) { + return server.NewServer(logger, address, sshDaemon, 5*time.Minute), nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_port_windows2012R2.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_port_windows2012R2.go new file mode 100644 index 0000000000..66364ffca1 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_port_windows2012R2.go @@ -0,0 +1,41 @@ +//go:build windows2012R2 + +package main + +import ( + "encoding/json" + "net" + "os" + "strconv" + "strings" + "time" + + "code.cloudfoundry.org/diego-ssh/server" + "code.cloudfoundry.org/lager/v3" +) + +type PortMapping struct { + Internal int `json:"internal"` + External int `json:"external"` +} + +func createServer( + logger lager.Logger, + address string, + sshDaemon server.ConnectionHandler, +) (*server.Server, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + jsonPortMappings := os.Getenv("CF_INSTANCE_PORTS") + var portMappings []PortMapping + json.Unmarshal([]byte(jsonPortMappings), &portMappings) + for _, mapping := range portMappings { + if strconv.Itoa(mapping.Internal) == port { + port = strconv.Itoa(mapping.External) + } + } + address = strings.Join([]string{host, port}, ":") + return server.NewServer(logger, address, sshDaemon, 5*time.Minute), err +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_suite_test.go new file mode 100644 index 0000000000..48592f46c4 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_suite_test.go @@ -0,0 +1,84 @@ +package main_test + +import ( + "encoding/json" + "os" + "runtime" + + "code.cloudfoundry.org/diego-ssh/keys" + "code.cloudfoundry.org/inigo/helpers/portauthority" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gexec" + + "testing" +) + +var ( + sshdPath string + + sshdPort uint16 + hostKeyPem string + privateKeyPem string + publicAuthorizedKey string + + portAllocator portauthority.PortAllocator +) + +func TestSSHDaemon(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Sshd Suite") +} + +var _ = SynchronizedBeforeSuite(func() []byte { + if runtime.GOOS == "windows" { + if os.Getenv("WINPTY_DLL_DIR") == "" { + Fail("Missing WINPTY_DLL_DIR environment variable") + } + } + sshd := buildSshd() + + hostKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + Expect(err).NotTo(HaveOccurred()) + + privateKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + Expect(err).NotTo(HaveOccurred()) + + payload, err := json.Marshal(map[string]string{ + "sshd": sshd, + "host-key": hostKey.PEMEncodedPrivateKey(), + "private-key": privateKey.PEMEncodedPrivateKey(), + "authorized-key": privateKey.AuthorizedKey(), + }) + + Expect(err).NotTo(HaveOccurred()) + + return payload +}, func(payload []byte) { + context := map[string]string{} + + err := json.Unmarshal(payload, &context) + Expect(err).NotTo(HaveOccurred()) + + hostKeyPem = context["host-key"] + privateKeyPem = context["private-key"] + publicAuthorizedKey = context["authorized-key"] + + node := GinkgoParallelProcess() + startPort := 1070 * node + portRange := 1000 + endPort := startPort + portRange + + portAllocator, err = portauthority.New(startPort, endPort) + Expect(err).NotTo(HaveOccurred()) + + sshdPort, err = portAllocator.ClaimPorts(1) + Expect(err).NotTo(HaveOccurred()) + + sshdPath = context["sshd"] +}) + +var _ = SynchronizedAfterSuite(func() { +}, func() { + gexec.CleanupBuildArtifacts() +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_test.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_test.go new file mode 100644 index 0000000000..8cf6767f3d --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_test.go @@ -0,0 +1,829 @@ +//go:build !windows2012R2 + +package main_test + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "regexp" + "runtime" + "strconv" + "strings" + "sync" + "time" + + "code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner" + "github.com/tedsuo/ifrit" + ginkgomon "github.com/tedsuo/ifrit/ginkgomon_v2" + "golang.org/x/crypto/ssh" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" + "github.com/onsi/gomega/gexec" + "github.com/onsi/gomega/ghttp" +) + +var _ = Describe("SSH daemon", func() { + var ( + runner ifrit.Runner + process ifrit.Process + + address string + hostKey string + privateKey string + authorizedKey string + + allowedCiphers string + allowedMACs string + allowedKeyExchanges string + + allowUnauthenticatedClients bool + inheritDaemonEnv bool + ) + + BeforeEach(func() { + hostKey = hostKeyPem + privateKey = privateKeyPem + authorizedKey = publicAuthorizedKey + + allowedCiphers = "" + allowedMACs = "" + allowedKeyExchanges = "" + + allowUnauthenticatedClients = false + inheritDaemonEnv = false + address = fmt.Sprintf("127.0.0.1:%d", sshdPort) + }) + + JustBeforeEach(func() { + args := testrunner.Args{ + HostKey: string(hostKey), + AuthorizedKey: string(authorizedKey), + + AllowedCiphers: string(allowedCiphers), + AllowedMACs: string(allowedMACs), + AllowedKeyExchanges: string(allowedKeyExchanges), + + AllowUnauthenticatedClients: allowUnauthenticatedClients, + InheritDaemonEnv: inheritDaemonEnv, + } + + runner, process = startSshd(sshdPath, args, "127.0.0.1", int(sshdPort)) + }) + + AfterEach(func() { + ginkgomon.Kill(process, 3*time.Second) + }) + + Describe("argument validation", func() { + Context("when an ill-formed host key is provided", func() { + BeforeEach(func() { + hostKey = "host-key" + }) + + It("reports and dies", func() { + Expect(runner).To(gbytes.Say("failed-to-parse-host-key")) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("when an ill-formed authorized key is provided", func() { + BeforeEach(func() { + authorizedKey = "authorized-key" + }) + + It("reports and dies", func() { + Expect(runner).To(gbytes.Say(`configure-failed.*ssh: no key found`)) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("the authorized key is not provided", func() { + BeforeEach(func() { + authorizedKey = "" + }) + + Context("and allowUnauthenticatedClients is not true", func() { + BeforeEach(func() { + allowUnauthenticatedClients = false + }) + + It("reports and dies", func() { + Expect(runner).To(gbytes.Say("authorized-key-required")) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("and allowUnauthenticatedClients is true", func() { + BeforeEach(func() { + allowUnauthenticatedClients = true + }) + + It("starts normally", func() { + Expect(process).NotTo(BeNil()) + }) + }) + }) + }) + + Describe("env variable validation", func() { + Context("when an ill-formed host key is provided", func() { + BeforeEach(func() { + hostKey = "invalid-host-key" + }) + + It("reports and dies", func() { + Expect(runner).To(gbytes.Say("failed-to-parse-host-key")) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("when an ill-formed authorized key is provided", func() { + BeforeEach(func() { + authorizedKey = "invalid-authorized-key" + }) + + It("reports and dies", func() { + Expect(runner).To(gbytes.Say(`configure-failed.*ssh: no key found`)) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("the authorized key is not provided", func() { + BeforeEach(func() { + authorizedKey = "" + }) + + Context("and allowUnauthenticatedClients is not true", func() { + BeforeEach(func() { + allowUnauthenticatedClients = false + }) + + It("reports and dies", func() { + Expect(runner).To(gbytes.Say("authorized-key-required")) + Expect(runner).NotTo(gexec.Exit(0)) + }) + }) + + Context("and allowUnauthenticatedClients is true", func() { + BeforeEach(func() { + allowUnauthenticatedClients = true + }) + + It("starts normally", func() { + Expect(process).NotTo(BeNil()) + }) + }) + }) + + Context("when the hostKey is provided as an env variable", func() { + var ( + client *ssh.Client + clientConfig *ssh.ClientConfig + handshakeHostKey ssh.PublicKey + ) + + JustBeforeEach(func() { + Expect(process).NotTo(BeNil()) + client, _ = ssh.Dial("tcp", address, clientConfig) + }) + + AfterEach(func() { + if client != nil { + client.Close() + } + os.Unsetenv("SSHD_HOSTKEY") + }) + + BeforeEach(func() { + hostKey = "host-key" + os.Setenv("SSHD_HOSTKEY", hostKeyPem) + allowUnauthenticatedClients = true + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + handshakeHostKey = key + return nil + }, + } + }) + + It("uses the hostKey from the environment", func() { + sshHostKey, err := ssh.ParsePrivateKey([]byte(hostKeyPem)) + Expect(err).NotTo(HaveOccurred()) + + sshPublicHostKey := sshHostKey.PublicKey() + Expect(sshPublicHostKey.Marshal()).To(Equal(handshakeHostKey.Marshal())) + }) + }) + }) + + Describe("daemon execution", func() { + var ( + client *ssh.Client + dialErr error + clientConfig *ssh.ClientConfig + ) + + JustBeforeEach(func() { + Expect(process).NotTo(BeNil()) + client, dialErr = ssh.Dial("tcp", address, clientConfig) + }) + + AfterEach(func() { + if client != nil { + client.Close() + } + }) + + var ItDoesNotExposeSensitiveInformation = func() { + It("does not expose the key on the command line", func() { + if runtime.GOOS == "windows" { + Skip("no fork/exec on windows") + } + + pid := runner.(*ginkgomon.Runner).Command.Process.Pid + command := exec.Command("ps", "-fp", strconv.Itoa(pid)) + session, err := gexec.Start(command, GinkgoWriter, GinkgoWriter) + Expect(err).NotTo(HaveOccurred()) + Eventually(session).Should(gexec.Exit(0)) + keyRegex := regexp.QuoteMeta(authorizedKey[:len(authorizedKey)-1]) + Expect(session.Out).NotTo(gbytes.Say(keyRegex)) + }) + } + + Context("when a host key is not specified", func() { + BeforeEach(func() { + hostKey = "" + allowUnauthenticatedClients = true + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("generates one internally", func() { + Expect(process).NotTo(BeNil()) + + Expect(client).NotTo(BeNil()) + Expect(dialErr).NotTo(HaveOccurred()) + }) + + ItDoesNotExposeSensitiveInformation() + }) + + Context("when a host key is specified", func() { + var handshakeHostKey ssh.PublicKey + + BeforeEach(func() { + allowUnauthenticatedClients = true + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { + handshakeHostKey = key + return nil + }, + } + }) + + It("uses the host key provided on the command line", func() { + sshHostKey, err := ssh.ParsePrivateKey([]byte(hostKeyPem)) + Expect(err).NotTo(HaveOccurred()) + + sshPublicHostKey := sshHostKey.PublicKey() + Expect(sshPublicHostKey.Marshal()).To(Equal(handshakeHostKey.Marshal())) + }) + + ItDoesNotExposeSensitiveInformation() + }) + + Context("when unauthenticated clients are not allowed", func() { + BeforeEach(func() { + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("rejects the client handshake", func() { + Expect(dialErr).To(MatchError(ContainSubstring("ssh: handshake failed"))) + }) + + Context("and client has a valid private key", func() { + BeforeEach(func() { + key, err := ssh.ParsePrivateKey([]byte(privateKey)) + Expect(err).NotTo(HaveOccurred()) + + clientConfig = &ssh.ClientConfig{ + User: os.Getenv("USER"), + Auth: []ssh.AuthMethod{ + ssh.PublicKeys(key), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("can complete a handshake with the daemon", func() { + Expect(dialErr).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + }) + }) + }) + + Context("when the daemon allows unauthenticated clients", func() { + BeforeEach(func() { + allowUnauthenticatedClients = true + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("allows a client without credentials to complete a handshake", func() { + Expect(dialErr).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + }) + + }) + + Context("when the daemon provides an unsupported cipher algorithm", func() { + BeforeEach(func() { + allowedCiphers = "unsupported" + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("rejects the cipher algorithm", func() { + Expect(dialErr).To(MatchError(ContainSubstring("ssh: no common algorithm for client to server cipher"))) + Expect(client).To(BeNil()) + }) + }) + + Context("when the daemon provides a supported cipher algorithm", func() { + BeforeEach(func() { + allowUnauthenticatedClients = true + allowedCiphers = "aes128-ctr,aes256-ctr" + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("allows a client to complete a handshake", func() { + Expect(dialErr).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + }) + }) + + Context("when the daemon provides an unsupported cipher algorithm", func() { + BeforeEach(func() { + allowUnauthenticatedClients = true + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + clientConfig.Ciphers = []string{"arcfour128"} + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("errors when the client doesn't provide one of the algorithm: 'aes128-gcm@openssh.com', 'aes256-ctr', 'aes192-ctr', 'aes128-ctr'", func() { + Expect(dialErr).To(MatchError("ssh: handshake failed: ssh: no common algorithm for client to server cipher; we offered: [arcfour128], peer offered: [aes128-gcm@openssh.com aes256-ctr aes192-ctr aes128-ctr]")) + Expect(client).To(BeNil()) + }) + }) + + Context("when the daemon provides an unsupported MAC algorithm", func() { + BeforeEach(func() { + allowedMACs = "unsupported" + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("rejects the MAC algorithm", func() { + Expect(dialErr).To(MatchError(ContainSubstring("no supported methods remain"))) + Expect(client).To(BeNil()) + }) + }) + + Context("when the daemon provides a supported MAC algorithm", func() { + BeforeEach(func() { + allowUnauthenticatedClients = true + allowedMACs = "hmac-sha2-256,hmac-sha1" + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("allows a client to complete a handshake", func() { + Expect(dialErr).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + }) + }) + + Context("when the daemon provides an unsupported MAC algorithm", func() { + BeforeEach(func() { + allowUnauthenticatedClients = true + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + clientConfig.MACs = []string{"hmac-sha1"} + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + Context("and the cipher is an AEAD cipher", func() { + BeforeEach(func() { + allowedCiphers = "aes128-gcm@openssh.com" + }) + + It("does not return an error", func() { + Expect(dialErr).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + }) + }) + + Context("and the cipher is not an AEAD cipher", func() { + BeforeEach(func() { + allowedCiphers = "aes128-ctr" + }) + + It("errors when the client doesn't provide one of the algorithms: 'hmac-sha2-256-etm@openssh.com', 'hmac-sha2-256'", func() { + Expect(dialErr).To(MatchError("ssh: handshake failed: ssh: no common algorithm for client to server MAC; we offered: [hmac-sha1], peer offered: [hmac-sha2-256-etm@openssh.com hmac-sha2-256]")) + Expect(client).To(BeNil()) + }) + }) + }) + + Context("when the daemon provides an unsupported key exchange algorithm by the proxy", func() { + BeforeEach(func() { + allowedKeyExchanges = "unsupported" + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("rejects the key exchange algorithm", func() { + Expect(dialErr).To(MatchError(ContainSubstring("ssh: no common algorithm for key exchange"))) + Expect(client).To(BeNil()) + }) + }) + + Context("when the daemon provides a supported key exchange algorithm", func() { + BeforeEach(func() { + allowUnauthenticatedClients = true + allowedKeyExchanges = "curve25519-sha256@libssh.org,ecdh-sha2-nistp384,diffie-hellman-group14-sha1" + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("allows a client to complete a handshake", func() { + Expect(dialErr).NotTo(HaveOccurred()) + Expect(client).NotTo(BeNil()) + }) + }) + + Context("when the daemon provides an unsupported KeyExchange algorithm", func() { + BeforeEach(func() { + allowUnauthenticatedClients = true + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + clientConfig.KeyExchanges = []string{"diffie-hellman-group14-sha1"} + }) + + It("starts the daemon", func() { + Expect(process).NotTo(BeNil()) + }) + + It("errors when the client doesn't provide the algorithm: 'curve25519-sha256@libssh.org'", func() { + Expect(dialErr).To(MatchError("ssh: handshake failed: ssh: no common algorithm for key exchange; we offered: [diffie-hellman-group14-sha1 ext-info-c kex-strict-c-v00@openssh.com], peer offered: [curve25519-sha256@libssh.org kex-strict-s-v00@openssh.com]")) + Expect(client).To(BeNil()) + }) + }) + }) + + Describe("SSH features", func() { + var clientConfig *ssh.ClientConfig + var client *ssh.Client + + BeforeEach(func() { + allowUnauthenticatedClients = true + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + JustBeforeEach(func() { + Expect(process).NotTo(BeNil()) + + var dialErr error + client, dialErr = ssh.Dial("tcp", address, clientConfig) + Expect(dialErr).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + client.Close() + }) + + Context("when a client connects", func() { + It("identifies itself as a diego-ssh server", func() { + Expect(string(client.Conn.ServerVersion())).To(Equal("SSH-2.0-diego-sshd")) + }) + }) + + Context("when a client requests the execution of a command", func() { + It("runs the command", func() { + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + + var cmd string + if runtime.GOOS == "windows" { + cmd = "echo Hello There!" + } else { + cmd = "/bin/echo -n 'Hello There!'" + } + + result, err := session.Output(cmd) + Expect(err).NotTo(HaveOccurred()) + + Expect(strings.TrimSpace(string(result))).To(Equal(strings.TrimSpace("Hello There!"))) + }) + }) + + Context("when a client requests a shell", func() { + Context("when inherit daemon env is enabled", func() { + BeforeEach(func() { + inheritDaemonEnv = true + os.Setenv("TEST", "FOO") + os.Setenv("PATH", os.Getenv("PATH")+":/tmp") + }) + + It("creates a shell environment", func() { + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + + stdout := &bytes.Buffer{} + + session.Stdin = strings.NewReader(envVarCmd("ENV_VAR")) + session.Stdout = stdout + + session.Setenv("ENV_VAR", "env_var_value") + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.String()).To(ContainSubstring("env_var_value")) + }) + + It("inherits daemon's environment excluding PATH", func() { + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + + stdout := &bytes.Buffer{} + + session.Stdin = strings.NewReader(envVarCmd("TEST")) + session.Stdout = stdout + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.String()).To(ContainSubstring("FOO")) + }) + + It("does not inherit the daemon's PATH", func() { + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + + stdout := &bytes.Buffer{} + + session.Stdin = strings.NewReader(envVarCmd("PATH")) + session.Stdout = stdout + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + Expect(stdout.String()).NotTo(ContainSubstring("/tmp")) + }) + }) + + Context("when inherit daemon env is disabled", func() { + BeforeEach(func() { + inheritDaemonEnv = false + os.Setenv("TEST", "FOO") + }) + + It("creates a shell environment", func() { + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + + stdout := &bytes.Buffer{} + + session.Stdin = strings.NewReader(envVarCmd("ENV_VAR")) + session.Stdout = stdout + + session.Setenv("ENV_VAR", "env_var_value") + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.String()).To(ContainSubstring("env_var_value")) + }) + + It("does not inherits daemon's environment", func() { + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + + stdout := &bytes.Buffer{} + + session.Stdin = strings.NewReader(envVarCmd("TEST")) + session.Stdout = stdout + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.String()).NotTo(ContainSubstring("FOO")) + }) + }) + }) + + Context("when a client requests a local port forward", func() { + var server *ghttp.Server + BeforeEach(func() { + server = ghttp.NewServer() + server.AppendHandlers( + ghttp.CombineHandlers( + ghttp.VerifyRequest("GET", "/"), + ghttp.RespondWith(http.StatusOK, "hi from jim\n"), + ), + ) + }) + + It("forwards the local port to the target from the server side", func() { + lconn, err := client.Dial("tcp", server.Addr()) + Expect(err).NotTo(HaveOccurred()) + + transport := &http.Transport{ + Dial: func(network, addr string) (net.Conn, error) { + return lconn, nil + }, + } + client := &http.Client{Transport: transport} + + resp, err := client.Get("http://127.0.0.1/") + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + reader := bufio.NewReader(resp.Body) + line, err := reader.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(line).To(ContainSubstring("hi from jim")) + }) + }) + + Context("when a client requests a remote port forward", func() { + var ( + server *ghttp.Server + ln net.Listener + ) + + BeforeEach(func() { + server = ghttp.NewServer() + server.AppendHandlers( + ghttp.RespondWith(http.StatusOK, "hello from the other side\n"), + ) + }) + + JustBeforeEach(func() { + var err error + ln, err = client.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + }) + + It("forwards the remote port from server side to the target", func() { + go func() { + for { + conn, err := ln.Accept() + if err != nil { + return + } + + proxyConn, err := net.Dial("tcp", server.Addr()) + if err != nil { + return + } + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + _, _ = io.Copy(conn, proxyConn) + wg.Done() + }() + + go func() { + _, _ = io.Copy(proxyConn, conn) + wg.Done() + }() + + wg.Wait() + } + }() + + resp, err := http.Get(fmt.Sprintf("http://%s", ln.Addr())) + Expect(err).NotTo(HaveOccurred()) + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + reader := bufio.NewReader(resp.Body) + line, err := reader.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(line).To(ContainSubstring("hello from the other side")) + }) + + Context("when the connection is closed", func() { + JustBeforeEach(func() { + Expect(client.Close()).To(Succeed()) + }) + + It("closes the listeners associated with this conn", func() { + Eventually(func() error { + _, err := http.Get(fmt.Sprintf("http://%s", ln.Addr())) + return err + }).Should(MatchError(ContainSubstring("refused"))) + }) + }) + + Context("when the listener is closed", func() { + JustBeforeEach(func() { + Expect(ln.Close()).To(Succeed()) + }) + + It("responds with a connection refused error to clients", func() { + Eventually(func() error { + _, err := http.Get(fmt.Sprintf("http://%s", ln.Addr())) + return err + }).Should(MatchError(ContainSubstring("refused"))) + }) + }) + }) + }) +}) + +func envVarCmd(envVar string) string { + if runtime.GOOS == "windows" { + return "echo %" + envVar + "%\r\n" + } + + return fmt.Sprintf("/bin/echo -n $%s", envVar) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_windows2012R2_test.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_windows2012R2_test.go new file mode 100644 index 0000000000..0b080f63a3 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/main_windows2012R2_test.go @@ -0,0 +1,77 @@ +//go:build windows2012R2 + +package main_test + +import ( + "fmt" + "time" + + "code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner" + + "github.com/tedsuo/ifrit" + ginkgomon "github.com/tedsuo/ifrit/ginkgomon_v2" + "golang.org/x/crypto/ssh" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/ghttp" +) + +var _ = Describe("SSH daemon", func() { + Describe("SSH features", func() { + var ( + process ifrit.Process + address string + clientConfig *ssh.ClientConfig + client *ssh.Client + ) + + BeforeEach(func() { + args := testrunner.Args{ + HostKey: string(privateKeyPem), + AuthorizedKey: string(publicAuthorizedKey), + + AllowUnauthenticatedClients: true, + InheritDaemonEnv: false, + } + address = fmt.Sprintf("127.0.0.1:%d", sshdPort) + _, process = startSshd(sshdPath, args, "127.0.0.1", sshdPort) + clientConfig = &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + Expect(process).NotTo(BeNil()) + + var dialErr error + client, dialErr = ssh.Dial("tcp", address, clientConfig) + Expect(dialErr).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + ginkgomon.Kill(process, 3*time.Second) + client.Close() + }) + + Context("when a client requests the execution of a command", func() { + It("runs the command", func() { + _, err := client.NewSession() + Expect(err).To(MatchError(ContainSubstring("not supported"))) + }) + }) + + Context("when a client requests a local port forward", func() { + var server *ghttp.Server + BeforeEach(func() { + server = ghttp.NewServer() + }) + + It("forwards the local port to the target from the server side", func() { + _, err := client.Dial("tcp", server.Addr()) + Expect(err).To(MatchError(ContainSubstring("unknown channel type"))) + }) + + It("server should not receive any connections", func() { + Expect(server.ReceivedRequests()).To(BeEmpty()) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/package.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/package.go new file mode 100644 index 0000000000..2bd3ab8990 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/package.go @@ -0,0 +1 @@ +package main // import "code.cloudfoundry.org/diego-ssh/cmd/sshd" diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner/package.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner/package.go new file mode 100644 index 0000000000..f973ab0799 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner/package.go @@ -0,0 +1 @@ +package testrunner // import "code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner" diff --git a/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner/runner.go b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner/runner.go new file mode 100644 index 0000000000..12da62464d --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/cmd/sshd/testrunner/runner.go @@ -0,0 +1,43 @@ +package testrunner + +import ( + "os/exec" + "strconv" + "time" + + ginkgomon "github.com/tedsuo/ifrit/ginkgomon_v2" +) + +type Args struct { + Address string + HostKey string + AuthorizedKey string + AllowedCiphers string + AllowedMACs string + AllowedKeyExchanges string + AllowUnauthenticatedClients bool + InheritDaemonEnv bool +} + +func (args Args) ArgSlice() []string { + return []string{ + "-address=" + args.Address, + "-hostKey=" + args.HostKey, + "-authorizedKey=" + args.AuthorizedKey, + "-allowedCiphers=" + args.AllowedCiphers, + "-allowedMACs=" + args.AllowedMACs, + "-allowedKeyExchanges=" + args.AllowedKeyExchanges, + "-allowUnauthenticatedClients=" + strconv.FormatBool(args.AllowUnauthenticatedClients), + "-inheritDaemonEnv=" + strconv.FormatBool(args.InheritDaemonEnv), + } +} + +func New(binPath string, args Args) *ginkgomon.Runner { + return ginkgomon.New(ginkgomon.Config{ + Name: "sshd", + AnsiColorCode: "1;96m", + StartCheck: "sshd.started", + StartCheckTimeout: 10 * time.Second, + Command: exec.Command(binPath, args.ArgSlice()...), + }) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/daemon/daemon.go b/src/code.cloudfoundry.org/diego-ssh/daemon/daemon.go new file mode 100644 index 0000000000..a914d5c18a --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/daemon/daemon.go @@ -0,0 +1,105 @@ +package daemon + +import ( + "net" + + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +type Daemon struct { + logger lager.Logger + serverConfig *ssh.ServerConfig + globalRequestHandlers map[string]handlers.GlobalRequestHandler + newChannelHandlers map[string]handlers.NewChannelHandler +} + +func New( + logger lager.Logger, + serverConfig *ssh.ServerConfig, + globalRequestHandlers map[string]handlers.GlobalRequestHandler, + newChannelHandlers map[string]handlers.NewChannelHandler, +) *Daemon { + return &Daemon{ + logger: logger, + serverConfig: serverConfig, + globalRequestHandlers: globalRequestHandlers, + newChannelHandlers: newChannelHandlers, + } +} + +func (d *Daemon) HandleConnection(netConn net.Conn) { + logger := d.logger.Session("handle-connection") + + logger.Info("started") + defer logger.Info("completed") + defer netConn.Close() + + serverConn, serverChannels, serverRequests, err := ssh.NewServerConn(netConn, d.serverConfig) + if err != nil { + logger.Error("handshake-failed", err) + return + } + + lnStore := helpers.NewListenerStore() + go d.handleGlobalRequests(logger, serverRequests, serverConn, lnStore) + go d.handleNewChannels(logger, serverChannels) + + err = serverConn.Wait() + if err != nil { + logger.Debug("failed-to-wait-for-server", lager.Data{"error": err}) + } + lnStore.RemoveAll() +} + +func (d *Daemon) handleGlobalRequests(logger lager.Logger, requests <-chan *ssh.Request, conn ssh.Conn, lnStore *helpers.ListenerStore) { + logger = logger.Session("handle-global-requests") + logger.Info("starting") + defer logger.Info("finished") + + for req := range requests { + logger.Debug("request", lager.Data{ + "request-type": req.Type, + "want-reply": req.WantReply, + }) + + handler, ok := d.globalRequestHandlers[req.Type] + if ok { + handler.HandleRequest(logger, req, conn, lnStore) + continue + } + + if req.WantReply { + err := req.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + } +} + +func (d *Daemon) handleNewChannels(logger lager.Logger, newChannelRequests <-chan ssh.NewChannel) { + logger = logger.Session("handle-new-channels") + logger.Info("starting") + defer logger.Info("finished") + + for newChannel := range newChannelRequests { + logger.Info("new-channel", lager.Data{ + "channelType": newChannel.ChannelType(), + "extraData": newChannel.ExtraData(), + }) + + if handler, ok := d.newChannelHandlers[newChannel.ChannelType()]; ok { + go handler.HandleNewChannel(logger, newChannel) + continue + } + + logger.Info("rejecting-channel", lager.Data{"reason": "unkonwn-channel-type"}) + err := newChannel.Reject(ssh.UnknownChannelType, newChannel.ChannelType()) + if err != nil { + logger.Debug("failed-to-reject", lager.Data{"error": err}) + } + } +} diff --git a/src/code.cloudfoundry.org/diego-ssh/daemon/daemon_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/daemon/daemon_suite_test.go new file mode 100644 index 0000000000..bbef83e2e4 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/daemon/daemon_suite_test.go @@ -0,0 +1,25 @@ +package daemon_test + +import ( + "code.cloudfoundry.org/diego-ssh/keys" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/crypto/ssh" + + "testing" +) + +var TestHostKey ssh.Signer + +func TestDaemon(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Daemon Suite") +} + +var _ = BeforeSuite(func() { + hostKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + + Expect(err).NotTo(HaveOccurred()) + + TestHostKey = hostKey.PrivateKey() +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/daemon/daemon_test.go b/src/code.cloudfoundry.org/diego-ssh/daemon/daemon_test.go new file mode 100644 index 0000000000..f4394dd267 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/daemon/daemon_test.go @@ -0,0 +1,251 @@ +package daemon_test + +import ( + "errors" + "net" + + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/fake_handlers" + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagertest" + "golang.org/x/crypto/ssh" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Daemon", func() { + var ( + logger lager.Logger + sshd *daemon.Daemon + + serverSSHConfig *ssh.ServerConfig + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + serverSSHConfig = &ssh.ServerConfig{ + NoClientAuth: true, + } + serverSSHConfig.AddHostKey(TestHostKey) + }) + + Describe("HandleConnection", func() { + var fakeConn *fake_net.FakeConn + + Context("when the function returns", func() { + BeforeEach(func() { + fakeConn = &fake_net.FakeConn{} + fakeConn.ReadReturns(0, errors.New("oops")) + + sshd = daemon.New(logger, serverSSHConfig, nil, nil) + }) + + It("closes the connection", func() { + sshd.HandleConnection(fakeConn) + Expect(fakeConn.CloseCallCount()).To(BeNumerically(">=", 1)) + }) + }) + + Context("when an ssh client connects", func() { + var ( + serverNetConn net.Conn + clientNetConn net.Conn + + clientConn ssh.Conn + clientChannels <-chan ssh.NewChannel + clientRequests <-chan *ssh.Request + clientConnErr error + + client *ssh.Client + ) + + BeforeEach(func() { + serverSSHConfig.PasswordCallback = func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + return nil, nil + } + + serverNetConn, clientNetConn = test_helpers.Pipe() + + clientConfig := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("secret"), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + sshd = daemon.New(logger, serverSSHConfig, nil, nil) + go sshd.HandleConnection(serverNetConn) + + clientConn, clientChannels, clientRequests, clientConnErr = ssh.NewClientConn(clientNetConn, "0.0.0.0", clientConfig) + Expect(clientConnErr).NotTo(HaveOccurred()) + + client = ssh.NewClient(clientConn, clientChannels, clientRequests) + }) + + AfterEach(func() { + if client != nil { + client.Close() + } + }) + + It("performs a handshake", func() { + Expect(clientConnErr).NotTo(HaveOccurred()) + }) + }) + }) + + Describe("handleGlobalRequests", func() { + var ( + globalRequestHandlers map[string]handlers.GlobalRequestHandler + + fakeHandler *fake_handlers.FakeGlobalRequestHandler + client *ssh.Client + ) + + BeforeEach(func() { + fakeHandler = &fake_handlers.FakeGlobalRequestHandler{} + globalRequestHandlers = map[string]handlers.GlobalRequestHandler{ + "known-handler": fakeHandler, + } + + serverNetConn, clientNetConn := test_helpers.Pipe() + + sshd = daemon.New(logger, serverSSHConfig, globalRequestHandlers, nil) + go sshd.HandleConnection(serverNetConn) + + client = test_helpers.NewClient(clientNetConn, nil) + }) + + AfterEach(func() { + client.Close() + }) + + Context("when a global request is recevied", func() { + var ( + accepted bool + requestErr error + + name string + wantReply bool + ) + + JustBeforeEach(func() { + accepted, _, requestErr = client.SendRequest(name, wantReply, []byte("payload")) + }) + + Context("and there is an associated handler", func() { + BeforeEach(func() { + name = "known-handler" + wantReply = true + + fakeHandler.HandleRequestStub = func(logger lager.Logger, request *ssh.Request, conn ssh.Conn, lnStore *helpers.ListenerStore) { + request.Reply(true, []byte("response")) + } + }) + + It("calls the handler to handle the request", func() { + Eventually(fakeHandler.HandleRequestCallCount).Should(Equal(1)) + }) + + It("does not reject the request as unknown", func() { + Expect(requestErr).NotTo(HaveOccurred()) + Expect(accepted).To(BeTrue()) + }) + }) + + Context("and there is not an associated handler", func() { + Context("when WantReply is true", func() { + BeforeEach(func() { + name = "unknown-handler" + wantReply = true + }) + + It("rejects the request", func() { + Expect(requestErr).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + }) + }) + }) + + Describe("handleNewChannels", func() { + var newChannelHandlers map[string]handlers.NewChannelHandler + var fakeHandler *fake_handlers.FakeNewChannelHandler + var client *ssh.Client + + BeforeEach(func() { + fakeHandler = &fake_handlers.FakeNewChannelHandler{} + newChannelHandlers = map[string]handlers.NewChannelHandler{ + "known-channel-type": fakeHandler, + } + + serverNetConn, clientNetConn := test_helpers.Pipe() + + sshd = daemon.New(logger, serverSSHConfig, nil, newChannelHandlers) + go sshd.HandleConnection(serverNetConn) + + client = test_helpers.NewClient(clientNetConn, nil) + }) + + AfterEach(func() { + client.Close() + }) + + Context("when a new channel request is received", func() { + var ( + channelType string + + openError error + ) + + JustBeforeEach(func() { + _, _, openError = client.OpenChannel(channelType, []byte("extra-data")) + }) + + Context("and there is an associated handler", func() { + BeforeEach(func() { + channelType = "known-channel-type" + + fakeHandler.HandleNewChannelStub = func(logger lager.Logger, newChannel ssh.NewChannel) { + ch, _, err := newChannel.Accept() + Expect(err).NotTo(HaveOccurred()) + ch.Close() + } + }) + + It("calls the handler to process the new channel request", func() { + Expect(fakeHandler.HandleNewChannelCallCount()).To(Equal(1)) + + logger, actualChannel := fakeHandler.HandleNewChannelArgsForCall(0) + Expect(logger).NotTo(BeNil()) + + Expect(actualChannel.ChannelType()).To(Equal("known-channel-type")) + Expect(actualChannel.ExtraData()).To(Equal([]byte("extra-data"))) + }) + }) + + Context("and there is not an associated handler", func() { + BeforeEach(func() { + channelType = "unknown-channel-type" + }) + + It("rejects the new channel request", func() { + Expect(openError).To(HaveOccurred()) + + channelError, ok := openError.(*ssh.OpenChannelError) + Expect(ok).To(BeTrue()) + + Expect(channelError.Reason).To(Equal(ssh.UnknownChannelType)) + Expect(channelError.Message).To(Equal("unknown-channel-type")) + }) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/daemon/package.go b/src/code.cloudfoundry.org/diego-ssh/daemon/package.go new file mode 100644 index 0000000000..30d0815d4f --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/daemon/package.go @@ -0,0 +1 @@ +package daemon // import "code.cloudfoundry.org/diego-ssh/daemon" diff --git a/src/code.cloudfoundry.org/diego-ssh/docs/010-proxy.md b/src/code.cloudfoundry.org/diego-ssh/docs/010-proxy.md new file mode 100644 index 0000000000..d19bf082a9 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/docs/010-proxy.md @@ -0,0 +1,172 @@ +--- +title: SSH Proxy +expires_at : never +tags: [diego-release, diego-ssh] +--- + +## Proxy + +The ssh proxy hosts the user-accessible ssh endpoint and is responsible for +authentication, policy enforcement, and access controls in the context of +Cloud Foundry. After a user has successfully authenticated with the proxy, the +proxy will attempt to locate the target container and create an ssh session to +a daemon running inside the container. After both sessions have been +established, the proxy will manage the communication between the user's ssh +client and the container's ssh daemon. + +### Proxy Authentication + +Clients authenticate with the proxy using a specially formed user name that +describes the authentication domain and target container and a password that +contains the appropriate credentials for the domain. + +The proxy currently supports authentication against a `diego` domain and a +`cf` domain. Each authentication domain can be enabled independently via +command line arguments. + +#### Diego via custom credentials + +For Diego, the user is of the form `diego:`_process-guid_/_index_ and the +password must hold the configured credentials. + +Client example: +``` +$ ssh -p 2222 'diego:my-process-guid/1'@ssh.bosh-lite.com +$ scp -P 2222 -oUser='diego:ssh-process-guid/0' my-local-file.json ssh.bosh-lite.com:my-remote-file.json +``` + +The credentials checked by the proxy are configurable via the +`--diegoCredentials` flag. The password provided by the client to the proxy +must match what is present in the flag for successful authentication. + +This support is enabled with the `--enableDiegoAuth` flag. + +#### Cloud Foundry via Cloud Controller and UAA + +For Cloud Foundry, the user is of the form `cf:`_app-guid_/_instance_ and the +password must be an authorization code that the ssh proxy server can exchange +for an authorization token. The SSH proxy must be configured to use an OAuth +client id that has been defined in the UAA. The client id used by the proxy +must be advertised in the `/v2/info` endpoint under the `app_ssh_oauth_client` +key. Please see the [UAA][non-standard-oauth-auth-code] documentation for +details on how to allocate an authorization code. + +The proxy will contact the Cloud Controller as the user to determine if the +policy allows the user to access application containers via SSH. + +Client example: +``` +$ curl -k -v -H "Authorization: $(cf oauth-token | tail -1)" \ + https://uaa.bosh-lite.com/oauth/authorize \ + --data-urlencode "client_id=$(cf curl /v2/info | jq -r .app_ssh_oauth_client)" \ + --data-urlencode 'response_type=code' 2>&1 | \ + grep Location: | \ + cut -f2 -d'?' | \ + cut -f2 -d'=' | \ + pbcopy # paste authoriztion code when prompted for password +``` +or, with the Cloud Foundry `cf` [command line interface][cli]; +``` +$ cf ssh-code | pbcopy # paste authorization code when prompted for password +``` + +The authorization code can then be used as the password: + +``` +$ ssh -p 2222 cf:$(cf app app-name --guid)/0@ssh.bosh-lite.com +$ scp -P 2222 -oUser=cf:$(cf app app-name --guid)/0 my-local-file.json ssh.bosh-lite.com:my-remote-file.json +$ sftp -P 2222 cf:$(cf app app-name --guid)/0@ssh.bosh-lite.com +``` + +The Cloud Foundry `cf` [command line interface][cli] (v6.13 and newer) can +also be used to access an interactive shell in an application container: +``` +$ cf ssh app-name +$ cf ssh app-name -i 3 # access the container hosting index 3 of the app +``` + +This support is enabled with the `--enableCFAuth` flag. + +### Daemon discovery + +To be accessible via the SSH proxy, containers must host an ssh daemon, expose +it via a mapped port, and advertise the port in a `diego-ssh` route. The proxy +will fail end user authentication if the target LRP or a route is not found. + +```json + "routes": { + "diego-ssh": { "container_port": 2222 } + } +``` + +The [CC-Bridge][bridge] components of Diego will generate the appropriate LRP +definitions for Cloud Foundry applications which reflect the policies that are +in effect. + +### Proxy to Container Authentication + +When the proxy attempts to handshake with the SSH daemon inside the target +container, it will use the information associated with the `diego-ssh` key in +the LRP routes. + +#### `container_port` [required] +`container_port` indicates which port inside the container the ssh daemon is +listening on. The proxy will attempt to connect to host side mapping of this +port after authenticating the client. + +#### `host_fingerprint` [optional] +When present, `host_fingerprint` declares the expected fingerprint of the SSH +daemon's host public key. When the fingerprint of the actual target's host key +does not match the expected fingerprint, the connection is terminated. The +fingerprint should only contain the hex string generated by `ssh-keygen -l`. + +#### `user` [optional] +`user` declares the user ID to use during authentication with the container's +SSH daemon. While it's not a required part of the routing data, it is required +for password authentication and may be required for public key authentication. + +#### `password` [optional] +`password` declares the password to use during password authentication with +the container's ssh daemon. + +#### `private_key` [optional] +`private_key` declares the private key to use when authenticating with the +container's SSH daemon. If present, the key must be a PEM encoded RSA or DSA +public key. + +##### Example LRP +```json +{ + "process_guid": "ssh-process-guid", + "domain": "ssh-experiments", + "rootfs": "preloaded:cflinuxfs3", + "instances": 1, + "start_timeout": 30, + "setup": { + "download": { + "artifact": "diego-sshd", + "from": "http://file-server.service.cf.internal:8080/v1/static/diego-sshd/diego-sshd.tgz", + "to": "/tmp", + "cache_key": "diego-sshd" + } + }, + "action": { + "run": { + "path": "/tmp/diego-sshd", + "args": [ + "-address=0.0.0.0:2222", + "-authorizedKey=ssh-rsa ..." + ], + "env": [], + "resource_limits": {} + } + }, + "ports": [ 2222 ], + "routes": { + "diego-ssh": { + "container_port": 2222, + "private_key": "PEM encoded PKCS#1 private key" + } + } +} +``` diff --git a/src/code.cloudfoundry.org/diego-ssh/docs/020-ssh-daemon.md b/src/code.cloudfoundry.org/diego-ssh/docs/020-ssh-daemon.md new file mode 100644 index 0000000000..8cb3c613c3 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/docs/020-ssh-daemon.md @@ -0,0 +1,27 @@ +--- +title: SSH Daemon +expires_at : never +tags: [diego-release, diego-ssh] +--- + +## SSH Daemon + +The ssh daemon is a lightweight implementation that is built around go's ssh +library. It supports command execution, interactive shells, local port +forwarding, scp, and sftp. The daemon is self-contained and has no +dependencies on the container root file system. + +The daemon is focused on delivering basic access to application instances in +Cloud Foundry. It is intended to run as an unprivileged process and +interactive shells and commands will run as the daemon user. The daemon only +supports one authorized key is not intended to support multiple users. + +The daemon can be made available on a file server and Diego LRPs that +want to use it can include a download action to acquire the binary and a run +action to start it. Cloud Foundry applications will download the daemon as +part of the lifecycle bundle. + +[bridge]: https://github.com/cloudfoundry/diego-design-notes#cc-bridge-components +[cflinuxfs3]: https://github.com/cloudfoundry/cflinuxfs3 +[cli]: https://github.com/cloudfoundry/cli +[non-standard-oauth-auth-code]: https://github.com/cloudfoundry/uaa/blob/master/docs/UAA-APIs.rst#api-authorization-requests-code-get-oauth-authorize-non-standard-oauth-authorize diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/command_runner.go b/src/code.cloudfoundry.org/diego-ssh/handlers/command_runner.go new file mode 100644 index 0000000000..5ca249819f --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/command_runner.go @@ -0,0 +1,24 @@ +package handlers + +import ( + "os/exec" + "syscall" +) + +type commandRunner struct{} + +func NewCommandRunner() Runner { + return &commandRunner{} +} + +func (commandRunner) Start(cmd *exec.Cmd) error { + return cmd.Start() +} + +func (commandRunner) Wait(cmd *exec.Cmd) error { + return cmd.Wait() +} + +func (commandRunner) Signal(cmd *exec.Cmd, signal syscall.Signal) error { + return cmd.Process.Signal(signal) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/direct_tcpip_channel_handler.go b/src/code.cloudfoundry.org/diego-ssh/handlers/direct_tcpip_channel_handler.go new file mode 100644 index 0000000000..bf7ff51840 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/direct_tcpip_channel_handler.go @@ -0,0 +1,98 @@ +package handlers + +import ( + "fmt" + "net" + "sync" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +type DirectTcpipChannelHandler struct { + dialer Dialer +} + +func NewDirectTcpipChannelHandler(dialer Dialer) *DirectTcpipChannelHandler { + return &DirectTcpipChannelHandler{ + dialer: dialer, + } +} + +func (handler *DirectTcpipChannelHandler) HandleNewChannel(logger lager.Logger, newChannel ssh.NewChannel) { + logger = logger.Session("directtcip-handle-new-channel") + logger.Debug("starting") + defer logger.Debug("complete") + + // RFC 4254 Section 7.1 + type channelOpenDirectTcpipMsg struct { + TargetAddr string + TargetPort uint32 + OriginAddr string + OriginPort uint32 + } + var directTcpipMessage channelOpenDirectTcpipMsg + + err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipMessage) + if err != nil { + logger.Error("failed-unmarshalling-ssh-message", err) + err := newChannel.Reject(ssh.ConnectionFailed, "Failed to parse open channel message") + if err != nil { + logger.Debug("failed-to-reject", lager.Data{"error": err}) + } + return + } + + destination := fmt.Sprintf("%s:%d", directTcpipMessage.TargetAddr, directTcpipMessage.TargetPort) + logger.Debug("dialing-connection", lager.Data{"destination": destination}) + + conn, err := handler.dialer.Dial("tcp", destination) + if err != nil { + logger.Error("failed-connecting-to-target", err) + err := newChannel.Reject(ssh.ConnectionFailed, err.Error()) + if err != nil { + logger.Debug("failed-to-reject", lager.Data{"error": err}) + } + return + } + defer conn.Close() + + logger.Debug("dialed-connection", lager.Data{"destintation": destination}) + channel, requests, err := newChannel.Accept() + if err != nil { + logger.Error("failed-to-accept-channel", err) + err := newChannel.Reject(ssh.ConnectionFailed, err.Error()) + if err != nil { + logger.Debug("failed-to-reject", lager.Data{"error": err}) + } + return + } + defer channel.Close() + + go ssh.DiscardRequests(requests) + + wg := &sync.WaitGroup{} + + wg.Add(2) + + logger.Debug("copying-channel-data") + go helpers.CopyAndClose(logger.Session("to-target"), wg, conn, channel, + func() { + err := conn.(*net.TCPConn).CloseWrite() + if err != nil { + logger.Debug("failed-to-close-connection", lager.Data{"error": err}) + } + }, + ) + go helpers.CopyAndClose(logger.Session("to-channel"), wg, channel, conn, + func() { + err := channel.CloseWrite() + if err != nil { + logger.Debug("failed-to-close-channel", lager.Data{"error": err}) + } + }, + ) + + wg.Wait() +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/direct_tcpip_channel_handler_test.go b/src/code.cloudfoundry.org/diego-ssh/handlers/direct_tcpip_channel_handler_test.go new file mode 100644 index 0000000000..f56443612a --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/direct_tcpip_channel_handler_test.go @@ -0,0 +1,214 @@ +package handlers_test + +import ( + "bufio" + "errors" + "io" + "net" + "strconv" + "time" + + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/fake_handlers" + "code.cloudfoundry.org/diego-ssh/handlers/fakes" + "code.cloudfoundry.org/diego-ssh/server" + fake_server "code.cloudfoundry.org/diego-ssh/server/fakes" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("DirectTcpipChannelHandler", func() { + var ( + sshd *daemon.Daemon + client *ssh.Client + + logger *lagertest.TestLogger + serverSSHConfig *ssh.ServerConfig + + handler *fake_handlers.FakeNewChannelHandler + testHandler *handlers.DirectTcpipChannelHandler + testDialer *fakes.FakeDialer + + echoHandler *fake_server.FakeConnectionHandler + echoServer *server.Server + echoAddress string + + handleConnFinished chan struct{} + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + + echoHandler = &fake_server.FakeConnectionHandler{} + echoHandler.HandleConnectionStub = func(conn net.Conn) { + io.Copy(conn, conn) + conn.Close() + } + + echoListener, err := net.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + echoAddress = echoListener.Addr().String() + + echoServer = server.NewServer(logger.Session("echo"), "", echoHandler, 500*time.Millisecond) + echoServer.SetListener(echoListener) + go echoServer.Serve() + + serverSSHConfig = &ssh.ServerConfig{ + NoClientAuth: true, + } + serverSSHConfig.AddHostKey(TestHostKey) + + testDialer = &fakes.FakeDialer{} + testDialer.DialStub = net.Dial + + testHandler = handlers.NewDirectTcpipChannelHandler(testDialer) + + handler = &fake_handlers.FakeNewChannelHandler{} + handler.HandleNewChannelStub = testHandler.HandleNewChannel + + newChannelHandlers := map[string]handlers.NewChannelHandler{ + "direct-tcpip": handler, + } + + serverNetConn, clientNetConn := test_helpers.Pipe() + + sshd = daemon.New(logger, serverSSHConfig, nil, newChannelHandlers) + + handleConnFinished = make(chan struct{}) + go func() { + sshd.HandleConnection(serverNetConn) + close(handleConnFinished) + }() + + client = test_helpers.NewClient(clientNetConn, nil) + }) + + AfterEach(func() { + client.Close() + echoServer.Shutdown() + Eventually(handleConnFinished).Should(BeClosed()) + }) + + Context("when a session is opened", func() { + var conn net.Conn + + JustBeforeEach(func() { + var dialErr error + conn, dialErr = client.Dial("tcp", echoAddress) + Expect(dialErr).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + conn.Close() + }) + + It("dials the the target from the remote end", func() { + Expect(testDialer.DialCallCount()).To(Equal(1)) + + net, addr := testDialer.DialArgsForCall(0) + Expect(net).To(Equal("tcp")) + Expect(addr).To(Equal(echoAddress)) + }) + + It("copies data between the local and target connections", func() { + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + + writer.WriteString("Hello, World!\n") + writer.Flush() + + data, err := reader.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + + Expect(data).To(Equal("Hello, World!\n")) + }) + + Describe("channel close coordination", func() { + var completed chan struct{} + + BeforeEach(func() { + completed = make(chan struct{}, 1) + handler.HandleNewChannelStub = func(logger lager.Logger, newChannel ssh.NewChannel) { + testHandler.HandleNewChannel(logger, newChannel) + completed <- struct{}{} + } + }) + + AfterEach(func() { + close(completed) + }) + + Context("when the client connection closes", func() { + It("the handler returns", func() { + Consistently(completed).ShouldNot(Receive()) + conn.Close() + Eventually(completed).Should(Receive()) + }) + }) + }) + }) + + Context("when the direct-tcpip extra data fails to unmarshal", func() { + It("rejects the open channel request", func() { + _, _, err := client.OpenChannel("direct-tcpip", ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).To(Equal(&ssh.OpenChannelError{ + Reason: ssh.ConnectionFailed, + Message: "Failed to parse open channel message", + })) + }) + }) + + Context("when dialing the target fails", func() { + BeforeEach(func() { + testDialer.DialStub = func(net, addr string) (net.Conn, error) { + return nil, errors.New("woops") + } + }) + + It("rejects the open channel request", func() { + _, err := client.Dial("tcp", echoAddress) + Expect(err).To(Equal(&ssh.OpenChannelError{ + Reason: ssh.ConnectionFailed, + Message: "woops", + })) + + }) + }) + + Context("when an out of band request is sent across the channel", func() { + type channelOpenDirectTcpipMsg struct { + TargetAddr string + TargetPort uint32 + OriginAddr string + OriginPort uint32 + } + var directTcpipMessage channelOpenDirectTcpipMsg + + BeforeEach(func() { + addr, port, err := net.SplitHostPort(echoAddress) + Expect(err).NotTo(HaveOccurred()) + + p, err := strconv.ParseUint(port, 10, 16) + Expect(err).NotTo(HaveOccurred()) + + directTcpipMessage.TargetAddr = addr + directTcpipMessage.TargetPort = uint32(p) + }) + + It("rejects the requests", func() { + channel, _, err := client.OpenChannel("direct-tcpip", ssh.Marshal(directTcpipMessage)) + Expect(err).NotTo(HaveOccurred()) + + accepted, err := channel.SendRequest("something", true, nil) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + + channel.Close() + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/fake_handlers/fake_global_request_handler.go b/src/code.cloudfoundry.org/diego-ssh/handlers/fake_handlers/fake_global_request_handler.go new file mode 100644 index 0000000000..b4331e3661 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/fake_handlers/fake_global_request_handler.go @@ -0,0 +1,85 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fake_handlers + +import ( + "sync" + + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +type FakeGlobalRequestHandler struct { + HandleRequestStub func(lager.Logger, *ssh.Request, ssh.Conn, *helpers.ListenerStore) + handleRequestMutex sync.RWMutex + handleRequestArgsForCall []struct { + arg1 lager.Logger + arg2 *ssh.Request + arg3 ssh.Conn + arg4 *helpers.ListenerStore + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeGlobalRequestHandler) HandleRequest(arg1 lager.Logger, arg2 *ssh.Request, arg3 ssh.Conn, arg4 *helpers.ListenerStore) { + fake.handleRequestMutex.Lock() + fake.handleRequestArgsForCall = append(fake.handleRequestArgsForCall, struct { + arg1 lager.Logger + arg2 *ssh.Request + arg3 ssh.Conn + arg4 *helpers.ListenerStore + }{arg1, arg2, arg3, arg4}) + fake.recordInvocation("HandleRequest", []interface{}{arg1, arg2, arg3, arg4}) + handleRequestStubCopy := fake.HandleRequestStub + fake.handleRequestMutex.Unlock() + if handleRequestStubCopy != nil { + handleRequestStubCopy(arg1, arg2, arg3, arg4) + } +} + +func (fake *FakeGlobalRequestHandler) HandleRequestCallCount() int { + fake.handleRequestMutex.RLock() + defer fake.handleRequestMutex.RUnlock() + return len(fake.handleRequestArgsForCall) +} + +func (fake *FakeGlobalRequestHandler) HandleRequestCalls(stub func(lager.Logger, *ssh.Request, ssh.Conn, *helpers.ListenerStore)) { + fake.handleRequestMutex.Lock() + defer fake.handleRequestMutex.Unlock() + fake.HandleRequestStub = stub +} + +func (fake *FakeGlobalRequestHandler) HandleRequestArgsForCall(i int) (lager.Logger, *ssh.Request, ssh.Conn, *helpers.ListenerStore) { + fake.handleRequestMutex.RLock() + defer fake.handleRequestMutex.RUnlock() + argsForCall := fake.handleRequestArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2, argsForCall.arg3, argsForCall.arg4 +} + +func (fake *FakeGlobalRequestHandler) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.handleRequestMutex.RLock() + defer fake.handleRequestMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeGlobalRequestHandler) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ handlers.GlobalRequestHandler = new(FakeGlobalRequestHandler) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/fake_handlers/fake_new_channel_handler.go b/src/code.cloudfoundry.org/diego-ssh/handlers/fake_handlers/fake_new_channel_handler.go new file mode 100644 index 0000000000..e593752934 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/fake_handlers/fake_new_channel_handler.go @@ -0,0 +1,80 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fake_handlers + +import ( + "sync" + + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +type FakeNewChannelHandler struct { + HandleNewChannelStub func(lager.Logger, ssh.NewChannel) + handleNewChannelMutex sync.RWMutex + handleNewChannelArgsForCall []struct { + arg1 lager.Logger + arg2 ssh.NewChannel + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeNewChannelHandler) HandleNewChannel(arg1 lager.Logger, arg2 ssh.NewChannel) { + fake.handleNewChannelMutex.Lock() + fake.handleNewChannelArgsForCall = append(fake.handleNewChannelArgsForCall, struct { + arg1 lager.Logger + arg2 ssh.NewChannel + }{arg1, arg2}) + fake.recordInvocation("HandleNewChannel", []interface{}{arg1, arg2}) + handleNewChannelStubCopy := fake.HandleNewChannelStub + fake.handleNewChannelMutex.Unlock() + if handleNewChannelStubCopy != nil { + handleNewChannelStubCopy(arg1, arg2) + } +} + +func (fake *FakeNewChannelHandler) HandleNewChannelCallCount() int { + fake.handleNewChannelMutex.RLock() + defer fake.handleNewChannelMutex.RUnlock() + return len(fake.handleNewChannelArgsForCall) +} + +func (fake *FakeNewChannelHandler) HandleNewChannelCalls(stub func(lager.Logger, ssh.NewChannel)) { + fake.handleNewChannelMutex.Lock() + defer fake.handleNewChannelMutex.Unlock() + fake.HandleNewChannelStub = stub +} + +func (fake *FakeNewChannelHandler) HandleNewChannelArgsForCall(i int) (lager.Logger, ssh.NewChannel) { + fake.handleNewChannelMutex.RLock() + defer fake.handleNewChannelMutex.RUnlock() + argsForCall := fake.handleNewChannelArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeNewChannelHandler) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.handleNewChannelMutex.RLock() + defer fake.handleNewChannelMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeNewChannelHandler) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ handlers.NewChannelHandler = new(FakeNewChannelHandler) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/fake_handlers/package.go b/src/code.cloudfoundry.org/diego-ssh/handlers/fake_handlers/package.go new file mode 100644 index 0000000000..346bab2fab --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/fake_handlers/package.go @@ -0,0 +1 @@ +package fake_handlers // import "code.cloudfoundry.org/diego-ssh/handlers/fake_handlers" diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/fake_dialer.go b/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/fake_dialer.go new file mode 100644 index 0000000000..0b9677dd6f --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/fake_dialer.go @@ -0,0 +1,119 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fakes + +import ( + "net" + "sync" + + "code.cloudfoundry.org/diego-ssh/handlers" +) + +type FakeDialer struct { + DialStub func(string, string) (net.Conn, error) + dialMutex sync.RWMutex + dialArgsForCall []struct { + arg1 string + arg2 string + } + dialReturns struct { + result1 net.Conn + result2 error + } + dialReturnsOnCall map[int]struct { + result1 net.Conn + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeDialer) Dial(arg1 string, arg2 string) (net.Conn, error) { + fake.dialMutex.Lock() + ret, specificReturn := fake.dialReturnsOnCall[len(fake.dialArgsForCall)] + fake.dialArgsForCall = append(fake.dialArgsForCall, struct { + arg1 string + arg2 string + }{arg1, arg2}) + fake.recordInvocation("Dial", []interface{}{arg1, arg2}) + dialStubCopy := fake.DialStub + fake.dialMutex.Unlock() + if dialStubCopy != nil { + return dialStubCopy(arg1, arg2) + } + if specificReturn { + return ret.result1, ret.result2 + } + fakeReturns := fake.dialReturns + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeDialer) DialCallCount() int { + fake.dialMutex.RLock() + defer fake.dialMutex.RUnlock() + return len(fake.dialArgsForCall) +} + +func (fake *FakeDialer) DialCalls(stub func(string, string) (net.Conn, error)) { + fake.dialMutex.Lock() + defer fake.dialMutex.Unlock() + fake.DialStub = stub +} + +func (fake *FakeDialer) DialArgsForCall(i int) (string, string) { + fake.dialMutex.RLock() + defer fake.dialMutex.RUnlock() + argsForCall := fake.dialArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeDialer) DialReturns(result1 net.Conn, result2 error) { + fake.dialMutex.Lock() + defer fake.dialMutex.Unlock() + fake.DialStub = nil + fake.dialReturns = struct { + result1 net.Conn + result2 error + }{result1, result2} +} + +func (fake *FakeDialer) DialReturnsOnCall(i int, result1 net.Conn, result2 error) { + fake.dialMutex.Lock() + defer fake.dialMutex.Unlock() + fake.DialStub = nil + if fake.dialReturnsOnCall == nil { + fake.dialReturnsOnCall = make(map[int]struct { + result1 net.Conn + result2 error + }) + } + fake.dialReturnsOnCall[i] = struct { + result1 net.Conn + result2 error + }{result1, result2} +} + +func (fake *FakeDialer) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.dialMutex.RLock() + defer fake.dialMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeDialer) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ handlers.Dialer = new(FakeDialer) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/fake_runner.go b/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/fake_runner.go new file mode 100644 index 0000000000..ee40e79443 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/fake_runner.go @@ -0,0 +1,263 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fakes + +import ( + "os/exec" + "sync" + "syscall" + + "code.cloudfoundry.org/diego-ssh/handlers" +) + +type FakeRunner struct { + SignalStub func(*exec.Cmd, syscall.Signal) error + signalMutex sync.RWMutex + signalArgsForCall []struct { + arg1 *exec.Cmd + arg2 syscall.Signal + } + signalReturns struct { + result1 error + } + signalReturnsOnCall map[int]struct { + result1 error + } + StartStub func(*exec.Cmd) error + startMutex sync.RWMutex + startArgsForCall []struct { + arg1 *exec.Cmd + } + startReturns struct { + result1 error + } + startReturnsOnCall map[int]struct { + result1 error + } + WaitStub func(*exec.Cmd) error + waitMutex sync.RWMutex + waitArgsForCall []struct { + arg1 *exec.Cmd + } + waitReturns struct { + result1 error + } + waitReturnsOnCall map[int]struct { + result1 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeRunner) Signal(arg1 *exec.Cmd, arg2 syscall.Signal) error { + fake.signalMutex.Lock() + ret, specificReturn := fake.signalReturnsOnCall[len(fake.signalArgsForCall)] + fake.signalArgsForCall = append(fake.signalArgsForCall, struct { + arg1 *exec.Cmd + arg2 syscall.Signal + }{arg1, arg2}) + fake.recordInvocation("Signal", []interface{}{arg1, arg2}) + signalStubCopy := fake.SignalStub + fake.signalMutex.Unlock() + if signalStubCopy != nil { + return signalStubCopy(arg1, arg2) + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.signalReturns + return fakeReturns.result1 +} + +func (fake *FakeRunner) SignalCallCount() int { + fake.signalMutex.RLock() + defer fake.signalMutex.RUnlock() + return len(fake.signalArgsForCall) +} + +func (fake *FakeRunner) SignalCalls(stub func(*exec.Cmd, syscall.Signal) error) { + fake.signalMutex.Lock() + defer fake.signalMutex.Unlock() + fake.SignalStub = stub +} + +func (fake *FakeRunner) SignalArgsForCall(i int) (*exec.Cmd, syscall.Signal) { + fake.signalMutex.RLock() + defer fake.signalMutex.RUnlock() + argsForCall := fake.signalArgsForCall[i] + return argsForCall.arg1, argsForCall.arg2 +} + +func (fake *FakeRunner) SignalReturns(result1 error) { + fake.signalMutex.Lock() + defer fake.signalMutex.Unlock() + fake.SignalStub = nil + fake.signalReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRunner) SignalReturnsOnCall(i int, result1 error) { + fake.signalMutex.Lock() + defer fake.signalMutex.Unlock() + fake.SignalStub = nil + if fake.signalReturnsOnCall == nil { + fake.signalReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.signalReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRunner) Start(arg1 *exec.Cmd) error { + fake.startMutex.Lock() + ret, specificReturn := fake.startReturnsOnCall[len(fake.startArgsForCall)] + fake.startArgsForCall = append(fake.startArgsForCall, struct { + arg1 *exec.Cmd + }{arg1}) + fake.recordInvocation("Start", []interface{}{arg1}) + startStubCopy := fake.StartStub + fake.startMutex.Unlock() + if startStubCopy != nil { + return startStubCopy(arg1) + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.startReturns + return fakeReturns.result1 +} + +func (fake *FakeRunner) StartCallCount() int { + fake.startMutex.RLock() + defer fake.startMutex.RUnlock() + return len(fake.startArgsForCall) +} + +func (fake *FakeRunner) StartCalls(stub func(*exec.Cmd) error) { + fake.startMutex.Lock() + defer fake.startMutex.Unlock() + fake.StartStub = stub +} + +func (fake *FakeRunner) StartArgsForCall(i int) *exec.Cmd { + fake.startMutex.RLock() + defer fake.startMutex.RUnlock() + argsForCall := fake.startArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRunner) StartReturns(result1 error) { + fake.startMutex.Lock() + defer fake.startMutex.Unlock() + fake.StartStub = nil + fake.startReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRunner) StartReturnsOnCall(i int, result1 error) { + fake.startMutex.Lock() + defer fake.startMutex.Unlock() + fake.StartStub = nil + if fake.startReturnsOnCall == nil { + fake.startReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.startReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRunner) Wait(arg1 *exec.Cmd) error { + fake.waitMutex.Lock() + ret, specificReturn := fake.waitReturnsOnCall[len(fake.waitArgsForCall)] + fake.waitArgsForCall = append(fake.waitArgsForCall, struct { + arg1 *exec.Cmd + }{arg1}) + fake.recordInvocation("Wait", []interface{}{arg1}) + waitStubCopy := fake.WaitStub + fake.waitMutex.Unlock() + if waitStubCopy != nil { + return waitStubCopy(arg1) + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.waitReturns + return fakeReturns.result1 +} + +func (fake *FakeRunner) WaitCallCount() int { + fake.waitMutex.RLock() + defer fake.waitMutex.RUnlock() + return len(fake.waitArgsForCall) +} + +func (fake *FakeRunner) WaitCalls(stub func(*exec.Cmd) error) { + fake.waitMutex.Lock() + defer fake.waitMutex.Unlock() + fake.WaitStub = stub +} + +func (fake *FakeRunner) WaitArgsForCall(i int) *exec.Cmd { + fake.waitMutex.RLock() + defer fake.waitMutex.RUnlock() + argsForCall := fake.waitArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeRunner) WaitReturns(result1 error) { + fake.waitMutex.Lock() + defer fake.waitMutex.Unlock() + fake.WaitStub = nil + fake.waitReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeRunner) WaitReturnsOnCall(i int, result1 error) { + fake.waitMutex.Lock() + defer fake.waitMutex.Unlock() + fake.WaitStub = nil + if fake.waitReturnsOnCall == nil { + fake.waitReturnsOnCall = make(map[int]struct { + result1 error + }) + } + fake.waitReturnsOnCall[i] = struct { + result1 error + }{result1} +} + +func (fake *FakeRunner) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.signalMutex.RLock() + defer fake.signalMutex.RUnlock() + fake.startMutex.RLock() + defer fake.startMutex.RUnlock() + fake.waitMutex.RLock() + defer fake.waitMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeRunner) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ handlers.Runner = new(FakeRunner) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/fake_shell_locator.go b/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/fake_shell_locator.go new file mode 100644 index 0000000000..cbb5376c2d --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/fake_shell_locator.go @@ -0,0 +1,102 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fakes + +import ( + "sync" + + "code.cloudfoundry.org/diego-ssh/handlers" +) + +type FakeShellLocator struct { + ShellPathStub func() string + shellPathMutex sync.RWMutex + shellPathArgsForCall []struct { + } + shellPathReturns struct { + result1 string + } + shellPathReturnsOnCall map[int]struct { + result1 string + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeShellLocator) ShellPath() string { + fake.shellPathMutex.Lock() + ret, specificReturn := fake.shellPathReturnsOnCall[len(fake.shellPathArgsForCall)] + fake.shellPathArgsForCall = append(fake.shellPathArgsForCall, struct { + }{}) + fake.recordInvocation("ShellPath", []interface{}{}) + shellPathStubCopy := fake.ShellPathStub + fake.shellPathMutex.Unlock() + if shellPathStubCopy != nil { + return shellPathStubCopy() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.shellPathReturns + return fakeReturns.result1 +} + +func (fake *FakeShellLocator) ShellPathCallCount() int { + fake.shellPathMutex.RLock() + defer fake.shellPathMutex.RUnlock() + return len(fake.shellPathArgsForCall) +} + +func (fake *FakeShellLocator) ShellPathCalls(stub func() string) { + fake.shellPathMutex.Lock() + defer fake.shellPathMutex.Unlock() + fake.ShellPathStub = stub +} + +func (fake *FakeShellLocator) ShellPathReturns(result1 string) { + fake.shellPathMutex.Lock() + defer fake.shellPathMutex.Unlock() + fake.ShellPathStub = nil + fake.shellPathReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeShellLocator) ShellPathReturnsOnCall(i int, result1 string) { + fake.shellPathMutex.Lock() + defer fake.shellPathMutex.Unlock() + fake.ShellPathStub = nil + if fake.shellPathReturnsOnCall == nil { + fake.shellPathReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.shellPathReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeShellLocator) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.shellPathMutex.RLock() + defer fake.shellPathMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeShellLocator) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ handlers.ShellLocator = new(FakeShellLocator) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/package.go b/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/package.go new file mode 100644 index 0000000000..485ea57e4e --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/fakes/package.go @@ -0,0 +1 @@ +package fakes // import "code.cloudfoundry.org/diego-ssh/handlers/fakes" diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/cancel_tcpip_forward_handler.go b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/cancel_tcpip_forward_handler.go new file mode 100644 index 0000000000..bc43b878a3 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/cancel_tcpip_forward_handler.go @@ -0,0 +1,52 @@ +package globalrequest + +import ( + "net" + "strconv" + + "code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal" + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +const CancelTCPIPForward = "cancel-tcpip-forward" + +type CancelTCPIPForwardHandler struct{} + +func (h *CancelTCPIPForwardHandler) HandleRequest(logger lager.Logger, request *ssh.Request, conn ssh.Conn, lnStore *helpers.ListenerStore) { + logger = logger.Session("cancel-tcpip-forward", lager.Data{ + "type": request.Type, + "want-reply": request.WantReply, + }) + logger.Info("start") + defer logger.Info("done") + + var tcpipForwardMessage internal.TCPIPForwardRequest + + err := ssh.Unmarshal(request.Payload, &tcpipForwardMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + err = request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + + address := net.JoinHostPort(tcpipForwardMessage.Address, strconv.FormatUint(uint64(tcpipForwardMessage.Port), 10)) + + logger.Info("recieved-payload", lager.Data{ + "message-address": tcpipForwardMessage.Address, + "message-port": tcpipForwardMessage.Port, + "listen-address": address, + }) + + if err = lnStore.RemoveListener(address); err != nil { + logger.Error("failed-to-cancel", err) + _ = request.Reply(false, nil) + return + } + + logger.Info("successfully-canceled-tcpip-forward") + _ = request.Reply(true, nil) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/cancel_tcpip_forward_handler_test.go b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/cancel_tcpip_forward_handler_test.go new file mode 100644 index 0000000000..e3364d50e5 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/cancel_tcpip_forward_handler_test.go @@ -0,0 +1,117 @@ +package globalrequest_test + +import ( + "net" + "strconv" + + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/globalrequest" + "code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/lager/v3/lagertest" + "golang.org/x/crypto/ssh" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("CancelTcpipForwardHandler", func() { + var ( + sshClient *ssh.Client + logger *lagertest.TestLogger + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("tcpip-forward-test") + + globalRequestHandlers := map[string]handlers.GlobalRequestHandler{ + globalrequest.TCPIPForward: new(globalrequest.TCPIPForwardHandler), + globalrequest.CancelTCPIPForward: new(globalrequest.CancelTCPIPForwardHandler), + } + + serverSSHConfig := &ssh.ServerConfig{ + NoClientAuth: true, + } + serverSSHConfig.AddHostKey(TestHostKey) + + sshd := daemon.New(logger, serverSSHConfig, globalRequestHandlers, nil) + + serverNetConn, clientNetConn := test_helpers.Pipe() + go sshd.HandleConnection(serverNetConn) + sshClient = test_helpers.NewClient(clientNetConn, nil) + }) + + Context("when the request is invalid", func() { + It("rejects the request", func() { + payload := ssh.Marshal(struct { + port uint16 + }{ + port: 10, + }) + ok, _, err := sshClient.SendRequest("cancel-tcpip-forward", true, payload) + Expect(err).NotTo(HaveOccurred()) + Expect(ok).NotTo(BeTrue()) + }) + }) + + Context("when the listener isn't found", func() { + It("rejects the request", func() { + payload := ssh.Marshal(internal.TCPIPForwardRequest{ + Address: "127.0.0.1", + Port: 9090, + }) + ok, _, err := sshClient.SendRequest("cancel-tcpip-forward", true, payload) + Expect(err).NotTo(HaveOccurred()) + Expect(ok).NotTo(BeTrue()) + }) + }) + + Context("when the listener exists", func() { + var ( + ok bool + err error + address string + ) + + BeforeEach(func() { + addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + ln, err := sshClient.ListenTCP(addr) + Expect(err).NotTo(HaveOccurred()) + + _, portStr, err := net.SplitHostPort(ln.Addr().String()) + Expect(err).NotTo(HaveOccurred()) + + address = "127.0.0.1:" + portStr + port, err := strconv.Atoi(portStr) + Expect(err).NotTo(HaveOccurred()) + + _, err = net.Dial("tcp", address) + Expect(err).NotTo(HaveOccurred()) + + payload := ssh.Marshal(internal.TCPIPForwardRequest{ + Address: "127.0.0.1", + Port: uint32(port), + }) + ok, _, err = sshClient.SendRequest("cancel-tcpip-forward", true, payload) + Expect(err).ToNot(HaveOccurred()) + }) + + It("successfully process the request", func() { + Expect(err).NotTo(HaveOccurred()) + Expect(ok).To(BeTrue()) + }) + + It("stops listening to the port", func() { + // the reason for the eventually instead of Expect is that a Close + // doesn't guarantee that the linux socket is actually closed. See + // https://github.com/golang/go/issues/10527 and build failures in + // https://diego.ci.cf-app.com/teams/main/pipelines/main/jobs/units-common/builds/1207 + Eventually(func() error { + _, err := net.Dial("tcp", address) + return err + }).Should(MatchError(ContainSubstring("refused"))) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/globalrequest_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/globalrequest_suite_test.go new file mode 100644 index 0000000000..51021bda8e --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/globalrequest_suite_test.go @@ -0,0 +1,32 @@ +package globalrequest_test + +import ( + "os" + "runtime" + "testing" + + "code.cloudfoundry.org/diego-ssh/keys" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/crypto/ssh" +) + +var TestHostKey ssh.Signer + +func TestGlobalRequest(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "GlobalRequest Suite") +} + +var _ = BeforeSuite(func() { + hostKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + Expect(err).NotTo(HaveOccurred()) + + TestHostKey = hostKey.PrivateKey() + + if runtime.GOOS == "windows" { + if os.Getenv("WINPTY_DLL_DIR") == "" { + Fail("Missing WINPTY_DLL_DIR environment variable") + } + } +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal/package.go b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal/package.go new file mode 100644 index 0000000000..19b8203767 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal/package.go @@ -0,0 +1 @@ +package internal // import "code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal" diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal/tcpip_forward_messages.go b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal/tcpip_forward_messages.go new file mode 100644 index 0000000000..612b503445 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal/tcpip_forward_messages.go @@ -0,0 +1,10 @@ +package internal + +type TCPIPForwardRequest struct { + Address string + Port uint32 +} + +type TCPIPForwardResponse struct { + Port uint32 +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/package.go b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/package.go new file mode 100644 index 0000000000..c4df0a9da7 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/package.go @@ -0,0 +1 @@ +package globalrequest // import "code.cloudfoundry.org/diego-ssh/handlers/globalrequest" diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/tcpip_forward_handler.go b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/tcpip_forward_handler.go new file mode 100644 index 0000000000..5285342a39 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/tcpip_forward_handler.go @@ -0,0 +1,148 @@ +package globalrequest + +import ( + "net" + "strconv" + "sync" + + "code.cloudfoundry.org/diego-ssh/handlers/globalrequest/internal" + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +const TCPIPForward = "tcpip-forward" + +type TCPIPForwardHandler struct{} + +func (h *TCPIPForwardHandler) HandleRequest(logger lager.Logger, request *ssh.Request, conn ssh.Conn, lnStore *helpers.ListenerStore) { + logger = logger.Session("tcpip-forward", lager.Data{ + "type": request.Type, + "want-reply": request.WantReply, + }) + logger.Info("start") + defer logger.Info("done") + + var tcpipForwardMessage internal.TCPIPForwardRequest + err := ssh.Unmarshal(request.Payload, &tcpipForwardMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + err = request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + + address := net.JoinHostPort(tcpipForwardMessage.Address, strconv.Itoa(int(tcpipForwardMessage.Port))) + + logger.Info("new-tcpip-forward", lager.Data{ + "message-address": tcpipForwardMessage.Address, + "message-port": tcpipForwardMessage.Port, + "listen-address": address, + }) + + listener, err := net.Listen("tcp", address) + if err != nil { + logger.Error("failed-to-listen", err) + err = request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + return + } + + var listenerAddr string + var listenerPort uint32 + if addr, ok := listener.Addr().(*net.TCPAddr); ok { + address = addr.String() + listenerAddr = addr.IP.String() + listenerPort = uint32(addr.Port) + } + logger.Info("actual-listener-address", lager.Data{ + "addr": listenerAddr, + "port": listenerPort, + }) + + lnStore.AddListener(address, listener) + + go h.forwardAcceptLoop(listener, logger, conn, tcpipForwardMessage.Address, listenerPort) + + var tcpipForwardResponse internal.TCPIPForwardResponse + tcpipForwardResponse.Port = listenerPort + + var replyPayload []byte + + if tcpipForwardMessage.Port == 0 { + // See RFC 4254, section 7.1 + replyPayload = ssh.Marshal(tcpipForwardResponse) + } + + // Reply() will only send something when WantReply is true + _ = request.Reply(true, replyPayload) +} + +// See RFC 4254, section 7.2 +type forwardedTCPPayload struct { + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +func (h *TCPIPForwardHandler) forwardAcceptLoop(listener net.Listener, logger lager.Logger, sshConn ssh.Conn, lnAddr string, lnPort uint32) { + logger = logger.Session("forward-accept-loop") + logger.Info("start") + defer logger.Info("done") + + defer listener.Close() + for { + conn, err := listener.Accept() + if err != nil { + logger.Error("failed-to-accept", err) + return + } + + go func(conn net.Conn) { + defer conn.Close() + + payload := forwardedTCPPayload{ + Addr: lnAddr, + Port: lnPort, + } + if addr, ok := conn.RemoteAddr().(*net.TCPAddr); ok { + payload.OriginAddr = addr.IP.String() + payload.OriginPort = uint32(addr.Port) + } + + channel, requests, err := sshConn.OpenChannel("forwarded-tcpip", ssh.Marshal(payload)) + if err != nil { + logger.Error("failed-to-open channel", err) + return + } + defer channel.Close() + + logger.Info("opened-channel", lager.Data{"payload": payload}) + go ssh.DiscardRequests(requests) + + var wg sync.WaitGroup + wg.Add(2) + + go helpers.CopyAndClose(logger.Session("to-target"), &wg, conn, channel, func() { + err := conn.Close() + if err != nil { + logger.Debug("failed-to-close-connection", lager.Data{"error": err}) + } + }) + go helpers.CopyAndClose(logger.Session("to-channel"), &wg, channel, conn, func() { + err := channel.CloseWrite() + if err != nil { + logger.Debug("failed-to-close-channel", lager.Data{"error": err}) + } + }) + + wg.Wait() + }(conn) + + logger.Info("accepted-connection", lager.Data{"Address": listener.Addr().String()}) + } +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/tcpip_forward_handler_test.go b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/tcpip_forward_handler_test.go new file mode 100644 index 0000000000..db088da332 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/globalrequest/tcpip_forward_handler_test.go @@ -0,0 +1,167 @@ +package globalrequest_test + +import ( + "bufio" + "fmt" + "io" + "net" + + "golang.org/x/crypto/ssh" + + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/globalrequest" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagertest" + "code.cloudfoundry.org/localip" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("TCPIPForward Handler", func() { + var ( + remoteAddress string + sshClient *ssh.Client + logger *lagertest.TestLogger + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("tcpip-forward-test") + + remotePort, err := localip.LocalPort() + Expect(err).NotTo(HaveOccurred()) + remoteAddress = fmt.Sprintf("127.0.0.1:%d", remotePort) + + globalRequestHandlers := map[string]handlers.GlobalRequestHandler{ + globalrequest.TCPIPForward: new(globalrequest.TCPIPForwardHandler), + globalrequest.CancelTCPIPForward: new(globalrequest.CancelTCPIPForwardHandler), + } + + serverSSHConfig := &ssh.ServerConfig{ + NoClientAuth: true, + } + serverSSHConfig.AddHostKey(TestHostKey) + + sshd := daemon.New(logger, serverSSHConfig, globalRequestHandlers, nil) + + serverNetConn, clientNetConn := test_helpers.Pipe() + go sshd.HandleConnection(serverNetConn) + sshClient = test_helpers.NewClient(clientNetConn, nil) + }) + + testTCPIPForwardAndReturnConn := func(remoteAddr string) net.Conn { + remoteConn, err := net.Dial("tcp", remoteAddr) + Expect(err).NotTo(HaveOccurred()) + + expectedMsg := "hello\n" + _, err = fmt.Fprint(remoteConn, expectedMsg) + Expect(err).NotTo(HaveOccurred()) + r := bufio.NewReader(remoteConn) + l, err := r.ReadString('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(l).To(Equal(expectedMsg)) + return remoteConn + } + + testTCPIPForward := func(remoteAddr string) { + conn := testTCPIPForwardAndReturnConn(remoteAddr) + Expect(conn.Close()).To(Succeed()) + } + + It("listens for multiple connections on the interface/port specified", func() { + listener, err := sshClient.Listen("tcp", remoteAddress) + Expect(err).NotTo(HaveOccurred()) + + defer listener.Close() + go ServeListener(listener, logger.Session("local")) + + conn1 := testTCPIPForwardAndReturnConn(remoteAddress) + conn2 := testTCPIPForwardAndReturnConn(remoteAddress) + + Expect(conn1.Close()).To(Succeed()) + Expect(conn2.Close()).To(Succeed()) + }) + + It("allows the requester to ask for connections to be forwarded from an unused port", func() { + listener, err := sshClient.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + defer listener.Close() + go ServeListener(listener, logger.Session("local")) + + testTCPIPForward(listener.Addr().String()) + }) + + It("allows the requester to ask for connections to be forwarded from all interfaces", func() { + listener, err := sshClient.Listen("tcp", "0.0.0.0:0") + Expect(err).NotTo(HaveOccurred()) + defer listener.Close() + go ServeListener(listener, logger.Session("local")) + + testTCPIPForward(listener.Addr().String()) + }) + + It("can listen again after cancelling the request", func() { + listener, err := sshClient.Listen("tcp", remoteAddress) + Expect(err).NotTo(HaveOccurred()) + Expect(listener.Close()).To(Succeed()) + + listener, err = sshClient.Listen("tcp", remoteAddress) + Expect(err).NotTo(HaveOccurred()) + + defer listener.Close() + go ServeListener(listener, logger.Session("local")) + + testTCPIPForward(remoteAddress) + }) + + Context("when listener cannot be created", func() { + var ( + ln net.Listener + ) + + BeforeEach(func() { + var err error + ln, err = net.Listen("tcp", ":0") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(ln.Close()).To(Succeed()) + }) + + It("reject the request", func() { + _, err := sshClient.Listen("tcp", ln.Addr().String()) + Expect(err).To(HaveOccurred()) + }) + }) +}) + +func ServeListener(ln net.Listener, logger lager.Logger) { + for { + conn, err := ln.Accept() + if err != nil { + logger.Error("listener-failed-to-accept", err) + return + } + + go func() { + defer conn.Close() + defer GinkgoRecover() + + for { + r := bufio.NewReader(conn) + l, err := r.ReadString('\n') + if err == io.EOF { + return + } + n, err := conn.Write([]byte(l)) + if err != nil { + logger.Error("server-sent-message-error", err) + return + } + logger.Info("server-sent-message-success", lager.Data{"bytes-sent": n}) + } + }() + } +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/handlers_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/handlers/handlers_suite_test.go new file mode 100644 index 0000000000..5d3c989580 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/handlers_suite_test.go @@ -0,0 +1,33 @@ +package handlers_test + +import ( + "os" + "runtime" + + "code.cloudfoundry.org/diego-ssh/keys" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/crypto/ssh" + + "testing" +) + +var TestHostKey ssh.Signer + +func TestHandlers(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Handlers Suite") +} + +var _ = BeforeSuite(func() { + hostKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + Expect(err).NotTo(HaveOccurred()) + + TestHostKey = hostKey.PrivateKey() + + if runtime.GOOS == "windows" { + if os.Getenv("WINPTY_DLL_DIR") == "" { + Fail("Missing WINPTY_DLL_DIR environment variable") + } + } +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/package.go b/src/code.cloudfoundry.org/diego-ssh/handlers/package.go new file mode 100644 index 0000000000..b13596528b --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/package.go @@ -0,0 +1 @@ +package handlers // import "code.cloudfoundry.org/diego-ssh/handlers" diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler.go b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler.go new file mode 100644 index 0000000000..89799ced2e --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler.go @@ -0,0 +1,751 @@ +//go:build !windows && !windows2012R2 + +package handlers + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "os" + "os/exec" + "regexp" + "sync" + "syscall" + "time" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/scp" + "code.cloudfoundry.org/diego-ssh/signals" + "code.cloudfoundry.org/diego-ssh/termcodes" + "code.cloudfoundry.org/lager/v3" + "github.com/kr/pty" + "github.com/moby/term" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +var scpRegex = regexp.MustCompile(`^\s*scp($|\s+)`) + +type SessionChannelHandler struct { + runner Runner + shellLocator ShellLocator + defaultEnv map[string]string + keepalive time.Duration +} + +func NewSessionChannelHandler( + runner Runner, + shellLocator ShellLocator, + defaultEnv map[string]string, + keepalive time.Duration, +) *SessionChannelHandler { + return &SessionChannelHandler{ + runner: runner, + shellLocator: shellLocator, + defaultEnv: defaultEnv, + keepalive: keepalive, + } +} + +func (handler *SessionChannelHandler) HandleNewChannel(logger lager.Logger, newChannel ssh.NewChannel) { + channel, requests, err := newChannel.Accept() + if err != nil { + logger.Error("handle-new-session-channel-failed", err) + return + } + + handler.newSession(logger, channel, handler.keepalive).serviceRequests(requests) +} + +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +type session struct { + logger lager.Logger + complete bool + keepaliveDuration time.Duration + keepaliveStopCh chan struct{} + + shellPath string + runner Runner + channel ssh.Channel + + sync.Mutex + env map[string]string + command *exec.Cmd + + wg sync.WaitGroup + allocPty bool + ptyRequest ptyRequestMsg + + ptyMaster *os.File +} + +func (handler *SessionChannelHandler) newSession(logger lager.Logger, channel ssh.Channel, keepalive time.Duration) *session { + return &session{ + logger: logger.Session("session-channel"), + keepaliveDuration: keepalive, + runner: handler.runner, + shellPath: handler.shellLocator.ShellPath(), + channel: channel, + env: handler.defaultEnv, + } +} + +func (sess *session) serviceRequests(requests <-chan *ssh.Request) { + logger := sess.logger + logger.Info("starting") + defer logger.Info("finished") + + defer sess.destroy() + + for req := range requests { + sess.logger.Info("received-request", lager.Data{"type": req.Type}) + switch req.Type { + case "env": + sess.handleEnvironmentRequest(req) + case "signal": + sess.handleSignalRequest(req) + case "pty-req": + sess.handlePtyRequest(req) + case "window-change": + sess.handleWindowChangeRequest(req) + case "exec": + sess.handleExecRequest(req) + case "shell": + sess.handleShellRequest(req) + case "subsystem": + sess.handleSubsystemRequest(req) + default: + if req.WantReply { + err := req.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + } + } +} + +func (sess *session) handleEnvironmentRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-environment-request") + + type envMsg struct { + Name string + Value string + } + var envMessage envMsg + + err := ssh.Unmarshal(request.Payload, &envMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + err := request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + return + } + + sess.Lock() + sess.env[envMessage.Name] = envMessage.Value + sess.Unlock() + + if request.WantReply { + err := request.Reply(true, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } +} + +func (sess *session) handleSignalRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-signal-request") + + type signalMsg struct { + Signal string + } + var signalMessage signalMsg + + err := ssh.Unmarshal(request.Payload, &signalMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + err := request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + return + } + + sess.Lock() + defer sess.Unlock() + + cmd := sess.command + + if cmd != nil { + signal := signals.SyscallSignals[ssh.Signal(signalMessage.Signal)] + err := sess.runner.Signal(cmd, signal) + if err != nil { + logger.Error("process-signal-failed", err) + } + } + + if request.WantReply { + err := request.Reply(true, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } +} + +func (sess *session) handlePtyRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-pty-request") + + var ptyRequestMessage ptyRequestMsg + + err := ssh.Unmarshal(request.Payload, &ptyRequestMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + err := request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + return + } + + sess.Lock() + defer sess.Unlock() + + sess.allocPty = true + sess.ptyRequest = ptyRequestMessage + sess.env["TERM"] = ptyRequestMessage.Term + + if request.WantReply { + err := request.Reply(true, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } +} + +func (sess *session) handleWindowChangeRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-window-change") + + type windowChangeMsg struct { + Columns uint32 + Rows uint32 + WidthPx uint32 + HeightPx uint32 + } + var windowChangeMessage windowChangeMsg + + err := ssh.Unmarshal(request.Payload, &windowChangeMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + err := request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + return + } + + sess.Lock() + defer sess.Unlock() + + if sess.allocPty { + sess.ptyRequest.Columns = windowChangeMessage.Columns + sess.ptyRequest.Rows = windowChangeMessage.Rows + } + + if sess.ptyMaster != nil { + err = setWindowSize(logger, sess.ptyMaster, sess.ptyRequest.Columns, sess.ptyRequest.Rows) + if err != nil { + logger.Error("failed-to-set-window-size", err) + } + } + + if request.WantReply { + err := request.Reply(true, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } +} + +func (sess *session) handleExecRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-exec-request") + + type execMsg struct { + Command string + } + var execMessage execMsg + + err := ssh.Unmarshal(request.Payload, &execMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + err := request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + return + } + + if scpRegex.MatchString(execMessage.Command) { + logger.Info("handling-scp-command", lager.Data{"Command": execMessage.Command}) + sess.executeSCP(execMessage.Command, request) + } else { + sess.executeShell(request, "-c", execMessage.Command) + } +} + +func (sess *session) handleShellRequest(request *ssh.Request) { + sess.executeShell(request) +} + +func (sess *session) handleSubsystemRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-subsystem-request") + logger.Info("starting") + defer logger.Info("finished") + + type subsysMsg struct { + Subsystem string + } + var subsystemMessage subsysMsg + + err := ssh.Unmarshal(request.Payload, &subsystemMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + err = request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + return + } + + if subsystemMessage.Subsystem != "sftp" { + logger.Info("unsupported-subsystem", lager.Data{"subsystem": subsystemMessage.Subsystem}) + if request.WantReply { + err = request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + return + } + + lagerWriter := helpers.NewLagerWriter(logger.Session("sftp-server")) + sftpServer, err := sftp.NewServer(sess.channel, sftp.WithDebug(lagerWriter)) + if err != nil { + logger.Error("sftp-new-server-failed", err) + if request.WantReply { + err = request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + return + } + + if request.WantReply { + err = request.Reply(true, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + + logger.Info("starting-server") + go func() { + defer sess.destroy() + err = sftpServer.Serve() + if err != nil { + logger.Error("sftp-serve-error", err) + } + }() +} + +func (sess *session) executeShell(request *ssh.Request, args ...string) { + logger := sess.logger.Session("execute-shell") + + sess.Lock() + cmd, err := sess.createCommand(args...) + if err != nil { + sess.Unlock() + logger.Error("failed-to-create-command", err) + if request.WantReply { + err = request.Reply(false, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + return + } + + if request.WantReply { + err := request.Reply(true, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + + } + + if sess.allocPty { + err = sess.runWithPty(cmd) + } else { + err = sess.run(cmd) + } + + sess.Unlock() + + if err != nil { + sess.sendExitMessage(err) + sess.destroy() + return + } + + go func() { + err := sess.wait(cmd) + sess.sendExitMessage(err) + sess.destroy() + }() +} + +func (sess *session) createCommand(args ...string) (*exec.Cmd, error) { + if sess.command != nil { + return nil, errors.New("command already started") + } + + cmd := exec.Command(sess.shellPath, args...) + cmd.Env = sess.environment() + sess.command = cmd + + return cmd, nil +} + +func (sess *session) environment() []string { + env := []string{} + + env = append(env, "PATH=/bin:/usr/bin") + env = append(env, "LANG=en_US.UTF8") + + for k, v := range sess.env { + if k != "HOME" && k != "USER" { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + } + + env = append(env, fmt.Sprintf("HOME=%s", os.Getenv("HOME"))) + env = append(env, fmt.Sprintf("USER=%s", os.Getenv("USER"))) + + return env +} + +type exitStatusMsg struct { + Status uint32 +} + +type exitSignalMsg struct { + Signal string + CoreDumped bool + Error string + Lang string +} + +func (sess *session) sendExitMessage(err error) { + logger := sess.logger.Session("send-exit-message") + logger.Info("started") + defer logger.Info("finished") + + if err != nil { + logger.Error("building-exit-message-from-error", err) + } + + if err == nil { + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitStatusMsg{})) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + exitError, ok := err.(*exec.ExitError) + if !ok { + exitMessage := exitStatusMsg{Status: 255} + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + waitStatus, ok := exitError.Sys().(syscall.WaitStatus) + if !ok { + exitMessage := exitStatusMsg{Status: 255} + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + if waitStatus.Signaled() { + exitMessage := exitSignalMsg{ + Signal: string(signals.SSHSignals[waitStatus.Signal()]), + CoreDumped: waitStatus.CoreDump(), + } + _, sendErr := sess.channel.SendRequest("exit-signal", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + exitMessage := exitStatusMsg{Status: uint32(waitStatus.ExitStatus())} + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } +} + +func setWindowSize(logger lager.Logger, pseudoTty *os.File, columns, rows uint32) error { + logger.Info("new-size", lager.Data{"columns": columns, "rows": rows}) + return term.SetWinsize(pseudoTty.Fd(), &term.Winsize{ + Width: uint16(columns), + Height: uint16(rows), + }) +} + +func setTerminalAttributes(logger lager.Logger, pseudoTty *os.File, modelist string) { + reader := bytes.NewReader([]byte(modelist)) + + for { + var opcode uint8 + var value uint32 + + err := binary.Read(reader, binary.BigEndian, &opcode) + if err != nil { + logger.Error("failed-to-read-modelist-opcode", err) + break + } + + if opcode == 0 || opcode >= 160 { + break + } + + err = binary.Read(reader, binary.BigEndian, &value) + if err != nil { + logger.Error("failed-to-read-modelist-value", err) + break + } + + logger.Info("set-terminal-attribute", lager.Data{ + "opcode": opcode, + "value": fmt.Sprintf("%x", value), + }) + + termios, err := termcodes.GetAttr(pseudoTty) + if err != nil { + logger.Error("failed-to-get-terminal-attrs", err) + continue + } + + setter, ok := termcodes.TermAttrSetters[opcode] + if !ok { + logger.Error("failed-to-find-setter-for-opcode", errors.New("opcode-not-found"), lager.Data{ + "opcode": opcode, + }) + continue + } + + err = setter.Set(pseudoTty, termios, value) + if err != nil { + logger.Error("failed-to-set-terminal-attrs", err, lager.Data{ + "opcode": opcode, + "value": fmt.Sprintf("%x", value), + }) + continue + } + } +} + +func (sess *session) run(command *exec.Cmd) error { + logger := sess.logger.Session("run") + + command.Stdout = sess.channel + command.Stderr = sess.channel.Stderr() + + stdin, err := command.StdinPipe() + if err != nil { + return err + } + + go helpers.CopyAndClose(logger.Session("to-stdin"), nil, stdin, sess.channel, func() { + err := stdin.Close() + if err != nil { + logger.Debug("failed-to-close-stdin", lager.Data{"error": err}) + } + }) + + return sess.runner.Start(command) +} + +func (sess *session) runWithPty(command *exec.Cmd) error { + logger := sess.logger.Session("run-with-pty") + + ptyMaster, ptySlave, err := pty.Open() + if err != nil { + logger.Error("failed-to-open-pty", err) + return err + } + + sess.ptyMaster = ptyMaster + defer ptySlave.Close() + + command.Stdout = ptySlave + command.Stdin = ptySlave + command.Stderr = ptySlave + + command.SysProcAttr = &syscall.SysProcAttr{ + Setctty: true, + Setsid: true, + } + + setTerminalAttributes(logger, ptyMaster, sess.ptyRequest.Modelist) + err = setWindowSize(logger, ptyMaster, sess.ptyRequest.Columns, sess.ptyRequest.Rows) + if err != nil { + logger.Debug("failed-to-set-session-window-size", lager.Data{"error": err}) + } + + sess.wg.Add(1) + go helpers.Copy(logger.Session("to-pty"), nil, ptyMaster, sess.channel) + go func() { + helpers.Copy(logger.Session("from-pty"), &sess.wg, sess.channel, ptyMaster) + err := sess.channel.CloseWrite() + if err != nil { + logger.Debug("failed-to-close-session-channel-writer", lager.Data{"error": err}) + } + }() + + err = sess.runner.Start(command) + if err == nil { + sess.keepaliveStopCh = make(chan struct{}) + go sess.keepalive(command, sess.keepaliveStopCh) + } + return err +} + +func (sess *session) keepalive(command *exec.Cmd, stopCh chan struct{}) { + logger := sess.logger.Session("keepalive") + + ticker := time.NewTicker(sess.keepaliveDuration) + defer ticker.Stop() + for { + select { + case <-ticker.C: + _, err := sess.channel.SendRequest("keepalive@cloudfoundry.org", true, nil) + logger.Info("keepalive", lager.Data{"success": err == nil}) + + if err != nil { + err = sess.runner.Signal(command, syscall.SIGHUP) + logger.Info("process-signaled", lager.Data{"error": err}) + return + } + case <-stopCh: + return + } + } +} + +func (sess *session) wait(command *exec.Cmd) error { + logger := sess.logger.Session("wait") + logger.Info("started") + defer logger.Info("done") + return sess.runner.Wait(command) +} + +func (sess *session) destroy() { + logger := sess.logger.Session("destroy") + logger.Info("started") + defer logger.Info("done") + + sess.Lock() + defer sess.Unlock() + + if sess.complete { + return + } + + sess.complete = true + sess.wg.Wait() + + if sess.channel != nil { + err := sess.channel.Close() + if err != nil { + logger.Debug("failed-to-close-session-channel", lager.Data{"error": err}) + } + } + + if sess.ptyMaster != nil { + err := sess.ptyMaster.Close() + if err != nil { + logger.Debug("failed-to-close-session-pty-master", lager.Data{"error": err}) + } + sess.ptyMaster = nil + } + + if sess.keepaliveStopCh != nil { + close(sess.keepaliveStopCh) + } +} + +func (sess *session) executeSCP(command string, request *ssh.Request) { + logger := sess.logger.Session("execute-scp") + + if request.WantReply { + err := request.Reply(true, nil) + if err != nil { + logger.Debug("failed-replying", lager.Data{"error": err}) + } + + } + + copier, err := scp.NewFromCommand(command, sess.channel, sess.channel, sess.channel.Stderr(), logger) + if err == nil { + err = copier.Copy() + } + + sess.sendSCPExitMessage(err) + sess.destroy() +} + +func (sess *session) sendSCPExitMessage(err error) { + logger := sess.logger.Session("send-scp-exit-message") + logger.Info("started") + defer logger.Info("finished") + + var exitMessage exitStatusMsg + if err != nil { + logger.Error("building-scp-exit-message-from-error", err) + exitMessage = exitStatusMsg{Status: 1} + } + + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_test.go b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_test.go new file mode 100644 index 0000000000..7a65b9b227 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_test.go @@ -0,0 +1,1023 @@ +//go:build !windows && !windows2012R2 + +package handlers_test + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "time" + + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/fakes" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("SessionChannelHandler", func() { + var ( + sshd *daemon.Daemon + client *ssh.Client + + logger *lagertest.TestLogger + serverSSHConfig *ssh.ServerConfig + + runner *fakes.FakeRunner + shellLocator *fakes.FakeShellLocator + sessionChannelHandler *handlers.SessionChannelHandler + + newChannelHandlers map[string]handlers.NewChannelHandler + defaultEnv map[string]string + connectionFinished chan struct{} + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + serverSSHConfig = &ssh.ServerConfig{ + NoClientAuth: true, + } + serverSSHConfig.AddHostKey(TestHostKey) + + runner = &fakes.FakeRunner{} + realRunner := handlers.NewCommandRunner() + runner.StartStub = realRunner.Start + runner.WaitStub = realRunner.Wait + runner.SignalStub = realRunner.Signal + + shellLocator = &fakes.FakeShellLocator{} + shellLocator.ShellPathReturns("/bin/sh") + + defaultEnv = map[string]string{} + defaultEnv["TEST"] = "FOO" + + sessionChannelHandler = handlers.NewSessionChannelHandler(runner, shellLocator, defaultEnv, time.Second) + + newChannelHandlers = map[string]handlers.NewChannelHandler{ + "session": sessionChannelHandler, + } + + serverNetConn, clientNetConn := test_helpers.Pipe() + + sshd = daemon.New(logger, serverSSHConfig, nil, newChannelHandlers) + connectionFinished = make(chan struct{}) + go func() { + sshd.HandleConnection(serverNetConn) + close(connectionFinished) + }() + + client = test_helpers.NewClient(clientNetConn, nil) + }) + + AfterEach(func() { + if client != nil { + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + } + Eventually(connectionFinished).Should(BeClosed()) + }) + + Context("when a session is opened", func() { + var session *ssh.Session + + BeforeEach(func() { + var sessionErr error + session, sessionErr = client.NewSession() + + Expect(sessionErr).NotTo(HaveOccurred()) + }) + + It("can use the session to execute a command with stdout and stderr", func() { + stdout, err := session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + stderr, err := session.StderrPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Run("/bin/echo -n Hello; /bin/echo -n Goodbye >&2") + Expect(err).NotTo(HaveOccurred()) + + stdoutBytes, err := io.ReadAll(stdout) + Expect(err).NotTo(HaveOccurred()) + Expect(stdoutBytes).To(Equal([]byte("Hello"))) + + stderrBytes, err := io.ReadAll(stderr) + Expect(err).NotTo(HaveOccurred()) + Expect(stderrBytes).To(Equal([]byte("Goodbye"))) + }) + + It("returns when the process exits", func() { + stdin, err := session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Run("ls") + Expect(err).NotTo(HaveOccurred()) + + stdin.Close() + }) + + Describe("scp", func() { + var ( + sourceDir, generatedTextFile, targetDir string + err error + stdin io.WriteCloser + stdout io.Reader + fileContents []byte + ) + + BeforeEach(func() { + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + sourceDir, err = os.MkdirTemp("", "scp-source") + Expect(err).NotTo(HaveOccurred()) + + fileContents = []byte("---\nthis is a simple file\n\n") + generatedTextFile = filepath.Join(sourceDir, "textfile.txt") + + err = os.WriteFile(generatedTextFile, fileContents, 0664) + Expect(err).NotTo(HaveOccurred()) + + targetDir, err = os.MkdirTemp("", "scp-target") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(os.RemoveAll(sourceDir)).To(Succeed()) + Expect(os.RemoveAll(targetDir)).To(Succeed()) + }) + + It("properly copies using the secure copier", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := session.Run(fmt.Sprintf("scp -v -t %s", targetDir)) + Expect(err).NotTo(HaveOccurred()) + close(done) + }() + + confirmation := make([]byte, 1) + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + expectedFileInfo, err := os.Stat(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdin.Write([]byte(fmt.Sprintf("C0664 %d textfile.txt\n", expectedFileInfo.Size()))) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + _, err = stdin.Write(fileContents) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdin.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + err = stdin.Close() + Expect(err).NotTo(HaveOccurred()) + + actualFilePath := filepath.Join(targetDir, filepath.Base(generatedTextFile)) + actualFileInfo, err := os.Stat(actualFilePath) + Expect(err).NotTo(HaveOccurred()) + + Expect(actualFileInfo.Mode()).To(Equal(expectedFileInfo.Mode())) + Expect(actualFileInfo.Size()).To(Equal(expectedFileInfo.Size())) + + actualContents, err := os.ReadFile(actualFilePath) + Expect(err).NotTo(HaveOccurred()) + + expectedContents, err := os.ReadFile(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + + Expect(actualContents).To(Equal(expectedContents)) + + Eventually(done).Should(BeClosed()) + }) + + It("properly fails when secure copying fails", func() { + errCh := make(chan error) + go func() { + defer GinkgoRecover() + errCh <- session.Run(fmt.Sprintf("scp -v -t %s", targetDir)) + }() + + confirmation := make([]byte, 1) + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + _, err = stdin.Write([]byte("BOGUS PROTOCOL MESSAGE\n")) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{1})) + + err = <-errCh + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).To(Equal(1)) + }) + + It("properly fails when incorrect arguments are supplied", func() { + err := session.Run("scp -v -t /tmp/foo /tmp/bar") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).To(Equal(1)) + }) + }) + + Describe("the shell locator", func() { + BeforeEach(func() { + err := session.Run("true") + Expect(err).NotTo(HaveOccurred()) + }) + + It("uses the shell locator to find the default shell path", func() { + Expect(shellLocator.ShellPathCallCount()).To(Equal(1)) + + cmd := runner.StartArgsForCall(0) + Expect(cmd.Path).To(Equal("/bin/sh")) + }) + }) + + Context("when stdin is provided by the client", func() { + BeforeEach(func() { + session.Stdin = strings.NewReader("Hello") + }) + + It("can use the session to execute a command that reads it", func() { + result, err := session.Output("cat") + Expect(err).NotTo(HaveOccurred()) + Expect(string(result)).To(Equal("Hello")) + }) + }) + + Context("when the command exits with a non-zero value", func() { + It("it preserve the exit code", func() { + err := session.Run("exit 3") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).To(Equal(3)) + }) + }) + + Context("when a signal is sent across the session", func() { + Context("before a command has been run", func() { + BeforeEach(func() { + err := session.Signal(ssh.SIGTERM) + Expect(err).NotTo(HaveOccurred()) + }) + + It("does not prevent the command from running", func() { + result, err := session.Output("/bin/echo -n 'still kicking'") + Expect(err).NotTo(HaveOccurred()) + Expect(string(result)).To(Equal("still kicking")) + }) + }) + + Context("while a command is running", func() { + var stdin io.WriteCloser + var stdout io.Reader + + BeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Start("trap 'echo Caught SIGUSR1' USR1; echo trapped; cat") + Expect(err).NotTo(HaveOccurred()) + + reader := bufio.NewReader(stdout) + Eventually(reader.ReadLine).Should(ContainSubstring("trapped")) + + Eventually(runner.StartCallCount).Should(Equal(1)) + }) + + It("delivers the signal to the process", func() { + err := session.Signal(ssh.SIGUSR1) + Expect(err).NotTo(HaveOccurred()) + + Eventually(runner.SignalCallCount).Should(Equal(1)) + + err = stdin.Close() + Expect(err).NotTo(HaveOccurred()) + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + + stdoutBytes, err := io.ReadAll(stdout) + Expect(err).NotTo(HaveOccurred()) + Expect(stdoutBytes).To(ContainSubstring("Caught SIGUSR1")) + }) + + It("exits with an exit-signal response", func() { + err := session.Signal(ssh.SIGUSR2) + Expect(err).NotTo(HaveOccurred()) + + Eventually(runner.SignalCallCount).Should(Equal(1)) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + err = session.Wait() + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.Signal()).To(Equal("USR2")) + }) + }) + }) + + Context("when running a command without an explicit environemnt", func() { + It("does not inherit daemon's environment", func() { + os.Setenv("DAEMON_ENV", "daemon_env_value") + + result, err := session.Output("/usr/bin/env") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).NotTo(ContainSubstring("DAEMON_ENV=daemon_env_value")) + }) + + It("includes a default environment excluding PATH", func() { + result, err := session.Output("/usr/bin/env") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("PATH=/bin:/usr/bin")) + Expect(result).To(ContainSubstring("LANG=en_US.UTF8")) + Expect(result).To(ContainSubstring("TEST=FOO")) + Expect(result).To(ContainSubstring(fmt.Sprintf("HOME=%s", os.Getenv("HOME")))) + Expect(result).To(ContainSubstring(fmt.Sprintf("USER=%s", os.Getenv("USER")))) + }) + }) + + Context("when environment variables are requested", func() { + Context("before starting the command", func() { + It("runs the command with the specified environment", func() { + err := session.Setenv("ENV1", "value1") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("ENV2", "value2") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("/usr/bin/env") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("ENV1=value1")) + Expect(result).To(ContainSubstring("ENV2=value2")) + }) + + It("uses the value last specified", func() { + err := session.Setenv("ENV1", "original") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("ENV1", "updated") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("/usr/bin/env") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("ENV1=updated")) + }) + + It("can override PATH and LANG", func() { + err := session.Setenv("PATH", "/bin:/usr/local/bin:/sbin") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("LANG", "en_UK.UTF8") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("/usr/bin/env") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("PATH=/bin:/usr/local/bin:/sbin")) + Expect(result).To(ContainSubstring("LANG=en_UK.UTF8")) + }) + + It("cannot override HOME and USER", func() { + err := session.Setenv("HOME", "/some/other/home") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("USER", "not-a-user") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("/usr/bin/env") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring(fmt.Sprintf("HOME=%s", os.Getenv("HOME")))) + Expect(result).To(ContainSubstring(fmt.Sprintf("USER=%s", os.Getenv("USER")))) + }) + + It("can override default env variables", func() { + err := session.Setenv("TEST", "BAR") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("/usr/bin/env") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("TEST=BAR")) + }) + }) + + Context("after starting the command", func() { + var stdin io.WriteCloser + var stdout io.Reader + + BeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Start("cat && /usr/bin/env") + Expect(err).NotTo(HaveOccurred()) + }) + + It("ignores the request", func() { + err := session.Setenv("ENV3", "value3") + Expect(err).NotTo(HaveOccurred()) + + stdin.Close() + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + + stdoutBytes, err := io.ReadAll(stdout) + Expect(err).NotTo(HaveOccurred()) + + Expect(stdoutBytes).NotTo(ContainSubstring("ENV3")) + }) + }) + }) + + Context("when a pty request is received", func() { + var terminalModes ssh.TerminalModes + + BeforeEach(func() { + terminalModes = ssh.TerminalModes{} + }) + + JustBeforeEach(func() { + err := session.RequestPty("vt100", 43, 80, terminalModes) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should allocate a tty for the session", func() { + result, err := session.Output("tty") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).NotTo(ContainSubstring("not a tty")) + }) + + It("returns when the process exits", func() { + stdin, err := session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Run("ls") + Expect(err).NotTo(HaveOccurred()) + + stdin.Close() + }) + + It("terminates the shell when the stdin closes", func() { + waitCh := make(chan error, 1) + waitStartedCh := make(chan struct{}, 1) + waitStub := runner.WaitStub + runner.WaitStub = func(command *exec.Cmd) error { + close(waitStartedCh) + err := waitStub(command) + waitCh <- err + return err + } + + err := session.Shell() + Expect(err).NotTo(HaveOccurred()) + + Eventually(waitStartedCh).Should(BeClosed()) + + err = client.Conn.Close() + client = nil + Expect(err).NotTo(HaveOccurred()) + session.Wait() + Eventually(waitCh, 3).Should(Receive(MatchError("signal: hangup"))) + }) + + It("should set the terminal type", func() { + result, err := session.Output(`/bin/echo -n "$TERM"`) + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(Equal("vt100")) + }) + + It("sets the correct window size for the terminal", func() { + result, err := session.Output("stty size") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("43 80")) + }) + + Context("when control character mappings are specified in TerminalModes", func() { + BeforeEach(func() { + // Swap CTRL-Z (suspend) with CTRL-D (eof) + terminalModes[ssh.VEOF] = 26 + terminalModes[ssh.VSUSP] = 4 + }) + + It("honors the control character changes", func() { + result, err := session.Output("stty -a") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring("susp = ^D")) + Expect(string(result)).To(ContainSubstring("eof = ^Z")) + }) + }) + + Context("when an unrecognized terminal mode is specified", func() { + BeforeEach(func() { + terminalModes[42] = 1 + }) + + It("ignores it", func() { + errCh := make(chan error) + go func() { + defer GinkgoRecover() + + result, err := session.Output("echo -n hi") + Expect(string(result)).To(Equal("hi")) + errCh <- err + }() + var err error + Eventually(errCh).Should(Receive(&err)) + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("when input modes are specified in TerminalModes", func() { + BeforeEach(func() { + terminalModes[ssh.IGNPAR] = 1 + terminalModes[ssh.IXON] = 0 + terminalModes[ssh.IXANY] = 0 + }) + + It("honors the input mode changes", func() { + result, err := session.Output("stty -a") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring(" ignpar")) + Expect(string(result)).To(ContainSubstring(" -ixon")) + Expect(string(result)).To(ContainSubstring(" -ixany")) + }) + }) + + // Looks like there are some issues with terminal attributes on Linux. + // These need further investigation there. + Context("when local modes are specified in TerminalModes", func() { + BeforeEach(func() { + terminalModes[ssh.IEXTEN] = 0 + terminalModes[ssh.ECHOCTL] = 1 + }) + + It("honors the local mode changes", func() { + result, err := session.Output("stty -a") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring(" -iexten")) + Expect(string(result)).To(MatchRegexp("[^-]echoctl")) + }) + }) + + Context("when output modes are specified in TerminalModes", func() { + BeforeEach(func() { + terminalModes[ssh.ONLCR] = 0 + }) + + It("honors the output mode changes", func() { + result, err := session.Output("stty -a") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring(" -onlcr")) + }) + + if runtime.GOOS == "linux" { + Context("on linux", func() { + BeforeEach(func() { + terminalModes[ssh.ONLRET] = 1 + }) + + It("honors the output mode changes", func() { + result, err := session.Output("stty -a") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring(" onlret")) + }) + }) + } + }) + + Context("when control character modes are specified in TerminalModes", func() { + BeforeEach(func() { + terminalModes[ssh.PARODD] = 0 + }) + + It("honors the control mode changes", func() { + result, err := session.Output("stty -a") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring(" -parodd")) + }) + }) + + Context("when an interactive command is executed", func() { + var stdin io.WriteCloser + + JustBeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + }) + + It("terminates the session when the shell exits", func() { + err := session.Start("/bin/sh") + Expect(err).NotTo(HaveOccurred()) + + _, err = stdin.Write([]byte("exit\n")) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + Expect(err).NotTo(HaveOccurred()) + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + }) + }) + }) + + Context("when a window change request is received", func() { + type winChangeMsg struct { + Columns uint32 + Rows uint32 + WidthPx uint32 + HeightPx uint32 + } + + var result []byte + + Context("before a pty is allocated", func() { + BeforeEach(func() { + _, err := session.SendRequest("window-change", false, ssh.Marshal(winChangeMsg{ + Rows: 50, + Columns: 132, + })) + Expect(err).NotTo(HaveOccurred()) + + err = session.RequestPty("vt100", 43, 80, ssh.TerminalModes{}) + Expect(err).NotTo(HaveOccurred()) + + result, err = session.Output("stty size") + Expect(err).NotTo(HaveOccurred()) + }) + + It("ignores the request", func() { + Expect(result).To(ContainSubstring("43 80")) + }) + }) + + Context("after a pty is allocated", func() { + BeforeEach(func() { + err := session.RequestPty("vt100", 43, 80, ssh.TerminalModes{}) + Expect(err).NotTo(HaveOccurred()) + + _, err = session.SendRequest("window-change", false, ssh.Marshal(winChangeMsg{ + Rows: 50, + Columns: 132, + })) + Expect(err).NotTo(HaveOccurred()) + + result, err = session.Output("stty size") + Expect(err).NotTo(HaveOccurred()) + }) + + It("changes the the size of the terminal", func() { + Expect(result).To(ContainSubstring("50 132")) + }) + }) + }) + + Context("after executing a command", func() { + BeforeEach(func() { + err := session.Run("true") + Expect(err).NotTo(HaveOccurred()) + }) + + It("the session is no longer usable", func() { + _, err := session.SendRequest("exec", true, ssh.Marshal(struct{ Command string }{Command: "true"})) + Expect(err).To(HaveOccurred()) + + _, err = session.SendRequest("bogus", true, nil) + Expect(err).To(HaveOccurred()) + + err = session.Setenv("foo", "bar") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when an interactive shell is requested", func() { + var stdin io.WriteCloser + + BeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + }) + + It("starts the shell with the runner", func() { + Eventually(runner.StartCallCount).Should(Equal(1)) + + command := runner.StartArgsForCall(0) + Expect(command.Path).To(Equal("/bin/sh")) + Expect(command.Args).To(ConsistOf("/bin/sh")) + }) + + It("terminates the session when the shell exits", func() { + _, err := stdin.Write([]byte("exit\n")) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("and a command is provided", func() { + BeforeEach(func() { + err := session.Run("true") + Expect(err).NotTo(HaveOccurred()) + }) + + It("uses the provided runner to start the command", func() { + Expect(runner.StartCallCount()).To(Equal(1)) + Expect(runner.WaitCallCount()).To(Equal(1)) + }) + + It("passes the correct command to the runner", func() { + command := runner.StartArgsForCall(0) + Expect(command.Path).To(Equal("/bin/sh")) + Expect(command.Args).To(ConsistOf("/bin/sh", "-c", "true")) + }) + + It("passes the same command to Start and Wait", func() { + command := runner.StartArgsForCall(0) + Expect(runner.WaitArgsForCall(0)).To(Equal(command)) + }) + }) + + Context("when executing an invalid command", func() { + It("returns an exit error with a non-zero exit status", func() { + err := session.Run("not-a-command") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).NotTo(Equal(0)) + }) + + Context("when starting the command fails", func() { + BeforeEach(func() { + runner.StartReturns(errors.New("oops")) + }) + + It("returns an exit status message with a non-zero status", func() { + err := session.Run("true") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).NotTo(Equal(0)) + }) + }) + + Context("when waiting on the command fails", func() { + BeforeEach(func() { + runner.WaitReturns(errors.New("oops")) + }) + + It("returns an exit status message with a non-zero status", func() { + err := session.Run("true") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).NotTo(Equal(0)) + }) + }) + }) + + Context("when an unknown request type is sent", func() { + var accepted bool + + BeforeEach(func() { + var err error + accepted, err = session.SendRequest("unknown-request-type", true, []byte("payload")) + Expect(err).NotTo(HaveOccurred()) + }) + + It("rejects the request", func() { + Expect(accepted).To(BeFalse()) + }) + + It("does not terminate the session", func() { + response, err := session.Output("/bin/echo -n Hello") + Expect(err).NotTo(HaveOccurred()) + Expect(response).To(Equal([]byte("Hello"))) + }) + }) + + Context("when an unknown subsystem is requested", func() { + var accepted bool + + BeforeEach(func() { + type subsysMsg struct{ Subsystem string } + + var err error + accepted, err = session.SendRequest("subsystem", true, ssh.Marshal(subsysMsg{Subsystem: "unknown"})) + Expect(err).NotTo(HaveOccurred()) + }) + + It("rejects the request", func() { + Expect(accepted).To(BeFalse()) + }) + + It("does not terminate the session", func() { + response, err := session.Output("/bin/echo -n Hello") + Expect(err).NotTo(HaveOccurred()) + Expect(response).To(Equal([]byte("Hello"))) + }) + }) + }) + + Context("when the sftp subystem is requested", func() { + It("accepts the request", func() { + type subsysMsg struct{ Subsystem string } + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + defer session.Close() + + accepted, err := session.SendRequest("subsystem", true, ssh.Marshal(subsysMsg{Subsystem: "sftp"})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeTrue()) + }) + + It("starts an sftp server in write mode", func() { + tempDir, err := os.MkdirTemp("", "sftp") + Expect(err).NotTo(HaveOccurred()) + defer os.RemoveAll(tempDir) + + sftp, err := sftp.NewClient(client) + Expect(err).NotTo(HaveOccurred()) + defer sftp.Close() + + By("creating the file") + target := filepath.Join(tempDir, "textfile.txt") + file, err := sftp.Create(target) + Expect(err).NotTo(HaveOccurred()) + + fileContents := []byte("---\nthis is a simple file\n\n") + _, err = file.Write(fileContents) + Expect(err).NotTo(HaveOccurred()) + + err = file.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(os.ReadFile(target)).To(Equal(fileContents)) + + By("reading the file") + file, err = sftp.Open(target) + Expect(err).NotTo(HaveOccurred()) + + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(file) + Expect(err).NotTo(HaveOccurred()) + + err = file.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(buffer.Bytes()).To(Equal(fileContents)) + + By("removing the file") + err = sftp.Remove(target) + Expect(err).NotTo(HaveOccurred()) + + _, err = os.Stat(target) + Expect(err).To(HaveOccurred()) + Expect(os.IsNotExist(err)).To(BeTrue()) + }) + }) + + Describe("invalid session channel requests", func() { + var channel ssh.Channel + var requests <-chan *ssh.Request + + BeforeEach(func() { + var err error + channel, requests, err = client.OpenChannel("session", nil) + Expect(err).NotTo(HaveOccurred()) + + go ssh.DiscardRequests(requests) + }) + + AfterEach(func() { + if channel != nil { + channel.Close() + } + }) + + Context("when an exec request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("exec", true, ssh.Marshal(struct{ Bogus uint32 }{Bogus: 1138})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when an env request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("env", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a signal request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("signal", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a pty request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("pty-req", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a window change request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("window-change", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a subsystem request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("subsystem", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2012R2.go b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2012R2.go new file mode 100644 index 0000000000..fb561ef551 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2012R2.go @@ -0,0 +1,31 @@ +//go:build windows2012R2 + +package handlers + +import ( + "time" + + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +type SessionChannelHandler struct { +} + +func NewSessionChannelHandler( + runner Runner, + shellLocator ShellLocator, + defaultEnv map[string]string, + keepalive time.Duration, +) *SessionChannelHandler { + return &SessionChannelHandler{} +} + +func (handler *SessionChannelHandler) HandleNewChannel(logger lager.Logger, newChannel ssh.NewChannel) { + err := newChannel.Reject(ssh.Prohibited, "SSH is not supported on windows2012R2 cells") + if err != nil { + logger.Error("handle-new-session-channel-failed", err) + } + + return +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2012R2_test.go b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2012R2_test.go new file mode 100644 index 0000000000..198a96ca83 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2012R2_test.go @@ -0,0 +1,77 @@ +//go:build windows2012R2 + +package handlers_test + +import ( + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/fakes" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("SessionChannelHandler", func() { + var ( + sshd *daemon.Daemon + client *ssh.Client + + logger *lagertest.TestLogger + serverSSHConfig *ssh.ServerConfig + + runner *fakes.FakeRunner + shellLocator *fakes.FakeShellLocator + sessionChannelHandler *handlers.SessionChannelHandler + + newChannelHandlers map[string]handlers.NewChannelHandler + defaultEnv map[string]string + connectionFinished chan struct{} + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + serverSSHConfig = &ssh.ServerConfig{ + NoClientAuth: true, + } + serverSSHConfig.AddHostKey(TestHostKey) + + runner = &fakes.FakeRunner{} + shellLocator = &fakes.FakeShellLocator{} + defaultEnv = map[string]string{} + + sessionChannelHandler = handlers.NewSessionChannelHandler() + + newChannelHandlers = map[string]handlers.NewChannelHandler{ + "session": sessionChannelHandler, + } + + serverNetConn, clientNetConn := test_helpers.Pipe() + + sshd = daemon.New(logger, serverSSHConfig, nil, newChannelHandlers) + connectionFinished = make(chan struct{}) + go func() { + sshd.HandleConnection(serverNetConn) + close(connectionFinished) + }() + + client = test_helpers.NewClient(clientNetConn, nil) + }) + + AfterEach(func() { + if client != nil { + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + } + Eventually(connectionFinished).Should(BeClosed()) + }) + + Context("when a session is opened", func() { + It("doesn't accept sessions", func() { + _, sessionErr := client.NewSession() + + Expect(sessionErr).To(MatchError(ContainSubstring("not supported on windows"))) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2016.go b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2016.go new file mode 100644 index 0000000000..1c2c6dcdbd --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2016.go @@ -0,0 +1,639 @@ +//go:build windows && !windows2012R2 + +package handlers + +import ( + "errors" + "fmt" + "os" + "os/exec" + "regexp" + "sync" + "syscall" + "time" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/scp" + "code.cloudfoundry.org/diego-ssh/signals" + "code.cloudfoundry.org/diego-ssh/winpty" + "code.cloudfoundry.org/lager/v3" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +var scpRegex = regexp.MustCompile(`^\s*scp($|\s+)`) + +type SessionChannelHandler struct { + runner Runner + shellLocator ShellLocator + defaultEnv map[string]string + keepalive time.Duration + winPTYDLLDir string +} + +func NewSessionChannelHandler( + runner Runner, + shellLocator ShellLocator, + defaultEnv map[string]string, + keepalive time.Duration, +) *SessionChannelHandler { + winPTYDLLDir := os.Getenv("WINPTY_DLL_DIR") + return &SessionChannelHandler{ + runner: runner, + shellLocator: shellLocator, + defaultEnv: defaultEnv, + keepalive: keepalive, + winPTYDLLDir: winPTYDLLDir, + } +} + +func (handler *SessionChannelHandler) HandleNewChannel(logger lager.Logger, newChannel ssh.NewChannel) { + channel, requests, err := newChannel.Accept() + if err != nil { + logger.Error("handle-new-session-channel-failed", err) + return + } + + handler.newSession(logger, channel, handler.keepalive).serviceRequests(requests) +} + +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +type session struct { + logger lager.Logger + complete bool + keepaliveDuration time.Duration + keepaliveStopCh chan struct{} + + shellPath string + runner Runner + channel ssh.Channel + + sync.Mutex + env map[string]string + command *exec.Cmd + + wg sync.WaitGroup + allocPty bool + ptyRequest ptyRequestMsg + + winpty *winpty.WinPTY + winPTYDLLDir string +} + +func (handler *SessionChannelHandler) newSession(logger lager.Logger, channel ssh.Channel, keepalive time.Duration) *session { + return &session{ + logger: logger.Session("session-channel"), + keepaliveDuration: keepalive, + runner: handler.runner, + shellPath: handler.shellLocator.ShellPath(), + channel: channel, + env: handler.defaultEnv, + winPTYDLLDir: handler.winPTYDLLDir, + } +} + +func (sess *session) serviceRequests(requests <-chan *ssh.Request) { + logger := sess.logger + logger.Info("starting") + defer logger.Info("finished") + + defer sess.destroy() + + for req := range requests { + sess.logger.Info("received-request", lager.Data{"type": req.Type}) + switch req.Type { + case "env": + sess.handleEnvironmentRequest(req) + case "signal": + sess.handleSignalRequest(req) + case "pty-req": + sess.handlePtyRequest(req) + case "window-change": + sess.handleWindowChangeRequest(req) + case "exec": + sess.handleExecRequest(req) + case "shell": + sess.handleShellRequest(req) + case "subsystem": + sess.handleSubsystemRequest(req) + default: + if req.WantReply { + req.Reply(false, nil) + } + } + } +} + +func (sess *session) handleEnvironmentRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-environment-request") + + type envMsg struct { + Name string + Value string + } + var envMessage envMsg + + err := ssh.Unmarshal(request.Payload, &envMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + request.Reply(false, nil) + return + } + + sess.Lock() + sess.env[envMessage.Name] = envMessage.Value + sess.Unlock() + + if request.WantReply { + request.Reply(true, nil) + } +} + +func (sess *session) handleSignalRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-signal-request") + + type signalMsg struct { + Signal string + } + var signalMessage signalMsg + + err := ssh.Unmarshal(request.Payload, &signalMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + sess.Lock() + defer sess.Unlock() + + cmd := sess.command + + if cmd != nil { + var err error + signal := signals.SyscallSignals[ssh.Signal(signalMessage.Signal)] + if sess.winpty != nil { + err = sess.winpty.Signal(signal) + } else { + err = sess.runner.Signal(cmd, signal) + } + if err != nil { + logger.Error("process-signal-failed", err) + } + } + + if request.WantReply { + request.Reply(true, nil) + } +} + +func (sess *session) handlePtyRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-pty-request") + + var ptyRequestMessage ptyRequestMsg + + err := ssh.Unmarshal(request.Payload, &ptyRequestMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + sess.Lock() + defer sess.Unlock() + + sess.allocPty = true + sess.winpty, err = winpty.New(sess.winPTYDLLDir) + if err != nil { + logger.Error("couldn't intialize winpty.dll", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + sess.ptyRequest = ptyRequestMessage + sess.env["TERM"] = ptyRequestMessage.Term + + if request.WantReply { + request.Reply(true, nil) + } +} + +func (sess *session) handleWindowChangeRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-window-change") + + type windowChangeMsg struct { + Columns uint32 + Rows uint32 + WidthPx uint32 + HeightPx uint32 + } + var windowChangeMessage windowChangeMsg + + err := ssh.Unmarshal(request.Payload, &windowChangeMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + sess.Lock() + defer sess.Unlock() + + if sess.allocPty { + sess.ptyRequest.Columns = windowChangeMessage.Columns + sess.ptyRequest.Rows = windowChangeMessage.Rows + } + + if sess.winpty != nil { + err = setWindowSize(logger, sess.winpty, sess.ptyRequest.Columns, sess.ptyRequest.Rows) + if err != nil { + logger.Error("failed-to-set-window-size", err) + } + } + + if request.WantReply { + request.Reply(true, nil) + } +} + +func (sess *session) handleExecRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-exec-request") + + type execMsg struct { + Command string + } + var execMessage execMsg + + err := ssh.Unmarshal(request.Payload, &execMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + if scpRegex.MatchString(execMessage.Command) { + logger.Info("handling-scp-command", lager.Data{"Command": execMessage.Command}) + sess.executeSCP(execMessage.Command, request) + } else { + sess.executeShell(request, "/c", execMessage.Command) + } +} + +func (sess *session) handleShellRequest(request *ssh.Request) { + sess.executeShell(request) +} + +func (sess *session) handleSubsystemRequest(request *ssh.Request) { + logger := sess.logger.Session("handle-subsystem-request") + logger.Info("starting") + defer logger.Info("finished") + + type subsysMsg struct { + Subsystem string + } + var subsystemMessage subsysMsg + + err := ssh.Unmarshal(request.Payload, &subsystemMessage) + if err != nil { + logger.Error("unmarshal-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + if subsystemMessage.Subsystem != "sftp" { + logger.Info("unsupported-subsystem", lager.Data{"subsystem": subsystemMessage.Subsystem}) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + lagerWriter := helpers.NewLagerWriter(logger.Session("sftp-server")) + sftpServer, err := sftp.NewServer(sess.channel, sftp.WithDebug(lagerWriter)) + if err != nil { + logger.Error("sftp-new-server-failed", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + if request.WantReply { + request.Reply(true, nil) + } + + logger.Info("starting-server") + go func() { + defer sess.destroy() + err = sftpServer.Serve() + if err != nil { + logger.Error("sftp-serve-error", err) + } + }() +} + +func (sess *session) executeShell(request *ssh.Request, args ...string) { + logger := sess.logger.Session("execute-shell") + + sess.Lock() + cmd, err := sess.createCommand(args...) + if err != nil { + sess.Unlock() + logger.Error("failed-to-create-command", err) + if request.WantReply { + request.Reply(false, nil) + } + return + } + + if request.WantReply { + request.Reply(true, nil) + } + + if sess.allocPty { + err = sess.runWithPty(cmd) + } else { + err = sess.run(cmd) + } + + sess.Unlock() + + if err != nil { + sess.sendExitMessage(err) + sess.destroy() + return + } + + go func() { + err := sess.wait(cmd) + sess.sendExitMessage(err) + sess.destroy() + }() +} + +func (sess *session) createCommand(args ...string) (*exec.Cmd, error) { + if sess.command != nil { + return nil, errors.New("command already started") + } + + cmd := exec.Command(sess.shellPath, args...) + cmd.Env = sess.environment() + sess.command = cmd + + return cmd, nil +} + +func (sess *session) environment() []string { + env := []string{} + + env = append(env, `PATH=C:\Windows\system32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`) + env = append(env, "LANG=en_US.UTF8") + + for k, v := range sess.env { + if k != "HOME" && k != "USER" { + env = append(env, fmt.Sprintf("%s=%s", k, v)) + } + } + + env = append(env, fmt.Sprintf("HOME=%s", os.Getenv("HOME"))) + env = append(env, fmt.Sprintf("USER=%s", os.Getenv("USER"))) + + return env +} + +type exitStatusMsg struct { + Status uint32 +} + +type exitSignalMsg struct { + Signal string + CoreDumped bool + Error string + Lang string +} + +func (sess *session) sendExitMessage(err error) { + logger := sess.logger.Session("send-exit-message") + logger.Info("started") + defer logger.Info("finished") + + if err != nil { + logger.Error("building-exit-message-from-error", err) + } + + if err == nil { + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitStatusMsg{})) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + var exitCode uint32 + winptyError, ok := err.(*winpty.ExitError) + if ok { + exitCode = winptyError.WaitStatus.ExitCode + } else { + exitError, ok := err.(*exec.ExitError) + if !ok { + exitMessage := exitStatusMsg{Status: 255} + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + waitStatus, ok := exitError.Sys().(syscall.WaitStatus) + if !ok { + exitMessage := exitStatusMsg{Status: 255} + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + if waitStatus.Signaled() { + exitMessage := exitSignalMsg{ + Signal: string(signals.SSHSignals[waitStatus.Signal()]), + CoreDumped: waitStatus.CoreDump(), + } + _, sendErr := sess.channel.SendRequest("exit-signal", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } + return + } + + exitCode = uint32(waitStatus.ExitStatus()) + } + + exitMessage := exitStatusMsg{Status: exitCode} + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } +} + +func setWindowSize(logger lager.Logger, pty *winpty.WinPTY, columns, rows uint32) error { + logger.Info("new-size", lager.Data{"columns": columns, "rows": rows}) + return pty.SetWinsize(columns, rows) +} + +func (sess *session) run(command *exec.Cmd) error { + logger := sess.logger.Session("run") + + command.Stdout = sess.channel + command.Stderr = sess.channel.Stderr() + + stdin, err := command.StdinPipe() + if err != nil { + return err + } + + go helpers.CopyAndClose(logger.Session("to-stdin"), nil, stdin, sess.channel, func() { stdin.Close() }) + + return sess.runner.Start(command) +} + +func (sess *session) runWithPty(command *exec.Cmd) error { + var err error + logger := sess.logger.Session("run") + + if err := sess.winpty.Open(); err != nil { + logger.Error("failed-to-open-pty", err) + return err + } + + setWindowSize(logger, sess.winpty, sess.ptyRequest.Columns, sess.ptyRequest.Rows) + + sess.wg.Add(1) + go helpers.Copy(logger.Session("to-pty"), nil, sess.winpty.StdIn, sess.channel) + go func() { + helpers.Copy(logger.Session("from-pty-out"), &sess.wg, sess.channel, sess.winpty.StdOut) + sess.channel.CloseWrite() + }() + + err = sess.winpty.Run(command) + if err == nil { + sess.keepaliveStopCh = make(chan struct{}) + go sess.keepalive(sess.keepaliveStopCh) + } + return err +} + +func (sess *session) keepalive(stopCh chan struct{}) { + logger := sess.logger.Session("keepalive") + + ticker := time.NewTicker(sess.keepaliveDuration) + defer ticker.Stop() + for { + select { + case <-ticker.C: + _, err := sess.channel.SendRequest("keepalive@cloudfoundry.org", true, nil) + logger.Info("keepalive", lager.Data{"success": err == nil}) + + if err != nil { + err = sess.winpty.Signal(syscall.SIGINT) + logger.Info("process-signaled", lager.Data{"error": err}) + return + } + case <-stopCh: + return + } + } +} + +func (sess *session) wait(command *exec.Cmd) error { + logger := sess.logger.Session("wait") + logger.Info("started") + defer logger.Info("done") + if sess.allocPty { + return sess.winpty.Wait() + } else { + return sess.runner.Wait(command) + } +} + +func (sess *session) destroy() { + logger := sess.logger.Session("destroy") + logger.Info("started") + defer logger.Info("done") + + sess.Lock() + defer sess.Unlock() + + if sess.complete { + return + } + + sess.complete = true + sess.wg.Wait() + + if sess.channel != nil { + sess.channel.Close() + } + + if sess.winpty != nil { + sess.winpty.Close() + sess.winpty = nil + } + + if sess.keepaliveStopCh != nil { + close(sess.keepaliveStopCh) + } +} + +func (sess *session) executeSCP(command string, request *ssh.Request) { + logger := sess.logger.Session("execute-scp") + + if request.WantReply { + request.Reply(true, nil) + } + + copier, err := scp.NewFromCommand(command, sess.channel, sess.channel, sess.channel.Stderr(), logger) + if err == nil { + err = copier.Copy() + } + + sess.sendSCPExitMessage(err) + sess.destroy() +} + +func (sess *session) sendSCPExitMessage(err error) { + logger := sess.logger.Session("send-scp-exit-message") + logger.Info("started") + defer logger.Info("finished") + + var exitMessage exitStatusMsg + if err != nil { + logger.Error("building-scp-exit-message-from-error", err) + exitMessage = exitStatusMsg{Status: 1} + } + + _, sendErr := sess.channel.SendRequest("exit-status", false, ssh.Marshal(exitMessage)) + if sendErr != nil { + logger.Error("send-exit-status-failed", sendErr) + } +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2016_test.go b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2016_test.go new file mode 100644 index 0000000000..26e5c4c440 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/session_channel_handler_windows2016_test.go @@ -0,0 +1,997 @@ +//go:build windows && !windows2012R2 + +package handlers_test + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/fakes" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/pkg/sftp" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("SessionChannelHandler", func() { + type command struct { + path string + args []string + } + + var ( + sshd *daemon.Daemon + client *ssh.Client + + commandsRan chan command + + logger *lagertest.TestLogger + serverSSHConfig *ssh.ServerConfig + + runner *fakes.FakeRunner + shellLocator *fakes.FakeShellLocator + sessionChannelHandler *handlers.SessionChannelHandler + + newChannelHandlers map[string]handlers.NewChannelHandler + defaultEnv map[string]string + connectionFinished chan struct{} + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + serverSSHConfig = &ssh.ServerConfig{ + NoClientAuth: true, + } + serverSSHConfig.AddHostKey(TestHostKey) + + commandsRan = make(chan command, 10) + + runner = &fakes.FakeRunner{} + realRunner := handlers.NewCommandRunner() + runner.StartStub = func(cmd *exec.Cmd) error { + commandsRan <- command{ + path: strings.ToLower(cmd.Path), + args: cmd.Args, + } + return realRunner.Start(cmd) + } + runner.WaitStub = realRunner.Wait + runner.SignalStub = realRunner.Signal + + shellLocator = &fakes.FakeShellLocator{} + shellLocator.ShellPathReturns("cmd.exe") + + defaultEnv = map[string]string{} + for _, env := range os.Environ() { + k := strings.Split(env, "=")[0] + v := strings.Split(env, "=")[1] + defaultEnv[k] = v + } + defaultEnv["TEST"] = "FOO" + + delete(defaultEnv, "Path") + delete(defaultEnv, "PATH") + + sessionChannelHandler = handlers.NewSessionChannelHandler(runner, shellLocator, defaultEnv, time.Second) + + newChannelHandlers = map[string]handlers.NewChannelHandler{ + "session": sessionChannelHandler, + } + + serverNetConn, clientNetConn := test_helpers.Pipe() + + sshd = daemon.New(logger, serverSSHConfig, nil, newChannelHandlers) + connectionFinished = make(chan struct{}) + go func() { + sshd.HandleConnection(serverNetConn) + close(connectionFinished) + }() + + client = test_helpers.NewClient(clientNetConn, nil) + }) + + AfterEach(func() { + if client != nil { + err := client.Close() + Expect(err).NotTo(HaveOccurred()) + } + Eventually(connectionFinished).Should(BeClosed()) + }) + + Context("when a session is opened", func() { + var session *ssh.Session + + BeforeEach(func() { + var sessionErr error + session, sessionErr = client.NewSession() + + Expect(sessionErr).NotTo(HaveOccurred()) + }) + + It("can use the session to execute a command with stdout and stderr", func() { + stdout, err := session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + stderr, err := session.StderrPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Run("echo Hello && echo Goodbye 1>&2") + Expect(err).NotTo(HaveOccurred()) + + stdoutBytes, err := io.ReadAll(stdout) + Expect(err).NotTo(HaveOccurred()) + Expect(string(stdoutBytes)).To(ContainSubstring("Hello")) + Expect(string(stdoutBytes)).NotTo(ContainSubstring("Goodbye")) + + stderrBytes, err := io.ReadAll(stderr) + Expect(err).NotTo(HaveOccurred()) + Expect(string(stderrBytes)).To(ContainSubstring("Goodbye")) + Expect(string(stderrBytes)).NotTo(ContainSubstring("Hello")) + }) + + It("returns when the process exits", func() { + stdin, err := session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Run("dir") + Expect(err).NotTo(HaveOccurred()) + + stdin.Close() + }) + + Describe("scp", func() { + var ( + sourceDir, generatedTextFile, targetDir string + err error + stdin io.WriteCloser + stdout io.Reader + fileContents []byte + ) + + BeforeEach(func() { + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + sourceDir, err = os.MkdirTemp("", "scp-source") + Expect(err).NotTo(HaveOccurred()) + + fileContents = []byte("---\nthis is a simple file\n\n") + generatedTextFile = filepath.Join(sourceDir, "textfile.txt") + + err = os.WriteFile(generatedTextFile, fileContents, 0664) + Expect(err).NotTo(HaveOccurred()) + + targetDir, err = os.MkdirTemp("", "scp-target") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(os.RemoveAll(sourceDir)).To(Succeed()) + Expect(os.RemoveAll(targetDir)).To(Succeed()) + }) + + It("properly copies using the secure copier", func() { + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := session.Run(fmt.Sprintf("scp -v -t %s", strings.Replace(targetDir, `\`, `\\`, -1))) + Expect(err).NotTo(HaveOccurred()) + close(done) + }() + + confirmation := make([]byte, 1) + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + expectedFileInfo, err := os.Stat(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdin.Write([]byte(fmt.Sprintf("C0664 %d textfile.txt\n", expectedFileInfo.Size()))) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + _, err = stdin.Write(fileContents) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdin.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + err = stdin.Close() + Expect(err).NotTo(HaveOccurred()) + + actualFilePath := filepath.Join(targetDir, filepath.Base(generatedTextFile)) + actualFileInfo, err := os.Stat(actualFilePath) + Expect(err).NotTo(HaveOccurred()) + + Expect(actualFileInfo.Mode()).To(Equal(expectedFileInfo.Mode())) + Expect(actualFileInfo.Size()).To(Equal(expectedFileInfo.Size())) + + actualContents, err := os.ReadFile(actualFilePath) + Expect(err).NotTo(HaveOccurred()) + + expectedContents, err := os.ReadFile(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + + Expect(actualContents).To(Equal(expectedContents)) + + Eventually(done).Should(BeClosed()) + }) + + It("properly fails when secure copying fails", func() { + errCh := make(chan error) + go func() { + defer GinkgoRecover() + errCh <- session.Run(fmt.Sprintf("scp -v -t %s", strings.Replace(targetDir, `\`, `\\`, -1))) + }() + + confirmation := make([]byte, 1) + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{0})) + + _, err = stdin.Write([]byte("BOGUS PROTOCOL MESSAGE\n")) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(confirmation).To(Equal([]byte{1})) + + err = <-errCh + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).To(Equal(1)) + }) + + It("properly fails when incorrect arguments are supplied", func() { + err := session.Run("scp -v -t /tmp/foo /tmp/bar") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).To(Equal(1)) + }) + }) + + Describe("the shell locator", func() { + BeforeEach(func() { + err := session.Run("exit 0") + Expect(err).NotTo(HaveOccurred()) + }) + + It("uses the shell locator to find the default shell path", func() { + Expect(shellLocator.ShellPathCallCount()).To(Equal(1)) + + Eventually(commandsRan).Should(Receive(Equal(command{ + path: "c:\\windows\\system32\\cmd.exe", + args: []string{"cmd.exe", "/c", "exit 0"}, + }))) + }) + }) + + Context("when stdin is provided by the client", func() { + BeforeEach(func() { + session.Stdin = strings.NewReader("Hello") + }) + + It("can use the session to execute a command that reads it", func() { + result, err := session.Output(`findstr x*`) + + Expect(err).NotTo(HaveOccurred()) + Expect(strings.TrimSpace(string(result))).To(Equal("Hello")) + }) + }) + + Context("when the command exits with a non-zero value", func() { + It("it preserve the exit code", func() { + err := session.Run("exit 3") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).To(Equal(3)) + }) + }) + + Context("when SIGKILL is sent across the session", func() { + Context("before a command has been run", func() { + BeforeEach(func() { + err := session.Signal(ssh.SIGKILL) + Expect(err).NotTo(HaveOccurred()) + }) + + It("does not prevent the command from running", func() { + result, err := session.Output("echo still kicking") + Expect(err).NotTo(HaveOccurred()) + Expect(strings.TrimSpace(string(result))).To(Equal(strings.TrimSpace("still kicking"))) + }) + }) + + Context("while a command is running", func() { + var stdin io.WriteCloser + var stdout io.Reader + + BeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + + reader := bufio.NewReader(stdout) + Eventually(reader.ReadLine).Should(ContainSubstring("Microsoft Windows")) + + Eventually(runner.StartCallCount).Should(Equal(1)) + }) + + It("is sent to the process", func() { + err := session.Signal(ssh.SIGKILL) + Expect(err).NotTo(HaveOccurred()) + + Eventually(runner.SignalCallCount).Should(Equal(1)) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + err = session.Wait() + Expect(err).To(HaveOccurred()) + Expect(err.(*ssh.ExitError).ExitStatus()).To(Equal(1)) + }) + }) + }) + + Context("when running a command without an explicit environemnt", func() { + It("does not inherit daemon's environment", func() { + os.Setenv("DAEMON_ENV", "daemon_env_value") + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).NotTo(ContainSubstring("DAEMON_ENV=daemon_env_value")) + os.Unsetenv("DAEMON_ENV") + }) + + It("includes a default environment excluding PATH", func() { + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring(`PATH=C:\Windows\system32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`)) + Expect(result).To(ContainSubstring("LANG=en_US.UTF8")) + Expect(result).To(ContainSubstring("TEST=FOO")) + Expect(result).To(ContainSubstring(fmt.Sprintf("HOME=%s", os.Getenv("HOME")))) + Expect(result).To(ContainSubstring(fmt.Sprintf("USER=%s", os.Getenv("USER")))) + }) + }) + + Context("when environment variables are requested", func() { + Context("before starting the command", func() { + It("runs the command with the specified environment", func() { + err := session.Setenv("ENV1", "value1") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("ENV2", "value2") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("ENV1=value1")) + Expect(result).To(ContainSubstring("ENV2=value2")) + }) + + It("uses the value last specified", func() { + err := session.Setenv("ENV1", "original") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("ENV1", "updated") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("ENV1=updated")) + }) + + It("can override PATH and LANG", func() { + err := session.Setenv("PATH", "/bin:/usr/local/bin:/sbin") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("LANG", "en_UK.UTF8") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("PATH=/bin:/usr/local/bin:/sbin")) + Expect(result).To(ContainSubstring("LANG=en_UK.UTF8")) + }) + + It("cannot override HOME and USER", func() { + err := session.Setenv("HOME", "/some/other/home") + Expect(err).NotTo(HaveOccurred()) + + err = session.Setenv("USER", "not-a-user") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring(fmt.Sprintf("HOME=%s", os.Getenv("HOME")))) + Expect(result).To(ContainSubstring(fmt.Sprintf("USER=%s", os.Getenv("USER")))) + }) + + It("can override default env variables", func() { + err := session.Setenv("TEST", "BAR") + Expect(err).NotTo(HaveOccurred()) + + result, err := session.Output("set") + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(ContainSubstring("TEST=BAR")) + }) + }) + + Context("after starting the command", func() { + var stdin io.WriteCloser + var stdout io.Reader + + BeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Start(`findstr x* & set`) + Expect(err).NotTo(HaveOccurred()) + }) + + It("ignores the request", func() { + err := session.Setenv("ENV3", "value3") + Expect(err).NotTo(HaveOccurred()) + + stdin.Close() + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + + stdoutBytes, err := io.ReadAll(stdout) + Expect(err).NotTo(HaveOccurred()) + + Expect(string(stdoutBytes)).NotTo(ContainSubstring("ENV3")) + }) + }) + }) + + Context("when a pty request is received", func() { + var terminalModes ssh.TerminalModes + + BeforeEach(func() { + terminalModes = ssh.TerminalModes{} + }) + + JustBeforeEach(func() { + err := session.RequestPty("vt100", 43, 80, terminalModes) + Expect(err).NotTo(HaveOccurred()) + }) + + It("should allocate a console for the session", func() { + result, err := session.Output("timeout 1 2>nul >nul & if errorlevel 1 (echo redirect) else (echo console)") + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring("console")) + }) + + It("returns when the process exits", func() { + stdin, err := session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Run("dir") + Expect(err).NotTo(HaveOccurred()) + + stdin.Close() + }) + + It("terminates the shell when the stdin closes", func() { + err := session.Shell() + Expect(err).NotTo(HaveOccurred()) + time.Sleep(1 * time.Second) + + err = client.Conn.Close() + client = nil + Expect(err).NotTo(HaveOccurred()) + err = session.Wait() + Expect(err.Error()).To(Equal("wait: remote command exited without exit status or exit signal")) + }) + + It("should set the terminal type", func() { + result, err := session.Output(`echo %TERM%`) + Expect(err).NotTo(HaveOccurred()) + + Expect(string(result)).To(ContainSubstring("vt100")) + }) + + It("sets the correct window size for the terminal", func() { + result, err := session.Output("powershell.exe -command $w = $host.ui.rawui.WindowSize.Width; $h = $host.ui.rawui.WindowSize.Height; echo \"$h $w\"") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(ContainSubstring("43 80")) + }) + + Context("when an interactive command is executed", func() { + var stdin io.WriteCloser + + JustBeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + }) + + It("terminates the session when the shell exits", func() { + err := session.Start("cmd.exe") + Expect(err).NotTo(HaveOccurred()) + + _, err = stdin.Write([]byte("exit\r\n")) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + Expect(err).NotTo(HaveOccurred()) + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("when a signal is sent across the session", func() { + Context("before a command has been run", func() { + BeforeEach(func() { + err := session.Signal(ssh.SIGKILL) + Expect(err).NotTo(HaveOccurred()) + }) + + It("does not prevent the command from running", func() { + result, err := session.Output("echo still kicking") + Expect(err).NotTo(HaveOccurred()) + Expect(string(result)).To(ContainSubstring("still kicking")) + }) + }) + + Context("SIGKILL is sent while a command is running", func() { + var stdin io.WriteCloser + var stdout io.Reader + + JustBeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + + reader := bufio.NewReader(stdout) + Eventually(reader.ReadLine).Should(ContainSubstring("Microsoft Windows")) + }) + + It("kills the process", func() { + err := session.Signal(ssh.SIGKILL) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + err = session.Wait() + Expect(err).To(HaveOccurred()) + Expect(err.(*ssh.ExitError).ExitStatus()).To(Equal(1)) + }) + }) + + Context("SIGINT is sent while a command is running", func() { + var stdin io.WriteCloser + var stdout io.Reader + + JustBeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + stdout, err = session.StdoutPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Start("ping -t 127.0.0.1 & echo goodbye") + Expect(err).NotTo(HaveOccurred()) + + reader := bufio.NewReader(stdout) + Eventually(reader.ReadLine).Should(ContainSubstring("127.0.0.1")) + }) + + It("the process is interrupted", func() { + err := session.Signal(ssh.SIGINT) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + resultCh := make(chan error) + go func() { + resultCh <- session.Wait() + }() + Eventually(resultCh).Should(Receive(BeNil())) + reader := bufio.NewReader(stdout) + Eventually(reader.ReadLine).Should(ContainSubstring("goodbye")) + }) + }) + }) + }) + + Context("when a window change request is received", func() { + type winChangeMsg struct { + Columns uint32 + Rows uint32 + WidthPx uint32 + HeightPx uint32 + } + + var result []byte + + Context("before a pty is allocated", func() { + BeforeEach(func() { + _, err := session.SendRequest("window-change", false, ssh.Marshal(winChangeMsg{ + Rows: 50, + Columns: 132, + })) + Expect(err).NotTo(HaveOccurred()) + + err = session.RequestPty("vt100", 43, 80, ssh.TerminalModes{}) + Expect(err).NotTo(HaveOccurred()) + + result, err = session.Output("powershell.exe -command $w = $host.ui.rawui.WindowSize.Width; $h = $host.ui.rawui.WindowSize.Height; echo \"$h $w\"") + Expect(err).NotTo(HaveOccurred()) + }) + + It("ignores the request", func() { + Expect(result).To(ContainSubstring("43 80")) + }) + }) + + Context("after a pty is allocated", func() { + BeforeEach(func() { + err := session.RequestPty("vt100", 43, 80, ssh.TerminalModes{}) + Expect(err).NotTo(HaveOccurred()) + + _, err = session.SendRequest("window-change", false, ssh.Marshal(winChangeMsg{ + Rows: 50, + Columns: 132, + })) + Expect(err).NotTo(HaveOccurred()) + + result, err = session.Output("powershell.exe -command $w = $host.ui.rawui.WindowSize.Width; $h = $host.ui.rawui.WindowSize.Height; echo \"$h $w\"") + Expect(err).NotTo(HaveOccurred()) + }) + + It("changes the the size of the terminal", func() { + Expect(result).To(ContainSubstring("50 132")) + }) + }) + }) + + Context("after executing a command", func() { + BeforeEach(func() { + err := session.Run("exit") + Expect(err).NotTo(HaveOccurred()) + }) + + It("the session is no longer usable", func() { + _, err := session.SendRequest("exec", true, ssh.Marshal(struct{ Command string }{Command: "exit"})) + Expect(err).To(HaveOccurred()) + + _, err = session.SendRequest("bogus", true, nil) + Expect(err).To(HaveOccurred()) + + err = session.Setenv("foo", "bar") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when an interactive shell is requested", func() { + var stdin io.WriteCloser + + BeforeEach(func() { + var err error + stdin, err = session.StdinPipe() + Expect(err).NotTo(HaveOccurred()) + + err = session.Shell() + Expect(err).NotTo(HaveOccurred()) + }) + + It("starts the shell with the runner", func() { + Eventually(runner.StartCallCount).Should(Equal(1)) + + Eventually(commandsRan).Should(Receive(Equal(command{ + path: "c:\\windows\\system32\\cmd.exe", + args: []string{"cmd.exe"}, + }))) + }) + + It("terminates the session when the shell exits", func() { + _, err := stdin.Write([]byte("exit\n")) + Expect(err).NotTo(HaveOccurred()) + + err = stdin.Close() + if err != nil { + Expect(err).To(Equal(io.EOF), "expected no error or ignorable EOF error") + } + + err = session.Wait() + Expect(err).NotTo(HaveOccurred()) + }) + }) + + Context("and a command is provided", func() { + BeforeEach(func() { + err := session.Run("exit") + Expect(err).NotTo(HaveOccurred()) + }) + + It("uses the provided runner to start the command", func() { + Expect(runner.StartCallCount()).To(Equal(1)) + Expect(runner.WaitCallCount()).To(Equal(1)) + }) + + It("passes the correct command to the runner", func() { + Eventually(commandsRan).Should(Receive(Equal(command{ + path: "c:\\windows\\system32\\cmd.exe", + args: []string{"cmd.exe", "/c", "exit"}, + }))) + }) + + It("passes the same command to Start and Wait", func() { + command := runner.StartArgsForCall(0) + Expect(runner.WaitArgsForCall(0)).To(Equal(command)) + }) + }) + + Context("when executing an invalid command", func() { + It("returns an exit error with a non-zero exit status", func() { + err := session.Run("not-a-command") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).NotTo(Equal(0)) + }) + + Context("when starting the command fails", func() { + BeforeEach(func() { + runner.StartReturns(errors.New("oops")) + }) + + It("returns an exit status message with a non-zero status", func() { + err := session.Run("true") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).NotTo(Equal(0)) + }) + }) + + Context("when waiting on the command fails", func() { + BeforeEach(func() { + runner.WaitReturns(errors.New("oops")) + }) + + It("returns an exit status message with a non-zero status", func() { + err := session.Run("true") + Expect(err).To(HaveOccurred()) + + exitErr, ok := err.(*ssh.ExitError) + Expect(ok).To(BeTrue()) + Expect(exitErr.ExitStatus()).NotTo(Equal(0)) + }) + }) + }) + + Context("when an unknown request type is sent", func() { + var accepted bool + + BeforeEach(func() { + var err error + accepted, err = session.SendRequest("unknown-request-type", true, []byte("payload")) + Expect(err).NotTo(HaveOccurred()) + }) + + It("rejects the request", func() { + Expect(accepted).To(BeFalse()) + }) + + It("does not terminate the session", func() { + response, err := session.Output("echo Hello") + Expect(err).NotTo(HaveOccurred()) + Expect(strings.TrimSpace(string(response))).To(Equal("Hello")) + }) + }) + + Context("when an unknown subsystem is requested", func() { + var accepted bool + + BeforeEach(func() { + type subsysMsg struct{ Subsystem string } + + var err error + accepted, err = session.SendRequest("subsystem", true, ssh.Marshal(subsysMsg{Subsystem: "unknown"})) + Expect(err).NotTo(HaveOccurred()) + }) + + It("rejects the request", func() { + Expect(accepted).To(BeFalse()) + }) + + It("does not terminate the session", func() { + response, err := session.Output("echo Hello") + Expect(err).NotTo(HaveOccurred()) + Expect(strings.TrimSpace(string(response))).To(Equal("Hello")) + }) + }) + }) + + Context("when the sftp subystem is requested", func() { + It("accepts the request", func() { + type subsysMsg struct{ Subsystem string } + session, err := client.NewSession() + Expect(err).NotTo(HaveOccurred()) + defer session.Close() + + accepted, err := session.SendRequest("subsystem", true, ssh.Marshal(subsysMsg{Subsystem: "sftp"})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeTrue()) + }) + + It("starts an sftp server in write mode", func() { + tempDir, err := os.MkdirTemp("", "sftp") + Expect(err).NotTo(HaveOccurred()) + defer os.RemoveAll(tempDir) + + sftp, err := sftp.NewClient(client) + Expect(err).NotTo(HaveOccurred()) + defer sftp.Close() + + By("creating the file") + target := filepath.Join(tempDir, "textfile.txt") + file, err := sftp.Create(target) + Expect(err).NotTo(HaveOccurred()) + + fileContents := []byte("---\nthis is a simple file\n\n") + _, err = file.Write(fileContents) + Expect(err).NotTo(HaveOccurred()) + + err = file.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(os.ReadFile(target)).To(Equal(fileContents)) + + By("reading the file") + file, err = sftp.Open(target) + Expect(err).NotTo(HaveOccurred()) + + buffer := &bytes.Buffer{} + _, err = buffer.ReadFrom(file) + Expect(err).NotTo(HaveOccurred()) + + err = file.Close() + Expect(err).NotTo(HaveOccurred()) + + Expect(buffer.Bytes()).To(Equal(fileContents)) + + By("removing the file") + err = sftp.Remove(target) + Expect(err).NotTo(HaveOccurred()) + + _, err = os.Stat(target) + Expect(err).To(HaveOccurred()) + Expect(os.IsNotExist(err)).To(BeTrue()) + }) + }) + + Describe("invalid session channel requests", func() { + var channel ssh.Channel + var requests <-chan *ssh.Request + + BeforeEach(func() { + var err error + channel, requests, err = client.OpenChannel("session", nil) + Expect(err).NotTo(HaveOccurred()) + + go ssh.DiscardRequests(requests) + }) + + AfterEach(func() { + if channel != nil { + channel.Close() + } + }) + + Context("when an exec request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("exec", true, ssh.Marshal(struct{ Bogus uint32 }{Bogus: 1138})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when an env request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("env", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a signal request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("signal", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a pty request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("pty-req", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a window change request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("window-change", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + + Context("when a subsystem request fails to unmarshal", func() { + It("rejects the request", func() { + accepted, err := channel.SendRequest("subsystem", true, ssh.Marshal(struct{ Bogus int }{Bogus: 1234})) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeFalse()) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/shell_locator.go b/src/code.cloudfoundry.org/diego-ssh/handlers/shell_locator.go new file mode 100644 index 0000000000..e4526dccc8 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/shell_locator.go @@ -0,0 +1,21 @@ +//go:build !windows + +package handlers + +import "os/exec" + +type shellLocator struct{} + +func NewShellLocator() ShellLocator { + return &shellLocator{} +} + +func (shellLocator) ShellPath() string { + for _, shell := range []string{"/bin/bash", "/usr/local/bin/bash", "/bin/sh", "bash", "sh"} { + if path, err := exec.LookPath(shell); err == nil { + return path + } + } + + return "/bin/sh" +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/shell_locator_windows.go b/src/code.cloudfoundry.org/diego-ssh/handlers/shell_locator_windows.go new file mode 100644 index 0000000000..6de7b1b4f0 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/shell_locator_windows.go @@ -0,0 +1,21 @@ +//go:build windows + +package handlers + +import "os/exec" + +type shellLocator struct{} + +func NewShellLocator() ShellLocator { + return &shellLocator{} +} + +func (shellLocator) ShellPath() string { + for _, shell := range []string{"cmd.exe"} { + if path, err := exec.LookPath(shell); err == nil { + return path + } + } + + return "cmd.exe" +} diff --git a/src/code.cloudfoundry.org/diego-ssh/handlers/types.go b/src/code.cloudfoundry.org/diego-ssh/handlers/types.go new file mode 100644 index 0000000000..433eaa29bd --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/handlers/types.go @@ -0,0 +1,38 @@ +package handlers + +import ( + "net" + "os/exec" + "syscall" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +//go:generate counterfeiter -o fakes/fake_dialer.go . Dialer +type Dialer interface { + Dial(net, addr string) (net.Conn, error) +} + +//go:generate counterfeiter -o fake_handlers/fake_global_request_handler.go . GlobalRequestHandler +type GlobalRequestHandler interface { + HandleRequest(logger lager.Logger, request *ssh.Request, conn ssh.Conn, lnStore *helpers.ListenerStore) +} + +//go:generate counterfeiter -o fake_handlers/fake_new_channel_handler.go . NewChannelHandler +type NewChannelHandler interface { + HandleNewChannel(logger lager.Logger, newChannel ssh.NewChannel) +} + +//go:generate counterfeiter -o fakes/fake_runner.go . Runner +type Runner interface { + Start(cmd *exec.Cmd) error + Wait(cmd *exec.Cmd) error + Signal(cmd *exec.Cmd, signal syscall.Signal) error +} + +//go:generate counterfeiter -o fakes/fake_shell_locator.go . ShellLocator +type ShellLocator interface { + ShellPath() string +} diff --git a/src/code.cloudfoundry.org/diego-ssh/healthcheck/healthcheck_handler.go b/src/code.cloudfoundry.org/diego-ssh/healthcheck/healthcheck_handler.go new file mode 100644 index 0000000000..879b9ec7d1 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/healthcheck/healthcheck_handler.go @@ -0,0 +1,38 @@ +package healthcheck + +import ( + "net/http" + + "github.com/tedsuo/rata" + + "code.cloudfoundry.org/lager/v3" +) + +type HealthCheckHandler struct { + logger lager.Logger +} + +func NewHandler(logger lager.Logger) http.Handler { + routes := rata.Routes{ + {Name: "HealthCheck", Method: "GET", Path: "/"}, + } + + logger = logger.Session("healthcheck") + + actions := map[string]http.Handler{ + "HealthCheck": &HealthCheckHandler{logger: logger}, + } + + handler, err := rata.NewRouter(routes, actions) + if err != nil { + panic(err) + } + + return handler +} + +func (h *HealthCheckHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { + h.logger.Debug("started") + defer h.logger.Debug("finished") + writer.WriteHeader(http.StatusOK) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/healthcheck/package.go b/src/code.cloudfoundry.org/diego-ssh/healthcheck/package.go new file mode 100644 index 0000000000..318319a058 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/healthcheck/package.go @@ -0,0 +1 @@ +package healthcheck // import "code.cloudfoundry.org/diego-ssh/healthcheck" diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/copy.go b/src/code.cloudfoundry.org/diego-ssh/helpers/copy.go new file mode 100644 index 0000000000..a56ccf537e --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/copy.go @@ -0,0 +1,45 @@ +package helpers + +import ( + "io" + "sync" + + "code.cloudfoundry.org/lager/v3" +) + +func Copy(logger lager.Logger, wg *sync.WaitGroup, dest io.Writer, src io.Reader) { + logger = logger.Session("copy") + logger.Info("started") + defer func() { + if wg != nil { + wg.Done() + } + }() + + n, err := io.Copy(dest, src) + if err != nil { + logger.Error("copy-error", err) + } + + logger.Info("completed", lager.Data{"bytes-copied": n}) +} + +func CopyAndClose(logger lager.Logger, wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader, closeFunc func()) { + logger = logger.Session("copy-and-close") + logger.Info("started") + + defer func() { + closeFunc() + + if wg != nil { + wg.Done() + } + }() + + n, err := io.Copy(dest, src) + if err != nil { + logger.Error("copy-error", err) + } + + logger.Info("completed", lager.Data{"bytes-copied": n}) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/copy_test.go b/src/code.cloudfoundry.org/diego-ssh/helpers/copy_test.go new file mode 100644 index 0000000000..cfcf36491f --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/copy_test.go @@ -0,0 +1,94 @@ +package helpers_test + +import ( + "io" + "strings" + "sync" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagertest" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Copy", func() { + var logger lager.Logger + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + }) + + Describe("Copy", func() { + var reader io.Reader + var fakeWriter *fake_io.FakeWriter + var wg *sync.WaitGroup + + BeforeEach(func() { + reader = strings.NewReader("message") + fakeWriter = &fake_io.FakeWriter{} + wg = nil + }) + + JustBeforeEach(func() { + helpers.Copy(logger, wg, fakeWriter, reader) + }) + + It("copies from source to target", func() { + Expect(fakeWriter.WriteCallCount()).To(Equal(1)) + Expect(string(fakeWriter.WriteArgsForCall(0))).To(Equal("message")) + }) + + Context("when a wait group is provided", func() { + BeforeEach(func() { + wg = &sync.WaitGroup{} + wg.Add(1) + }) + + It("calls done before returning", func() { + wg.Wait() + }) + }) + }) + + Describe("CopyAndClose", func() { + var reader io.Reader + var fakeWriteCloser *fake_io.FakeWriteCloser + var wg *sync.WaitGroup + + BeforeEach(func() { + reader = strings.NewReader("message") + fakeWriteCloser = &fake_io.FakeWriteCloser{} + wg = nil + }) + + JustBeforeEach(func() { + closeFunc := func() { + fakeWriteCloser.Close() + } + helpers.CopyAndClose(logger, wg, fakeWriteCloser, reader, closeFunc) + }) + + It("copies from source to target", func() { + Expect(fakeWriteCloser.WriteCallCount()).To(Equal(1)) + Expect(string(fakeWriteCloser.WriteArgsForCall(0))).To(Equal("message")) + }) + + It("it calls the close function when complete", func() { + Expect(fakeWriteCloser.CloseCallCount()).To(Equal(1)) + }) + + Context("when a wait group is provided", func() { + BeforeEach(func() { + wg = &sync.WaitGroup{} + wg.Add(1) + }) + + It("calls done before returning", func() { + wg.Wait() + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/fingerprint.go b/src/code.cloudfoundry.org/diego-ssh/helpers/fingerprint.go new file mode 100644 index 0000000000..17b9fbf5c9 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/fingerprint.go @@ -0,0 +1,33 @@ +package helpers + +import ( + "crypto/md5" + "crypto/sha1" + "fmt" + "strings" + + "golang.org/x/crypto/ssh" +) + +const MD5_FINGERPRINT_LENGTH = 47 +const SHA1_FINGERPRINT_LENGTH = 59 +const SHA256_FINGERPRINT_LENGTH = 44 //unpadded base64 + +func MD5Fingerprint(key ssh.PublicKey) string { + md5sum := md5.Sum(key.Marshal()) + return colonize(fmt.Sprintf("% x", md5sum)) +} + +func SHA256Fingerprint(key ssh.PublicKey) string { + value := ssh.FingerprintSHA256(key) + return strings.TrimPrefix(value, "SHA256:") +} + +func SHA1Fingerprint(key ssh.PublicKey) string { + sha1sum := sha1.Sum(key.Marshal()) + return colonize(fmt.Sprintf("% x", sha1sum)) +} + +func colonize(s string) string { + return strings.Replace(s, " ", ":", -1) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/fingerprint_test.go b/src/code.cloudfoundry.org/diego-ssh/helpers/fingerprint_test.go new file mode 100644 index 0000000000..7bd8a18f7f --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/fingerprint_test.go @@ -0,0 +1,100 @@ +package helpers_test + +import ( + "fmt" + "unicode/utf8" + + "code.cloudfoundry.org/diego-ssh/helpers" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "golang.org/x/crypto/ssh" +) + +const ( + TestPrivateKeyPem = `-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEAx0y65jB977anY39jzB7AkojdAyqiADG4BTcXmKIy7w/GY/bi +Aq/AcO/SVAsq1iJ+SAmiQXt1K6kL8wUGtxlB1D+Ze0d0jw05Ep+O5rRF1dMDFUsA +0yrgfsUfryl7XOl9LmE1PKLinKExooZrfJTSqW1oRjHpIWZMtJj25glDtrfz+Wyd +6wPYH1CgOdQjSAORWYrhb9xYPzIxG5XMeMyO5xtL3sLEsdM1iROIl9rOu1qem1n+ +Xs/z+03tpIoQc4gjVvbykYdZQ5mvCgbxtPRVy7EmYSbQflmMu3TWZmT20ZWDCHzl +3eDdFMOpqRaRFg/TjlNjH20TQbruKumk7ldogQIDAQABAoIBABujCEfjcZNMQOoL +QEuN+CZZ1EwcHVrpihsvCJah524/QcOa+LxmoskGeKQu6EHJhrl2nIl4FUd4qa+J +guThG7/TEfWGcyNjMgbjGW3kkcqU+Fh7jiG6UGdD7qDbn7/CoRlNYZSHAeW2dKuU ++FLOUGguQ8d4JFv9U6W3kIVVw44StVMkQwh0TB9kh7yzeHrpVddaMPzVZUmCWm2Y +NPN5EZq96DmmcEQC7Gktj7kPgC5UWcc8wF2Xy74sZb3RKOeyc5e7ddMDLbNI5STr +iRT3Fg+bhWQhhMUQfvD9KSh/9IK0OGu/3SSb9WeEzMUdh5mho1IsERugaXsVlne8 +6JWW7gECgYEAzTTSJDRm8CiBSa4sn5KzLOHvn3YfSC91aERjrbZuVDdmVMJvhpLw +JW6/5zmz7X7Hr3mwHBSj+rS4/rIoEVvTjWrJm6GUSXXPwRwoJedbK67FU0MxBMzt +iqi+qBHdsKRhdrlM3W9RryGkcS1AkK+6B2Feu3GVGUQDz6G/yaTJ8UcCgYEA+KGh +D+PtdAd3s1sdAJlRuS4kCXCLbO/5EWfMMHVaewebpGs8bZnW4cpFaGR7zXd9Emkk +QuZWE7L44SNQrirECtGcu3zEKx1grYo+2jYoLYexiwOf6UEMWJEExLS475EDgmUJ +7Fy5tt2mwwV2GBXZfTHuQLOo9Zxjsf3NAKAZ+/cCgYBV+nKtrrMOnroE6BBUT7/4 +5zViJ7jVouTbagQlrZEuggPDMbBOv1QVKwEG3Ztwv7Tk5eSO72sBSSVVucml9EaA +MyUDq0CZQt5oN+bucrA1bkXJLBbmvwIsHaW8f7fWIhmgB+WXxeOAsGTY8q/hr28P +VpG9kcp5ypCaN1hHIV9nUwKBgQDKcUBlYd8MJLBwV3XL8Qq7zzgEf6Dm+JZCd9Oo +eUVM+6rdO3ueei6e9kWBdJ/hcrNh9D5UQpw/ufAv0MN2rNenP3lwp2xK9sarRu9a +WdJpEB2d5TulfxOAYcQSLlyOo/LJj19/FxkYLm4ESUQY5GGMMMWf5Sljow0B9nef +VL0TjQKBgG9/w5XpX7K8nnUVGgYuEhbBj7lel2Ad7wjqwxuqDxi3jqVvuIR7VYeg +feuxbZkmphtEOKtaVDSWxGbNXbuN8H9eQqsGhK1Xcn/FxKVu7k+9GYyqeOwhjaqy +HbXzxBM4Ki0l1kaUjDVKjz3fsIq9Pl/lBoKYAmDvkK4xoxcs05ws +-----END RSA PRIVATE KEY-----` + + ExpectedMD5Fingerprint = `24:2e:53:c3:72:4f:25:b8:72:29:2d:e3:56:63:4b:c8` + ExpectedSHA1Fingerprint = `8b:d1:ce:b8:3a:f0:37:7f:56:9e:33:1a:72:4b:32:5a:bc:9d:3b:49` + ExpectedSHA256Fingerprint = `x+EcRzt7EfVuXTxnFt01lkxabPULguUgpvcpo52/Puc=` +) + +var _ = Describe("Fingerprint", func() { + var publicKey ssh.PublicKey + var fingerprint string + + BeforeEach(func() { + privateKey, err := ssh.ParsePrivateKey([]byte(TestPrivateKeyPem)) + Expect(err).NotTo(HaveOccurred()) + + publicKey = privateKey.PublicKey() + }) + + Describe("MD5 Fingerprint", func() { + BeforeEach(func() { + fingerprint = helpers.MD5Fingerprint(publicKey) + }) + + It("should have the correct length", func() { + Expect(utf8.RuneCountInString(fingerprint)).To(Equal(helpers.MD5_FINGERPRINT_LENGTH)) + }) + + It("should match the expected fingerprint", func() { + Expect(fingerprint).To(Equal(ExpectedMD5Fingerprint)) + }) + }) + + Describe("SHA1 Fingerprint", func() { + BeforeEach(func() { + fingerprint = helpers.SHA1Fingerprint(publicKey) + }) + + It("should have the correct length", func() { + Expect(utf8.RuneCountInString(fingerprint)).To(Equal(helpers.SHA1_FINGERPRINT_LENGTH)) + }) + + It("should match the expected fingerprint", func() { + Expect(fingerprint).To(Equal(ExpectedSHA1Fingerprint)) + }) + }) + + Describe("SHA256 Fingerprint", func() { + BeforeEach(func() { + fingerprint = fmt.Sprintf("%s=", helpers.SHA256Fingerprint(publicKey)) + }) + + It("should have the correct length", func() { + Expect(utf8.RuneCountInString(fingerprint)).To(Equal(helpers.SHA256_FINGERPRINT_LENGTH)) + }) + + It("should match the expected fingerprint", func() { + Expect(fingerprint).To(Equal(ExpectedSHA256Fingerprint)) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/helpers_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/helpers/helpers_suite_test.go new file mode 100644 index 0000000000..96cf17caab --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/helpers_suite_test.go @@ -0,0 +1,13 @@ +package helpers_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "testing" +) + +func TestHelpers(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Helpers Suite") +} diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/http_client.go b/src/code.cloudfoundry.org/diego-ssh/helpers/http_client.go new file mode 100644 index 0000000000..f2ac4cdfdf --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/http_client.go @@ -0,0 +1,44 @@ +package helpers + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "net" + "net/http" + "os" + "time" +) + +func NewHTTPSClient(insecureSkipVerify bool, caCertFiles []string, communicationTimeout time.Duration) (*http.Client, error) { + dialer := &net.Dialer{ + Timeout: 5 * time.Second, + KeepAlive: 30 * time.Second, + } + + tlsConfig := &tls.Config{InsecureSkipVerify: insecureSkipVerify} + + caCertPool := x509.NewCertPool() + for _, caCertFile := range caCertFiles { + if caCertFile != "" { + certBytes, err := os.ReadFile(caCertFile) + if err != nil { + return nil, fmt.Errorf("failed to read ca cert file: %s", err.Error()) + } + + if ok := caCertPool.AppendCertsFromPEM(certBytes); !ok { + return nil, errors.New("Unable to load caCert") + } + } + } + tlsConfig.RootCAs = caCertPool + + return &http.Client{ + Transport: &http.Transport{ + Dial: dialer.Dial, + TLSClientConfig: tlsConfig, + }, + Timeout: communicationTimeout, + }, nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/http_client_test.go b/src/code.cloudfoundry.org/diego-ssh/helpers/http_client_test.go new file mode 100644 index 0000000000..6787f6ed17 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/http_client_test.go @@ -0,0 +1,121 @@ +package helpers_test + +import ( + "crypto/x509" + "net/http" + "os" + "time" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/inigo/helpers/certauthority" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("NewHTTPSClient", func() { + var ( + caCertFiles []string + insecureSkipVerify bool + timeout time.Duration + ) + + BeforeEach(func() { + caCertFiles = []string{} + }) + + It("sets InsecureSkipVerify on the TLS config", func() { + client, err := helpers.NewHTTPSClient(true, caCertFiles, timeout) + Expect(err).NotTo(HaveOccurred()) + httpTrans, ok := client.Transport.(*http.Transport) + Expect(ok).To(BeTrue()) + Expect(httpTrans.TLSClientConfig.InsecureSkipVerify).To(BeTrue()) + }) + + It("sets the client timeout", func() { + client, err := helpers.NewHTTPSClient(insecureSkipVerify, caCertFiles, 5*time.Second) + Expect(err).NotTo(HaveOccurred()) + Expect(client.Timeout).To(Equal(5 * time.Second)) + }) + + Context("when a list of ca Cert files is provided", func() { + var ( + certDepotDir string + ca certauthority.CertAuthority + ) + + BeforeEach(func() { + var err error + certDepotDir, err = os.MkdirTemp("", "cert-depot-dir") + Expect(err).NotTo(HaveOccurred()) + + ca, err = certauthority.NewCertAuthority(certDepotDir, "one") + Expect(err).NotTo(HaveOccurred()) + + _, cert := ca.CAAndKey() + caCertFiles = []string{cert} + }) + + AfterEach(func() { + Expect(os.RemoveAll(certDepotDir)).To(Succeed()) + }) + + It("sets the RootCAs with a pool consisting of those CAs", func() { + expectedPool := x509.NewCertPool() + for _, caCert := range caCertFiles { + certBytes, err2 := os.ReadFile(caCert) + Expect(err2).NotTo(HaveOccurred()) + + Expect(expectedPool.AppendCertsFromPEM(certBytes)).To(BeTrue()) + } + + client, err := helpers.NewHTTPSClient(insecureSkipVerify, caCertFiles, timeout) + Expect(err).NotTo(HaveOccurred()) + httpTrans, ok := client.Transport.(*http.Transport) + Expect(ok).To(BeTrue()) + + caPool := httpTrans.TLSClientConfig.RootCAs + + //lint:ignore SA1019 - ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool. + Expect(expectedPool.Subjects()).To(Equal(caPool.Subjects())) + }) + + Context("when an invalid tls cert is provided", func() { + var invalidCertPath string + + BeforeEach(func() { + invalidCert, err := os.CreateTemp("", "invalid-cert-") + Expect(err).NotTo(HaveOccurred()) + + invalidCertPath = invalidCert.Name() + + Expect(invalidCert.Close()).To(Succeed()) + + Expect(os.WriteFile(invalidCertPath, []byte("not valid pem"), 0644)).To(Succeed()) + + caCertFiles = append(caCertFiles, invalidCertPath) + }) + + AfterEach(func() { + Expect(os.Remove(invalidCertPath)).To(Succeed()) + }) + + It("returns an error", func() { + _, err := helpers.NewHTTPSClient(insecureSkipVerify, caCertFiles, timeout) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("Unable to load caCert")) + }) + }) + + Context("when the UAA tls cert does not exist", func() { + BeforeEach(func() { + caCertFiles = append(caCertFiles, "doesntexist") + }) + + It("returns an error", func() { + _, err := helpers.NewHTTPSClient(insecureSkipVerify, caCertFiles, timeout) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to read ca cert file")) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/lager_writer.go b/src/code.cloudfoundry.org/diego-ssh/helpers/lager_writer.go new file mode 100644 index 0000000000..55116bb6d8 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/lager_writer.go @@ -0,0 +1,22 @@ +package helpers + +import ( + "io" + + "code.cloudfoundry.org/lager/v3" +) + +type lagerWriter struct { + logger lager.Logger +} + +// NewLagerWriter wraps a Writer around a lager.Logger +// Log messages will be written at the specified log level +func NewLagerWriter(logger lager.Logger) io.Writer { + return &lagerWriter{logger: logger} +} + +func (lw *lagerWriter) Write(p []byte) (int, error) { + lw.logger.Info("write", lager.Data{"payload": string(p)}) + return len(p), nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/lager_writer_test.go b/src/code.cloudfoundry.org/diego-ssh/helpers/lager_writer_test.go new file mode 100644 index 0000000000..07afcd6df5 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/lager_writer_test.go @@ -0,0 +1,38 @@ +package helpers_test + +import ( + "io" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagertest" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("LagerWriter", func() { + var logger *lagertest.TestLogger + var lagerWriter io.Writer + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + lagerWriter = helpers.NewLagerWriter(logger) + }) + + It("writes the payload as lager.Data", func() { + payload := []byte("Hello, world!\n") + + n, err := lagerWriter.Write(payload) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(len(payload))) + + Expect(logger.Logs()).To(HaveLen(1)) + + log := logger.Logs()[0] + Expect(log.Source).To(Equal("test")) + Expect(log.Message).To(Equal("test.write")) + Expect(log.LogLevel).To(Equal(lager.INFO)) + Expect(log.Data).To(Equal(lager.Data{"payload": string(payload)})) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/listener_store.go b/src/code.cloudfoundry.org/diego-ssh/helpers/listener_store.go new file mode 100644 index 0000000000..371e913a47 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/listener_store.go @@ -0,0 +1,54 @@ +package helpers + +import ( + "fmt" + "net" + "sync" +) + +type ListenerStore struct { + store map[string]net.Listener + lock sync.Mutex +} + +func NewListenerStore() *ListenerStore { + return &ListenerStore{ + store: make(map[string]net.Listener), + } +} + +func (t *ListenerStore) AddListener(addr string, ln net.Listener) { + t.lock.Lock() + t.store[addr] = ln + t.lock.Unlock() +} + +func (t *ListenerStore) RemoveListener(addr string) error { + t.lock.Lock() + defer t.lock.Unlock() + if ln, ok := t.store[addr]; ok { + delete(t.store, addr) + return ln.Close() + } + return fmt.Errorf("RemoveListener error: addr %s doesn't exist", addr) +} + +func (t *ListenerStore) ListAll() []string { + t.lock.Lock() + a := make([]string, 0, len(t.store)) + for k := range t.store { + a = append(a, k) + } + t.lock.Unlock() + return a +} + +func (t *ListenerStore) RemoveAll() { + t.lock.Lock() + for k, ln := range t.store { + delete(t.store, k) + // #nosec G104 - no logging here, ignoring close errors + ln.Close() + } + t.lock.Unlock() +} diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/listener_store_test.go b/src/code.cloudfoundry.org/diego-ssh/helpers/listener_store_test.go new file mode 100644 index 0000000000..d63a5b3bda --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/listener_store_test.go @@ -0,0 +1,80 @@ +package helpers_test + +import ( + "fmt" + "sync" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("ListenerStore", func() { + var lnStore *helpers.ListenerStore + JustBeforeEach(func() { + lnStore = helpers.NewListenerStore() + }) + + It("concurrently adds and removes", func() { + addrs := make(chan string, 100) + ln := &fake_net.FakeListener{} + + wg := sync.WaitGroup{} + for i := 0; i < 20; i++ { + wg.Add(1) + go func(i int) { + addr := fmt.Sprintf("127.0.0.1:%d", 8080+i) + defer wg.Done() + lnStore.AddListener(addr, ln) + addrs <- addr + }(i) + } + + for i := 0; i < 20; i++ { + wg.Add(1) + go func() { + addr := <-addrs + defer GinkgoRecover() + defer wg.Done() + err := lnStore.RemoveListener(addr) + Expect(err).ToNot(HaveOccurred()) + }() + } + wg.Wait() + + Expect(lnStore.ListAll()).To(HaveLen(0)) + }) + + Describe("RemoveListener", func() { + It("closes listeners when it removes them", func() { + ln := &fake_net.FakeListener{} + addr := "127.0.0.1:8080" + lnStore.AddListener(addr, ln) + lnStore.RemoveListener(addr) + Expect(ln.CloseCallCount()).To(Equal(1)) + }) + + It("errors if the requested listener does not exist", func() { + err := lnStore.RemoveListener("127.0.0.1:12345") + Expect(err).To(MatchError("RemoveListener error: addr 127.0.0.1:12345 doesn't exist")) + }) + }) + + Describe("RemoveAll", func() { + It("removes and closes all listeners", func() { + ln1 := &fake_net.FakeListener{} + addr1 := "127.0.0.1:8080" + ln2 := &fake_net.FakeListener{} + addr2 := "127.0.0.1:8081" + + lnStore.AddListener(addr1, ln1) + lnStore.AddListener(addr2, ln2) + + lnStore.RemoveAll() + Expect(lnStore.ListAll()).To(HaveLen(0)) + Expect(ln1.CloseCallCount()).To(Equal(1)) + Expect(ln2.CloseCallCount()).To(Equal(1)) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/helpers/package.go b/src/code.cloudfoundry.org/diego-ssh/helpers/package.go new file mode 100644 index 0000000000..a196145688 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/helpers/package.go @@ -0,0 +1 @@ +package helpers // import "code.cloudfoundry.org/diego-ssh/helpers" diff --git a/src/code.cloudfoundry.org/diego-ssh/keys/fake_keys/fake_key_pair.go b/src/code.cloudfoundry.org/diego-ssh/keys/fake_keys/fake_key_pair.go new file mode 100644 index 0000000000..f699166176 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/keys/fake_keys/fake_key_pair.go @@ -0,0 +1,363 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fake_keys + +import ( + "sync" + + "code.cloudfoundry.org/diego-ssh/keys" + "golang.org/x/crypto/ssh" +) + +type FakeKeyPair struct { + AuthorizedKeyStub func() string + authorizedKeyMutex sync.RWMutex + authorizedKeyArgsForCall []struct { + } + authorizedKeyReturns struct { + result1 string + } + authorizedKeyReturnsOnCall map[int]struct { + result1 string + } + FingerprintStub func() string + fingerprintMutex sync.RWMutex + fingerprintArgsForCall []struct { + } + fingerprintReturns struct { + result1 string + } + fingerprintReturnsOnCall map[int]struct { + result1 string + } + PEMEncodedPrivateKeyStub func() string + pEMEncodedPrivateKeyMutex sync.RWMutex + pEMEncodedPrivateKeyArgsForCall []struct { + } + pEMEncodedPrivateKeyReturns struct { + result1 string + } + pEMEncodedPrivateKeyReturnsOnCall map[int]struct { + result1 string + } + PrivateKeyStub func() ssh.Signer + privateKeyMutex sync.RWMutex + privateKeyArgsForCall []struct { + } + privateKeyReturns struct { + result1 ssh.Signer + } + privateKeyReturnsOnCall map[int]struct { + result1 ssh.Signer + } + PublicKeyStub func() ssh.PublicKey + publicKeyMutex sync.RWMutex + publicKeyArgsForCall []struct { + } + publicKeyReturns struct { + result1 ssh.PublicKey + } + publicKeyReturnsOnCall map[int]struct { + result1 ssh.PublicKey + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeKeyPair) AuthorizedKey() string { + fake.authorizedKeyMutex.Lock() + ret, specificReturn := fake.authorizedKeyReturnsOnCall[len(fake.authorizedKeyArgsForCall)] + fake.authorizedKeyArgsForCall = append(fake.authorizedKeyArgsForCall, struct { + }{}) + fake.recordInvocation("AuthorizedKey", []interface{}{}) + authorizedKeyStubCopy := fake.AuthorizedKeyStub + fake.authorizedKeyMutex.Unlock() + if authorizedKeyStubCopy != nil { + return authorizedKeyStubCopy() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.authorizedKeyReturns + return fakeReturns.result1 +} + +func (fake *FakeKeyPair) AuthorizedKeyCallCount() int { + fake.authorizedKeyMutex.RLock() + defer fake.authorizedKeyMutex.RUnlock() + return len(fake.authorizedKeyArgsForCall) +} + +func (fake *FakeKeyPair) AuthorizedKeyCalls(stub func() string) { + fake.authorizedKeyMutex.Lock() + defer fake.authorizedKeyMutex.Unlock() + fake.AuthorizedKeyStub = stub +} + +func (fake *FakeKeyPair) AuthorizedKeyReturns(result1 string) { + fake.authorizedKeyMutex.Lock() + defer fake.authorizedKeyMutex.Unlock() + fake.AuthorizedKeyStub = nil + fake.authorizedKeyReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeKeyPair) AuthorizedKeyReturnsOnCall(i int, result1 string) { + fake.authorizedKeyMutex.Lock() + defer fake.authorizedKeyMutex.Unlock() + fake.AuthorizedKeyStub = nil + if fake.authorizedKeyReturnsOnCall == nil { + fake.authorizedKeyReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.authorizedKeyReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeKeyPair) Fingerprint() string { + fake.fingerprintMutex.Lock() + ret, specificReturn := fake.fingerprintReturnsOnCall[len(fake.fingerprintArgsForCall)] + fake.fingerprintArgsForCall = append(fake.fingerprintArgsForCall, struct { + }{}) + fake.recordInvocation("Fingerprint", []interface{}{}) + fingerprintStubCopy := fake.FingerprintStub + fake.fingerprintMutex.Unlock() + if fingerprintStubCopy != nil { + return fingerprintStubCopy() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.fingerprintReturns + return fakeReturns.result1 +} + +func (fake *FakeKeyPair) FingerprintCallCount() int { + fake.fingerprintMutex.RLock() + defer fake.fingerprintMutex.RUnlock() + return len(fake.fingerprintArgsForCall) +} + +func (fake *FakeKeyPair) FingerprintCalls(stub func() string) { + fake.fingerprintMutex.Lock() + defer fake.fingerprintMutex.Unlock() + fake.FingerprintStub = stub +} + +func (fake *FakeKeyPair) FingerprintReturns(result1 string) { + fake.fingerprintMutex.Lock() + defer fake.fingerprintMutex.Unlock() + fake.FingerprintStub = nil + fake.fingerprintReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeKeyPair) FingerprintReturnsOnCall(i int, result1 string) { + fake.fingerprintMutex.Lock() + defer fake.fingerprintMutex.Unlock() + fake.FingerprintStub = nil + if fake.fingerprintReturnsOnCall == nil { + fake.fingerprintReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.fingerprintReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeKeyPair) PEMEncodedPrivateKey() string { + fake.pEMEncodedPrivateKeyMutex.Lock() + ret, specificReturn := fake.pEMEncodedPrivateKeyReturnsOnCall[len(fake.pEMEncodedPrivateKeyArgsForCall)] + fake.pEMEncodedPrivateKeyArgsForCall = append(fake.pEMEncodedPrivateKeyArgsForCall, struct { + }{}) + fake.recordInvocation("PEMEncodedPrivateKey", []interface{}{}) + pEMEncodedPrivateKeyStubCopy := fake.PEMEncodedPrivateKeyStub + fake.pEMEncodedPrivateKeyMutex.Unlock() + if pEMEncodedPrivateKeyStubCopy != nil { + return pEMEncodedPrivateKeyStubCopy() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.pEMEncodedPrivateKeyReturns + return fakeReturns.result1 +} + +func (fake *FakeKeyPair) PEMEncodedPrivateKeyCallCount() int { + fake.pEMEncodedPrivateKeyMutex.RLock() + defer fake.pEMEncodedPrivateKeyMutex.RUnlock() + return len(fake.pEMEncodedPrivateKeyArgsForCall) +} + +func (fake *FakeKeyPair) PEMEncodedPrivateKeyCalls(stub func() string) { + fake.pEMEncodedPrivateKeyMutex.Lock() + defer fake.pEMEncodedPrivateKeyMutex.Unlock() + fake.PEMEncodedPrivateKeyStub = stub +} + +func (fake *FakeKeyPair) PEMEncodedPrivateKeyReturns(result1 string) { + fake.pEMEncodedPrivateKeyMutex.Lock() + defer fake.pEMEncodedPrivateKeyMutex.Unlock() + fake.PEMEncodedPrivateKeyStub = nil + fake.pEMEncodedPrivateKeyReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeKeyPair) PEMEncodedPrivateKeyReturnsOnCall(i int, result1 string) { + fake.pEMEncodedPrivateKeyMutex.Lock() + defer fake.pEMEncodedPrivateKeyMutex.Unlock() + fake.PEMEncodedPrivateKeyStub = nil + if fake.pEMEncodedPrivateKeyReturnsOnCall == nil { + fake.pEMEncodedPrivateKeyReturnsOnCall = make(map[int]struct { + result1 string + }) + } + fake.pEMEncodedPrivateKeyReturnsOnCall[i] = struct { + result1 string + }{result1} +} + +func (fake *FakeKeyPair) PrivateKey() ssh.Signer { + fake.privateKeyMutex.Lock() + ret, specificReturn := fake.privateKeyReturnsOnCall[len(fake.privateKeyArgsForCall)] + fake.privateKeyArgsForCall = append(fake.privateKeyArgsForCall, struct { + }{}) + fake.recordInvocation("PrivateKey", []interface{}{}) + privateKeyStubCopy := fake.PrivateKeyStub + fake.privateKeyMutex.Unlock() + if privateKeyStubCopy != nil { + return privateKeyStubCopy() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.privateKeyReturns + return fakeReturns.result1 +} + +func (fake *FakeKeyPair) PrivateKeyCallCount() int { + fake.privateKeyMutex.RLock() + defer fake.privateKeyMutex.RUnlock() + return len(fake.privateKeyArgsForCall) +} + +func (fake *FakeKeyPair) PrivateKeyCalls(stub func() ssh.Signer) { + fake.privateKeyMutex.Lock() + defer fake.privateKeyMutex.Unlock() + fake.PrivateKeyStub = stub +} + +func (fake *FakeKeyPair) PrivateKeyReturns(result1 ssh.Signer) { + fake.privateKeyMutex.Lock() + defer fake.privateKeyMutex.Unlock() + fake.PrivateKeyStub = nil + fake.privateKeyReturns = struct { + result1 ssh.Signer + }{result1} +} + +func (fake *FakeKeyPair) PrivateKeyReturnsOnCall(i int, result1 ssh.Signer) { + fake.privateKeyMutex.Lock() + defer fake.privateKeyMutex.Unlock() + fake.PrivateKeyStub = nil + if fake.privateKeyReturnsOnCall == nil { + fake.privateKeyReturnsOnCall = make(map[int]struct { + result1 ssh.Signer + }) + } + fake.privateKeyReturnsOnCall[i] = struct { + result1 ssh.Signer + }{result1} +} + +func (fake *FakeKeyPair) PublicKey() ssh.PublicKey { + fake.publicKeyMutex.Lock() + ret, specificReturn := fake.publicKeyReturnsOnCall[len(fake.publicKeyArgsForCall)] + fake.publicKeyArgsForCall = append(fake.publicKeyArgsForCall, struct { + }{}) + fake.recordInvocation("PublicKey", []interface{}{}) + publicKeyStubCopy := fake.PublicKeyStub + fake.publicKeyMutex.Unlock() + if publicKeyStubCopy != nil { + return publicKeyStubCopy() + } + if specificReturn { + return ret.result1 + } + fakeReturns := fake.publicKeyReturns + return fakeReturns.result1 +} + +func (fake *FakeKeyPair) PublicKeyCallCount() int { + fake.publicKeyMutex.RLock() + defer fake.publicKeyMutex.RUnlock() + return len(fake.publicKeyArgsForCall) +} + +func (fake *FakeKeyPair) PublicKeyCalls(stub func() ssh.PublicKey) { + fake.publicKeyMutex.Lock() + defer fake.publicKeyMutex.Unlock() + fake.PublicKeyStub = stub +} + +func (fake *FakeKeyPair) PublicKeyReturns(result1 ssh.PublicKey) { + fake.publicKeyMutex.Lock() + defer fake.publicKeyMutex.Unlock() + fake.PublicKeyStub = nil + fake.publicKeyReturns = struct { + result1 ssh.PublicKey + }{result1} +} + +func (fake *FakeKeyPair) PublicKeyReturnsOnCall(i int, result1 ssh.PublicKey) { + fake.publicKeyMutex.Lock() + defer fake.publicKeyMutex.Unlock() + fake.PublicKeyStub = nil + if fake.publicKeyReturnsOnCall == nil { + fake.publicKeyReturnsOnCall = make(map[int]struct { + result1 ssh.PublicKey + }) + } + fake.publicKeyReturnsOnCall[i] = struct { + result1 ssh.PublicKey + }{result1} +} + +func (fake *FakeKeyPair) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.authorizedKeyMutex.RLock() + defer fake.authorizedKeyMutex.RUnlock() + fake.fingerprintMutex.RLock() + defer fake.fingerprintMutex.RUnlock() + fake.pEMEncodedPrivateKeyMutex.RLock() + defer fake.pEMEncodedPrivateKeyMutex.RUnlock() + fake.privateKeyMutex.RLock() + defer fake.privateKeyMutex.RUnlock() + fake.publicKeyMutex.RLock() + defer fake.publicKeyMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeKeyPair) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ keys.KeyPair = new(FakeKeyPair) diff --git a/src/code.cloudfoundry.org/diego-ssh/keys/fake_keys/fake_ssh_key_factory.go b/src/code.cloudfoundry.org/diego-ssh/keys/fake_keys/fake_ssh_key_factory.go new file mode 100644 index 0000000000..c57c3b95b7 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/keys/fake_keys/fake_ssh_key_factory.go @@ -0,0 +1,116 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fake_keys + +import ( + "sync" + + "code.cloudfoundry.org/diego-ssh/keys" +) + +type FakeSSHKeyFactory struct { + NewKeyPairStub func(int) (keys.KeyPair, error) + newKeyPairMutex sync.RWMutex + newKeyPairArgsForCall []struct { + arg1 int + } + newKeyPairReturns struct { + result1 keys.KeyPair + result2 error + } + newKeyPairReturnsOnCall map[int]struct { + result1 keys.KeyPair + result2 error + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeSSHKeyFactory) NewKeyPair(arg1 int) (keys.KeyPair, error) { + fake.newKeyPairMutex.Lock() + ret, specificReturn := fake.newKeyPairReturnsOnCall[len(fake.newKeyPairArgsForCall)] + fake.newKeyPairArgsForCall = append(fake.newKeyPairArgsForCall, struct { + arg1 int + }{arg1}) + fake.recordInvocation("NewKeyPair", []interface{}{arg1}) + newKeyPairStubCopy := fake.NewKeyPairStub + fake.newKeyPairMutex.Unlock() + if newKeyPairStubCopy != nil { + return newKeyPairStubCopy(arg1) + } + if specificReturn { + return ret.result1, ret.result2 + } + fakeReturns := fake.newKeyPairReturns + return fakeReturns.result1, fakeReturns.result2 +} + +func (fake *FakeSSHKeyFactory) NewKeyPairCallCount() int { + fake.newKeyPairMutex.RLock() + defer fake.newKeyPairMutex.RUnlock() + return len(fake.newKeyPairArgsForCall) +} + +func (fake *FakeSSHKeyFactory) NewKeyPairCalls(stub func(int) (keys.KeyPair, error)) { + fake.newKeyPairMutex.Lock() + defer fake.newKeyPairMutex.Unlock() + fake.NewKeyPairStub = stub +} + +func (fake *FakeSSHKeyFactory) NewKeyPairArgsForCall(i int) int { + fake.newKeyPairMutex.RLock() + defer fake.newKeyPairMutex.RUnlock() + argsForCall := fake.newKeyPairArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeSSHKeyFactory) NewKeyPairReturns(result1 keys.KeyPair, result2 error) { + fake.newKeyPairMutex.Lock() + defer fake.newKeyPairMutex.Unlock() + fake.NewKeyPairStub = nil + fake.newKeyPairReturns = struct { + result1 keys.KeyPair + result2 error + }{result1, result2} +} + +func (fake *FakeSSHKeyFactory) NewKeyPairReturnsOnCall(i int, result1 keys.KeyPair, result2 error) { + fake.newKeyPairMutex.Lock() + defer fake.newKeyPairMutex.Unlock() + fake.NewKeyPairStub = nil + if fake.newKeyPairReturnsOnCall == nil { + fake.newKeyPairReturnsOnCall = make(map[int]struct { + result1 keys.KeyPair + result2 error + }) + } + fake.newKeyPairReturnsOnCall[i] = struct { + result1 keys.KeyPair + result2 error + }{result1, result2} +} + +func (fake *FakeSSHKeyFactory) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.newKeyPairMutex.RLock() + defer fake.newKeyPairMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeSSHKeyFactory) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ keys.SSHKeyFactory = new(FakeSSHKeyFactory) diff --git a/src/code.cloudfoundry.org/diego-ssh/keys/fake_keys/package.go b/src/code.cloudfoundry.org/diego-ssh/keys/fake_keys/package.go new file mode 100644 index 0000000000..4e8835a0d2 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/keys/fake_keys/package.go @@ -0,0 +1 @@ +package fake_keys // import "code.cloudfoundry.org/diego-ssh/keys/fake_keys" diff --git a/src/code.cloudfoundry.org/diego-ssh/keys/keys_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/keys/keys_suite_test.go new file mode 100644 index 0000000000..4fc046f2e2 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/keys/keys_suite_test.go @@ -0,0 +1,13 @@ +package keys_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "testing" +) + +func TestKeys(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Keys Suite") +} diff --git a/src/code.cloudfoundry.org/diego-ssh/keys/package.go b/src/code.cloudfoundry.org/diego-ssh/keys/package.go new file mode 100644 index 0000000000..371a2f987d --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/keys/package.go @@ -0,0 +1 @@ +package keys // import "code.cloudfoundry.org/diego-ssh/keys" diff --git a/src/code.cloudfoundry.org/diego-ssh/keys/rsa.go b/src/code.cloudfoundry.org/diego-ssh/keys/rsa.go new file mode 100644 index 0000000000..aef7a2d318 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/keys/rsa.go @@ -0,0 +1,87 @@ +package keys + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + + "code.cloudfoundry.org/diego-ssh/helpers" + "golang.org/x/crypto/ssh" +) + +//go:generate counterfeiter -o fake_keys/fake_key_pair.go . KeyPair +type KeyPair interface { + PrivateKey() ssh.Signer + PEMEncodedPrivateKey() string + + PublicKey() ssh.PublicKey + Fingerprint() string + AuthorizedKey() string +} + +//go:generate counterfeiter -o fake_keys/fake_ssh_key_factory.go . SSHKeyFactory +type SSHKeyFactory interface { + NewKeyPair(bits int) (KeyPair, error) +} + +var RSAKeyPairFactory SSHKeyFactory = &keyPairFactory{} + +type keyPairFactory struct{} + +func (r *keyPairFactory) NewKeyPair(bits int) (KeyPair, error) { + return newRSA(bits) +} + +type rsaKeyPair struct { + encodedPrivateKey string + privateKey ssh.Signer +} + +func newRSA(bits int) (KeyPair, error) { + key, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, err + } + + err = key.Validate() + if err != nil { + return nil, err + } + + encodedPrivateKey := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Headers: nil, + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + + privateKey, err := ssh.ParsePrivateKey(encodedPrivateKey) + if err != nil { + return nil, err + } + + return &rsaKeyPair{ + encodedPrivateKey: string(encodedPrivateKey), + privateKey: privateKey, + }, nil +} + +func (k *rsaKeyPair) PrivateKey() ssh.Signer { + return k.privateKey +} + +func (k *rsaKeyPair) PEMEncodedPrivateKey() string { + return k.encodedPrivateKey +} + +func (k *rsaKeyPair) PublicKey() ssh.PublicKey { + return k.privateKey.PublicKey() +} + +func (k *rsaKeyPair) Fingerprint() string { + return helpers.MD5Fingerprint(k.PublicKey()) +} + +func (k *rsaKeyPair) AuthorizedKey() string { + return string(ssh.MarshalAuthorizedKey(k.PublicKey())) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/keys/rsa_test.go b/src/code.cloudfoundry.org/diego-ssh/keys/rsa_test.go new file mode 100644 index 0000000000..839cf56c3d --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/keys/rsa_test.go @@ -0,0 +1,94 @@ +package keys_test + +import ( + "crypto/x509" + "encoding/pem" + + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/keys" + "golang.org/x/crypto/ssh" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("RSA", func() { + var keyPair keys.KeyPair + var bits int + + BeforeEach(func() { + bits = 1024 + }) + + JustBeforeEach(func() { + var err error + keyPair, err = keys.RSAKeyPairFactory.NewKeyPair(bits) + Expect(err).NotTo(HaveOccurred()) + }) + + Describe("PrivateKey", func() { + It("returns the ssh private key associted with the public key", func() { + Expect(keyPair.PrivateKey()).NotTo(BeNil()) + Expect(keyPair.PrivateKey().PublicKey()).To(Equal(keyPair.PublicKey())) + }) + + Context("when creating a 1024 bit key", func() { + BeforeEach(func() { + bits = 1024 + }) + + It("the private key is 1024 bits", func() { + block, _ := pem.Decode([]byte(keyPair.PEMEncodedPrivateKey())) + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + Expect(err).NotTo(HaveOccurred()) + + Expect(key.N.BitLen()).To(Equal(1024)) + }) + }) + + Context("when creating a 2048 bit key", func() { + BeforeEach(func() { + bits = 2048 + }) + + It("the private key is 2048 bits", func() { + block, _ := pem.Decode([]byte(keyPair.PEMEncodedPrivateKey())) + key, err := x509.ParsePKCS1PrivateKey(block.Bytes) + Expect(err).NotTo(HaveOccurred()) + + Expect(key.N.BitLen()).To(Equal(2048)) + }) + }) + }) + + Describe("PEMEncodedPrivateKey", func() { + It("correctly represents the private key", func() { + privateKey, err := ssh.ParsePrivateKey([]byte(keyPair.PEMEncodedPrivateKey())) + Expect(err).NotTo(HaveOccurred()) + + Expect(privateKey.PublicKey().Marshal()).To(Equal(keyPair.PublicKey().Marshal())) + }) + }) + + Describe("PublicKey", func() { + It("equals the public key associated with the private key", func() { + Expect(keyPair.PrivateKey().PublicKey().Marshal()).To(Equal(keyPair.PublicKey().Marshal())) + }) + }) + + Describe("Fingerprint", func() { + It("equals the MD5 fingerprint of the public key", func() { + expectedFingerprint := helpers.MD5Fingerprint(keyPair.PublicKey()) + + Expect(keyPair.Fingerprint()).To(Equal(expectedFingerprint)) + }) + }) + + Describe("AuthorizedKey", func() { + It("equals the authorized key formatted public key", func() { + expectedAuthorizedKey := string(ssh.MarshalAuthorizedKey(keyPair.PublicKey())) + + Expect(keyPair.AuthorizedKey()).To(Equal(expectedAuthorizedKey)) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/proxy/package.go b/src/code.cloudfoundry.org/diego-ssh/proxy/package.go new file mode 100644 index 0000000000..2a6e7068fb --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/proxy/package.go @@ -0,0 +1 @@ +package proxy // import "code.cloudfoundry.org/diego-ssh/proxy" diff --git a/src/code.cloudfoundry.org/diego-ssh/proxy/proxy.go b/src/code.cloudfoundry.org/diego-ssh/proxy/proxy.go new file mode 100644 index 0000000000..9bd3c8cce7 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/proxy/proxy.go @@ -0,0 +1,441 @@ +package proxy + +import ( + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "net" + "strings" + "sync" + "unicode/utf8" + + loggingclient "code.cloudfoundry.org/diego-logging-client" + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/lager/v3" + "golang.org/x/crypto/ssh" +) + +const ( + sshConnectionsMetric = "ssh-connections" +) + +type Waiter interface { + Wait() error +} + +type TargetConfig struct { + Address string `json:"address"` + TLSAddress string `json:"tls_address"` + ServerCertDomainSAN string `json:"server_cert_domain_san"` + HostFingerprint string `json:"host_fingerprint"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` + PrivateKey string `json:"private_key,omitempty"` +} + +type LogMessage struct { + Message string `json:"message"` + Tags map[string]string `json:"tags"` +} + +type Proxy struct { + logger lager.Logger + serverConfig *ssh.ServerConfig + + connectionLock *sync.Mutex + connections int + metronClient loggingclient.IngressClient + + tlsConfig *tls.Config +} + +func New( + logger lager.Logger, + serverConfig *ssh.ServerConfig, + metronClient loggingclient.IngressClient, + tlsConfig *tls.Config, +) *Proxy { + return &Proxy{ + logger: logger, + serverConfig: serverConfig, + connectionLock: &sync.Mutex{}, + metronClient: metronClient, + tlsConfig: tlsConfig, + } +} + +func (p *Proxy) HandleConnection(netConn net.Conn) { + logger := p.logger.Session("handle-connection") + defer netConn.Close() + + serverConn, serverChannels, serverRequests, err := ssh.NewServerConn(netConn, p.serverConfig) + if err != nil { + return + } + defer serverConn.Close() + + clientConn, clientChannels, clientRequests, err := NewClientConn(logger, serverConn.Permissions, p.tlsConfig) + if err != nil { + return + } + + logMessage := extractLogMessage(logger, serverConn.Permissions) + + defer func() { + if logMessage != nil { + endMessage := fmt.Sprintf("Remote access ended for %s", serverConn.RemoteAddr().String()) + err = p.metronClient.SendAppLog(endMessage, "SSH", logMessage.Tags) + if err != nil { + logger.Debug("failed-to-send-appLog", lager.Data{"error": err}) + } + } + err = clientConn.Close() + if err != nil { + logger.Debug("failed-to-close-connection", lager.Data{"error": err}) + } + }() + + if logMessage != nil { + err = p.metronClient.SendAppLog(logMessage.Message, "SSH", logMessage.Tags) + if err != nil { + logger.Debug("failed-to-send-appLog", lager.Data{"error": err}) + } + } + + fromClientLogger := logger.Session("from-client") + fromDaemonLogger := logger.Session("from-daemon") + + go ProxyGlobalRequests(fromClientLogger, clientConn, serverRequests) + go ProxyGlobalRequests(fromDaemonLogger, serverConn, clientRequests) + + go ProxyChannels(fromClientLogger, clientConn, serverChannels) + go ProxyChannels(fromDaemonLogger, serverConn, clientChannels) + + p.connectionLock.Lock() + p.connections++ + err = p.metronClient.SendMetric(sshConnectionsMetric, p.connections) + if err != nil { + logger.Error("failed-to-send-ssh-connections-metric", err) + } + p.connectionLock.Unlock() + + defer func() { + p.emitConnectionClosing(logger) + }() + + Wait(logger, serverConn, clientConn) +} + +func (p *Proxy) emitConnectionClosing(logger lager.Logger) { + p.connectionLock.Lock() + p.connections-- + err := p.metronClient.SendMetric(sshConnectionsMetric, p.connections) + p.connectionLock.Unlock() + + if err != nil { + logger.Error("failed-to-send-ssh-connections-metric", err) + } +} + +func extractLogMessage(logger lager.Logger, perms *ssh.Permissions) *LogMessage { + logMessageJson := perms.CriticalOptions["log-message"] + if logMessageJson == "" { + return nil + } + + logMessage := &LogMessage{} + err := json.Unmarshal([]byte(logMessageJson), logMessage) + if err != nil { + logger.Error("json-unmarshal-failed", err) + return nil + } + + return logMessage +} + +func ProxyGlobalRequests(logger lager.Logger, conn ssh.Conn, reqs <-chan *ssh.Request) { + logger = logger.Session("proxy-global-requests") + + logger.Info("started") + defer logger.Info("completed") + + for req := range reqs { + logger.Info("request", lager.Data{ + "type": req.Type, + "wantReply": req.WantReply, + "payload": req.Payload, + }) + + success, reply, err := conn.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + logger.Error("send-request-failed", err) + continue + } + + if req.WantReply { + err = req.Reply(success, reply) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + } +} + +func ProxyChannels(logger lager.Logger, conn ssh.Conn, channels <-chan ssh.NewChannel) { + logger = logger.Session("proxy-channels") + + logger.Info("started") + defer func() { + logger.Info("completed") + err := conn.Close() + logger.Debug("failed-to-close-channel", lager.Data{"error": err}) + }() + + for newChannel := range channels { + handleNewChannel(logger, conn, newChannel) + } +} + +func handleNewChannel(logger lager.Logger, conn ssh.Conn, newChannel ssh.NewChannel) { + logger.Info("new-channel", lager.Data{ + "channelType": newChannel.ChannelType(), + "extraData": newChannel.ExtraData(), + }) + + logger.Debug("openning-channel-to-daemon") + + targetChan, targetReqs, err := conn.OpenChannel(newChannel.ChannelType(), newChannel.ExtraData()) + if err != nil { + logger.Error("failed-to-open-channel", err) + if openErr, ok := err.(*ssh.OpenChannelError); ok { + err = newChannel.Reject(openErr.Reason, openErr.Message) + if err != nil { + logger.Debug("failed-to-reject-channel-creation", lager.Data{"error": err}) + } + } else { + err = newChannel.Reject(ssh.ConnectionFailed, err.Error()) + if err != nil { + logger.Debug("failed-to-reject-channel-creation", lager.Data{"error": err}) + } + } + return + } + logger.Debug("opened-channel-to-daemon") + + sourceChan, sourceReqs, err := newChannel.Accept() + if err != nil { + err = targetChan.Close() + if err != nil { + logger.Debug("failed-to-close-target-chan", lager.Data{"error": err}) + } + return + } + logger.Debug("accepted-channel-from-client") + + toTargetLogger := logger.Session("to-target") + toSourceLogger := logger.Session("to-source") + + targetWg := &sync.WaitGroup{} + sourceWg := &sync.WaitGroup{} + + targetWg.Add(2) + go helpers.Copy(toTargetLogger.Session("stdout"), targetWg, targetChan, sourceChan) + go helpers.Copy(toTargetLogger.Session("stderr"), targetWg, targetChan.Stderr(), sourceChan.Stderr()) + go func() { + targetWg.Wait() + err := targetChan.CloseWrite() + if err != nil { + logger.Debug("failed-to-close-write-target-chan", lager.Data{"error": err}) + } + }() + + sourceWg.Add(2) + go helpers.Copy(toSourceLogger.Session("stdout"), sourceWg, sourceChan, targetChan) + go helpers.Copy(toSourceLogger.Session("stderr"), sourceWg, sourceChan.Stderr(), targetChan.Stderr()) + go func() { + sourceWg.Wait() + err := sourceChan.CloseWrite() + if err != nil { + logger.Debug("failed-to-close-write-source-chan", lager.Data{"error": err}) + } + }() + + go ProxyRequests(toTargetLogger, newChannel.ChannelType(), sourceReqs, targetChan, targetWg) + go ProxyRequests(toSourceLogger, newChannel.ChannelType(), targetReqs, sourceChan, sourceWg) +} + +func ProxyRequests(logger lager.Logger, channelType string, reqs <-chan *ssh.Request, channel ssh.Channel, wg *sync.WaitGroup) { + logger = logger.Session("proxy-requests", lager.Data{ + "channel-type": channelType, + }) + + logger.Info("started") + defer func() { + logger.Info("completed") + wg.Wait() + err := channel.Close() + if err != nil { + logger.Debug("failed-to-close-channel", lager.Data{"error": err}) + } + }() + + for req := range reqs { + logger.Info("request", lager.Data{ + "type": req.Type, + "wantReply": req.WantReply, + "payload": req.Payload, + }) + success, err := channel.SendRequest(req.Type, req.WantReply, req.Payload) + if err != nil { + logger.Error("send-request-failed", err) + continue + } + + if req.WantReply { + err = req.Reply(success, nil) + if err != nil { + logger.Debug("failed-to-reply", lager.Data{"error": err}) + } + } + + if req.Type == "exit-status" { + return + } + } +} + +func Wait(logger lager.Logger, waiters ...Waiter) { + wg := &sync.WaitGroup{} + for _, waiter := range waiters { + wg.Add(1) + go func(waiter Waiter) { + err := waiter.Wait() + if err != nil { + logger.Debug("failed-to-wait", lager.Data{"error": err}) + } + wg.Done() + }(waiter) + } + wg.Wait() +} + +func NewClientConn(logger lager.Logger, permissions *ssh.Permissions, tlsConfig *tls.Config) (ssh.Conn, <-chan ssh.NewChannel, <-chan *ssh.Request, error) { + if permissions == nil || permissions.CriticalOptions == nil { + err := errors.New("Invalid permissions from authentication") + logger.Error("permissions-and-critical-options-required", err) + return nil, nil, nil, err + } + + targetConfigJson := permissions.CriticalOptions["proxy-target-config"] + logger = logger.Session("new-client-conn") + + logger.Debug("creating-client-connection", lager.Data{ + "proxy-target-config": targetConfigJson, + }) + + var targetConfig TargetConfig + err := json.Unmarshal([]byte(permissions.CriticalOptions["proxy-target-config"]), &targetConfig) + if err != nil { + logger.Error("unmarshal-failed", err) + return nil, nil, nil, err + } + + dialer := func() (net.Conn, error) { + tlsConfig := tlsConfigWithServerName(tlsConfig, targetConfig.ServerCertDomainSAN) + if tlsConfig != nil && targetConfig.TLSAddress != "" { + nConn, err := tls.Dial("tcp", targetConfig.TLSAddress, tlsConfig) + if err == nil { + return nConn, nil + } + + logger.Error("tls-dial-failed", err, lager.Data{ + "tcp_address": targetConfig.TLSAddress, + "server_cert_domain_san": targetConfig.ServerCertDomainSAN, + }) + } + + nConn, err := net.Dial("tcp", targetConfig.Address) + if err != nil { + logger.Error("dial-failed", err, lager.Data{ + "address": targetConfig.Address, + }) + return nil, err + } + + return nConn, nil + } + + nConn, err := dialer() + if err != nil { + return nil, nil, nil, err + } + + logger.Info("connected-to-backend", lager.Data{ + "backend-address": nConn.RemoteAddr().String(), + }) + + clientConfig := &ssh.ClientConfig{} + + if targetConfig.User != "" { + clientConfig.User = targetConfig.User + } + + if targetConfig.PrivateKey != "" { + key, err := ssh.ParsePrivateKey([]byte(targetConfig.PrivateKey)) + if err != nil { + logger.Error("parsing-key-failed", err) + return nil, nil, nil, err + } + clientConfig.Auth = append(clientConfig.Auth, ssh.PublicKeys(key)) + } + + if targetConfig.User != "" && targetConfig.Password != "" { + clientConfig.Auth = append(clientConfig.Auth, ssh.Password(targetConfig.Password)) + } + + if targetConfig.HostFingerprint != "" { + clientConfig.HostKeyCallback = func(hostname string, remote net.Addr, key ssh.PublicKey) error { + expectedFingerprint := targetConfig.HostFingerprint + + var actualFingerprint string + switch utf8.RuneCountInString(expectedFingerprint) { + case helpers.MD5_FINGERPRINT_LENGTH: + actualFingerprint = helpers.MD5Fingerprint(key) + case helpers.SHA1_FINGERPRINT_LENGTH: + actualFingerprint = helpers.SHA1Fingerprint(key) + case helpers.SHA256_FINGERPRINT_LENGTH: + actualFingerprint = helpers.SHA256Fingerprint(key) + //sshkey ruby gem pads the base64 output, but golang implementation returns unpaded + expectedFingerprint = strings.TrimRight(expectedFingerprint, "=") + } + + if expectedFingerprint != actualFingerprint { + err := errors.New("Host fingerprint mismatch") + logger.Error("host-key-fingerprint-mismatch", err) + return err + } + + return nil + } + } else { + clientConfig.HostKeyCallback = ssh.InsecureIgnoreHostKey() + } + + conn, ch, req, err := ssh.NewClientConn(nConn, targetConfig.Address, clientConfig) + if err != nil { + logger.Error("handshake-failed", err) + return nil, nil, nil, err + } + + return conn, ch, req, nil +} + +func tlsConfigWithServerName(original *tls.Config, serverName string) *tls.Config { + if original == nil { + return nil + } + new := original.Clone() + new.ServerName = serverName + return new +} diff --git a/src/code.cloudfoundry.org/diego-ssh/proxy/proxy_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/proxy/proxy_suite_test.go new file mode 100644 index 0000000000..593b8ba05f --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/proxy/proxy_suite_test.go @@ -0,0 +1,34 @@ +package proxy_test + +import ( + "code.cloudfoundry.org/diego-ssh/keys" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "golang.org/x/crypto/ssh" + + "testing" +) + +var ( + TestHostKey ssh.Signer + + TestPrivatePem string + TestPublicAuthorizedKey string +) + +func TestProxy(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Proxy Suite") +} + +var _ = BeforeSuite(func() { + hostKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + Expect(err).NotTo(HaveOccurred()) + + privateKey, err := keys.RSAKeyPairFactory.NewKeyPair(1024) + Expect(err).NotTo(HaveOccurred()) + + TestHostKey = hostKey.PrivateKey() + TestPrivatePem = privateKey.PEMEncodedPrivateKey() + TestPublicAuthorizedKey = privateKey.AuthorizedKey() +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/proxy/proxy_test.go b/src/code.cloudfoundry.org/diego-ssh/proxy/proxy_test.go new file mode 100644 index 0000000000..5974393041 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/proxy/proxy_test.go @@ -0,0 +1,1566 @@ +package proxy_test + +import ( + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "os" + "sync" + "time" + + mfakes "code.cloudfoundry.org/diego-logging-client/testhelpers" + "code.cloudfoundry.org/diego-ssh/authenticators/fake_authenticators" + "code.cloudfoundry.org/diego-ssh/daemon" + "code.cloudfoundry.org/diego-ssh/handlers" + "code.cloudfoundry.org/diego-ssh/handlers/fake_handlers" + "code.cloudfoundry.org/diego-ssh/helpers" + "code.cloudfoundry.org/diego-ssh/proxy" + "code.cloudfoundry.org/diego-ssh/server" + server_fakes "code.cloudfoundry.org/diego-ssh/server/fakes" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" + loggregator "code.cloudfoundry.org/go-loggregator/v9" + "code.cloudfoundry.org/inigo/helpers/certauthority" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagertest" + "code.cloudfoundry.org/tlsconfig" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" + "golang.org/x/crypto/ssh" +) + +var _ = Describe("Proxy", func() { + var ( + logger *lagertest.TestLogger + fakeMetronClient *mfakes.FakeIngressClient + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + }) + + Describe("HandleConnection", func() { + var ( + proxyAuthenticator *fake_authenticators.FakePasswordAuthenticator + proxySSHConfig *ssh.ServerConfig + sshProxy *proxy.Proxy + + daemonTargetConfig proxy.TargetConfig + daemonAuthenticator *fake_authenticators.FakePasswordAuthenticator + daemonSSHConfig *ssh.ServerConfig + daemonGlobalRequestHandlers map[string]handlers.GlobalRequestHandler + daemonNewChannelHandlers map[string]handlers.NewChannelHandler + sshDaemon *daemon.Daemon + + proxyListener net.Listener + sshdListener net.Listener + + proxyAddress string + daemonAddress string + + proxyServer *server.Server + sshdServer *server.Server + + proxyDone chan struct{} + daemonDone chan struct{} + ) + + BeforeEach(func() { + proxyDone = make(chan struct{}) + daemonDone = make(chan struct{}) + + fakeMetronClient = &mfakes.FakeIngressClient{} + + proxyAuthenticator = &fake_authenticators.FakePasswordAuthenticator{} + + proxySSHConfig = &ssh.ServerConfig{} + proxySSHConfig.PasswordCallback = proxyAuthenticator.Authenticate + proxySSHConfig.AddHostKey(TestHostKey) + + daemonAuthenticator = &fake_authenticators.FakePasswordAuthenticator{} + daemonAuthenticator.AuthenticateReturns(&ssh.Permissions{}, nil) + + daemonSSHConfig = &ssh.ServerConfig{} + daemonSSHConfig.PasswordCallback = daemonAuthenticator.Authenticate + daemonSSHConfig.AddHostKey(TestHostKey) + daemonGlobalRequestHandlers = map[string]handlers.GlobalRequestHandler{} + daemonNewChannelHandlers = map[string]handlers.NewChannelHandler{} + + var err error + proxyListener, err = net.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + proxyAddress = proxyListener.Addr().String() + + sshdListener, err = net.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + daemonAddress = sshdListener.Addr().String() + + daemonTargetConfig = proxy.TargetConfig{ + Address: daemonAddress, + HostFingerprint: helpers.MD5Fingerprint(TestHostKey.PublicKey()), + User: "some-user", + Password: "fake-some-password", + } + + targetConfigJson, err := json.Marshal(daemonTargetConfig) + Expect(err).NotTo(HaveOccurred()) + + logMessageJson, err := json.Marshal(proxy.LogMessage{ + Message: "a-message", + Tags: map[string]string{ + "instance_id": "1", + "source_id": "a-guid", + }, + }) + Expect(err).NotTo(HaveOccurred()) + + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + "log-message": string(logMessageJson), + }, + } + proxyAuthenticator.AuthenticateReturns(permissions, nil) + }) + + JustBeforeEach(func() { + sshProxy = proxy.New(logger.Session("proxy"), proxySSHConfig, fakeMetronClient, nil) + proxyServer = server.NewServer(logger.Session("proxy-server"), "", sshProxy, 500*time.Millisecond) + proxyServer.SetListener(proxyListener) + go func() { + proxyServer.Serve() + close(proxyDone) + }() + + sshDaemon = daemon.New(logger.Session("sshd"), daemonSSHConfig, daemonGlobalRequestHandlers, daemonNewChannelHandlers) + sshdServer = server.NewServer(logger.Session("sshd-server"), "", sshDaemon, 500*time.Millisecond) + sshdServer.SetListener(sshdListener) + go func() { + sshdServer.Serve() + close(daemonDone) + }() + }) + + AfterEach(func() { + proxyServer.Shutdown() + sshdServer.Shutdown() + + Eventually(proxyDone).Should(BeClosed()) + Eventually(daemonDone).Should(BeClosed()) + }) + + Context("when a new connection arrives", func() { + var clientConfig *ssh.ClientConfig + + BeforeEach(func() { + clientConfig = &ssh.ClientConfig{ + User: "diego:some-instance-guid", + Auth: []ssh.AuthMethod{ + ssh.Password("diego-user:diego-password"), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + }) + + It("performs a handshake with the client using the proxy server config", func() { + _, err := ssh.Dial("tcp", proxyAddress, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + Expect(proxyAuthenticator.AuthenticateCallCount()).To(Equal(1)) + + metadata, password := proxyAuthenticator.AuthenticateArgsForCall(0) + Expect(metadata.User()).To(Equal("diego:some-instance-guid")) + Expect(string(password)).To(Equal("diego-user:diego-password")) + }) + + Context("when the handshake fails", func() { + BeforeEach(func() { + proxyAuthenticator.AuthenticateReturns(nil, errors.New("go away")) + }) + + JustBeforeEach(func() { + _, err := ssh.Dial("tcp", proxyAddress, clientConfig) + Expect(err).To(MatchError(ContainSubstring("ssh: handshake failed: ssh: unable to authenticate"))) + }) + + It("does not attempt to authenticate with the daemon", func() { + Expect(daemonAuthenticator.AuthenticateCallCount()).To(Equal(0)) + }) + }) + + Context("when the client handshake is successful", func() { + var client *ssh.Client + + JustBeforeEach(func() { + var err error + client, err = ssh.Dial("tcp", proxyAddress, clientConfig) + Expect(err).NotTo(HaveOccurred()) + }) + + It("handshakes with the target using the provided configuration", func() { + Eventually(daemonAuthenticator.AuthenticateCallCount).Should(Equal(1)) + + metadata, password := daemonAuthenticator.AuthenticateArgsForCall(0) + Expect(metadata.User()).To(Equal("some-user")) + Expect(string(password)).To(Equal("fake-some-password")) + }) + + Context("metron", func() { + It("emits a successful log message on behalf of the lrp", func() { + Eventually(fakeMetronClient.SendAppLogCallCount).Should(Equal(1)) + message, sourceType, tags := fakeMetronClient.SendAppLogArgsForCall(0) + Expect(message).To(Equal("a-message")) + Expect(sourceType).To(Equal("SSH")) + Expect(tags["source_id"]).To(Equal("a-guid")) + Expect(tags["instance_id"]).To(Equal("1")) + }) + }) + + Context("when the target contains a host fingerprint", func() { + Context("when the fingerprint is an md5 hash", func() { + BeforeEach(func() { + targetConfigJson, err := json.Marshal(proxy.TargetConfig{ + Address: sshdListener.Addr().String(), + HostFingerprint: helpers.MD5Fingerprint(TestHostKey.PublicKey()), + User: "some-user", + Password: "fake-some-password", + }) + Expect(err).NotTo(HaveOccurred()) + + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + }, + } + proxyAuthenticator.AuthenticateReturns(permissions, nil) + }) + + It("handshakes with the target using the provided configuration", func() { + Eventually(daemonAuthenticator.AuthenticateCallCount).Should(Equal(1)) + }) + }) + + Context("when the host fingerprint is a sha1 hash", func() { + BeforeEach(func() { + targetConfigJson, err := json.Marshal(proxy.TargetConfig{ + Address: sshdListener.Addr().String(), + HostFingerprint: helpers.SHA1Fingerprint(TestHostKey.PublicKey()), + User: "some-user", + Password: "fake-some-password", + }) + Expect(err).NotTo(HaveOccurred()) + + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + }, + } + proxyAuthenticator.AuthenticateReturns(permissions, nil) + }) + + It("handshakes with the target using the provided configuration", func() { + Eventually(daemonAuthenticator.AuthenticateCallCount).Should(Equal(1)) + }) + }) + + Context("when the host fingerprint is a sha256 hash", func() { + BeforeEach(func() { + targetConfigJson, err := json.Marshal(proxy.TargetConfig{ + Address: sshdListener.Addr().String(), + HostFingerprint: fmt.Sprintf("%s=", helpers.SHA256Fingerprint(TestHostKey.PublicKey())), + User: "some-user", + Password: "fake-some-password", + }) + Expect(err).NotTo(HaveOccurred()) + + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + }, + } + proxyAuthenticator.AuthenticateReturns(permissions, nil) + }) + + It("handshakes with the target using the provided configuration", func() { + Eventually(daemonAuthenticator.AuthenticateCallCount).Should(Equal(1)) + }) + }) + + Context("when the actual host fingerpreint does not match the expected fingerprint", func() { + BeforeEach(func() { + targetConfigJson, err := json.Marshal(proxy.TargetConfig{ + Address: sshdListener.Addr().String(), + HostFingerprint: "bogus-fingerprint", + User: "some-user", + Password: "fake-some-password", + }) + Expect(err).NotTo(HaveOccurred()) + + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + }, + } + proxyAuthenticator.AuthenticateReturns(permissions, nil) + }) + + It("does not attempt authentication with the target", func() { + Consistently(daemonAuthenticator.AuthenticateCallCount).Should(Equal(0)) + }) + + It("closes the connection", func() { + Eventually(client.Wait).Should(Equal(io.EOF)) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say(`host-key-fingerprint-mismatch`)) + }) + }) + }) + + Context("when the target address is unreachable", func() { + BeforeEach(func() { + permissions := &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": `{"address": "0.0.0.0:0"}`, + }, + } + proxyAuthenticator.AuthenticateReturns(permissions, nil) + }) + + It("closes the connection", func() { + Eventually(client.Wait).Should(Equal(io.EOF)) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say(`new-client-conn.dial-failed.*0\.0\.0\.0:0`)) + }) + }) + + Context("when the handshake fails", func() { + BeforeEach(func() { + daemonAuthenticator.AuthenticateReturns(nil, errors.New("go away")) + }) + + It("closes the connection", func() { + Eventually(client.Wait).Should(Equal(io.EOF)) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say(`new-client-conn.handshake-failed`)) + }) + }) + }) + + Context("when HandleConnection returns", func() { + var fakeServerConnection *fake_net.FakeConn + + BeforeEach(func() { + proxySSHConfig.NoClientAuth = true + daemonSSHConfig.NoClientAuth = true + }) + + JustBeforeEach(func() { + clientNetConn, serverNetConn := test_helpers.Pipe() + + fakeServerConnection = &fake_net.FakeConn{} + fakeServerConnection.ReadStub = serverNetConn.Read + fakeServerConnection.WriteStub = serverNetConn.Write + fakeServerConnection.CloseStub = serverNetConn.Close + + go sshProxy.HandleConnection(fakeServerConnection) + + clientConn, clientChannels, clientRequests, err := ssh.NewClientConn(clientNetConn, "0.0.0.0", clientConfig) + Expect(err).NotTo(HaveOccurred()) + + client := ssh.NewClient(clientConn, clientChannels, clientRequests) + client.Close() + }) + + It("ensures the network connection is closed", func() { + Eventually(fakeServerConnection.CloseCallCount).Should(BeNumerically(">=", 1)) + }) + }) + }) + + Context("after both handshakes have been performed", func() { + var clientConfig *ssh.ClientConfig + + BeforeEach(func() { + clientConfig = &ssh.ClientConfig{ + User: "diego:some-instance-guid", + Auth: []ssh.AuthMethod{ + ssh.Password("diego-user:diego-password"), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + daemonSSHConfig.NoClientAuth = true + }) + + Describe("client requests to target", func() { + var client *ssh.Client + + JustBeforeEach(func() { + var err error + client, err = ssh.Dial("tcp", proxyAddress, clientConfig) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + client.Close() + }) + + Context("when the client sends a global request", func() { + var globalRequestHandler *fake_handlers.FakeGlobalRequestHandler + + BeforeEach(func() { + globalRequestHandler = &fake_handlers.FakeGlobalRequestHandler{} + globalRequestHandler.HandleRequestStub = func(logger lager.Logger, request *ssh.Request, conn ssh.Conn, lnStore *helpers.ListenerStore) { + request.Reply(true, []byte("response-payload")) + } + daemonGlobalRequestHandlers["test-global-request"] = globalRequestHandler + }) + + It("gets forwarded to the daemon and the response comes back", func() { + accepted, response, err := client.SendRequest("test-global-request", true, []byte("request-payload")) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeTrue()) + Expect(response).To(Equal([]byte("response-payload"))) + + Expect(globalRequestHandler.HandleRequestCallCount()).To(Equal(1)) + + _, request, _, _ := globalRequestHandler.HandleRequestArgsForCall(0) + Expect(request.Type).To(Equal("test-global-request")) + Expect(request.WantReply).To(BeTrue()) + Expect(request.Payload).To(Equal([]byte("request-payload"))) + }) + }) + + Context("when the client requests a new channel", func() { + var newChannelHandler *fake_handlers.FakeNewChannelHandler + + BeforeEach(func() { + newChannelHandler = &fake_handlers.FakeNewChannelHandler{} + newChannelHandler.HandleNewChannelStub = func(logger lager.Logger, newChannel ssh.NewChannel) { + newChannel.Reject(ssh.Prohibited, "not now") + } + daemonNewChannelHandlers["test"] = newChannelHandler + }) + + It("gets forwarded to the daemon", func() { + _, _, err := client.OpenChannel("test", nil) + Expect(err).To(Equal(&ssh.OpenChannelError{Reason: ssh.Prohibited, Message: "not now"})) + }) + }) + }) + + Describe("target requests to client", func() { + var ( + connectionHandler *server_fakes.FakeConnectionHandler + + target *server.Server + targetDone chan struct{} + listener net.Listener + targetAddress string + + clientChannels <-chan ssh.NewChannel + clientRequests <-chan *ssh.Request + ) + + BeforeEach(func() { + var err error + listener, err = net.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + targetAddress = listener.Addr().String() + + connectionHandler = &server_fakes.FakeConnectionHandler{} + targetDone = make(chan struct{}) + }) + + JustBeforeEach(func() { + target = server.NewServer(logger.Session("target"), "", connectionHandler, 500*time.Millisecond) + target.SetListener(listener) + go func() { + target.Serve() + close(targetDone) + }() + + clientConfig := &ssh.ClientConfig{HostKeyCallback: ssh.InsecureIgnoreHostKey()} + clientNetConn, err := net.Dial("tcp", targetAddress) + Expect(err).ToNot(HaveOccurred()) + _, clientChannels, clientRequests, err = ssh.NewClientConn(clientNetConn, "0.0.0.0", clientConfig) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + target.Shutdown() + Eventually(targetDone).Should(BeClosed()) + }) + + Context("when the target sends a global request", func() { + var handleConnDone chan struct{} + + BeforeEach(func() { + handleConnDone = make(chan struct{}) + connectionHandler.HandleConnectionStub = func(conn net.Conn) { + defer GinkgoRecover() + defer func() { + handleConnDone <- struct{}{} + }() + + serverConn, _, _, err := ssh.NewServerConn(conn, daemonSSHConfig) + Expect(err).NotTo(HaveOccurred()) + + accepted, response, err := serverConn.SendRequest("test", true, []byte("test-data")) + Expect(err).NotTo(HaveOccurred()) + Expect(accepted).To(BeTrue()) + Expect(response).To(Equal([]byte("response-data"))) + + serverConn.Close() + } + }) + + AfterEach(func() { + close(handleConnDone) + }) + + It("gets forwarded to the client", func() { + var req *ssh.Request + Eventually(clientRequests).Should(Receive(&req)) + + req.Reply(true, []byte("response-data")) + + Eventually(handleConnDone).Should(Receive()) + }) + }) + + Context("when the target requests a new channel", func() { + var done chan struct{} + + BeforeEach(func() { + done = make(chan struct{}) + + connectionHandler.HandleConnectionStub = func(conn net.Conn) { + defer GinkgoRecover() + + serverConn, _, _, err := ssh.NewServerConn(conn, daemonSSHConfig) + Expect(err).NotTo(HaveOccurred()) + + channel, requests, err := serverConn.OpenChannel("test-channel", []byte("extra-data")) + Expect(err).NotTo(HaveOccurred()) + Expect(channel).NotTo(BeNil()) + Expect(requests).NotTo(BeClosed()) + + channel.Write([]byte("hello")) + + channelResponse := make([]byte, 7) + channel.Read(channelResponse) + Expect(string(channelResponse)).To(Equal("goodbye")) + + channel.Close() + serverConn.Close() + + close(done) + } + }) + + AfterEach(func() { + Eventually(done).Should(BeClosed()) + }) + + It("gets forwarded to the client", func() { + var newChannel ssh.NewChannel + Eventually(clientChannels).Should(Receive(&newChannel)) + + Expect(newChannel.ChannelType()).To(Equal("test-channel")) + Expect(newChannel.ExtraData()).To(Equal([]byte("extra-data"))) + + channel, requests, err := newChannel.Accept() + Expect(err).NotTo(HaveOccurred()) + Expect(channel).NotTo(BeNil()) + Expect(requests).NotTo(BeClosed()) + + channelRequest := make([]byte, 5) + channel.Read(channelRequest) + Expect(string(channelRequest)).To(Equal("hello")) + + channel.Write([]byte("goodbye")) + channel.Close() + }) + }) + }) + + Describe("connection metrics", func() { + type metric struct { + name string + value int + } + + var ( + metricChan chan metric + ) + + BeforeEach(func() { + metricChan = make(chan metric, 2) + + fakeMetronClient.SendMetricStub = func(name string, value int, opts ...loggregator.EmitGaugeOption) error { + metricChan <- metric{name: name, value: value} + return nil + } + }) + + Context("when a connection is received", func() { + It("emit a metric for the total number of connections", func() { + _, err := ssh.Dial("tcp", proxyAddress, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + Eventually(metricChan).Should(Receive(Equal(metric{ + name: "ssh-connections", + value: 1, + }))) + + _, err = ssh.Dial("tcp", proxyAddress, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + Eventually(metricChan).Should(Receive(Equal(metric{ + name: "ssh-connections", + value: 2, + }))) + }) + }) + + Context("when a connection is closed", func() { + It("emit a metric for the total number of connections", func() { + conn, err := ssh.Dial("tcp", proxyAddress, clientConfig) + Expect(err).NotTo(HaveOccurred()) + Eventually(metricChan).Should(Receive(Equal(metric{ + name: "ssh-connections", + value: 1, + }))) + + conn.Close() + Eventually(metricChan).Should(Receive(Equal(metric{ + name: "ssh-connections", + value: 0, + }))) + }) + }) + }) + + Describe("app logs", func() { + Context("when a connection is closed", func() { + It("logs that the connection has been closed", func() { + conn, err := ssh.Dial("tcp", proxyAddress, clientConfig) + Expect(err).NotTo(HaveOccurred()) + + conn.Close() + + Eventually(fakeMetronClient.SendAppLogCallCount).Should(Equal(2)) + message, sourceType, tags := fakeMetronClient.SendAppLogArgsForCall(1) + Expect(message).To(ContainSubstring("Remote access ended for")) + Expect(sourceType).To(Equal("SSH")) + Expect(tags["source_id"]).To(Equal("a-guid")) + Expect(tags["instance_id"]).To(Equal("1")) + }) + }) + }) + }) + }) + + Describe("ProxyGlobalRequests", func() { + var ( + sshConn *fake_ssh.FakeConn + reqChan chan *ssh.Request + + done chan struct{} + ) + + BeforeEach(func() { + sshConn = &fake_ssh.FakeConn{} + reqChan = make(chan *ssh.Request, 2) + done = make(chan struct{}, 1) + }) + + JustBeforeEach(func() { + go func(done chan<- struct{}) { + proxy.ProxyGlobalRequests(logger, sshConn, reqChan) + done <- struct{}{} + }(done) + }) + + Context("when a request is received", func() { + BeforeEach(func() { + request := &ssh.Request{Type: "test", WantReply: false, Payload: []byte("test-data")} + reqChan <- request + reqChan <- request + }) + + AfterEach(func() { + close(reqChan) + }) + + It("forwards requests from the channel to the connection", func() { + Eventually(sshConn.SendRequestCallCount).Should(Equal(2)) + Consistently(sshConn.SendRequestCallCount).Should(Equal(2)) + + reqType, wantReply, payload := sshConn.SendRequestArgsForCall(0) + Expect(reqType).To(Equal("test")) + Expect(wantReply).To(BeFalse()) + Expect(payload).To(Equal([]byte("test-data"))) + + reqType, wantReply, payload = sshConn.SendRequestArgsForCall(1) + Expect(reqType).To(Equal("test")) + Expect(wantReply).To(BeFalse()) + Expect(payload).To(Equal([]byte("test-data"))) + }) + }) + + Context("when SendRequest fails", func() { + BeforeEach(func() { + callCount := 0 + sshConn.SendRequestStub = func(rt string, wr bool, p []byte) (bool, []byte, error) { + callCount++ + if callCount == 1 { + return false, nil, errors.New("woops") + } + return true, nil, nil + } + + reqChan <- &ssh.Request{} + reqChan <- &ssh.Request{} + }) + + AfterEach(func() { + close(reqChan) + }) + + It("continues processing requests", func() { + Eventually(sshConn.SendRequestCallCount).Should(Equal(2)) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say(`send-request-failed.*woops`)) + }) + }) + + Context("when the request channel closes", func() { + JustBeforeEach(func() { + Consistently(reqChan).ShouldNot(BeClosed()) + close(reqChan) + }) + + It("returns gracefully", func() { + Eventually(done).Should(Receive()) + }) + }) + }) + + Describe("ProxyChannels", func() { + var ( + targetConn *fake_ssh.FakeConn + newChanChan chan ssh.NewChannel + + newChan *fake_ssh.FakeNewChannel + sourceChannel *fake_ssh.FakeChannel + sourceReqChan chan *ssh.Request + sourceStderr *fake_ssh.FakeChannel + + targetChannel *fake_ssh.FakeChannel + targetReqChan chan *ssh.Request + targetStderr *fake_ssh.FakeChannel + + done chan struct{} + ) + + BeforeEach(func() { + targetConn = &fake_ssh.FakeConn{} + newChanChan = make(chan ssh.NewChannel, 1) + + newChan = &fake_ssh.FakeNewChannel{} + sourceChannel = &fake_ssh.FakeChannel{} + sourceReqChan = make(chan *ssh.Request, 2) + sourceStderr = &fake_ssh.FakeChannel{} + sourceStderr.ReadReturns(0, io.EOF) + sourceChannel.StderrReturns(sourceStderr) + + targetChannel = &fake_ssh.FakeChannel{} + targetReqChan = make(chan *ssh.Request, 2) + targetStderr = &fake_ssh.FakeChannel{} + targetStderr.ReadReturns(0, io.EOF) + targetChannel.StderrReturns(targetStderr) + + done = make(chan struct{}, 1) + }) + + JustBeforeEach(func() { + go func(done chan<- struct{}) { + proxy.ProxyChannels(logger, targetConn, newChanChan) + done <- struct{}{} + }(done) + }) + + Context("when a new channel is opened by the client", func() { + BeforeEach(func() { + sourceChannel.ReadReturns(0, io.EOF) + targetChannel.ReadReturns(0, io.EOF) + + newChan.ChannelTypeReturns("test") + newChan.ExtraDataReturns([]byte("extra-data")) + newChan.AcceptReturns(sourceChannel, sourceReqChan, nil) + + targetConn.OpenChannelReturns(targetChannel, targetReqChan, nil) + + newChanChan <- newChan + }) + + AfterEach(func() { + close(newChanChan) + }) + + It("forwards the NewChannel request to the target", func() { + Eventually(targetConn.OpenChannelCallCount).Should(Equal(1)) + Consistently(targetConn.OpenChannelCallCount).Should(Equal(1)) + + channelType, extraData := targetConn.OpenChannelArgsForCall(0) + Expect(channelType).To(Equal("test")) + Expect(extraData).To(Equal([]byte("extra-data"))) + }) + + Context("when the target accepts the connection", func() { + It("accepts the source request", func() { + Eventually(newChan.AcceptCallCount).Should(Equal(1)) + }) + + Context("when the source channel has data available", func() { + BeforeEach(func() { + sourceChannel.ReadStub = func(dest []byte) (int, error) { + if cap(dest) >= 3 { + copy(dest, []byte("abc")) + return 3, io.EOF + } + return 0, io.EOF + } + sourceStderr.ReadStub = func(dest []byte) (int, error) { + if cap(dest) >= 3 { + copy(dest, []byte("xyz")) + return 3, io.EOF + } + return 0, io.EOF + } + }) + + It("copies the source channel to the target channel", func() { + Eventually(targetChannel.WriteCallCount).ShouldNot(Equal(0)) + + data := targetChannel.WriteArgsForCall(0) + Expect(data).To(Equal([]byte("abc"))) + + }) + + It("copies the source stderr to the target stderr", func() { + Eventually(targetStderr.WriteCallCount).ShouldNot(Equal(0)) + + data := targetStderr.WriteArgsForCall(0) + Expect(data).To(Equal([]byte("xyz"))) + }) + }) + + Context("when the target channel has data available", func() { + BeforeEach(func() { + targetChannel.ReadStub = func(dest []byte) (int, error) { + if cap(dest) >= 3 { + copy(dest, []byte("xyz")) + return 3, io.EOF + } + return 0, io.EOF + } + targetStderr.ReadStub = func(dest []byte) (int, error) { + if cap(dest) >= 3 { + copy(dest, []byte("abc")) + return 3, io.EOF + } + return 0, io.EOF + } + }) + + It("copies the target channel to the source channel", func() { + Eventually(sourceChannel.WriteCallCount).ShouldNot(Equal(0)) + + data := sourceChannel.WriteArgsForCall(0) + Expect(data).To(Equal([]byte("xyz"))) + + }) + + It("copies the target stderr to the source stderr", func() { + Eventually(sourceStderr.WriteCallCount).ShouldNot(Equal(0)) + + data := sourceStderr.WriteArgsForCall(0) + Expect(data).To(Equal([]byte("abc"))) + }) + }) + + Context("when the source channel closes", func() { + BeforeEach(func() { + sourceChannel.ReadReturns(0, io.EOF) + }) + + It("closes the target channel", func() { + Eventually(sourceChannel.ReadCallCount).Should(Equal(1)) + Eventually(targetChannel.CloseWriteCallCount).Should(Equal(1)) + }) + }) + + Context("when the target channel closes", func() { + BeforeEach(func() { + targetChannel.ReadReturns(0, io.EOF) + }) + + It("closes the source channel", func() { + Eventually(sourceChannel.ReadCallCount).Should(Equal(1)) + Eventually(targetChannel.CloseWriteCallCount).Should(Equal(1)) + }) + }) + + Context("when out of band requests are received on the source channel", func() { + BeforeEach(func() { + request := &ssh.Request{Type: "test", WantReply: false, Payload: []byte("test-data")} + sourceReqChan <- request + }) + + It("forwards the request to the target channel", func() { + Eventually(targetChannel.SendRequestCallCount).Should(Equal(1)) + + reqType, wantReply, payload := targetChannel.SendRequestArgsForCall(0) + Expect(reqType).To(Equal("test")) + Expect(wantReply).To(BeFalse()) + Expect(payload).To(Equal([]byte("test-data"))) + }) + }) + + Context("when out of band requests are received from the target channel", func() { + BeforeEach(func() { + request := &ssh.Request{Type: "test", WantReply: false, Payload: []byte("test-data")} + targetReqChan <- request + }) + + It("forwards the request to the target channel", func() { + Eventually(sourceChannel.SendRequestCallCount).Should(Equal(1)) + + reqType, wantReply, payload := sourceChannel.SendRequestArgsForCall(0) + Expect(reqType).To(Equal("test")) + Expect(wantReply).To(BeFalse()) + Expect(payload).To(Equal([]byte("test-data"))) + }) + }) + }) + + Context("when the target rejects the connection", func() { + BeforeEach(func() { + openError := &ssh.OpenChannelError{ + Reason: ssh.Prohibited, + Message: "go away", + } + targetConn.OpenChannelReturns(nil, nil, openError) + }) + + It("rejects the source request with the upstream error", func() { + Eventually(newChan.RejectCallCount).Should(Equal(1)) + + reason, message := newChan.RejectArgsForCall(0) + Expect(reason).To(Equal(ssh.Prohibited)) + Expect(message).To(Equal("go away")) + }) + + It("continues processing new channel requests", func() { + newChanChan <- newChan + Eventually(newChan.RejectCallCount).Should(Equal(2)) + }) + }) + + Context("when openning a channel failsfails", func() { + BeforeEach(func() { + targetConn.OpenChannelReturns(nil, nil, errors.New("woops")) + }) + + It("rejects the source request with a connection failed reason", func() { + Eventually(newChan.RejectCallCount).Should(Equal(1)) + + reason, message := newChan.RejectArgsForCall(0) + Expect(reason).To(Equal(ssh.ConnectionFailed)) + Expect(message).To(Equal("woops")) + }) + + It("continues processing new channel requests", func() { + newChanChan <- newChan + Eventually(newChan.RejectCallCount).Should(Equal(2)) + }) + }) + }) + + Context("when the new channel channel closes", func() { + JustBeforeEach(func() { + Consistently(newChanChan).ShouldNot(BeClosed()) + close(newChanChan) + }) + + It("returns gracefully", func() { + Eventually(done).Should(Receive()) + }) + }) + }) + + Describe("ProxyRequests", func() { + var ( + channel *fake_ssh.FakeChannel + reqChan chan *ssh.Request + + wg *sync.WaitGroup + done chan struct{} + ) + + BeforeEach(func() { + wg = &sync.WaitGroup{} + channel = &fake_ssh.FakeChannel{} + reqChan = make(chan *ssh.Request, 2) + done = make(chan struct{}, 1) + }) + + JustBeforeEach(func() { + go func(done chan<- struct{}) { + proxy.ProxyRequests(logger, "test", reqChan, channel, wg) + done <- struct{}{} + }(done) + }) + + Context("when a request is received", func() { + BeforeEach(func() { + request := &ssh.Request{Type: "test", WantReply: false, Payload: []byte("test-data")} + reqChan <- request + reqChan <- request + }) + + AfterEach(func() { + close(reqChan) + }) + + It("forwards requests from the channel to the connection", func() { + Eventually(channel.SendRequestCallCount).Should(Equal(2)) + Consistently(channel.SendRequestCallCount).Should(Equal(2)) + + reqType, wantReply, payload := channel.SendRequestArgsForCall(0) + Expect(reqType).To(Equal("test")) + Expect(wantReply).To(BeFalse()) + Expect(payload).To(Equal([]byte("test-data"))) + + reqType, wantReply, payload = channel.SendRequestArgsForCall(1) + Expect(reqType).To(Equal("test")) + Expect(wantReply).To(BeFalse()) + Expect(payload).To(Equal([]byte("test-data"))) + }) + }) + + Context("when SendRequest fails", func() { + BeforeEach(func() { + callCount := 0 + channel.SendRequestStub = func(rt string, wr bool, p []byte) (bool, error) { + callCount++ + if callCount == 1 { + return false, errors.New("woops") + } + return true, nil + } + + reqChan <- &ssh.Request{} + reqChan <- &ssh.Request{} + }) + + AfterEach(func() { + close(reqChan) + }) + + It("continues processing requests", func() { + Eventually(channel.SendRequestCallCount).Should(Equal(2)) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say(`send-request-failed.*woops`)) + }) + }) + + Context("when the request channel closes", func() { + JustBeforeEach(func() { + Consistently(reqChan).ShouldNot(BeClosed()) + close(reqChan) + }) + + It("returns gracefully", func() { + Eventually(done).Should(Receive()) + }) + }) + + Context("when an exit-status request is received", func() { + BeforeEach(func() { + request := &ssh.Request{Type: "exit-status", WantReply: false, Payload: []byte("test-data")} + reqChan <- request + reqChan <- request + }) + + AfterEach(func() { + close(reqChan) + }) + + It("does not handle extra requests", func() { + Eventually(channel.SendRequestCallCount).Should(Equal(1)) + Consistently(channel.SendRequestCallCount).Should(Equal(1)) + + Eventually(channel.CloseCallCount).Should(Equal(1)) + reqType, wantReply, payload := channel.SendRequestArgsForCall(0) + Expect(reqType).To(Equal("exit-status")) + Expect(wantReply).To(BeFalse()) + Expect(payload).To(Equal([]byte("test-data"))) + }) + + Context("when there is a wait group", func() { + BeforeEach(func() { + wg.Add(1) + }) + + It("exits when the waitgroup is done", func() { + Eventually(channel.SendRequestCallCount).Should(Equal(1)) + Consistently(channel.SendRequestCallCount).Should(Equal(1)) + + Consistently(channel.CloseCallCount).Should(Equal(0)) + wg.Done() + Eventually(channel.CloseCallCount).Should(Equal(1)) + + reqType, wantReply, payload := channel.SendRequestArgsForCall(0) + Expect(reqType).To(Equal("exit-status")) + Expect(wantReply).To(BeFalse()) + Expect(payload).To(Equal([]byte("test-data"))) + }) + }) + }) + }) + + Describe("NewClientConn", func() { + var ( + permissions *ssh.Permissions + + daemonSSHConfig *ssh.ServerConfig + sshDaemon *daemon.Daemon + sshdListener net.Listener + sshdServer *server.Server + + newClientConnErr error + tlsCfg *tls.Config + ) + + BeforeEach(func() { + permissions = &ssh.Permissions{ + CriticalOptions: map[string]string{}, + } + + daemonSSHConfig = &ssh.ServerConfig{} + daemonSSHConfig.AddHostKey(TestHostKey) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + + sshdListener = listener + }) + + JustBeforeEach(func() { + sshDaemon = daemon.New(logger.Session("sshd"), daemonSSHConfig, nil, nil) + sshdServer = server.NewServer(logger, "127.0.0.1:0", sshDaemon, 500*time.Millisecond) + sshdServer.SetListener(sshdListener) + go sshdServer.Serve() + + _, _, _, newClientConnErr = proxy.NewClientConn(logger, permissions, tlsCfg) + }) + + AfterEach(func() { + sshdServer.Shutdown() + }) + + Context("when permissions is nil", func() { + BeforeEach(func() { + permissions = nil + }) + + It("returns an error", func() { + Expect(newClientConnErr).To(HaveOccurred()) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say("permissions-and-critical-options-required")) + }) + }) + + Context("when permissions.CriticalOptions is nil", func() { + BeforeEach(func() { + permissions.CriticalOptions = nil + }) + + It("returns an error", func() { + Expect(newClientConnErr).To(HaveOccurred()) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say("permissions-and-critical-options-required")) + }) + }) + + Context("when the config is missing", func() { + BeforeEach(func() { + delete(permissions.CriticalOptions, "proxy-target-config") + }) + + It("returns an error", func() { + Expect(newClientConnErr).To(HaveOccurred()) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say("unmarshal-failed")) + }) + }) + + Context("when the config fails to unmarshal", func() { + BeforeEach(func() { + permissions.CriticalOptions["proxy-target-config"] = "{ this_is: invalid json" + }) + + It("returns an error", func() { + Expect(newClientConnErr).To(HaveOccurred()) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say("unmarshal-failed")) + }) + }) + + Context("when the address in the config is bad", func() { + BeforeEach(func() { + permissions.CriticalOptions["proxy-target-config"] = `{ "address": "0.0.0.0:0" }` + }) + + It("returns an error", func() { + Expect(newClientConnErr).To(HaveOccurred()) + }) + + It("logs the failure", func() { + Eventually(logger).Should(gbytes.Say("dial-failed")) + }) + }) + + Context("when tls config is passed in", func() { + var certDepoDir string + + BeforeEach(func() { + var err error + certDepoDir, err = os.MkdirTemp("", "") + Expect(err).NotTo(HaveOccurred()) + + ca, err := certauthority.NewCertAuthority(certDepoDir, "server_ca") + Expect(err).NotTo(HaveOccurred()) + + serverKeyFile, serverCertFile, err := ca.GenerateSelfSignedCertAndKey("server", []string{"some-instance-guid"}, false) + Expect(err).NotTo(HaveOccurred()) + _, serverCAFile := ca.CAAndKey() + + tlsCfg, err = tlsconfig.Build( + tlsconfig.WithInternalServiceDefaults(), + tlsconfig.WithIdentityFromFile(serverCertFile, serverKeyFile), + ).Client(tlsconfig.WithAuthorityFromFile(serverCAFile)) // used for a client connection to the tls proxy address + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + Expect(os.RemoveAll(certDepoDir)).To(Succeed()) + }) + + Context("and the tls address is not available", func() { + BeforeEach(func() { + targetConfigJson, err := json.Marshal(proxy.TargetConfig{ + Address: sshdListener.Addr().String(), + TLSAddress: "", + }) + Expect(err).NotTo(HaveOccurred()) + + permissions = &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + }, + } + }) + + It("does not log any errors", func() { + Consistently(logger).ShouldNot(gbytes.Say("tls-dial-failed")) + }) + }) + + Context("and the tls address is available", func() { + var ( + onConnectionReceived chan struct{} + ) + + BeforeEach(func() { + onConnectionReceived = make(chan struct{}, 10) + intermediaryListener, err := tls.Listen("tcp", "127.0.0.1:0", tlsCfg) + Expect(err).ToNot(HaveOccurred()) + go forwardTLSConn(sshdListener.Addr().String(), intermediaryListener, onConnectionReceived) + + targetConfigJSON, err := json.Marshal(proxy.TargetConfig{ + Address: "", + TLSAddress: intermediaryListener.Addr().String(), + ServerCertDomainSAN: "some-instance-guid", + }) + Expect(err).NotTo(HaveOccurred()) + + permissions = &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJSON), + }, + } + }) + + It("connects successfully", func() { + Eventually(onConnectionReceived).Should(Receive()) + Eventually(logger).Should(gbytes.Say("connected-to-backend")) + }) + }) + }) + + Context("when the config contains a user and password", func() { + var passwordAuthenticator *fake_authenticators.FakePasswordAuthenticator + + BeforeEach(func() { + targetConfigJson, err := json.Marshal(proxy.TargetConfig{ + Address: sshdListener.Addr().String(), + User: "my-user", + Password: "my-password", + }) + Expect(err).NotTo(HaveOccurred()) + + permissions = &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + }, + } + + passwordAuthenticator = &fake_authenticators.FakePasswordAuthenticator{} + daemonSSHConfig.PasswordCallback = passwordAuthenticator.Authenticate + }) + + It("uses the user and password for authentication", func() { + Expect(passwordAuthenticator.AuthenticateCallCount()).To(Equal(1)) + + metadata, password := passwordAuthenticator.AuthenticateArgsForCall(0) + Expect(metadata.User()).To(Equal("my-user")) + Expect(string(password)).To(Equal("my-password")) + }) + }) + + Context("when the config contains a public key", func() { + var publicKeyAuthenticator *fake_authenticators.FakePublicKeyAuthenticator + + BeforeEach(func() { + targetConfigJson, err := json.Marshal(proxy.TargetConfig{ + Address: sshdListener.Addr().String(), + PrivateKey: TestPrivatePem, + }) + Expect(err).NotTo(HaveOccurred()) + + permissions = &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + }, + } + + publicKeyAuthenticator = &fake_authenticators.FakePublicKeyAuthenticator{} + publicKeyAuthenticator.AuthenticateReturns(&ssh.Permissions{}, nil) + daemonSSHConfig.PublicKeyCallback = publicKeyAuthenticator.Authenticate + }) + + It("will use the public key for authentication", func() { + expectedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(TestPublicAuthorizedKey)) + Expect(err).NotTo(HaveOccurred()) + + Expect(publicKeyAuthenticator.AuthenticateCallCount()).To(Equal(1)) + + _, actualKey := publicKeyAuthenticator.AuthenticateArgsForCall(0) + Expect(actualKey.Marshal()).To(Equal(expectedKey.Marshal())) + }) + }) + + Context("when the config contains a user and a public key", func() { + var publicKeyAuthenticator *fake_authenticators.FakePublicKeyAuthenticator + + BeforeEach(func() { + targetConfigJson, err := json.Marshal(proxy.TargetConfig{ + Address: sshdListener.Addr().String(), + User: "my-user", + PrivateKey: TestPrivatePem, + }) + Expect(err).NotTo(HaveOccurred()) + + permissions = &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + }, + } + + publicKeyAuthenticator = &fake_authenticators.FakePublicKeyAuthenticator{} + publicKeyAuthenticator.AuthenticateReturns(&ssh.Permissions{}, nil) + daemonSSHConfig.PublicKeyCallback = publicKeyAuthenticator.Authenticate + }) + + It("will use the user and public key for authentication", func() { + expectedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(TestPublicAuthorizedKey)) + Expect(err).NotTo(HaveOccurred()) + + Expect(publicKeyAuthenticator.AuthenticateCallCount()).To(Equal(1)) + + metadata, actualKey := publicKeyAuthenticator.AuthenticateArgsForCall(0) + Expect(metadata.User()).To(Equal("my-user")) + Expect(actualKey.Marshal()).To(Equal(expectedKey.Marshal())) + }) + }) + + Context("when the config contains a user, password, a public key", func() { + var publicKeyAuthenticator *fake_authenticators.FakePublicKeyAuthenticator + var passwordAuthenticator *fake_authenticators.FakePasswordAuthenticator + + BeforeEach(func() { + targetConfigJson, err := json.Marshal(proxy.TargetConfig{ + Address: sshdListener.Addr().String(), + User: "my-user", + Password: "my-password", + PrivateKey: TestPrivatePem, + }) + Expect(err).NotTo(HaveOccurred()) + + permissions = &ssh.Permissions{ + CriticalOptions: map[string]string{ + "proxy-target-config": string(targetConfigJson), + }, + } + + passwordAuthenticator = &fake_authenticators.FakePasswordAuthenticator{} + daemonSSHConfig.PasswordCallback = passwordAuthenticator.Authenticate + + publicKeyAuthenticator = &fake_authenticators.FakePublicKeyAuthenticator{} + publicKeyAuthenticator.AuthenticateReturns(&ssh.Permissions{}, nil) + daemonSSHConfig.PublicKeyCallback = publicKeyAuthenticator.Authenticate + }) + + It("will attempt to use the public key for authentication before the password", func() { + Expect(publicKeyAuthenticator.AuthenticateCallCount()).To(Equal(1)) + Expect(passwordAuthenticator.AuthenticateCallCount()).To(Equal(0)) + }) + + Context("when public key authentication fails", func() { + BeforeEach(func() { + passwordAuthenticator.AuthenticateReturns(&ssh.Permissions{}, nil) + publicKeyAuthenticator.AuthenticateReturns(nil, errors.New("go away")) + }) + + It("will fall back to password authentication", func() { + Expect(publicKeyAuthenticator.AuthenticateCallCount()).To(Equal(1)) + Expect(passwordAuthenticator.AuthenticateCallCount()).To(Equal(1)) + }) + }) + }) + }) + + Describe("Wait", func() { + var ( + waitChans []chan struct{} + waiters []proxy.Waiter + + done chan struct{} + ) + + BeforeEach(func() { + for i := 0; i < 3; i++ { + idx := i + waitChans = append(waitChans, make(chan struct{})) + + conn := &fake_ssh.FakeConn{} + conn.WaitStub = func() error { + <-waitChans[idx] + return nil + } + waiters = append(waiters, conn) + } + + done = make(chan struct{}, 1) + }) + + JustBeforeEach(func() { + go func(done chan<- struct{}) { + proxy.Wait(logger, waiters...) + done <- struct{}{} + }(done) + }) + + It("waits for all Waiters to finish", func() { + Consistently(done).ShouldNot(Receive()) + close(waitChans[0]) + + Consistently(done).ShouldNot(Receive()) + close(waitChans[1]) + + Consistently(done).ShouldNot(Receive()) + close(waitChans[2]) + + Eventually(done).Should(Receive()) + }) + }) +}) + +func forwardTLSConn(serverAddress string, proxy net.Listener, onConnectionReceived chan struct{}) { + for { + conn, err := proxy.Accept() + if err != nil { + return + } + + tlsConn := conn.(*tls.Conn) + err = tlsConn.Handshake() + if err != nil { + return + } + + defer tlsConn.Close() + + if onConnectionReceived != nil { + onConnectionReceived <- struct{}{} + } + + proxyConn, err := net.Dial("tcp", serverAddress) + if err != nil { + return + } + + wg := sync.WaitGroup{} + wg.Add(2) + + go func() { + _, _ = io.Copy(conn, proxyConn) + tlsConn.CloseWrite() + wg.Done() + }() + + go func() { + _, _ = io.Copy(proxyConn, conn) + wg.Done() + }() + + wg.Wait() + } +} diff --git a/src/code.cloudfoundry.org/diego-ssh/routes/diego_ssh.go b/src/code.cloudfoundry.org/diego-ssh/routes/diego_ssh.go new file mode 100644 index 0000000000..3f33e52950 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/routes/diego_ssh.go @@ -0,0 +1,11 @@ +package routes + +const DIEGO_SSH = "diego-ssh" + +type SSHRoute struct { + ContainerPort uint32 `json:"container_port"` + HostFingerprint string `json:"host_fingerprint,omitempty"` + User string `json:"user,omitempty"` + Password string `json:"password,omitempty"` + PrivateKey string `json:"private_key,omitempty"` +} diff --git a/src/code.cloudfoundry.org/diego-ssh/routes/diego_ssh_test.go b/src/code.cloudfoundry.org/diego-ssh/routes/diego_ssh_test.go new file mode 100644 index 0000000000..8f466f2d91 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/routes/diego_ssh_test.go @@ -0,0 +1,83 @@ +package routes_test + +import ( + "encoding/json" + + "code.cloudfoundry.org/diego-ssh/routes" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Diego SSH Route", func() { + var route routes.SSHRoute + + BeforeEach(func() { + route = routes.SSHRoute{ + ContainerPort: 2222, + HostFingerprint: "my-key-fingerprint", + User: "user", + Password: "password", + PrivateKey: "FAKE_PEM_ENCODED_KEY", + } + }) + + Describe("JSON Marshalling", func() { + Context("when the user and password are missing", func() { + var expectedJson string + + BeforeEach(func() { + route.User = "" + route.Password = "" + + expectedJson = `{ + "container_port": 2222, + "host_fingerprint": "my-key-fingerprint", + "private_key": "FAKE_PEM_ENCODED_KEY" + }` + }) + + It("marshals the structure correctly", func() { + payload, err := json.Marshal(route) + Expect(err).NotTo(HaveOccurred()) + + Expect(payload).To(MatchJSON(expectedJson)) + }) + }) + + Context("when the private key and host fingerprint are empty", func() { + var expectedJson string + + BeforeEach(func() { + route.PrivateKey = "" + route.HostFingerprint = "" + + expectedJson = `{ + "container_port": 2222, + "user": "user", + "password": "password" + }` + }) + + It("marshals the structure correctly", func() { + payload, err := json.Marshal(route) + Expect(err).NotTo(HaveOccurred()) + + Expect(payload).To(MatchJSON(expectedJson)) + }) + }) + }) + + Describe("Round Trip Marshalling", func() { + It("successfully marshals and unmarshals", func() { + payload, err := json.Marshal(route) + Expect(err).NotTo(HaveOccurred()) + + var result routes.SSHRoute + err = json.Unmarshal(payload, &result) + Expect(err).NotTo(HaveOccurred()) + + Expect(result).To(Equal(route)) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/routes/models_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/routes/models_suite_test.go new file mode 100644 index 0000000000..c44ccee176 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/routes/models_suite_test.go @@ -0,0 +1,13 @@ +package routes_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "testing" +) + +func TestRoutes(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Routes Suite") +} diff --git a/src/code.cloudfoundry.org/diego-ssh/routes/package.go b/src/code.cloudfoundry.org/diego-ssh/routes/package.go new file mode 100644 index 0000000000..31a307de02 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/routes/package.go @@ -0,0 +1 @@ +package routes // import "code.cloudfoundry.org/diego-ssh/routes" diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/atime/access_time.go b/src/code.cloudfoundry.org/diego-ssh/scp/atime/access_time.go new file mode 100644 index 0000000000..afb329e399 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/atime/access_time.go @@ -0,0 +1,19 @@ +//go:build !windows + +package atime + +import ( + "errors" + "os" + "time" +) + +func AccessTime(fileInfo os.FileInfo) (time.Time, error) { + if fileInfo == nil || fileInfo.Sys() == nil { + return time.Time{}, errors.New("underlying file information unavailable") + } + + timespec := accessTimespec(fileInfo) + + return time.Unix(int64(timespec.Sec), int64(timespec.Nsec)), nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/atime/access_time_windows.go b/src/code.cloudfoundry.org/diego-ssh/scp/atime/access_time_windows.go new file mode 100644 index 0000000000..bba6f315dc --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/atime/access_time_windows.go @@ -0,0 +1,20 @@ +//go:build windows + +package atime + +import ( + "errors" + "os" + "syscall" + "time" +) + +func AccessTime(fileInfo os.FileInfo) (time.Time, error) { + if fileInfo == nil || fileInfo.Sys() == nil { + return time.Time{}, errors.New("underlying file information unavailable") + } + + accessTime := fileInfo.Sys().(*syscall.Win32FileAttributeData).LastAccessTime.Nanoseconds() + + return time.Unix(int64(accessTime/1e9), int64(accessTime%1e9)), nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/atime/atime_darwin.go b/src/code.cloudfoundry.org/diego-ssh/scp/atime/atime_darwin.go new file mode 100644 index 0000000000..0d46b5bf26 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/atime/atime_darwin.go @@ -0,0 +1,12 @@ +//go:build darwin + +package atime + +import ( + "os" + "syscall" +) + +func accessTimespec(fileInfo os.FileInfo) syscall.Timespec { + return fileInfo.Sys().(*syscall.Stat_t).Atimespec +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/atime/atime_linux.go b/src/code.cloudfoundry.org/diego-ssh/scp/atime/atime_linux.go new file mode 100644 index 0000000000..68310cca4b --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/atime/atime_linux.go @@ -0,0 +1,12 @@ +//go:build linux + +package atime + +import ( + "os" + "syscall" +) + +func accessTimespec(fileInfo os.FileInfo) syscall.Timespec { + return fileInfo.Sys().(*syscall.Stat_t).Atim +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/atime/package.go b/src/code.cloudfoundry.org/diego-ssh/scp/atime/package.go new file mode 100644 index 0000000000..f60afa4468 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/atime/package.go @@ -0,0 +1 @@ +package atime // import "code.cloudfoundry.org/diego-ssh/scp/atime" diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/directory.go b/src/code.cloudfoundry.org/diego-ssh/scp/directory.go new file mode 100644 index 0000000000..99bdc79371 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/directory.go @@ -0,0 +1,182 @@ +package scp + +import ( + "fmt" + "os" + "path/filepath" + "strconv" +) + +func (s *secureCopy) SendDirectory(dir string, dirInfo os.FileInfo) error { + return s.sendDirectory(dir, dirInfo) +} + +func (s *secureCopy) ReceiveDirectory(dir string, timeMessage *TimeMessage) error { + messageType, err := s.session.readByte() + if err != nil { + return err + } + + if messageType != byte('D') { + return fmt.Errorf("unexpected message type: %c", messageType) + } + + dirModeString, err := s.session.readString(SPACE) + if err != nil { + return err + } + + dirMode, err := strconv.ParseUint(dirModeString, 8, 32) + if err != nil { + return err + } + + // Length field is ignored + _, err = s.session.readString(SPACE) + if err != nil { + return err + } + + dirName, err := s.session.readString(NEWLINE) + if err != nil { + return err + } + + err = s.session.sendConfirmation() + if err != nil { + return err + } + + targetPath := filepath.Join(dir, dirName) + _, err = os.Stat(dir) + if os.IsNotExist(err) { + targetPath = dir + } + + targetInfo, err := os.Stat(targetPath) + if err != nil { + if !os.IsNotExist(err) { + return err + } + + err = os.Mkdir(targetPath, os.FileMode(dirMode)) + if err != nil { + return err + } + } else if !targetInfo.Mode().IsDir() { + return fmt.Errorf("target exists and is not a directory: %q", dirName) + } + + err = s.processDirectoryMessages(targetPath) + if err != nil { + return err + } + + message, err := s.session.readString(NEWLINE) + if err != nil { + return err + } + if message != "E" { + return fmt.Errorf("unexpected message type: %c", messageType) + } + + if timeMessage != nil && s.session.preserveTimesAndMode { + err := os.Chtimes(targetPath, timeMessage.accessTime, timeMessage.modificationTime) + if err != nil { + return err + } + } + + err = s.session.sendConfirmation() + if err != nil { + return err + } + + return nil +} + +func (s *secureCopy) processDirectoryMessages(dirPath string) error { + for { + messageType, err := s.session.peekByte() + if err != nil { + return err + } + + var timeMessage *TimeMessage + if messageType == 'T' && s.session.preserveTimesAndMode { + timeMessage = &TimeMessage{} + err := timeMessage.Receive(s.session) + if err != nil { + return err + } + + messageType, err = s.session.peekByte() + if err != nil { + return err + } + } + + switch messageType { + case 'D': + err := s.ReceiveDirectory(dirPath, timeMessage) + if err != nil { + return err + } + case 'C': + err := s.ReceiveFile(dirPath, true, timeMessage) + if err != nil { + return err + } + case 'E': + if timeMessage != nil { + return fmt.Errorf("unexpected message type: %c", messageType) + } + return nil + default: + return fmt.Errorf("unexpected message type: %c", messageType) + } + } +} + +func (s *secureCopy) sendDirectory(dirname string, directoryInfo os.FileInfo) error { + if s.session.preserveTimesAndMode { + timeMessage := NewTimeMessage(directoryInfo) + err := timeMessage.Send(s.session) + if err != nil { + return err + } + } + + _, err := fmt.Fprintf(s.session.stdout, "D%.4o 0 %s\n", directoryInfo.Mode()&07777, directoryInfo.Name()) + if err != nil { + return err + } + + err = s.session.awaitConfirmation() + if err != nil { + return err + } + + fileInfos, err := os.ReadDir(dirname) + if err != nil { + return err + } + + for _, fileInfo := range fileInfos { + source := filepath.Join(dirname, fileInfo.Name()) + // #nosec G104 - we intentionally ignore this error + s.send(source, s.session.logger.Session("send-directory")) + } + + _, err = fmt.Fprintf(s.session.stdout, "E\n") + if err != nil { + return err + } + + err = s.session.awaitConfirmation() + if err != nil { + return err + } + + return nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/directory_test.go b/src/code.cloudfoundry.org/diego-ssh/scp/directory_test.go new file mode 100644 index 0000000000..67039403a4 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/directory_test.go @@ -0,0 +1,585 @@ +package scp_test + +import ( + "bytes" + "io" + "os" + "path/filepath" + "time" + + "code.cloudfoundry.org/diego-ssh/scp" + "code.cloudfoundry.org/diego-ssh/scp/atime" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Directory Message", func() { + var ( + tempDir string + logger *lagertest.TestLogger + err error + + copier TestCopier + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + tempDir, err = os.MkdirTemp("", "scp") + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + os.RemoveAll(tempDir) + }) + + newTestCopier := func(stdin io.Reader, stdout io.Writer, stderr io.Writer, preserveTimeAndMode bool) TestCopier { + options := &scp.Options{ + PreserveTimesAndMode: preserveTimeAndMode, + } + secureCopier, ok := scp.New(options, stdin, stdout, stderr, logger).(TestCopier) + Expect(ok).To(BeTrue()) + return secureCopier + } + + Context("when sending an empty directory to an scp sink", func() { + var ( + emptySubdir string + emptyDirInfo os.FileInfo + ) + + BeforeEach(func() { + emptySubdir = filepath.Join(tempDir, "empty-dir") + err := os.Mkdir(emptySubdir, os.FileMode(0775)) + Expect(err).NotTo(HaveOccurred()) + + err = os.Chmod(emptySubdir, 0775) + Expect(err).NotTo(HaveOccurred()) + + modificationTime := time.Unix(123456789, 12345678) + accessTime := time.Unix(987654321, 987654321) + err = os.Chtimes(emptySubdir, accessTime, modificationTime) + Expect(err).NotTo(HaveOccurred()) + + emptyDirInfo, err = os.Stat(emptySubdir) + Expect(err).NotTo(HaveOccurred()) + }) + + It("sends the directory start and end messages", func() { + stdin := bytes.NewReader([]byte{0, 0}) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.SendDirectory(emptySubdir, emptyDirInfo) + Expect(err).NotTo(HaveOccurred()) + + dirMessage, err := stdout.ReadString('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(dirMessage).To(Equal("D0775 0 empty-dir\n")) + + endMessage, err := stdout.ReadString('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(endMessage).To(Equal("E\n")) + }) + + It("waits for confirmation of each message", func() { + stdin := &fake_io.FakeReader{} + stdout := &fake_io.FakeWriter{} + stdoutBuffer := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdout.WriteStub = stdoutBuffer.Write + stdin.ReadStub = func(buffer []byte) (int, error) { + if stdin.ReadCallCount() == 1 { + dMessage, err := stdoutBuffer.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(dMessage).To(Equal("D0775 0 empty-dir\n")) + Expect(stdoutBuffer.Len()).To(Equal(0)) + } else { + eMessage, err := stdoutBuffer.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(eMessage).To(Equal("E\n")) + Expect(stdoutBuffer.Len()).To(Equal(0)) + } + + buffer[0] = 0 + return 1, nil + } + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.SendDirectory(emptySubdir, emptyDirInfo) + Expect(err).NotTo(HaveOccurred()) + }) + + It("does not return before the end message is confirmed", func() { + stdin, pw := io.Pipe() + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + errCh := make(chan error, 1) + go func() { + copier = newTestCopier(stdin, stdout, stderr, false) + errCh <- copier.SendDirectory(emptySubdir, emptyDirInfo) + }() + + Consistently(errCh).ShouldNot(Receive()) + + pw.Write([]byte{0}) + Consistently(errCh).ShouldNot(Receive()) + + pw.Write([]byte{0}) + Eventually(errCh).Should(Receive(BeNil())) + }) + + Context("when the directory cannot be opened", func() { + BeforeEach(func() { + emptySubdir = filepath.Join(emptySubdir, "non-existent-dir") + }) + + It("returns an error", func() { + stdin := bytes.NewReader([]byte{0, 0}) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + copier = newTestCopier(stdin, stdout, stderr, false) + + err := copier.SendDirectory(emptySubdir, emptyDirInfo) + Expect(err).To(MatchError(MatchRegexp("no such file or directory"))) + }) + }) + + Context("when preserving time stamps", func() { + It("sends the time information before the file message", func() { + stdin := bytes.NewReader([]byte{0, 0, 0}) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + copier = newTestCopier(stdin, stdout, stderr, true) + err := copier.SendDirectory(emptySubdir, emptyDirInfo) + Expect(err).NotTo(HaveOccurred()) + + tMessage, err := stdout.ReadString('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(tMessage).To(Equal("T123456789 0 987654321 0\n")) + + dMessage, err := stdout.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(dMessage).To(Equal("D0775 0 empty-dir\n")) + + eMessage, err := stdout.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(eMessage).To(Equal("E\n")) + }) + }) + }) + + Context("when sending an directory that contains files and directories", func() { + var ( + dirInfo os.FileInfo + subdir string + subdirFile string + tempFile string + ) + + BeforeEach(func() { + tempFile = filepath.Join(tempDir, "tempfile.txt") + err := os.WriteFile(tempFile, []byte("temporary-file-contents\n"), os.FileMode(0644)) + Expect(err).NotTo(HaveOccurred()) + + err = os.Chmod(tempFile, 0644) + Expect(err).NotTo(HaveOccurred()) + + subdir = filepath.Join(tempDir, "subdir") + err = os.Mkdir(subdir, os.FileMode(0700)) + Expect(err).NotTo(HaveOccurred()) + + err = os.Chmod(subdir, 0700) + Expect(err).NotTo(HaveOccurred()) + + subdirFile = filepath.Join(subdir, "subdir-file.txt") + err = os.WriteFile(subdirFile, []byte("subdir-file-contents\n"), os.FileMode(0644)) + Expect(err).NotTo(HaveOccurred()) + + err = os.Chmod(subdirFile, 0644) + Expect(err).NotTo(HaveOccurred()) + + emptySubdir := filepath.Join(tempDir, "empty-dir") + err = os.Mkdir(emptySubdir, os.FileMode(0775)) + Expect(err).NotTo(HaveOccurred()) + + err = os.Chmod(emptySubdir, 0775) + Expect(err).NotTo(HaveOccurred()) + + modificationTime := time.Unix(123456789, 12345678) + accessTime := time.Unix(987654321, 987654321) + err = os.Chtimes(emptySubdir, accessTime, modificationTime) + Expect(err).NotTo(HaveOccurred()) + + dirInfo, err = os.Stat(tempDir) + Expect(err).NotTo(HaveOccurred()) + }) + + It("sends the correct messages", func() { + stdin := bytes.NewReader(bytes.Repeat([]byte{0}, 10)) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.SendDirectory(tempDir, dirInfo) + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.ReadString('\n')).To(Equal("D0700 0 " + filepath.Base(tempDir) + "\n")) + Expect(stdout.ReadString('\n')).To(Equal("D0775 0 empty-dir\n")) + Expect(stdout.ReadString('\n')).To(Equal("E\n")) + Expect(stdout.ReadString('\n')).To(Equal("D0700 0 subdir\n")) + Expect(stdout.ReadString('\n')).To(Equal("C0644 21 subdir-file.txt\n")) + Expect(stdout.ReadString('\n')).To(Equal("subdir-file-contents\n")) + Expect(stdout.ReadByte()).To(BeEquivalentTo(0)) + Expect(stdout.ReadString('\n')).To(Equal("E\n")) + Expect(stdout.ReadString('\n')).To(Equal("C0644 24 tempfile.txt\n")) + Expect(stdout.ReadString('\n')).To(Equal("temporary-file-contents\n")) + Expect(stdout.ReadByte()).To(BeEquivalentTo(0)) + Expect(stdout.ReadString('\n')).To(Equal("E\n")) + }) + + Context("when sending a file fails", func() { + BeforeEach(func() { + subdirFile2 := filepath.Join(subdir, "does-not-exist.link") + os.Symlink(filepath.Join(subdir, "does-not-exist"), subdirFile2) + Expect(err).NotTo(HaveOccurred()) + }) + + It("continues to send the other files", func() { + stdin := bytes.NewReader(bytes.Repeat([]byte{0}, 10)) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + copier = newTestCopier(stdin, stdout, stderr, false) + err = copier.SendDirectory(tempDir, dirInfo) + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.ReadString('\n')).To(Equal("D0700 0 " + filepath.Base(tempDir) + "\n")) + Expect(stdout.ReadString('\n')).To(Equal("D0775 0 empty-dir\n")) + Expect(stdout.ReadString('\n')).To(Equal("E\n")) + Expect(stdout.ReadString('\n')).To(Equal("D0700 0 subdir\n")) + Expect(stdout.ReadByte()).To(BeEquivalentTo(1)) + Expect(stdout.ReadString('\n')).To(ContainSubstring("no such file or directory")) + Expect(stdout.ReadString('\n')).To(Equal("C0644 21 subdir-file.txt\n")) + Expect(stdout.ReadString('\n')).To(Equal("subdir-file-contents\n")) + Expect(stdout.ReadByte()).To(BeEquivalentTo(0)) + Expect(stdout.ReadString('\n')).To(Equal("E\n")) + Expect(stdout.ReadString('\n')).To(Equal("C0644 24 tempfile.txt\n")) + Expect(stdout.ReadString('\n')).To(Equal("temporary-file-contents\n")) + Expect(stdout.ReadByte()).To(BeEquivalentTo(0)) + Expect(stdout.ReadString('\n')).To(Equal("E\n")) + }) + }) + }) + + Context("when receiving a directory from an scp source", func() { + It("populates the directory with the received contents", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0700 0 received-dir\n") + stdin.WriteString("D0755 0 empty-dir\n") + stdin.WriteString("E\n") + stdin.WriteString("D0700 0 subdir\n") + stdin.WriteString("C0644 21 subdir-file.txt\n") + stdin.WriteString("subdir-file-contents\n") + stdin.WriteByte(0) + stdin.WriteString("E\n") + stdin.WriteString("C0600 24 tempfile.txt\n") + stdin.WriteString("temporary-file-contents\n") + stdin.WriteByte(0) + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(tempDir, nil) + Expect(err).NotTo(HaveOccurred()) + + Expect(filepath.Join(tempDir, "received-dir")).To(BeADirectory()) + Expect(filepath.Join(tempDir, "received-dir", "empty-dir")).To(BeADirectory()) + Expect(filepath.Join(tempDir, "received-dir", "subdir")).To(BeADirectory()) + Expect(filepath.Join(tempDir, "received-dir", "subdir", "subdir-file.txt")).To(BeARegularFile()) + Expect(filepath.Join(tempDir, "received-dir", "tempfile.txt")).To(BeARegularFile()) + + info, err := os.Stat(filepath.Join(tempDir, "received-dir")) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Mode() & 0777).To(Equal(os.FileMode(0700))) + + info, err = os.Stat(filepath.Join(tempDir, "received-dir", "empty-dir")) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Mode() & 0777).To(Equal(os.FileMode(0755))) + + info, err = os.Stat(filepath.Join(tempDir, "received-dir", "subdir")) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Mode() & 0777).To(Equal(os.FileMode(0700))) + + info, err = os.Stat(filepath.Join(tempDir, "received-dir", "subdir", "subdir-file.txt")) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Mode() & 0777).To(Equal(os.FileMode(0644))) + + info, err = os.Stat(filepath.Join(tempDir, "received-dir", "tempfile.txt")) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Mode() & 0777).To(Equal(os.FileMode(0600))) + + contents, err := os.ReadFile(filepath.Join(tempDir, "received-dir", "subdir", "subdir-file.txt")) + Expect(err).ToNot(HaveOccurred()) + Expect(contents).To(BeEquivalentTo("subdir-file-contents\n")) + + contents, err = os.ReadFile(filepath.Join(tempDir, "received-dir", "tempfile.txt")) + Expect(err).ToNot(HaveOccurred()) + Expect(contents).To(BeEquivalentTo("temporary-file-contents\n")) + }) + + Context("when preserving time stamps", func() { + It("restores the access time and modification time", func() { + timeStdin := &bytes.Buffer{} + timeStdout := &bytes.Buffer{} + timeStderr := &bytes.Buffer{} + + timeStdin.WriteString("T123456789 0 987654321 0\n") + timeSession := scp.NewSession(timeStdin, timeStdout, timeStderr, true, logger) + + timeMessage := &scp.TimeMessage{} + timeMessage.Receive(timeSession) + + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0755 0 empty-dir\n") + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, true) + err := copier.ReceiveDirectory(tempDir, timeMessage) + Expect(err).NotTo(HaveOccurred()) + + Expect(filepath.Join(tempDir, "empty-dir")).To(BeADirectory()) + + info, err := os.Stat(filepath.Join(tempDir, "empty-dir")) + Expect(err).NotTo(HaveOccurred()) + + accessTime, err := atime.AccessTime(info) + Expect(err).NotTo(HaveOccurred()) + + Expect(info.ModTime()).To(Equal(time.Unix(123456789, 0))) + Expect(accessTime).To(Equal(time.Unix(987654321, 0))) + }) + }) + + Context("when the message is not a directory message", func() { + It("raises an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0755 0 empty-dir\n") + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(tempDir, nil) + Expect(err).To(MatchError("unexpected message type: C")) + }) + }) + + Context("when the directory mode is not octal", func() { + It("raises an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0999 0 empty-dir\n") + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(tempDir, nil) + Expect(err).To(MatchError(`strconv.ParseUint: parsing "0999": invalid syntax`)) + }) + }) + + Context("when the ignored length field is not sent", func() { + It("raises an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0755 empty-dir\n") + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(tempDir, nil) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when the directory end message is not sent", func() { + It("raises an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0755 empty-dir\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(tempDir, nil) + Expect(err).To(Equal(io.EOF)) + }) + }) + + Context("when creating the target directory fails", func() { + var targetDir string + + BeforeEach(func() { + targetDir = filepath.Join(tempDir, "non-existent-dir", "target") + }) + + It("raises an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0755 0 empty-dir\n") + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(targetDir, nil) + Expect(err).To(MatchError(MatchRegexp("no such file or directory"))) + }) + }) + + Context("when the target directory does not exist", func() { + Context("but the target enclosing directory does", func() { + var targetDir string + + BeforeEach(func() { + targetDir = filepath.Join(tempDir, "target") + err := os.Mkdir(targetDir, os.FileMode(0777)) + Expect(err).NotTo(HaveOccurred()) + }) + + It("makes the new target directory and populates it with the sources contents", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0700 0 subdir\n") + stdin.WriteString("C0644 21 subdir-file.txt\n") + stdin.WriteString("subdir-file-contents\n") + stdin.WriteByte(0) + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(filepath.Join(targetDir, "newdir"), nil) + Expect(err).NotTo(HaveOccurred()) + + info, err := os.Stat(filepath.Join(tempDir, "target", "newdir")) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Mode() & 0777).To(Equal(os.FileMode(0700))) + + info, err = os.Stat(filepath.Join(tempDir, "target", "newdir", "subdir-file.txt")) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Mode() & 0777).To(Equal(os.FileMode(0644))) + + contents, err := os.ReadFile(filepath.Join(tempDir, "target", "newdir", "subdir-file.txt")) + Expect(err).ToNot(HaveOccurred()) + Expect(contents).To(BeEquivalentTo("subdir-file-contents\n")) + }) + }) + + Context("and the enclosing target directory does not exist", func() { + var targetDir string + + BeforeEach(func() { + targetDir = filepath.Join(tempDir, "target") + err := os.Mkdir(targetDir, os.FileMode(0777)) + Expect(err).NotTo(HaveOccurred()) + }) + + It("fails", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0700 0 empty-dir\n") + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(filepath.Join(targetDir, "newdir", "newer-dir"), nil) + Expect(err).To(HaveOccurred()) + }) + }) + }) + + Context("when the target directory already exists", func() { + BeforeEach(func() { + dir := filepath.Join(tempDir, "empty-dir") + err := os.Mkdir(dir, os.FileMode(0775)) + Expect(err).NotTo(HaveOccurred()) + + err = os.Chmod(dir, 0775) + Expect(err).NotTo(HaveOccurred()) + }) + + It("does not raise an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0755 0 empty-dir\n") + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(tempDir, nil) + Expect(err).NotTo(HaveOccurred()) + }) + + It("does not change the permissions", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0755 0 empty-dir\n") + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(tempDir, nil) + Expect(err).NotTo(HaveOccurred()) + + info, err := os.Stat(filepath.Join(tempDir, "empty-dir")) + Expect(err).NotTo(HaveOccurred()) + Expect(info.Mode() & 0777).To(Equal(os.FileMode(0775))) + }) + }) + + Context("when the target directory is really a file", func() { + BeforeEach(func() { + target := filepath.Join(tempDir, "empty-dir") + err := os.WriteFile(target, []byte("ego existo!"), 0660) + Expect(err).NotTo(HaveOccurred()) + + err = os.Chmod(target, 0660) + Expect(err).NotTo(HaveOccurred()) + }) + + It("raises an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("D0755 0 empty-dir\n") + stdin.WriteString("E\n") + + copier = newTestCopier(stdin, stdout, stderr, false) + err := copier.ReceiveDirectory(tempDir, nil) + Expect(err).To(HaveOccurred()) + + Expect(filepath.Join(tempDir, "empty-dir")).To(BeARegularFile()) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/file.go b/src/code.cloudfoundry.org/diego-ssh/scp/file.go new file mode 100644 index 0000000000..27bfe8dcba --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/file.go @@ -0,0 +1,145 @@ +package scp + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strconv" + + "code.cloudfoundry.org/lager/v3" +) + +func (s *secureCopy) SendFile(file *os.File, fileInfo os.FileInfo) error { + logger := s.session.logger.Session("send-file") + if fileInfo.IsDir() { + return errors.New("cannot send a directory") + } + + if s.session.preserveTimesAndMode { + timeMessage := NewTimeMessage(fileInfo) + err := timeMessage.Send(s.session) + if err != nil { + return err + } + } + + _, err := fmt.Fprintf(s.session.stdout, "C%.4o %d %s\n", fileInfo.Mode()&07777, fileInfo.Size(), fileInfo.Name()) + if err != nil { + return err + } + + err = s.session.awaitConfirmation() + if err != nil { + return err + } + + bytesSent, err := io.CopyN(s.session.stdout, file, fileInfo.Size()) + if err != nil { + return err + } + + err = s.session.sendConfirmation() + if err != nil { + return err + } + + logger.Info("awaiting-contents-confirmation", lager.Data{"File Size": fileInfo.Size(), "Bytes Sent": bytesSent}) + err = s.session.awaitConfirmation() + if err != nil { + logger.Error("failed-contents-confirmation", err) + return err + } + logger.Info("recieved-contents-confirmation") + + return nil +} + +func (s *secureCopy) ReceiveFile(path string, pathIsDir bool, timeMessage *TimeMessage) error { + messageType, err := s.session.readByte() + if err != nil { + return err + } + + if messageType != byte('C') { + return fmt.Errorf("unexpected message type: %c", messageType) + } + + fileModeString, err := s.session.readString(SPACE) + if err != nil { + return err + } + + fileMode, err := strconv.ParseUint(fileModeString, 8, 32) + if err != nil { + return err + } + + lengthString, err := s.session.readString(SPACE) + if err != nil { + return err + } + + length, err := strconv.ParseInt(lengthString, 10, 64) + if err != nil { + return err + } + + fileName, err := s.session.readString(NEWLINE) + if err != nil { + return err + } + + err = s.session.sendConfirmation() + if err != nil { + return err + } + + targetPath := path + if pathIsDir { + targetPath = filepath.Join(path, fileName) + } + + targetFile, err := os.OpenFile(targetPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(fileMode)) + if err != nil { + return err + } + + _, err = io.CopyN(targetFile, s.session.stdin, length) + if err != nil { + return err + } + + err = targetFile.Close() + if err != nil { + return err + } + + if s.session.preserveTimesAndMode { + err := os.Chmod(targetPath, os.FileMode(fileMode)) + if err != nil { + return err + } + } + + // OpenSSH does not check the flag + if timeMessage != nil { + err := os.Chtimes(targetPath, timeMessage.accessTime, timeMessage.modificationTime) + if err != nil { + return err + } + } + + err = s.session.awaitConfirmation() + if err != nil { + return err + } + + err = s.session.sendConfirmation() + if err != nil { + return err + } + + return nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/file_test.go b/src/code.cloudfoundry.org/diego-ssh/scp/file_test.go new file mode 100644 index 0000000000..638eb03e5b --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/file_test.go @@ -0,0 +1,668 @@ +package scp_test + +import ( + "bytes" + "crypto/rand" + "errors" + "io" + "os" + "path/filepath" + "time" + + "code.cloudfoundry.org/diego-ssh/scp" + "code.cloudfoundry.org/diego-ssh/scp/atime" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("File Message", func() { + var ( + tempDir string + tempFile string + + logger *lagertest.TestLogger + testCopier TestCopier + ) + + newTestCopier := func(stdin io.Reader, stdout io.Writer, stderr io.Writer, preserveTimeAndMode bool) TestCopier { + options := &scp.Options{ + PreserveTimesAndMode: preserveTimeAndMode, + } + secureCopier, ok := scp.New(options, stdin, stdout, stderr, logger).(TestCopier) + Expect(ok).To(BeTrue()) + return secureCopier + } + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + + var err error + tempDir, err = os.MkdirTemp("", "scp") + Expect(err).NotTo(HaveOccurred()) + + fileContents := make([]byte, 1024) + tempFile = filepath.Join(tempDir, "binary.dat") + + _, err = rand.Read(fileContents) + Expect(err).NotTo(HaveOccurred()) + + err = os.WriteFile(tempFile, fileContents, 0640) + Expect(err).NotTo(HaveOccurred()) + + modificationTime := time.Unix(123456789, 12345678) + accessTime := time.Unix(987654321, 987654321) + err = os.Chtimes(tempFile, accessTime, modificationTime) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + os.RemoveAll(tempDir) + }) + + Context("when sending the file to an scp sink", func() { + var ( + file *os.File + fileInfo os.FileInfo + err error + ) + + BeforeEach(func() { + file, err = os.Open(tempFile) + Expect(err).NotTo(HaveOccurred()) + + fileInfo, err = file.Stat() + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + file.Close() + }) + + It("sends the file message and contents to the sink", func() { + stdin := bytes.NewReader([]byte{0, 0}) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.SendFile(file, fileInfo) + Expect(err).NotTo(HaveOccurred()) + + cMessage, err := stdout.ReadString('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(cMessage).To(Equal("C0640 1024 binary.dat\n")) + + contents := make([]byte, 1024) + n, err := stdout.Read(contents) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(1024)) + + expectedContents, err := os.ReadFile(tempFile) + Expect(err).NotTo(HaveOccurred()) + Expect(contents).To(Equal(expectedContents)) + + confirmation := make([]byte, 1) + bytesRead, err := stdout.Read(confirmation) + Expect(err).NotTo(HaveOccurred()) + Expect(bytesRead).To(Equal(1)) + Expect(confirmation).To(Equal([]byte{0})) + }) + + It("waits for confirmation before sending the file contents", func() { + stdin := &fake_io.FakeReader{} + stdout := &fake_io.FakeWriter{} + stdoutBuffer := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdout.WriteStub = stdoutBuffer.Write + stdin.ReadStub = func(buffer []byte) (int, error) { + if stdin.ReadCallCount() == 1 { + cMessage, err := stdoutBuffer.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(cMessage).To(Equal("C0640 1024 binary.dat\n")) + Expect(stdoutBuffer.Len()).To(Equal(0)) + } else { + contents := make([]byte, 1024) + n, err := stdoutBuffer.Read(contents) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(1024)) + + expectedContents, err := os.ReadFile(tempFile) + Expect(err).NotTo(HaveOccurred()) + Expect(contents).To(Equal(expectedContents)) + } + + buffer[0] = 0 + return 1, nil + } + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.SendFile(file, fileInfo) + Expect(err).NotTo(HaveOccurred()) + }) + + It("does not return before the contents are confirmed", func() { + stdin, pw := io.Pipe() + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + errCh := make(chan error, 1) + go func() { + testCopier = newTestCopier(stdin, stdout, stderr, false) + errCh <- testCopier.SendFile(file, fileInfo) + }() + + Consistently(errCh).ShouldNot(Receive()) + + pw.Write([]byte{0}) + Consistently(errCh).ShouldNot(Receive()) + + pw.Write([]byte{0}) + Eventually(errCh).Should(Receive(BeNil())) + }) + + Context("when preserving time stamps", func() { + It("sends the time information before the file message", func() { + stdin := bytes.NewReader([]byte{0, 0, 0}) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + testCopier = newTestCopier(stdin, stdout, stderr, true) + err := testCopier.SendFile(file, fileInfo) + Expect(err).NotTo(HaveOccurred()) + + tMessage, err := stdout.ReadString('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(tMessage).To(Equal("T123456789 0 987654321 0\n")) + + cMessage, err := stdout.ReadString('\n') + Expect(err).ToNot(HaveOccurred()) + Expect(cMessage).To(Equal("C0640 1024 binary.dat\n")) + + contents := make([]byte, 1024) + n, err := stdout.Read(contents) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(1024)) + + expectedContents, err := os.ReadFile(tempFile) + Expect(err).NotTo(HaveOccurred()) + Expect(contents).To(Equal(expectedContents)) + }) + }) + + Context("when copy encounters a short read", func() { + It("returns with an error", func() { + stdin := bytes.NewReader([]byte{0, 0}) + stdout := &fake_io.FakeWriter{} + stderr := &bytes.Buffer{} + + stdout.WriteStub = func(buffer []byte) (int, error) { + f, err := os.OpenFile(tempFile, os.O_RDWR, 0640) + Expect(err).NotTo(HaveOccurred()) + + err = f.Truncate(512) + Expect(err).NotTo(HaveOccurred()) + + err = f.Close() + Expect(err).NotTo(HaveOccurred()) + + return len(buffer), nil + } + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.SendFile(file, fileInfo) + Expect(err).To(Equal(io.EOF)) + }) + }) + + Context("when the sink responds with a warning", func() { + var stdin, stdout, stderr *bytes.Buffer + + BeforeEach(func() { + stdin = &bytes.Buffer{} + stdout = &bytes.Buffer{} + stderr = &bytes.Buffer{} + + testCopier = newTestCopier(stdin, stdout, stderr, false) + + stdin.WriteByte(1) + stdin.WriteString("Danger!\n") + + stdin.WriteByte(0) + }) + + It("returns without an error", func() { + err := testCopier.SendFile(file, fileInfo) + Expect(err).NotTo(HaveOccurred()) + }) + + It("writes the message to stderr", func() { + testCopier.SendFile(file, fileInfo) + Expect(stderr.String()).To(Equal("Danger!")) + }) + }) + + Context("when the sink responds with a warning", func() { + var stdin, stdout, stderr *bytes.Buffer + + BeforeEach(func() { + stdin = &bytes.Buffer{} + stdout = &bytes.Buffer{} + stderr = &bytes.Buffer{} + + testCopier = newTestCopier(stdin, stdout, stderr, false) + + stdin.WriteByte(2) + stdin.WriteString("oops...\n") + + stdin.WriteByte(0) + }) + + It("returns with an error", func() { + err := testCopier.SendFile(file, fileInfo) + Expect(err).To(MatchError("oops...")) + }) + }) + + Context("when the sink responds with an invalid acknowledgement", func() { + var stdin, stdout, stderr *bytes.Buffer + + BeforeEach(func() { + stdin = &bytes.Buffer{} + stdout = &bytes.Buffer{} + stderr = &bytes.Buffer{} + + testCopier = newTestCopier(stdin, stdout, stderr, false) + + stdin.WriteByte('a') + }) + + It("returns with an error", func() { + err := testCopier.SendFile(file, fileInfo) + Expect(err).To(MatchError("invalid acknowledgement identifier: 61")) + }) + }) + + Context("when the file is a directory", func() { + It("fails and returns an error", func() { + dir, err := os.Open(tempDir) + Expect(err).NotTo(HaveOccurred()) + + dirInfo, err := dir.Stat() + Expect(err).NotTo(HaveOccurred()) + + testCopier = newTestCopier(nil, nil, nil, false) + Expect(testCopier.SendFile(dir, dirInfo)).To(HaveOccurred()) + }) + }) + + Context("when sending the confirmation fails", func() { + It("returns an error", func() { + stdin := bytes.NewReader([]byte{0, 0}) + stdout := &fake_io.FakeWriter{} + stderr := &bytes.Buffer{} + + stdout.WriteStub = func(buffer []byte) (int, error) { + if stdout.WriteCallCount() == 3 { + return 0, errors.New("BOOM") + } else { + return len(buffer), nil + } + } + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.SendFile(file, fileInfo) + Expect(err).To(HaveOccurred()) + }) + }) + }) + + Context("when receiving a file message from an scp source", func() { + It("creates the file with the received contents", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0640 5 hello.txt\n") + stdin.WriteString("hello") + stdin.WriteByte(0) + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.ReceiveFile(tempFile, false, nil) + Expect(err).NotTo(HaveOccurred()) + + Expect(os.ReadFile(tempFile)).To(BeEquivalentTo("hello")) + }) + + It("sends a confirmation after each message is received", func() { + stdinBuffer := &bytes.Buffer{} + stdin := &fake_io.FakeReader{} + stdout := &fake_io.FakeWriter{} + stderr := &bytes.Buffer{} + + stdinBuffer.WriteString("C0640 5 hello.txt\n") + stdinBuffer.WriteString("hello") + + stdin.ReadStub = func(buffer []byte) (int, error) { + b, err := stdinBuffer.ReadByte() + if err != nil { + return 0, err + } + buffer[0] = b + return 1, nil + } + + stdout.WriteStub = func(message []byte) (int, error) { + if stdout.WriteCallCount() == 1 { + Expect(stdin.ReadCallCount()).To(BeNumerically(">", 0)) + Expect(stdinBuffer.Len()).To(Equal(len("hello"))) + stdinBuffer.WriteByte(0) + + Expect(message).To(HaveLen(1)) + Expect(message[0]).To(BeEquivalentTo(0)) + } else { + Expect(stdinBuffer.Len()).To(Equal(0)) + + Expect(message).To(HaveLen(1)) + Expect(message[0]).To(BeEquivalentTo(0)) + } + + return 1, nil + } + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.ReceiveFile(tempFile, false, nil) + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.WriteCallCount()).To(Equal(2)) + + Expect(os.ReadFile(tempFile)).To(BeEquivalentTo("hello")) + }) + + It("sets the permissions of the file", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0444 5 hello.txt\n") + stdin.WriteString("hello") + stdin.WriteByte(0) + + testCopier = newTestCopier(stdin, stdout, stderr, true) + err := testCopier.ReceiveFile(tempDir, true, nil) + Expect(err).NotTo(HaveOccurred()) + + fileInfo, err := os.Stat(filepath.Join(tempDir, "hello.txt")) + Expect(err).NotTo(HaveOccurred()) + + Expect(fileInfo.Mode()).To(Equal(os.FileMode(0444))) + }) + + It("sets the timestamp of the file if present", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0444 5 hello.txt\n") + stdin.WriteString("hello") + stdin.WriteByte(0) + + tempFileInfo, err := os.Stat(tempFile) + Expect(err).NotTo(HaveOccurred()) + timestamp := scp.NewTimeMessage(tempFileInfo) + + testCopier = newTestCopier(stdin, stdout, stderr, true) + err = testCopier.ReceiveFile(tempDir, true, timestamp) + Expect(err).NotTo(HaveOccurred()) + + fileInfo, err := os.Stat(filepath.Join(tempDir, "hello.txt")) + Expect(err).NotTo(HaveOccurred()) + + Expect(fileInfo.ModTime()).To(Equal(tempFileInfo.ModTime())) + + fileAtime, err := atime.AccessTime(fileInfo) + Expect(err).NotTo(HaveOccurred()) + + tempAtime, err := atime.AccessTime(tempFileInfo) + Expect(err).NotTo(HaveOccurred()) + + Expect(fileInfo.ModTime()).To(Equal(tempFileInfo.ModTime())) + Expect(fileAtime).To(Equal(tempAtime)) + }) + + It("waits for a confirmation that the file has been sent", func() { + stdin, pw := io.Pipe() + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + errCh := make(chan error, 1) + go func() { + testCopier = newTestCopier(stdin, stdout, stderr, false) + errCh <- testCopier.ReceiveFile(tempFile, false, nil) + }() + + pw.Write([]byte("C0640 5 hello.txt\n")) + pw.Write([]byte("hello")) + + Consistently(errCh).ShouldNot(Receive()) + + pw.Write([]byte{0}) + Eventually(errCh).Should(Receive(BeNil())) + }) + + Context("when preserving time stamps and mode", func() { + It("restores the access time and modification time", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0444 5 hello.txt\n") + stdin.WriteString("hello") + stdin.WriteByte(0) + + tempFileInfo, err := os.Stat(tempFile) + Expect(err).NotTo(HaveOccurred()) + + timeMessage := scp.NewTimeMessage(tempFileInfo) + + testCopier = newTestCopier(stdin, stdout, stderr, true) + err = testCopier.ReceiveFile(tempFile, false, timeMessage) + Expect(err).NotTo(HaveOccurred()) + + fileInfo, err := os.Stat(tempFile) + Expect(err).NotTo(HaveOccurred()) + + fileAccessTime, err := atime.AccessTime(fileInfo) + Expect(err).NotTo(HaveOccurred()) + + expectedAccessTime, err := atime.AccessTime(tempFileInfo) + Expect(err).NotTo(HaveOccurred()) + + Expect(fileInfo.ModTime()).To(Equal(tempFileInfo.ModTime())) + Expect(fileAccessTime).To(Equal(expectedAccessTime)) + }) + }) + + Context("when the file already exists", func() { + var ( + preserveTimestampsAndMode bool + target string + ) + + BeforeEach(func() { + preserveTimestampsAndMode = false + }) + + JustBeforeEach(func() { + target = filepath.Join(tempDir, "hello.txt") + err := os.WriteFile(target, []byte("goodbye"), 0600) + Expect(err).NotTo(HaveOccurred()) + + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0640 5 hello.txt\n") + stdin.WriteString("hello") + stdin.WriteByte(0) + + testCopier = newTestCopier(stdin, stdout, stderr, preserveTimestampsAndMode) + err = testCopier.ReceiveFile(target, false, nil) + Expect(err).NotTo(HaveOccurred()) + }) + + It("replaces the file with the received contents", func() { + Expect(os.ReadFile(target)).To(BeEquivalentTo("hello")) + }) + + It("does not change the permissions of the file", func() { + file, err := os.Open(target) + Expect(err).NotTo(HaveOccurred()) + defer file.Close() + + fileInfo, err := file.Stat() + Expect(err).NotTo(HaveOccurred()) + + Expect(fileInfo.Mode()).To(Equal(os.FileMode(0600 & 0777))) + }) + + Context("and preserving mode is set", func() { + BeforeEach(func() { + preserveTimestampsAndMode = true + }) + + It("changes permissions of the file", func() { + file, err := os.Open(target) + Expect(err).NotTo(HaveOccurred()) + defer file.Close() + + fileInfo, err := file.Stat() + Expect(err).NotTo(HaveOccurred()) + + Expect(fileInfo.Mode()).To(Equal(os.FileMode(0640))) + }) + }) + }) + + Context("when opening the target file fails", func() { + BeforeEach(func() { + tempFile = filepath.Join(tempDir, "some-dir", "does-not-exists.txt") + }) + + It("fails with an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0640 5 hello.txt\n") + stdin.WriteString("hello") + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.ReceiveFile(tempFile, false, nil) + Expect(err).To(MatchError(MatchRegexp("no such file or directory"))) + }) + }) + + Context("when the message is not a file message", func() { + It("fails with an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("c0640 5 hello.txt\n") + stdin.WriteString("hello") + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.ReceiveFile(tempFile, false, nil) + Expect(err).To(MatchError(`unexpected message type: c`)) + }) + }) + + Context("when the file length field is not a number", func() { + It("fails with an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0640 five hello.txt\n") + stdin.WriteString("hello") + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.ReceiveFile(tempFile, false, nil) + Expect(err).To(MatchError(`strconv.ParseInt: parsing "five": invalid syntax`)) + }) + }) + + Context("when the file mode field is not an octal number", func() { + It("fails with an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0999 5 hello.txt\n") + stdin.WriteString("hello") + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.ReceiveFile(tempFile, false, nil) + Expect(err).To(MatchError(`strconv.ParseUint: parsing "0999": invalid syntax`)) + }) + }) + + Context("when the source does not send enough data for the file", func() { + BeforeEach(func() { + target := filepath.Join(tempDir, "hello.txt") + err := os.WriteFile(target, []byte("h"), 0660) + Expect(err).NotTo(HaveOccurred()) + }) + + It("fails with an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0640 512 hello.txt\n") + stdin.WriteString("hello") + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.ReceiveFile(tempFile, false, nil) + Expect(err).To(Equal(io.EOF)) + }) + }) + + Context("when the target is a directory", func() { + It("copies the file into the directory", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.WriteString("C0640 5 hello.txt\n") + stdin.WriteString("hello") + stdin.WriteByte(0) + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.ReceiveFile(tempDir, true, nil) + Expect(err).NotTo(HaveOccurred()) + + Expect(os.ReadFile(filepath.Join(tempDir, "hello.txt"))).To(BeEquivalentTo("hello")) + }) + }) + + Context("when the confirmation of the file fails", func() { + It("returns an error", func() { + stdin := &bytes.Buffer{} + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + + stdin.Write([]byte("C0640 5 hello.txt\n")) + stdin.Write([]byte("hello")) + stdin.Write([]byte{2}) + stdin.Write([]byte("BOOM\n")) + + testCopier = newTestCopier(stdin, stdout, stderr, false) + err := testCopier.ReceiveFile(tempFile, false, nil) + Expect(err).To(HaveOccurred()) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/flag_parser.go b/src/code.cloudfoundry.org/diego-ssh/scp/flag_parser.go new file mode 100644 index 0000000000..d9d945f2bb --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/flag_parser.go @@ -0,0 +1,101 @@ +package scp + +import ( + "errors" + + "github.com/google/shlex" + "github.com/pborman/getopt" +) + +type Options struct { + SourceMode bool + TargetMode bool + TargetIsDirectory bool + Verbose bool + PreserveTimesAndMode bool + Recursive bool + Quiet bool + + Sources []string + Target string +} + +func ParseCommand(command string) ([]string, error) { + args, err := shlex.Split(command) + if err != nil { + return []string{}, err + } + return args, err +} + +func ParseFlags(args []string) (*Options, error) { + cmd := args[0] + + if cmd != "scp" { + return nil, errors.New("Usage: call scp") + } + + opts := getopt.New() + + targetMode := opts.Bool('t', "", "Sets target mode for scp") + opts.Lookup('t').SetOptional() + + sourceMode := opts.Bool('f', "", "Sets source mode for scp") + opts.Lookup('f').SetOptional() + + targetIsDirectory := opts.Bool('d', "", "Indicates that the target is a directory") + opts.Lookup('d').SetOptional() + + verbose := opts.Bool('v', "", "Indicates that the command should be run in verbose mode") + opts.Lookup('v').SetOptional() + + preserveTimesAndMode := opts.Bool('p', "", "Indicates that scp should preserve timestamps and mode of files/directories transferred") + opts.Lookup('p').SetOptional() + + recursive := opts.Bool('r', "", "Indicates a recursive transfer, must be set if source is a directory") + opts.Lookup('r').SetOptional() + + // showprogress option is not used but can be provided + quiet := opts.Bool('q', "", "Indicates that the user wishes to run in quiet mode") + opts.Lookup('q').SetOptional() + + err := opts.Getopt(args, nil) + if err != nil { + return nil, err + } + + if *targetMode == *sourceMode { + return nil, errors.New("Must specify either target mode(-t) or source mode(-f) at a time") + } + + var sources []string + var target string + + if *sourceMode { + if len(opts.Args()) < 1 { + return nil, errors.New("Must specify at least one source in source mode") + } + + sources = opts.Args() + } + + if *targetMode { + if len(opts.Args()) != 1 { + return nil, errors.New("Must specify one target in target mode") + } + + target = opts.Args()[0] + } + + return &Options{ + TargetMode: *targetMode, + SourceMode: *sourceMode, + TargetIsDirectory: *targetIsDirectory, + Verbose: *verbose, + PreserveTimesAndMode: *preserveTimesAndMode, + Recursive: *recursive, + Quiet: *quiet, + Sources: sources, + Target: target, + }, nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/flag_parser_test.go b/src/code.cloudfoundry.org/diego-ssh/scp/flag_parser_test.go new file mode 100644 index 0000000000..29618015c1 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/flag_parser_test.go @@ -0,0 +1,203 @@ +package scp_test + +import ( + "code.cloudfoundry.org/diego-ssh/scp" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("FlagParser", func() { + Describe("ParseFlags", func() { + Context("when invalid flags are specified", func() { + It("returns an error", func() { + _, err := scp.ParseFlags([]string{"scp", "-xxx", "/tmp/foo"}) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when unix style command concatenated args are used", func() { + It("parses command line flags and returns Options", func() { + scpOptions, err := scp.ParseFlags([]string{"scp", "-tdvprq", "/tmp/foo"}) + Expect(err).NotTo(HaveOccurred()) + + Expect(scpOptions.TargetMode).To(BeTrue()) + Expect(scpOptions.SourceMode).To(BeFalse()) + Expect(scpOptions.TargetIsDirectory).To(BeTrue()) + Expect(scpOptions.Verbose).To(BeTrue()) + Expect(scpOptions.PreserveTimesAndMode).To(BeTrue()) + Expect(scpOptions.Recursive).To(BeTrue()) + Expect(scpOptions.Quiet).To(BeTrue()) + Expect(scpOptions.Target).To(Equal("/tmp/foo")) + }) + }) + + Context("when separate flags arguments are used", func() { + It("parses command line flags and returns Options", func() { + scpOptions, err := scp.ParseFlags([]string{"scp", "-t", "-d", "-v", "-p", "-r", "/tmp/foo"}) + Expect(err).NotTo(HaveOccurred()) + + Expect(scpOptions.TargetMode).To(BeTrue()) + Expect(scpOptions.SourceMode).To(BeFalse()) + Expect(scpOptions.TargetIsDirectory).To(BeTrue()) + Expect(scpOptions.Verbose).To(BeTrue()) + Expect(scpOptions.PreserveTimesAndMode).To(BeTrue()) + Expect(scpOptions.Recursive).To(BeTrue()) + Expect(scpOptions.Target).To(Equal("/tmp/foo")) + }) + }) + + Context("when source mode is specified", func() { + It("returns Options with SourceMode enabled", func() { + scpOptions, err := scp.ParseFlags([]string{"scp", "-f", "/tmp/foo"}) + Expect(err).NotTo(HaveOccurred()) + Expect(scpOptions.SourceMode).To(BeTrue()) + }) + + It("does not allow TargetMode to be enabled", func() { + _, err := scp.ParseFlags([]string{"scp", "-ft"}) + Expect(err).To(HaveOccurred()) + }) + + Context("Arguments", func() { + It("populates the Sources with following arguments", func() { + scpOptions, err := scp.ParseFlags([]string{"scp", "-f", "/foo/bar", "/baz/buzz"}) + Expect(err).NotTo(HaveOccurred()) + Expect(scpOptions.Sources).To(Equal([]string{"/foo/bar", "/baz/buzz"})) + }) + + It("returns an empty string for Target", func() { + scpOptions, err := scp.ParseFlags([]string{"scp", "-f", "/foo/bar", "/baz/buzz"}) + Expect(err).NotTo(HaveOccurred()) + Expect(scpOptions.Target).To(BeEmpty()) + }) + + Context("when no argument is provided", func() { + It("returns an error", func() { + _, err := scp.ParseFlags([]string{"scp", "-f"}) + Expect(err).To(MatchError("Must specify at least one source in source mode")) + }) + }) + }) + }) + + Context("when target mode is specified", func() { + It("returns Options with TargetMode enabled", func() { + scpOptions, err := scp.ParseFlags([]string{"scp", "-t", "/tmp/foo"}) + Expect(err).NotTo(HaveOccurred()) + Expect(scpOptions.TargetMode).To(BeTrue()) + }) + + It("does not allow SourceMode to be enabled", func() { + _, err := scp.ParseFlags([]string{"scp", "-tf"}) + Expect(err).To(HaveOccurred()) + }) + + Context("Arguments", func() { + It("populates the Target with the argument", func() { + scpOptions, err := scp.ParseFlags([]string{"scp", "-t", "/foo/bar"}) + Expect(err).NotTo(HaveOccurred()) + Expect(scpOptions.Target).To(Equal("/foo/bar")) + }) + + It("returns an empty array for Sources", func() { + scpOptions, err := scp.ParseFlags([]string{"scp", "-t", "/foo/bar"}) + Expect(err).NotTo(HaveOccurred()) + Expect(scpOptions.Sources).To(BeEmpty()) + }) + + Context("when no argument is provided", func() { + It("returns an error", func() { + _, err := scp.ParseFlags([]string{"scp", "-t"}) + Expect(err).To(MatchError("Must specify one target in target mode")) + }) + }) + + Context("when more than one argument is provided", func() { + It("returns an error", func() { + _, err := scp.ParseFlags([]string{"scp", "-t", "/foo/bar", "/baz/buzz"}) + Expect(err).To(MatchError("Must specify one target in target mode")) + }) + }) + }) + }) + + Context("when neither target or source mode is specified", func() { + It("does not allow this", func() { + _, err := scp.ParseFlags([]string{"scp", ""}) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when the command is not scp", func() { + It("returns an error", func() { + _, err := scp.ParseFlags([]string{"foobar", ""}) + Expect(err).To(HaveOccurred()) + }) + }) + }) + + Describe("ParseCommand", func() { + var ( + command string + args []string + err error + ) + + BeforeEach(func() { + command = "scp -v -f source" + }) + + JustBeforeEach(func() { + args, err = scp.ParseCommand(command) + }) + + It("returns an string slice from an scp command", func() { + Expect(err).NotTo(HaveOccurred()) + Expect(args).To(Equal([]string{"scp", "-v", "-f", "source"})) + }) + + Context("when the shell lexer returns an error", func() { + BeforeEach(func() { + command = "scp -v -f source\\" + }) + + It("returns an error", func() { + Expect(err).To(HaveOccurred()) + Expect(args).To(BeEmpty()) + }) + }) + + Context("when the command string contains escaped spaces as parts of filenames", func() { + BeforeEach(func() { + command = "scp -v -f source\\ file" + }) + + It("correctly captures the path as a single argument", func() { + Expect(err).NotTo(HaveOccurred()) + Expect(args).To(Equal([]string{"scp", "-v", "-f", "source file"})) + }) + }) + + Context("when an argument is in quotes", func() { + BeforeEach(func() { + command = "scp -v -f \"source\"" + }) + + It("correctly captures the path as a single argument", func() { + Expect(err).NotTo(HaveOccurred()) + Expect(args).To(Equal([]string{"scp", "-v", "-f", "source"})) + }) + }) + + Context("when the command contains unexpected whitespace", func() { + BeforeEach(func() { + command = "scp -v -f source" + }) + + It("correctly strips excess whitespace", func() { + Expect(err).NotTo(HaveOccurred()) + Expect(args).To(Equal([]string{"scp", "-v", "-f", "source"})) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/package.go b/src/code.cloudfoundry.org/diego-ssh/scp/package.go new file mode 100644 index 0000000000..573cc5c7eb --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/package.go @@ -0,0 +1 @@ +package scp // import "code.cloudfoundry.org/diego-ssh/scp" diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/scp.go b/src/code.cloudfoundry.org/diego-ssh/scp/scp.go new file mode 100644 index 0000000000..260eef6997 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/scp.go @@ -0,0 +1,223 @@ +package scp + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + + "code.cloudfoundry.org/lager/v3" +) + +type SecureCopier interface { + Copy() error +} + +type secureCopy struct { + options *Options + session *Session +} + +func New(options *Options, stdin io.Reader, stdout io.Writer, stderr io.Writer, logger lager.Logger) SecureCopier { + session := NewSession(stdin, stdout, stderr, options.PreserveTimesAndMode, logger) + + return &secureCopy{ + options: options, + session: session, + } +} + +func NewFromCommand(command string, stdin io.Reader, stdout io.Writer, stderr io.Writer, logger lager.Logger) (SecureCopier, error) { + cmd, err := ParseCommand(command) + if err != nil { + return nil, err + } + + options, err := ParseFlags(cmd) + if err != nil { + return nil, err + } + + return New(options, stdin, stdout, stderr, logger), nil +} + +func (s *secureCopy) Copy() error { + if s.options.SourceMode { + var lastErr error + logger := s.session.logger.Session("source-mode") + + logger.Info("started") + defer logger.Info("finished") + + logger.Debug("awaiting-connection-confirmation") + err := s.session.awaitConfirmation() + if err != nil { + logger.Error("failed-confirmation", err) + return err + } + logger.Debug("received-connection-confirmation") + + for _, sourceGlob := range s.options.Sources { + logger.Debug("evaluating-glob", lager.Data{"Source Glob": sourceGlob}) + sources, err := filepath.Glob(sourceGlob) + if err != nil || len(sources) == 0 { + logger.Debug("failed-matching-glob", lager.Data{"Source Glob": sourceGlob}) + sources = []string{sourceGlob} + } + + for _, source := range sources { + logger.Debug("sending-source", lager.Data{"Source": source}) + + sourceInfo, err := os.Stat(source) + if err != nil { + logger.Error("failed-to-stat", err) + sendErr := s.session.sendError(err.Error()) + if sendErr != nil { + logger.Debug("failed-sending-send-error", lager.Data{"error": sendErr}) + } + lastErr = err + continue + } + + if sourceInfo.IsDir() && !s.options.Recursive { + err = fmt.Errorf("%s: not a regular file", sourceInfo.Name()) + logger.Error("sending-non-recursive-directory-failed", err) + sendErr := s.session.sendError(err.Error()) + if sendErr != nil { + logger.Debug("failed-sending-send-error", lager.Data{"error": sendErr}) + } + lastErr = err + continue + } + + err = s.send(source, logger) + if err != nil { + logger.Error("failed-sending-source", err, lager.Data{"Source": source}) + lastErr = err + continue + } + logger.Debug("sent-source", lager.Data{"Source": source}) + } + } + + return lastErr + } + + if s.options.TargetMode { + logger := s.session.logger.Session("target-mode") + + logger.Info("started") + defer logger.Info("finished") + + targetIsDir := false + targetInfo, err := os.Stat(s.options.Target) + if err == nil { + targetIsDir = targetInfo.IsDir() + } + + if s.options.TargetIsDirectory { + if !targetIsDir { + err = errors.New("target is not a directory") + logger.Error("failed-target-directory-validation", err) + return err + } + } + + err = s.session.sendConfirmation() + if err != nil { + logger.Error("failed-sending-confirmation", err) + return err + } + + for { + var timeMessage *TimeMessage + + var err error + messageType, err := s.session.peekByte() + if err == io.EOF { + return nil + } + + if messageType == 'T' { + timeMessage = &TimeMessage{} + err := timeMessage.Receive(s.session) + if err != nil { + logger.Error("failed-receiving-time-message", err) + sendErr := s.session.sendError(err.Error()) + if sendErr != nil { + logger.Debug("failed-sending-send-error", lager.Data{"error": sendErr}) + } + return err + } + + messageType, err = s.session.peekByte() + if err == io.EOF { + return nil + } + } + + if messageType == 'C' { + s.session.logger.Info("receiving-file", lager.Data{"Message Type": messageType}) + err = s.ReceiveFile(s.options.Target, targetIsDir, timeMessage) + } else if messageType == 'D' { + err = s.ReceiveDirectory(s.options.Target, timeMessage) + } else { + err = fmt.Errorf("unexpected message type: %c", messageType) + logger.Error("unexpected-message", err) + sendErr := s.session.sendError(err.Error()) + if sendErr != nil { + logger.Debug("failed-sending-send-error", lager.Data{"error": sendErr}) + } + return err + } + + if err != nil { + logger.Error("failed-receiving-message", err) + sendErr := s.session.sendError(err.Error()) + if sendErr != nil { + logger.Debug("failed-sending-send-error", lager.Data{"error": sendErr}) + } + return err + } + } + } + + return nil +} + +func (s *secureCopy) send(source string, logger lager.Logger) error { + var err error + + defer func() { + if err != nil { + sendErr := s.session.sendError(err.Error()) + if sendErr != nil { + logger.Session("copy").Debug("failed-sending-send-error", lager.Data{"error": sendErr}) + } + } + }() + + file, err := os.Open(source) + if err != nil { + return err + } + defer file.Close() + + fileInfo, err := file.Stat() + if err != nil { + return err + } + + if !fileInfo.IsDir() { + err = s.SendFile(file, fileInfo) + } else { + err = s.SendDirectory(file.Name(), fileInfo) + } + + if err != nil { + return err + } + + return err +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/scp_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/scp/scp_suite_test.go new file mode 100644 index 0000000000..e9d5c1b26d --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/scp_suite_test.go @@ -0,0 +1,20 @@ +package scp_test + +import ( + "runtime" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "testing" +) + +func TestScp(t *testing.T) { + RegisterFailHandler(Fail) + BeforeEach(func() { + if runtime.GOOS == "windows" { + Skip("scp isn't supported on windows") + } + }) + RunSpecs(t, "Scp Suite") +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/scp_test.go b/src/code.cloudfoundry.org/diego-ssh/scp/scp_test.go new file mode 100644 index 0000000000..622f2edcca --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/scp_test.go @@ -0,0 +1,764 @@ +package scp_test + +import ( + "bufio" + "crypto/rand" + "fmt" + "io" + "os" + "path/filepath" + "time" + + "code.cloudfoundry.org/diego-ssh/scp" + "code.cloudfoundry.org/diego-ssh/scp/atime" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +type TestCopier interface { + scp.SecureCopier + + SendDirectory(dir string, dirInfo os.FileInfo) error + ReceiveDirectory(dir string, timeStampMessage *scp.TimeMessage) error + + SendFile(file *os.File, fileInfo os.FileInfo) error + ReceiveFile(path string, pathIsDir bool, timeMessage *scp.TimeMessage) error +} + +var fileInfos map[string]os.FileInfo + +var _ = Describe("scp", func() { + var ( + stdin, stdoutSource io.ReadCloser + stdinSource, stdout io.WriteCloser + stderr io.Writer + + sourceDir string + sourceDirInfo os.FileInfo + targetDir string + nestedTempDir string + nestedTempDirInfo os.FileInfo + generatedTextFile string + generatedTextFileInfo os.FileInfo + generatedNestedTextFile string + generatedNestedTextFileInfo os.FileInfo + generatedBinaryFile string + generatedBinaryFileInfo os.FileInfo + + secureCopier scp.SecureCopier + logger *lagertest.TestLogger + + testCopier TestCopier + ) + + newTestCopier := func(stdin io.Reader, stdout io.Writer, stderr io.Writer, preserveTimeAndMode bool) TestCopier { + options := &scp.Options{ + PreserveTimesAndMode: preserveTimeAndMode, + } + secureCopier, ok := scp.New(options, stdin, stdout, stderr, logger).(TestCopier) + Expect(ok).To(BeTrue()) + return secureCopier + } + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + + fileInfos = make(map[string]os.FileInfo) + + stdin, stdinSource = io.Pipe() + stdoutSource, stdout = io.Pipe() + stderr = io.Discard + + var err error + sourceDir, err = os.MkdirTemp("", "scp-source") + Expect(err).NotTo(HaveOccurred()) + + fileContents := []byte("---\nthis is a simple file\n\n") + generatedTextFile = filepath.Join(sourceDir, "textfile.txt") + + err = os.WriteFile(generatedTextFile, fileContents, 0664) + Expect(err).NotTo(HaveOccurred()) + + fileContents = make([]byte, 1024) + generatedBinaryFile = filepath.Join(sourceDir, "binary.dat") + + _, err = rand.Read(fileContents) + Expect(err).NotTo(HaveOccurred()) + + err = os.WriteFile(generatedBinaryFile, fileContents, 0400) + Expect(err).NotTo(HaveOccurred()) + + nestedTempDir, err = os.MkdirTemp(sourceDir, "nested") + Expect(err).NotTo(HaveOccurred()) + + nestedFileContents := []byte("---\nthis is a simple nested file\n\n") + generatedNestedTextFile = filepath.Join(nestedTempDir, "nested-textfile.txt") + + err = os.WriteFile(generatedNestedTextFile, nestedFileContents, 0664) + Expect(err).NotTo(HaveOccurred()) + + // save off file infos + sourceDirInfo, err = os.Stat(sourceDir) + Expect(err).NotTo(HaveOccurred()) + fileInfos[sourceDir] = sourceDirInfo + + generatedTextFileInfo, err = os.Stat(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + fileInfos[generatedTextFile] = generatedTextFileInfo + + generatedBinaryFileInfo, err = os.Stat(generatedBinaryFile) + Expect(err).NotTo(HaveOccurred()) + fileInfos[generatedBinaryFile] = generatedBinaryFileInfo + + nestedTempDirInfo, err = os.Stat(nestedTempDir) + Expect(err).NotTo(HaveOccurred()) + fileInfos[nestedTempDir] = nestedTempDirInfo + + generatedNestedTextFileInfo, err = os.Stat(generatedNestedTextFile) + Expect(err).NotTo(HaveOccurred()) + fileInfos[generatedNestedTextFile] = generatedNestedTextFileInfo + + targetDir, err = os.MkdirTemp("", "scp-target") + Expect(err).NotTo(HaveOccurred()) + + secureCopier = nil + }) + + AfterEach(func() { + os.RemoveAll(sourceDir) + os.RemoveAll(targetDir) + }) + + Context("source mode", func() { + var preserveTimestamps bool + Context("when no files are requested", func() { + It("fails construct the copier", func() { + _, err := scp.NewFromCommand("scp -f", stdin, stdout, stderr, logger) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when files are requested", func() { + var sourceFileInfo os.FileInfo + + BeforeEach(func() { + preserveTimestamps = false + }) + + Context("when the requested file exists", func() { + JustBeforeEach(func() { + var err error + + command := fmt.Sprintf("scp -f %s", generatedTextFile) + if preserveTimestamps { + command = fmt.Sprintf("scp -fp %s", generatedTextFile) + } + + secureCopier, err = scp.NewFromCommand(command, stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + done := make(chan struct{}) + go func() { + err := secureCopier.Copy() + Expect(err).NotTo(HaveOccurred()) + close(done) + }() + + _, err = stdinSource.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + + session := scp.NewSession(stdoutSource, stdinSource, nil, preserveTimestamps, logger) + + var timestampMessage *scp.TimeMessage + if preserveTimestamps { + timestampMessage = &scp.TimeMessage{} + err = timestampMessage.Receive(session) + Expect(err).NotTo(HaveOccurred()) + } + + testCopier = newTestCopier(stdoutSource, stdinSource, nil, preserveTimestamps) + err = testCopier.ReceiveFile(targetDir, true, timestampMessage) + Expect(err).NotTo(HaveOccurred()) + Eventually(done).Should(BeClosed()) + + sourceFileInfo, err = os.Stat(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + }) + + It("sends the file", func() { + compareFile(filepath.Join(targetDir, sourceFileInfo.Name()), generatedTextFile, preserveTimestamps) + }) + + Context("when -p (preserve times) is specified", func() { + BeforeEach(func() { + preserveTimestamps = true + }) + + It("sends the timestamp information before the file", func() { + compareFile(filepath.Join(targetDir, sourceFileInfo.Name()), generatedTextFile, preserveTimestamps) + }) + }) + }) + + Context("when the requested file does not exist", func() { + BeforeEach(func() { + os.RemoveAll(generatedTextFile) + }) + + It("returns an error and continues sending", func() { + testCopier = newTestCopier(stdoutSource, stdinSource, nil, preserveTimestamps) + + command := fmt.Sprintf("scp -f %s %s", generatedTextFile, generatedBinaryFile) + secureCopier, err := scp.NewFromCommand(command, stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + errCh := make(chan error) + go func() { + errCh <- secureCopier.Copy() + }() + + _, err = stdinSource.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + + stdoutReader := bufio.NewReader(stdoutSource) + + errCode, err := stdoutReader.ReadByte() + Expect(err).NotTo(HaveOccurred()) + Expect(errCode).To(BeEquivalentTo(1)) + + errMessage, err := stdoutReader.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(errMessage).To(ContainSubstring("no such file or directory")) + + err = testCopier.ReceiveFile(targetDir, true, nil) + Expect(err).NotTo(HaveOccurred()) + + Eventually(errCh).Should(Receive(HaveOccurred())) + + compareFile(filepath.Join(targetDir, "binary.dat"), generatedBinaryFile, false) + }) + }) + }) + + Context("when a directory is requested", func() { + Context("when the -r (recursive) flag is not specified", func() { + BeforeEach(func() { + var err error + command := fmt.Sprintf("scp -f %s %s", sourceDir, generatedTextFile) + secureCopier, err = scp.NewFromCommand(command, stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + }) + + It("returns an error and continues sending sources", func() { + testCopier = newTestCopier(stdoutSource, stdinSource, nil, preserveTimestamps) + + errCh := make(chan error) + go func() { + errCh <- secureCopier.Copy() + }() + + _, err := stdinSource.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + + stdoutReader := bufio.NewReader(stdoutSource) + + errCode, err := stdoutReader.ReadByte() + Expect(err).NotTo(HaveOccurred()) + Expect(errCode).To(BeEquivalentTo(1)) + + errMessage, err := stdoutReader.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(errMessage).To(ContainSubstring("not a regular file")) + + err = testCopier.ReceiveFile(targetDir, true, nil) + Expect(err).NotTo(HaveOccurred()) + + Eventually(errCh).Should(Receive(HaveOccurred())) + compareFile(filepath.Join(targetDir, "textfile.txt"), generatedTextFile, false) + }) + }) + + Context("when the -r (recursive) flag is specified", func() { + var sourceDirInfo os.FileInfo + + BeforeEach(func() { + preserveTimestamps = false + }) + + JustBeforeEach(func() { + var err error + + command := fmt.Sprintf("scp -rf %s", sourceDir) + if preserveTimestamps { + command = fmt.Sprintf("scp -rfp %s", sourceDir) + } + + secureCopier, err = scp.NewFromCommand(command, stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + done := make(chan struct{}) + go func() { + err := secureCopier.Copy() + Expect(err).NotTo(HaveOccurred()) + close(done) + }() + + _, err = stdinSource.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + + session := scp.NewSession(stdoutSource, stdinSource, nil, preserveTimestamps, logger) + + timestampMessage := &scp.TimeMessage{} + if preserveTimestamps { + err = timestampMessage.Receive(session) + Expect(err).NotTo(HaveOccurred()) + } + + testCopier = newTestCopier(stdoutSource, stdinSource, nil, preserveTimestamps) + err = testCopier.ReceiveDirectory(targetDir, timestampMessage) + Expect(err).NotTo(HaveOccurred()) + Eventually(done).Should(BeClosed()) + + sourceDirInfo, err = os.Stat(sourceDir) + Expect(err).NotTo(HaveOccurred()) + }) + + It("sends the directory and all the files", func() { + compareDir(filepath.Join(targetDir, sourceDirInfo.Name()), sourceDir, preserveTimestamps) + }) + + Context("when the -p is specified", func() { + BeforeEach(func() { + preserveTimestamps = true + }) + + It("sends timestamp information before files and directories", func() { + compareDir(filepath.Join(targetDir, sourceDirInfo.Name()), sourceDir, preserveTimestamps) + }) + }) + }) + }) + + Context("when a glob is requested", func() { + var ( + command string + ) + + BeforeEach(func() { + command = fmt.Sprintf("scp -f %s/[bt]*", sourceDir) + }) + + Context("when the glob is valid", func() { + + JustBeforeEach(func() { + var err error + + secureCopier, err = scp.NewFromCommand(command, stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + done := make(chan struct{}) + go func() { + err := secureCopier.Copy() + Expect(err).NotTo(HaveOccurred()) + close(done) + }() + + _, err = stdinSource.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + + testCopier = newTestCopier(stdoutSource, stdinSource, nil, false) + + // Receive File 1 + err = testCopier.ReceiveFile(targetDir, true, nil) + Expect(err).NotTo(HaveOccurred()) + + // Receive File 2 + err = testCopier.ReceiveFile(targetDir, true, nil) + Expect(err).NotTo(HaveOccurred()) + + Eventually(done).Should(BeClosed()) + }) + + It("properly matches the glob against a single filename", func() { + compareFile(filepath.Join(targetDir, "textfile.txt"), generatedTextFile, false) + compareFile(filepath.Join(targetDir, "binary.dat"), generatedBinaryFile, false) + }) + }) + + Context("when the glob does not match any sources", func() { + var generatedBadGlobFile string + + JustBeforeEach(func() { + fileContents := []byte("---\nthis is a bad glob file\n\n") + + err := os.WriteFile(generatedBadGlobFile, fileContents, 0664) + Expect(err).NotTo(HaveOccurred()) + + generatedBadGlobFileInfo, err := os.Stat(generatedBadGlobFile) + Expect(err).NotTo(HaveOccurred()) + fileInfos[generatedBadGlobFile] = generatedBadGlobFileInfo + + command = fmt.Sprintf("scp -f %s", generatedBadGlobFile) + secureCopier, err := scp.NewFromCommand(command, stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + done := make(chan struct{}) + go func() { + err := secureCopier.Copy() + Expect(err).NotTo(HaveOccurred()) + close(done) + }() + + _, err = stdinSource.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + + testCopier = newTestCopier(stdoutSource, stdinSource, nil, false) + + // Receive File 1 + err = testCopier.ReceiveFile(targetDir, true, nil) + Expect(err).NotTo(HaveOccurred()) + + Eventually(done).Should(BeClosed()) + }) + + Context("because it is malformed", func() { + BeforeEach(func() { + generatedBadGlobFile = filepath.Join(sourceDir, "[") + }) + + It("attempts to match the glob literally", func() { + compareFile(filepath.Join(targetDir, "["), generatedBadGlobFile, false) + }) + }) + + Context("because nothing matches the glob", func() { + BeforeEach(func() { + generatedBadGlobFile = filepath.Join(sourceDir, "[a].txt") + }) + + It("attempts to match the glob literally", func() { + compareFile(filepath.Join(targetDir, "[a].txt"), generatedBadGlobFile, false) + }) + }) + }) + }) + }) + + Context("target mode", func() { + Context("when no target is specified", func() { + It("fails construct the copier", func() { + _, err := scp.NewFromCommand("scp -t", stdin, stdout, stderr, logger) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when multiple targets are specified", func() { + It("fails construct the copier", func() { + _, err := scp.NewFromCommand("scp -t a b", stdin, stdout, stderr, logger) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when the target is not a directory", func() { + Context("and the target is specified as a directory", func() { + It("fails when the target does not exist", func() { + secureCopier, err := scp.NewFromCommand("scp -td bogus", stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + err = secureCopier.Copy() + Expect(err).To(HaveOccurred()) + }) + + It("fails when the target is not a directory", func() { + tempFile, err := os.CreateTemp(targetDir, "target") + Expect(err).NotTo(HaveOccurred()) + + secureCopier, err := scp.NewFromCommand("scp -td "+tempFile.Name(), stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + err = secureCopier.Copy() + Expect(err).To(HaveOccurred()) + }) + }) + }) + + Context("when a file is specified as the target", func() { + var ( + targetFile string + preserveTimestamps bool + ) + + BeforeEach(func() { + preserveTimestamps = false + targetFile = filepath.Join(targetDir, "targetFile") + }) + + JustBeforeEach(func() { + var err error + + args := "-t" + if preserveTimestamps { + args += "p" + } + command := fmt.Sprintf("scp %s %s", args, targetFile) + + secureCopier, err = scp.NewFromCommand(command, stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + done := make(chan struct{}) + go func() { + defer GinkgoRecover() + err := secureCopier.Copy() + Expect(err).NotTo(HaveOccurred()) + close(done) + }() + + bytes := make([]byte, 1) + _, err = stdoutSource.Read(bytes) + Expect(err).NotTo(HaveOccurred()) + + textFile, err := os.Open(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + + textFileInfo, err := textFile.Stat() + Expect(err).NotTo(HaveOccurred()) + + testCopier = newTestCopier(stdoutSource, stdinSource, nil, preserveTimestamps) + err = testCopier.SendFile(textFile, textFileInfo) + Expect(err).NotTo(HaveOccurred()) + stdinSource.Close() + Eventually(done).Should(BeClosed()) + + _, err = os.Stat(targetFile) + Expect(err).NotTo(HaveOccurred()) + }) + + It("allows a file to be sent", func() { + compareFile(targetFile, generatedTextFile, preserveTimestamps) + }) + + Context("when preserving timestamps and mode", func() { + BeforeEach(func() { + preserveTimestamps = true + }) + + It("sets the mode and timestamp", func() { + compareFile(targetFile, generatedTextFile, preserveTimestamps) + }) + + Context("when the target file exists", func() { + BeforeEach(func() { + err := os.WriteFile(targetFile, []byte{'a'}, 0640) + Expect(err).NotTo(HaveOccurred()) + + modificationTime := time.Unix(123456789, 12345678) + accessTime := time.Unix(987654321, 987654321) + err = os.Chtimes(targetFile, accessTime, modificationTime) + Expect(err).NotTo(HaveOccurred()) + }) + + It("sets the mode and timestamp", func() { + targetFileInfo, err := os.Stat(targetFile) + Expect(err).NotTo(HaveOccurred()) + compareFileInfo(targetFileInfo, generatedTextFileInfo, preserveTimestamps) + }) + }) + }) + }) + + Context("when a directory is specified as the target", func() { + var ( + dir string + preserveTimestamps bool + targetIsDirectory bool + done chan struct{} + ) + + BeforeEach(func() { + dir = targetDir + preserveTimestamps = false + targetIsDirectory = false + }) + + JustBeforeEach(func() { + var err error + + args := "-t" + if preserveTimestamps { + args += "p" + } + if targetIsDirectory { + args += "d" + } + command := fmt.Sprintf("scp %s %s", args, dir) + + secureCopier, err = scp.NewFromCommand(command, stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + done = make(chan struct{}) + go func() { + defer GinkgoRecover() + err := secureCopier.Copy() + Expect(err).NotTo(HaveOccurred()) + close(done) + }() + + bytes := make([]byte, 1) + _, err = stdoutSource.Read(bytes) + Expect(err).NotTo(HaveOccurred()) + + scp.NewSession(stdoutSource, stdinSource, nil, preserveTimestamps, logger) + testCopier = newTestCopier(stdoutSource, stdinSource, nil, preserveTimestamps) + }) + + Context("and a file is sent", func() { + var file *os.File + var err error + + JustBeforeEach(func() { + file, err = os.Open(generatedTextFile) + Expect(err).NotTo(HaveOccurred()) + + fileInfo, err := file.Stat() + Expect(err).NotTo(HaveOccurred()) + + testCopier = newTestCopier(stdoutSource, stdinSource, nil, preserveTimestamps) + err = testCopier.SendFile(file, fileInfo) + Expect(err).NotTo(HaveOccurred()) + + stdinSource.Close() + Eventually(done).Should(BeClosed()) + }) + + It("copies the file and its contents into the target", func() { + compareFile(filepath.Join(dir, filepath.Base(file.Name())), generatedTextFile, preserveTimestamps) + }) + }) + + Context("and a directory is sent", func() { + JustBeforeEach(func() { + sourceDirInfo, err := os.Stat(sourceDir) + Expect(err).NotTo(HaveOccurred()) + + err = testCopier.SendDirectory(sourceDir, sourceDirInfo) + Expect(err).NotTo(HaveOccurred()) + stdinSource.Close() + Eventually(done).Should(BeClosed()) + }) + + It("receives the directory and its content", func() { + compareDir(filepath.Join(dir, filepath.Base(sourceDir)), sourceDir, preserveTimestamps) + }) + + Context("when the target directory does not exist but its parent directory does", func() { + BeforeEach(func() { + dir = filepath.Join(targetDir, "newdir") + }) + + It("makes the target directory and populates with the source directories contents", func() { + compareDir(dir, sourceDir, preserveTimestamps) + }) + }) + }) + }) + + Context("when an unknown message type is sent", func() { + It("returns an error", func() { + secureCopier, err := scp.NewFromCommand("scp -t /tmp/foo", stdin, stdout, stderr, logger) + Expect(err).NotTo(HaveOccurred()) + + errCh := make(chan error) + go func() { + defer GinkgoRecover() + errCh <- secureCopier.Copy() + }() + + bytes := make([]byte, 1) + _, err = stdoutSource.Read(bytes) + Expect(err).NotTo(HaveOccurred()) + + _, err = stdinSource.Write([]byte("F this protocol message does not exist")) + Expect(err).NotTo(HaveOccurred()) + + stdoutReader := bufio.NewReader(stdoutSource) + + errCode, err := stdoutReader.ReadByte() + Expect(err).NotTo(HaveOccurred()) + Expect(errCode).To(BeEquivalentTo(1)) + + errMessage, err := stdoutReader.ReadString('\n') + Expect(err).NotTo(HaveOccurred()) + Expect(errMessage).To(ContainSubstring("unexpected message type: F")) + + Eventually(errCh).Should(Receive(HaveOccurred())) + }) + }) + }) +}) + +func compareDir(actualDir, expectedDir string, compareTimestamps bool) { + actualDirInfo, err := os.Stat(actualDir) + Expect(err).NotTo(HaveOccurred()) + + expectedDirInfo, ok := fileInfos[expectedDir] + Expect(ok).To(BeTrue()) + + Expect(actualDirInfo.Mode()).To(Equal(expectedDirInfo.Mode())) + if compareTimestamps { + compareTimestampsFromInfo(actualDirInfo, expectedDirInfo) + } + + actualFiles, err := os.ReadDir(actualDir) + Expect(err).NotTo(HaveOccurred()) + + expectedFiles, err := os.ReadDir(expectedDir) + Expect(err).NotTo(HaveOccurred()) + + Expect(len(actualFiles)).To(Equal(len(expectedFiles))) + for i, actualFile := range actualFiles { + expectedFile := expectedFiles[i] + if actualFile.IsDir() { + compareDir(filepath.Join(actualDir, actualFile.Name()), filepath.Join(expectedDir, expectedFile.Name()), compareTimestamps) + } else { + compareFile(filepath.Join(actualDir, actualFile.Name()), filepath.Join(expectedDir, expectedFile.Name()), compareTimestamps) + } + } +} + +func compareFile(actualFile, expectedFile string, compareTimestamps bool) { + actualFileInfo, err := os.Stat(actualFile) + Expect(err).NotTo(HaveOccurred()) + + expectedFileInfo, ok := fileInfos[expectedFile] + Expect(ok).To(BeTrue()) + + compareFileInfo(actualFileInfo, expectedFileInfo, compareTimestamps) + + actualContents, err := os.ReadFile(actualFile) + Expect(err).NotTo(HaveOccurred()) + + expectedContents, err := os.ReadFile(expectedFile) + Expect(err).NotTo(HaveOccurred()) + + Expect(actualContents).To(Equal(expectedContents)) +} + +func compareFileInfo(actualFileInfo os.FileInfo, expectedFileInfo os.FileInfo, compareTimestamps bool) { + Expect(actualFileInfo.Mode()).To(Equal(expectedFileInfo.Mode())) + Expect(actualFileInfo.Size()).To(Equal(expectedFileInfo.Size())) + if compareTimestamps { + compareTimestampsFromInfo(actualFileInfo, expectedFileInfo) + } +} + +func compareTimestampsFromInfo(actualInfo, expectedInfo os.FileInfo) { + actualAccessTime, err := atime.AccessTime(actualInfo) + Expect(err).NotTo(HaveOccurred()) + + expectedAccessTime, err := atime.AccessTime(expectedInfo) + Expect(err).NotTo(HaveOccurred()) + + Expect(actualInfo.ModTime().Unix()).To(Equal(expectedInfo.ModTime().Unix())) + Expect(actualAccessTime.Unix()).To(Equal(expectedAccessTime.Unix())) +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/session.go b/src/code.cloudfoundry.org/diego-ssh/scp/session.go new file mode 100644 index 0000000000..fa62e8d7a9 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/session.go @@ -0,0 +1,125 @@ +package scp + +import ( + "bufio" + "errors" + "fmt" + "io" + "strings" + + "code.cloudfoundry.org/lager/v3" +) + +const ( + NEWLINE = "\n" + SPACE = " " +) + +type Session struct { + stdin *bufio.Reader + stdout io.Writer + stderr io.Writer + + preserveTimesAndMode bool + + logger lager.Logger +} + +func NewSession(stdin io.Reader, stdout io.Writer, stderr io.Writer, preserveTimesAndMode bool, logger lager.Logger) *Session { + return &Session{ + stdin: bufio.NewReader(stdin), + stdout: stdout, + stderr: stderr, + preserveTimesAndMode: preserveTimesAndMode, + logger: logger.Session("scp-session"), + } +} + +func (sess *Session) sendConfirmation() error { + _, err := sess.stdout.Write([]byte{0}) + return err +} + +func (sess *Session) sendError(message string) error { + _, err := sess.stdout.Write([]byte{1}) + if err != nil { + return err + } + + _, err = fmt.Fprintf(sess.stdout, "scp: %s\n", message) + if err != nil { + return err + } + + return nil +} + +func (sess *Session) awaitConfirmation() error { + ackType, err := sess.readByte() + if err != nil { + return err + } + + switch ackType { + case 0: + case 1: + message, err := sess.readString(NEWLINE) + if err != nil { + return err + } + fmt.Fprint(sess.stderr, message) + case 2: + message, err := sess.readString(NEWLINE) + if err != nil { + return err + } + return errors.New(message) + default: + return fmt.Errorf("invalid acknowledgement identifier: %x", ackType) + } + + return nil +} + +func (sess *Session) readString(delim string) (string, error) { + message, err := sess.stdin.ReadString(delim[0]) + if err != nil { + return "", err + } + + return strings.TrimSuffix(message, delim), nil +} + +func (sess *Session) readByte() (byte, error) { + message := make([]byte, 1) + + var n int + var err error + for n == 0 && err == nil { + n, err = sess.stdin.Read(message) + } + + if err != nil { + return 0, err + } + + if n != 1 { + return 0, errors.New("read failed") + } + + return message[0], nil +} + +func (sess *Session) peekByte() (byte, error) { + b, err := sess.readByte() + if err != nil { + return b, err + } + + err = sess.stdin.UnreadByte() + if err != nil { + return b, err + } + + return b, nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/time.go b/src/code.cloudfoundry.org/diego-ssh/scp/time.go new file mode 100644 index 0000000000..274eb960e6 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/time.go @@ -0,0 +1,96 @@ +package scp + +import ( + "fmt" + "os" + "strconv" + "time" + + "code.cloudfoundry.org/diego-ssh/scp/atime" +) + +type TimeMessage struct { + modificationTime time.Time + accessTime time.Time +} + +func NewTimeMessage(fileInfo os.FileInfo) *TimeMessage { + accessTime, err := atime.AccessTime(fileInfo) + if err != nil { + accessTime = time.Unix(0, 0) + } + + return &TimeMessage{ + modificationTime: fileInfo.ModTime(), + accessTime: accessTime, + } +} + +func (tm *TimeMessage) ModificationTime() time.Time { + return tm.modificationTime +} + +func (tm *TimeMessage) AccessTime() time.Time { + return tm.accessTime +} + +func (tm *TimeMessage) Send(session *Session) error { + _, err := fmt.Fprintf(session.stdout, "T%d 0 %d 0\n", tm.modificationTime.Unix(), tm.accessTime.Unix()) + if err != nil { + return err + } + + return session.awaitConfirmation() +} + +func (tm *TimeMessage) Receive(session *Session) error { + messageType, err := session.readByte() + if err != nil { + return err + } + + if messageType != byte('T') { + return fmt.Errorf("unexpected message type: %c", messageType) + } + + modTimeString, err := session.readString(SPACE) + if err != nil { + return err + } + + modTimeSeconds, err := strconv.ParseUint(modTimeString, 10, 64) + if err != nil { + return err + } + + tm.modificationTime = time.Unix(int64(modTimeSeconds), 0) + + _, err = session.readString(SPACE) + if err != nil { + return err + } + + accessTimeString, err := session.readString(SPACE) + if err != nil { + return err + } + + accessTimeSeconds, err := strconv.ParseUint(accessTimeString, 10, 64) + if err != nil { + return err + } + + tm.accessTime = time.Unix(int64(accessTimeSeconds), 0) + + _, err = session.readString(NEWLINE) + if err != nil { + return err + } + + err = session.sendConfirmation() + if err != nil { + return err + } + + return nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/scp/time_test.go b/src/code.cloudfoundry.org/diego-ssh/scp/time_test.go new file mode 100644 index 0000000000..ab62eb12b3 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/scp/time_test.go @@ -0,0 +1,277 @@ +package scp_test + +import ( + "bytes" + "crypto/rand" + "io" + "os" + "path/filepath" + "strings" + "time" + + "code.cloudfoundry.org/diego-ssh/scp" + "code.cloudfoundry.org/diego-ssh/scp/atime" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" + "code.cloudfoundry.org/lager/v3/lagertest" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("TimeMessage", func() { + var ( + tempDir string + tempFile string + + logger *lagertest.TestLogger + ) + + BeforeEach(func() { + logger = lagertest.NewTestLogger("test") + + var err error + tempDir, err = os.MkdirTemp("", "scp") + Expect(err).NotTo(HaveOccurred()) + + fileContents := make([]byte, 1024) + tempFile = filepath.Join(tempDir, "binary.dat") + + _, err = rand.Read(fileContents) + Expect(err).NotTo(HaveOccurred()) + + err = os.WriteFile(tempFile, fileContents, 0400) + Expect(err).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + os.RemoveAll(tempDir) + }) + + Context("when creating a time message from file information", func() { + var ( + timeMessage *scp.TimeMessage + + expectedModificationTime time.Time + expectedAccessTime time.Time + ) + + BeforeEach(func() { + fileInfo, ferr := os.Stat(tempFile) + Expect(ferr).NotTo(HaveOccurred()) + + expectedAccessTime, ferr = atime.AccessTime(fileInfo) + Expect(ferr).NotTo(HaveOccurred()) + + expectedModificationTime = fileInfo.ModTime() + + timeMessage = scp.NewTimeMessage(fileInfo) + }) + + It("acquires the correct modification time", func() { + Expect(timeMessage.ModificationTime()).To(Equal(expectedModificationTime)) + }) + + It("acquires the correct access time", func() { + Expect(timeMessage.AccessTime()).To(Equal(expectedAccessTime)) + }) + }) + + Context("when sending the time information to an scp sink", func() { + var timeMessage *scp.TimeMessage + + BeforeEach(func() { + modificationTime := time.Unix(123456789, 12345678) + accessTime := time.Unix(987654321, 987654321) + os.Chtimes(tempFile, accessTime, modificationTime) + + fileInfo, ferr := os.Stat(tempFile) + Expect(ferr).NotTo(HaveOccurred()) + + timeMessage = scp.NewTimeMessage(fileInfo) + }) + + It("sends the message with the appropriate times", func() { + stdin := bytes.NewReader([]byte{0}) + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + session := scp.NewSession(stdin, stdout, stderr, true, logger) + + err := timeMessage.Send(session) + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.String()).To(Equal("T123456789 0 987654321 0\n")) + }) + + It("writes the message before waiting for an acknowledgement", func() { + stdin := &fake_io.FakeReader{} + stdout := &fake_io.FakeWriter{} + stdoutBuffer := &bytes.Buffer{} + stderr := &bytes.Buffer{} + session := scp.NewSession(stdin, stdout, stderr, true, logger) + + stdout.WriteStub = stdoutBuffer.Write + stdin.ReadStub = func(buffer []byte) (int, error) { + Expect(stdout.WriteCallCount()).To(BeNumerically(">", 0)) + + buffer[0] = 0 + return 1, nil + } + + err := timeMessage.Send(session) + Expect(err).NotTo(HaveOccurred()) + + Expect(stdin.ReadCallCount()).To(BeNumerically(">", 0)) + Expect(stdoutBuffer.String()).To(Equal("T123456789 0 987654321 0\n")) + }) + + It("does not return before the message is confirmed", func() { + stdin, pw := io.Pipe() + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + session := scp.NewSession(stdin, stdout, stderr, true, logger) + + errCh := make(chan error, 1) + go func() { + errCh <- timeMessage.Send(session) + }() + + Consistently(errCh).ShouldNot(Receive(HaveOccurred())) + + n, err := pw.Write([]byte{0}) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(1)) + + Expect(stdout.String()).To(Equal("T123456789 0 987654321 0\n")) + }) + + Context("when the sink responds with a warning", func() { + var stdin, stdout, stderr *bytes.Buffer + var session *scp.Session + + BeforeEach(func() { + stdin = &bytes.Buffer{} + stdout = &bytes.Buffer{} + stderr = &bytes.Buffer{} + + session = scp.NewSession(stdin, stdout, stderr, true, logger) + + stdin.WriteByte(1) + stdin.WriteString("Danger!\n") + }) + + It("returns without an error", func() { + err := timeMessage.Send(session) + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.String()).To(Equal("T123456789 0 987654321 0\n")) + }) + + It("writes the message to stderr", func() { + timeMessage.Send(session) + Expect(stderr.String()).To(Equal("Danger!")) + }) + }) + + Context("when the sink responds with an error ", func() { + var stdin, stdout, stderr *bytes.Buffer + var session *scp.Session + + BeforeEach(func() { + stdin = &bytes.Buffer{} + stdout = &bytes.Buffer{} + stderr = &bytes.Buffer{} + + session = scp.NewSession(stdin, stdout, stderr, true, logger) + + stdin.WriteByte(2) + stdin.WriteString("oops...\n") + }) + + It("returns with an error", func() { + err := timeMessage.Send(session) + Expect(err).To(MatchError("oops...")) + }) + }) + }) + + Context("when receiving a time message from an scp source", func() { + var timeMessage *scp.TimeMessage + + BeforeEach(func() { + timeMessage = &scp.TimeMessage{} + }) + + It("creates a time message with the appropriate information", func() { + stdin := strings.NewReader("T123456789 0 987654321 0\n") + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + session := scp.NewSession(stdin, stdout, stderr, true, logger) + + err := timeMessage.Receive(session) + Expect(err).NotTo(HaveOccurred()) + + Expect(timeMessage.ModificationTime()).To(Equal(time.Unix(123456789, 0))) + Expect(timeMessage.AccessTime()).To(Equal(time.Unix(987654321, 0))) + }) + + It("sends a confirmation after the message is received", func() { + reader := strings.NewReader("T123456789 0 987654321 0\n") + stdin := &fake_io.FakeReader{} + stdout := &fake_io.FakeWriter{} + stderr := &bytes.Buffer{} + session := scp.NewSession(stdin, stdout, stderr, true, logger) + + stdin.ReadStub = reader.Read + stdout.WriteStub = func(message []byte) (int, error) { + Expect(stdin.ReadCallCount()).To(BeNumerically(">", 0)) + Expect(reader.Len()).To(Equal(0)) + + Expect(message).To(HaveLen(1)) + Expect(message[0]).To(BeEquivalentTo(0)) + + return 1, nil + } + + err := timeMessage.Receive(session) + Expect(err).NotTo(HaveOccurred()) + + Expect(stdout.WriteCallCount()).To(BeNumerically(">", 0)) + Expect(stdout.WriteArgsForCall(0)).To(Equal([]byte{0})) + }) + + Context("when the message is not a time message", func() { + It("fails with an error", func() { + stdin := strings.NewReader("$123456789 0 987654321 0\n") + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + session := scp.NewSession(stdin, stdout, stderr, true, logger) + + err := timeMessage.Receive(session) + Expect(err).To(MatchError("unexpected message type: $")) + }) + }) + + Context("when the modification time field is not a number", func() { + It("fails with an error", func() { + stdin := strings.NewReader("Tmodification 0 987654321 0\n") + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + session := scp.NewSession(stdin, stdout, stderr, true, logger) + + err := timeMessage.Receive(session) + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when the access time field is not a number", func() { + It("fails with an error", func() { + stdin := strings.NewReader("T123456789 0 access 0\n") + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + session := scp.NewSession(stdin, stdout, stderr, true, logger) + + err := timeMessage.Receive(session) + Expect(err).To(HaveOccurred()) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/server/conn_handler.go b/src/code.cloudfoundry.org/diego-ssh/server/conn_handler.go new file mode 100644 index 0000000000..001fa11686 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/server/conn_handler.go @@ -0,0 +1,79 @@ +package server + +import ( + "net" + "sync" + "sync/atomic" +) + +type serverState int32 + +const ( + stateDefault = int32(iota) + stateStopped +) + +func (s *serverState) StopOnce() bool { + return atomic.CompareAndSwapInt32((*int32)(s), stateDefault, stateStopped) +} + +func (s *serverState) Stopped() bool { + return atomic.LoadInt32((*int32)(s)) == stateStopped +} + +type connHandler struct { + store map[net.Conn]struct{} + mu sync.Mutex + wg sync.WaitGroup + state serverState +} + +func (s *connHandler) remove(conn net.Conn) { + s.mu.Lock() + delete(s.store, conn) + s.wg.Done() + s.mu.Unlock() +} + +func (s *connHandler) handle(handler ConnectionHandler, conn net.Conn) { + defer s.remove(conn) + handler.HandleConnection(conn) +} + +func (s *connHandler) Handle(handler ConnectionHandler, conn net.Conn) { + // fast exit: don't attempt to acquire the mutex or + // handle the conn if shutdown + if s.state.Stopped() { + // #nosec G104 - ignore errors when closing SSH connections so we don't spam our logs during a DoS + conn.Close() + return + } + s.mu.Lock() + defer s.mu.Unlock() + // recheck the state now that we've locked the mutex + // as we may have been blocked on call to Shutdown() + if s.state.Stopped() { + // #nosec G104 - ignore errors when closing SSH connections so we don't spam our logs during a DoS + conn.Close() + return + } + // lazily initialize the store + if s.store == nil { + s.store = make(map[net.Conn]struct{}) + } + s.store[conn] = struct{}{} + s.wg.Add(1) + go s.handle(handler, conn) +} + +func (s *connHandler) Shutdown() { + if s.state.StopOnce() { + s.mu.Lock() + for c := range s.store { + // #nosec G104 - ignore errors when closing SSH connections so we don't spam our logs during a DoS + c.Close() + } + s.mu.Unlock() + s.wg.Wait() + } +} diff --git a/src/code.cloudfoundry.org/diego-ssh/server/fakes/fake_connection_handler.go b/src/code.cloudfoundry.org/diego-ssh/server/fakes/fake_connection_handler.go new file mode 100644 index 0000000000..6be5d05bb4 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/server/fakes/fake_connection_handler.go @@ -0,0 +1,77 @@ +// Code generated by counterfeiter. DO NOT EDIT. +package fakes + +import ( + "net" + "sync" + + "code.cloudfoundry.org/diego-ssh/server" +) + +type FakeConnectionHandler struct { + HandleConnectionStub func(net.Conn) + handleConnectionMutex sync.RWMutex + handleConnectionArgsForCall []struct { + arg1 net.Conn + } + invocations map[string][][]interface{} + invocationsMutex sync.RWMutex +} + +func (fake *FakeConnectionHandler) HandleConnection(arg1 net.Conn) { + fake.handleConnectionMutex.Lock() + fake.handleConnectionArgsForCall = append(fake.handleConnectionArgsForCall, struct { + arg1 net.Conn + }{arg1}) + fake.recordInvocation("HandleConnection", []interface{}{arg1}) + handleConnectionStubCopy := fake.HandleConnectionStub + fake.handleConnectionMutex.Unlock() + if handleConnectionStubCopy != nil { + handleConnectionStubCopy(arg1) + } +} + +func (fake *FakeConnectionHandler) HandleConnectionCallCount() int { + fake.handleConnectionMutex.RLock() + defer fake.handleConnectionMutex.RUnlock() + return len(fake.handleConnectionArgsForCall) +} + +func (fake *FakeConnectionHandler) HandleConnectionCalls(stub func(net.Conn)) { + fake.handleConnectionMutex.Lock() + defer fake.handleConnectionMutex.Unlock() + fake.HandleConnectionStub = stub +} + +func (fake *FakeConnectionHandler) HandleConnectionArgsForCall(i int) net.Conn { + fake.handleConnectionMutex.RLock() + defer fake.handleConnectionMutex.RUnlock() + argsForCall := fake.handleConnectionArgsForCall[i] + return argsForCall.arg1 +} + +func (fake *FakeConnectionHandler) Invocations() map[string][][]interface{} { + fake.invocationsMutex.RLock() + defer fake.invocationsMutex.RUnlock() + fake.handleConnectionMutex.RLock() + defer fake.handleConnectionMutex.RUnlock() + copiedInvocations := map[string][][]interface{}{} + for key, value := range fake.invocations { + copiedInvocations[key] = value + } + return copiedInvocations +} + +func (fake *FakeConnectionHandler) recordInvocation(key string, args []interface{}) { + fake.invocationsMutex.Lock() + defer fake.invocationsMutex.Unlock() + if fake.invocations == nil { + fake.invocations = map[string][][]interface{}{} + } + if fake.invocations[key] == nil { + fake.invocations[key] = [][]interface{}{} + } + fake.invocations[key] = append(fake.invocations[key], args) +} + +var _ server.ConnectionHandler = new(FakeConnectionHandler) diff --git a/src/code.cloudfoundry.org/diego-ssh/server/fakes/package.go b/src/code.cloudfoundry.org/diego-ssh/server/fakes/package.go new file mode 100644 index 0000000000..97d5f1c665 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/server/fakes/package.go @@ -0,0 +1 @@ +package fakes // import "code.cloudfoundry.org/diego-ssh/server/fakes" diff --git a/src/code.cloudfoundry.org/diego-ssh/server/package.go b/src/code.cloudfoundry.org/diego-ssh/server/package.go new file mode 100644 index 0000000000..10696671dc --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/server/package.go @@ -0,0 +1 @@ +package server // import "code.cloudfoundry.org/diego-ssh/server" diff --git a/src/code.cloudfoundry.org/diego-ssh/server/server.go b/src/code.cloudfoundry.org/diego-ssh/server/server.go new file mode 100644 index 0000000000..75236a90b9 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/server/server.go @@ -0,0 +1,139 @@ +package server + +import ( + "errors" + "net" + "os" + "sync" + "time" + + "code.cloudfoundry.org/lager/v3" +) + +//go:generate counterfeiter -o fakes/fake_connection_handler.go . ConnectionHandler +type ConnectionHandler interface { + HandleConnection(net.Conn) +} + +type Server struct { + logger lager.Logger + listenAddress string + connectionHandler ConnectionHandler + listener net.Listener + mutex *sync.Mutex + state serverState + idleConnTimeout time.Duration + store connHandler +} + +func NewServer( + logger lager.Logger, + listenAddress string, + connectionHandler ConnectionHandler, + idleConnTimeout time.Duration, +) *Server { + return &Server{ + logger: logger.Session("server"), + listenAddress: listenAddress, + connectionHandler: connectionHandler, + mutex: &sync.Mutex{}, + idleConnTimeout: idleConnTimeout, + } +} + +func (s *Server) Run(signals <-chan os.Signal, ready chan<- struct{}) error { + listener, err := net.Listen("tcp", s.listenAddress) + if err != nil { + return err + } + + s.SetListener(listener) + go s.Serve() + + close(ready) + + <-signals + s.Shutdown() + + return nil +} + +func (s *Server) Shutdown() { + if s.state.StopOnce() { + s.logger.Info("stopping-server") + err := s.listener.Close() + if err != nil { + s.logger.Error("listener-failed-to-close", err) + } + s.store.Shutdown() + } +} + +func (s *Server) IsStopping() bool { return s.state.Stopped() } + +func (s *Server) SetListener(listener net.Listener) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.listener != nil { + err := errors.New("Listener has already been set") + s.logger.Error("listener-already-set", err) + } + + s.listener = listener +} + +func (s *Server) ListenAddr() (net.Addr, error) { + if s.listener == nil { + return nil, errors.New("No listener") + } + + return s.listener.Addr(), nil +} + +type idleTimeoutConn struct { + Timeout time.Duration + net.Conn +} + +func (c *idleTimeoutConn) Read(b []byte) (n int, err error) { + if err = c.Conn.SetDeadline(time.Now().Add(c.Timeout)); err != nil { + return + } + return c.Conn.Read(b) +} + +func (c *idleTimeoutConn) Write(b []byte) (n int, err error) { + if err = c.Conn.SetDeadline(time.Now().Add(c.Timeout)); err != nil { + return + } + return c.Conn.Write(b) +} + +func (s *Server) Serve() { + logger := s.logger.Session("serve") + defer s.listener.Close() + + for { + netConn, err := s.listener.Accept() + if s.idleConnTimeout > 0 { + netConn = &idleTimeoutConn{s.idleConnTimeout, netConn} + } + if err != nil { + //lint:ignore SA1019 - http.Server still uses this logic, and they dont want to update it because its scary. Following their lead. + if netErr, ok := err.(net.Error); ok && netErr.Temporary() { + logger.Error("accept-temporary-error", netErr) + time.Sleep(100 * time.Millisecond) + continue + } + + if s.IsStopping() { + break + } + + logger.Error("accept-failed", err) + return + } + s.store.Handle(s.connectionHandler, netConn) + } +} diff --git a/src/code.cloudfoundry.org/diego-ssh/server/server_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/server/server_suite_test.go new file mode 100644 index 0000000000..f9ad369843 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/server/server_suite_test.go @@ -0,0 +1,26 @@ +package server_test + +import ( + "code.cloudfoundry.org/inigo/helpers/portauthority" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "testing" +) + +var portAllocator portauthority.PortAllocator +var _ = BeforeSuite(func() { + node := GinkgoParallelProcess() + startPort := 1070 * node + portRange := 1000 + endPort := startPort + portRange + + var err error + portAllocator, err = portauthority.New(startPort, endPort) + Expect(err).NotTo(HaveOccurred()) +}) + +func TestServer(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Server Suite") +} diff --git a/src/code.cloudfoundry.org/diego-ssh/server/server_test.go b/src/code.cloudfoundry.org/diego-ssh/server/server_test.go new file mode 100644 index 0000000000..c5c1870904 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/server/server_test.go @@ -0,0 +1,244 @@ +package server_test + +import ( + "errors" + "fmt" + "net" + "os" + "time" + + "code.cloudfoundry.org/diego-ssh/server" + "code.cloudfoundry.org/diego-ssh/server/fakes" + "code.cloudfoundry.org/diego-ssh/test_helpers" + "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" + "code.cloudfoundry.org/lager/v3" + "code.cloudfoundry.org/lager/v3/lagertest" + "github.com/tedsuo/ifrit" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "github.com/onsi/gomega/gbytes" +) + +var _ = Describe("Server", func() { + var ( + logger lager.Logger + srv *server.Server + + handler *fakes.FakeConnectionHandler + + address string + ) + + BeforeEach(func() { + port, err := portAllocator.ClaimPorts(1) + Expect(err).NotTo(HaveOccurred()) + + handler = &fakes.FakeConnectionHandler{} + address = fmt.Sprintf("127.0.0.1:%d", port) + logger = lagertest.NewTestLogger("test") + }) + + Describe("Run", func() { + var process ifrit.Process + + BeforeEach(func() { + srv = server.NewServer(logger, address, handler, 500*time.Millisecond) + process = ifrit.Invoke(srv) + }) + + AfterEach(func() { + process.Signal(os.Interrupt) + Eventually(process.Wait()).Should(Receive()) + }) + + It("accepts connections on the specified address", func() { + _, err := net.Dial("tcp", address) + Expect(err).NotTo(HaveOccurred()) + }) + + Context("when a second client connects", func() { + JustBeforeEach(func() { + _, err := net.Dial("tcp", address) + Expect(err).NotTo(HaveOccurred()) + }) + + It("accepts the new connection", func() { + _, err := net.Dial("tcp", address) + Expect(err).NotTo(HaveOccurred()) + }) + }) + }) + + Describe("SetListener", func() { + var fakeListener *fake_net.FakeListener + + BeforeEach(func() { + fakeListener = &fake_net.FakeListener{} + + srv = server.NewServer(logger, address, handler, 500*time.Millisecond) + srv.SetListener(fakeListener) + }) + + Context("when a listener has already been set", func() { + It("logs", func() { + listener := &fake_net.FakeListener{} + srv.SetListener(listener) + + Expect(logger.(*lagertest.TestLogger).Logs()[0].Message).To(ContainSubstring("listener-already-set")) + }) + }) + }) + + Describe("Serve", func() { + var fakeListener *fake_net.FakeListener + var fakeConn *fake_net.FakeConn + + BeforeEach(func() { + fakeListener = &fake_net.FakeListener{} + fakeConn = &fake_net.FakeConn{} + + connectionCh := make(chan net.Conn, 1) + connectionCh <- fakeConn + + fakeListener.AcceptStub = func() (net.Conn, error) { + cx := connectionCh + select { + case conn := <-cx: + return conn, nil + default: + return nil, errors.New("fail") + } + } + }) + + JustBeforeEach(func() { + srv = server.NewServer(logger, address, handler, 500*time.Millisecond) + srv.SetListener(fakeListener) + srv.Serve() + }) + + It("accepts inbound connections", func() { + Expect(fakeListener.AcceptCallCount()).To(Equal(2)) + }) + + It("passes the connection to the connection handler", func() { + Eventually(handler.HandleConnectionCallCount).Should(Equal(1)) + conn := handler.HandleConnectionArgsForCall(0) + conn.Read(nil) + Expect(fakeConn.ReadCallCount()).To(Equal(1)) + }) + + It("sets a deadline on the connection", func() { + Eventually(handler.HandleConnectionCallCount).Should(Equal(1)) + conn := handler.HandleConnectionArgsForCall(0) + conn.Read(nil) + Expect(fakeConn.SetDeadlineCallCount()).To(Equal(1)) + t := fakeConn.SetDeadlineArgsForCall(0) + Expect(time.Until(t)).To(BeNumerically("<=", 500*time.Millisecond)) + }) + + Context("when accept returns a permanent error", func() { + BeforeEach(func() { + fakeListener.AcceptReturns(nil, errors.New("oops")) + }) + + It("closes the listener", func() { + Expect(fakeListener.CloseCallCount()).To(Equal(1)) + }) + }) + + Context("when accept returns a temporary error", func() { + var timeCh chan time.Time + + BeforeEach(func() { + timeCh = make(chan time.Time, 3) + + fakeListener.AcceptStub = func() (net.Conn, error) { + timeCh := timeCh + select { + case timeCh <- time.Now(): + return nil, test_helpers.NewTestNetError(false, true) + default: + close(timeCh) + return nil, test_helpers.NewTestNetError(false, false) + } + } + }) + + It("retries the accept after a short delay", func() { + Expect(timeCh).To(HaveLen(3)) + + times := make([]time.Time, 0) + for t := range timeCh { + times = append(times, t) + } + + Expect(times[1]).To(BeTemporally("~", times[0].Add(100*time.Millisecond), 20*time.Millisecond)) + Expect(times[2]).To(BeTemporally("~", times[1].Add(100*time.Millisecond), 20*time.Millisecond)) + }) + }) + }) + + Describe("Shutdown", func() { + var fakeListener *fake_net.FakeListener + + BeforeEach(func() { + fakeListener = &fake_net.FakeListener{} + + srv = server.NewServer(logger, address, handler, 500*time.Millisecond) + srv.SetListener(fakeListener) + }) + + Context("when the server is shutdown", func() { + BeforeEach(func() { + srv.Shutdown() + }) + + It("closes the listener", func() { + Expect(fakeListener.CloseCallCount()).To(Equal(1)) + }) + + It("marks the server as stopping", func() { + Expect(srv.IsStopping()).To(BeTrue()) + }) + + It("does not log an accept failure", func() { + Eventually(func() error { + _, err := net.Dial("tcp", address) + return err + }).Should(HaveOccurred()) + Consistently(logger).ShouldNot(gbytes.Say("test.serve.accept-failed")) + }) + }) + }) + + Describe("ListenAddr", func() { + var listener net.Listener + BeforeEach(func() { + srv = server.NewServer(logger, address, handler, 500*time.Millisecond) + }) + + Context("when the server has a listener", func() { + BeforeEach(func() { + var err error + listener, err = net.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + + srv = server.NewServer(logger, address, handler, 500*time.Millisecond) + srv.SetListener(listener) + }) + + It("returns the address reported by the listener", func() { + Expect(srv.ListenAddr()).To(Equal(listener.Addr())) + }) + }) + + Context("when the server does not have a listener", func() { + It("returns an error", func() { + _, err := srv.ListenAddr() + Expect(err).To(HaveOccurred()) + }) + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/signals/package.go b/src/code.cloudfoundry.org/diego-ssh/signals/package.go new file mode 100644 index 0000000000..eb4a7c9530 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/signals/package.go @@ -0,0 +1 @@ +package signals // import "code.cloudfoundry.org/diego-ssh/signals" diff --git a/src/code.cloudfoundry.org/diego-ssh/signals/signals.go b/src/code.cloudfoundry.org/diego-ssh/signals/signals.go new file mode 100644 index 0000000000..8d60dea806 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/signals/signals.go @@ -0,0 +1,41 @@ +//go:build !windows + +package signals + +import ( + "syscall" + + "golang.org/x/crypto/ssh" +) + +var SyscallSignals = map[ssh.Signal]syscall.Signal{ + ssh.SIGABRT: syscall.SIGABRT, + ssh.SIGALRM: syscall.SIGALRM, + ssh.SIGFPE: syscall.SIGFPE, + ssh.SIGHUP: syscall.SIGHUP, + ssh.SIGILL: syscall.SIGILL, + ssh.SIGINT: syscall.SIGINT, + ssh.SIGKILL: syscall.SIGKILL, + ssh.SIGPIPE: syscall.SIGPIPE, + ssh.SIGQUIT: syscall.SIGQUIT, + ssh.SIGSEGV: syscall.SIGSEGV, + ssh.SIGTERM: syscall.SIGTERM, + ssh.SIGUSR1: syscall.SIGUSR1, + ssh.SIGUSR2: syscall.SIGUSR2, +} + +var SSHSignals = map[syscall.Signal]ssh.Signal{ + syscall.SIGABRT: ssh.SIGABRT, + syscall.SIGALRM: ssh.SIGALRM, + syscall.SIGFPE: ssh.SIGFPE, + syscall.SIGHUP: ssh.SIGHUP, + syscall.SIGILL: ssh.SIGILL, + syscall.SIGINT: ssh.SIGINT, + syscall.SIGKILL: ssh.SIGKILL, + syscall.SIGPIPE: ssh.SIGPIPE, + syscall.SIGQUIT: ssh.SIGQUIT, + syscall.SIGSEGV: ssh.SIGSEGV, + syscall.SIGTERM: ssh.SIGTERM, + syscall.SIGUSR1: ssh.SIGUSR1, + syscall.SIGUSR2: ssh.SIGUSR2, +} diff --git a/src/code.cloudfoundry.org/diego-ssh/signals/signals_suite_test.go b/src/code.cloudfoundry.org/diego-ssh/signals/signals_suite_test.go new file mode 100644 index 0000000000..8ba847af7e --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/signals/signals_suite_test.go @@ -0,0 +1,13 @@ +package signals_test + +import ( + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "testing" +) + +func TestSignals(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Signals Suite") +} diff --git a/src/code.cloudfoundry.org/diego-ssh/signals/signals_test.go b/src/code.cloudfoundry.org/diego-ssh/signals/signals_test.go new file mode 100644 index 0000000000..38742217a9 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/signals/signals_test.go @@ -0,0 +1,26 @@ +package signals_test + +import ( + "code.cloudfoundry.org/diego-ssh/signals" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("Signals", func() { + Describe("Signal Mapping", func() { + It("should have the same length map", func() { + Expect(signals.SyscallSignals).To(HaveLen(len(signals.SSHSignals))) + }) + + It("has the correct mapping", func() { + for k, v := range signals.SyscallSignals { + Expect(k).To(Equal(signals.SSHSignals[v])) + } + + for k, v := range signals.SSHSignals { + Expect(k).To(Equal(signals.SyscallSignals[v])) + } + }) + }) +}) diff --git a/src/code.cloudfoundry.org/diego-ssh/signals/signals_windows.go b/src/code.cloudfoundry.org/diego-ssh/signals/signals_windows.go new file mode 100644 index 0000000000..d0d46b2ba4 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/signals/signals_windows.go @@ -0,0 +1,37 @@ +//go:build windows + +package signals + +import ( + "syscall" + + "golang.org/x/crypto/ssh" +) + +var SyscallSignals = map[ssh.Signal]syscall.Signal{ + ssh.SIGABRT: syscall.SIGABRT, + ssh.SIGALRM: syscall.SIGALRM, + ssh.SIGFPE: syscall.SIGFPE, + ssh.SIGHUP: syscall.SIGHUP, + ssh.SIGILL: syscall.SIGILL, + ssh.SIGINT: syscall.SIGINT, + ssh.SIGKILL: syscall.SIGKILL, + ssh.SIGPIPE: syscall.SIGPIPE, + ssh.SIGQUIT: syscall.SIGQUIT, + ssh.SIGSEGV: syscall.SIGSEGV, + ssh.SIGTERM: syscall.SIGTERM, +} + +var SSHSignals = map[syscall.Signal]ssh.Signal{ + syscall.SIGABRT: ssh.SIGABRT, + syscall.SIGALRM: ssh.SIGALRM, + syscall.SIGFPE: ssh.SIGFPE, + syscall.SIGHUP: ssh.SIGHUP, + syscall.SIGILL: ssh.SIGILL, + syscall.SIGINT: ssh.SIGINT, + syscall.SIGKILL: ssh.SIGKILL, + syscall.SIGPIPE: ssh.SIGPIPE, + syscall.SIGQUIT: ssh.SIGQUIT, + syscall.SIGSEGV: ssh.SIGSEGV, + syscall.SIGTERM: ssh.SIGTERM, +} diff --git a/src/code.cloudfoundry.org/diego-ssh/termcodes/package.go b/src/code.cloudfoundry.org/diego-ssh/termcodes/package.go new file mode 100644 index 0000000000..3f8147c598 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/termcodes/package.go @@ -0,0 +1 @@ +package termcodes // import "code.cloudfoundry.org/diego-ssh/termcodes" diff --git a/src/code.cloudfoundry.org/diego-ssh/termcodes/termcodes.go b/src/code.cloudfoundry.org/diego-ssh/termcodes/termcodes.go new file mode 100644 index 0000000000..a17634324b --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/termcodes/termcodes.go @@ -0,0 +1,152 @@ +//go:build !windows + +package termcodes + +import ( + "os" + "syscall" + + "golang.org/x/crypto/ssh" +) + +// struct termios { +// tcflag_t c_iflag; /* input modes */ +// tcflag_t c_oflag; /* output modes */ +// tcflag_t c_cflag; /* control modes */ +// tcflag_t c_lflag; /* local modes */ +// cc_t c_cc[NCCS]; /* special characters */ +// speed_t c_ispeed; +// speed_t c_ospeed; +// }; + +type Setter interface { + Set(pty *os.File, termios *syscall.Termios, value uint32) error +} + +var TermAttrSetters map[uint8]Setter = map[uint8]Setter{ + ssh.VINTR: &ccSetter{Character: syscall.VINTR}, + ssh.VQUIT: &ccSetter{Character: syscall.VQUIT}, + ssh.VERASE: &ccSetter{Character: syscall.VERASE}, + ssh.VKILL: &ccSetter{Character: syscall.VKILL}, + ssh.VEOF: &ccSetter{Character: syscall.VEOF}, + ssh.VEOL: &ccSetter{Character: syscall.VEOL}, + ssh.VEOL2: &ccSetter{Character: syscall.VEOL2}, + ssh.VSTART: &ccSetter{Character: syscall.VSTART}, + ssh.VSTOP: &ccSetter{Character: syscall.VSTOP}, + ssh.VSUSP: &ccSetter{Character: syscall.VSUSP}, + ssh.VDSUSP: &nopSetter{}, + ssh.VREPRINT: &ccSetter{Character: syscall.VREPRINT}, + ssh.VWERASE: &ccSetter{Character: syscall.VWERASE}, + ssh.VLNEXT: &ccSetter{Character: syscall.VLNEXT}, + ssh.VFLUSH: &nopSetter{}, + ssh.VSWTCH: &nopSetter{}, + ssh.VSTATUS: &nopSetter{}, + ssh.VDISCARD: &ccSetter{Character: syscall.VDISCARD}, + + // Input modes + ssh.IGNPAR: &iflagSetter{Flag: syscall.IGNPAR}, + ssh.PARMRK: &iflagSetter{Flag: syscall.PARMRK}, + ssh.INPCK: &iflagSetter{Flag: syscall.INPCK}, + ssh.ISTRIP: &iflagSetter{Flag: syscall.ISTRIP}, + ssh.INLCR: &iflagSetter{Flag: syscall.INLCR}, + ssh.IGNCR: &iflagSetter{Flag: syscall.IGNCR}, + ssh.ICRNL: &iflagSetter{Flag: syscall.ICRNL}, + ssh.IUCLC: &nopSetter{}, + ssh.IXON: &iflagSetter{Flag: syscall.IXON}, + ssh.IXANY: &iflagSetter{Flag: syscall.IXANY}, + ssh.IXOFF: &iflagSetter{Flag: syscall.IXOFF}, + ssh.IMAXBEL: &iflagSetter{Flag: syscall.IMAXBEL}, + + // Local modes + ssh.ISIG: &lflagSetter{Flag: syscall.ISIG}, + ssh.ICANON: &lflagSetter{Flag: syscall.ICANON}, + ssh.XCASE: &nopSetter{}, + ssh.ECHO: &lflagSetter{Flag: syscall.ECHO}, + ssh.ECHOE: &lflagSetter{Flag: syscall.ECHOE}, + ssh.ECHOK: &lflagSetter{Flag: syscall.ECHOK}, + ssh.ECHONL: &lflagSetter{Flag: syscall.ECHONL}, + ssh.NOFLSH: &lflagSetter{Flag: syscall.NOFLSH}, + ssh.TOSTOP: &lflagSetter{Flag: syscall.TOSTOP}, + ssh.IEXTEN: &lflagSetter{Flag: syscall.IEXTEN}, + ssh.ECHOCTL: &lflagSetter{Flag: syscall.ECHOCTL}, + ssh.ECHOKE: &lflagSetter{Flag: syscall.ECHOKE}, + ssh.PENDIN: &lflagSetter{Flag: syscall.PENDIN}, + + // Output modes + ssh.OPOST: &oflagSetter{Flag: syscall.OPOST}, + ssh.OLCUC: &nopSetter{}, + ssh.ONLCR: &oflagSetter{Flag: syscall.ONLCR}, + ssh.OCRNL: &oflagSetter{Flag: syscall.OCRNL}, + ssh.ONOCR: &oflagSetter{Flag: syscall.ONOCR}, + ssh.ONLRET: &oflagSetter{Flag: syscall.ONLRET}, + + // Control modes + ssh.CS7: &cflagSetter{Flag: syscall.CS7}, + ssh.CS8: &cflagSetter{Flag: syscall.CS8}, + ssh.PARENB: &cflagSetter{Flag: syscall.PARENB}, + ssh.PARODD: &cflagSetter{Flag: syscall.PARODD}, + + // Baud rates (ignore) + ssh.TTY_OP_ISPEED: &nopSetter{}, + ssh.TTY_OP_OSPEED: &nopSetter{}, +} + +type nopSetter struct{} + +type ccSetter struct { + Character uint8 +} + +func (cc *ccSetter) Set(pty *os.File, termios *syscall.Termios, value uint32) error { + termios.Cc[cc.Character] = byte(value) + return SetAttr(pty, termios) +} + +func (i *iflagSetter) Set(pty *os.File, termios *syscall.Termios, value uint32) error { + if value == 0 { + termios.Iflag &^= i.Flag + } else { + termios.Iflag |= i.Flag + } + return SetAttr(pty, termios) +} + +func (l *lflagSetter) Set(pty *os.File, termios *syscall.Termios, value uint32) error { + if value == 0 { + termios.Lflag &^= l.Flag + } else { + termios.Lflag |= l.Flag + } + return SetAttr(pty, termios) +} + +func (o *oflagSetter) Set(pty *os.File, termios *syscall.Termios, value uint32) error { + if value == 0 { + termios.Oflag &^= o.Flag + } else { + termios.Oflag |= o.Flag + } + + return SetAttr(pty, termios) +} + +func (c *cflagSetter) Set(pty *os.File, termios *syscall.Termios, value uint32) error { + switch c.Flag { + // CSIZE is a field + case syscall.CS7, syscall.CS8: + termios.Cflag &^= syscall.CSIZE + termios.Cflag |= c.Flag + default: + if value == 0 { + termios.Cflag &^= c.Flag + } else { + termios.Cflag |= c.Flag + } + } + + return SetAttr(pty, termios) +} + +func (n *nopSetter) Set(pty *os.File, termios *syscall.Termios, value uint32) error { + return nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/termcodes/termcodes_darwin.go b/src/code.cloudfoundry.org/diego-ssh/termcodes/termcodes_darwin.go new file mode 100644 index 0000000000..176dee99c9 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/termcodes/termcodes_darwin.go @@ -0,0 +1,45 @@ +//go:build darwin + +package termcodes + +import ( + "os" + "syscall" + "unsafe" +) + +type iflagSetter struct { + Flag uint64 +} + +type lflagSetter struct { + Flag uint64 +} + +type oflagSetter struct { + Flag uint64 +} + +type cflagSetter struct { + Flag uint64 +} + +func SetAttr(tty *os.File, termios *syscall.Termios) error { + r, _, e := syscall.Syscall(syscall.SYS_IOCTL, tty.Fd(), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios))) + if r != 0 { + return os.NewSyscallError("SYS_IOCTL", e) + } + + return nil +} + +func GetAttr(tty *os.File) (*syscall.Termios, error) { + termios := &syscall.Termios{} + + r, _, e := syscall.Syscall(syscall.SYS_IOCTL, tty.Fd(), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios))) + if r != 0 { + return nil, os.NewSyscallError("SYS_IOCTL", e) + } + + return termios, nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/termcodes/termcodes_linux.go b/src/code.cloudfoundry.org/diego-ssh/termcodes/termcodes_linux.go new file mode 100644 index 0000000000..ae37c49ea8 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/termcodes/termcodes_linux.go @@ -0,0 +1,45 @@ +//go:build linux + +package termcodes + +import ( + "os" + "syscall" + "unsafe" +) + +type iflagSetter struct { + Flag uint32 +} + +type lflagSetter struct { + Flag uint32 +} + +type oflagSetter struct { + Flag uint32 +} + +type cflagSetter struct { + Flag uint32 +} + +func SetAttr(tty *os.File, termios *syscall.Termios) error { + r, _, e := syscall.Syscall(syscall.SYS_IOCTL, tty.Fd(), syscall.TCSETS, uintptr(unsafe.Pointer(termios))) + if r != 0 { + return os.NewSyscallError("SYS_IOCTL", e) + } + + return nil +} + +func GetAttr(tty *os.File) (*syscall.Termios, error) { + termios := &syscall.Termios{} + + r, _, e := syscall.Syscall(syscall.SYS_IOCTL, tty.Fd(), syscall.TCGETS, uintptr(unsafe.Pointer(termios))) + if r != 0 { + return nil, os.NewSyscallError("SYS_IOCTL", e) + } + + return termios, nil +} diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_read_closer.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_read_closer.go new file mode 100644 index 0000000000..2bbb45ac36 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_read_closer.go @@ -0,0 +1,84 @@ +// This file was generated by counterfeiter +package fake_io + +import ( + "io" + "sync" +) + +type FakeReadCloser struct { + ReadStub func(p []byte) (n int, err error) + readMutex sync.RWMutex + readArgsForCall []struct { + p []byte + } + readReturns struct { + result1 int + result2 error + } + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct{} + closeReturns struct { + result1 error + } +} + +func (fake *FakeReadCloser) Read(p []byte) (n int, err error) { + fake.readMutex.Lock() + fake.readArgsForCall = append(fake.readArgsForCall, struct { + p []byte + }{p}) + fake.readMutex.Unlock() + if fake.ReadStub != nil { + return fake.ReadStub(p) + } else { + return fake.readReturns.result1, fake.readReturns.result2 + } +} + +func (fake *FakeReadCloser) ReadCallCount() int { + fake.readMutex.RLock() + defer fake.readMutex.RUnlock() + return len(fake.readArgsForCall) +} + +func (fake *FakeReadCloser) ReadArgsForCall(i int) []byte { + fake.readMutex.RLock() + defer fake.readMutex.RUnlock() + return fake.readArgsForCall[i].p +} + +func (fake *FakeReadCloser) ReadReturns(result1 int, result2 error) { + fake.ReadStub = nil + fake.readReturns = struct { + result1 int + result2 error + }{result1, result2} +} + +func (fake *FakeReadCloser) Close() error { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct{}{}) + fake.closeMutex.Unlock() + if fake.CloseStub != nil { + return fake.CloseStub() + } else { + return fake.closeReturns.result1 + } +} + +func (fake *FakeReadCloser) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeReadCloser) CloseReturns(result1 error) { + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +var _ io.ReadCloser = new(FakeReadCloser) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_reader.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_reader.go new file mode 100644 index 0000000000..a4f17c4cde --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_reader.go @@ -0,0 +1,54 @@ +// This file was generated by counterfeiter +package fake_io + +import ( + "io" + "sync" +) + +type FakeReader struct { + ReadStub func(p []byte) (n int, err error) + readMutex sync.RWMutex + readArgsForCall []struct { + p []byte + } + readReturns struct { + result1 int + result2 error + } +} + +func (fake *FakeReader) Read(p []byte) (n int, err error) { + fake.readMutex.Lock() + fake.readArgsForCall = append(fake.readArgsForCall, struct { + p []byte + }{p}) + fake.readMutex.Unlock() + if fake.ReadStub != nil { + return fake.ReadStub(p) + } else { + return fake.readReturns.result1, fake.readReturns.result2 + } +} + +func (fake *FakeReader) ReadCallCount() int { + fake.readMutex.RLock() + defer fake.readMutex.RUnlock() + return len(fake.readArgsForCall) +} + +func (fake *FakeReader) ReadArgsForCall(i int) []byte { + fake.readMutex.RLock() + defer fake.readMutex.RUnlock() + return fake.readArgsForCall[i].p +} + +func (fake *FakeReader) ReadReturns(result1 int, result2 error) { + fake.ReadStub = nil + fake.readReturns = struct { + result1 int + result2 error + }{result1, result2} +} + +var _ io.Reader = new(FakeReader) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_write_closer.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_write_closer.go new file mode 100644 index 0000000000..2c64804714 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_write_closer.go @@ -0,0 +1,84 @@ +// This file was generated by counterfeiter +package fake_io + +import ( + "io" + "sync" +) + +type FakeWriteCloser struct { + WriteStub func(p []byte) (n int, err error) + writeMutex sync.RWMutex + writeArgsForCall []struct { + p []byte + } + writeReturns struct { + result1 int + result2 error + } + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct{} + closeReturns struct { + result1 error + } +} + +func (fake *FakeWriteCloser) Write(p []byte) (n int, err error) { + fake.writeMutex.Lock() + fake.writeArgsForCall = append(fake.writeArgsForCall, struct { + p []byte + }{p}) + fake.writeMutex.Unlock() + if fake.WriteStub != nil { + return fake.WriteStub(p) + } else { + return fake.writeReturns.result1, fake.writeReturns.result2 + } +} + +func (fake *FakeWriteCloser) WriteCallCount() int { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return len(fake.writeArgsForCall) +} + +func (fake *FakeWriteCloser) WriteArgsForCall(i int) []byte { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return fake.writeArgsForCall[i].p +} + +func (fake *FakeWriteCloser) WriteReturns(result1 int, result2 error) { + fake.WriteStub = nil + fake.writeReturns = struct { + result1 int + result2 error + }{result1, result2} +} + +func (fake *FakeWriteCloser) Close() error { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct{}{}) + fake.closeMutex.Unlock() + if fake.CloseStub != nil { + return fake.CloseStub() + } else { + return fake.closeReturns.result1 + } +} + +func (fake *FakeWriteCloser) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeWriteCloser) CloseReturns(result1 error) { + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +var _ io.WriteCloser = new(FakeWriteCloser) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_writer.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_writer.go new file mode 100644 index 0000000000..54c2f047f8 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/fake_writer.go @@ -0,0 +1,54 @@ +// This file was generated by counterfeiter +package fake_io + +import ( + "io" + "sync" +) + +type FakeWriter struct { + WriteStub func(p []byte) (n int, err error) + writeMutex sync.RWMutex + writeArgsForCall []struct { + p []byte + } + writeReturns struct { + result1 int + result2 error + } +} + +func (fake *FakeWriter) Write(p []byte) (n int, err error) { + fake.writeMutex.Lock() + fake.writeArgsForCall = append(fake.writeArgsForCall, struct { + p []byte + }{p}) + fake.writeMutex.Unlock() + if fake.WriteStub != nil { + return fake.WriteStub(p) + } else { + return fake.writeReturns.result1, fake.writeReturns.result2 + } +} + +func (fake *FakeWriter) WriteCallCount() int { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return len(fake.writeArgsForCall) +} + +func (fake *FakeWriter) WriteArgsForCall(i int) []byte { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return fake.writeArgsForCall[i].p +} + +func (fake *FakeWriter) WriteReturns(result1 int, result2 error) { + fake.WriteStub = nil + fake.writeReturns = struct { + result1 int + result2 error + }{result1, result2} +} + +var _ io.Writer = new(FakeWriter) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/package.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/package.go new file mode 100644 index 0000000000..bec2f37464 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_io/package.go @@ -0,0 +1 @@ +package fake_io // import "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_net/fake_conn.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_net/fake_conn.go new file mode 100644 index 0000000000..d5f7d893f0 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_net/fake_conn.go @@ -0,0 +1,307 @@ +// This file was generated by counterfeiter +package fake_net + +import ( + "net" + "sync" + "time" +) + +type FakeConn struct { + ReadStub func(b []byte) (n int, err error) + readMutex sync.RWMutex + readArgsForCall []struct { + b []byte + } + readReturns struct { + result1 int + result2 error + } + WriteStub func(b []byte) (n int, err error) + writeMutex sync.RWMutex + writeArgsForCall []struct { + b []byte + } + writeReturns struct { + result1 int + result2 error + } + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct{} + closeReturns struct { + result1 error + } + LocalAddrStub func() net.Addr + localAddrMutex sync.RWMutex + localAddrArgsForCall []struct{} + localAddrReturns struct { + result1 net.Addr + } + RemoteAddrStub func() net.Addr + remoteAddrMutex sync.RWMutex + remoteAddrArgsForCall []struct{} + remoteAddrReturns struct { + result1 net.Addr + } + SetDeadlineStub func(t time.Time) error + setDeadlineMutex sync.RWMutex + setDeadlineArgsForCall []struct { + t time.Time + } + setDeadlineReturns struct { + result1 error + } + SetReadDeadlineStub func(t time.Time) error + setReadDeadlineMutex sync.RWMutex + setReadDeadlineArgsForCall []struct { + t time.Time + } + setReadDeadlineReturns struct { + result1 error + } + SetWriteDeadlineStub func(t time.Time) error + setWriteDeadlineMutex sync.RWMutex + setWriteDeadlineArgsForCall []struct { + t time.Time + } + setWriteDeadlineReturns struct { + result1 error + } +} + +func (fake *FakeConn) Read(b []byte) (n int, err error) { + fake.readMutex.Lock() + fake.readArgsForCall = append(fake.readArgsForCall, struct { + b []byte + }{b}) + fake.readMutex.Unlock() + if fake.ReadStub != nil { + return fake.ReadStub(b) + } else { + return fake.readReturns.result1, fake.readReturns.result2 + } +} + +func (fake *FakeConn) ReadCallCount() int { + fake.readMutex.RLock() + defer fake.readMutex.RUnlock() + return len(fake.readArgsForCall) +} + +func (fake *FakeConn) ReadArgsForCall(i int) []byte { + fake.readMutex.RLock() + defer fake.readMutex.RUnlock() + return fake.readArgsForCall[i].b +} + +func (fake *FakeConn) ReadReturns(result1 int, result2 error) { + fake.ReadStub = nil + fake.readReturns = struct { + result1 int + result2 error + }{result1, result2} +} + +func (fake *FakeConn) Write(b []byte) (n int, err error) { + fake.writeMutex.Lock() + fake.writeArgsForCall = append(fake.writeArgsForCall, struct { + b []byte + }{b}) + fake.writeMutex.Unlock() + if fake.WriteStub != nil { + return fake.WriteStub(b) + } else { + return fake.writeReturns.result1, fake.writeReturns.result2 + } +} + +func (fake *FakeConn) WriteCallCount() int { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return len(fake.writeArgsForCall) +} + +func (fake *FakeConn) WriteArgsForCall(i int) []byte { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return fake.writeArgsForCall[i].b +} + +func (fake *FakeConn) WriteReturns(result1 int, result2 error) { + fake.WriteStub = nil + fake.writeReturns = struct { + result1 int + result2 error + }{result1, result2} +} + +func (fake *FakeConn) Close() error { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct{}{}) + fake.closeMutex.Unlock() + if fake.CloseStub != nil { + return fake.CloseStub() + } else { + return fake.closeReturns.result1 + } +} + +func (fake *FakeConn) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeConn) CloseReturns(result1 error) { + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeConn) LocalAddr() net.Addr { + fake.localAddrMutex.Lock() + fake.localAddrArgsForCall = append(fake.localAddrArgsForCall, struct{}{}) + fake.localAddrMutex.Unlock() + if fake.LocalAddrStub != nil { + return fake.LocalAddrStub() + } else { + return fake.localAddrReturns.result1 + } +} + +func (fake *FakeConn) LocalAddrCallCount() int { + fake.localAddrMutex.RLock() + defer fake.localAddrMutex.RUnlock() + return len(fake.localAddrArgsForCall) +} + +func (fake *FakeConn) LocalAddrReturns(result1 net.Addr) { + fake.LocalAddrStub = nil + fake.localAddrReturns = struct { + result1 net.Addr + }{result1} +} + +func (fake *FakeConn) RemoteAddr() net.Addr { + fake.remoteAddrMutex.Lock() + fake.remoteAddrArgsForCall = append(fake.remoteAddrArgsForCall, struct{}{}) + fake.remoteAddrMutex.Unlock() + if fake.RemoteAddrStub != nil { + return fake.RemoteAddrStub() + } else { + return fake.remoteAddrReturns.result1 + } +} + +func (fake *FakeConn) RemoteAddrCallCount() int { + fake.remoteAddrMutex.RLock() + defer fake.remoteAddrMutex.RUnlock() + return len(fake.remoteAddrArgsForCall) +} + +func (fake *FakeConn) RemoteAddrReturns(result1 net.Addr) { + fake.RemoteAddrStub = nil + fake.remoteAddrReturns = struct { + result1 net.Addr + }{result1} +} + +func (fake *FakeConn) SetDeadline(t time.Time) error { + fake.setDeadlineMutex.Lock() + fake.setDeadlineArgsForCall = append(fake.setDeadlineArgsForCall, struct { + t time.Time + }{t}) + fake.setDeadlineMutex.Unlock() + if fake.SetDeadlineStub != nil { + return fake.SetDeadlineStub(t) + } else { + return fake.setDeadlineReturns.result1 + } +} + +func (fake *FakeConn) SetDeadlineCallCount() int { + fake.setDeadlineMutex.RLock() + defer fake.setDeadlineMutex.RUnlock() + return len(fake.setDeadlineArgsForCall) +} + +func (fake *FakeConn) SetDeadlineArgsForCall(i int) time.Time { + fake.setDeadlineMutex.RLock() + defer fake.setDeadlineMutex.RUnlock() + return fake.setDeadlineArgsForCall[i].t +} + +func (fake *FakeConn) SetDeadlineReturns(result1 error) { + fake.SetDeadlineStub = nil + fake.setDeadlineReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeConn) SetReadDeadline(t time.Time) error { + fake.setReadDeadlineMutex.Lock() + fake.setReadDeadlineArgsForCall = append(fake.setReadDeadlineArgsForCall, struct { + t time.Time + }{t}) + fake.setReadDeadlineMutex.Unlock() + if fake.SetReadDeadlineStub != nil { + return fake.SetReadDeadlineStub(t) + } else { + return fake.setReadDeadlineReturns.result1 + } +} + +func (fake *FakeConn) SetReadDeadlineCallCount() int { + fake.setReadDeadlineMutex.RLock() + defer fake.setReadDeadlineMutex.RUnlock() + return len(fake.setReadDeadlineArgsForCall) +} + +func (fake *FakeConn) SetReadDeadlineArgsForCall(i int) time.Time { + fake.setReadDeadlineMutex.RLock() + defer fake.setReadDeadlineMutex.RUnlock() + return fake.setReadDeadlineArgsForCall[i].t +} + +func (fake *FakeConn) SetReadDeadlineReturns(result1 error) { + fake.SetReadDeadlineStub = nil + fake.setReadDeadlineReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeConn) SetWriteDeadline(t time.Time) error { + fake.setWriteDeadlineMutex.Lock() + fake.setWriteDeadlineArgsForCall = append(fake.setWriteDeadlineArgsForCall, struct { + t time.Time + }{t}) + fake.setWriteDeadlineMutex.Unlock() + if fake.SetWriteDeadlineStub != nil { + return fake.SetWriteDeadlineStub(t) + } else { + return fake.setWriteDeadlineReturns.result1 + } +} + +func (fake *FakeConn) SetWriteDeadlineCallCount() int { + fake.setWriteDeadlineMutex.RLock() + defer fake.setWriteDeadlineMutex.RUnlock() + return len(fake.setWriteDeadlineArgsForCall) +} + +func (fake *FakeConn) SetWriteDeadlineArgsForCall(i int) time.Time { + fake.setWriteDeadlineMutex.RLock() + defer fake.setWriteDeadlineMutex.RUnlock() + return fake.setWriteDeadlineArgsForCall[i].t +} + +func (fake *FakeConn) SetWriteDeadlineReturns(result1 error) { + fake.SetWriteDeadlineStub = nil + fake.setWriteDeadlineReturns = struct { + result1 error + }{result1} +} + +var _ net.Conn = new(FakeConn) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_net/fake_listener.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_net/fake_listener.go new file mode 100644 index 0000000000..964de24e1f --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_net/fake_listener.go @@ -0,0 +1,104 @@ +// This file was generated by counterfeiter +package fake_net + +import ( + "net" + "sync" +) + +type FakeListener struct { + AcceptStub func() (c net.Conn, err error) + acceptMutex sync.RWMutex + acceptArgsForCall []struct{} + acceptReturns struct { + result1 net.Conn + result2 error + } + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct{} + closeReturns struct { + result1 error + } + AddrStub func() net.Addr + addrMutex sync.RWMutex + addrArgsForCall []struct{} + addrReturns struct { + result1 net.Addr + } +} + +func (fake *FakeListener) Accept() (c net.Conn, err error) { + fake.acceptMutex.Lock() + fake.acceptArgsForCall = append(fake.acceptArgsForCall, struct{}{}) + fake.acceptMutex.Unlock() + if fake.AcceptStub != nil { + return fake.AcceptStub() + } else { + return fake.acceptReturns.result1, fake.acceptReturns.result2 + } +} + +func (fake *FakeListener) AcceptCallCount() int { + fake.acceptMutex.RLock() + defer fake.acceptMutex.RUnlock() + return len(fake.acceptArgsForCall) +} + +func (fake *FakeListener) AcceptReturns(result1 net.Conn, result2 error) { + fake.AcceptStub = nil + fake.acceptReturns = struct { + result1 net.Conn + result2 error + }{result1, result2} +} + +func (fake *FakeListener) Close() error { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct{}{}) + fake.closeMutex.Unlock() + if fake.CloseStub != nil { + return fake.CloseStub() + } else { + return fake.closeReturns.result1 + } +} + +func (fake *FakeListener) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeListener) CloseReturns(result1 error) { + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeListener) Addr() net.Addr { + fake.addrMutex.Lock() + fake.addrArgsForCall = append(fake.addrArgsForCall, struct{}{}) + fake.addrMutex.Unlock() + if fake.AddrStub != nil { + return fake.AddrStub() + } else { + return fake.addrReturns.result1 + } +} + +func (fake *FakeListener) AddrCallCount() int { + fake.addrMutex.RLock() + defer fake.addrMutex.RUnlock() + return len(fake.addrArgsForCall) +} + +func (fake *FakeListener) AddrReturns(result1 net.Addr) { + fake.AddrStub = nil + fake.addrReturns = struct { + result1 net.Addr + }{result1} +} + +var _ net.Listener = new(FakeListener) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_net/package.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_net/package.go new file mode 100644 index 0000000000..52c7e5fc0c --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_net/package.go @@ -0,0 +1 @@ +package fake_net // import "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_channel.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_channel.go new file mode 100644 index 0000000000..87e79fdb84 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_channel.go @@ -0,0 +1,234 @@ +// This file was generated by counterfeiter +package fake_ssh + +import ( + "io" + "sync" + + "golang.org/x/crypto/ssh" +) + +type FakeChannel struct { + ReadStub func(data []byte) (int, error) + readMutex sync.RWMutex + readArgsForCall []struct { + data []byte + } + readReturns struct { + result1 int + result2 error + } + WriteStub func(data []byte) (int, error) + writeMutex sync.RWMutex + writeArgsForCall []struct { + data []byte + } + writeReturns struct { + result1 int + result2 error + } + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct{} + closeReturns struct { + result1 error + } + CloseWriteStub func() error + closeWriteMutex sync.RWMutex + closeWriteArgsForCall []struct{} + closeWriteReturns struct { + result1 error + } + SendRequestStub func(name string, wantReply bool, payload []byte) (bool, error) + sendRequestMutex sync.RWMutex + sendRequestArgsForCall []struct { + name string + wantReply bool + payload []byte + } + sendRequestReturns struct { + result1 bool + result2 error + } + StderrStub func() io.ReadWriter + stderrMutex sync.RWMutex + stderrArgsForCall []struct{} + stderrReturns struct { + result1 io.ReadWriter + } +} + +func (fake *FakeChannel) Read(data []byte) (int, error) { + fake.readMutex.Lock() + fake.readArgsForCall = append(fake.readArgsForCall, struct { + data []byte + }{data}) + fake.readMutex.Unlock() + if fake.ReadStub != nil { + return fake.ReadStub(data) + } else { + return fake.readReturns.result1, fake.readReturns.result2 + } +} + +func (fake *FakeChannel) ReadCallCount() int { + fake.readMutex.RLock() + defer fake.readMutex.RUnlock() + return len(fake.readArgsForCall) +} + +func (fake *FakeChannel) ReadArgsForCall(i int) []byte { + fake.readMutex.RLock() + defer fake.readMutex.RUnlock() + return fake.readArgsForCall[i].data +} + +func (fake *FakeChannel) ReadReturns(result1 int, result2 error) { + fake.ReadStub = nil + fake.readReturns = struct { + result1 int + result2 error + }{result1, result2} +} + +func (fake *FakeChannel) Write(data []byte) (int, error) { + fake.writeMutex.Lock() + fake.writeArgsForCall = append(fake.writeArgsForCall, struct { + data []byte + }{data}) + fake.writeMutex.Unlock() + if fake.WriteStub != nil { + return fake.WriteStub(data) + } else { + return fake.writeReturns.result1, fake.writeReturns.result2 + } +} + +func (fake *FakeChannel) WriteCallCount() int { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return len(fake.writeArgsForCall) +} + +func (fake *FakeChannel) WriteArgsForCall(i int) []byte { + fake.writeMutex.RLock() + defer fake.writeMutex.RUnlock() + return fake.writeArgsForCall[i].data +} + +func (fake *FakeChannel) WriteReturns(result1 int, result2 error) { + fake.WriteStub = nil + fake.writeReturns = struct { + result1 int + result2 error + }{result1, result2} +} + +func (fake *FakeChannel) Close() error { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct{}{}) + fake.closeMutex.Unlock() + if fake.CloseStub != nil { + return fake.CloseStub() + } else { + return fake.closeReturns.result1 + } +} + +func (fake *FakeChannel) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeChannel) CloseReturns(result1 error) { + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeChannel) CloseWrite() error { + fake.closeWriteMutex.Lock() + fake.closeWriteArgsForCall = append(fake.closeWriteArgsForCall, struct{}{}) + fake.closeWriteMutex.Unlock() + if fake.CloseWriteStub != nil { + return fake.CloseWriteStub() + } else { + return fake.closeWriteReturns.result1 + } +} + +func (fake *FakeChannel) CloseWriteCallCount() int { + fake.closeWriteMutex.RLock() + defer fake.closeWriteMutex.RUnlock() + return len(fake.closeWriteArgsForCall) +} + +func (fake *FakeChannel) CloseWriteReturns(result1 error) { + fake.CloseWriteStub = nil + fake.closeWriteReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeChannel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + fake.sendRequestMutex.Lock() + fake.sendRequestArgsForCall = append(fake.sendRequestArgsForCall, struct { + name string + wantReply bool + payload []byte + }{name, wantReply, payload}) + fake.sendRequestMutex.Unlock() + if fake.SendRequestStub != nil { + return fake.SendRequestStub(name, wantReply, payload) + } else { + return fake.sendRequestReturns.result1, fake.sendRequestReturns.result2 + } +} + +func (fake *FakeChannel) SendRequestCallCount() int { + fake.sendRequestMutex.RLock() + defer fake.sendRequestMutex.RUnlock() + return len(fake.sendRequestArgsForCall) +} + +func (fake *FakeChannel) SendRequestArgsForCall(i int) (string, bool, []byte) { + fake.sendRequestMutex.RLock() + defer fake.sendRequestMutex.RUnlock() + return fake.sendRequestArgsForCall[i].name, fake.sendRequestArgsForCall[i].wantReply, fake.sendRequestArgsForCall[i].payload +} + +func (fake *FakeChannel) SendRequestReturns(result1 bool, result2 error) { + fake.SendRequestStub = nil + fake.sendRequestReturns = struct { + result1 bool + result2 error + }{result1, result2} +} + +func (fake *FakeChannel) Stderr() io.ReadWriter { + fake.stderrMutex.Lock() + fake.stderrArgsForCall = append(fake.stderrArgsForCall, struct{}{}) + fake.stderrMutex.Unlock() + if fake.StderrStub != nil { + return fake.StderrStub() + } else { + return fake.stderrReturns.result1 + } +} + +func (fake *FakeChannel) StderrCallCount() int { + fake.stderrMutex.RLock() + defer fake.stderrMutex.RUnlock() + return len(fake.stderrArgsForCall) +} + +func (fake *FakeChannel) StderrReturns(result1 io.ReadWriter) { + fake.StderrStub = nil + fake.stderrReturns = struct { + result1 io.ReadWriter + }{result1} +} + +var _ ssh.Channel = new(FakeChannel) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_conn.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_conn.go new file mode 100644 index 0000000000..12793e8711 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_conn.go @@ -0,0 +1,348 @@ +// This file was generated by counterfeiter +package fake_ssh + +import ( + "net" + "sync" + + "golang.org/x/crypto/ssh" +) + +type FakeConn struct { + UserStub func() string + userMutex sync.RWMutex + userArgsForCall []struct{} + userReturns struct { + result1 string + } + SessionIDStub func() []byte + sessionIDMutex sync.RWMutex + sessionIDArgsForCall []struct{} + sessionIDReturns struct { + result1 []byte + } + ClientVersionStub func() []byte + clientVersionMutex sync.RWMutex + clientVersionArgsForCall []struct{} + clientVersionReturns struct { + result1 []byte + } + ServerVersionStub func() []byte + serverVersionMutex sync.RWMutex + serverVersionArgsForCall []struct{} + serverVersionReturns struct { + result1 []byte + } + RemoteAddrStub func() net.Addr + remoteAddrMutex sync.RWMutex + remoteAddrArgsForCall []struct{} + remoteAddrReturns struct { + result1 net.Addr + } + LocalAddrStub func() net.Addr + localAddrMutex sync.RWMutex + localAddrArgsForCall []struct{} + localAddrReturns struct { + result1 net.Addr + } + SendRequestStub func(name string, wantReply bool, payload []byte) (bool, []byte, error) + sendRequestMutex sync.RWMutex + sendRequestArgsForCall []struct { + name string + wantReply bool + payload []byte + } + sendRequestReturns struct { + result1 bool + result2 []byte + result3 error + } + OpenChannelStub func(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) + openChannelMutex sync.RWMutex + openChannelArgsForCall []struct { + name string + data []byte + } + openChannelReturns struct { + result1 ssh.Channel + result2 <-chan *ssh.Request + result3 error + } + CloseStub func() error + closeMutex sync.RWMutex + closeArgsForCall []struct{} + closeReturns struct { + result1 error + } + WaitStub func() error + waitMutex sync.RWMutex + waitArgsForCall []struct{} + waitReturns struct { + result1 error + } +} + +func (fake *FakeConn) User() string { + fake.userMutex.Lock() + fake.userArgsForCall = append(fake.userArgsForCall, struct{}{}) + fake.userMutex.Unlock() + if fake.UserStub != nil { + return fake.UserStub() + } else { + return fake.userReturns.result1 + } +} + +func (fake *FakeConn) UserCallCount() int { + fake.userMutex.RLock() + defer fake.userMutex.RUnlock() + return len(fake.userArgsForCall) +} + +func (fake *FakeConn) UserReturns(result1 string) { + fake.UserStub = nil + fake.userReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeConn) SessionID() []byte { + fake.sessionIDMutex.Lock() + fake.sessionIDArgsForCall = append(fake.sessionIDArgsForCall, struct{}{}) + fake.sessionIDMutex.Unlock() + if fake.SessionIDStub != nil { + return fake.SessionIDStub() + } else { + return fake.sessionIDReturns.result1 + } +} + +func (fake *FakeConn) SessionIDCallCount() int { + fake.sessionIDMutex.RLock() + defer fake.sessionIDMutex.RUnlock() + return len(fake.sessionIDArgsForCall) +} + +func (fake *FakeConn) SessionIDReturns(result1 []byte) { + fake.SessionIDStub = nil + fake.sessionIDReturns = struct { + result1 []byte + }{result1} +} + +func (fake *FakeConn) ClientVersion() []byte { + fake.clientVersionMutex.Lock() + fake.clientVersionArgsForCall = append(fake.clientVersionArgsForCall, struct{}{}) + fake.clientVersionMutex.Unlock() + if fake.ClientVersionStub != nil { + return fake.ClientVersionStub() + } else { + return fake.clientVersionReturns.result1 + } +} + +func (fake *FakeConn) ClientVersionCallCount() int { + fake.clientVersionMutex.RLock() + defer fake.clientVersionMutex.RUnlock() + return len(fake.clientVersionArgsForCall) +} + +func (fake *FakeConn) ClientVersionReturns(result1 []byte) { + fake.ClientVersionStub = nil + fake.clientVersionReturns = struct { + result1 []byte + }{result1} +} + +func (fake *FakeConn) ServerVersion() []byte { + fake.serverVersionMutex.Lock() + fake.serverVersionArgsForCall = append(fake.serverVersionArgsForCall, struct{}{}) + fake.serverVersionMutex.Unlock() + if fake.ServerVersionStub != nil { + return fake.ServerVersionStub() + } else { + return fake.serverVersionReturns.result1 + } +} + +func (fake *FakeConn) ServerVersionCallCount() int { + fake.serverVersionMutex.RLock() + defer fake.serverVersionMutex.RUnlock() + return len(fake.serverVersionArgsForCall) +} + +func (fake *FakeConn) ServerVersionReturns(result1 []byte) { + fake.ServerVersionStub = nil + fake.serverVersionReturns = struct { + result1 []byte + }{result1} +} + +func (fake *FakeConn) RemoteAddr() net.Addr { + fake.remoteAddrMutex.Lock() + fake.remoteAddrArgsForCall = append(fake.remoteAddrArgsForCall, struct{}{}) + fake.remoteAddrMutex.Unlock() + if fake.RemoteAddrStub != nil { + return fake.RemoteAddrStub() + } else { + return fake.remoteAddrReturns.result1 + } +} + +func (fake *FakeConn) RemoteAddrCallCount() int { + fake.remoteAddrMutex.RLock() + defer fake.remoteAddrMutex.RUnlock() + return len(fake.remoteAddrArgsForCall) +} + +func (fake *FakeConn) RemoteAddrReturns(result1 net.Addr) { + fake.RemoteAddrStub = nil + fake.remoteAddrReturns = struct { + result1 net.Addr + }{result1} +} + +func (fake *FakeConn) LocalAddr() net.Addr { + fake.localAddrMutex.Lock() + fake.localAddrArgsForCall = append(fake.localAddrArgsForCall, struct{}{}) + fake.localAddrMutex.Unlock() + if fake.LocalAddrStub != nil { + return fake.LocalAddrStub() + } else { + return fake.localAddrReturns.result1 + } +} + +func (fake *FakeConn) LocalAddrCallCount() int { + fake.localAddrMutex.RLock() + defer fake.localAddrMutex.RUnlock() + return len(fake.localAddrArgsForCall) +} + +func (fake *FakeConn) LocalAddrReturns(result1 net.Addr) { + fake.LocalAddrStub = nil + fake.localAddrReturns = struct { + result1 net.Addr + }{result1} +} + +func (fake *FakeConn) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + fake.sendRequestMutex.Lock() + fake.sendRequestArgsForCall = append(fake.sendRequestArgsForCall, struct { + name string + wantReply bool + payload []byte + }{name, wantReply, payload}) + fake.sendRequestMutex.Unlock() + if fake.SendRequestStub != nil { + return fake.SendRequestStub(name, wantReply, payload) + } else { + return fake.sendRequestReturns.result1, fake.sendRequestReturns.result2, fake.sendRequestReturns.result3 + } +} + +func (fake *FakeConn) SendRequestCallCount() int { + fake.sendRequestMutex.RLock() + defer fake.sendRequestMutex.RUnlock() + return len(fake.sendRequestArgsForCall) +} + +func (fake *FakeConn) SendRequestArgsForCall(i int) (string, bool, []byte) { + fake.sendRequestMutex.RLock() + defer fake.sendRequestMutex.RUnlock() + return fake.sendRequestArgsForCall[i].name, fake.sendRequestArgsForCall[i].wantReply, fake.sendRequestArgsForCall[i].payload +} + +func (fake *FakeConn) SendRequestReturns(result1 bool, result2 []byte, result3 error) { + fake.SendRequestStub = nil + fake.sendRequestReturns = struct { + result1 bool + result2 []byte + result3 error + }{result1, result2, result3} +} + +func (fake *FakeConn) OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error) { + fake.openChannelMutex.Lock() + fake.openChannelArgsForCall = append(fake.openChannelArgsForCall, struct { + name string + data []byte + }{name, data}) + fake.openChannelMutex.Unlock() + if fake.OpenChannelStub != nil { + return fake.OpenChannelStub(name, data) + } else { + return fake.openChannelReturns.result1, fake.openChannelReturns.result2, fake.openChannelReturns.result3 + } +} + +func (fake *FakeConn) OpenChannelCallCount() int { + fake.openChannelMutex.RLock() + defer fake.openChannelMutex.RUnlock() + return len(fake.openChannelArgsForCall) +} + +func (fake *FakeConn) OpenChannelArgsForCall(i int) (string, []byte) { + fake.openChannelMutex.RLock() + defer fake.openChannelMutex.RUnlock() + return fake.openChannelArgsForCall[i].name, fake.openChannelArgsForCall[i].data +} + +func (fake *FakeConn) OpenChannelReturns(result1 ssh.Channel, result2 <-chan *ssh.Request, result3 error) { + fake.OpenChannelStub = nil + fake.openChannelReturns = struct { + result1 ssh.Channel + result2 <-chan *ssh.Request + result3 error + }{result1, result2, result3} +} + +func (fake *FakeConn) Close() error { + fake.closeMutex.Lock() + fake.closeArgsForCall = append(fake.closeArgsForCall, struct{}{}) + fake.closeMutex.Unlock() + if fake.CloseStub != nil { + return fake.CloseStub() + } else { + return fake.closeReturns.result1 + } +} + +func (fake *FakeConn) CloseCallCount() int { + fake.closeMutex.RLock() + defer fake.closeMutex.RUnlock() + return len(fake.closeArgsForCall) +} + +func (fake *FakeConn) CloseReturns(result1 error) { + fake.CloseStub = nil + fake.closeReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeConn) Wait() error { + fake.waitMutex.Lock() + fake.waitArgsForCall = append(fake.waitArgsForCall, struct{}{}) + fake.waitMutex.Unlock() + if fake.WaitStub != nil { + return fake.WaitStub() + } else { + return fake.waitReturns.result1 + } +} + +func (fake *FakeConn) WaitCallCount() int { + fake.waitMutex.RLock() + defer fake.waitMutex.RUnlock() + return len(fake.waitArgsForCall) +} + +func (fake *FakeConn) WaitReturns(result1 error) { + fake.WaitStub = nil + fake.waitReturns = struct { + result1 error + }{result1} +} + +var _ ssh.Conn = new(FakeConn) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_conn_metadata.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_conn_metadata.go new file mode 100644 index 0000000000..be091a8d7c --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_conn_metadata.go @@ -0,0 +1,194 @@ +// This file was generated by counterfeiter +package fake_ssh + +import ( + "net" + "sync" + + "golang.org/x/crypto/ssh" +) + +type FakeConnMetadata struct { + UserStub func() string + userMutex sync.RWMutex + userArgsForCall []struct{} + userReturns struct { + result1 string + } + SessionIDStub func() []byte + sessionIDMutex sync.RWMutex + sessionIDArgsForCall []struct{} + sessionIDReturns struct { + result1 []byte + } + ClientVersionStub func() []byte + clientVersionMutex sync.RWMutex + clientVersionArgsForCall []struct{} + clientVersionReturns struct { + result1 []byte + } + ServerVersionStub func() []byte + serverVersionMutex sync.RWMutex + serverVersionArgsForCall []struct{} + serverVersionReturns struct { + result1 []byte + } + RemoteAddrStub func() net.Addr + remoteAddrMutex sync.RWMutex + remoteAddrArgsForCall []struct{} + remoteAddrReturns struct { + result1 net.Addr + } + LocalAddrStub func() net.Addr + localAddrMutex sync.RWMutex + localAddrArgsForCall []struct{} + localAddrReturns struct { + result1 net.Addr + } +} + +func (fake *FakeConnMetadata) User() string { + fake.userMutex.Lock() + fake.userArgsForCall = append(fake.userArgsForCall, struct{}{}) + fake.userMutex.Unlock() + if fake.UserStub != nil { + return fake.UserStub() + } else { + return fake.userReturns.result1 + } +} + +func (fake *FakeConnMetadata) UserCallCount() int { + fake.userMutex.RLock() + defer fake.userMutex.RUnlock() + return len(fake.userArgsForCall) +} + +func (fake *FakeConnMetadata) UserReturns(result1 string) { + fake.UserStub = nil + fake.userReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeConnMetadata) SessionID() []byte { + fake.sessionIDMutex.Lock() + fake.sessionIDArgsForCall = append(fake.sessionIDArgsForCall, struct{}{}) + fake.sessionIDMutex.Unlock() + if fake.SessionIDStub != nil { + return fake.SessionIDStub() + } else { + return fake.sessionIDReturns.result1 + } +} + +func (fake *FakeConnMetadata) SessionIDCallCount() int { + fake.sessionIDMutex.RLock() + defer fake.sessionIDMutex.RUnlock() + return len(fake.sessionIDArgsForCall) +} + +func (fake *FakeConnMetadata) SessionIDReturns(result1 []byte) { + fake.SessionIDStub = nil + fake.sessionIDReturns = struct { + result1 []byte + }{result1} +} + +func (fake *FakeConnMetadata) ClientVersion() []byte { + fake.clientVersionMutex.Lock() + fake.clientVersionArgsForCall = append(fake.clientVersionArgsForCall, struct{}{}) + fake.clientVersionMutex.Unlock() + if fake.ClientVersionStub != nil { + return fake.ClientVersionStub() + } else { + return fake.clientVersionReturns.result1 + } +} + +func (fake *FakeConnMetadata) ClientVersionCallCount() int { + fake.clientVersionMutex.RLock() + defer fake.clientVersionMutex.RUnlock() + return len(fake.clientVersionArgsForCall) +} + +func (fake *FakeConnMetadata) ClientVersionReturns(result1 []byte) { + fake.ClientVersionStub = nil + fake.clientVersionReturns = struct { + result1 []byte + }{result1} +} + +func (fake *FakeConnMetadata) ServerVersion() []byte { + fake.serverVersionMutex.Lock() + fake.serverVersionArgsForCall = append(fake.serverVersionArgsForCall, struct{}{}) + fake.serverVersionMutex.Unlock() + if fake.ServerVersionStub != nil { + return fake.ServerVersionStub() + } else { + return fake.serverVersionReturns.result1 + } +} + +func (fake *FakeConnMetadata) ServerVersionCallCount() int { + fake.serverVersionMutex.RLock() + defer fake.serverVersionMutex.RUnlock() + return len(fake.serverVersionArgsForCall) +} + +func (fake *FakeConnMetadata) ServerVersionReturns(result1 []byte) { + fake.ServerVersionStub = nil + fake.serverVersionReturns = struct { + result1 []byte + }{result1} +} + +func (fake *FakeConnMetadata) RemoteAddr() net.Addr { + fake.remoteAddrMutex.Lock() + fake.remoteAddrArgsForCall = append(fake.remoteAddrArgsForCall, struct{}{}) + fake.remoteAddrMutex.Unlock() + if fake.RemoteAddrStub != nil { + return fake.RemoteAddrStub() + } else { + return fake.remoteAddrReturns.result1 + } +} + +func (fake *FakeConnMetadata) RemoteAddrCallCount() int { + fake.remoteAddrMutex.RLock() + defer fake.remoteAddrMutex.RUnlock() + return len(fake.remoteAddrArgsForCall) +} + +func (fake *FakeConnMetadata) RemoteAddrReturns(result1 net.Addr) { + fake.RemoteAddrStub = nil + fake.remoteAddrReturns = struct { + result1 net.Addr + }{result1} +} + +func (fake *FakeConnMetadata) LocalAddr() net.Addr { + fake.localAddrMutex.Lock() + fake.localAddrArgsForCall = append(fake.localAddrArgsForCall, struct{}{}) + fake.localAddrMutex.Unlock() + if fake.LocalAddrStub != nil { + return fake.LocalAddrStub() + } else { + return fake.localAddrReturns.result1 + } +} + +func (fake *FakeConnMetadata) LocalAddrCallCount() int { + fake.localAddrMutex.RLock() + defer fake.localAddrMutex.RUnlock() + return len(fake.localAddrArgsForCall) +} + +func (fake *FakeConnMetadata) LocalAddrReturns(result1 net.Addr) { + fake.LocalAddrStub = nil + fake.localAddrReturns = struct { + result1 net.Addr + }{result1} +} + +var _ ssh.ConnMetadata = new(FakeConnMetadata) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_new_channel.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_new_channel.go new file mode 100644 index 0000000000..924fcdc6ae --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_new_channel.go @@ -0,0 +1,149 @@ +// This file was generated by counterfeiter +package fake_ssh + +import ( + "sync" + + "golang.org/x/crypto/ssh" +) + +type FakeNewChannel struct { + AcceptStub func() (ssh.Channel, <-chan *ssh.Request, error) + acceptMutex sync.RWMutex + acceptArgsForCall []struct{} + acceptReturns struct { + result1 ssh.Channel + result2 <-chan *ssh.Request + result3 error + } + RejectStub func(reason ssh.RejectionReason, message string) error + rejectMutex sync.RWMutex + rejectArgsForCall []struct { + reason ssh.RejectionReason + message string + } + rejectReturns struct { + result1 error + } + ChannelTypeStub func() string + channelTypeMutex sync.RWMutex + channelTypeArgsForCall []struct{} + channelTypeReturns struct { + result1 string + } + ExtraDataStub func() []byte + extraDataMutex sync.RWMutex + extraDataArgsForCall []struct{} + extraDataReturns struct { + result1 []byte + } +} + +func (fake *FakeNewChannel) Accept() (ssh.Channel, <-chan *ssh.Request, error) { + fake.acceptMutex.Lock() + fake.acceptArgsForCall = append(fake.acceptArgsForCall, struct{}{}) + fake.acceptMutex.Unlock() + if fake.AcceptStub != nil { + return fake.AcceptStub() + } else { + return fake.acceptReturns.result1, fake.acceptReturns.result2, fake.acceptReturns.result3 + } +} + +func (fake *FakeNewChannel) AcceptCallCount() int { + fake.acceptMutex.RLock() + defer fake.acceptMutex.RUnlock() + return len(fake.acceptArgsForCall) +} + +func (fake *FakeNewChannel) AcceptReturns(result1 ssh.Channel, result2 <-chan *ssh.Request, result3 error) { + fake.AcceptStub = nil + fake.acceptReturns = struct { + result1 ssh.Channel + result2 <-chan *ssh.Request + result3 error + }{result1, result2, result3} +} + +func (fake *FakeNewChannel) Reject(reason ssh.RejectionReason, message string) error { + fake.rejectMutex.Lock() + fake.rejectArgsForCall = append(fake.rejectArgsForCall, struct { + reason ssh.RejectionReason + message string + }{reason, message}) + fake.rejectMutex.Unlock() + if fake.RejectStub != nil { + return fake.RejectStub(reason, message) + } else { + return fake.rejectReturns.result1 + } +} + +func (fake *FakeNewChannel) RejectCallCount() int { + fake.rejectMutex.RLock() + defer fake.rejectMutex.RUnlock() + return len(fake.rejectArgsForCall) +} + +func (fake *FakeNewChannel) RejectArgsForCall(i int) (ssh.RejectionReason, string) { + fake.rejectMutex.RLock() + defer fake.rejectMutex.RUnlock() + return fake.rejectArgsForCall[i].reason, fake.rejectArgsForCall[i].message +} + +func (fake *FakeNewChannel) RejectReturns(result1 error) { + fake.RejectStub = nil + fake.rejectReturns = struct { + result1 error + }{result1} +} + +func (fake *FakeNewChannel) ChannelType() string { + fake.channelTypeMutex.Lock() + fake.channelTypeArgsForCall = append(fake.channelTypeArgsForCall, struct{}{}) + fake.channelTypeMutex.Unlock() + if fake.ChannelTypeStub != nil { + return fake.ChannelTypeStub() + } else { + return fake.channelTypeReturns.result1 + } +} + +func (fake *FakeNewChannel) ChannelTypeCallCount() int { + fake.channelTypeMutex.RLock() + defer fake.channelTypeMutex.RUnlock() + return len(fake.channelTypeArgsForCall) +} + +func (fake *FakeNewChannel) ChannelTypeReturns(result1 string) { + fake.ChannelTypeStub = nil + fake.channelTypeReturns = struct { + result1 string + }{result1} +} + +func (fake *FakeNewChannel) ExtraData() []byte { + fake.extraDataMutex.Lock() + fake.extraDataArgsForCall = append(fake.extraDataArgsForCall, struct{}{}) + fake.extraDataMutex.Unlock() + if fake.ExtraDataStub != nil { + return fake.ExtraDataStub() + } else { + return fake.extraDataReturns.result1 + } +} + +func (fake *FakeNewChannel) ExtraDataCallCount() int { + fake.extraDataMutex.RLock() + defer fake.extraDataMutex.RUnlock() + return len(fake.extraDataArgsForCall) +} + +func (fake *FakeNewChannel) ExtraDataReturns(result1 []byte) { + fake.ExtraDataStub = nil + fake.extraDataReturns = struct { + result1 []byte + }{result1} +} + +var _ ssh.NewChannel = new(FakeNewChannel) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_public_key.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_public_key.go new file mode 100644 index 0000000000..5467ccd043 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/fake_public_key.go @@ -0,0 +1,115 @@ +// This file was generated by counterfeiter +package fake_ssh + +import ( + "sync" + + "golang.org/x/crypto/ssh" +) + +type FakePublicKey struct { + TypeStub func() string + typeMutex sync.RWMutex + typeArgsForCall []struct{} + typeReturns struct { + result1 string + } + MarshalStub func() []byte + marshalMutex sync.RWMutex + marshalArgsForCall []struct{} + marshalReturns struct { + result1 []byte + } + VerifyStub func(data []byte, sig *ssh.Signature) error + verifyMutex sync.RWMutex + verifyArgsForCall []struct { + data []byte + sig *ssh.Signature + } + verifyReturns struct { + result1 error + } +} + +func (fake *FakePublicKey) Type() string { + fake.typeMutex.Lock() + fake.typeArgsForCall = append(fake.typeArgsForCall, struct{}{}) + fake.typeMutex.Unlock() + if fake.TypeStub != nil { + return fake.TypeStub() + } else { + return fake.typeReturns.result1 + } +} + +func (fake *FakePublicKey) TypeCallCount() int { + fake.typeMutex.RLock() + defer fake.typeMutex.RUnlock() + return len(fake.typeArgsForCall) +} + +func (fake *FakePublicKey) TypeReturns(result1 string) { + fake.TypeStub = nil + fake.typeReturns = struct { + result1 string + }{result1} +} + +func (fake *FakePublicKey) Marshal() []byte { + fake.marshalMutex.Lock() + fake.marshalArgsForCall = append(fake.marshalArgsForCall, struct{}{}) + fake.marshalMutex.Unlock() + if fake.MarshalStub != nil { + return fake.MarshalStub() + } else { + return fake.marshalReturns.result1 + } +} + +func (fake *FakePublicKey) MarshalCallCount() int { + fake.marshalMutex.RLock() + defer fake.marshalMutex.RUnlock() + return len(fake.marshalArgsForCall) +} + +func (fake *FakePublicKey) MarshalReturns(result1 []byte) { + fake.MarshalStub = nil + fake.marshalReturns = struct { + result1 []byte + }{result1} +} + +func (fake *FakePublicKey) Verify(data []byte, sig *ssh.Signature) error { + fake.verifyMutex.Lock() + fake.verifyArgsForCall = append(fake.verifyArgsForCall, struct { + data []byte + sig *ssh.Signature + }{data, sig}) + fake.verifyMutex.Unlock() + if fake.VerifyStub != nil { + return fake.VerifyStub(data, sig) + } else { + return fake.verifyReturns.result1 + } +} + +func (fake *FakePublicKey) VerifyCallCount() int { + fake.verifyMutex.RLock() + defer fake.verifyMutex.RUnlock() + return len(fake.verifyArgsForCall) +} + +func (fake *FakePublicKey) VerifyArgsForCall(i int) ([]byte, *ssh.Signature) { + fake.verifyMutex.RLock() + defer fake.verifyMutex.RUnlock() + return fake.verifyArgsForCall[i].data, fake.verifyArgsForCall[i].sig +} + +func (fake *FakePublicKey) VerifyReturns(result1 error) { + fake.VerifyStub = nil + fake.verifyReturns = struct { + result1 error + }{result1} +} + +var _ ssh.PublicKey = new(FakePublicKey) diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/package.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/package.go new file mode 100644 index 0000000000..cec937c06e --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh/package.go @@ -0,0 +1 @@ +package fake_ssh // import "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/package.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/package.go new file mode 100644 index 0000000000..8e730a12a2 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/package.go @@ -0,0 +1 @@ +package test_helpers // import "code.cloudfoundry.org/diego-ssh/test_helpers" diff --git a/src/code.cloudfoundry.org/diego-ssh/test_helpers/test_helpers.go b/src/code.cloudfoundry.org/diego-ssh/test_helpers/test_helpers.go new file mode 100644 index 0000000000..ad024502f1 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/test_helpers/test_helpers.go @@ -0,0 +1,75 @@ +package test_helpers + +import ( + "net" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "golang.org/x/crypto/ssh" +) + +func WaitFor(f func() error) error { + ch := make(chan error) + go func() { + err := f() + ch <- err + }() + var err error + Eventually(ch, 10).Should(Receive(&err)) + return err +} + +func Pipe() (net.Conn, net.Conn) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + Expect(err).NotTo(HaveOccurred()) + + address := listener.Addr().String() + + serverConnCh := make(chan net.Conn, 1) + go func(serverConnCh chan net.Conn, listener net.Listener) { + defer GinkgoRecover() + conn, err := listener.Accept() + Expect(err).NotTo(HaveOccurred()) + + serverConnCh <- conn + }(serverConnCh, listener) + + clientConn, err := net.Dial("tcp", address) + Expect(err).NotTo(HaveOccurred()) + + return <-serverConnCh, clientConn +} + +func NewClient(clientNetConn net.Conn, clientConfig *ssh.ClientConfig) *ssh.Client { + if clientConfig == nil { + clientConfig = &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("secret"), + }, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + } + + clientConn, clientChannels, clientRequests, clientConnErr := ssh.NewClientConn(clientNetConn, "0.0.0.0", clientConfig) + Expect(clientConnErr).NotTo(HaveOccurred()) + + return ssh.NewClient(clientConn, clientChannels, clientRequests) +} + +type TestNetError struct { + timeout bool + temporary bool +} + +func NewTestNetError(timeout, temporary bool) *TestNetError { + return &TestNetError{ + timeout: timeout, + temporary: temporary, + } +} + +func (e *TestNetError) Error() string { return "test error" } +func (e *TestNetError) Timeout() bool { return e.timeout } +func (e *TestNetError) Temporary() bool { return e.temporary } diff --git a/src/code.cloudfoundry.org/diego-ssh/winpty/package.go b/src/code.cloudfoundry.org/diego-ssh/winpty/package.go new file mode 100644 index 0000000000..58640a8f82 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/winpty/package.go @@ -0,0 +1 @@ +package winpty // import "code.cloudfoundry.org/diego-ssh/winpty" diff --git a/src/code.cloudfoundry.org/diego-ssh/winpty/winpty.go b/src/code.cloudfoundry.org/diego-ssh/winpty/winpty.go new file mode 100644 index 0000000000..38760571b7 --- /dev/null +++ b/src/code.cloudfoundry.org/diego-ssh/winpty/winpty.go @@ -0,0 +1,368 @@ +//go:build windows + +package winpty + +import ( + "errors" + "fmt" + "math" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" + "syscall" + "unicode/utf16" + "unsafe" + + "golang.org/x/sys/windows" +) + +var ( + winpty *windows.DLL + winpty_config_new *windows.Proc + winpty_config_free *windows.Proc + winpty_error_free *windows.Proc + winpty_error_msg *windows.Proc + winpty_open *windows.Proc + winpty_free *windows.Proc + winpty_conin_name *windows.Proc + winpty_conout_name *windows.Proc + winpty_spawn_config_new *windows.Proc + winpty_spawn_config_free *windows.Proc + winpty_spawn *windows.Proc + winpty_set_size *windows.Proc +) + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + terminateProcess = kernel32.NewProc("TerminateProcess") +) + +type WinPTY struct { + StdIn *os.File + StdOut *os.File + + winPTYHandle uintptr + childHandle uintptr +} + +const ( + WINPTY_SPAWN_FLAG_AUTO_SHUTDOWN = uint64(1) +) + +func New(winPTYDLLDir string) (*WinPTY, error) { + var err error + winpty, err = windows.LoadDLL(filepath.Join(winPTYDLLDir, "winpty.dll")) + if err != nil { + return nil, err + } + winpty_config_new, err = winpty.FindProc("winpty_config_new") + if err != nil { + return nil, err + } + winpty_config_free, err = winpty.FindProc("winpty_config_free") + if err != nil { + return nil, err + } + winpty_error_free, err = winpty.FindProc("winpty_error_free") + if err != nil { + return nil, err + } + winpty_error_msg, err = winpty.FindProc("winpty_error_msg") + if err != nil { + return nil, err + } + winpty_open, err = winpty.FindProc("winpty_open") + if err != nil { + return nil, err + } + winpty_free, err = winpty.FindProc("winpty_free") + if err != nil { + return nil, err + } + winpty_conin_name, err = winpty.FindProc("winpty_conin_name") + if err != nil { + return nil, err + } + winpty_conout_name, err = winpty.FindProc("winpty_conout_name") + if err != nil { + return nil, err + } + winpty_spawn_config_new, err = winpty.FindProc("winpty_spawn_config_new") + if err != nil { + return nil, err + } + winpty_spawn_config_free, err = winpty.FindProc("winpty_spawn_config_free") + if err != nil { + return nil, err + } + winpty_spawn, err = winpty.FindProc("winpty_spawn") + if err != nil { + return nil, err + } + winpty_set_size, err = winpty.FindProc("winpty_set_size") + if err != nil { + return nil, err + } + + var errorPtr uintptr + defer winpty_error_free.Call(errorPtr) + agentCfg, _, _ := winpty_config_new.Call(uintptr(0), uintptr(unsafe.Pointer(&errorPtr))) + if agentCfg == 0 { + return nil, fmt.Errorf("unable to create agent config: %s", winPTYErrorMessage(errorPtr)) + } + + winPTYHandle, _, _ := winpty_open.Call(agentCfg, uintptr(unsafe.Pointer(&errorPtr))) + if winPTYHandle == 0 { + return nil, fmt.Errorf("unable to launch WinPTY agent: %s", winPTYErrorMessage(errorPtr)) + } + winpty_config_free.Call(agentCfg) + + return &WinPTY{ + winPTYHandle: winPTYHandle, + }, nil +} + +// unsafeExternPointer converts a uintptr address known to be a valid pointer +// external to the Go heap — such as one returned by the mmap system call — to +// an unsafe.Pointer, without triggering the unsafeptr vet warning. +// Taken from https://go-review.googlesource.com/c/sys/+/465235 +func unsafeExternPointer(addr uintptr) unsafe.Pointer { + // Converting a uintptr directly to an unsafe.Pointer triggers a vet warning, + // because a uintptr cannot safely hold a pointer to the Go heap. (Because a + // uintptr may hold an integer, uintptr values are not traced during garbage + // collection and are not updated during stack resizing.) + // + // However, if we know that the address is not owned by the Go heap, it does + // not need to be traced by the GC and cannot be implicitly relocated. + // We silence the unsafeptr warning by converting a pointer-to-uintptr to + // a pointer-to-pointer. + return *(*unsafe.Pointer)(unsafe.Pointer(&addr)) +} + +func (w *WinPTY) Open() error { + if w.winPTYHandle == 0 { + return errors.New("winpty dll not initialized") + } + + stdinName, _, err := winpty_conin_name.Call(w.winPTYHandle) + if stdinName == 0 { + return fmt.Errorf("unable to get stdin pipe name: %s", err.Error()) + } + + stdoutName, _, err := winpty_conout_name.Call(w.winPTYHandle) + if stdoutName == 0 { + return fmt.Errorf("unable to get stdout pipe name: %s", err.Error()) + } + + stdinHandle, err := syscall.CreateFile((*uint16)(unsafeExternPointer(stdinName)), syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, 0, 0) + if err != nil { + return fmt.Errorf("unable to open stdin pipe: %s", err.Error()) + } + + stdoutHandle, err := syscall.CreateFile((*uint16)(unsafeExternPointer(stdoutName)), syscall.GENERIC_READ, 0, nil, syscall.OPEN_EXISTING, 0, 0) + if err != nil { + return fmt.Errorf("unable to open stdout pipe: %s", err.Error()) + } + + w.StdIn = os.NewFile(uintptr(stdinHandle), "stdin") + w.StdOut = os.NewFile(uintptr(stdoutHandle), "stdout") + return nil +} + +func (w *WinPTY) Run(cmd *exec.Cmd) error { + escaped := makeCmdLine(append([]string{cmd.Path}, cmd.Args...)) + cmdLineStr, err := syscall.UTF16PtrFromString(escaped) + if err != nil { + w.StdOut.Close() + return fmt.Errorf("failed to convert cmd (%s) to pointer: %s", escaped, err.Error()) + } + + env := "" + for _, val := range cmd.Env { + env += (val + "\x00") + } + + var envPtr *uint16 + if env != "" { + envPtr = &utf16.Encode([]rune(env))[0] + } + + var errorPtr uintptr + defer winpty_error_free.Call(errorPtr) + spawnCfg, _, _ := winpty_spawn_config_new.Call( + uintptr(uint64(WINPTY_SPAWN_FLAG_AUTO_SHUTDOWN)), + uintptr(0), + uintptr(unsafe.Pointer(cmdLineStr)), + uintptr(0), + uintptr(unsafe.Pointer(envPtr)), + uintptr(unsafe.Pointer(&errorPtr))) + if spawnCfg == 0 { + w.StdOut.Close() + return fmt.Errorf("unable to create process config: %s", winPTYErrorMessage(errorPtr)) + } + + var createProcessErr uint32 + // we ignore err here because Windows is Windows, and we generate everything based off of spawnRet + spawnRet, _, _ := winpty_spawn.Call(w.winPTYHandle, + spawnCfg, + uintptr(unsafe.Pointer(&w.childHandle)), + uintptr(0), + uintptr(unsafe.Pointer(&createProcessErr)), + uintptr(unsafe.Pointer(&errorPtr))) + winpty_spawn_config_free.Call(spawnCfg) + if spawnRet == 0 { + w.StdOut.Close() + return fmt.Errorf("unable to spawn process: %s: %s", winPTYErrorMessage(errorPtr), windowsErrorMessage(createProcessErr)) + } + + return nil +} + +func (w *WinPTY) Wait() error { + _, err := syscall.WaitForSingleObject(syscall.Handle(w.childHandle), math.MaxUint32) + if err != nil { + return fmt.Errorf("unable to wait for child process: %s", err.Error()) + } + + var exitCode uint32 + err = syscall.GetExitCodeProcess(syscall.Handle(w.childHandle), &exitCode) + if err != nil { + return fmt.Errorf("couldn't get child exit code: %s", err.Error()) + } + + if exitCode != 0 { + return &ExitError{WaitStatus: syscall.WaitStatus{ExitCode: exitCode}} + } + + return nil +} + +type ExitError struct { + WaitStatus syscall.WaitStatus +} + +func (ee *ExitError) Error() string { + return fmt.Sprintf("exit code %d", ee.WaitStatus.ExitCode) +} + +func (w *WinPTY) Close() { + if w.winPTYHandle == 0 { + return + } + + winpty_free.Call(w.winPTYHandle) + + if w.StdIn != nil { + w.StdIn.Close() + } + + if w.StdOut != nil { + w.StdOut.Close() + } + + if w.childHandle != 0 { + syscall.CloseHandle(syscall.Handle(w.childHandle)) + } +} + +func (w *WinPTY) SetWinsize(columns, rows uint32) error { + if columns == 0 || rows == 0 { + return nil + } + ret, _, err := winpty_set_size.Call(w.winPTYHandle, uintptr(columns), uintptr(rows), uintptr(0)) + if ret == 0 { + return fmt.Errorf("failed to set window size: %s", err.Error()) + } + return nil +} + +func (w *WinPTY) Signal(sig syscall.Signal) error { + if sig == syscall.SIGINT { + return w.sendCtrlC() + } else if sig == syscall.SIGKILL { + return w.terminateChild() + } + + return syscall.Errno(syscall.EWINDOWS) +} + +func (w *WinPTY) sendCtrlC() error { + if w.childHandle == 0 { + return nil + } + + // 0x03 is Ctrl+C + // this tells the agent to generate Ctrl+C in the child process + // https://github.com/rprichard/winpty/blob/4978cf94b6ea48e38eea3146bd0d23210f87aa89/src/agent/ConsoleInput.cc#L387 + _, err := w.StdIn.Write([]byte{0x03}) + if err != nil { + return fmt.Errorf("couldn't send ctrl+c to child: %s", err.Error()) + } + return nil +} + +func (w *WinPTY) terminateChild() error { + if w.childHandle == 0 { + return nil + } + ret, _, err := terminateProcess.Call(w.childHandle, 1) + if ret == 0 { + return fmt.Errorf("failed to terminate child process: %s", err.Error()) + } + return nil +} + +func winPTYErrorMessage(ptr uintptr) string { + msgPtr, _, err := winpty_error_msg.Call(ptr) + if msgPtr == 0 { + return fmt.Sprintf("unknown error, couldn't convert: %s", err.Error()) + } + + out := make([]uint16, 0) + p := unsafeExternPointer(msgPtr) + + for { + val := *(*uint16)(p) + if val == 0 { + break + } + + out = append(out, val) + p = unsafe.Pointer(uintptr(p) + unsafe.Sizeof(uint16(0))) + } + return string(utf16.Decode(out)) +} + +func windowsErrorMessage(code uint32) string { + flags := uint32(windows.FORMAT_MESSAGE_FROM_SYSTEM | windows.FORMAT_MESSAGE_IGNORE_INSERTS) + langId := uint32(windows.SUBLANG_ENGLISH_US)<<10 | uint32(windows.LANG_ENGLISH) + buf := make([]uint16, 512) + + _, err := windows.FormatMessage(flags, uintptr(0), code, langId, buf, nil) + if err != nil { + return fmt.Sprintf("0x%x", code) + } + return strings.TrimSpace(syscall.UTF16ToString(buf)) +} + +func makeCmdLine(args []string) string { + if len(args) > 0 { + args[0] = filepath.Clean(args[0]) + base := filepath.Base(args[0]) + match, _ := regexp.MatchString(`\.[a-zA-Z]{3}$`, base) + if !match { + args[0] += ".exe" + } + } + var s string + for _, v := range args { + if s != "" { + s += " " + } + s += syscall.EscapeArg(v) + } + + return s +}