From 53950b4c6e4821b3289612f9b505226d526a1059 Mon Sep 17 00:00:00 2001 From: Fred Heinecke Date: Fri, 13 Dec 2024 00:16:41 -0600 Subject: [PATCH 1/2] Add support for HTTP 'Accept' header to '/webapi/auth/export' endpoint Signed-off-by: Fred Heinecke --- lib/web/apiserver_test.go | 299 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 299 insertions(+) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index ca6f1272e1763..6ad44e73f8c70 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -11375,3 +11375,302 @@ func Test_setEntitlementsWithLegacyLogic(t *testing.T) { }) } } + +func Test_encodeTextResponse(t *testing.T) { + t.Parallel() + + baselineTestBody := "@cert-authority platform.teleport.sh,*.platform.teleport.sh ssh-rsa" + jsonTestResponseBody, err := json.Marshal(baselineTestBody) + require.NoError(t, err) + + testCases := []struct { + providedBody string + expectedBody []byte + acceptHeader []string + expectedContentTypeHeader string + errFunc require.ErrorAssertionFunc + }{ + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"text/plain"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"text/plain;q=0.8"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"text/*"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"text/plain", "application/json"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: jsonTestResponseBody, + acceptHeader: []string{"application/json"}, + expectedContentTypeHeader: "application/json", + }, + { + providedBody: baselineTestBody, + expectedBody: jsonTestResponseBody, + acceptHeader: []string{"application/json", "text/plain"}, + expectedContentTypeHeader: "application/json", + }, + { + providedBody: baselineTestBody, + expectedBody: jsonTestResponseBody, + acceptHeader: []string{"unsupported", "application/json"}, + expectedContentTypeHeader: "application/json", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"application/scim+json"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"application/*"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"*/*"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"*"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"text/"}, // Malformed + expectedContentTypeHeader: "text/plain", + errFunc: require.Error, + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"unsupportedtype"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: baselineTestBody, + expectedBody: jsonTestResponseBody, + acceptHeader: []string{"application/json, unsupportedtype"}, + expectedContentTypeHeader: "application/json", + }, + { + providedBody: baselineTestBody, + expectedBody: []byte(baselineTestBody), + acceptHeader: []string{"text/plain, application/json"}, + expectedContentTypeHeader: "text/plain", + }, + } + + for _, testCase := range testCases { + // Setup + req, err := http.NewRequest("", "", nil) + require.NoError(t, err) + req.Header["Accept"] = testCase.acceptHeader + + // Test + responseBodyBytes, contentType, err := encodeTextResponse(testCase.providedBody, req) + + // Verify results + if testCase.errFunc == nil { + testCase.errFunc = require.NoError + } + testCase.errFunc(t, err) + + require.ElementsMatch(t, testCase.expectedBody, responseBodyBytes) + + if len(contentType) > len(testCase.expectedContentTypeHeader) { + // Handle the case where other parameters are returned (such as charset), which _requires_ + // a semicolon + require.True(t, strings.HasPrefix(contentType, testCase.expectedContentTypeHeader+";")) + } else { + require.Equal(t, testCase.expectedContentTypeHeader, contentType) + } + } +} + +func Test_serveErrorResponse(t *testing.T) { + t.Parallel() + + testCases := []struct { + err error + expectedBody []byte + acceptHeader []string + expectedContentTypeHeader string + }{ + { + err: fmt.Errorf("error message"), + expectedBody: []byte("error message"), + expectedContentTypeHeader: "text/plain", + }, + { + err: fmt.Errorf("error message"), + expectedBody: []byte("\"error message\""), + acceptHeader: []string{"application/json"}, + expectedContentTypeHeader: "application/json", + }, + { + err: fmt.Errorf("error message"), + expectedBody: []byte("\"error message\""), + acceptHeader: []string{"application/json", "text/plain"}, + expectedContentTypeHeader: "application/json", + }, + { + err: fmt.Errorf("error message"), + expectedBody: []byte("error message"), + acceptHeader: []string{"application/unsupported", "text/plain"}, + expectedContentTypeHeader: "text/plain", + }, + { + err: fmt.Errorf("error message"), + expectedBody: []byte("error message"), + acceptHeader: []string{"malformed/", "text/plain"}, + expectedContentTypeHeader: "text/plain", + }, + } + + for _, testCase := range testCases { + // Setup + req, err := http.NewRequest("", "", nil) + require.NoError(t, err) + req.Header["Accept"] = testCase.acceptHeader + + respWriter := httptest.NewRecorder() + + // Test + (&Handler{}).serveErrorResponse(testCase.err, respWriter, req) + + // Verify results + require.ElementsMatch(t, testCase.expectedBody, respWriter.Body.Bytes()) + + contentTypes := respWriter.Header()["Content-Type"] + require.Len(t, contentTypes, 1) + + contentType := contentTypes[0] + if len(contentType) > len(testCase.expectedContentTypeHeader) { + // Handle the case where other parameters are returned (such as charset), which _requires_ + // a semicolon + require.True(t, strings.HasPrefix(contentType, testCase.expectedContentTypeHeader+";")) + } else { + require.Equal(t, testCase.expectedContentTypeHeader, contentType) + } + } +} + +func Test_serveTextResponse(t *testing.T) { + t.Parallel() + + testCases := []struct { + providedBody string + expectedBody []byte + acceptHeader []string + expectedContentTypeHeader string + errFunc require.ErrorAssertionFunc + shouldNotSendResponse bool + }{ + { + providedBody: "test message", + expectedBody: []byte("test message"), + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: "test message", + expectedBody: []byte("\"test message\""), + acceptHeader: []string{"application/json"}, + expectedContentTypeHeader: "application/json", + }, + { + providedBody: "test message", + expectedBody: []byte("\"test message\""), + acceptHeader: []string{"application/json", "text/plain"}, + expectedContentTypeHeader: "application/json", + }, + { + providedBody: "test message", + expectedBody: []byte("test message"), + acceptHeader: []string{"application/unsupported", "text/plain"}, + expectedContentTypeHeader: "text/plain", + }, + { + providedBody: "test message", + acceptHeader: []string{"malformed/", "text/plain"}, + errFunc: require.Error, + shouldNotSendResponse: true, + }, + { + providedBody: "test message", + acceptHeader: []string{"malformed/"}, + errFunc: require.Error, + shouldNotSendResponse: true, + }, + } + + for _, testCase := range testCases { + // Setup + req, err := http.NewRequest("", "", nil) + require.NoError(t, err) + req.Header["Accept"] = testCase.acceptHeader + + respWriter := httptest.NewRecorder() + respWriter.Code = 0 + + // Test + err = serveTextResponse(testCase.providedBody, respWriter, req) + + // Verify results + if testCase.errFunc == nil { + testCase.errFunc = require.NoError + } + testCase.errFunc(t, err) + + expectedStatusCode := http.StatusOK + if testCase.shouldNotSendResponse { + expectedStatusCode = 0 + } + require.Equal(t, expectedStatusCode, respWriter.Code) + + require.ElementsMatch(t, testCase.expectedBody, respWriter.Body.Bytes()) + + if testCase.shouldNotSendResponse { + continue + } + + contentTypes := respWriter.Header()["Content-Type"] + require.Len(t, contentTypes, 1) + + contentType := contentTypes[0] + if len(contentType) > len(testCase.expectedContentTypeHeader) { + // Handle the case where other parameters are returned (such as charset), which _requires_ + // a semicolon + require.True(t, strings.HasPrefix(contentType, testCase.expectedContentTypeHeader+";")) + } else { + require.Equal(t, testCase.expectedContentTypeHeader, contentType) + } + } +} From 9e42a9fb9c059734c8ef796cb41780d23e669dea Mon Sep 17 00:00:00 2001 From: Fred Heinecke Date: Wed, 19 Feb 2025 14:08:10 -0600 Subject: [PATCH 2/2] Switch PR to use new `format` parameter --- lib/client/ca_export.go | 2 +- lib/web/apiserver_test.go | 299 -------------------------------------- lib/web/ca_export.go | 39 ++++- lib/web/ca_export_test.go | 30 ++++ 4 files changed, 62 insertions(+), 308 deletions(-) diff --git a/lib/client/ca_export.go b/lib/client/ca_export.go index ad5de7e995e34..98bfe9d34da6d 100644 --- a/lib/client/ca_export.go +++ b/lib/client/ca_export.go @@ -59,7 +59,7 @@ type ExportedAuthority struct { // Data is the output of the exported authority. // May be an SSH authorized key, an SSH known hosts entry, a DER or a PEM, // depending on the type of the exported authority. - Data []byte + Data []byte `json:"data"` } // ExportAllAuthorities exports public keys of all authorities of a particular diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 6ad44e73f8c70..ca6f1272e1763 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -11375,302 +11375,3 @@ func Test_setEntitlementsWithLegacyLogic(t *testing.T) { }) } } - -func Test_encodeTextResponse(t *testing.T) { - t.Parallel() - - baselineTestBody := "@cert-authority platform.teleport.sh,*.platform.teleport.sh ssh-rsa" - jsonTestResponseBody, err := json.Marshal(baselineTestBody) - require.NoError(t, err) - - testCases := []struct { - providedBody string - expectedBody []byte - acceptHeader []string - expectedContentTypeHeader string - errFunc require.ErrorAssertionFunc - }{ - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"text/plain"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"text/plain;q=0.8"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"text/*"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"text/plain", "application/json"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: jsonTestResponseBody, - acceptHeader: []string{"application/json"}, - expectedContentTypeHeader: "application/json", - }, - { - providedBody: baselineTestBody, - expectedBody: jsonTestResponseBody, - acceptHeader: []string{"application/json", "text/plain"}, - expectedContentTypeHeader: "application/json", - }, - { - providedBody: baselineTestBody, - expectedBody: jsonTestResponseBody, - acceptHeader: []string{"unsupported", "application/json"}, - expectedContentTypeHeader: "application/json", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"application/scim+json"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"application/*"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"*/*"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"*"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"text/"}, // Malformed - expectedContentTypeHeader: "text/plain", - errFunc: require.Error, - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"unsupportedtype"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: baselineTestBody, - expectedBody: jsonTestResponseBody, - acceptHeader: []string{"application/json, unsupportedtype"}, - expectedContentTypeHeader: "application/json", - }, - { - providedBody: baselineTestBody, - expectedBody: []byte(baselineTestBody), - acceptHeader: []string{"text/plain, application/json"}, - expectedContentTypeHeader: "text/plain", - }, - } - - for _, testCase := range testCases { - // Setup - req, err := http.NewRequest("", "", nil) - require.NoError(t, err) - req.Header["Accept"] = testCase.acceptHeader - - // Test - responseBodyBytes, contentType, err := encodeTextResponse(testCase.providedBody, req) - - // Verify results - if testCase.errFunc == nil { - testCase.errFunc = require.NoError - } - testCase.errFunc(t, err) - - require.ElementsMatch(t, testCase.expectedBody, responseBodyBytes) - - if len(contentType) > len(testCase.expectedContentTypeHeader) { - // Handle the case where other parameters are returned (such as charset), which _requires_ - // a semicolon - require.True(t, strings.HasPrefix(contentType, testCase.expectedContentTypeHeader+";")) - } else { - require.Equal(t, testCase.expectedContentTypeHeader, contentType) - } - } -} - -func Test_serveErrorResponse(t *testing.T) { - t.Parallel() - - testCases := []struct { - err error - expectedBody []byte - acceptHeader []string - expectedContentTypeHeader string - }{ - { - err: fmt.Errorf("error message"), - expectedBody: []byte("error message"), - expectedContentTypeHeader: "text/plain", - }, - { - err: fmt.Errorf("error message"), - expectedBody: []byte("\"error message\""), - acceptHeader: []string{"application/json"}, - expectedContentTypeHeader: "application/json", - }, - { - err: fmt.Errorf("error message"), - expectedBody: []byte("\"error message\""), - acceptHeader: []string{"application/json", "text/plain"}, - expectedContentTypeHeader: "application/json", - }, - { - err: fmt.Errorf("error message"), - expectedBody: []byte("error message"), - acceptHeader: []string{"application/unsupported", "text/plain"}, - expectedContentTypeHeader: "text/plain", - }, - { - err: fmt.Errorf("error message"), - expectedBody: []byte("error message"), - acceptHeader: []string{"malformed/", "text/plain"}, - expectedContentTypeHeader: "text/plain", - }, - } - - for _, testCase := range testCases { - // Setup - req, err := http.NewRequest("", "", nil) - require.NoError(t, err) - req.Header["Accept"] = testCase.acceptHeader - - respWriter := httptest.NewRecorder() - - // Test - (&Handler{}).serveErrorResponse(testCase.err, respWriter, req) - - // Verify results - require.ElementsMatch(t, testCase.expectedBody, respWriter.Body.Bytes()) - - contentTypes := respWriter.Header()["Content-Type"] - require.Len(t, contentTypes, 1) - - contentType := contentTypes[0] - if len(contentType) > len(testCase.expectedContentTypeHeader) { - // Handle the case where other parameters are returned (such as charset), which _requires_ - // a semicolon - require.True(t, strings.HasPrefix(contentType, testCase.expectedContentTypeHeader+";")) - } else { - require.Equal(t, testCase.expectedContentTypeHeader, contentType) - } - } -} - -func Test_serveTextResponse(t *testing.T) { - t.Parallel() - - testCases := []struct { - providedBody string - expectedBody []byte - acceptHeader []string - expectedContentTypeHeader string - errFunc require.ErrorAssertionFunc - shouldNotSendResponse bool - }{ - { - providedBody: "test message", - expectedBody: []byte("test message"), - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: "test message", - expectedBody: []byte("\"test message\""), - acceptHeader: []string{"application/json"}, - expectedContentTypeHeader: "application/json", - }, - { - providedBody: "test message", - expectedBody: []byte("\"test message\""), - acceptHeader: []string{"application/json", "text/plain"}, - expectedContentTypeHeader: "application/json", - }, - { - providedBody: "test message", - expectedBody: []byte("test message"), - acceptHeader: []string{"application/unsupported", "text/plain"}, - expectedContentTypeHeader: "text/plain", - }, - { - providedBody: "test message", - acceptHeader: []string{"malformed/", "text/plain"}, - errFunc: require.Error, - shouldNotSendResponse: true, - }, - { - providedBody: "test message", - acceptHeader: []string{"malformed/"}, - errFunc: require.Error, - shouldNotSendResponse: true, - }, - } - - for _, testCase := range testCases { - // Setup - req, err := http.NewRequest("", "", nil) - require.NoError(t, err) - req.Header["Accept"] = testCase.acceptHeader - - respWriter := httptest.NewRecorder() - respWriter.Code = 0 - - // Test - err = serveTextResponse(testCase.providedBody, respWriter, req) - - // Verify results - if testCase.errFunc == nil { - testCase.errFunc = require.NoError - } - testCase.errFunc(t, err) - - expectedStatusCode := http.StatusOK - if testCase.shouldNotSendResponse { - expectedStatusCode = 0 - } - require.Equal(t, expectedStatusCode, respWriter.Code) - - require.ElementsMatch(t, testCase.expectedBody, respWriter.Body.Bytes()) - - if testCase.shouldNotSendResponse { - continue - } - - contentTypes := respWriter.Header()["Content-Type"] - require.Len(t, contentTypes, 1) - - contentType := contentTypes[0] - if len(contentType) > len(testCase.expectedContentTypeHeader) { - // Handle the case where other parameters are returned (such as charset), which _requires_ - // a semicolon - require.True(t, strings.HasPrefix(contentType, testCase.expectedContentTypeHeader+";")) - } else { - require.Equal(t, testCase.expectedContentTypeHeader, contentType) - } - } -} diff --git a/lib/web/ca_export.go b/lib/web/ca_export.go index 99e548f84ece1..eab9c894f6b35 100644 --- a/lib/web/ca_export.go +++ b/lib/web/ca_export.go @@ -19,6 +19,7 @@ package web import ( "archive/zip" "bytes" + "encoding/json" "fmt" "net/http" "time" @@ -52,12 +53,6 @@ func (h *Handler) authExportPublicError(w http.ResponseWriter, r *http.Request, query := r.URL.Query() caType := query.Get("type") // validated by ExportAllAuthorities - format := query.Get("format") - - const formatZip = "zip" - if format != "" && format != formatZip { - return trace.BadParameter("unsupported format %q", format) - } ctx := r.Context() authorities, err := client.ExportAllAuthorities( @@ -72,11 +67,23 @@ func (h *Handler) authExportPublicError(w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - if format == formatZip { + format := query.Get("format") + + const formatZip = "zip" + const formatJSON = "json" + switch format { + case "": + break + case formatZip: return h.authExportPublicZip(w, r, authorities) + case formatJSON: + return h.authExportPublicJSON(w, r, authorities) + default: + return trace.BadParameter("unsupported format %q", format) } + if l := len(authorities); l > 1 { - return trace.BadParameter("found %d authorities to export, use format=%s to export all", l, formatZip) + return trace.BadParameter("found %d authorities to export, use format=%s or format=%s to export all", l, formatZip, formatJSON) } // ServeContent sets the correct headers: Content-Type, Content-Length and Accept-Ranges. @@ -119,3 +126,19 @@ func (h *Handler) authExportPublicZip( http.ServeContent(w, r, zipName, now, bytes.NewReader(out.Bytes())) return nil } + +func (h *Handler) authExportPublicJSON( + w http.ResponseWriter, + r *http.Request, + authorities []*client.ExportedAuthority, +) error { + marshalledAuthorities, err := json.Marshal(authorities) + if err != nil { + return trace.Wrap(err, "failed to JSON marshal authorities") + } + + // File name is not critical here. It is only used by `ServeContent` to determine the value of the + // `Content-Type` header. + http.ServeContent(w, r, "export.json", time.Now(), bytes.NewReader(marshalledAuthorities)) + return nil +} diff --git a/lib/web/ca_export_test.go b/lib/web/ca_export_test.go index 258b9f44f4910..8963d750cbf0e 100644 --- a/lib/web/ca_export_test.go +++ b/lib/web/ca_export_test.go @@ -22,6 +22,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "encoding/json" "encoding/pem" "fmt" "io" @@ -33,6 +34,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/client" ) func TestAuthExport(t *testing.T) { @@ -90,6 +93,22 @@ func TestAuthExport(t *testing.T) { validateFormatZip(t, body, wantCAFiles, validateTLSCertificatePEMFunc) } + validateFormatJSON := func( + t *testing.T, + body []byte, + wantCAFiles int, + validateCAFile func(t *testing.T, contents []byte), + ) { + var authorities []client.ExportedAuthority + err := json.Unmarshal(body, &authorities) + require.NoError(t, err) + assert.Len(t, authorities, wantCAFiles) + + for _, authority := range authorities { + validateCAFile(t, authority.Data) + } + } + ctx := context.Background() for _, tt := range []struct { @@ -215,6 +234,17 @@ func TestAuthExport(t *testing.T) { validateFormatZipPEM(t, b, 1 /* wantCAFiles */) }, }, + { + name: "format=json", + params: url.Values{ + "type": []string{"db-client"}, + "format": []string{"json"}, + }, + expectedStatus: http.StatusOK, + assertBody: func(t *testing.T, b []byte) { + validateFormatJSON(t, b, 1, validateTLSCertificatePEMFunc) + }, + }, } { t.Run(tt.name, func(t *testing.T) { t.Parallel()