Skip to content

Commit

Permalink
routing/http/server: add cache control (#584)
Browse files Browse the repository at this point in the history
  • Loading branch information
hacdias authored Mar 7, 2024
1 parent 97e347e commit b6b7771
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 38 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The following emojis are used to highlight certain changes:

### Added

* `routing/http/server` now adds `Cache-Control` HTTP header to GET requests: 15 seconds for empty responses, or 5 minutes for responses with providers.

### Changed

### Removed
Expand Down
60 changes: 53 additions & 7 deletions routing/http/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io"
"mime"
"net/http"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -402,15 +401,26 @@ func (s *server) GetIPNS(w http.ResponseWriter, r *http.Request) {
return
}

var remainingValidity int
// Include 'Expires' header with time when signature expiration happens
if validityType, err := record.ValidityType(); err == nil && validityType == ipns.ValidityEOL {
if validity, err := record.Validity(); err == nil {
w.Header().Set("Expires", validity.UTC().Format(http.TimeFormat))
remainingValidity = int(time.Until(validity).Seconds())
}
} else {
remainingValidity = int(ipns.DefaultRecordLifetime.Seconds())
}
if ttl, err := record.TTL(); err == nil {
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", int(ttl.Seconds())))
setCacheControl(w, int(ttl.Seconds()), remainingValidity)
} else {
w.Header().Set("Cache-Control", "max-age=60")
setCacheControl(w, int(ipns.DefaultRecordTTL.Seconds()), remainingValidity)
}
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))

recordEtag := strconv.FormatUint(xxhash.Sum64(rawRecord), 32)
w.Header().Set("Etag", recordEtag)
w.Header().Set("Etag", fmt.Sprintf(`"%x"`, xxhash.Sum64(rawRecord)))
w.Header().Set("Content-Type", mediaTypeIPNSRecord)
w.Header().Add("Vary", "Accept")
w.Write(rawRecord)
}

Expand Down Expand Up @@ -462,8 +472,30 @@ func (s *server) PutIPNS(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

func writeJSONResult(w http.ResponseWriter, method string, val any) {
var (
// Rule-of-thumb Cache-Control policy is to work well with caching proxies and load balancers.
// If there are any results, cache on the client for longer, and hint any in-between caches to
// serve cached result and upddate cache in background as long we have
// result that is within Amino DHT expiration window
maxAgeWithResults = int((5 * time.Minute).Seconds()) // cache >0 results for longer
maxAgeWithoutResults = int((15 * time.Second).Seconds()) // cache no results briefly
maxStale = int((48 * time.Hour).Seconds()) // allow stale results as long within Amino DHT Expiration window
)

func setCacheControl(w http.ResponseWriter, maxAge int, stale int) {
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d, stale-while-revalidate=%d, stale-if-error=%d", maxAge, stale, stale))
}

func writeJSONResult(w http.ResponseWriter, method string, val interface{ Length() int }) {
w.Header().Add("Content-Type", mediaTypeJSON)
w.Header().Add("Vary", "Accept")

if val.Length() > 0 {
setCacheControl(w, maxAgeWithResults, maxStale)
} else {
setCacheControl(w, maxAgeWithoutResults, maxStale)
}
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))

// keep the marshaling separate from the writing, so we can distinguish bugs (which surface as 500)
// from transient network issues (which surface as transport errors)
Expand Down Expand Up @@ -500,21 +532,30 @@ func writeResultsIterNDJSON[T any](w http.ResponseWriter, resultIter iter.Result
defer resultIter.Close()

w.Header().Set("Content-Type", mediaTypeNDJSON)
w.WriteHeader(http.StatusOK)
w.Header().Add("Vary", "Accept")
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))

