Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net-next
[cascardo/linux.git] / net / netfilter / nf_log.c
index 6e3b911..43c926c 100644 (file)
@@ -19,6 +19,9 @@
 static struct nf_logger __rcu *loggers[NFPROTO_NUMPROTO][NF_LOG_TYPE_MAX] __read_mostly;
 static DEFINE_MUTEX(nf_log_mutex);
 
+#define nft_log_dereference(logger) \
+       rcu_dereference_protected(logger, lockdep_is_held(&nf_log_mutex))
+
 static struct nf_logger *__find_logger(int pf, const char *str_logger)
 {
        struct nf_logger *log;
@@ -28,8 +31,7 @@ static struct nf_logger *__find_logger(int pf, const char *str_logger)
                if (loggers[pf][i] == NULL)
                        continue;
 
-               log = rcu_dereference_protected(loggers[pf][i],
-                                               lockdep_is_held(&nf_log_mutex));
+               log = nft_log_dereference(loggers[pf][i]);
                if (!strncasecmp(str_logger, log->name, strlen(log->name)))
                        return log;
        }
@@ -45,8 +47,7 @@ void nf_log_set(struct net *net, u_int8_t pf, const struct nf_logger *logger)
                return;
 
        mutex_lock(&nf_log_mutex);
-       log = rcu_dereference_protected(net->nf.nf_loggers[pf],
-                                       lockdep_is_held(&nf_log_mutex));
+       log = nft_log_dereference(net->nf.nf_loggers[pf]);
        if (log == NULL)
                rcu_assign_pointer(net->nf.nf_loggers[pf], logger);
 
@@ -61,8 +62,7 @@ void nf_log_unset(struct net *net, const struct nf_logger *logger)
 
        mutex_lock(&nf_log_mutex);
        for (i = 0; i < NFPROTO_NUMPROTO; i++) {
-               log = rcu_dereference_protected(net->nf.nf_loggers[i],
-                               lockdep_is_held(&nf_log_mutex));
+               log = nft_log_dereference(net->nf.nf_loggers[i]);
                if (log == logger)
                        RCU_INIT_POINTER(net->nf.nf_loggers[i], NULL);
        }
@@ -75,6 +75,7 @@ EXPORT_SYMBOL(nf_log_unset);
 int nf_log_register(u_int8_t pf, struct nf_logger *logger)
 {
        int i;
+       int ret = 0;
 
        if (pf >= ARRAY_SIZE(init_net.nf.nf_loggers))
                return -EINVAL;
@@ -82,16 +83,25 @@ int nf_log_register(u_int8_t pf, struct nf_logger *logger)
        mutex_lock(&nf_log_mutex);
 
        if (pf == NFPROTO_UNSPEC) {
+               for (i = NFPROTO_UNSPEC; i < NFPROTO_NUMPROTO; i++) {
+                       if (rcu_access_pointer(loggers[i][logger->type])) {
+                               ret = -EEXIST;
+                               goto unlock;
+                       }
+               }
                for (i = NFPROTO_UNSPEC; i < NFPROTO_NUMPROTO; i++)
                        rcu_assign_pointer(loggers[i][logger->type], logger);
        } else {
-               /* register at end of list to honor first register win */
+               if (rcu_access_pointer(loggers[pf][logger->type])) {
+                       ret = -EEXIST;
+                       goto unlock;
+               }
                rcu_assign_pointer(loggers[pf][logger->type], logger);
        }
 
+unlock:
        mutex_unlock(&nf_log_mutex);
-
-       return 0;
+       return ret;
 }
 EXPORT_SYMBOL(nf_log_register);
 
@@ -144,8 +154,7 @@ int nf_logger_find_get(int pf, enum nf_log_type type)
        struct nf_logger *logger;
        int ret = -ENOENT;
 
-       logger = loggers[pf][type];
-       if (logger == NULL)
+       if (rcu_access_pointer(loggers[pf][type]) == NULL)
                request_module("nf-logger-%u-%u", pf, type);
 
        rcu_read_lock();
@@ -297,8 +306,7 @@ static int seq_show(struct seq_file *s, void *v)
        int i;
        struct net *net = seq_file_net(s);
 
-       logger = rcu_dereference_protected(net->nf.nf_loggers[*pos],
-                                          lockdep_is_held(&nf_log_mutex));
+       logger = nft_log_dereference(net->nf.nf_loggers[*pos]);
 
        if (!logger)
                seq_printf(s, "%2lld NONE (", *pos);
@@ -312,8 +320,7 @@ static int seq_show(struct seq_file *s, void *v)
                if (loggers[*pos][i] == NULL)
                        continue;
 
-               logger = rcu_dereference_protected(loggers[*pos][i],
-                                          lockdep_is_held(&nf_log_mutex));
+               logger = nft_log_dereference(loggers[*pos][i]);
                seq_printf(s, "%s", logger->name);
                if (i == 0 && loggers[*pos][i + 1] != NULL)
                        seq_printf(s, ",");
@@ -387,8 +394,7 @@ static int nf_log_proc_dostring(struct ctl_table *table, int write,
                mutex_unlock(&nf_log_mutex);
        } else {
                mutex_lock(&nf_log_mutex);
-               logger = rcu_dereference_protected(net->nf.nf_loggers[tindex],
-                                                  lockdep_is_held(&nf_log_mutex));
+               logger = nft_log_dereference(net->nf.nf_loggers[tindex]);
                if (!logger)
                        table->data = "NONE";
                else