Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/core/bloom.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <absl/numeric/bits.h>
#include <xxhash.h>

#include <algorithm>
#include <cmath>

#include "base/logging.h"
Expand Down
22 changes: 3 additions & 19 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,6 @@ struct Connection::Shutdown {
}
};

Connection::PubMessage::PubMessage(string pattern, shared_ptr<char[]> buf, size_t channel_len,
size_t message_len)
: pattern{std::move(pattern)},
buf{std::move(buf)},
channel_len{channel_len},
message_len{message_len} {
}

string_view Connection::PubMessage::Channel() const {
return {buf.get(), channel_len};
}

string_view Connection::PubMessage::Message() const {
return {buf.get() + channel_len, message_len};
}

void Connection::PipelineMessage::SetArgs(const RespVec& args) {
auto* next = storage.data();
for (size_t i = 0; i < args.size(); ++i) {
Expand Down Expand Up @@ -361,7 +345,7 @@ size_t Connection::PipelineMessage::StorageCapacity() const {
size_t Connection::MessageHandle::UsedMemory() const {
struct MessageSize {
size_t operator()(const PubMessagePtr& msg) {
return sizeof(PubMessage) + (msg->channel_len + msg->message_len);
return sizeof(PubMessage) + (msg->channel.size() + msg->message.size());
}
size_t operator()(const PipelineMessagePtr& msg) {
return sizeof(PipelineMessage) + msg->args.capacity() * sizeof(MutableSlice) +
Expand Down Expand Up @@ -449,8 +433,8 @@ void Connection::DispatchOperations::operator()(const PubMessage& pub_msg) {
arr[i++] = "pmessage";
arr[i++] = pub_msg.pattern;
}
arr[i++] = pub_msg.Channel();
arr[i++] = pub_msg.Message();
arr[i++] = pub_msg.channel;
arr[i++] = pub_msg.message;
rbuilder->SendStringArr(absl::Span<string_view>{arr.data(), i},
RedisReplyBuilder::CollectionType::PUSH);
}
Expand Down
12 changes: 3 additions & 9 deletions src/facade/dragonfly_connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,9 @@ class Connection : public util::Connection {

// PubSub message, either incoming message for active subscription or reply for new subscription.
struct PubMessage {
std::string pattern{}; // non-empty for pattern subscriber
std::shared_ptr<char[]> buf; // stores channel name and message
size_t channel_len, message_len; // lengths in buf

std::string_view Channel() const;
std::string_view Message() const;

PubMessage(std::string pattern, std::shared_ptr<char[]> buf, size_t channel_len,
size_t message_len);
std::string pattern{}; // non-empty for pattern subscriber
std::shared_ptr<char[]> buf; // stores channel name and message
std::string_view channel, message; // channel and message parts from buf
};

// Pipeline message, accumulated Redis command to be executed.
Expand Down
2 changes: 1 addition & 1 deletion src/facade/reply_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ void RedisReplyBuilder::SendMGetResponse(MGetResponse resp) {

void RedisReplyBuilder::SendSimpleStrArr(StrSpan arr) {
string res = absl::StrCat("*", arr.Size(), kCRLF);
for (std::string_view str : arr)
for (string_view str : arr)
StrAppend(&res, "+", str, kCRLF);

SendRaw(res);
Expand Down
4 changes: 2 additions & 2 deletions src/server/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ add_library(dfly_transaction db_slice.cc malloc_stats.cc blocking_controller.cc
common.cc journal/journal.cc journal/types.cc journal/journal_slice.cc
server_state.cc table.cc top_keys.cc transaction.cc tx_base.cc
serializer_commons.cc journal/serializer.cc journal/executor.cc journal/streamer.cc
${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc)
${TX_LINUX_SRCS} acl/acl_log.cc slowlog.cc channel_store.cc)

SET(DF_SEARCH_SRCS search/search_family.cc search/doc_index.cc search/doc_accessors.cc
search/aggregator.cc)
Expand All @@ -43,7 +43,7 @@ if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux")
cxx_test(tiered_storage_test dfly_test_lib LABELS DFLY)
endif()

add_library(dragonfly_lib bloom_family.cc engine_shard_set.cc channel_store.cc
add_library(dragonfly_lib bloom_family.cc engine_shard_set.cc
config_registry.cc conn_context.cc debugcmd.cc dflycmd.cc
generic_family.cc hset_family.cc http_api.cc json_family.cc
list_family.cc main_service.cc memory_cmd.cc rdb_load.cc rdb_save.cc replica.cc
Expand Down
57 changes: 57 additions & 0 deletions src/server/channel_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@ bool Matches(string_view pattern, string_view channel) {
return stringmatchlen(pattern.data(), pattern.size(), channel.data(), channel.size(), 0) == 1;
}

// Build functor for sending messages to connection
auto BuildSender(string_view channel, facade::ArgRange messages) {
size_t messages_size = accumulate(messages.begin(), messages.end(), 0,
[](int sum, string_view str) { return sum + str.size(); });
auto buf = shared_ptr<char[]>{new char[channel.size() + messages_size]};
{
memcpy(buf.get(), channel.data(), channel.size());
char* ptr = buf.get() + channel.size();
for (string_view message : messages) {
memcpy(ptr, message.data(), message.size());
ptr += message.size();
}
}

return [channel, buf = std::move(buf), messages](facade::Connection* conn, string pattern) {
size_t offset = channel.size();
for (std::string_view message : messages) {
conn->SendPubMessageAsync({std::move(pattern), buf, string_view{buf.get(), channel.size()},
string_view{buf.get() + offset, message.size()}});
offset += message.size();
}
};
}

} // namespace

bool ChannelStore::Subscriber::ByThread(const Subscriber& lhs, const Subscriber& rhs) {
Expand Down Expand Up @@ -95,6 +119,39 @@ void ChannelStore::Destroy() {

ChannelStore::ControlBlock ChannelStore::control_block;

unsigned ChannelStore::SendMessages(std::string_view channel, facade::ArgRange messages) const {
vector<Subscriber> subscribers = FetchSubscribers(channel);
if (subscribers.empty())
return 0;

// Make sure none of the threads publish buffer limits is reached. We don't reserve memory ahead
// and don't prevent the buffer from possibly filling, but the approach is good enough for
// limiting fast producers. Most importantly, we can use DispatchBrief below as we block here
optional<uint32_t> last_thread;
for (auto& sub : subscribers) {
DCHECK_LE(last_thread.value_or(0), sub.Thread());
if (last_thread && *last_thread == sub.Thread()) // skip same thread
continue;

if (sub.EnsureMemoryBudget()) // Invalid pointers are skipped
last_thread = sub.Thread();
}

auto subscribers_ptr = make_shared<decltype(subscribers)>(std::move(subscribers));
auto cb = [subscribers_ptr, send = BuildSender(channel, messages)](unsigned idx, auto*) {
auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx,
ChannelStore::Subscriber::ByThreadId);
while (it != subscribers_ptr->end() && it->Thread() == idx) {
if (auto* ptr = it->Get(); ptr)
send(ptr, it->pattern);
it++;
}
};
shard_set->pool()->DispatchBrief(std::move(cb));

return subscribers_ptr->size();
}

vector<ChannelStore::Subscriber> ChannelStore::FetchSubscribers(string_view channel) const {
vector<Subscriber> res;

Expand Down
3 changes: 3 additions & 0 deletions src/server/channel_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class ChannelStore {

ChannelStore();

// Send messages to channel, block on connection backpressure
unsigned SendMessages(std::string_view channel, facade::ArgRange messages) const;

// Fetch all subscribers for channel, including matching patterns.
std::vector<Subscriber> FetchSubscribers(std::string_view channel) const;

Expand Down
38 changes: 26 additions & 12 deletions src/server/db_slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "base/flags.h"
#include "base/logging.h"
#include "generic_family.h"
#include "server/channel_store.h"
#include "server/cluster/cluster_defs.h"
#include "server/engine_shard_set.h"
#include "server/error.h"
Expand All @@ -33,6 +34,8 @@ ABSL_FLAG(double, table_growth_margin, 0.4,
"Prevents table from growing if number of free slots x average object size x this ratio "
"is larger than memory budget.");

ABSL_FLAG(bool, expiration_keyspace_events, false, "Send keyspace events for expiration");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to what setting is it equivalent for notify-keyspace-events
https://redis.io/docs/latest/develop/use/keyspace-notifications/#configuration

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it's configured with it's own flag syntax, so not sure if we take over it, just keep the name with a boolean flag or use a differently named flag

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my question is what is the setting of notify-keyspace-events in valkey should be to roughly be equivalent to what you implemented here. How do you explain this feature to a valkey user?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting true is equvalent to using Ke in notify-keyspace-events

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, please add it to the help string

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about 'x' ? we do not handle it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eviction? No, we don't, but we can if needed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'x' is expiry and 'e' is eviction. So which one we support? :)
is it Ke or Kx ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, Kx then, I forgot the naming 😅


namespace dfly {

using namespace std;
Expand Down Expand Up @@ -204,9 +207,7 @@ unsigned PrimeEvictionPolicy::Evict(const PrimeTable::HotspotBuckets& eb, PrimeT

// log the evicted keys to journal.
if (auto journal = db_slice_->shard_owner()->journal(); journal) {
ArgSlice delete_args(&key, 1);
journal->RecordEntry(0, journal::Op::EXPIRED, cntx_.db_index, 1, cluster::KeySlot(key),
Payload("DEL", delete_args), false);
RecordExpiry(cntx_.db_index, key);
}

db_slice_->PerformDeletion(DbSlice::Iterator(last_slot_it, StringOrView::FromView(key)), table);
Expand Down Expand Up @@ -268,6 +269,7 @@ DbSlice::DbSlice(uint32_t index, bool caching_mode, EngineShard* owner)
CreateDb(0);
expire_base_[0] = expire_base_[1] = 0;
soft_budget_limit_ = (0.3 * max_memory_limit / shard_set->size());
expired_keys_events_recording_ = GetFlag(FLAGS_expiration_keyspace_events);
}

DbSlice::~DbSlice() {
Expand Down Expand Up @@ -1047,11 +1049,15 @@ DbSlice::PrimeItAndExp DbSlice::ExpireIfNeeded(const Context& cntx, PrimeIterato
<< ", expire table size: " << db->expire.size()
<< ", prime table size: " << db->prime.size() << util::fb2::GetStacktrace();
}

// Replicate expiry
if (auto journal = owner_->journal(); journal) {
RecordExpiry(cntx.db_index, key);
}

if (expired_keys_events_recording_)
db->expired_keys_events_.emplace_back(key);

auto obj_type = it->second.ObjType();
if (doc_del_cb_ && (obj_type == OBJ_JSON || obj_type == OBJ_HASH)) {
doc_del_cb_(key, cntx, it->second);
Expand Down Expand Up @@ -1157,6 +1163,13 @@ auto DbSlice::DeleteExpiredStep(const Context& cntx, unsigned count) -> DeleteEx
}
}

// Send and clear accumulated expired key events
if (auto& events = db_arr_[cntx.db_index]->expired_keys_events_; !events.empty()) {
ChannelStore* store = ServerState::tlocal()->channel_store();
store->SendMessages(absl::StrCat("__keyevent@", cntx.db_index, "__:expired"), events);
events.clear();
}

return result;
}

Expand Down Expand Up @@ -1185,6 +1198,8 @@ void DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t increase_goal_bytes
string tmp;
int32_t starting_segment_id = rand() % num_segments;
size_t used_memory_before = owner_->UsedMemory();

bool record_keys = owner_->journal() != nullptr || expired_keys_events_recording_;
vector<string> keys_to_journal;

{
Expand Down Expand Up @@ -1213,9 +1228,8 @@ void DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t increase_goal_bytes
if (lt.Find(LockTag(key)).has_value())
continue;

if (auto journal = owner_->journal(); journal) {
keys_to_journal.push_back(string(key));
}
if (record_keys)
keys_to_journal.emplace_back(key);

PerformDeletion(Iterator(evict_it, StringOrView::FromView(key)), db_table.get());
++evicted;
Expand All @@ -1233,12 +1247,12 @@ void DbSlice::FreeMemWithEvictionStep(DbIndex db_ind, size_t increase_goal_bytes
finish:
// send the deletion to the replicas.
// fiber preemption could happen in this phase.
if (auto journal = owner_->journal(); journal) {
for (string_view key : keys_to_journal) {
ArgSlice delete_args(&key, 1);
journal->RecordEntry(0, journal::Op::EXPIRED, db_ind, 1, cluster::KeySlot(key),
Payload("DEL", delete_args), false);
}
for (string_view key : keys_to_journal) {
if (auto journal = owner_->journal(); journal)
RecordExpiry(db_ind, key);

if (expired_keys_events_recording_)
db_table->expired_keys_events_.emplace_back(key);
}

auto time_finish = absl::GetCurrentTimeNanos();
Expand Down
3 changes: 3 additions & 0 deletions src/server/db_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ class DbSlice {
// Registered by shard indices on when first document index is created.
DocDeletionCallback doc_del_cb_;

// Record whenever a key expired to DbTable::expired_keys_events_ for keyspace notifications
bool expired_keys_events_recording_ = true;

struct Hash {
size_t operator()(const facade::Connection::WeakRef& c) const {
return std::hash<uint32_t>()(c.GetClientId());
Expand Down
4 changes: 2 additions & 2 deletions src/server/dragonfly_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ TEST_F(DflyEngineTest, PSubscribe) {
ASSERT_EQ(1, SubscriberMessagesLen("IO1"));

const auto& msg = GetPublishedMessage("IO1", 0);
EXPECT_EQ("foo", msg.Message());
EXPECT_EQ("ab", msg.Channel());
EXPECT_EQ("foo", msg.message);
EXPECT_EQ("ab", msg.channel);
EXPECT_EQ("a*", msg.pattern);
}

Expand Down
43 changes: 2 additions & 41 deletions src/server/main_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2243,49 +2243,10 @@ void Service::Exec(CmdArgList args, ConnectionContext* cntx) {

void Service::Publish(CmdArgList args, ConnectionContext* cntx) {
string_view channel = ArgS(args, 0);
string_view msg = ArgS(args, 1);
string_view messages[] = {ArgS(args, 1)};

auto* cs = ServerState::tlocal()->channel_store();
vector<ChannelStore::Subscriber> subscribers = cs->FetchSubscribers(channel);
int num_published = subscribers.size();
if (!subscribers.empty()) {
// Make sure neither of the threads limits is reached.
// This check actually doesn't reserve any memory ahead and doesn't prevent the buffer
// from eventually filling up, especially if multiple clients are unblocked simultaneously,
// but is generally good enough to limit too fast producers.
// Most importantly, this approach allows not blocking and not awaiting in the dispatch below,
// thus not adding any overhead to backpressure checks.
optional<uint32_t> last_thread;
for (auto& sub : subscribers) {
DCHECK_LE(last_thread.value_or(0), sub.Thread());
if (last_thread && *last_thread == sub.Thread()) // skip same thread
continue;

if (sub.EnsureMemoryBudget()) // Invalid pointers are skipped
last_thread = sub.Thread();
}

auto subscribers_ptr = make_shared<decltype(subscribers)>(std::move(subscribers));
auto buf = shared_ptr<char[]>{new char[channel.size() + msg.size()]};
memcpy(buf.get(), channel.data(), channel.size());
memcpy(buf.get() + channel.size(), msg.data(), msg.size());

auto cb = [subscribers_ptr, buf, channel, msg](unsigned idx, util::ProactorBase*) {
auto it = lower_bound(subscribers_ptr->begin(), subscribers_ptr->end(), idx,
ChannelStore::Subscriber::ByThreadId);

while (it != subscribers_ptr->end() && it->Thread() == idx) {
if (auto* ptr = it->Get(); ptr) {
ptr->SendPubMessageAsync(
{std::move(it->pattern), std::move(buf), channel.size(), msg.size()});
}
it++;
}
};
shard_set->pool()->DispatchBrief(std::move(cb));
}

cntx->SendLong(num_published);
cntx->SendLong(cs->SendMessages(channel, messages));
}

void Service::Subscribe(CmdArgList args, ConnectionContext* cntx) {
Expand Down
3 changes: 3 additions & 0 deletions src/server/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ struct DbTable : boost::intrusive_ref_counter<DbTable, boost::thread_unsafe_coun
// Stores a list of dependant connections for each watched key.
absl::flat_hash_map<std::string, std::vector<ConnectionState::ExecInfo*>> watched_keys;

// Keyspace notifications: list of expired keys since last batch of messages was published.
mutable std::vector<std::string> expired_keys_events_;

mutable DbTableStats stats;
std::vector<SlotStats> slots_stats;
ExpireTable::Cursor expire_cursor;
Expand Down
26 changes: 26 additions & 0 deletions tests/dragonfly/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,32 @@ async def subscribe_worker():
await async_pool.disconnect()


@dfly_args({"expiration_keyspace_events": "true"})
async def test_keyspace_events(async_client: aioredis.Redis):
pclient = async_client.pubsub()
await pclient.subscribe("__keyevent@0__:expired")

keys = []
for i in range(10, 50):
keys.append(f"k{i}")
await async_client.set(keys[-1], "X", px=200 + i * 10)

# We don't support immediate expiration:
# keys += ['immediate']
# await async_client.set(keys[-1], 'Y', exat=123) # expired 50 years ago

events = []
async for message in pclient.listen():
if message["type"] == "subscribe":
continue

events.append(message)
if len(events) >= len(keys):
break

assert set(ev["data"] for ev in events) == set(keys)


async def test_big_command(df_server, size=8 * 1024):
reader, writer = await asyncio.open_connection("127.0.0.1", df_server.port)

Expand Down