diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go index f832cce0a833d..d2ef33fd26c8b 100644 --- a/lib/srv/app/aws/handler.go +++ b/lib/srv/app/aws/handler.go @@ -32,7 +32,6 @@ import ( "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" - "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/srv/app/common" @@ -60,6 +59,8 @@ type SignerHandlerConfig struct { *awsutils.SigningService // Clock is used to override time in tests. Clock clockwork.Clock + // MaxHTTPRequestBodySize is the limit on how big a request body can be. + MaxHTTPRequestBodySize int64 } // CheckAndSetDefaults validates the AwsSignerHandlerConfig. @@ -80,6 +81,12 @@ func (cfg *SignerHandlerConfig) CheckAndSetDefaults() error { if cfg.Clock == nil { cfg.Clock = clockwork.NewRealClock() } + + // Limit HTTP request body size to 70MB, which matches AWS Lambda function + // zip file upload limit (50MB) after accounting for base64 encoding bloat. + if cfg.MaxHTTPRequestBodySize == 0 { + cfg.MaxHTTPRequestBodySize = 70 << 20 + } return nil } @@ -119,6 +126,7 @@ func (s *signerHandler) formatForwardResponseError(rw http.ResponseWriter, r *ht // ServeHTTP handles incoming requests by signing them and then forwarding them to the proper AWS API. func (s *signerHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + req.Body = utils.MaxBytesReader(w, req.Body, s.MaxHTTPRequestBodySize) if err := s.serveHTTP(w, req); err != nil { s.formatForwardResponseError(w, req, err) return @@ -230,7 +238,7 @@ func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.Resolved } outReq.Body = http.NoBody if r.Body != nil { - outReq.Body = io.NopCloser(io.LimitReader(r.Body, teleport.MaxHTTPRequestSize)) + outReq.Body = r.Body } // need to rewrite the host header as well. The oxy forwarder will do this for us, // since we use the PassHostHeader(false) option, but if host is a signed header diff --git a/lib/srv/app/aws/handler_test.go b/lib/srv/app/aws/handler_test.go index 89311e46db25d..69a1ae4043281 100644 --- a/lib/srv/app/aws/handler_test.go +++ b/lib/srv/app/aws/handler_test.go @@ -24,6 +24,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" @@ -34,6 +35,7 @@ import ( "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/lambda" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/sts" "github.com/google/go-cmp/cmp" @@ -94,6 +96,39 @@ func dynamoRequestWithTransport(url string, provider client.ConfigProvider, tran return err } +// dont make tests generate huge requests just to test limiting the request +// size. Use a 1MB limit instead of the actual 70MB limit. +const maxTestHTTPRequestBodySize = 1 << 20 + +func maxSizeExceededRequest(url string, provider client.ConfigProvider, _ string) error { + // fake an upload that's too large + payload := strings.Repeat("x", maxTestHTTPRequestBodySize) + return lambdaRequestWithPayload(url, provider, payload) +} + +func lambdaRequest(url string, provider client.ConfigProvider, awsHost string) error { + // fake a zip file with 70% of the max limit. Lambda will base64 encode it, + // which bloats it up, and our proxy should still handle it. + const size = (maxTestHTTPRequestBodySize * 7) / 10 + payload := strings.Repeat("x", size) + return lambdaRequestWithPayload(url, provider, payload) +} + +func lambdaRequestWithPayload(url string, provider client.ConfigProvider, payload string) error { + lambdaClient := lambda.New(provider, &aws.Config{ + Endpoint: &url, + MaxRetries: aws.Int(0), + HTTPClient: &http.Client{ + Timeout: 5 * time.Second, + }, + }) + _, err := lambdaClient.UpdateFunctionCode(&lambda.UpdateFunctionCodeInput{ + FunctionName: aws.String("fakeFunc"), + ZipFile: []byte(payload), + }) + return err +} + func assumeRoleRequest(requestDuration time.Duration) makeRequest { return func(url string, provider client.ConfigProvider, _ string) error { stsClient := sts.New(provider, &aws.Config{ @@ -289,6 +324,37 @@ func TestAWSSignerHandler(t *testing.T) { require.NoError, }, }, + { + name: "Lambda access", + app: consoleApp, + awsClientSession: session.Must(session.NewSession(&aws.Config{ + Credentials: staticAWSCredentialsForClient, + Region: aws.String("us-east-1"), + })), + request: lambdaRequest, + wantHost: "lambda.us-east-1.amazonaws.com", + wantAuthCredKeyID: "AKIDl", + wantAuthCredService: "lambda", + wantAuthCredRegion: "us-east-1", + wantEventType: &events.AppSessionRequest{}, + errAssertionFns: []require.ErrorAssertionFunc{ + require.NoError, + }, + }, + { + name: "Request exceeding max size", + app: consoleApp, + awsClientSession: session.Must(session.NewSession(&aws.Config{ + Credentials: staticAWSCredentialsForClient, + Region: aws.String("us-east-1"), + })), + request: maxSizeExceededRequest, + errAssertionFns: []require.ErrorAssertionFunc{ + // TODO(gavin): change this to [http.StatusRequestEntityTooLarge] + // after updating [trace.ErrorToCode]. + hasStatusCode(http.StatusTooManyRequests), + }, + }, { name: "AssumeRole success (shorter identity duration)", app: consoleApp, @@ -346,7 +412,9 @@ func TestAWSSignerHandler(t *testing.T) { }, } for _, tc := range tests { + tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() fakeClock := clockwork.NewFakeClock() mockAwsHandler := func(w http.ResponseWriter, r *http.Request) { // check that we got what the test case expects first. @@ -529,7 +597,8 @@ func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Applic return net.Dial(awsAPIMock.Listener.Addr().Network(), awsAPIMock.Listener.Addr().String()) }, }, - Clock: clock, + Clock: clock, + MaxHTTPRequestBodySize: maxTestHTTPRequestBodySize, }) require.NoError(t, err) mux := http.NewServeMux() diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index 159d0d864d9e4..f269a587eeb19 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -32,6 +32,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/azure" "github.com/gravitational/teleport/lib/defaults" @@ -129,6 +130,9 @@ func newAzureHandler(ctx context.Context, config HandlerConfig) (*handler, error // RoundTrip handles incoming requests and forwards them to the proper API. func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Body != nil { + req.Body = utils.MaxBytesReader(w, req.Body, teleport.MaxHTTPRequestSize) + } if err := s.serveHTTP(w, req); err != nil { s.formatForwardResponseError(w, req, err) return diff --git a/lib/srv/app/gcp/handler.go b/lib/srv/app/gcp/handler.go index 4b687ca8d3917..d4a457578957d 100644 --- a/lib/srv/app/gcp/handler.go +++ b/lib/srv/app/gcp/handler.go @@ -30,6 +30,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/utils/gcp" "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/defaults" @@ -151,6 +152,9 @@ func newGCPHandler(ctx context.Context, config HandlerConfig) (*handler, error) // RoundTrip handles incoming requests and forwards them to the proper API. func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Body != nil { + req.Body = utils.MaxBytesReader(w, req.Body, teleport.MaxHTTPRequestSize) + } if err := s.serveHTTP(w, req); err != nil { s.formatForwardResponseError(w, req, err) return diff --git a/lib/srv/db/dynamodb/engine.go b/lib/srv/db/dynamodb/engine.go index 032aa0f735490..2d40babbdfb01 100644 --- a/lib/srv/db/dynamodb/engine.go +++ b/lib/srv/db/dynamodb/engine.go @@ -36,6 +36,7 @@ import ( "github.com/aws/aws-sdk-go/service/dynamodbstreams" "github.com/gravitational/trace" + "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" apiaws "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/cloud" @@ -167,6 +168,7 @@ func (e *Engine) process(ctx context.Context, req *http.Request, signer *libaws. if req.Body != nil { // make sure we close the incoming request's body. ignore any close error. defer req.Body.Close() + req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize)) } re, err := e.resolveEndpoint(req) diff --git a/lib/srv/db/elasticsearch/engine.go b/lib/srv/db/elasticsearch/engine.go index 8c7bc38f920d6..7121189d4ce83 100644 --- a/lib/srv/db/elasticsearch/engine.go +++ b/lib/srv/db/elasticsearch/engine.go @@ -31,6 +31,7 @@ import ( elastic "github.com/elastic/go-elasticsearch/v8/typedapi/types" "github.com/gravitational/trace" + "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/lib/events" @@ -146,6 +147,11 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio // process reads request from connected elasticsearch client, processes the requests/responses and send data back // to the client. func (e *Engine) process(ctx context.Context, sessionCtx *common.Session, req *http.Request, client *http.Client) error { + if req.Body != nil { + // make sure we close the incoming request's body. ignore any close error. + defer req.Body.Close() + req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize)) + } payload, err := utils.GetAndReplaceRequestBody(req) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/opensearch/engine.go b/lib/srv/db/opensearch/engine.go index 555690d7cbae4..2618d493d2c6a 100644 --- a/lib/srv/db/opensearch/engine.go +++ b/lib/srv/db/opensearch/engine.go @@ -27,6 +27,7 @@ import ( "github.com/aws/aws-sdk-go/service/opensearchservice" "github.com/gravitational/trace" + "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/lib/cloud" @@ -171,6 +172,11 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error // process reads request from connected OpenSearch client, processes the requests/responses and send data back // to the client. func (e *Engine) process(ctx context.Context, tr *http.Transport, signer *libaws.SigningService, req *http.Request) error { + if req.Body != nil { + // make sure we close the incoming request's body. ignore any close error. + defer req.Body.Close() + req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize)) + } reqCopy, payload, err := e.rewriteRequest(ctx, req) if err != nil { return trace.Wrap(err) @@ -311,11 +317,14 @@ func (e *Engine) emitAuditEvent(req *http.Request, body []byte, statusCode uint3 // sendResponse sends the response back to the OpenSearch client. func (e *Engine) sendResponse(serverResponse *http.Response) error { + if serverResponse.Body != nil { + defer serverResponse.Body.Close() + serverResponse.Body = io.NopCloser(io.LimitReader(serverResponse.Body, teleport.MaxHTTPResponseSize)) + } payload, err := utils.GetAndReplaceResponseBody(serverResponse) if err != nil { return trace.Wrap(err) } - // serverResponse may be HTTP2 response, but we should reply with HTTP 1.1 clientResponse := &http.Response{ ProtoMajor: 1, diff --git a/lib/srv/db/snowflake/http.go b/lib/srv/db/snowflake/http.go index 669d21802ae59..2ecf2307ca82e 100644 --- a/lib/srv/db/snowflake/http.go +++ b/lib/srv/db/snowflake/http.go @@ -75,24 +75,24 @@ func readRequestBody(req *http.Request) ([]byte, error) { return nil, trace.Wrap(err) } - return maybeReadGzip(&req.Header, body) + return maybeReadGzip(&req.Header, body, teleport.MaxHTTPRequestSize) } func readResponseBody(resp *http.Response) ([]byte, error) { defer resp.Body.Close() - body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPRequestSize) + body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) if err != nil { return nil, trace.Wrap(err) } - return maybeReadGzip(&resp.Header, body) + return maybeReadGzip(&resp.Header, body, teleport.MaxHTTPResponseSize) } // maybeReadGzip checks if the body is gzip encoded and returns decoded version. // To determine gzip encoding the beginning of body message is being checked // instead of HTTP header and the second one was less reliable during testing. -func maybeReadGzip(headers *http.Header, body []byte) ([]byte, error) { +func maybeReadGzip(headers *http.Header, body []byte, limit int64) ([]byte, error) { gzipMagic := []byte{0x1f, 0x8b, 0x08} // Check if the body is gzip encoded. Alternative here could check @@ -108,7 +108,7 @@ func maybeReadGzip(headers *http.Header, body []byte) ([]byte, error) { } defer bodyGZ.Close() - body, err = utils.ReadAtMost(bodyGZ, teleport.MaxHTTPRequestSize) + body, err = utils.ReadAtMost(bodyGZ, limit) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/utils/http.go b/lib/utils/http.go index e0717bf170c54..5b66cc422be8b 100644 --- a/lib/utils/http.go +++ b/lib/utils/http.go @@ -18,23 +18,25 @@ package utils import ( "bytes" + "errors" "io" "net/http" "github.com/gravitational/trace" - - "github.com/gravitational/teleport" ) // GetAndReplaceRequestBody returns the request body and replaces the drained -// body reader with io.NopCloser allowing for further body processing by http -// transport. +// body reader with an [io.NopCloser] allowing for further body processing by +// http transport. +// If memory exhaustion is a concern, it is the caller's responsibility to wrap +// the request body in an [io.LimitReader] prior to calling this function. func GetAndReplaceRequestBody(req *http.Request) ([]byte, error) { if req.Body == nil || req.Body == http.NoBody { return []byte{}, nil } - // req.Body is closed during tryDrainBody call. - payload, err := tryDrainBody(req.Body) + defer req.Body.Close() + + payload, err := io.ReadAll(req.Body) if err != nil { return nil, trace.Wrap(err) } @@ -45,13 +47,16 @@ func GetAndReplaceRequestBody(req *http.Request) ([]byte, error) { } // GetAndReplaceResponseBody returns the response body and replaces the drained -// body reader with io.NopCloser allowing for further body processing. +// body reader with [io.NopCloser] allowing for further body processing. +// If memory exhaustion is a concern, it is the caller's responsibility to wrap +// the response body in an [io.LimitReader] prior to calling this function. func GetAndReplaceResponseBody(response *http.Response) ([]byte, error) { if response.Body == nil { return []byte{}, nil } + defer response.Body.Close() - payload, err := tryDrainBody(response.Body) + payload, err := io.ReadAll(response.Body) if err != nil { return nil, trace.Wrap(err) } @@ -62,32 +67,20 @@ func GetAndReplaceResponseBody(response *http.Response) ([]byte, error) { // ReplaceRequestBody drains the old request body and replaces it with a new one. func ReplaceRequestBody(req *http.Request, newBody io.ReadCloser) error { - if _, err := tryDrainBody(req.Body); err != nil { - return trace.Wrap(err) + if req.Body != nil { + defer req.Body.Close() + // drain and discard the request body to allow connection reuse. + // No need to enforce a max request size, nor rely on callers to do so, + // since we do not buffer the entire request body. + _, err := io.Copy(io.Discard, req.Body) + if err != nil && !errors.Is(err, io.EOF) { + return trace.Wrap(err) + } } req.Body = newBody return nil } -// tryDrainBody tries to drain and close the body, returning the read bytes. -// It may fail to completely drain the body if the size of the body exceeds MaxHTTPRequestSize. -func tryDrainBody(b io.ReadCloser) (payload []byte, err error) { - if b == nil { - return nil, nil - } - defer func() { - if closeErr := b.Close(); closeErr != nil { - err = trace.NewAggregate(err, closeErr) - } - }() - payload, err = ReadAtMost(b, teleport.MaxHTTPRequestSize) - if err != nil { - err = trace.Wrap(err) - return - } - return -} - // RenameHeader moves all values from the old header key to the new header key. func RenameHeader(header http.Header, oldKey, newKey string) { if oldKey == newKey { @@ -157,3 +150,27 @@ func ChainHTTPMiddlewares(handler http.Handler, middlewares ...HTTPMiddleware) h func NoopHTTPMiddleware(next http.Handler) http.Handler { return next } + +// MaxBytesReader returns an [io.ReadCloser] that wraps an [http.MaxBytesReader] +// to act as a shim for converting from [http.MaxBytesError] to +// [ErrLimitReached]. +func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { + return &maxBytesReader{ReadCloser: http.MaxBytesReader(w, r, n)} +} + +// maxBytesReader wraps an [http.MaxBytesReader] and converts any +// [http.MaxBytesError] to [ErrLimitReached]. +type maxBytesReader struct { + io.ReadCloser +} + +func (m *maxBytesReader) Read(p []byte) (int, error) { + n, err := m.ReadCloser.Read(p) + + // convert [http.MaxBytesError] to our limit error. + var mbErr *http.MaxBytesError + if errors.As(err, &mbErr) { + return n, ErrLimitReached + } + return n, err +} diff --git a/lib/utils/utils.go b/lib/utils/utils.go index 7473e8e62badd..c03c5879fb84a 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -645,18 +645,34 @@ func StoreErrorOf(f func() error, err *error) { *err = trace.NewAggregate(*err, f()) } +// LimitReader returns a reader that limits bytes from r, and reports an error +// when limit bytes are read. +func LimitReader(r io.Reader, limit int64) io.Reader { + return &limitedReader{ + LimitedReader: &io.LimitedReader{R: r, N: limit}, + } +} + +// limitedReader wraps an [io.LimitedReader] that limits bytes read, and +// reports an error when the read limit is reached. +type limitedReader struct { + *io.LimitedReader +} + +func (l *limitedReader) Read(p []byte) (int, error) { + n, err := l.LimitedReader.Read(p) + if l.LimitedReader.N <= 0 { + return n, ErrLimitReached + } + return n, err +} + // ReadAtMost reads up to limit bytes from r, and reports an error // when limit bytes are read. func ReadAtMost(r io.Reader, limit int64) ([]byte, error) { - limitedReader := &io.LimitedReader{R: r, N: limit} + limitedReader := LimitReader(r, limit) data, err := io.ReadAll(limitedReader) - if err != nil { - return data, err - } - if limitedReader.N <= 0 { - return data, ErrLimitReached - } - return data, nil + return data, err } // HasPrefixAny determines if any of the string values have the given prefix. @@ -686,6 +702,9 @@ func ByteCount(b int64) string { } // ErrLimitReached means that the read limit is reached. +// +// TODO(gavin): this should be converted to a 413 StatusRequestEntityTooLarge +// in trace.ErrorToCode instead of 429 StatusTooManyRequests. var ErrLimitReached = &trace.LimitExceededError{Message: "the read limit is reached"} const (