sock: struct proto hash function may error
[cascardo/linux.git] / include / net / sock.h
index e830c10..255d3e0 100644 (file)
 #include <net/tcp_states.h>
 #include <linux/net_tstamp.h>
 
-struct cgroup;
-struct cgroup_subsys;
-#ifdef CONFIG_NET
-int mem_cgroup_sockets_init(struct mem_cgroup *memcg, struct cgroup_subsys *ss);
-void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg);
-#else
-static inline
-int mem_cgroup_sockets_init(struct mem_cgroup *memcg, struct cgroup_subsys *ss)
-{
-       return 0;
-}
-static inline
-void mem_cgroup_sockets_destroy(struct mem_cgroup *memcg)
-{
-}
-#endif
 /*
  * This structure really needs to be cleaned up.
  * Most of it is for TCP, and not used by any of
@@ -245,7 +229,6 @@ struct sock_common {
        /* public: */
 };
 
-struct cg_proto;
 /**
   *    struct sock - network layer representation of sockets
   *    @__sk_common: shared layout with inet_timewait_sock
@@ -310,7 +293,7 @@ struct cg_proto;
   *    @sk_security: used by security modules
   *    @sk_mark: generic packet mark
   *    @sk_cgrp_data: cgroup data for this cgroup
-  *    @sk_cgrp: this socket's cgroup-specific proto data
+  *    @sk_memcg: this socket's memory cgroup association
   *    @sk_write_pending: a write to stream socket waits to start
   *    @sk_state_change: callback to indicate change in the state of the sock
   *    @sk_data_ready: callback to indicate there is data to be processed
@@ -446,7 +429,7 @@ struct sock {
        void                    *sk_security;
 #endif
        struct sock_cgroup_data sk_cgrp_data;
-       struct cg_proto         *sk_cgrp;
+       struct mem_cgroup       *sk_memcg;
        void                    (*sk_state_change)(struct sock *sk);
        void                    (*sk_data_ready)(struct sock *sk);
        void                    (*sk_write_space)(struct sock *sk);
@@ -1001,7 +984,7 @@ struct proto {
        void            (*release_cb)(struct sock *sk);
 
        /* Keeping track of sk's, looking them up, and port selection methods. */
-       void                    (*hash)(struct sock *sk);
+       int                     (*hash)(struct sock *sk);
        void                    (*unhash)(struct sock *sk);
        void                    (*rehash)(struct sock *sk);
        int                     (*get_port)(struct sock *sk, unsigned short snum);
@@ -1052,18 +1035,6 @@ struct proto {
        struct list_head        node;
 #ifdef SOCK_REFCNT_DEBUG
        atomic_t                socks;
-#endif
-#ifdef CONFIG_MEMCG_KMEM
-       /*
-        * cgroup specific init/deinit functions. Called once for all
-        * protocols that implement it, from cgroups populate function.
-        * This function has to setup any files the protocol want to
-        * appear in the kmem cgroup filesystem.
-        */
-       int                     (*init_cgroup)(struct mem_cgroup *memcg,
-                                              struct cgroup_subsys *ss);
-       void                    (*destroy_cgroup)(struct mem_cgroup *memcg);
-       struct cg_proto         *(*proto_cgroup)(struct mem_cgroup *memcg);
 #endif
        int                     (*diag_destroy)(struct sock *sk, int err);
 };
@@ -1096,23 +1067,6 @@ static inline void sk_refcnt_debug_release(const struct sock *sk)
 #define sk_refcnt_debug_release(sk) do { } while (0)
 #endif /* SOCK_REFCNT_DEBUG */
 
-#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_NET)
-extern struct static_key memcg_socket_limit_enabled;
-static inline struct cg_proto *parent_cg_proto(struct proto *proto,
-                                              struct cg_proto *cg_proto)
-{
-       return proto->proto_cgroup(parent_mem_cgroup(cg_proto->memcg));
-}
-#define mem_cgroup_sockets_enabled static_key_false(&memcg_socket_limit_enabled)
-#else
-#define mem_cgroup_sockets_enabled 0
-static inline struct cg_proto *parent_cg_proto(struct proto *proto,
-                                              struct cg_proto *cg_proto)
-{
-       return NULL;
-}
-#endif
-
 static inline bool sk_stream_memory_free(const struct sock *sk)
 {
        if (sk->sk_wmem_queued >= sk->sk_sndbuf)
@@ -1139,8 +1093,9 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
        if (!sk->sk_prot->memory_pressure)
                return false;
 
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               return !!sk->sk_cgrp->memory_pressure;
+       if (mem_cgroup_sockets_enabled && sk->sk_memcg &&
+           mem_cgroup_under_socket_pressure(sk->sk_memcg))
+               return true;
 
        return !!*sk->sk_prot->memory_pressure;
 }
@@ -1154,15 +1109,6 @@ static inline void sk_leave_memory_pressure(struct sock *sk)
 
        if (*memory_pressure)
                *memory_pressure = 0;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               struct cg_proto *cg_proto = sk->sk_cgrp;
