Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions boltstub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# limitations under the License.


import socket
import time
import traceback
from copy import deepcopy
Expand All @@ -25,7 +26,6 @@
TCPServer,
ThreadingMixIn,
)
from sys import stdout
from threading import (
Lock,
Thread,
Expand Down Expand Up @@ -56,18 +56,27 @@ class BoltStubServer(TCPServer):

timed_out = False

def __init__(self, *args, **kwargs):
super(BoltStubServer, self).__init__(*args, **kwargs)
_ipv6: bool

def __init__(self, *args, ipv6=False, **kwargs) -> None:
self._ipv6 = ipv6
if self._ipv6:
self.address_family = socket.AF_INET6
super().__init__(*args, **kwargs)

def handle_timeout(self):
self.timed_out = True

def server_activate(self):
super(BoltStubServer, self).server_activate()
def server_bind(self) -> None:
if self._ipv6:
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 1)
super().server_bind()

def server_activate(self) -> None:
super().server_activate()
# Must be here, testkit waits for something to be written on stdout to
# know when the server is listening.
print("Listening")
stdout.flush()
print("Listening", flush=True)


class ThreadedBoltStubServer(ThreadingMixIn, BoltStubServer):
Expand All @@ -86,13 +95,14 @@ class BoltStubService:
def load(cls, *script_filenames, **kwargs):
return cls(*map(parse_file, script_filenames), **kwargs)

def __init__(self, script: Script, listen_addr=None, timeout=None):
def __init__(self, script: Script, listen_addr=None, timeout=None,
ipv6=False):
if listen_addr:
listen_addr = Address.parse(listen_addr)
else:
listen_addr = Address(("localhost", self.default_base_port))
self.host = listen_addr.host
self.address = Address((listen_addr.host, listen_addr.port_number))
self.address = Address(listen_addr)
self.script = script
self.exceptions = []
self.actors = []
Expand Down Expand Up @@ -153,7 +163,8 @@ def finish(self):
server_cls = ThreadedBoltStubServer
else:
server_cls = BoltStubServer
self.server = server_cls(self.address, BoltStubRequestHandler)
self.server = server_cls(self.address, BoltStubRequestHandler,
ipv6=ipv6)
self.server.timeout = timeout or self.default_timeout

def start(self):
Expand Down
6 changes: 5 additions & 1 deletion boltstub/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def _main():
"-v", "--verbose", action="store_true",
help="Show more detail about the client-server exchange."
)
parser.add_argument(
"-6", "--ipv6", action="store_true",
help="Force the server to only listen on IPv6 interfaces."
)
parser.add_argument("script", nargs="+")
parsed = parser.parse_args()

Expand All @@ -103,7 +107,7 @@ def _main():

scripts = map(parse_file, parsed.script)
service = BoltStubService(*scripts, listen_addr=parsed.listen_addr,
timeout=parsed.timeout)
timeout=parsed.timeout, ipv6=parsed.ipv6)

try:
service.start()
Expand Down
4 changes: 2 additions & 2 deletions boltstub/addressing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def parse(cls, s, default_host=None, default_port=None):
# IPv6
host, _, port = s[1:].rpartition("]")
return cls((host or default_host or "localhost",
port.lstrip(":") or default_port or 0,
int(port.lstrip(":")) or default_port or 0,
0, 0))
else:
# IPv4
host, _, port = s.partition(":")
return cls((host or default_host or "localhost",
port or default_port or 0))
int(port) or default_port or 0))
else:
raise TypeError("Address.parse requires a string argument")

Expand Down
10 changes: 5 additions & 5 deletions boltstub/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,20 @@ def __repr__(self):

class Line(str, abc.ABC):
def __new__(cls, line_number: int, raw_line, content: str):
obj = super(Line, cls).__new__(cls, raw_line)
obj = super().__new__(cls, raw_line)
obj.line_number = line_number
obj.content = content
return obj

def __str__(self):
return "({:3}) {}".format(self.line_number,
super(Line, self).__str__())
super().__str__())

def __repr__(self):
return "<{}>{}".format(self.__class__.__name__, self.__str__())

def __getnewargs__(self):
return self.line_number, super(Line, self).__str__(), self.content
return self.line_number, super().__str__(), self.content

@abc.abstractmethod
def canonical(self):
Expand Down Expand Up @@ -390,7 +390,7 @@ class ServerLine(MessageLine):
always_parse = False