hasResults := false
for resultIter.Next() {
res := resultIter.Val()
if res.Err != nil {
logger.Errorw("ndjson iterator error", "Error", res.Err)
return
}

// don't use an encoder because we can't easily differentiate writer errors from encoding errors
b, err := drjson.MarshalJSONBytes(res.Val)
if err != nil {
logger.Errorw("ndjson marshal error", "Error", err)
return
}

if !hasResults {
hasResults = true
// There's results, cache useful result for longer
setCacheControl(w, maxAgeWithResults, maxStale)
}

_, err = w.Write(b)
if err != nil {
logger.Warn("ndjson write error", "Error", err)
Expand All @@ -531,4 +572,9 @@ func writeResultsIterNDJSON[T any](w http.ResponseWriter, resultIter iter.Result
f.Flush()
}
}

if !hasResults {
// There weren't results, cache for shorter
setCacheControl(w, maxAgeWithoutResults, maxStale)
}
}
164 changes: 133 additions & 31 deletions routing/http/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"io"
"net/http"
"net/http/httptest"
"regexp"
"strconv"
"testing"
"time"

Expand Down Expand Up @@ -47,6 +49,7 @@ func TestHeaders(t *testing.T) {
require.Equal(t, 200, resp.StatusCode)
header := resp.Header.Get("Content-Type")
require.Equal(t, mediaTypeJSON, header)
require.Equal(t, "Accept", resp.Header.Get("Vary"))

resp, err = http.Get(serverAddr + "/routing/v1/providers/" + "BAD_CID")
require.NoError(t, err)
Expand All @@ -66,6 +69,13 @@ func makePeerID(t *testing.T) (crypto.PrivKey, peer.ID) {
return sk, pid
}

func requireCloseToNow(t *testing.T, lastModified string) {
// inspecting fields like 'Last-Modified' is prone to one-off errors, we test with 1m buffer
lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified)
require.NoError(t, err)
require.WithinDuration(t, time.Now(), lastModifiedTime, 1*time.Minute)
}

