datapath: STT: Fix checksum handling.
[cascardo/ovs.git] / datapath / linux / compat / stt.c
index dcd70ff..eb397e8 100644 (file)
@@ -62,6 +62,7 @@ struct stt_dev {
        struct net_device       *dev;
        struct net              *net;
        struct list_head        next;
+       struct list_head        up_next;
        struct socket           *sock;
        __be16                  dst_port;
 };
@@ -150,7 +151,11 @@ struct frag_skb_cb {
 /* per-network namespace private data for this module */
 struct stt_net {
        struct list_head stt_list;
+       struct list_head stt_up_list;   /* Devices which are in IFF_UP state. */
        int n_tunnels;
+#ifdef HAVE_NF_REGISTER_NET_HOOK
+       bool nf_hook_reg_done;
+#endif
 };
 
 static int stt_net_id;
@@ -167,13 +172,13 @@ static DEFINE_PER_CPU(u32, pkt_seq_counter);
 static void clean_percpu(struct work_struct *work);
 static DECLARE_DELAYED_WORK(clean_percpu_wq, clean_percpu);
 
-static struct stt_dev *stt_find_sock(struct net *net, __be16 port)
+static struct stt_dev *stt_find_up_dev(struct net *net, __be16 port)
 {
        struct stt_net *sn = net_generic(net, stt_net_id);
        struct stt_dev *stt_dev;
 
-       list_for_each_entry_rcu(stt_dev, &sn->stt_list, next) {
-               if (inet_sk(stt_dev->sock->sk)->inet_sport == port)
+       list_for_each_entry_rcu(stt_dev, &sn->stt_up_list, up_next) {
+               if (stt_dev->dst_port == port)
                        return stt_dev;
        }
        return NULL;
@@ -934,7 +939,7 @@ netdev_tx_t ovs_stt_xmit(struct sk_buff *skb)
        struct net_device *dev = skb->dev;
        struct stt_dev *stt_dev = netdev_priv(dev);
        struct net *net = stt_dev->net;
-       __be16 dport = inet_sk(stt_dev->sock->sk)->inet_sport;
+       __be16 dport = stt_dev->dst_port;
        struct ip_tunnel_key *tun_key;
        struct ip_tunnel_info *tun_info;
        struct rtable *rt;
@@ -1342,6 +1347,7 @@ static void stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
        if (unlikely(!validate_checksum(skb)))
                goto drop;
 
+       __skb_pull(skb, sizeof(struct tcphdr));
        skb = reassemble(skb);
        if (!skb)
                return;
@@ -1481,11 +1487,11 @@ static unsigned int nf_ip_hook(FIRST_PARAM, struct sk_buff *skb, LAST_PARAM)
 
        skb_set_transport_header(skb, ip_hdr_len);
 
-       stt_dev = stt_find_sock(dev_net(skb->dev), tcp_hdr(skb)->dest);
+       stt_dev = stt_find_up_dev(dev_net(skb->dev), tcp_hdr(skb)->dest);
        if (!stt_dev)
                return NF_ACCEPT;
 
-       __skb_pull(skb, ip_hdr_len + sizeof(struct tcphdr));
+       __skb_pull(skb, ip_hdr_len);
        stt_rcv(stt_dev, skb);
        return NF_STOLEN;
 }
@@ -1551,15 +1557,28 @@ static int stt_start(struct net *net)
         * rtnl-lock, which results in dead lock in stt-dev-create. Therefore
         * use this new API.
         */
+
+       if (sn->nf_hook_reg_done)
+               goto out;
+
        err = nf_register_net_hook(net, &nf_hook_ops);
+       if (!err)
+               sn->nf_hook_reg_done = true;
 #else
+       /* Register STT only on very first STT device addition. */
+       if (!list_empty(&nf_hook_ops.list))
+               goto out;
+
        err = nf_register_hook(&nf_hook_ops);
 #endif
        if (err)
-               goto free_percpu;
+               goto dec_n_tunnel;
+out:
        sn->n_tunnels++;
        return 0;
 
+dec_n_tunnel:
+       n_tunnels--;
 free_percpu:
        for_each_possible_cpu(i) {
                struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
@@ -1582,12 +1601,6 @@ static void stt_cleanup(struct net *net)
        sn->n_tunnels--;
        if (sn->n_tunnels)
                goto out;
-#ifdef HAVE_NF_REGISTER_NET_HOOK
-       nf_unregister_net_hook(net, &nf_hook_ops);
-#else
-       nf_unregister_hook(&nf_hook_ops);
-#endif
-
 out:
        n_tunnels--;
        if (n_tunnels)
@@ -1644,6 +1657,7 @@ static int stt_open(struct net_device *dev)
 {
        struct stt_dev *stt = netdev_priv(dev);
        struct net *net = stt->net;
+       struct stt_net *sn = net_generic(net, stt_net_id);
        int err;
 
        err = stt_start(net);
@@ -1653,6 +1667,7 @@ static int stt_open(struct net_device *dev)
        err = tcp_sock_create4(net, stt->dst_port, &stt->sock);
        if (err)
                return err;
+       list_add_rcu(&stt->up_next, &sn->stt_up_list);
        return 0;
 }
 
@@ -1661,12 +1676,38 @@ static int stt_stop(struct net_device *dev)
        struct stt_dev *stt_dev = netdev_priv(dev);
        struct net *net = stt_dev->net;
 
+       list_del_rcu(&stt_dev->up_next);
+       synchronize_net();
        tcp_sock_release(stt_dev->sock);
        stt_dev->sock = NULL;
        stt_cleanup(net);
        return 0;
 }
 
+static int __stt_change_mtu(struct net_device *dev, int new_mtu, bool strict)
+{
+       int max_mtu = IP_MAX_MTU - STT_HEADER_LEN - sizeof(struct iphdr)
+                     - dev->hard_header_len;
+
+       if (new_mtu < 68)
+               return -EINVAL;
+
+       if (new_mtu > max_mtu) {
+               if (strict)
+                       return -EINVAL;
+
+               new_mtu = max_mtu;
+       }
+
+       dev->mtu = new_mtu;
+       return 0;
+}
+
+static int stt_change_mtu(struct net_device *dev, int new_mtu)
+{
+       return __stt_change_mtu(dev, new_mtu, true);
+}
+
 static const struct net_device_ops stt_netdev_ops = {
        .ndo_init               = stt_init,
        .ndo_uninit             = stt_uninit,
@@ -1674,7 +1715,7 @@ static const struct net_device_ops stt_netdev_ops = {
        .ndo_stop               = stt_stop,
        .ndo_start_xmit         = stt_dev_xmit,
        .ndo_get_stats64        = ip_tunnel_get_stats64,
-       .ndo_change_mtu         = eth_change_mtu,
+       .ndo_change_mtu         = stt_change_mtu,
        .ndo_validate_addr      = eth_validate_addr,
        .ndo_set_mac_address    = eth_mac_addr,
 };
@@ -1766,11 +1807,15 @@ static int stt_configure(struct net *net, struct net_device *dev,
        if (find_dev(net, dst_port))
                return -EBUSY;
 
+       err = __stt_change_mtu(dev, IP_MAX_MTU, false);
+       if (err)
+               return err;
+
        err = register_netdevice(dev);
        if (err)
                return err;
 
-       list_add_rcu(&stt->next, &sn->stt_list);
+       list_add(&stt->next, &sn->stt_list);
        return 0;
 }
 
@@ -1789,7 +1834,7 @@ static void stt_dellink(struct net_device *dev, struct list_head *head)
 {
        struct stt_dev *stt = netdev_priv(dev);
 
-       list_del_rcu(&stt->next);
+       list_del(&stt->next);
        unregister_netdevice_queue(dev, head);
 }
 
@@ -1851,6 +1896,10 @@ static int stt_init_net(struct net *net)
        struct stt_net *sn = net_generic(net, stt_net_id);
 
        INIT_LIST_HEAD(&sn->stt_list);
+       INIT_LIST_HEAD(&sn->stt_up_list);
+#ifdef HAVE_NF_REGISTER_NET_HOOK
+       sn->nf_hook_reg_done = false;
+#endif
        return 0;
 }
 
@@ -1861,6 +1910,14 @@ static void stt_exit_net(struct net *net)
        struct net_device *dev, *aux;
        LIST_HEAD(list);
 
+#ifdef HAVE_NF_REGISTER_NET_HOOK
+       /* Ideally this should be done from stt_stop(), But on some kernels
+        * nf-unreg operation needs RTNL-lock, which can cause deallock.
+        * So it is done from here. */
+       if (sn->nf_hook_reg_done)
+               nf_unregister_net_hook(net, &nf_hook_ops);
+#endif
+
        rtnl_lock();
 
        /* gather any stt devices that were moved into this ns */
@@ -1900,6 +1957,7 @@ int stt_init_module(void)
        if (rc)
                goto out2;
 
+       INIT_LIST_HEAD(&nf_hook_ops.list);
        pr_info("STT tunneling driver\n");
        return 0;
 out2:
@@ -1910,6 +1968,10 @@ out1:
 
 void stt_cleanup_module(void)
 {
+#ifndef HAVE_NF_REGISTER_NET_HOOK
+       if (!list_empty(&nf_hook_ops.list))
+               nf_unregister_hook(&nf_hook_ops);
+#endif
        rtnl_link_unregister(&stt_link_ops);
        unregister_pernet_subsys(&stt_net_ops);
 }