Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add an API for getting all keys #11

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 51 additions & 10 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,31 @@ func (c *Client[T]) get(key string) (value T, exists, ignore, refresh bool) {
return val, exists, ignore, refresh
}

// Get retrieves a single value from the cache.
func (c *Client[T]) Get(key string) (T, bool) {
shard := c.getShard(key)
val, ok, ignore, _ := shard.get(key)
c.reportCacheHits(ok)
return val, ok && !ignore
}

func (c *Client[T]) GetMany(ids []string, keyFn KeyFn) map[string]T {
// GetMany retrieves multiple values from the cache.
func (c *Client[T]) GetMany(keys []string) map[string]T {
records := make(map[string]T, len(keys))
for _, key := range keys {
if value, ok := c.Get(key); ok {
records[key] = value
}
}
return records
}

// GetManyKeyFn follows the same API as GetFetchBatch and PassthroughBatch.
// You provide it with a slice of IDs and a keyFn, which is applied to create
// the cache key. The returned map uses the IDs as keys instead of the cache key.
// If you've used ScanKeys to retrieve the actual keys, you can retrieve the records
// using GetMany instead.
func (c *Client[T]) GetManyKeyFn(ids []string, keyFn KeyFn) map[string]T {
records := make(map[string]T, len(ids))
for _, id := range ids {
if value, ok := c.Get(keyFn(id)); ok {
Expand All @@ -170,29 +187,53 @@ func (c *Client[T]) GetMany(ids []string, keyFn KeyFn) map[string]T {
return records
}

// SetMissing writes a single value to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) SetMissing(key string, value T, isMissingRecord bool) bool {
// Set writes a single value to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) Set(key string, value T) bool {
shard := c.getShard(key)
return shard.set(key, value, isMissingRecord)
return shard.set(key, value, false)
}

// Set writes a single value to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) Set(key string, value T) bool {
return c.SetMissing(key, value, false)
// StoreMissingRecord writes a single value to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) StoreMissingRecord(key string) bool {
shard := c.getShard(key)
return shard.set(key, *new(T), true)
}

// SetMany writes a map of key value pairs to the cache.
func (c *Client[T]) SetMany(records map[string]T) bool {
var triggeredEviction bool
for key, value := range records {
evicted := c.Set(key, value)
if evicted {
triggeredEviction = true
}
}
return triggeredEviction
}

// SetMany writes multiple values to the cache. Returns true if it triggered an eviction.
func (c *Client[T]) SetMany(records map[string]T, cacheKeyFn KeyFn) bool {
// SetManyKeyFn follows the same API as GetFetchBatch and PassThroughBatch. It
// takes a map of records where the keyFn is applied to each key in the map
// before it's stored in the cache.
func (c *Client[T]) SetManyKeyFn(records map[string]T, cacheKeyFn KeyFn) bool {
var triggeredEviction bool
for id, value := range records {
evicted := c.SetMissing(cacheKeyFn(id), value, false)
evicted := c.Set(cacheKeyFn(id), value)
if evicted {
triggeredEviction = true
}
}
return triggeredEviction
}

// ScanKeys returns a list of all keys in the cache.
func (c *Client[T]) ScanKeys() []string {
keys := make([]string, 0, c.Size())
for _, shard := range c.shards {
keys = append(keys, shard.keys()...)
}
return keys
}

// Size returns the number of entries in the cache.
func (c *Client[T]) Size() int {
var sum int
Expand Down
73 changes: 72 additions & 1 deletion cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sturdyc_test

import (
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -208,11 +209,81 @@ func TestSetMany(t *testing.T) {
for i := 0; i < 10; i++ {
records[strconv.Itoa(i)] = i
}
c.SetMany(records, c.BatchKeyFn("key"))
c.SetMany(records)

if c.Size() != 10 {
t.Errorf("expected cache size to be 10, got %d", c.Size())
}

keys := c.ScanKeys()
if len(keys) != 10 {
t.Errorf("expected 10 keys, got %d", len(keys))
}
for _, key := range keys {
if _, ok := records[key]; !ok {
t.Errorf("expected key %s to be in the cache", key)
}
}
}

func TestSetManyKeyFn(t *testing.T) {
t.Parallel()

c := sturdyc.New[int](1000, 10, time.Hour, 5)

if c.Size() != 0 {
t.Errorf("expected cache size to be 0, got %d", c.Size())
}

records := make(map[string]int, 10)
for i := 0; i < 10; i++ {
records[strconv.Itoa(i)] = i
}
c.SetManyKeyFn(records, c.BatchKeyFn("foo"))

if c.Size() != 10 {
t.Errorf("expected cache size to be 10, got %d", c.Size())
}

keys := c.ScanKeys()
if len(keys) != 10 {
t.Errorf("expected 10 keys, got %d", len(keys))
}
for _, key := range keys {
if !strings.HasPrefix(key, "foo") {
t.Errorf("expected key %s to start with foo", key)
}
}
}

func TestGetMany(t *testing.T) {
t.Parallel()

c := sturdyc.New[int](1000, 10, time.Hour, 5)

if c.Size() != 0 {
t.Errorf("expected cache size to be 0, got %d", c.Size())
}

records := make(map[string]int, 10)
for i := 0; i < 10; i++ {
records[strconv.Itoa(i)] = i
}
c.SetMany(records)

keys := make([]string, 0, 10)
for key := range records {
keys = append(keys, key)
}

cacheHits := c.GetMany(keys)
if len(cacheHits) != 10 {
for key := range records {
if _, ok := cacheHits[key]; !ok {
t.Errorf("expected key %s to be in the cache", key)
}
}
}
}

func TestEvictsAndReturnsTheCorrectSize(t *testing.T) {
Expand Down
8 changes: 4 additions & 4 deletions inflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func makeCall[T, V any](ctx context.Context, c *Client[T], key string, fn FetchF

response, err := fn(ctx)
if err != nil && c.storeMissingRecords && errors.Is(err, ErrStoreMissingRecord) {
c.SetMissing(key, *new(T), true)
c.StoreMissingRecord(key)
call.err = ErrMissingRecord
return
}
Expand All @@ -52,7 +52,7 @@ func makeCall[T, V any](ctx context.Context, c *Client[T], key string, fn FetchF

call.err = nil
call.val = res
c.SetMissing(key, res, false)
c.Set(key, res)
}

func callAndCache[V, T any](ctx context.Context, c *Client[T], key string, fn FetchFn[V]) (V, error) {
Expand Down Expand Up @@ -107,7 +107,7 @@ func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCa
if c.storeMissingRecords && len(response) < len(opts.ids) {
for _, id := range opts.ids {
if _, ok := response[id]; !ok {
c.SetMissing(opts.keyFn(id), *new(T), true)
c.StoreMissingRecord(opts.keyFn(id))
}
}
}
Expand All @@ -118,7 +118,7 @@ func makeBatchCall[T, V any](ctx context.Context, c *Client[T], opts makeBatchCa
if !ok {
continue
}
c.SetMissing(opts.keyFn(id), v, false)
c.Set(opts.keyFn(id), v)
opts.call.val[id] = v
}
}
Expand Down
2 changes: 1 addition & 1 deletion passthrough.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func (c *Client[T]) PassthroughBatch(ctx context.Context, ids []string, keyFn Ke
return res, nil
}

values := c.GetMany(ids, keyFn)
values := c.GetManyKeyFn(ids, keyFn)
if len(values) > 0 {
return values, nil
}
Expand Down
10 changes: 5 additions & 5 deletions refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ func (c *Client[T]) refresh(key string, fetchFn FetchFn[T]) {
response, err := fetchFn(context.Background())
if err != nil {
if c.storeMissingRecords && errors.Is(err, ErrStoreMissingRecord) {
c.SetMissing(key, response, true)
c.StoreMissingRecord(key)
}
if errors.Is(err, ErrDeleteRecord) {
c.Delete(key)
}
return
}
c.SetMissing(key, response, false)
c.Set(key, response)
}

func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn[T]) {
Expand All @@ -32,7 +32,7 @@ func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn
// Check if any of the records have been deleted at the data source.
for _, id := range ids {
_, okCache, _, _ := c.get(keyFn(id))
v, okResponse := response[id]
_, okResponse := response[id]

if okResponse {
continue
Expand All @@ -43,12 +43,12 @@ func (c *Client[T]) refreshBatch(ids []string, keyFn KeyFn, fetchFn BatchFetchFn
}

if c.storeMissingRecords && !okResponse {
c.SetMissing(keyFn(id), v, true)
c.StoreMissingRecord(keyFn(id))
}
}

// Cache the refreshed records.
for id, record := range response {
c.SetMissing(keyFn(id), record, false)
c.Set(keyFn(id), record)
}
}
13 changes: 13 additions & 0 deletions shard.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,16 @@ func (s *shard[T]) delete(key string) {
defer s.Unlock()
delete(s.entries, key)
}

func (s *shard[T]) keys() []string {
s.RLock()
defer s.RUnlock()
keys := make([]string, 0, len(s.entries))
for k, v := range s.entries {
if s.clock.Now().After(v.expiresAt) {
continue
}
keys = append(keys, k)
}
return keys
}
Loading