packet: make packet_snd fail on len smaller than l2 header
[cascardo/linux.git] / net / packet / af_packet.c
index 87d20f4..58af580 100644 (file)
@@ -2095,6 +2095,18 @@ static void tpacket_destruct_skb(struct sk_buff *skb)
        sock_wfree(skb);
 }
 
+static bool ll_header_truncated(const struct net_device *dev, int len)
+{
+       /* net device doesn't like empty head */
+       if (unlikely(len <= dev->hard_header_len)) {
+               net_warn_ratelimited("%s: packet size is too short (%d < %d)\n",
+                                    current->comm, len, dev->hard_header_len);
+               return true;
+       }
+
+       return false;
+}
+
 static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
                void *frame, struct net_device *dev, int size_max,
                __be16 proto, unsigned char *addr, int hlen)
@@ -2170,12 +2182,8 @@ static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
                if (unlikely(err < 0))
                        return -EINVAL;
        } else if (dev->hard_header_len) {
-               /* net device doesn't like empty head */
-               if (unlikely(tp_len <= dev->hard_header_len)) {
-                       pr_err("packet size is too short (%d < %d)\n",
-                              tp_len, dev->hard_header_len);
+               if (ll_header_truncated(dev, tp_len))
                        return -EINVAL;
-               }
 
                skb_push(skb, dev->hard_header_len);
                err = skb_store_bits(skb, 0, data,
@@ -2500,9 +2508,14 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
        skb_set_network_header(skb, reserve);
 
        err = -EINVAL;
-       if (sock->type == SOCK_DGRAM &&
-           (offset = dev_hard_header(skb, dev, ntohs(proto), addr, NULL, len)) < 0)
-               goto out_free;
+       if (sock->type == SOCK_DGRAM) {
+               offset = dev_hard_header(skb, dev, ntohs(proto), addr, NULL, len);
+               if (unlikely(offset) < 0)
+                       goto out_free;
+       } else {
+               if (ll_header_truncated(dev, len))
+                       goto out_free;
+       }
 
        /* Returns -EFAULT on error */
        err = skb_copy_datagram_from_iovec(skb, offset, msg->msg_iov, 0, len);
@@ -2953,7 +2966,7 @@ static int packet_recvmsg(struct kiocb *iocb, struct socket *sock,
                msg->msg_flags |= MSG_TRUNC;
        }
 
-       err = skb_copy_datagram_iovec(skb, 0, msg->msg_iov, copied);
+       err = skb_copy_datagram_msg(skb, 0, msg, copied);
        if (err)
                goto out_free;