Skip to content

Commit

Permalink
raw: convert raw sockets to RCU
Browse files Browse the repository at this point in the history
Using rwlock in networking code is extremely risky.
writers can starve if enough readers are constantly
grabing the rwlock.

I thought rwlock were at fault and sent this patch:

https://lkml.org/lkml/2022/6/17/272

But Peter and Linus essentially told me rwlock had to be unfair.

We need to get rid of rwlock in networking code.

Without this fix, following script triggers soft lockups:

for i in {1..48}
do
 ping -f -n -q 127.0.0.1 &
 sleep 0.1
done

Fixes: 1da177e ("Linux-2.6.12-rc2")
Signed-off-by: Eric Dumazet <[email protected]>
Signed-off-by: David S. Miller <[email protected]>
  • Loading branch information
Eric Dumazet authored and davem330 committed Jun 19, 2022
1 parent ba44f81 commit 0daf07e
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 70 deletions.
11 changes: 10 additions & 1 deletion include/net/raw.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,18 @@ int raw_rcv(struct sock *, struct sk_buff *);

struct raw_hashinfo {
rwlock_t lock;
struct hlist_head ht[RAW_HTABLE_SIZE];
struct hlist_nulls_head ht[RAW_HTABLE_SIZE];
};

static inline void raw_hashinfo_init(struct raw_hashinfo *hashinfo)
{
int i;

rwlock_init(&hashinfo->lock);
for (i = 0; i < RAW_HTABLE_SIZE; i++)
INIT_HLIST_NULLS_HEAD(&hashinfo->ht[i], i);
}

#ifdef CONFIG_PROC_FS
int raw_proc_init(void);
void raw_proc_exit(void);
Expand Down
1 change: 1 addition & 0 deletions include/net/rawv6.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#define _NET_RAWV6_H

#include <net/protocol.h>
#include <net/raw.h>

extern struct raw_hashinfo raw_v6_hashinfo;
bool raw_v6_match(struct net *net, struct sock *sk, unsigned short num,
Expand Down
2 changes: 2 additions & 0 deletions net/ipv4/af_inet.c
Original file line number Diff line number Diff line change
Expand Up @@ -1929,6 +1929,8 @@ static int __init inet_init(void)

sock_skb_cb_check_size(sizeof(struct inet_skb_parm));

raw_hashinfo_init(&raw_v4_hashinfo);

rc = proto_register(&tcp_prot, 1);
if (rc)
goto out;
Expand Down
83 changes: 38 additions & 45 deletions net/ipv4/raw.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,19 @@ struct raw_frag_vec {
int hlen;
};

struct raw_hashinfo raw_v4_hashinfo = {
.lock = __RW_LOCK_UNLOCKED(raw_v4_hashinfo.lock),
};
struct raw_hashinfo raw_v4_hashinfo;
EXPORT_SYMBOL_GPL(raw_v4_hashinfo);

int raw_hash_sk(struct sock *sk)
{
struct raw_hashinfo *h = sk->sk_prot->h.raw_hash;
struct hlist_head *head;
struct hlist_nulls_head *hlist;

head = &h->ht[inet_sk(sk)->inet_num & (RAW_HTABLE_SIZE - 1)];
hlist = &h->ht[inet_sk(sk)->inet_num & (RAW_HTABLE_SIZE - 1)];

write_lock_bh(&h->lock);
sk_add_node(sk, head);
hlist_nulls_add_head_rcu(&sk->sk_nulls_node, hlist);
sock_set_flag(sk, SOCK_RCU_FREE);
write_unlock_bh(&h->lock);
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);

Expand All @@ -111,7 +110,7 @@ void raw_unhash_sk(struct sock *sk)
struct raw_hashinfo *h = sk->sk_prot->h.raw_hash;

write_lock_bh(&h->lock);
if (sk_del_node_init(sk))
if (__sk_nulls_del_node_init_rcu(sk))
sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1);
write_unlock_bh(&h->lock);
}
Expand Down Expand Up @@ -164,17 +163,16 @@ static int icmp_filter(const struct sock *sk, const struct sk_buff *skb)
static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
{
struct net *net = dev_net(skb->dev);
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
int sdif = inet_sdif(skb);
int dif = inet_iif(skb);
struct hlist_head *head;
int delivered = 0;
struct sock *sk;

head = &raw_v4_hashinfo.ht[hash];
if (hlist_empty(head))
return 0;
read_lock(&raw_v4_hashinfo.lock);
sk_for_each(sk, head) {
hlist = &raw_v4_hashinfo.ht[hash];
rcu_read_lock();
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (!raw_v4_match(net, sk, iph->protocol,
iph->saddr, iph->daddr, dif, sdif))
continue;
Expand All @@ -189,7 +187,7 @@ static int raw_v4_input(struct sk_buff *skb, const struct iphdr *iph, int hash)
raw_rcv(sk, clone);
}
}
read_unlock(&raw_v4_hashinfo.lock);
rcu_read_unlock();
return delivered;
}