func TestProviders(t *testing.T) {
pidStr := "12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn"
pid2Str := "12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"
Expand All @@ -79,25 +89,31 @@ func TestProviders(t *testing.T) {
cid, err := cid.Decode(cidStr)
require.NoError(t, err)

runTest := func(t *testing.T, contentType string, expectedStream bool, expectedBody string) {
runTest := func(t *testing.T, contentType string, empty bool, expectedStream bool, expectedBody string) {
t.Parallel()

results := iter.FromSlice([]iter.Result[types.Record]{
{Val: &types.PeerRecord{
Schema: types.SchemaPeer,
ID: &pid,
Protocols: []string{"transport-bitswap"},
Addrs: []types.Multiaddr{},
}},
//lint:ignore SA1019 // ignore staticcheck
{Val: &types.BitswapRecord{
var results *iter.SliceIter[iter.Result[types.Record]]

if empty {
results = iter.FromSlice([]iter.Result[types.Record]{})
} else {
results = iter.FromSlice([]iter.Result[types.Record]{
{Val: &types.PeerRecord{
Schema: types.SchemaPeer,
ID: &pid,
Protocols: []string{"transport-bitswap"},
Addrs: []types.Multiaddr{},
}},
//lint:ignore SA1019 // ignore staticcheck
Schema: types.SchemaBitswap,
ID: &pid2,
Protocol: "transport-bitswap",
Addrs: []types.Multiaddr{},
}}},
)
{Val: &types.BitswapRecord{
//lint:ignore SA1019 // ignore staticcheck
Schema: types.SchemaBitswap,
ID: &pid2,
Protocol: "transport-bitswap",
Addrs: []types.Multiaddr{},
}}},
)
}

router := &mockContentRouter{}
server := httptest.NewServer(Handler(router))
Expand All @@ -117,8 +133,16 @@ func TestProviders(t *testing.T) {
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
require.Equal(t, 200, resp.StatusCode)
header := resp.Header.Get("Content-Type")
require.Equal(t, contentType, header)

require.Equal(t, contentType, resp.Header.Get("Content-Type"))
require.Equal(t, "Accept", resp.Header.Get("Vary"))

if empty {
require.Equal(t, "public, max-age=15, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))
} else {
require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))
}
requireCloseToNow(t, resp.Header.Get("Last-Modified"))

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
Expand All @@ -127,11 +151,19 @@ func TestProviders(t *testing.T) {
}

t.Run("JSON Response", func(t *testing.T) {
runTest(t, mediaTypeJSON, false, `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}]}`)
runTest(t, mediaTypeJSON, false, false, `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}]}`)
})

t.Run("Empty JSON Response", func(t *testing.T) {
runTest(t, mediaTypeJSON, true, false, `{"Providers":null}`)
})

t.Run("NDJSON Response", func(t *testing.T) {
runTest(t, mediaTypeNDJSON, true, `{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}`+"\n")
runTest(t, mediaTypeNDJSON, false, true, `{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}`+"\n")
})

t.Run("Empty NDJSON Response", func(t *testing.T) {
runTest(t, mediaTypeNDJSON, true, true, "")
})
}

Expand All @@ -155,7 +187,26 @@ func TestPeers(t *testing.T) {
require.Equal(t, 400, resp.StatusCode)
})

t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body (JSON)", func(t *testing.T) {
t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (No Results, JSON)", func(t *testing.T) {
t.Parallel()

_, pid := makePeerID(t)
results := iter.FromSlice([]iter.Result[*types.PeerRecord]{})

router := &mockContentRouter{}
router.On("FindPeers", mock.Anything, pid, 20).Return(results, nil)

resp := makeRequest(t, router, mediaTypeJSON, peer.ToCid(pid).String())
require.Equal(t, 200, resp.StatusCode)

require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type"))
require.Equal(t, "Accept", resp.Header.Get("Vary"))
require.Equal(t, "public, max-age=15, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))

requireCloseToNow(t, resp.Header.Get("Last-Modified"))
})

t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (JSON)", func(t *testing.T) {
t.Parallel()

_, pid := makePeerID(t)
Expand All @@ -181,8 +232,11 @@ func TestPeers(t *testing.T) {
resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID)
require.Equal(t, 200, resp.StatusCode)

header := resp.Header.Get("Content-Type")
require.Equal(t, mediaTypeJSON, header)
require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type"))
require.Equal(t, "Accept", resp.Header.Get("Vary"))
require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))

requireCloseToNow(t, resp.Header.Get("Last-Modified"))

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
Expand All @@ -191,7 +245,26 @@ func TestPeers(t *testing.T) {
require.Equal(t, expectedBody, string(body))
})

t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body (NDJSON)", func(t *testing.T) {
t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (No Results, NDJSON)", func(t *testing.T) {
t.Parallel()

_, pid := makePeerID(t)
results := iter.FromSlice([]iter.Result[*types.PeerRecord]{})

router := &mockContentRouter{}
router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil)

resp := makeRequest(t, router, mediaTypeNDJSON, peer.ToCid(pid).String())
require.Equal(t, 200, resp.StatusCode)

require.Equal(t, mediaTypeNDJSON, resp.Header.Get("Content-Type"))
require.Equal(t, "Accept", resp.Header.Get("Vary"))
require.Equal(t, "public, max-age=15, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))

requireCloseToNow(t, resp.Header.Get("Last-Modified"))
})

t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (NDJSON)", func(t *testing.T) {
t.Parallel()

