Skip to content

Commit d4050a2

Browse files
committed
Resolves gofiber#3072
Signed-off-by: brunodmartins <[email protected]>
1 parent 9ea7651 commit d4050a2

File tree

5 files changed

+187
-55
lines changed

5 files changed

+187
-55
lines changed

Diff for: middleware/cache/cache.go

+42-38
Original file line numberDiff line numberDiff line change
@@ -117,46 +117,49 @@ func New(config ...Config) fiber.Handler {
117117
// Get timestamp
118118
ts := atomic.LoadUint64(&timestamp)
119119

120-
// Invalidate cache if requested
121-
if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) && e != nil {
122-
e.exp = ts - 1
123-
}
124-
125-
// Check if entry is expired
126-
if e.exp != 0 && ts >= e.exp {
127-
deleteKey(key)
128-
if cfg.MaxBytes > 0 {
129-
_, size := heap.remove(e.heapidx)
130-
storedBytes -= size
131-
}
132-
} else if e.exp != 0 && !hasRequestDirective(c, noCache) {
133-
// Separate body value to avoid msgp serialization
134-
// We can store raw bytes with Storage 👍
135-
if cfg.Storage != nil {
136-
e.body = manager.getRaw(key + "_body")
137-
}
138-
// Set response headers from cache
139-
c.Response().SetBodyRaw(e.body)
140-
c.Response().SetStatusCode(e.status)
141-
c.Response().Header.SetContentTypeBytes(e.ctype)
142-
if len(e.cencoding) > 0 {
143-
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
144-
}
145-
for k, v := range e.headers {
146-
c.Response().Header.SetBytesV(k, v)
120+
// Cache Entry not found
121+
if e != nil {
122+
// Invalidate cache if requested
123+
if cfg.CacheInvalidator != nil && cfg.CacheInvalidator(c) {
124+
e.exp = ts - 1
147125
}
148-
// Set Cache-Control header if enabled
149-
if cfg.CacheControl {
150-
maxAge := strconv.FormatUint(e.exp-ts, 10)
151-
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
152-
}
153-
154-
c.Set(cfg.CacheHeader, cacheHit)
155-
156-
mux.Unlock()
157126

158-
// Return response
159-
return nil
127+
// Check if entry is expired
128+
if e.exp != 0 && ts >= e.exp {
129+
deleteKey(key)
130+
if cfg.MaxBytes > 0 {
131+
_, size := heap.remove(e.heapidx)
132+
storedBytes -= size
133+
}
134+
} else if e.exp != 0 && !hasRequestDirective(c, noCache) {
135+
// Separate body value to avoid msgp serialization
136+
// We can store raw bytes with Storage 👍
137+
if cfg.Storage != nil {
138+
e.body = manager.getRaw(key + "_body")
139+
}
140+
// Set response headers from cache
141+
c.Response().SetBodyRaw(e.body)
142+
c.Response().SetStatusCode(e.status)
143+
c.Response().Header.SetContentTypeBytes(e.ctype)
144+
if len(e.cencoding) > 0 {
145+
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
146+
}
147+
for k, v := range e.headers {
148+
c.Response().Header.SetBytesV(k, v)
149+
}
150+
// Set Cache-Control header if enabled
151+
if cfg.CacheControl {
152+
maxAge := strconv.FormatUint(e.exp-ts, 10)
153+
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
154+
}
155+
156+
c.Set(cfg.CacheHeader, cacheHit)
157+
158+
mux.Unlock()
159+
160+
// Return response
161+
return nil
162+
}
160163
}
161164

162165
// make sure we're not blocking concurrent requests - do unlock
@@ -193,6 +196,7 @@ func New(config ...Config) fiber.Handler {
193196
}
194197
}
195198

199+
e = manager.acquire()
196200
// Cache response
197201
e.body = utils.CopyBytes(c.Response().Body())
198202
e.status = c.Response().StatusCode()

Diff for: middleware/cache/cache_test.go

+120-14
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ func Test_Cache_Expired(t *testing.T) {
4747
t.Parallel()
4848
app := fiber.New()
4949
app.Use(New(Config{Expiration: 2 * time.Second}))
50-
50+
count := 0
5151
app.Get("/", func(c fiber.Ctx) error {
52-
return c.SendString(strconv.FormatInt(time.Now().UnixNano(), 10))
52+
count++
53+
return c.SendString(strconv.Itoa(count))
5354
})
5455

5556
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
@@ -86,9 +87,10 @@ func Test_Cache(t *testing.T) {
8687
app := fiber.New()
8788
app.Use(New())
8889

90+
count := 0
8991
app.Get("/", func(c fiber.Ctx) error {
90-
now := strconv.FormatInt(time.Now().UnixNano(), 10)
91-
return c.SendString(now)
92+
count++
93+
return c.SendString(strconv.Itoa(count))
9294
})
9395

9496
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
@@ -305,9 +307,10 @@ func Test_Cache_Invalid_Expiration(t *testing.T) {
305307
cache := New(Config{Expiration: 0 * time.Second})
306308
app.Use(cache)
307309

310+
count := 0
308311
app.Get("/", func(c fiber.Ctx) error {
309-
now := strconv.FormatInt(time.Now().UnixNano(), 10)
310-
return c.SendString(now)
312+
count++
313+
return c.SendString(strconv.Itoa(count))
311314
})
312315

313316
req := httptest.NewRequest(fiber.MethodGet, "/", nil)
@@ -414,8 +417,10 @@ func Test_Cache_NothingToCache(t *testing.T) {
414417

415418
app.Use(New(Config{Expiration: -(time.Second * 1)}))
416419

420+
count := 0
417421
app.Get("/", func(c fiber.Ctx) error {
418-
return c.SendString(time.Now().String())
422+
count++
423+
return c.SendString(strconv.Itoa(count))
419424
})
420425

421426
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
@@ -447,12 +452,16 @@ func Test_Cache_CustomNext(t *testing.T) {
447452
CacheControl: true,
448453
}))
449454

455+
count := 0
450456
app.Get("/", func(c fiber.Ctx) error {
451-
return c.SendString(time.Now().String())
457+
count++
458+
return c.SendString(strconv.Itoa(count))
452459
})
453460

461+
errorCount := 0
454462
app.Get("/error", func(c fiber.Ctx) error {
455-
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
463+
errorCount++
464+
return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(errorCount))
456465
})
457466

458467
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
@@ -508,9 +517,11 @@ func Test_CustomExpiration(t *testing.T) {
508517
return time.Second * time.Duration(newCacheTime)
509518
}}))
510519

