Contributors: 3
Author Tokens Token Proportion Commits Commit Proportion
Dmitry Safonov 12897 98.24% 2 50.00%
Gautam Menghani 228 1.74% 1 25.00%
Willem de Bruijn 3 0.02% 1 25.00%
Total 13128 4

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341
// SPDX-License-Identifier: GPL-2.0
/*
 * ipsec.c - Check xfrm on veth inside a net-ns.
 * Copyright (c) 2018 Dmitry Safonov
 */

#define _GNU_SOURCE

#include <arpa/inet.h>
#include <asm/types.h>
#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <linux/limits.h>
#include <linux/netlink.h>
#include <linux/random.h>
#include <linux/rtnetlink.h>
#include <linux/veth.h>
#include <linux/xfrm.h>
#include <netinet/in.h>
#include <net/if.h>
#include <sched.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <time.h>
#include <unistd.h>

#include "../kselftest.h"

#define printk(fmt, ...)						\
	ksft_print_msg("%d[%u] " fmt "\n", getpid(), __LINE__, ##__VA_ARGS__)

#define pr_err(fmt, ...)	printk(fmt ": %m", ##__VA_ARGS__)

#define BUILD_BUG_ON(condition) ((void)sizeof(char[1 - 2*!!(condition)]))

#define IPV4_STR_SZ	16	/* xxx.xxx.xxx.xxx is longest + \0 */
#define MAX_PAYLOAD	2048
#define XFRM_ALGO_KEY_BUF_SIZE	512
#define MAX_PROCESSES	(1 << 14) /* /16 mask divided by /30 subnets */
#define INADDR_A	((in_addr_t) 0x0a000000) /* 10.0.0.0 */
#define INADDR_B	((in_addr_t) 0xc0a80000) /* 192.168.0.0 */

/* /30 mask for one veth connection */
#define PREFIX_LEN	30
#define child_ip(nr)	(4*nr + 1)
#define grchild_ip(nr)	(4*nr + 2)

#define VETH_FMT	"ktst-%d"
#define VETH_LEN	12

#define XFRM_ALGO_NR_KEYS 29

static int nsfd_parent	= -1;
static int nsfd_childa	= -1;
static int nsfd_childb	= -1;
static long page_size;

/*
 * ksft_cnt is static in kselftest, so isn't shared with children.
 * We have to send a test result back to parent and count there.
 * results_fd is a pipe with test feedback from children.
 */
static int results_fd[2];

const unsigned int ping_delay_nsec	= 50 * 1000 * 1000;
const unsigned int ping_timeout		= 300;
const unsigned int ping_count		= 100;
const unsigned int ping_success		= 80;

struct xfrm_key_entry {
	char algo_name[35];
	int key_len;
};

struct xfrm_key_entry xfrm_key_entries[] = {
	{"digest_null", 0},
	{"ecb(cipher_null)", 0},
	{"cbc(des)", 64},
	{"hmac(md5)", 128},
	{"cmac(aes)", 128},
	{"xcbc(aes)", 128},
	{"cbc(cast5)", 128},
	{"cbc(serpent)", 128},
	{"hmac(sha1)", 160},
	{"hmac(rmd160)", 160},
	{"cbc(des3_ede)", 192},
	{"hmac(sha256)", 256},
	{"cbc(aes)", 256},
	{"cbc(camellia)", 256},
	{"cbc(twofish)", 256},
	{"rfc3686(ctr(aes))", 288},
	{"hmac(sha384)", 384},
	{"cbc(blowfish)", 448},
	{"hmac(sha512)", 512},
	{"rfc4106(gcm(aes))-128", 160},
	{"rfc4543(gcm(aes))-128", 160},
	{"rfc4309(ccm(aes))-128", 152},
	{"rfc4106(gcm(aes))-192", 224},
	{"rfc4543(gcm(aes))-192", 224},
	{"rfc4309(ccm(aes))-192", 216},
	{"rfc4106(gcm(aes))-256", 288},
	{"rfc4543(gcm(aes))-256", 288},
	{"rfc4309(ccm(aes))-256", 280},
	{"rfc7539(chacha20,poly1305)-128", 0}
};

static void randomize_buffer(void *buf, size_t buflen)
{
	int *p = (int *)buf;
	size_t words = buflen / sizeof(int);
	size_t leftover = buflen % sizeof(int);

	if (!buflen)
		return;

	while (words--)
		*p++ = rand();

	if (leftover) {
		int tmp = rand();

		memcpy(buf + buflen - leftover, &tmp, leftover);
	}

	return;
}

static int unshare_open(void)
{
	const char *netns_path = "/proc/self/ns/net";
	int fd;

	if (unshare(CLONE_NEWNET) != 0) {
		pr_err("unshare()");
		return -1;
	}

	fd = open(netns_path, O_RDONLY);
	if (fd <= 0) {
		pr_err("open(%s)", netns_path);
		return -1;
	}

	return fd;
}

static int switch_ns(int fd)
{
	if (setns(fd, CLONE_NEWNET)) {
		pr_err("setns()");
		return -1;
	}
	return 0;
}

/*
 * Running the test inside a new parent net namespace to bother less
 * about cleanup on error-path.
 */
static int init_namespaces(void)
{
	nsfd_parent = unshare_open();
	if (nsfd_parent <= 0)
		return -1;

	nsfd_childa = unshare_open();
	if (nsfd_childa <= 0)
		return -1;

	if (switch_ns(nsfd_parent))
		return -1;

	nsfd_childb = unshare_open();
	if (nsfd_childb <= 0)
		return -1;

	if (switch_ns(nsfd_parent))
		return -1;
	return 0;
}

static int netlink_sock(int *sock, uint32_t *seq_nr, int proto)
{
	if (*sock > 0) {
		seq_nr++;
		return 0;
	}

	*sock = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, proto);
	if (*sock <= 0) {
		pr_err("socket(AF_NETLINK)");
		return -1;
	}

	randomize_buffer(seq_nr, sizeof(*seq_nr));

	return 0;
}

static inline struct rtattr *rtattr_hdr(struct nlmsghdr *nh)
{
	return (struct rtattr *)((char *)(nh) + RTA_ALIGN((nh)->nlmsg_len));
}

static int rtattr_pack(struct nlmsghdr *nh, size_t req_sz,
		unsigned short rta_type, const void *payload, size_t size)
{
	/* NLMSG_ALIGNTO == RTA_ALIGNTO, nlmsg_len already aligned */
	struct rtattr *attr = rtattr_hdr(nh);
	size_t nl_size = RTA_ALIGN(nh->nlmsg_len) + RTA_LENGTH(size);

	if (req_sz < nl_size) {
		printk("req buf is too small: %zu < %zu", req_sz, nl_size);
		return -1;
	}
	nh->nlmsg_len = nl_size;

	attr->rta_len = RTA_LENGTH(size);
	attr->rta_type = rta_type;
	memcpy(RTA_DATA(attr), payload, size);

	return 0;
}

static struct rtattr *_rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
		unsigned short rta_type, const void *payload, size_t size)
{
	struct rtattr *ret = rtattr_hdr(nh);

	if (rtattr_pack(nh, req_sz, rta_type, payload, size))
		return 0;

	return ret;
}

static inline struct rtattr *rtattr_begin(struct nlmsghdr *nh, size_t req_sz,
		unsigned short rta_type)
{
	return _rtattr_begin(nh, req_sz, rta_type, 0, 0);
}

static inline void rtattr_end(struct nlmsghdr *nh, struct rtattr *attr)
{
	char *nlmsg_end = (char *)nh + nh->nlmsg_len;

	attr->rta_len = nlmsg_end - (char *)attr;
}

static int veth_pack_peerb(struct nlmsghdr *nh, size_t req_sz,
		const char *peer, int ns)
{
	struct ifinfomsg pi;
	struct rtattr *peer_attr;

	memset(&pi, 0, sizeof(pi));
	pi.ifi_family	= AF_UNSPEC;
	pi.ifi_change	= 0xFFFFFFFF;

	peer_attr = _rtattr_begin(nh, req_sz, VETH_INFO_PEER, &pi, sizeof(pi));
	if (!peer_attr)
		return -1;

	if (rtattr_pack(nh, req_sz, IFLA_IFNAME, peer, strlen(peer)))
		return -1;

	if (rtattr_pack(nh, req_sz, IFLA_NET_NS_FD, &ns, sizeof(ns)))
		return -1;

	rtattr_end(nh, peer_attr);

	return 0;
}

static int netlink_check_answer(int sock)
{
	struct nlmsgerror {
		struct nlmsghdr hdr;
		int error;
		struct nlmsghdr orig_msg;
	} answer;

	if (recv(sock, &answer, sizeof(answer), 0) < 0) {
		pr_err("recv()");
		return -1;
	} else if (answer.hdr.nlmsg_type != NLMSG_ERROR) {
		printk("expected NLMSG_ERROR, got %d", (int)answer.hdr.nlmsg_type);
		return -1;
	} else if (answer.error) {
		printk("NLMSG_ERROR: %d: %s",
			answer.error, strerror(-answer.error));
		return answer.error;
	}

	return 0;
}