-               struct proto *prot = sk->sk_prot;
-
-               for (; cg_proto; cg_proto = parent_cg_proto(prot, cg_proto))
-                       cg_proto->memory_pressure = 0;
-       }
-
 }
 
 static inline void sk_enter_memory_pressure(struct sock *sk)
@@ -1170,116 +1116,46 @@ static inline void sk_enter_memory_pressure(struct sock *sk)
        if (!sk->sk_prot->enter_memory_pressure)
                return;
 
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               struct cg_proto *cg_proto = sk->sk_cgrp;
-               struct proto *prot = sk->sk_prot;
-
-               for (; cg_proto; cg_proto = parent_cg_proto(prot, cg_proto))
-                       cg_proto->memory_pressure = 1;
-       }
-
        sk->sk_prot->enter_memory_pressure(sk);
 }
 
 static inline long sk_prot_mem_limits(const struct sock *sk, int index)
 {
-       long *prot = sk->sk_prot->sysctl_mem;
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               prot = sk->sk_cgrp->sysctl_mem;
-       return prot[index];
-}
-
-static inline void memcg_memory_allocated_add(struct cg_proto *prot,
-                                             unsigned long amt,
-                                             int *parent_status)
-{
-       page_counter_charge(&prot->memory_allocated, amt);
-
-       if (page_counter_read(&prot->memory_allocated) >
-           prot->memory_allocated.limit)
-               *parent_status = OVER_LIMIT;
-}
-
-static inline void memcg_memory_allocated_sub(struct cg_proto *prot,
-                                             unsigned long amt)
-{
-       page_counter_uncharge(&prot->memory_allocated, amt);
+       return sk->sk_prot->sysctl_mem[index];
 }
 
 static inline long
 sk_memory_allocated(const struct sock *sk)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               return page_counter_read(&sk->sk_cgrp->memory_allocated);
-
-       return atomic_long_read(prot->memory_allocated);
+       return atomic_long_read(sk->sk_prot->memory_allocated);
 }
 
 static inline long
-sk_memory_allocated_add(struct sock *sk, int amt, int *parent_status)
+sk_memory_allocated_add(struct sock *sk, int amt)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               memcg_memory_allocated_add(sk->sk_cgrp, amt, parent_status);
-               /* update the root cgroup regardless */
-               atomic_long_add_return(amt, prot->memory_allocated);
-               return page_counter_read(&sk->sk_cgrp->memory_allocated);
-       }
-
-       return atomic_long_add_return(amt, prot->memory_allocated);
+       return atomic_long_add_return(amt, sk->sk_prot->memory_allocated);
 }
 
 static inline void
 sk_memory_allocated_sub(struct sock *sk, int amt)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               memcg_memory_allocated_sub(sk->sk_cgrp, amt);
-
-       atomic_long_sub(amt, prot->memory_allocated);
+       atomic_long_sub(amt, sk->sk_prot->memory_allocated);
 }
 
 static inline void sk_sockets_allocated_dec(struct sock *sk)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               struct cg_proto *cg_proto = sk->sk_cgrp;
-
-               for (; cg_proto; cg_proto = parent_cg_proto(prot, cg_proto))
-                       percpu_counter_dec(&cg_proto->sockets_allocated);
-       }
-
-       percpu_counter_dec(prot->sockets_allocated);
+       percpu_counter_dec(sk->sk_prot->sockets_allocated);
 }
 
 static inline void sk_sockets_allocated_inc(struct sock *sk)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp) {
-               struct cg_proto *cg_proto = sk->sk_cgrp;
-
-               for (; cg_proto; cg_proto = parent_cg_proto(prot, cg_proto))
-                       percpu_counter_inc(&cg_proto->sockets_allocated);
-       }
-
-       percpu_counter_inc(prot->sockets_allocated);
+       percpu_counter_inc(sk->sk_prot->sockets_allocated);
 }
 
 static inline int
 sk_sockets_allocated_read_positive(struct sock *sk)
 {
-       struct proto *prot = sk->sk_prot;
-
-       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
-               return percpu_counter_read_positive(&sk->sk_cgrp->sockets_allocated);
-
-       return percpu_counter_read_positive(prot->sockets_allocated);
+       return percpu_counter_read_positive(sk->sk_prot->sockets_allocated);
 }
 
 static inline int
@@ -1318,10 +1194,10 @@ static inline void sock_prot_inuse_add(struct net *net, struct proto *prot,
 /* With per-bucket locks this operation is not-atomic, so that
  * this version is not worse.
  */
-static inline void __sk_prot_rehash(struct sock *sk)
+static inline int __sk_prot_rehash(struct sock *sk)
 {
        sk->sk_prot->unhash(sk);
-       sk->sk_prot->hash(sk);
+       return sk->sk_prot->hash(sk);
 }
 
 void sk_prot_clear_portaddr_nulls(struct sock *sk, int size);