Expand Down Expand Up @@ -265,25 +263,26 @@ static void raw_err(struct sock *sk, struct sk_buff *skb, u32 info)
void raw_icmp_error(struct sk_buff *skb, int protocol, u32 info)
{
struct net *net = dev_net(skb->dev);;
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
int dif = skb->dev->ifindex;
int sdif = inet_sdif(skb);
struct hlist_head *head;
const struct iphdr *iph;
struct sock *sk;
int hash;

hash = protocol & (RAW_HTABLE_SIZE - 1);
head = &raw_v4_hashinfo.ht[hash];
hlist = &raw_v4_hashinfo.ht[hash];

read_lock(&raw_v4_hashinfo.lock);
sk_for_each(sk, head) {
rcu_read_lock();
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
iph = (const struct iphdr *)skb->data;
if (!raw_v4_match(net, sk, iph->protocol,
iph->saddr, iph->daddr, dif, sdif))
continue;
raw_err(sk, skb, info);
}
read_unlock(&raw_v4_hashinfo.lock);
rcu_read_unlock();
}

static int raw_rcv_skb(struct sock *sk, struct sk_buff *skb)
Expand Down Expand Up @@ -944,44 +943,41 @@ struct proto raw_prot = {
};

#ifdef CONFIG_PROC_FS
static struct sock *raw_get_first(struct seq_file *seq)
static struct sock *raw_get_first(struct seq_file *seq, int bucket)
{
struct sock *sk;
struct raw_hashinfo *h = pde_data(file_inode(seq->file));
struct raw_iter_state *state = raw_seq_private(seq);
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
struct sock *sk;

for (state->bucket = 0; state->bucket < RAW_HTABLE_SIZE;
for (state->bucket = bucket; state->bucket < RAW_HTABLE_SIZE;
++state->bucket) {
sk_for_each(sk, &h->ht[state->bucket])
hlist = &h->ht[state->bucket];
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (sock_net(sk) == seq_file_net(seq))
goto found;
return sk;
}
}
sk = NULL;
found:
return sk;
return NULL;
}

static struct sock *raw_get_next(struct seq_file *seq, struct sock *sk)
{
struct raw_hashinfo *h = pde_data(file_inode(seq->file));
struct raw_iter_state *state = raw_seq_private(seq);

do {
sk = sk_next(sk);
try_again:
;
sk = sk_nulls_next(sk);
} while (sk && sock_net(sk) != seq_file_net(seq));

if (!sk && ++state->bucket < RAW_HTABLE_SIZE) {
sk = sk_head(&h->ht[state->bucket]);
goto try_again;
}
if (!sk)
return raw_get_first(seq, state->bucket + 1);
return sk;
}

static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos)
{
struct sock *sk = raw_get_first(seq);
struct sock *sk = raw_get_first(seq, 0);

if (sk)
while (pos && (sk = raw_get_next(seq, sk)) != NULL)
Expand All @@ -990,11 +986,9 @@ static struct sock *raw_get_idx(struct seq_file *seq, loff_t pos)
}

void *raw_seq_start(struct seq_file *seq, loff_t *pos)
__acquires(&h->lock)
__acquires(RCU)
{
struct raw_hashinfo *h = pde_data(file_inode(seq->file));

read_lock(&h->lock);
rcu_read_lock();
return *pos ? raw_get_idx(seq, *pos - 1) : SEQ_START_TOKEN;
}
EXPORT_SYMBOL_GPL(raw_seq_start);
Expand All @@ -1004,7 +998,7 @@ void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos)
struct sock *sk;

if (v == SEQ_START_TOKEN)
sk = raw_get_first(seq);
sk = raw_get_first(seq, 0);
else
sk = raw_get_next(seq, v);
++*pos;
Expand All @@ -1013,11 +1007,9 @@ void *raw_seq_next(struct seq_file *seq, void *v, loff_t *pos)
EXPORT_SYMBOL_GPL(raw_seq_next);

void raw_seq_stop(struct seq_file *seq, void *v)
__releases(&h->lock)
__releases(RCU)
{
struct raw_hashinfo *h = pde_data(file_inode(seq->file));

read_unlock(&h->lock);
rcu_read_unlock();
}
EXPORT_SYMBOL_GPL(raw_seq_stop);