def __new__(cls, *args, **kwargs):
obj = super(ServerLine, cls).__new__(cls, *args, **kwargs)
obj = super().__new__(cls, *args, **kwargs)
obj.command_match = re.match(r"^<(.+?)>(.*)$", obj.content)
obj.is_command = bool(obj.command_match)
if not obj.is_command:
Expand Down Expand Up @@ -623,7 +623,7 @@ def __init__(self, line: AutoLine, line_number: int):
# A: RESET
# This is to avoid ambiguity when it comes to `?:`, `*:`, and `+:`
# macros.
super(AutoBlock, self).__init__([line], line_number)
super().__init__([line], line_number)

def _consume(self, channel):
msg = channel.consume(self.lines[self.index].line_number)
Expand Down
6 changes: 3 additions & 3 deletions boltstub/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def bright_white(s):
class ColourFormatter(Formatter):

def format(self, record):
s = super(ColourFormatter, self).format(record)
s = super().format(record)
bits = s.split(" ", maxsplit=1)
bits[0] = bright_black(bits[0])
if record.levelno == CRITICAL:
Expand All @@ -121,13 +121,13 @@ def formatTime(self, record, datefmt=None): # noqa: N802
return f"{t}.{ms:03d}"


class Watcher(object):
class Watcher:
"""Log watcher for monitoring driver and protocol activity."""

handlers = {}

def __init__(self, logger_name):
super(Watcher, self).__init__()
super().__init__()
self.logger_name = logger_name
self.logger = getLogger(self.logger_name)
self.formatter = ColourFormatter("%(asctime)s %(message)s")
Expand Down
2 changes: 1 addition & 1 deletion boltstub/wiring.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def send(self, payload) -> int:
sendall = send


class Wire(object):
class Wire:
"""Buffered socket wrapper for reading and writing bytes."""

_closed = False
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _exit():
# address to be able to start services on the network that the driver
# connects to (stub server and TLS server).
for network in networks:
cmd = ["docker", "network", "create", network]
cmd = ["docker", "network", "create", "--ipv6", network]
print(cmd)
subprocess.run(cmd)

Expand Down
7 changes: 4 additions & 3 deletions tests/neo4j/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from nutkit.frontend import Driver
from nutkit.protocol import AuthorizationToken
from tests.shared import (
dns_resolve_single,
dns_resolve,
Potential,
TestkitTestCase,
)
Expand Down Expand Up @@ -56,9 +56,10 @@ def get_neo4j_host_and_port():
return host, port


def get_neo4j_resolved_host_and_port():
def get_neo4j_host_and_port_resolutions():
host, port = get_neo4j_host_and_port()
return dns_resolve_single(host), port
return [(resolved, port)
for resolved in dns_resolve(host, ipv4=True, ipv6=True)]


def get_neo4j_host_and_http_port():
Expand Down
17 changes: 12 additions & 5 deletions tests/neo4j/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
cluster_unsafe_test,
get_driver,
get_neo4j_host_and_port,
get_neo4j_resolved_host_and_port,
get_neo4j_host_and_port_resolutions,
get_server_info,
QueryBuilder,
requires_multi_db_support,
)
from tests.shared import TestkitTestCase
from tests.shared import (
format_address,
TestkitTestCase,
)


class TestSummary(TestkitTestCase):
Expand Down Expand Up @@ -120,9 +123,13 @@ def test_agent_string(self):
@cluster_unsafe_test # routing can lead us to another server (address)
def test_address(self):
summary = self.get_summary("RETURN 1 AS number")
self.assertTrue(summary.server_info.address in
["%s:%s" % get_neo4j_resolved_host_and_port(),
"%s:%s" % get_neo4j_host_and_port()])

expected = [
format_address(host, port)
for (host, port) in
(*get_neo4j_host_and_port_resolutions(), get_neo4j_host_and_port())
]
self.assertTrue(summary.server_info.address in expected)

