datapath: stt: Use RCU API to update stt-dev list.
[cascardo/ovs.git] / datapath / linux / compat / stt.c
index b44f470..dcd70ff 100644 (file)
@@ -9,6 +9,7 @@
  * 2 of the License, or (at your option) any later version.
  */
 
+#define pr_fmt(fmt) KBUILD_MODNAME ": " fmt
 #include <asm/unaligned.h>
 
 #include <linux/delay.h>
 #include <linux/list.h>
 #include <linux/log2.h>
 #include <linux/module.h>
+#include <linux/net.h>
 #include <linux/netfilter.h>
 #include <linux/percpu.h>
 #include <linux/skbuff.h>
 #include <linux/tcp.h>
 #include <linux/workqueue.h>
 
+#include <net/dst_metadata.h>
 #include <net/icmp.h>
 #include <net/inet_ecn.h>
 #include <net/ip.h>
+#include <net/ip_tunnels.h>
+#include <net/ip6_checksum.h>
 #include <net/net_namespace.h>
 #include <net/netns/generic.h>
 #include <net/sock.h>
 #include <net/udp.h>
 
 #include "gso.h"
+#include "compat.h"
+
+#define STT_NETDEV_VER "0.1"
+#define STT_DST_PORT 7471
 
 #ifdef OVS_STT
 #define STT_VER 0
 
+/* @list: Per-net list of STT ports.
+ * @rcv: The callback is called on STT packet recv, STT reassembly can generate
+ * multiple packets, in this case first packet has tunnel outer header, rest
+ * of the packets are inner packet segments with no stt header.
+ * @rcv_data: user data.
+ * @sock: Fake TCP socket for the STT port.
+ */
+struct stt_dev {
+       struct net_device       *dev;
+       struct net              *net;
+       struct list_head        next;
+       struct socket           *sock;
+       __be16                  dst_port;
+};
+
 #define STT_CSUM_VERIFIED      BIT(0)
 #define STT_CSUM_PARTIAL       BIT(1)
 #define STT_PROTO_IPV4         BIT(2)
