soreuseport: setsockopt SO_ATTACH_REUSEPORT_[CE]BPF
[cascardo/linux.git] / net / ipv6 / udp.c
index 9da3287..56fcb55 100644 (file)
@@ -47,6 +47,7 @@
 #include <net/xfrm.h>
 #include <net/inet6_hashtables.h>
 #include <net/busy_poll.h>
+#include <net/sock_reuseport.h>
 
 #include <linux/proc_fs.h>
 #include <linux/seq_file.h>
@@ -76,7 +77,14 @@ static u32 udp6_ehashfn(const struct net *net,
                               udp_ipv6_hash_secret + net_hash_mix(net));
 }
 
-int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2)
+/* match_wildcard == true:  IPV6_ADDR_ANY equals to any IPv6 addresses if IPv6
+ *                          only, and any IPv4 addresses if not IPv6 only
+ * match_wildcard == false: addresses must be exactly the same, i.e.
+ *                          IPV6_ADDR_ANY only equals to IPV6_ADDR_ANY,
+ *                          and 0.0.0.0 equals to 0.0.0.0 only
+ */
+int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2,
+                        bool match_wildcard)
 {
        const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2);
        int sk2_ipv6only = inet_v6_ipv6only(sk2);
@@ -84,16 +92,24 @@ int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2)
        int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;
 
        /* if both are mapped, treat as IPv4 */
-       if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED)
-               return (!sk2_ipv6only &&
-                       (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr ||
-                         sk->sk_rcv_saddr == sk2->sk_rcv_saddr));
+       if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED) {
+               if (!sk2_ipv6only) {
+                       if (sk->sk_rcv_saddr == sk2->sk_rcv_saddr)
+                               return 1;
+                       if (!sk->sk_rcv_saddr || !sk2->sk_rcv_saddr)
+                               return match_wildcard;
+               }
+               return 0;
+       }
 
-       if (addr_type2 == IPV6_ADDR_ANY &&
+       if (addr_type == IPV6_ADDR_ANY && addr_type2 == IPV6_ADDR_ANY)
+               return 1;
+
+       if (addr_type2 == IPV6_ADDR_ANY && match_wildcard &&
            !(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED))
                return 1;
 
-       if (addr_type == IPV6_ADDR_ANY &&
+       if (addr_type == IPV6_ADDR_ANY && match_wildcard &&
            !(ipv6_only_sock(sk) && addr_type2 == IPV6_ADDR_MAPPED))
                return 1;
 
@@ -253,8 +269,14 @@ begin:
                        badness = score;
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
+                               struct sock *sk2;
                                hash = udp6_ehashfn(net, daddr, hnum,
                                                    saddr, sport);
+                               sk2 = reuseport_select_sock(sk, hash, NULL, 0);
+                               if (sk2) {
+                                       result = sk2;
+                                       goto found;
+                               }
                                matches = 1;
                        }
                } else if (score == badness && reuseport) {
@@ -273,6 +295,7 @@ begin:
                goto begin;
 
        if (result) {
+found:
                if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
                        result = NULL;
                else if (unlikely(compute_score2(result, net, saddr, sport,
@@ -287,7 +310,8 @@ begin:
 struct sock *__udp6_lib_lookup(struct net *net,
                                      const struct in6_addr *saddr, __be16 sport,
                                      const struct in6_addr *daddr, __be16 dport,
-                                     int dif, struct udp_table *udptable)
+                                     int dif, struct udp_table *udptable,
+                                     struct sk_buff *skb)
 {
        struct sock *sk, *result;
        struct hlist_nulls_node *node;
@@ -332,8 +356,15 @@ begin:
                        badness = score;
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
+                               struct sock *sk2;
                                hash = udp6_ehashfn(net, daddr, hnum,
                                                    saddr, sport);
+                               sk2 = reuseport_select_sock(sk, hash, skb,
+                                                       sizeof(struct udphdr));
+                               if (sk2) {
+                                       result = sk2;
+                                       goto found;
+                               }
                                matches = 1;
                        }
                } else if (score == badness && reuseport) {
@@ -352,6 +383,7 @@ begin:
                goto begin;
 
        if (result) {
+found:
                if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
                        result = NULL;
                else if (unlikely(compute_score(result, net, hnum, saddr, sport,
@@ -377,13 +409,13 @@ static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb,
                return sk;
        return __udp6_lib_lookup(dev_net(skb_dst(skb)->dev), &iph->saddr, sport,
                                 &iph->daddr, dport, inet6_iif(skb),
-                                udptable);
+                                udptable, skb);
 }
 
 struct sock *udp6_lib_lookup(struct net *net, const struct in6_addr *saddr, __be16 sport,
                             const struct in6_addr *daddr, __be16 dport, int dif)
 {
-       return __udp6_lib_lookup(net, saddr, sport, daddr, dport, dif, &udp_table);
+       return __udp6_lib_lookup(net, saddr, sport, daddr, dport, dif, &udp_table, NULL);
 }
 EXPORT_SYMBOL_GPL(udp6_lib_lookup);
 
@@ -402,6 +434,7 @@ int udpv6_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
        int peeked, off = 0;
        int err;
        int is_udplite = IS_UDPLITE(sk);
+       bool checksum_valid = false;
        int is_udp4;
        bool slow;
 
@@ -433,11 +466,12 @@ try_again:
         */
 
        if (copied < ulen || UDP_SKB_CB(skb)->partial_cov) {
-               if (udp_lib_checksum_complete(skb))
+               checksum_valid = !udp_lib_checksum_complete(skb);
+               if (!checksum_valid)
                        goto csum_copy_err;
        }
 
-       if (skb_csum_unnecessary(skb))
+       if (checksum_valid || skb_csum_unnecessary(skb))
                err = skb_copy_datagram_msg(skb, sizeof(struct udphdr),
                                            msg, copied);
        else {
@@ -547,8 +581,8 @@ void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
        int err;
        struct net *net = dev_net(skb->dev);
 
-       sk = __udp6_lib_lookup(net, daddr, uh->dest,
-                              saddr, uh->source, inet6_iif(skb), udptable);
+       sk = __udp6_lib_lookup(net, daddr, uh->dest, saddr, uh->source,
+                              inet6_iif(skb), udptable, skb);
        if (!sk) {
                ICMP6_INC_STATS_BH(net, __in6_dev_get(skb->dev),
                                   ICMP6_MIB_INERRORS);