vxlan: avoid using stale vxlan socket.
[cascardo/linux.git] / drivers / net / vxlan.c
index 6e65832..f3c2fa3 100644 (file)
@@ -27,7 +27,6 @@
 #include <net/net_namespace.h>
 #include <net/netns/generic.h>
 #include <net/vxlan.h>
-#include <net/protocol.h>
 
 #if IS_ENABLED(CONFIG_IPV6)
 #include <net/ip6_tunnel.h>
@@ -288,7 +287,7 @@ static int vxlan_fdb_info(struct sk_buff *skb, struct vxlan_dev *vxlan,
 
        if (!net_eq(dev_net(vxlan->dev), vxlan->net) &&
            nla_put_s32(skb, NDA_LINK_NETNSID,
-                       peernet2id_alloc(dev_net(vxlan->dev), vxlan->net)))
+                       peernet2id(dev_net(vxlan->dev), vxlan->net)))
                goto nla_put_failure;
 
        if (send_eth && nla_put(skb, NDA_LLADDR, ETH_ALEN, &fdb->eth_addr))
@@ -584,7 +583,7 @@ static struct sk_buff **vxlan_gro_receive(struct sock *sk,
                }
        }
 
-       pp = eth_gro_receive(head, skb);
+       pp = call_gro_receive(eth_gro_receive, head, skb);
        flush = 0;
 
 out:
