Contributors: 2
Author Tokens Token Proportion Commits Commit Proportion
Samuel Holland 1569 95.03% 1 50.00%
Charlie Jenkins 82 4.97% 1 50.00%
Total 1651 2

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
// SPDX-License-Identifier: GPL-2.0-only

#include <errno.h>
#include <fcntl.h>
#include <setjmp.h>
#include <signal.h>
#include <stdbool.h>
#include <sys/prctl.h>
#include <sys/wait.h>
#include <unistd.h>

#include "../../kselftest.h"

#ifndef PR_PMLEN_SHIFT
#define PR_PMLEN_SHIFT			24
#endif
#ifndef PR_PMLEN_MASK
#define PR_PMLEN_MASK			(0x7fUL << PR_PMLEN_SHIFT)
#endif

static int dev_zero;

static int pipefd[2];

static sigjmp_buf jmpbuf;

static void sigsegv_handler(int sig)
{
	siglongjmp(jmpbuf, 1);
}

static int min_pmlen;
static int max_pmlen;

static inline bool valid_pmlen(int pmlen)
{
	return pmlen == 0 || pmlen == 7 || pmlen == 16;
}

static void test_pmlen(void)
{
	ksft_print_msg("Testing available PMLEN values\n");

	for (int request = 0; request <= 16; request++) {
		int pmlen, ret;

		ret = prctl(PR_SET_TAGGED_ADDR_CTRL, request << PR_PMLEN_SHIFT, 0, 0, 0);
		if (ret)
			goto pr_set_error;

		ret = prctl(PR_GET_TAGGED_ADDR_CTRL, 0, 0, 0, 0);
		ksft_test_result(ret >= 0, "PMLEN=%d PR_GET_TAGGED_ADDR_CTRL\n", request);
		if (ret < 0)
			goto pr_get_error;

		pmlen = (ret & PR_PMLEN_MASK) >> PR_PMLEN_SHIFT;
		ksft_test_result(pmlen >= request, "PMLEN=%d constraint\n", request);
		ksft_test_result(valid_pmlen(pmlen), "PMLEN=%d validity\n", request);

		if (min_pmlen == 0)
			min_pmlen = pmlen;
		if (max_pmlen < pmlen)
			max_pmlen = pmlen;

		continue;

pr_set_error:
		ksft_test_result_skip("PMLEN=%d PR_GET_TAGGED_ADDR_CTRL\n", request);
pr_get_error:
		ksft_test_result_skip("PMLEN=%d constraint\n", request);
		ksft_test_result_skip("PMLEN=%d validity\n", request);
	}

	if (max_pmlen == 0)
		ksft_exit_fail_msg("Failed to enable pointer masking\n");
}

static int set_tagged_addr_ctrl(int pmlen, bool tagged_addr_abi)
{
	int arg, ret;

	arg = pmlen << PR_PMLEN_SHIFT | tagged_addr_abi;
	ret = prctl(PR_SET_TAGGED_ADDR_CTRL, arg, 0, 0, 0);
	if (!ret) {
		ret = prctl(PR_GET_TAGGED_ADDR_CTRL, 0, 0, 0, 0);
		if (ret == arg)
			return 0;
	}

	return ret < 0 ? -errno : -ENODATA;
}

static void test_dereference_pmlen(int pmlen)
{
	static volatile int i;
	volatile int *p;
	int ret;

	ret = set_tagged_addr_ctrl(pmlen, false);
	if (ret)
		return ksft_test_result_error("PMLEN=%d setup (%d)\n", pmlen, ret);

	i = pmlen;

	if (pmlen) {
		p = (volatile int *)((uintptr_t)&i | 1UL << (__riscv_xlen - pmlen));

		/* These dereferences should succeed. */
		if (sigsetjmp(jmpbuf, 1))
			return ksft_test_result_fail("PMLEN=%d valid tag\n", pmlen);
		if (*p != pmlen)
			return ksft_test_result_fail("PMLEN=%d bad value\n", pmlen);
		++*p;
	}

	p = (volatile int *)((uintptr_t)&i | 1UL << (__riscv_xlen - pmlen - 1));

	/* These dereferences should raise SIGSEGV. */
	if (sigsetjmp(jmpbuf, 1))
		return ksft_test_result_pass("PMLEN=%d dereference\n", pmlen);
	++*p;
	ksft_test_result_fail("PMLEN=%d invalid tag\n", pmlen);
}

