Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions rueidisaside/aside.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ retry:
}

if rueidis.IsRedisNil(err) { // successfully set client id on the key as a lock
// attach TTL pointer to context for potential modification via OverrideCacheTTL
ctx = context.WithValue(ctx, ttlKey, &ttl)
if val, err = fn(ctx, key); err == nil {
err = setkey.Exec(ctx, c.client, []string{key}, []string{id, val, strconv.FormatInt(ttl.Milliseconds(), 10)}).Error()
}
Expand Down Expand Up @@ -222,6 +224,19 @@ func (c *Client) Close() {

const PlaceholderPrefix = "rueidisid:"

type ctxKey struct{}

var ttlKey = ctxKey{}

// OverrideCacheTTL sets a custom TTL for the cache entry being populated in the current context.
// It can be called in the callback function passed to CacheAsideClient.Get() to customize
// the TTL based on the data being cached.
func OverrideCacheTTL(ctx context.Context, ttl time.Duration) {
if p, ok := ctx.Value(ttlKey).(*time.Duration); ok {
*p = ttl
}
}

var (
delkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("DEL",KEYS[1]) else return 0 end`)
setkey = rueidis.NewLuaScript(`if redis.call("GET",KEYS[1]) == ARGV[1] then return redis.call("SET",KEYS[1],ARGV[2],"PX",ARGV[3]) else return 0 end`)
Expand Down
104 changes: 104 additions & 0 deletions rueidisaside/aside_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,107 @@ func TestMultipleClientLL(t *testing.T) {
}
}
}

func TestOverrideCacheTTL(t *testing.T) {
client := makeClient(t, addr)
defer client.Close()
key := strconv.Itoa(rand.Int())

val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
OverrideCacheTTL(ctx, time.Millisecond*300)
return "1", nil
})
if err != nil || val != "1" {
t.Fatal(err)
}

val, err = client.Get(context.Background(), time.Second*5, key, nil)
if err != nil || val != "1" {
t.Fatal(err)
}

time.Sleep(time.Millisecond * 400)
val, err = client.Get(context.Background(), time.Second*5, key, nil) // should miss
if !rueidis.IsRedisNil(err) {
t.Fatal("expected cache miss after overridden TTL expired")
}
}

func TestOverrideCacheTTLLL(t *testing.T) {
client := makeClientWithLuaLock(t, addr)
defer client.Close()
key := strconv.Itoa(rand.Int())

val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
OverrideCacheTTL(ctx, time.Millisecond*300)
return "1", nil
})
if err != nil || val != "1" {
t.Fatal(err)
}

val, err = client.Get(context.Background(), time.Second*5, key, nil)
if err != nil || val != "1" {
t.Fatal(err)
}

time.Sleep(time.Millisecond * 400)
val, err = client.Get(context.Background(), time.Second*5, key, nil) // should miss
if !rueidis.IsRedisNil(err) {
t.Fatal("expected cache miss after overridden TTL expired")
}
}

func TestOverrideCacheTTLNegativeCaching(t *testing.T) {
client := makeClient(t, addr)
defer client.Close()
key := strconv.Itoa(rand.Int())

val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
OverrideCacheTTL(ctx, time.Millisecond*300)
return "NOT_FOUND", nil
})
if err != nil || val != "NOT_FOUND" {
t.Fatal(err)
}

val, err = client.Get(context.Background(), time.Second*5, key, nil)
if err != nil || val != "NOT_FOUND" {
t.Fatal(err)
}

time.Sleep(time.Millisecond * 400)
val, err = client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
return "FOUND", nil
})
if err != nil || val != "FOUND" {
t.Fatal(err)
}
}

func TestOverrideCacheTTLNegativeCachingLL(t *testing.T) {
client := makeClientWithLuaLock(t, addr)
defer client.Close()
key := strconv.Itoa(rand.Int())

val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
OverrideCacheTTL(ctx, time.Millisecond*300)
return "NOT_FOUND", nil
})
if err != nil || val != "NOT_FOUND" {
t.Fatal(err)
}

val, err = client.Get(context.Background(), time.Second*5, key, nil)
if err != nil || val != "NOT_FOUND" {
t.Fatal(err)
}

time.Sleep(time.Millisecond * 400)
val, err = client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (val string, err error) {
return "FOUND", nil
})
if err != nil || val != "FOUND" {
t.Fatal(err)
}
}
54 changes: 54 additions & 0 deletions rueidisaside/typed_aside_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,57 @@ func TestTypedCacheAsideClient_Del(t *testing.T) {
t.Fatal("expected function to be called because the value should be deleted")
}
}

func TestTypedCacheAsideClient_OverrideTTL(t *testing.T) {
baseClient := makeClient(t, addr)
t.Cleanup(baseClient.Close)

serializer := func(v *testStruct) (string, error) {
if v == nil {
return "nilTestStruct", nil
}
b, err := json.Marshal(v)
return string(b), err
}

deserializer := func(s string) (*testStruct, error) {
if s == "nilTestStruct" {
return nil, nil
}
var v testStruct
err := json.Unmarshal([]byte(s), &v)
return &v, err
}

client := NewTypedCacheAsideClient[testStruct](baseClient, serializer, deserializer)

t.Run("override ttl for negative caching", func(t *testing.T) {
key := randStr()

val, err := client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (*testStruct, error) {
OverrideCacheTTL(ctx, time.Millisecond*300)
return nil, nil
})
if err != nil || val != nil {
t.Fatalf("expected nil value, got %v, err: %v", val, err)
}

val, err = client.Get(context.Background(), time.Second*5, key, nil)
if err != nil || val != nil {
t.Fatalf("expected cached nil value, got %v, err: %v", val, err)
}

time.Sleep(time.Millisecond * 400)

found := &testStruct{ID: 42, Name: "found"}
val, err = client.Get(context.Background(), time.Second*5, key, func(ctx context.Context, key string) (*testStruct, error) {
return found, nil
})
if err != nil {
t.Fatal(err)
}
if val.ID != found.ID || val.Name != found.Name {
t.Fatalf("expected %v, got %v", found, val)
}
})
}
Loading