Skip to content

Commit 490616b

Browse files
committed
feat(rueidisaside): add OverrideCacheTTL for custom TTL per cache entry
1 parent 0a9e664 commit 490616b

File tree

3 files changed

+173
-0
lines changed

3 files changed

+173
-0
lines changed

rueidisaside/aside.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ retry:
163163
}
164164

165165
if rueidis.IsRedisNil(err) { // successfully set client id on the key as a lock
166+
// attach TTL pointer to context for potential modification via OverrideCacheTTL
167+
ctx = context.WithValue(ctx, ttlKey, &ttl)
166168
if val, err = fn(ctx, key); err == nil {
167169
err = setkey.Exec(ctx, c.client, []string{key}, []string{id, val, strconv.FormatInt(ttl.Milliseconds(), 10)}).Error()
168170
}
@@ -222,6 +224,19 @@ func (c *Client) Close() {
222224

223225
const PlaceholderPrefix = "rueidisid:"
224226

227+
type ctxKey struct{}
228+
229+
var ttlKey = ctxKey{}
230+
231+
// OverrideCacheTTL sets a custom TTL for the cache entry being populated in the current context.
232+
// It can be called in the callback function passed to CacheAsideClient.Get() to customize
233+
// the TTL based on the data being cached.
234+
func OverrideCacheTTL(ctx context.Context, ttl time.Duration) {
235+
if p, ok := ctx.Value(ttlKey).(*time.Duration); ok {
236+
*p = ttl
237+
}
238+
}
239+
225240
var (
226241
delkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`)
227242
setkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`)

rueidisaside/aside_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,3 +512,107 @@ func TestMultipleClientLL(t *testing.T) {
512512
}
513513
}
514514
}
515+
516+
func TestOverrideCacheTTL(t *testing.T) {
517+
client := makeClient(t, addr)
518+
defer client.Close()
519+
key := strconv.Itoa(rand.Int())
520+
521+
val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
522+
OverrideCacheTTL(ctx, time.Millisecond*300)
523+
return "1", nil
524+
})
525+
if err != nil || val != "1" {
526+
t.Fatal(err)
527+
}
528+
529+
val, err = client.Get(context.Background(), time.Second*5, key, nil)
530+
if err != nil || val != "1" {
531+
t.Fatal(err)
532+
}
533+
534+
time.Sleep(time.Millisecond * 400)
535+
val, err = client.Get(context.Background(), time.Second*5, key, nil) // should miss
536+
if !rueidis.IsRedisNil(err) {
537+
t.Fatal("expected cache miss after overridden TTL expired")
538+
}
539+
}
540+
541+
func TestOverrideCacheTTLLL(t *testing.T) {
542+
client := makeClientWithLuaLock(t, addr)
543+
defer client.Close()
544+
key := strconv.Itoa(rand.Int())
545+
546+
val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
547+
OverrideCacheTTL(ctx, time.Millisecond*300)
548+
return "1", nil
549+
})
550+
if err != nil || val != "1" {
551+
t.Fatal(err)
552+
}
553+
554+
val, err = client.Get(context.Background(), time.Second*5, key, nil)
555+
if err != nil || val != "1" {
556+
t.Fatal(err)
557+
}
558+
559+
time.Sleep(time.Millisecond * 400)
560+
val, err = client.Get(context.Background(), time.Second*5, key, nil) // should miss
561+
if !rueidis.IsRedisNil(err) {
562+
t.Fatal("expected cache miss after overridden TTL expired")
563+
}
564+
}
565+
566+
func TestOverrideCacheTTLNegativeCaching(t *testing.T) {
567+
client := makeClient(t, addr)
568+
defer client.Close()
569+
key := strconv.Itoa(rand.Int())
570+
571+
val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
572+
OverrideCacheTTL(ctx, time.Millisecond*300)
573+
return "NOT_FOUND", nil
574+
})
575+
if err != nil || val != "NOT_FOUND" {
576+
t.Fatal(err)
577+
}
578+
579+
val, err = client.Get(context.Background(), time.Second*5, key, nil)
580+
if err != nil || val != "NOT_FOUND" {
581+
t.Fatal(err)
582+
}
583+
584+
time.Sleep(time.Millisecond * 400)
585+
val, err = client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
586+
return "FOUND", nil
587+
})
588+
if err != nil || val != "FOUND" {
589+
t.Fatal(err)
590+
}
591+
}
592+
593+
func TestOverrideCacheTTLNegativeCachingLL(t *testing.T) {
594+
client := makeClientWithLuaLock(t, addr)
595+
defer client.Close()
596+
key := strconv.Itoa(rand.Int())
597+
598+
val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
599+
OverrideCacheTTL(ctx, time.Millisecond*300)
600+
return "NOT_FOUND", nil
601+
})
602+
if err != nil || val != "NOT_FOUND" {
603+
t.Fatal(err)
604+
}
605+
606+
val, err = client.Get(context.Background(), time.Second*5, key, nil)
607+
if err != nil || val != "NOT_FOUND" {
608+
t.Fatal(err)
609+
}
610+
611+
time.Sleep(time.Millisecond * 400)
612+
val, err = client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
613+
return "FOUND", nil
614+
})
615+
if err != nil || val != "FOUND" {
616+
t.Fatal(err)
617+
}
618+
}

rueidisaside/typed_aside_test.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,57 @@ func TestTypedCacheAsideClient_Del(t *testing.T) {
163163
t.Fatal("expected function to be called because the value should be deleted")
164164
}
165165
}
166+
167+
func TestTypedCacheAsideClient_OverrideTTL(t *testing.T) {
168+
baseClient := makeClient(t, addr)
169+
t.Cleanup(baseClient.Close)
170+
171+
serializer := func(v *testStruct) (string, error) {
172+
if v == nil {
173+
return "nilTestStruct", nil
174+
}
175+
b, err := json.Marshal(v)
176+
return string(b), err
177+
}
178+
179+
deserializer := func(s string) (*testStruct, error) {
180+
if s == "nilTestStruct" {
181+
return nil, nil
182+
}
183+
var v testStruct
184+
err := json.Unmarshal([]byte(s), &v)
185+
return &v, err
186+
}
187+
188+
client := NewTypedCacheAsideClient[testStruct](baseClient, serializer, deserializer)
189+
190+
t.Run("override ttl for negative caching", func(t *testing.T) {
191+
key := randStr()
192+
193+
val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (*testStruct, error) {
194+
OverrideCacheTTL(ctx, time.Millisecond*300)
195+
return nil, nil
196+
})
197+
if err != nil || val != nil {
198+
t.Fatalf("expected nil value, got %v, err: %v", val, err)
199+
}
200+
201+
val, err = client.Get(context.Background(), time.Second*5, key, nil)
202+
if err != nil || val != nil {
203+
t.Fatalf("expected cached nil value, got %v, err: %v", val, err)
204+
}
205+
206+
time.Sleep(time.Millisecond * 400)
207+
208+
found := &testStruct{ID: 42, Name: "found"}
209+
val, err = client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (*testStruct, error) {
210+
return found, nil
211+
})
212+
if err != nil {
213+
t.Fatal(err)
214+
}
215+
if val.ID != found.ID || val.Name != found.Name {
216+
t.Fatalf("expected %v, got %v", found, val)
217+
}
218+
})
219+
}

0 commit comments

Comments
 (0)