Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion api/types/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
// TestDatabaseRDSEndpoint verifies AWS info is correctly populated
// based on the RDS endpoint.
func TestDatabaseRDSEndpoint(t *testing.T) {
isBadParamErrFn := func(tt require.TestingT, err error, i ...interface{}) {
isBadParamErrFn := func(tt require.TestingT, err error, i ...any) {
require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
}

Expand Down
39 changes: 39 additions & 0 deletions api/utils/aws/identifiers.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ limitations under the License.
package aws

import (
"regexp"
"strings"
"unicode"

"github.com/gravitational/trace"
)

Expand All @@ -35,3 +39,38 @@ func IsValidAccountID(accountID string) error {

return nil
}

// matchRoleName is a regex that matches against AWS IAM Role Names.
var matchRoleName = regexp.MustCompile(`^[\w+=,.@-]+$`).MatchString

// IsValidIAMRoleName checks whether the role name is a valid AWS IAM Role identifier.
//
// > Length Constraints: Minimum length of 1. Maximum length of 64.
// > Pattern: [\w+=,.@-]+
// https://docs.aws.amazon.com/IAM/latest/APIReference/API_CreateRole.html
func IsValidIAMRoleName(roleName string) error {
if len(roleName) == 0 || len(roleName) > 64 || !matchRoleName(roleName) {
return trace.BadParameter("role is invalid")
}

return nil
}

// IsValidRegion ensures the region looks to be valid.
// It does not do a full validation, because AWS doesn't provide documentation for that.
// However, they usually only have the following chars: [a-z0-9\-]
func IsValidRegion(region string) error {
indexNotFound := -1

if len(region) == 0 {
return trace.BadParameter("region is invalid")
}

if strings.IndexFunc(region, func(r rune) bool {
return !(unicode.IsDigit(r) || unicode.IsLetter(r) || r == '-')
}) == indexNotFound {
return nil
}

return trace.BadParameter("region is invalid")
}
97 changes: 96 additions & 1 deletion api/utils/aws/identifiers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@ limitations under the License.
package aws

import (
"strings"
"testing"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
)

func TestIsValidAccountID(t *testing.T) {
isBadParamErrFn := func(tt require.TestingT, err error, i ...interface{}) {
isBadParamErrFn := func(tt require.TestingT, err error, i ...any) {
require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
}

Expand Down Expand Up @@ -74,3 +75,97 @@ func TestIsValidAccountID(t *testing.T) {
})
}
}

func TestIsValidIAMRoleName(t *testing.T) {
isBadParamErrFn := func(tt require.TestingT, err error, i ...any) {
require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
}

for _, tt := range []struct {
name string
role string
errCheck require.ErrorAssertionFunc
}{
{
name: "valid",
role: "valid",
errCheck: require.NoError,
},
{
name: "valid with numbers",
role: "00VALID11",
errCheck: require.NoError,
},
{
name: "only one symbol",
role: "_",
errCheck: require.NoError,
},
{
name: "all symbols",
role: "Test+1=2,3.4@5-6_7",
errCheck: require.NoError,
},
{
name: "empty",
role: "",
errCheck: isBadParamErrFn,
},
{
name: "too large",
role: strings.Repeat("r", 65),
errCheck: isBadParamErrFn,
},
{
name: "invalid symbols",
role: "role/admin",
errCheck: isBadParamErrFn,
},
} {
t.Run(tt.name, func(t *testing.T) {
tt.errCheck(t, IsValidIAMRoleName(tt.role))
})
}
}

func TestIsValidRegion(t *testing.T) {
isBadParamErrFn := func(tt require.TestingT, err error, i ...any) {
require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
}

for _, tt := range []struct {
name string
region string
errCheck require.ErrorAssertionFunc
}{
{
name: "us region",
region: "us-east-1",
errCheck: require.NoError,
},
{
name: "eu region",
region: "eu-west-1",
errCheck: require.NoError,
},
{
name: "us gov",
region: "us-gov-east-1",
errCheck: require.NoError,
},
{
name: "empty",
region: "",
errCheck: isBadParamErrFn,
},
{
name: "symbols",
region: "us@east-1",
errCheck: isBadParamErrFn,
},
} {
t.Run(tt.name, func(t *testing.T) {
tt.errCheck(t, IsValidRegion(tt.region))
})
}
}
2 changes: 1 addition & 1 deletion lib/integrations/awsoidc/deployservice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
)

