locking/static_keys: Rework update logic
[cascardo/linux.git] / kernel / jump_label.c
index 72707e4..2e7cc1e 100644 (file)
@@ -54,12 +54,7 @@ jump_label_sort_entries(struct jump_entry *start, struct jump_entry *stop)
        sort(start, size, sizeof(struct jump_entry), jump_label_cmp, NULL);
 }
 
-static void jump_label_update(struct static_key *key, int enable);
-
-static inline bool static_key_type(struct static_key *key)
-{
-       return (unsigned long)key->entries & JUMP_TYPE_MASK;
-}
+static void jump_label_update(struct static_key *key);
 
 void static_key_slow_inc(struct static_key *key)
 {
@@ -68,13 +63,8 @@ void static_key_slow_inc(struct static_key *key)
                return;
 
        jump_label_lock();
-       if (atomic_read(&key->enabled) == 0) {
-               if (!static_key_type(key))
-                       jump_label_update(key, JUMP_LABEL_JMP);
-               else
-                       jump_label_update(key, JUMP_LABEL_NOP);
-       }
-       atomic_inc(&key->enabled);
+       if (atomic_inc_return(&key->enabled) == 1)
+               jump_label_update(key);
        jump_label_unlock();
 }
 EXPORT_SYMBOL_GPL(static_key_slow_inc);
@@ -92,10 +82,7 @@ static void __static_key_slow_dec(struct static_key *key,
                atomic_inc(&key->enabled);
                schedule_delayed_work(work, rate_limit);
        } else {
-               if (!static_key_type(key))
-                       jump_label_update(key, JUMP_LABEL_NOP);
-               else
-                       jump_label_update(key, JUMP_LABEL_JMP);
+               jump_label_update(key);
        }
        jump_label_unlock();
 }
@@ -154,7 +141,7 @@ static int __jump_label_text_reserved(struct jump_entry *iter_start,
        return 0;
 }
 
-/* 
+/*
  * Update code which is definitely not currently executing.
  * Architectures which need heavyweight synchronization to modify
  * running code can override this to make the non-live update case
@@ -163,29 +150,17 @@ static int __jump_label_text_reserved(struct jump_entry *iter_start,
 void __weak __init_or_module arch_jump_label_transform_static(struct jump_entry *entry,
                                            enum jump_label_type type)
 {
-       arch_jump_label_transform(entry, type); 
+       arch_jump_label_transform(entry, type);
 }
 
-static void __jump_label_update(struct static_key *key,
-                               struct jump_entry *entry,
-                               struct jump_entry *stop, int enable)
+static inline struct jump_entry *static_key_entries(struct static_key *key)
 {
-       for (; (entry < stop) &&
-             (entry->key == (jump_label_t)(unsigned long)key);
-             entry++) {
-               /*
-                * entry->code set to 0 invalidates module init text sections
-                * kernel_text_address() verifies we are not in core kernel
-                * init code, see jump_label_invalidate_module_init().
-                */
-               if (entry->code && kernel_text_address(entry->code))
-                       arch_jump_label_transform(entry, enable);
-       }
+       return (struct jump_entry *)((unsigned long)key->entries & ~JUMP_TYPE_MASK);
 }
 
-static inline struct jump_entry *static_key_entries(struct static_key *key)
+static inline bool static_key_type(struct static_key *key)
 {
-       return (struct jump_entry *)((unsigned long)key->entries & ~JUMP_TYPE_MASK);
+       return (unsigned long)key->entries & JUMP_TYPE_MASK;
 }
 
 static inline struct static_key *jump_entry_key(struct jump_entry *entry)
@@ -193,14 +168,30 @@ static inline struct static_key *jump_entry_key(struct jump_entry *entry)
        return (struct static_key *)((unsigned long)entry->key);
 }
 
-static enum jump_label_type jump_label_type(struct static_key *key)
+static enum jump_label_type jump_label_type(struct jump_entry *entry)
 {
+       struct static_key *key = jump_entry_key(entry);
        bool enabled = static_key_enabled(key);
        bool type = static_key_type(key);
 
        return enabled ^ type;
 }
 
+static void __jump_label_update(struct static_key *key,
+                               struct jump_entry *entry,
+                               struct jump_entry *stop)
+{
+       for (; (entry < stop) && (jump_entry_key(entry) == key); entry++) {
+               /*
+                * entry->code set to 0 invalidates module init text sections
+                * kernel_text_address() verifies we are not in core kernel
+                * init code, see jump_label_invalidate_module_init().
+                */
+               if (entry->code && kernel_text_address(entry->code))
+                       arch_jump_label_transform(entry, jump_label_type(entry));
+       }
+}
+
 void __init jump_label_init(void)
 {
        struct jump_entry *iter_start = __start___jump_table;
@@ -214,8 +205,8 @@ void __init jump_label_init(void)
        for (iter = iter_start; iter < iter_stop; iter++) {
                struct static_key *iterk;
 
+               arch_jump_label_transform_static(iter, jump_label_type(iter));
                iterk = jump_entry_key(iter);
-               arch_jump_label_transform_static(iter, jump_label_type(iterk));
                if (iterk == key)
                        continue;
 
@@ -255,17 +246,15 @@ static int __jump_label_mod_text_reserved(void *start, void *end)
                                start, end);
 }
 
-static void __jump_label_mod_update(struct static_key *key, int enable)
+static void __jump_label_mod_update(struct static_key *key)
 {
-       struct static_key_mod *mod = key->next;
+       struct static_key_mod *mod;
 
-       while (mod) {
+       for (mod = key->next; mod; mod = mod->next) {
                struct module *m = mod->mod;
 
                __jump_label_update(key, mod->entries,
-                                   m->jump_entries + m->num_jump_entries,
-                                   enable);
-               mod = mod->next;
+                                   m->jump_entries + m->num_jump_entries);
        }
 }
 
@@ -287,9 +276,8 @@ void jump_label_apply_nops(struct module *mod)
        if (iter_start == iter_stop)
                return;
 
-       for (iter = iter_start; iter < iter_stop; iter++) {
+       for (iter = iter_start; iter < iter_stop; iter++)
                arch_jump_label_transform_static(iter, JUMP_LABEL_NOP);
-       }
 }
 
 static int jump_label_add_module(struct module *mod)
@@ -330,8 +318,8 @@ static int jump_label_add_module(struct module *mod)
                jlm->next = key->next;
                key->next = jlm;
 
-               if (jump_label_type(key) == JUMP_LABEL_JMP)
-                       __jump_label_update(key, iter, iter_stop, JUMP_LABEL_JMP);
+               if (jump_label_type(iter) == JUMP_LABEL_JMP)
+                       __jump_label_update(key, iter, iter_stop);
        }
 
        return 0;
@@ -451,14 +439,14 @@ int jump_label_text_reserved(void *start, void *end)
        return ret;
 }
 
-static void jump_label_update(struct static_key *key, int enable)
+static void jump_label_update(struct static_key *key)
 {
        struct jump_entry *stop = __stop___jump_table;
        struct jump_entry *entry = static_key_entries(key);
 #ifdef CONFIG_MODULES
        struct module *mod;
 
-       __jump_label_mod_update(key, enable);
+       __jump_label_mod_update(key);
 
        preempt_disable();
        mod = __module_address((unsigned long)key);
@@ -468,7 +456,7 @@ static void jump_label_update(struct static_key *key, int enable)
 #endif
        /* if there are no users, entry can be NULL */
        if (entry)
-               __jump_label_update(key, entry, stop, enable);
+               __jump_label_update(key, entry, stop);
 }
 
 #endif