Expand Down Expand Up @@ -1079,6 +1071,7 @@ static __net_initdata struct pernet_operations raw_net_ops = {

int __init raw_proc_init(void)
{

return register_pernet_subsys(&raw_net_ops);
}

Expand Down
22 changes: 13 additions & 9 deletions net/ipv4/raw_diag.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,32 @@ static bool raw_lookup(struct net *net, struct sock *sk,
static struct sock *raw_sock_get(struct net *net, const struct inet_diag_req_v2 *r)
{
struct raw_hashinfo *hashinfo = raw_get_hashinfo(r);
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
struct sock *sk;
int slot;

if (IS_ERR(hashinfo))
return ERR_CAST(hashinfo);

read_lock(&hashinfo->lock);
rcu_read_lock();
for (slot = 0; slot < RAW_HTABLE_SIZE; slot++) {
sk_for_each(sk, &hashinfo->ht[slot]) {
hlist = &hashinfo->ht[slot];
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
if (raw_lookup(net, sk, r)) {
/*
* Grab it and keep until we fill
* diag meaage to be reported, so
* diag message to be reported, so
* caller should call sock_put then.
* We can do that because we're keeping
* hashinfo->lock here.
*/
sock_hold(sk);
goto out_unlock;
if (refcount_inc_not_zero(&sk->sk_refcnt))
goto out_unlock;
}
}
}
sk = ERR_PTR(-ENOENT);
out_unlock:
read_unlock(&hashinfo->lock);
rcu_read_unlock();

return sk;
}
Expand Down Expand Up @@ -141,6 +142,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
struct raw_hashinfo *hashinfo = raw_get_hashinfo(r);
struct net *net = sock_net(skb->sk);
struct inet_diag_dump_data *cb_data;
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
int num, s_num, slot, s_slot;
struct sock *sk = NULL;
struct nlattr *bc;
Expand All @@ -157,7 +160,8 @@ static void raw_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
for (slot = s_slot; slot < RAW_HTABLE_SIZE; s_num = 0, slot++) {
num = 0;

sk_for_each(sk, &hashinfo->ht[slot]) {
hlist = &hashinfo->ht[slot];
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
struct inet_sock *inet = inet_sk(sk);

if (!net_eq(sock_net(sk), net))
Expand Down
3 changes: 3 additions & 0 deletions net/ipv6/af_inet6.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#include <net/compat.h>
#include <net/xfrm.h>
#include <net/ioam6.h>
#include <net/rawv6.h>

#include <linux/uaccess.h>
#include <linux/mroute6.h>
Expand Down Expand Up @@ -1073,6 +1074,8 @@ static int __init inet6_init(void)
goto out;
}

raw_hashinfo_init(&raw_v6_hashinfo);

err = proto_register(&tcpv6_prot, 1);
if (err)
goto out;
Expand Down
28 changes: 13 additions & 15 deletions net/ipv6/raw.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,7 @@

#define ICMPV6_HDRLEN 4 /* ICMPv6 header, RFC 4443 Section 2.1 */

struct raw_hashinfo raw_v6_hashinfo = {
.lock = __RW_LOCK_UNLOCKED(raw_v6_hashinfo.lock),
};
struct raw_hashinfo raw_v6_hashinfo;
EXPORT_SYMBOL_GPL(raw_v6_hashinfo);

bool raw_v6_match(struct net *net, struct sock *sk, unsigned short num,
Expand Down Expand Up @@ -143,9 +141,10 @@ EXPORT_SYMBOL(rawv6_mh_filter_unregister);
static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
{
struct net *net = dev_net(skb->dev);
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
const struct in6_addr *saddr;
const struct in6_addr *daddr;
struct hlist_head *head;
struct sock *sk;
bool delivered = false;
__u8 hash;
Expand All @@ -154,11 +153,9 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
daddr = saddr + 1;

hash = nexthdr & (RAW_HTABLE_SIZE - 1);
head = &raw_v6_hashinfo.ht[hash];
if (hlist_empty(head))
return false;
read_lock(&raw_v6_hashinfo.lock);
sk_for_each(sk, head) {
hlist = &raw_v6_hashinfo.ht[hash];
rcu_read_lock();
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
int filtered;

if (!raw_v6_match(net, sk, nexthdr, daddr, saddr,
Expand Down Expand Up @@ -203,7 +200,7 @@ static bool ipv6_raw_deliver(struct sk_buff *skb, int nexthdr)
}
}
}
read_unlock(&raw_v6_hashinfo.lock);
rcu_read_unlock();
return delivered;
}

Expand Down Expand Up @@ -337,14 +334,15 @@ void raw6_icmp_error(struct sk_buff *skb, int nexthdr,
{
const struct in6_addr *saddr, *daddr;
struct net *net = dev_net(skb->dev);
struct hlist_head *head;
struct hlist_nulls_head *hlist;
struct hlist_nulls_node *hnode;
struct sock *sk;
int hash;

hash = nexthdr & (RAW_HTABLE_SIZE - 1);
head = &raw_v6_hashinfo.ht[hash];
read_lock(&raw_v6_hashinfo.lock);
sk_for_each(sk, head) {
hlist = &raw_v6_hashinfo.ht[hash];
rcu_read_lock();
hlist_nulls_for_each_entry(sk, hnode, hlist, sk_nulls_node) {
/* Note: ipv6_hdr(skb) != skb->data */
const struct ipv6hdr *ip6h = (const struct ipv6hdr *)skb->data;
saddr = &ip6h->saddr;
Expand All @@ -355,7 +353,7 @@ void raw6_icmp_error(struct sk_buff *skb, int nexthdr,
continue;
rawv6_err(sk, skb, NULL, type, code, inner_offset, info);
}
read_unlock(&raw_v6_hashinfo.lock);
rcu_read_unlock();
}

static inline int rawv6_rcv_skb(struct sock *sk, struct sk_buff *skb)
Expand Down

0 comments on commit 0daf07e

Please sign in to comment.