Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions net/mptcp/protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ static u64 mptcp_wnd_end(const struct mptcp_sock *msk)

static const struct proto_ops *mptcp_fallback_tcp_ops(const struct sock *sk)
{
unsigned short family = READ_ONCE(sk->sk_family);

#if IS_ENABLED(CONFIG_MPTCP_IPV6)
if (sk->sk_prot == &tcpv6_prot)
if (family == AF_INET6)
return &inet6_stream_ops;
#endif
WARN_ON_ONCE(sk->sk_prot != &tcp_prot);
WARN_ON_ONCE(family != AF_INET);
return &inet_stream_ops;
}

Expand Down
8 changes: 8 additions & 0 deletions net/mptcp/subflow.c
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,10 @@ void __init mptcp_subflow_init(void)
tcp_prot_override = tcp_prot;
tcp_prot_override.release_cb = tcp_release_cb_override;
tcp_prot_override.diag_destroy = tcp_abort_override;
#ifdef CONFIG_BPF_SYSCALL
/* Disable sockmap processing for subflows */
tcp_prot_override.psock_update_sk_prot = NULL;
#endif

#if IS_ENABLED(CONFIG_MPTCP_IPV6)
/* In struct mptcp_subflow_request_sock, we assume the TCP request sock
Expand Down Expand Up @@ -2180,6 +2184,10 @@ void __init mptcp_subflow_init(void)
tcpv6_prot_override = tcpv6_prot;
tcpv6_prot_override.release_cb = tcp_release_cb_override;
tcpv6_prot_override.diag_destroy = tcp_abort_override;
#ifdef CONFIG_BPF_SYSCALL
/* Disable sockmap processing for subflows */
tcpv6_prot_override.psock_update_sk_prot = NULL;
#endif
#endif

mptcp_diag_subflow_init(&subflow_ulp_ops);
Expand Down
141 changes: 141 additions & 0 deletions tools/testing/selftests/bpf/prog_tests/mptcp.c
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
#include <netinet/in.h>
#include <test_progs.h>
#include <unistd.h>
#include <error.h>
#include "cgroup_helpers.h"
#include "network_helpers.h"
#include "socket_helpers.h"
#include "mptcp_sock.skel.h"
#include "mptcpify.skel.h"
#include "mptcp_subflow.skel.h"
#include "mptcp_sockmap.skel.h"

#define NS_TEST "mptcp_ns"
#define ADDR_1 "10.0.1.1"
Expand Down Expand Up @@ -436,6 +439,142 @@ static void test_subflow(void)
close(cgroup_fd);
}

/* Test sockmap on MPTCP server handling non-mp-capable clients. */
static void test_sockmap_with_mptcp_fallback(struct mptcp_sockmap *skel)
{
int listen_fd = -1, client_fd1 = -1, client_fd2 = -1;
int server_fd1 = -1, server_fd2 = -1, sent, recvd;
char snd[9] = "123456789";
char rcv[10];

/* start server with MPTCP enabled */
listen_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
if (!ASSERT_OK_FD(listen_fd, "sockmap-fb:start_mptcp_server"))
return;

skel->bss->trace_port = ntohs(get_socket_local_port(listen_fd));
skel->bss->sk_index = 0;
/* create client without MPTCP enabled */
client_fd1 = connect_to_fd_opts(listen_fd, NULL);
if (!ASSERT_OK_FD(client_fd1, "sockmap-fb:connect_to_fd"))
goto end;

server_fd1 = xaccept_nonblock(listen_fd, NULL, NULL);
skel->bss->sk_index = 1;
client_fd2 = connect_to_fd_opts(listen_fd, NULL);
if (!ASSERT_OK_FD(client_fd2, "sockmap-fb:connect_to_fd"))
goto end;

server_fd2 = xaccept_nonblock(listen_fd, NULL, NULL);
/* test normal redirect behavior: data sent by client_fd1 can be
* received by client_fd2
*/
skel->bss->redirect_idx = 1;
sent = xsend(client_fd1, snd, sizeof(snd), 0);
if (!ASSERT_EQ(sent, sizeof(snd), "sockmap-fb:xsend(client_fd1)"))
goto end;

/* try to recv more bytes to avoid truncation check */
recvd = recv_timeout(client_fd2, rcv, sizeof(rcv), MSG_DONTWAIT, 2);
if (!ASSERT_EQ(recvd, sizeof(snd), "sockmap-fb:recv(client_fd2)"))
goto end;

end:
if (client_fd1 >= 0)
close(client_fd1);
if (client_fd2 >= 0)
close(client_fd2);
if (server_fd1 >= 0)
close(server_fd1);
if (server_fd2 >= 0)
close(server_fd2);
close(listen_fd);
}

