diff --git a/api/types/database_test.go b/api/types/database_test.go index 7e8b9aa0b837c..769fd05bc9607 100644 --- a/api/types/database_test.go +++ b/api/types/database_test.go @@ -28,7 +28,7 @@ import ( // TestDatabaseRDSEndpoint verifies AWS info is correctly populated // based on the RDS endpoint. func TestDatabaseRDSEndpoint(t *testing.T) { - isBadParamErrFn := func(tt require.TestingT, err error, i ...interface{}) { + isBadParamErrFn := func(tt require.TestingT, err error, i ...any) { require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) } diff --git a/api/utils/aws/identifiers.go b/api/utils/aws/identifiers.go index 2094b6e65e032..330fb8319478f 100644 --- a/api/utils/aws/identifiers.go +++ b/api/utils/aws/identifiers.go @@ -17,6 +17,10 @@ limitations under the License. package aws import ( + "regexp" + "strings" + "unicode" + "github.com/gravitational/trace" ) @@ -35,3 +39,38 @@ func IsValidAccountID(accountID string) error { return nil } + +// matchRoleName is a regex that matches against AWS IAM Role Names. +var matchRoleName = regexp.MustCompile(`^[\w+=,.@-]+$`).MatchString + +// IsValidIAMRoleName checks whether the role name is a valid AWS IAM Role identifier. +// +// > Length Constraints: Minimum length of 1. Maximum length of 64. +// > Pattern: [\w+=,.@-]+ +// https://docs.aws.amazon.com/IAM/latest/APIReference/API_CreateRole.html +func IsValidIAMRoleName(roleName string) error { + if len(roleName) == 0 || len(roleName) > 64 || !matchRoleName(roleName) { + return trace.BadParameter("role is invalid") + } + + return nil +} + +// IsValidRegion ensures the region looks to be valid. +// It does not do a full validation, because AWS doesn't provide documentation for that. +// However, they usually only have the following chars: [a-z0-9\-] +func IsValidRegion(region string) error { + indexNotFound := -1 + + if len(region) == 0 { + return trace.BadParameter("region is invalid") + } + + if strings.IndexFunc(region, func(r rune) bool { + return !(unicode.IsDigit(r) || unicode.IsLetter(r) || r == '-') + }) == indexNotFound { + return nil + } + + return trace.BadParameter("region is invalid") +} diff --git a/api/utils/aws/identifiers_test.go b/api/utils/aws/identifiers_test.go index 8d33fa9b9eae3..f5718beef88bd 100644 --- a/api/utils/aws/identifiers_test.go +++ b/api/utils/aws/identifiers_test.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "strings" "testing" "github.com/gravitational/trace" @@ -24,7 +25,7 @@ import ( ) func TestIsValidAccountID(t *testing.T) { - isBadParamErrFn := func(tt require.TestingT, err error, i ...interface{}) { + isBadParamErrFn := func(tt require.TestingT, err error, i ...any) { require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) } @@ -74,3 +75,97 @@ func TestIsValidAccountID(t *testing.T) { }) } } + +func TestIsValidIAMRoleName(t *testing.T) { + isBadParamErrFn := func(tt require.TestingT, err error, i ...any) { + require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) + } + + for _, tt := range []struct { + name string + role string + errCheck require.ErrorAssertionFunc + }{ + { + name: "valid", + role: "valid", + errCheck: require.NoError, + }, + { + name: "valid with numbers", + role: "00VALID11", + errCheck: require.NoError, + }, + { + name: "only one symbol", + role: "_", + errCheck: require.NoError, + }, + { + name: "all symbols", + role: "Test+1=2,3.4@5-6_7", + errCheck: require.NoError, + }, + { + name: "empty", + role: "", + errCheck: isBadParamErrFn, + }, + { + name: "too large", + role: strings.Repeat("r", 65), + errCheck: isBadParamErrFn, + }, + { + name: "invalid symbols", + role: "role/admin", + errCheck: isBadParamErrFn, + }, + } { + t.Run(tt.name, func(t *testing.T) { + tt.errCheck(t, IsValidIAMRoleName(tt.role)) + }) + } +} + +func TestIsValidRegion(t *testing.T) { + isBadParamErrFn := func(tt require.TestingT, err error, i ...any) { + require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) + } + + for _, tt := range []struct { + name string + region string + errCheck require.ErrorAssertionFunc + }{ + { + name: "us region", + region: "us-east-1", + errCheck: require.NoError, + }, + { + name: "eu region", + region: "eu-west-1", + errCheck: require.NoError, + }, + { + name: "us gov", + region: "us-gov-east-1", + errCheck: require.NoError, + }, + { + name: "empty", + region: "", + errCheck: isBadParamErrFn, + }, + { + name: "symbols", + region: "us@east-1", + errCheck: isBadParamErrFn, + }, + } { + t.Run(tt.name, func(t *testing.T) { + tt.errCheck(t, IsValidRegion(tt.region)) + }) + } +} diff --git a/lib/integrations/awsoidc/deployservice_test.go b/lib/integrations/awsoidc/deployservice_test.go index 1453ff11b9c01..647090d34e158 100644 --- a/lib/integrations/awsoidc/deployservice_test.go +++ b/lib/integrations/awsoidc/deployservice_test.go @@ -28,7 +28,7 @@ import ( ) func TestDeployServiceRequest(t *testing.T) { - isBadParamErrFn := func(tt require.TestingT, err error, i ...interface{}) { + isBadParamErrFn := func(tt require.TestingT, err error, i ...any) { require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) } diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index dc66d30739a2b..0f6fee43f318d 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -743,6 +743,7 @@ func (h *Handler) bindDefaultEndpoints() { // AWS OIDC Integration Actions h.POST("/webapi/sites/:site/integrations/aws-oidc/:name/databases", h.WithClusterAuth(h.awsOIDCListDatabases)) h.POST("/webapi/sites/:site/integrations/aws-oidc/:name/deployservice", h.WithClusterAuth(h.awsOIDCDeployService)) + h.GET("/webapi/scripts/integrations/configure/deployservice-iam.sh", h.WithLimiter(h.awsOIDCConfigureDeployServiceIAM)) // AWS OIDC Integration specific endpoints: // Unauthenticated access to OpenID Configuration - used for AWS OIDC IdP integration diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index 7fe849979bcdb..cd1f0ed04097f 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -15,16 +15,20 @@ package web import ( "context" + "fmt" "net/http" + "strings" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/web/scripts/oneoff" "github.com/gravitational/teleport/lib/web/ui" ) @@ -160,3 +164,64 @@ func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p ServiceDashboardURL: deployServiceResp.ServiceDashboardURL, }, nil } + +// awsOIDCConfigureDeployServiceIAM returns a script that configures the required IAM permissions to enable the usage of DeployService action. +func (h *Handler) awsOIDCConfigureDeployServiceIAM(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) { + ctx := r.Context() + + queryParams := r.URL.Query() + + clusterName, err := h.GetProxyClient().GetDomainName(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + integrationName := queryParams.Get("integrationName") + if len(integrationName) == 0 { + return nil, trace.BadParameter("missing integrationName param") + } + + // Ensure the IntegrationName is valid. + _, err = h.GetProxyClient().GetIntegration(ctx, integrationName) + // NotFound error is ignored to prevent disclosure of whether the integration exists in a public/no-auth endpoint. + if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err) + } + + awsRegion := queryParams.Get("awsRegion") + if err := aws.IsValidRegion(awsRegion); err != nil { + return nil, trace.BadParameter("invalid awsRegion") + } + + role := queryParams.Get("role") + if err := aws.IsValidIAMRoleName(role); err != nil { + return nil, trace.BadParameter("invalid role") + } + + taskRole := queryParams.Get("taskRole") + if err := aws.IsValidIAMRoleName(taskRole); err != nil { + return nil, trace.BadParameter("invalid taskRole") + } + + // The script must execute the following command: + // teleport integration configure deployservice-iam + argsList := []string{ + "integration", "configure", "deployservice-iam", + fmt.Sprintf(`--cluster="%s"`, clusterName), + fmt.Sprintf(`--name="%s"`, integrationName), + fmt.Sprintf(`--aws-region="%s"`, awsRegion), + fmt.Sprintf(`--role="%s"`, role), + fmt.Sprintf(`--task-role="%s"`, taskRole), + } + script, err := oneoff.BuildScript(oneoff.OneOffScriptParams{ + TeleportArgs: strings.Join(argsList, " "), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + httplib.SetScriptHeaders(w.Header()) + fmt.Fprint(w, script) + + return nil, trace.Wrap(err) +} diff --git a/lib/web/integrations_awsoidc_test.go b/lib/web/integrations_awsoidc_test.go new file mode 100644 index 0000000000000..3fbc06987b9e5 --- /dev/null +++ b/lib/web/integrations_awsoidc_test.go @@ -0,0 +1,146 @@ +/* +Copyright 2023 Gravitational, Inc. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package web + +import ( + "context" + "fmt" + "net/url" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" +) + +func TestBuildDeployServiceConfigureIAMScript(t *testing.T) { + isBadParamErrFn := func(tt require.TestingT, err error, i ...any) { + require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) + } + + ctx := context.Background() + env := newWebPack(t, 1) + + // Unauthenticated client for script downloading. + publicClt := env.proxies[0].newClient(t) + pathVars := []string{ + "webapi", + "scripts", + "integrations", + "configure", + "deployservice-iam.sh", + } + endpoint := publicClt.Endpoint(pathVars...) + + tests := []struct { + name string + reqRelativeURL string + reqQuery url.Values + errCheck require.ErrorAssertionFunc + expectedTeleportArgs string + }{ + { + name: "valid", + reqQuery: url.Values{ + "awsRegion": []string{"us-east-1"}, + "role": []string{"myRole"}, + "taskRole": []string{"taskRole"}, + "integrationName": []string{"myintegration"}, + }, + errCheck: require.NoError, + expectedTeleportArgs: "integration configure deployservice-iam " + + `--cluster="localhost" ` + + `--name="myintegration" ` + + `--aws-region="us-east-1" ` + + `--role="myRole" ` + + `--task-role="taskRole"`, + }, + { + name: "valid with symbols in role", + reqQuery: url.Values{ + "awsRegion": []string{"us-east-1"}, + "role": []string{"Test+1=2,3.4@5-6_7"}, + "taskRole": []string{"taskRole"}, + "integrationName": []string{"myintegration"}, + }, + errCheck: require.NoError, + expectedTeleportArgs: "integration configure deployservice-iam " + + `--cluster="localhost" ` + + `--name="myintegration" ` + + `--aws-region="us-east-1" ` + + `--role="Test+1=2,3.4@5-6_7" ` + + `--task-role="taskRole"`, + }, + { + name: "missing aws-region", + reqQuery: url.Values{ + "role": []string{"myRole"}, + "taskRole": []string{"taskRole"}, + "integrationName": []string{"myintegration"}, + }, + errCheck: isBadParamErrFn, + }, + { + name: "missing role", + reqQuery: url.Values{ + "awsRegion": []string{"us-east-1"}, + "taskRole": []string{"taskRole"}, + "integrationName": []string{"myintegration"}, + }, + errCheck: isBadParamErrFn, + }, + { + name: "missing task role", + reqQuery: url.Values{ + "awsRegion": []string{"us-east-1"}, + "role": []string{"role"}, + "integrationName": []string{"myintegration"}, + }, + errCheck: isBadParamErrFn, + }, + { + name: "missing integration name", + reqQuery: url.Values{ + "awsRegion": []string{"us-east-1"}, + "role": []string{"role"}, + "taskRole": []string{"taskRole"}, + }, + errCheck: isBadParamErrFn, + }, + { + name: "trying to inject escape sequence into query params", + reqQuery: url.Values{ + "awsRegion": []string{"us-east-1"}, + "role": []string{"role"}, + "taskRole": []string{"taskRole"}, + "integrationName": []string{"'; rm -rf /tmp/dir; echo '"}, + }, + errCheck: isBadParamErrFn, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + resp, err := publicClt.Get(ctx, endpoint, tc.reqQuery) + tc.errCheck(t, err) + if err != nil { + return + } + + require.Contains(t, string(resp.Bytes()), + fmt.Sprintf("teleportArgs='%s'\n", tc.expectedTeleportArgs), + ) + }) + } +} diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go index 5bd30a60f5a58..374c60f082a5f 100644 --- a/lib/web/join_tokens.go +++ b/lib/web/join_tokens.go @@ -210,7 +210,7 @@ func (h *Handler) createTokenHandle(w http.ResponseWriter, r *http.Request, para } func (h *Handler) getNodeJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) { - scripts.SetScriptHeaders(w.Header()) + httplib.SetScriptHeaders(w.Header()) settings := scriptSettings{ token: params.ByName("token"), @@ -236,7 +236,7 @@ func (h *Handler) getNodeJoinScriptHandle(w http.ResponseWriter, r *http.Request } func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) { - scripts.SetScriptHeaders(w.Header()) + httplib.SetScriptHeaders(w.Header()) queryValues := r.URL.Query() name, err := url.QueryUnescape(queryValues.Get("name")) @@ -278,7 +278,7 @@ func (h *Handler) getAppJoinScriptHandle(w http.ResponseWriter, r *http.Request, } func (h *Handler) getDatabaseJoinScriptHandle(w http.ResponseWriter, r *http.Request, params httprouter.Params) (interface{}, error) { - scripts.SetScriptHeaders(w.Header()) + httplib.SetScriptHeaders(w.Header()) settings := scriptSettings{ token: params.ByName("token"), diff --git a/lib/web/scripts/install_node.go b/lib/web/scripts/install_node.go index 91ed32487f8d5..64cbd0ac88776 100644 --- a/lib/web/scripts/install_node.go +++ b/lib/web/scripts/install_node.go @@ -19,7 +19,6 @@ package scripts import ( _ "embed" "fmt" - "net/http" "sort" "strings" "text/template" @@ -31,11 +30,6 @@ import ( "github.com/gravitational/teleport/api/utils" ) -// SetScriptHeaders sets response headers to plain text. -func SetScriptHeaders(h http.Header) { - h.Set("Content-Type", "text/plain") -} - // ErrorBashScript is used to display friendly error message when // there is an error prepping the actual script. var ErrorBashScript = []byte(`