Skip to content
8 changes: 8 additions & 0 deletions src/core/overloaded.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Copyright 2022, DragonflyDB authors. All rights reserved.
// See LICENSE for licensing terms.
//
//

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this but the implementation is slightly different because it adds an extra overloads.
Also I am not sure if we can have the licence here ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can also replace the overloaded in other files wit this include?

Copy link
Contributor Author

@kostasrim kostasrim Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's only used in dragonfly_connection.cc

template <class... Ts> struct Overloaded : Ts... {
    using Ts::operator()...;
   
    template <typename T, typename D> size_t operator()(const unique_ptr<T, D>& ptr) {
        return operator()(*ptr.get());
    }
 };

So I can't really replace it because:

  1. template in lambdas is C++20 and it wouldn't be wise to use auto instead of std::unique_ptr<T, D>
  2. I could inherit from my Overload but that would be exactly of the same verbosity + we would need a new deduction rule for this new derived class (therefore literally exactly the same code)

There are some things in C++20 I really love and would be useful in certain cases -- who knows maybe we consider it at some point in the future :D

template <class... Ts> struct Overloaded : Ts... { using Ts::operator()...; };

template <class... Ts> Overloaded(Ts...) -> Overloaded<Ts...>;
6 changes: 3 additions & 3 deletions src/server/acl/acl_commands_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ inline const absl::flat_hash_map<std::string_view, uint32_t> CATEGORY_INDEX_TABL
{"SCRIPTING", SCRIPTING},
{"FT_SEARCH", FT_SEARCH},
{"THROTTLE", THROTTLE},
{"JSON", JSON}

};
{"JSON", JSON},
{"ALL", ALL},
{"NONE", NONE}};

// bit 0 at index 0
// bit 1 at index 1
Expand Down
115 changes: 110 additions & 5 deletions src/server/acl/acl_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,30 @@

#include "server/acl/acl_family.h"

#include <optional>
#include <variant>

#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "core/overloaded.h"
#include "facade/facade_types.h"
#include "server/acl/acl_commands_def.h"
#include "server/command_registry.h"
#include "server/conn_context.h"
#include "server/server_state.h"

namespace dfly::acl {

constexpr uint32_t kList = acl::ADMIN | acl::SLOW | acl::DANGEROUS;

static std::string AclToString(uint32_t acl_category) {
std::string tmp;

if (acl_category == acl::ALL) {
return "+@all";
return "+@ALL";
}

if (acl_category == acl::NONE) {
return "+@none";
return "+@NONE";
}

const std::string prefix = "+@";
Expand All @@ -32,7 +37,8 @@ static std::string AclToString(uint32_t acl_category) {
absl::StrAppend(&tmp, prefix, REVERSE_CATEGORY_INDEX_TABLE[step], postfix);
}
}
tmp.erase(tmp.size());

tmp.pop_back();

return tmp;
}
Expand Down Expand Up @@ -61,10 +67,107 @@ void AclFamily::List(CmdArgList args, ConnectionContext* cntx) {
}
}

namespace {

std::optional<std::string> MaybeParsePassword(std::string_view command) {
if (command[0] != '>') {
return {};
}

return {std::string(command.substr(1))};
}

std::optional<bool> MaybeParseStatus(std::string_view command) {
if (command == "ON") {
return true;
}
if (command == "OFF") {
return false;
}
return {};
}

using OptCat = std::optional<uint32_t>;

// bool == true if +
// bool == false if -
std::pair<OptCat, bool> MaybeParseAclCategory(std::string_view command) {
if (absl::StartsWith(command, "+@")) {
auto res = CATEGORY_INDEX_TABLE.find(command.substr(2));
if (res == CATEGORY_INDEX_TABLE.end()) {
return {};
}
return {res->second, true};
}

if (absl::StartsWith(command, "-@")) {
auto res = CATEGORY_INDEX_TABLE.find(command.substr(2));
if (res == CATEGORY_INDEX_TABLE.end()) {
return {};
}
return {res->second, false};
}

return {};
}

using facade::ErrorReply;

std::variant<User::UpdateRequest, ErrorReply> ParseAclSetUser(CmdArgList args) {
User::UpdateRequest req;

for (auto arg : args) {
ToUpper(&arg);
const auto command = facade::ToSV(arg);
if (auto pass = MaybeParsePassword(command); pass) {
if (req.password) {
return ErrorReply("Only one password is allowed");
}
req.password = std::move(pass);
continue;
}

if (auto status = MaybeParseStatus(command); status) {
if (req.is_active) {
return ErrorReply("Multiple ON/OFF are not allowed");
}
req.is_active = *status;
continue;
}

auto [cat, add] = MaybeParseAclCategory(command);
if (!cat) {
return ErrorReply(absl::StrCat("Unrecognized paramter", command));
}

auto* acl_field = add ? &req.plus_acl_categories : &req.minus_acl_categories;
*acl_field = acl_field->value_or(0) | *cat;
}

return req;
}

} // namespace

