diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f9dbd1d5c..38ea69e1f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ * [preloader] Parse klog flags. * [CTFE] Add a /log.v3.json endpoint to help satisfy a requirement of the Chrome CT Log Policy by @robstradling in https://github.com/google/certificate-transparency-go/pull/1703 * [preloader] add continuous mode. - +* [CTFE] Enforce max request body size using `http.MaxBytesHandler`. ## v1.3.2 ### Misc diff --git a/trillian/ctfe/ct_server/main.go b/trillian/ctfe/ct_server/main.go index 1410faf29d..44885bc9b2 100644 --- a/trillian/ctfe/ct_server/main.go +++ b/trillian/ctfe/ct_server/main.go @@ -88,6 +88,7 @@ var ( cacheSize = flag.Int("cache_size", -1, "Size parameter set to 0 makes cache of unlimited size") cacheTTL = flag.Duration("cache_ttl", -1*time.Second, "Providing 0 TTL turns expiring off") trillianTLSCACertFile = flag.String("trillian_tls_ca_cert_file", "", "CA certificate file to use for secure connections with Trillian server") + maxCertChainSize = flag.Int64("max_cert_chain_size", 512000, "Maximum size of certificate chain in bytes for add-chain and add-pre-chain endpoints (default: 512000 bytes = 500KB)") ) const unknownRemoteUser = "UNKNOWN_REMOTE" @@ -307,7 +308,7 @@ func main() { go func() { mux := http.NewServeMux() mux.Handle("/metrics", promhttp.Handler()) - metricsServer := http.Server{Addr: metricsAt, Handler: mux} + metricsServer := http.Server{Addr: metricsAt, Handler: mux, MaxHeaderBytes: 128 * 1024} err := metricsServer.ListenAndServe() klog.Warningf("Metrics server exited: %v", err) }() @@ -336,9 +337,9 @@ func main() { Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS12, } - srv = http.Server{Addr: *httpEndpoint, Handler: handler, TLSConfig: tlsConfig} + srv = http.Server{Addr: *httpEndpoint, Handler: handler, TLSConfig: tlsConfig, MaxHeaderBytes: 128 * 1024} } else { - srv = http.Server{Addr: *httpEndpoint, Handler: handler} + srv = http.Server{Addr: *httpEndpoint, Handler: handler, MaxHeaderBytes: 128 * 1024} } if *httpIdleTimeout > 0 { srv.IdleTimeout = *httpIdleTimeout @@ -448,7 +449,12 @@ func setupAndRegister(ctx context.Context, client trillian.TrillianLogClient, de return nil, err } for path, handler := range inst.Handlers { - mux.Handle(lhp+path, handler) + if strings.HasSuffix(path, "/add-chain") || strings.HasSuffix(path, "/add-pre-chain") { + klog.Infof("Applying MaxBytesHandler to %s with limit %d bytes", lhp+path, *maxCertChainSize) + mux.Handle(lhp+path, http.MaxBytesHandler(handler, *maxCertChainSize)) + } else { + mux.Handle(lhp+path, handler) + } } return inst, nil } diff --git a/trillian/ctfe/handlers.go b/trillian/ctfe/handlers.go index 23c5894b84..e6dbf34f22 100644 --- a/trillian/ctfe/handlers.go +++ b/trillian/ctfe/handlers.go @@ -405,6 +405,10 @@ func (li *logInfo) buildLeaf(ctx context.Context, chain []*x509.Certificate, mer func ParseBodyAsJSONChain(r *http.Request) (ct.AddChainRequest, error) { body, err := io.ReadAll(r.Body) if err != nil { + if mbe, ok := err.(*http.MaxBytesError); ok { + klog.V(1).Infof("Request body exceeds %d-byte limit", mbe.Limit) + return ct.AddChainRequest{}, fmt.Errorf("certificate chain exceeds %d-byte limit: %w", mbe.Limit, err) + } klog.V(1).Infof("Failed to read request body: %v", err) return ct.AddChainRequest{}, err } @@ -461,6 +465,10 @@ func addChainInternal(ctx context.Context, li *logInfo, w http.ResponseWriter, r // Check the contents of the request and convert to slice of certificates. addChainReq, err := ParseBodyAsJSONChain(r) if err != nil { + var maxBytesErr *http.MaxBytesError + if errors.As(err, &maxBytesErr) { + return http.StatusRequestEntityTooLarge, fmt.Errorf("%s: %v", li.LogPrefix, err) + } return http.StatusBadRequest, fmt.Errorf("%s: failed to parse add-chain body: %s", li.LogPrefix, err) } // Log the DERs now because they might not parse as valid X.509. diff --git a/trillian/ctfe/handlers_test.go b/trillian/ctfe/handlers_test.go index a0c1241219..abba8a86c8 100644 --- a/trillian/ctfe/handlers_test.go +++ b/trillian/ctfe/handlers_test.go @@ -240,24 +240,36 @@ func TestGetHandlersRejectPost(t *testing.T) { func TestPostHandlersFailure(t *testing.T) { var tests = []struct { descr string - body io.Reader + body func() io.Reader want int }{ - {"nil", nil, http.StatusBadRequest}, - {"''", strings.NewReader(""), http.StatusBadRequest}, - {"malformed-json", strings.NewReader("{ !$%^& not valid json "), http.StatusBadRequest}, - {"empty-chain", strings.NewReader(`{ "chain": [] }`), http.StatusBadRequest}, - {"wrong-chain", strings.NewReader(`{ "chain": [ "test" ] }`), http.StatusBadRequest}, + {"nil", func() io.Reader { return nil }, http.StatusBadRequest}, + {"''", func() io.Reader { return strings.NewReader("") }, http.StatusBadRequest}, + {"malformed-json", func() io.Reader { return strings.NewReader("{ !$%^& not valid json ") }, http.StatusBadRequest}, + {"empty-chain", func() io.Reader { return strings.NewReader(`{ "chain": [] }`) }, http.StatusBadRequest}, + {"wrong-chain", func() io.Reader { return strings.NewReader(`{ "chain": [ "test" ] }`) }, http.StatusBadRequest}, + {"too-large-body", func() io.Reader { + return strings.NewReader(fmt.Sprintf(`{ "chain": [ "%s" ] }`, strings.Repeat("A", 600000))) + }, http.StatusRequestEntityTooLarge}, } info := setupTest(t, []string{cttestonly.FakeCACertPEM}, nil) defer info.mockCtrl.Finish() + maxCertChainSize := int64(500 * 1024) for path, handler := range info.postHandlers() { t.Run(path, func(t *testing.T) { - s := httptest.NewServer(handler) + var wrappedHandler http.Handler + if path == "add-chain" || path == "add-pre-chain" { + wrappedHandler = http.MaxBytesHandler(http.Handler(handler), maxCertChainSize) + } else { + wrappedHandler = handler + } + + s := httptest.NewServer(wrappedHandler) + defer s.Close() for _, test := range tests { - resp, err := http.Post(s.URL+"/ct/v1/"+path, "application/json", test.body) + resp, err := http.Post(s.URL+"/ct/v1/"+path, "application/json", test.body()) if err != nil { t.Errorf("http.Post(%s,%s)=(_,%q); want (_,nil)", path, test.descr, err) continue