static void test_dereference(void)
{
	ksft_print_msg("Testing userspace pointer dereference\n");

	signal(SIGSEGV, sigsegv_handler);

	test_dereference_pmlen(0);
	test_dereference_pmlen(min_pmlen);
	test_dereference_pmlen(max_pmlen);

	signal(SIGSEGV, SIG_DFL);
}

static void execve_child_sigsegv_handler(int sig)
{
	exit(42);
}

static int execve_child(void)
{
	static volatile int i;
	volatile int *p = (volatile int *)((uintptr_t)&i | 1UL << (__riscv_xlen - 7));

	signal(SIGSEGV, execve_child_sigsegv_handler);

	/* This dereference should raise SIGSEGV. */
	return *p;
}

static void test_fork_exec(void)
{
	int ret, status;

	ksft_print_msg("Testing fork/exec behavior\n");

	ret = set_tagged_addr_ctrl(min_pmlen, false);
	if (ret)
		return ksft_test_result_error("setup (%d)\n", ret);

	if (fork()) {
		wait(&status);
		ksft_test_result(WIFEXITED(status) && WEXITSTATUS(status) == 42,
				 "dereference after fork\n");
	} else {
		static volatile int i = 42;
		volatile int *p;

		p = (volatile int *)((uintptr_t)&i | 1UL << (__riscv_xlen - min_pmlen));

		/* This dereference should succeed. */
		exit(*p);
	}

	if (fork()) {
		wait(&status);
		ksft_test_result(WIFEXITED(status) && WEXITSTATUS(status) == 42,
				 "dereference after fork+exec\n");
	} else {
		/* Will call execve_child(). */
		execve("/proc/self/exe", (char *const []) { "", NULL }, NULL);
	}
}

static bool pwrite_wrapper(int fd, void *buf, size_t count, const char *msg)
{
	int ret = pwrite(fd, buf, count, 0);

	if (ret != count) {
		ksft_perror(msg);
		return false;
	}
	return true;
}

static void test_tagged_addr_abi_sysctl(void)
{
	char *err_pwrite_msg = "failed to write to /proc/sys/abi/tagged_addr_disabled\n";
	char value;
	int fd;

	ksft_print_msg("Testing tagged address ABI sysctl\n");

	fd = open("/proc/sys/abi/tagged_addr_disabled", O_WRONLY);
	if (fd < 0) {
		ksft_test_result_skip("failed to open sysctl file\n");
		ksft_test_result_skip("failed to open sysctl file\n");
		return;
	}

	value = '1';
	if (!pwrite_wrapper(fd, &value, 1, "write '1'"))
		ksft_test_result_fail(err_pwrite_msg);
	else
		ksft_test_result(set_tagged_addr_ctrl(min_pmlen, true) == -EINVAL,
				 "sysctl disabled\n");

	value = '0';
	if (!pwrite_wrapper(fd, &value, 1, "write '0'"))
		ksft_test_result_fail(err_pwrite_msg);
	else
		ksft_test_result(set_tagged_addr_ctrl(min_pmlen, true) == 0,
				 "sysctl enabled\n");

	set_tagged_addr_ctrl(0, false);

	close(fd);
}

