KVM: MMU: coalesce more page zapping in mmu_sync_children
[cascardo/linux.git] / arch / x86 / kvm / mmu.c
index 95a955d..754d2c4 100644 (file)
@@ -41,6 +41,7 @@
 #include <asm/cmpxchg.h>
 #include <asm/io.h>
 #include <asm/vmx.h>
+#include <asm/kvm_page_track.h>
 
 /*
  * When setting this variable to true it enables Two-Dimensional-Paging
@@ -776,62 +777,85 @@ static struct kvm_lpage_info *lpage_info_slot(gfn_t gfn,
        return &slot->arch.lpage_info[level - 2][idx];
 }
 
+static void update_gfn_disallow_lpage_count(struct kvm_memory_slot *slot,
+                                           gfn_t gfn, int count)
+{
+       struct kvm_lpage_info *linfo;
+       int i;
+
+       for (i = PT_DIRECTORY_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
+               linfo = lpage_info_slot(gfn, slot, i);
+               linfo->disallow_lpage += count;
+               WARN_ON(linfo->disallow_lpage < 0);
+       }
+}
+
+void kvm_mmu_gfn_disallow_lpage(struct kvm_memory_slot *slot, gfn_t gfn)
+{
+       update_gfn_disallow_lpage_count(slot, gfn, 1);
+}
+
+void kvm_mmu_gfn_allow_lpage(struct kvm_memory_slot *slot, gfn_t gfn)
+{
+       update_gfn_disallow_lpage_count(slot, gfn, -1);
+}
+
 static void account_shadowed(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        struct kvm_memslots *slots;
        struct kvm_memory_slot *slot;
-       struct kvm_lpage_info *linfo;
        gfn_t gfn;
-       int i;
 
+       kvm->arch.indirect_shadow_pages++;
        gfn = sp->gfn;
        slots = kvm_memslots_for_spte_role(kvm, sp->role);
        slot = __gfn_to_memslot(slots, gfn);
-       for (i = PT_DIRECTORY_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
-               linfo = lpage_info_slot(gfn, slot, i);
-               linfo->write_count += 1;
-       }
-       kvm->arch.indirect_shadow_pages++;
+
+       /* the non-leaf shadow pages are keeping readonly. */
+       if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+               return kvm_slot_page_track_add_page(kvm, slot, gfn,
+                                                   KVM_PAGE_TRACK_WRITE);
+
+       kvm_mmu_gfn_disallow_lpage(slot, gfn);
 }
 
 static void unaccount_shadowed(struct kvm *kvm, struct kvm_mmu_page *sp)
 {
        struct kvm_memslots *slots;
        struct kvm_memory_slot *slot;
-       struct kvm_lpage_info *linfo;
        gfn_t gfn;
-       int i;
 
+       kvm->arch.indirect_shadow_pages--;
        gfn = sp->gfn;
        slots = kvm_memslots_for_spte_role(kvm, sp->role);
        slot = __gfn_to_memslot(slots, gfn);
-       for (i = PT_DIRECTORY_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
-               linfo = lpage_info_slot(gfn, slot, i);
-               linfo->write_count -= 1;
-               WARN_ON(linfo->write_count < 0);
-       }
-       kvm->arch.indirect_shadow_pages--;
+       if (sp->role.level > PT_PAGE_TABLE_LEVEL)
+               return kvm_slot_page_track_remove_page(kvm, slot, gfn,
+                                                      KVM_PAGE_TRACK_WRITE);
+
+       kvm_mmu_gfn_allow_lpage(slot, gfn);
 }
 
-static int __has_wrprotected_page(gfn_t gfn, int level,
-                                 struct kvm_memory_slot *slot)
+static bool __mmu_gfn_lpage_is_disallowed(gfn_t gfn, int level,
+                                         struct kvm_memory_slot *slot)
 {
        struct kvm_lpage_info *linfo;
 
        if (slot) {
                linfo = lpage_info_slot(gfn, slot, level);
-               return linfo->write_count;
+               return !!linfo->disallow_lpage;
        }
 
-       return 1;
+       return true;
 }
 
-static int has_wrprotected_page(struct kvm_vcpu *vcpu, gfn_t gfn, int level)
+static bool mmu_gfn_lpage_is_disallowed(struct kvm_vcpu *vcpu, gfn_t gfn,
+                                       int level)
 {
        struct kvm_memory_slot *slot;
 
        slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
-       return __has_wrprotected_page(gfn, level, slot);
+       return __mmu_gfn_lpage_is_disallowed(gfn, level, slot);
 }
 
 static int host_mapping_level(struct kvm *kvm, gfn_t gfn)