static int veth_add(int sock, uint32_t seq, const char *peera, int ns_a,
		const char *peerb, int ns_b)
{
	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
	struct {
		struct nlmsghdr		nh;
		struct ifinfomsg	info;
		char			attrbuf[MAX_PAYLOAD];
	} req;
	const char veth_type[] = "veth";
	struct rtattr *link_info, *info_data;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
	req.nh.nlmsg_type	= RTM_NEWLINK;
	req.nh.nlmsg_flags	= flags;
	req.nh.nlmsg_seq	= seq;
	req.info.ifi_family	= AF_UNSPEC;
	req.info.ifi_change	= 0xFFFFFFFF;

	if (rtattr_pack(&req.nh, sizeof(req), IFLA_IFNAME, peera, strlen(peera)))
		return -1;

	if (rtattr_pack(&req.nh, sizeof(req), IFLA_NET_NS_FD, &ns_a, sizeof(ns_a)))
		return -1;

	link_info = rtattr_begin(&req.nh, sizeof(req), IFLA_LINKINFO);
	if (!link_info)
		return -1;

	if (rtattr_pack(&req.nh, sizeof(req), IFLA_INFO_KIND, veth_type, sizeof(veth_type)))
		return -1;

	info_data = rtattr_begin(&req.nh, sizeof(req), IFLA_INFO_DATA);
	if (!info_data)
		return -1;

	if (veth_pack_peerb(&req.nh, sizeof(req), peerb, ns_b))
		return -1;

	rtattr_end(&req.nh, info_data);
	rtattr_end(&req.nh, link_info);

	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}
	return netlink_check_answer(sock);
}

static int ip4_addr_set(int sock, uint32_t seq, const char *intf,
		struct in_addr addr, uint8_t prefix)
{
	uint16_t flags = NLM_F_REQUEST | NLM_F_ACK | NLM_F_EXCL | NLM_F_CREATE;
	struct {
		struct nlmsghdr		nh;
		struct ifaddrmsg	info;
		char			attrbuf[MAX_PAYLOAD];
	} req;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
	req.nh.nlmsg_type	= RTM_NEWADDR;
	req.nh.nlmsg_flags	= flags;
	req.nh.nlmsg_seq	= seq;
	req.info.ifa_family	= AF_INET;
	req.info.ifa_prefixlen	= prefix;
	req.info.ifa_index	= if_nametoindex(intf);

#ifdef DEBUG
	{
		char addr_str[IPV4_STR_SZ] = {};

		strncpy(addr_str, inet_ntoa(addr), IPV4_STR_SZ - 1);

		printk("ip addr set %s", addr_str);
	}
#endif

	if (rtattr_pack(&req.nh, sizeof(req), IFA_LOCAL, &addr, sizeof(addr)))
		return -1;

	if (rtattr_pack(&req.nh, sizeof(req), IFA_ADDRESS, &addr, sizeof(addr)))
		return -1;

	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}
	return netlink_check_answer(sock);
}

static int link_set_up(int sock, uint32_t seq, const char *intf)
{
	struct {
		struct nlmsghdr		nh;
		struct ifinfomsg	info;
		char			attrbuf[MAX_PAYLOAD];
	} req;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
	req.nh.nlmsg_type	= RTM_NEWLINK;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
	req.nh.nlmsg_seq	= seq;
	req.info.ifi_family	= AF_UNSPEC;
	req.info.ifi_change	= 0xFFFFFFFF;
	req.info.ifi_index	= if_nametoindex(intf);
	req.info.ifi_flags	= IFF_UP;
	req.info.ifi_change	= IFF_UP;

	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}
	return netlink_check_answer(sock);
}

static int ip4_route_set(int sock, uint32_t seq, const char *intf,
		struct in_addr src, struct in_addr dst)
{
	struct {
		struct nlmsghdr	nh;
		struct rtmsg	rt;
		char		attrbuf[MAX_PAYLOAD];
	} req;
	unsigned int index = if_nametoindex(intf);

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.rt));
	req.nh.nlmsg_type	= RTM_NEWROUTE;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK | NLM_F_CREATE;
	req.nh.nlmsg_seq	= seq;
	req.rt.rtm_family	= AF_INET;
	req.rt.rtm_dst_len	= 32;
	req.rt.rtm_table	= RT_TABLE_MAIN;
	req.rt.rtm_protocol	= RTPROT_BOOT;
	req.rt.rtm_scope	= RT_SCOPE_LINK;
	req.rt.rtm_type		= RTN_UNICAST;

	if (rtattr_pack(&req.nh, sizeof(req), RTA_DST, &dst, sizeof(dst)))
		return -1;

	if (rtattr_pack(&req.nh, sizeof(req), RTA_PREFSRC, &src, sizeof(src)))
		return -1;

	if (rtattr_pack(&req.nh, sizeof(req), RTA_OIF, &index, sizeof(index)))
		return -1;

	if (send(sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}

	return netlink_check_answer(sock);
}

static int tunnel_set_route(int route_sock, uint32_t *route_seq, char *veth,
		struct in_addr tunsrc, struct in_addr tundst)
{
	if (ip4_addr_set(route_sock, (*route_seq)++, "lo",
			tunsrc, PREFIX_LEN)) {
		printk("Failed to set ipv4 addr");
		return -1;
	}

	if (ip4_route_set(route_sock, (*route_seq)++, veth, tunsrc, tundst)) {
		printk("Failed to set ipv4 route");
		return -1;
	}

	return 0;
}

static int init_child(int nsfd, char *veth, unsigned int src, unsigned int dst)
{
	struct in_addr intsrc = inet_makeaddr(INADDR_B, src);
	struct in_addr tunsrc = inet_makeaddr(INADDR_A, src);
	struct in_addr tundst = inet_makeaddr(INADDR_A, dst);
	int route_sock = -1, ret = -1;
	uint32_t route_seq;

	if (switch_ns(nsfd))
		return -1;

	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE)) {
		printk("Failed to open netlink route socket in child");
		return -1;
	}

	if (ip4_addr_set(route_sock, route_seq++, veth, intsrc, PREFIX_LEN)) {
		printk("Failed to set ipv4 addr");
		goto err;
	}

	if (link_set_up(route_sock, route_seq++, veth)) {
		printk("Failed to bring up %s", veth);
		goto err;
	}

	if (tunnel_set_route(route_sock, &route_seq, veth, tunsrc, tundst)) {
		printk("Failed to add tunnel route on %s", veth);
		goto err;
	}
	ret = 0;

err:
	close(route_sock);
	return ret;
}

#define ALGO_LEN	64
enum desc_type {
	CREATE_TUNNEL	= 0,
	ALLOCATE_SPI,
	MONITOR_ACQUIRE,
	EXPIRE_STATE,
	EXPIRE_POLICY,
	SPDINFO_ATTRS,
};
const char *desc_name[] = {
	"create tunnel",
	"alloc spi",
	"monitor acquire",
	"expire state",
	"expire policy",
	"spdinfo attributes",
	""
};
struct xfrm_desc {
	enum desc_type	type;
	uint8_t		proto;
	char		a_algo[ALGO_LEN];
	char		e_algo[ALGO_LEN];
	char		c_algo[ALGO_LEN];
	char		ae_algo[ALGO_LEN];
	unsigned int	icv_len;
	/* unsigned key_len; */
};

enum msg_type {
	MSG_ACK		= 0,
	MSG_EXIT,
	MSG_PING,
	MSG_XFRM_PREPARE,
	MSG_XFRM_ADD,
	MSG_XFRM_DEL,
	MSG_XFRM_CLEANUP,
};

struct test_desc {
	enum msg_type type;
	union {
		struct {
			in_addr_t reply_ip;
			unsigned int port;
		} ping;
		struct xfrm_desc xfrm_desc;
	} body;
};

struct test_result {
	struct xfrm_desc desc;
	unsigned int res;
};

static void write_test_result(unsigned int res, struct xfrm_desc *d)
{
	struct test_result tr = {};
	ssize_t ret;

	tr.desc = *d;
	tr.res = res;

	ret = write(results_fd[1], &tr, sizeof(tr));
	if (ret != sizeof(tr))
		pr_err("Failed to write the result in pipe %zd", ret);
}

static void write_msg(int fd, struct test_desc *msg, bool exit_of_fail)
{
	ssize_t bytes = write(fd, msg, sizeof(*msg));

	/* Make sure that write/read is atomic to a pipe */
	BUILD_BUG_ON(sizeof(struct test_desc) > PIPE_BUF);

	if (bytes < 0) {
		pr_err("write()");
		if (exit_of_fail)
			exit(KSFT_FAIL);
	}
	if (bytes != sizeof(*msg)) {
		pr_err("sent part of the message %zd/%zu", bytes, sizeof(*msg));
		if (exit_of_fail)
			exit(KSFT_FAIL);
	}
}

