Contributors: 3
Author Tokens Token Proportion Commits Commit Proportion
Jordan Rife 1909 66.42% 2 50.00%
Martin KaFai Lau 964 33.54% 1 25.00%
Alexei Starovoitov 1 0.03% 1 25.00%
Total 2874 4


// SPDX-License-Identifier: GPL-2.0
// Copyright (c) 2024 Meta

#include <test_progs.h>
#include "network_helpers.h"
#include "sock_iter_batch.skel.h"

#define TEST_NS "sock_iter_batch_netns"

static const int init_batch_size = 16;
static const int nr_soreuse = 4;

struct iter_out {
	int idx;
	__u64 cookie;
} __packed;

struct sock_count {
	__u64 cookie;
	int count;
};

static int insert(__u64 cookie, struct sock_count counts[], int counts_len)
{
	int insert = -1;
	int i = 0;

	for (; i < counts_len; i++) {
		if (!counts[i].cookie) {
			insert = i;
		} else if (counts[i].cookie == cookie) {
			insert = i;
			break;
		}
	}
	if (insert < 0)
		return insert;

	counts[insert].cookie = cookie;
	counts[insert].count++;

	return counts[insert].count;
}

static int read_n(int iter_fd, int n, struct sock_count counts[],
		  int counts_len)
{
	struct iter_out out;
	int nread = 1;
	int i = 0;

	for (; nread > 0 && (n < 0 || i < n); i++) {
		nread = read(iter_fd, &out, sizeof(out));
		if (!nread || !ASSERT_EQ(nread, sizeof(out), "nread"))
			break;
		ASSERT_GE(insert(out.cookie, counts, counts_len), 0, "insert");
	}

	ASSERT_TRUE(n < 0 || i == n, "n < 0 || i == n");

	return i;
}

static __u64 socket_cookie(int fd)
{
	__u64 cookie;
	socklen_t cookie_len = sizeof(cookie);

	if (!ASSERT_OK(getsockopt(fd, SOL_SOCKET, SO_COOKIE, &cookie,
				  &cookie_len), "getsockopt(SO_COOKIE)"))
		return 0;
	return cookie;
}

static bool was_seen(int fd, struct sock_count counts[], int counts_len)
{
	__u64 cookie = socket_cookie(fd);
	int i = 0;

	for (; cookie && i < counts_len; i++)
		if (cookie == counts[i].cookie)
			return true;

	return false;
}

static int get_seen_socket(int *fds, struct sock_count counts[], int n)
{
	int i = 0;

	for (; i < n; i++)
		if (was_seen(fds[i], counts, n))
			return i;
	return -1;
}

static int get_nth_socket(int *fds, int fds_len, struct bpf_link *link, int n)
{
	int i, nread, iter_fd;
	int nth_sock_idx = -1;
	struct iter_out out;

	iter_fd = bpf_iter_create(bpf_link__fd(link));
	if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create"))
		return -1;

	for (; n >= 0; n--) {
		nread = read(iter_fd, &out, sizeof(out));
		if (!nread || !ASSERT_GE(nread, 1, "nread"))
			goto done;
	}

	for (i = 0; i < fds_len && nth_sock_idx < 0; i++)
		if (fds[i] >= 0 && socket_cookie(fds[i]) == out.cookie)
			nth_sock_idx = i;
done:
	close(iter_fd);
	return nth_sock_idx;
}

static int get_seen_count(int fd, struct sock_count counts[], int n)
{
	__u64 cookie = socket_cookie(fd);
	int count = 0;
	int i = 0;

	for (; cookie && !count && i < n; i++)
		if (cookie == counts[i].cookie)
			count = counts[i].count;

	return count;
}

static void check_n_were_seen_once(int *fds, int fds_len, int n,
				   struct sock_count counts[], int counts_len)
{
	int seen_once = 0;
	int seen_cnt;
	int i = 0;

	for (; i < fds_len; i++) {
		/* Skip any sockets that were closed or that weren't seen
		 * exactly once.
		 */
		if (fds[i] < 0)
			continue;
		seen_cnt = get_seen_count(fds[i], counts, counts_len);
		if (seen_cnt && ASSERT_EQ(seen_cnt, 1, "seen_cnt"))
			seen_once++;
	}

	ASSERT_EQ(seen_once, n, "seen_once");
}

