soreuseport: Prep for fast reuseport TCP socket selection
[cascardo/linux.git] / net / core / sock.c
index 0d91f7d..46dc8ad 100644 (file)
 #include <linux/sock_diag.h>
 
 #include <linux/filter.h>
+#include <net/sock_reuseport.h>
 
 #include <trace/events/sock.h>
 
@@ -194,44 +195,6 @@ bool sk_net_capable(const struct sock *sk, int cap)
 }
 EXPORT_SYMBOL(sk_net_capable);
 
-
-#ifdef CONFIG_MEMCG_KMEM
-int mem_cgroup_sockets_init(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
-{
-       struct proto *proto;
-       int ret = 0;
-
-       mutex_lock(&proto_list_mutex);
-       list_for_each_entry(proto, &proto_list, node) {
-               if (proto->init_cgroup) {
-                       ret = proto->init_cgroup(memcg, ss);
-                       if (ret)
-                               goto out;
-               }
-       }
-
-       mutex_unlock(&proto_list_mutex);
-       return ret;
-out:
-       list_for_each_entry_continue_reverse(proto, &proto_list, node)
-               if (proto->destroy_cgroup)
-                       proto->destroy_cgroup(memcg);
-       mutex_unlock(&proto_list_mutex);
-       return ret;
-}
-
-void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg)
-{
-       struct proto *proto;
-
-       mutex_lock(&proto_list_mutex);
-       list_for_each_entry_reverse(proto, &proto_list, node)
-               if (proto->destroy_cgroup)
-                       proto->destroy_cgroup(memcg);
-       mutex_unlock(&proto_list_mutex);
-}
-#endif
-
 /*
  * Each address family might have different locking rules, so we have
  * one slock key per address family:
@@ -239,11 +202,6 @@ void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg)
 static struct lock_class_key af_family_keys[AF_MAX];
 static struct lock_class_key af_family_slock_keys[AF_MAX];
 
-#if defined(CONFIG_MEMCG_KMEM)
-struct static_key memcg_socket_limit_enabled;
-EXPORT_SYMBOL(memcg_socket_limit_enabled);
-#endif
-
 /*
  * Make lock validator output more readable. (we pre-construct these
  * strings build-time, so that runtime initialization of socket
@@ -932,6 +890,32 @@ set_rcvbuf:
                }
                break;
 
+       case SO_ATTACH_REUSEPORT_CBPF:
+               ret = -EINVAL;
+               if (optlen == sizeof(struct sock_fprog)) {
+                       struct sock_fprog fprog;
+
+                       ret = -EFAULT;
+                       if (copy_from_user(&fprog, optval, sizeof(fprog)))
+                               break;
+
+                       ret = sk_reuseport_attach_filter(&fprog, sk);
+               }
+               break;
+
+       case SO_ATTACH_REUSEPORT_EBPF:
+               ret = -EINVAL;
+               if (optlen == sizeof(u32)) {
+                       u32 ufd;
+
+                       ret = -EFAULT;
+                       if (copy_from_user(&ufd, optval, sizeof(ufd)))
+                               break;
+
+                       ret = sk_reuseport_attach_bpf(ufd, sk);
+               }
+               break;
+
        case SO_DETACH_FILTER:
                ret = sk_detach_filter(sk);
                break;
@@ -1362,6 +1346,7 @@ static struct sock *sk_prot_alloc(struct proto *prot, gfp_t priority,
                if (!try_module_get(prot->owner))
                        goto out_free_sec;
                sk_tx_queue_clear(sk);
+               cgroup_sk_alloc(&sk->sk_cgrp_data);
        }
 
        return sk;
@@ -1384,6 +1369,7 @@ static void sk_prot_free(struct proto *prot, struct sock *sk)
        owner = prot->owner;
        slab = prot->slab;
 
+       cgroup_sk_free(&sk->sk_cgrp_data);
        security_sk_free(sk);
        if (slab != NULL)
                kmem_cache_free(slab, sk);
@@ -1392,17 +1378,6 @@ static void sk_prot_free(struct proto *prot, struct sock *sk)
        module_put(owner);
 }
 
-#if IS_ENABLED(CONFIG_CGROUP_NET_PRIO)
-void sock_update_netprioidx(struct sock *sk)
-{
-       if (in_interrupt())
-               return;
-
-       sk->sk_cgrp_prioidx = task_netprioidx(current);
-}
-EXPORT_SYMBOL_GPL(sock_update_netprioidx);
-#endif
-
 /**
  *     sk_alloc - All socket objects are allocated here
  *     @net: the applicable net namespace
@@ -1431,8 +1406,8 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
                sock_net_set(sk, net);
                atomic_set(&sk->sk_wmem_alloc, 1);
 
-               sock_update_classid(sk);
-               sock_update_netprioidx(sk);
+               sock_update_classid(&sk->sk_cgrp_data);
+               sock_update_netprioidx(&sk->sk_cgrp_data);
        }
 
        return sk;
@@ -1452,6 +1427,8 @@ void sk_destruct(struct sock *sk)
                sk_filter_uncharge(sk, filter);
                RCU_INIT_POINTER(sk->sk_filter, NULL);
        }
+       if (rcu_access_pointer(sk->sk_reuseport_cb))
+               reuseport_detach_sock(sk);
 
        sock_disable_timestamp(sk, SK_FLAGS_TIMESTAMP);
 
@@ -1487,12 +1464,6 @@ void sk_free(struct sock *sk)
 }
 EXPORT_SYMBOL(sk_free);
 
-static void sk_update_clone(const struct sock *sk, struct sock *newsk)
-{
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               sock_update_memcg(newsk);
-}
-
 /**
  *     sk_clone_lock - clone a socket, and lock its clone
  *     @sk: the socket to clone
@@ -1560,6 +1531,7 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
                        newsk = NULL;
                        goto out;
                }
+               RCU_INIT_POINTER(newsk->sk_reuseport_cb, NULL);
 
                newsk->sk_err      = 0;
                newsk->sk_priority = 0;
@@ -1587,7 +1559,8 @@ struct sock *sk_clone_lock(const struct sock *sk, const gfp_t priority)
                sk_set_socket(newsk, NULL);
                newsk->sk_wq = NULL;
 
-               sk_update_clone(sk, newsk);
+               if (mem_cgroup_sockets_enabled && sk->sk_memcg)
+                       sock_update_memcg(newsk);
 
                if (newsk->sk_prot->sockets_allocated)
                        sk_sockets_allocated_inc(newsk);
@@ -2069,27 +2042,27 @@ int __sk_mem_schedule(struct sock *sk, int size, int kind)
        struct proto *prot = sk->sk_prot;
        int amt = sk_mem_pages(size);
        long allocated;
-       int parent_status = UNDER_LIMIT;
 
        sk->sk_forward_alloc += amt * SK_MEM_QUANTUM;
 
-       allocated = sk_memory_allocated_add(sk, amt, &parent_status);
+       allocated = sk_memory_allocated_add(sk, amt);
+
+       if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
+           !mem_cgroup_charge_skmem(sk->sk_memcg, amt))
+               goto suppress_allocation;
 
        /* Under limit. */