@@ -897,7 +921,7 @@ static int mapping_level(struct kvm_vcpu *vcpu, gfn_t large_gfn,
        max_level = min(kvm_x86_ops->get_lpage_level(), host_level);
 
        for (level = PT_DIRECTORY_LEVEL; level <= max_level; ++level)
-               if (__has_wrprotected_page(large_gfn, level, slot))
+               if (__mmu_gfn_lpage_is_disallowed(large_gfn, level, slot))
                        break;
 
        return level - 1;
@@ -1323,23 +1347,29 @@ void kvm_arch_mmu_enable_log_dirty_pt_masked(struct kvm *kvm,
                kvm_mmu_write_protect_pt_masked(kvm, slot, gfn_offset, mask);
 }
 
-static bool rmap_write_protect(struct kvm_vcpu *vcpu, u64 gfn)
+bool kvm_mmu_slot_gfn_write_protect(struct kvm *kvm,
+                                   struct kvm_memory_slot *slot, u64 gfn)
 {
-       struct kvm_memory_slot *slot;
        struct kvm_rmap_head *rmap_head;
        int i;
        bool write_protected = false;
 
-       slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
-
        for (i = PT_PAGE_TABLE_LEVEL; i <= PT_MAX_HUGEPAGE_LEVEL; ++i) {
                rmap_head = __gfn_to_rmap(gfn, i, slot);
-               write_protected |= __rmap_write_protect(vcpu->kvm, rmap_head, true);
+               write_protected |= __rmap_write_protect(kvm, rmap_head, true);
        }
 
        return write_protected;
 }
 
+static bool rmap_write_protect(struct kvm_vcpu *vcpu, u64 gfn)
+{
+       struct kvm_memory_slot *slot;
+
+       slot = kvm_vcpu_gfn_to_memslot(vcpu, gfn);
+       return kvm_mmu_slot_gfn_write_protect(vcpu->kvm, slot, gfn);
+}
+
 static bool kvm_zap_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head)
 {
        u64 *sptep;
@@ -1754,7 +1784,7 @@ static void mark_unsync(u64 *spte)
 static int nonpaging_sync_page(struct kvm_vcpu *vcpu,
                               struct kvm_mmu_page *sp)
 {
-       return 1;
+       return 0;
 }
 
 static void nonpaging_invlpg(struct kvm_vcpu *vcpu, gva_t gva)
@@ -1840,13 +1870,16 @@ static int __mmu_unsync_walk(struct kvm_mmu_page *sp,
        return nr_unsync_leaf;
 }
 
+#define INVALID_INDEX (-1)
+
 static int mmu_unsync_walk(struct kvm_mmu_page *sp,
                           struct kvm_mmu_pages *pvec)
 {
+       pvec->nr = 0;
        if (!sp->unsync_children)
                return 0;
 
-       mmu_pages_add(pvec, sp, 0);
+       mmu_pages_add(pvec, sp, INVALID_INDEX);
        return __mmu_unsync_walk(sp, pvec);
 }
 
@@ -1883,37 +1916,35 @@ static void kvm_mmu_commit_zap_page(struct kvm *kvm,
                if ((_sp)->role.direct || (_sp)->role.invalid) {} else
 
 /* @sp->gfn should be write-protected at the call site */
-static int __kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
-                          struct list_head *invalid_list, bool clear_unsync)
+static bool __kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
+                           struct list_head *invalid_list)
 {
        if (sp->role.cr4_pae != !!is_pae(vcpu)) {
                kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
-               return 1;
+               return false;
        }
 
-       if (clear_unsync)
-               kvm_unlink_unsync_page(vcpu->kvm, sp);
-
-       if (vcpu->arch.mmu.sync_page(vcpu, sp)) {
+       if (vcpu->arch.mmu.sync_page(vcpu, sp) == 0) {
                kvm_mmu_prepare_zap_page(vcpu->kvm, sp, invalid_list);
-               return 1;
+               return false;
        }
 
-       kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
-       return 0;
+       return true;
 }
 
-static int kvm_sync_page_transient(struct kvm_vcpu *vcpu,
-                                  struct kvm_mmu_page *sp)
+static void kvm_mmu_flush_or_zap(struct kvm_vcpu *vcpu,
+                                struct list_head *invalid_list,
+                                bool remote_flush, bool local_flush)
 {
-       LIST_HEAD(invalid_list);
-       int ret;
-
-       ret = __kvm_sync_page(vcpu, sp, &invalid_list, false);
-       if (ret)
-               kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
+       if (!list_empty(invalid_list)) {
+               kvm_mmu_commit_zap_page(vcpu->kvm, invalid_list);
+               return;
+       }
 
-       return ret;
+       if (remote_flush)
+               kvm_flush_remote_tlbs(vcpu->kvm);
+       else if (local_flush)
+               kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
 }
 
 #ifdef CONFIG_KVM_MMU_AUDIT
@@ -1923,46 +1954,38 @@ static void kvm_mmu_audit(struct kvm_vcpu *vcpu, int point) { }
 static void mmu_audit_disable(void) { }
 #endif
 
-static int kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
+static bool kvm_sync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp,
                         struct list_head *invalid_list)
 {
-       return __kvm_sync_page(vcpu, sp, invalid_list, true);
+       kvm_unlink_unsync_page(vcpu->kvm, sp);
+       return __kvm_sync_page(vcpu, sp, invalid_list);
 }
 
 /* @gfn should be write-protected at the call site */
