mlock: only hold mmap_sem in shared mode when faulting in pages
[cascardo/linux.git] / mm / mlock.c
index b70919c..67b3dd8 100644 (file)
@@ -171,7 +171,12 @@ static long __mlock_vma_pages_range(struct vm_area_struct *vma,
        VM_BUG_ON(!rwsem_is_locked(&mm->mmap_sem));
 
        gup_flags = FOLL_TOUCH | FOLL_GET;
-       if (vma->vm_flags & VM_WRITE)
+       /*
+        * We want to touch writable mappings with a write fault in order
+        * to break COW, except for shared mappings because these don't COW
+        * and we would not want to dirty them for nothing.
+        */
+       if ((vma->vm_flags & (VM_WRITE | VM_SHARED)) == VM_WRITE)
                gup_flags |= FOLL_WRITE;
 
        /* We don't try to access the guard page of a stack vma */
@@ -372,17 +377,9 @@ static int mlock_fixup(struct vm_area_struct *vma, struct vm_area_struct **prev,
        int ret = 0;
        int lock = newflags & VM_LOCKED;
 
-       if (newflags == vma->vm_flags ||
-                       (vma->vm_flags & (VM_IO | VM_PFNMAP)))
-               goto out;       /* don't set VM_LOCKED,  don't count */
-
-       if ((vma->vm_flags & (VM_DONTEXPAND | VM_RESERVED)) ||
-                       is_vm_hugetlb_page(vma) ||
-                       vma == get_gate_vma(current)) {
-               if (lock)
-                       make_pages_present(start, end);
+       if (newflags == vma->vm_flags || (vma->vm_flags & VM_SPECIAL) ||
+           is_vm_hugetlb_page(vma) || vma == get_gate_vma(current))
                goto out;       /* don't set VM_LOCKED,  don't count */
-       }
 
        pgoff = vma->vm_pgoff + ((start - vma->vm_start) >> PAGE_SHIFT);
        *prev = vma_merge(mm, *prev, start, end, newflags, vma->anon_vma,
@@ -419,14 +416,10 @@ success:
         * set VM_LOCKED, __mlock_vma_pages_range will bring it back.
         */
 
-       if (lock) {
+       if (lock)
                vma->vm_flags = newflags;
-               ret = __mlock_vma_pages_range(vma, start, end);
-               if (ret < 0)
-                       ret = __mlock_posix_error_return(ret);
-       } else {
+       else
                munlock_vma_pages_range(vma, start, end);
-       }
 
 out:
        *prev = vma;
@@ -439,7 +432,8 @@ static int do_mlock(unsigned long start, size_t len, int on)
        struct vm_area_struct * vma, * prev;
        int error;
 
-       len = PAGE_ALIGN(len);
+       VM_BUG_ON(start & ~PAGE_MASK);
+       VM_BUG_ON(len != PAGE_ALIGN(len));
        end = start + len;
        if (end < start)
                return -EINVAL;
@@ -482,6 +476,58 @@ static int do_mlock(unsigned long start, size_t len, int on)
        return error;
 }
 
+static int do_mlock_pages(unsigned long start, size_t len, int ignore_errors)
+{
+       struct mm_struct *mm = current->mm;
+       unsigned long end, nstart, nend;
+       struct vm_area_struct *vma = NULL;
+       int ret = 0;
+
+       VM_BUG_ON(start & ~PAGE_MASK);
+       VM_BUG_ON(len != PAGE_ALIGN(len));
+       end = start + len;
+
+       down_read(&mm->mmap_sem);
+       for (nstart = start; nstart < end; nstart = nend) {
+               /*
+                * We want to fault in pages for [nstart; end) address range.
+                * Find first corresponding VMA.
+                */
+               if (!vma)
+                       vma = find_vma(mm, nstart);
+               else
+                       vma = vma->vm_next;
+               if (!vma || vma->vm_start >= end)
+                       break;
+               /*
+                * Set [nstart; nend) to intersection of desired address
+                * range with the first VMA. Also, skip undesirable VMA types.
+                */
+               nend = min(end, vma->vm_end);
+               if (vma->vm_flags & (VM_IO | VM_PFNMAP))
+                       continue;
+               if (nstart < vma->vm_start)
+                       nstart = vma->vm_start;
+               /*
+                * Now fault in a range of pages within the first VMA.
+                */
+               if (vma->vm_flags & VM_LOCKED) {
+                       ret = __mlock_vma_pages_range(vma, nstart, nend);
+                       if (ret < 0 && ignore_errors) {
+                               ret = 0;
+                               continue;       /* continue at next VMA */
+                       }
+                       if (ret) {
+                               ret = __mlock_posix_error_return(ret);
+                               break;
+                       }
+               } else
+                       make_pages_present(nstart, nend);
+       }
+       up_read(&mm->mmap_sem);
+       return ret;     /* 0 or negative error code */
+}
+
 SYSCALL_DEFINE2(mlock, unsigned long, start, size_t, len)
 {
        unsigned long locked;
@@ -507,6 +553,8 @@ SYSCALL_DEFINE2(mlock, unsigned long, start, size_t, len)
        if ((locked <= lock_limit) || capable(CAP_IPC_LOCK))
                error = do_mlock(start, len, 1);
        up_write(&current->mm->mmap_sem);
+       if (!error)
+               error = do_mlock_pages(start, len, 0);
        return error;
 }
 
@@ -571,6 +619,10 @@ SYSCALL_DEFINE1(mlockall, int, flags)
            capable(CAP_IPC_LOCK))
                ret = do_mlockall(flags);
        up_write(&current->mm->mmap_sem);
+       if (!ret && (flags & MCL_CURRENT)) {
+               /* Ignore errors */
+               do_mlock_pages(0, TASK_SIZE, 1);
+       }
 out:
        return ret;
 }