/* Test sockmap rejection of MPTCP sockets - both server and client sides. */
static void test_sockmap_reject_mptcp(struct mptcp_sockmap *skel)
{
int listen_fd = -1, server_fd = -1, client_fd1 = -1;
int err, zero = 0;

/* start server with MPTCP enabled */
listen_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
if (!ASSERT_OK_FD(listen_fd, "start_mptcp_server"))
return;

skel->bss->trace_port = ntohs(get_socket_local_port(listen_fd));
skel->bss->sk_index = 0;
/* create client with MPTCP enabled */
client_fd1 = connect_to_fd(listen_fd, 0);
if (!ASSERT_OK_FD(client_fd1, "connect_to_fd client_fd1"))
goto end;

/* bpf_sock_map_update() called from sockops should reject MPTCP sk */
if (!ASSERT_EQ(skel->bss->helper_ret, -EOPNOTSUPP, "should reject"))
goto end;

server_fd = xaccept_nonblock(listen_fd, NULL, NULL);
err = bpf_map_update_elem(bpf_map__fd(skel->maps.sock_map),
&zero, &server_fd, BPF_NOEXIST);
if (!ASSERT_EQ(err, -EOPNOTSUPP, "server should be disallowed"))
goto end;

/* MPTCP client should also be disallowed */
err = bpf_map_update_elem(bpf_map__fd(skel->maps.sock_map),
&zero, &client_fd1, BPF_NOEXIST);
if (!ASSERT_EQ(err, -EOPNOTSUPP, "client should be disallowed"))
goto end;
end:
if (client_fd1 >= 0)
close(client_fd1);
if (server_fd >= 0)
close(server_fd);
close(listen_fd);
}

static void test_mptcp_sockmap(void)
{
struct mptcp_sockmap *skel;
struct netns_obj *netns;
int cgroup_fd, err;

cgroup_fd = test__join_cgroup("/mptcp_sockmap");
if (!ASSERT_OK_FD(cgroup_fd, "join_cgroup: mptcp_sockmap"))
return;

skel = mptcp_sockmap__open_and_load();
if (!ASSERT_OK_PTR(skel, "skel_open_load: mptcp_sockmap"))
goto close_cgroup;

skel->links.mptcp_sockmap_inject =
bpf_program__attach_cgroup(skel->progs.mptcp_sockmap_inject, cgroup_fd);
if (!ASSERT_OK_PTR(skel->links.mptcp_sockmap_inject, "attach sockmap"))
goto skel_destroy;

err = bpf_prog_attach(bpf_program__fd(skel->progs.mptcp_sockmap_redirect),
bpf_map__fd(skel->maps.sock_map),
BPF_SK_SKB_STREAM_VERDICT, 0);
if (!ASSERT_OK(err, "bpf_prog_attach stream verdict"))
goto skel_destroy;

netns = netns_new(NS_TEST, true);
if (!ASSERT_OK_PTR(netns, "netns_new: mptcp_sockmap"))
goto skel_destroy;

if (endpoint_init("subflow") < 0)
goto close_netns;

test_sockmap_with_mptcp_fallback(skel);
test_sockmap_reject_mptcp(skel);

close_netns:
netns_free(netns);
skel_destroy:
mptcp_sockmap__destroy(skel);
close_cgroup:
close(cgroup_fd);
}

void test_mptcp(void)
{
if (test__start_subtest("base"))
Expand All @@ -444,4 +583,6 @@ void test_mptcp(void)
test_mptcpify();
if (test__start_subtest("subflow"))
test_subflow();
if (test__start_subtest("sockmap"))
test_mptcp_sockmap();
}
43 changes: 43 additions & 0 deletions tools/testing/selftests/bpf/progs/mptcp_sockmap.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// SPDX-License-Identifier: GPL-2.0

#include "bpf_tracing_net.h"

char _license[] SEC("license") = "GPL";

int sk_index;
int redirect_idx;
int trace_port;
int helper_ret;
struct {
__uint(type, BPF_MAP_TYPE_SOCKMAP);
__uint(key_size, sizeof(__u32));
__uint(value_size, sizeof(__u32));
__uint(max_entries, 100);
} sock_map SEC(".maps");

SEC("sockops")
int mptcp_sockmap_inject(struct bpf_sock_ops *skops)
{
struct bpf_sock *sk;

/* only accept specified connection */
if (skops->local_port != trace_port ||
skops->op != BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB)
return 1;

sk = skops->sk;
if (!sk)
return 1;

/* update sk handler */
helper_ret = bpf_sock_map_update(skops, &sock_map, &sk_index, BPF_NOEXIST);

return 1;
}

SEC("sk_skb/stream_verdict")
int mptcp_sockmap_redirect(struct __sk_buff *skb)
{
/* redirect skb to the sk under sock_map[redirect_idx] */
return bpf_sk_redirect_map(skb, &sock_map, redirect_idx, 0);
}
Loading