-static void kvm_sync_pages(struct kvm_vcpu *vcpu,  gfn_t gfn)
+static bool kvm_sync_pages(struct kvm_vcpu *vcpu, gfn_t gfn,
+                          struct list_head *invalid_list)
 {
        struct kvm_mmu_page *s;
-       LIST_HEAD(invalid_list);
-       bool flush = false;
+       bool ret = false;
 
        for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn) {
                if (!s->unsync)
                        continue;
 
                WARN_ON(s->role.level != PT_PAGE_TABLE_LEVEL);
-               kvm_unlink_unsync_page(vcpu->kvm, s);
-               if ((s->role.cr4_pae != !!is_pae(vcpu)) ||
-                       (vcpu->arch.mmu.sync_page(vcpu, s))) {
-                       kvm_mmu_prepare_zap_page(vcpu->kvm, s, &invalid_list);
-                       continue;
-               }
-               flush = true;
+               ret |= kvm_sync_page(vcpu, s, invalid_list);
        }
 
-       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
-       if (flush)
-               kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
+       return ret;
 }
 
 struct mmu_page_path {
-       struct kvm_mmu_page *parent[PT64_ROOT_LEVEL-1];
-       unsigned int idx[PT64_ROOT_LEVEL-1];
+       struct kvm_mmu_page *parent[PT64_ROOT_LEVEL];
+       unsigned int idx[PT64_ROOT_LEVEL];
 };
 
 #define for_each_sp(pvec, sp, parents, i)                      \
