datapath: stt compatibility for RHEL7
[cascardo/ovs.git] / datapath / linux / compat / stt.c
1 /*
2  * Stateless TCP Tunnel (STT) vport.
3  *
4  * Copyright (c) 2015 Nicira, Inc.
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; either version
9  * 2 of the License, or (at your option) any later version.
10  */
11
12 #include <asm/unaligned.h>
13
14 #include <linux/delay.h>
15 #include <linux/flex_array.h>
16 #include <linux/if.h>
17 #include <linux/if_vlan.h>
18 #include <linux/ip.h>
19 #include <linux/ipv6.h>
20 #include <linux/jhash.h>
21 #include <linux/list.h>
22 #include <linux/log2.h>
23 #include <linux/module.h>
24 #include <linux/netfilter.h>
25 #include <linux/percpu.h>
26 #include <linux/skbuff.h>
27 #include <linux/tcp.h>
28 #include <linux/workqueue.h>
29
30 #include <net/icmp.h>
31 #include <net/inet_ecn.h>
32 #include <net/ip.h>
33 #include <net/net_namespace.h>
34 #include <net/netns/generic.h>
35 #include <net/sock.h>
36 #include <net/stt.h>
37 #include <net/tcp.h>
38 #include <net/udp.h>
39
40 #include "gso.h"
41
42 #ifdef OVS_STT
43 #define STT_VER 0
44
45 #define STT_CSUM_VERIFIED       BIT(0)
46 #define STT_CSUM_PARTIAL        BIT(1)
47 #define STT_PROTO_IPV4          BIT(2)
48 #define STT_PROTO_TCP           BIT(3)
49 #define STT_PROTO_TYPES         (STT_PROTO_IPV4 | STT_PROTO_TCP)
50
51 #define SUPPORTED_GSO_TYPES (SKB_GSO_TCPV4 | SKB_GSO_UDP | SKB_GSO_DODGY | \
52                              SKB_GSO_TCPV6)
53
54 /* The length and offset of a fragment are encoded in the sequence number.
55  * STT_SEQ_LEN_SHIFT is the left shift needed to store the length.
56  * STT_SEQ_OFFSET_MASK is the mask to extract the offset.
57  */
58 #define STT_SEQ_LEN_SHIFT 16
59 #define STT_SEQ_OFFSET_MASK (BIT(STT_SEQ_LEN_SHIFT) - 1)
60
61 /* The maximum amount of memory used to store packets waiting to be reassembled
62  * on a given CPU.  Once this threshold is exceeded we will begin freeing the
63  * least recently used fragments.
64  */
65 #define REASM_HI_THRESH (4 * 1024 * 1024)
66 /* The target for the high memory evictor.  Once we have exceeded
67  * REASM_HI_THRESH, we will continue freeing fragments until we hit
68  * this limit.
69  */
70 #define REASM_LO_THRESH (3 * 1024 * 1024)
71 /* The length of time a given packet has to be reassembled from the time the
72  * first fragment arrives.  Once this limit is exceeded it becomes available
73  * for cleaning.
74  */
75 #define FRAG_EXP_TIME (30 * HZ)
76 /* Number of hash entries.  Each entry has only a single slot to hold a packet
77  * so if there are collisions, we will drop packets.  This is allocated
78  * per-cpu and each entry consists of struct pkt_frag.
79  */
80 #define FRAG_HASH_SHIFT         8
81 #define FRAG_HASH_ENTRIES       BIT(FRAG_HASH_SHIFT)
82 #define FRAG_HASH_SEGS          ((sizeof(u32) * 8) / FRAG_HASH_SHIFT)
83
84 #define CLEAN_PERCPU_INTERVAL (30 * HZ)
85
86 struct pkt_key {
87         __be32 saddr;
88         __be32 daddr;
89         __be32 pkt_seq;
90         u32 mark;
91 };
92
93 struct pkt_frag {
94         struct sk_buff *skbs;
95         unsigned long timestamp;
96         struct list_head lru_node;
97         struct pkt_key key;
98 };
99
100 struct stt_percpu {
101         struct flex_array *frag_hash;
102         struct list_head frag_lru;
103         unsigned int frag_mem_used;
104
105         /* Protect frags table. */
106         spinlock_t lock;
107 };
108
109 struct first_frag {
110         struct sk_buff *last_skb;
111         unsigned int mem_used;
112         u16 tot_len;
113         u16 rcvd_len;
114         bool set_ecn_ce;
115 };
116
117 struct frag_skb_cb {
118         u16 offset;
119
120         /* Only valid for the first skb in the chain. */
121         struct first_frag first;
122 };
123
124 #define FRAG_CB(skb) ((struct frag_skb_cb *)(skb)->cb)
125
126 /* per-network namespace private data for this module */
127 struct stt_net {
128         struct list_head sock_list;
129 };
130
131 static int stt_net_id;
132
133 static struct stt_percpu __percpu *stt_percpu_data __read_mostly;
134 static u32 frag_hash_seed __read_mostly;
135
136 /* Protects sock-hash and refcounts. */
137 static DEFINE_MUTEX(stt_mutex);
138
139 static int n_tunnels;
140 static DEFINE_PER_CPU(u32, pkt_seq_counter);
141
142 static void clean_percpu(struct work_struct *work);
143 static DECLARE_DELAYED_WORK(clean_percpu_wq, clean_percpu);
144
145 static struct stt_sock *stt_find_sock(struct net *net, __be16 port)
146 {
147         struct stt_net *sn = net_generic(net, stt_net_id);
148         struct stt_sock *stt_sock;
149
150         list_for_each_entry_rcu(stt_sock, &sn->sock_list, list) {
151                 if (inet_sk(stt_sock->sock->sk)->inet_sport == port)
152                         return stt_sock;
153         }
154         return NULL;
155 }
156
157 static __be32 ack_seq(void)
158 {
159 #if NR_CPUS <= 65536
160         u32 pkt_seq, ack;
161
162         pkt_seq = this_cpu_read(pkt_seq_counter);
163         ack = pkt_seq << ilog2(NR_CPUS) | smp_processor_id();
164         this_cpu_inc(pkt_seq_counter);
165
166         return (__force __be32)ack;
167 #else
168 #error "Support for greater than 64k CPUs not implemented"
169 #endif
170 }
171
172 static int clear_gso(struct sk_buff *skb)
173 {
174         struct skb_shared_info *shinfo = skb_shinfo(skb);
175         int err;
176
177         if (shinfo->gso_type == 0 && shinfo->gso_size == 0 &&
178             shinfo->gso_segs == 0)
179                 return 0;
180
181         err = skb_unclone(skb, GFP_ATOMIC);
182         if (unlikely(err))
183                 return err;
184
185         shinfo = skb_shinfo(skb);
186         shinfo->gso_type = 0;
187         shinfo->gso_size = 0;
188         shinfo->gso_segs = 0;
189         return 0;
190 }
191
192 static struct sk_buff *normalize_frag_list(struct sk_buff *head,
193                                            struct sk_buff **skbp)
194 {
195         struct sk_buff *skb = *skbp;
196         struct sk_buff *last;
197
198         do {
199                 struct sk_buff *frags;
200
201                 if (skb_shared(skb)) {
202                         struct sk_buff *nskb = skb_clone(skb, GFP_ATOMIC);
203
204                         if (unlikely(!nskb))
205                                 return ERR_PTR(-ENOMEM);
206
207                         nskb->next = skb->next;
208                         consume_skb(skb);
209                         skb = nskb;
210                         *skbp = skb;
211                 }
212
213                 if (head) {
214                         head->len -= skb->len;
215                         head->data_len -= skb->len;
216                         head->truesize -= skb->truesize;
217                 }
218
219                 frags = skb_shinfo(skb)->frag_list;
220                 if (frags) {
221                         int err;
222
223                         err = skb_unclone(skb, GFP_ATOMIC);
224                         if (unlikely(err))
225                                 return ERR_PTR(err);
226
227                         last = normalize_frag_list(skb, &frags);
228                         if (IS_ERR(last))
229                                 return last;
230
231                         skb_shinfo(skb)->frag_list = NULL;
232                         last->next = skb->next;
233                         skb->next = frags;
234                 } else {
235                         last = skb;
236                 }
237
238                 skbp = &skb->next;
239         } while ((skb = skb->next));
240
241         return last;
242 }
243
244 /* Takes a linked list of skbs, which potentially contain frag_list
245  * (whose members in turn potentially contain frag_lists, etc.) and
246  * converts them into a single linear linked list.
247  */
248 static int straighten_frag_list(struct sk_buff **skbp)
249 {
250         struct sk_buff *err_skb;
251
252         err_skb = normalize_frag_list(NULL, skbp);
253         if (IS_ERR(err_skb))
254                 return PTR_ERR(err_skb);
255
256         return 0;
257 }
258
259 static void copy_skb_metadata(struct sk_buff *to, struct sk_buff *from)
260 {
261         to->protocol = from->protocol;
262         to->tstamp = from->tstamp;
263         to->priority = from->priority;
264         to->mark = from->mark;
265         to->vlan_tci = from->vlan_tci;
266 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
267         to->vlan_proto = from->vlan_proto;
268 #endif
269         skb_copy_secmark(to, from);
270 }
271
272 static void update_headers(struct sk_buff *skb, bool head,
273                                unsigned int l4_offset, unsigned int hdr_len,
274                                bool ipv4, u32 tcp_seq)
275 {
276         u16 old_len, new_len;
277         __be32 delta;
278         struct tcphdr *tcph;
279         int gso_size;
280
281         if (ipv4) {
282                 struct iphdr *iph = (struct iphdr *)(skb->data + ETH_HLEN);
283
284                 old_len = ntohs(iph->tot_len);
285                 new_len = skb->len - ETH_HLEN;
286                 iph->tot_len = htons(new_len);
287
288                 ip_send_check(iph);
289         } else {
290                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(skb->data + ETH_HLEN);
291
292                 old_len = ntohs(ip6h->payload_len);
293                 new_len = skb->len - ETH_HLEN - sizeof(struct ipv6hdr);
294                 ip6h->payload_len = htons(new_len);
295         }
296
297         tcph = (struct tcphdr *)(skb->data + l4_offset);
298         if (!head) {
299                 tcph->seq = htonl(tcp_seq);
300                 tcph->cwr = 0;
301         }
302
303         if (skb->next) {
304                 tcph->fin = 0;
305                 tcph->psh = 0;
306         }
307
308         delta = htonl(~old_len + new_len);
309         tcph->check = ~csum_fold((__force __wsum)((__force u32)tcph->check +
310                                  (__force u32)delta));
311
312         gso_size = skb_shinfo(skb)->gso_size;
313         if (gso_size && skb->len - hdr_len <= gso_size)
314                 BUG_ON(clear_gso(skb));
315 }
316
317 static bool can_segment(struct sk_buff *head, bool ipv4, bool tcp, bool csum_partial)
318 {
319         /* If no offloading is in use then we don't have enough information
320          * to process the headers.
321          */
322         if (!csum_partial)
323                 goto linearize;
324
325         /* Handling UDP packets requires IP fragmentation, which means that
326          * the L4 checksum can no longer be calculated by hardware (since the
327          * fragments are in different packets.  If we have to compute the
328          * checksum it's faster just to linearize and large UDP packets are
329          * pretty uncommon anyways, so it's not worth dealing with for now.
330          */
331         if (!tcp)
332                 goto linearize;
333
334         if (ipv4) {
335                 struct iphdr *iph = (struct iphdr *)(head->data + ETH_HLEN);
336
337                 /* It's difficult to get the IP IDs exactly right here due to
338                  * varying segment sizes and potentially multiple layers of
339                  * segmentation.  IP ID isn't important when DF is set and DF
340                  * is generally set for TCP packets, so just linearize if it's
341                  * not.
342                  */
343                 if (!(iph->frag_off & htons(IP_DF)))
344                         goto linearize;
345         } else {
346                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(head->data + ETH_HLEN);
347
348                 /* Jumbograms require more processing to update and we'll
349                  * probably never see them, so just linearize.
350                  */
351                 if (ip6h->payload_len == 0)
352                         goto linearize;
353         }
354         return true;
355
356 linearize:
357         return false;
358 }
359
360 static int copy_headers(struct sk_buff *head, struct sk_buff *frag,
361                             int hdr_len)
362 {
363         u16 csum_start;
364
365         if (skb_cloned(frag) || skb_headroom(frag) < hdr_len) {
366                 int extra_head = hdr_len - skb_headroom(frag);
367
368                 extra_head = extra_head > 0 ? extra_head : 0;
369                 if (unlikely(pskb_expand_head(frag, extra_head, 0,
370                                               GFP_ATOMIC)))
371                         return -ENOMEM;
372         }
373
374         memcpy(__skb_push(frag, hdr_len), head->data, hdr_len);
375
376         csum_start = head->csum_start - skb_headroom(head);
377         frag->csum_start = skb_headroom(frag) + csum_start;
378         frag->csum_offset = head->csum_offset;
379         frag->ip_summed = head->ip_summed;
380
381         skb_shinfo(frag)->gso_size = skb_shinfo(head)->gso_size;
382         skb_shinfo(frag)->gso_type = skb_shinfo(head)->gso_type;
383         skb_shinfo(frag)->gso_segs = 0;
384
385         copy_skb_metadata(frag, head);
386         return 0;
387 }
388
389 static int skb_list_segment(struct sk_buff *head, bool ipv4, int l4_offset)
390 {
391         struct sk_buff *skb;
392         struct tcphdr *tcph;
393         int seg_len;
394         int hdr_len;
395         int tcp_len;
396         u32 seq;
397
398         if (unlikely(!pskb_may_pull(head, l4_offset + sizeof(*tcph))))
399                 return -ENOMEM;
400
401         tcph = (struct tcphdr *)(head->data + l4_offset);
402         tcp_len = tcph->doff * 4;
403         hdr_len = l4_offset + tcp_len;
404
405         if (unlikely((tcp_len < sizeof(struct tcphdr)) ||
406                      (head->len < hdr_len)))
407                 return -EINVAL;
408
409         if (unlikely(!pskb_may_pull(head, hdr_len)))
410                 return -ENOMEM;
411
412         tcph = (struct tcphdr *)(head->data + l4_offset);
413         /* Update header of each segment. */
414         seq = ntohl(tcph->seq);
415         seg_len = skb_pagelen(head) - hdr_len;
416
417         skb = skb_shinfo(head)->frag_list;
418         skb_shinfo(head)->frag_list = NULL;
419         head->next = skb;
420         for (; skb; skb = skb->next) {
421                 int err;
422
423                 head->len -= skb->len;
424                 head->data_len -= skb->len;
425                 head->truesize -= skb->truesize;
426
427                 seq += seg_len;
428                 seg_len = skb->len;
429                 err = copy_headers(head, skb, hdr_len);
430                 if (err)
431                         return err;
432                 update_headers(skb, false, l4_offset, hdr_len, ipv4, seq);
433         }
434         update_headers(head, true, l4_offset, hdr_len, ipv4, 0);
435         return 0;
436 }
437
438 static int coalesce_skb(struct sk_buff **headp)
439 {
440         struct sk_buff *frag, *head, *prev;
441         int err;
442
443         err = straighten_frag_list(headp);
444         if (unlikely(err))
445                 return err;
446         head = *headp;
447
448         /* Coalesce frag list. */
449         prev = head;
450         for (frag = head->next; frag; frag = frag->next) {
451                 bool headstolen;
452                 int delta;
453
454                 if (unlikely(skb_unclone(prev, GFP_ATOMIC)))
455                         return -ENOMEM;
456
457                 if (!skb_try_coalesce(prev, frag, &headstolen, &delta)) {
458                         prev = frag;
459                         continue;
460                 }
461
462                 prev->next = frag->next;
463                 frag->len = 0;
464                 frag->data_len = 0;
465                 frag->truesize -= delta;
466                 kfree_skb_partial(frag, headstolen);
467                 frag = prev;
468         }
469
470         if (!head->next)
471                 return 0;
472
473         for (frag = head->next; frag; frag = frag->next) {
474                 head->len += frag->len;
475                 head->data_len += frag->len;
476                 head->truesize += frag->truesize;
477         }
478
479         skb_shinfo(head)->frag_list = head->next;
480         head->next = NULL;
481         return 0;
482 }
483
484 static int __try_to_segment(struct sk_buff *skb, bool csum_partial,
485                             bool ipv4, bool tcp, int l4_offset)
486 {
487         if (can_segment(skb, ipv4, tcp, csum_partial))
488                 return skb_list_segment(skb, ipv4, l4_offset);
489         else
490                 return skb_linearize(skb);
491 }
492
493 static int try_to_segment(struct sk_buff *skb)
494 {
495         struct stthdr *stth = stt_hdr(skb);
496         bool csum_partial = !!(stth->flags & STT_CSUM_PARTIAL);
497         bool ipv4 = !!(stth->flags & STT_PROTO_IPV4);
498         bool tcp = !!(stth->flags & STT_PROTO_TCP);
499         int l4_offset = stth->l4_offset;
500
501         return __try_to_segment(skb, csum_partial, ipv4, tcp, l4_offset);
502 }
503
504 static int segment_skb(struct sk_buff **headp, bool csum_partial,
505                        bool ipv4, bool tcp, int l4_offset)
506 {
507         int err;
508
509         err = coalesce_skb(headp);
510         if (err)
511                 return err;
512
513         if (skb_shinfo(*headp)->frag_list)
514                 return __try_to_segment(*headp, csum_partial,
515                                         ipv4, tcp, l4_offset);
516         return 0;
517 }
518
519 static int __push_stt_header(struct sk_buff *skb, __be64 tun_id,
520                              __be16 s_port, __be16 d_port,
521                              __be32 saddr, __be32 dst,
522                              __be16 l3_proto, u8 l4_proto,
523                              int dst_mtu)
524 {
525         int data_len = skb->len + sizeof(struct stthdr) + STT_ETH_PAD;
526         unsigned short encap_mss;
527         struct tcphdr *tcph;
528         struct stthdr *stth;
529
530         skb_push(skb, STT_HEADER_LEN);
531         skb_reset_transport_header(skb);
532         tcph = tcp_hdr(skb);
533         memset(tcph, 0, STT_HEADER_LEN);
534         stth = stt_hdr(skb);
535
536         if (skb->ip_summed == CHECKSUM_PARTIAL) {
537                 stth->flags |= STT_CSUM_PARTIAL;
538
539                 stth->l4_offset = skb->csum_start -
540                                         (skb_headroom(skb) +
541                                         STT_HEADER_LEN);
542
543                 if (l3_proto == htons(ETH_P_IP))
544                         stth->flags |= STT_PROTO_IPV4;
545
546                 if (l4_proto == IPPROTO_TCP)
547                         stth->flags |= STT_PROTO_TCP;
548
549                 stth->mss = htons(skb_shinfo(skb)->gso_size);
550         } else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
551                 stth->flags |= STT_CSUM_VERIFIED;
552         }
553
554         stth->vlan_tci = htons(skb->vlan_tci);
555         skb->vlan_tci = 0;
556         put_unaligned(tun_id, &stth->key);
557
558         tcph->source    = s_port;
559         tcph->dest      = d_port;
560         tcph->doff      = sizeof(struct tcphdr) / 4;
561         tcph->ack       = 1;
562         tcph->psh       = 1;
563         tcph->window    = htons(USHRT_MAX);
564         tcph->seq       = htonl(data_len << STT_SEQ_LEN_SHIFT);
565         tcph->ack_seq   = ack_seq();
566         tcph->check     = ~tcp_v4_check(skb->len, saddr, dst, 0);
567
568         skb->csum_start = skb_transport_header(skb) - skb->head;
569         skb->csum_offset = offsetof(struct tcphdr, check);
570         skb->ip_summed = CHECKSUM_PARTIAL;
571
572         encap_mss = dst_mtu - sizeof(struct iphdr) - sizeof(struct tcphdr);
573         if (data_len > encap_mss) {
574                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
575                         return -EINVAL;
576
577                 skb_shinfo(skb)->gso_type = SKB_GSO_TCPV4;
578                 skb_shinfo(skb)->gso_size = encap_mss;
579                 skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(data_len, encap_mss);
580         } else {
581                 if (unlikely(clear_gso(skb)))
582                         return -EINVAL;
583         }
584         return 0;
585 }
586
587 static struct sk_buff *push_stt_header(struct sk_buff *head, __be64 tun_id,
588                                        __be16 s_port, __be16 d_port,
589                                        __be32 saddr, __be32 dst,
590                                        __be16 l3_proto, u8 l4_proto,
591                                        int dst_mtu)
592 {
593         struct sk_buff *skb;
594
595         if (skb_shinfo(head)->frag_list) {
596                 bool ipv4 = (l3_proto == htons(ETH_P_IP));
597                 bool tcp = (l4_proto == IPPROTO_TCP);
598                 bool csum_partial = (head->ip_summed == CHECKSUM_PARTIAL);
599                 int l4_offset = skb_transport_offset(head);
600
601                 /* Need to call skb_orphan() to report currect true-size.
602                  * calling skb_orphan() in this layer is odd but SKB with
603                  * frag-list should not be associated with any socket, so
604                  * skb-orphan should be no-op. */
605                 skb_orphan(head);
606                 if (unlikely(segment_skb(&head, csum_partial,
607                                          ipv4, tcp, l4_offset)))
608                         goto error;
609         }
610
611         for (skb = head; skb; skb = skb->next) {
612                 if (__push_stt_header(skb, tun_id, s_port, d_port, saddr, dst,
613                                       l3_proto, l4_proto, dst_mtu))
614                         goto error;
615         }
616
617         return head;
618 error:
619         kfree_skb_list(head);
620         return NULL;
621 }
622
623 static int stt_can_offload(struct sk_buff *skb, __be16 l3_proto, u8 l4_proto)
624 {
625         if (skb_is_gso(skb) && skb->ip_summed != CHECKSUM_PARTIAL) {
626                 int csum_offset;
627                 __sum16 *csum;
628                 int len;
629
630                 if (l4_proto == IPPROTO_TCP)
631                         csum_offset = offsetof(struct tcphdr, check);
632                 else if (l4_proto == IPPROTO_UDP)
633                         csum_offset = offsetof(struct udphdr, check);
634                 else
635                         return 0;
636
637                 len = skb->len - skb_transport_offset(skb);
638                 csum = (__sum16 *)(skb_transport_header(skb) + csum_offset);
639
640                 if (unlikely(!pskb_may_pull(skb, skb_transport_offset(skb) +
641                                                  csum_offset + sizeof(*csum))))
642                         return -EINVAL;
643
644                 if (l3_proto == htons(ETH_P_IP)) {
645                         struct iphdr *iph = ip_hdr(skb);
646
647                         *csum = ~csum_tcpudp_magic(iph->saddr, iph->daddr,
648                                                    len, l4_proto, 0);
649                 } else if (l3_proto == htons(ETH_P_IPV6)) {
650                         struct ipv6hdr *ip6h = ipv6_hdr(skb);
651
652                         *csum = ~csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
653                                                  len, l4_proto, 0);
654                 } else {
655                         return 0;
656                 }
657                 skb->csum_start = skb_transport_header(skb) - skb->head;
658                 skb->csum_offset = csum_offset;
659                 skb->ip_summed = CHECKSUM_PARTIAL;
660         }
661
662         if (skb->ip_summed == CHECKSUM_PARTIAL) {
663                 /* Assume receiver can only offload TCP/UDP over IPv4/6,
664                  * and require 802.1Q VLANs to be accelerated.
665                  */
666                 if (l3_proto != htons(ETH_P_IP) &&
667                     l3_proto != htons(ETH_P_IPV6))
668                         return 0;
669
670                 if (l4_proto != IPPROTO_TCP && l4_proto != IPPROTO_UDP)
671                         return 0;
672
673                 /* L4 offset must fit in a 1-byte field. */
674                 if (skb->csum_start - skb_headroom(skb) > 255)
675                         return 0;
676
677                 if (skb_shinfo(skb)->gso_type & ~SUPPORTED_GSO_TYPES)
678                         return 0;
679         }
680         /* Total size of encapsulated packet must fit in 16 bits. */
681         if (skb->len + STT_HEADER_LEN + sizeof(struct iphdr) > 65535)
682                 return 0;
683
684 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
685         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q))
686                 return 0;
687 #endif
688         return 1;
689 }
690
691 static bool need_linearize(const struct sk_buff *skb)
692 {
693         struct skb_shared_info *shinfo = skb_shinfo(skb);
694         int i;
695
696         if (unlikely(shinfo->frag_list))
697                 return true;
698
699         /* Generally speaking we should linearize if there are paged frags.
700          * However, if all of the refcounts are 1 we know nobody else can
701          * change them from underneath us and we can skip the linearization.
702          */
703         for (i = 0; i < shinfo->nr_frags; i++)
704                 if (unlikely(page_count(skb_frag_page(&shinfo->frags[i])) > 1))
705                         return true;
706
707         return false;
708 }
709
710 static struct sk_buff *handle_offloads(struct sk_buff *skb, int min_headroom)
711 {
712         int err;
713
714 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
715         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q)) {
716
717                 min_headroom += VLAN_HLEN;
718                 if (skb_headroom(skb) < min_headroom) {
719                         int head_delta = SKB_DATA_ALIGN(min_headroom -
720                                                         skb_headroom(skb) + 16);
721
722                         err = pskb_expand_head(skb, max_t(int, head_delta, 0),
723                                                0, GFP_ATOMIC);
724                         if (unlikely(err))
725                                 goto error;
726                 }
727
728                 skb = __vlan_hwaccel_push_inside(skb);
729                 if (!skb) {
730                         err = -ENOMEM;
731                         goto error;
732                 }
733         }
734 #endif
735
736         if (skb_is_gso(skb)) {
737                 struct sk_buff *nskb;
738                 char cb[sizeof(skb->cb)];
739
740                 memcpy(cb, skb->cb, sizeof(cb));
741
742                 nskb = __skb_gso_segment(skb, 0, false);
743                 if (IS_ERR(nskb)) {
744                         err = PTR_ERR(nskb);
745                         goto error;
746                 }
747
748                 consume_skb(skb);
749                 skb = nskb;
750                 while (nskb) {
751                         memcpy(nskb->cb, cb, sizeof(cb));
752                         nskb = nskb->next;
753                 }
754         } else if (skb->ip_summed == CHECKSUM_PARTIAL) {
755                 /* Pages aren't locked and could change at any time.
756                  * If this happens after we compute the checksum, the
757                  * checksum will be wrong.  We linearize now to avoid
758                  * this problem.
759                  */
760                 if (unlikely(need_linearize(skb))) {
761                         err = __skb_linearize(skb);
762                         if (unlikely(err))
763                                 goto error;
764                 }
765
766                 err = skb_checksum_help(skb);
767                 if (unlikely(err))
768                         goto error;
769         }
770         skb->ip_summed = CHECKSUM_NONE;
771
772         return skb;
773 error:
774         kfree_skb(skb);
775         return ERR_PTR(err);
776 }
777
778 static int skb_list_xmit(struct rtable *rt, struct sk_buff *skb, __be32 src,
779                          __be32 dst, __u8 tos, __u8 ttl, __be16 df)
780 {
781         int len = 0;
782
783         while (skb) {
784                 struct sk_buff *next = skb->next;
785
786                 if (next)
787                         dst_clone(&rt->dst);
788
789                 skb_clear_ovs_gso_cb(skb);
790                 skb->next = NULL;
791                 len += iptunnel_xmit(NULL, rt, skb, src, dst, IPPROTO_TCP,
792                                      tos, ttl, df, false);
793
794                 skb = next;
795         }
796         return len;
797 }
798
799 static u8 parse_ipv6_l4_proto(struct sk_buff *skb)
800 {
801         unsigned int nh_ofs = skb_network_offset(skb);
802         int payload_ofs;
803         struct ipv6hdr *nh;
804         uint8_t nexthdr;
805         __be16 frag_off;
806
807         if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct ipv6hdr))))
808                 return 0;
809
810         nh = ipv6_hdr(skb);
811         nexthdr = nh->nexthdr;
812         payload_ofs = (u8 *)(nh + 1) - skb->data;
813
814         payload_ofs = ipv6_skip_exthdr(skb, payload_ofs, &nexthdr, &frag_off);
815         if (unlikely(payload_ofs < 0))
816                 return 0;
817
818         return nexthdr;
819 }
820
821 static u8 skb_get_l4_proto(struct sk_buff *skb, __be16 l3_proto)
822 {
823         if (l3_proto == htons(ETH_P_IP)) {
824                 unsigned int nh_ofs = skb_network_offset(skb);
825
826                 if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct iphdr))))
827                         return 0;
828
829                 return ip_hdr(skb)->protocol;
830         } else if (l3_proto == htons(ETH_P_IPV6)) {
831                 return parse_ipv6_l4_proto(skb);
832         }
833         return 0;
834 }
835
836 int rpl_stt_xmit_skb(struct sk_buff *skb, struct rtable *rt,
837                  __be32 src, __be32 dst, __u8 tos,
838                  __u8 ttl, __be16 df, __be16 src_port, __be16 dst_port,
839                  __be64 tun_id)
840 {
841         struct ethhdr *eh = eth_hdr(skb);
842         int ret = 0, min_headroom;
843         __be16 inner_l3_proto;
844          u8 inner_l4_proto;
845
846         inner_l3_proto = eh->h_proto;
847         inner_l4_proto = skb_get_l4_proto(skb, inner_l3_proto);
848
849         min_headroom = LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len
850                         + STT_HEADER_LEN + sizeof(struct iphdr);
851
852         if (skb_headroom(skb) < min_headroom || skb_header_cloned(skb)) {
853                 int head_delta = SKB_DATA_ALIGN(min_headroom -
854                                                 skb_headroom(skb) +
855                                                 16);
856
857                 ret = pskb_expand_head(skb, max_t(int, head_delta, 0),
858                                        0, GFP_ATOMIC);
859                 if (unlikely(ret))
860                         goto err_free_rt;
861         }
862
863         ret = stt_can_offload(skb, inner_l3_proto, inner_l4_proto);
864         if (ret < 0)
865                 goto err_free_rt;
866         if (!ret) {
867                 skb = handle_offloads(skb, min_headroom);
868                 if (IS_ERR(skb)) {
869                         ret = PTR_ERR(skb);
870                         skb = NULL;
871                         goto err_free_rt;
872                 }
873         }
874
875         ret = 0;
876         while (skb) {
877                 struct sk_buff *next_skb = skb->next;
878
879                 skb->next = NULL;
880
881                 if (next_skb)
882                         dst_clone(&rt->dst);
883
884                 /* Push STT and TCP header. */
885                 skb = push_stt_header(skb, tun_id, src_port, dst_port, src,
886                                       dst, inner_l3_proto, inner_l4_proto,
887                                       dst_mtu(&rt->dst));
888                 if (unlikely(!skb)) {
889                         ip_rt_put(rt);
890                         goto next;
891                 }
892
893                 /* Push IP header. */
894                 ret += skb_list_xmit(rt, skb, src, dst, tos, ttl, df);
895
896 next:
897                 skb = next_skb;
898         }
899
900         return ret;
901
902 err_free_rt:
903         ip_rt_put(rt);
904         kfree_skb(skb);
905         return ret;
906 }
907 EXPORT_SYMBOL_GPL(rpl_stt_xmit_skb);
908
909 static void free_frag(struct stt_percpu *stt_percpu,
910                       struct pkt_frag *frag)
911 {
912         stt_percpu->frag_mem_used -= FRAG_CB(frag->skbs)->first.mem_used;
913         kfree_skb_list(frag->skbs);
914         list_del(&frag->lru_node);
915         frag->skbs = NULL;
916 }
917
918 static void evict_frags(struct stt_percpu *stt_percpu)
919 {
920         while (!list_empty(&stt_percpu->frag_lru) &&
921                stt_percpu->frag_mem_used > REASM_LO_THRESH) {
922                 struct pkt_frag *frag;
923
924                 frag = list_first_entry(&stt_percpu->frag_lru,
925                                         struct pkt_frag,
926                                         lru_node);
927                 free_frag(stt_percpu, frag);
928         }
929 }
930
931 static bool pkt_key_match(struct net *net,
932                           const struct pkt_frag *a, const struct pkt_key *b)
933 {
934         return a->key.saddr == b->saddr && a->key.daddr == b->daddr &&
935                a->key.pkt_seq == b->pkt_seq && a->key.mark == b->mark &&
936                net_eq(dev_net(a->skbs->dev), net);
937 }
938
939 static u32 pkt_key_hash(const struct net *net, const struct pkt_key *key)
940 {
941         u32 initval = frag_hash_seed ^ (u32)(unsigned long)net ^ key->mark;
942
943         return jhash_3words((__force u32)key->saddr, (__force u32)key->daddr,
944                             (__force u32)key->pkt_seq, initval);
945 }
946
947 static struct pkt_frag *lookup_frag(struct net *net,
948                                     struct stt_percpu *stt_percpu,
949                                     const struct pkt_key *key, u32 hash)
950 {
951         struct pkt_frag *frag, *victim_frag = NULL;
952         int i;
953
954         for (i = 0; i < FRAG_HASH_SEGS; i++) {
955                 frag = flex_array_get(stt_percpu->frag_hash,
956                                       hash & (FRAG_HASH_ENTRIES - 1));
957
958                 if (frag->skbs &&
959                     time_before(jiffies, frag->timestamp + FRAG_EXP_TIME) &&
960                     pkt_key_match(net, frag, key))
961                         return frag;
962
963                 if (!victim_frag ||
964                     (victim_frag->skbs &&
965                      (!frag->skbs ||
966                       time_before(frag->timestamp, victim_frag->timestamp))))
967                         victim_frag = frag;
968
969                 hash >>= FRAG_HASH_SHIFT;
970         }
971
972         if (victim_frag->skbs)
973                 free_frag(stt_percpu, victim_frag);
974
975         return victim_frag;
976 }
977
978 static struct sk_buff *reassemble(struct sk_buff *skb)
979 {
980         struct iphdr *iph = ip_hdr(skb);
981         struct tcphdr *tcph = tcp_hdr(skb);
982         u32 seq = ntohl(tcph->seq);
983         struct stt_percpu *stt_percpu;
984         struct sk_buff *last_skb;
985         struct pkt_frag *frag;
986         struct pkt_key key;
987         int tot_len;
988         u32 hash;
989
990         tot_len = seq >> STT_SEQ_LEN_SHIFT;
991         FRAG_CB(skb)->offset = seq & STT_SEQ_OFFSET_MASK;
992
993         if (unlikely(skb->len == 0))
994                 goto out_free;
995
996         if (unlikely(FRAG_CB(skb)->offset + skb->len > tot_len))
997                 goto out_free;
998
999         if (tot_len == skb->len)
1000                 goto out;
1001
1002         key.saddr = iph->saddr;
1003         key.daddr = iph->daddr;
1004         key.pkt_seq = tcph->ack_seq;
1005         key.mark = skb->mark;
1006         hash = pkt_key_hash(dev_net(skb->dev), &key);
1007
1008         stt_percpu = per_cpu_ptr(stt_percpu_data, smp_processor_id());
1009
1010         spin_lock(&stt_percpu->lock);
1011
1012         if (unlikely(stt_percpu->frag_mem_used + skb->truesize > REASM_HI_THRESH))
1013                 evict_frags(stt_percpu);
1014
1015         frag = lookup_frag(dev_net(skb->dev), stt_percpu, &key, hash);
1016         if (!frag->skbs) {
1017                 frag->skbs = skb;
1018                 frag->key = key;
1019                 frag->timestamp = jiffies;
1020                 FRAG_CB(skb)->first.last_skb = skb;
1021                 FRAG_CB(skb)->first.mem_used = skb->truesize;
1022                 FRAG_CB(skb)->first.tot_len = tot_len;
1023                 FRAG_CB(skb)->first.rcvd_len = skb->len;
1024                 FRAG_CB(skb)->first.set_ecn_ce = false;
1025                 list_add_tail(&frag->lru_node, &stt_percpu->frag_lru);
1026                 stt_percpu->frag_mem_used += skb->truesize;
1027
1028                 skb = NULL;
1029                 goto unlock;
1030         }
1031
1032         /* Optimize for the common case where fragments are received in-order
1033          * and not overlapping.
1034          */
1035         last_skb = FRAG_CB(frag->skbs)->first.last_skb;
1036         if (likely(FRAG_CB(last_skb)->offset + last_skb->len ==
1037                    FRAG_CB(skb)->offset)) {
1038                 last_skb->next = skb;
1039                 FRAG_CB(frag->skbs)->first.last_skb = skb;
1040         } else {
1041                 struct sk_buff *prev = NULL, *next;
1042
1043                 for (next = frag->skbs; next; next = next->next) {
1044                         if (FRAG_CB(next)->offset >= FRAG_CB(skb)->offset)
1045                                 break;
1046                         prev = next;
1047                 }
1048
1049                 /* Overlapping fragments aren't allowed.  We shouldn't start
1050                  * before the end of the previous fragment.
1051                  */
1052                 if (prev &&
1053                     FRAG_CB(prev)->offset + prev->len > FRAG_CB(skb)->offset)
1054                         goto unlock_free;
1055
1056                 /* We also shouldn't end after the beginning of the next
1057                  * fragment.
1058                  */
1059                 if (next &&
1060                     FRAG_CB(skb)->offset + skb->len > FRAG_CB(next)->offset)
1061                         goto unlock_free;
1062
1063                 if (prev) {
1064                         prev->next = skb;
1065                 } else {
1066                         FRAG_CB(skb)->first = FRAG_CB(frag->skbs)->first;
1067                         frag->skbs = skb;
1068                 }
1069
1070                 if (next)
1071                         skb->next = next;
1072                 else
1073                         FRAG_CB(frag->skbs)->first.last_skb = skb;
1074         }
1075
1076         FRAG_CB(frag->skbs)->first.set_ecn_ce |= INET_ECN_is_ce(iph->tos);
1077         FRAG_CB(frag->skbs)->first.rcvd_len += skb->len;
1078         FRAG_CB(frag->skbs)->first.mem_used += skb->truesize;
1079         stt_percpu->frag_mem_used += skb->truesize;
1080
1081         if (FRAG_CB(frag->skbs)->first.tot_len ==
1082             FRAG_CB(frag->skbs)->first.rcvd_len) {
1083                 struct sk_buff *frag_head = frag->skbs;
1084
1085                 frag_head->tstamp = skb->tstamp;
1086                 if (FRAG_CB(frag_head)->first.set_ecn_ce)
1087                         INET_ECN_set_ce(frag_head);
1088
1089                 list_del(&frag->lru_node);
1090                 stt_percpu->frag_mem_used -= FRAG_CB(frag_head)->first.mem_used;
1091                 frag->skbs = NULL;
1092                 skb = frag_head;
1093         } else {
1094                 list_move_tail(&frag->lru_node, &stt_percpu->frag_lru);
1095                 skb = NULL;
1096         }
1097
1098         goto unlock;
1099
1100 unlock_free:
1101         kfree_skb(skb);
1102         skb = NULL;
1103 unlock:
1104         spin_unlock(&stt_percpu->lock);
1105         return skb;
1106 out_free:
1107         kfree_skb(skb);
1108         skb = NULL;
1109 out:
1110         return skb;
1111 }
1112
1113 static bool validate_checksum(struct sk_buff *skb)
1114 {
1115         struct iphdr *iph = ip_hdr(skb);
1116
1117         if (skb_csum_unnecessary(skb))
1118                 return true;
1119
1120         if (skb->ip_summed == CHECKSUM_COMPLETE &&
1121             !tcp_v4_check(skb->len, iph->saddr, iph->daddr, skb->csum))
1122                 return true;
1123
1124         skb->csum = csum_tcpudp_nofold(iph->saddr, iph->daddr, skb->len,
1125                                        IPPROTO_TCP, 0);
1126
1127         return __tcp_checksum_complete(skb) == 0;
1128 }
1129
1130 static bool set_offloads(struct sk_buff *skb)
1131 {
1132         struct stthdr *stth = stt_hdr(skb);
1133         unsigned short gso_type;
1134         int l3_header_size;
1135         int l4_header_size;
1136         u16 csum_offset;
1137         u8 proto_type;
1138
1139         if (stth->vlan_tci)
1140                 __vlan_hwaccel_put_tag(skb, htons(ETH_P_8021Q),
1141                                        ntohs(stth->vlan_tci));
1142
1143         if (!(stth->flags & STT_CSUM_PARTIAL)) {
1144                 if (stth->flags & STT_CSUM_VERIFIED)
1145                         skb->ip_summed = CHECKSUM_UNNECESSARY;
1146                 else
1147                         skb->ip_summed = CHECKSUM_NONE;
1148
1149                 return clear_gso(skb) == 0;
1150         }
1151
1152         proto_type = stth->flags & STT_PROTO_TYPES;
1153
1154         switch (proto_type) {
1155         case (STT_PROTO_IPV4 | STT_PROTO_TCP):
1156                 /* TCP/IPv4 */
1157                 csum_offset = offsetof(struct tcphdr, check);
1158                 gso_type = SKB_GSO_TCPV4;
1159                 l3_header_size = sizeof(struct iphdr);
1160                 l4_header_size = sizeof(struct tcphdr);
1161                 skb->protocol = htons(ETH_P_IP);
1162                 break;
1163         case STT_PROTO_TCP:
1164                 /* TCP/IPv6 */
1165                 csum_offset = offsetof(struct tcphdr, check);
1166                 gso_type = SKB_GSO_TCPV6;
1167                 l3_header_size = sizeof(struct ipv6hdr);
1168                 l4_header_size = sizeof(struct tcphdr);
1169                 skb->protocol = htons(ETH_P_IPV6);
1170                 break;
1171         case STT_PROTO_IPV4:
1172                 /* UDP/IPv4 */
1173                 csum_offset = offsetof(struct udphdr, check);
1174                 gso_type = SKB_GSO_UDP;
1175                 l3_header_size = sizeof(struct iphdr);
1176                 l4_header_size = sizeof(struct udphdr);
1177                 skb->protocol = htons(ETH_P_IP);
1178                 break;
1179         default:
1180                 /* UDP/IPv6 */
1181                 csum_offset = offsetof(struct udphdr, check);
1182                 gso_type = SKB_GSO_UDP;
1183                 l3_header_size = sizeof(struct ipv6hdr);
1184                 l4_header_size = sizeof(struct udphdr);
1185                 skb->protocol = htons(ETH_P_IPV6);
1186         }
1187
1188         if (unlikely(stth->l4_offset < ETH_HLEN + l3_header_size))
1189                 return false;
1190
1191         if (unlikely(!pskb_may_pull(skb, stth->l4_offset + l4_header_size)))
1192                 return false;
1193
1194         stth = stt_hdr(skb);
1195
1196         skb->csum_start = skb_headroom(skb) + stth->l4_offset;
1197         skb->csum_offset = csum_offset;
1198         skb->ip_summed = CHECKSUM_PARTIAL;
1199
1200         if (stth->mss) {
1201                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
1202                         return false;
1203
1204                 skb_shinfo(skb)->gso_type = gso_type | SKB_GSO_DODGY;
1205                 skb_shinfo(skb)->gso_size = ntohs(stth->mss);
1206                 skb_shinfo(skb)->gso_segs = 0;
1207         } else {
1208                 if (unlikely(clear_gso(skb)))
1209                         return false;
1210         }
1211
1212         return true;
1213 }
1214 static void stt_rcv(struct stt_sock *stt_sock, struct sk_buff *skb)
1215 {
1216         int err;
1217
1218         if (unlikely(!validate_checksum(skb)))
1219                 goto drop;
1220
1221         skb = reassemble(skb);
1222         if (!skb)
1223                 return;
1224
1225         if (skb->next && coalesce_skb(&skb))
1226                 goto drop;
1227
1228         err = iptunnel_pull_header(skb,
1229                                    sizeof(struct stthdr) + STT_ETH_PAD,
1230                                    htons(ETH_P_TEB));
1231         if (unlikely(err))
1232                 goto drop;
1233
1234         if (unlikely(stt_hdr(skb)->version != 0))
1235                 goto drop;
1236
1237         if (unlikely(!set_offloads(skb)))
1238                 goto drop;
1239
1240         if (skb_shinfo(skb)->frag_list && try_to_segment(skb))
1241                 goto drop;
1242
1243         stt_sock->rcv(stt_sock, skb);
1244         return;
1245 drop:
1246         /* Consume bad packet */
1247         kfree_skb_list(skb);
1248 }
1249
1250 static void tcp_sock_release(struct socket *sock)
1251 {
1252         kernel_sock_shutdown(sock, SHUT_RDWR);
1253         sk_release_kernel(sock->sk);
1254 }
1255
1256 static int tcp_sock_create4(struct net *net, __be16 port,
1257                             struct socket **sockp)
1258 {
1259         struct sockaddr_in tcp_addr;
1260         struct socket *sock = NULL;
1261         int err;
1262
1263         err = sock_create_kern(AF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
1264         if (err < 0)
1265                 goto error;
1266
1267         sk_change_net(sock->sk, net);
1268
1269         memset(&tcp_addr, 0, sizeof(tcp_addr));
1270         tcp_addr.sin_family = AF_INET;
1271         tcp_addr.sin_addr.s_addr = htonl(INADDR_ANY);
1272         tcp_addr.sin_port = port;
1273         err = kernel_bind(sock, (struct sockaddr *)&tcp_addr,
1274                           sizeof(tcp_addr));
1275         if (err < 0)
1276                 goto error;
1277
1278         *sockp = sock;
1279         return 0;
1280
1281 error:
1282         if (sock)
1283                 tcp_sock_release(sock);
1284         *sockp = NULL;
1285         return err;
1286 }
1287
1288 static void schedule_clean_percpu(void)
1289 {
1290         schedule_delayed_work(&clean_percpu_wq, CLEAN_PERCPU_INTERVAL);
1291 }
1292
1293 static void clean_percpu(struct work_struct *work)
1294 {
1295         int i;
1296
1297         for_each_possible_cpu(i) {
1298                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1299                 int j;
1300
1301                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1302                         struct pkt_frag *frag;
1303
1304                         frag = flex_array_get(stt_percpu->frag_hash, j);
1305                         if (!frag->skbs ||
1306                             time_before(jiffies, frag->timestamp + FRAG_EXP_TIME))
1307                                 continue;
1308
1309                         spin_lock_bh(&stt_percpu->lock);
1310
1311                         if (frag->skbs &&
1312                             time_after(jiffies, frag->timestamp + FRAG_EXP_TIME))
1313                                 free_frag(stt_percpu, frag);
1314
1315                         spin_unlock_bh(&stt_percpu->lock);
1316                 }
1317         }
1318         schedule_clean_percpu();
1319 }
1320
1321 #ifdef HAVE_NF_HOOKFN_ARG_OPS
1322 #define FIRST_PARAM const struct nf_hook_ops *ops,
1323 #else
1324 #define FIRST_PARAM unsigned int hooknum,
1325 #endif
1326
1327 static unsigned int nf_ip_hook(FIRST_PARAM
1328                                struct sk_buff *skb,
1329                                const struct net_device *in,
1330                                const struct net_device *out,
1331                                int (*okfn)(struct sk_buff *))
1332 {
1333         struct stt_sock *stt_sock;
1334         int ip_hdr_len;
1335
1336         if (ip_hdr(skb)->protocol != IPPROTO_TCP)
1337                 return NF_ACCEPT;
1338
1339         ip_hdr_len = ip_hdrlen(skb);
1340         if (unlikely(!pskb_may_pull(skb, ip_hdr_len + sizeof(struct tcphdr))))
1341                 return NF_ACCEPT;
1342
1343         skb_set_transport_header(skb, ip_hdr_len);
1344
1345         stt_sock = stt_find_sock(dev_net(skb->dev), tcp_hdr(skb)->dest);
1346         if (!stt_sock)
1347                 return NF_ACCEPT;
1348
1349         __skb_pull(skb, ip_hdr_len + sizeof(struct tcphdr));
1350         stt_rcv(stt_sock, skb);
1351         return NF_STOLEN;
1352 }
1353
1354 static struct nf_hook_ops nf_hook_ops __read_mostly = {
1355         .hook           = nf_ip_hook,
1356         .owner          = THIS_MODULE,
1357         .pf             = NFPROTO_IPV4,
1358         .hooknum        = NF_INET_LOCAL_IN,
1359         .priority       = INT_MAX,
1360 };
1361
1362 static int stt_start(void)
1363 {
1364         int err;
1365         int i;
1366
1367         if (n_tunnels) {
1368                 n_tunnels++;
1369                 return 0;
1370         }
1371         get_random_bytes(&frag_hash_seed, sizeof(u32));
1372
1373         stt_percpu_data = alloc_percpu(struct stt_percpu);
1374         if (!stt_percpu_data) {
1375                 err = -ENOMEM;
1376                 goto error;
1377         }
1378
1379         for_each_possible_cpu(i) {
1380                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1381                 struct flex_array *frag_hash;
1382
1383                 spin_lock_init(&stt_percpu->lock);
1384                 INIT_LIST_HEAD(&stt_percpu->frag_lru);
1385                 get_random_bytes(&per_cpu(pkt_seq_counter, i), sizeof(u32));
1386
1387                 frag_hash = flex_array_alloc(sizeof(struct pkt_frag),
1388                                              FRAG_HASH_ENTRIES,
1389                                              GFP_KERNEL | __GFP_ZERO);
1390                 if (!frag_hash) {
1391                         err = -ENOMEM;
1392                         goto free_percpu;
1393                 }
1394                 stt_percpu->frag_hash = frag_hash;
1395
1396                 err = flex_array_prealloc(stt_percpu->frag_hash, 0,
1397                                           FRAG_HASH_ENTRIES,
1398                                           GFP_KERNEL | __GFP_ZERO);
1399                 if (err)
1400                         goto free_percpu;
1401         }
1402         err = nf_register_hook(&nf_hook_ops);
1403         if (err)
1404                 goto free_percpu;
1405
1406         schedule_clean_percpu();
1407         n_tunnels++;
1408         return 0;
1409
1410 free_percpu:
1411         for_each_possible_cpu(i) {
1412                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1413
1414                 if (stt_percpu->frag_hash)
1415                         flex_array_free(stt_percpu->frag_hash);
1416         }
1417
1418         free_percpu(stt_percpu_data);
1419
1420 error:
1421         return err;
1422 }
1423
1424 static void stt_cleanup(void)
1425 {
1426         int i;
1427
1428         n_tunnels--;
1429         if (n_tunnels)
1430                 return;
1431
1432         cancel_delayed_work_sync(&clean_percpu_wq);
1433         nf_unregister_hook(&nf_hook_ops);
1434
1435         for_each_possible_cpu(i) {
1436                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1437                 int j;
1438
1439                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1440                         struct pkt_frag *frag;
1441
1442                         frag = flex_array_get(stt_percpu->frag_hash, j);
1443                         kfree_skb_list(frag->skbs);
1444                 }
1445
1446                 flex_array_free(stt_percpu->frag_hash);
1447         }
1448
1449         free_percpu(stt_percpu_data);
1450 }
1451
1452 static struct stt_sock *stt_socket_create(struct net *net, __be16 port,
1453                                           stt_rcv_t *rcv, void *data)
1454 {
1455         struct stt_net *sn = net_generic(net, stt_net_id);
1456         struct stt_sock *stt_sock;
1457         struct socket *sock;
1458         int err;
1459
1460         stt_sock = kzalloc(sizeof(*stt_sock), GFP_KERNEL);
1461         if (!stt_sock)
1462                 return ERR_PTR(-ENOMEM);
1463
1464         err = tcp_sock_create4(net, port, &sock);
1465         if (err) {
1466                 kfree(stt_sock);
1467                 return ERR_PTR(err);
1468         }
1469
1470         stt_sock->sock = sock;
1471         stt_sock->rcv = rcv;
1472         stt_sock->rcv_data = data;
1473
1474         list_add_rcu(&stt_sock->list, &sn->sock_list);
1475
1476         return stt_sock;
1477 }
1478
1479 static void __stt_sock_release(struct stt_sock *stt_sock)
1480 {
1481         list_del_rcu(&stt_sock->list);
1482         tcp_sock_release(stt_sock->sock);
1483         kfree_rcu(stt_sock, rcu);
1484 }
1485
1486 struct stt_sock *rpl_stt_sock_add(struct net *net, __be16 port,
1487                               stt_rcv_t *rcv, void *data)
1488 {
1489         struct stt_sock *stt_sock;
1490         int err;
1491
1492         err = stt_start();
1493         if (err)
1494                 return ERR_PTR(err);
1495
1496         mutex_lock(&stt_mutex);
1497         rcu_read_lock();
1498         stt_sock = stt_find_sock(net, port);
1499         rcu_read_unlock();
1500         if (stt_sock)
1501                 stt_sock = ERR_PTR(-EBUSY);
1502         else
1503                 stt_sock = stt_socket_create(net, port, rcv, data);
1504
1505         mutex_unlock(&stt_mutex);
1506
1507         if (IS_ERR(stt_sock))
1508                 stt_cleanup();
1509
1510         return stt_sock;
1511 }
1512 EXPORT_SYMBOL_GPL(rpl_stt_sock_add);
1513
1514 void rpl_stt_sock_release(struct stt_sock *stt_sock)
1515 {
1516         mutex_lock(&stt_mutex);
1517         if (stt_sock) {
1518                 __stt_sock_release(stt_sock);
1519                 stt_cleanup();
1520         }
1521         mutex_unlock(&stt_mutex);
1522 }
1523 EXPORT_SYMBOL_GPL(rpl_stt_sock_release);
1524
1525 static int stt_init_net(struct net *net)
1526 {
1527         struct stt_net *sn = net_generic(net, stt_net_id);
1528
1529         INIT_LIST_HEAD(&sn->sock_list);
1530         return 0;
1531 }
1532
1533 static struct pernet_operations stt_net_ops = {
1534         .init = stt_init_net,
1535         .id   = &stt_net_id,
1536         .size = sizeof(struct stt_net),
1537 };
1538
1539 int ovs_stt_init_module(void)
1540 {
1541         return register_pernet_subsys(&stt_net_ops);
1542 }
1543 EXPORT_SYMBOL_GPL(ovs_stt_init_module);
1544
1545 void ovs_stt_cleanup_module(void)
1546 {
1547         unregister_pernet_subsys(&stt_net_ops);
1548 }
1549 EXPORT_SYMBOL_GPL(ovs_stt_cleanup_module);
1550 #endif