soreuseport: pass skb to secondary UDP socket lookup
[cascardo/linux.git] / net / ipv4 / udp.c
index 0c7b0e6..3a66731 100644 (file)
 #include <trace/events/skb.h>
 #include <net/busy_poll.h>
 #include "udp_impl.h"
+#include <net/sock_reuseport.h>
 
 struct udp_table udp_table __read_mostly;
 EXPORT_SYMBOL(udp_table);
@@ -137,7 +138,8 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
                               unsigned long *bitmap,
                               struct sock *sk,
                               int (*saddr_comp)(const struct sock *sk1,
-                                                const struct sock *sk2),
+                                                const struct sock *sk2,
+                                                bool match_wildcard),
                               unsigned int log)
 {
        struct sock *sk2;
@@ -152,8 +154,9 @@ static int udp_lib_lport_inuse(struct net *net, __u16 num,
                    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
                     sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
                    (!sk2->sk_reuseport || !sk->sk_reuseport ||
+                    rcu_access_pointer(sk->sk_reuseport_cb) ||
                     !uid_eq(uid, sock_i_uid(sk2))) &&
-                   saddr_comp(sk, sk2)) {
+                   saddr_comp(sk, sk2, true)) {
                        if (!bitmap)
                                return 1;
                        __set_bit(udp_sk(sk2)->udp_port_hash >> log, bitmap);
@@ -170,7 +173,8 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
                                struct udp_hslot *hslot2,
                                struct sock *sk,
                                int (*saddr_comp)(const struct sock *sk1,
-                                                 const struct sock *sk2))
+                                                 const struct sock *sk2,
+                                                 bool match_wildcard))
 {
        struct sock *sk2;
        struct hlist_nulls_node *node;
@@ -186,8 +190,9 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
                    (!sk2->sk_bound_dev_if || !sk->sk_bound_dev_if ||
                     sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
                    (!sk2->sk_reuseport || !sk->sk_reuseport ||
+                    rcu_access_pointer(sk->sk_reuseport_cb) ||
                     !uid_eq(uid, sock_i_uid(sk2))) &&
-                   saddr_comp(sk, sk2)) {
+                   saddr_comp(sk, sk2, true)) {
                        res = 1;
                        break;
                }
@@ -196,6 +201,35 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
        return res;
 }
 
+static int udp_reuseport_add_sock(struct sock *sk, struct udp_hslot *hslot,
+                                 int (*saddr_same)(const struct sock *sk1,
+                                                   const struct sock *sk2,
+                                                   bool match_wildcard))
+{
+       struct net *net = sock_net(sk);
+       struct hlist_nulls_node *node;
+       kuid_t uid = sock_i_uid(sk);
+       struct sock *sk2;
+
+       sk_nulls_for_each(sk2, node, &hslot->head) {
+               if (net_eq(sock_net(sk2), net) &&
+                   sk2 != sk &&
+                   sk2->sk_family == sk->sk_family &&
+                   ipv6_only_sock(sk2) == ipv6_only_sock(sk) &&
+                   (udp_sk(sk2)->udp_port_hash == udp_sk(sk)->udp_port_hash) &&
+                   (sk2->sk_bound_dev_if == sk->sk_bound_dev_if) &&
+                   sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) &&
+                   (*saddr_same)(sk, sk2, false)) {
+                       return reuseport_add_sock(sk, sk2);
+               }
+       }
+
+       /* Initial allocation may have already happened via setsockopt */
+       if (!rcu_access_pointer(sk->sk_reuseport_cb))
+               return reuseport_alloc(sk);
+       return 0;
+}
+
 /**
  *  udp_lib_get_port  -  UDP/-Lite port lookup for IPv4 and IPv6
  *
@@ -207,7 +241,8 @@ static int udp_lib_lport_inuse2(struct net *net, __u16 num,
  */
 int udp_lib_get_port(struct sock *sk, unsigned short snum,
                     int (*saddr_comp)(const struct sock *sk1,
-                                      const struct sock *sk2),
+                                      const struct sock *sk2,
+                                      bool match_wildcard),
                     unsigned int hash2_nulladdr)
 {
        struct udp_hslot *hslot, *hslot2;
@@ -290,6 +325,14 @@ found:
        udp_sk(sk)->udp_port_hash = snum;
        udp_sk(sk)->udp_portaddr_hash ^= snum;
        if (sk_unhashed(sk)) {
+               if (sk->sk_reuseport &&
+                   udp_reuseport_add_sock(sk, hslot, saddr_comp)) {
+                       inet_sk(sk)->inet_num = 0;
+                       udp_sk(sk)->udp_port_hash = 0;
+                       udp_sk(sk)->udp_portaddr_hash ^= snum;
+                       goto fail_unlock;
+               }
+
                sk_nulls_add_node_rcu(sk, &hslot->head);
                hslot->count++;
                sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1);
@@ -309,13 +352,22 @@ fail:
 }
 EXPORT_SYMBOL(udp_lib_get_port);
 