-               for (i = mmu_pages_next(&pvec, &parents, -1),   \
-                       sp = pvec.page[i].sp;                   \
+               for (i = mmu_pages_first(&pvec, &parents);      \
                        i < pvec.nr && ({ sp = pvec.page[i].sp; 1;});   \
                        i = mmu_pages_next(&pvec, &parents, i))
 
@@ -1974,19 +1997,43 @@ static int mmu_pages_next(struct kvm_mmu_pages *pvec,
 
        for (n = i+1; n < pvec->nr; n++) {
                struct kvm_mmu_page *sp = pvec->page[n].sp;
+               unsigned idx = pvec->page[n].idx;
+               int level = sp->role.level;
 
-               if (sp->role.level == PT_PAGE_TABLE_LEVEL) {
-                       parents->idx[0] = pvec->page[n].idx;
-                       return n;
-               }
+               parents->idx[level-1] = idx;
+               if (level == PT_PAGE_TABLE_LEVEL)
+                       break;
 
-               parents->parent[sp->role.level-2] = sp;
-               parents->idx[sp->role.level-1] = pvec->page[n].idx;
+               parents->parent[level-2] = sp;
        }
 
        return n;
 }
 
+static int mmu_pages_first(struct kvm_mmu_pages *pvec,
+                          struct mmu_page_path *parents)
+{
+       struct kvm_mmu_page *sp;
+       int level;
+
+       if (pvec->nr == 0)
+               return 0;
+
+       WARN_ON(pvec->page[0].idx != INVALID_INDEX);
+
+       sp = pvec->page[0].sp;
+       level = sp->role.level;
+       WARN_ON(level == PT_PAGE_TABLE_LEVEL);
+
+       parents->parent[level-2] = sp;
+
+       /* Also set up a sentinel.  Further entries in pvec are all
+        * children of sp, so this element is never overwritten.
+        */
+       parents->parent[level-1] = NULL;
+       return mmu_pages_next(pvec, parents, 0);
+}
+
 static void mmu_pages_clear_parents(struct mmu_page_path *parents)
 {
        struct kvm_mmu_page *sp;
@@ -1994,22 +2041,14 @@ static void mmu_pages_clear_parents(struct mmu_page_path *parents)
 
        do {
                unsigned int idx = parents->idx[level];
-
                sp = parents->parent[level];
                if (!sp)
                        return;
 
+               WARN_ON(idx == INVALID_INDEX);
                clear_unsync_child_bit(sp, idx);
                level++;
-       } while (level < PT64_ROOT_LEVEL-1 && !sp->unsync_children);
-}
-
-static void kvm_mmu_pages_init(struct kvm_mmu_page *parent,
-                              struct mmu_page_path *parents,
-                              struct kvm_mmu_pages *pvec)
-{
-       parents->parent[parent->role.level-1] = NULL;
-       pvec->nr = 0;
+       } while (!sp->unsync_children);
 }
 
 static void mmu_sync_children(struct kvm_vcpu *vcpu,
@@ -2020,30 +2059,36 @@ static void mmu_sync_children(struct kvm_vcpu *vcpu,
        struct mmu_page_path parents;
        struct kvm_mmu_pages pages;
        LIST_HEAD(invalid_list);
+       bool flush = false;
 
-       kvm_mmu_pages_init(parent, &parents, &pages);
        while (mmu_unsync_walk(parent, &pages)) {
                bool protected = false;
 
                for_each_sp(pages, sp, parents, i)
                        protected |= rmap_write_protect(vcpu, sp->gfn);
 
-               if (protected)
+               if (protected) {
                        kvm_flush_remote_tlbs(vcpu->kvm);
+                       flush = false;
+               }
 
                for_each_sp(pages, sp, parents, i) {
-                       kvm_sync_page(vcpu, sp, &invalid_list);
+                       flush |= kvm_sync_page(vcpu, sp, &invalid_list);
                        mmu_pages_clear_parents(&parents);
                }
-               kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
-               cond_resched_lock(&vcpu->kvm->mmu_lock);
-               kvm_mmu_pages_init(parent, &parents, &pages);
+               if (need_resched() || spin_needbreak(&vcpu->kvm->mmu_lock)) {
+                       kvm_mmu_flush_or_zap(vcpu, &invalid_list, false, flush);
+                       cond_resched_lock(&vcpu->kvm->mmu_lock);
+                       flush = false;
+               }
        }
+
+       kvm_mmu_flush_or_zap(vcpu, &invalid_list, false, flush);
 }
 
 static void __clear_sp_write_flooding_count(struct kvm_mmu_page *sp)
 {
-       sp->write_flooding_count = 0;
+       atomic_set(&sp->write_flooding_count,  0);
 }
 
 static void clear_sp_write_flooding_count(u64 *spte)
@@ -2069,6 +2114,8 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        unsigned quadrant;
        struct kvm_mmu_page *sp;
        bool need_sync = false;
+       bool flush = false;
+       LIST_HEAD(invalid_list);
 
        role = vcpu->arch.mmu.base_role;
        role.level = level;
@@ -2092,8 +2139,16 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
                if (sp->role.word != role.word)
                        continue;
 
-               if (sp->unsync && kvm_sync_page_transient(vcpu, sp))
-                       break;
+               if (sp->unsync) {
+                       /* The page is good, but __kvm_sync_page might still end
+                        * up zapping it.  If so, break in order to rebuild it.
+                        */
+                       if (!__kvm_sync_page(vcpu, sp, &invalid_list))
+                               break;
+
+                       WARN_ON(!list_empty(&invalid_list));
+                       kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
+               }
 
                if (sp->unsync_children)
                        kvm_make_request(KVM_REQ_MMU_SYNC, vcpu);
@@ -2112,16 +2167,24 @@ static struct kvm_mmu_page *kvm_mmu_get_page(struct kvm_vcpu *vcpu,
        hlist_add_head(&sp->hash_link,
                &vcpu->kvm->arch.mmu_page_hash[kvm_page_table_hashfn(gfn)]);
        if (!direct) {
-               if (rmap_write_protect(vcpu, gfn))
+               /*
+                * we should do write protection before syncing pages
+                * otherwise the content of the synced shadow page may
+                * be inconsistent with guest page table.
+                */
+               account_shadowed(vcpu->kvm, sp);
+               if (level == PT_PAGE_TABLE_LEVEL &&
+                     rmap_write_protect(vcpu, gfn))
                        kvm_flush_remote_tlbs(vcpu->kvm);
-               if (level > PT_PAGE_TABLE_LEVEL && need_sync)
-                       kvm_sync_pages(vcpu, gfn);
 
-               account_shadowed(vcpu->kvm, sp);
+               if (level > PT_PAGE_TABLE_LEVEL && need_sync)
+                       flush |= kvm_sync_pages(vcpu, gfn, &invalid_list);
        }
        sp->mmu_valid_gen = vcpu->kvm->arch.mmu_valid_gen;
        clear_page(sp->spt);
        trace_kvm_mmu_get_page(sp, true);
+
+       kvm_mmu_flush_or_zap(vcpu, &invalid_list, false, flush);
        return sp;
 }
 
@@ -2269,7 +2332,6 @@ static int mmu_zap_unsync_children(struct kvm *kvm,
        if (parent->role.level == PT_PAGE_TABLE_LEVEL)
                return 0;
 
-       kvm_mmu_pages_init(parent, &parents, &pages);
        while (mmu_unsync_walk(parent, &pages)) {
                struct kvm_mmu_page *sp;
 
@@ -2278,7 +2340,6 @@ static int mmu_zap_unsync_children(struct kvm *kvm,
                        mmu_pages_clear_parents(&parents);
                        zapped++;
                }
-               kvm_mmu_pages_init(parent, &parents, &pages);
        }
 
        return zapped;
@@ -2354,8 +2415,8 @@ static bool prepare_zap_oldest_mmu_page(struct kvm *kvm,
        if (list_empty(&kvm->arch.active_mmu_pages))
                return false;
 
-       sp = list_entry(kvm->arch.active_mmu_pages.prev,
-                       struct kvm_mmu_page, link);
+       sp = list_last_entry(&kvm->arch.active_mmu_pages,
+                            struct kvm_mmu_page, link);
        kvm_mmu_prepare_zap_page(kvm, sp, invalid_list);
 
        return true;
@@ -2408,7 +2469,7 @@ int kvm_mmu_unprotect_page(struct kvm *kvm, gfn_t gfn)
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_unprotect_page);
 
-static void __kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
+static void kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
 {
        trace_kvm_mmu_unsync_page(sp);
        ++vcpu->kvm->stat.mmu_unsync;
@@ -2417,37 +2478,26 @@ static void __kvm_unsync_page(struct kvm_vcpu *vcpu, struct kvm_mmu_page *sp)
        kvm_mmu_mark_parents_unsync(sp);
 }
 
-static void kvm_unsync_pages(struct kvm_vcpu *vcpu,  gfn_t gfn)
+static bool mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
+                                  bool can_unsync)
 {
-       struct kvm_mmu_page *s;
-
-       for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn) {
-               if (s->unsync)
-                       continue;
-               WARN_ON(s->role.level != PT_PAGE_TABLE_LEVEL);
-               __kvm_unsync_page(vcpu, s);
-       }
-}
+       struct kvm_mmu_page *sp;
 
