Merge git://git.kernel.org/pub/scm/linux/kernel/git/pablo/nf
[cascardo/linux.git] / net / netfilter / nf_nat_core.c
1 /*
2  * (C) 1999-2001 Paul `Rusty' Russell
3  * (C) 2002-2006 Netfilter Core Team <coreteam@netfilter.org>
4  * (C) 2011 Patrick McHardy <kaber@trash.net>
5  *
6  * This program is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License version 2 as
8  * published by the Free Software Foundation.
9  */
10
11 #include <linux/module.h>
12 #include <linux/types.h>
13 #include <linux/timer.h>
14 #include <linux/skbuff.h>
15 #include <linux/gfp.h>
16 #include <net/xfrm.h>
17 #include <linux/jhash.h>
18 #include <linux/rtnetlink.h>
19
20 #include <net/netfilter/nf_conntrack.h>
21 #include <net/netfilter/nf_conntrack_core.h>
22 #include <net/netfilter/nf_nat.h>
23 #include <net/netfilter/nf_nat_l3proto.h>
24 #include <net/netfilter/nf_nat_l4proto.h>
25 #include <net/netfilter/nf_nat_core.h>
26 #include <net/netfilter/nf_nat_helper.h>
27 #include <net/netfilter/nf_conntrack_helper.h>
28 #include <net/netfilter/nf_conntrack_seqadj.h>
29 #include <net/netfilter/nf_conntrack_l3proto.h>
30 #include <net/netfilter/nf_conntrack_zones.h>
31 #include <linux/netfilter/nf_nat.h>
32
33 static DEFINE_SPINLOCK(nf_nat_lock);
34
35 static DEFINE_MUTEX(nf_nat_proto_mutex);
36 static const struct nf_nat_l3proto __rcu *nf_nat_l3protos[NFPROTO_NUMPROTO]
37                                                 __read_mostly;
38 static const struct nf_nat_l4proto __rcu **nf_nat_l4protos[NFPROTO_NUMPROTO]
39                                                 __read_mostly;
40
41 static struct hlist_head *nf_nat_bysource __read_mostly;
42 static unsigned int nf_nat_htable_size __read_mostly;
43 static unsigned int nf_nat_hash_rnd __read_mostly;
44
45 inline const struct nf_nat_l3proto *
46 __nf_nat_l3proto_find(u8 family)
47 {
48         return rcu_dereference(nf_nat_l3protos[family]);
49 }
50
51 inline const struct nf_nat_l4proto *
52 __nf_nat_l4proto_find(u8 family, u8 protonum)
53 {
54         return rcu_dereference(nf_nat_l4protos[family][protonum]);
55 }
56 EXPORT_SYMBOL_GPL(__nf_nat_l4proto_find);
57
58 #ifdef CONFIG_XFRM
59 static void __nf_nat_decode_session(struct sk_buff *skb, struct flowi *fl)
60 {
61         const struct nf_nat_l3proto *l3proto;
62         const struct nf_conn *ct;
63         enum ip_conntrack_info ctinfo;
64         enum ip_conntrack_dir dir;
65         unsigned  long statusbit;
66         u8 family;
67
68         ct = nf_ct_get(skb, &ctinfo);
69         if (ct == NULL)
70                 return;
71
72         family = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple.src.l3num;
73         rcu_read_lock();
74         l3proto = __nf_nat_l3proto_find(family);
75         if (l3proto == NULL)
76                 goto out;
77
78         dir = CTINFO2DIR(ctinfo);
79         if (dir == IP_CT_DIR_ORIGINAL)
80                 statusbit = IPS_DST_NAT;
81         else
82                 statusbit = IPS_SRC_NAT;
83
84         l3proto->decode_session(skb, ct, dir, statusbit, fl);
85 out:
86         rcu_read_unlock();
87 }
88
89 int nf_xfrm_me_harder(struct net *net, struct sk_buff *skb, unsigned int family)
90 {
91         struct flowi fl;
92         unsigned int hh_len;
93         struct dst_entry *dst;
94         int err;
95
96         err = xfrm_decode_session(skb, &fl, family);
97         if (err < 0)
98                 return err;
99
100         dst = skb_dst(skb);
101         if (dst->xfrm)
102                 dst = ((struct xfrm_dst *)dst)->route;
103         dst_hold(dst);
104
105         dst = xfrm_lookup(net, dst, &fl, skb->sk, 0);
106         if (IS_ERR(dst))
107                 return PTR_ERR(dst);
108
109         skb_dst_drop(skb);
110         skb_dst_set(skb, dst);
111
112         /* Change in oif may mean change in hh_len. */
113         hh_len = skb_dst(skb)->dev->hard_header_len;
114         if (skb_headroom(skb) < hh_len &&
115             pskb_expand_head(skb, hh_len - skb_headroom(skb), 0, GFP_ATOMIC))
116                 return -ENOMEM;
117         return 0;
118 }
119 EXPORT_SYMBOL(nf_xfrm_me_harder);
120 #endif /* CONFIG_XFRM */
121
122 /* We keep an extra hash for each conntrack, for fast searching. */
123 static inline unsigned int
124 hash_by_src(const struct net *n, const struct nf_conntrack_tuple *tuple)
125 {
126         unsigned int hash;
127
128         get_random_once(&nf_nat_hash_rnd, sizeof(nf_nat_hash_rnd));
129
130         /* Original src, to ensure we map it consistently if poss. */
131         hash = jhash2((u32 *)&tuple->src, sizeof(tuple->src) / sizeof(u32),
132                       tuple->dst.protonum ^ nf_nat_hash_rnd ^ net_hash_mix(n));
133
134         return reciprocal_scale(hash, nf_nat_htable_size);
135 }
136
137 /* Is this tuple already taken? (not by us) */
138 int
139 nf_nat_used_tuple(const struct nf_conntrack_tuple *tuple,
140                   const struct nf_conn *ignored_conntrack)
141 {
142         /* Conntrack tracking doesn't keep track of outgoing tuples; only
143          * incoming ones.  NAT means they don't have a fixed mapping,
144          * so we invert the tuple and look for the incoming reply.
145          *
146          * We could keep a separate hash if this proves too slow.
147          */
148         struct nf_conntrack_tuple reply;
149
150         nf_ct_invert_tuplepr(&reply, tuple);
151         return nf_conntrack_tuple_taken(&reply, ignored_conntrack);
152 }
153 EXPORT_SYMBOL(nf_nat_used_tuple);
154
155 /* If we source map this tuple so reply looks like reply_tuple, will
156  * that meet the constraints of range.
157  */
158 static int in_range(const struct nf_nat_l3proto *l3proto,
159                     const struct nf_nat_l4proto *l4proto,
160                     const struct nf_conntrack_tuple *tuple,
161                     const struct nf_nat_range *range)
162 {
163         /* If we are supposed to map IPs, then we must be in the
164          * range specified, otherwise let this drag us onto a new src IP.
165          */
166         if (range->flags & NF_NAT_RANGE_MAP_IPS &&
167             !l3proto->in_range(tuple, range))
168                 return 0;
169
170         if (!(range->flags & NF_NAT_RANGE_PROTO_SPECIFIED) ||
171             l4proto->in_range(tuple, NF_NAT_MANIP_SRC,
172                               &range->min_proto, &range->max_proto))
173                 return 1;
174
175         return 0;
176 }
177
178 static inline int
179 same_src(const struct nf_conn *ct,
180          const struct nf_conntrack_tuple *tuple)
181 {
182         const struct nf_conntrack_tuple *t;
183
184         t = &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple;
185         return (t->dst.protonum == tuple->dst.protonum &&
186                 nf_inet_addr_cmp(&t->src.u3, &tuple->src.u3) &&
187                 t->src.u.all == tuple->src.u.all);
188 }
189
190 /* Only called for SRC manip */
191 static int
192 find_appropriate_src(struct net *net,
193                      const struct nf_conntrack_zone *zone,
194                      const struct nf_nat_l3proto *l3proto,
195                      const struct nf_nat_l4proto *l4proto,
196                      const struct nf_conntrack_tuple *tuple,
197                      struct nf_conntrack_tuple *result,
198                      const struct nf_nat_range *range)
199 {
200         unsigned int h = hash_by_src(net, tuple);
201         const struct nf_conn_nat *nat;
202         const struct nf_conn *ct;
203
204         hlist_for_each_entry_rcu(nat, &nf_nat_bysource[h], bysource) {
205                 ct = nat->ct;
206                 if (same_src(ct, tuple) &&
207                     net_eq(net, nf_ct_net(ct)) &&
208                     nf_ct_zone_equal(ct, zone, IP_CT_DIR_ORIGINAL)) {
209                         /* Copy source part from reply tuple. */
210                         nf_ct_invert_tuplepr(result,
211                                        &ct->tuplehash[IP_CT_DIR_REPLY].tuple);
212                         result->dst = tuple->dst;
213
214                         if (in_range(l3proto, l4proto, result, range))
215                                 return 1;
216                 }
217         }
218         return 0;
219 }
220
221 /* For [FUTURE] fragmentation handling, we want the least-used
222  * src-ip/dst-ip/proto triple.  Fairness doesn't come into it.  Thus
223  * if the range specifies 1.2.3.4 ports 10000-10005 and 1.2.3.5 ports
224  * 1-65535, we don't do pro-rata allocation based on ports; we choose
225  * the ip with the lowest src-ip/dst-ip/proto usage.
226  */
227 static void
228 find_best_ips_proto(const struct nf_conntrack_zone *zone,
229                     struct nf_conntrack_tuple *tuple,
230                     const struct nf_nat_range *range,
231                     const struct nf_conn *ct,
232                     enum nf_nat_manip_type maniptype)
233 {
234         union nf_inet_addr *var_ipp;
235         unsigned int i, max;
236         /* Host order */
237         u32 minip, maxip, j, dist;
238         bool full_range;
239
240         /* No IP mapping?  Do nothing. */
241         if (!(range->flags & NF_NAT_RANGE_MAP_IPS))
242                 return;
243
244         if (maniptype == NF_NAT_MANIP_SRC)
245                 var_ipp = &tuple->src.u3;
246         else
247                 var_ipp = &tuple->dst.u3;
248
249         /* Fast path: only one choice. */
250         if (nf_inet_addr_cmp(&range->min_addr, &range->max_addr)) {
251                 *var_ipp = range->min_addr;
252                 return;
253         }
254
255         if (nf_ct_l3num(ct) == NFPROTO_IPV4)
256                 max = sizeof(var_ipp->ip) / sizeof(u32) - 1;
257         else
258                 max = sizeof(var_ipp->ip6) / sizeof(u32) - 1;
259
260         /* Hashing source and destination IPs gives a fairly even
261          * spread in practice (if there are a small number of IPs
262          * involved, there usually aren't that many connections
263          * anyway).  The consistency means that servers see the same
264          * client coming from the same IP (some Internet Banking sites
265          * like this), even across reboots.
266          */
267         j = jhash2((u32 *)&tuple->src.u3, sizeof(tuple->src.u3) / sizeof(u32),
268                    range->flags & NF_NAT_RANGE_PERSISTENT ?
269                         0 : (__force u32)tuple->dst.u3.all[max] ^ zone->id);
270
271         full_range = false;
272         for (i = 0; i <= max; i++) {
273                 /* If first bytes of the address are at the maximum, use the
274                  * distance. Otherwise use the full range.
275                  */
276                 if (!full_range) {
277                         minip = ntohl((__force __be32)range->min_addr.all[i]);
278                         maxip = ntohl((__force __be32)range->max_addr.all[i]);
279                         dist  = maxip - minip + 1;
280                 } else {
281                         minip = 0;
282                         dist  = ~0;
283                 }
284
285                 var_ipp->all[i] = (__force __u32)
286                         htonl(minip + reciprocal_scale(j, dist));
287                 if (var_ipp->all[i] != range->max_addr.all[i])
288                         full_range = true;
289
290                 if (!(range->flags & NF_NAT_RANGE_PERSISTENT))
291                         j ^= (__force u32)tuple->dst.u3.all[i];
292         }
293 }
294
295 /* Manipulate the tuple into the range given. For NF_INET_POST_ROUTING,
296  * we change the source to map into the range. For NF_INET_PRE_ROUTING
297  * and NF_INET_LOCAL_OUT, we change the destination to map into the
298  * range. It might not be possible to get a unique tuple, but we try.
299  * At worst (or if we race), we will end up with a final duplicate in
300  * __ip_conntrack_confirm and drop the packet. */
301 static void
302 get_unique_tuple(struct nf_conntrack_tuple *tuple,
303                  const struct nf_conntrack_tuple *orig_tuple,
304                  const struct nf_nat_range *range,
305                  struct nf_conn *ct,
306                  enum nf_nat_manip_type maniptype)
307 {
308         const struct nf_conntrack_zone *zone;
309         const struct nf_nat_l3proto *l3proto;
310         const struct nf_nat_l4proto *l4proto;
311         struct net *net = nf_ct_net(ct);
312
313         zone = nf_ct_zone(ct);
314
315         rcu_read_lock();
316         l3proto = __nf_nat_l3proto_find(orig_tuple->src.l3num);
317         l4proto = __nf_nat_l4proto_find(orig_tuple->src.l3num,
318                                         orig_tuple->dst.protonum);
319
320         /* 1) If this srcip/proto/src-proto-part is currently mapped,
321          * and that same mapping gives a unique tuple within the given
322          * range, use that.
323          *
324          * This is only required for source (ie. NAT/masq) mappings.
325          * So far, we don't do local source mappings, so multiple
326          * manips not an issue.
327          */
328         if (maniptype == NF_NAT_MANIP_SRC &&
329             !(range->flags & NF_NAT_RANGE_PROTO_RANDOM_ALL)) {
330                 /* try the original tuple first */
331                 if (in_range(l3proto, l4proto, orig_tuple, range)) {
332                         if (!nf_nat_used_tuple(orig_tuple, ct)) {
333                                 *tuple = *orig_tuple;
334                                 goto out;
335                         }
336                 } else if (find_appropriate_src(net, zone, l3proto, l4proto,
337                                                 orig_tuple, tuple, range)) {
338                         pr_debug("get_unique_tuple: Found current src map\n");
339                         if (!nf_nat_used_tuple(tuple, ct))
340                                 goto out;
341                 }
342         }
343
344         /* 2) Select the least-used IP/proto combination in the given range */
345         *tuple = *orig_tuple;
346         find_best_ips_proto(zone, tuple, range, ct, maniptype);
347
348         /* 3) The per-protocol part of the manip is made to map into
349          * the range to make a unique tuple.
350          */
351
352         /* Only bother mapping if it's not already in range and unique */
353         if (!(range->flags & NF_NAT_RANGE_PROTO_RANDOM_ALL)) {
354                 if (range->flags & NF_NAT_RANGE_PROTO_SPECIFIED) {
355                         if (l4proto->in_range(tuple, maniptype,
356                                               &range->min_proto,
357                                               &range->max_proto) &&
358                             (range->min_proto.all == range->max_proto.all ||
359                              !nf_nat_used_tuple(tuple, ct)))
360                                 goto out;
361                 } else if (!nf_nat_used_tuple(tuple, ct)) {
362                         goto out;
363                 }
364         }
365
366         /* Last change: get protocol to try to obtain unique tuple. */
367         l4proto->unique_tuple(l3proto, tuple, range, maniptype, ct);
368 out:
369         rcu_read_unlock();
370 }
371
372 struct nf_conn_nat *nf_ct_nat_ext_add(struct nf_conn *ct)
373 {
374         struct nf_conn_nat *nat = nfct_nat(ct);
375         if (nat)
376                 return nat;
377
378         if (!nf_ct_is_confirmed(ct))
379                 nat = nf_ct_ext_add(ct, NF_CT_EXT_NAT, GFP_ATOMIC);
380
381         return nat;
382 }
383 EXPORT_SYMBOL_GPL(nf_ct_nat_ext_add);
384
385 unsigned int
386 nf_nat_setup_info(struct nf_conn *ct,
387                   const struct nf_nat_range *range,
388                   enum nf_nat_manip_type maniptype)
389 {
390         struct net *net = nf_ct_net(ct);
391         struct nf_conntrack_tuple curr_tuple, new_tuple;
392         struct nf_conn_nat *nat;
393
394         /* nat helper or nfctnetlink also setup binding */
395         nat = nf_ct_nat_ext_add(ct);
396         if (nat == NULL)
397                 return NF_ACCEPT;
398
399         NF_CT_ASSERT(maniptype == NF_NAT_MANIP_SRC ||
400                      maniptype == NF_NAT_MANIP_DST);
401         BUG_ON(nf_nat_initialized(ct, maniptype));
402
403         /* What we've got will look like inverse of reply. Normally
404          * this is what is in the conntrack, except for prior
405          * manipulations (future optimization: if num_manips == 0,
406          * orig_tp = ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple)
407          */
408         nf_ct_invert_tuplepr(&curr_tuple,
409                              &ct->tuplehash[IP_CT_DIR_REPLY].tuple);
410
411         get_unique_tuple(&new_tuple, &curr_tuple, range, ct, maniptype);
412
413         if (!nf_ct_tuple_equal(&new_tuple, &curr_tuple)) {
414                 struct nf_conntrack_tuple reply;
415
416                 /* Alter conntrack table so will recognize replies. */
417                 nf_ct_invert_tuplepr(&reply, &new_tuple);
418                 nf_conntrack_alter_reply(ct, &reply);
419
420                 /* Non-atomic: we own this at the moment. */
421                 if (maniptype == NF_NAT_MANIP_SRC)
422                         ct->status |= IPS_SRC_NAT;
423                 else
424                         ct->status |= IPS_DST_NAT;
425
426                 if (nfct_help(ct))
427                         nfct_seqadj_ext_add(ct);
428         }
429
430         if (maniptype == NF_NAT_MANIP_SRC) {
431                 unsigned int srchash;
432
433                 srchash = hash_by_src(net,
434                                       &ct->tuplehash[IP_CT_DIR_ORIGINAL].tuple);
435                 spin_lock_bh(&nf_nat_lock);
436                 /* nf_conntrack_alter_reply might re-allocate extension aera */
437                 nat = nfct_nat(ct);
438                 nat->ct = ct;
439                 hlist_add_head_rcu(&nat->bysource,
440                                    &nf_nat_bysource[srchash]);
441                 spin_unlock_bh(&nf_nat_lock);
442         }
443
444         /* It's done. */
445         if (maniptype == NF_NAT_MANIP_DST)
446                 ct->status |= IPS_DST_NAT_DONE;
447         else
448                 ct->status |= IPS_SRC_NAT_DONE;
449
450         return NF_ACCEPT;
451 }
452 EXPORT_SYMBOL(nf_nat_setup_info);
453
454 static unsigned int
455 __nf_nat_alloc_null_binding(struct nf_conn *ct, enum nf_nat_manip_type manip)
456 {
457         /* Force range to this IP; let proto decide mapping for
458          * per-proto parts (hence not IP_NAT_RANGE_PROTO_SPECIFIED).
459          * Use reply in case it's already been mangled (eg local packet).
460          */
461         union nf_inet_addr ip =
462                 (manip == NF_NAT_MANIP_SRC ?
463                 ct->tuplehash[IP_CT_DIR_REPLY].tuple.dst.u3 :
464                 ct->tuplehash[IP_CT_DIR_REPLY].tuple.src.u3);
465         struct nf_nat_range range = {
466                 .flags          = NF_NAT_RANGE_MAP_IPS,
467                 .min_addr       = ip,
468                 .max_addr       = ip,
469         };
470         return nf_nat_setup_info(ct, &range, manip);
471 }
472
473 unsigned int
474 nf_nat_alloc_null_binding(struct nf_conn *ct, unsigned int hooknum)
475 {
476         return __nf_nat_alloc_null_binding(ct, HOOK2MANIP(hooknum));
477 }
478 EXPORT_SYMBOL_GPL(nf_nat_alloc_null_binding);
479
480 /* Do packet manipulations according to nf_nat_setup_info. */
481 unsigned int nf_nat_packet(struct nf_conn *ct,
482                            enum ip_conntrack_info ctinfo,
483                            unsigned int hooknum,
484                            struct sk_buff *skb)
485 {
486         const struct nf_nat_l3proto *l3proto;
487         const struct nf_nat_l4proto *l4proto;
488         enum ip_conntrack_dir dir = CTINFO2DIR(ctinfo);
489         unsigned long statusbit;
490         enum nf_nat_manip_type mtype = HOOK2MANIP(hooknum);
491
492         if (mtype == NF_NAT_MANIP_SRC)
493                 statusbit = IPS_SRC_NAT;
494         else
495                 statusbit = IPS_DST_NAT;
496
497         /* Invert if this is reply dir. */
498         if (dir == IP_CT_DIR_REPLY)
499                 statusbit ^= IPS_NAT_MASK;
500
501         /* Non-atomic: these bits don't change. */
502         if (ct->status & statusbit) {
503                 struct nf_conntrack_tuple target;
504
505                 /* We are aiming to look like inverse of other direction. */
506                 nf_ct_invert_tuplepr(&target, &ct->tuplehash[!dir].tuple);
507
508                 l3proto = __nf_nat_l3proto_find(target.src.l3num);
509                 l4proto = __nf_nat_l4proto_find(target.src.l3num,
510                                                 target.dst.protonum);
511                 if (!l3proto->manip_pkt(skb, 0, l4proto, &target, mtype))
512                         return NF_DROP;
513         }
514         return NF_ACCEPT;
515 }
516 EXPORT_SYMBOL_GPL(nf_nat_packet);
517
518 struct nf_nat_proto_clean {
519         u8      l3proto;
520         u8      l4proto;
521 };
522
523 /* kill conntracks with affected NAT section */
524 static int nf_nat_proto_remove(struct nf_conn *i, void *data)
525 {
526         const struct nf_nat_proto_clean *clean = data;
527         struct nf_conn_nat *nat = nfct_nat(i);
528
529         if (!nat)
530                 return 0;
531
532         if ((clean->l3proto && nf_ct_l3num(i) != clean->l3proto) ||
533             (clean->l4proto && nf_ct_protonum(i) != clean->l4proto))
534                 return 0;
535
536         return i->status & IPS_NAT_MASK ? 1 : 0;
537 }
538
539 static int nf_nat_proto_clean(struct nf_conn *ct, void *data)
540 {
541         struct nf_conn_nat *nat = nfct_nat(ct);
542
543         if (nf_nat_proto_remove(ct, data))
544                 return 1;
545
546         if (!nat || !nat->ct)
547                 return 0;
548
549         /* This netns is being destroyed, and conntrack has nat null binding.
550          * Remove it from bysource hash, as the table will be freed soon.
551          *
552          * Else, when the conntrack is destoyed, nf_nat_cleanup_conntrack()
553          * will delete entry from already-freed table.
554          */
555         if (!del_timer(&ct->timeout))
556                 return 1;
557
558         spin_lock_bh(&nf_nat_lock);
559         hlist_del_rcu(&nat->bysource);
560         ct->status &= ~IPS_NAT_DONE_MASK;
561         nat->ct = NULL;
562         spin_unlock_bh(&nf_nat_lock);
563
564         add_timer(&ct->timeout);
565
566         /* don't delete conntrack.  Although that would make things a lot
567          * simpler, we'd end up flushing all conntracks on nat rmmod.
568          */
569         return 0;
570 }
571
572 static void nf_nat_l4proto_clean(u8 l3proto, u8 l4proto)
573 {
574         struct nf_nat_proto_clean clean = {
575                 .l3proto = l3proto,
576                 .l4proto = l4proto,
577         };
578         struct net *net;
579
580         rtnl_lock();
581         for_each_net(net)
582                 nf_ct_iterate_cleanup(net, nf_nat_proto_remove, &clean, 0, 0);
583         rtnl_unlock();
584 }
585
586 static void nf_nat_l3proto_clean(u8 l3proto)
587 {
588         struct nf_nat_proto_clean clean = {
589                 .l3proto = l3proto,
590         };
591         struct net *net;
592
593         rtnl_lock();
594
595         for_each_net(net)
596                 nf_ct_iterate_cleanup(net, nf_nat_proto_remove, &clean, 0, 0);
597         rtnl_unlock();
598 }
599
600 /* Protocol registration. */
601 int nf_nat_l4proto_register(u8 l3proto, const struct nf_nat_l4proto *l4proto)
602 {
603         const struct nf_nat_l4proto **l4protos;
604         unsigned int i;
605         int ret = 0;
606
607         mutex_lock(&nf_nat_proto_mutex);
608         if (nf_nat_l4protos[l3proto] == NULL) {
609                 l4protos = kmalloc(IPPROTO_MAX * sizeof(struct nf_nat_l4proto *),
610                                    GFP_KERNEL);
611                 if (l4protos == NULL) {
612                         ret = -ENOMEM;
613                         goto out;
614                 }
615
616                 for (i = 0; i < IPPROTO_MAX; i++)
617                         RCU_INIT_POINTER(l4protos[i], &nf_nat_l4proto_unknown);
618
619                 /* Before making proto_array visible to lockless readers,
620                  * we must make sure its content is committed to memory.
621                  */
622                 smp_wmb();
623
624                 nf_nat_l4protos[l3proto] = l4protos;
625         }
626
627         if (rcu_dereference_protected(
628                         nf_nat_l4protos[l3proto][l4proto->l4proto],
629                         lockdep_is_held(&nf_nat_proto_mutex)
630                         ) != &nf_nat_l4proto_unknown) {
631                 ret = -EBUSY;
632                 goto out;
633         }
634         RCU_INIT_POINTER(nf_nat_l4protos[l3proto][l4proto->l4proto], l4proto);
635  out:
636         mutex_unlock(&nf_nat_proto_mutex);
637         return ret;
638 }
639 EXPORT_SYMBOL_GPL(nf_nat_l4proto_register);
640
641 /* No one stores the protocol anywhere; simply delete it. */
642 void nf_nat_l4proto_unregister(u8 l3proto, const struct nf_nat_l4proto *l4proto)
643 {
644         mutex_lock(&nf_nat_proto_mutex);
645         RCU_INIT_POINTER(nf_nat_l4protos[l3proto][l4proto->l4proto],
646                          &nf_nat_l4proto_unknown);
647         mutex_unlock(&nf_nat_proto_mutex);
648         synchronize_rcu();
649
650         nf_nat_l4proto_clean(l3proto, l4proto->l4proto);
651 }
652 EXPORT_SYMBOL_GPL(nf_nat_l4proto_unregister);
653
654 int nf_nat_l3proto_register(const struct nf_nat_l3proto *l3proto)
655 {
656         int err;
657
658         err = nf_ct_l3proto_try_module_get(l3proto->l3proto);
659         if (err < 0)
660                 return err;
661
662         mutex_lock(&nf_nat_proto_mutex);
663         RCU_INIT_POINTER(nf_nat_l4protos[l3proto->l3proto][IPPROTO_TCP],
664                          &nf_nat_l4proto_tcp);
665         RCU_INIT_POINTER(nf_nat_l4protos[l3proto->l3proto][IPPROTO_UDP],
666                          &nf_nat_l4proto_udp);
667         mutex_unlock(&nf_nat_proto_mutex);
668
669         RCU_INIT_POINTER(nf_nat_l3protos[l3proto->l3proto], l3proto);
670         return 0;
671 }
672 EXPORT_SYMBOL_GPL(nf_nat_l3proto_register);
673
674 void nf_nat_l3proto_unregister(const struct nf_nat_l3proto *l3proto)
675 {
676         mutex_lock(&nf_nat_proto_mutex);
677         RCU_INIT_POINTER(nf_nat_l3protos[l3proto->l3proto], NULL);
678         mutex_unlock(&nf_nat_proto_mutex);
679         synchronize_rcu();
680
681         nf_nat_l3proto_clean(l3proto->l3proto);
682         nf_ct_l3proto_module_put(l3proto->l3proto);
683 }
684 EXPORT_SYMBOL_GPL(nf_nat_l3proto_unregister);
685
686 /* No one using conntrack by the time this called. */
687 static void nf_nat_cleanup_conntrack(struct nf_conn *ct)
688 {
689         struct nf_conn_nat *nat = nf_ct_ext_find(ct, NF_CT_EXT_NAT);
690
691         if (nat == NULL || nat->ct == NULL)
692                 return;
693
694         NF_CT_ASSERT(nat->ct->status & IPS_SRC_NAT_DONE);
695
696         spin_lock_bh(&nf_nat_lock);
697         hlist_del_rcu(&nat->bysource);
698         spin_unlock_bh(&nf_nat_lock);
699 }
700
701 static void nf_nat_move_storage(void *new, void *old)
702 {
703         struct nf_conn_nat *new_nat = new;
704         struct nf_conn_nat *old_nat = old;
705         struct nf_conn *ct = old_nat->ct;
706
707         if (!ct || !(ct->status & IPS_SRC_NAT_DONE))
708                 return;
709
710         spin_lock_bh(&nf_nat_lock);
711         hlist_replace_rcu(&old_nat->bysource, &new_nat->bysource);
712         spin_unlock_bh(&nf_nat_lock);
713 }
714
715 static struct nf_ct_ext_type nat_extend __read_mostly = {
716         .len            = sizeof(struct nf_conn_nat),
717         .align          = __alignof__(struct nf_conn_nat),
718         .destroy        = nf_nat_cleanup_conntrack,
719         .move           = nf_nat_move_storage,
720         .id             = NF_CT_EXT_NAT,
721         .flags          = NF_CT_EXT_F_PREALLOC,
722 };
723
724 #if IS_ENABLED(CONFIG_NF_CT_NETLINK)
725
726 #include <linux/netfilter/nfnetlink.h>
727 #include <linux/netfilter/nfnetlink_conntrack.h>
728
729 static const struct nla_policy protonat_nla_policy[CTA_PROTONAT_MAX+1] = {
730         [CTA_PROTONAT_PORT_MIN] = { .type = NLA_U16 },
731         [CTA_PROTONAT_PORT_MAX] = { .type = NLA_U16 },
732 };
733
734 static int nfnetlink_parse_nat_proto(struct nlattr *attr,
735                                      const struct nf_conn *ct,
736                                      struct nf_nat_range *range)
737 {
738         struct nlattr *tb[CTA_PROTONAT_MAX+1];
739         const struct nf_nat_l4proto *l4proto;
740         int err;
741
742         err = nla_parse_nested(tb, CTA_PROTONAT_MAX, attr, protonat_nla_policy);
743         if (err < 0)
744                 return err;
745
746         l4proto = __nf_nat_l4proto_find(nf_ct_l3num(ct), nf_ct_protonum(ct));
747         if (l4proto->nlattr_to_range)
748                 err = l4proto->nlattr_to_range(tb, range);
749
750         return err;
751 }
752
753 static const struct nla_policy nat_nla_policy[CTA_NAT_MAX+1] = {
754         [CTA_NAT_V4_MINIP]      = { .type = NLA_U32 },
755         [CTA_NAT_V4_MAXIP]      = { .type = NLA_U32 },
756         [CTA_NAT_V6_MINIP]      = { .len = sizeof(struct in6_addr) },
757         [CTA_NAT_V6_MAXIP]      = { .len = sizeof(struct in6_addr) },
758         [CTA_NAT_PROTO]         = { .type = NLA_NESTED },
759 };
760
761 static int
762 nfnetlink_parse_nat(const struct nlattr *nat,
763                     const struct nf_conn *ct, struct nf_nat_range *range,
764                     const struct nf_nat_l3proto *l3proto)
765 {
766         struct nlattr *tb[CTA_NAT_MAX+1];
767         int err;
768
769         memset(range, 0, sizeof(*range));
770
771         err = nla_parse_nested(tb, CTA_NAT_MAX, nat, nat_nla_policy);
772         if (err < 0)
773                 return err;
774
775         err = l3proto->nlattr_to_range(tb, range);
776         if (err < 0)
777                 return err;
778
779         if (!tb[CTA_NAT_PROTO])
780                 return 0;
781
782         return nfnetlink_parse_nat_proto(tb[CTA_NAT_PROTO], ct, range);
783 }
784
785 /* This function is called under rcu_read_lock() */
786 static int
787 nfnetlink_parse_nat_setup(struct nf_conn *ct,
788                           enum nf_nat_manip_type manip,
789                           const struct nlattr *attr)
790 {
791         struct nf_nat_range range;
792         const struct nf_nat_l3proto *l3proto;
793         int err;
794
795         /* Should not happen, restricted to creating new conntracks
796          * via ctnetlink.
797          */
798         if (WARN_ON_ONCE(nf_nat_initialized(ct, manip)))
799                 return -EEXIST;
800
801         /* Make sure that L3 NAT is there by when we call nf_nat_setup_info to
802          * attach the null binding, otherwise this may oops.
803          */
804         l3proto = __nf_nat_l3proto_find(nf_ct_l3num(ct));
805         if (l3proto == NULL)
806                 return -EAGAIN;
807
808         /* No NAT information has been passed, allocate the null-binding */
809         if (attr == NULL)
810                 return __nf_nat_alloc_null_binding(ct, manip);
811
812         err = nfnetlink_parse_nat(attr, ct, &range, l3proto);
813         if (err < 0)
814                 return err;
815
816         return nf_nat_setup_info(ct, &range, manip);
817 }
818 #else
819 static int
820 nfnetlink_parse_nat_setup(struct nf_conn *ct,
821                           enum nf_nat_manip_type manip,
822                           const struct nlattr *attr)
823 {
824         return -EOPNOTSUPP;
825 }
826 #endif
827
828 static void __net_exit nf_nat_net_exit(struct net *net)
829 {
830         struct nf_nat_proto_clean clean = {};
831
832         nf_ct_iterate_cleanup(net, nf_nat_proto_clean, &clean, 0, 0);
833 }
834
835 static struct pernet_operations nf_nat_net_ops = {
836         .exit = nf_nat_net_exit,
837 };
838
839 static struct nf_ct_helper_expectfn follow_master_nat = {
840         .name           = "nat-follow-master",
841         .expectfn       = nf_nat_follow_master,
842 };
843
844 static int __init nf_nat_init(void)
845 {
846         int ret;
847
848         /* Leave them the same for the moment. */
849         nf_nat_htable_size = nf_conntrack_htable_size;
850
851         nf_nat_bysource = nf_ct_alloc_hashtable(&nf_nat_htable_size, 0);
852         if (!nf_nat_bysource)
853                 return -ENOMEM;
854
855         ret = nf_ct_extend_register(&nat_extend);
856         if (ret < 0) {
857                 nf_ct_free_hashtable(nf_nat_bysource, nf_nat_htable_size);
858                 printk(KERN_ERR "nf_nat_core: Unable to register extension\n");
859                 return ret;
860         }
861
862         ret = register_pernet_subsys(&nf_nat_net_ops);
863         if (ret < 0)
864                 goto cleanup_extend;
865
866         nf_ct_helper_expectfn_register(&follow_master_nat);
867
868         /* Initialize fake conntrack so that NAT will skip it */
869         nf_ct_untracked_status_or(IPS_NAT_DONE_MASK);
870
871         BUG_ON(nfnetlink_parse_nat_setup_hook != NULL);
872         RCU_INIT_POINTER(nfnetlink_parse_nat_setup_hook,
873                            nfnetlink_parse_nat_setup);
874 #ifdef CONFIG_XFRM
875         BUG_ON(nf_nat_decode_session_hook != NULL);
876         RCU_INIT_POINTER(nf_nat_decode_session_hook, __nf_nat_decode_session);
877 #endif
878         return 0;
879
880  cleanup_extend:
881         nf_ct_free_hashtable(nf_nat_bysource, nf_nat_htable_size);
882         nf_ct_extend_unregister(&nat_extend);
883         return ret;
884 }
885
886 static void __exit nf_nat_cleanup(void)
887 {
888         unsigned int i;
889
890         unregister_pernet_subsys(&nf_nat_net_ops);
891         nf_ct_extend_unregister(&nat_extend);
892         nf_ct_helper_expectfn_unregister(&follow_master_nat);
893         RCU_INIT_POINTER(nfnetlink_parse_nat_setup_hook, NULL);
894 #ifdef CONFIG_XFRM
895         RCU_INIT_POINTER(nf_nat_decode_session_hook, NULL);
896 #endif
897         for (i = 0; i < NFPROTO_NUMPROTO; i++)
898                 kfree(nf_nat_l4protos[i]);
899         synchronize_net();
900         nf_ct_free_hashtable(nf_nat_bysource, nf_nat_htable_size);
901 }
902
903 MODULE_LICENSE("GPL");
904
905 module_init(nf_nat_init);
906 module_exit(nf_nat_cleanup);