_, pid := makePeerID(t)
Expand All @@ -217,8 +290,9 @@ func TestPeers(t *testing.T) {
resp := makeRequest(t, router, mediaTypeNDJSON, libp2pKeyCID)
require.Equal(t, 200, resp.StatusCode)

header := resp.Header.Get("Content-Type")
require.Equal(t, mediaTypeNDJSON, header)
require.Equal(t, mediaTypeNDJSON, resp.Header.Get("Content-Type"))
require.Equal(t, "Accept", resp.Header.Get("Vary"))
require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
Expand Down Expand Up @@ -254,6 +328,7 @@ func TestPeers(t *testing.T) {
require.Equal(t, 200, resp.StatusCode)

header := resp.Header.Get("Content-Type")
require.Equal(t, "Accept", resp.Header.Get("Vary"))
require.Equal(t, mediaTypeJSON, header)

body, err := io.ReadAll(resp.Body)
Expand Down Expand Up @@ -290,6 +365,7 @@ func TestPeers(t *testing.T) {
require.Equal(t, 200, resp.StatusCode)

header := resp.Header.Get("Content-Type")
require.Equal(t, "Accept", resp.Header.Get("Vary"))
require.Equal(t, mediaTypeNDJSON, header)

body, err := io.ReadAll(resp.Body)
Expand All @@ -306,10 +382,8 @@ func makeName(t *testing.T) (crypto.PrivKey, ipns.Name) {
return sk, ipns.NameFromPeer(pid)
}

func makeIPNSRecord(t *testing.T, cid cid.Cid, sk crypto.PrivKey, opts ...ipns.Option) (*ipns.Record, []byte) {
func makeIPNSRecord(t *testing.T, cid cid.Cid, eol time.Time, ttl time.Duration, sk crypto.PrivKey, opts ...ipns.Option) (*ipns.Record, []byte) {
path := path.FromCid(cid)
eol := time.Now().Add(time.Hour * 48)
ttl := time.Second * 20

record, err := ipns.NewRecord(sk, path, 1, eol, ttl, opts...)
require.NoError(t, err)
Expand Down Expand Up @@ -339,7 +413,18 @@ func TestIPNS(t *testing.T) {

runWithRecordOptions := func(t *testing.T, opts ...ipns.Option) {
sk, name1 := makeName(t)
record1, rawRecord1 := makeIPNSRecord(t, cid1, sk)
now := time.Now()
eol := now.Add(24 * time.Hour * 7) // record valid for a week
ttl := 42 * time.Second // distinct TTL
record1, rawRecord1 := makeIPNSRecord(t, cid1, eol, ttl, sk)

stringToDuration := func(s string) time.Duration {
seconds, err := strconv.Atoi(s)
if err != nil {
return 0
}
return time.Duration(seconds) * time.Second
}

_, name2 := makeName(t)

Expand All @@ -355,8 +440,25 @@ func TestIPNS(t *testing.T) {
resp := makeRequest(t, router, "/routing/v1/ipns/"+name1.String())
require.Equal(t, 200, resp.StatusCode)
require.Equal(t, mediaTypeIPNSRecord, resp.Header.Get("Content-Type"))
require.Equal(t, "Accept", resp.Header.Get("Vary"))
require.NotEmpty(t, resp.Header.Get("Etag"))
require.Equal(t, "max-age=20", resp.Header.Get("Cache-Control"))

requireCloseToNow(t, resp.Header.Get("Last-Modified"))

require.Contains(t, resp.Header.Get("Cache-Control"), "public, max-age=42")

// expected "stale" values are int(time.Until(eol).Seconds())
// but running test on slow machine may be off by a few seconds
// and we need to assert with some room for drift (1 minute just to not break any CI)
re := regexp.MustCompile(`(?:^|,\s*)(max-age|stale-while-revalidate|stale-if-error)=(\d+)`)
matches := re.FindAllStringSubmatch(resp.Header.Get("Cache-Control"), -1)
staleWhileRevalidate := stringToDuration(matches[1][2])
staleWhileError := stringToDuration(matches[2][2])
require.WithinDuration(t, eol, time.Now().Add(staleWhileRevalidate), 1*time.Minute)
require.WithinDuration(t, eol, time.Now().Add(staleWhileError), 1*time.Minute)

// 'Expires' on IPNS result is expected to match EOL of IPNS Record with ValidityType=0
require.Equal(t, eol.UTC().Format(http.TimeFormat), resp.Header.Get("Expires"))

body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
Expand Down
Loading

0 comments on commit b6b7771

Please sign in to comment.