oom_reaper: close race with exiting task
[cascardo/linux.git] / mm / oom_kill.c
index 415f7eb..dfb1ab6 100644 (file)
@@ -174,8 +174,13 @@ unsigned long oom_badness(struct task_struct *p, struct mem_cgroup *memcg,
        if (!p)
                return 0;
 
+       /*
+        * Do not even consider tasks which are explicitly marked oom
+        * unkillable or have been already oom reaped.
+        */
        adj = (long)p->signal->oom_score_adj;
-       if (adj == OOM_SCORE_ADJ_MIN) {
+       if (adj == OOM_SCORE_ADJ_MIN ||
+                       test_bit(MMF_OOM_REAPED, &p->mm->flags)) {
                task_unlock(p);
                return 0;
        }
@@ -278,12 +283,8 @@ enum oom_scan_t oom_scan_process_thread(struct oom_control *oc,
         * This task already has access to memory reserves and is being killed.
         * Don't allow any other task to have access to the reserves.
         */
-       if (test_tsk_thread_flag(task, TIF_MEMDIE)) {
-               if (!is_sysrq_oom(oc))
-                       return OOM_SCAN_ABORT;
-       }
-       if (!task->mm)
-               return OOM_SCAN_CONTINUE;
+       if (!is_sysrq_oom(oc) && atomic_read(&task->signal->oom_victims))
+               return OOM_SCAN_ABORT;
 
        /*
         * If task is allocating a lot of memory and has been marked to be
@@ -302,12 +303,12 @@ enum oom_scan_t oom_scan_process_thread(struct oom_control *oc,
 static struct task_struct *select_bad_process(struct oom_control *oc,
                unsigned int *ppoints, unsigned long totalpages)
 {
-       struct task_struct *g, *p;
+       struct task_struct *p;
        struct task_struct *chosen = NULL;
        unsigned long chosen_points = 0;
 
        rcu_read_lock();
-       for_each_process_thread(g, p) {
+       for_each_process(p) {
                unsigned int points;
 
                switch (oom_scan_process_thread(oc, p, totalpages)) {
@@ -326,9 +327,6 @@ static struct task_struct *select_bad_process(struct oom_control *oc,
                points = oom_badness(p, NULL, oc->nodemask, totalpages);
                if (!points || points < chosen_points)
                        continue;
-               /* Prefer thread group leaders for display purposes */
-               if (points == chosen_points && thread_group_leader(chosen))
-                       continue;
 
                chosen = p;
                chosen_points = points;
@@ -441,17 +439,32 @@ static DECLARE_WAIT_QUEUE_HEAD(oom_reaper_wait);
 static struct task_struct *oom_reaper_list;
 static DEFINE_SPINLOCK(oom_reaper_lock);
 
-
 static bool __oom_reap_task(struct task_struct *tsk)
 {
        struct mmu_gather tlb;
        struct vm_area_struct *vma;
-       struct mm_struct *mm;
+       struct mm_struct *mm = NULL;
        struct task_struct *p;
        struct zap_details details = {.check_swap_entries = true,
                                      .ignore_dirty = true};
        bool ret = true;
 
+       /*
+        * We have to make sure to not race with the victim exit path
+        * and cause premature new oom victim selection:
+        * __oom_reap_task              exit_mm
+        *   atomic_inc_not_zero
+        *                                mmput
+        *                                  atomic_dec_and_test
+        *                                exit_oom_victim
+        *                              [...]
+        *                              out_of_memory
+        *                                select_bad_process
+        *                                  # no TIF_MEMDIE task selects new victim
+        *  unmap_page_range # frees some memory
+        */
+       mutex_lock(&oom_lock);
+
        /*
         * Make sure we find the associated mm_struct even when the particular
         * thread has already terminated and cleared its mm.
@@ -460,19 +473,19 @@ static bool __oom_reap_task(struct task_struct *tsk)
         */
        p = find_lock_task_mm(tsk);
        if (!p)
-               return true;
+               goto unlock_oom;
 
        mm = p->mm;
        if (!atomic_inc_not_zero(&mm->mm_users)) {
                task_unlock(p);
-               return true;
+               goto unlock_oom;
        }
 
        task_unlock(p);
 
        if (!down_read_trylock(&mm->mmap_sem)) {
                ret = false;
-               goto out;
+               goto unlock_oom;
        }
 
        tlb_gather_mmu(&tlb, mm, 0, -1);
@@ -513,9 +526,16 @@ static bool __oom_reap_task(struct task_struct *tsk)
         * This task can be safely ignored because we cannot do much more
         * to release its memory.
         */
-       tsk->signal->oom_score_adj = OOM_SCORE_ADJ_MIN;
-out:
-       mmput(mm);
+       set_bit(MMF_OOM_REAPED, &mm->flags);
+unlock_oom:
+       mutex_unlock(&oom_lock);
+       /*
+        * Drop our reference but make sure the mmput slow path is called from a
+        * different context because we shouldn't risk we get stuck there and
+        * put the oom_reaper out of the way.
+        */
+       if (mm)
+               mmput_async(mm);
        return ret;
 }
 
@@ -609,8 +629,6 @@ void try_oom_reaper(struct task_struct *tsk)
 
                        if (!process_shares_mm(p, mm))
                                continue;
-                       if (same_thread_group(p, tsk))
-                               continue;
                        if (fatal_signal_pending(p))
                                continue;
 
@@ -664,6 +682,7 @@ void mark_oom_victim(struct task_struct *tsk)
        /* OOM killer might race with memcg OOM */
        if (test_and_set_tsk_thread_flag(tsk, TIF_MEMDIE))
                return;
+       atomic_inc(&tsk->signal->oom_victims);
        /*
         * Make sure that the task is woken up from uninterruptible sleep
         * if it is frozen because OOM killer wouldn't be able to free
@@ -681,6 +700,7 @@ void exit_oom_victim(struct task_struct *tsk)
 {
        if (!test_and_clear_tsk_thread_flag(tsk, TIF_MEMDIE))
                return;
+       atomic_dec(&tsk->signal->oom_victims);
 
        if (!atomic_dec_return(&oom_victims))
                wake_up_all(&oom_victims_wait);