static void read_msg(int fd, struct test_desc *msg, bool exit_of_fail)
{
	ssize_t bytes = read(fd, msg, sizeof(*msg));

	if (bytes < 0) {
		pr_err("read()");
		if (exit_of_fail)
			exit(KSFT_FAIL);
	}
	if (bytes != sizeof(*msg)) {
		pr_err("got incomplete message %zd/%zu", bytes, sizeof(*msg));
		if (exit_of_fail)
			exit(KSFT_FAIL);
	}
}

static int udp_ping_init(struct in_addr listen_ip, unsigned int u_timeout,
		unsigned int *server_port, int sock[2])
{
	struct sockaddr_in server;
	struct timeval t = { .tv_sec = 0, .tv_usec = u_timeout };
	socklen_t s_len = sizeof(server);

	sock[0] = socket(AF_INET, SOCK_DGRAM, 0);
	if (sock[0] < 0) {
		pr_err("socket()");
		return -1;
	}

	server.sin_family	= AF_INET;
	server.sin_port		= 0;
	memcpy(&server.sin_addr.s_addr, &listen_ip, sizeof(struct in_addr));

	if (bind(sock[0], (struct sockaddr *)&server, s_len)) {
		pr_err("bind()");
		goto err_close_server;
	}

	if (getsockname(sock[0], (struct sockaddr *)&server, &s_len)) {
		pr_err("getsockname()");
		goto err_close_server;
	}

	*server_port = ntohs(server.sin_port);

	if (setsockopt(sock[0], SOL_SOCKET, SO_RCVTIMEO, (const char *)&t, sizeof t)) {
		pr_err("setsockopt()");
		goto err_close_server;
	}

	sock[1] = socket(AF_INET, SOCK_DGRAM, 0);
	if (sock[1] < 0) {
		pr_err("socket()");
		goto err_close_server;
	}

	return 0;

err_close_server:
	close(sock[0]);
	return -1;
}

static int udp_ping_send(int sock[2], in_addr_t dest_ip, unsigned int port,
		char *buf, size_t buf_len)
{
	struct sockaddr_in server;
	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
	char *sock_buf[buf_len];
	ssize_t r_bytes, s_bytes;

	server.sin_family	= AF_INET;
	server.sin_port		= htons(port);
	server.sin_addr.s_addr	= dest_ip;

	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
	if (s_bytes < 0) {
		pr_err("sendto()");
		return -1;
	} else if (s_bytes != buf_len) {
		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
		return -1;
	}

	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
	if (r_bytes < 0) {
		if (errno != EAGAIN)
			pr_err("recv()");
		return -1;
	} else if (r_bytes == 0) { /* EOF */
		printk("EOF on reply to ping");
		return -1;
	} else if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
		return -1;
	}

	return 0;
}

static int udp_ping_reply(int sock[2], in_addr_t dest_ip, unsigned int port,
		char *buf, size_t buf_len)
{
	struct sockaddr_in server;
	const struct sockaddr *dest_addr = (struct sockaddr *)&server;
	char *sock_buf[buf_len];
	ssize_t r_bytes, s_bytes;

	server.sin_family	= AF_INET;
	server.sin_port		= htons(port);
	server.sin_addr.s_addr	= dest_ip;

	r_bytes = recv(sock[0], sock_buf, buf_len, 0);
	if (r_bytes < 0) {
		if (errno != EAGAIN)
			pr_err("recv()");
		return -1;
	}
	if (r_bytes == 0) { /* EOF */
		printk("EOF on reply to ping");
		return -1;
	}
	if (r_bytes != buf_len || memcmp(buf, sock_buf, buf_len)) {
		printk("ping reply packet is corrupted %zd/%zu", r_bytes, buf_len);
		return -1;
	}

	s_bytes = sendto(sock[1], buf, buf_len, 0, dest_addr, sizeof(server));
	if (s_bytes < 0) {
		pr_err("sendto()");
		return -1;
	} else if (s_bytes != buf_len) {
		printk("send part of the message: %zd/%zu", s_bytes, sizeof(server));
		return -1;
	}

	return 0;
}

typedef int (*ping_f)(int sock[2], in_addr_t dest_ip, unsigned int port,
		char *buf, size_t buf_len);
static int do_ping(int cmd_fd, char *buf, size_t buf_len, struct in_addr from,
		bool init_side, int d_port, in_addr_t to, ping_f func)
{
	struct test_desc msg;
	unsigned int s_port, i, ping_succeeded = 0;
	int ping_sock[2];
	char to_str[IPV4_STR_SZ] = {}, from_str[IPV4_STR_SZ] = {};

	if (udp_ping_init(from, ping_timeout, &s_port, ping_sock)) {
		printk("Failed to init ping");
		return -1;
	}

	memset(&msg, 0, sizeof(msg));
	msg.type		= MSG_PING;
	msg.body.ping.port	= s_port;
	memcpy(&msg.body.ping.reply_ip, &from, sizeof(from));

	write_msg(cmd_fd, &msg, 0);
	if (init_side) {
		/* The other end sends ip to ping */
		read_msg(cmd_fd, &msg, 0);
		if (msg.type != MSG_PING)
			return -1;
		to = msg.body.ping.reply_ip;
		d_port = msg.body.ping.port;
	}

	for (i = 0; i < ping_count ; i++) {
		struct timespec sleep_time = {
			.tv_sec = 0,
			.tv_nsec = ping_delay_nsec,
		};

		ping_succeeded += !func(ping_sock, to, d_port, buf, page_size);
		nanosleep(&sleep_time, 0);
	}

	close(ping_sock[0]);
	close(ping_sock[1]);

	strncpy(to_str, inet_ntoa(*(struct in_addr *)&to), IPV4_STR_SZ - 1);
	strncpy(from_str, inet_ntoa(from), IPV4_STR_SZ - 1);

	if (ping_succeeded < ping_success) {
		printk("ping (%s) %s->%s failed %u/%u times",
			init_side ? "send" : "reply", from_str, to_str,
			ping_count - ping_succeeded, ping_count);
		return -1;
	}

#ifdef DEBUG
	printk("ping (%s) %s->%s succeeded %u/%u times",
		init_side ? "send" : "reply", from_str, to_str,
		ping_succeeded, ping_count);
#endif

	return 0;
}

static int xfrm_fill_key(char *name, char *buf,
		size_t buf_len, unsigned int *key_len)
{
	int i;

	for (i = 0; i < XFRM_ALGO_NR_KEYS; i++) {
		if (strncmp(name, xfrm_key_entries[i].algo_name, ALGO_LEN) == 0)
			*key_len = xfrm_key_entries[i].key_len;
	}

	if (*key_len > buf_len) {
		printk("Can't pack a key - too big for buffer");
		return -1;
	}

	randomize_buffer(buf, *key_len);

	return 0;
}

static int xfrm_state_pack_algo(struct nlmsghdr *nh, size_t req_sz,
		struct xfrm_desc *desc)
{
	struct {
		union {
			struct xfrm_algo	alg;
			struct xfrm_algo_aead	aead;
			struct xfrm_algo_auth	auth;
		} u;
		char buf[XFRM_ALGO_KEY_BUF_SIZE];
	} alg = {};
	size_t alen, elen, clen, aelen;
	unsigned short type;

	alen = strlen(desc->a_algo);
	elen = strlen(desc->e_algo);
	clen = strlen(desc->c_algo);
	aelen = strlen(desc->ae_algo);

	/* Verify desc */
	switch (desc->proto) {
	case IPPROTO_AH:
		if (!alen || elen || clen || aelen) {
			printk("BUG: buggy ah desc");
			return -1;
		}
		strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN - 1);
		if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
				sizeof(alg.buf), &alg.u.alg.alg_key_len))
			return -1;
		type = XFRMA_ALG_AUTH;
		break;
	case IPPROTO_COMP:
		if (!clen || elen || alen || aelen) {
			printk("BUG: buggy comp desc");
			return -1;
		}
		strncpy(alg.u.alg.alg_name, desc->c_algo, ALGO_LEN - 1);
		if (xfrm_fill_key(desc->c_algo, alg.u.alg.alg_key,
				sizeof(alg.buf), &alg.u.alg.alg_key_len))
			return -1;
		type = XFRMA_ALG_COMP;
		break;
	case IPPROTO_ESP:
		if (!((alen && elen) ^ aelen) || clen) {
			printk("BUG: buggy esp desc");
			return -1;
		}
		if (aelen) {
			alg.u.aead.alg_icv_len = desc->icv_len;
			strncpy(alg.u.aead.alg_name, desc->ae_algo, ALGO_LEN - 1);
			if (xfrm_fill_key(desc->ae_algo, alg.u.aead.alg_key,
						sizeof(alg.buf), &alg.u.aead.alg_key_len))
				return -1;
			type = XFRMA_ALG_AEAD;
		} else {

			strncpy(alg.u.alg.alg_name, desc->e_algo, ALGO_LEN - 1);
			type = XFRMA_ALG_CRYPT;
			if (xfrm_fill_key(desc->e_algo, alg.u.alg.alg_key,
						sizeof(alg.buf), &alg.u.alg.alg_key_len))
				return -1;
			if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
				return -1;

			strncpy(alg.u.alg.alg_name, desc->a_algo, ALGO_LEN);
			type = XFRMA_ALG_AUTH;
			if (xfrm_fill_key(desc->a_algo, alg.u.alg.alg_key,
						sizeof(alg.buf), &alg.u.alg.alg_key_len))
				return -1;
		}
		break;
	default:
		printk("BUG: unknown proto in desc");
		return -1;
	}

	if (rtattr_pack(nh, req_sz, type, &alg, sizeof(alg)))
		return -1;

	return 0;
}

