sock: struct proto hash function may error
[cascardo/linux.git] / include / net / sock.h
index 14d3c07..255d3e0 100644 (file)
@@ -58,6 +58,8 @@
 #include <linux/memcontrol.h>
 #include <linux/static_key.h>
 #include <linux/sched.h>
+#include <linux/wait.h>
+#include <linux/cgroup-defs.h>
 
 #include <linux/filter.h>
 #include <linux/rculist_nulls.h>
 #include <net/tcp_states.h>
 #include <linux/net_tstamp.h>
 
-struct cgroup;
-struct cgroup_subsys;
-#ifdef CONFIG_NET
-int mem_cgroup_sockets_init(struct mem_cgroup *memcg, struct cgroup_subsys *ss);
-void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg);
-#else
-static inline
-int mem_cgroup_sockets_init(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
-{
-       return 0;
-}
-static inline
-void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg)
-{
-}
-#endif
 /*
  * This structure really needs to be cleaned up.
  * Most of it is for TCP, and not used by any of
@@ -243,7 +229,6 @@ struct sock_common {
        /* public: */
 };
 
-struct cg_proto;
 /**
   *    struct sock - network layer representation of sockets
   *    @__sk_common: shared layout with inet_timewait_sock
@@ -287,7 +272,6 @@ struct cg_proto;
   *    @sk_ack_backlog: current listen backlog
   *    @sk_max_ack_backlog: listen backlog set in listen()
   *    @sk_priority: %SO_PRIORITY setting
-  *    @sk_cgrp_prioidx: socket group's priority map index
   *    @sk_type: socket type (%SOCK_STREAM, etc)
   *    @sk_protocol: which protocol this socket belongs in this network family
   *    @sk_peer_pid: &struct pid for this socket's peer
@@ -308,8 +292,8 @@ struct cg_proto;
   *    @sk_send_head: front of stuff to transmit
   *    @sk_security: used by security modules
   *    @sk_mark: generic packet mark
-  *    @sk_classid: this socket's cgroup classid
-  *    @sk_cgrp: this socket's cgroup-specific proto data
+  *    @sk_cgrp_data: cgroup data for this cgroup
+  *    @sk_memcg: this socket's memory cgroup association
   *    @sk_write_pending: a write to stream socket waits to start
   *    @sk_state_change: callback to indicate change in the state of the sock
   *    @sk_data_ready: callback to indicate there is data to be processed
@@ -317,6 +301,7 @@ struct cg_proto;
   *    @sk_error_report: callback to indicate errors (e.g. %MSG_ERRQUEUE)
   *    @sk_backlog_rcv: callback to process the backlog
   *    @sk_destruct: called at sock freeing time, i.e. when all refcnt == 0
+  *    @sk_reuseport_cb: reuseport group container
  */
 struct sock {
        /*
@@ -425,9 +410,7 @@ struct sock {
        u32                     sk_ack_backlog;
        u32                     sk_max_ack_backlog;
        __u32                   sk_priority;
-#if IS_ENABLED(CONFIG_CGROUP_NET_PRIO)
-       __u32                   sk_cgrp_prioidx;
-#endif
+       __u32                   sk_mark;
        struct pid              *sk_peer_pid;
        const struct cred       *sk_peer_cred;
        long                    sk_rcvtimeo;
@@ -445,11 +428,8 @@ struct sock {
 #ifdef CONFIG_SECURITY
        void                    *sk_security;
 #endif
-       __u32                   sk_mark;
-#ifdef CONFIG_CGROUP_NET_CLASSID
-       u32                     sk_classid;
-#endif
-       struct cg_proto         *sk_cgrp;
+       struct sock_cgroup_data sk_cgrp_data;
+       struct mem_cgroup       *sk_memcg;
        void                    (*sk_state_change)(struct sock *sk);
        void                    (*sk_data_ready)(struct sock *sk);
        void                    (*sk_write_space)(struct sock *sk);
@@ -457,6 +437,7 @@ struct sock {
        int                     (*sk_backlog_rcv)(struct sock *sk,
                                                  struct sk_buff *skb);
        void                    (*sk_destruct)(struct sock *sk);
+       struct sock_reuseport __rcu     *sk_reuseport_cb;
 };
 
 #define __sk_user_data(sk) ((*((void __rcu **)&(sk)->sk_user_data)))
@@ -778,9 +759,9 @@ static inline int sk_memalloc_socks(void)
 
 #endif
 
-static inline gfp_t sk_gfp_atomic(const struct sock *sk, gfp_t gfp_mask)
+static inline gfp_t sk_gfp_mask(const struct sock *sk, gfp_t gfp_mask)
 {
-       return GFP_ATOMIC | (sk->sk_allocation & __GFP_MEMALLOC);
+       return gfp_mask | (sk->sk_allocation & __GFP_MEMALLOC);
 }
 
 static inline void sk_acceptq_removed(struct sock *sk)
@@ -1003,7 +984,7 @@ struct proto {
        void            (*release_cb)(struct sock *sk);
 
        /* Keeping track of sk's, looking them up, and port selection methods. */
-       void                    (*hash)(struct sock *sk);
+       int                     (*hash)(struct sock *sk);
        void                    (*unhash)(struct sock *sk);
        void                    (*rehash)(struct sock *sk);
        int                     (*get_port)(struct sock *sk, unsigned short snum);
@@ -1055,18 +1036,7 @@ struct proto {
 #ifdef SOCK_REFCNT_DEBUG
        atomic_t                socks;
 #endif
-#ifdef CONFIG_MEMCG_KMEM
-       /*
-        * cgroup specific init/deinit functions. Called once for all
-        * protocols that implement it, from cgroups populate function.
-        * This function has to setup any files the protocol want to
-        * appear in the kmem cgroup filesystem.
-        */
-       int                     (*init_cgroup)(struct mem_cgroup *memcg,
-                                              struct cgroup_subsys *ss);
-       void                    (*destroy_cgroup)(struct mem_cgroup *memcg);
-       struct cg_proto         *(*proto_cgroup)(struct mem_cgroup *memcg);
-#endif
+       int                     (*diag_destroy)(struct sock *sk, int err);
 };
 
 int proto_register(struct proto *prot, int alloc_slab);
@@ -1097,23 +1067,6 @@ static inline void sk_refcnt_debug_release(const struct sock *sk)
 #define sk_refcnt_debug_release(sk) do { } while (0)
 #endif /* SOCK_REFCNT_DEBUG */
 
-#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_NET)
-extern struct static_key memcg_socket_limit_enabled;
-static inline struct cg_proto *parent_cg_proto(struct proto *proto,
-                                              struct cg_proto *cg_proto)
-{
-       return proto->proto_cgroup(parent_mem_cgroup(cg_proto->memcg));
-}
-#define mem_cgroup_sockets_enabled static_key_false(&memcg_socket_limit_enabled)
-#else
-#define mem_cgroup_sockets_enabled 0
-static inline struct cg_proto *parent_cg_proto(struct proto *proto,
-                                              struct cg_proto *cg_proto)
-{
-       return NULL;
-}
-#endif
-
 static inline bool sk_stream_memory_free(const struct sock *sk)
 {
        if (sk->sk_wmem_queued >= sk->sk_sndbuf)
@@ -1140,8 +1093,9 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
        if (!sk->sk_prot->memory_pressure)
                return false;
 
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               return !!sk->sk_cgrp->memory_pressure;
+       if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
+           mem_cgroup_under_socket_pressure(sk->sk_memcg))
+               return true;
 
        return !!*sk->sk_prot->memory_pressure;
 }
