diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go index 7ce54fc717b7a..1022a82838894 100644 --- a/lib/srv/app/aws/handler.go +++ b/lib/srv/app/aws/handler.go @@ -30,7 +30,6 @@ import ( "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" - "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/httplib/reverseproxy" @@ -59,6 +58,8 @@ type SignerHandlerConfig struct { *awsutils.SigningService // Clock is used to override time in tests. Clock clockwork.Clock + // MaxHTTPRequestBodySize is the limit on how big a request body can be. + MaxHTTPRequestBodySize int64 } // CheckAndSetDefaults validates the AwsSignerHandlerConfig. @@ -79,6 +80,12 @@ func (cfg *SignerHandlerConfig) CheckAndSetDefaults() error { if cfg.Clock == nil { cfg.Clock = clockwork.NewRealClock() } + + // Limit HTTP request body size to 70MB, which matches AWS Lambda function + // zip file upload limit (50MB) after accounting for base64 encoding bloat. + if cfg.MaxHTTPRequestBodySize == 0 { + cfg.MaxHTTPRequestBodySize = 70 << 20 + } return nil } @@ -115,6 +122,7 @@ func (s *signerHandler) formatForwardResponseError(rw http.ResponseWriter, r *ht // ServeHTTP handles incoming requests by signing them and then forwarding them to the proper AWS API. func (s *signerHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + req.Body = utils.MaxBytesReader(w, req.Body, s.MaxHTTPRequestBodySize) if err := s.serveHTTP(w, req); err != nil { s.formatForwardResponseError(w, req, err) return @@ -226,7 +234,7 @@ func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.Resolved } outReq.Body = http.NoBody if r.Body != nil { - outReq.Body = io.NopCloser(io.LimitReader(r.Body, teleport.MaxHTTPRequestSize)) + outReq.Body = r.Body } // need to rewrite the host header as well. The oxy forwarder will do this for us, // since we use the PassHostHeader(false) option, but if host is a signed header diff --git a/lib/srv/app/aws/handler_test.go b/lib/srv/app/aws/handler_test.go index 966e4177ec368..ba1a391642a89 100644 --- a/lib/srv/app/aws/handler_test.go +++ b/lib/srv/app/aws/handler_test.go @@ -24,6 +24,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "strings" "testing" "time" @@ -34,6 +35,7 @@ import ( "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/dynamodb" + "github.com/aws/aws-sdk-go/service/lambda" "github.com/aws/aws-sdk-go/service/s3" "github.com/aws/aws-sdk-go/service/sts" "github.com/google/go-cmp/cmp" @@ -99,6 +101,39 @@ func dynamoRequestWithTransport(url string, provider client.ConfigProvider, tran return err } +// dont make tests generate huge requests just to test limiting the request +// size. Use a 1MB limit instead of the actual 70MB limit. +const maxTestHTTPRequestBodySize = 1 << 20 + +func maxSizeExceededRequest(url string, provider client.ConfigProvider, _ string) error { + // fake an upload that's too large + payload := strings.Repeat("x", maxTestHTTPRequestBodySize) + return lambdaRequestWithPayload(url, provider, payload) +} + +func lambdaRequest(url string, provider client.ConfigProvider, awsHost string) error { + // fake a zip file with 70% of the max limit. Lambda will base64 encode it, + // which bloats it up, and our proxy should still handle it. + const size = (maxTestHTTPRequestBodySize * 7) / 10 + payload := strings.Repeat("x", size) + return lambdaRequestWithPayload(url, provider, payload) +} + +func lambdaRequestWithPayload(url string, provider client.ConfigProvider, payload string) error { + lambdaClient := lambda.New(provider, &aws.Config{ + Endpoint: &url, + MaxRetries: aws.Int(0), + HTTPClient: &http.Client{ + Timeout: 5 * time.Second, + }, + }) + _, err := lambdaClient.UpdateFunctionCode(&lambda.UpdateFunctionCodeInput{ + FunctionName: aws.String("fakeFunc"), + ZipFile: []byte(payload), + }) + return err +} + func assumeRoleRequest(requestDuration time.Duration) makeRequest { return func(url string, provider client.ConfigProvider, _ string) error { stsClient := sts.New(provider, &aws.Config{ @@ -294,6 +329,37 @@ func TestAWSSignerHandler(t *testing.T) { require.NoError, }, }, + { + name: "Lambda access", + app: consoleApp, + awsClientSession: session.Must(session.NewSession(&aws.Config{ + Credentials: staticAWSCredentialsForClient, + Region: aws.String("us-east-1"), + })), + request: lambdaRequest, + wantHost: "lambda.us-east-1.amazonaws.com", + wantAuthCredKeyID: "AKIDl", + wantAuthCredService: "lambda", + wantAuthCredRegion: "us-east-1", + wantEventType: &events.AppSessionRequest{}, + errAssertionFns: []require.ErrorAssertionFunc{ + require.NoError, + }, + }, + { + name: "Request exceeding max size", + app: consoleApp, + awsClientSession: session.Must(session.NewSession(&aws.Config{ + Credentials: staticAWSCredentialsForClient, + Region: aws.String("us-east-1"), + })), + request: maxSizeExceededRequest, + errAssertionFns: []require.ErrorAssertionFunc{ + // TODO(gavin): change this to [http.StatusRequestEntityTooLarge] + // after updating [trace.ErrorToCode]. + hasStatusCode(http.StatusTooManyRequests), + }, + }, { name: "AssumeRole success (shorter identity duration)", app: consoleApp, @@ -351,7 +417,9 @@ func TestAWSSignerHandler(t *testing.T) { }, } for _, tc := range tests { + tc := tc t.Run(tc.name, func(t *testing.T) { + t.Parallel() fakeClock := clockwork.NewFakeClock() mockAwsHandler := func(w http.ResponseWriter, r *http.Request) { // check that we got what the test case expects first. @@ -537,7 +605,8 @@ func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Applic return net.Dial(awsAPIMock.Listener.Addr().Network(), awsAPIMock.Listener.Addr().String()) }, }, - Clock: clock, + Clock: clock, + MaxHTTPRequestBodySize: maxTestHTTPRequestBodySize, }) require.NoError(t, err) mux := http.NewServeMux() diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index 002a6e8e03591..20685eb02b92a 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -30,6 +30,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/azure" "github.com/gravitational/teleport/lib/defaults" @@ -123,6 +124,9 @@ func newAzureHandler(ctx context.Context, config HandlerConfig) (*handler, error // RoundTrip handles incoming requests and forwards them to the proper API. func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Body != nil { + req.Body = utils.MaxBytesReader(w, req.Body, teleport.MaxHTTPRequestSize) + } if err := s.serveHTTP(w, req); err != nil { s.formatForwardResponseError(w, req, err) return diff --git a/lib/srv/app/gcp/handler.go b/lib/srv/app/gcp/handler.go index 89b7e0bdb24a8..ac0cb4cc3f3d9 100644 --- a/lib/srv/app/gcp/handler.go +++ b/lib/srv/app/gcp/handler.go @@ -28,6 +28,7 @@ import ( "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/utils/gcp" "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/defaults" @@ -144,6 +145,9 @@ func newGCPHandler(ctx context.Context, config HandlerConfig) (*handler, error) // RoundTrip handles incoming requests and forwards them to the proper API. func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if req.Body != nil { + req.Body = utils.MaxBytesReader(w, req.Body, teleport.MaxHTTPRequestSize) + } if err := s.serveHTTP(w, req); err != nil { s.formatForwardResponseError(w, req, err) return diff --git a/lib/srv/db/clickhouse/engine_http.go b/lib/srv/db/clickhouse/engine_http.go index 954398e999e0b..70f99aeb162e8 100644 --- a/lib/srv/db/clickhouse/engine_http.go +++ b/lib/srv/db/clickhouse/engine_http.go @@ -32,6 +32,7 @@ import ( "github.com/andybalholm/brotli" "github.com/gravitational/trace" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/srv/db/common" "github.com/gravitational/teleport/lib/utils" @@ -57,31 +58,44 @@ func (e *Engine) handleHTTPConnection(ctx context.Context, sessionCtx *common.Se if err != nil { return trace.Wrap(err) } - query, err := getQuery(req) - if err != nil { + if err := e.handleRequest(req, sessionCtx, tr); err != nil { return trace.Wrap(err) } + } +} - queryEvent := common.Query{ - Query: query, - Parameters: []string{fmt.Sprintf("url=%s", req.URL.String())}, - } +func (e *Engine) handleRequest(req *http.Request, sessionCtx *common.Session, tr *http.Transport) error { + if req.Body != nil { + // we have to close the request body since [http.Server] didn't serve it + // up for us. + defer req.Body.Close() + req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize)) + } + query, err := getQuery(req) + if err != nil { + return trace.Wrap(err) + } - e.Audit.OnQuery(e.Context, sessionCtx, queryEvent) + queryEvent := common.Query{ + Query: query, + Parameters: []string{fmt.Sprintf("url=%s", req.URL.String())}, + } - if err := e.handleRequest(req, sessionCtx); err != nil { - return trace.Wrap(err) - } + e.Audit.OnQuery(e.Context, sessionCtx, queryEvent) - resp, err := tr.RoundTrip(req) - if err != nil { - return trace.Wrap(err) - } + if err := e.rewriteRequest(req, sessionCtx); err != nil { + return trace.Wrap(err) + } - if err := e.writeResp(resp); err != nil { - return trace.Wrap(err) - } + resp, err := tr.RoundTrip(req) + if err != nil { + return trace.Wrap(err) + } + + if err := e.writeResp(resp); err != nil { + return trace.Wrap(err) } + return nil } func handleCompression(body []byte, compression string) ([]byte, error) { @@ -153,7 +167,7 @@ func (e *Engine) writeResp(resp *http.Response) error { return nil } -func (e *Engine) handleRequest(req *http.Request, sessionCtx *common.Session) error { +func (e *Engine) rewriteRequest(req *http.Request, sessionCtx *common.Session) error { uri, err := url.Parse(sessionCtx.Database.GetURI()) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/dynamodb/engine.go b/lib/srv/db/dynamodb/engine.go index 5cfb1a449c992..459a53f1946c0 100644 --- a/lib/srv/db/dynamodb/engine.go +++ b/lib/srv/db/dynamodb/engine.go @@ -37,6 +37,7 @@ import ( "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" + "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" apiaws "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/cloud" @@ -180,6 +181,7 @@ func (e *Engine) process(ctx context.Context, req *http.Request, signer *libaws. if req.Body != nil { // make sure we close the incoming request's body. ignore any close error. defer req.Body.Close() + req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize)) } re, err := e.resolveEndpoint(req) diff --git a/lib/srv/db/elasticsearch/engine.go b/lib/srv/db/elasticsearch/engine.go index 3a94b02781c12..b558b3f29c9ea 100644 --- a/lib/srv/db/elasticsearch/engine.go +++ b/lib/srv/db/elasticsearch/engine.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" + "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/lib/events" @@ -156,6 +157,11 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio func (e *Engine) process(ctx context.Context, sessionCtx *common.Session, req *http.Request, client *http.Client, msgFromClient prometheus.Counter, msgFromServer prometheus.Counter) error { msgFromClient.Inc() + if req.Body != nil { + // make sure we close the incoming request's body. ignore any close error. + defer req.Body.Close() + req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize)) + } payload, err := utils.GetAndReplaceRequestBody(req) if err != nil { return trace.Wrap(err) diff --git a/lib/srv/db/opensearch/engine.go b/lib/srv/db/opensearch/engine.go index 986ac81657cd0..c25db69407a0b 100644 --- a/lib/srv/db/opensearch/engine.go +++ b/lib/srv/db/opensearch/engine.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" + "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/lib/cloud" @@ -186,6 +187,11 @@ func (e *Engine) HandleConnection(ctx context.Context, _ *common.Session) error func (e *Engine) process(ctx context.Context, tr *http.Transport, signer *libaws.SigningService, req *http.Request, msgFromClient prometheus.Counter, msgFromServer prometheus.Counter) error { msgFromClient.Inc() + if req.Body != nil { + // make sure we close the incoming request's body. ignore any close error. + defer req.Body.Close() + req.Body = io.NopCloser(utils.LimitReader(req.Body, teleport.MaxHTTPRequestSize)) + } reqCopy, payload, err := e.rewriteRequest(ctx, req) if err != nil { return trace.Wrap(err) @@ -328,11 +334,14 @@ func (e *Engine) emitAuditEvent(req *http.Request, body []byte, statusCode uint3 // sendResponse sends the response back to the OpenSearch client. func (e *Engine) sendResponse(serverResponse *http.Response) error { + if serverResponse.Body != nil { + defer serverResponse.Body.Close() + serverResponse.Body = io.NopCloser(io.LimitReader(serverResponse.Body, teleport.MaxHTTPResponseSize)) + } payload, err := utils.GetAndReplaceResponseBody(serverResponse) if err != nil { return trace.Wrap(err) } - // serverResponse may be HTTP2 response, but we should reply with HTTP 1.1 clientResponse := &http.Response{ ProtoMajor: 1, diff --git a/lib/srv/db/snowflake/http.go b/lib/srv/db/snowflake/http.go index 669d21802ae59..2ecf2307ca82e 100644 --- a/lib/srv/db/snowflake/http.go +++ b/lib/srv/db/snowflake/http.go @@ -75,24 +75,24 @@ func readRequestBody(req *http.Request) ([]byte, error) { return nil, trace.Wrap(err) } - return maybeReadGzip(&req.Header, body) + return maybeReadGzip(&req.Header, body, teleport.MaxHTTPRequestSize) } func readResponseBody(resp *http.Response) ([]byte, error) { defer resp.Body.Close() - body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPRequestSize) + body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) if err != nil { return nil, trace.Wrap(err) } - return maybeReadGzip(&resp.Header, body) + return maybeReadGzip(&resp.Header, body, teleport.MaxHTTPResponseSize) } // maybeReadGzip checks if the body is gzip encoded and returns decoded version. // To determine gzip encoding the beginning of body message is being checked // instead of HTTP header and the second one was less reliable during testing. -func maybeReadGzip(headers *http.Header, body []byte) ([]byte, error) { +func maybeReadGzip(headers *http.Header, body []byte, limit int64) ([]byte, error) { gzipMagic := []byte{0x1f, 0x8b, 0x08} // Check if the body is gzip encoded. Alternative here could check @@ -108,7 +108,7 @@ func maybeReadGzip(headers *http.Header, body []byte) ([]byte, error) { } defer bodyGZ.Close() - body, err = utils.ReadAtMost(bodyGZ, teleport.MaxHTTPRequestSize) + body, err = utils.ReadAtMost(bodyGZ, limit) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/utils/http.go b/lib/utils/http.go index e0717bf170c54..5b66cc422be8b 100644 --- a/lib/utils/http.go +++ b/lib/utils/http.go @@ -18,23 +18,25 @@ package utils import ( "bytes" + "errors" "io" "net/http" "github.com/gravitational/trace" - - "github.com/gravitational/teleport" ) // GetAndReplaceRequestBody returns the request body and replaces the drained -// body reader with io.NopCloser allowing for further body processing by http -// transport. +// body reader with an [io.NopCloser] allowing for further body processing by +// http transport. +// If memory exhaustion is a concern, it is the caller's responsibility to wrap +// the request body in an [io.LimitReader] prior to calling this function. func GetAndReplaceRequestBody(req *http.Request) ([]byte, error) { if req.Body == nil || req.Body == http.NoBody { return []byte{}, nil } - // req.Body is closed during tryDrainBody call. - payload, err := tryDrainBody(req.Body) + defer req.Body.Close() + + payload, err := io.ReadAll(req.Body) if err != nil { return nil, trace.Wrap(err) } @@ -45,13 +47,16 @@ func GetAndReplaceRequestBody(req *http.Request) ([]byte, error) { } // GetAndReplaceResponseBody returns the response body and replaces the drained -// body reader with io.NopCloser allowing for further body processing. +// body reader with [io.NopCloser] allowing for further body processing. +// If memory exhaustion is a concern, it is the caller's responsibility to wrap +// the response body in an [io.LimitReader] prior to calling this function. func GetAndReplaceResponseBody(response *http.Response) ([]byte, error) { if response.Body == nil { return []byte{}, nil } + defer response.Body.Close() - payload, err := tryDrainBody(response.Body) + payload, err := io.ReadAll(response.Body) if err != nil { return nil, trace.Wrap(err) } @@ -62,32 +67,20 @@ func GetAndReplaceResponseBody(response *http.Response) ([]byte, error) { // ReplaceRequestBody drains the old request body and replaces it with a new one. func ReplaceRequestBody(req *http.Request, newBody io.ReadCloser) error { - if _, err := tryDrainBody(req.Body); err != nil { - return trace.Wrap(err) + if req.Body != nil { + defer req.Body.Close() + // drain and discard the request body to allow connection reuse. + // No need to enforce a max request size, nor rely on callers to do so, + // since we do not buffer the entire request body. + _, err := io.Copy(io.Discard, req.Body) + if err != nil && !errors.Is(err, io.EOF) { + return trace.Wrap(err) + } } req.Body = newBody return nil } -// tryDrainBody tries to drain and close the body, returning the read bytes. -// It may fail to completely drain the body if the size of the body exceeds MaxHTTPRequestSize. -func tryDrainBody(b io.ReadCloser) (payload []byte, err error) { - if b == nil { - return nil, nil - } - defer func() { - if closeErr := b.Close(); closeErr != nil { - err = trace.NewAggregate(err, closeErr) - } - }() - payload, err = ReadAtMost(b, teleport.MaxHTTPRequestSize) - if err != nil { - err = trace.Wrap(err) - return - } - return -} - // RenameHeader moves all values from the old header key to the new header key. func RenameHeader(header http.Header, oldKey, newKey string) { if oldKey == newKey { @@ -157,3 +150,27 @@ func ChainHTTPMiddlewares(handler http.Handler, middlewares ...HTTPMiddleware) h func NoopHTTPMiddleware(next http.Handler) http.Handler { return next } + +// MaxBytesReader returns an [io.ReadCloser] that wraps an [http.MaxBytesReader] +// to act as a shim for converting from [http.MaxBytesError] to +// [ErrLimitReached]. +func MaxBytesReader(w http.ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { + return &maxBytesReader{ReadCloser: http.MaxBytesReader(w, r, n)} +} + +// maxBytesReader wraps an [http.MaxBytesReader] and converts any +// [http.MaxBytesError] to [ErrLimitReached]. +type maxBytesReader struct { + io.ReadCloser +} + +func (m *maxBytesReader) Read(p []byte) (int, error) { + n, err := m.ReadCloser.Read(p) + + // convert [http.MaxBytesError] to our limit error. + var mbErr *http.MaxBytesError + if errors.As(err, &mbErr) { + return n, ErrLimitReached + } + return n, err +} diff --git a/lib/utils/utils.go b/lib/utils/utils.go index 1763377888bfc..b73ad43d912ff 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -653,18 +653,34 @@ func StoreErrorOf(f func() error, err *error) { *err = trace.NewAggregate(*err, f()) } +// LimitReader returns a reader that limits bytes from r, and reports an error +// when limit bytes are read. +func LimitReader(r io.Reader, limit int64) io.Reader { + return &limitedReader{ + LimitedReader: &io.LimitedReader{R: r, N: limit}, + } +} + +// limitedReader wraps an [io.LimitedReader] that limits bytes read, and +// reports an error when the read limit is reached. +type limitedReader struct { + *io.LimitedReader +} + +func (l *limitedReader) Read(p []byte) (int, error) { + n, err := l.LimitedReader.Read(p) + if l.LimitedReader.N <= 0 { + return n, ErrLimitReached + } + return n, err +} + // ReadAtMost reads up to limit bytes from r, and reports an error // when limit bytes are read. func ReadAtMost(r io.Reader, limit int64) ([]byte, error) { - limitedReader := &io.LimitedReader{R: r, N: limit} + limitedReader := LimitReader(r, limit) data, err := io.ReadAll(limitedReader) - if err != nil { - return data, err - } - if limitedReader.N <= 0 { - return data, ErrLimitReached - } - return data, nil + return data, err } // HasPrefixAny determines if any of the string values have the given prefix. @@ -694,6 +710,9 @@ func ByteCount(b int64) string { } // ErrLimitReached means that the read limit is reached. +// +// TODO(gavin): this should be converted to a 413 StatusRequestEntityTooLarge +// in trace.ErrorToCode instead of 429 StatusTooManyRequests. var ErrLimitReached = &trace.LimitExceededError{Message: "the read limit is reached"} const (