static inline uint32_t gen_spi(struct in_addr src)
{
	return htonl(inet_lnaof(src));
}

static int xfrm_state_add(int xfrm_sock, uint32_t seq, uint32_t spi,
		struct in_addr src, struct in_addr dst,
		struct xfrm_desc *desc)
{
	struct {
		struct nlmsghdr		nh;
		struct xfrm_usersa_info	info;
		char			attrbuf[MAX_PAYLOAD];
	} req;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
	req.nh.nlmsg_type	= XFRM_MSG_NEWSA;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
	req.nh.nlmsg_seq	= seq;

	/* Fill selector. */
	memcpy(&req.info.sel.daddr, &dst, sizeof(dst));
	memcpy(&req.info.sel.saddr, &src, sizeof(src));
	req.info.sel.family		= AF_INET;
	req.info.sel.prefixlen_d	= PREFIX_LEN;
	req.info.sel.prefixlen_s	= PREFIX_LEN;

	/* Fill id */
	memcpy(&req.info.id.daddr, &dst, sizeof(dst));
	/* Note: zero-spi cannot be deleted */
	req.info.id.spi = spi;
	req.info.id.proto	= desc->proto;

	memcpy(&req.info.saddr, &src, sizeof(src));

	/* Fill lifteme_cfg */
	req.info.lft.soft_byte_limit	= XFRM_INF;
	req.info.lft.hard_byte_limit	= XFRM_INF;
	req.info.lft.soft_packet_limit	= XFRM_INF;
	req.info.lft.hard_packet_limit	= XFRM_INF;

	req.info.family		= AF_INET;
	req.info.mode		= XFRM_MODE_TUNNEL;

	if (xfrm_state_pack_algo(&req.nh, sizeof(req), desc))
		return -1;

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}

	return netlink_check_answer(xfrm_sock);
}

static bool xfrm_usersa_found(struct xfrm_usersa_info *info, uint32_t spi,
		struct in_addr src, struct in_addr dst,
		struct xfrm_desc *desc)
{
	if (memcmp(&info->sel.daddr, &dst, sizeof(dst)))
		return false;

	if (memcmp(&info->sel.saddr, &src, sizeof(src)))
		return false;

	if (info->sel.family != AF_INET					||
			info->sel.prefixlen_d != PREFIX_LEN		||
			info->sel.prefixlen_s != PREFIX_LEN)
		return false;

	if (info->id.spi != spi || info->id.proto != desc->proto)
		return false;

	if (memcmp(&info->id.daddr, &dst, sizeof(dst)))
		return false;

	if (memcmp(&info->saddr, &src, sizeof(src)))
		return false;

	if (info->lft.soft_byte_limit != XFRM_INF			||
			info->lft.hard_byte_limit != XFRM_INF		||
			info->lft.soft_packet_limit != XFRM_INF		||
			info->lft.hard_packet_limit != XFRM_INF)
		return false;

	if (info->family != AF_INET || info->mode != XFRM_MODE_TUNNEL)
		return false;

	/* XXX: check xfrm algo, see xfrm_state_pack_algo(). */

	return true;
}

static int xfrm_state_check(int xfrm_sock, uint32_t seq, uint32_t spi,
		struct in_addr src, struct in_addr dst,
		struct xfrm_desc *desc)
{
	struct {
		struct nlmsghdr		nh;
		char			attrbuf[MAX_PAYLOAD];
	} req;
	struct {
		struct nlmsghdr		nh;
		union {
			struct xfrm_usersa_info	info;
			int error;
		};
		char			attrbuf[MAX_PAYLOAD];
	} answer;
	struct xfrm_address_filter filter = {};
	bool found = false;


	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(0);
	req.nh.nlmsg_type	= XFRM_MSG_GETSA;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_DUMP;
	req.nh.nlmsg_seq	= seq;

	/*
	 * Add dump filter by source address as there may be other tunnels
	 * in this netns (if tests run in parallel).
	 */
	filter.family = AF_INET;
	filter.splen = 0x1f;	/* 0xffffffff mask see addr_match() */
	memcpy(&filter.saddr, &src, sizeof(src));
	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_ADDRESS_FILTER,
				&filter, sizeof(filter)))
		return -1;

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}

	while (1) {
		if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
			pr_err("recv()");
			return -1;
		}
		if (answer.nh.nlmsg_type == NLMSG_ERROR) {
			printk("NLMSG_ERROR: %d: %s",
				answer.error, strerror(-answer.error));
			return -1;
		} else if (answer.nh.nlmsg_type == NLMSG_DONE) {
			if (found)
				return 0;
			printk("didn't find allocated xfrm state in dump");
			return -1;
		} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
			if (xfrm_usersa_found(&answer.info, spi, src, dst, desc))
				found = true;
		}
	}
}

static int xfrm_set(int xfrm_sock, uint32_t *seq,
		struct in_addr src, struct in_addr dst,
		struct in_addr tunsrc, struct in_addr tundst,
		struct xfrm_desc *desc)
{
	int err;

	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
	if (err) {
		printk("Failed to add xfrm state");
		return -1;
	}

	err = xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
	if (err) {
		printk("Failed to add xfrm state");
		return -1;
	}

	/* Check dumps for XFRM_MSG_GETSA */
	err = xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc);
	err |= xfrm_state_check(xfrm_sock, (*seq)++, gen_spi(src), dst, src, desc);
	if (err) {
		printk("Failed to check xfrm state");
		return -1;
	}

	return 0;
}

static int xfrm_policy_add(int xfrm_sock, uint32_t seq, uint32_t spi,
		struct in_addr src, struct in_addr dst, uint8_t dir,
		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
{
	struct {
		struct nlmsghdr			nh;
		struct xfrm_userpolicy_info	info;
		char				attrbuf[MAX_PAYLOAD];
	} req;
	struct xfrm_user_tmpl tmpl;

	memset(&req, 0, sizeof(req));
	memset(&tmpl, 0, sizeof(tmpl));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.info));
	req.nh.nlmsg_type	= XFRM_MSG_NEWPOLICY;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
	req.nh.nlmsg_seq	= seq;

	/* Fill selector. */
	memcpy(&req.info.sel.daddr, &dst, sizeof(tundst));
	memcpy(&req.info.sel.saddr, &src, sizeof(tunsrc));
	req.info.sel.family		= AF_INET;
	req.info.sel.prefixlen_d	= PREFIX_LEN;
	req.info.sel.prefixlen_s	= PREFIX_LEN;

	/* Fill lifteme_cfg */
	req.info.lft.soft_byte_limit	= XFRM_INF;
	req.info.lft.hard_byte_limit	= XFRM_INF;
	req.info.lft.soft_packet_limit	= XFRM_INF;
	req.info.lft.hard_packet_limit	= XFRM_INF;

	req.info.dir = dir;

	/* Fill tmpl */
	memcpy(&tmpl.id.daddr, &dst, sizeof(dst));
	/* Note: zero-spi cannot be deleted */
	tmpl.id.spi = spi;
	tmpl.id.proto	= proto;
	tmpl.family	= AF_INET;
	memcpy(&tmpl.saddr, &src, sizeof(src));
	tmpl.mode	= XFRM_MODE_TUNNEL;
	tmpl.aalgos = (~(uint32_t)0);
	tmpl.ealgos = (~(uint32_t)0);
	tmpl.calgos = (~(uint32_t)0);

	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &tmpl, sizeof(tmpl)))
		return -1;

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}

	return netlink_check_answer(xfrm_sock);
}

static int xfrm_prepare(int xfrm_sock, uint32_t *seq,
		struct in_addr src, struct in_addr dst,
		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
{
	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
				XFRM_POLICY_OUT, tunsrc, tundst, proto)) {
		printk("Failed to add xfrm policy");
		return -1;
	}

	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), dst, src,
				XFRM_POLICY_IN, tunsrc, tundst, proto)) {
		printk("Failed to add xfrm policy");
		return -1;
	}

	return 0;
}

