bpf: try harder on clones when writing into skb
[cascardo/linux.git] / net / core / filter.c
index 94d2620..f031b82 100644 (file)
@@ -1181,7 +1181,7 @@ static int __reuseport_attach_prog(struct bpf_prog *prog, struct sock *sk)
        if (bpf_prog_size(prog->len) > sysctl_optmem_max)
                return -ENOMEM;
 
-       if (sk_unhashed(sk)) {
+       if (sk_unhashed(sk) && sk->sk_reuseport) {
                err = reuseport_alloc(sk);
                if (err)
                        return err;
@@ -1333,15 +1333,22 @@ int sk_reuseport_attach_bpf(u32 ufd, struct sock *sk)
        return 0;
 }
 
-#define BPF_LDST_LEN 16U
+struct bpf_scratchpad {
+       union {
+               __be32 diff[MAX_BPF_STACK / sizeof(__be32)];
+               u8     buff[MAX_BPF_STACK];
+       };
+};
+
+static DEFINE_PER_CPU(struct bpf_scratchpad, bpf_sp);
 
 static u64 bpf_skb_store_bytes(u64 r1, u64 r2, u64 r3, u64 r4, u64 flags)
 {
+       struct bpf_scratchpad *sp = this_cpu_ptr(&bpf_sp);
        struct sk_buff *skb = (struct sk_buff *) (long) r1;
        int offset = (int) r2;
        void *from = (void *) (long) r3;
        unsigned int len = (unsigned int) r4;
-       char buf[BPF_LDST_LEN];
        void *ptr;
 
        if (unlikely(flags & ~(BPF_F_RECOMPUTE_CSUM)))
@@ -1355,14 +1362,12 @@ static u64 bpf_skb_store_bytes(u64 r1, u64 r2, u64 r3, u64 r4, u64 flags)
         *
         * so check for invalid 'offset' and too large 'len'
         */
-       if (unlikely((u32) offset > 0xffff || len > sizeof(buf)))
+       if (unlikely((u32) offset > 0xffff || len > sizeof(sp->buff)))
                return -EFAULT;
-
-       if (unlikely(skb_cloned(skb) &&
-                    !skb_clone_writable(skb, offset + len)))
+       if (unlikely(skb_try_make_writable(skb, offset + len)))
                return -EFAULT;
 
-       ptr = skb_header_pointer(skb, offset, len, buf);
+       ptr = skb_header_pointer(skb, offset, len, sp->buff);
        if (unlikely(!ptr))
                return -EFAULT;
 
@@ -1371,7 +1376,7 @@ static u64 bpf_skb_store_bytes(u64 r1, u64 r2, u64 r3, u64 r4, u64 flags)
 
        memcpy(ptr, from, len);
 
-       if (ptr == buf)
+       if (ptr == sp->buff)
                /* skb_store_bits cannot return -EFAULT here */
                skb_store_bits(skb, offset, ptr, len);
 
@@ -1400,7 +1405,7 @@ static u64 bpf_skb_load_bytes(u64 r1, u64 r2, u64 r3, u64 r4, u64 r5)
        unsigned int len = (unsigned int) r4;
        void *ptr;
 
-       if (unlikely((u32) offset > 0xffff || len > BPF_LDST_LEN))
+       if (unlikely((u32) offset > 0xffff || len > MAX_BPF_STACK))
                return -EFAULT;
 
        ptr = skb_header_pointer(skb, offset, len, to);
@@ -1432,9 +1437,7 @@ static u64 bpf_l3_csum_replace(u64 r1, u64 r2, u64 from, u64 to, u64 flags)
                return -EINVAL;
        if (unlikely((u32) offset > 0xffff))
                return -EFAULT;
-
-       if (unlikely(skb_cloned(skb) &&
-                    !skb_clone_writable(skb, offset + sizeof(sum))))
+       if (unlikely(skb_try_make_writable(skb, offset + sizeof(sum))))
                return -EFAULT;
 
        ptr = skb_header_pointer(skb, offset, sizeof(sum), &sum);
@@ -1481,9 +1484,7 @@ static u64 bpf_l4_csum_replace(u64 r1, u64 r2, u64 from, u64 to, u64 flags)
                return -EINVAL;
        if (unlikely((u32) offset > 0xffff))
                return -EFAULT;
-
-       if (unlikely(skb_cloned(skb) &&
-                    !skb_clone_writable(skb, offset + sizeof(sum))))
+       if (unlikely(skb_try_make_writable(skb, offset + sizeof(sum))))
                return -EFAULT;
 
        ptr = skb_header_pointer(skb, offset, sizeof(sum), &sum);
@@ -1491,6 +1492,12 @@ static u64 bpf_l4_csum_replace(u64 r1, u64 r2, u64 from, u64 to, u64 flags)
                return -EFAULT;
 
        switch (flags & BPF_F_HDR_FIELD_MASK) {
+       case 0:
+               if (unlikely(from != 0))
+                       return -EINVAL;
+
+               inet_proto_csum_replace_by_diff(ptr, skb, to, is_pseudo);
+               break;
        case 2:
                inet_proto_csum_replace2(ptr, skb, from, to, is_pseudo);
                break;
@@ -1519,6 +1526,45 @@ const struct bpf_func_proto bpf_l4_csum_replace_proto = {
        .arg5_type      = ARG_ANYTHING,
 };
 
+static u64 bpf_csum_diff(u64 r1, u64 from_size, u64 r3, u64 to_size, u64 seed)
+{
+       struct bpf_scratchpad *sp = this_cpu_ptr(&bpf_sp);
+       u64 diff_size = from_size + to_size;
+       __be32 *from = (__be32 *) (long) r1;
+       __be32 *to   = (__be32 *) (long) r3;
+       int i, j = 0;
+
+       /* This is quite flexible, some examples:
+        *
+        * from_size == 0, to_size > 0,  seed := csum --> pushing data
+        * from_size > 0,  to_size == 0, seed := csum --> pulling data
+        * from_size > 0,  to_size > 0,  seed := 0    --> diffing data
+        *
+        * Even for diffing, from_size and to_size don't need to be equal.
+        */
+       if (unlikely(((from_size | to_size) & (sizeof(__be32) - 1)) ||
+                    diff_size > sizeof(sp->diff)))
+               return -EINVAL;
+
+       for (i = 0; i < from_size / sizeof(__be32); i++, j++)
+               sp->diff[j] = ~from[i];
+       for (i = 0; i <   to_size / sizeof(__be32); i++, j++)
+               sp->diff[j] = to[i];
+
+       return csum_partial(sp->diff, diff_size, seed);
+}
+
+const struct bpf_func_proto bpf_csum_diff_proto = {
+       .func           = bpf_csum_diff,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_STACK,
+       .arg2_type      = ARG_CONST_STACK_SIZE_OR_ZERO,
+       .arg3_type      = ARG_PTR_TO_STACK,
+       .arg4_type      = ARG_CONST_STACK_SIZE_OR_ZERO,
+       .arg5_type      = ARG_ANYTHING,
+};
+
 static u64 bpf_clone_redirect(u64 r1, u64 ifindex, u64 flags, u64 r4, u64 r5)
 {
        struct sk_buff *skb = (struct sk_buff *) (long) r1, *skb2;
@@ -1682,6 +1728,13 @@ bool bpf_helper_changes_skb_data(void *func)
                return true;
        if (func == bpf_skb_vlan_pop)
                return true;
+       if (func == bpf_skb_store_bytes)
+               return true;
+       if (func == bpf_l3_csum_replace)
+               return true;
+       if (func == bpf_l4_csum_replace)
+               return true;
+
        return false;
 }
 
@@ -1849,6 +1902,8 @@ tc_cls_act_func_proto(enum bpf_func_id func_id)
                return &bpf_skb_store_bytes_proto;
        case BPF_FUNC_skb_load_bytes:
                return &bpf_skb_load_bytes_proto;
+       case BPF_FUNC_csum_diff:
+               return &bpf_csum_diff_proto;
        case BPF_FUNC_l3_csum_replace:
                return &bpf_l3_csum_replace_proto;
        case BPF_FUNC_l4_csum_replace: