Contributors: 11
Author Tokens Token Proportion Commits Commit Proportion
Jakub Kiciński 1239 91.98% 7 26.92%
Linus Torvalds (pre-git) 39 2.90% 5 19.23%
Eric Dumazet 20 1.48% 4 15.38%
Américo Wang 13 0.97% 1 3.85%
David S. Miller 9 0.67% 1 3.85%
Hideaki Yoshifuji / 吉藤英明 8 0.59% 2 7.69%
Daniel Borkmann 6 0.45% 1 3.85%
Arnaldo Carvalho de Melo 4 0.30% 2 7.69%
Adrian Bunk 4 0.30% 1 3.85%
Steffen Klassert 3 0.22% 1 3.85%
Herbert Xu 2 0.15% 1 3.85%
Total 1347 26


// SPDX-License-Identifier: GPL-2.0-only

#include <linux/file.h>
#include <linux/net.h>
#include <linux/rcupdate.h>
#include <linux/tcp.h>

#include <net/ip.h>
#include <net/psp.h>
#include "psp.h"

struct psp_dev *psp_dev_get_for_sock(struct sock *sk)
{
	struct psp_dev *psd = NULL;
	struct dst_entry *dst;

	rcu_read_lock();
	dst = __sk_dst_get(sk);
	if (dst) {
		psd = rcu_dereference(dst_dev_rcu(dst)->psp_dev);
		if (psd && !psp_dev_tryget(psd))
			psd = NULL;
	}
	rcu_read_unlock();

	return psd;
}

static struct sk_buff *
psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb)
{
	struct psp_assoc *pas;
	bool good;

	rcu_read_lock();
	pas = psp_skb_get_assoc_rcu(skb);
	good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd;
	rcu_read_unlock();
	if (!good) {
		sk_skb_reason_drop(sk, skb, SKB_DROP_REASON_PSP_OUTPUT);
		return NULL;
	}

	return skb;
}

struct psp_assoc *psp_assoc_create(struct psp_dev *psd)
{
	struct psp_assoc *pas;

	lockdep_assert_held(&psd->lock);

	pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc),
		      GFP_KERNEL_ACCOUNT);
	if (!pas)
		return NULL;

	pas->psd = psd;
	pas->dev_id = psd->id;
	pas->generation = psd->generation;
	psp_dev_get(psd);
	refcount_set(&pas->refcnt, 1);

	list_add_tail(&pas->assocs_list, &psd->active_assocs);

	return pas;
}

static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas)
{
	struct psp_dev *psd = pas->psd;
	size_t sz;

	lockdep_assert_held(&psd->lock);

	sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc);
	return kmemdup(pas, sz, GFP_KERNEL);
}

static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas,
			      struct netlink_ext_ack *extack)
{
	return psd->ops->tx_key_add(psd, pas, extack);
}

void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas)
{
	if (pas->tx.spi)
		psd->ops->tx_key_del(psd, pas);
	list_del(&pas->assocs_list);
}

static void psp_assoc_free(struct work_struct *work)
{
	struct psp_assoc *pas = container_of(work, struct psp_assoc, work);
	struct psp_dev *psd = pas->psd;

	mutex_lock(&psd->lock);
	if (psd->ops)
		psp_dev_tx_key_del(psd, pas);
	mutex_unlock(&psd->lock);
	psp_dev_put(psd);
	kfree(pas);
}

static void psp_assoc_free_queue(struct rcu_head *head)
{
	struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu);

	INIT_WORK(&pas->work, psp_assoc_free);
	schedule_work(&pas->work);
}

/**
 * psp_assoc_put() - release a reference on a PSP association
 * @pas: association to release
 */
void psp_assoc_put(struct psp_assoc *pas)
{
	if (pas && refcount_dec_and_test(&pas->refcnt))
		call_rcu(&pas->rcu, psp_assoc_free_queue);
}

void psp_sk_assoc_free(struct sock *sk)
{
	struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1);

	rcu_assign_pointer(sk->psp_assoc, NULL);
	psp_assoc_put(pas);
}

int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas,
			  struct psp_key_parsed *key,
			  struct netlink_ext_ack *extack)
{
	int err;

	memcpy(&pas->rx, key, sizeof(*key));

	lock_sock(sk);