static int xfrm_policy_del(int xfrm_sock, uint32_t seq,
		struct in_addr src, struct in_addr dst, uint8_t dir,
		struct in_addr tunsrc, struct in_addr tundst)
{
	struct {
		struct nlmsghdr			nh;
		struct xfrm_userpolicy_id	id;
		char				attrbuf[MAX_PAYLOAD];
	} req;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
	req.nh.nlmsg_type	= XFRM_MSG_DELPOLICY;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
	req.nh.nlmsg_seq	= seq;

	/* Fill id */
	memcpy(&req.id.sel.daddr, &dst, sizeof(tundst));
	memcpy(&req.id.sel.saddr, &src, sizeof(tunsrc));
	req.id.sel.family		= AF_INET;
	req.id.sel.prefixlen_d		= PREFIX_LEN;
	req.id.sel.prefixlen_s		= PREFIX_LEN;
	req.id.dir = dir;

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}

	return netlink_check_answer(xfrm_sock);
}

static int xfrm_cleanup(int xfrm_sock, uint32_t *seq,
		struct in_addr src, struct in_addr dst,
		struct in_addr tunsrc, struct in_addr tundst)
{
	if (xfrm_policy_del(xfrm_sock, (*seq)++, src, dst,
				XFRM_POLICY_OUT, tunsrc, tundst)) {
		printk("Failed to add xfrm policy");
		return -1;
	}

	if (xfrm_policy_del(xfrm_sock, (*seq)++, dst, src,
				XFRM_POLICY_IN, tunsrc, tundst)) {
		printk("Failed to add xfrm policy");
		return -1;
	}

	return 0;
}

static int xfrm_state_del(int xfrm_sock, uint32_t seq, uint32_t spi,
		struct in_addr src, struct in_addr dst, uint8_t proto)
{
	struct {
		struct nlmsghdr		nh;
		struct xfrm_usersa_id	id;
		char			attrbuf[MAX_PAYLOAD];
	} req;
	xfrm_address_t saddr = {};

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.id));
	req.nh.nlmsg_type	= XFRM_MSG_DELSA;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
	req.nh.nlmsg_seq	= seq;

	memcpy(&req.id.daddr, &dst, sizeof(dst));
	req.id.family		= AF_INET;
	req.id.proto		= proto;
	/* Note: zero-spi cannot be deleted */
	req.id.spi = spi;

	memcpy(&saddr, &src, sizeof(src));
	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SRCADDR, &saddr, sizeof(saddr)))
		return -1;

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}

	return netlink_check_answer(xfrm_sock);
}

static int xfrm_delete(int xfrm_sock, uint32_t *seq,
		struct in_addr src, struct in_addr dst,
		struct in_addr tunsrc, struct in_addr tundst, uint8_t proto)
{
	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), src, dst, proto)) {
		printk("Failed to remove xfrm state");
		return -1;
	}

	if (xfrm_state_del(xfrm_sock, (*seq)++, gen_spi(src), dst, src, proto)) {
		printk("Failed to remove xfrm state");
		return -1;
	}

	return 0;
}

static int xfrm_state_allocspi(int xfrm_sock, uint32_t *seq,
		uint32_t spi, uint8_t proto)
{
	struct {
		struct nlmsghdr			nh;
		struct xfrm_userspi_info	spi;
	} req;
	struct {
		struct nlmsghdr			nh;
		union {
			struct xfrm_usersa_info	info;
			int error;
		};
	} answer;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.spi));
	req.nh.nlmsg_type	= XFRM_MSG_ALLOCSPI;
	req.nh.nlmsg_flags	= NLM_F_REQUEST;
	req.nh.nlmsg_seq	= (*seq)++;

	req.spi.info.family	= AF_INET;
	req.spi.min		= spi;
	req.spi.max		= spi;
	req.spi.info.id.proto	= proto;

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return KSFT_FAIL;
	}

	if (recv(xfrm_sock, &answer, sizeof(answer), 0) < 0) {
		pr_err("recv()");
		return KSFT_FAIL;
	} else if (answer.nh.nlmsg_type == XFRM_MSG_NEWSA) {
		uint32_t new_spi = htonl(answer.info.id.spi);

		if (new_spi != spi) {
			printk("allocated spi is different from requested: %#x != %#x",
					new_spi, spi);
			return KSFT_FAIL;
		}
		return KSFT_PASS;
	} else if (answer.nh.nlmsg_type != NLMSG_ERROR) {
		printk("expected NLMSG_ERROR, got %d", (int)answer.nh.nlmsg_type);
		return KSFT_FAIL;
	}

	printk("NLMSG_ERROR: %d: %s", answer.error, strerror(-answer.error));
	return (answer.error) ? KSFT_FAIL : KSFT_PASS;
}

static int netlink_sock_bind(int *sock, uint32_t *seq, int proto, uint32_t groups)
{
	struct sockaddr_nl snl = {};
	socklen_t addr_len;
	int ret = -1;

	snl.nl_family = AF_NETLINK;
	snl.nl_groups = groups;

	if (netlink_sock(sock, seq, proto)) {
		printk("Failed to open xfrm netlink socket");
		return -1;
	}

	if (bind(*sock, (struct sockaddr *)&snl, sizeof(snl)) < 0) {
		pr_err("bind()");
		goto out_close;
	}

	addr_len = sizeof(snl);
	if (getsockname(*sock, (struct sockaddr *)&snl, &addr_len) < 0) {
		pr_err("getsockname()");
		goto out_close;
	}
	if (addr_len != sizeof(snl)) {
		printk("Wrong address length %d", addr_len);
		goto out_close;
	}
	if (snl.nl_family != AF_NETLINK) {
		printk("Wrong address family %d", snl.nl_family);
		goto out_close;
	}
	return 0;

out_close:
	close(*sock);
	return ret;
}

static int xfrm_monitor_acquire(int xfrm_sock, uint32_t *seq, unsigned int nr)
{
	struct {
		struct nlmsghdr nh;
		union {
			struct xfrm_user_acquire acq;
			int error;
		};
		char attrbuf[MAX_PAYLOAD];
	} req;
	struct xfrm_user_tmpl xfrm_tmpl = {};
	int xfrm_listen = -1, ret = KSFT_FAIL;
	uint32_t seq_listen;

	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_ACQUIRE))
		return KSFT_FAIL;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.acq));
	req.nh.nlmsg_type	= XFRM_MSG_ACQUIRE;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
	req.nh.nlmsg_seq	= (*seq)++;

	req.acq.policy.sel.family	= AF_INET;
	req.acq.aalgos	= 0xfeed;
	req.acq.ealgos	= 0xbaad;
	req.acq.calgos	= 0xbabe;

	xfrm_tmpl.family = AF_INET;
	xfrm_tmpl.id.proto = IPPROTO_ESP;
	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_TMPL, &xfrm_tmpl, sizeof(xfrm_tmpl)))
		goto out_close;

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		goto out_close;
	}

	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
		pr_err("recv()");
		goto out_close;
	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
		goto out_close;
	}

	if (req.error) {
		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
		ret = req.error;
		goto out_close;
	}

	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
		pr_err("recv()");
		goto out_close;
	}

	if (req.acq.aalgos != 0xfeed || req.acq.ealgos != 0xbaad
			|| req.acq.calgos != 0xbabe) {
		printk("xfrm_user_acquire has changed  %x %x %x",
				req.acq.aalgos, req.acq.ealgos, req.acq.calgos);
		goto out_close;
	}

	ret = KSFT_PASS;
out_close:
	close(xfrm_listen);
	return ret;
}

static int xfrm_expire_state(int xfrm_sock, uint32_t *seq,
		unsigned int nr, struct xfrm_desc *desc)
{
	struct {
		struct nlmsghdr nh;
		union {
			struct xfrm_user_expire expire;
			int error;
		};
	} req;
	struct in_addr src, dst;
	int xfrm_listen = -1, ret = KSFT_FAIL;
	uint32_t seq_listen;

	src = inet_makeaddr(INADDR_B, child_ip(nr));
	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));

	if (xfrm_state_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst, desc)) {
		printk("Failed to add xfrm state");
		return KSFT_FAIL;
	}

	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
		return KSFT_FAIL;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
	req.nh.nlmsg_type	= XFRM_MSG_EXPIRE;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
	req.nh.nlmsg_seq	= (*seq)++;

	memcpy(&req.expire.state.id.daddr, &dst, sizeof(dst));
	req.expire.state.id.spi		= gen_spi(src);
	req.expire.state.id.proto	= desc->proto;
	req.expire.state.family		= AF_INET;
	req.expire.hard			= 0xff;

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		goto out_close;
	}

	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
		pr_err("recv()");
		goto out_close;
	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
		goto out_close;
	}

	if (req.error) {
		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
		ret = req.error;
		goto out_close;
	}

	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
		pr_err("recv()");
		goto out_close;
	}

	if (req.expire.hard != 0x1) {
		printk("expire.hard is not set: %x", req.expire.hard);
		goto out_close;
	}

	ret = KSFT_PASS;
