diff --git a/lib/events/api.go b/lib/events/api.go index 1290ae7e2845a..f86834676c030 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -395,6 +395,10 @@ const ( // AppSessionRequestEvent is an HTTP request and response. AppSessionRequestEvent = "app.session.request" + // AppSessionDynamoDBRequestEvent is emitted when DynamoDB client sends + // a request via app access session. + AppSessionDynamoDBRequestEvent = "app.session.dynamodb.request" + // DatabaseCreateEvent is emitted when a database resource is created. DatabaseCreateEvent = "db.create" // DatabaseUpdateEvent is emitted when a database resource is updated. diff --git a/lib/events/codes.go b/lib/events/codes.go index 5a490f89e39cf..fa56ec10b8cee 100644 --- a/lib/events/codes.go +++ b/lib/events/codes.go @@ -108,6 +108,8 @@ const ( AppSessionEndCode = "T2011I" // SessionRecordingAccessCode is the session recording view data event code. SessionRecordingAccessCode = "T2012I" + // AppSessionDynamoDBRequestCode is the application request/response code. + AppSessionDynamoDBRequestCode = "T2013I" // AppCreateCode is the app.create event code. AppCreateCode = "TAP03I" diff --git a/lib/events/dynamic.go b/lib/events/dynamic.go index b822656e2753a..46037d09353e1 100644 --- a/lib/events/dynamic.go +++ b/lib/events/dynamic.go @@ -151,6 +151,8 @@ func FromEventFields(fields EventFields) (events.AuditEvent, error) { e = &events.AppSessionChunk{} case AppSessionRequestEvent: e = &events.AppSessionRequest{} + case AppSessionDynamoDBRequestEvent: + e = &events.AppSessionDynamoDBRequest{} case AppCreateEvent: e = &events.AppCreate{} case AppUpdateEvent: diff --git a/lib/service/service.go b/lib/service/service.go index c3bebec8eef3e..5c2cbde6d2fc1 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -4295,9 +4295,9 @@ func (process *TeleportProcess) initApps() { return trace.Wrap(err) } - ok := false + shouldSkipCleanup := false defer func() { - if !ok { + if !shouldSkipCleanup { warnOnErr(conn.Close(), log) } }() @@ -4431,6 +4431,12 @@ func (process *TeleportProcess) initApps() { proxyGetter := reversetunnel.NewConnectedProxyGetter() + defer func() { + if !shouldSkipCleanup { + warnOnErr(asyncEmitter.Close(), log) + } + }() + appServer, err := app.New(process.ExitContext(), &app.Config{ Clock: process.Config.Clock, DataDir: process.Config.DataDir, @@ -4456,7 +4462,7 @@ func (process *TeleportProcess) initApps() { } defer func() { - if !ok { + if !shouldSkipCleanup { warnOnErr(appServer.Close(), log) } }() @@ -4494,13 +4500,16 @@ func (process *TeleportProcess) initApps() { log.Infof("All applications successfully started.") // Cancel deferred cleanup actions, because we're going - // to regsiter an OnExit handler to take care of it - ok = true + // to register an OnExit handler to take care of it + shouldSkipCleanup = true // Execute this when process is asked to exit. process.OnExit("apps.stop", func(payload interface{}) { log.Infof("Shutting down.") warnOnErr(appServer.Close(), log) + if asyncEmitter != nil { + warnOnErr(asyncEmitter.Close(), log) + } agentPool.Stop() warnOnErr(asyncEmitter.Close(), log) warnOnErr(conn.Close(), log) diff --git a/lib/srv/app/aws/endpoints.go b/lib/srv/app/aws/endpoints.go index c5823f95b2503..d227e7f225e94 100644 --- a/lib/srv/app/aws/endpoints.go +++ b/lib/srv/app/aws/endpoints.go @@ -124,6 +124,22 @@ func endpointsIDFromSigningName(signingName string) string { return signingName } +func isDynamoDBEndpoint(re *endpoints.ResolvedEndpoint) bool { + // Some clients may sign some services with upper case letters. We use all + // lower cases in our mapping. + signingName := strings.ToLower(re.SigningName) + _, ok := dynamoDBSigningNames[signingName] + return ok +} + +// dynamoDBSigningNames is a set of signing names used for DynamoDB APIs. +var dynamoDBSigningNames = map[string]struct{}{ + // signing name for dynamodb and dynamodbstreams API. + "dynamodb": {}, + // signing name for dynamodb accelerator API. + "dax": {}, +} + // signingNameToEndpointsID is a map of AWS services' signing names to their // endpoints IDs. // diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go index 214df3e43eb8a..0fdfc1f95630a 100644 --- a/lib/srv/app/aws/handler.go +++ b/lib/srv/app/aws/handler.go @@ -18,7 +18,7 @@ package aws import ( "bytes" - "context" + "io" "net/http" "net/url" @@ -29,14 +29,12 @@ import ( "github.com/aws/aws-sdk-go/aws/endpoints" awssession "github.com/aws/aws-sdk-go/aws/session" "github.com/gravitational/oxy/forward" - "github.com/gravitational/oxy/utils" + oxyutils "github.com/gravitational/oxy/utils" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" - apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/srv/app/common" awsutils "github.com/gravitational/teleport/lib/utils/aws" ) @@ -52,7 +50,7 @@ func NewSigningService(config SigningServiceConfig) (*SigningService, error) { fwd, err := forward.New( forward.RoundTripper(svc), - forward.ErrorHandler(utils.ErrorHandlerFunc(svc.formatForwardResponseError)), + forward.ErrorHandler(oxyutils.ErrorHandlerFunc(svc.formatForwardResponseError)), forward.PassHostHeader(true), ) if err != nil { @@ -74,8 +72,8 @@ type SigningService struct { // SigningServiceConfig is the SigningService configuration. type SigningServiceConfig struct { - // Client is an HTTP client instance used for HTTP calls. - Client *http.Client + // RoundTripper is an http.RoundTripper instance used for requests. + RoundTripper http.RoundTripper // Log is the Logger. Log logrus.FieldLogger // Session is AWS session. @@ -90,14 +88,12 @@ type SigningServiceConfig struct { // CheckAndSetDefaults validates the SigningServiceConfig config. func (s *SigningServiceConfig) CheckAndSetDefaults() error { - if s.Client == nil { + if s.RoundTripper == nil { tr, err := defaults.Transport() if err != nil { return trace.Wrap(err) } - s.Client = &http.Client{ - Transport: tr, - } + s.RoundTripper = tr } if s.Clock == nil { s.Clock = clockwork.NewRealClock() @@ -137,6 +133,7 @@ func (s *SigningServiceConfig) CheckAndSetDefaults() error { // 5. Sign HTTP request. // 6. Forward the signed HTTP request to the AWS API. func (s *SigningService) RoundTrip(req *http.Request) (*http.Response, error) { + defer req.Body.Close() sessionCtx, err := common.GetSessionContext(req) if err != nil { return nil, trace.Wrap(err) @@ -145,46 +142,31 @@ func (s *SigningService) RoundTrip(req *http.Request) (*http.Response, error) { if err != nil { return nil, trace.Wrap(err) } - signedReq, err := s.prepareSignedRequest(req, resolvedEndpoint, sessionCtx) + payload, err := awsutils.GetAndReplaceReqBody(req) if err != nil { return nil, trace.Wrap(err) } - resp, err := s.Client.Do(signedReq) + signedReq, err := s.prepareSignedRequest(req, payload, resolvedEndpoint, sessionCtx) if err != nil { return nil, trace.Wrap(err) } - - if err := s.emitAuditEvent(req.Context(), signedReq, resp, sessionCtx, resolvedEndpoint); err != nil { + resp, err := s.RoundTripper.RoundTrip(signedReq) + if err != nil { + return nil, trace.Wrap(err) + } + // emit audit event with original request, but change the URL since we resolved and rewrote it. + signedReq.Body = io.NopCloser(bytes.NewReader(payload)) + if isDynamoDBEndpoint(resolvedEndpoint) { + err = sessionCtx.Audit.OnDynamoDBRequest(req.Context(), sessionCtx, signedReq, resp, resolvedEndpoint) + } else { + err = sessionCtx.Audit.OnRequest(req.Context(), sessionCtx, signedReq, resp, resolvedEndpoint) + } + if err != nil { s.Log.WithError(err).Warn("Failed to emit audit event.") } return resp, nil } -// emitAuditEvent writes details of the AWS request to audit stream. -func (s *SigningService) emitAuditEvent(ctx context.Context, req *http.Request, resp *http.Response, sessionCtx *common.SessionContext, endpoint *endpoints.ResolvedEndpoint) error { - event := &apievents.AppSessionRequest{ - Metadata: apievents.Metadata{ - Type: events.AppSessionRequestEvent, - Code: events.AppSessionRequestCode, - }, - Method: req.Method, - Path: req.URL.Path, - RawQuery: req.URL.RawQuery, - StatusCode: uint32(resp.StatusCode), - AppMetadata: apievents.AppMetadata{ - AppURI: sessionCtx.App.GetURI(), - AppPublicAddr: sessionCtx.App.GetPublicAddr(), - AppName: sessionCtx.App.GetName(), - }, - AWSRequestMetadata: apievents.AWSRequestMetadata{ - AWSRegion: endpoint.SigningRegion, - AWSService: endpoint.SigningName, - AWSHost: req.Host, - }, - } - return trace.Wrap(sessionCtx.Emitter.EmitAuditEvent(ctx, event)) -} - func (s *SigningService) formatForwardResponseError(rw http.ResponseWriter, r *http.Request, err error) { switch trace.Unwrap(err).(type) { case *trace.BadParameterError: @@ -201,15 +183,11 @@ func (s *SigningService) formatForwardResponseError(rw http.ResponseWriter, r *h // prepareSignedRequest creates a new HTTP request and rewrites the header from the original request and returns a new // HTTP request signed by STS AWS API. -func (s *SigningService) prepareSignedRequest(r *http.Request, re *endpoints.ResolvedEndpoint, sessionCtx *common.SessionContext) (*http.Request, error) { +func (s *SigningService) prepareSignedRequest(r *http.Request, payload []byte, re *endpoints.ResolvedEndpoint, sessionCtx *common.SessionContext) (*http.Request, error) { url, err := urlForResolvedEndpoint(r, re) if err != nil { return nil, trace.Wrap(err) } - payload, err := awsutils.GetAndReplaceReqBody(r) - if err != nil { - return nil, trace.Wrap(err) - } reqCopy, err := http.NewRequest(r.Method, url, bytes.NewReader(payload)) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/srv/app/aws/handler_test.go b/lib/srv/app/aws/handler_test.go index 0e36c3f2cc2d0..bd5ba3375aabd 100644 --- a/lib/srv/app/aws/handler_test.go +++ b/lib/srv/app/aws/handler_test.go @@ -31,7 +31,9 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/s3" + "github.com/google/go-cmp/cmp" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -45,39 +47,59 @@ import ( awsutils "github.com/gravitational/teleport/lib/utils/aws" ) -// TestAWSSignerHandler test the AWS SigningService APP handler logic with mocked STS signing credentials. -func TestAWSSignerHandler(t *testing.T) { - type check func(t *testing.T, resp *s3.ListBucketsOutput, err error) - checks := func(chs ...check) []check { return chs } +type makeRequest func(url string, provider client.ConfigProvider) error - hasNoErr := func() check { - return func(t *testing.T, resp *s3.ListBucketsOutput, err error) { - require.NoError(t, err) - } - } +func s3Request(url string, provider client.ConfigProvider) error { + s3Client := s3.New(provider, &aws.Config{ + Endpoint: &url, + }) + _, err := s3Client.ListBuckets(&s3.ListBucketsInput{}) + return err +} - hasStatusCode := func(wantStatusCode int) check { - return func(t *testing.T, resp *s3.ListBucketsOutput, err error) { - require.Error(t, err) - apiErr, ok := err.(awserr.RequestFailure) - if !ok { - t.Errorf("invalid error type: %T", err) - } - require.Equal(t, wantStatusCode, apiErr.StatusCode()) - } +func dynamoRequest(url string, provider client.ConfigProvider) error { + dynamoClient := dynamodb.New(provider, &aws.Config{ + Endpoint: &url, + }) + _, err := dynamoClient.Scan(&dynamodb.ScanInput{ + TableName: aws.String("test-table"), + }) + return err +} + +func hasStatusCode(wantStatusCode int) require.ErrorAssertionFunc { + return func(t require.TestingT, err error, msgAndArgs ...interface{}) { + var apiErr awserr.RequestFailure + require.ErrorAs(t, err, &apiErr, msgAndArgs...) + require.Equal(t, wantStatusCode, apiErr.StatusCode(), msgAndArgs...) } +} + +// TestAWSSignerHandler test the AWS SigningService APP handler logic with mocked STS signing credentials. +func TestAWSSignerHandler(t *testing.T) { + consoleApp, err := types.NewAppV3(types.Metadata{ + Name: "awsconsole", + }, types.AppSpecV3{ + URI: constants.AWSConsoleURL, + PublicAddr: "test.local", + }) + require.NoError(t, err) tests := []struct { name string + app types.Application awsClientSession *session.Session + request makeRequest wantHost string wantAuthCredService string wantAuthCredRegion string wantAuthCredKeyID string - checks []check + wantEventType events.AuditEvent + errAssertionFns []require.ErrorAssertionFunc }{ { name: "s3 access", + app: consoleApp, awsClientSession: session.Must(session.NewSession(&aws.Config{ Credentials: credentials.NewCredentials(&credentials.StaticProvider{Value: credentials.Value{ AccessKeyID: "fakeClientKeyID", @@ -85,16 +107,19 @@ func TestAWSSignerHandler(t *testing.T) { }}), Region: aws.String("us-west-2"), })), + request: s3Request, wantHost: "s3.us-west-2.amazonaws.com", wantAuthCredKeyID: "AKIDl", wantAuthCredService: "s3", wantAuthCredRegion: "us-west-2", - checks: checks( - hasNoErr(), - ), + wantEventType: &events.AppSessionRequest{}, + errAssertionFns: []require.ErrorAssertionFunc{ + require.NoError, + }, }, { name: "s3 access with different region", + app: consoleApp, awsClientSession: session.Must(session.NewSession(&aws.Config{ Credentials: credentials.NewCredentials(&credentials.StaticProvider{Value: credentials.Value{ AccessKeyID: "fakeClientKeyID", @@ -102,23 +127,79 @@ func TestAWSSignerHandler(t *testing.T) { }}), Region: aws.String("us-west-1"), })), + request: s3Request, wantHost: "s3.us-west-1.amazonaws.com", wantAuthCredKeyID: "AKIDl", wantAuthCredService: "s3", wantAuthCredRegion: "us-west-1", - checks: checks( - hasNoErr(), - ), + wantEventType: &events.AppSessionRequest{}, + errAssertionFns: []require.ErrorAssertionFunc{ + require.NoError, + }, }, { name: "s3 access missing credentials", + app: consoleApp, + awsClientSession: session.Must(session.NewSession(&aws.Config{ + Credentials: credentials.AnonymousCredentials, + Region: aws.String("us-west-1"), + })), + request: s3Request, + errAssertionFns: []require.ErrorAssertionFunc{ + hasStatusCode(http.StatusBadRequest), + }, + }, + { + name: "DynamoDB access", + app: consoleApp, + awsClientSession: session.Must(session.NewSession(&aws.Config{ + Credentials: credentials.NewCredentials(&credentials.StaticProvider{Value: credentials.Value{ + AccessKeyID: "fakeClientKeyID", + SecretAccessKey: "fakeClientSecret", + }}), + Region: aws.String("us-east-1"), + })), + request: dynamoRequest, + wantHost: "dynamodb.us-east-1.amazonaws.com", + wantAuthCredKeyID: "AKIDl", + wantAuthCredService: "dynamodb", + wantAuthCredRegion: "us-east-1", + wantEventType: &events.AppSessionDynamoDBRequest{}, + errAssertionFns: []require.ErrorAssertionFunc{ + require.NoError, + }, + }, + { + name: "DynamoDB access with different region", + app: consoleApp, + awsClientSession: session.Must(session.NewSession(&aws.Config{ + Credentials: credentials.NewCredentials(&credentials.StaticProvider{Value: credentials.Value{ + AccessKeyID: "fakeClientKeyID", + SecretAccessKey: "fakeClientSecret", + }}), + Region: aws.String("us-west-1"), + })), + request: dynamoRequest, + wantHost: "dynamodb.us-west-1.amazonaws.com", + wantAuthCredKeyID: "AKIDl", + wantAuthCredService: "dynamodb", + wantAuthCredRegion: "us-west-1", + wantEventType: &events.AppSessionDynamoDBRequest{}, + errAssertionFns: []require.ErrorAssertionFunc{ + require.NoError, + }, + }, + { + name: "DynamoDB access missing credentials", + app: consoleApp, awsClientSession: session.Must(session.NewSession(&aws.Config{ Credentials: credentials.AnonymousCredentials, Region: aws.String("us-west-1"), })), - checks: checks( + request: dynamoRequest, + errAssertionFns: []require.ErrorAssertionFunc{ hasStatusCode(http.StatusBadRequest), - ), + }, }, } for _, tc := range tests { @@ -132,14 +213,11 @@ func TestAWSSignerHandler(t *testing.T) { require.Equal(t, tc.wantAuthCredService, awsAuthHeader.Service) } - suite := createSuite(t, handler) + suite := createSuite(t, handler, tc.app) - s3Client := s3.New(tc.awsClientSession, &aws.Config{ - Endpoint: &suite.URL, - }) - resp, err := s3Client.ListBuckets(&s3.ListBucketsInput{}) - for _, check := range tc.checks { - check(t, resp, err) + err := tc.request(suite.URL, tc.awsClientSession) + for _, assertFn := range tc.errAssertionFns { + assertFn(t, err) } // Validate audit event. @@ -147,11 +225,25 @@ func TestAWSSignerHandler(t *testing.T) { require.Len(t, suite.emitter.C(), 1) event := <-suite.emitter.C() - appSessionEvent, ok := event.(*events.AppSessionRequest) - require.True(t, ok) - require.Equal(t, tc.wantHost, appSessionEvent.AWSHost) - require.Equal(t, tc.wantAuthCredService, appSessionEvent.AWSService) - require.Equal(t, tc.wantAuthCredRegion, appSessionEvent.AWSRegion) + switch appSessionEvent := event.(type) { + case *events.AppSessionDynamoDBRequest: + _, ok := tc.wantEventType.(*events.AppSessionDynamoDBRequest) + require.True(t, ok, "unexpected event type: wanted %T but got %T", tc.wantEventType, appSessionEvent) + require.Equal(t, tc.wantHost, appSessionEvent.AWSHost) + require.Equal(t, tc.wantAuthCredService, appSessionEvent.AWSService) + require.Equal(t, tc.wantAuthCredRegion, appSessionEvent.AWSRegion) + j, err := appSessionEvent.Body.MarshalJSON() + require.NoError(t, err) + require.Empty(t, cmp.Diff(`{"TableName":"test-table"}`, string(j))) + case *events.AppSessionRequest: + _, ok := tc.wantEventType.(*events.AppSessionRequest) + require.True(t, ok, "unexpected event type: wanted %T but got %T", tc.wantEventType, appSessionEvent) + require.Equal(t, tc.wantHost, appSessionEvent.AWSHost) + require.Equal(t, tc.wantAuthCredService, appSessionEvent.AWSService) + require.Equal(t, tc.wantAuthCredRegion, appSessionEvent.AWSRegion) + default: + require.FailNow(t, "wrong event type", "unexpected event type: wanted %T but got %T", tc.wantEventType, appSessionEvent) + } } else { require.Len(t, suite.emitter.C(), 0) } @@ -214,16 +306,9 @@ type suite struct { emitter *eventstest.ChannelEmitter } -func createSuite(t *testing.T, handler http.HandlerFunc) *suite { +func createSuite(t *testing.T, handler http.HandlerFunc, app types.Application) *suite { emitter := eventstest.NewChannelEmitter(1) user := auth.LocalUser{Username: "user"} - app, err := types.NewAppV3(types.Metadata{ - Name: "awsconsole", - }, types.AppSpecV3{ - URI: constants.AWSConsoleURL, - PublicAddr: "test.local", - }) - require.NoError(t, err) awsAPIMock := httptest.NewUnstartedServer(handler) awsAPIMock.StartTLS() @@ -231,8 +316,9 @@ func createSuite(t *testing.T, handler http.HandlerFunc) *suite { awsAPIMock.Close() }) - client := &http.Client{ - Transport: &http.Transport{ + svc, err := NewSigningService(SigningServiceConfig{ + getSigningCredentials: staticAWSCredentials, + RoundTripper: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, @@ -240,12 +326,12 @@ func createSuite(t *testing.T, handler http.HandlerFunc) *suite { return net.Dial(awsAPIMock.Listener.Addr().Network(), awsAPIMock.Listener.Addr().String()) }, }, - } + Clock: clockwork.NewFakeClock(), + }) + require.NoError(t, err) - svc, err := NewSigningService(SigningServiceConfig{ - getSigningCredentials: staticAWSCredentials, - Client: client, - Clock: clockwork.NewFakeClock(), + audit, err := common.NewAudit(common.AuditConfig{ + Emitter: emitter, }) require.NoError(t, err) @@ -254,7 +340,7 @@ func createSuite(t *testing.T, handler http.HandlerFunc) *suite { request = common.WithSessionContext(request, &common.SessionContext{ Identity: &user.Identity, App: app, - Emitter: emitter, + Audit: audit, }) svc.ServeHTTP(writer, request) diff --git a/lib/srv/app/common/audit.go b/lib/srv/app/common/audit.go new file mode 100644 index 0000000000000..abe2d42c4234b --- /dev/null +++ b/lib/srv/app/common/audit.go @@ -0,0 +1,241 @@ +/* +Copyright 2022 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package common + +import ( + "context" + "net/http" + + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/tlsca" + awsutils "github.com/gravitational/teleport/lib/utils/aws" +) + +// Audit defines an interface for app access audit events logger. +type Audit interface { + // OnSessionStart is called when new app session starts. + OnSessionStart(ctx context.Context, serverID string, identity *tlsca.Identity, app types.Application) error + // OnSessionEnd is called when an app session ends. + OnSessionEnd(ctx context.Context, serverID string, identity *tlsca.Identity, app types.Application) error + // OnSessionChunk is called when a new session chunk is created. + OnSessionChunk(ctx context.Context, sessionCtx *SessionContext, serverID string) error + // OnRequest is called when an app request is sent during the session and a response is received. + OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, res *http.Response, re *endpoints.ResolvedEndpoint) error + // OnDynamoDBRequest is called when app request for a DynamoDB API is sent and a response is received. + OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, res *http.Response, re *endpoints.ResolvedEndpoint) error + // EmitEvent emits the provided audit event. + EmitEvent(ctx context.Context, event apievents.AuditEvent) error +} + +// AuditConfig is the audit events emitter configuration. +type AuditConfig struct { + // Emitter is used to emit audit events. + Emitter apievents.Emitter +} + +// Check validates the config. +func (c *AuditConfig) Check() error { + if c.Emitter == nil { + return trace.BadParameter("missing Emitter") + } + return nil +} + +// audit provides methods for emitting app access audit events. +type audit struct { + // cfg is the audit events emitter configuration. + cfg AuditConfig + // log is used for logging + log logrus.FieldLogger +} + +// NewAudit returns a new instance of the audit events emitter. +func NewAudit(config AuditConfig) (Audit, error) { + if err := config.Check(); err != nil { + return nil, trace.Wrap(err) + } + return &audit{ + cfg: config, + log: logrus.WithField(trace.Component, "app:audit"), + }, nil +} + +// OnSessionStart is called when new app session starts. +func (a *audit) OnSessionStart(ctx context.Context, serverID string, identity *tlsca.Identity, app types.Application) error { + event := &apievents.AppSessionStart{ + Metadata: apievents.Metadata{ + Type: events.AppSessionStartEvent, + Code: events.AppSessionStartCode, + ClusterName: identity.RouteToApp.ClusterName, + }, + ServerMetadata: apievents.ServerMetadata{ + ServerID: serverID, + ServerNamespace: apidefaults.Namespace, + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: identity.RouteToApp.SessionID, + WithMFA: identity.MFAVerified, + }, + UserMetadata: identity.GetUserMetadata(), + ConnectionMetadata: apievents.ConnectionMetadata{ + RemoteAddr: identity.ClientIP, + }, + AppMetadata: apievents.AppMetadata{ + AppURI: app.GetURI(), + AppPublicAddr: app.GetPublicAddr(), + AppName: app.GetName(), + }, + } + return trace.Wrap(a.EmitEvent(ctx, event)) +} + +// OnSessionEnd is called when an app session ends. +func (a *audit) OnSessionEnd(ctx context.Context, serverID string, identity *tlsca.Identity, app types.Application) error { + event := &apievents.AppSessionEnd{ + Metadata: apievents.Metadata{ + Type: events.AppSessionEndEvent, + Code: events.AppSessionEndCode, + ClusterName: identity.RouteToApp.ClusterName, + }, + ServerMetadata: apievents.ServerMetadata{ + ServerID: serverID, + ServerNamespace: apidefaults.Namespace, + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: identity.RouteToApp.SessionID, + WithMFA: identity.MFAVerified, + }, + UserMetadata: identity.GetUserMetadata(), + ConnectionMetadata: apievents.ConnectionMetadata{ + RemoteAddr: identity.ClientIP, + }, + AppMetadata: apievents.AppMetadata{ + AppURI: app.GetURI(), + AppPublicAddr: app.GetPublicAddr(), + AppName: app.GetName(), + }, + } + return trace.Wrap(a.EmitEvent(ctx, event)) +} + +// OnSessionChunk is called when a new session chunk is created. +func (a *audit) OnSessionChunk(ctx context.Context, sessionCtx *SessionContext, serverID string) error { + event := &apievents.AppSessionChunk{ + Metadata: apievents.Metadata{ + Type: events.AppSessionChunkEvent, + Code: events.AppSessionChunkCode, + ClusterName: sessionCtx.Identity.RouteToApp.ClusterName, + }, + ServerMetadata: apievents.ServerMetadata{ + ServerID: serverID, + ServerNamespace: apidefaults.Namespace, + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: sessionCtx.Identity.RouteToApp.SessionID, + WithMFA: sessionCtx.Identity.MFAVerified, + }, + UserMetadata: sessionCtx.Identity.GetUserMetadata(), + AppMetadata: apievents.AppMetadata{ + AppURI: sessionCtx.App.GetURI(), + AppPublicAddr: sessionCtx.App.GetPublicAddr(), + AppName: sessionCtx.App.GetName(), + }, + SessionChunkID: sessionCtx.ChunkID, + } + return trace.Wrap(a.EmitEvent(ctx, event)) +} + +// OnRequest is called when an app request is sent during the session and a response is received. +func (a *audit) OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, res *http.Response, re *endpoints.ResolvedEndpoint) error { + event := &apievents.AppSessionRequest{ + Metadata: apievents.Metadata{ + Type: events.AppSessionRequestEvent, + Code: events.AppSessionRequestCode, + }, + AppMetadata: *MakeAppMetadata(sessionCtx.App), + Method: req.Method, + Path: req.URL.Path, + RawQuery: req.URL.RawQuery, + StatusCode: uint32(res.StatusCode), + AWSRequestMetadata: *MakeAWSRequestMetadata(req, re), + } + return trace.Wrap(a.EmitEvent(ctx, event)) +} + +// OnDynamoDBRequest is called when a DynamoDB app request is sent during the session. +func (a *audit) OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, res *http.Response, re *endpoints.ResolvedEndpoint) error { + // Try to read the body and JSON unmarshal it. + // If this fails, we still want to emit the rest of the event info; the request event Body is nullable, so it's ok if body is left nil here. + body, err := awsutils.UnmarshalRequestBody(req) + if err != nil { + a.log.WithError(err).Warn("Failed to read request body as JSON, omitting the body from the audit event.") + } + // get the API target from the request header, according to the API request format documentation: + // https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.LowLevelAPI.html#Programming.LowLevelAPI.RequestFormat + target := req.Header.Get(awsutils.TargetHeader) + event := &apievents.AppSessionDynamoDBRequest{ + Metadata: apievents.Metadata{ + Type: events.AppSessionDynamoDBRequestEvent, + Code: events.AppSessionDynamoDBRequestCode, + }, + UserMetadata: sessionCtx.Identity.GetUserMetadata(), + AppMetadata: *MakeAppMetadata(sessionCtx.App), + AWSRequestMetadata: *MakeAWSRequestMetadata(req, re), + SessionChunkID: sessionCtx.ChunkID, + StatusCode: uint32(res.StatusCode), + Path: req.URL.Path, + RawQuery: req.URL.RawQuery, + Method: req.Method, + Target: target, + Body: body, + } + return trace.Wrap(a.EmitEvent(ctx, event)) +} + +// EmitEvent emits the provided audit event. +func (a *audit) EmitEvent(ctx context.Context, event apievents.AuditEvent) error { + return trace.Wrap(a.cfg.Emitter.EmitAuditEvent(ctx, event)) +} + +// MakeAppMetadata returns common server metadata for database session. +func MakeAppMetadata(app types.Application) *apievents.AppMetadata { + return &apievents.AppMetadata{ + AppURI: app.GetURI(), + AppPublicAddr: app.GetPublicAddr(), + AppName: app.GetName(), + } +} + +// MakeAWSRequestMetadata is a helper to build AWSRequestMetadata from the provided request and endpoint. +// If the aws endpoint is nil, returns an empty request metadata. +func MakeAWSRequestMetadata(req *http.Request, awsEndpoint *endpoints.ResolvedEndpoint) *apievents.AWSRequestMetadata { + if awsEndpoint == nil { + return &apievents.AWSRequestMetadata{} + } + return &apievents.AWSRequestMetadata{ + AWSRegion: awsEndpoint.SigningRegion, + AWSService: awsEndpoint.SigningName, + AWSHost: req.Host, + } +} diff --git a/lib/srv/app/common/session.go b/lib/srv/app/common/session.go index f8fc367287b7a..cf1f0e692f78e 100644 --- a/lib/srv/app/common/session.go +++ b/lib/srv/app/common/session.go @@ -23,7 +23,6 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/tlsca" ) @@ -33,8 +32,10 @@ type SessionContext struct { Identity *tlsca.Identity // App is the requested identity. App types.Application - // Emitter is the audit log emitter. - Emitter events.Emitter + // ChunkID is the session chunk's uuid. + ChunkID string + // Audit is used to emit audit events for the session. + Audit Audit } // WithSessionContext adds session context to provided request. diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index c056243347433..20049ee7f0110 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -286,7 +286,11 @@ func New(ctx context.Context, c *Config) (*Server, error) { s.httpServer = s.newHTTPServer() // TCP server will handle TCP applications. - s.tcpServer = s.newTCPServer() + tcpServer, err := s.newTCPServer() + if err != nil { + return nil, trace.Wrap(err) + } + s.tcpServer = tcpServer // Create a new session cache, this holds sessions that can be used to // forward requests. @@ -795,7 +799,7 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { identity := authCtx.Identity.GetIdentity() switch { case app.IsAWSConsole(): - // Requests from AWS applications are singed by AWS Signature Version + // Requests from AWS applications are signed by AWS Signature Version // 4 algorithm. AWS CLI and AWS SDKs automatically use SigV4 for all // services that support it (All services expect Amazon SimpleDB but // this AWS service has been deprecated) @@ -842,15 +846,8 @@ func (s *Server) serveSession(w http.ResponseWriter, r *http.Request, identity * } defer session.release() - // Create session context. - sessionCtx := &common.SessionContext{ - Identity: identity, - App: app, - Emitter: session.streamWriter, - } - // Forward request to the target application. - session.fwd.ServeHTTP(w, common.WithSessionContext(r, sessionCtx)) + session.fwd.ServeHTTP(w, common.WithSessionContext(r, session.sessionCtx)) return nil } @@ -987,12 +984,18 @@ func (s *Server) newHTTPServer() *http.Server { } // newTCPServer creates a server that proxies TCP applications. -func (s *Server) newTCPServer() *tcpServer { - return &tcpServer{ - authClient: s.c.AuthClient, - hostID: s.c.HostID, - log: s.log, +func (s *Server) newTCPServer() (*tcpServer, error) { + audit, err := common.NewAudit(common.AuditConfig{ + Emitter: s.c.Emitter, + }) + if err != nil { + return nil, trace.Wrap(err) } + return &tcpServer{ + audit: audit, + hostID: s.c.HostID, + log: s.log, + }, nil } // getProxyPort tries to figure out the address the proxy is running at. diff --git a/lib/srv/app/server_test.go b/lib/srv/app/server_test.go index 3193c93fea1d5..af788b1696d7f 100644 --- a/lib/srv/app/server_test.go +++ b/lib/srv/app/server_test.go @@ -529,16 +529,41 @@ func TestRequestAuditEvents(t *testing.T) { require.NoError(t, err) var requestEventsReceived atomic.Uint64 + var chunkEventsReceived atomic.Uint64 serverStreamer, err := events.NewCallbackStreamer(events.CallbackStreamerConfig{ Inner: events.NewDiscardEmitter(), OnEmitAuditEvent: func(_ context.Context, _ libsession.ID, event apievents.AuditEvent) error { - if event.GetType() == events.AppSessionRequestEvent { + switch event.GetType() { + case events.AppSessionChunkEvent: + chunkEventsReceived.Add(1) + expectedEvent := &apievents.AppSessionChunk{ + Metadata: apievents.Metadata{ + Type: events.AppSessionChunkEvent, + Code: events.AppSessionChunkCode, + ClusterName: "root.example.com", + Index: 0, + }, + AppMetadata: apievents.AppMetadata{ + AppURI: app.Spec.URI, + AppPublicAddr: app.Spec.PublicAddr, + AppName: app.Metadata.Name, + }, + } + require.Empty(t, cmp.Diff( + expectedEvent, + event, + cmpopts.IgnoreTypes(apievents.ServerMetadata{}, apievents.SessionMetadata{}, apievents.UserMetadata{}, apievents.ConnectionMetadata{}), + cmpopts.IgnoreFields(apievents.Metadata{}, "ID", "ClusterName", "Time"), + cmpopts.IgnoreFields(apievents.AppSessionChunk{}, "SessionChunkID"), + )) + case events.AppSessionRequestEvent: requestEventsReceived.Add(1) - expectedEvent := &apievents.AppSessionRequest{ Metadata: apievents.Metadata{ - Type: events.AppSessionRequestEvent, - Code: events.AppSessionRequestCode, + Type: events.AppSessionRequestEvent, + Code: events.AppSessionRequestCode, + ClusterName: "root.example.com", + Index: 1, }, AppMetadata: apievents.AppMetadata{ AppURI: app.Spec.URI, @@ -570,10 +595,14 @@ func TestRequestAuditEvents(t *testing.T) { // make a request to generate events. s.checkHTTPResponse(t, s.clientCertificate, func(_ *http.Response) { + // wait until chunk events are generated before closing the server. + require.Eventually(t, func() bool { + return chunkEventsReceived.Load() == 1 + }, 500*time.Millisecond, 50*time.Millisecond, "app.session.chunk event not generated") // wait until request events are generated before closing the server. require.Eventually(t, func() bool { return requestEventsReceived.Load() == 1 - }, 500*time.Millisecond, 50*time.Millisecond, "app.request event not generated") + }, 500*time.Millisecond, 50*time.Millisecond, "app.session.request event not generated") }) searchEvents, _, err := s.authServer.AuditLog.SearchEvents(time.Time{}, time.Now().Add(time.Minute), "", []string{events.AppSessionChunkEvent}, 10, types.EventOrderDescending, "") diff --git a/lib/srv/app/session.go b/lib/srv/app/session.go index f7f718592e334..f8ca846c8ec86 100644 --- a/lib/srv/app/session.go +++ b/lib/srv/app/session.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" - apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" @@ -66,8 +65,10 @@ type sessionChunk struct { id string // fwd can rewrite and forward requests to the target application. fwd *forward.Forwarder - // streamWriter can emit events to the audit log. - streamWriter events.StreamWriter + // streamCloser closes the session chunk stream. + streamCloser utils.WriteContextCloser + // sessionCtx contains common context parameters for an App session. + sessionCtx *common.SessionContext // inflightCond protects and signals change of inflight inflightCond *sync.Cond @@ -111,11 +112,25 @@ func (s *Server) newSessionChunk(ctx context.Context, identity *tlsca.Identity, } // Create the stream writer that will write this chunk to the audit log. - var err error - sess.streamWriter, err = s.newStreamWriter(identity, app, sess.id) + streamWriter, err := s.newStreamWriter(app, sess.id) if err != nil { return nil, trace.Wrap(err) } + sess.streamCloser = streamWriter + + // Create session context. + audit, err := common.NewAudit(common.AuditConfig{ + Emitter: streamWriter, + }) + if err != nil { + return nil, trace.Wrap(err) + } + sess.sessionCtx = &common.SessionContext{ + Identity: identity, + App: app, + ChunkID: sess.id, + Audit: audit, + } for _, opt := range opts { if err = opt(ctx, sess, identity, app); err != nil { @@ -131,6 +146,10 @@ func (s *Server) newSessionChunk(ctx context.Context, identity *tlsca.Identity, return nil, trace.Wrap(err) } + // only emit a session chunk if we didnt get an error making the new session chunk + if err := sess.sessionCtx.Audit.OnSessionChunk(ctx, sess.sessionCtx, s.c.HostID); err != nil { + return nil, trace.Wrap(err) + } return sess, nil } @@ -159,7 +178,7 @@ func (s *Server) withJWTTokenForwarder(ctx context.Context, sess *sessionChunk, // Create a rewriting transport that will be used to forward requests. transport, err := newTransport(s.closeContext, &transportConfig{ - w: sess.streamWriter, + audit: sess.sessionCtx.Audit, app: app, publicPort: s.proxyPort, cipherSuites: s.c.CipherSuites, @@ -242,7 +261,7 @@ func (s *sessionChunk) close(ctx context.Context) error { s.inflightCond.L.Unlock() close(s.closeC) s.log.Debugf("Closed session chunk %s", s.id) - err := s.streamWriter.Close(ctx) + err := s.streamCloser.Close(ctx) return trace.Wrap(err) } @@ -255,7 +274,7 @@ func (s *Server) closeSession(sess *sessionChunk) { // newStreamWriter creates a session stream that will be used to record // requests that occur within this session chunk and upload the recording // to the Auth server. -func (s *Server) newStreamWriter(identity *tlsca.Identity, app types.Application, chunkID string) (events.StreamWriter, error) { +func (s *Server) newStreamWriter(app types.Application, chunkID string) (events.StreamWriter, error) { recConfig, err := s.c.AccessPoint.GetSessionRecordingConfig(s.closeContext) if err != nil { return nil, trace.Wrap(err) @@ -267,7 +286,7 @@ func (s *Server) newStreamWriter(identity *tlsca.Identity, app types.Application } // Create a sync or async streamer depending on configuration of cluster. - streamer, err := s.newStreamer(s.closeContext, chunkID, recConfig) + streamer, err := s.newStreamer(app, chunkID, recConfig) if err != nil { return nil, trace.Wrap(err) } @@ -288,33 +307,6 @@ func (s *Server) newStreamWriter(identity *tlsca.Identity, app types.Application return nil, trace.Wrap(err) } - // Emit an event to the Audit Log that a new session chunk has been created. - appSessionChunkEvent := &apievents.AppSessionChunk{ - Metadata: apievents.Metadata{ - Type: events.AppSessionChunkEvent, - Code: events.AppSessionChunkCode, - ClusterName: identity.RouteToApp.ClusterName, - }, - ServerMetadata: apievents.ServerMetadata{ - ServerID: s.c.HostID, - ServerNamespace: apidefaults.Namespace, - }, - SessionMetadata: apievents.SessionMetadata{ - SessionID: identity.RouteToApp.SessionID, - WithMFA: identity.MFAVerified, - }, - UserMetadata: identity.GetUserMetadata(), - AppMetadata: apievents.AppMetadata{ - AppURI: app.GetURI(), - AppPublicAddr: app.GetPublicAddr(), - AppName: app.GetName(), - }, - SessionChunkID: chunkID, - } - if err := s.c.AuthClient.EmitAuditEvent(s.closeContext, appSessionChunkEvent); err != nil { - return nil, trace.Wrap(err) - } - return streamWriter, nil } @@ -322,7 +314,7 @@ func (s *Server) newStreamWriter(identity *tlsca.Identity, app types.Application // of the server and the session, sync streamer sends the events // directly to the auth server and blocks if the events can not be received, // async streamer buffers the events to disk and uploads the events later -func (s *Server) newStreamer(ctx context.Context, chunkID string, recConfig types.SessionRecordingConfig) (events.Streamer, error) { +func (s *Server) newStreamer(app types.Application, chunkID string, recConfig types.SessionRecordingConfig) (events.Streamer, error) { if services.IsRecordSync(recConfig.GetMode()) { s.log.Debugf("Using sync streamer for session chunk %v.", chunkID) return s.c.AuthClient, nil @@ -337,7 +329,7 @@ func (s *Server) newStreamer(ctx context.Context, chunkID string, recConfig type if err != nil { return nil, trace.Wrap(err) } - return fileStreamer, nil + return events.NewTeeStreamer(fileStreamer, s.c.Emitter), nil } // createTracker creates a new session tracker for the session chunk. diff --git a/lib/srv/app/session_test.go b/lib/srv/app/session_test.go index 6408ed6f87d4a..f5e2bec0d6a89 100644 --- a/lib/srv/app/session_test.go +++ b/lib/srv/app/session_test.go @@ -34,7 +34,7 @@ func newSessionChunk(timeout time.Duration) *sessionChunk { inflightCond: sync.NewCond(&sync.Mutex{}), closeTimeout: timeout, log: logrus.NewEntry(logrus.StandardLogger()), - streamWriter: &events.DiscardStream{}, + streamCloser: &events.DiscardStream{}, } } diff --git a/lib/srv/app/tcpserver.go b/lib/srv/app/tcpserver.go index 5793b1b426cef..944882aba2366 100644 --- a/lib/srv/app/tcpserver.go +++ b/lib/srv/app/tcpserver.go @@ -25,17 +25,15 @@ import ( apidefaults "github.com/gravitational/teleport/api/defaults" apitypes "github.com/gravitational/teleport/api/types" - apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) type tcpServer struct { - authClient *auth.Client - hostID string - log logrus.FieldLogger + audit common.Audit + hostID string + log logrus.FieldLogger } // handleConnection handles connection from a TCP application. @@ -54,13 +52,11 @@ func (s *tcpServer) handleConnection(ctx context.Context, clientConn net.Conn, i if err != nil { return trace.Wrap(err) } - err = s.emitStartEvent(ctx, identity, app) - if err != nil { + if err := s.audit.OnSessionStart(ctx, s.hostID, identity, app); err != nil { return trace.Wrap(err) } defer func() { - err = s.emitEndEvent(ctx, identity, app) - if err != nil { + if err := s.audit.OnSessionEnd(ctx, s.hostID, identity, app); err != nil { s.log.WithError(err).Warnf("Failed to emit session end event for app %v.", app.GetName()) } }() @@ -70,57 +66,3 @@ func (s *tcpServer) handleConnection(ctx context.Context, clientConn net.Conn, i } return nil } - -func (s *tcpServer) emitStartEvent(ctx context.Context, identity *tlsca.Identity, app apitypes.Application) error { - return s.authClient.EmitAuditEvent(ctx, &apievents.AppSessionStart{ - Metadata: apievents.Metadata{ - Type: events.AppSessionStartEvent, - Code: events.AppSessionStartCode, - ClusterName: identity.RouteToApp.ClusterName, - }, - ServerMetadata: apievents.ServerMetadata{ - ServerID: s.hostID, - ServerNamespace: apidefaults.Namespace, - }, - SessionMetadata: apievents.SessionMetadata{ - SessionID: identity.RouteToApp.SessionID, - WithMFA: identity.MFAVerified, - }, - UserMetadata: identity.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - RemoteAddr: identity.ClientIP, - }, - AppMetadata: apievents.AppMetadata{ - AppURI: app.GetURI(), - AppPublicAddr: app.GetPublicAddr(), - AppName: app.GetName(), - }, - }) -} - -func (s *tcpServer) emitEndEvent(ctx context.Context, identity *tlsca.Identity, app apitypes.Application) error { - return s.authClient.EmitAuditEvent(ctx, &apievents.AppSessionEnd{ - Metadata: apievents.Metadata{ - Type: events.AppSessionEndEvent, - Code: events.AppSessionEndCode, - ClusterName: identity.RouteToApp.ClusterName, - }, - ServerMetadata: apievents.ServerMetadata{ - ServerID: s.hostID, - ServerNamespace: apidefaults.Namespace, - }, - SessionMetadata: apievents.SessionMetadata{ - SessionID: identity.RouteToApp.SessionID, - WithMFA: identity.MFAVerified, - }, - UserMetadata: identity.GetUserMetadata(), - ConnectionMetadata: apievents.ConnectionMetadata{ - RemoteAddr: identity.ClientIP, - }, - AppMetadata: apievents.AppMetadata{ - AppURI: app.GetURI(), - AppPublicAddr: app.GetPublicAddr(), - AppName: app.GetName(), - }, - }) -} diff --git a/lib/srv/app/transport.go b/lib/srv/app/transport.go index c10f1460c2c16..27814f6d6ed52 100644 --- a/lib/srv/app/transport.go +++ b/lib/srv/app/transport.go @@ -30,12 +30,10 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" - apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib" "github.com/gravitational/teleport/lib/defaults" - "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/utils" @@ -47,7 +45,7 @@ type transportConfig struct { publicPort string cipherSuites []uint16 jwt string - w events.StreamWriter + audit common.Audit traits wrappers.Traits log logrus.FieldLogger user string @@ -55,8 +53,8 @@ type transportConfig struct { // Check validates configuration. func (c *transportConfig) Check() error { - if c.w == nil { - return trace.BadParameter("stream writer missing") + if c.audit == nil { + return trace.BadParameter("audit writer missing") } if c.app == nil { return trace.BadParameter("app missing") @@ -147,6 +145,11 @@ func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) { return nil, trace.Wrap(err) } + sessionCtx, err := common.GetSessionContext(r) + if err != nil { + return nil, trace.Wrap(err) + } + // Forward the request to the target application and emit an audit event. resp, err := t.tr.RoundTrip(r) if err != nil { @@ -154,7 +157,7 @@ func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) { } // Emit the event to the audit log. - if err := t.emitAuditEvent(r, resp); err != nil { + if err := t.c.audit.OnRequest(t.closeContext, sessionCtx, r, resp, nil /*aws endpoint*/); err != nil { return nil, trace.Wrap(err) } @@ -273,29 +276,6 @@ func (t *transport) rewriteRedirect(resp *http.Response) error { return nil } -// emitAuditEvent writes the request and response to audit stream. -func (t *transport) emitAuditEvent(req *http.Request, resp *http.Response) error { - appSessionRequestEvent := &apievents.AppSessionRequest{ - Metadata: apievents.Metadata{ - Type: events.AppSessionRequestEvent, - Code: events.AppSessionRequestCode, - }, - Method: req.Method, - Path: req.URL.Path, - RawQuery: req.URL.RawQuery, - StatusCode: uint32(resp.StatusCode), - AppMetadata: apievents.AppMetadata{ - AppURI: t.c.app.GetURI(), - AppPublicAddr: t.c.app.GetPublicAddr(), - AppName: t.c.app.GetName(), - }, - } - if err := t.c.w.EmitAuditEvent(t.closeContext, appSessionRequestEvent); err != nil { - return trace.Wrap(err) - } - return nil -} - // configureTLS creates and configures a *tls.Config that will be used for // mutual authentication. func configureTLS(c *transportConfig) (*tls.Config, error) { diff --git a/lib/utils/aws/aws.go b/lib/utils/aws/aws.go index 39367083e7b57..020ec4b88ca4a 100644 --- a/lib/utils/aws/aws.go +++ b/lib/utils/aws/aws.go @@ -30,6 +30,10 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/gravitational/trace" + + "github.com/gravitational/teleport" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/utils" ) const ( @@ -51,6 +55,14 @@ const ( credentialAuthHeaderElem = "Credential" signedHeaderAuthHeaderElem = "SignedHeaders" signatureAuthHeaderElem = "Signature" + // TargetHeader is a header containing the API target. + // Format: target_version.operation + // Example: DynamoDB_20120810.Scan + TargetHeader = "X-Amz-Target" + // AmzJSON1_0 is an AWS Content-Type header that indicates the media type is JSON. + AmzJSON1_0 = "application/x-amz-json-1.0" + // AmzJSON1_1 is an AWS Content-Type header that indicates the media type is JSON. + AmzJSON1_1 = "application/x-amz-json-1.1" ) // SigV4 contains parsed content of the AWS Authorization header. @@ -129,8 +141,8 @@ func GetAndReplaceReqBody(req *http.Request) ([]byte, error) { if req.Body == nil || req.Body == http.NoBody { return []byte{}, nil } - // req.Body is closed during drainBody call. - payload, err := drainBody(req.Body) + // req.Body is closed during tryDrainBody call. + payload, err := tryDrainBody(req.Body) if err != nil { return nil, trace.Wrap(err) } @@ -140,16 +152,20 @@ func GetAndReplaceReqBody(req *http.Request) ([]byte, error) { return payload, nil } -// drainBody drains the body, close the reader and returns the read bytes. -func drainBody(b io.ReadCloser) ([]byte, error) { - payload, err := io.ReadAll(b) +// tryDrainBody tries to drain and close the body, returning the read bytes. +// It may fail to completely drain the body if the size of the body exceeds MaxHTTPRequestSize. +func tryDrainBody(b io.ReadCloser) (payload []byte, err error) { + defer func() { + if closeErr := b.Close(); closeErr != nil { + err = trace.NewAggregate(err, closeErr) + } + }() + payload, err = utils.ReadAtMost(b, teleport.MaxHTTPRequestSize) if err != nil { - return nil, trace.Wrap(err) - } - if err = b.Close(); err != nil { - return nil, trace.Wrap(err) + err = trace.Wrap(err) + return } - return payload, nil + return } // VerifyAWSSignature verifies the request signature ensuring that the request originates from tsh aws command execution @@ -307,3 +323,40 @@ func (roles Roles) FindRolesByName(name string) (result Roles) { } return } + +// UnmarshalRequestBody reads and unmarshals a JSON request body into a protobuf Struct wrapper. +// If the request is not a recognized AWS JSON media type, or the body cannot be read, or the body +// is not valid JSON, then this function returns a nil value and an error. +// The protobuf Struct wrapper is useful for serializing JSON into a protobuf, because otherwise when the +// protobuf is marshaled it will re-marshall a JSON string field with escape characters or base64 encode +// a []byte field. +// Examples showing differences: +// - JSON string in proto: `{"Table": "some-table"}` --marshal to JSON--> `"{\"Table\": \"some-table\"}"` +// - bytes in proto: []byte --marshal to JSON--> `eyJUYWJsZSI6ICJzb21lLXRhYmxlIn0K` (base64 encoded) +// - *Struct in proto: *Struct --marshal to JSON--> `{"Table": "some-table"}` (unescaped JSON) +func UnmarshalRequestBody(req *http.Request) (*apievents.Struct, error) { + contentType := req.Header.Get("Content-Type") + if !isJSON(contentType) { + return nil, trace.BadParameter("invalid JSON request Content-Type: %q", contentType) + } + jsonBody, err := GetAndReplaceReqBody(req) + if err != nil { + return nil, trace.Wrap(err) + } + s := &apievents.Struct{} + if err := s.UnmarshalJSON(jsonBody); err != nil { + return nil, trace.Wrap(err) + } + return s, nil +} + +// isJSON returns true if the Content-Type is recognized as standard JSON or any non-standard +// Amazon Content-Type header that indicates JSON media type. +func isJSON(contentType string) bool { + switch contentType { + case "application/json", AmzJSON1_0, AmzJSON1_1: + return true + default: + return false + } +}