Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions src/facade/dragonfly_connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ ABSL_FLAG(std::string, admin_bind, "",
ABSL_FLAG(std::uint64_t, request_cache_limit, 1ULL << 26, // 64MB
"Amount of memory to use for request cache in bytes - per IO thread.");

ABSL_FLAG(bool, no_tls_on_admin_port, false, "Allow non-tls connections on admin port");

using namespace util;
using namespace std;
using nonstd::make_unexpected;
Expand Down Expand Up @@ -300,22 +302,24 @@ void Connection::HandleRequests() {

auto remote_ep = lsb->RemoteEndpoint();

FiberSocketBase* peer = socket_.get();
#ifdef DFLY_USE_SSL
unique_ptr<tls::TlsSocket> tls_sock;
if (ctx_) {
tls_sock.reset(new tls::TlsSocket(socket_.get()));
tls_sock->InitSSL(ctx_);

FiberSocketBase::AcceptResult aresult = tls_sock->Accept();
if (!aresult) {
LOG(WARNING) << "Error handshaking " << aresult.error().message();
return;
const bool no_tls_on_admin_port = absl::GetFlag(FLAGS_no_tls_on_admin_port);
if (!(IsAdmin() && no_tls_on_admin_port)) {
tls_sock.reset(new tls::TlsSocket(socket_.get()));
tls_sock->InitSSL(ctx_);
FiberSocketBase::AcceptResult aresult = tls_sock->Accept();

if (!aresult) {
LOG(WARNING) << "Error handshaking " << aresult.error().message();
return;
}
peer = tls_sock.get();
VLOG(1) << "TLS handshake succeeded";
}
VLOG(1) << "TLS handshake succeeded";
}
FiberSocketBase* peer = tls_sock ? (FiberSocketBase*)tls_sock.get() : socket_.get();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why was this removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I simplified it a little bit, it gets initialized at the top :)

Copy link
Contributor

@adiholden adiholden Jul 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am missing something but in the old flow tls_sock is assigned to peer if tls_sock not null. I dont that you update peer with tls_sock in the new flow

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is it working?
I would expect the test you wrote to fail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed I messed this up on refactoring and HUGE thumps up for catching this. It works, because we only test the admin port and not the TLS connection itself -- we have 0 tests for that (although my other PR addresses this and it would have been caught there).

#else
FiberSocketBase* peer = socket_.get();
#endif

io::Result<bool> http_res{false};
Expand All @@ -335,7 +339,6 @@ void Connection::HandleRequests() {
http_conn.ReleaseSocket();
} else {
cc_.reset(service_->CreateContext(peer, this));

auto* us = static_cast<LinuxSocketBase*>(socket_.get());
if (breaker_cb_) {
break_poll_id_ =
Expand Down
5 changes: 5 additions & 0 deletions tests/dragonfly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
import subprocess
import aiohttp
import os
from prometheus_client.parser import text_string_to_metric_families
from redis.asyncio import Redis as RedisClient

Expand Down Expand Up @@ -159,6 +160,10 @@ def stop_all(self):
def __str__(self):
return f"Factory({self.args})"

@property
def dfly_path(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use os.path.basedir here? hard-coding the test binary name is a bit annoying if people try to bisect regressions etc

return str(os.path.dirname(self.params.path))


def dfly_args(*args):
""" Used to define a singular set of arguments for dragonfly test """
Expand Down
43 changes: 43 additions & 0 deletions tests/dragonfly/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import redis
import pymemcache
import random
import subprocess

from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -225,3 +226,45 @@ def port_picker():
@pytest.fixture(scope="class")
def memcached_connection(df_server: DflyInstance):
return pymemcache.Client(f"localhost:{df_server.mc_port}")


@pytest.fixture(scope="session")
def gen_tls_cert(df_factory: DflyInstanceFactory):
tls_server_key_file_name = "df-key.pem"
tls_server_cert_file_name = "df-cert.pem"
dfly_path = df_factory.dfly_path
# We first need to generate the tls certificates to be used by the server

# Step 1
# Generate CA (certificate authority) key and self-signed certificate
# In production, CA should be generated by a third party authority
# Expires in one day and is not encrtypted (-nodes)
# X.509 format for the key
ca_key = dfly_path + "ca-key.pem"
ca_cert = dfly_path + "ca-cert.pem"
step1 = rf'openssl req -x509 -newkey rsa:4096 -days 1 -nodes -keyout {ca_key} -out {ca_cert} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=AcmeStudios/CN=Gr/[email protected]"'
subprocess.run(step1, shell=True)

# Step 2
# Generate Dragonfly's private key and certificate signing request (CSR)
tls_server_key = dfly_path + tls_server_key_file_name
tls_server_req = dfly_path + "df-req.pem"
step2 = rf'openssl req -newkey rsa:4096 -nodes -keyout {tls_server_key} -out {tls_server_req} -subj "/C=GR/ST=SKG/L=Thessaloniki/O=KK/OU=Comp/CN=Gr/[email protected]"'
subprocess.run(step2, shell=True)

# Step 3
# Use CA's private key to sign dragonfly's CSR and get back the signed certificate
tls_server_cert = dfly_path + tls_server_cert_file_name
step3 = fr'openssl x509 -req -in {tls_server_req} -days 1 -CA {ca_cert} -CAkey {ca_key} -CAcreateserial -out {tls_server_cert}'
subprocess.run(step3, shell=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gen_tls_cert can return the path of key and certificate file names and instead of defining tls_server_cert_file_name and tls_server_key_file_name set the filenames to constant var

return tls_server_key_file_name, tls_server_cert_file_name


@pytest.fixture(scope="session")
def with_tls_args(df_factory: DflyInstanceFactory, gen_tls_cert):
tls_server_key_file_name, tls_server_cert_file_name = gen_tls_cert
args = {"tls": "",
"tls_key_file": df_factory.dfly_path + tls_server_key_file_name,
"tls_cert_file": df_factory.dfly_path + tls_server_cert_file_name,
"no_tls_on_admin_port": "true"}
return args
19 changes: 19 additions & 0 deletions tests/dragonfly/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest
import asyncio
from redis import asyncio as aioredis
from redis.exceptions import ConnectionError as redis_conn_error
import async_timeout

from . import DflyInstance, dfly_args
Expand Down Expand Up @@ -415,3 +416,21 @@ async def test_large_cmd(async_client: aioredis.Redis):

res = await async_client.mget([f"key{i}" for i in range(MAX_ARR_SIZE)])
assert len(res) == MAX_ARR_SIZE


@pytest.mark.asyncio
async def test_reject_non_tls_connections_on_tls_master(with_tls_args, df_local_factory):
master = df_local_factory.create(admin_port=1111, port=1211, **with_tls_args)
master.start()

# Try to connect on master without admin port. This should fail.
client = aioredis.Redis(port=master.port)
try:
await client.execute_command("DBSIZE")
raise "Non tls connection connected on master with tls. This should NOT happen"
except redis_conn_error:
pass

# Try to connect on master on admin port
client = aioredis.Redis(port=master.admin_port)
assert await client.ping()
29 changes: 29 additions & 0 deletions tests/dragonfly/replication_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging

BASE_PORT = 1111
ADMIN_PORT = 1211

DISCONNECT_CRASH_FULL_SYNC = 0
DISCONNECT_CRASH_STABLE_SYNC = 1
Expand Down Expand Up @@ -1216,3 +1217,31 @@ async def seed():

assert await c_master.execute_command("role") == [b'master', [[b'127.0.0.1', bytes(str(replica.port), 'ascii'), b'stable_sync']]]
assert await c_replica.execute_command("role") == [b'replica', b'localhost', bytes(str(master.port), 'ascii'), b'stable_sync']


# 1. Number of master threads
# 2. Number of threads for each replica
replication_cases = [(8, 8)]

@pytest.mark.asyncio
@pytest.mark.parametrize("t_master, t_replica", replication_cases)
async def test_no_tls_on_admin_port(df_local_factory, df_seeder_factory, t_master, t_replica, with_tls_args):
# 1. Spin up dragonfly without tls, debug populate
master = df_local_factory.create(admin_port=ADMIN_PORT, **with_tls_args, port=BASE_PORT, proactor_threads=t_master)
master.start()
c_master = aioredis.Redis(port=master.admin_port)
await c_master.execute_command("DEBUG POPULATE 100")
db_size = await c_master.execute_command("DBSIZE")
assert 100 == db_size

# 2. Spin up a replica and initiate a REPLICAOF
replica = df_local_factory.create(admin_port=ADMIN_PORT + 1, **with_tls_args, port=BASE_PORT + 1, proactor_threads=t_replica)
replica.start()
c_replica = aioredis.Redis(port=replica.admin_port)
res = await c_replica.execute_command("REPLICAOF localhost " + str(master.admin_port))
assert b"OK" == res
await check_all_replicas_finished([c_replica], c_master)

# 3. Verify that replica dbsize == debug populate key size -- replication works
db_size = await c_replica.execute_command("DBSIZE")
assert 100 == db_size