-static int mmu_need_write_protect(struct kvm_vcpu *vcpu, gfn_t gfn,
-                                 bool can_unsync)
-{
-       struct kvm_mmu_page *s;
-       bool need_unsync = false;
+       if (kvm_page_track_is_active(vcpu, gfn, KVM_PAGE_TRACK_WRITE))
+               return true;
 
-       for_each_gfn_indirect_valid_sp(vcpu->kvm, s, gfn) {
+       for_each_gfn_indirect_valid_sp(vcpu->kvm, sp, gfn) {
                if (!can_unsync)
-                       return 1;
+                       return true;
 
-               if (s->role.level != PT_PAGE_TABLE_LEVEL)
-                       return 1;
+               if (sp->unsync)
+                       continue;
 
-               if (!s->unsync)
-                       need_unsync = true;
+               WARN_ON(sp->role.level != PT_PAGE_TABLE_LEVEL);
+               kvm_unsync_page(vcpu, sp);
        }
-       if (need_unsync)
-               kvm_unsync_pages(vcpu, gfn);
-       return 0;
+
+       return false;
 }
 
 static bool kvm_is_mmio_pfn(kvm_pfn_t pfn)
@@ -2503,7 +2553,7 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
                 * be fixed if guest refault.
                 */
                if (level > PT_PAGE_TABLE_LEVEL &&
-                   has_wrprotected_page(vcpu, gfn, level))
+                   mmu_gfn_lpage_is_disallowed(vcpu, gfn, level))
                        goto done;
 
                spte |= PT_WRITABLE_MASK | SPTE_MMU_WRITEABLE;