520+
count := 0
511521
app.Get("/", func(c fiber.Ctx) error {
522+
count++
512523
c.Response().Header.Add("Cache-Time", "1")
513-
return c.SendString(strconv.FormatInt(time.Now().UnixNano(), 10))
524+
return c.SendString(strconv.Itoa(count))
514525
})
515526

516527
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
@@ -588,8 +599,11 @@ func Test_CacheHeader(t *testing.T) {
588599
return c.SendString(fiber.Query[string](c, "cache"))
589600
})
590601

602+
count := 0
591603
app.Get("/error", func(c fiber.Ctx) error {
592-
return c.Status(fiber.StatusInternalServerError).SendString(time.Now().String())
604+
count++
605+
c.Response().Header.Add("Cache-Time", "1")
606+
return c.Status(fiber.StatusInternalServerError).SendString(strconv.Itoa(count))
593607
})
594608

595609
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
@@ -615,10 +629,13 @@ func Test_Cache_WithHead(t *testing.T) {
615629
app := fiber.New()
616630
app.Use(New())
617631

632+
count := 0
618633
handler := func(c fiber.Ctx) error {
619-
now := strconv.FormatInt(time.Now().UnixNano(), 10)
620-
return c.SendString(now)
634+
count++
635+
c.Response().Header.Add("Cache-Time", "1")
636+
return c.SendString(strconv.Itoa(count))
621637
}
638+
622639
app.Route("/").Get(handler).Head(handler)
623640

624641
req := httptest.NewRequest(fiber.MethodHead, "/", nil)
@@ -708,8 +725,10 @@ func Test_CacheInvalidation(t *testing.T) {
708725
},
709726
}))
710727

728+
count := 0
711729
app.Get("/", func(c fiber.Ctx) error {
712-
return c.SendString(time.Now().String())
730+
count++
731+
return c.SendString(strconv.Itoa(count))
713732
})
714733

715734
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
@@ -731,6 +750,93 @@ func Test_CacheInvalidation(t *testing.T) {
731750
require.NotEqual(t, body, bodyInvalidate)
732751
}
733752

