tun: Don't assume type tun in tun_device_event
[cascardo/linux.git] / drivers / net / tun.c
index e16487c..9c8b5bc 100644 (file)
@@ -71,6 +71,7 @@
 #include <net/sock.h>
 #include <linux/seq_file.h>
 #include <linux/uio.h>
+#include <linux/skb_array.h>
 
 #include <asm/uaccess.h>
 
@@ -167,6 +168,7 @@ struct tun_file {
        };
        struct list_head next;
        struct tun_struct *detached;
+       struct skb_array tx_array;
 };
 
 struct tun_flow_entry {
@@ -515,7 +517,11 @@ static struct tun_struct *tun_enable_queue(struct tun_file *tfile)
 
 static void tun_queue_purge(struct tun_file *tfile)
 {
-       skb_queue_purge(&tfile->sk.sk_receive_queue);
+       struct sk_buff *skb;
+
+       while ((skb = skb_array_consume(&tfile->tx_array)) != NULL)
+               kfree_skb(skb);
+
        skb_queue_purge(&tfile->sk.sk_error_queue);
 }
 
@@ -560,6 +566,8 @@ static void __tun_detach(struct tun_file *tfile, bool clean)
                            tun->dev->reg_state == NETREG_REGISTERED)
                                unregister_netdevice(tun->dev);
                }
+               if (tun)
+                       skb_array_cleanup(&tfile->tx_array);
                sock_put(&tfile->sk);
        }
 }