@@ -2768,7 +2818,7 @@ static void transparent_hugepage_adjust(struct kvm_vcpu *vcpu,
        if (!is_error_noslot_pfn(pfn) && !kvm_is_reserved_pfn(pfn) &&
            level == PT_PAGE_TABLE_LEVEL &&
            PageTransCompound(pfn_to_page(pfn)) &&
-           !has_wrprotected_page(vcpu, gfn, PT_DIRECTORY_LEVEL)) {
+           !mmu_gfn_lpage_is_disallowed(vcpu, gfn, PT_DIRECTORY_LEVEL)) {
                unsigned long mask;
                /*
                 * mmu_notifier_retry was successful and we hold the
@@ -2796,20 +2846,16 @@ static void transparent_hugepage_adjust(struct kvm_vcpu *vcpu,
 static bool handle_abnormal_pfn(struct kvm_vcpu *vcpu, gva_t gva, gfn_t gfn,
                                kvm_pfn_t pfn, unsigned access, int *ret_val)
 {
-       bool ret = true;
-
        /* The pfn is invalid, report the error! */
        if (unlikely(is_error_pfn(pfn))) {
                *ret_val = kvm_handle_bad_page(vcpu, gfn, pfn);
-               goto exit;
+               return true;
        }
 
        if (unlikely(is_noslot_pfn(pfn)))
                vcpu_cache_mmio_info(vcpu, gva, gfn, access);
 
-       ret = false;
-exit:
-       return ret;
+       return false;
 }
 
 static bool page_fault_can_be_fast(u32 error_code)
@@ -3273,7 +3319,7 @@ static bool is_shadow_zero_bits_set(struct kvm_mmu *mmu, u64 spte, int level)
        return __is_rsvd_bits_set(&mmu->shadow_zero_check, spte, level);
 }
 
-static bool quickly_check_mmio_pf(struct kvm_vcpu *vcpu, u64 addr, bool direct)
+static bool mmio_info_in_cache(struct kvm_vcpu *vcpu, u64 addr, bool direct)
 {
        if (direct)
                return vcpu_match_mmio_gpa(vcpu, addr);
@@ -3332,7 +3378,7 @@ int handle_mmio_page_fault(struct kvm_vcpu *vcpu, u64 addr, bool direct)
        u64 spte;
        bool reserved;
 
-       if (quickly_check_mmio_pf(vcpu, addr, direct))
+       if (mmio_info_in_cache(vcpu, addr, direct))
                return RET_MMIO_PF_EMULATE;
 
        reserved = walk_shadow_page_get_mmio_spte(vcpu, addr, &spte);
@@ -3362,20 +3408,53 @@ int handle_mmio_page_fault(struct kvm_vcpu *vcpu, u64 addr, bool direct)
 }
 EXPORT_SYMBOL_GPL(handle_mmio_page_fault);
 