static void remove_seen(int family, int sock_type, const char *addr, __u16 port,
			int *socks, int socks_len, struct sock_count *counts,
			int counts_len, struct bpf_link *link, int iter_fd)
{
	int close_idx;

	/* Iterate through the first socks_len - 1 sockets. */
	read_n(iter_fd, socks_len - 1, counts, counts_len);

	/* Make sure we saw socks_len - 1 sockets exactly once. */
	check_n_were_seen_once(socks, socks_len, socks_len - 1, counts,
			       counts_len);

	/* Close a socket we've already seen to remove it from the bucket. */
	close_idx = get_seen_socket(socks, counts, counts_len);
	if (!ASSERT_GE(close_idx, 0, "close_idx"))
		return;
	close(socks[close_idx]);
	socks[close_idx] = -1;

	/* Iterate through the rest of the sockets. */
	read_n(iter_fd, -1, counts, counts_len);

	/* Make sure the last socket wasn't skipped and that there were no
	 * repeats.
	 */
	check_n_were_seen_once(socks, socks_len, socks_len - 1, counts,
			       counts_len);
}

static void remove_unseen(int family, int sock_type, const char *addr,
			  __u16 port, int *socks, int socks_len,
			  struct sock_count *counts, int counts_len,
			  struct bpf_link *link, int iter_fd)
{
	int close_idx;

	/* Iterate through the first socket. */
	read_n(iter_fd, 1, counts, counts_len);

	/* Make sure we saw a socket from fds. */
	check_n_were_seen_once(socks, socks_len, 1, counts, counts_len);

	/* Close what would be the next socket in the bucket to exercise the
	 * condition where we need to skip past the first cookie we remembered.
	 */
	close_idx = get_nth_socket(socks, socks_len, link, 1);
	if (!ASSERT_GE(close_idx, 0, "close_idx"))
		return;
	close(socks[close_idx]);
	socks[close_idx] = -1;

	/* Iterate through the rest of the sockets. */
	read_n(iter_fd, -1, counts, counts_len);

	/* Make sure the remaining sockets were seen exactly once and that we
	 * didn't repeat the socket that was already seen.
	 */
	check_n_were_seen_once(socks, socks_len, socks_len - 1, counts,
			       counts_len);
}

static void remove_all(int family, int sock_type, const char *addr,
		       __u16 port, int *socks, int socks_len,
		       struct sock_count *counts, int counts_len,
		       struct bpf_link *link, int iter_fd)
{
	int close_idx, i;

	/* Iterate through the first socket. */
	read_n(iter_fd, 1, counts, counts_len);

	/* Make sure we saw a socket from fds. */
	check_n_were_seen_once(socks, socks_len, 1, counts, counts_len);

	/* Close all remaining sockets to exhaust the list of saved cookies and
	 * exit without putting any sockets into the batch on the next read.
	 */
	for (i = 0; i < socks_len - 1; i++) {
		close_idx = get_nth_socket(socks, socks_len, link, 1);
		if (!ASSERT_GE(close_idx, 0, "close_idx"))
			return;
		close(socks[close_idx]);
		socks[close_idx] = -1;
	}

	/* Make sure there are no more sockets returned */
	ASSERT_EQ(read_n(iter_fd, -1, counts, counts_len), 0, "read_n");
}

static void add_some(int family, int sock_type, const char *addr, __u16 port,
		     int *socks, int socks_len, struct sock_count *counts,
		     int counts_len, struct bpf_link *link, int iter_fd)
{
	int *new_socks = NULL;

	/* Iterate through the first socks_len - 1 sockets. */
	read_n(iter_fd, socks_len - 1, counts, counts_len);

	/* Make sure we saw socks_len - 1 sockets exactly once. */
	check_n_were_seen_once(socks, socks_len, socks_len - 1, counts,
			       counts_len);

	/* Double the number of sockets in the bucket. */
	new_socks = start_reuseport_server(family, sock_type, addr, port, 0,
					   socks_len);
	if (!ASSERT_OK_PTR(new_socks, "start_reuseport_server"))
		goto done;

	/* Iterate through the rest of the sockets. */
	read_n(iter_fd, -1, counts, counts_len);

	/* Make sure each of the original sockets was seen exactly once. */
	check_n_were_seen_once(socks, socks_len, socks_len, counts,
			       counts_len);
done:
	free_fds(new_socks, socks_len);
}

static void force_realloc(int family, int sock_type, const char *addr,
			  __u16 port, int *socks, int socks_len,
			  struct sock_count *counts, int counts_len,
			  struct bpf_link *link, int iter_fd)
{
	int *new_socks = NULL;

	/* Iterate through the first socket just to initialize the batch. */
	read_n(iter_fd, 1, counts, counts_len);

