soreuseport: setsockopt SO_ATTACH_REUSEPORT_[CE]BPF
[cascardo/linux.git] / net / core / sock_reuseport.c
index 963c8d5..ae0969c 100644 (file)
@@ -1,10 +1,12 @@
 /*
  * To speed up listener socket lookup, create an array to store all sockets
  * listening on the same port.  This allows a decision to be made after finding
- * the first socket.
+ * the first socket.  An optional BPF program can also be configured for
+ * selecting the socket index from the array of available sockets.
  */
 
 #include <net/sock_reuseport.h>
+#include <linux/bpf.h>
 #include <linux/rcupdate.h>
 
 #define INIT_SOCKS 128
@@ -22,6 +24,7 @@ static struct sock_reuseport *__reuseport_alloc(u16 max_socks)
 
        reuse->max_socks = max_socks;
 
+       RCU_INIT_POINTER(reuse->prog, NULL);
        return reuse;
 }
 
@@ -67,6 +70,7 @@ static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse)
 
        more_reuse->max_socks = more_socks_size;
        more_reuse->num_socks = reuse->num_socks;
+       more_reuse->prog = reuse->prog;
 
        memcpy(more_reuse->socks, reuse->socks,
               reuse->num_socks * sizeof(struct sock *));
@@ -75,6 +79,10 @@ static struct sock_reuseport *reuseport_grow(struct sock_reuseport *reuse)
                rcu_assign_pointer(reuse->socks[i]->sk_reuseport_cb,
                                   more_reuse);
 
+       /* Note: we use kfree_rcu here instead of reuseport_free_rcu so
+        * that reuse and more_reuse can temporarily share a reference
+        * to prog.
+        */
        kfree_rcu(reuse, rcu);
        return more_reuse;
 }
@@ -116,6 +124,16 @@ int reuseport_add_sock(struct sock *sk, const struct sock *sk2)
 }
 EXPORT_SYMBOL(reuseport_add_sock);
 
+static void reuseport_free_rcu(struct rcu_head *head)
+{
+       struct sock_reuseport *reuse;
+
+       reuse = container_of(head, struct sock_reuseport, rcu);
+       if (reuse->prog)
+               bpf_prog_destroy(reuse->prog);
+       kfree(reuse);
+}
+
 void reuseport_detach_sock(struct sock *sk)
 {
        struct sock_reuseport *reuse;
@@ -131,7 +149,7 @@ void reuseport_detach_sock(struct sock *sk)
                        reuse->socks[i] = reuse->socks[reuse->num_socks - 1];
                        reuse->num_socks--;
                        if (reuse->num_socks == 0)
-                               kfree_rcu(reuse, rcu);
+                               call_rcu(&reuse->rcu, reuseport_free_rcu);
                        break;
                }
        }
@@ -139,15 +157,53 @@ void reuseport_detach_sock(struct sock *sk)
 }
 EXPORT_SYMBOL(reuseport_detach_sock);
 
+static struct sock *run_bpf(struct sock_reuseport *reuse, u16 socks,
+                           struct bpf_prog *prog, struct sk_buff *skb,
+                           int hdr_len)
+{
+       struct sk_buff *nskb = NULL;
+       u32 index;
+
+       if (skb_shared(skb)) {
+               nskb = skb_clone(skb, GFP_ATOMIC);
+               if (!nskb)
+                       return NULL;
+               skb = nskb;
+       }
+
+       /* temporarily advance data past protocol header */
+       if (!pskb_pull(skb, hdr_len)) {
+               consume_skb(nskb);
+               return NULL;
+       }
+       index = bpf_prog_run_save_cb(prog, skb);
+       __skb_push(skb, hdr_len);
+
+       consume_skb(nskb);
+
+       if (index >= socks)
+               return NULL;
+
+       return reuse->socks[index];
+}
+
 /**
  *  reuseport_select_sock - Select a socket from an SO_REUSEPORT group.
  *  @sk: First socket in the group.
- *  @hash: Use this hash to select.
+ *  @hash: When no BPF filter is available, use this hash to select.
+ *  @skb: skb to run through BPF filter.
+ *  @hdr_len: BPF filter expects skb data pointer at payload data.  If
+ *    the skb does not yet point at the payload, this parameter represents
+ *    how far the pointer needs to advance to reach the payload.
  *  Returns a socket that should receive the packet (or NULL on error).
  */
-struct sock *reuseport_select_sock(struct sock *sk, u32 hash)
+struct sock *reuseport_select_sock(struct sock *sk,
+                                  u32 hash,
+                                  struct sk_buff *skb,
+                                  int hdr_len)
 {
        struct sock_reuseport *reuse;
+       struct bpf_prog *prog;
        struct sock *sk2 = NULL;
        u16 socks;
 
@@ -158,12 +214,16 @@ struct sock *reuseport_select_sock(struct sock *sk, u32 hash)
        if (!reuse)
                goto out;
 
+       prog = rcu_dereference(reuse->prog);
        socks = READ_ONCE(reuse->num_socks);
        if (likely(socks)) {
                /* paired with smp_wmb() in reuseport_add_sock() */
                smp_rmb();
 
-               sk2 = reuse->socks[reciprocal_scale(hash, socks)];
+               if (prog && skb)
+                       sk2 = run_bpf(reuse, socks, prog, skb, hdr_len);
+               else
+                       sk2 = reuse->socks[reciprocal_scale(hash, socks)];
        }
 
 out:
@@ -171,3 +231,21 @@ out:
        return sk2;
 }
 EXPORT_SYMBOL(reuseport_select_sock);
+
+struct bpf_prog *
+reuseport_attach_prog(struct sock *sk, struct bpf_prog *prog)
+{
+       struct sock_reuseport *reuse;
+       struct bpf_prog *old_prog;
+
+       spin_lock_bh(&reuseport_lock);
+       reuse = rcu_dereference_protected(sk->sk_reuseport_cb,
+                                         lockdep_is_held(&reuseport_lock));
+       old_prog = rcu_dereference_protected(reuse->prog,
+                                            lockdep_is_held(&reuseport_lock));
+       rcu_assign_pointer(reuse->prog, prog);
+       spin_unlock_bh(&reuseport_lock);
+
+       return old_prog;
+}
+EXPORT_SYMBOL(reuseport_attach_prog);