Merge branch 'for-linus' of git://git.kernel.org/pub/scm/linux/kernel/git/mason/linux...
[cascardo/linux.git] / net / netfilter / ipset / ip_set_hash_gen.h
index fee7c64..974ff38 100644 (file)
@@ -147,16 +147,22 @@ hbucket_elem_add(struct hbucket *n, u8 ahash_max, size_t dsize)
 #else
 #define __CIDR(cidr, i)                (cidr)
 #endif
+
+/* cidr + 1 is stored in net_prefixes to support /0 */
+#define SCIDR(cidr, i)         (__CIDR(cidr, i) + 1)
+
 #ifdef IP_SET_HASH_WITH_NETS_PACKED
-/* When cidr is packed with nomatch, cidr - 1 is stored in the entry */
-#define CIDR(cidr, i)          (__CIDR(cidr, i) + 1)
+/* When cidr is packed with nomatch, cidr - 1 is stored in the data entry */
+#define GCIDR(cidr, i)         (__CIDR(cidr, i) + 1)
+#define NCIDR(cidr)            (cidr)
 #else
-#define CIDR(cidr, i)          (__CIDR(cidr, i))
+#define GCIDR(cidr, i)         (__CIDR(cidr, i))
+#define NCIDR(cidr)            (cidr - 1)
 #endif
 
 #define SET_HOST_MASK(family)  (family == AF_INET ? 32 : 128)
 
-#ifdef IP_SET_HASH_WITH_MULTI
+#ifdef IP_SET_HASH_WITH_NET0
 #define NLEN(family)           (SET_HOST_MASK(family) + 1)
 #else
 #define NLEN(family)           SET_HOST_MASK(family)
@@ -292,24 +298,22 @@ mtype_add_cidr(struct htype *h, u8 cidr, u8 nets_length, u8 n)
        int i, j;
 
        /* Add in increasing prefix order, so larger cidr first */
-       for (i = 0, j = -1; i < nets_length && h->nets[i].nets[n]; i++) {
+       for (i = 0, j = -1; i < nets_length && h->nets[i].cidr[n]; i++) {
                if (j != -1)
                        continue;
                else if (h->nets[i].cidr[n] < cidr)
                        j = i;
                else if (h->nets[i].cidr[n] == cidr) {
-                       h->nets[i].nets[n]++;
+                       h->nets[cidr - 1].nets[n]++;
                        return;
                }
        }
        if (j != -1) {
-               for (; i > j; i--) {
+               for (; i > j; i--)
                        h->nets[i].cidr[n] = h->nets[i - 1].cidr[n];
-                       h->nets[i].nets[n] = h->nets[i - 1].nets[n];
-               }
        }
        h->nets[i].cidr[n] = cidr;
-       h->nets[i].nets[n] = 1;
+       h->nets[cidr - 1].nets[n] = 1;
 }
 
 static void
@@ -320,16 +324,12 @@ mtype_del_cidr(struct htype *h, u8 cidr, u8 nets_length, u8 n)
        for (i = 0; i < nets_length; i++) {
                if (h->nets[i].cidr[n] != cidr)
                        continue;
-                if (h->nets[i].nets[n] > 1 || i == net_end ||
-                    h->nets[i + 1].nets[n] == 0) {
-                        h->nets[i].nets[n]--;
+               h->nets[cidr -1].nets[n]--;
+               if (h->nets[cidr -1].nets[n] > 0)
                         return;
-                }
-                for (j = i; j < net_end && h->nets[j].nets[n]; j++) {
+               for (j = i; j < net_end && h->nets[j].cidr[n]; j++)
                        h->nets[j].cidr[n] = h->nets[j + 1].cidr[n];
-                       h->nets[j].nets[n] = h->nets[j + 1].nets[n];
-                }
-                h->nets[j].nets[n] = 0;
+               h->nets[j].cidr[n] = 0;
                 return;
        }
 }
@@ -486,7 +486,7 @@ mtype_expire(struct ip_set *set, struct htype *h, u8 nets_length, size_t dsize)
                                pr_debug("expired %u/%u\n", i, j);
 #ifdef IP_SET_HASH_WITH_NETS
                                for (k = 0; k < IPSET_NET_COUNT; k++)
-                                       mtype_del_cidr(h, CIDR(data->cidr, k),
+                                       mtype_del_cidr(h, SCIDR(data->cidr, k),
                                                       nets_length, k);
 #endif
                                ip_set_ext_destroy(set, data);
@@ -633,29 +633,6 @@ mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
        bool flag_exist = flags & IPSET_FLAG_EXIST;
        u32 key, multi = 0;
 
-       if (h->elements >= h->maxelem && SET_WITH_FORCEADD(set)) {
-               rcu_read_lock_bh();
-               t = rcu_dereference_bh(h->table);
-               key = HKEY(value, h->initval, t->htable_bits);
-               n = hbucket(t,key);
-               if (n->pos) {
-                       /* Choosing the first entry in the array to replace */
-                       j = 0;
-                       goto reuse_slot;
-               }
-               rcu_read_unlock_bh();
-       }
-       if (SET_WITH_TIMEOUT(set) && h->elements >= h->maxelem)
-               /* FIXME: when set is full, we slow down here */
-               mtype_expire(set, h, NLEN(set->family), set->dsize);
-
-       if (h->elements >= h->maxelem) {
-               if (net_ratelimit())
-                       pr_warn("Set %s is full, maxelem %u reached\n",
-                               set->name, h->maxelem);
-               return -IPSET_ERR_HASH_FULL;
-       }
-
        rcu_read_lock_bh();
        t = rcu_dereference_bh(h->table);
        key = HKEY(value, h->initval, t->htable_bits);
@@ -680,15 +657,32 @@ mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
                    j != AHASH_MAX(h) + 1)
                        j = i;
        }
