Merge tag 'for-linus-20130301' of git://git.infradead.org/linux-mtd
[cascardo/linux.git] / mm / ksm.c
index df05299..85bfd4c 100644 (file)
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -87,6 +87,9 @@
  *    take 10 attempts to find a page in the unstable tree, once it is found,
  *    it is secured in the stable tree.  (When we scan a new page, we first
  *    compare it against the stable tree, and then against the unstable tree.)
+ *
+ * If the merge_across_nodes tunable is unset, then KSM maintains multiple
+ * stable trees and multiple unstable trees: one of each for each NUMA node.
  */
 
 /**
@@ -122,36 +125,50 @@ struct ksm_scan {
 /**
  * struct stable_node - node of the stable rbtree
  * @node: rb node of this ksm page in the stable tree
+ * @head: (overlaying parent) &migrate_nodes indicates temporarily on that list
+ * @list: linked into migrate_nodes, pending placement in the proper node tree
  * @hlist: hlist head of rmap_items using this ksm page
- * @kpfn: page frame number of this ksm page
+ * @kpfn: page frame number of this ksm page (perhaps temporarily on wrong nid)
+ * @nid: NUMA node id of stable tree in which linked (may not match kpfn)
  */
 struct stable_node {
-       struct rb_node node;
+       union {
+               struct rb_node node;    /* when node of stable tree */
+               struct {                /* when listed for migration */
+                       struct list_head *head;
+                       struct list_head list;
+               };
+       };
        struct hlist_head hlist;
        unsigned long kpfn;
+#ifdef CONFIG_NUMA
+       int nid;
+#endif
 };
 
 /**
  * struct rmap_item - reverse mapping item for virtual addresses
  * @rmap_list: next rmap_item in mm_slot's singly-linked rmap_list
  * @anon_vma: pointer to anon_vma for this mm,address, when in stable tree
+ * @nid: NUMA node id of unstable tree in which linked (may not match page)
  * @mm: the memory structure this rmap_item is pointing into
  * @address: the virtual address this rmap_item tracks (+ flags in low bits)
  * @oldchecksum: previous checksum of the page at that virtual address
- * @nid: NUMA node id of unstable tree in which linked (may not match page)
  * @node: rb node of this rmap_item in the unstable tree
  * @head: pointer to stable_node heading this list in the stable tree
  * @hlist: link into hlist of rmap_items hanging off that stable_node
  */
 struct rmap_item {
        struct rmap_item *rmap_list;
-       struct anon_vma *anon_vma;      /* when stable */
+       union {
+               struct anon_vma *anon_vma;      /* when stable */
+#ifdef CONFIG_NUMA
+               int nid;                /* when node of unstable tree */
+#endif
+       };
        struct mm_struct *mm;
        unsigned long address;          /* + low bits used for flags below */
        unsigned int oldchecksum;       /* when unstable */
-#ifdef CONFIG_NUMA
-       int nid;
-#endif
        union {
                struct rb_node node;    /* when node of unstable tree */
                struct {                /* when listed from stable tree */
@@ -166,8 +183,13 @@ struct rmap_item {
 #define STABLE_FLAG    0x200   /* is listed from the stable tree */
 
 /* The stable and unstable tree heads */
-static struct rb_root root_unstable_tree[MAX_NUMNODES];
-static struct rb_root root_stable_tree[MAX_NUMNODES];
+static struct rb_root one_stable_tree[1] = { RB_ROOT };
+static struct rb_root one_unstable_tree[1] = { RB_ROOT };
+static struct rb_root *root_stable_tree = one_stable_tree;
+static struct rb_root *root_unstable_tree = one_unstable_tree;
+
+/* Recently migrated nodes of stable tree, pending proper placement */
+static LIST_HEAD(migrate_nodes);
 
 #define MM_SLOTS_HASH_BITS 10
 static DEFINE_HASHTABLE(mm_slots_hash, MM_SLOTS_HASH_BITS);
@@ -204,14 +226,18 @@ static unsigned int ksm_thread_sleep_millisecs = 20;
 #ifdef CONFIG_NUMA
 /* Zeroed when merging across nodes is not allowed */
 static unsigned int ksm_merge_across_nodes = 1;
+static int ksm_nr_node_ids = 1;
 #else
 #define ksm_merge_across_nodes 1U
+#define ksm_nr_node_ids                1
 #endif
 
 #define KSM_RUN_STOP   0
 #define KSM_RUN_MERGE  1
 #define KSM_RUN_UNMERGE        2
-static unsigned int ksm_run = KSM_RUN_STOP;
+#define KSM_RUN_OFFLINE        4
+static unsigned long ksm_run = KSM_RUN_STOP;
+static void wait_while_offlining(void);
 
 static DECLARE_WAIT_QUEUE_HEAD(ksm_thread_wait);
 static DEFINE_MUTEX(ksm_thread_mutex);
@@ -294,10 +320,9 @@ static inline void free_mm_slot(struct mm_slot *mm_slot)
 
 static struct mm_slot *get_mm_slot(struct mm_struct *mm)
 {
-       struct hlist_node *node;
        struct mm_slot *slot;
 
-       hash_for_each_possible(mm_slots_hash, slot, node, link, (unsigned long)mm)
+       hash_for_each_possible(mm_slots_hash, slot, link, (unsigned long)mm)
                if (slot->mm == mm)
                        return slot;
 
@@ -311,11 +336,6 @@ static void insert_to_mm_slots_hash(struct mm_struct *mm,
        hash_add(mm_slots_hash, &mm_slot->link, (unsigned long)mm);
 }
 
-static inline int in_stable_tree(struct rmap_item *rmap_item)
-{
-       return rmap_item->address & STABLE_FLAG;
-}
-
 /*
  * ksmd, and unmerge_and_remove_all_rmap_items(), must not touch an mm's
  * page tables after it has passed through ksm_exit() - which, if necessary,
@@ -347,7 +367,7 @@ static int break_ksm(struct vm_area_struct *vma, unsigned long addr)
 
        do {
                cond_resched();
-               page = follow_page(vma, addr, FOLL_GET);
+               page = follow_page(vma, addr, FOLL_GET | FOLL_MIGRATION);
                if (IS_ERR_OR_NULL(page))
                        break;
                if (PageKsm(page))
@@ -475,10 +495,8 @@ static inline int get_kpfn_nid(unsigned long kpfn)
 static void remove_node_from_stable_tree(struct stable_node *stable_node)
 {
        struct rmap_item *rmap_item;
-       struct hlist_node *hlist;
-       int nid;
 
-       hlist_for_each_entry(rmap_item, hlist, &stable_node->hlist, hlist) {
+       hlist_for_each_entry(rmap_item, &stable_node->hlist, hlist) {
                if (rmap_item->hlist.next)
                        ksm_pages_sharing--;
                else
@@ -488,8 +506,11 @@ static void remove_node_from_stable_tree(struct stable_node *stable_node)
                cond_resched();
        }
 
-       nid = get_kpfn_nid(stable_node->kpfn);
-       rb_erase(&stable_node->node, &root_stable_tree[nid]);
+       if (stable_node->head == &migrate_nodes)
+               list_del(&stable_node->list);
+       else
+               rb_erase(&stable_node->node,
+                        root_stable_tree + NUMA(stable_node->nid));
        free_stable_node(stable_node);
 }
 
@@ -512,7 +533,7 @@ static void remove_node_from_stable_tree(struct stable_node *stable_node)
  * a page to put something that might look like our key in page->mapping.
  * is on its way to being freed; but it is an anomaly to bear in mind.
  */
-static struct page *get_ksm_page(struct stable_node *stable_node, bool locked)
+static struct page *get_ksm_page(struct stable_node *stable_node, bool lock_it)
 {
        struct page *page;
        void *expected_mapping;
@@ -561,7 +582,7 @@ again:
                goto stale;
        }
 
-       if (locked) {
+       if (lock_it) {
                lock_page(page);
                if (ACCESS_ONCE(page->mapping) != expected_mapping) {
                        unlock_page(page);
@@ -625,7 +646,7 @@ static void remove_rmap_item_from_tree(struct rmap_item *rmap_item)
                BUG_ON(age > 1);
                if (!age)
                        rb_erase(&rmap_item->node,
-                                &root_unstable_tree[NUMA(rmap_item->nid)]);
+                                root_unstable_tree + NUMA(rmap_item->nid));
                ksm_pages_unshared--;
                rmap_item->address &= PAGE_MASK;
        }
@@ -691,10 +712,17 @@ static int remove_stable_node(struct stable_node *stable_node)
                return 0;
        }
 
-       if (WARN_ON_ONCE(page_mapped(page)))
+       if (WARN_ON_ONCE(page_mapped(page))) {
+               /*
+                * This should not happen: but if it does, just refuse to let
+                * merge_across_nodes be switched - there is no need to panic.
+                */
                err = -EBUSY;
-       else {
+       else {
                /*
+                * The stable node did not yet appear stale to get_ksm_page(),
+                * since that allows for an unmapped ksm page to be recognized
+                * right up until it is freed; but the node is safe to remove.
                 * This page might be in a pagevec waiting to be freed,
                 * or it might be PageSwapCache (perhaps under writeback),
                 * or it might have been removed from swapcache a moment ago.
@@ -712,10 +740,11 @@ static int remove_stable_node(struct stable_node *stable_node)
 static int remove_all_stable_nodes(void)
 {
        struct stable_node *stable_node;
+       struct list_head *this, *next;
        int nid;
        int err = 0;
 
-       for (nid = 0; nid < nr_node_ids; nid++) {
+       for (nid = 0; nid < ksm_nr_node_ids; nid++) {
                while (root_stable_tree[nid].rb_node) {
                        stable_node = rb_entry(root_stable_tree[nid].rb_node,
                                                struct stable_node, node);
@@ -726,6 +755,12 @@ static int remove_all_stable_nodes(void)
                        cond_resched();
                }
        }
+       list_for_each_safe(this, next, &migrate_nodes) {
+               stable_node = list_entry(this, struct stable_node, list);
+               if (remove_stable_node(stable_node))
+                       err = -EBUSY;
+               cond_resched();
+       }
        return err;
 }
 
@@ -1063,6 +1098,9 @@ static int try_to_merge_with_ksm_page(struct rmap_item *rmap_item,
        if (err)
                goto out;
 
+       /* Unstable nid is in union with stable anon_vma: remove first */
+       remove_rmap_item_from_tree(rmap_item);
+
        /* Must get reference to anon_vma while still holding mmap_sem */
        rmap_item->anon_vma = vma->anon_vma;
        get_anon_vma(vma->anon_vma);
@@ -1113,25 +1151,32 @@ static struct page *try_to_merge_two_pages(struct rmap_item *rmap_item,
  */
 static struct page *stable_tree_search(struct page *page)
 {
-       struct rb_node *node;
-       struct stable_node *stable_node;
        int nid;
+       struct rb_root *root;
+       struct rb_node **new;
+       struct rb_node *parent;
+       struct stable_node *stable_node;
+       struct stable_node *page_node;
 
-       stable_node = page_stable_node(page);
-       if (stable_node) {                      /* ksm page forked */
+       page_node = page_stable_node(page);
+       if (page_node && page_node->head != &migrate_nodes) {
+               /* ksm page forked */
                get_page(page);
                return page;
        }
 
        nid = get_kpfn_nid(page_to_pfn(page));
-       node = root_stable_tree[nid].rb_node;
+       root = root_stable_tree + nid;
+again:
+       new = &root->rb_node;
+       parent = NULL;
 
-       while (node) {
+       while (*new) {
                struct page *tree_page;
                int ret;
 
                cond_resched();
-               stable_node = rb_entry(node, struct stable_node, node);
+               stable_node = rb_entry(*new, struct stable_node, node);
                tree_page = get_ksm_page(stable_node, false);
                if (!tree_page)
                        return NULL;
@@ -1139,10 +1184,11 @@ static struct page *stable_tree_search(struct page *page)
                ret = memcmp_pages(page, tree_page);
                put_page(tree_page);
 
+               parent = *new;
                if (ret < 0)
-                       node = node->rb_left;
+                       new = &parent->rb_left;
                else if (ret > 0)
-                       node = node->rb_right;
+                       new = &parent->rb_right;
                else {
                        /*
                         * Lock and unlock the stable_node's page (which
@@ -1152,13 +1198,48 @@ static struct page *stable_tree_search(struct page *page)
                         * than kpage, but that involves more changes.
                         */
                        tree_page = get_ksm_page(stable_node, true);
-                       if (tree_page)
+                       if (tree_page) {
                                unlock_page(tree_page);
-                       return tree_page;
+                               if (get_kpfn_nid(stable_node->kpfn) !=
+                                               NUMA(stable_node->nid)) {
+                                       put_page(tree_page);
+                                       goto replace;
+                               }
+                               return tree_page;
+                       }
+                       /*
+                        * There is now a place for page_node, but the tree may
+                        * have been rebalanced, so re-evaluate parent and new.
+                        */
+                       if (page_node)
+                               goto again;
+                       return NULL;
                }
        }
 
-       return NULL;
+       if (!page_node)
+               return NULL;
+
+       list_del(&page_node->list);
+       DO_NUMA(page_node->nid = nid);
+       rb_link_node(&page_node->node, parent, new);
+       rb_insert_color(&page_node->node, root);
+       get_page(page);
+       return page;
+
+replace:
+       if (page_node) {
+               list_del(&page_node->list);
+               DO_NUMA(page_node->nid = nid);
+               rb_replace_node(&stable_node->node, &page_node->node, root);
+               get_page(page);
+       } else {
+               rb_erase(&stable_node->node, root);
+               page = NULL;
+       }
+       stable_node->head = &migrate_nodes;
+       list_add(&stable_node->list, stable_node->head);
+       return page;
 }
 
 /*
@@ -1172,13 +1253,15 @@ static struct stable_node *stable_tree_insert(struct page *kpage)
 {
        int nid;
        unsigned long kpfn;
+       struct rb_root *root;
        struct rb_node **new;
        struct rb_node *parent = NULL;
        struct stable_node *stable_node;
 
        kpfn = page_to_pfn(kpage);
        nid = get_kpfn_nid(kpfn);
-       new = &root_stable_tree[nid].rb_node;
+       root = root_stable_tree + nid;
+       new = &root->rb_node;
 
        while (*new) {
                struct page *tree_page;
@@ -1215,8 +1298,9 @@ static struct stable_node *stable_tree_insert(struct page *kpage)
        INIT_HLIST_HEAD(&stable_node->hlist);
        stable_node->kpfn = kpfn;
        set_page_stable_node(kpage, stable_node);
+       DO_NUMA(stable_node->nid = nid);
        rb_link_node(&stable_node->node, parent, new);
-       rb_insert_color(&stable_node->node, &root_stable_tree[nid]);
+       rb_insert_color(&stable_node->node, root);
 
        return stable_node;
 }
@@ -1246,7 +1330,7 @@ struct rmap_item *unstable_tree_search_insert(struct rmap_item *rmap_item,
        int nid;
 
        nid = get_kpfn_nid(page_to_pfn(page));
-       root = &root_unstable_tree[nid];
+       root = root_unstable_tree + nid;
        new = &root->rb_node;
 
        while (*new) {
@@ -1268,16 +1352,6 @@ struct rmap_item *unstable_tree_search_insert(struct rmap_item *rmap_item,
                        return NULL;
                }
 
-               /*
-                * If tree_page has been migrated to another NUMA node, it
-                * will be flushed out and put into the right unstable tree
-                * next time: only merge with it if merge_across_nodes.
-                */
-               if (!ksm_merge_across_nodes && page_to_nid(tree_page) != nid) {
-                       put_page(tree_page);
-                       return NULL;
-               }
-
                ret = memcmp_pages(page, tree_page);
 
                parent = *new;
@@ -1287,6 +1361,15 @@ struct rmap_item *unstable_tree_search_insert(struct rmap_item *rmap_item,
                } else if (ret > 0) {
                        put_page(tree_page);
                        new = &parent->rb_right;
+               } else if (!ksm_merge_across_nodes &&
+                          page_to_nid(tree_page) != nid) {
+                       /*
+                        * If tree_page has been migrated to another NUMA node,
+                        * it will be flushed out and put in the right unstable
+                        * tree next time: only merge with it when across_nodes.
+                        */
+                       put_page(tree_page);
+                       return NULL;
                } else {
                        *tree_pagep = tree_page;
                        return tree_rmap_item;
@@ -1311,11 +1394,6 @@ struct rmap_item *unstable_tree_search_insert(struct rmap_item *rmap_item,
 static void stable_tree_append(struct rmap_item *rmap_item,
                               struct stable_node *stable_node)
 {
-       /*
-        * Usually rmap_item->nid is already set correctly,
-        * but it may be wrong after switching merge_across_nodes.
-        */
-       DO_NUMA(rmap_item->nid = get_kpfn_nid(stable_node->kpfn));
        rmap_item->head = stable_node;
        rmap_item->address |= STABLE_FLAG;
        hlist_add_head(&rmap_item->hlist, &stable_node->hlist);
@@ -1344,10 +1422,29 @@ static void cmp_and_merge_page(struct page *page, struct rmap_item *rmap_item)
        unsigned int checksum;
        int err;
 
-       remove_rmap_item_from_tree(rmap_item);
+       stable_node = page_stable_node(page);
+       if (stable_node) {
+               if (stable_node->head != &migrate_nodes &&
+                   get_kpfn_nid(stable_node->kpfn) != NUMA(stable_node->nid)) {
+                       rb_erase(&stable_node->node,
+                                root_stable_tree + NUMA(stable_node->nid));
+                       stable_node->head = &migrate_nodes;
+                       list_add(&stable_node->list, stable_node->head);
+               }
+               if (stable_node->head != &migrate_nodes &&
+                   rmap_item->head == stable_node)
+                       return;
+       }
 
        /* We first start with searching the page inside the stable tree */
        kpage = stable_tree_search(page);
+       if (kpage == page && rmap_item->head == stable_node) {
+               put_page(kpage);
+               return;
+       }
+
+       remove_rmap_item_from_tree(rmap_item);
+
        if (kpage) {
                err = try_to_merge_with_ksm_page(rmap_item, page, kpage);
                if (!err) {
@@ -1381,14 +1478,11 @@ static void cmp_and_merge_page(struct page *page, struct rmap_item *rmap_item)
                kpage = try_to_merge_two_pages(rmap_item, page,
                                                tree_rmap_item, tree_page);
                put_page(tree_page);
-               /*
-                * As soon as we merge this page, we want to remove the
-                * rmap_item of the page we have merged with from the unstable
-                * tree, and insert it instead as new node in the stable tree.
-                */
                if (kpage) {
-                       remove_rmap_item_from_tree(tree_rmap_item);
-
+                       /*
+                        * The pages were successfully merged: insert new
+                        * node in the stable tree and add both rmap_items.
+                        */
                        lock_page(kpage);
                        stable_node = stable_tree_insert(kpage);
                        if (stable_node) {
@@ -1464,7 +1558,28 @@ static struct rmap_item *scan_get_next_rmap_item(struct page **page)
                 */
                lru_add_drain_all();
 
-               for (nid = 0; nid < nr_node_ids; nid++)
+               /*
+                * Whereas stale stable_nodes on the stable_tree itself
+                * get pruned in the regular course of stable_tree_search(),
+                * those moved out to the migrate_nodes list can accumulate:
+                * so prune them once before each full scan.
+                */
+               if (!ksm_merge_across_nodes) {
+                       struct stable_node *stable_node;
+                       struct list_head *this, *next;
+                       struct page *page;
+
+                       list_for_each_safe(this, next, &migrate_nodes) {
+                               stable_node = list_entry(this,
+                                               struct stable_node, list);
+                               page = get_ksm_page(stable_node, false);
+                               if (page)
+                                       put_page(page);
+                               cond_resched();
+                       }
+               }
+
+               for (nid = 0; nid < ksm_nr_node_ids; nid++)
                        root_unstable_tree[nid] = RB_ROOT;
 
                spin_lock(&ksm_mmlist_lock);
@@ -1586,8 +1701,7 @@ static void ksm_do_scan(unsigned int scan_npages)
                rmap_item = scan_get_next_rmap_item(&page);
                if (!rmap_item)
                        return;
-               if (!PageKsm(page) || !in_stable_tree(rmap_item))
-                       cmp_and_merge_page(page, rmap_item);
+               cmp_and_merge_page(page, rmap_item);
                put_page(page);
        }
 }
@@ -1604,6 +1718,7 @@ static int ksm_scan_thread(void *nothing)
 
        while (!kthread_should_stop()) {
                mutex_lock(&ksm_thread_mutex);
+               wait_while_offlining();
                if (ksmd_should_run())
                        ksm_do_scan(ksm_thread_pages_to_scan);
                mutex_unlock(&ksm_thread_mutex);
@@ -1781,7 +1896,6 @@ int page_referenced_ksm(struct page *page, struct mem_cgroup *memcg,
 {
        struct stable_node *stable_node;
        struct rmap_item *rmap_item;
-       struct hlist_node *hlist;
        unsigned int mapcount = page_mapcount(page);
        int referenced = 0;
        int search_new_forks = 0;
@@ -1793,7 +1907,7 @@ int page_referenced_ksm(struct page *page, struct mem_cgroup *memcg,
        if (!stable_node)
                return 0;
 again:
-       hlist_for_each_entry(rmap_item, hlist, &stable_node->hlist, hlist) {
+       hlist_for_each_entry(rmap_item, &stable_node->hlist, hlist) {
                struct anon_vma *anon_vma = rmap_item->anon_vma;
                struct anon_vma_chain *vmac;
                struct vm_area_struct *vma;
@@ -1835,7 +1949,6 @@ out:
 int try_to_unmap_ksm(struct page *page, enum ttu_flags flags)
 {
        struct stable_node *stable_node;
-       struct hlist_node *hlist;
        struct rmap_item *rmap_item;
        int ret = SWAP_AGAIN;
        int search_new_forks = 0;
@@ -1847,7 +1960,7 @@ int try_to_unmap_ksm(struct page *page, enum ttu_flags flags)
        if (!stable_node)
                return SWAP_FAIL;
 again:
-       hlist_for_each_entry(rmap_item, hlist, &stable_node->hlist, hlist) {
+       hlist_for_each_entry(rmap_item, &stable_node->hlist, hlist) {
                struct anon_vma *anon_vma = rmap_item->anon_vma;
                struct anon_vma_chain *vmac;
                struct vm_area_struct *vma;
@@ -1888,7 +2001,6 @@ int rmap_walk_ksm(struct page *page, int (*rmap_one)(struct page *,
                  struct vm_area_struct *, unsigned long, void *), void *arg)
 {
        struct stable_node *stable_node;
-       struct hlist_node *hlist;
        struct rmap_item *rmap_item;
        int ret = SWAP_AGAIN;
        int search_new_forks = 0;
@@ -1900,7 +2012,7 @@ int rmap_walk_ksm(struct page *page, int (*rmap_one)(struct page *,
        if (!stable_node)
                return ret;
 again:
-       hlist_for_each_entry(rmap_item, hlist, &stable_node->hlist, hlist) {
+       hlist_for_each_entry(rmap_item, &stable_node->hlist, hlist) {
                struct anon_vma *anon_vma = rmap_item->anon_vma;
                struct anon_vma_chain *vmac;
                struct vm_area_struct *vma;
@@ -1960,15 +2072,32 @@ void ksm_migrate_page(struct page *newpage, struct page *oldpage)
 #endif /* CONFIG_MIGRATION */
 
 #ifdef CONFIG_MEMORY_HOTREMOVE
+static int just_wait(void *word)
+{
+       schedule();
+       return 0;
+}
+
+static void wait_while_offlining(void)
+{
+       while (ksm_run & KSM_RUN_OFFLINE) {
+               mutex_unlock(&ksm_thread_mutex);
+               wait_on_bit(&ksm_run, ilog2(KSM_RUN_OFFLINE),
+                               just_wait, TASK_UNINTERRUPTIBLE);
+               mutex_lock(&ksm_thread_mutex);
+       }
+}
+
 static void ksm_check_stable_tree(unsigned long start_pfn,
                                  unsigned long end_pfn)
 {
        struct stable_node *stable_node;
+       struct list_head *this, *next;
        struct rb_node *node;
        int nid;
 
-       for (nid = 0; nid < nr_node_ids; nid++) {
-               node = rb_first(&root_stable_tree[nid]);
+       for (nid = 0; nid < ksm_nr_node_ids; nid++) {
+               node = rb_first(root_stable_tree + nid);
                while (node) {
                        stable_node = rb_entry(node, struct stable_node, node);
                        if (stable_node->kpfn >= start_pfn &&
@@ -1978,12 +2107,19 @@ static void ksm_check_stable_tree(unsigned long start_pfn,
                                 * which is why we keep kpfn instead of page*
                                 */
                                remove_node_from_stable_tree(stable_node);
-                               node = rb_first(&root_stable_tree[nid]);
+                               node = rb_first(root_stable_tree + nid);
                        } else
                                node = rb_next(node);
                        cond_resched();
                }
        }
+       list_for_each_safe(this, next, &migrate_nodes) {
+               stable_node = list_entry(this, struct stable_node, list);
+               if (stable_node->kpfn >= start_pfn &&
+                   stable_node->kpfn < end_pfn)
+                       remove_node_from_stable_tree(stable_node);
+               cond_resched();
+       }
 }
 
 static int ksm_memory_callback(struct notifier_block *self,
@@ -1994,15 +2130,15 @@ static int ksm_memory_callback(struct notifier_block *self,
        switch (action) {
        case MEM_GOING_OFFLINE:
                /*
-                * Keep it very simple for now: just lock out ksmd and
-                * MADV_UNMERGEABLE while any memory is going offline.
-                * mutex_lock_nested() is necessary because lockdep was alarmed
-                * that here we take ksm_thread_mutex inside notifier chain
-                * mutex, and later take notifier chain mutex inside
-                * ksm_thread_mutex to unlock it.   But that's safe because both
-                * are inside mem_hotplug_mutex.
+                * Prevent ksm_do_scan(), unmerge_and_remove_all_rmap_items()
+                * and remove_all_stable_nodes() while memory is going offline:
+                * it is unsafe for them to touch the stable tree at this time.
+                * But unmerge_ksm_pages(), rmap lookups and other entry points
+                * which do not need the ksm_thread_mutex are all safe.
                 */
-               mutex_lock_nested(&ksm_thread_mutex, SINGLE_DEPTH_NESTING);
+               mutex_lock(&ksm_thread_mutex);
+               ksm_run |= KSM_RUN_OFFLINE;
+               mutex_unlock(&ksm_thread_mutex);
                break;
 
        case MEM_OFFLINE:
@@ -2018,11 +2154,20 @@ static int ksm_memory_callback(struct notifier_block *self,
                /* fallthrough */
 
        case MEM_CANCEL_OFFLINE:
+               mutex_lock(&ksm_thread_mutex);
+               ksm_run &= ~KSM_RUN_OFFLINE;
                mutex_unlock(&ksm_thread_mutex);
+
+               smp_mb();       /* wake_up_bit advises this */
+               wake_up_bit(&ksm_run, ilog2(KSM_RUN_OFFLINE));
                break;
        }
        return NOTIFY_OK;
 }
+#else
+static void wait_while_offlining(void)
+{
+}
 #endif /* CONFIG_MEMORY_HOTREMOVE */
 
 #ifdef CONFIG_SYSFS
@@ -2085,7 +2230,7 @@ KSM_ATTR(pages_to_scan);
 static ssize_t run_show(struct kobject *kobj, struct kobj_attribute *attr,
                        char *buf)
 {
-       return sprintf(buf, "%u\n", ksm_run);
+       return sprintf(buf, "%lu\n", ksm_run);
 }
 
 static ssize_t run_store(struct kobject *kobj, struct kobj_attribute *attr,
@@ -2108,6 +2253,7 @@ static ssize_t run_store(struct kobject *kobj, struct kobj_attribute *attr,
         */
 
        mutex_lock(&ksm_thread_mutex);
+       wait_while_offlining();
        if (ksm_run != flags) {
                ksm_run = flags;
                if (flags & KSM_RUN_UNMERGE) {
@@ -2150,11 +2296,35 @@ static ssize_t merge_across_nodes_store(struct kobject *kobj,
                return -EINVAL;
 
        mutex_lock(&ksm_thread_mutex);
+       wait_while_offlining();
        if (ksm_merge_across_nodes != knob) {
                if (ksm_pages_shared || remove_all_stable_nodes())
                        err = -EBUSY;
-               else
+               else if (root_stable_tree == one_stable_tree) {
+                       struct rb_root *buf;
+                       /*
+                        * This is the first time that we switch away from the
+                        * default of merging across nodes: must now allocate
+                        * a buffer to hold as many roots as may be needed.
+                        * Allocate stable and unstable together:
+                        * MAXSMP NODES_SHIFT 10 will use 16kB.
+                        */
+                       buf = kcalloc(nr_node_ids + nr_node_ids,
+                               sizeof(*buf), GFP_KERNEL | __GFP_ZERO);
+                       /* Let us assume that RB_ROOT is NULL is zero */
+                       if (!buf)
+                               err = -ENOMEM;
+                       else {
+                               root_stable_tree = buf;
+                               root_unstable_tree = buf + nr_node_ids;
+                               /* Stable tree is empty but not the unstable */
+                               root_unstable_tree[0] = one_unstable_tree[0];
+                       }
+               }
+               if (!err) {
                        ksm_merge_across_nodes = knob;
+                       ksm_nr_node_ids = knob ? 1 : nr_node_ids;
+               }
        }
        mutex_unlock(&ksm_thread_mutex);
 
@@ -2233,15 +2403,11 @@ static int __init ksm_init(void)
 {
        struct task_struct *ksm_thread;
        int err;
-       int nid;
 
        err = ksm_slab_init();
        if (err)
                goto out;
 
-       for (nid = 0; nid < nr_node_ids; nid++)
-               root_stable_tree[nid] = RB_ROOT;
-
        ksm_thread = kthread_run(ksm_scan_thread, NULL, "ksmd");
        if (IS_ERR(ksm_thread)) {
                printk(KERN_ERR "ksm: creating kthread failed\n");
@@ -2262,10 +2428,7 @@ static int __init ksm_init(void)
 #endif /* CONFIG_SYSFS */
 
 #ifdef CONFIG_MEMORY_HOTREMOVE
-       /*
-        * Choose a high priority since the callback takes ksm_thread_mutex:
-        * later callbacks could only be taking locks which nest within that.
-        */
+       /* There is no significance to this priority 100 */
        hotplug_memory_notifier(ksm_memory_callback, 100);
 #endif
        return 0;