	/* Double the number of sockets in the bucket to force a realloc on the
	 * next read.
	 */
	new_socks = start_reuseport_server(family, sock_type, addr, port, 0,
					   socks_len);
	if (!ASSERT_OK_PTR(new_socks, "start_reuseport_server"))
		goto done;

	/* Iterate through the rest of the sockets. */
	read_n(iter_fd, -1, counts, counts_len);

	/* Make sure each socket from the first set was seen exactly once. */
	check_n_were_seen_once(socks, socks_len, socks_len, counts,
			       counts_len);
done:
	free_fds(new_socks, socks_len);
}

struct test_case {
	void (*test)(int family, int sock_type, const char *addr, __u16 port,
		     int *socks, int socks_len, struct sock_count *counts,
		     int counts_len, struct bpf_link *link, int iter_fd);
	const char *description;
	int init_socks;
	int max_socks;
	int sock_type;
	int family;
};

static struct test_case resume_tests[] = {
	{
		.description = "udp: resume after removing a seen socket",
		.init_socks = nr_soreuse,
		.max_socks = nr_soreuse,
		.sock_type = SOCK_DGRAM,
		.family = AF_INET6,
		.test = remove_seen,
	},
	{
		.description = "udp: resume after removing one unseen socket",
		.init_socks = nr_soreuse,
		.max_socks = nr_soreuse,
		.sock_type = SOCK_DGRAM,
		.family = AF_INET6,
		.test = remove_unseen,
	},
	{
		.description = "udp: resume after removing all unseen sockets",
		.init_socks = nr_soreuse,
		.max_socks = nr_soreuse,
		.sock_type = SOCK_DGRAM,
		.family = AF_INET6,
		.test = remove_all,
	},
	{
		.description = "udp: resume after adding a few sockets",
		.init_socks = nr_soreuse,
		.max_socks = nr_soreuse,
		.sock_type = SOCK_DGRAM,
		/* Use AF_INET so that new sockets are added to the head of the
		 * bucket's list.
		 */
		.family = AF_INET,
		.test = add_some,
	},
	{
		.description = "udp: force a realloc to occur",
		.init_socks = init_batch_size,
		.max_socks = init_batch_size * 2,
		.sock_type = SOCK_DGRAM,
		/* Use AF_INET6 so that new sockets are added to the tail of the
		 * bucket's list, needing to be added to the next batch to force
		 * a realloc.
		 */
		.family = AF_INET6,
		.test = force_realloc,
	},
};

static void do_resume_test(struct test_case *tc)
{
	struct sock_iter_batch *skel = NULL;
	static const __u16 port = 10001;
	struct bpf_link *link = NULL;
	struct sock_count *counts;
	int err, iter_fd = -1;
	const char *addr;
	int *fds = NULL;
	int local_port;

	counts = calloc(tc->max_socks, sizeof(*counts));
	if (!ASSERT_OK_PTR(counts, "counts"))
		goto done;
	skel = sock_iter_batch__open();
	if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open"))
		goto done;

	/* Prepare a bucket of sockets in the kernel hashtable */
	addr = tc->family == AF_INET6 ? "::1" : "127.0.0.1";
	fds = start_reuseport_server(tc->family, tc->sock_type, addr, port, 0,
				     tc->init_socks);
	if (!ASSERT_OK_PTR(fds, "start_reuseport_server"))
		goto done;
	local_port = get_socket_local_port(*fds);
	if (!ASSERT_GE(local_port, 0, "get_socket_local_port"))
		goto done;
	skel->rodata->ports[0] = ntohs(local_port);
	skel->rodata->sf = tc->family;

	err = sock_iter_batch__load(skel);
	if (!ASSERT_OK(err, "sock_iter_batch__load"))
		goto done;

	link = bpf_program__attach_iter(tc->sock_type == SOCK_STREAM ?
					skel->progs.iter_tcp_soreuse :
					skel->progs.iter_udp_soreuse,
					NULL);
	if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter"))
		goto done;

	iter_fd = bpf_iter_create(bpf_link__fd(link));
	if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create"))
		goto done;

	tc->test(tc->family, tc->sock_type, addr, port, fds, tc->init_socks,
		 counts, tc->max_socks, link, iter_fd);
done:
	free(counts);
	free_fds(fds, tc->init_socks);
	if (iter_fd >= 0)
		close(iter_fd);
	bpf_link__destroy(link);
	sock_iter_batch__destroy(skel);
}

static void do_resume_tests(void)
{
	int i;

	for (i = 0; i < ARRAY_SIZE(resume_tests); i++) {
		if (test__start_subtest(resume_tests[i].description)) {
			do_resume_test(&resume_tests[i]);
		}
	}
}

