Skip to content

Commit

Permalink
feat: Add an API for getting all keys (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
viccon authored May 26, 2024
1 parent 3bc44ff commit bf36589
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 21 deletions.
61 changes: 51 additions & 10 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,31 @@ func (c *Client[T]) get(key string) (value T, exists, ignore, refresh bool) {
return val, exists, ignore, refresh
}

// Get retrieves a single value from the cache.
func (c *Client[T]) Get(key string) (T, bool) {
shard := c.getShard(key)
val, ok, ignore, _ := shard.get(key)
c.reportCacheHits(ok)
return val, ok && !ignore
}

func (c *Client[T]) GetMany(ids []string, keyFn KeyFn) map[string]T {
// GetMany retrieves multiple values from the cache.
func (c *Client[T]) GetMany(keys []string) map[string]T {
records := make(map[string]T, len(keys))
for _, key := range keys {
if value, ok := c.Get(key); ok {
records[key] = value
}
}
return records
}

// GetManyKeyFn follows the same API as GetFetchBatch and PassthroughBatch.
// You provide it with a slice of IDs and a keyFn, which is applied to create
// the cache key. The returned map uses the IDs as keys instead of the cache key.
// If you've used ScanKeys to retrieve the actual keys, you can retrieve the records
// using GetMany instead.
func (c *Client[T]) GetManyKeyFn(ids []string, keyFn KeyFn) map[string]T {
records := make(map[string]T, len(ids))
for _, id := range ids {
if value, ok := c.Get(keyFn(id)); ok {
Expand All @@ -170,29 +187,53 @@ func (c *Client[T]) GetMany(ids []string, keyFn KeyFn) map[string]T {
return records
}

// SetMissing writes a single value to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) SetMissing(key string, value T, isMissingRecord bool) bool {
// Set writes a single value to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) Set(key string, value T) bool {
shard := c.getShard(key)
return shard.set(key, value, isMissingRecord)
return shard.set(key, value, false)
}

// Set writes a single value to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) Set(key string, value T) bool {
return c.SetMissing(key, value, false)
// StoreMissingRecord writes a single value to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) StoreMissingRecord(key string) bool {
shard := c.getShard(key)
return shard.set(key, *new(T), true)
}

// SetMany writes a map of key value pairs to the cache.
func (c *Client[T]) SetMany(records map[string]T) bool {
var triggeredEviction bool
for key, value := range records {
evicted := c.Set(key, value)
if evicted {
triggeredEviction = true
}
}
return triggeredEviction
}

// SetMany writes multiple values to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) SetMany(records map[string]T, cacheKeyFn KeyFn) bool {
// SetManyKeyFn follows the same API as GetFetchBatch and PassThroughBatch. It
// takes a map of records where the keyFn is applied to each key in the map
// before it's stored in the cache.
func (c *Client[T]) SetManyKeyFn(records map[string]T, cacheKeyFn KeyFn) bool {
var triggeredEviction bool
for id, value := range records {
evicted := c.SetMissing(cacheKeyFn(id), value, false)
evicted := c.Set(cacheKeyFn(id), value)
if evicted {
triggeredEviction = true
}
}
return triggeredEviction
}

// ScanKeys returns a list of all keys in the cache.
func (c *Client[T]) ScanKeys() []string {
keys := make([]string, 0, c.Size())
for _, shard := range c.shards {
keys = append(keys, shard.keys()...)
}
return keys
}

// Size returns the number of entries in the cache.
func (c *Client[T]) Size() int {
var sum int
Expand Down
73 changes: 72 additions & 1 deletion cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sturdyc_test

import (
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -208,11 +209,81 @@ func TestSetMany(t *testing.T) {
for i := 0; i < 10; i++ {
records[strconv.Itoa(i)] = i
}
c.SetMany(records, c.BatchKeyFn("key"))
c.SetMany(records)

if c.Size() != 10 {
t.Errorf("expected cache size to be 10, got %d", c.Size())
}

keys := c.ScanKeys()
if len(keys) != 10 {
t.Errorf("expected 10 keys, got %d", len(keys))
}
for _, key := range keys {
if _, ok := records[key]; !ok {
t.Errorf("expected key %s to be in the cache", key)
}
}
}

func TestSetManyKeyFn(t *testing.T) {
t.Parallel()

c := sturdyc.New[int](1000, 10, time.Hour, 5)

if c.Size() != 0 {
t.Errorf("expected cache size to be 0, got %d", c.Size())
}

records := make(map[string]int, 10)
for i := 0; i < 10; i++ {
records[strconv.Itoa(i)] = i
}
c.SetManyKeyFn(records, c.BatchKeyFn("foo"))

if c.Size() != 10 {
t.Errorf("expected cache size to be 10, got %d", c.Size())
}

keys := c.ScanKeys()
if len(keys) != 10 {
t.Errorf("expected 10 keys, got %d", len(keys))
}
for _, key := range keys {
if !strings.HasPrefix(key, "foo") {
t.Errorf("expected key %s to start with foo", key)
}
}
}

func TestGetMany(t *testing.T) {
t.Parallel()

c := sturdyc.New[int](1000, 10, time.Hour, 5)

if c.Size() != 0 {
t.Errorf("expected cache size to be 0, got %d", c.Size())
}

records := make(map[string]int, 10)
for i := 0; i < 10; i++ {
records[strconv.Itoa(i)] = i
}
c.SetMany(records)

keys := make([]string, 0, 10)
for key := range records {
keys = append(keys, key)
}

cacheHits := c.GetMany(keys)
if len(cacheHits) != 10 {
for key := range records {
if _, ok := cacheHits[key]; !ok {
t.Errorf("expected key %s to be in the cache", key)
}
}
}
}

func TestEvictsAndReturnsTheCorrectSize(t *testing.T) {
Expand Down
8 changes: 4 additions & 4 deletions inflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func makeCall[T, V any](ctx context.Context, c *Client[T], key string, fn FetchF

response, err := fn(ctx)
if err != nil && c.storeMissingRecords && errors.Is(err, ErrStoreMissingRecord) {
c.SetMissing(key, *new(T), true)
c.StoreMissingRecord(key)
call.err = ErrMissingRecord
return
}
Expand All @@ -52,7 +52,7 @@ func makeCall[T, V any](ctx context.Context, c *Client[T], key string, fn FetchF

call.err = nil
call.val = res
c.SetMissing(key, res, false)
c.Set(key, res)
}

func callAndCache[V, T any](ctx context.Context, c *Client[T], key string, fn FetchFn[V]) (V, error) {
Expand Down Expand Up @@ -107,7 +107,7 @@ func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCa
if c.storeMissingRecords && len(response) < len(opts.ids) {
for _, id := range opts.ids {
if _, ok := response[id]; !ok {
c.SetMissing(opts.keyFn(id), *new(T), true)
c.StoreMissingRecord(opts.keyFn(id))
}
}
}
Expand All @@ -118,7 +118,7 @@ func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCa
if !ok {
continue
}
c.SetMissing(opts.keyFn(id), v, false)
c.Set(opts.keyFn(id), v)
opts.call.val[id] = v
}
}
Expand Down
2 changes: 1 addition & 1 deletion passthrough.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (c *Client[T]) PassthroughBatch(ctx context.Context, ids []string, keyFn Ke
return res, nil
}

values := c.GetMany(ids, keyFn)
values := c.GetManyKeyFn(ids, keyFn)
if len(values) > 0 {
return values, nil
}
Expand Down
10 changes: 5 additions & 5 deletions refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ func (c *Client[T]) refresh(key string, fetchFn FetchFn[T]) {
response, err := fetchFn(context.Background())
if err != nil {
if c.storeMissingRecords && errors.Is(err, ErrStoreMissingRecord) {
c.SetMissing(key, response, true)
c.StoreMissingRecord(key)
}
if errors.Is(err, ErrDeleteRecord) {
c.Delete(key)
}
return
}
c.SetMissing(key, response, false)
c.Set(key, response)
}

func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn[T]) {
Expand All @@ -32,7 +32,7 @@ func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn
// Check if any of the records have been deleted at the data source.
for _, id := range ids {
_, okCache, _, _ := c.get(keyFn(id))
v, okResponse := response[id]
_, okResponse := response[id]

if okResponse {
continue
Expand All @@ -43,12 +43,12 @@ func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn
}

if c.storeMissingRecords && !okResponse {
c.SetMissing(keyFn(id), v, true)
c.StoreMissingRecord(keyFn(id))
}
}

// Cache the refreshed records.
for id, record := range response {
c.SetMissing(keyFn(id), record, false)
c.Set(keyFn(id), record)
}
}
13 changes: 13 additions & 0 deletions shard.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,16 @@ func (s *shard[T]) delete(key string) {
defer s.Unlock()
delete(s.entries, key)
}

func (s *shard[T]) keys() []string {
s.RLock()
defer s.RUnlock()
keys := make([]string, 0, len(s.entries))
for k, v := range s.entries {
if s.clock.Now().After(v.expiresAt) {
continue
}
keys = append(keys, k)
}
return keys
}

0 comments on commit bf36589

Please sign in to comment.