diff --git a/caddyhttp/push/handler.go b/caddyhttp/push/handler.go index 532e7a2b08e..6e0c8ee3794 100644 --- a/caddyhttp/push/handler.go +++ b/caddyhttp/push/handler.go @@ -21,20 +21,16 @@ func (h Middleware) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, erro return h.Next.ServeHTTP(w, r) } - // Serve file first - code, err := h.Next.ServeHTTP(w, r) - - if flusher, ok := w.(http.Flusher); ok { - flusher.Flush() - } + headers := h.filterProxiedHeaders(r.Header) + // Push first outer: for _, rule := range h.Rules { if httpserver.Path(r.URL.Path).Matches(rule.Path) { for _, resource := range rule.Resources { pushErr := pusher.Push(resource.Path, &http.PushOptions{ Method: resource.Method, - Header: resource.Header, + Header: h.mergeHeaders(headers, resource.Header), }) if pushErr != nil { // If we cannot push (either not supported or concurrent streams are full - break) @@ -44,14 +40,17 @@ outer: } } + // Serve later + code, err := h.Next.ServeHTTP(w, r) + if links, exists := w.Header()["Link"]; exists { - h.pushLinks(pusher, links) + h.servePreloadLinks(pusher, headers, links) } return code, err } -func (h Middleware) pushLinks(pusher http.Pusher, links []string) { +func (h Middleware) servePreloadLinks(pusher http.Pusher, headers http.Header, links []string) { outer: for _, link := range links { parts := strings.Split(link, ";") @@ -62,9 +61,51 @@ outer: target := strings.TrimSuffix(strings.TrimPrefix(parts[0], "<"), ">") - err := pusher.Push(target, &http.PushOptions{Method: http.MethodGet}) + err := pusher.Push(target, &http.PushOptions{ + Method: http.MethodGet, + Header: headers, + }) + if err != nil { break outer } } } + +func (h Middleware) mergeHeaders(l, r http.Header) http.Header { + + out := http.Header{} + + for k, v := range l { + out[k] = v + } + + for k, vv := range r { + for _, v := range vv { + out.Add(k, v) + } + } + + return out +} + +func (h Middleware) filterProxiedHeaders(headers http.Header) http.Header { + + filter := http.Header{} + + for _, header := range proxiedHeaders { + if val, ok := headers[header]; ok { + filter[header] = val + } + } + + return filter +} + +var proxiedHeaders = []string{ + "Accept-Encoding", + "Accept-Language", + "Cache-Control", + "Host", + "User-Agent", +} diff --git a/caddyhttp/push/handler_test.go b/caddyhttp/push/handler_test.go index ad680245972..343ab7460eb 100644 --- a/caddyhttp/push/handler_test.go +++ b/caddyhttp/push/handler_test.go @@ -61,7 +61,51 @@ func TestMiddlewareWillPushResources(t *testing.T) { "/index2.css": { Method: http.MethodGet, - Header: nil, + Header: http.Header{}, + }, + } + + comparePushedResources(t, expectedPushedResources, pushingWriter.pushed) +} + +func TestMiddlewareWillPushResourcesWithMergedHeaders(t *testing.T) { + + // given + request, err := http.NewRequest(http.MethodGet, "/index.html", nil) + request.Header = http.Header{"Accept-Encoding": []string{"br"}, "Invalid-Header": []string{"Should be filter out"}} + writer := httptest.NewRecorder() + + if err != nil { + t.Fatalf("Could not create HTTP request: %v", err) + } + + middleware := Middleware{ + Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { + return 0, nil + }), + Rules: []Rule{ + {Path: "/index.html", Resources: []Resource{ + {Path: "/index.css", Method: http.MethodHead, Header: http.Header{"Test": []string{"Value"}}}, + {Path: "/index2.css", Method: http.MethodGet}, + }}, + }, + } + + pushingWriter := &MockedPusher{ResponseWriter: writer} + + // when + middleware.ServeHTTP(pushingWriter, request) + + // then + expectedPushedResources := map[string]*http.PushOptions{ + "/index.css": { + Method: http.MethodHead, + Header: http.Header{"Test": []string{"Value"}, "Accept-Encoding": []string{"br"}}, + }, + + "/index2.css": { + Method: http.MethodGet, + Header: http.Header{"Accept-Encoding": []string{"br"}}, }, } @@ -169,7 +213,7 @@ func TestMiddlewareWillNotPushResources(t *testing.T) { // then if err2 != nil { - t.Errorf("Should not return error") + t.Error("Should not return error") } } @@ -201,21 +245,21 @@ func TestMiddlewareShouldInterceptLinkHeader(t *testing.T) { // then if err2 != nil { - t.Errorf("Should not return error") + t.Error("Should not return error") } expectedPushedResources := map[string]*http.PushOptions{ "/index.css": { Method: http.MethodGet, - Header: nil, + Header: http.Header{}, }, "/index2.css": { Method: http.MethodGet, - Header: nil, + Header: http.Header{}, }, "/index3.css": { Method: http.MethodGet, - Header: nil, + Header: http.Header{}, }, } @@ -224,7 +268,10 @@ func TestMiddlewareShouldInterceptLinkHeader(t *testing.T) { func TestMiddlewareShouldInterceptLinkHeaderPusherError(t *testing.T) { // given + expectedHeaders := http.Header{"Accept-Encoding": []string{"br"}} request, err := http.NewRequest(http.MethodGet, "/index.html", nil) + request.Header = http.Header{"Accept-Encoding": []string{"br"}, "Invalid-Header": []string{"Should be filter out"}} + writer := httptest.NewRecorder() if err != nil { @@ -247,13 +294,13 @@ func TestMiddlewareShouldInterceptLinkHeaderPusherError(t *testing.T) { // then if err2 != nil { - t.Errorf("Should not return error") + t.Error("Should not return error") } expectedPushedResources := map[string]*http.PushOptions{ "/index.css": { Method: http.MethodGet, - Header: nil, + Header: expectedHeaders, }, } @@ -273,7 +320,7 @@ func comparePushedResources(t *testing.T, expected, actual map[string]*http.Push } if !reflect.DeepEqual(expectedTarget.Header, actualTarget.Header) { - t.Errorf("Expected %s resource push headers to be %v, actual: %v", target, expectedTarget.Header, actualTarget.Header) + t.Errorf("Expected %s resource push headers to be %+v, actual: %+v", target, expectedTarget.Header, actualTarget.Header) } } else { t.Errorf("Expected %s to be pushed", target)