753+
func Test_CacheInvalidation_noCacheEntry(t *testing.T) {
754+
t.Parallel()
755+
t.Run("Cache Invalidator should not be called if no cache entry exist ", func(t *testing.T) {
756+
t.Parallel()
757+
app := fiber.New()
758+
cacheInvalidatorExecuted := false
759+
app.Use(New(Config{
760+
CacheControl: true,
761+
CacheInvalidator: func(c fiber.Ctx) bool {
762+
cacheInvalidatorExecuted = true
763+
return fiber.Query[bool](c, "invalidate")
764+
},
765+
MaxBytes: 10 * 1024 * 1024,
766+
}))
767+
_, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil))
768+
require.NoError(t, err)
769+
require.False(t, cacheInvalidatorExecuted)
770+
})
771+
}
772+
773+
func Test_CacheInvalidation_removeFromHeap(t *testing.T) {
774+
t.Parallel()
775+
t.Run("Invalidate and remove from the heap", func(t *testing.T) {
776+
t.Parallel()
777+
app := fiber.New()
778+
app.Use(New(Config{
779+
CacheControl: true,
780+
CacheInvalidator: func(c fiber.Ctx) bool {
781+
return fiber.Query[bool](c, "invalidate")
782+
},
783+
MaxBytes: 10 * 1024 * 1024,
784+
}))
785+
786+
count := 0
787+
app.Get("/", func(c fiber.Ctx) error {
788+
count++
789+
return c.SendString(strconv.Itoa(count))
790+
})
791+
792+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
793+
require.NoError(t, err)
794+
body, err := io.ReadAll(resp.Body)
795+
require.NoError(t, err)
796+
797+
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
798+
require.NoError(t, err)
799+
bodyCached, err := io.ReadAll(respCached.Body)
800+
require.NoError(t, err)
801+
require.True(t, bytes.Equal(body, bodyCached))
802+
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))
803+
804+
respInvalidate, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/?invalidate=true", nil))
805+
require.NoError(t, err)
806+
bodyInvalidate, err := io.ReadAll(respInvalidate.Body)
807+
require.NoError(t, err)
808+
require.NotEqual(t, body, bodyInvalidate)
809+
})
810+
}
811+
812+
func Test_CacheStorage_CustomHeaders(t *testing.T) {
813+
t.Parallel()
814+
app := fiber.New()
815+
app.Use(New(Config{
816+
CacheControl: true,
817+
Storage: memory.New(),
818+
MaxBytes: 10 * 1024 * 1024,
819+
}))
820+
821+
app.Get("/", func(c fiber.Ctx) error {
822+
c.Response().Header.Set("Content-Type", "text/xml")
823+
c.Response().Header.Set("Content-Encoding", "utf8")
824+
return c.Send([]byte("<xml><value>Test</value></xml>"))
825+
})
826+
827+
resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
828+
require.NoError(t, err)
829+
body, err := io.ReadAll(resp.Body)
830+
require.NoError(t, err)
831+
832+
respCached, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", nil))
833+
require.NoError(t, err)
834+
bodyCached, err := io.ReadAll(respCached.Body)
835+
require.NoError(t, err)
836+
require.True(t, bytes.Equal(body, bodyCached))
837+
require.NotEmpty(t, respCached.Header.Get(fiber.HeaderCacheControl))
838+
}
839+
734840
// Because time points are updated once every X milliseconds, entries in tests can often have
735841
// equal expiration times and thus be in an random order. This closure hands out increasing
736842
// time intervals to maintain strong ascending order of expiration

Diff for: middleware/cache/heap.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type heapEntry struct {
1515
// elements in constant time. It does so by handing out special indices
1616
// and tracking entry movement.
1717
//
18-
// indexdedHeap is used for quickly finding entries with the lowest
18+
// indexedHeap is used for quickly finding entries with the lowest
1919
// expiration timestamp and deleting arbitrary entries.
2020
type indexedHeap struct {
2121
// Slice the heap is built on

Diff for: middleware/cache/manager.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,7 @@ func (m *manager) get(key string) *item {
8383
return it
8484
}
8585
if it, _ = m.memory.Get(key).(*item); it == nil { //nolint:errcheck // We store nothing else in the pool
86-
it = m.acquire()
87-
return it
86+
return nil
8887
}
8988
return it
9089
}

Diff for: middleware/cache/manager_test.go

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package cache
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/gofiber/utils/v2"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
func Test_manager_get(t *testing.T) {
12+
cacheManager := newManager(nil)
13+
t.Run("Item not found in cache", func(t *testing.T) {
14+
assert.Nil(t, cacheManager.get(utils.UUID()))
15+
})
16+
t.Run("Item found in cache", func(t *testing.T) {
17+
id := utils.UUID()
18+
cacheItem := cacheManager.acquire()
19+
cacheItem.body = []byte("test-body")
20+
cacheManager.set(id, cacheItem, 10*time.Second)
21+
assert.NotNil(t, cacheManager.get(id))
22+
})
23+
}

0 commit comments

Comments
 (0)