Skip to content

Commit 4ca8d75

Browse files
committed
Resolves gofiber#3072
Signed-off-by: brunodmartins <[email protected]>
1 parent 9ea7651 commit 4ca8d75

File tree

5 files changed

+152
-41
lines changed

5 files changed

+152
-41
lines changed

Diff for: middleware/cache/cache.go

+42-38
Original file line numberDiff line numberDiff line change
@@ -117,46 +117,49 @@ func New(config ...Config) fiber.Handler {
117117
// Get timestamp
118118
ts := atomic.LoadUint64(&timestamp)
119119

120-
// Invalidate cache if requested
121-
if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) && e != nil {
122-
e.exp = ts - 1
123-
}
124-
125-
// Check if entry is expired
126-
if e.exp != 0 && ts >= e.exp {
127-
deleteKey(key)
128-
if cfg.MaxBytes > 0 {
129-
_, size := heap.remove(e.heapidx)
130-
storedBytes -= size
131-
}
132-
} else if e.exp != 0 && !hasRequestDirective(c, noCache) {
133-
// Separate body value to avoid msgp serialization
134-
// We can store raw bytes with Storage 👍
135-
if cfg.Storage != nil {
136-
e.body = manager.getRaw(key + "_body")
137-
}
138-
// Set response headers from cache
139-
c.Response().SetBodyRaw(e.body)
140-
c.Response().SetStatusCode(e.status)
141-
c.Response().Header.SetContentTypeBytes(e.ctype)
142-
if len(e.cencoding) > 0 {
143-
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
144-
}
145-
for k, v := range e.headers {
146-
c.Response().Header.SetBytesV(k, v)
120+
// Cache Entry not found
121+
if e != nil {
122+
// Invalidate cache if requested
123+
if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) {
124+
e.exp = ts - 1
147125
}
148-
// Set Cache-Control header if enabled
149-
if cfg.CacheControl {
150-
maxAge := strconv.FormatUint(e.exp-ts, 10)
151-
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
152-
}
153-
154-
c.Set(cfg.CacheHeader, cacheHit)
155-
156-
mux.Unlock()
157126

158-
// Return response
159-
return nil
127+
// Check if entry is expired
128+
if e.exp != 0 && ts >= e.exp {
129+
deleteKey(key)
130+
if cfg.MaxBytes > 0 {
131+
_, size := heap.remove(e.heapidx)
132+
storedBytes -= size
133+
}
134+
} else if e.exp != 0 && !hasRequestDirective(c, noCache) {
135+
// Separate body value to avoid msgp serialization
136+
// We can store raw bytes with Storage 👍
137+
if cfg.Storage != nil {
138+
e.body = manager.getRaw(key + "_body")
139+
}
140+
// Set response headers from cache
141+
c.Response().SetBodyRaw(e.body)
142+
c.Response().SetStatusCode(e.status)
143+
c.Response().Header.SetContentTypeBytes(e.ctype)
144+
if len(e.cencoding) > 0 {
145+
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
146+
}
147+
for k, v := range e.headers {
148+
c.Response().Header.SetBytesV(k, v)
149+
}
150+
// Set Cache-Control header if enabled
151+
if cfg.CacheControl {
152+
maxAge := strconv.FormatUint(e.exp-ts, 10)
153+
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
154+
}
155+
156+
c.Set(cfg.CacheHeader, cacheHit)
157+
158+
mux.Unlock()
159+
160+
// Return response
161+
return nil
162+
}
160163
}
161164

162165
// make sure we're not blocking concurrent requests - do unlock
@@ -193,6 +196,7 @@ func New(config ...Config) fiber.Handler {
193196
}
194197
}
195198

199+
e = manager.acquire()
196200
// Cache response
197201
e.body = utils.CopyBytes(c.Response().Body())
198202
e.status = c.Response().StatusCode()

Diff for: middleware/cache/cache_test.go

+85
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,91 @@ func Test_CacheInvalidation(t *testing.T) {
731731
require.NotEqual(t, body, bodyInvalidate)
732732
}
733733