+       if (h->elements >= h->maxelem && SET_WITH_FORCEADD(set) && n->pos) {
+               /* Choosing the first entry in the array to replace */
+               j = 0;
+               goto reuse_slot;
+       }
+       if (SET_WITH_TIMEOUT(set) && h->elements >= h->maxelem)
+               /* FIXME: when set is full, we slow down here */
+               mtype_expire(set, h, NLEN(set->family), set->dsize);
+
+       if (h->elements >= h->maxelem) {
+               if (net_ratelimit())
+                       pr_warn("Set %s is full, maxelem %u reached\n",
+                               set->name, h->maxelem);
+               ret = -IPSET_ERR_HASH_FULL;
+               goto out;
+       }
+
 reuse_slot:
        if (j != AHASH_MAX(h) + 1) {
                /* Fill out reused slot */
                data = ahash_data(n, j, set->dsize);
 #ifdef IP_SET_HASH_WITH_NETS
                for (i = 0; i < IPSET_NET_COUNT; i++) {
-                       mtype_del_cidr(h, CIDR(data->cidr, i),
+                       mtype_del_cidr(h, SCIDR(data->cidr, i),
                                       NLEN(set->family), i);
-                       mtype_add_cidr(h, CIDR(d->cidr, i),
+                       mtype_add_cidr(h, SCIDR(d->cidr, i),
                                       NLEN(set->family), i);
                }
 #endif
@@ -705,7 +699,7 @@ reuse_slot:
                data = ahash_data(n, n->pos++, set->dsize);
 #ifdef IP_SET_HASH_WITH_NETS
                for (i = 0; i < IPSET_NET_COUNT; i++)
-                       mtype_add_cidr(h, CIDR(d->cidr, i), NLEN(set->family),
+                       mtype_add_cidr(h, SCIDR(d->cidr, i), NLEN(set->family),
                                       i);
 #endif
                h->elements++;
@@ -766,7 +760,7 @@ mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext,
                h->elements--;
 #ifdef IP_SET_HASH_WITH_NETS
                for (j = 0; j < IPSET_NET_COUNT; j++)
-                       mtype_del_cidr(h, CIDR(d->cidr, j), NLEN(set->family),
+                       mtype_del_cidr(h, SCIDR(d->cidr, j), NLEN(set->family),
                                       j);
 #endif
                ip_set_ext_destroy(set, data);
@@ -827,15 +821,15 @@ mtype_test_cidrs(struct ip_set *set, struct mtype_elem *d,
        u8 nets_length = NLEN(set->family);
 
        pr_debug("test by nets\n");
-       for (; j < nets_length && h->nets[j].nets[0] && !multi; j++) {
+       for (; j < nets_length && h->nets[j].cidr[0] && !multi; j++) {
 #if IPSET_NET_COUNT == 2
                mtype_data_reset_elem(d, &orig);
-               mtype_data_netmask(d, h->nets[j].cidr[0], false);
-               for (k = 0; k < nets_length && h->nets[k].nets[1] && !multi;
+               mtype_data_netmask(d, NCIDR(h->nets[j].cidr[0]), false);
+               for (k = 0; k < nets_length && h->nets[k].cidr[1] && !multi;
                     k++) {
-                       mtype_data_netmask(d, h->nets[k].cidr[1], true);
+                       mtype_data_netmask(d, NCIDR(h->nets[k].cidr[1]), true);
 #else
-               mtype_data_netmask(d, h->nets[j].cidr[0]);
+               mtype_data_netmask(d, NCIDR(h->nets[j].cidr[0]));
 #endif
                key = HKEY(d, h->initval, t->htable_bits);
                n = hbucket(t, key);
@@ -883,7 +877,7 @@ mtype_test(struct ip_set *set, void *value, const struct ip_set_ext *ext,
        /* If we test an IP address and not a network address,
         * try all possible network sizes */
        for (i = 0; i < IPSET_NET_COUNT; i++)
-               if (CIDR(d->cidr, i) != SET_HOST_MASK(set->family))
+               if (GCIDR(d->cidr, i) != SET_HOST_MASK(set->family))
                        break;
        if (i == IPSET_NET_COUNT) {
                ret = mtype_test_cidrs(set, d, ext, mext, flags);
@@ -1107,8 +1101,7 @@ IPSET_TOKEN(HTYPE, _create)(struct net *net, struct ip_set *set,
 
        hsize = sizeof(*h);
 #ifdef IP_SET_HASH_WITH_NETS
-       hsize += sizeof(struct net_prefixes) *
-               (set->family == NFPROTO_IPV4 ? 32 : 128);
+       hsize += sizeof(struct net_prefixes) * NLEN(set->family);
 #endif
        h = kzalloc(hsize, GFP_KERNEL);
        if (!h)