@@ -1155,15 +1109,6 @@ static inline void sk_leave_memory_pressure(struct sock *sk)
 
        if (*memory_pressure)
                *memory_pressure = 0;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               struct cg_proto *cg_proto = sk->sk_cgrp;
-               struct proto *prot = sk->sk_prot;
-
-               for (; cg_proto; cg_proto = parent_cg_proto(prot, cg_proto))
-                       cg_proto->memory_pressure = 0;
-       }
-
 }
 
 static inline void sk_enter_memory_pressure(struct sock *sk)
@@ -1171,116 +1116,46 @@ static inline void sk_enter_memory_pressure(struct sock *sk)
        if (!sk->sk_prot->enter_memory_pressure)
                return;
 
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               struct cg_proto *cg_proto = sk->sk_cgrp;
-               struct proto *prot = sk->sk_prot;
-
-               for (; cg_proto; cg_proto = parent_cg_proto(prot, cg_proto))
-                       cg_proto->memory_pressure = 1;
-       }
-
        sk->sk_prot->enter_memory_pressure(sk);
 }
 
 static inline long sk_prot_mem_limits(const struct sock *sk, int index)
 {
-       long *prot = sk->sk_prot->sysctl_mem;
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               prot = sk->sk_cgrp->sysctl_mem;
-       return prot[index];
-}
-
-static inline void memcg_memory_allocated_add(struct cg_proto *prot,
-                                             unsigned long amt,
-                                             int *parent_status)
-{
-       page_counter_charge(&prot->memory_allocated, amt);
-
-       if (page_counter_read(&prot->memory_allocated) >
-           prot->memory_allocated.limit)
-               *parent_status = OVER_LIMIT;
-}
-
-static inline void memcg_memory_allocated_sub(struct cg_proto *prot,
-                                             unsigned long amt)
-{
-       page_counter_uncharge(&prot->memory_allocated, amt);
+       return sk->sk_prot->sysctl_mem[index];
 }
 
 static inline long
 sk_memory_allocated(const struct sock *sk)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               return page_counter_read(&sk->sk_cgrp->memory_allocated);
-
-       return atomic_long_read(prot->memory_allocated);
+       return atomic_long_read(sk->sk_prot->memory_allocated);
 }
 
 static inline long