734+
func Test_CacheInvalidation_noCacheEntry(t *testing.T) {
735+
t.Parallel()
736+
t.Run("Cache Invalidator should not be called if no cache entry exist ", func(t *testing.T) {
737+
t.Parallel()
738+
app := fiber.New()
739+
cacheInvalidatorExecuted := false
740+
app.Use(New(Config{
741+
CacheControl: true,
742+
CacheInvalidator: func(c fiber.Ctx) bool {
743+
cacheInvalidatorExecuted = true
744+
return fiber.Query[bool](c, "invalidate")
745+
},
746+
MaxBytes: 10 * 1024 * 1024,
747+
}))
748+
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil))
749+
require.NoError(t, err)
750+
require.False(t, cacheInvalidatorExecuted)
751+
})
752+
}
753+
754+
func Test_CacheInvalidation_removeFromHeap(t *testing.T) {
755+
t.Parallel()
756+
t.Run("Invalidate and remove from the heap", func(t *testing.T) {
757+
t.Parallel()
758+
app := fiber.New()
759+
app.Use(New(Config{
760+
CacheControl: true,
761+
CacheInvalidator: func(c fiber.Ctx) bool {
762+
return fiber.Query[bool](c, "invalidate")
763+
},
764+
MaxBytes: 10 * 1024 * 1024,
765+
}))
766+
767+
app.Get("/", func(c fiber.Ctx) error {
768+
return c.SendString(time.Now().String())
769+
})
770+
771+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
772+
require.NoError(t, err)
773+
body, err := io.ReadAll(resp.Body)
774+
require.NoError(t, err)
775+
776+
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
777+
require.NoError(t, err)
778+
bodyCached, err := io.ReadAll(respCached.Body)
779+
require.NoError(t, err)
780+
require.True(t, bytes.Equal(body, bodyCached))
781+
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))
782+
783+
respInvalidate, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil))
784+
require.NoError(t, err)
785+
bodyInvalidate, err := io.ReadAll(respInvalidate.Body)
786+
require.NoError(t, err)
787+
require.NotEqual(t, body, bodyInvalidate)
788+
})
789+
}
790+
791+
func Test_CacheStorage_CustomHeaders(t *testing.T) {
792+
t.Parallel()
793+
app := fiber.New()
794+
app.Use(New(Config{
795+
CacheControl: true,
796+
Storage: memory.New(),
797+
MaxBytes: 10 * 1024 * 1024,
798+
}))
799+
800+
app.Get("/", func(c fiber.Ctx) error {
801+
c.Response().Header.Set("Content-Type", "text/xml")
802+
c.Response().Header.Set("Content-Encoding", "utf8")
803+
return c.Send([]byte("<xml><value>Test</value></xml>"))
804+
})
805+
806+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
807+
require.NoError(t, err)
808+
body, err := io.ReadAll(resp.Body)
809+
require.NoError(t, err)
810+
811+
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
812+
require.NoError(t, err)
813+
bodyCached, err := io.ReadAll(respCached.Body)
814+
require.NoError(t, err)
815+
require.True(t, bytes.Equal(body, bodyCached))
816+
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))
817+
}
818+
734819
// Because time points are updated once every X milliseconds, entries in tests can often have
735820
// equal expiration times and thus be in an random order. This closure hands out increasing
736821
// time intervals to maintain strong ascending order of expiration

Diff for: middleware/cache/heap.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type heapEntry struct {
1515
// elements in constant time. It does so by handing out special indices
1616
// and tracking entry movement.
1717
//
18-
// indexdedHeap is used for quickly finding entries with the lowest
18+
// indexedHeap is used for quickly finding entries with the lowest
1919
// expiration timestamp and deleting arbitrary entries.
2020
type indexedHeap struct {
2121
// Slice the heap is built on

Diff for: middleware/cache/manager.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ func (m *manager) get(key string) *item {
8383
return it
8484
}
8585
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
86-
it = m.acquire()
87-
return it
86+
return nil
8887
}
8988
return it
9089
}

Diff for: middleware/cache/manager_test.go

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package cache
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/gofiber/utils/v2"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func Test_manager_get(t *testing.T) {
12+
cacheManager := newManager(nil)
13+
t.Run("Item not found in cache", func(t *testing.T) {
14+
assert.Nil(t, cacheManager.get(utils.UUID()))
15+
})
16+
t.Run("Item found in cache", func(t *testing.T) {
17+
id := utils.UUID()
18+
cacheItem := cacheManager.acquire()
19+
cacheItem.body = []byte("test-body")
20+
cacheManager.set(id, cacheItem, 10*time.Second)
21+
assert.NotNil(t, cacheManager.get(id))
22+
})
23+
}

0 commit comments

Comments
 (0)