@@ -613,6 +621,7 @@ static void tun_detach_all(struct net_device *dev)
 static int tun_attach(struct tun_struct *tun, struct file *file, bool skip_filter)
 {
        struct tun_file *tfile = file->private_data;
+       struct net_device *dev = tun->dev;
        int err;
 
        err = security_tun_dev_attach(tfile->socket.sk, tun->security);
@@ -642,6 +651,13 @@ static int tun_attach(struct tun_struct *tun, struct file *file, bool skip_filte
                if (!err)
                        goto out;
        }
+
+       if (!tfile->detached &&
+           skb_array_init(&tfile->tx_array, dev->tx_queue_len, GFP_KERNEL)) {
+               err = -ENOMEM;
+               goto out;
+       }
+
        tfile->queue_index = tun->numqueues;
        tfile->socket.sk->sk_shutdown &= ~RCV_SHUTDOWN;
        rcu_assign_pointer(tfile->tun, tun);
@@ -891,8 +907,8 @@ static netdev_tx_t tun_net_xmit(struct sk_buff *skb, struct net_device *dev)
 
        nf_reset(skb);
 
-       /* Enqueue packet */
-       skb_queue_tail(&tfile->socket.sk->sk_receive_queue, skb);
+       if (skb_array_produce(&tfile->tx_array, skb))
+               goto drop;
 
        /* Notify and wake up reader process */
        if (tfile->flags & TUN_FASYNC)
@@ -1107,7 +1123,7 @@ static unsigned int tun_chr_poll(struct file *file, poll_table *wait)
 
        poll_wait(file, sk_sleep(sk), wait);
 
-       if (!skb_queue_empty(&sk->sk_receive_queue))
+       if (!skb_array_empty(&tfile->tx_array))
                mask |= POLLIN | POLLRDNORM;
 
        if (sock_writeable(sk) ||
@@ -1254,13 +1270,11 @@ static ssize_t tun_get_user(struct tun_struct *tun, struct tun_file *tfile,
                return -EFAULT;
        }
 
-       if (gso.flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) {
-               if (!skb_partial_csum_set(skb, tun16_to_cpu(tun, gso.csum_start),
-                                         tun16_to_cpu(tun, gso.csum_offset))) {
-                       this_cpu_inc(tun->pcpu_stats->rx_frame_errors);
-                       kfree_skb(skb);
-                       return -EINVAL;
-               }
+       err = virtio_net_hdr_to_skb(skb, &gso, tun_is_little_endian(tun));
+       if (err) {
+               this_cpu_inc(tun->pcpu_stats->rx_frame_errors);
+               kfree_skb(skb);
+               return -EINVAL;
        }
 
        switch (tun->flags & TUN_TYPE_MASK) {
@@ -1289,39 +1303,6 @@ static ssize_t tun_get_user(struct tun_struct *tun, struct tun_file *tfile,
                break;
        }
 
-       if (gso.gso_type != VIRTIO_NET_HDR_GSO_NONE) {
-               pr_debug("GSO!\n");
-               switch (gso.gso_type & ~VIRTIO_NET_HDR_GSO_ECN) {
-               case VIRTIO_NET_HDR_GSO_TCPV4:
-                       skb_shinfo(skb)->gso_type = SKB_GSO_TCPV4;
-                       break;
-               case VIRTIO_NET_HDR_GSO_TCPV6:
-                       skb_shinfo(skb)->gso_type = SKB_GSO_TCPV6;
-                       break;
-               case VIRTIO_NET_HDR_GSO_UDP:
-                       skb_shinfo(skb)->gso_type = SKB_GSO_UDP;
-                       break;
-               default:
-                       this_cpu_inc(tun->pcpu_stats->rx_frame_errors);
-                       kfree_skb(skb);
-                       return -EINVAL;
-               }
-
-               if (gso.gso_type & VIRTIO_NET_HDR_GSO_ECN)
-                       skb_shinfo(skb)->gso_type |= SKB_GSO_TCP_ECN;
-
-               skb_shinfo(skb)->gso_size = tun16_to_cpu(tun, gso.gso_size);
-               if (skb_shinfo(skb)->gso_size == 0) {
-                       this_cpu_inc(tun->pcpu_stats->rx_frame_errors);
-                       kfree_skb(skb);
-                       return -EINVAL;
-               }
-
-               /* Header must be checked, and gso_segs computed. */
-               skb_shinfo(skb)->gso_type |= SKB_GSO_DODGY;
-               skb_shinfo(skb)->gso_segs = 0;
-       }
-
        /* copy skb_ubuf_info for callback when skb has no error */
        if (zerocopy) {
                skb_shinfo(skb)->destructor_arg = msg_control;
@@ -1399,46 +1380,26 @@ static ssize_t tun_put_user(struct tun_struct *tun,
 
        if (vnet_hdr_sz) {
                struct virtio_net_hdr gso = { 0 }; /* no info leak */
+               int ret;
+
                if (iov_iter_count(iter) < vnet_hdr_sz)
                        return -EINVAL;
 
-               if (skb_is_gso(skb)) {
+               ret = virtio_net_hdr_from_skb(skb, &gso,
+                                             tun_is_little_endian(tun));
+               if (ret) {
                        struct skb_shared_info *sinfo = skb_shinfo(skb);
-
-                       /* This is a hint as to how much should be linear. */
-                       gso.hdr_len = cpu_to_tun16(tun, skb_headlen(skb));
-                       gso.gso_size = cpu_to_tun16(tun, sinfo->gso_size);
-                       if (sinfo->gso_type & SKB_GSO_TCPV4)
-                               gso.gso_type = VIRTIO_NET_HDR_GSO_TCPV4;
-                       else if (sinfo->gso_type & SKB_GSO_TCPV6)
-                               gso.gso_type = VIRTIO_NET_HDR_GSO_TCPV6;
-                       else if (sinfo->gso_type & SKB_GSO_UDP)
-                               gso.gso_type = VIRTIO_NET_HDR_GSO_UDP;
-                       else {
-                               pr_err("unexpected GSO type: "
-                                      "0x%x, gso_size %d, hdr_len %d\n",
-                                      sinfo->gso_type, tun16_to_cpu(tun, gso.gso_size),
-                                      tun16_to_cpu(tun, gso.hdr_len));
-                               print_hex_dump(KERN_ERR, "tun: ",
-                                              DUMP_PREFIX_NONE,
-                                              16, 1, skb->head,
-                                              min((int)tun16_to_cpu(tun, gso.hdr_len), 64), true);
-                               WARN_ON_ONCE(1);
-                               return -EINVAL;
-                       }
-                       if (sinfo->gso_type & SKB_GSO_TCP_ECN)
-                               gso.gso_type |= VIRTIO_NET_HDR_GSO_ECN;
-               } else
-                       gso.gso_type = VIRTIO_NET_HDR_GSO_NONE;
-
-               if (skb->ip_summed == CHECKSUM_PARTIAL) {
-                       gso.flags = VIRTIO_NET_HDR_F_NEEDS_CSUM;
-                       gso.csum_start = cpu_to_tun16(tun, skb_checksum_start_offset(skb) +
-                                                     vlan_hlen);
-                       gso.csum_offset = cpu_to_tun16(tun, skb->csum_offset);
-               } else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
-                       gso.flags = VIRTIO_NET_HDR_F_DATA_VALID;
-               } /* else everything is zero */
+                       pr_err("unexpected GSO type: "
+                              "0x%x, gso_size %d, hdr_len %d\n",
+                              sinfo->gso_type, tun16_to_cpu(tun, gso.gso_size),
+                              tun16_to_cpu(tun, gso.hdr_len));
+                       print_hex_dump(KERN_ERR, "tun: ",
+                                      DUMP_PREFIX_NONE,
+                                      16, 1, skb->head,
+                                      min((int)tun16_to_cpu(tun, gso.hdr_len), 64), true);
+                       WARN_ON_ONCE(1);
+                       return -EINVAL;
+               }
 
                if (copy_to_iter(&gso, sizeof(gso), iter) != sizeof(gso))
                        return -EFAULT;
@@ -1481,22 +1442,63 @@ done:
        return total;
 }
 
+static struct sk_buff *tun_ring_recv(struct tun_file *tfile, int noblock,
+                                    int *err)
+{
+       DECLARE_WAITQUEUE(wait, current);
+       struct sk_buff *skb = NULL;
+       int error = 0;
+
+       skb = skb_array_consume(&tfile->tx_array);
+       if (skb)
+               goto out;
+       if (noblock) {
+               error = -EAGAIN;
+               goto out;
+       }
+
+       add_wait_queue(&tfile->wq.wait, &wait);
+       current->state = TASK_INTERRUPTIBLE;
+
+       while (1) {
+               skb = skb_array_consume(&tfile->tx_array);
+               if (skb)
+                       break;
+               if (signal_pending(current)) {
+                       error = -ERESTARTSYS;
+                       break;
+               }
+               if (tfile->socket.sk->sk_shutdown & RCV_SHUTDOWN) {
+                       error = -EFAULT;
+                       break;
+               }
+
+               schedule();
+       }
+
+       current->state = TASK_RUNNING;
+       remove_wait_queue(&tfile->wq.wait, &wait);
+
+out:
+       *err = error;
+       return skb;
+}
+
 static ssize_t tun_do_read(struct tun_struct *tun, struct tun_file *tfile,
                           struct iov_iter *to,
                           int noblock)
 {
        struct sk_buff *skb;
        ssize_t ret;
-       int peeked, err, off = 0;
+       int err;
 
        tun_debug(KERN_INFO, tun, "tun_do_read\n");
 
        if (!iov_iter_count(to))
                return 0;
 
-       /* Read frames from queue */
-       skb = __skb_recv_datagram(tfile->socket.sk, noblock ? MSG_DONTWAIT : 0,
-                                 &peeked, &off, &err);
+       /* Read frames from ring */
+       skb = tun_ring_recv(tfile, noblock, &err);
        if (!skb)
                return err;
 
@@ -1629,8 +1631,25 @@ out:
        return ret;
 }
 
+static int tun_peek_len(struct socket *sock)
+{
+       struct tun_file *tfile = container_of(sock, struct tun_file, socket);
+       struct tun_struct *tun;
+       int ret = 0;
+
+       tun = __tun_get(tfile);
+       if (!tun)
+               return 0;
+
+       ret = skb_array_peek_len(&tfile->tx_array);
+       tun_put(tun);
+
+       return ret;
+}
+
 /* Ops structure to mimic raw sockets with tun */
 static const struct proto_ops tun_socket_ops = {
+       .peek_len = tun_peek_len,
        .sendmsg = tun_sendmsg,
        .recvmsg = tun_recvmsg,
 };
@@ -2452,6 +2471,56 @@ static const struct ethtool_ops tun_ethtool_ops = {
        .get_ts_info    = ethtool_op_get_ts_info,
 };
 
+static int tun_queue_resize(struct tun_struct *tun)
+{
+       struct net_device *dev = tun->dev;
+       struct tun_file *tfile;
+       struct skb_array **arrays;
+       int n = tun->numqueues + tun->numdisabled;
+       int ret, i;
+
+       arrays = kmalloc(sizeof *arrays * n, GFP_KERNEL);
+       if (!arrays)
+               return -ENOMEM;
+
+       for (i = 0; i < tun->numqueues; i++) {
+               tfile = rtnl_dereference(tun->tfiles[i]);
+               arrays[i] = &tfile->tx_array;
+       }
+       list_for_each_entry(tfile, &tun->disabled, next)
+               arrays[i++] = &tfile->tx_array;
+
+       ret = skb_array_resize_multiple(arrays, n,
+                                       dev->tx_queue_len, GFP_KERNEL);
+
+       kfree(arrays);
+       return ret;
+}
+
+static int tun_device_event(struct notifier_block *unused,
+                           unsigned long event, void *ptr)
+{
+       struct net_device *dev = netdev_notifier_info_to_dev(ptr);
+       struct tun_struct *tun = netdev_priv(dev);
+
+       if (dev->rtnl_link_ops != &tun_link_ops)
+               return NOTIFY_DONE;
+
+       switch (event) {
+       case NETDEV_CHANGE_TX_QUEUE_LEN:
+               if (tun_queue_resize(tun))
+                       return NOTIFY_BAD;
+               break;
+       default:
+               break;
+       }
+
+       return NOTIFY_DONE;
+}
+
+static struct notifier_block tun_notifier_block __read_mostly = {
+       .notifier_call  = tun_device_event,
+};
 
 static int __init tun_init(void)
 {
@@ -2471,6 +2540,8 @@ static int __init tun_init(void)
                pr_err("Can't register misc device %d\n", TUN_MINOR);
                goto err_misc;
        }
+
+       register_netdevice_notifier(&tun_notifier_block);
        return  0;
 err_misc:
        rtnl_link_unregister(&tun_link_ops);
@@ -2482,6 +2553,7 @@ static void tun_cleanup(void)
 {
        misc_deregister(&tun_miscdev);
        rtnl_link_unregister(&tun_link_ops);
+       unregister_netdevice_notifier(&tun_notifier_block);
 }
 
 /* Get an underlying socket object from tun file.  Returns error unless file is