out_close:
	close(xfrm_listen);
	return ret;
}

static int xfrm_expire_policy(int xfrm_sock, uint32_t *seq,
		unsigned int nr, struct xfrm_desc *desc)
{
	struct {
		struct nlmsghdr nh;
		union {
			struct xfrm_user_polexpire expire;
			int error;
		};
	} req;
	struct in_addr src, dst, tunsrc, tundst;
	int xfrm_listen = -1, ret = KSFT_FAIL;
	uint32_t seq_listen;

	src = inet_makeaddr(INADDR_B, child_ip(nr));
	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));

	if (xfrm_policy_add(xfrm_sock, (*seq)++, gen_spi(src), src, dst,
				XFRM_POLICY_OUT, tunsrc, tundst, desc->proto)) {
		printk("Failed to add xfrm policy");
		return KSFT_FAIL;
	}

	if (netlink_sock_bind(&xfrm_listen, &seq_listen, NETLINK_XFRM, XFRMNLGRP_EXPIRE))
		return KSFT_FAIL;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.expire));
	req.nh.nlmsg_type	= XFRM_MSG_POLEXPIRE;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
	req.nh.nlmsg_seq	= (*seq)++;

	/* Fill selector. */
	memcpy(&req.expire.pol.sel.daddr, &dst, sizeof(tundst));
	memcpy(&req.expire.pol.sel.saddr, &src, sizeof(tunsrc));
	req.expire.pol.sel.family	= AF_INET;
	req.expire.pol.sel.prefixlen_d	= PREFIX_LEN;
	req.expire.pol.sel.prefixlen_s	= PREFIX_LEN;
	req.expire.pol.dir		= XFRM_POLICY_OUT;
	req.expire.hard			= 0xff;

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		goto out_close;
	}

	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
		pr_err("recv()");
		goto out_close;
	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
		goto out_close;
	}

	if (req.error) {
		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
		ret = req.error;
		goto out_close;
	}

	if (recv(xfrm_listen, &req, sizeof(req), 0) < 0) {
		pr_err("recv()");
		goto out_close;
	}

	if (req.expire.hard != 0x1) {
		printk("expire.hard is not set: %x", req.expire.hard);
		goto out_close;
	}

	ret = KSFT_PASS;
out_close:
	close(xfrm_listen);
	return ret;
}

static int xfrm_spdinfo_set_thresh(int xfrm_sock, uint32_t *seq,
		unsigned thresh4_l, unsigned thresh4_r,
		unsigned thresh6_l, unsigned thresh6_r,
		bool add_bad_attr)

{
	struct {
		struct nlmsghdr		nh;
		union {
			uint32_t	unused;
			int		error;
		};
		char			attrbuf[MAX_PAYLOAD];
	} req;
	struct xfrmu_spdhthresh thresh;

	memset(&req, 0, sizeof(req));
	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
	req.nh.nlmsg_type	= XFRM_MSG_NEWSPDINFO;
	req.nh.nlmsg_flags	= NLM_F_REQUEST | NLM_F_ACK;
	req.nh.nlmsg_seq	= (*seq)++;

	thresh.lbits = thresh4_l;
	thresh.rbits = thresh4_r;
	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV4_HTHRESH, &thresh, sizeof(thresh)))
		return -1;

	thresh.lbits = thresh6_l;
	thresh.rbits = thresh6_r;
	if (rtattr_pack(&req.nh, sizeof(req), XFRMA_SPD_IPV6_HTHRESH, &thresh, sizeof(thresh)))
		return -1;

	if (add_bad_attr) {
		BUILD_BUG_ON(XFRMA_IF_ID <= XFRMA_SPD_MAX + 1);
		if (rtattr_pack(&req.nh, sizeof(req), XFRMA_IF_ID, NULL, 0)) {
			pr_err("adding attribute failed: no space");
			return -1;
		}
	}

	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return -1;
	}

	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
		pr_err("recv()");
		return -1;
	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
		return -1;
	}

	if (req.error) {
		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
		return -1;
	}

	return 0;
}

static int xfrm_spdinfo_attrs(int xfrm_sock, uint32_t *seq)
{
	struct {
		struct nlmsghdr			nh;
		union {
			uint32_t	unused;
			int		error;
		};
		char			attrbuf[MAX_PAYLOAD];
	} req;

	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 31, 120, 16, false)) {
		pr_err("Can't set SPD HTHRESH");
		return KSFT_FAIL;
	}

	memset(&req, 0, sizeof(req));

	req.nh.nlmsg_len	= NLMSG_LENGTH(sizeof(req.unused));
	req.nh.nlmsg_type	= XFRM_MSG_GETSPDINFO;
	req.nh.nlmsg_flags	= NLM_F_REQUEST;
	req.nh.nlmsg_seq	= (*seq)++;
	if (send(xfrm_sock, &req, req.nh.nlmsg_len, 0) < 0) {
		pr_err("send()");
		return KSFT_FAIL;
	}

	if (recv(xfrm_sock, &req, sizeof(req), 0) < 0) {
		pr_err("recv()");
		return KSFT_FAIL;
	} else if (req.nh.nlmsg_type == XFRM_MSG_NEWSPDINFO) {
		size_t len = NLMSG_PAYLOAD(&req.nh, sizeof(req.unused));
		struct rtattr *attr = (void *)req.attrbuf;
		int got_thresh = 0;

		for (; RTA_OK(attr, len); attr = RTA_NEXT(attr, len)) {
			if (attr->rta_type == XFRMA_SPD_IPV4_HTHRESH) {
				struct xfrmu_spdhthresh *t = RTA_DATA(attr);

				got_thresh++;
				if (t->lbits != 32 || t->rbits != 31) {
					pr_err("thresh differ: %u, %u",
							t->lbits, t->rbits);
					return KSFT_FAIL;
				}
			}
			if (attr->rta_type == XFRMA_SPD_IPV6_HTHRESH) {
				struct xfrmu_spdhthresh *t = RTA_DATA(attr);

				got_thresh++;
				if (t->lbits != 120 || t->rbits != 16) {
					pr_err("thresh differ: %u, %u",
							t->lbits, t->rbits);
					return KSFT_FAIL;
				}
			}
		}
		if (got_thresh != 2) {
			pr_err("only %d thresh returned by XFRM_MSG_GETSPDINFO", got_thresh);
			return KSFT_FAIL;
		}
	} else if (req.nh.nlmsg_type != NLMSG_ERROR) {
		printk("expected NLMSG_ERROR, got %d", (int)req.nh.nlmsg_type);
		return KSFT_FAIL;
	} else {
		printk("NLMSG_ERROR: %d: %s", req.error, strerror(-req.error));
		return -1;
	}

	/* Restore the default */
	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, false)) {
		pr_err("Can't restore SPD HTHRESH");
		return KSFT_FAIL;
	}

	/*
	 * At this moment xfrm uses nlmsg_parse_deprecated(), which
	 * implies NL_VALIDATE_LIBERAL - ignoring attributes with
	 * (type > maxtype). nla_parse_depricated_strict() would enforce
	 * it. Or even stricter nla_parse().
	 * Right now it's not expected to fail, but to be ignored.
	 */
	if (xfrm_spdinfo_set_thresh(xfrm_sock, seq, 32, 32, 128, 128, true))
		return KSFT_PASS;

	return KSFT_PASS;
}

static int child_serv(int xfrm_sock, uint32_t *seq,
		unsigned int nr, int cmd_fd, void *buf, struct xfrm_desc *desc)
{
	struct in_addr src, dst, tunsrc, tundst;
	struct test_desc msg;
	int ret = KSFT_FAIL;

	src = inet_makeaddr(INADDR_B, child_ip(nr));
	dst = inet_makeaddr(INADDR_B, grchild_ip(nr));
	tunsrc = inet_makeaddr(INADDR_A, child_ip(nr));
	tundst = inet_makeaddr(INADDR_A, grchild_ip(nr));

	/* UDP pinging without xfrm */
	if (do_ping(cmd_fd, buf, page_size, src, true, 0, 0, udp_ping_send)) {
		printk("ping failed before setting xfrm");
		return KSFT_FAIL;
	}

	memset(&msg, 0, sizeof(msg));
	msg.type = MSG_XFRM_PREPARE;
	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
	write_msg(cmd_fd, &msg, 1);

	if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
		printk("failed to prepare xfrm");
		goto cleanup;
	}

	memset(&msg, 0, sizeof(msg));
	msg.type = MSG_XFRM_ADD;
	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
	write_msg(cmd_fd, &msg, 1);
	if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
		printk("failed to set xfrm");
		goto delete;
	}

	/* UDP pinging with xfrm tunnel */
	if (do_ping(cmd_fd, buf, page_size, tunsrc,
				true, 0, 0, udp_ping_send)) {
		printk("ping failed for xfrm");
		goto delete;
	}

	ret = KSFT_PASS;
