From 4f81b26b9f39da12f7fd26d9dbadc8b31c3eb8aa Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Tue, 29 Nov 2022 19:30:48 -0800 Subject: [PATCH 1/5] Refactor app access * Move logic out of RoundTrip and into ServeHTTP as a middleware before handing off to oxy forwarder * Move AWS signing service code into lib/utils/aws/signing.go --- lib/httplib/httplib.go | 48 ++++++ lib/kube/proxy/forwarder.go | 46 +----- lib/srv/app/aws/handler.go | 271 +++++++++++++------------------- lib/srv/app/aws/handler_test.go | 61 ++++--- lib/srv/app/azure/handler.go | 96 ++++++----- lib/srv/app/common/audit.go | 34 ++-- lib/srv/app/server.go | 34 ++-- lib/srv/app/session.go | 49 +++--- lib/srv/app/transport.go | 31 ++-- lib/utils/aws/aws.go | 45 ++++-- lib/utils/aws/signing.go | 194 +++++++++++++++++++++++ 11 files changed, 547 insertions(+), 362 deletions(-) create mode 100644 lib/utils/aws/signing.go diff --git a/lib/httplib/httplib.go b/lib/httplib/httplib.go index b1d1a61368a95..713243f8e8c2d 100644 --- a/lib/httplib/httplib.go +++ b/lib/httplib/httplib.go @@ -246,3 +246,51 @@ func SafeRedirect(w http.ResponseWriter, r *http.Request, redirectURL string) er http.Redirect(w, r, parsedURL.RequestURI(), http.StatusFound) return nil } + +// ResponseStatusRecorder is an http.ResponseWriter that records the response status code. +type ResponseStatusRecorder struct { + http.ResponseWriter + flusher http.Flusher + status int +} + +// NewResponseStatusRecorder makes and returns a ResponseStatusRecorder. +func NewResponseStatusRecorder(w http.ResponseWriter) *ResponseStatusRecorder { + rec := &ResponseStatusRecorder{ResponseWriter: w} + if flusher, ok := w.(http.Flusher); ok { + rec.flusher = flusher + } + return rec +} + +// WriteHeader sends an HTTP response header with the provided +// status code and save the status code in the recorder. +func (r *ResponseStatusRecorder) WriteHeader(status int) { + r.status = status + r.ResponseWriter.WriteHeader(status) +} + +// Flush optionally flushes the inner ResponseWriter if it supports that. +// Otherwise, Flush is a noop. +// +// Flush is optionally used by github.com/gravitational/oxy/forward to flush +// pending data on streaming HTTP responses (like streaming pod logs). +// +// Without this, oxy/forward will handle streaming responses by accumulating +// ~32kb of response in a buffer before flushing it. +func (r *ResponseStatusRecorder) Flush() { + if r.flusher != nil { + r.flusher.Flush() + } +} + +// Status returns the recorded status after WriteHeader is called, or StatusOK if WriteHeader hasn't been called +// explicitly. +func (r *ResponseStatusRecorder) Status() int { + // http.ResponseWriter implicitly sets StatusOK, if WriteHeader hasn't been + // explicitly called. + if r.status == 0 { + return http.StatusOK + } + return r.status +} diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 319b88879b080..8a55704e91e37 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -1553,7 +1553,7 @@ func (f *Forwarder) catchAll(ctx *authContext, w http.ResponseWriter, req *http. f.log.Errorf("Failed to set up forwarding headers: %v.", err) return nil, trace.Wrap(err) } - rw := newResponseStatusRecorder(w) + rw := httplib.NewResponseStatusRecorder(w) sess.forwarder.ServeHTTP(rw, req) if sess.noAuditEvents { @@ -1577,7 +1577,7 @@ func (f *Forwarder) catchAll(ctx *authContext, w http.ResponseWriter, req *http. }, RequestPath: req.URL.Path, Verb: req.Method, - ResponseCode: int32(rw.getStatus()), + ResponseCode: int32(rw.Status()), KubernetesClusterMetadata: ctx.eventClusterMeta(), } r := parseResourcePath(req.URL.Path) @@ -2109,45 +2109,3 @@ func (f *Forwarder) removeKubeDetails(name string) { } delete(f.clusterDetails, name) } - -type responseStatusRecorder struct { - http.ResponseWriter - flusher http.Flusher - status int -} - -func newResponseStatusRecorder(w http.ResponseWriter) *responseStatusRecorder { - rec := &responseStatusRecorder{ResponseWriter: w} - if flusher, ok := w.(http.Flusher); ok { - rec.flusher = flusher - } - return rec -} - -func (r *responseStatusRecorder) WriteHeader(status int) { - r.status = status - r.ResponseWriter.WriteHeader(status) -} - -// Flush optionally flushes the inner ResponseWriter if it supports that. -// Otherwise, Flush is a noop. -// -// Flush is optionally used by github.com/gravitational/oxy/forward to flush -// pending data on streaming HTTP responses (like streaming pod logs). -// -// Without this, oxy/forward will handle streaming responses by accumulating -// ~32kb of response in a buffer before flushing it. -func (r *responseStatusRecorder) Flush() { - if r.flusher != nil { - r.flusher.Flush() - } -} - -func (r *responseStatusRecorder) getStatus() int { - // http.ResponseWriter implicitly sets StatusOK, if WriteHeader hasn't been - // explicitly called. - if r.status == 0 { - return http.StatusOK - } - return r.status -} diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go index 7f1612396e75f..12188dabb6b0f 100644 --- a/lib/srv/app/aws/handler.go +++ b/lib/srv/app/aws/handler.go @@ -17,211 +17,171 @@ limitations under the License. package aws import ( - "bytes" - "io" "net/http" "net/url" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/endpoints" - awssession "github.com/aws/aws-sdk-go/aws/session" "github.com/gravitational/oxy/forward" oxyutils "github.com/gravitational/oxy/utils" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" "github.com/sirupsen/logrus" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/srv/app/common" awsutils "github.com/gravitational/teleport/lib/utils/aws" ) -// NewSigningService creates a new instance of SigningService. -func NewSigningService(config SigningServiceConfig) (*SigningService, error) { - if err := config.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - svc := &SigningService{ - SigningServiceConfig: config, - } - - fwd, err := forward.New( - forward.RoundTripper(svc), - forward.ErrorHandler(oxyutils.ErrorHandlerFunc(svc.formatForwardResponseError)), - forward.PassHostHeader(true), - ) - if err != nil { - return nil, trace.Wrap(err) - } - svc.Forwarder = fwd - return svc, nil -} - -// SigningService is an AWS CLI proxy service that signs AWS requests -// based on user identity. -type SigningService struct { - // SigningServiceConfig is the SigningService configuration. - SigningServiceConfig - - // Forwarder signs and forwards the request to AWS API. - *forward.Forwarder +// signerHandler is an http.Handler for signing and forwarding requests to AWS API. +type signerHandler struct { + // fwd is a Forwarder used to forward signed requests to AWS API. + fwd *forward.Forwarder + // AwsSignerHandlerConfig is the awsSignerHandler configuration. + SignerHandlerConfig } -// SigningServiceConfig is the SigningService configuration. -type SigningServiceConfig struct { +// SignerHandlerConfig is the awsSignerHandler configuration. +type SignerHandlerConfig struct { + // Log is a logger for the handler. + Log logrus.FieldLogger // RoundTripper is an http.RoundTripper instance used for requests. RoundTripper http.RoundTripper - // Log is the Logger. - Log logrus.FieldLogger - // Session is AWS session. - Session *awssession.Session - // Clock is used to override time in tests. - Clock clockwork.Clock - - // getSigningCredentials allows so set the function responsible for obtaining STS credentials. - // Used in tests to set static AWS credentials and skip API call. - getSigningCredentials getSigningCredentialsFunc + // SigningService is used to sign requests before forwarding them. + *awsutils.SigningService } -// CheckAndSetDefaults validates the SigningServiceConfig config. -func (s *SigningServiceConfig) CheckAndSetDefaults() error { - if s.RoundTripper == nil { - tr, err := defaults.Transport() - if err != nil { - return trace.Wrap(err) - } - s.RoundTripper = tr - } - if s.Clock == nil { - s.Clock = clockwork.NewRealClock() - } - if s.Log == nil { - s.Log = logrus.WithField(trace.Component, "aws:signer") +// CheckAndSetDefaults validates the AwsSignerHandlerConfig. +func (cfg *SignerHandlerConfig) CheckAndSetDefaults() error { + if cfg.SigningService == nil { + return trace.BadParameter("missing SigningService") } - if s.Session == nil { - ses, err := awssession.NewSessionWithOptions(awssession.Options{ - SharedConfigState: awssession.SharedConfigEnable, - }) + if cfg.RoundTripper == nil { + tr, err := defaults.Transport() if err != nil { return trace.Wrap(err) } - s.Session = ses + cfg.RoundTripper = tr } - if s.getSigningCredentials == nil { - s.getSigningCredentials = getAWSCredentialsFromSTSAPI + if cfg.Log == nil { + cfg.Log = logrus.WithField(trace.Component, "aws:signer") } return nil } -// RoundTrip handles incoming requests and forwards them to the proper AWS API. -// Handling steps: -// 1) Decoded Authorization Header. Authorization Header example: -// -// Authorization: AWS4-HMAC-SHA256 -// Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request, -// SignedHeaders=host;range;x-amz-date, -// Signature=fe5f80f77d5fa3beca038a248ff027d0445342fe2855ddc963176630326f1024 -// -// 2. Extract credential section from credential Authorization Header. -// 3. Extract aws-region and aws-service from the credential section. -// 4. Build AWS API endpoint based on extracted aws-region and aws-service fields. -// Not that for endpoint resolving the https://github.com/aws/aws-sdk-go/aws/endpoints/endpoints.go -// package is used and when Amazon releases a new API the dependency update is needed. -// 5. Sign HTTP request. -// 6. Forward the signed HTTP request to the AWS API. -func (s *SigningService) RoundTrip(req *http.Request) (*http.Response, error) { - defer req.Body.Close() - sessionCtx, err := common.GetSessionContext(req) - if err != nil { +// NewAWSSignerHandler creates a new request handler for signing and forwarding requests to AWS API. +func NewAWSSignerHandler(config SignerHandlerConfig) (http.Handler, error) { + if err := config.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } - resolvedEndpoint, err := resolveEndpoint(req) - if err != nil { - return nil, trace.Wrap(err) + + handler := &signerHandler{ + SignerHandlerConfig: config, } - payload, err := awsutils.GetAndReplaceReqBody(req) + fwd, err := forward.New( + forward.RoundTripper(config.RoundTripper), + forward.ErrorHandler(oxyutils.ErrorHandlerFunc(handler.formatForwardResponseError)), + forward.PassHostHeader(true), + ) if err != nil { return nil, trace.Wrap(err) } - signedReq, err := s.prepareSignedRequest(req, payload, resolvedEndpoint, sessionCtx) - if err != nil { - return nil, trace.Wrap(err) + handler.fwd = fwd + return handler, nil +} + +// formatForwardResponseError converts an error to a status code and writes the code to a response. +func (s *signerHandler) formatForwardResponseError(rw http.ResponseWriter, r *http.Request, err error) { + s.Log.WithError(err).Debugf("Failed to process request.") + common.SetTeleportAPIErrorHeader(rw, err) + + // Convert trace error type to HTTP and write response. + code := trace.ErrorToCode(err) + http.Error(rw, http.StatusText(code), code) +} + +// ServeHTTP handles incoming requests by signing them and then forwarding them to the proper AWS API. +func (s *signerHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if err := s.serveHTTP(w, req); err != nil { + s.formatForwardResponseError(w, req, err) + return } - resp, err := s.RoundTripper.RoundTrip(signedReq) +} + +// serveHTTP is a helper to simplify error handling in ServeHTTP. +func (s *signerHandler) serveHTTP(w http.ResponseWriter, req *http.Request) error { + sessCtx, err := common.GetSessionContext(req) if err != nil { - return nil, trace.Wrap(err) - } - // emit audit event with original request, but change the URL since we resolved and rewrote it. - signedReq.Body = io.NopCloser(bytes.NewReader(payload)) - if isDynamoDBEndpoint(resolvedEndpoint) { - err = sessionCtx.Audit.OnDynamoDBRequest(req.Context(), sessionCtx, signedReq, resp, resolvedEndpoint) - } else { - err = sessionCtx.Audit.OnRequest(req.Context(), sessionCtx, signedReq, resp, resolvedEndpoint) + return trace.Wrap(err) } + + // It's important that we resolve the endpoint before modifying the request headers, + // as they may be needed to resolve the endpoint correctly. + re, err := resolveEndpoint(req) if err != nil { - s.Log.WithError(err).Warn("Failed to emit audit event.") + return trace.Wrap(err) } - return resp, nil -} -func (s *SigningService) formatForwardResponseError(rw http.ResponseWriter, r *http.Request, err error) { - common.SetTeleportAPIErrorHeader(rw, err) - - switch trace.Unwrap(err).(type) { - case *trace.BadParameterError: - s.Log.Debugf("Failed to process request: %v.", err) - rw.WriteHeader(http.StatusBadRequest) - case *trace.AccessDeniedError: - s.Log.Infof("Failed to process request: %v.", err) - rw.WriteHeader(http.StatusForbidden) - default: - s.Log.Warnf("Failed to process request: %v.", err) - rw.WriteHeader(http.StatusInternalServerError) + // rewrite headers before signing the request to avoid signature validation problems. + unsignedReq, err := rewriteRequest(req, re) + if err != nil { + return trace.Wrap(err) } -} -// prepareSignedRequest creates a new HTTP request and rewrites the header from the original request and returns a new -// HTTP request signed by STS AWS API. -func (s *SigningService) prepareSignedRequest(r *http.Request, payload []byte, re *endpoints.ResolvedEndpoint, sessionCtx *common.SessionContext) (*http.Request, error) { - url, err := urlForResolvedEndpoint(r, re) + signedReq, err := s.SignRequest(unsignedReq, + &awsutils.SigningCtx{ + SigningName: re.SigningName, + SigningRegion: re.SigningRegion, + Expiry: sessCtx.Identity.Expires, + SessionName: sessCtx.Identity.Username, + AWSRoleArn: sessCtx.Identity.RouteToApp.AWSRoleARN, + AWSExternalID: sessCtx.App.GetAWSExternalID(), + }) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) + } + recorder := httplib.NewResponseStatusRecorder(w) + s.fwd.ServeHTTP(recorder, signedReq) + + var auditErr error + if isDynamoDBEndpoint(re) { + auditErr = sessCtx.Audit.OnDynamoDBRequest(unsignedReq.Context(), sessCtx, unsignedReq, recorder.Status(), re) + } else { + auditErr = sessCtx.Audit.OnRequest(unsignedReq.Context(), sessCtx, unsignedReq, recorder.Status(), re) } - reqCopy, err := http.NewRequest(r.Method, url, bytes.NewReader(payload)) if err != nil { - return nil, trace.Wrap(err) + // log but don't return the error, because we already handed off request/response handling to the oxy forwarder. + s.Log.WithError(auditErr).Warn("Failed to emit audit event.") } - rewriteHeaders(r, reqCopy) - // Sign the copy of the request. - signer := awsutils.NewSigner(s.getSigningCredentials(s.Session, sessionCtx), re.SigningName) - _, err = signer.Sign(reqCopy, bytes.NewReader(payload), re.SigningName, re.SigningRegion, s.Clock.Now()) + return nil +} + +// rewriteRequest rewrites a request to remove Teleport reserved headers, sets the url, and sets host. +func rewriteRequest(r *http.Request, re *endpoints.ResolvedEndpoint) (*http.Request, error) { + // shallow copy request and make a deep copy for header modification. + outReq := &http.Request{} + *outReq = *r + outReq.Header = r.Header.Clone() + u, err := urlForResolvedEndpoint(r, re) if err != nil { return nil, trace.Wrap(err) } - return reqCopy, nil -} + outReq.URL = u + outReq.Host = u.Host -func rewriteHeaders(r *http.Request, reqCopy *http.Request) { - for key, values := range r.Header { + for key := range outReq.Header { // Remove Teleport app headers. - if common.IsReservedHeader(key) { - continue - } - for _, v := range values { - reqCopy.Header.Add(key, v) + if common.IsReservedHeader(key) || http.CanonicalHeaderKey(key) == "Content-Length" { + outReq.Header.Del(key) } } - reqCopy.Header.Del("Content-Length") + return outReq, nil } -// urlForResolvedEndpoint creates an URL based on input request and resolved endpoint. -func urlForResolvedEndpoint(r *http.Request, re *endpoints.ResolvedEndpoint) (string, error) { +// urlForResolvedEndpoint creates a URL based on input request and resolved endpoint. +func urlForResolvedEndpoint(r *http.Request, re *endpoints.ResolvedEndpoint) (*url.URL, error) { resolvedURL, err := url.Parse(re.URL) if err != nil { - return "", trace.Wrap(err) + return nil, trace.Wrap(err) } // Replaces scheme and host. Keeps original path etc. @@ -232,20 +192,5 @@ func urlForResolvedEndpoint(r *http.Request, re *endpoints.ResolvedEndpoint) (st if resolvedURL.Scheme != "" { clone.Scheme = resolvedURL.Scheme } - return clone.String(), nil -} - -type getSigningCredentialsFunc func(c client.ConfigProvider, sessionCtx *common.SessionContext) *credentials.Credentials - -func getAWSCredentialsFromSTSAPI(provider client.ConfigProvider, sessionCtx *common.SessionContext) *credentials.Credentials { - return stscreds.NewCredentials(provider, sessionCtx.Identity.RouteToApp.AWSRoleARN, - func(cred *stscreds.AssumeRoleProvider) { - cred.RoleSessionName = sessionCtx.Identity.Username - cred.Expiry.SetExpiration(sessionCtx.Identity.Expires, 0) - - if externalID := sessionCtx.App.GetAWSExternalID(); externalID != "" { - cred.ExternalID = aws.String(externalID) - } - }, - ) + return &clone, nil } diff --git a/lib/srv/app/aws/handler_test.go b/lib/srv/app/aws/handler_test.go index bd5ba3375aabd..2c05aa129be7b 100644 --- a/lib/srv/app/aws/handler_test.go +++ b/lib/srv/app/aws/handler_test.go @@ -23,7 +23,9 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "testing" + "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" @@ -40,7 +42,6 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/events/eventstest" "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/tlsca" @@ -59,7 +60,8 @@ func s3Request(url string, provider client.ConfigProvider) error { func dynamoRequest(url string, provider client.ConfigProvider) error { dynamoClient := dynamodb.New(provider, &aws.Config{ - Endpoint: &url, + Endpoint: &url, + MaxRetries: aws.Int(0), }) _, err := dynamoClient.Scan(&dynamodb.ScanInput{ TableName: aws.String("test-table"), @@ -204,7 +206,7 @@ func TestAWSSignerHandler(t *testing.T) { } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - handler := func(writer http.ResponseWriter, request *http.Request) { + mockAWSHandler := func(writer http.ResponseWriter, request *http.Request) { require.Equal(t, tc.wantHost, request.Host) awsAuthHeader, err := awsutils.ParseSigV4(request.Header.Get(awsutils.AuthorizationHeader)) require.NoError(t, err) @@ -213,7 +215,7 @@ func TestAWSSignerHandler(t *testing.T) { require.Equal(t, tc.wantAuthCredService, awsAuthHeader.Service) } - suite := createSuite(t, handler, tc.app) + suite := createSuite(t, mockAWSHandler, tc.app, clockwork.NewRealClock()) err := tc.request(suite.URL, tc.awsClientSession) for _, assertFn := range tc.errAssertionFns { @@ -257,7 +259,7 @@ func TestURLForResolvedEndpoint(t *testing.T) { inputReq *http.Request inputResolvedEnpoint *endpoints.ResolvedEndpoint requireError require.ErrorAssertionFunc - expectURL string + expectURL *url.URL }{ { name: "bad resolved endpoint", @@ -273,7 +275,12 @@ func TestURLForResolvedEndpoint(t *testing.T) { inputResolvedEnpoint: &endpoints.ResolvedEndpoint{ URL: "https://local.test.com", }, - expectURL: "https://local.test.com/hello/world?aa=2", + expectURL: &url.URL{ + Scheme: "https", + Host: "local.test.com", + Path: "/hello/world", + RawQuery: "aa=2", + }, requireError: require.NoError, }, } @@ -295,7 +302,7 @@ func mustNewRequest(t *testing.T, method, url string, body io.Reader) *http.Requ return r } -func staticAWSCredentials(client.ConfigProvider, *common.SessionContext) *credentials.Credentials { +func staticAWSCredentials(client.ConfigProvider, time.Time, string, string, string) *credentials.Credentials { return credentials.NewStaticCredentials("AKIDl", "SECRET", "SESSION") } @@ -306,18 +313,34 @@ type suite struct { emitter *eventstest.ChannelEmitter } -func createSuite(t *testing.T, handler http.HandlerFunc, app types.Application) *suite { +func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Application, clock clockwork.Clock) *suite { emitter := eventstest.NewChannelEmitter(1) - user := auth.LocalUser{Username: "user"} + identity := tlsca.Identity{ + Username: "user", + Expires: clock.Now().Add(time.Hour), + RouteToApp: tlsca.RouteToApp{ + AWSRoleARN: "arn:aws:iam::123456789:role/test", + }, + } - awsAPIMock := httptest.NewUnstartedServer(handler) + awsAPIMock := httptest.NewUnstartedServer(mockAWSHandler) awsAPIMock.StartTLS() t.Cleanup(func() { awsAPIMock.Close() }) - svc, err := NewSigningService(SigningServiceConfig{ - getSigningCredentials: staticAWSCredentials, + svc, err := awsutils.NewSigningService(awsutils.SigningServiceConfig{ + GetSigningCredentials: staticAWSCredentials, + Clock: clock, + }) + require.NoError(t, err) + + audit, err := common.NewAudit(common.AuditConfig{ + Emitter: emitter, + }) + require.NoError(t, err) + signerHandler, err := NewAWSSignerHandler(SignerHandlerConfig{ + SigningService: svc, RoundTripper: &http.Transport{ TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, @@ -326,24 +349,18 @@ func createSuite(t *testing.T, handler http.HandlerFunc, app types.Application) return net.Dial(awsAPIMock.Listener.Addr().Network(), awsAPIMock.Listener.Addr().String()) }, }, - Clock: clockwork.NewFakeClock(), - }) - require.NoError(t, err) - - audit, err := common.NewAudit(common.AuditConfig{ - Emitter: emitter, }) require.NoError(t, err) - mux := http.NewServeMux() mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { request = common.WithSessionContext(request, &common.SessionContext{ - Identity: &user.Identity, + Identity: &identity, App: app, Audit: audit, + ChunkID: "123abc", }) - svc.ServeHTTP(writer, request) + signerHandler.ServeHTTP(writer, request) }) server := httptest.NewServer(mux) @@ -353,7 +370,7 @@ func createSuite(t *testing.T, handler http.HandlerFunc, app types.Application) return &suite{ Server: server, - identity: &user.Identity, + identity: &identity, app: app, emitter: emitter, } diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index 069674a1ab18b..0196ac8da630e 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -35,16 +35,17 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/azure" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" ) -// ForwarderConfig is the Forwarder configuration. -type ForwarderConfig struct { - // Client is an HTTP client instance used for HTTP calls. - Client *http.Client +// HandlerConfig is the configuration for an Azure app-access handler. +type HandlerConfig struct { + // RoundTripper is the underlying transport given to an oxy Forwarder. + RoundTripper http.RoundTripper // Log is the Logger. Log logrus.FieldLogger // Clock is used to override time in tests. @@ -54,16 +55,14 @@ type ForwarderConfig struct { getAccessToken getAccessTokenFunc } -// CheckAndSetDefaults validates the ForwarderConfig config. -func (s *ForwarderConfig) CheckAndSetDefaults() error { - if s.Client == nil { +// CheckAndSetDefaults validates the HandlerConfig. +func (s *HandlerConfig) CheckAndSetDefaults() error { + if s.RoundTripper == nil { tr, err := defaults.Transport() if err != nil { return trace.Wrap(err) } - s.Client = &http.Client{ - Transport: tr, - } + s.RoundTripper = tr } if s.Clock == nil { s.Clock = clockwork.NewRealClock() @@ -77,21 +76,21 @@ func (s *ForwarderConfig) CheckAndSetDefaults() error { return nil } -// Forwarder is an Azure CLI proxy service that forwards the requests to Azure API, but updates the authorization headers +// handler is an Azure CLI proxy service handler that forwards the requests to Azure API, but updates the authorization headers // based on user identity. -type Forwarder struct { - // ForwarderConfig is the Forwarder configuration. - ForwarderConfig +type handler struct { + // config is the handler configuration. + HandlerConfig - // Forwarder signs and forwards the request to Azure API. - *forward.Forwarder + // fwd is used to forward requests to Azure API after the handler has rewritten them. + fwd *forward.Forwarder // tokenCache caches access tokens. tokenCache *utils.FnCache } -// NewForwarder creates a new instance of Forwarder. -func NewForwarder(ctx context.Context, config ForwarderConfig) (*Forwarder, error) { +// NewAzureHandler creates a new instance of an http.Handler for Azure requests. +func NewAzureHandler(ctx context.Context, config HandlerConfig) (http.Handler, error) { if err := config.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } @@ -105,63 +104,62 @@ func NewForwarder(ctx context.Context, config ForwarderConfig) (*Forwarder, erro return nil, trace.Wrap(err) } - svc := &Forwarder{ - ForwarderConfig: config, - tokenCache: tokenCache, + svc := &handler{ + HandlerConfig: config, + tokenCache: tokenCache, } fwd, err := forward.New( - forward.RoundTripper(svc), + forward.RoundTripper(config.RoundTripper), forward.ErrorHandler(oxyutils.ErrorHandlerFunc(svc.formatForwardResponseError)), forward.PassHostHeader(true), ) if err != nil { return nil, trace.Wrap(err) } - svc.Forwarder = fwd + svc.fwd = fwd return svc, nil } // RoundTrip handles incoming requests and forwards them to the proper API. -func (s *Forwarder) RoundTrip(req *http.Request) (*http.Response, error) { +func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if err := s.serveHTTP(w, req); err != nil { + s.formatForwardResponseError(w, req, err) + return + } +} + +// serveHTTP is a helper to simplify error handling in ServeHTTP. +func (s *handler) serveHTTP(w http.ResponseWriter, req *http.Request) error { sessionCtx, err := common.GetSessionContext(req) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } fwdRequest, err := s.prepareForwardRequest(req, sessionCtx) if err != nil { - return nil, trace.Wrap(err) - } - resp, err := s.Client.Do(fwdRequest) - if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } + recorder := httplib.NewResponseStatusRecorder(w) + s.fwd.ServeHTTP(recorder, fwdRequest) - if err := sessionCtx.Audit.OnRequest(req.Context(), sessionCtx, fwdRequest, resp, nil); err != nil { + if err := sessionCtx.Audit.OnRequest(req.Context(), sessionCtx, fwdRequest, recorder.Status(), nil); err != nil { + // log but don't return the error, because we already handed off request/response handling to the oxy forwarder. s.Log.WithError(err).Warn("Failed to emit audit event.") } - - return resp, nil + return nil } -func (s *Forwarder) formatForwardResponseError(rw http.ResponseWriter, r *http.Request, err error) { +func (s *handler) formatForwardResponseError(rw http.ResponseWriter, r *http.Request, err error) { + s.Log.WithError(err).Debugf("Failed to process request.") common.SetTeleportAPIErrorHeader(rw, err) - switch trace.Unwrap(err).(type) { - case *trace.BadParameterError: - s.Log.Debugf("Failed to process request: %v.", err) - rw.WriteHeader(http.StatusBadRequest) - case *trace.AccessDeniedError: - s.Log.Infof("Failed to process request: %v.", err) - rw.WriteHeader(http.StatusForbidden) - default: - s.Log.Warnf("Failed to process request: %v.", err) - rw.WriteHeader(http.StatusInternalServerError) - } + // Convert trace error type to HTTP and write response. + code := trace.ErrorToCode(err) + http.Error(rw, http.StatusText(code), code) } // prepareForwardRequest prepares a request for forwarding, updating headers and target host. Several checks are made along the way. -func (s *Forwarder) prepareForwardRequest(r *http.Request, sessionCtx *common.SessionContext) (*http.Request, error) { +func (s *handler) prepareForwardRequest(r *http.Request, sessionCtx *common.SessionContext) (*http.Request, error) { forwardedHost := r.Header.Get("X-Forwarded-Host") if !azure.IsAzureEndpoint(forwardedHost) { return nil, trace.AccessDenied("%q is not an Azure endpoint", forwardedHost) @@ -206,7 +204,7 @@ func getPeerKey(certs []*x509.Certificate) (crypto.PublicKey, error) { } -func (s *Forwarder) replaceAuthHeaders(r *http.Request, sessionCtx *common.SessionContext, reqCopy *http.Request) error { +func (s *handler) replaceAuthHeaders(r *http.Request, sessionCtx *common.SessionContext, reqCopy *http.Request) error { auth := reqCopy.Header.Get("Authorization") if auth == "" { s.Log.Debugf("No Authorization header present, skipping replacement.") @@ -234,7 +232,7 @@ func (s *Forwarder) replaceAuthHeaders(r *http.Request, sessionCtx *common.Sessi return nil } -func (s *Forwarder) parseAuthHeader(token string, pubKey crypto.PublicKey) (*jwt.AzureTokenClaims, error) { +func (s *handler) parseAuthHeader(token string, pubKey crypto.PublicKey) (*jwt.AzureTokenClaims, error) { before, after, found := strings.Cut(token, " ") if !found { return nil, trace.BadParameter("Unable to parse auth header") @@ -279,7 +277,7 @@ type cacheKey struct { const getTokenTimeout = time.Second * 5 -func (s *Forwarder) getToken(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) { +func (s *handler) getToken(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) { key := cacheKey{managedIdentity, scope} timeoutCtx, cancel := context.WithTimeout(ctx, getTokenTimeout) diff --git a/lib/srv/app/common/audit.go b/lib/srv/app/common/audit.go index abe2d42c4234b..73736fc0a33de 100644 --- a/lib/srv/app/common/audit.go +++ b/lib/srv/app/common/audit.go @@ -39,11 +39,11 @@ type Audit interface { // OnSessionEnd is called when an app session ends. OnSessionEnd(ctx context.Context, serverID string, identity *tlsca.Identity, app types.Application) error // OnSessionChunk is called when a new session chunk is created. - OnSessionChunk(ctx context.Context, sessionCtx *SessionContext, serverID string) error + OnSessionChunk(ctx context.Context, serverID, chunkID string, identity *tlsca.Identity, app types.Application) error // OnRequest is called when an app request is sent during the session and a response is received. - OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, res *http.Response, re *endpoints.ResolvedEndpoint) error + OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, code int, re *endpoints.ResolvedEndpoint) error // OnDynamoDBRequest is called when app request for a DynamoDB API is sent and a response is received. - OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, res *http.Response, re *endpoints.ResolvedEndpoint) error + OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, code int, re *endpoints.ResolvedEndpoint) error // EmitEvent emits the provided audit event. EmitEvent(ctx context.Context, event apievents.AuditEvent) error } @@ -140,34 +140,34 @@ func (a *audit) OnSessionEnd(ctx context.Context, serverID string, identity *tls } // OnSessionChunk is called when a new session chunk is created. -func (a *audit) OnSessionChunk(ctx context.Context, sessionCtx *SessionContext, serverID string) error { +func (a *audit) OnSessionChunk(ctx context.Context, serverID, chunkID string, identity *tlsca.Identity, app types.Application) error { event := &apievents.AppSessionChunk{ Metadata: apievents.Metadata{ Type: events.AppSessionChunkEvent, Code: events.AppSessionChunkCode, - ClusterName: sessionCtx.Identity.RouteToApp.ClusterName, + ClusterName: identity.RouteToApp.ClusterName, }, ServerMetadata: apievents.ServerMetadata{ ServerID: serverID, ServerNamespace: apidefaults.Namespace, }, SessionMetadata: apievents.SessionMetadata{ - SessionID: sessionCtx.Identity.RouteToApp.SessionID, - WithMFA: sessionCtx.Identity.MFAVerified, + SessionID: identity.RouteToApp.SessionID, + WithMFA: identity.MFAVerified, }, - UserMetadata: sessionCtx.Identity.GetUserMetadata(), + UserMetadata: identity.GetUserMetadata(), AppMetadata: apievents.AppMetadata{ - AppURI: sessionCtx.App.GetURI(), - AppPublicAddr: sessionCtx.App.GetPublicAddr(), - AppName: sessionCtx.App.GetName(), + AppURI: app.GetURI(), + AppPublicAddr: app.GetPublicAddr(), + AppName: app.GetName(), }, - SessionChunkID: sessionCtx.ChunkID, + SessionChunkID: chunkID, } return trace.Wrap(a.EmitEvent(ctx, event)) } // OnRequest is called when an app request is sent during the session and a response is received. -func (a *audit) OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, res *http.Response, re *endpoints.ResolvedEndpoint) error { +func (a *audit) OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, code int, re *endpoints.ResolvedEndpoint) error { event := &apievents.AppSessionRequest{ Metadata: apievents.Metadata{ Type: events.AppSessionRequestEvent, @@ -177,14 +177,14 @@ func (a *audit) OnRequest(ctx context.Context, sessionCtx *SessionContext, req * Method: req.Method, Path: req.URL.Path, RawQuery: req.URL.RawQuery, - StatusCode: uint32(res.StatusCode), + StatusCode: uint32(code), AWSRequestMetadata: *MakeAWSRequestMetadata(req, re), } return trace.Wrap(a.EmitEvent(ctx, event)) } // OnDynamoDBRequest is called when a DynamoDB app request is sent during the session. -func (a *audit) OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, res *http.Response, re *endpoints.ResolvedEndpoint) error { +func (a *audit) OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, statusCode int, re *endpoints.ResolvedEndpoint) error { // Try to read the body and JSON unmarshal it. // If this fails, we still want to emit the rest of the event info; the request event Body is nullable, so it's ok if body is left nil here. body, err := awsutils.UnmarshalRequestBody(req) @@ -193,7 +193,7 @@ func (a *audit) OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContex } // get the API target from the request header, according to the API request format documentation: // https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Programming.LowLevelAPI.html#Programming.LowLevelAPI.RequestFormat - target := req.Header.Get(awsutils.TargetHeader) + target := req.Header.Get(awsutils.AmzTargetHeader) event := &apievents.AppSessionDynamoDBRequest{ Metadata: apievents.Metadata{ Type: events.AppSessionDynamoDBRequestEvent, @@ -203,7 +203,7 @@ func (a *audit) OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContex AppMetadata: *MakeAppMetadata(sessionCtx.App), AWSRequestMetadata: *MakeAWSRequestMetadata(req, re), SessionChunkID: sessionCtx.ChunkID, - StatusCode: uint32(res.StatusCode), + StatusCode: uint32(statusCode), Path: req.URL.Path, RawQuery: req.URL.RawQuery, Method: req.Method, diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index 8bfce221e6090..bb50afbb31510 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -50,7 +50,7 @@ import ( "github.com/gravitational/teleport/lib/srv/app/common" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/teleport/lib/utils/aws" + awsutils "github.com/gravitational/teleport/lib/utils/aws" ) type appServerContextKey string @@ -211,8 +211,8 @@ type Server struct { cache *sessionChunkCache - awsSigner *appaws.SigningService - azureHandler *appazure.Forwarder + awsHandler http.Handler + azureHandler http.Handler // watcher monitors changes to application resources. watcher *services.AppWatcher @@ -263,14 +263,19 @@ func New(ctx context.Context, c *Config) (*Server, error) { } }() - awsSigner, err := appaws.NewSigningService(appaws.SigningServiceConfig{}) + awsSigner, err := awsutils.NewSigningService(awsutils.SigningServiceConfig{}) + if err != nil { + return nil, trace.Wrap(err) + } + awsHandler, err := appaws.NewAWSSignerHandler(appaws.SignerHandlerConfig{ + SigningService: awsSigner, + }) if err != nil { return nil, trace.Wrap(err) } - azureHandler, err := appazure.NewForwarder(closeContext, appazure.ForwarderConfig{}) + azureHandler, err := appazure.NewAzureHandler(closeContext, appazure.HandlerConfig{}) if err != nil { - closeFunc() return nil, trace.Wrap(err) } @@ -283,7 +288,7 @@ func New(ctx context.Context, c *Config) (*Server, error) { dynamicLabels: make(map[string]*labels.Dynamic), apps: make(map[string]types.Application), connAuth: make(map[net.Conn]error), - awsSigner: awsSigner, + awsHandler: awsHandler, azureHandler: azureHandler, monitoredApps: monitoredApps{ static: c.Apps, @@ -798,8 +803,8 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { // 4 algorithm. AWS CLI and AWS SDKs automatically use SigV4 for all // services that support it (All services expect Amazon SimpleDB but // this AWS service has been deprecated) - if aws.IsSignedByAWSSigV4(r) { - return s.serveSession(w, r, &identity, app, s.withAWSForwarder) + if awsutils.IsSignedByAWSSigV4(r) { + return s.serveSession(w, r, &identity, app, s.withAWSSigner) } // Request for AWS console access originated from Teleport Proxy WebUI @@ -812,7 +817,6 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { default: return s.serveSession(w, r, &identity, app, s.withJWTTokenForwarder) } - } // serveAWSWebConsole generates a sign-in URL for AWS management console and @@ -844,8 +848,16 @@ func (s *Server) serveSession(w http.ResponseWriter, r *http.Request, identity * } defer session.release() + // Create session context. + sessionCtx := &common.SessionContext{ + Identity: identity, + App: app, + ChunkID: session.id, + Audit: session.audit, + } + // Forward request to the target application. - session.fwd.ServeHTTP(w, common.WithSessionContext(r, session.sessionCtx)) + session.handler.ServeHTTP(w, common.WithSessionContext(r, sessionCtx)) return nil } diff --git a/lib/srv/app/session.go b/lib/srv/app/session.go index 5e85c8fc63e81..24215c4492063 100644 --- a/lib/srv/app/session.go +++ b/lib/srv/app/session.go @@ -19,6 +19,7 @@ package app import ( "context" "errors" + "net/http" "path/filepath" "sync" "time" @@ -49,7 +50,7 @@ const sessionChunkCloseTimeout = 1 * time.Hour var errSessionChunkAlreadyClosed = errors.New("session chunk already closed") -// sessionChunk holds an open request forwarder and audit log for an app session. +// sessionChunk holds an open request handler and stream closer for an app session. // // An app session is only bounded by the lifetime of the certificate in // the caller's identity, so we create sessionChunks to track and record @@ -63,12 +64,12 @@ type sessionChunk struct { closeC chan struct{} // id is the session chunk's uuid, which is used as the id of its session upload. id string - // fwd can rewrite and forward requests to the target application. - fwd *forward.Forwarder // streamCloser closes the session chunk stream. streamCloser utils.WriteContextCloser - // sessionCtx contains common context parameters for an App session. - sessionCtx *common.SessionContext + // audit is the session chunk audit logger. + audit common.Audit + // handler handles requests for this session chunk. + handler http.Handler // inflightCond protects and signals change of inflight inflightCond *sync.Cond @@ -112,25 +113,19 @@ func (s *Server) newSessionChunk(ctx context.Context, identity *tlsca.Identity, } // Create the stream writer that will write this chunk to the audit log. - streamWriter, err := s.newStreamWriter(app, sess.id) + streamWriter, err := s.newStreamWriter(sess.id) if err != nil { return nil, trace.Wrap(err) } sess.streamCloser = streamWriter - // Create session context. audit, err := common.NewAudit(common.AuditConfig{ Emitter: streamWriter, }) if err != nil { return nil, trace.Wrap(err) } - sess.sessionCtx = &common.SessionContext{ - Identity: identity, - App: app, - ChunkID: sess.id, - Audit: audit, - } + sess.audit = audit for _, opt := range opts { if err = opt(ctx, sess, identity, app); err != nil { @@ -141,13 +136,12 @@ func (s *Server) newSessionChunk(ctx context.Context, identity *tlsca.Identity, // Put the session chunk in the cache so that upcoming requests can use it for // 5 minutes or the time until the certificate expires, whichever comes first. ttl := utils.MinTTL(identity.Expires.Sub(s.c.Clock.Now()), 5*time.Minute) - err = s.cache.set(identity.RouteToApp.SessionID, sess, ttl) - if err != nil { + if err = s.cache.set(identity.RouteToApp.SessionID, sess, ttl); err != nil { return nil, trace.Wrap(err) } // only emit a session chunk if we didnt get an error making the new session chunk - if err := sess.sessionCtx.Audit.OnSessionChunk(ctx, sess.sessionCtx, s.c.HostID); err != nil { + if err := sess.audit.OnSessionChunk(ctx, s.c.HostID, sess.id, identity, app); err != nil { return nil, trace.Wrap(err) } return sess, nil @@ -178,21 +172,19 @@ func (s *Server) withJWTTokenForwarder(ctx context.Context, sess *sessionChunk, // Create a rewriting transport that will be used to forward requests. transport, err := newTransport(s.closeContext, &transportConfig{ - audit: sess.sessionCtx.Audit, app: app, publicPort: s.proxyPort, cipherSuites: s.c.CipherSuites, jwt: jwt, traits: traits, log: s.log, - user: identity.Username, }) if err != nil { return trace.Wrap(err) } delegate := forward.NewHeaderRewriter() - sess.fwd, err = forward.New( + fwd, err := forward.New( forward.FlushInterval(100*time.Millisecond), forward.RoundTripper(transport), forward.Logger(logrus.StandardLogger()), @@ -203,18 +195,18 @@ func (s *Server) withJWTTokenForwarder(ctx context.Context, sess *sessionChunk, if err != nil { return trace.Wrap(err) } + sess.handler = fwd return nil } -// withAWSForwarder is a sessionOpt that uses forwarder of the AWS signning -// service. -func (s *Server) withAWSForwarder(ctx context.Context, sess *sessionChunk, identity *tlsca.Identity, app types.Application) error { - sess.fwd = s.awsSigner.Forwarder +// withAWSSigner is a sessionOpt that uses an AWS signing service handler. +func (s *Server) withAWSSigner(_ context.Context, sess *sessionChunk, _ *tlsca.Identity, _ types.Application) error { + sess.handler = s.awsHandler return nil } func (s *Server) withAzureForwarder(ctx context.Context, sess *sessionChunk, identity *tlsca.Identity, app types.Application) error { - sess.fwd = s.azureHandler.Forwarder + sess.handler = s.azureHandler return nil } @@ -266,8 +258,7 @@ func (s *sessionChunk) close(ctx context.Context) error { s.inflightCond.L.Unlock() close(s.closeC) s.log.Debugf("Closed session chunk %s", s.id) - err := s.streamCloser.Close(ctx) - return trace.Wrap(err) + return trace.Wrap(s.streamCloser.Close(ctx)) } func (s *Server) closeSession(sess *sessionChunk) { @@ -279,7 +270,7 @@ func (s *Server) closeSession(sess *sessionChunk) { // newStreamWriter creates a session stream that will be used to record // requests that occur within this session chunk and upload the recording // to the Auth server. -func (s *Server) newStreamWriter(app types.Application, chunkID string) (events.StreamWriter, error) { +func (s *Server) newStreamWriter(chunkID string) (events.StreamWriter, error) { recConfig, err := s.c.AccessPoint.GetSessionRecordingConfig(s.closeContext) if err != nil { return nil, trace.Wrap(err) @@ -291,7 +282,7 @@ func (s *Server) newStreamWriter(app types.Application, chunkID string) (events. } // Create a sync or async streamer depending on configuration of cluster. - streamer, err := s.newStreamer(app, chunkID, recConfig) + streamer, err := s.newStreamer(chunkID, recConfig) if err != nil { return nil, trace.Wrap(err) } @@ -319,7 +310,7 @@ func (s *Server) newStreamWriter(app types.Application, chunkID string) (events. // of the server and the session, sync streamer sends the events // directly to the auth server and blocks if the events can not be received, // async streamer buffers the events to disk and uploads the events later -func (s *Server) newStreamer(app types.Application, chunkID string, recConfig types.SessionRecordingConfig) (events.Streamer, error) { +func (s *Server) newStreamer(chunkID string, recConfig types.SessionRecordingConfig) (events.Streamer, error) { if services.IsRecordSync(recConfig.GetMode()) { s.log.Debugf("Using sync streamer for session chunk %v.", chunkID) return s.c.AuthClient, nil diff --git a/lib/srv/app/transport.go b/lib/srv/app/transport.go index 5b98f64031907..622de1188b13e 100644 --- a/lib/srv/app/transport.go +++ b/lib/srv/app/transport.go @@ -45,17 +45,12 @@ type transportConfig struct { publicPort string cipherSuites []uint16 jwt string - audit common.Audit traits wrappers.Traits log logrus.FieldLogger - user string } // Check validates configuration. func (c *transportConfig) Check() error { - if c.audit == nil { - return trace.BadParameter("audit writer missing") - } if c.app == nil { return trace.BadParameter("app missing") } @@ -77,7 +72,7 @@ func (c *transportConfig) Check() error { type transport struct { closeContext context.Context - c *transportConfig + *transportConfig tr http.RoundTripper @@ -109,11 +104,11 @@ func newTransport(ctx context.Context, c *transportConfig) (*transport, error) { } return &transport{ - closeContext: ctx, - c: c, - uri: uri, - tr: tr, - ws: newWebsocketTransport(uri, tr.TLSClientConfig, c), + closeContext: ctx, + uri: uri, + tr: tr, + ws: newWebsocketTransport(uri, tr.TLSClientConfig, c), + transportConfig: c, }, nil } @@ -145,7 +140,7 @@ func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) { return nil, trace.Wrap(err) } - sessionCtx, err := common.GetSessionContext(r) + sessCtx, err := common.GetSessionContext(r) if err != nil { return nil, trace.Wrap(err) } @@ -157,7 +152,7 @@ func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) { } // Emit the event to the audit log. - if err := t.c.audit.OnRequest(t.closeContext, sessionCtx, r, resp, nil /*aws endpoint*/); err != nil { + if err := sessCtx.Audit.OnRequest(t.closeContext, sessCtx, r, resp.StatusCode, nil /*aws endpoint*/); err != nil { return nil, trace.Wrap(err) } @@ -176,7 +171,7 @@ func (t *transport) rewriteRequest(r *http.Request) error { r.URL.Host = t.uri.Host // Add headers from rewrite configuration. - rewriteHeaders(r, t.c) + rewriteHeaders(r, t.transportConfig) return nil } @@ -237,7 +232,7 @@ func (t *transport) needsPathRedirect(r *http.Request) (string, bool) { u := url.URL{ Scheme: "https", - Host: net.JoinHostPort(t.c.app.GetPublicAddr(), t.c.publicPort), + Host: net.JoinHostPort(t.app.GetPublicAddr(), t.publicPort), Path: uriPath, } return u.String(), true @@ -246,7 +241,7 @@ func (t *transport) needsPathRedirect(r *http.Request) (string, bool) { // rewriteResponse applies any rewriting rules to the response before returning it. func (t *transport) rewriteResponse(resp *http.Response) error { switch { - case t.c.app.GetRewrite() != nil && len(t.c.app.GetRewrite().Redirect) > 0: + case t.app.GetRewrite() != nil && len(t.app.GetRewrite().Redirect) > 0: err := t.rewriteRedirect(resp) if err != nil { return trace.Wrap(err) @@ -267,9 +262,9 @@ func (t *transport) rewriteRedirect(resp *http.Response) error { // If the redirect location is one of the hosts specified in the list of // redirects, rewrite the header. - if slices.Contains(t.c.app.GetRewrite().Redirect, host(u.Host)) { + if slices.Contains(t.app.GetRewrite().Redirect, host(u.Host)) { u.Scheme = "https" - u.Host = net.JoinHostPort(t.c.app.GetPublicAddr(), t.c.publicPort) + u.Host = net.JoinHostPort(t.app.GetPublicAddr(), t.publicPort) } resp.Header.Set("Location", u.String()) } diff --git a/lib/utils/aws/aws.go b/lib/utils/aws/aws.go index 020ec4b88ca4a..a1177b72796dc 100644 --- a/lib/utils/aws/aws.go +++ b/lib/utils/aws/aws.go @@ -19,6 +19,7 @@ package aws import ( "bytes" "context" + "fmt" "io" "net/http" "net/textproto" @@ -33,6 +34,7 @@ import ( "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" + apiawsutils "github.com/gravitational/teleport/api/utils/aws" "github.com/gravitational/teleport/lib/utils" ) @@ -55,10 +57,11 @@ const ( credentialAuthHeaderElem = "Credential" signedHeaderAuthHeaderElem = "SignedHeaders" signatureAuthHeaderElem = "Signature" - // TargetHeader is a header containing the API target. + + // AmzTargetHeader is a header containing the API target. // Format: target_version.operation // Example: DynamoDB_20120810.Scan - TargetHeader = "X-Amz-Target" + AmzTargetHeader = "X-Amz-Target" // AmzJSON1_0 is an AWS Content-Type header that indicates the media type is JSON. AmzJSON1_0 = "application/x-amz-json-1.0" // AmzJSON1_1 is an AWS Content-Type header that indicates the media type is JSON. @@ -239,17 +242,24 @@ func NewSigner(credentials *credentials.Credentials, signingServiceName string) return v4.NewSigner(credentials, options) } -// filterHeaders removes request headers that are not in the headers list. -func filterHeaders(r *http.Request, headers []string) { +// filterHeaders removes request headers that are not in the headers list and returns the removed header keys. +func filterHeaders(r *http.Request, headers []string) []string { + keep := make(map[string]struct{}) + for _, key := range headers { + keep[textproto.CanonicalMIMEHeaderKey(key)] = struct{}{} + } + + var removed []string out := make(http.Header) - for _, v := range headers { - ck := textproto.CanonicalMIMEHeaderKey(v) - val, ok := r.Header[ck] - if ok { - out[ck] = val + for key, vals := range r.Header { + if _, ok := keep[textproto.CanonicalMIMEHeaderKey(key)]; ok { + out[key] = vals + continue } + removed = append(removed, key) } r.Header = out + return removed } // FilterAWSRoles returns role ARNs from the provided list that belong to the @@ -360,3 +370,20 @@ func isJSON(contentType string) bool { return false } } + +// BuildRoleARN constructs a string AWS ARN from a username, region, and account ID. +func BuildRoleARN(username, region, accountID string) string { + if arn.IsARN(username) { + return username + } + resource := username + if !strings.Contains(resource, "/") { + resource = fmt.Sprintf("role/%s", username) + } + return arn.ARN{ + Partition: apiawsutils.GetPartitionFromRegion(region), + Service: "iam", + AccountID: accountID, + Resource: resource, + }.String() +} diff --git a/lib/utils/aws/signing.go b/lib/utils/aws/signing.go new file mode 100644 index 0000000000000..96b670a4e0330 --- /dev/null +++ b/lib/utils/aws/signing.go @@ -0,0 +1,194 @@ +/* +Copyright 2021 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package aws + +import ( + "bytes" + "net/http" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/client" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + awssession "github.com/aws/aws-sdk-go/aws/session" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport/lib/utils" +) + +// NewSigningService creates a new instance of SigningService. +func NewSigningService(config SigningServiceConfig) (*SigningService, error) { + if err := config.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return &SigningService{ + SigningServiceConfig: config, + }, nil +} + +// SigningService is an AWS CLI proxy service that signs AWS requests +// based on user identity. +type SigningService struct { + // SigningServiceConfig is the SigningService configuration. + SigningServiceConfig +} + +// SigningServiceConfig is the SigningService configuration. +type SigningServiceConfig struct { + // Session is AWS session. + Session *awssession.Session + // Clock is used to override time in tests. + Clock clockwork.Clock + // GetSigningCredentials allows to set the function responsible for obtaining STS credentials. + // Used in tests to set static AWS credentials and skip API call. + GetSigningCredentials GetSigningCredentialsFunc +} + +// CheckAndSetDefaults validates the SigningServiceConfig config. +func (s *SigningServiceConfig) CheckAndSetDefaults() error { + if s.Clock == nil { + s.Clock = clockwork.NewRealClock() + } + if s.Session == nil { + ses, err := awssession.NewSessionWithOptions(awssession.Options{ + SharedConfigState: awssession.SharedConfigEnable, + }) + if err != nil { + return trace.Wrap(err) + } + s.Session = ses + } + if s.GetSigningCredentials == nil { + s.GetSigningCredentials = GetAWSCredentialsFromSTSAPI + } + return nil +} + +// SigningCtx contains AWS SigV4 signing context parameters. +type SigningCtx struct { + SigningName string + SigningRegion string + Expiry time.Time + SessionName string + AWSRoleArn string + AWSExternalID string +} + +// Check checks signing context parameters. +func (sc *SigningCtx) Check(clock clockwork.Clock) error { + switch { + case sc.SigningName == "": + return trace.BadParameter("missing AWS signing name") + case sc.SigningRegion == "": + return trace.BadParameter("missing AWS signing region") + case sc.SessionName == "": + return trace.BadParameter("missing AWS session name") + case sc.AWSRoleArn == "": + return trace.BadParameter("missing AWS Role ARN") + case sc.Expiry.Before(clock.Now()): + return trace.BadParameter("AWS SigV4 expiry has already expired") + default: + return nil + } +} + +// SignRequest creates a new HTTP request and rewrites the header from the original request and returns a new +// HTTP request signed by STS AWS API. +// Signing steps: +// 1) Decode Authorization Header. Authorization Header example: +// +// Authorization: AWS4-HMAC-SHA256 +// Credential=AKIAIOSFODNN7EXAMPLE/20130524/us-east-1/s3/aws4_request, +// SignedHeaders=host;range;x-amz-date, +// Signature=fe5f80f77d5fa3beca038a248ff027d0445342fe2855ddc963176630326f1024 +// +// 2. Extract credential section from credential Authorization Header. +// 3. Extract aws-region and aws-service from the credential section. +// 4. Build AWS API endpoint based on extracted aws-region and aws-service fields. +// Not that for endpoint resolving the https://github.com/aws/aws-sdk-go/aws/endpoints/endpoints.go +// package is used and when Amazon releases a new API the dependency update is needed. +// 5. Sign HTTP request. +func (s *SigningService) SignRequest(req *http.Request, signCtx *SigningCtx) (*http.Request, error) { + if signCtx == nil { + return nil, trace.BadParameter("missing signing context") + } + if err := signCtx.Check(s.Clock); err != nil { + return nil, trace.Wrap(err) + } + payload, err := GetAndReplaceReqBody(req) + if err != nil { + return nil, trace.Wrap(err) + } + reqCopy, err := http.NewRequest(req.Method, req.URL.String(), bytes.NewReader(payload)) + if err != nil { + return nil, trace.Wrap(err) + } + reqCopy.Header = req.Header.Clone() + + unsignedHeaders := removeUnsignedHeaders(reqCopy) + credentials := s.GetSigningCredentials(s.Session, signCtx.Expiry, signCtx.SessionName, signCtx.AWSRoleArn, signCtx.AWSExternalID) + signer := NewSigner(credentials, signCtx.SigningName) + _, err = signer.Sign(reqCopy, bytes.NewReader(payload), signCtx.SigningName, signCtx.SigningRegion, s.Clock.Now()) + if err != nil { + return nil, trace.Wrap(err) + } + + // copy removed headers back to the request after signing it, but don't copy the old Authorization header. + copyHeaders(reqCopy, req, utils.RemoveFromSlice(unsignedHeaders, "Authorization")) + return reqCopy, nil +} + +// GetSigningCredentialsFunc allows to set the function responsible for obtaining STS credentials. +// Used in tests to set static AWS credentials and skip API call. +type GetSigningCredentialsFunc func(provider client.ConfigProvider, expiry time.Time, sessName, roleARN, externalID string) *credentials.Credentials + +// GetAWSCredentialsFromSTSAPI obtains STS credentials. +func GetAWSCredentialsFromSTSAPI(provider client.ConfigProvider, expiry time.Time, sessName, roleARN, externalID string) *credentials.Credentials { + return stscreds.NewCredentials(provider, roleARN, + func(cred *stscreds.AssumeRoleProvider) { + cred.RoleSessionName = sessName + cred.Expiry.SetExpiration(expiry, 0) + + if externalID != "" { + cred.ExternalID = aws.String(externalID) + } + }, + ) +} + +// removeUnsignedHeaders removes and returns header keys that are not included in SigV4 SignedHeaders. +// If the request is not already signed, then no headers are removed. +func removeUnsignedHeaders(reqCopy *http.Request) []string { + // check if the request is already signed. + authHeader := reqCopy.Header.Get("Authorization") + sig, err := ParseSigV4(authHeader) + if err != nil { + return nil + } + return filterHeaders(reqCopy, sig.SignedHeaders) +} + +// copyHeaders copies headers from src request to dst request, using a list of header keys to copy. +func copyHeaders(dst *http.Request, src *http.Request, keys []string) { + for _, k := range keys { + if vals, ok := src.Header[k]; ok { + dst.Header[k] = vals + } + } +} From d6f3f6dedc16bd8a6f9c48dd67b35acbe17ee4bf Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Fri, 16 Dec 2022 12:12:33 -0800 Subject: [PATCH 2/5] address feedback * check auditErr instead of err for logging * use app server close context for audit event emitting * add go doc comments. * refactor request rewriting to make the copy in a more robust way. * pass status code as uint32 rather than casting in audit emitter * clone request in signing service --- lib/srv/app/aws/handler.go | 32 ++++++++++++++++++-------------- lib/srv/app/aws/handler_test.go | 24 +++++++++++++----------- lib/srv/app/azure/handler.go | 3 ++- lib/srv/app/common/audit.go | 12 ++++++------ lib/srv/app/server.go | 2 +- lib/srv/app/transport.go | 3 ++- lib/utils/aws/signing.go | 23 +++++++++++++---------- 7 files changed, 55 insertions(+), 44 deletions(-) diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go index 12188dabb6b0f..0f70c9df9e950 100644 --- a/lib/srv/app/aws/handler.go +++ b/lib/srv/app/aws/handler.go @@ -17,6 +17,7 @@ limitations under the License. package aws import ( + "context" "net/http" "net/url" @@ -36,8 +37,10 @@ import ( type signerHandler struct { // fwd is a Forwarder used to forward signed requests to AWS API. fwd *forward.Forwarder - // AwsSignerHandlerConfig is the awsSignerHandler configuration. + // SignerHandlerConfig is the configuration for the handler. SignerHandlerConfig + // closeContext is the app server close context. + closeContext context.Context } // SignerHandlerConfig is the awsSignerHandler configuration. @@ -69,13 +72,14 @@ func (cfg *SignerHandlerConfig) CheckAndSetDefaults() error { } // NewAWSSignerHandler creates a new request handler for signing and forwarding requests to AWS API. -func NewAWSSignerHandler(config SignerHandlerConfig) (http.Handler, error) { +func NewAWSSignerHandler(ctx context.Context, config SignerHandlerConfig) (http.Handler, error) { if err := config.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } handler := &signerHandler{ SignerHandlerConfig: config, + closeContext: ctx, } fwd, err := forward.New( forward.RoundTripper(config.RoundTripper), @@ -122,12 +126,12 @@ func (s *signerHandler) serveHTTP(w http.ResponseWriter, req *http.Request) erro } // rewrite headers before signing the request to avoid signature validation problems. - unsignedReq, err := rewriteRequest(req, re) + unsignedReq, err := rewriteRequest(s.closeContext, req, re) if err != nil { return trace.Wrap(err) } - signedReq, err := s.SignRequest(unsignedReq, + signedReq, err := s.SignRequest(s.closeContext, unsignedReq, &awsutils.SigningCtx{ SigningName: re.SigningName, SigningRegion: re.SigningRegion, @@ -141,14 +145,15 @@ func (s *signerHandler) serveHTTP(w http.ResponseWriter, req *http.Request) erro } recorder := httplib.NewResponseStatusRecorder(w) s.fwd.ServeHTTP(recorder, signedReq) + status := uint32(recorder.Status()) var auditErr error if isDynamoDBEndpoint(re) { - auditErr = sessCtx.Audit.OnDynamoDBRequest(unsignedReq.Context(), sessCtx, unsignedReq, recorder.Status(), re) + auditErr = sessCtx.Audit.OnDynamoDBRequest(s.closeContext, sessCtx, unsignedReq, status, re) } else { - auditErr = sessCtx.Audit.OnRequest(unsignedReq.Context(), sessCtx, unsignedReq, recorder.Status(), re) + auditErr = sessCtx.Audit.OnRequest(s.closeContext, sessCtx, unsignedReq, status, re) } - if err != nil { + if auditErr != nil { // log but don't return the error, because we already handed off request/response handling to the oxy forwarder. s.Log.WithError(auditErr).Warn("Failed to emit audit event.") } @@ -156,17 +161,16 @@ func (s *signerHandler) serveHTTP(w http.ResponseWriter, req *http.Request) erro } // rewriteRequest rewrites a request to remove Teleport reserved headers, sets the url, and sets host. -func rewriteRequest(r *http.Request, re *endpoints.ResolvedEndpoint) (*http.Request, error) { - // shallow copy request and make a deep copy for header modification. - outReq := &http.Request{} - *outReq = *r - outReq.Header = r.Header.Clone() +func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.ResolvedEndpoint) (*http.Request, error) { u, err := urlForResolvedEndpoint(r, re) if err != nil { return nil, trace.Wrap(err) } - outReq.URL = u - outReq.Host = u.Host + outReq, err := http.NewRequestWithContext(ctx, r.Method, u.String(), r.Body) + if err != nil { + return nil, trace.Wrap(err) + } + outReq.Header = r.Header.Clone() for key := range outReq.Header { // Remove Teleport app headers. diff --git a/lib/srv/app/aws/handler_test.go b/lib/srv/app/aws/handler_test.go index 2c05aa129be7b..62f177f24513e 100644 --- a/lib/srv/app/aws/handler_test.go +++ b/lib/srv/app/aws/handler_test.go @@ -215,7 +215,8 @@ func TestAWSSignerHandler(t *testing.T) { require.Equal(t, tc.wantAuthCredService, awsAuthHeader.Service) } - suite := createSuite(t, mockAWSHandler, tc.app, clockwork.NewRealClock()) + fakeClock := clockwork.NewFakeClock() + suite := createSuite(t, mockAWSHandler, tc.app, fakeClock) err := tc.request(suite.URL, tc.awsClientSession) for _, assertFn := range tc.errAssertionFns { @@ -339,17 +340,18 @@ func createSuite(t *testing.T, mockAWSHandler http.HandlerFunc, app types.Applic Emitter: emitter, }) require.NoError(t, err) - signerHandler, err := NewAWSSignerHandler(SignerHandlerConfig{ - SigningService: svc, - RoundTripper: &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: true, + signerHandler, err := NewAWSSignerHandler(context.Background(), + SignerHandlerConfig{ + SigningService: svc, + RoundTripper: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial(awsAPIMock.Listener.Addr().Network(), awsAPIMock.Listener.Addr().String()) + }, }, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return net.Dial(awsAPIMock.Listener.Addr().Network(), awsAPIMock.Listener.Addr().String()) - }, - }, - }) + }) require.NoError(t, err) mux := http.NewServeMux() mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index 0196ac8da630e..b09f54f7852a6 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -141,8 +141,9 @@ func (s *handler) serveHTTP(w http.ResponseWriter, req *http.Request) error { } recorder := httplib.NewResponseStatusRecorder(w) s.fwd.ServeHTTP(recorder, fwdRequest) + status := uint32(recorder.Status()) - if err := sessionCtx.Audit.OnRequest(req.Context(), sessionCtx, fwdRequest, recorder.Status(), nil); err != nil { + if err := sessionCtx.Audit.OnRequest(req.Context(), sessionCtx, fwdRequest, status, nil); err != nil { // log but don't return the error, because we already handed off request/response handling to the oxy forwarder. s.Log.WithError(err).Warn("Failed to emit audit event.") } diff --git a/lib/srv/app/common/audit.go b/lib/srv/app/common/audit.go index 73736fc0a33de..197e63a12dae1 100644 --- a/lib/srv/app/common/audit.go +++ b/lib/srv/app/common/audit.go @@ -41,9 +41,9 @@ type Audit interface { // OnSessionChunk is called when a new session chunk is created. OnSessionChunk(ctx context.Context, serverID, chunkID string, identity *tlsca.Identity, app types.Application) error // OnRequest is called when an app request is sent during the session and a response is received. - OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, code int, re *endpoints.ResolvedEndpoint) error + OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, status uint32, re *endpoints.ResolvedEndpoint) error // OnDynamoDBRequest is called when app request for a DynamoDB API is sent and a response is received. - OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, code int, re *endpoints.ResolvedEndpoint) error + OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, status uint32, re *endpoints.ResolvedEndpoint) error // EmitEvent emits the provided audit event. EmitEvent(ctx context.Context, event apievents.AuditEvent) error } @@ -167,7 +167,7 @@ func (a *audit) OnSessionChunk(ctx context.Context, serverID, chunkID string, id } // OnRequest is called when an app request is sent during the session and a response is received. -func (a *audit) OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, code int, re *endpoints.ResolvedEndpoint) error { +func (a *audit) OnRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, status uint32, re *endpoints.ResolvedEndpoint) error { event := &apievents.AppSessionRequest{ Metadata: apievents.Metadata{ Type: events.AppSessionRequestEvent, @@ -177,14 +177,14 @@ func (a *audit) OnRequest(ctx context.Context, sessionCtx *SessionContext, req * Method: req.Method, Path: req.URL.Path, RawQuery: req.URL.RawQuery, - StatusCode: uint32(code), + StatusCode: status, AWSRequestMetadata: *MakeAWSRequestMetadata(req, re), } return trace.Wrap(a.EmitEvent(ctx, event)) } // OnDynamoDBRequest is called when a DynamoDB app request is sent during the session. -func (a *audit) OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, statusCode int, re *endpoints.ResolvedEndpoint) error { +func (a *audit) OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContext, req *http.Request, status uint32, re *endpoints.ResolvedEndpoint) error { // Try to read the body and JSON unmarshal it. // If this fails, we still want to emit the rest of the event info; the request event Body is nullable, so it's ok if body is left nil here. body, err := awsutils.UnmarshalRequestBody(req) @@ -203,7 +203,7 @@ func (a *audit) OnDynamoDBRequest(ctx context.Context, sessionCtx *SessionContex AppMetadata: *MakeAppMetadata(sessionCtx.App), AWSRequestMetadata: *MakeAWSRequestMetadata(req, re), SessionChunkID: sessionCtx.ChunkID, - StatusCode: uint32(statusCode), + StatusCode: status, Path: req.URL.Path, RawQuery: req.URL.RawQuery, Method: req.Method, diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index bb50afbb31510..9df1838e2d54a 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -267,7 +267,7 @@ func New(ctx context.Context, c *Config) (*Server, error) { if err != nil { return nil, trace.Wrap(err) } - awsHandler, err := appaws.NewAWSSignerHandler(appaws.SignerHandlerConfig{ + awsHandler, err := appaws.NewAWSSignerHandler(closeContext, appaws.SignerHandlerConfig{ SigningService: awsSigner, }) if err != nil { diff --git a/lib/srv/app/transport.go b/lib/srv/app/transport.go index 622de1188b13e..e544a7ca122ac 100644 --- a/lib/srv/app/transport.go +++ b/lib/srv/app/transport.go @@ -150,9 +150,10 @@ func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) { if err != nil { return nil, trace.Wrap(err) } + status := uint32(resp.StatusCode) // Emit the event to the audit log. - if err := sessCtx.Audit.OnRequest(t.closeContext, sessCtx, r, resp.StatusCode, nil /*aws endpoint*/); err != nil { + if err := sessCtx.Audit.OnRequest(t.closeContext, sessCtx, r, status, nil /*aws endpoint*/); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/utils/aws/signing.go b/lib/utils/aws/signing.go index 96b670a4e0330..b7cca750d1abc 100644 --- a/lib/utils/aws/signing.go +++ b/lib/utils/aws/signing.go @@ -18,6 +18,7 @@ package aws import ( "bytes" + "context" "net/http" "time" @@ -82,11 +83,17 @@ func (s *SigningServiceConfig) CheckAndSetDefaults() error { // SigningCtx contains AWS SigV4 signing context parameters. type SigningCtx struct { - SigningName string + // SigningName is the AWS signing service name. + SigningName string + // SigningRegion is the AWS region to sign a request for. SigningRegion string - Expiry time.Time - SessionName string - AWSRoleArn string + // Expiry is the expiration of the AWS credentials used to sign requests. + Expiry time.Time + // SessionName is role session name of AWS credentials used to sign requests. + SessionName string + // AWSRoleArn is the AWS ARN of the role to assume for signing requests. + AWSRoleArn string + // AWSExternalID is an optional external ID used when getting sts credentials. AWSExternalID string } @@ -124,7 +131,7 @@ func (sc *SigningCtx) Check(clock clockwork.Clock) error { // Not that for endpoint resolving the https://github.com/aws/aws-sdk-go/aws/endpoints/endpoints.go // package is used and when Amazon releases a new API the dependency update is needed. // 5. Sign HTTP request. -func (s *SigningService) SignRequest(req *http.Request, signCtx *SigningCtx) (*http.Request, error) { +func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, signCtx *SigningCtx) (*http.Request, error) { if signCtx == nil { return nil, trace.BadParameter("missing signing context") } @@ -135,11 +142,7 @@ func (s *SigningService) SignRequest(req *http.Request, signCtx *SigningCtx) (*h if err != nil { return nil, trace.Wrap(err) } - reqCopy, err := http.NewRequest(req.Method, req.URL.String(), bytes.NewReader(payload)) - if err != nil { - return nil, trace.Wrap(err) - } - reqCopy.Header = req.Header.Clone() + reqCopy := req.Clone(ctx) unsignedHeaders := removeUnsignedHeaders(reqCopy) credentials := s.GetSigningCredentials(s.Session, signCtx.Expiry, signCtx.SessionName, signCtx.AWSRoleArn, signCtx.AWSExternalID) From ab78f19505be0d32a22b1663133385555ddc9f39 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Mon, 19 Dec 2022 15:21:42 -0800 Subject: [PATCH 3/5] clone request instead of making a new request, and rewrite url to force https --- lib/srv/app/aws/handler.go | 16 ++++++++++------ lib/utils/aws/signing.go | 2 ++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go index 0f70c9df9e950..7a9c74d67dc71 100644 --- a/lib/srv/app/aws/handler.go +++ b/lib/srv/app/aws/handler.go @@ -18,6 +18,7 @@ package aws import ( "context" + "io" "net/http" "net/url" @@ -27,6 +28,7 @@ import ( "github.com/gravitational/trace" "github.com/sirupsen/logrus" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/srv/app/common" @@ -160,17 +162,19 @@ func (s *signerHandler) serveHTTP(w http.ResponseWriter, req *http.Request) erro return nil } -// rewriteRequest rewrites a request to remove Teleport reserved headers, sets the url, and sets host. +// rewriteRequest clones a request to remove Teleport reserved headers and rewrite the url. func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.ResolvedEndpoint) (*http.Request, error) { u, err := urlForResolvedEndpoint(r, re) if err != nil { return nil, trace.Wrap(err) } - outReq, err := http.NewRequestWithContext(ctx, r.Method, u.String(), r.Body) - if err != nil { - return nil, trace.Wrap(err) - } - outReq.Header = r.Header.Clone() + // force https + u.Scheme = "https" + + // clone the request for rewriting + outReq := r.Clone(ctx) + outReq.URL = u + outReq.Body = io.NopCloser(io.LimitReader(r.Body, teleport.MaxHTTPRequestSize)) for key := range outReq.Header { // Remove Teleport app headers. diff --git a/lib/utils/aws/signing.go b/lib/utils/aws/signing.go index b7cca750d1abc..3e2ac701b7f34 100644 --- a/lib/utils/aws/signing.go +++ b/lib/utils/aws/signing.go @@ -19,6 +19,7 @@ package aws import ( "bytes" "context" + "io" "net/http" "time" @@ -143,6 +144,7 @@ func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, sig return nil, trace.Wrap(err) } reqCopy := req.Clone(ctx) + reqCopy.Body = io.NopCloser(req.Body) unsignedHeaders := removeUnsignedHeaders(reqCopy) credentials := s.GetSigningCredentials(s.Session, signCtx.Expiry, signCtx.SessionName, signCtx.AWSRoleArn, signCtx.AWSExternalID) From ff9a74735c04278bd1b26b4ec01c37f75dc9ebcd Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Wed, 21 Dec 2022 16:44:55 -0800 Subject: [PATCH 4/5] Update header handling The handlers for aws/azure were inside of an oxy/forward.Forwarder RoundTrip function but once moved outside of that we should not pass host header of the inbound request. * Set oxy forwarder to PassHostHeader=false to ensure the host header is the URL being sought. * Remove code that deleted forwarding headers previously, we should keep those (X-Forwarded-*). * Audit log the AWS Host sought rather than the incoming request Host header (prior behavior maintained, we just rewrite the request differently using Clone). --- lib/srv/app/aws/handler.go | 20 +++++++++----------- lib/srv/app/aws/handler_test.go | 5 ++++- lib/srv/app/azure/handler.go | 7 ++++--- lib/srv/app/common/audit.go | 2 +- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/lib/srv/app/aws/handler.go b/lib/srv/app/aws/handler.go index 7a9c74d67dc71..70b00060d298a 100644 --- a/lib/srv/app/aws/handler.go +++ b/lib/srv/app/aws/handler.go @@ -86,7 +86,9 @@ func NewAWSSignerHandler(ctx context.Context, config SignerHandlerConfig) (http. fwd, err := forward.New( forward.RoundTripper(config.RoundTripper), forward.ErrorHandler(oxyutils.ErrorHandlerFunc(handler.formatForwardResponseError)), - forward.PassHostHeader(true), + // Explicitly passing false here to be clear that we always want the host + // header to be the same as the outbound request's URL host. + forward.PassHostHeader(false), ) if err != nil { return nil, trace.Wrap(err) @@ -168,20 +170,16 @@ func rewriteRequest(ctx context.Context, r *http.Request, re *endpoints.Resolved if err != nil { return nil, trace.Wrap(err) } - // force https - u.Scheme = "https" // clone the request for rewriting outReq := r.Clone(ctx) - outReq.URL = u - outReq.Body = io.NopCloser(io.LimitReader(r.Body, teleport.MaxHTTPRequestSize)) - - for key := range outReq.Header { - // Remove Teleport app headers. - if common.IsReservedHeader(key) || http.CanonicalHeaderKey(key) == "Content-Length" { - outReq.Header.Del(key) - } + if outReq.URL == nil { + outReq.URL = u + } else { + outReq.URL.Scheme = "https" + outReq.URL.Host = u.Host } + outReq.Body = io.NopCloser(io.LimitReader(r.Body, teleport.MaxHTTPRequestSize)) return outReq, nil } diff --git a/lib/srv/app/aws/handler_test.go b/lib/srv/app/aws/handler_test.go index 62f177f24513e..55c1366a36723 100644 --- a/lib/srv/app/aws/handler_test.go +++ b/lib/srv/app/aws/handler_test.go @@ -52,7 +52,9 @@ type makeRequest func(url string, provider client.ConfigProvider) error func s3Request(url string, provider client.ConfigProvider) error { s3Client := s3.New(provider, &aws.Config{ - Endpoint: &url, + Endpoint: &url, + MaxRetries: aws.Int(0), + HTTPClient: &http.Client{Timeout: 5 * time.Second}, }) _, err := s3Client.ListBuckets(&s3.ListBucketsInput{}) return err @@ -62,6 +64,7 @@ func dynamoRequest(url string, provider client.ConfigProvider) error { dynamoClient := dynamodb.New(provider, &aws.Config{ Endpoint: &url, MaxRetries: aws.Int(0), + HTTPClient: &http.Client{Timeout: 5 * time.Second}, }) _, err := dynamoClient.Scan(&dynamodb.ScanInput{ TableName: aws.String("test-table"), diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index b09f54f7852a6..39e14d4b298e3 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -112,7 +112,9 @@ func NewAzureHandler(ctx context.Context, config HandlerConfig) (http.Handler, e fwd, err := forward.New( forward.RoundTripper(config.RoundTripper), forward.ErrorHandler(oxyutils.ErrorHandlerFunc(svc.formatForwardResponseError)), - forward.PassHostHeader(true), + // Explicitly passing false here to be clear that we always want the host + // header to be the same as the outbound request's URL host. + forward.PassHostHeader(false), ) if err != nil { return nil, trace.Wrap(err) @@ -178,8 +180,7 @@ func (s *handler) prepareForwardRequest(r *http.Request, sessionCtx *common.Sess reqCopy.URL.Scheme = "https" reqCopy.URL.Host = forwardedHost - - copyHeaders(r, reqCopy) + reqCopy.Header = r.Header.Clone() err = s.replaceAuthHeaders(r, sessionCtx, reqCopy) if err != nil { diff --git a/lib/srv/app/common/audit.go b/lib/srv/app/common/audit.go index 197e63a12dae1..7a4ac0e63b6e0 100644 --- a/lib/srv/app/common/audit.go +++ b/lib/srv/app/common/audit.go @@ -236,6 +236,6 @@ func MakeAWSRequestMetadata(req *http.Request, awsEndpoint *endpoints.ResolvedEn return &apievents.AWSRequestMetadata{ AWSRegion: awsEndpoint.SigningRegion, AWSService: awsEndpoint.SigningName, - AWSHost: req.Host, + AWSHost: req.URL.Host, } } From 215d3dcda9d3b876ca91c4fdbb3b53898f02912a Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Thu, 22 Dec 2022 10:49:06 -0800 Subject: [PATCH 5/5] Remove obsolete header copying helper func --- lib/srv/app/azure/handler.go | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index 39e14d4b298e3..9ff67a2877899 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -289,16 +289,3 @@ func (s *handler) getToken(ctx context.Context, managedIdentity string, scope st return s.getAccessToken(ctx, managedIdentity, scope) }) } - -func copyHeaders(r *http.Request, reqCopy *http.Request) { - for key, values := range r.Header { - // Remove Teleport app headers. - if common.IsReservedHeader(key) { - continue - } - - for _, v := range values { - reqCopy.Header.Add(key, v) - } - } -}