netfilter: nfnetlink: use list_for_each_entry_safe to delete all objects
[cascardo/linux.git] / net / netfilter / nfnetlink_cttimeout.c
index 4cdcd96..f74fee1 100644 (file)
@@ -330,16 +330,16 @@ static int ctnl_timeout_try_del(struct net *net, struct ctnl_timeout *timeout)
 {
        int ret = 0;
 
-       /* we want to avoid races with nf_ct_timeout_find_get. */
-       if (atomic_dec_and_test(&timeout->refcnt)) {
+       /* We want to avoid races with ctnl_timeout_put. So only when the
+        * current refcnt is 1, we decrease it to 0.
+        */
+       if (atomic_cmpxchg(&timeout->refcnt, 1, 0) == 1) {
                /* We are protected by nfnl mutex. */
                list_del_rcu(&timeout->head);
                nf_ct_l4proto_put(timeout->l4proto);
                ctnl_untimeout(net, timeout);
                kfree_rcu(timeout, rcu_head);
        } else {
-               /* still in use, restore reference counter. */
-               atomic_inc(&timeout->refcnt);
                ret = -EBUSY;
        }
        return ret;
@@ -350,12 +350,13 @@ static int cttimeout_del_timeout(struct net *net, struct sock *ctnl,
                                 const struct nlmsghdr *nlh,
                                 const struct nlattr * const cda[])
 {
-       struct ctnl_timeout *cur;
+       struct ctnl_timeout *cur, *tmp;
        int ret = -ENOENT;
        char *name;
 
        if (!cda[CTA_TIMEOUT_NAME]) {
-               list_for_each_entry(cur, &net->nfct_timeout_list, head)
+               list_for_each_entry_safe(cur, tmp, &net->nfct_timeout_list,
+                                        head)
                        ctnl_timeout_try_del(net, cur);
 
                return 0;
@@ -543,7 +544,9 @@ err:
 
 static void ctnl_timeout_put(struct ctnl_timeout *timeout)
 {
-       atomic_dec(&timeout->refcnt);
+       if (atomic_dec_and_test(&timeout->refcnt))
+               kfree_rcu(timeout, rcu_head);
+
        module_put(THIS_MODULE);
 }
 #endif /* CONFIG_NF_CONNTRACK_TIMEOUT */
@@ -591,7 +594,9 @@ static void __net_exit cttimeout_net_exit(struct net *net)
        list_for_each_entry_safe(cur, tmp, &net->nfct_timeout_list, head) {
                list_del_rcu(&cur->head);
                nf_ct_l4proto_put(cur->l4proto);
-               kfree_rcu(cur, rcu_head);
+
+               if (atomic_dec_and_test(&cur->refcnt))
+                       kfree_rcu(cur, rcu_head);
        }
 }