Merge branch 'for-3.9-fixes' of git://git.kernel.org/pub/scm/linux/kernel/git/tj/wq
[cascardo/linux.git] / include / linux / netfilter / ipset / ip_set_ahash.h
1 #ifndef _IP_SET_AHASH_H
2 #define _IP_SET_AHASH_H
3
4 #include <linux/rcupdate.h>
5 #include <linux/jhash.h>
6 #include <linux/netfilter/ipset/ip_set_timeout.h>
7
8 #define CONCAT(a, b, c)         a##b##c
9 #define TOKEN(a, b, c)          CONCAT(a, b, c)
10
11 #define type_pf_next            TOKEN(TYPE, PF, _elem)
12
13 /* Hashing which uses arrays to resolve clashing. The hash table is resized
14  * (doubled) when searching becomes too long.
15  * Internally jhash is used with the assumption that the size of the
16  * stored data is a multiple of sizeof(u32). If storage supports timeout,
17  * the timeout field must be the last one in the data structure - that field
18  * is ignored when computing the hash key.
19  *
20  * Readers and resizing
21  *
22  * Resizing can be triggered by userspace command only, and those
23  * are serialized by the nfnl mutex. During resizing the set is
24  * read-locked, so the only possible concurrent operations are
25  * the kernel side readers. Those must be protected by proper RCU locking.
26  */
27
28 /* Number of elements to store in an initial array block */
29 #define AHASH_INIT_SIZE                 4
30 /* Max number of elements to store in an array block */
31 #define AHASH_MAX_SIZE                  (3*AHASH_INIT_SIZE)
32
33 /* Max number of elements can be tuned */
34 #ifdef IP_SET_HASH_WITH_MULTI
35 #define AHASH_MAX(h)                    ((h)->ahash_max)
36
37 static inline u8
38 tune_ahash_max(u8 curr, u32 multi)
39 {
40         u32 n;
41
42         if (multi < curr)
43                 return curr;
44
45         n = curr + AHASH_INIT_SIZE;
46         /* Currently, at listing one hash bucket must fit into a message.
47          * Therefore we have a hard limit here.
48          */
49         return n > curr && n <= 64 ? n : curr;
50 }
51 #define TUNE_AHASH_MAX(h, multi)        \
52         ((h)->ahash_max = tune_ahash_max((h)->ahash_max, multi))
53 #else
54 #define AHASH_MAX(h)                    AHASH_MAX_SIZE
55 #define TUNE_AHASH_MAX(h, multi)
56 #endif
57
58 /* A hash bucket */
59 struct hbucket {
60         void *value;            /* the array of the values */
61         u8 size;                /* size of the array */
62         u8 pos;                 /* position of the first free entry */
63 };
64
65 /* The hash table: the table size stored here in order to make resizing easy */
66 struct htable {
67         u8 htable_bits;         /* size of hash table == 2^htable_bits */
68         struct hbucket bucket[0]; /* hashtable buckets */
69 };
70
71 #define hbucket(h, i)           (&((h)->bucket[i]))
72
73 /* Book-keeping of the prefixes added to the set */
74 struct ip_set_hash_nets {
75         u8 cidr;                /* the different cidr values in the set */
76         u32 nets;               /* number of elements per cidr */
77 };
78
79 /* The generic ip_set hash structure */
80 struct ip_set_hash {
81         struct htable *table;   /* the hash table */
82         u32 maxelem;            /* max elements in the hash */
83         u32 elements;           /* current element (vs timeout) */
84         u32 initval;            /* random jhash init value */
85         u32 timeout;            /* timeout value, if enabled */
86         struct timer_list gc;   /* garbage collection when timeout enabled */
87         struct type_pf_next next; /* temporary storage for uadd */
88 #ifdef IP_SET_HASH_WITH_MULTI
89         u8 ahash_max;           /* max elements in an array block */
90 #endif
91 #ifdef IP_SET_HASH_WITH_NETMASK
92         u8 netmask;             /* netmask value for subnets to store */
93 #endif
94 #ifdef IP_SET_HASH_WITH_RBTREE
95         struct rb_root rbtree;
96 #endif
97 #ifdef IP_SET_HASH_WITH_NETS
98         struct ip_set_hash_nets nets[0]; /* book-keeping of prefixes */
99 #endif
100 };
101
102 static size_t
103 htable_size(u8 hbits)
104 {
105         size_t hsize;
106
107         /* We must fit both into u32 in jhash and size_t */
108         if (hbits > 31)
109                 return 0;
110         hsize = jhash_size(hbits);
111         if ((((size_t)-1) - sizeof(struct htable))/sizeof(struct hbucket)
112             < hsize)
113                 return 0;
114
115         return hsize * sizeof(struct hbucket) + sizeof(struct htable);
116 }
117
118 /* Compute htable_bits from the user input parameter hashsize */
119 static u8
120 htable_bits(u32 hashsize)
121 {
122         /* Assume that hashsize == 2^htable_bits */
123         u8 bits = fls(hashsize - 1);
124         if (jhash_size(bits) != hashsize)
125                 /* Round up to the first 2^n value */
126                 bits = fls(hashsize);
127
128         return bits;
129 }
130
131 #ifdef IP_SET_HASH_WITH_NETS
132 #ifdef IP_SET_HASH_WITH_NETS_PACKED
133 /* When cidr is packed with nomatch, cidr - 1 is stored in the entry */
134 #define CIDR(cidr)      (cidr + 1)
135 #else
136 #define CIDR(cidr)      (cidr)
137 #endif
138
139 #define SET_HOST_MASK(family)   (family == AF_INET ? 32 : 128)
140 #ifdef IP_SET_HASH_WITH_MULTI
141 #define NETS_LENGTH(family)     (SET_HOST_MASK(family) + 1)
142 #else
143 #define NETS_LENGTH(family)     SET_HOST_MASK(family)
144 #endif
145
146 /* Network cidr size book keeping when the hash stores different
147  * sized networks */
148 static void
149 add_cidr(struct ip_set_hash *h, u8 cidr, u8 nets_length)
150 {
151         int i, j;
152
153         /* Add in increasing prefix order, so larger cidr first */
154         for (i = 0, j = -1; i < nets_length && h->nets[i].nets; i++) {
155                 if (j != -1)
156                         continue;
157                 else if (h->nets[i].cidr < cidr)
158                         j = i;
159                 else if (h->nets[i].cidr == cidr) {
160                         h->nets[i].nets++;
161                         return;
162                 }
163         }
164         if (j != -1) {
165                 for (; i > j; i--) {
166                         h->nets[i].cidr = h->nets[i - 1].cidr;
167                         h->nets[i].nets = h->nets[i - 1].nets;
168                 }
169         }
170         h->nets[i].cidr = cidr;
171         h->nets[i].nets = 1;
172 }
173
174 static void
175 del_cidr(struct ip_set_hash *h, u8 cidr, u8 nets_length)
176 {
177         u8 i, j;
178
179         for (i = 0; i < nets_length - 1 && h->nets[i].cidr != cidr; i++)
180                 ;
181         h->nets[i].nets--;
182
183         if (h->nets[i].nets != 0)
184                 return;
185
186         for (j = i; j < nets_length - 1 && h->nets[j].nets; j++) {
187                 h->nets[j].cidr = h->nets[j + 1].cidr;
188                 h->nets[j].nets = h->nets[j + 1].nets;
189         }
190 }
191 #else
192 #define NETS_LENGTH(family)             0
193 #endif
194
195 /* Destroy the hashtable part of the set */
196 static void
197 ahash_destroy(struct htable *t)
198 {
199         struct hbucket *n;
200         u32 i;
201
202         for (i = 0; i < jhash_size(t->htable_bits); i++) {
203                 n = hbucket(t, i);
204                 if (n->size)
205                         /* FIXME: use slab cache */
206                         kfree(n->value);
207         }
208
209         ip_set_free(t);
210 }
211
212 /* Calculate the actual memory size of the set data */
213 static size_t
214 ahash_memsize(const struct ip_set_hash *h, size_t dsize, u8 nets_length)
215 {
216         u32 i;
217         struct htable *t = h->table;
218         size_t memsize = sizeof(*h)
219                          + sizeof(*t)
220 #ifdef IP_SET_HASH_WITH_NETS
221                          + sizeof(struct ip_set_hash_nets) * nets_length
222 #endif
223                          + jhash_size(t->htable_bits) * sizeof(struct hbucket);
224
225         for (i = 0; i < jhash_size(t->htable_bits); i++)
226                         memsize += t->bucket[i].size * dsize;
227
228         return memsize;
229 }
230
231 /* Flush a hash type of set: destroy all elements */
232 static void
233 ip_set_hash_flush(struct ip_set *set)
234 {
235         struct ip_set_hash *h = set->data;
236         struct htable *t = h->table;
237         struct hbucket *n;
238         u32 i;
239
240         for (i = 0; i < jhash_size(t->htable_bits); i++) {
241                 n = hbucket(t, i);
242                 if (n->size) {
243                         n->size = n->pos = 0;
244                         /* FIXME: use slab cache */
245                         kfree(n->value);
246                 }
247         }
248 #ifdef IP_SET_HASH_WITH_NETS
249         memset(h->nets, 0, sizeof(struct ip_set_hash_nets)
250                            * NETS_LENGTH(set->family));
251 #endif
252         h->elements = 0;
253 }
254
255 /* Destroy a hash type of set */
256 static void
257 ip_set_hash_destroy(struct ip_set *set)
258 {
259         struct ip_set_hash *h = set->data;
260
261         if (with_timeout(h->timeout))
262                 del_timer_sync(&h->gc);
263
264         ahash_destroy(h->table);
265 #ifdef IP_SET_HASH_WITH_RBTREE
266         rbtree_destroy(&h->rbtree);
267 #endif
268         kfree(h);
269
270         set->data = NULL;
271 }
272
273 #endif /* _IP_SET_AHASH_H */
274
275 #ifndef HKEY_DATALEN
276 #define HKEY_DATALEN    sizeof(struct type_pf_elem)
277 #endif
278
279 #define HKEY(data, initval, htable_bits)                        \
280 (jhash2((u32 *)(data), HKEY_DATALEN/sizeof(u32), initval)       \
281         & jhash_mask(htable_bits))
282
283 /* Type/family dependent function prototypes */
284
285 #define type_pf_data_equal      TOKEN(TYPE, PF, _data_equal)
286 #define type_pf_data_isnull     TOKEN(TYPE, PF, _data_isnull)
287 #define type_pf_data_copy       TOKEN(TYPE, PF, _data_copy)
288 #define type_pf_data_zero_out   TOKEN(TYPE, PF, _data_zero_out)
289 #define type_pf_data_netmask    TOKEN(TYPE, PF, _data_netmask)
290 #define type_pf_data_list       TOKEN(TYPE, PF, _data_list)
291 #define type_pf_data_tlist      TOKEN(TYPE, PF, _data_tlist)
292 #define type_pf_data_next       TOKEN(TYPE, PF, _data_next)
293 #define type_pf_data_flags      TOKEN(TYPE, PF, _data_flags)
294 #ifdef IP_SET_HASH_WITH_NETS
295 #define type_pf_data_match      TOKEN(TYPE, PF, _data_match)
296 #else
297 #define type_pf_data_match(d)   1
298 #endif
299
300 #define type_pf_elem            TOKEN(TYPE, PF, _elem)
301 #define type_pf_telem           TOKEN(TYPE, PF, _telem)
302 #define type_pf_data_timeout    TOKEN(TYPE, PF, _data_timeout)
303 #define type_pf_data_expired    TOKEN(TYPE, PF, _data_expired)
304 #define type_pf_data_timeout_set TOKEN(TYPE, PF, _data_timeout_set)
305
306 #define type_pf_elem_add        TOKEN(TYPE, PF, _elem_add)
307 #define type_pf_add             TOKEN(TYPE, PF, _add)
308 #define type_pf_del             TOKEN(TYPE, PF, _del)
309 #define type_pf_test_cidrs      TOKEN(TYPE, PF, _test_cidrs)
310 #define type_pf_test            TOKEN(TYPE, PF, _test)
311
312 #define type_pf_elem_tadd       TOKEN(TYPE, PF, _elem_tadd)
313 #define type_pf_del_telem       TOKEN(TYPE, PF, _ahash_del_telem)
314 #define type_pf_expire          TOKEN(TYPE, PF, _expire)
315 #define type_pf_tadd            TOKEN(TYPE, PF, _tadd)
316 #define type_pf_tdel            TOKEN(TYPE, PF, _tdel)
317 #define type_pf_ttest_cidrs     TOKEN(TYPE, PF, _ahash_ttest_cidrs)
318 #define type_pf_ttest           TOKEN(TYPE, PF, _ahash_ttest)
319
320 #define type_pf_resize          TOKEN(TYPE, PF, _resize)
321 #define type_pf_tresize         TOKEN(TYPE, PF, _tresize)
322 #define type_pf_flush           ip_set_hash_flush
323 #define type_pf_destroy         ip_set_hash_destroy
324 #define type_pf_head            TOKEN(TYPE, PF, _head)
325 #define type_pf_list            TOKEN(TYPE, PF, _list)
326 #define type_pf_tlist           TOKEN(TYPE, PF, _tlist)
327 #define type_pf_same_set        TOKEN(TYPE, PF, _same_set)
328 #define type_pf_kadt            TOKEN(TYPE, PF, _kadt)
329 #define type_pf_uadt            TOKEN(TYPE, PF, _uadt)
330 #define type_pf_gc              TOKEN(TYPE, PF, _gc)
331 #define type_pf_gc_init         TOKEN(TYPE, PF, _gc_init)
332 #define type_pf_variant         TOKEN(TYPE, PF, _variant)
333 #define type_pf_tvariant        TOKEN(TYPE, PF, _tvariant)
334
335 /* Flavour without timeout */
336
337 /* Get the ith element from the array block n */
338 #define ahash_data(n, i)        \
339         ((struct type_pf_elem *)((n)->value) + (i))
340
341 /* Add an element to the hash table when resizing the set:
342  * we spare the maintenance of the internal counters. */
343 static int
344 type_pf_elem_add(struct hbucket *n, const struct type_pf_elem *value,
345                  u8 ahash_max, u32 cadt_flags)
346 {
347         struct type_pf_elem *data;
348
349         if (n->pos >= n->size) {
350                 void *tmp;
351
352                 if (n->size >= ahash_max)
353                         /* Trigger rehashing */
354                         return -EAGAIN;
355
356                 tmp = kzalloc((n->size + AHASH_INIT_SIZE)
357                               * sizeof(struct type_pf_elem),
358                               GFP_ATOMIC);
359                 if (!tmp)
360                         return -ENOMEM;
361                 if (n->size) {
362                         memcpy(tmp, n->value,
363                                sizeof(struct type_pf_elem) * n->size);
364                         kfree(n->value);
365                 }
366                 n->value = tmp;
367                 n->size += AHASH_INIT_SIZE;
368         }
369         data = ahash_data(n, n->pos++);
370         type_pf_data_copy(data, value);
371 #ifdef IP_SET_HASH_WITH_NETS
372         /* Resizing won't overwrite stored flags */
373         if (cadt_flags)
374                 type_pf_data_flags(data, cadt_flags);
375 #endif
376         return 0;
377 }
378
379 /* Resize a hash: create a new hash table with doubling the hashsize
380  * and inserting the elements to it. Repeat until we succeed or
381  * fail due to memory pressures. */
382 static int
383 type_pf_resize(struct ip_set *set, bool retried)
384 {
385         struct ip_set_hash *h = set->data;
386         struct htable *t, *orig = h->table;
387         u8 htable_bits = orig->htable_bits;
388         const struct type_pf_elem *data;
389         struct hbucket *n, *m;
390         u32 i, j;
391         int ret;
392
393 retry:
394         ret = 0;
395         htable_bits++;
396         pr_debug("attempt to resize set %s from %u to %u, t %p\n",
397                  set->name, orig->htable_bits, htable_bits, orig);
398         if (!htable_bits) {
399                 /* In case we have plenty of memory :-) */
400                 pr_warning("Cannot increase the hashsize of set %s further\n",
401                            set->name);
402                 return -IPSET_ERR_HASH_FULL;
403         }
404         t = ip_set_alloc(sizeof(*t)
405                          + jhash_size(htable_bits) * sizeof(struct hbucket));
406         if (!t)
407                 return -ENOMEM;
408         t->htable_bits = htable_bits;
409
410         read_lock_bh(&set->lock);
411         for (i = 0; i < jhash_size(orig->htable_bits); i++) {
412                 n = hbucket(orig, i);
413                 for (j = 0; j < n->pos; j++) {
414                         data = ahash_data(n, j);
415                         m = hbucket(t, HKEY(data, h->initval, htable_bits));
416                         ret = type_pf_elem_add(m, data, AHASH_MAX(h), 0);
417                         if (ret < 0) {
418                                 read_unlock_bh(&set->lock);
419                                 ahash_destroy(t);
420                                 if (ret == -EAGAIN)
421                                         goto retry;
422                                 return ret;
423                         }
424                 }
425         }
426
427         rcu_assign_pointer(h->table, t);
428         read_unlock_bh(&set->lock);
429
430         /* Give time to other readers of the set */
431         synchronize_rcu_bh();
432
433         pr_debug("set %s resized from %u (%p) to %u (%p)\n", set->name,
434                  orig->htable_bits, orig, t->htable_bits, t);
435         ahash_destroy(orig);
436
437         return 0;
438 }
439
440 static inline void
441 type_pf_data_next(struct ip_set_hash *h, const struct type_pf_elem *d);
442
443 /* Add an element to a hash and update the internal counters when succeeded,
444  * otherwise report the proper error code. */
445 static int
446 type_pf_add(struct ip_set *set, void *value, u32 timeout, u32 flags)
447 {
448         struct ip_set_hash *h = set->data;
449         struct htable *t;
450         const struct type_pf_elem *d = value;
451         struct hbucket *n;
452         int i, ret = 0;
453         u32 key, multi = 0;
454         u32 cadt_flags = flags >> 16;
455
456         if (h->elements >= h->maxelem) {
457                 if (net_ratelimit())
458                         pr_warning("Set %s is full, maxelem %u reached\n",
459                                    set->name, h->maxelem);
460                 return -IPSET_ERR_HASH_FULL;
461         }
462
463         rcu_read_lock_bh();
464         t = rcu_dereference_bh(h->table);
465         key = HKEY(value, h->initval, t->htable_bits);
466         n = hbucket(t, key);
467         for (i = 0; i < n->pos; i++)
468                 if (type_pf_data_equal(ahash_data(n, i), d, &multi)) {
469 #ifdef IP_SET_HASH_WITH_NETS
470                         if (flags & IPSET_FLAG_EXIST)
471                                 /* Support overwriting just the flags */
472                                 type_pf_data_flags(ahash_data(n, i),
473                                                    cadt_flags);
474 #endif
475                         ret = -IPSET_ERR_EXIST;
476                         goto out;
477                 }
478         TUNE_AHASH_MAX(h, multi);
479         ret = type_pf_elem_add(n, value, AHASH_MAX(h), cadt_flags);
480         if (ret != 0) {
481                 if (ret == -EAGAIN)
482                         type_pf_data_next(h, d);
483                 goto out;
484         }
485
486 #ifdef IP_SET_HASH_WITH_NETS
487         add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
488 #endif
489         h->elements++;
490 out:
491         rcu_read_unlock_bh();
492         return ret;
493 }
494
495 /* Delete an element from the hash: swap it with the last element
496  * and free up space if possible.
497  */
498 static int
499 type_pf_del(struct ip_set *set, void *value, u32 timeout, u32 flags)
500 {
501         struct ip_set_hash *h = set->data;
502         struct htable *t = h->table;
503         const struct type_pf_elem *d = value;
504         struct hbucket *n;
505         int i;
506         struct type_pf_elem *data;
507         u32 key, multi = 0;
508
509         key = HKEY(value, h->initval, t->htable_bits);
510         n = hbucket(t, key);
511         for (i = 0; i < n->pos; i++) {
512                 data = ahash_data(n, i);
513                 if (!type_pf_data_equal(data, d, &multi))
514                         continue;
515                 if (i != n->pos - 1)
516                         /* Not last one */
517                         type_pf_data_copy(data, ahash_data(n, n->pos - 1));
518
519                 n->pos--;
520                 h->elements--;
521 #ifdef IP_SET_HASH_WITH_NETS
522                 del_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
523 #endif
524                 if (n->pos + AHASH_INIT_SIZE < n->size) {
525                         void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
526                                             * sizeof(struct type_pf_elem),
527                                             GFP_ATOMIC);
528                         if (!tmp)
529                                 return 0;
530                         n->size -= AHASH_INIT_SIZE;
531                         memcpy(tmp, n->value,
532                                n->size * sizeof(struct type_pf_elem));
533                         kfree(n->value);
534                         n->value = tmp;
535                 }
536                 return 0;
537         }
538
539         return -IPSET_ERR_EXIST;
540 }
541
542 #ifdef IP_SET_HASH_WITH_NETS
543
544 /* Special test function which takes into account the different network
545  * sizes added to the set */
546 static int
547 type_pf_test_cidrs(struct ip_set *set, struct type_pf_elem *d, u32 timeout)
548 {
549         struct ip_set_hash *h = set->data;
550         struct htable *t = h->table;
551         struct hbucket *n;
552         const struct type_pf_elem *data;
553         int i, j = 0;
554         u32 key, multi = 0;
555         u8 nets_length = NETS_LENGTH(set->family);
556
557         pr_debug("test by nets\n");
558         for (; j < nets_length && h->nets[j].nets && !multi; j++) {
559                 type_pf_data_netmask(d, h->nets[j].cidr);
560                 key = HKEY(d, h->initval, t->htable_bits);
561                 n = hbucket(t, key);
562                 for (i = 0; i < n->pos; i++) {
563                         data = ahash_data(n, i);
564                         if (type_pf_data_equal(data, d, &multi))
565                                 return type_pf_data_match(data);
566                 }
567         }
568         return 0;
569 }
570 #endif
571
572 /* Test whether the element is added to the set */
573 static int
574 type_pf_test(struct ip_set *set, void *value, u32 timeout, u32 flags)
575 {
576         struct ip_set_hash *h = set->data;
577         struct htable *t = h->table;
578         struct type_pf_elem *d = value;
579         struct hbucket *n;
580         const struct type_pf_elem *data;
581         int i;
582         u32 key, multi = 0;
583
584 #ifdef IP_SET_HASH_WITH_NETS
585         /* If we test an IP address and not a network address,
586          * try all possible network sizes */
587         if (CIDR(d->cidr) == SET_HOST_MASK(set->family))
588                 return type_pf_test_cidrs(set, d, timeout);
589 #endif
590
591         key = HKEY(d, h->initval, t->htable_bits);
592         n = hbucket(t, key);
593         for (i = 0; i < n->pos; i++) {
594                 data = ahash_data(n, i);
595                 if (type_pf_data_equal(data, d, &multi))
596                         return type_pf_data_match(data);
597         }
598         return 0;
599 }
600
601 /* Reply a HEADER request: fill out the header part of the set */
602 static int
603 type_pf_head(struct ip_set *set, struct sk_buff *skb)
604 {
605         const struct ip_set_hash *h = set->data;
606         struct nlattr *nested;
607         size_t memsize;
608
609         read_lock_bh(&set->lock);
610         memsize = ahash_memsize(h, with_timeout(h->timeout)
611                                         ? sizeof(struct type_pf_telem)
612                                         : sizeof(struct type_pf_elem),
613                                 NETS_LENGTH(set->family));
614         read_unlock_bh(&set->lock);
615
616         nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
617         if (!nested)
618                 goto nla_put_failure;
619         if (nla_put_net32(skb, IPSET_ATTR_HASHSIZE,
620                           htonl(jhash_size(h->table->htable_bits))) ||
621             nla_put_net32(skb, IPSET_ATTR_MAXELEM, htonl(h->maxelem)))
622                 goto nla_put_failure;
623 #ifdef IP_SET_HASH_WITH_NETMASK
624         if (h->netmask != HOST_MASK &&
625             nla_put_u8(skb, IPSET_ATTR_NETMASK, h->netmask))
626                 goto nla_put_failure;
627 #endif
628         if (nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref - 1)) ||
629             nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)) ||
630             (with_timeout(h->timeout) &&
631              nla_put_net32(skb, IPSET_ATTR_TIMEOUT, htonl(h->timeout))))
632                 goto nla_put_failure;
633         ipset_nest_end(skb, nested);
634
635         return 0;
636 nla_put_failure:
637         return -EMSGSIZE;
638 }
639
640 /* Reply a LIST/SAVE request: dump the elements of the specified set */
641 static int
642 type_pf_list(const struct ip_set *set,
643              struct sk_buff *skb, struct netlink_callback *cb)
644 {
645         const struct ip_set_hash *h = set->data;
646         const struct htable *t = h->table;
647         struct nlattr *atd, *nested;
648         const struct hbucket *n;
649         const struct type_pf_elem *data;
650         u32 first = cb->args[2];
651         /* We assume that one hash bucket fills into one page */
652         void *incomplete;
653         int i;
654
655         atd = ipset_nest_start(skb, IPSET_ATTR_ADT);
656         if (!atd)
657                 return -EMSGSIZE;
658         pr_debug("list hash set %s\n", set->name);
659         for (; cb->args[2] < jhash_size(t->htable_bits); cb->args[2]++) {
660                 incomplete = skb_tail_pointer(skb);
661                 n = hbucket(t, cb->args[2]);
662                 pr_debug("cb->args[2]: %lu, t %p n %p\n", cb->args[2], t, n);
663                 for (i = 0; i < n->pos; i++) {
664                         data = ahash_data(n, i);
665                         pr_debug("list hash %lu hbucket %p i %u, data %p\n",
666                                  cb->args[2], n, i, data);
667                         nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
668                         if (!nested) {
669                                 if (cb->args[2] == first) {
670                                         nla_nest_cancel(skb, atd);
671                                         return -EMSGSIZE;
672                                 } else
673                                         goto nla_put_failure;
674                         }
675                         if (type_pf_data_list(skb, data))
676                                 goto nla_put_failure;
677                         ipset_nest_end(skb, nested);
678                 }
679         }
680         ipset_nest_end(skb, atd);
681         /* Set listing finished */
682         cb->args[2] = 0;
683
684         return 0;
685
686 nla_put_failure:
687         nlmsg_trim(skb, incomplete);
688         ipset_nest_end(skb, atd);
689         if (unlikely(first == cb->args[2])) {
690                 pr_warning("Can't list set %s: one bucket does not fit into "
691                            "a message. Please report it!\n", set->name);
692                 cb->args[2] = 0;
693                 return -EMSGSIZE;
694         }
695         return 0;
696 }
697
698 static int
699 type_pf_kadt(struct ip_set *set, const struct sk_buff *skb,
700              const struct xt_action_param *par,
701              enum ipset_adt adt, const struct ip_set_adt_opt *opt);
702 static int
703 type_pf_uadt(struct ip_set *set, struct nlattr *tb[],
704              enum ipset_adt adt, u32 *lineno, u32 flags, bool retried);
705
706 static const struct ip_set_type_variant type_pf_variant = {
707         .kadt   = type_pf_kadt,
708         .uadt   = type_pf_uadt,
709         .adt    = {
710                 [IPSET_ADD] = type_pf_add,
711                 [IPSET_DEL] = type_pf_del,
712                 [IPSET_TEST] = type_pf_test,
713         },
714         .destroy = type_pf_destroy,
715         .flush  = type_pf_flush,
716         .head   = type_pf_head,
717         .list   = type_pf_list,
718         .resize = type_pf_resize,
719         .same_set = type_pf_same_set,
720 };
721
722 /* Flavour with timeout support */
723
724 #define ahash_tdata(n, i) \
725         (struct type_pf_elem *)((struct type_pf_telem *)((n)->value) + (i))
726
727 static inline u32
728 type_pf_data_timeout(const struct type_pf_elem *data)
729 {
730         const struct type_pf_telem *tdata =
731                 (const struct type_pf_telem *) data;
732
733         return tdata->timeout;
734 }
735
736 static inline bool
737 type_pf_data_expired(const struct type_pf_elem *data)
738 {
739         const struct type_pf_telem *tdata =
740                 (const struct type_pf_telem *) data;
741
742         return ip_set_timeout_expired(tdata->timeout);
743 }
744
745 static inline void
746 type_pf_data_timeout_set(struct type_pf_elem *data, u32 timeout)
747 {
748         struct type_pf_telem *tdata = (struct type_pf_telem *) data;
749
750         tdata->timeout = ip_set_timeout_set(timeout);
751 }
752
753 static int
754 type_pf_elem_tadd(struct hbucket *n, const struct type_pf_elem *value,
755                   u8 ahash_max, u32 cadt_flags, u32 timeout)
756 {
757         struct type_pf_elem *data;
758
759         if (n->pos >= n->size) {
760                 void *tmp;
761
762                 if (n->size >= ahash_max)
763                         /* Trigger rehashing */
764                         return -EAGAIN;
765
766                 tmp = kzalloc((n->size + AHASH_INIT_SIZE)
767                               * sizeof(struct type_pf_telem),
768                               GFP_ATOMIC);
769                 if (!tmp)
770                         return -ENOMEM;
771                 if (n->size) {
772                         memcpy(tmp, n->value,
773                                sizeof(struct type_pf_telem) * n->size);
774                         kfree(n->value);
775                 }
776                 n->value = tmp;
777                 n->size += AHASH_INIT_SIZE;
778         }
779         data = ahash_tdata(n, n->pos++);
780         type_pf_data_copy(data, value);
781         type_pf_data_timeout_set(data, timeout);
782 #ifdef IP_SET_HASH_WITH_NETS
783         /* Resizing won't overwrite stored flags */
784         if (cadt_flags)
785                 type_pf_data_flags(data, cadt_flags);
786 #endif
787         return 0;
788 }
789
790 /* Delete expired elements from the hashtable */
791 static void
792 type_pf_expire(struct ip_set_hash *h, u8 nets_length)
793 {
794         struct htable *t = h->table;
795         struct hbucket *n;
796         struct type_pf_elem *data;
797         u32 i;
798         int j;
799
800         for (i = 0; i < jhash_size(t->htable_bits); i++) {
801                 n = hbucket(t, i);
802                 for (j = 0; j < n->pos; j++) {
803                         data = ahash_tdata(n, j);
804                         if (type_pf_data_expired(data)) {
805                                 pr_debug("expired %u/%u\n", i, j);
806 #ifdef IP_SET_HASH_WITH_NETS
807                                 del_cidr(h, CIDR(data->cidr), nets_length);
808 #endif
809                                 if (j != n->pos - 1)
810                                         /* Not last one */
811                                         type_pf_data_copy(data,
812                                                 ahash_tdata(n, n->pos - 1));
813                                 n->pos--;
814                                 h->elements--;
815                         }
816                 }
817                 if (n->pos + AHASH_INIT_SIZE < n->size) {
818                         void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
819                                             * sizeof(struct type_pf_telem),
820                                             GFP_ATOMIC);
821                         if (!tmp)
822                                 /* Still try to delete expired elements */
823                                 continue;
824                         n->size -= AHASH_INIT_SIZE;
825                         memcpy(tmp, n->value,
826                                n->size * sizeof(struct type_pf_telem));
827                         kfree(n->value);
828                         n->value = tmp;
829                 }
830         }
831 }
832
833 static int
834 type_pf_tresize(struct ip_set *set, bool retried)
835 {
836         struct ip_set_hash *h = set->data;
837         struct htable *t, *orig = h->table;
838         u8 htable_bits = orig->htable_bits;
839         const struct type_pf_elem *data;
840         struct hbucket *n, *m;
841         u32 i, j;
842         int ret;
843
844         /* Try to cleanup once */
845         if (!retried) {
846                 i = h->elements;
847                 write_lock_bh(&set->lock);
848                 type_pf_expire(set->data, NETS_LENGTH(set->family));
849                 write_unlock_bh(&set->lock);
850                 if (h->elements <  i)
851                         return 0;
852         }
853
854 retry:
855         ret = 0;
856         htable_bits++;
857         pr_debug("attempt to resize set %s from %u to %u, t %p\n",
858                  set->name, orig->htable_bits, htable_bits, orig);
859         if (!htable_bits) {
860                 /* In case we have plenty of memory :-) */
861                 pr_warning("Cannot increase the hashsize of set %s further\n",
862                            set->name);
863                 return -IPSET_ERR_HASH_FULL;
864         }
865         t = ip_set_alloc(sizeof(*t)
866                          + jhash_size(htable_bits) * sizeof(struct hbucket));
867         if (!t)
868                 return -ENOMEM;
869         t->htable_bits = htable_bits;
870
871         read_lock_bh(&set->lock);
872         for (i = 0; i < jhash_size(orig->htable_bits); i++) {
873                 n = hbucket(orig, i);
874                 for (j = 0; j < n->pos; j++) {
875                         data = ahash_tdata(n, j);
876                         m = hbucket(t, HKEY(data, h->initval, htable_bits));
877                         ret = type_pf_elem_tadd(m, data, AHASH_MAX(h), 0,
878                                                 ip_set_timeout_get(type_pf_data_timeout(data)));
879                         if (ret < 0) {
880                                 read_unlock_bh(&set->lock);
881                                 ahash_destroy(t);
882                                 if (ret == -EAGAIN)
883                                         goto retry;
884                                 return ret;
885                         }
886                 }
887         }
888
889         rcu_assign_pointer(h->table, t);
890         read_unlock_bh(&set->lock);
891
892         /* Give time to other readers of the set */
893         synchronize_rcu_bh();
894
895         ahash_destroy(orig);
896
897         return 0;
898 }
899
900 static int
901 type_pf_tadd(struct ip_set *set, void *value, u32 timeout, u32 flags)
902 {
903         struct ip_set_hash *h = set->data;
904         struct htable *t = h->table;
905         const struct type_pf_elem *d = value;
906         struct hbucket *n;
907         struct type_pf_elem *data;
908         int ret = 0, i, j = AHASH_MAX(h) + 1;
909         bool flag_exist = flags & IPSET_FLAG_EXIST;
910         u32 key, multi = 0;
911         u32 cadt_flags = flags >> 16;
912
913         if (h->elements >= h->maxelem)
914                 /* FIXME: when set is full, we slow down here */
915                 type_pf_expire(h, NETS_LENGTH(set->family));
916         if (h->elements >= h->maxelem) {
917                 if (net_ratelimit())
918                         pr_warning("Set %s is full, maxelem %u reached\n",
919                                    set->name, h->maxelem);
920                 return -IPSET_ERR_HASH_FULL;
921         }
922
923         rcu_read_lock_bh();
924         t = rcu_dereference_bh(h->table);
925         key = HKEY(d, h->initval, t->htable_bits);
926         n = hbucket(t, key);
927         for (i = 0; i < n->pos; i++) {
928                 data = ahash_tdata(n, i);
929                 if (type_pf_data_equal(data, d, &multi)) {
930                         if (type_pf_data_expired(data) || flag_exist)
931                                 /* Just timeout value may be updated */
932                                 j = i;
933                         else {
934                                 ret = -IPSET_ERR_EXIST;
935                                 goto out;
936                         }
937                 } else if (j == AHASH_MAX(h) + 1 &&
938                            type_pf_data_expired(data))
939                         j = i;
940         }
941         if (j != AHASH_MAX(h) + 1) {
942                 data = ahash_tdata(n, j);
943 #ifdef IP_SET_HASH_WITH_NETS
944                 del_cidr(h, CIDR(data->cidr), NETS_LENGTH(set->family));
945                 add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
946 #endif
947                 type_pf_data_copy(data, d);
948                 type_pf_data_timeout_set(data, timeout);
949 #ifdef IP_SET_HASH_WITH_NETS
950                 type_pf_data_flags(data, cadt_flags);
951 #endif
952                 goto out;
953         }
954         TUNE_AHASH_MAX(h, multi);
955         ret = type_pf_elem_tadd(n, d, AHASH_MAX(h), cadt_flags, timeout);
956         if (ret != 0) {
957                 if (ret == -EAGAIN)
958                         type_pf_data_next(h, d);
959                 goto out;
960         }
961
962 #ifdef IP_SET_HASH_WITH_NETS
963         add_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
964 #endif
965         h->elements++;
966 out:
967         rcu_read_unlock_bh();
968         return ret;
969 }
970
971 static int
972 type_pf_tdel(struct ip_set *set, void *value, u32 timeout, u32 flags)
973 {
974         struct ip_set_hash *h = set->data;
975         struct htable *t = h->table;
976         const struct type_pf_elem *d = value;
977         struct hbucket *n;
978         int i;
979         struct type_pf_elem *data;
980         u32 key, multi = 0;
981
982         key = HKEY(value, h->initval, t->htable_bits);
983         n = hbucket(t, key);
984         for (i = 0; i < n->pos; i++) {
985                 data = ahash_tdata(n, i);
986                 if (!type_pf_data_equal(data, d, &multi))
987                         continue;
988                 if (type_pf_data_expired(data))
989                         return -IPSET_ERR_EXIST;
990                 if (i != n->pos - 1)
991                         /* Not last one */
992                         type_pf_data_copy(data, ahash_tdata(n, n->pos - 1));
993
994                 n->pos--;
995                 h->elements--;
996 #ifdef IP_SET_HASH_WITH_NETS
997                 del_cidr(h, CIDR(d->cidr), NETS_LENGTH(set->family));
998 #endif
999                 if (n->pos + AHASH_INIT_SIZE < n->size) {
1000                         void *tmp = kzalloc((n->size - AHASH_INIT_SIZE)
1001                                             * sizeof(struct type_pf_telem),
1002                                             GFP_ATOMIC);
1003                         if (!tmp)
1004                                 return 0;
1005                         n->size -= AHASH_INIT_SIZE;
1006                         memcpy(tmp, n->value,
1007                                n->size * sizeof(struct type_pf_telem));
1008                         kfree(n->value);
1009                         n->value = tmp;
1010                 }
1011                 return 0;
1012         }
1013
1014         return -IPSET_ERR_EXIST;
1015 }
1016
1017 #ifdef IP_SET_HASH_WITH_NETS
1018 static int
1019 type_pf_ttest_cidrs(struct ip_set *set, struct type_pf_elem *d, u32 timeout)
1020 {
1021         struct ip_set_hash *h = set->data;
1022         struct htable *t = h->table;
1023         struct type_pf_elem *data;
1024         struct hbucket *n;
1025         int i, j = 0;
1026         u32 key, multi = 0;
1027         u8 nets_length = NETS_LENGTH(set->family);
1028
1029         for (; j < nets_length && h->nets[j].nets && !multi; j++) {
1030                 type_pf_data_netmask(d, h->nets[j].cidr);
1031                 key = HKEY(d, h->initval, t->htable_bits);
1032                 n = hbucket(t, key);
1033                 for (i = 0; i < n->pos; i++) {
1034                         data = ahash_tdata(n, i);
1035 #ifdef IP_SET_HASH_WITH_MULTI
1036                         if (type_pf_data_equal(data, d, &multi)) {
1037                                 if (!type_pf_data_expired(data))
1038                                         return type_pf_data_match(data);
1039                                 multi = 0;
1040                         }
1041 #else
1042                         if (type_pf_data_equal(data, d, &multi) &&
1043                             !type_pf_data_expired(data))
1044                                 return type_pf_data_match(data);
1045 #endif
1046                 }
1047         }
1048         return 0;
1049 }
1050 #endif
1051
1052 static int
1053 type_pf_ttest(struct ip_set *set, void *value, u32 timeout, u32 flags)
1054 {
1055         struct ip_set_hash *h = set->data;
1056         struct htable *t = h->table;
1057         struct type_pf_elem *data, *d = value;
1058         struct hbucket *n;
1059         int i;
1060         u32 key, multi = 0;
1061
1062 #ifdef IP_SET_HASH_WITH_NETS
1063         if (CIDR(d->cidr) == SET_HOST_MASK(set->family))
1064                 return type_pf_ttest_cidrs(set, d, timeout);
1065 #endif
1066         key = HKEY(d, h->initval, t->htable_bits);
1067         n = hbucket(t, key);
1068         for (i = 0; i < n->pos; i++) {
1069                 data = ahash_tdata(n, i);
1070                 if (type_pf_data_equal(data, d, &multi) &&
1071                     !type_pf_data_expired(data))
1072                         return type_pf_data_match(data);
1073         }
1074         return 0;
1075 }
1076
1077 static int
1078 type_pf_tlist(const struct ip_set *set,
1079               struct sk_buff *skb, struct netlink_callback *cb)
1080 {
1081         const struct ip_set_hash *h = set->data;
1082         const struct htable *t = h->table;
1083         struct nlattr *atd, *nested;
1084         const struct hbucket *n;
1085         const struct type_pf_elem *data;
1086         u32 first = cb->args[2];
1087         /* We assume that one hash bucket fills into one page */
1088         void *incomplete;
1089         int i;
1090
1091         atd = ipset_nest_start(skb, IPSET_ATTR_ADT);
1092         if (!atd)
1093                 return -EMSGSIZE;
1094         for (; cb->args[2] < jhash_size(t->htable_bits); cb->args[2]++) {
1095                 incomplete = skb_tail_pointer(skb);
1096                 n = hbucket(t, cb->args[2]);
1097                 for (i = 0; i < n->pos; i++) {
1098                         data = ahash_tdata(n, i);
1099                         pr_debug("list %p %u\n", n, i);
1100                         if (type_pf_data_expired(data))
1101                                 continue;
1102                         pr_debug("do list %p %u\n", n, i);
1103                         nested = ipset_nest_start(skb, IPSET_ATTR_DATA);
1104                         if (!nested) {
1105                                 if (cb->args[2] == first) {
1106                                         nla_nest_cancel(skb, atd);
1107                                         return -EMSGSIZE;
1108                                 } else
1109                                         goto nla_put_failure;
1110                         }
1111                         if (type_pf_data_tlist(skb, data))
1112                                 goto nla_put_failure;
1113                         ipset_nest_end(skb, nested);
1114                 }
1115         }
1116         ipset_nest_end(skb, atd);
1117         /* Set listing finished */
1118         cb->args[2] = 0;
1119
1120         return 0;
1121
1122 nla_put_failure:
1123         nlmsg_trim(skb, incomplete);
1124         ipset_nest_end(skb, atd);
1125         if (unlikely(first == cb->args[2])) {
1126                 pr_warning("Can't list set %s: one bucket does not fit into "
1127                            "a message. Please report it!\n", set->name);
1128                 cb->args[2] = 0;
1129                 return -EMSGSIZE;
1130         }
1131         return 0;
1132 }
1133
1134 static const struct ip_set_type_variant type_pf_tvariant = {
1135         .kadt   = type_pf_kadt,
1136         .uadt   = type_pf_uadt,
1137         .adt    = {
1138                 [IPSET_ADD] = type_pf_tadd,
1139                 [IPSET_DEL] = type_pf_tdel,
1140                 [IPSET_TEST] = type_pf_ttest,
1141         },
1142         .destroy = type_pf_destroy,
1143         .flush  = type_pf_flush,
1144         .head   = type_pf_head,
1145         .list   = type_pf_tlist,
1146         .resize = type_pf_tresize,
1147         .same_set = type_pf_same_set,
1148 };
1149
1150 static void
1151 type_pf_gc(unsigned long ul_set)
1152 {
1153         struct ip_set *set = (struct ip_set *) ul_set;
1154         struct ip_set_hash *h = set->data;
1155
1156         pr_debug("called\n");
1157         write_lock_bh(&set->lock);
1158         type_pf_expire(h, NETS_LENGTH(set->family));
1159         write_unlock_bh(&set->lock);
1160
1161         h->gc.expires = jiffies + IPSET_GC_PERIOD(h->timeout) * HZ;
1162         add_timer(&h->gc);
1163 }
1164
1165 static void
1166 type_pf_gc_init(struct ip_set *set)
1167 {
1168         struct ip_set_hash *h = set->data;
1169
1170         init_timer(&h->gc);
1171         h->gc.data = (unsigned long) set;
1172         h->gc.function = type_pf_gc;
1173         h->gc.expires = jiffies + IPSET_GC_PERIOD(h->timeout) * HZ;
1174         add_timer(&h->gc);
1175         pr_debug("gc initialized, run in every %u\n",
1176                  IPSET_GC_PERIOD(h->timeout));
1177 }
1178
1179 #undef HKEY_DATALEN
1180 #undef HKEY
1181 #undef type_pf_data_equal
1182 #undef type_pf_data_isnull
1183 #undef type_pf_data_copy
1184 #undef type_pf_data_zero_out
1185 #undef type_pf_data_netmask
1186 #undef type_pf_data_list
1187 #undef type_pf_data_tlist
1188 #undef type_pf_data_next
1189 #undef type_pf_data_flags
1190 #undef type_pf_data_match
1191
1192 #undef type_pf_elem
1193 #undef type_pf_telem
1194 #undef type_pf_data_timeout
1195 #undef type_pf_data_expired
1196 #undef type_pf_data_timeout_set
1197
1198 #undef type_pf_elem_add
1199 #undef type_pf_add
1200 #undef type_pf_del
1201 #undef type_pf_test_cidrs
1202 #undef type_pf_test
1203
1204 #undef type_pf_elem_tadd
1205 #undef type_pf_del_telem
1206 #undef type_pf_expire
1207 #undef type_pf_tadd
1208 #undef type_pf_tdel
1209 #undef type_pf_ttest_cidrs
1210 #undef type_pf_ttest
1211
1212 #undef type_pf_resize
1213 #undef type_pf_tresize
1214 #undef type_pf_flush
1215 #undef type_pf_destroy
1216 #undef type_pf_head
1217 #undef type_pf_list
1218 #undef type_pf_tlist
1219 #undef type_pf_same_set
1220 #undef type_pf_kadt
1221 #undef type_pf_uadt
1222 #undef type_pf_gc
1223 #undef type_pf_gc_init
1224 #undef type_pf_variant
1225 #undef type_pf_tvariant