net: tcp_memcontrol: sanitize tcp memory accounting callbacks
[cascardo/linux.git] / net / core / sock.c
index 6c5dab0..89ae859 100644 (file)
@@ -2084,27 +2084,27 @@ int __sk_mem_schedule(struct sock *sk, int size, int kind)
        struct proto *prot = sk->sk_prot;
        int amt = sk_mem_pages(size);
        long allocated;
-       int parent_status = UNDER_LIMIT;
 
        sk->sk_forward_alloc += amt * SK_MEM_QUANTUM;
 
-       allocated = sk_memory_allocated_add(sk, amt, &parent_status);
+       allocated = sk_memory_allocated_add(sk, amt);
+
+       if (mem_cgroup_sockets_enabled && sk->sk_cgrp &&
+           !mem_cgroup_charge_skmem(sk->sk_cgrp, amt))
+               goto suppress_allocation;
 
        /* Under limit. */
-       if (parent_status == UNDER_LIMIT &&
-                       allocated <= sk_prot_mem_limits(sk, 0)) {
+       if (allocated <= sk_prot_mem_limits(sk, 0)) {
                sk_leave_memory_pressure(sk);
                return 1;
        }
 
-       /* Under pressure. (we or our parents) */
-       if ((parent_status > SOFT_LIMIT) ||
-                       allocated > sk_prot_mem_limits(sk, 1))
+       /* Under pressure. */
+       if (allocated > sk_prot_mem_limits(sk, 1))
                sk_enter_memory_pressure(sk);
 
-       /* Over hard limit (we or our parents) */
-       if ((parent_status == OVER_LIMIT) ||
-                       (allocated > sk_prot_mem_limits(sk, 2)))
+       /* Over hard limit. */
+       if (allocated > sk_prot_mem_limits(sk, 2))
                goto suppress_allocation;
 
        /* guarantee minimum buffer size under pressure */
@@ -2153,6 +2153,9 @@ suppress_allocation:
 
        sk_memory_allocated_sub(sk, amt);
 
+       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
+               mem_cgroup_uncharge_skmem(sk->sk_cgrp, amt);
+
        return 0;
 }
 EXPORT_SYMBOL(__sk_mem_schedule);
@@ -2168,6 +2171,9 @@ void __sk_mem_reclaim(struct sock *sk, int amount)
        sk_memory_allocated_sub(sk, amount);
        sk->sk_forward_alloc -= amount << SK_MEM_QUANTUM_SHIFT;
 
+       if (mem_cgroup_sockets_enabled && sk->sk_cgrp)
+               mem_cgroup_uncharge_skmem(sk->sk_cgrp, amount);
+
        if (sk_under_memory_pressure(sk) &&
            (sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
                sk_leave_memory_pressure(sk);