@@ -861,20 +860,20 @@ out:
 /* Dump forwarding table */
 static int vxlan_fdb_dump(struct sk_buff *skb, struct netlink_callback *cb,
                          struct net_device *dev,
-                         struct net_device *filter_dev, int idx)
+                         struct net_device *filter_dev, int *idx)
 {
        struct vxlan_dev *vxlan = netdev_priv(dev);
        unsigned int h;
+       int err = 0;
 
        for (h = 0; h < FDB_HASH_SIZE; ++h) {
                struct vxlan_fdb *f;
-               int err;
 
                hlist_for_each_entry_rcu(f, &vxlan->fdb_head[h], hlist) {
                        struct vxlan_rdst *rd;
 
                        list_for_each_entry_rcu(rd, &f->remotes, list) {
-                               if (idx < cb->args[0])
+                               if (*idx < cb->args[2])
                                        goto skip;
 
                                err = vxlan_fdb_info(skb, vxlan, f,
@@ -882,17 +881,15 @@ static int vxlan_fdb_dump(struct sk_buff *skb, struct netlink_callback *cb,
                                                     cb->nlh->nlmsg_seq,
                                                     RTM_NEWNEIGH,
                                                     NLM_F_MULTI, rd);
-                               if (err < 0) {
-                                       cb->args[1] = err;
+                               if (err < 0)
                                        goto out;
-                               }
 skip:
-                               ++idx;
+                               *idx += 1;
                        }
                }
        }
 out:
-       return idx;
+       return err;
 }
 
 /* Watch incoming packets to learn mapping between Ethernet address
@@ -946,17 +943,20 @@ static bool vxlan_snoop(struct net_device *dev,
 static bool vxlan_group_used(struct vxlan_net *vn, struct vxlan_dev *dev)
 {
        struct vxlan_dev *vxlan;
+       struct vxlan_sock *sock4;
+       struct vxlan_sock *sock6 = NULL;
        unsigned short family = dev->default_dst.remote_ip.sa.sa_family;
 
+       sock4 = rtnl_dereference(dev->vn4_sock);
+
        /* The vxlan_sock is only used by dev, leaving group has
         * no effect on other vxlan devices.
         */
-       if (family == AF_INET && dev->vn4_sock &&
-           atomic_read(&dev->vn4_sock->refcnt) == 1)
+       if (family == AF_INET && sock4 && atomic_read(&sock4->refcnt) == 1)
                return false;
 #if IS_ENABLED(CONFIG_IPV6)
-       if (family == AF_INET6 && dev->vn6_sock &&
-           atomic_read(&dev->vn6_sock->refcnt) == 1)
+       sock6 = rtnl_dereference(dev->vn6_sock);
+       if (family == AF_INET6 && sock6 && atomic_read(&sock6->refcnt) == 1)
                return false;
 #endif
 
@@ -964,10 +964,12 @@ static bool vxlan_group_used(struct vxlan_net *vn, struct vxlan_dev *dev)
                if (!netif_running(vxlan->dev) || vxlan == dev)
                        continue;
 
-               if (family == AF_INET && vxlan->vn4_sock != dev->vn4_sock)
+               if (family == AF_INET &&
+                   rtnl_dereference(vxlan->vn4_sock) != sock4)
                        continue;
 #if IS_ENABLED(CONFIG_IPV6)
-               if (family == AF_INET6 && vxlan->vn6_sock != dev->vn6_sock)
+               if (family == AF_INET6 &&
+                   rtnl_dereference(vxlan->vn6_sock) != sock6)
                        continue;
 #endif
 
@@ -1008,22 +1010,25 @@ static bool __vxlan_sock_release_prep(struct vxlan_sock *vs)
 
 static void vxlan_sock_release(struct vxlan_dev *vxlan)
 {
-       bool ipv4 = __vxlan_sock_release_prep(vxlan->vn4_sock);
+       struct vxlan_sock *sock4 = rtnl_dereference(vxlan->vn4_sock);
 #if IS_ENABLED(CONFIG_IPV6)
-       bool ipv6 = __vxlan_sock_release_prep(vxlan->vn6_sock);
+       struct vxlan_sock *sock6 = rtnl_dereference(vxlan->vn6_sock);
+
+       rcu_assign_pointer(vxlan->vn6_sock, NULL);
 #endif
 
+       rcu_assign_pointer(vxlan->vn4_sock, NULL);
        synchronize_net();
 
-       if (ipv4) {
-               udp_tunnel_sock_release(vxlan->vn4_sock->sock);
-               kfree(vxlan->vn4_sock);
+       if (__vxlan_sock_release_prep(sock4)) {
+               udp_tunnel_sock_release(sock4->sock);
+               kfree(sock4);
        }
 
 #if IS_ENABLED(CONFIG_IPV6)
-       if (ipv6) {
-               udp_tunnel_sock_release(vxlan->vn6_sock->sock);
-               kfree(vxlan->vn6_sock);
+       if (__vxlan_sock_release_prep(sock6)) {
+               udp_tunnel_sock_release(sock6->sock);
+               kfree(sock6);
        }
 #endif
 }
@@ -1039,18 +1044,21 @@ static int vxlan_igmp_join(struct vxlan_dev *vxlan)
        int ret = -EINVAL;
 
        if (ip->sa.sa_family == AF_INET) {
+               struct vxlan_sock *sock4 = rtnl_dereference(vxlan->vn4_sock);
                struct ip_mreqn mreq = {
                        .imr_multiaddr.s_addr   = ip->sin.sin_addr.s_addr,
                        .imr_ifindex            = ifindex,
                };
 
-               sk = vxlan->vn4_sock->sock->sk;
+               sk = sock4->sock->sk;
                lock_sock(sk);
                ret = ip_mc_join_group(sk, &mreq);
                release_sock(sk);
 #if IS_ENABLED(CONFIG_IPV6)
        } else {
-               sk = vxlan->vn6_sock->sock->sk;
+               struct vxlan_sock *sock6 = rtnl_dereference(vxlan->vn6_sock);
+
+               sk = sock6->sock->sk;
                lock_sock(sk);
                ret = ipv6_stub->ipv6_sock_mc_join(sk, ifindex,
                                                   &ip->sin6.sin6_addr);
@@ -1070,18 +1078,21 @@ static int vxlan_igmp_leave(struct vxlan_dev *vxlan)
        int ret = -EINVAL;
 
        if (ip->sa.sa_family == AF_INET) {
+               struct vxlan_sock *sock4 = rtnl_dereference(vxlan->vn4_sock);
                struct ip_mreqn mreq = {
                        .imr_multiaddr.s_addr   = ip->sin.sin_addr.s_addr,
                        .imr_ifindex            = ifindex,
                };
 
-               sk = vxlan->vn4_sock->sock->sk;
+               sk = sock4->sock->sk;
                lock_sock(sk);
                ret = ip_mc_leave_group(sk, &mreq);
                release_sock(sk);
 #if IS_ENABLED(CONFIG_IPV6)
        } else {
-               sk = vxlan->vn6_sock->sock->sk;
+               struct vxlan_sock *sock6 = rtnl_dereference(vxlan->vn6_sock);
+
+               sk = sock6->sock->sk;
                lock_sock(sk);
                ret = ipv6_stub->ipv6_sock_mc_drop(sk, ifindex,
                                                   &ip->sin6.sin6_addr);
@@ -1294,7 +1305,7 @@ static int vxlan_rcv(struct sock *sk, struct sk_buff *skb)
                struct metadata_dst *tun_dst;
 
                tun_dst = udp_tun_rx_dst(skb, vxlan_get_sk_family(vs), TUNNEL_KEY,
-                                        vxlan_vni_to_tun_id(vni), sizeof(*md));
+                                        key32_to_tunnel_id(vni), sizeof(*md));
 
                if (!tun_dst)
                        goto drop;
@@ -1831,11 +1842,15 @@ static struct dst_entry *vxlan6_get_route(struct vxlan_dev *vxlan,
                                          struct dst_cache *dst_cache,
                                          const struct ip_tunnel_info *info)
 {
+       struct vxlan_sock *sock6 = rcu_dereference(vxlan->vn6_sock);
        bool use_cache = ip_tunnel_dst_cache_usable(skb, info);
        struct dst_entry *ndst;
        struct flowi6 fl6;
        int err;
 
+       if (!sock6)
+               return ERR_PTR(-EIO);
+
        if (tos && !info)
                use_cache = false;
        if (use_cache) {
@@ -1853,7 +1868,7 @@ static struct dst_entry *vxlan6_get_route(struct vxlan_dev *vxlan,
        fl6.flowi6_proto = IPPROTO_UDP;
 
        err = ipv6_stub->ipv6_dst_lookup(vxlan->net,
-                                        vxlan->vn6_sock->sock->sk,
+                                        sock6->sock->sk,
                                         &ndst, &fl6);
        if (err < 0)
                return ERR_PTR(err);
@@ -1948,7 +1963,7 @@ static void vxlan_xmit_one(struct sk_buff *skb, struct net_device *dev,
                        goto drop;
                }
                dst_port = info->key.tp_dst ? : vxlan->cfg.dst_port;
-               vni = vxlan_tun_id_to_vni(info->key.tun_id);
+               vni = tunnel_id_to_key32(info->key.tun_id);
                remote_ip.sa.sa_family = ip_tunnel_info_af(info);
                if (remote_ip.sa.sa_family == AF_INET) {
                        remote_ip.sin.sin_addr.s_addr = info->key.u.ipv4.dst;
@@ -1998,9 +2013,11 @@ static void vxlan_xmit_one(struct sk_buff *skb, struct net_device *dev,
        }
 
        if (dst->sa.sa_family == AF_INET) {
-               if (!vxlan->vn4_sock)
+               struct vxlan_sock *sock4 = rcu_dereference(vxlan->vn4_sock);
+
+               if (!sock4)
                        goto drop;
-               sk = vxlan->vn4_sock->sock->sk;
+               sk = sock4->sock->sk;
 
                rt = vxlan_get_route(vxlan, skb,
                                     rdst ? rdst->remote_ifindex : 0, tos,
@@ -2053,12 +2070,13 @@ static void vxlan_xmit_one(struct sk_buff *skb, struct net_device *dev,
                                    src_port, dst_port, xnet, !udp_sum);
 #if IS_ENABLED(CONFIG_IPV6)
        } else {
+               struct vxlan_sock *sock6 = rcu_dereference(vxlan->vn6_sock);
                struct dst_entry *ndst;
                u32 rt6i_flags;
 
-               if (!vxlan->vn6_sock)
+               if (!sock6)
                        goto drop;
-               sk = vxlan->vn6_sock->sock->sk;
+               sk = sock6->sock->sk;
 
                ndst = vxlan6_get_route(vxlan, skb,
                                        rdst ? rdst->remote_ifindex : 0, tos,
@@ -2106,6 +2124,7 @@ static void vxlan_xmit_one(struct sk_buff *skb, struct net_device *dev,
                                      vni, md, flags, udp_sum);
                if (err < 0) {
                        dst_release(ndst);
+                       dev->stats.tx_errors++;
                        return;
                }
                udp_tunnel6_xmit_skb(ndst, sk, skb, dev,
@@ -2417,9 +2436,10 @@ static int vxlan_fill_metadata_dst(struct net_device *dev, struct sk_buff *skb)
        dport = info->key.tp_dst ? : vxlan->cfg.dst_port;
 
        if (ip_tunnel_info_af(info) == AF_INET) {
+               struct vxlan_sock *sock4 = rcu_dereference(vxlan->vn4_sock);
                struct rtable *rt;
 
-               if (!vxlan->vn4_sock)
+               if (!sock4)
                        return -EINVAL;
                rt = vxlan_get_route(vxlan, skb, 0, info->key.tos,
                                     info->key.u.ipv4.dst,
@@ -2431,8 +2451,6 @@ static int vxlan_fill_metadata_dst(struct net_device *dev, struct sk_buff *skb)
 #if IS_ENABLED(CONFIG_IPV6)
                struct dst_entry *ndst;
 
-               if (!vxlan->vn6_sock)
-                       return -EINVAL;
                ndst = vxlan6_get_route(vxlan, skb, 0, info->key.tos,
                                        info->key.label, &info->key.u.ipv6.dst,
                                        &info->key.u.ipv6.src, NULL, info);
@@ -2742,10 +2760,10 @@ static int __vxlan_sock_add(struct vxlan_dev *vxlan, bool ipv6)
                return PTR_ERR(vs);
 #if IS_ENABLED(CONFIG_IPV6)
        if (ipv6)
-               vxlan->vn6_sock = vs;
+               rcu_assign_pointer(vxlan->vn6_sock, vs);
        else
 #endif
-               vxlan->vn4_sock = vs;
+               rcu_assign_pointer(vxlan->vn4_sock, vs);
        vxlan_vs_add_dev(vs, vxlan);
        return 0;
 }
@@ -2756,9 +2774,9 @@ static int vxlan_sock_add(struct vxlan_dev *vxlan)
        bool metadata = vxlan->flags & VXLAN_F_COLLECT_METADATA;
        int ret = 0;
 
-       vxlan->vn4_sock = NULL;
+       RCU_INIT_POINTER(vxlan->vn4_sock, NULL);
 #if IS_ENABLED(CONFIG_IPV6)
-       vxlan->vn6_sock = NULL;
+       RCU_INIT_POINTER(vxlan->vn6_sock, NULL);
        if (ipv6 || metadata)
                ret = __vxlan_sock_add(vxlan, true);
 #endif