delete:
	/* xfrm delete */
	memset(&msg, 0, sizeof(msg));
	msg.type = MSG_XFRM_DEL;
	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
	write_msg(cmd_fd, &msg, 1);

	if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst, desc->proto)) {
		printk("failed ping to remove xfrm");
		ret = KSFT_FAIL;
	}

cleanup:
	memset(&msg, 0, sizeof(msg));
	msg.type = MSG_XFRM_CLEANUP;
	memcpy(&msg.body.xfrm_desc, desc, sizeof(*desc));
	write_msg(cmd_fd, &msg, 1);
	if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
		printk("failed ping to cleanup xfrm");
		ret = KSFT_FAIL;
	}
	return ret;
}

static int child_f(unsigned int nr, int test_desc_fd, int cmd_fd, void *buf)
{
	struct xfrm_desc desc;
	struct test_desc msg;
	int xfrm_sock = -1;
	uint32_t seq;

	if (switch_ns(nsfd_childa))
		exit(KSFT_FAIL);

	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
		printk("Failed to open xfrm netlink socket");
		exit(KSFT_FAIL);
	}

	/* Check that seq sock is ready, just for sure. */
	memset(&msg, 0, sizeof(msg));
	msg.type = MSG_ACK;
	write_msg(cmd_fd, &msg, 1);
	read_msg(cmd_fd, &msg, 1);
	if (msg.type != MSG_ACK) {
		printk("Ack failed");
		exit(KSFT_FAIL);
	}

	for (;;) {
		ssize_t received = read(test_desc_fd, &desc, sizeof(desc));
		int ret;

		if (received == 0) /* EOF */
			break;

		if (received != sizeof(desc)) {
			pr_err("read() returned %zd", received);
			exit(KSFT_FAIL);
		}

		switch (desc.type) {
		case CREATE_TUNNEL:
			ret = child_serv(xfrm_sock, &seq, nr,
					 cmd_fd, buf, &desc);
			break;
		case ALLOCATE_SPI:
			ret = xfrm_state_allocspi(xfrm_sock, &seq,
						  -1, desc.proto);
			break;
		case MONITOR_ACQUIRE:
			ret = xfrm_monitor_acquire(xfrm_sock, &seq, nr);
			break;
		case EXPIRE_STATE:
			ret = xfrm_expire_state(xfrm_sock, &seq, nr, &desc);
			break;
		case EXPIRE_POLICY:
			ret = xfrm_expire_policy(xfrm_sock, &seq, nr, &desc);
			break;
		case SPDINFO_ATTRS:
			ret = xfrm_spdinfo_attrs(xfrm_sock, &seq);
			break;
		default:
			printk("Unknown desc type %d", desc.type);
			exit(KSFT_FAIL);
		}
		write_test_result(ret, &desc);
	}

	close(xfrm_sock);

	msg.type = MSG_EXIT;
	write_msg(cmd_fd, &msg, 1);
	exit(KSFT_PASS);
}

static void grand_child_serv(unsigned int nr, int cmd_fd, void *buf,
		struct test_desc *msg, int xfrm_sock, uint32_t *seq)
{
	struct in_addr src, dst, tunsrc, tundst;
	bool tun_reply;
	struct xfrm_desc *desc = &msg->body.xfrm_desc;

	src = inet_makeaddr(INADDR_B, grchild_ip(nr));
	dst = inet_makeaddr(INADDR_B, child_ip(nr));
	tunsrc = inet_makeaddr(INADDR_A, grchild_ip(nr));
	tundst = inet_makeaddr(INADDR_A, child_ip(nr));

	switch (msg->type) {
	case MSG_EXIT:
		exit(KSFT_PASS);
	case MSG_ACK:
		write_msg(cmd_fd, msg, 1);
		break;
	case MSG_PING:
		tun_reply = memcmp(&dst, &msg->body.ping.reply_ip, sizeof(in_addr_t));
		/* UDP pinging without xfrm */
		if (do_ping(cmd_fd, buf, page_size, tun_reply ? tunsrc : src,
				false, msg->body.ping.port,
				msg->body.ping.reply_ip, udp_ping_reply)) {
			printk("ping failed before setting xfrm");
		}
		break;
	case MSG_XFRM_PREPARE:
		if (xfrm_prepare(xfrm_sock, seq, src, dst, tunsrc, tundst,
					desc->proto)) {
			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
			printk("failed to prepare xfrm");
		}
		break;
	case MSG_XFRM_ADD:
		if (xfrm_set(xfrm_sock, seq, src, dst, tunsrc, tundst, desc)) {
			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
			printk("failed to set xfrm");
		}
		break;
	case MSG_XFRM_DEL:
		if (xfrm_delete(xfrm_sock, seq, src, dst, tunsrc, tundst,
					desc->proto)) {
			xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst);
			printk("failed to remove xfrm");
		}
		break;
	case MSG_XFRM_CLEANUP:
		if (xfrm_cleanup(xfrm_sock, seq, src, dst, tunsrc, tundst)) {
			printk("failed to cleanup xfrm");
		}
		break;
	default:
		printk("got unknown msg type %d", msg->type);
	}
}

static int grand_child_f(unsigned int nr, int cmd_fd, void *buf)
{
	struct test_desc msg;
	int xfrm_sock = -1;
	uint32_t seq;

	if (switch_ns(nsfd_childb))
		exit(KSFT_FAIL);

	if (netlink_sock(&xfrm_sock, &seq, NETLINK_XFRM)) {
		printk("Failed to open xfrm netlink socket");
		exit(KSFT_FAIL);
	}

	do {
		read_msg(cmd_fd, &msg, 1);
		grand_child_serv(nr, cmd_fd, buf, &msg, xfrm_sock, &seq);
	} while (1);

	close(xfrm_sock);
	exit(KSFT_FAIL);
}

static int start_child(unsigned int nr, char *veth, int test_desc_fd[2])
{
	int cmd_sock[2];
	void *data_map;
	pid_t child;

	if (init_child(nsfd_childa, veth, child_ip(nr), grchild_ip(nr)))
		return -1;

	if (init_child(nsfd_childb, veth, grchild_ip(nr), child_ip(nr)))
		return -1;

	child = fork();
	if (child < 0) {
		pr_err("fork()");
		return -1;
	} else if (child) {
		/* in parent - selftest */
		return switch_ns(nsfd_parent);
	}

	if (close(test_desc_fd[1])) {
		pr_err("close()");
		return -1;
	}

	/* child */
	data_map = mmap(0, page_size, PROT_READ | PROT_WRITE,
			MAP_SHARED | MAP_ANONYMOUS, -1, 0);
	if (data_map == MAP_FAILED) {
		pr_err("mmap()");
		return -1;
	}

	randomize_buffer(data_map, page_size);

	if (socketpair(PF_LOCAL, SOCK_SEQPACKET, 0, cmd_sock)) {
		pr_err("socketpair()");
		return -1;
	}

	child = fork();
	if (child < 0) {
		pr_err("fork()");
		return -1;
	} else if (child) {
		if (close(cmd_sock[0])) {
			pr_err("close()");
			return -1;
		}
		return child_f(nr, test_desc_fd[0], cmd_sock[1], data_map);
	}
	if (close(cmd_sock[1])) {
		pr_err("close()");
		return -1;
	}
	return grand_child_f(nr, cmd_sock[0], data_map);
}

static void exit_usage(char **argv)
{
	printk("Usage: %s [nr_process]", argv[0]);
	exit(KSFT_FAIL);
}

static int __write_desc(int test_desc_fd, struct xfrm_desc *desc)
{
	ssize_t ret;

	ret = write(test_desc_fd, desc, sizeof(*desc));

	if (ret == sizeof(*desc))
		return 0;

	pr_err("Writing test's desc failed %ld", ret);

	return -1;
}

static int write_desc(int proto, int test_desc_fd,
		char *a, char *e, char *c, char *ae)
{
	struct xfrm_desc desc = {};

	desc.type = CREATE_TUNNEL;
	desc.proto = proto;

	if (a)
		strncpy(desc.a_algo, a, ALGO_LEN - 1);
	if (e)
		strncpy(desc.e_algo, e, ALGO_LEN - 1);
	if (c)
		strncpy(desc.c_algo, c, ALGO_LEN - 1);
	if (ae)
		strncpy(desc.ae_algo, ae, ALGO_LEN - 1);

	return __write_desc(test_desc_fd, &desc);
}