static void do_test(int sock_type, bool onebyone)
{
	int err, i, nread, to_read, total_read, iter_fd = -1;
	struct iter_out outputs[nr_soreuse];
	struct bpf_link *link = NULL;
	struct sock_iter_batch *skel;
	int first_idx, second_idx;
	int *fds[2] = {};

	skel = sock_iter_batch__open();
	if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open"))
		return;

	/* Prepare 2 buckets of sockets in the kernel hashtable */
	for (i = 0; i < ARRAY_SIZE(fds); i++) {
		int local_port;

		fds[i] = start_reuseport_server(AF_INET6, sock_type, "::1", 0, 0,
						nr_soreuse);
		if (!ASSERT_OK_PTR(fds[i], "start_reuseport_server"))
			goto done;
		local_port = get_socket_local_port(*fds[i]);
		if (!ASSERT_GE(local_port, 0, "get_socket_local_port"))
			goto done;
		skel->rodata->ports[i] = ntohs(local_port);
	}
	skel->rodata->sf = AF_INET6;

	err = sock_iter_batch__load(skel);
	if (!ASSERT_OK(err, "sock_iter_batch__load"))
		goto done;

	link = bpf_program__attach_iter(sock_type == SOCK_STREAM ?
					skel->progs.iter_tcp_soreuse :
					skel->progs.iter_udp_soreuse,
					NULL);
	if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter"))
		goto done;

	iter_fd = bpf_iter_create(bpf_link__fd(link));
	if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create"))
		goto done;

	/* Test reading a bucket (either from fds[0] or fds[1]).
	 * Only read "nr_soreuse - 1" number of sockets
	 * from a bucket and leave one socket out from
	 * that bucket on purpose.
	 */
	to_read = (nr_soreuse - 1) * sizeof(*outputs);
	total_read = 0;
	first_idx = -1;
	do {
		nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read);
		if (nread <= 0 || nread % sizeof(*outputs))
			break;
		total_read += nread;

		if (first_idx == -1)
			first_idx = outputs[0].idx;
		for (i = 0; i < nread / sizeof(*outputs); i++)
			ASSERT_EQ(outputs[i].idx, first_idx, "first_idx");
	} while (total_read < to_read);
	ASSERT_EQ(nread, onebyone ? sizeof(*outputs) : to_read, "nread");
	ASSERT_EQ(total_read, to_read, "total_read");

	free_fds(fds[first_idx], nr_soreuse);
	fds[first_idx] = NULL;

	/* Read the "whole" second bucket */
	to_read = nr_soreuse * sizeof(*outputs);
	total_read = 0;
	second_idx = !first_idx;
	do {
		nread = read(iter_fd, outputs, onebyone ? sizeof(*outputs) : to_read);
		if (nread <= 0 || nread % sizeof(*outputs))
			break;
		total_read += nread;

		for (i = 0; i < nread / sizeof(*outputs); i++)
			ASSERT_EQ(outputs[i].idx, second_idx, "second_idx");
	} while (total_read <= to_read);
	ASSERT_EQ(nread, 0, "nread");
	/* Both so_reuseport ports should be in different buckets, so
	 * total_read must equal to the expected to_read.
	 *
	 * For a very unlikely case, both ports collide at the same bucket,
	 * the bucket offset (i.e. 3) will be skipped and it cannot
	 * expect the to_read number of bytes.
	 */
	if (skel->bss->bucket[0] != skel->bss->bucket[1])
		ASSERT_EQ(total_read, to_read, "total_read");

done:
	for (i = 0; i < ARRAY_SIZE(fds); i++)
		free_fds(fds[i], nr_soreuse);
	if (iter_fd < 0)
		close(iter_fd);
	bpf_link__destroy(link);
	sock_iter_batch__destroy(skel);
}

void test_sock_iter_batch(void)
{
	struct nstoken *nstoken = NULL;

	SYS_NOFAIL("ip netns del " TEST_NS);
	SYS(done, "ip netns add %s", TEST_NS);
	SYS(done, "ip -net %s link set dev lo up", TEST_NS);

	nstoken = open_netns(TEST_NS);
	if (!ASSERT_OK_PTR(nstoken, "open_netns"))
		goto done;

	if (test__start_subtest("tcp")) {
		do_test(SOCK_STREAM, true);
		do_test(SOCK_STREAM, false);
	}
	if (test__start_subtest("udp")) {
		do_test(SOCK_DGRAM, true);
		do_test(SOCK_DGRAM, false);
	}
	do_resume_tests();
	close_netns(nstoken);

done:
	SYS_NOFAIL("ip netns del " TEST_NS);
}