	if (psp_sk_assoc(sk)) {
		NL_SET_ERR_MSG(extack, "Socket already has PSP state");
		err = -EBUSY;
		goto exit_unlock;
	}

	refcount_inc(&pas->refcnt);
	rcu_assign_pointer(sk->psp_assoc, pas);
	err = 0;

exit_unlock:
	release_sock(sk);

	return err;
}

static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas)
{
	struct psp_skb_ext *pse;
	struct sk_buff *skb;

	skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) {
		pse = skb_ext_find(skb, SKB_EXT_PSP);
		if (!psp_pse_matches_pas(pse, pas))
			return -EBUSY;
	}

	skb_queue_walk(&sk->sk_receive_queue, skb) {
		pse = skb_ext_find(skb, SKB_EXT_PSP);
		if (!psp_pse_matches_pas(pse, pas))
			return -EBUSY;
	}
	return 0;
}

int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd,
			  u32 version, struct psp_key_parsed *key,
			  struct netlink_ext_ack *extack)
{
	struct inet_connection_sock *icsk;
	struct psp_assoc *pas, *dummy;
	int err;

	lock_sock(sk);

	pas = psp_sk_assoc(sk);
	if (!pas) {
		NL_SET_ERR_MSG(extack, "Socket has no Rx key");
		err = -EINVAL;
		goto exit_unlock;
	}
	if (pas->psd != psd) {
		NL_SET_ERR_MSG(extack, "Rx key from different device");
		err = -EINVAL;
		goto exit_unlock;
	}
	if (pas->version != version) {
		NL_SET_ERR_MSG(extack,
			       "PSP version mismatch with existing state");
		err = -EINVAL;
		goto exit_unlock;
	}
	if (pas->tx.spi) {
		NL_SET_ERR_MSG(extack, "Tx key already set");
		err = -EBUSY;
		goto exit_unlock;
	}

	err = psp_sock_recv_queue_check(sk, pas);
	if (err) {
		NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue");
		goto exit_unlock;
	}

	/* Pass a fake association to drivers to make sure they don't
	 * try to store pointers to it. For re-keying we'll need to
	 * re-allocate the assoc structures.
	 */
	dummy = psp_assoc_dummy(pas);
	if (!dummy) {
		err = -ENOMEM;
		goto exit_unlock;
	}

	memcpy(&dummy->tx, key, sizeof(*key));
	err = psp_dev_tx_key_add(psd, dummy, extack);
	if (err)
		goto exit_free_dummy;

	memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc);
	memcpy(&pas->tx, key, sizeof(*key));

	WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit);
	tcp_write_collapse_fence(sk);
	pas->upgrade_seq = tcp_sk(sk)->rcv_nxt;

	icsk = inet_csk(sk);
	icsk->icsk_ext_hdr_len += psp_sk_overhead(sk);
	icsk->icsk_sync_mss(sk, icsk->icsk_pmtu_cookie);

exit_free_dummy:
	kfree(dummy);
exit_unlock:
	release_sock(sk);
	return err;
}

void psp_assocs_key_rotated(struct psp_dev *psd)
{
	struct psp_assoc *pas, *next;

	/* Mark the stale associations as invalid, they will no longer
	 * be able to Rx any traffic.
	 */
	list_for_each_entry_safe(pas, next, &psd->prev_assocs, assocs_list)
		pas->generation |= ~PSP_GEN_VALID_MASK;
	list_splice_init(&psd->prev_assocs, &psd->stale_assocs);
	list_splice_init(&psd->active_assocs, &psd->prev_assocs);

	/* TODO: we should inform the sockets that got shut down */
}

void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk)
{
	struct psp_assoc *pas = psp_sk_assoc(sk);

	if (pas)
		refcount_inc(&pas->refcnt);
	rcu_assign_pointer(tw->psp_assoc, pas);
	tw->tw_validate_xmit_skb = psp_validate_xmit;
}

void psp_twsk_assoc_free(struct inet_timewait_sock *tw)
{
	struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1);

	rcu_assign_pointer(tw->psp_assoc, NULL);
	psp_assoc_put(pas);
}

void psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb)
{
	struct psp_assoc *pas;

	rcu_read_lock();
	pas = psp_sk_get_assoc_rcu(sk);
	if (pas && pas->tx.spi)
		skb->decrypted = 1;
	rcu_read_unlock();
}
EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);