void AclFamily::SetUser(CmdArgList args, ConnectionContext* cntx) {
std::string_view username = facade::ToSV(args[0]);
auto req = ParseAclSetUser(args.subspan(1));
auto error_case = [cntx](ErrorReply&& error) { (*cntx)->SendError(error); };
auto update_case = [username, cntx](User::UpdateRequest&& req) {
ServerState::tlocal()->user_registry->MaybeAddAndUpdate(username, std::move(req));
(*cntx)->SendOk();
};

std::visit(Overloaded{error_case, update_case}, std::move(req));
}

using CI = dfly::CommandId;

#define HFUNC(x) SetHandler(&AclFamily::x)

constexpr uint32_t kList = acl::ADMIN | acl::SLOW | acl::DANGEROUS;
constexpr uint32_t kSetUser = acl::ADMIN | acl::SLOW | acl::DANGEROUS;

// We can't implement the ACL commands and its respective subcommands LIST, CAT, etc
// the usual way, (that is, one command called ACL which then dispatches to the subcommand
// based on the secocond argument) because each of the subcommands has different ACL
Expand All @@ -76,6 +179,8 @@ void AclFamily::Register(dfly::CommandRegistry* registry) {
*registry << CI{"ACL", CO::NOSCRIPT | CO::LOADING, 0, 0, 0, 0, acl::kList}.HFUNC(Acl);
*registry << CI{"ACL LIST", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, 1, 0, 0, 0, acl::kList}.HFUNC(
List);
*registry << CI{"ACL SETUSER", CO::ADMIN | CO::NOSCRIPT | CO::LOADING, -2, 0, 0, 0, acl::kSetUser}
.HFUNC(SetUser);
}

#undef HFUNC
Expand Down
1 change: 1 addition & 0 deletions src/server/acl/acl_family.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class AclFamily {
private:
static void Acl(CmdArgList args, ConnectionContext* cntx);
static void List(CmdArgList args, ConnectionContext* cntx);
static void SetUser(CmdArgList args, ConnectionContext* cntx);
};

} // namespace acl
Expand Down
36 changes: 35 additions & 1 deletion tests/dragonfly/acl_family_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,38 @@ async def test_acl_list_default_user(async_client):

result = await async_client.execute_command("ACL LIST")
assert 1 == len(result)
assert "user default on nopass +@all" == result[0]
assert "user default on nopass +@ALL" == result[0]


@pytest.mark.asyncio
async def test_acl_setuser(async_client):
# Bad input
with pytest.raises(redis.exceptions.ResponseError):
await async_client.execute_command("ACL SETUSER")

await async_client.execute_command("ACL SETUSER kostas")
result = await async_client.execute_command("ACL LIST")
assert 2 == len(result)
assert "user kostas off nopass +@NONE" in result

await async_client.execute_command("ACL SETUSER kostas ON")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@NONE" in result

await async_client.execute_command("ACL SETUSER kostas +@list +@string +@admin")
result = await async_client.execute_command("ACL LIST")
# TODO consider printing to lowercase
assert "user kostas on nopass +@LIST +@STRING +@ADMIN" in result

await async_client.execute_command("ACL SETUSER kostas -@list -@admin")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@STRING" in result

# mix and match
await async_client.execute_command("ACL SETUSER kostas +@list -@string")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@LIST" in result

await async_client.execute_command("ACL SETUSER kostas +@all")
result = await async_client.execute_command("ACL LIST")
assert "user kostas on nopass +@ALL" in result