def _assert_counters(self, summary, nodes_created=0, nodes_deleted=0,
relationships_created=0, relationships_deleted=0,
Expand Down
29 changes: 22 additions & 7 deletions tests/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,36 @@ def pick_address(adapter_):
return ips


def dns_resolve(host_name):
_, _, ip_addresses = socket.gethostbyname_ex(host_name)
return ip_addresses
def dns_resolve(host_name, ipv4=True, ipv6=False):
return list({
address[0]
for family, _type, _proto, _canonname, address in socket.getaddrinfo(
host_name, None
)
if (
(family == socket.AF_INET and ipv4)
or (family == socket.AF_INET6 and ipv6)
)
})


def dns_resolve_single(host_name):
ips = dns_resolve(host_name)
def dns_resolve_single(host_name, ipv4=True, ipv6=False):
ips = dns_resolve(host_name, ipv4=ipv4, ipv6=ipv6)
if len(ips) != 1:
raise ValueError("%s resolved to %i instead of 1 IP address"
% (host_name, len(ips)))
return ips[0]


def get_dns_resolved_server_address(server):
return "%s:%i" % (dns_resolve_single(server.host), server.port)
def format_address(host, port):
if ":" in host:
return f"[{host}]:{port}"
return f"{host}:{port}"


def get_dns_resolved_server_address(server, ipv4=True, ipv6=False):
host = dns_resolve_single(server.host, ipv4=ipv4, ipv6=ipv6)
return format_address(host, server.port)


def driver_feature(*features):
Expand Down
44 changes: 28 additions & 16 deletions tests/stub/routing/_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,7 @@
class RoutingBase(TestkitTestCase):
def setUp(self):
super().setUp()
self._routingServer1 = StubServer(9000)
self._routingServer2 = StubServer(9001)
self._routingServer3 = StubServer(9002)
self._readServer1 = StubServer(9010)
self._readServer2 = StubServer(9011)
self._readServer3 = StubServer(9012)
self._writeServer1 = StubServer(9020)
self._writeServer2 = StubServer(9021)
self._writeServer3 = StubServer(9022)
self._uri_template = "neo4j://%s:%d"
self._uri_template_with_context = \
self._uri_template + "?region=china&policy=my_policy"
self._uri_with_context = self._uri_template_with_context % (
self._routingServer1.host, self._routingServer1.port
)
self.set_up_servers()
self._auth = types.AuthorizationToken(
"basic", principal="p", credentials="c"
)
Expand All @@ -45,6 +31,26 @@ def tearDown(self):
self._writeServer3.reset()
super().tearDown()

def set_up_servers(self, ipv6=False):
self._routingServer1 = StubServer(9000, ipv6=ipv6)
self._routingServer2 = StubServer(9001, ipv6=ipv6)
self._routingServer3 = StubServer(9002, ipv6=ipv6)
self._readServer1 = StubServer(9010, ipv6=ipv6)
self._readServer2 = StubServer(9011, ipv6=ipv6)
self._readServer3 = StubServer(9012, ipv6=ipv6)
self._writeServer1 = StubServer(9020, ipv6=ipv6)
self._writeServer2 = StubServer(9021, ipv6=ipv6)
self._writeServer3 = StubServer(9022, ipv6=ipv6)
self._set_up_uris()

def _set_up_uris(self):
self._uri_template = "neo4j://%s"
self._uri_template_with_context = \
self._uri_template + "?region=china&policy=my_policy"
self._uri_with_context = self._uri_template_with_context % (
self._routingServer1.address
)

@property
@abstractmethod
def bolt_version(self):
Expand All @@ -60,9 +66,15 @@ def server_agent(self):
def adb(self):
pass

def get_vars(self, host=None):
def host_in_address(self, host=None):
if host is None:
host = self._routingServer1.host
if ":" in host:
host = f"[{host}]"
return host

def get_vars(self, host=None):
host = self.host_in_address(host)
v = {
"#VERSION#": self.bolt_version,
"#HOST#": host,
Expand Down
6 changes: 4 additions & 2 deletions tests/stub/routing/test_routing_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ class RoutingV3(RoutingV4x4):
server_agent = "Neo4j/3.5.0"

def get_vars(self, host=None):
if host is None:
host = self._routingServer1.host
host = self.host_in_address(host)
v = {
"#VERSION#": self.bolt_version,
"#HOST#": host,
Expand Down Expand Up @@ -82,3 +81,6 @@ def test_should_fail_on_empty_routing_response(self):

def test_should_drop_connections_failing_liveness_check(self):
super().test_should_drop_connections_failing_liveness_check()

def test_ipv6_read(self):
super().test_ipv6_read()
6 changes: 4 additions & 2 deletions tests/stub/routing/test_routing_v4x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ class RoutingV4x1(RoutingV4x4):
server_agent = "Neo4j/4.1.0"

def get_vars(self, host=None):
if host is None:
host = self._routingServer1.host
host = self.host_in_address(host)
v = {
"#VERSION#": self.bolt_version,
"#HOST#": host,
Expand Down Expand Up @@ -65,3 +64,6 @@ def test_should_pass_system_bookmark_when_getting_rt_for_multi_db(self):

def test_should_ignore_system_bookmark_when_getting_rt_for_multi_db(self):
pass

def test_ipv6_read(self):
super().test_ipv6_read()
Loading