Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions lib/srv/app/aws/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/httplib"
"github.com/gravitational/teleport/lib/srv/app/common"
Expand Down Expand Up @@ -60,6 +59,8 @@ type SignerHandlerConfig struct {
*awsutils.SigningService
// Clock is used to override time in tests.
Clock clockwork.Clock
// MaxHTTPRequestBodySize is the limit on how big a request body can be.
MaxHTTPRequestBodySize int64
}

// CheckAndSetDefaults validates the AwsSignerHandlerConfig.
Expand All @@ -80,6 +81,12 @@ func (cfg *SignerHandlerConfig) CheckAndSetDefaults() error {
if cfg.Clock == nil {
cfg.Clock = clockwork.NewRealClock()
}

// Limit HTTP request body size to 70MB, which matches AWS Lambda function
// zip file upload limit (50MB) after accounting for base64 encoding bloat.
if cfg.MaxHTTPRequestBodySize == 0 {
cfg.MaxHTTPRequestBodySize = 70 << 20
}
return nil
}

Expand Down Expand Up @@ -119,6 +126,7 @@ func (s *signerHandler) formatForwardResponseError(rw http.ResponseWriter, r *ht

// ServeHTTP handles incoming requests by signing them and then forwarding them to the proper AWS API.
func (s *signerHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
req.Body = utils.MaxBytesReader(w, req.Body, s.MaxHTTPRequestBodySize)
if err := s.serveHTTP(w, req); err != nil {
s.formatForwardResponseError(w, req, err)
return
Expand Down Expand Up @@ -230,7 +238,7 @@ func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.Resolved
}
outReq.Body = http.NoBody
if r.Body != nil {
outReq.Body = io.NopCloser(io.LimitReader(r.Body, teleport.MaxHTTPRequestSize))
outReq.Body = r.Body
}
// need to rewrite the host header as well. The oxy forwarder will do this for us,
// since we use the PassHostHeader(false) option, but if host is a signed header
Expand Down
71 changes: 70 additions & 1 deletion lib/srv/app/aws/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"

Expand All @@ -34,6 +35,7 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/lambda"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/google/go-cmp/cmp"
Expand Down Expand Up @@ -94,6 +96,39 @@ func dynamoRequestWithTransport(url string, provider client.ConfigProvider, tran
return err
}

// dont make tests generate huge requests just to test limiting the request
// size. Use a 1MB limit instead of the actual 70MB limit.
const maxTestHTTPRequestBodySize = 1 << 20

func maxSizeExceededRequest(url string, provider client.ConfigProvider, _ string) error {
// fake an upload that's too large
payload := strings.Repeat("x", maxTestHTTPRequestBodySize)
return lambdaRequestWithPayload(url, provider, payload)
}

func lambdaRequest(url string, provider client.ConfigProvider, awsHost string) error {
// fake a zip file with 70% of the max limit. Lambda will base64 encode it,
// which bloats it up, and our proxy should still handle it.
const size = (maxTestHTTPRequestBodySize * 7) / 10
payload := strings.Repeat("x", size)
return lambdaRequestWithPayload(url, provider, payload)
}

func lambdaRequestWithPayload(url string, provider client.ConfigProvider, payload string) error {
lambdaClient := lambda.New(provider, &aws.Config{
Endpoint: &url,
MaxRetries: aws.Int(0),
HTTPClient: &http.Client{
Timeout: 5 * time.Second,
},
})
_, err := lambdaClient.UpdateFunctionCode(&lambda.UpdateFunctionCodeInput{
FunctionName: aws.String("fakeFunc"),
ZipFile: []byte(payload),
})
return err
}

func assumeRoleRequest(requestDuration time.Duration) makeRequest {
return func(url string, provider client.ConfigProvider, _ string) error {
stsClient := sts.New(provider, &aws.Config{
Expand Down Expand Up @@ -289,6 +324,37 @@ func TestAWSSignerHandler(t *testing.T) {
require.NoError,
},
},
{
name: "Lambda access",
app: consoleApp,
awsClientSession: session.Must(session.NewSession(&aws.Config{
Credentials: staticAWSCredentialsForClient,
Region: aws.String("us-east-1"),
})),
request: lambdaRequest,
wantHost: "lambda.us-east-1.amazonaws.com",
wantAuthCredKeyID: "AKIDl",
wantAuthCredService: "lambda",
wantAuthCredRegion: "us-east-1",
wantEventType: &events.AppSessionRequest{},
errAssertionFns: []require.ErrorAssertionFunc{
require.NoError,
},
},
{
name: "Request exceeding max size",
app: consoleApp,
awsClientSession: session.Must(session.NewSession(&aws.Config{
Credentials: staticAWSCredentialsForClient,
Region: aws.String("us-east-1"),
})),
request: maxSizeExceededRequest,
errAssertionFns: []require.ErrorAssertionFunc{
// TODO(gavin): change this to [http.StatusRequestEntityTooLarge]
// after updating [trace.ErrorToCode].
hasStatusCode(http.StatusTooManyRequests),
},
},
{
name: "AssumeRole success (shorter identity duration)",
app: consoleApp,
Expand Down Expand Up @@ -346,7 +412,9 @@ func TestAWSSignerHandler(t *testing.T) {
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
fakeClock := clockwork.NewFakeClock()
mockAwsHandler := func(w http.ResponseWriter, r *http.Request) {
// check that we got what the test case expects first.
Expand Down Expand Up @@ -529,7 +597,8 @@ func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Applic
return net.Dial(awsAPIMock.Listener.Addr().Network(), awsAPIMock.Listener.Addr().String())
},
},
Clock: clock,
Clock: clock,
MaxHTTPRequestBodySize: maxTestHTTPRequestBodySize,
})
require.NoError(t, err)
mux := http.NewServeMux()
Expand Down
4 changes: 4 additions & 0 deletions lib/srv/app/azure/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/azure"
"github.com/gravitational/teleport/lib/defaults"
Expand Down Expand Up @@ -129,6 +130,9 @@ func newAzureHandler(ctx context.Context, config HandlerConfig) (*handler, error

// RoundTrip handles incoming requests and forwards them to the proper API.
func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Body != nil {
req.Body = utils.MaxBytesReader(w, req.Body, teleport.MaxHTTPRequestSize)
}
if err := s.serveHTTP(w, req); err != nil {
s.formatForwardResponseError(w, req, err)
return
Expand Down
4 changes: 4 additions & 0 deletions lib/srv/app/gcp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/utils/gcp"
"github.com/gravitational/teleport/lib/cloud"
"github.com/gravitational/teleport/lib/defaults"
Expand Down Expand Up @@ -151,6 +152,9 @@ func newGCPHandler(ctx context.Context, config HandlerConfig) (*handler, error)

// RoundTrip handles incoming requests and forwards them to the proper API.
func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
if req.Body != nil {
req.Body = utils.MaxBytesReader(w, req.Body, teleport.MaxHTTPRequestSize)
}
if err := s.serveHTTP(w, req); err != nil {
s.formatForwardResponseError(w, req, err)
return
Expand Down
2 changes: 2 additions & 0 deletions lib/srv/db/dynamodb/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
"github.com/aws/aws-sdk-go/service/dynamodbstreams"
"github.com/gravitational/trace"

"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
apiaws "github.com/gravitational/teleport/api/utils/aws"
"github.com/gravitational/teleport/lib/cloud"
Expand Down Expand Up @@ -167,6 +168,7 @@ func (e *Engine) process(ctx context.Context, req *http.Request, signer *libaws.
if req.Body != nil {
// make sure we close the incoming request's body. ignore any close error.
defer req.Body.Close()
req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize))
}

re, err := e.resolveEndpoint(req)
Expand Down
6 changes: 6 additions & 0 deletions lib/srv/db/elasticsearch/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
elastic "github.com/elastic/go-elasticsearch/v8/typedapi/types"
"github.com/gravitational/trace"

"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/events"
Expand Down Expand Up @@ -146,6 +147,11 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
// process reads request from connected elasticsearch client, processes the requests/responses and send data back
// to the client.
func (e *Engine) process(ctx context.Context, sessionCtx *common.Session, req *http.Request, client *http.Client) error {
if req.Body != nil {
// make sure we close the incoming request's body. ignore any close error.
defer req.Body.Close()
req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize))
}
payload, err := utils.GetAndReplaceRequestBody(req)
if err != nil {
return trace.Wrap(err)
Expand Down
11 changes: 10 additions & 1 deletion lib/srv/db/opensearch/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/aws/aws-sdk-go/service/opensearchservice"
"github.com/gravitational/trace"

"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/cloud"
Expand Down Expand Up @@ -171,6 +172,11 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error
// process reads request from connected OpenSearch client, processes the requests/responses and send data back
// to the client.
func (e *Engine) process(ctx context.Context, tr *http.Transport, signer *libaws.SigningService, req *http.Request) error {
if req.Body != nil {
// make sure we close the incoming request's body. ignore any close error.
defer req.Body.Close()
req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize))
}
reqCopy, payload, err := e.rewriteRequest(ctx, req)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -311,11 +317,14 @@ func (e *Engine) emitAuditEvent(req *http.Request, body []byte, statusCode uint3

// sendResponse sends the response back to the OpenSearch client.
func (e *Engine) sendResponse(serverResponse *http.Response) error {
if serverResponse.Body != nil {
defer serverResponse.Body.Close()
serverResponse.Body = io.NopCloser(io.LimitReader(serverResponse.Body, teleport.MaxHTTPResponseSize))
}
payload, err := utils.GetAndReplaceResponseBody(serverResponse)
if err != nil {
return trace.Wrap(err)
}

// serverResponse may be HTTP2 response, but we should reply with HTTP 1.1
clientResponse := &http.Response{
ProtoMajor: 1,
Expand Down
10 changes: 5 additions & 5 deletions lib/srv/db/snowflake/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,24 @@ func readRequestBody(req *http.Request) ([]byte, error) {
return nil, trace.Wrap(err)
}

return maybeReadGzip(&req.Header, body)
return maybeReadGzip(&req.Header, body, teleport.MaxHTTPRequestSize)
}

func readResponseBody(resp *http.Response) ([]byte, error) {
defer resp.Body.Close()

body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPRequestSize)
body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize)
if err != nil {
return nil, trace.Wrap(err)
}

return maybeReadGzip(&resp.Header, body)
return maybeReadGzip(&resp.Header, body, teleport.MaxHTTPResponseSize)
}

// maybeReadGzip checks if the body is gzip encoded and returns decoded version.
// To determine gzip encoding the beginning of body message is being checked
// instead of HTTP header and the second one was less reliable during testing.
func maybeReadGzip(headers *http.Header, body []byte) ([]byte, error) {
func maybeReadGzip(headers *http.Header, body []byte, limit int64) ([]byte, error) {
gzipMagic := []byte{0x1f, 0x8b, 0x08}

// Check if the body is gzip encoded. Alternative here could check
Expand All @@ -108,7 +108,7 @@ func maybeReadGzip(headers *http.Header, body []byte) ([]byte, error) {
}
defer bodyGZ.Close()

body, err = utils.ReadAtMost(bodyGZ, teleport.MaxHTTPRequestSize)
body, err = utils.ReadAtMost(bodyGZ, limit)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
Loading