+static bool page_fault_handle_page_track(struct kvm_vcpu *vcpu,
+                                        u32 error_code, gfn_t gfn)
+{
+       if (unlikely(error_code & PFERR_RSVD_MASK))
+               return false;
+
+       if (!(error_code & PFERR_PRESENT_MASK) ||
+             !(error_code & PFERR_WRITE_MASK))
+               return false;
+
+       /*
+        * guest is writing the page which is write tracked which can
+        * not be fixed by page fault handler.
+        */
+       if (kvm_page_track_is_active(vcpu, gfn, KVM_PAGE_TRACK_WRITE))
+               return true;
+
+       return false;
+}
+
+static void shadow_page_table_clear_flood(struct kvm_vcpu *vcpu, gva_t addr)
+{
+       struct kvm_shadow_walk_iterator iterator;
+       u64 spte;
+
+       if (!VALID_PAGE(vcpu->arch.mmu.root_hpa))
+               return;
+
+       walk_shadow_page_lockless_begin(vcpu);
+       for_each_shadow_entry_lockless(vcpu, addr, iterator, spte) {
+               clear_sp_write_flooding_count(iterator.sptep);
+               if (!is_shadow_present_pte(spte))
+                       break;
+       }
+       walk_shadow_page_lockless_end(vcpu);
+}
+
 static int nonpaging_page_fault(struct kvm_vcpu *vcpu, gva_t gva,
                                u32 error_code, bool prefault)
 {
-       gfn_t gfn;
+       gfn_t gfn = gva >> PAGE_SHIFT;
        int r;
 
        pgprintk("%s: gva %lx error %x\n", __func__, gva, error_code);
 
-       if (unlikely(error_code & PFERR_RSVD_MASK)) {
-               r = handle_mmio_page_fault(vcpu, gva, true);
-
-               if (likely(r != RET_MMIO_PF_INVALID))
-                       return r;
-       }
+       if (page_fault_handle_page_track(vcpu, error_code, gfn))
+               return 1;
 
        r = mmu_topup_memory_caches(vcpu);
        if (r)
@@ -3383,7 +3462,6 @@ static int nonpaging_page_fault(struct kvm_vcpu *vcpu, gva_t gva,
 
        MMU_WARN_ON(!VALID_PAGE(vcpu->arch.mmu.root_hpa));
 
-       gfn = gva >> PAGE_SHIFT;
 
        return nonpaging_map(vcpu, gva & PAGE_MASK,
                             error_code, gfn, prefault);
@@ -3460,12 +3538,8 @@ static int tdp_page_fault(struct kvm_vcpu *vcpu, gva_t gpa, u32 error_code,
 
        MMU_WARN_ON(!VALID_PAGE(vcpu->arch.mmu.root_hpa));
 
-       if (unlikely(error_code & PFERR_RSVD_MASK)) {
-               r = handle_mmio_page_fault(vcpu, gpa, true);
-
-               if (likely(r != RET_MMIO_PF_INVALID))
-                       return r;
-       }
+       if (page_fault_handle_page_track(vcpu, error_code, gfn))
+               return 1;
 
        r = mmu_topup_memory_caches(vcpu);
        if (r)
@@ -4125,18 +4199,6 @@ static bool need_remote_flush(u64 old, u64 new)
        return (old & ~new & PT64_PERM_MASK) != 0;
 }
 
-static void mmu_pte_write_flush_tlb(struct kvm_vcpu *vcpu, bool zap_page,
-                                   bool remote_flush, bool local_flush)
-{
-       if (zap_page)
-               return;
-
-       if (remote_flush)
-               kvm_flush_remote_tlbs(vcpu->kvm);
-       else if (local_flush)
-               kvm_make_request(KVM_REQ_TLB_FLUSH, vcpu);
-}
-
 static u64 mmu_pte_write_fetch_gpte(struct kvm_vcpu *vcpu, gpa_t *gpa,
                                    const u8 *new, int *bytes)
 {
@@ -4186,7 +4248,8 @@ static bool detect_write_flooding(struct kvm_mmu_page *sp)
        if (sp->role.level == PT_PAGE_TABLE_LEVEL)
                return false;
 
-       return ++sp->write_flooding_count >= 3;
+       atomic_inc(&sp->write_flooding_count);
+       return atomic_read(&sp->write_flooding_count) >= 3;
 }
 
 /*
@@ -4248,15 +4311,15 @@ static u64 *get_written_sptes(struct kvm_mmu_page *sp, gpa_t gpa, int *nspte)
        return spte;
 }
 
-void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
-                      const u8 *new, int bytes)
+static void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
+                             const u8 *new, int bytes)
 {
        gfn_t gfn = gpa >> PAGE_SHIFT;
        struct kvm_mmu_page *sp;
        LIST_HEAD(invalid_list);
        u64 entry, gentry, *spte;
        int npte;
-       bool remote_flush, local_flush, zap_page;
+       bool remote_flush, local_flush;
        union kvm_mmu_page_role mask = { };
 
        mask.cr0_wp = 1;
@@ -4273,7 +4336,7 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        if (!ACCESS_ONCE(vcpu->kvm->arch.indirect_shadow_pages))
                return;
 
-       zap_page = remote_flush = local_flush = false;
+       remote_flush = local_flush = false;
 
        pgprintk("%s: gpa %llx bytes %d\n", __func__, gpa, bytes);
 
@@ -4293,8 +4356,7 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
        for_each_gfn_indirect_valid_sp(vcpu->kvm, sp, gfn) {
                if (detect_write_misaligned(sp, gpa, bytes) ||
                      detect_write_flooding(sp)) {
-                       zap_page |= !!kvm_mmu_prepare_zap_page(vcpu->kvm, sp,
-                                                    &invalid_list);
+                       kvm_mmu_prepare_zap_page(vcpu->kvm, sp, &invalid_list);
                        ++vcpu->kvm->stat.mmu_flooded;
                        continue;
                }
@@ -4316,8 +4378,7 @@ void kvm_mmu_pte_write(struct kvm_vcpu *vcpu, gpa_t gpa,
                        ++spte;
                }
        }
-       mmu_pte_write_flush_tlb(vcpu, zap_page, remote_flush, local_flush);
-       kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
+       kvm_mmu_flush_or_zap(vcpu, &invalid_list, remote_flush, local_flush);
        kvm_mmu_audit(vcpu, AUDIT_POST_PTE_WRITE);
        spin_unlock(&vcpu->kvm->mmu_lock);
 }
@@ -4354,32 +4415,34 @@ static void make_mmu_pages_available(struct kvm_vcpu *vcpu)
        kvm_mmu_commit_zap_page(vcpu->kvm, &invalid_list);
 }
 
-static bool is_mmio_page_fault(struct kvm_vcpu *vcpu, gva_t addr)
-{
-       if (vcpu->arch.mmu.direct_map || mmu_is_nested(vcpu))
-               return vcpu_match_mmio_gpa(vcpu, addr);
-
-       return vcpu_match_mmio_gva(vcpu, addr);
-}
-
 int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code,
                       void *insn, int insn_len)
 {
        int r, emulation_type = EMULTYPE_RETRY;
        enum emulation_result er;
+       bool direct = vcpu->arch.mmu.direct_map || mmu_is_nested(vcpu);
+
+       if (unlikely(error_code & PFERR_RSVD_MASK)) {
+               r = handle_mmio_page_fault(vcpu, cr2, direct);
+               if (r == RET_MMIO_PF_EMULATE) {
+                       emulation_type = 0;
+                       goto emulate;
+               }
+               if (r == RET_MMIO_PF_RETRY)
+                       return 1;
+               if (r < 0)
+                       return r;
+       }
 
        r = vcpu->arch.mmu.page_fault(vcpu, cr2, error_code, false);
        if (r < 0)
-               goto out;
-
-       if (!r) {
-               r = 1;
-               goto out;
-       }
+               return r;
+       if (!r)
+               return 1;
 
-       if (is_mmio_page_fault(vcpu, cr2))
+       if (mmio_info_in_cache(vcpu, cr2, direct))
                emulation_type = 0;
-
+emulate:
        er = x86_emulate_instruction(vcpu, cr2, emulation_type, insn, insn_len);
 
        switch (er) {
@@ -4393,8 +4456,6 @@ int kvm_mmu_page_fault(struct kvm_vcpu *vcpu, gva_t cr2, u32 error_code,
        default:
                BUG();
        }
-out:
-       return r;
 }
 EXPORT_SYMBOL_GPL(kvm_mmu_page_fault);
 
@@ -4463,6 +4524,21 @@ void kvm_mmu_setup(struct kvm_vcpu *vcpu)
        init_kvm_mmu(vcpu);
 }
 
+void kvm_mmu_init_vm(struct kvm *kvm)
+{
+       struct kvm_page_track_notifier_node *node = &kvm->arch.mmu_sp_tracker;
+
+       node->track_write = kvm_mmu_pte_write;
+       kvm_page_track_register_notifier(kvm, node);
+}
+
+void kvm_mmu_uninit_vm(struct kvm *kvm)
+{
+       struct kvm_page_track_notifier_node *node = &kvm->arch.mmu_sp_tracker;
+
+       kvm_page_track_unregister_notifier(kvm, node);
+}
+
 /* The return value indicates if tlb flush on all vcpus is needed. */
 typedef bool (*slot_level_handler) (struct kvm *kvm, struct kvm_rmap_head *rmap_head);