-static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2)
+/* match_wildcard == true:  0.0.0.0 equals to any IPv4 addresses
+ * match_wildcard == false: addresses must be exactly the same, i.e.
+ *                          0.0.0.0 only equals to 0.0.0.0
+ */
+static int ipv4_rcv_saddr_equal(const struct sock *sk1, const struct sock *sk2,
+                               bool match_wildcard)
 {
        struct inet_sock *inet1 = inet_sk(sk1), *inet2 = inet_sk(sk2);
 
-       return  (!ipv6_only_sock(sk2)  &&
-                (!inet1->inet_rcv_saddr || !inet2->inet_rcv_saddr ||
-                  inet1->inet_rcv_saddr == inet2->inet_rcv_saddr));
+       if (!ipv6_only_sock(sk2)) {
+               if (inet1->inet_rcv_saddr == inet2->inet_rcv_saddr)
+                       return 1;
+               if (!inet1->inet_rcv_saddr || !inet2->inet_rcv_saddr)
+                       return match_wildcard;
+       }
+       return 0;
 }
 
 static u32 udp4_portaddr_hash(const struct net *net, __be32 saddr,
@@ -441,7 +493,8 @@ static u32 udp_ehashfn(const struct net *net, const __be32 laddr,
 static struct sock *udp4_lib_lookup2(struct net *net,
                __be32 saddr, __be16 sport,
                __be32 daddr, unsigned int hnum, int dif,
-               struct udp_hslot *hslot2, unsigned int slot2)
+               struct udp_hslot *hslot2, unsigned int slot2,
+               struct sk_buff *skb)
 {
        struct sock *sk, *result;
        struct hlist_nulls_node *node;
@@ -459,8 +512,15 @@ begin:
                        badness = score;
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
+                               struct sock *sk2;
                                hash = udp_ehashfn(net, daddr, hnum,
                                                   saddr, sport);
+                               sk2 = reuseport_select_sock(sk, hash, skb,
+                                                           sizeof(struct udphdr));
+                               if (sk2) {
+                                       result = sk2;
+                                       goto found;
+                               }
                                matches = 1;
                        }
                } else if (score == badness && reuseport) {
@@ -478,6 +538,7 @@ begin:
        if (get_nulls_value(node) != slot2)
                goto begin;
        if (result) {
+found:
                if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
                        result = NULL;
                else if (unlikely(compute_score2(result, net, saddr, sport,
@@ -494,7 +555,7 @@ begin:
  */
 struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
                __be16 sport, __be32 daddr, __be16 dport,
-               int dif, struct udp_table *udptable)
+               int dif, struct udp_table *udptable, struct sk_buff *skb)
 {
        struct sock *sk, *result;
        struct hlist_nulls_node *node;
@@ -514,7 +575,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
 
                result = udp4_lib_lookup2(net, saddr, sport,
                                          daddr, hnum, dif,
-                                         hslot2, slot2);
+                                         hslot2, slot2, skb);
                if (!result) {
                        hash2 = udp4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
                        slot2 = hash2 & udptable->mask;
@@ -524,7 +585,7 @@ struct sock *__udp4_lib_lookup(struct net *net, __be32 saddr,
 
                        result = udp4_lib_lookup2(net, saddr, sport,
                                                  htonl(INADDR_ANY), hnum, dif,
-                                                 hslot2, slot2);
+                                                 hslot2, slot2, skb);
                }
                rcu_read_unlock();
                return result;
@@ -540,8 +601,15 @@ begin:
                        badness = score;
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
+                               struct sock *sk2;
                                hash = udp_ehashfn(net, daddr, hnum,
                                                   saddr, sport);
+                               sk2 = reuseport_select_sock(sk, hash, skb,
+                                                       sizeof(struct udphdr));
+                               if (sk2) {
+                                       result = sk2;
+                                       goto found;
+                               }
                                matches = 1;
                        }
                } else if (score == badness && reuseport) {
@@ -560,6 +628,7 @@ begin:
                goto begin;
 
        if (result) {
+found:
                if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
                        result = NULL;
                else if (unlikely(compute_score(result, net, saddr, hnum, sport,
@@ -581,13 +650,14 @@ static inline struct sock *__udp4_lib_lookup_skb(struct sk_buff *skb,
 
        return __udp4_lib_lookup(dev_net(skb_dst(skb)->dev), iph->saddr, sport,
                                 iph->daddr, dport, inet_iif(skb),
-                                udptable);
+                                udptable, skb);
 }
 
 struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
                             __be32 daddr, __be16 dport, int dif)
 {
-       return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif, &udp_table);
+       return __udp4_lib_lookup(net, saddr, sport, daddr, dport, dif,
+                                &udp_table, NULL);
 }
 EXPORT_SYMBOL_GPL(udp4_lib_lookup);
 
@@ -635,7 +705,8 @@ void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
        struct net *net = dev_net(skb->dev);
 
        sk = __udp4_lib_lookup(net, iph->daddr, uh->dest,
-                       iph->saddr, uh->source, skb->dev->ifindex, udptable);
+                       iph->saddr, uh->source, skb->dev->ifindex, udptable,
+                       NULL);
        if (!sk) {
                ICMP_INC_STATS_BH(net, ICMP_MIB_INERRORS);
                return; /* No socket for error */
@@ -772,7 +843,8 @@ void udp_set_csum(bool nocheck, struct sk_buff *skb,
        else if (skb_is_gso(skb))
                uh->check = ~udp_v4_check(len, saddr, daddr, 0);
        else if (skb_dst(skb) && skb_dst(skb)->dev &&
-                (skb_dst(skb)->dev->features & NETIF_F_V4_CSUM)) {
+                (skb_dst(skb)->dev->features &
+                 (NETIF_F_IP_CSUM | NETIF_F_HW_CSUM))) {
 
                BUG_ON(skb->ip_summed == CHECKSUM_PARTIAL);
 
@@ -1270,6 +1342,7 @@ int udp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int noblock,
        int peeked, off = 0;
        int err;
        int is_udplite = IS_UDPLITE(sk);
+       bool checksum_valid = false;
        bool slow;
 
        if (flags & MSG_ERRQUEUE)
@@ -1295,11 +1368,12 @@ try_again:
         */
 
        if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) {
-               if (udp_lib_checksum_complete(skb))
+               checksum_valid = !udp_lib_checksum_complete(skb);
+               if (!checksum_valid)
                        goto csum_copy_err;
        }
 
-       if (skb_csum_unnecessary(skb))
+       if (checksum_valid || skb_csum_unnecessary(skb))
                err = skb_copy_datagram_msg(skb, sizeof(struct udphdr),
                                            msg, copied);
        else {
@@ -1395,6 +1469,8 @@ void udp_lib_unhash(struct sock *sk)
                hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash);
 
                spin_lock_bh(&hslot->lock);
+               if (rcu_access_pointer(sk->sk_reuseport_cb))
+                       reuseport_detach_sock(sk);
                if (sk_nulls_del_node_init_rcu(sk)) {
                        hslot->count--;
                        inet_sk(sk)->inet_num = 0;
@@ -1422,22 +1498,28 @@ void udp_lib_rehash(struct sock *sk, u16 newhash)
                hslot2 = udp_hashslot2(udptable, udp_sk(sk)->udp_portaddr_hash);
                nhslot2 = udp_hashslot2(udptable, newhash);
                udp_sk(sk)->udp_portaddr_hash = newhash;
-               if (hslot2 != nhslot2) {
+
+               if (hslot2 != nhslot2 ||
+                   rcu_access_pointer(sk->sk_reuseport_cb)) {
                        hslot = udp_hashslot(udptable, sock_net(sk),
                                             udp_sk(sk)->udp_port_hash);
                        /* we must lock primary chain too */
                        spin_lock_bh(&hslot->lock);
-
-                       spin_lock(&hslot2->lock);
-                       hlist_nulls_del_init_rcu(&udp_sk(sk)->udp_portaddr_node);
-                       hslot2->count--;
-                       spin_unlock(&hslot2->lock);
-
-                       spin_lock(&nhslot2->lock);
-                       hlist_nulls_add_head_rcu(&udp_sk(sk)->udp_portaddr_node,
-                                                &nhslot2->head);
-                       nhslot2->count++;
-                       spin_unlock(&nhslot2->lock);
+                       if (rcu_access_pointer(sk->sk_reuseport_cb))
+                               reuseport_detach_sock(sk);
+
+                       if (hslot2 != nhslot2) {
+                               spin_lock(&hslot2->lock);
+                               hlist_nulls_del_init_rcu(&udp_sk(sk)->udp_portaddr_node);
+                               hslot2->count--;
+                               spin_unlock(&hslot2->lock);
+
+                               spin_lock(&nhslot2->lock);
+                               hlist_nulls_add_head_rcu(&udp_sk(sk)->udp_portaddr_node,
+                                                        &nhslot2->head);
+                               nhslot2->count++;
+                               spin_unlock(&nhslot2->lock);
+                       }
 
                        spin_unlock_bh(&hslot->lock);
                }