From b6b7771fbca0c5079f42b3441f948b5dd547f1e9 Mon Sep 17 00:00:00 2001 From: Henrique Dias Date: Thu, 7 Mar 2024 09:34:55 +0100 Subject: [PATCH] routing/http/server: add cache control (#584) --- CHANGELOG.md | 2 + routing/http/server/server.go | 60 ++++++++-- routing/http/server/server_test.go | 164 ++++++++++++++++++++++----- routing/http/types/json/responses.go | 12 ++ 4 files changed, 200 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d148b03c..ecf1ffe2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The following emojis are used to highlight certain changes: ### Added +* `routing/http/server` now adds `Cache-Control` HTTP header to GET requests: 15 seconds for empty responses, or 5 minutes for responses with providers. + ### Changed ### Removed diff --git a/routing/http/server/server.go b/routing/http/server/server.go index df93c57fd..b00f22012 100644 --- a/routing/http/server/server.go +++ b/routing/http/server/server.go @@ -9,7 +9,6 @@ import ( "io" "mime" "net/http" - "strconv" "strings" "time" @@ -402,15 +401,26 @@ func (s *server) GetIPNS(w http.ResponseWriter, r *http.Request) { return } + var remainingValidity int + // Include 'Expires' header with time when signature expiration happens + if validityType, err := record.ValidityType(); err == nil && validityType == ipns.ValidityEOL { + if validity, err := record.Validity(); err == nil { + w.Header().Set("Expires", validity.UTC().Format(http.TimeFormat)) + remainingValidity = int(time.Until(validity).Seconds()) + } + } else { + remainingValidity = int(ipns.DefaultRecordLifetime.Seconds()) + } if ttl, err := record.TTL(); err == nil { - w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", int(ttl.Seconds()))) + setCacheControl(w, int(ttl.Seconds()), remainingValidity) } else { - w.Header().Set("Cache-Control", "max-age=60") + setCacheControl(w, int(ipns.DefaultRecordTTL.Seconds()), remainingValidity) } + w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat)) - recordEtag := strconv.FormatUint(xxhash.Sum64(rawRecord), 32) - w.Header().Set("Etag", recordEtag) + w.Header().Set("Etag", fmt.Sprintf(`"%x"`, xxhash.Sum64(rawRecord))) w.Header().Set("Content-Type", mediaTypeIPNSRecord) + w.Header().Add("Vary", "Accept") w.Write(rawRecord) } @@ -462,8 +472,30 @@ func (s *server) PutIPNS(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -func writeJSONResult(w http.ResponseWriter, method string, val any) { +var ( + // Rule-of-thumb Cache-Control policy is to work well with caching proxies and load balancers. + // If there are any results, cache on the client for longer, and hint any in-between caches to + // serve cached result and upddate cache in background as long we have + // result that is within Amino DHT expiration window + maxAgeWithResults = int((5 * time.Minute).Seconds()) // cache >0 results for longer + maxAgeWithoutResults = int((15 * time.Second).Seconds()) // cache no results briefly + maxStale = int((48 * time.Hour).Seconds()) // allow stale results as long within Amino DHT Expiration window +) + +func setCacheControl(w http.ResponseWriter, maxAge int, stale int) { + w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d, stale-while-revalidate=%d, stale-if-error=%d", maxAge, stale, stale)) +} + +func writeJSONResult(w http.ResponseWriter, method string, val interface{ Length() int }) { w.Header().Add("Content-Type", mediaTypeJSON) + w.Header().Add("Vary", "Accept") + + if val.Length() > 0 { + setCacheControl(w, maxAgeWithResults, maxStale) + } else { + setCacheControl(w, maxAgeWithoutResults, maxStale) + } + w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat)) // keep the marshaling separate from the writing, so we can distinguish bugs (which surface as 500) // from transient network issues (which surface as transport errors) @@ -500,14 +532,17 @@ func writeResultsIterNDJSON[T any](w http.ResponseWriter, resultIter iter.Result defer resultIter.Close() w.Header().Set("Content-Type", mediaTypeNDJSON) - w.WriteHeader(http.StatusOK) + w.Header().Add("Vary", "Accept") + w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat)) + hasResults := false for resultIter.Next() { res := resultIter.Val() if res.Err != nil { logger.Errorw("ndjson iterator error", "Error", res.Err) return } + // don't use an encoder because we can't easily differentiate writer errors from encoding errors b, err := drjson.MarshalJSONBytes(res.Val) if err != nil { @@ -515,6 +550,12 @@ func writeResultsIterNDJSON[T any](w http.ResponseWriter, resultIter iter.Result return } + if !hasResults { + hasResults = true + // There's results, cache useful result for longer + setCacheControl(w, maxAgeWithResults, maxStale) + } + _, err = w.Write(b) if err != nil { logger.Warn("ndjson write error", "Error", err) @@ -531,4 +572,9 @@ func writeResultsIterNDJSON[T any](w http.ResponseWriter, resultIter iter.Result f.Flush() } } + + if !hasResults { + // There weren't results, cache for shorter + setCacheControl(w, maxAgeWithoutResults, maxStale) + } } diff --git a/routing/http/server/server_test.go b/routing/http/server/server_test.go index c767d8f2e..ea827000f 100644 --- a/routing/http/server/server_test.go +++ b/routing/http/server/server_test.go @@ -7,6 +7,8 @@ import ( "io" "net/http" "net/http/httptest" + "regexp" + "strconv" "testing" "time" @@ -47,6 +49,7 @@ func TestHeaders(t *testing.T) { require.Equal(t, 200, resp.StatusCode) header := resp.Header.Get("Content-Type") require.Equal(t, mediaTypeJSON, header) + require.Equal(t, "Accept", resp.Header.Get("Vary")) resp, err = http.Get(serverAddr + "/routing/v1/providers/" + "BAD_CID") require.NoError(t, err) @@ -66,6 +69,13 @@ func makePeerID(t *testing.T) (crypto.PrivKey, peer.ID) { return sk, pid } +func requireCloseToNow(t *testing.T, lastModified string) { + // inspecting fields like 'Last-Modified' is prone to one-off errors, we test with 1m buffer + lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified) + require.NoError(t, err) + require.WithinDuration(t, time.Now(), lastModifiedTime, 1*time.Minute) +} + func TestProviders(t *testing.T) { pidStr := "12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn" pid2Str := "12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz" @@ -79,25 +89,31 @@ func TestProviders(t *testing.T) { cid, err := cid.Decode(cidStr) require.NoError(t, err) - runTest := func(t *testing.T, contentType string, expectedStream bool, expectedBody string) { + runTest := func(t *testing.T, contentType string, empty bool, expectedStream bool, expectedBody string) { t.Parallel() - results := iter.FromSlice([]iter.Result[types.Record]{ - {Val: &types.PeerRecord{ - Schema: types.SchemaPeer, - ID: &pid, - Protocols: []string{"transport-bitswap"}, - Addrs: []types.Multiaddr{}, - }}, - //lint:ignore SA1019 // ignore staticcheck - {Val: &types.BitswapRecord{ + var results *iter.SliceIter[iter.Result[types.Record]] + + if empty { + results = iter.FromSlice([]iter.Result[types.Record]{}) + } else { + results = iter.FromSlice([]iter.Result[types.Record]{ + {Val: &types.PeerRecord{ + Schema: types.SchemaPeer, + ID: &pid, + Protocols: []string{"transport-bitswap"}, + Addrs: []types.Multiaddr{}, + }}, //lint:ignore SA1019 // ignore staticcheck - Schema: types.SchemaBitswap, - ID: &pid2, - Protocol: "transport-bitswap", - Addrs: []types.Multiaddr{}, - }}}, - ) + {Val: &types.BitswapRecord{ + //lint:ignore SA1019 // ignore staticcheck + Schema: types.SchemaBitswap, + ID: &pid2, + Protocol: "transport-bitswap", + Addrs: []types.Multiaddr{}, + }}}, + ) + } router := &mockContentRouter{} server := httptest.NewServer(Handler(router)) @@ -117,8 +133,16 @@ func TestProviders(t *testing.T) { resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, 200, resp.StatusCode) - header := resp.Header.Get("Content-Type") - require.Equal(t, contentType, header) + + require.Equal(t, contentType, resp.Header.Get("Content-Type")) + require.Equal(t, "Accept", resp.Header.Get("Vary")) + + if empty { + require.Equal(t, "public, max-age=15, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control")) + } else { + require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control")) + } + requireCloseToNow(t, resp.Header.Get("Last-Modified")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -127,11 +151,19 @@ func TestProviders(t *testing.T) { } t.Run("JSON Response", func(t *testing.T) { - runTest(t, mediaTypeJSON, false, `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}]}`) + runTest(t, mediaTypeJSON, false, false, `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}]}`) + }) + + t.Run("Empty JSON Response", func(t *testing.T) { + runTest(t, mediaTypeJSON, true, false, `{"Providers":null}`) }) t.Run("NDJSON Response", func(t *testing.T) { - runTest(t, mediaTypeNDJSON, true, `{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}`+"\n") + runTest(t, mediaTypeNDJSON, false, true, `{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}`+"\n") + }) + + t.Run("Empty NDJSON Response", func(t *testing.T) { + runTest(t, mediaTypeNDJSON, true, true, "") }) } @@ -155,7 +187,26 @@ func TestPeers(t *testing.T) { require.Equal(t, 400, resp.StatusCode) }) - t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body (JSON)", func(t *testing.T) { + t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (No Results, JSON)", func(t *testing.T) { + t.Parallel() + + _, pid := makePeerID(t) + results := iter.FromSlice([]iter.Result[*types.PeerRecord]{}) + + router := &mockContentRouter{} + router.On("FindPeers", mock.Anything, pid, 20).Return(results, nil) + + resp := makeRequest(t, router, mediaTypeJSON, peer.ToCid(pid).String()) + require.Equal(t, 200, resp.StatusCode) + + require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type")) + require.Equal(t, "Accept", resp.Header.Get("Vary")) + require.Equal(t, "public, max-age=15, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control")) + + requireCloseToNow(t, resp.Header.Get("Last-Modified")) + }) + + t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (JSON)", func(t *testing.T) { t.Parallel() _, pid := makePeerID(t) @@ -181,8 +232,11 @@ func TestPeers(t *testing.T) { resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID) require.Equal(t, 200, resp.StatusCode) - header := resp.Header.Get("Content-Type") - require.Equal(t, mediaTypeJSON, header) + require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type")) + require.Equal(t, "Accept", resp.Header.Get("Vary")) + require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control")) + + requireCloseToNow(t, resp.Header.Get("Last-Modified")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -191,7 +245,26 @@ func TestPeers(t *testing.T) { require.Equal(t, expectedBody, string(body)) }) - t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body (NDJSON)", func(t *testing.T) { + t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (No Results, NDJSON)", func(t *testing.T) { + t.Parallel() + + _, pid := makePeerID(t) + results := iter.FromSlice([]iter.Result[*types.PeerRecord]{}) + + router := &mockContentRouter{} + router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil) + + resp := makeRequest(t, router, mediaTypeNDJSON, peer.ToCid(pid).String()) + require.Equal(t, 200, resp.StatusCode) + + require.Equal(t, mediaTypeNDJSON, resp.Header.Get("Content-Type")) + require.Equal(t, "Accept", resp.Header.Get("Vary")) + require.Equal(t, "public, max-age=15, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control")) + + requireCloseToNow(t, resp.Header.Get("Last-Modified")) + }) + + t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (NDJSON)", func(t *testing.T) { t.Parallel() _, pid := makePeerID(t) @@ -217,8 +290,9 @@ func TestPeers(t *testing.T) { resp := makeRequest(t, router, mediaTypeNDJSON, libp2pKeyCID) require.Equal(t, 200, resp.StatusCode) - header := resp.Header.Get("Content-Type") - require.Equal(t, mediaTypeNDJSON, header) + require.Equal(t, mediaTypeNDJSON, resp.Header.Get("Content-Type")) + require.Equal(t, "Accept", resp.Header.Get("Vary")) + require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -254,6 +328,7 @@ func TestPeers(t *testing.T) { require.Equal(t, 200, resp.StatusCode) header := resp.Header.Get("Content-Type") + require.Equal(t, "Accept", resp.Header.Get("Vary")) require.Equal(t, mediaTypeJSON, header) body, err := io.ReadAll(resp.Body) @@ -290,6 +365,7 @@ func TestPeers(t *testing.T) { require.Equal(t, 200, resp.StatusCode) header := resp.Header.Get("Content-Type") + require.Equal(t, "Accept", resp.Header.Get("Vary")) require.Equal(t, mediaTypeNDJSON, header) body, err := io.ReadAll(resp.Body) @@ -306,10 +382,8 @@ func makeName(t *testing.T) (crypto.PrivKey, ipns.Name) { return sk, ipns.NameFromPeer(pid) } -func makeIPNSRecord(t *testing.T, cid cid.Cid, sk crypto.PrivKey, opts ...ipns.Option) (*ipns.Record, []byte) { +func makeIPNSRecord(t *testing.T, cid cid.Cid, eol time.Time, ttl time.Duration, sk crypto.PrivKey, opts ...ipns.Option) (*ipns.Record, []byte) { path := path.FromCid(cid) - eol := time.Now().Add(time.Hour * 48) - ttl := time.Second * 20 record, err := ipns.NewRecord(sk, path, 1, eol, ttl, opts...) require.NoError(t, err) @@ -339,7 +413,18 @@ func TestIPNS(t *testing.T) { runWithRecordOptions := func(t *testing.T, opts ...ipns.Option) { sk, name1 := makeName(t) - record1, rawRecord1 := makeIPNSRecord(t, cid1, sk) + now := time.Now() + eol := now.Add(24 * time.Hour * 7) // record valid for a week + ttl := 42 * time.Second // distinct TTL + record1, rawRecord1 := makeIPNSRecord(t, cid1, eol, ttl, sk) + + stringToDuration := func(s string) time.Duration { + seconds, err := strconv.Atoi(s) + if err != nil { + return 0 + } + return time.Duration(seconds) * time.Second + } _, name2 := makeName(t) @@ -355,8 +440,25 @@ func TestIPNS(t *testing.T) { resp := makeRequest(t, router, "/routing/v1/ipns/"+name1.String()) require.Equal(t, 200, resp.StatusCode) require.Equal(t, mediaTypeIPNSRecord, resp.Header.Get("Content-Type")) + require.Equal(t, "Accept", resp.Header.Get("Vary")) require.NotEmpty(t, resp.Header.Get("Etag")) - require.Equal(t, "max-age=20", resp.Header.Get("Cache-Control")) + + requireCloseToNow(t, resp.Header.Get("Last-Modified")) + + require.Contains(t, resp.Header.Get("Cache-Control"), "public, max-age=42") + + // expected "stale" values are int(time.Until(eol).Seconds()) + // but running test on slow machine may be off by a few seconds + // and we need to assert with some room for drift (1 minute just to not break any CI) + re := regexp.MustCompile(`(?:^|,\s*)(max-age|stale-while-revalidate|stale-if-error)=(\d+)`) + matches := re.FindAllStringSubmatch(resp.Header.Get("Cache-Control"), -1) + staleWhileRevalidate := stringToDuration(matches[1][2]) + staleWhileError := stringToDuration(matches[2][2]) + require.WithinDuration(t, eol, time.Now().Add(staleWhileRevalidate), 1*time.Minute) + require.WithinDuration(t, eol, time.Now().Add(staleWhileError), 1*time.Minute) + + // 'Expires' on IPNS result is expected to match EOL of IPNS Record with ValidityType=0 + require.Equal(t, eol.UTC().Format(http.TimeFormat), resp.Header.Get("Expires")) body, err := io.ReadAll(resp.Body) require.NoError(t, err) diff --git a/routing/http/types/json/responses.go b/routing/http/types/json/responses.go index cc687df48..d8f659ac5 100644 --- a/routing/http/types/json/responses.go +++ b/routing/http/types/json/responses.go @@ -11,11 +11,19 @@ type ProvidersResponse struct { Providers RecordsArray } +func (r ProvidersResponse) Length() int { + return len(r.Providers) +} + // PeersResponse is the result of a GET Peers request. type PeersResponse struct { Peers []*types.PeerRecord } +func (r PeersResponse) Length() int { + return len(r.Peers) +} + // RecordsArray is an array of [types.Record] type RecordsArray []types.Record @@ -65,6 +73,10 @@ type WriteProvidersResponse struct { ProvideResults []types.Record } +func (r WriteProvidersResponse) Length() int { + return len(r.ProvideResults) +} + func (r *WriteProvidersResponse) UnmarshalJSON(b []byte) error { var tempWPR struct{ ProvideResults []json.RawMessage } err := json.Unmarshal(b, &tempWPR)