-sk_memory_allocated_add(struct sock *sk, int amt, int *parent_status)
+sk_memory_allocated_add(struct sock *sk, int amt)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               memcg_memory_allocated_add(sk->sk_cgrp, amt, parent_status);
-               /* update the root cgroup regardless */
-               atomic_long_add_return(amt, prot->memory_allocated);
-               return page_counter_read(&sk->sk_cgrp->memory_allocated);
-       }
-
-       return atomic_long_add_return(amt, prot->memory_allocated);
+       return atomic_long_add_return(amt, sk->sk_prot->memory_allocated);
 }
 
 static inline void
 sk_memory_allocated_sub(struct sock *sk, int amt)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               memcg_memory_allocated_sub(sk->sk_cgrp, amt);
-
-       atomic_long_sub(amt, prot->memory_allocated);
+       atomic_long_sub(amt, sk->sk_prot->memory_allocated);
 }
 
 static inline void sk_sockets_allocated_dec(struct sock *sk)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               struct cg_proto *cg_proto = sk->sk_cgrp;
-
-               for (; cg_proto; cg_proto = parent_cg_proto(prot, cg_proto))
-                       percpu_counter_dec(&cg_proto->sockets_allocated);
-       }
-
-       percpu_counter_dec(prot->sockets_allocated);
+       percpu_counter_dec(sk->sk_prot->sockets_allocated);
 }
 
 static inline void sk_sockets_allocated_inc(struct sock *sk)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               struct cg_proto *cg_proto = sk->sk_cgrp;
-
-               for (; cg_proto; cg_proto = parent_cg_proto(prot, cg_proto))
-                       percpu_counter_inc(&cg_proto->sockets_allocated);
-       }
-
-       percpu_counter_inc(prot->sockets_allocated);
+       percpu_counter_inc(sk->sk_prot->sockets_allocated);
 }
 
 static inline int
 sk_sockets_allocated_read_positive(struct sock *sk)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               return percpu_counter_read_positive(&sk->sk_cgrp->sockets_allocated);
-
-       return percpu_counter_read_positive(prot->sockets_allocated);
+       return percpu_counter_read_positive(sk->sk_prot->sockets_allocated);
 }
 
 static inline int
@@ -1319,10 +1194,10 @@ static inline void sock_prot_inuse_add(struct net *net, struct proto *prot,
 /* With per-bucket locks this operation is not-atomic, so that
  * this version is not worse.
  */
-static inline void __sk_prot_rehash(struct sock *sk)
+static inline int __sk_prot_rehash(struct sock *sk)
 {
        sk->sk_prot->unhash(sk);
-       sk->sk_prot->hash(sk);
+       return sk->sk_prot->hash(sk);
 }
 
 void sk_prot_clear_portaddr_nulls(struct sock *sk, int size);
@@ -1798,6 +1673,15 @@ static inline void sk_nocaps_add(struct sock *sk, netdev_features_t flags)
        sk->sk_route_caps &= ~flags;
 }
 
+static inline bool sk_check_csum_caps(struct sock *sk)
+{
+       return (sk->sk_route_caps & NETIF_F_HW_CSUM) ||
+              (sk->sk_family == PF_INET &&
+               (sk->sk_route_caps & NETIF_F_IP_CSUM)) ||
+              (sk->sk_family == PF_INET6 &&
+               (sk->sk_route_caps & NETIF_F_IPV6_CSUM));
+}
+
 static inline int skb_do_copy_data_nocache(struct sock *sk, struct sk_buff *skb,
                                           struct iov_iter *from, char *to,
                                           int copy, int offset)
@@ -1883,12 +1767,12 @@ static inline bool sk_has_allocations(const struct sock *sk)
 }
 
 /**
- * wq_has_sleeper - check if there are any waiting processes
+ * skwq_has_sleeper - check if there are any waiting processes
  * @wq: struct socket_wq
  *
  * Returns true if socket_wq has waiting processes
  *
- * The purpose of the wq_has_sleeper and sock_poll_wait is to wrap the memory
+ * The purpose of the skwq_has_sleeper and sock_poll_wait is to wrap the memory
  * barrier call. They were added due to the race found within the tcp code.
  *
  * Consider following tcp code paths:
@@ -1914,15 +1798,9 @@ static inline bool sk_has_allocations(const struct sock *sk)
  * data on the socket.
  *
  */
-static inline bool wq_has_sleeper(struct socket_wq *wq)
+static inline bool skwq_has_sleeper(struct socket_wq *wq)
 {
-       /* We need to be sure we are in sync with the
-        * add_wait_queue modifications to the wait queue.
-        *
-        * This memory barrier is paired in the sock_poll_wait.
-        */
-       smp_mb();
-       return wq && waitqueue_active(&wq->wait);
+       return wq && wq_has_sleeper(&wq->wait);
 }
 
 /**