func TestDeployServiceRequest(t *testing.T) {
isBadParamErrFn := func(tt require.TestingT, err error, i ...interface{}) {
isBadParamErrFn := func(tt require.TestingT, err error, i ...any) {
require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
}

Expand Down
1 change: 1 addition & 0 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,7 @@ func (h *Handler) bindDefaultEndpoints() {
// AWS OIDC Integration Actions
h.POST("/webapi/sites/:site/integrations/aws-oidc/:name/databases", h.WithClusterAuth(h.awsOIDCListDatabases))
h.POST("/webapi/sites/:site/integrations/aws-oidc/:name/deployservice", h.WithClusterAuth(h.awsOIDCDeployService))
h.GET("/webapi/scripts/integrations/configure/deployservice-iam.sh", h.WithLimiter(h.awsOIDCConfigureDeployServiceIAM))

// AWS OIDC Integration specific endpoints:
// Unauthenticated access to OpenID Configuration - used for AWS OIDC IdP integration
Expand Down
65 changes: 65 additions & 0 deletions lib/web/integrations_awsoidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,20 @@ package web

import (
"context"
"fmt"
"net/http"
"strings"

"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/integrations/awsoidc"
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/web/scripts/oneoff"
"github.com/gravitational/teleport/lib/web/ui"
)

Expand Down Expand Up @@ -160,3 +164,64 @@ func (h *Handler) awsOIDCDeployService(w http.ResponseWriter, r *http.Request, p
ServiceDashboardURL: deployServiceResp.ServiceDashboardURL,
}, nil
}

// awsOIDCConfigureDeployServiceIAM returns a script that configures the required IAM permissions to enable the usage of DeployService action.
func (h *Handler) awsOIDCConfigureDeployServiceIAM(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) {
ctx := r.Context()

queryParams := r.URL.Query()

clusterName, err := h.GetProxyClient().GetDomainName(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

integrationName := queryParams.Get("integrationName")
if len(integrationName) == 0 {
return nil, trace.BadParameter("missing integrationName param")
}

// Ensure the IntegrationName is valid.
_, err = h.GetProxyClient().GetIntegration(ctx, integrationName)
// NotFound error is ignored to prevent disclosure of whether the integration exists in a public/no-auth endpoint.
if err != nil && !trace.IsNotFound(err) {
return nil, trace.Wrap(err)
}

awsRegion := queryParams.Get("awsRegion")
if err := aws.IsValidRegion(awsRegion); err != nil {
return nil, trace.BadParameter("invalid awsRegion")
}

role := queryParams.Get("role")
if err := aws.IsValidIAMRoleName(role); err != nil {
return nil, trace.BadParameter("invalid role")
}

taskRole := queryParams.Get("taskRole")
if err := aws.IsValidIAMRoleName(taskRole); err != nil {
return nil, trace.BadParameter("invalid taskRole")
}

// The script must execute the following command:
// teleport integration configure deployservice-iam
argsList := []string{
"integration", "configure", "deployservice-iam",
fmt.Sprintf(`--cluster="%s"`, clusterName),
fmt.Sprintf(`--name="%s"`, integrationName),
fmt.Sprintf(`--aws-region="%s"`, awsRegion),
fmt.Sprintf(`--role="%s"`, role),
fmt.Sprintf(`--task-role="%s"`, taskRole),
}
script, err := oneoff.BuildScript(oneoff.OneOffScriptParams{
TeleportArgs: strings.Join(argsList, " "),
})
if err != nil {
return nil, trace.Wrap(err)
}

httplib.SetScriptHeaders(w.Header())
fmt.Fprint(w, script)

return nil, trace.Wrap(err)
}
Loading