@@ -125,7 +149,8 @@ struct frag_skb_cb {
 
 /* per-network namespace private data for this module */
 struct stt_net {
-       struct list_head sock_list;
+       struct list_head stt_list;
+       int n_tunnels;
 };
 
 static int stt_net_id;
@@ -142,14 +167,14 @@ static DEFINE_PER_CPU(u32, pkt_seq_counter);
 static void clean_percpu(struct work_struct *work);
 static DECLARE_DELAYED_WORK(clean_percpu_wq, clean_percpu);
 
-static struct stt_sock *stt_find_sock(struct net *net, __be16 port)
+static struct stt_dev *stt_find_sock(struct net *net, __be16 port)
 {
        struct stt_net *sn = net_generic(net, stt_net_id);
-       struct stt_sock *stt_sock;
+       struct stt_dev *stt_dev;
 
-       list_for_each_entry_rcu(stt_sock, &sn->sock_list, list) {
-               if (inet_sk(stt_sock->sock->sk)->inet_sport == port)
-                       return stt_sock;
+       list_for_each_entry_rcu(stt_dev, &sn->stt_list, next) {
+               if (inet_sk(stt_dev->sock->sk)->inet_sport == port)
+                       return stt_dev;
        }
        return NULL;
 }
@@ -786,7 +811,6 @@ static int skb_list_xmit(struct rtable *rt, struct sk_buff *skb, __be32 src,
                if (next)
                        dst_clone(&rt->dst);
 
-               skb_clear_ovs_gso_cb(skb);
                skb->next = NULL;
                len += iptunnel_xmit(NULL, rt, skb, src, dst, IPPROTO_TCP,
                                     tos, ttl, df, false);
@@ -833,7 +857,7 @@ static u8 skb_get_l4_proto(struct sk_buff *skb, __be16 l3_proto)
        return 0;
 }
 
-int rpl_stt_xmit_skb(struct sk_buff *skb, struct rtable *rt,
+static int stt_xmit_skb(struct sk_buff *skb, struct rtable *rt,
                 __be32 src, __be32 dst, __u8 tos,
                 __u8 ttl, __be16 df, __be16 src_port, __be16 dst_port,
                 __be64 tun_id)
@@ -904,7 +928,57 @@ err_free_rt:
        kfree_skb(skb);
        return ret;
 }
-EXPORT_SYMBOL_GPL(rpl_stt_xmit_skb);
+
+netdev_tx_t ovs_stt_xmit(struct sk_buff *skb)
+{
+       struct net_device *dev = skb->dev;
+       struct stt_dev *stt_dev = netdev_priv(dev);
+       struct net *net = stt_dev->net;
+       __be16 dport = inet_sk(stt_dev->sock->sk)->inet_sport;
+       struct ip_tunnel_key *tun_key;
+       struct ip_tunnel_info *tun_info;
+       struct rtable *rt;
+       struct flowi4 fl;
+       __be16 sport;
+       __be16 df;
+       int err;
+
+       tun_info = skb_tunnel_info(skb);
+       if (unlikely(!tun_info)) {
+               err = -EINVAL;
+               goto error;
+       }
+
+       tun_key = &tun_info->key;
+
+       /* Route lookup */
+       memset(&fl, 0, sizeof(fl));
+       fl.daddr = tun_key->u.ipv4.dst;
+       fl.saddr = tun_key->u.ipv4.src;
+       fl.flowi4_tos = RT_TOS(tun_key->tos);
+       fl.flowi4_mark = skb->mark;
+       fl.flowi4_proto = IPPROTO_TCP;
+       rt = ip_route_output_key(net, &fl);
+       if (IS_ERR(rt)) {
+               err = PTR_ERR(rt);
+               goto error;
+       }
+
+       df = tun_key->tun_flags & TUNNEL_DONT_FRAGMENT ? htons(IP_DF) : 0;
+       sport = udp_flow_src_port(net, skb, 1, USHRT_MAX, true);
+       skb->ignore_df = 1;
+
+       err = stt_xmit_skb(skb, rt, fl.saddr, tun_key->u.ipv4.dst,
+                           tun_key->tos, tun_key->ttl,
+                           df, sport, dport, tun_key->tun_id);
+       iptunnel_xmit_stats(err, &dev->stats, (struct pcpu_sw_netstats __percpu *)dev->tstats);
+       return NETDEV_TX_OK;
+error:
+       kfree_skb(skb);
+       dev->stats.tx_errors++;
+       return NETDEV_TX_OK;
+}
+EXPORT_SYMBOL(ovs_stt_xmit);
 
 static void free_frag(struct stt_percpu *stt_percpu,
                      struct pkt_frag *frag)
@@ -1211,7 +1285,57 @@ static bool set_offloads(struct sk_buff *skb)
 
        return true;
 }
-static void stt_rcv(struct stt_sock *stt_sock, struct sk_buff *skb)
+
+static void rcv_list(struct net_device *dev, struct sk_buff *skb,
+                    struct metadata_dst *tun_dst)
+{
+       struct sk_buff *next;
+
+       do {
+               next = skb->next;
+               skb->next = NULL;
+               if (next) {
+                       ovs_dst_hold((struct dst_entry *)tun_dst);
+                       ovs_skb_dst_set(next, (struct dst_entry *)tun_dst);
+               }
+               ovs_ip_tunnel_rcv(dev, skb, tun_dst);
+       } while ((skb = next));
+}
+
+#ifndef HAVE_METADATA_DST
+static int __stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
+{
+       struct metadata_dst tun_dst;
+
+       ovs_ip_tun_rx_dst(&tun_dst.u.tun_info, skb, TUNNEL_KEY | TUNNEL_CSUM,
+                         get_unaligned(&stt_hdr(skb)->key), 0);
+       tun_dst.u.tun_info.key.tp_src = tcp_hdr(skb)->source;
+       tun_dst.u.tun_info.key.tp_dst = tcp_hdr(skb)->dest;
+
+       rcv_list(stt_dev->dev, skb, &tun_dst);
+       return 0;
+}
+#else
+static int __stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
+{
+       struct metadata_dst *tun_dst;
+       __be16 flags;
+       __be64 tun_id;
+
+       flags = TUNNEL_KEY | TUNNEL_CSUM;
+       tun_id = get_unaligned(&stt_hdr(skb)->key);
+       tun_dst = ip_tun_rx_dst(skb, flags, tun_id, 0);
+       if (!tun_dst)
+               return -ENOMEM;
+       tun_dst->u.tun_info.key.tp_src = tcp_hdr(skb)->source;
+       tun_dst->u.tun_info.key.tp_dst = tcp_hdr(skb)->dest;
+
+       rcv_list(stt_dev->dev, skb, tun_dst);
+       return 0;
+}
+#endif
+
+static void stt_rcv(struct stt_dev *stt_dev, struct sk_buff *skb)
 {
        int err;
 
@@ -1240,17 +1364,20 @@ static void stt_rcv(struct stt_sock *stt_sock, struct sk_buff *skb)
        if (skb_shinfo(skb)->frag_list && try_to_segment(skb))
                goto drop;
 
-       stt_sock->rcv(stt_sock, skb);
+       err = __stt_rcv(stt_dev, skb);
+       if (err)
+               goto drop;
        return;
 drop:
        /* Consume bad packet */
        kfree_skb_list(skb);
+       stt_dev->dev->stats.rx_errors++;
 }
 
 static void tcp_sock_release(struct socket *sock)
 {
        kernel_sock_shutdown(sock, SHUT_RDWR);
-       sk_release_kernel(sock->sk);
+       sock_release(sock);
 }
 
 static int tcp_sock_create4(struct net *net, __be16 port,
@@ -1260,12 +1387,10 @@ static int tcp_sock_create4(struct net *net, __be16 port,
        struct socket *sock = NULL;
        int err;
 
-       err = sock_create_kern(AF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
+       err = sock_create_kern(net, AF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
        if (err < 0)
                goto error;
 
-       sk_change_net(sock->sk, net);
-
        memset(&tcp_addr, 0, sizeof(tcp_addr));
        tcp_addr.sin_family = AF_INET;
        tcp_addr.sin_addr.s_addr = htonl(INADDR_ANY);
@@ -1319,18 +1444,32 @@ static void clean_percpu(struct work_struct *work)
 }
 
 #ifdef HAVE_NF_HOOKFN_ARG_OPS
-#define FIRST_PARAM const struct nf_hook_ops *ops,
+#define FIRST_PARAM const struct nf_hook_ops *ops
 #else
-#define FIRST_PARAM unsigned int hooknum,
+#define FIRST_PARAM unsigned int hooknum
 #endif
 
-static unsigned int nf_ip_hook(FIRST_PARAM
-                              struct sk_buff *skb,
-                              const struct net_device *in,
-                              const struct net_device *out,
-                              int (*okfn)(struct sk_buff *))
+#ifdef HAVE_NF_HOOK_STATE
+#if RHEL_RELEASE_CODE > RHEL_RELEASE_VERSION(7,0)
+/* RHEL nfhook hacks. */
+#ifndef __GENKSYMS__
+#define LAST_PARAM const struct net_device *in, const struct net_device *out, \
+                  const struct nf_hook_state *state
+#else
+#define LAST_PARAM const struct net_device *in, const struct net_device *out, \
+                  int (*okfn)(struct sk_buff *)
+#endif
+#else
+#define LAST_PARAM const struct nf_hook_state *state
+#endif
+#else
+#define LAST_PARAM const struct net_device *in, const struct net_device *out, \
+                  int (*okfn)(struct sk_buff *)
+#endif
+
+static unsigned int nf_ip_hook(FIRST_PARAM, struct sk_buff *skb, LAST_PARAM)
 {
-       struct stt_sock *stt_sock;
+       struct stt_dev *stt_dev;
        int ip_hdr_len;
 
        if (ip_hdr(skb)->protocol != IPPROTO_TCP)
@@ -1342,12 +1481,12 @@ static unsigned int nf_ip_hook(FIRST_PARAM
 
        skb_set_transport_header(skb, ip_hdr_len);
 
-       stt_sock = stt_find_sock(dev_net(skb->dev), tcp_hdr(skb)->dest);
-       if (!stt_sock)
+       stt_dev = stt_find_sock(dev_net(skb->dev), tcp_hdr(skb)->dest);
+       if (!stt_dev)
                return NF_ACCEPT;
 
        __skb_pull(skb, ip_hdr_len + sizeof(struct tcphdr));
-       stt_rcv(stt_sock, skb);
+       stt_rcv(stt_dev, skb);
        return NF_STOLEN;
 }
 
@@ -1359,8 +1498,9 @@ static struct nf_hook_ops nf_hook_ops __read_mostly = {
        .priority       = INT_MAX,
 };
 
-static int stt_start(void)
+static int stt_start(struct net *net)
 {
+       struct stt_net *sn = net_generic(net, stt_net_id);
        int err;
        int i;
 
@@ -1399,12 +1539,25 @@ static int stt_start(void)
                if (err)
                        goto free_percpu;
        }
+       schedule_clean_percpu();
+       n_tunnels++;
+
+       if (sn->n_tunnels) {
+               sn->n_tunnels++;
+               return 0;
+       }
+#ifdef HAVE_NF_REGISTER_NET_HOOK
+       /* On kernel which support per net nf-hook, nf_register_hook() takes
+        * rtnl-lock, which results in dead lock in stt-dev-create. Therefore
+        * use this new API.
+        */
+       err = nf_register_net_hook(net, &nf_hook_ops);
+#else
        err = nf_register_hook(&nf_hook_ops);
+#endif
        if (err)
                goto free_percpu;
-
-       schedule_clean_percpu();
-       n_tunnels++;
+       sn->n_tunnels++;
        return 0;
 
 free_percpu:
@@ -1421,17 +1574,26 @@ error:
        return err;
 }
 
-static void stt_cleanup(void)
+static void stt_cleanup(struct net *net)
 {
+       struct stt_net *sn = net_generic(net, stt_net_id);
        int i;
 
+       sn->n_tunnels--;
+       if (sn->n_tunnels)
+               goto out;
+#ifdef HAVE_NF_REGISTER_NET_HOOK
+       nf_unregister_net_hook(net, &nf_hook_ops);
+#else
+       nf_unregister_hook(&nf_hook_ops);
+#endif
+
+out:
        n_tunnels--;
        if (n_tunnels)
                return;
 
        cancel_delayed_work_sync(&clean_percpu_wq);
-       nf_unregister_hook(&nf_hook_ops);
-
        for_each_possible_cpu(i) {
                struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
                int j;
@@ -1449,102 +1611,306 @@ static void stt_cleanup(void)
        free_percpu(stt_percpu_data);
 }
 
-static struct stt_sock *stt_socket_create(struct net *net, __be16 port,
-                                         stt_rcv_t *rcv, void *data)
+static netdev_tx_t stt_dev_xmit(struct sk_buff *skb, struct net_device *dev)
 {
-       struct stt_net *sn = net_generic(net, stt_net_id);
-       struct stt_sock *stt_sock;
-       struct socket *sock;
+#ifdef HAVE_METADATA_DST
+       return ovs_stt_xmit(skb);
+#else
+       /* Drop All packets coming from networking stack. OVS-CB is
+        * not initialized for these packets.
+        */
+       dev_kfree_skb(skb);
+       dev->stats.tx_dropped++;
+       return NETDEV_TX_OK;
+#endif
+}
+
+/* Setup stats when device is created */
+static int stt_init(struct net_device *dev)
+{
+       dev->tstats = (typeof(dev->tstats)) netdev_alloc_pcpu_stats(struct pcpu_sw_netstats);
+       if (!dev->tstats)
+               return -ENOMEM;
+
+       return 0;
+}
+
+static void stt_uninit(struct net_device *dev)
+{
+       free_percpu(dev->tstats);
+}
+
+static int stt_open(struct net_device *dev)
+{
+       struct stt_dev *stt = netdev_priv(dev);
+       struct net *net = stt->net;
        int err;
 
-       stt_sock = kzalloc(sizeof(*stt_sock), GFP_KERNEL);
-       if (!stt_sock)
-               return ERR_PTR(-ENOMEM);
+       err = stt_start(net);
+       if (err)
+               return err;
+
+       err = tcp_sock_create4(net, stt->dst_port, &stt->sock);
+       if (err)
+               return err;
+       return 0;
+}
 
-       err = tcp_sock_create4(net, port, &sock);
-       if (err) {
-               kfree(stt_sock);
-               return ERR_PTR(err);
-       }
+static int stt_stop(struct net_device *dev)
+{
+       struct stt_dev *stt_dev = netdev_priv(dev);
+       struct net *net = stt_dev->net;
+
+       tcp_sock_release(stt_dev->sock);
+       stt_dev->sock = NULL;
+       stt_cleanup(net);
+       return 0;
+}
+
+static const struct net_device_ops stt_netdev_ops = {
+       .ndo_init               = stt_init,
+       .ndo_uninit             = stt_uninit,
+       .ndo_open               = stt_open,
+       .ndo_stop               = stt_stop,
+       .ndo_start_xmit         = stt_dev_xmit,
+       .ndo_get_stats64        = ip_tunnel_get_stats64,
+       .ndo_change_mtu         = eth_change_mtu,
+       .ndo_validate_addr      = eth_validate_addr,
+       .ndo_set_mac_address    = eth_mac_addr,
+};
+
+static void stt_get_drvinfo(struct net_device *dev,
+               struct ethtool_drvinfo *drvinfo)
+{
+       strlcpy(drvinfo->version, STT_NETDEV_VER, sizeof(drvinfo->version));
+       strlcpy(drvinfo->driver, "stt", sizeof(drvinfo->driver));
+}
+
+static const struct ethtool_ops stt_ethtool_ops = {
+       .get_drvinfo    = stt_get_drvinfo,
+       .get_link       = ethtool_op_get_link,
+};
+
+/* Info for udev, that this is a virtual tunnel endpoint */
+static struct device_type stt_type = {
+       .name = "stt",
+};
 
-       stt_sock->sock = sock;
-       stt_sock->rcv = rcv;
-       stt_sock->rcv_data = data;
+/* Initialize the device structure. */
+static void stt_setup(struct net_device *dev)
+{
+       ether_setup(dev);
+
+       dev->netdev_ops = &stt_netdev_ops;
+       dev->ethtool_ops = &stt_ethtool_ops;
+       dev->destructor = free_netdev;
+
+       SET_NETDEV_DEVTYPE(dev, &stt_type);
+
+       dev->features    |= NETIF_F_LLTX | NETIF_F_NETNS_LOCAL;
+       dev->features    |= NETIF_F_SG | NETIF_F_HW_CSUM;
+       dev->features    |= NETIF_F_RXCSUM;
+       dev->features    |= NETIF_F_GSO_SOFTWARE;
+
+       dev->hw_features |= NETIF_F_SG | NETIF_F_HW_CSUM | NETIF_F_RXCSUM;
+       dev->hw_features |= NETIF_F_GSO_SOFTWARE;
+
+#ifdef HAVE_METADATA_DST
+       netif_keep_dst(dev);
+#endif
+       dev->priv_flags |= IFF_LIVE_ADDR_CHANGE | IFF_NO_QUEUE;
+       eth_hw_addr_random(dev);
+}
+
+static const struct nla_policy stt_policy[IFLA_STT_MAX + 1] = {
+       [IFLA_STT_PORT]              = { .type = NLA_U16 },
+};
+
+static int stt_validate(struct nlattr *tb[], struct nlattr *data[])
+{
+       if (tb[IFLA_ADDRESS]) {
+               if (nla_len(tb[IFLA_ADDRESS]) != ETH_ALEN)
+                       return -EINVAL;
 
-       list_add_rcu(&stt_sock->list, &sn->sock_list);
+               if (!is_valid_ether_addr(nla_data(tb[IFLA_ADDRESS])))
+                       return -EADDRNOTAVAIL;
+       }
 
-       return stt_sock;
+       return 0;
 }
 
-static void __stt_sock_release(struct stt_sock *stt_sock)
+static struct stt_dev *find_dev(struct net *net, __be16 dst_port)
 {
-       list_del_rcu(&stt_sock->list);
-       tcp_sock_release(stt_sock->sock);
-       kfree_rcu(stt_sock, rcu);
+       struct stt_net *sn = net_generic(net, stt_net_id);
+       struct stt_dev *dev;
+
+       list_for_each_entry(dev, &sn->stt_list, next) {
+               if (dev->dst_port == dst_port)
+                       return dev;
+       }
+       return NULL;
 }
 
-struct stt_sock *rpl_stt_sock_add(struct net *net, __be16 port,
-                             stt_rcv_t *rcv, void *data)
+static int stt_configure(struct net *net, struct net_device *dev,
+                         __be16 dst_port)
 {
-       struct stt_sock *stt_sock;
+       struct stt_net *sn = net_generic(net, stt_net_id);
+       struct stt_dev *stt = netdev_priv(dev);
        int err;
 
-       err = stt_start();
+       stt->net = net;
+       stt->dev = dev;
+
+       stt->dst_port = dst_port;
+
+       if (find_dev(net, dst_port))
+               return -EBUSY;
+
+       err = register_netdevice(dev);
        if (err)
-               return ERR_PTR(err);
+               return err;
 
-       mutex_lock(&stt_mutex);
-       rcu_read_lock();
-       stt_sock = stt_find_sock(net, port);
-       rcu_read_unlock();
-       if (stt_sock)
-               stt_sock = ERR_PTR(-EBUSY);
-       else
-               stt_sock = stt_socket_create(net, port, rcv, data);
+       list_add_rcu(&stt->next, &sn->stt_list);
+       return 0;
+}
+
+static int stt_newlink(struct net *net, struct net_device *dev,
+               struct nlattr *tb[], struct nlattr *data[])
+{
+       __be16 dst_port = htons(STT_DST_PORT);
+
+       if (data[IFLA_STT_PORT])
+               dst_port = nla_get_be16(data[IFLA_STT_PORT]);
+
+       return stt_configure(net, dev, dst_port);
+}
 
-       mutex_unlock(&stt_mutex);
+static void stt_dellink(struct net_device *dev, struct list_head *head)
+{
+       struct stt_dev *stt = netdev_priv(dev);
 
-       if (IS_ERR(stt_sock))
-               stt_cleanup();
+       list_del_rcu(&stt->next);
+       unregister_netdevice_queue(dev, head);
+}
 
-       return stt_sock;
+static size_t stt_get_size(const struct net_device *dev)
+{
+       return nla_total_size(sizeof(__be32));  /* IFLA_STT_PORT */
 }
-EXPORT_SYMBOL_GPL(rpl_stt_sock_add);
 
-void rpl_stt_sock_release(struct stt_sock *stt_sock)
+static int stt_fill_info(struct sk_buff *skb, const struct net_device *dev)
 {
-       mutex_lock(&stt_mutex);
-       if (stt_sock) {
-               __stt_sock_release(stt_sock);
-               stt_cleanup();
+       struct stt_dev *stt = netdev_priv(dev);
+
+       if (nla_put_be16(skb, IFLA_STT_PORT, stt->dst_port))
+               goto nla_put_failure;
+
+       return 0;
+
+nla_put_failure:
+       return -EMSGSIZE;
+}
+
+static struct rtnl_link_ops stt_link_ops __read_mostly = {
+       .kind           = "stt",
+       .maxtype        = IFLA_STT_MAX,
+       .policy         = stt_policy,
+       .priv_size      = sizeof(struct stt_dev),
+       .setup          = stt_setup,
+       .validate       = stt_validate,
+       .newlink        = stt_newlink,
+       .dellink        = stt_dellink,
+       .get_size       = stt_get_size,
+       .fill_info      = stt_fill_info,
+};
+
+struct net_device *ovs_stt_dev_create_fb(struct net *net, const char *name,
+                                     u8 name_assign_type, u16 dst_port)
+{
+       struct nlattr *tb[IFLA_MAX + 1];
+       struct net_device *dev;
+       int err;
+
+       memset(tb, 0, sizeof(tb));
+       dev = rtnl_create_link(net, (char *) name, name_assign_type,
+                       &stt_link_ops, tb);
+       if (IS_ERR(dev))
+               return dev;
+
+       err = stt_configure(net, dev, htons(dst_port));
+       if (err) {
+               free_netdev(dev);
+               return ERR_PTR(err);
        }
-       mutex_unlock(&stt_mutex);
+       return dev;
 }
-EXPORT_SYMBOL_GPL(rpl_stt_sock_release);
+EXPORT_SYMBOL_GPL(ovs_stt_dev_create_fb);
 
 static int stt_init_net(struct net *net)
 {
        struct stt_net *sn = net_generic(net, stt_net_id);
 
-       INIT_LIST_HEAD(&sn->sock_list);
+       INIT_LIST_HEAD(&sn->stt_list);
        return 0;
 }
 
+static void stt_exit_net(struct net *net)
+{
+       struct stt_net *sn = net_generic(net, stt_net_id);
+       struct stt_dev *stt, *next;
+       struct net_device *dev, *aux;
+       LIST_HEAD(list);
+
+       rtnl_lock();
+
+       /* gather any stt devices that were moved into this ns */
+       for_each_netdev_safe(net, dev, aux)
+               if (dev->rtnl_link_ops == &stt_link_ops)
+                       unregister_netdevice_queue(dev, &list);
+
+       list_for_each_entry_safe(stt, next, &sn->stt_list, next) {
+               /* If stt->dev is in the same netns, it was already added
+                * to the stt by the previous loop.
+                */
+               if (!net_eq(dev_net(stt->dev), net))
+                       unregister_netdevice_queue(stt->dev, &list);
+       }
+
+       /* unregister the devices gathered above */
+       unregister_netdevice_many(&list);
+       rtnl_unlock();
+}
+
 static struct pernet_operations stt_net_ops = {
        .init = stt_init_net,
+       .exit = stt_exit_net,
        .id   = &stt_net_id,
        .size = sizeof(struct stt_net),
 };
 
-int ovs_stt_init_module(void)
+int stt_init_module(void)
 {
-       return register_pernet_subsys(&stt_net_ops);
+       int rc;
+
+       rc = register_pernet_subsys(&stt_net_ops);
+       if (rc)
+               goto out1;
+
+       rc = rtnl_link_register(&stt_link_ops);
+       if (rc)
+               goto out2;
+
+       pr_info("STT tunneling driver\n");
+       return 0;
+out2:
+       unregister_pernet_subsys(&stt_net_ops);
+out1:
+       return rc;
 }
-EXPORT_SYMBOL_GPL(ovs_stt_init_module);
 
-void ovs_stt_cleanup_module(void)
+void stt_cleanup_module(void)
 {
+       rtnl_link_unregister(&stt_link_ops);
        unregister_pernet_subsys(&stt_net_ops);
 }
-EXPORT_SYMBOL_GPL(ovs_stt_cleanup_module);
 #endif