static void test_tagged_addr_abi_pmlen(int pmlen)
{
	int i, *p, ret;

	i = ~pmlen;

	if (pmlen) {
		p = (int *)((uintptr_t)&i | 1UL << (__riscv_xlen - pmlen));

		ret = set_tagged_addr_ctrl(pmlen, false);
		if (ret)
			return ksft_test_result_error("PMLEN=%d ABI disabled setup (%d)\n",
						      pmlen, ret);

		ret = write(pipefd[1], p, sizeof(*p));
		if (ret >= 0 || errno != EFAULT)
			return ksft_test_result_fail("PMLEN=%d ABI disabled write\n", pmlen);

		ret = read(dev_zero, p, sizeof(*p));
		if (ret >= 0 || errno != EFAULT)
			return ksft_test_result_fail("PMLEN=%d ABI disabled read\n", pmlen);

		if (i != ~pmlen)
			return ksft_test_result_fail("PMLEN=%d ABI disabled value\n", pmlen);

		ret = set_tagged_addr_ctrl(pmlen, true);
		if (ret)
			return ksft_test_result_error("PMLEN=%d ABI enabled setup (%d)\n",
						      pmlen, ret);

		ret = write(pipefd[1], p, sizeof(*p));
		if (ret != sizeof(*p))
			return ksft_test_result_fail("PMLEN=%d ABI enabled write\n", pmlen);

		ret = read(dev_zero, p, sizeof(*p));
		if (ret != sizeof(*p))
			return ksft_test_result_fail("PMLEN=%d ABI enabled read\n", pmlen);

		if (i)
			return ksft_test_result_fail("PMLEN=%d ABI enabled value\n", pmlen);

		i = ~pmlen;
	} else {
		/* The tagged address ABI cannot be enabled when PMLEN == 0. */
		ret = set_tagged_addr_ctrl(pmlen, true);
		if (ret != -EINVAL)
			return ksft_test_result_error("PMLEN=%d ABI setup (%d)\n",
						      pmlen, ret);
	}

	p = (int *)((uintptr_t)&i | 1UL << (__riscv_xlen - pmlen - 1));

	ret = write(pipefd[1], p, sizeof(*p));
	if (ret >= 0 || errno != EFAULT)
		return ksft_test_result_fail("PMLEN=%d invalid tag write (%d)\n", pmlen, errno);

	ret = read(dev_zero, p, sizeof(*p));
	if (ret >= 0 || errno != EFAULT)
		return ksft_test_result_fail("PMLEN=%d invalid tag read\n", pmlen);

	if (i != ~pmlen)
		return ksft_test_result_fail("PMLEN=%d invalid tag value\n", pmlen);

	ksft_test_result_pass("PMLEN=%d tagged address ABI\n", pmlen);
}

static void test_tagged_addr_abi(void)
{
	ksft_print_msg("Testing tagged address ABI\n");

	test_tagged_addr_abi_pmlen(0);
	test_tagged_addr_abi_pmlen(min_pmlen);
	test_tagged_addr_abi_pmlen(max_pmlen);
}

static struct test_info {
	unsigned int nr_tests;
	void (*test_fn)(void);
} tests[] = {
	{ .nr_tests = 17 * 3, test_pmlen },
	{ .nr_tests = 3, test_dereference },
	{ .nr_tests = 2, test_fork_exec },
	{ .nr_tests = 2, test_tagged_addr_abi_sysctl },
	{ .nr_tests = 3, test_tagged_addr_abi },
};

int main(int argc, char **argv)
{
	unsigned int plan = 0;
	int ret;

	/* Check if this is the child process after execve(). */
	if (!argv[0][0])
		return execve_child();

	dev_zero = open("/dev/zero", O_RDWR);
	if (dev_zero < 0)
		return 1;

	/* Write to a pipe so the kernel must dereference the buffer pointer. */
	ret = pipe(pipefd);
	if (ret)
		return 1;

	ksft_print_header();

	for (int i = 0; i < ARRAY_SIZE(tests); i++)
		plan += tests[i].nr_tests;

	ksft_set_plan(plan);

	for (int i = 0; i < ARRAY_SIZE(tests); i++)
		tests[i].test_fn();

	ksft_finished();
}