int proto_list[] = { IPPROTO_AH, IPPROTO_COMP, IPPROTO_ESP };
char *ah_list[] = {
	"digest_null", "hmac(md5)", "hmac(sha1)", "hmac(sha256)",
	"hmac(sha384)", "hmac(sha512)", "hmac(rmd160)",
	"xcbc(aes)", "cmac(aes)"
};
char *comp_list[] = {
	"deflate",
#if 0
	/* No compression backend realization */
	"lzs", "lzjh"
#endif
};
char *e_list[] = {
	"ecb(cipher_null)", "cbc(des)", "cbc(des3_ede)", "cbc(cast5)",
	"cbc(blowfish)", "cbc(aes)", "cbc(serpent)", "cbc(camellia)",
	"cbc(twofish)", "rfc3686(ctr(aes))"
};
char *ae_list[] = {
#if 0
	/* not implemented */
	"rfc4106(gcm(aes))", "rfc4309(ccm(aes))", "rfc4543(gcm(aes))",
	"rfc7539esp(chacha20,poly1305)"
#endif
};

const unsigned int proto_plan = ARRAY_SIZE(ah_list) + ARRAY_SIZE(comp_list) \
				+ (ARRAY_SIZE(ah_list) * ARRAY_SIZE(e_list)) \
				+ ARRAY_SIZE(ae_list);

static int write_proto_plan(int fd, int proto)
{
	unsigned int i;

	switch (proto) {
	case IPPROTO_AH:
		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
			if (write_desc(proto, fd, ah_list[i], 0, 0, 0))
				return -1;
		}
		break;
	case IPPROTO_COMP:
		for (i = 0; i < ARRAY_SIZE(comp_list); i++) {
			if (write_desc(proto, fd, 0, 0, comp_list[i], 0))
				return -1;
		}
		break;
	case IPPROTO_ESP:
		for (i = 0; i < ARRAY_SIZE(ah_list); i++) {
			int j;

			for (j = 0; j < ARRAY_SIZE(e_list); j++) {
				if (write_desc(proto, fd, ah_list[i],
							e_list[j], 0, 0))
					return -1;
			}
		}
		for (i = 0; i < ARRAY_SIZE(ae_list); i++) {
			if (write_desc(proto, fd, 0, 0, 0, ae_list[i]))
				return -1;
		}
		break;
	default:
		printk("BUG: Specified unknown proto %d", proto);
		return -1;
	}

	return 0;
}

/*
 * Some structures in xfrm uapi header differ in size between
 * 64-bit and 32-bit ABI:
 *
 *             32-bit UABI               |            64-bit UABI
 *  -------------------------------------|-------------------------------------
 *   sizeof(xfrm_usersa_info)     = 220  |  sizeof(xfrm_usersa_info)     = 224
 *   sizeof(xfrm_userpolicy_info) = 164  |  sizeof(xfrm_userpolicy_info) = 168
 *   sizeof(xfrm_userspi_info)    = 228  |  sizeof(xfrm_userspi_info)    = 232
 *   sizeof(xfrm_user_acquire)    = 276  |  sizeof(xfrm_user_acquire)    = 280
 *   sizeof(xfrm_user_expire)     = 224  |  sizeof(xfrm_user_expire)     = 232
 *   sizeof(xfrm_user_polexpire)  = 168  |  sizeof(xfrm_user_polexpire)  = 176
 *
 * Check the affected by the UABI difference structures.
 * Also, check translation for xfrm_set_spdinfo: it has it's own attributes
 * which needs to be correctly copied, but not translated.
 */
const unsigned int compat_plan = 5;
static int write_compat_struct_tests(int test_desc_fd)
{
	struct xfrm_desc desc = {};

	desc.type = ALLOCATE_SPI;
	desc.proto = IPPROTO_AH;
	strncpy(desc.a_algo, ah_list[0], ALGO_LEN - 1);

	if (__write_desc(test_desc_fd, &desc))
		return -1;

	desc.type = MONITOR_ACQUIRE;
	if (__write_desc(test_desc_fd, &desc))
		return -1;

	desc.type = EXPIRE_STATE;
	if (__write_desc(test_desc_fd, &desc))
		return -1;

	desc.type = EXPIRE_POLICY;
	if (__write_desc(test_desc_fd, &desc))
		return -1;

	desc.type = SPDINFO_ATTRS;
	if (__write_desc(test_desc_fd, &desc))
		return -1;

	return 0;
}

static int write_test_plan(int test_desc_fd)
{
	unsigned int i;
	pid_t child;

	child = fork();
	if (child < 0) {
		pr_err("fork()");
		return -1;
	}
	if (child) {
		if (close(test_desc_fd))
			printk("close(): %m");
		return 0;
	}

	if (write_compat_struct_tests(test_desc_fd))
		exit(KSFT_FAIL);

	for (i = 0; i < ARRAY_SIZE(proto_list); i++) {
		if (write_proto_plan(test_desc_fd, proto_list[i]))
			exit(KSFT_FAIL);
	}

	exit(KSFT_PASS);
}

static int children_cleanup(void)
{
	unsigned ret = KSFT_PASS;

	while (1) {
		int status;
		pid_t p = wait(&status);

		if ((p < 0) && errno == ECHILD)
			break;

		if (p < 0) {
			pr_err("wait()");
			return KSFT_FAIL;
		}

		if (!WIFEXITED(status)) {
			ret = KSFT_FAIL;
			continue;
		}

		if (WEXITSTATUS(status) == KSFT_FAIL)
			ret = KSFT_FAIL;
	}

	return ret;
}

typedef void (*print_res)(const char *, ...);

static int check_results(void)
{
	struct test_result tr = {};
	struct xfrm_desc *d = &tr.desc;
	int ret = KSFT_PASS;

	while (1) {
		ssize_t received = read(results_fd[0], &tr, sizeof(tr));
		print_res result;

		if (received == 0) /* EOF */
			break;

		if (received != sizeof(tr)) {
			pr_err("read() returned %zd", received);
			return KSFT_FAIL;
		}

		switch (tr.res) {
		case KSFT_PASS:
			result = ksft_test_result_pass;
			break;
		case KSFT_FAIL:
		default:
			result = ksft_test_result_fail;
			ret = KSFT_FAIL;
		}

		result(" %s: [%u, '%s', '%s', '%s', '%s', %u]\n",
		       desc_name[d->type], (unsigned int)d->proto, d->a_algo,
		       d->e_algo, d->c_algo, d->ae_algo, d->icv_len);
	}

	return ret;
}

int main(int argc, char **argv)
{
	long nr_process = 1;
	int route_sock = -1, ret = KSFT_SKIP;
	int test_desc_fd[2];
	uint32_t route_seq;
	unsigned int i;

	if (argc > 2)
		exit_usage(argv);

	if (argc > 1) {
		char *endptr;

		errno = 0;
		nr_process = strtol(argv[1], &endptr, 10);
		if ((errno == ERANGE && (nr_process == LONG_MAX || nr_process == LONG_MIN))
				|| (errno != 0 && nr_process == 0)
				|| (endptr == argv[1]) || (*endptr != '\0')) {
			printk("Failed to parse [nr_process]");
			exit_usage(argv);
		}

		if (nr_process > MAX_PROCESSES || nr_process < 1) {
			printk("nr_process should be between [1; %u]",
					MAX_PROCESSES);
			exit_usage(argv);
		}
	}

	srand(time(NULL));
	page_size = sysconf(_SC_PAGESIZE);
	if (page_size < 1)
		ksft_exit_skip("sysconf(): %m\n");

	if (pipe2(test_desc_fd, O_DIRECT) < 0)
		ksft_exit_skip("pipe(): %m\n");

	if (pipe2(results_fd, O_DIRECT) < 0)
		ksft_exit_skip("pipe(): %m\n");

	if (init_namespaces())
		ksft_exit_skip("Failed to create namespaces\n");

	if (netlink_sock(&route_sock, &route_seq, NETLINK_ROUTE))
		ksft_exit_skip("Failed to open netlink route socket\n");

	for (i = 0; i < nr_process; i++) {
		char veth[VETH_LEN];

		snprintf(veth, VETH_LEN, VETH_FMT, i);

		if (veth_add(route_sock, route_seq++, veth, nsfd_childa, veth, nsfd_childb)) {
			close(route_sock);
			ksft_exit_fail_msg("Failed to create veth device");
		}

		if (start_child(i, veth, test_desc_fd)) {
			close(route_sock);
			ksft_exit_fail_msg("Child %u failed to start", i);
		}
	}

	if (close(route_sock) || close(test_desc_fd[0]) || close(results_fd[1]))
		ksft_exit_fail_msg("close(): %m");

	ksft_set_plan(proto_plan + compat_plan);

	if (write_test_plan(test_desc_fd[1]))
		ksft_exit_fail_msg("Failed to write test plan to pipe");

	ret = check_results();

	if (children_cleanup() == KSFT_FAIL)
		exit(KSFT_FAIL);

	exit(ret);
}