-       if (parent_status == UNDER_LIMIT &&
-                       allocated <= sk_prot_mem_limits(sk, 0)) {
+       if (allocated <= sk_prot_mem_limits(sk, 0)) {
                sk_leave_memory_pressure(sk);
                return 1;
        }
 
-       /* Under pressure. (we or our parents) */
-       if ((parent_status > SOFT_LIMIT) ||
-                       allocated > sk_prot_mem_limits(sk, 1))
+       /* Under pressure. */
+       if (allocated > sk_prot_mem_limits(sk, 1))
                sk_enter_memory_pressure(sk);
 
-       /* Over hard limit (we or our parents) */
-       if ((parent_status == OVER_LIMIT) ||
-                       (allocated > sk_prot_mem_limits(sk, 2)))
+       /* Over hard limit. */
+       if (allocated > sk_prot_mem_limits(sk, 2))
                goto suppress_allocation;
 
        /* guarantee minimum buffer size under pressure */
@@ -2138,6 +2111,9 @@ suppress_allocation:
 
        sk_memory_allocated_sub(sk, amt);
 
+       if (mem_cgroup_sockets_enabled && sk->sk_memcg)
+               mem_cgroup_uncharge_skmem(sk->sk_memcg, amt);
+
        return 0;
 }
 EXPORT_SYMBOL(__sk_mem_schedule);
@@ -2153,6 +2129,9 @@ void __sk_mem_reclaim(struct sock *sk, int amount)
        sk_memory_allocated_sub(sk, amount);
        sk->sk_forward_alloc -= amount << SK_MEM_QUANTUM_SHIFT;
 
+       if (mem_cgroup_sockets_enabled && sk->sk_memcg)
+               mem_cgroup_uncharge_skmem(sk->sk_memcg, amount);
+
        if (sk_under_memory_pressure(sk) &&
            (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
                sk_leave_memory_pressure(sk);
@@ -2281,7 +2260,7 @@ static void sock_def_wakeup(struct sock *sk)
 
        rcu_read_lock();
        wq = rcu_dereference(sk->sk_wq);
-       if (wq_has_sleeper(wq))
+       if (skwq_has_sleeper(wq))
                wake_up_interruptible_all(&wq->wait);
        rcu_read_unlock();
 }
@@ -2292,7 +2271,7 @@ static void sock_def_error_report(struct sock *sk)
 
        rcu_read_lock();
        wq = rcu_dereference(sk->sk_wq);
-       if (wq_has_sleeper(wq))
+       if (skwq_has_sleeper(wq))
                wake_up_interruptible_poll(&wq->wait, POLLERR);
        sk_wake_async(sk, SOCK_WAKE_IO, POLL_ERR);
        rcu_read_unlock();
@@ -2304,7 +2283,7 @@ static void sock_def_readable(struct sock *sk)
 
        rcu_read_lock();
        wq = rcu_dereference(sk->sk_wq);
-       if (wq_has_sleeper(wq))
+       if (skwq_has_sleeper(wq))
                wake_up_interruptible_sync_poll(&wq->wait, POLLIN | POLLPRI |
                                                POLLRDNORM | POLLRDBAND);
        sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
@@ -2322,7 +2301,7 @@ static void sock_def_write_space(struct sock *sk)
         */
        if ((atomic_read(&sk->sk_wmem_alloc) << 1) <= sk->sk_sndbuf) {
                wq = rcu_dereference(sk->sk_wq);
-               if (wq_has_sleeper(wq))
+               if (skwq_has_sleeper(wq))
                        wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
                                                POLLWRNORM | POLLWRBAND);