0659c0b63a26533ba4d334bf5bf12010e42fa414
[cascardo/ovs.git] / datapath / linux / compat / stt.c
1 /*
2  * Stateless TCP Tunnel (STT) vport.
3  *
4  * Copyright (c) 2015 Nicira, Inc.
5  *
6  * This program is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU General Public License
8  * as published by the Free Software Foundation; either version
9  * 2 of the License, or (at your option) any later version.
10  */
11
12 #include <asm/unaligned.h>
13
14 #include <linux/delay.h>
15 #include <linux/flex_array.h>
16 #include <linux/if.h>
17 #include <linux/if_vlan.h>
18 #include <linux/ip.h>
19 #include <linux/ipv6.h>
20 #include <linux/jhash.h>
21 #include <linux/list.h>
22 #include <linux/log2.h>
23 #include <linux/module.h>
24 #include <linux/net.h>
25 #include <linux/netfilter.h>
26 #include <linux/percpu.h>
27 #include <linux/skbuff.h>
28 #include <linux/tcp.h>
29 #include <linux/workqueue.h>
30
31 #include <net/icmp.h>
32 #include <net/inet_ecn.h>
33 #include <net/ip.h>
34 #include <net/ip6_checksum.h>
35 #include <net/net_namespace.h>
36 #include <net/netns/generic.h>
37 #include <net/sock.h>
38 #include <net/stt.h>
39 #include <net/tcp.h>
40 #include <net/udp.h>
41
42 #include "gso.h"
43
44 #ifdef OVS_STT
45 #define STT_VER 0
46
47 #define STT_CSUM_VERIFIED       BIT(0)
48 #define STT_CSUM_PARTIAL        BIT(1)
49 #define STT_PROTO_IPV4          BIT(2)
50 #define STT_PROTO_TCP           BIT(3)
51 #define STT_PROTO_TYPES         (STT_PROTO_IPV4 | STT_PROTO_TCP)
52
53 #define SUPPORTED_GSO_TYPES (SKB_GSO_TCPV4 | SKB_GSO_UDP | SKB_GSO_DODGY | \
54                              SKB_GSO_TCPV6)
55
56 /* The length and offset of a fragment are encoded in the sequence number.
57  * STT_SEQ_LEN_SHIFT is the left shift needed to store the length.
58  * STT_SEQ_OFFSET_MASK is the mask to extract the offset.
59  */
60 #define STT_SEQ_LEN_SHIFT 16
61 #define STT_SEQ_OFFSET_MASK (BIT(STT_SEQ_LEN_SHIFT) - 1)
62
63 /* The maximum amount of memory used to store packets waiting to be reassembled
64  * on a given CPU.  Once this threshold is exceeded we will begin freeing the
65  * least recently used fragments.
66  */
67 #define REASM_HI_THRESH (4 * 1024 * 1024)
68 /* The target for the high memory evictor.  Once we have exceeded
69  * REASM_HI_THRESH, we will continue freeing fragments until we hit
70  * this limit.
71  */
72 #define REASM_LO_THRESH (3 * 1024 * 1024)
73 /* The length of time a given packet has to be reassembled from the time the
74  * first fragment arrives.  Once this limit is exceeded it becomes available
75  * for cleaning.
76  */
77 #define FRAG_EXP_TIME (30 * HZ)
78 /* Number of hash entries.  Each entry has only a single slot to hold a packet
79  * so if there are collisions, we will drop packets.  This is allocated
80  * per-cpu and each entry consists of struct pkt_frag.
81  */
82 #define FRAG_HASH_SHIFT         8
83 #define FRAG_HASH_ENTRIES       BIT(FRAG_HASH_SHIFT)
84 #define FRAG_HASH_SEGS          ((sizeof(u32) * 8) / FRAG_HASH_SHIFT)
85
86 #define CLEAN_PERCPU_INTERVAL (30 * HZ)
87
88 struct pkt_key {
89         __be32 saddr;
90         __be32 daddr;
91         __be32 pkt_seq;
92         u32 mark;
93 };
94
95 struct pkt_frag {
96         struct sk_buff *skbs;
97         unsigned long timestamp;
98         struct list_head lru_node;
99         struct pkt_key key;
100 };
101
102 struct stt_percpu {
103         struct flex_array *frag_hash;
104         struct list_head frag_lru;
105         unsigned int frag_mem_used;
106
107         /* Protect frags table. */
108         spinlock_t lock;
109 };
110
111 struct first_frag {
112         struct sk_buff *last_skb;
113         unsigned int mem_used;
114         u16 tot_len;
115         u16 rcvd_len;
116         bool set_ecn_ce;
117 };
118
119 struct frag_skb_cb {
120         u16 offset;
121
122         /* Only valid for the first skb in the chain. */
123         struct first_frag first;
124 };
125
126 #define FRAG_CB(skb) ((struct frag_skb_cb *)(skb)->cb)
127
128 /* per-network namespace private data for this module */
129 struct stt_net {
130         struct list_head sock_list;
131 };
132
133 static int stt_net_id;
134
135 static struct stt_percpu __percpu *stt_percpu_data __read_mostly;
136 static u32 frag_hash_seed __read_mostly;
137
138 /* Protects sock-hash and refcounts. */
139 static DEFINE_MUTEX(stt_mutex);
140
141 static int n_tunnels;
142 static DEFINE_PER_CPU(u32, pkt_seq_counter);
143
144 static void clean_percpu(struct work_struct *work);
145 static DECLARE_DELAYED_WORK(clean_percpu_wq, clean_percpu);
146
147 static struct stt_sock *stt_find_sock(struct net *net, __be16 port)
148 {
149         struct stt_net *sn = net_generic(net, stt_net_id);
150         struct stt_sock *stt_sock;
151
152         list_for_each_entry_rcu(stt_sock, &sn->sock_list, list) {
153                 if (inet_sk(stt_sock->sock->sk)->inet_sport == port)
154                         return stt_sock;
155         }
156         return NULL;
157 }
158
159 static __be32 ack_seq(void)
160 {
161 #if NR_CPUS <= 65536
162         u32 pkt_seq, ack;
163
164         pkt_seq = this_cpu_read(pkt_seq_counter);
165         ack = pkt_seq << ilog2(NR_CPUS) | smp_processor_id();
166         this_cpu_inc(pkt_seq_counter);
167
168         return (__force __be32)ack;
169 #else
170 #error "Support for greater than 64k CPUs not implemented"
171 #endif
172 }
173
174 static int clear_gso(struct sk_buff *skb)
175 {
176         struct skb_shared_info *shinfo = skb_shinfo(skb);
177         int err;
178
179         if (shinfo->gso_type == 0 && shinfo->gso_size == 0 &&
180             shinfo->gso_segs == 0)
181                 return 0;
182
183         err = skb_unclone(skb, GFP_ATOMIC);
184         if (unlikely(err))
185                 return err;
186
187         shinfo = skb_shinfo(skb);
188         shinfo->gso_type = 0;
189         shinfo->gso_size = 0;
190         shinfo->gso_segs = 0;
191         return 0;
192 }
193
194 static struct sk_buff *normalize_frag_list(struct sk_buff *head,
195                                            struct sk_buff **skbp)
196 {
197         struct sk_buff *skb = *skbp;
198         struct sk_buff *last;
199
200         do {
201                 struct sk_buff *frags;
202
203                 if (skb_shared(skb)) {
204                         struct sk_buff *nskb = skb_clone(skb, GFP_ATOMIC);
205
206                         if (unlikely(!nskb))
207                                 return ERR_PTR(-ENOMEM);
208
209                         nskb->next = skb->next;
210                         consume_skb(skb);
211                         skb = nskb;
212                         *skbp = skb;
213                 }
214
215                 if (head) {
216                         head->len -= skb->len;
217                         head->data_len -= skb->len;
218                         head->truesize -= skb->truesize;
219                 }
220
221                 frags = skb_shinfo(skb)->frag_list;
222                 if (frags) {
223                         int err;
224
225                         err = skb_unclone(skb, GFP_ATOMIC);
226                         if (unlikely(err))
227                                 return ERR_PTR(err);
228
229                         last = normalize_frag_list(skb, &frags);
230                         if (IS_ERR(last))
231                                 return last;
232
233                         skb_shinfo(skb)->frag_list = NULL;
234                         last->next = skb->next;
235                         skb->next = frags;
236                 } else {
237                         last = skb;
238                 }
239
240                 skbp = &skb->next;
241         } while ((skb = skb->next));
242
243         return last;
244 }
245
246 /* Takes a linked list of skbs, which potentially contain frag_list
247  * (whose members in turn potentially contain frag_lists, etc.) and
248  * converts them into a single linear linked list.
249  */
250 static int straighten_frag_list(struct sk_buff **skbp)
251 {
252         struct sk_buff *err_skb;
253
254         err_skb = normalize_frag_list(NULL, skbp);
255         if (IS_ERR(err_skb))
256                 return PTR_ERR(err_skb);
257
258         return 0;
259 }
260
261 static void copy_skb_metadata(struct sk_buff *to, struct sk_buff *from)
262 {
263         to->protocol = from->protocol;
264         to->tstamp = from->tstamp;
265         to->priority = from->priority;
266         to->mark = from->mark;
267         to->vlan_tci = from->vlan_tci;
268 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
269         to->vlan_proto = from->vlan_proto;
270 #endif
271         skb_copy_secmark(to, from);
272 }
273
274 static void update_headers(struct sk_buff *skb, bool head,
275                                unsigned int l4_offset, unsigned int hdr_len,
276                                bool ipv4, u32 tcp_seq)
277 {
278         u16 old_len, new_len;
279         __be32 delta;
280         struct tcphdr *tcph;
281         int gso_size;
282
283         if (ipv4) {
284                 struct iphdr *iph = (struct iphdr *)(skb->data + ETH_HLEN);
285
286                 old_len = ntohs(iph->tot_len);
287                 new_len = skb->len - ETH_HLEN;
288                 iph->tot_len = htons(new_len);
289
290                 ip_send_check(iph);
291         } else {
292                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(skb->data + ETH_HLEN);
293
294                 old_len = ntohs(ip6h->payload_len);
295                 new_len = skb->len - ETH_HLEN - sizeof(struct ipv6hdr);
296                 ip6h->payload_len = htons(new_len);
297         }
298
299         tcph = (struct tcphdr *)(skb->data + l4_offset);
300         if (!head) {
301                 tcph->seq = htonl(tcp_seq);
302                 tcph->cwr = 0;
303         }
304
305         if (skb->next) {
306                 tcph->fin = 0;
307                 tcph->psh = 0;
308         }
309
310         delta = htonl(~old_len + new_len);
311         tcph->check = ~csum_fold((__force __wsum)((__force u32)tcph->check +
312                                  (__force u32)delta));
313
314         gso_size = skb_shinfo(skb)->gso_size;
315         if (gso_size && skb->len - hdr_len <= gso_size)
316                 BUG_ON(clear_gso(skb));
317 }
318
319 static bool can_segment(struct sk_buff *head, bool ipv4, bool tcp, bool csum_partial)
320 {
321         /* If no offloading is in use then we don't have enough information
322          * to process the headers.
323          */
324         if (!csum_partial)
325                 goto linearize;
326
327         /* Handling UDP packets requires IP fragmentation, which means that
328          * the L4 checksum can no longer be calculated by hardware (since the
329          * fragments are in different packets.  If we have to compute the
330          * checksum it's faster just to linearize and large UDP packets are
331          * pretty uncommon anyways, so it's not worth dealing with for now.
332          */
333         if (!tcp)
334                 goto linearize;
335
336         if (ipv4) {
337                 struct iphdr *iph = (struct iphdr *)(head->data + ETH_HLEN);
338
339                 /* It's difficult to get the IP IDs exactly right here due to
340                  * varying segment sizes and potentially multiple layers of
341                  * segmentation.  IP ID isn't important when DF is set and DF
342                  * is generally set for TCP packets, so just linearize if it's
343                  * not.
344                  */
345                 if (!(iph->frag_off & htons(IP_DF)))
346                         goto linearize;
347         } else {
348                 struct ipv6hdr *ip6h = (struct ipv6hdr *)(head->data + ETH_HLEN);
349
350                 /* Jumbograms require more processing to update and we'll
351                  * probably never see them, so just linearize.
352                  */
353                 if (ip6h->payload_len == 0)
354                         goto linearize;
355         }
356         return true;
357
358 linearize:
359         return false;
360 }
361
362 static int copy_headers(struct sk_buff *head, struct sk_buff *frag,
363                             int hdr_len)
364 {
365         u16 csum_start;
366
367         if (skb_cloned(frag) || skb_headroom(frag) < hdr_len) {
368                 int extra_head = hdr_len - skb_headroom(frag);
369
370                 extra_head = extra_head > 0 ? extra_head : 0;
371                 if (unlikely(pskb_expand_head(frag, extra_head, 0,
372                                               GFP_ATOMIC)))
373                         return -ENOMEM;
374         }
375
376         memcpy(__skb_push(frag, hdr_len), head->data, hdr_len);
377
378         csum_start = head->csum_start - skb_headroom(head);
379         frag->csum_start = skb_headroom(frag) + csum_start;
380         frag->csum_offset = head->csum_offset;
381         frag->ip_summed = head->ip_summed;
382
383         skb_shinfo(frag)->gso_size = skb_shinfo(head)->gso_size;
384         skb_shinfo(frag)->gso_type = skb_shinfo(head)->gso_type;
385         skb_shinfo(frag)->gso_segs = 0;
386
387         copy_skb_metadata(frag, head);
388         return 0;
389 }
390
391 static int skb_list_segment(struct sk_buff *head, bool ipv4, int l4_offset)
392 {
393         struct sk_buff *skb;
394         struct tcphdr *tcph;
395         int seg_len;
396         int hdr_len;
397         int tcp_len;
398         u32 seq;
399
400         if (unlikely(!pskb_may_pull(head, l4_offset + sizeof(*tcph))))
401                 return -ENOMEM;
402
403         tcph = (struct tcphdr *)(head->data + l4_offset);
404         tcp_len = tcph->doff * 4;
405         hdr_len = l4_offset + tcp_len;
406
407         if (unlikely((tcp_len < sizeof(struct tcphdr)) ||
408                      (head->len < hdr_len)))
409                 return -EINVAL;
410
411         if (unlikely(!pskb_may_pull(head, hdr_len)))
412                 return -ENOMEM;
413
414         tcph = (struct tcphdr *)(head->data + l4_offset);
415         /* Update header of each segment. */
416         seq = ntohl(tcph->seq);
417         seg_len = skb_pagelen(head) - hdr_len;
418
419         skb = skb_shinfo(head)->frag_list;
420         skb_shinfo(head)->frag_list = NULL;
421         head->next = skb;
422         for (; skb; skb = skb->next) {
423                 int err;
424
425                 head->len -= skb->len;
426                 head->data_len -= skb->len;
427                 head->truesize -= skb->truesize;
428
429                 seq += seg_len;
430                 seg_len = skb->len;
431                 err = copy_headers(head, skb, hdr_len);
432                 if (err)
433                         return err;
434                 update_headers(skb, false, l4_offset, hdr_len, ipv4, seq);
435         }
436         update_headers(head, true, l4_offset, hdr_len, ipv4, 0);
437         return 0;
438 }
439
440 static int coalesce_skb(struct sk_buff **headp)
441 {
442         struct sk_buff *frag, *head, *prev;
443         int err;
444
445         err = straighten_frag_list(headp);
446         if (unlikely(err))
447                 return err;
448         head = *headp;
449
450         /* Coalesce frag list. */
451         prev = head;
452         for (frag = head->next; frag; frag = frag->next) {
453                 bool headstolen;
454                 int delta;
455
456                 if (unlikely(skb_unclone(prev, GFP_ATOMIC)))
457                         return -ENOMEM;
458
459                 if (!skb_try_coalesce(prev, frag, &headstolen, &delta)) {
460                         prev = frag;
461                         continue;
462                 }
463
464                 prev->next = frag->next;
465                 frag->len = 0;
466                 frag->data_len = 0;
467                 frag->truesize -= delta;
468                 kfree_skb_partial(frag, headstolen);
469                 frag = prev;
470         }
471
472         if (!head->next)
473                 return 0;
474
475         for (frag = head->next; frag; frag = frag->next) {
476                 head->len += frag->len;
477                 head->data_len += frag->len;
478                 head->truesize += frag->truesize;
479         }
480
481         skb_shinfo(head)->frag_list = head->next;
482         head->next = NULL;
483         return 0;
484 }
485
486 static int __try_to_segment(struct sk_buff *skb, bool csum_partial,
487                             bool ipv4, bool tcp, int l4_offset)
488 {
489         if (can_segment(skb, ipv4, tcp, csum_partial))
490                 return skb_list_segment(skb, ipv4, l4_offset);
491         else
492                 return skb_linearize(skb);
493 }
494
495 static int try_to_segment(struct sk_buff *skb)
496 {
497         struct stthdr *stth = stt_hdr(skb);
498         bool csum_partial = !!(stth->flags & STT_CSUM_PARTIAL);
499         bool ipv4 = !!(stth->flags & STT_PROTO_IPV4);
500         bool tcp = !!(stth->flags & STT_PROTO_TCP);
501         int l4_offset = stth->l4_offset;
502
503         return __try_to_segment(skb, csum_partial, ipv4, tcp, l4_offset);
504 }
505
506 static int segment_skb(struct sk_buff **headp, bool csum_partial,
507                        bool ipv4, bool tcp, int l4_offset)
508 {
509         int err;
510
511         err = coalesce_skb(headp);
512         if (err)
513                 return err;
514
515         if (skb_shinfo(*headp)->frag_list)
516                 return __try_to_segment(*headp, csum_partial,
517                                         ipv4, tcp, l4_offset);
518         return 0;
519 }
520
521 static int __push_stt_header(struct sk_buff *skb, __be64 tun_id,
522                              __be16 s_port, __be16 d_port,
523                              __be32 saddr, __be32 dst,
524                              __be16 l3_proto, u8 l4_proto,
525                              int dst_mtu)
526 {
527         int data_len = skb->len + sizeof(struct stthdr) + STT_ETH_PAD;
528         unsigned short encap_mss;
529         struct tcphdr *tcph;
530         struct stthdr *stth;
531
532         skb_push(skb, STT_HEADER_LEN);
533         skb_reset_transport_header(skb);
534         tcph = tcp_hdr(skb);
535         memset(tcph, 0, STT_HEADER_LEN);
536         stth = stt_hdr(skb);
537
538         if (skb->ip_summed == CHECKSUM_PARTIAL) {
539                 stth->flags |= STT_CSUM_PARTIAL;
540
541                 stth->l4_offset = skb->csum_start -
542                                         (skb_headroom(skb) +
543                                         STT_HEADER_LEN);
544
545                 if (l3_proto == htons(ETH_P_IP))
546                         stth->flags |= STT_PROTO_IPV4;
547
548                 if (l4_proto == IPPROTO_TCP)
549                         stth->flags |= STT_PROTO_TCP;
550
551                 stth->mss = htons(skb_shinfo(skb)->gso_size);
552         } else if (skb->ip_summed == CHECKSUM_UNNECESSARY) {
553                 stth->flags |= STT_CSUM_VERIFIED;
554         }
555
556         stth->vlan_tci = htons(skb->vlan_tci);
557         skb->vlan_tci = 0;
558         put_unaligned(tun_id, &stth->key);
559
560         tcph->source    = s_port;
561         tcph->dest      = d_port;
562         tcph->doff      = sizeof(struct tcphdr) / 4;
563         tcph->ack       = 1;
564         tcph->psh       = 1;
565         tcph->window    = htons(USHRT_MAX);
566         tcph->seq       = htonl(data_len << STT_SEQ_LEN_SHIFT);
567         tcph->ack_seq   = ack_seq();
568         tcph->check     = ~tcp_v4_check(skb->len, saddr, dst, 0);
569
570         skb->csum_start = skb_transport_header(skb) - skb->head;
571         skb->csum_offset = offsetof(struct tcphdr, check);
572         skb->ip_summed = CHECKSUM_PARTIAL;
573
574         encap_mss = dst_mtu - sizeof(struct iphdr) - sizeof(struct tcphdr);
575         if (data_len > encap_mss) {
576                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
577                         return -EINVAL;
578
579                 skb_shinfo(skb)->gso_type = SKB_GSO_TCPV4;
580                 skb_shinfo(skb)->gso_size = encap_mss;
581                 skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(data_len, encap_mss);
582         } else {
583                 if (unlikely(clear_gso(skb)))
584                         return -EINVAL;
585         }
586         return 0;
587 }
588
589 static struct sk_buff *push_stt_header(struct sk_buff *head, __be64 tun_id,
590                                        __be16 s_port, __be16 d_port,
591                                        __be32 saddr, __be32 dst,
592                                        __be16 l3_proto, u8 l4_proto,
593                                        int dst_mtu)
594 {
595         struct sk_buff *skb;
596
597         if (skb_shinfo(head)->frag_list) {
598                 bool ipv4 = (l3_proto == htons(ETH_P_IP));
599                 bool tcp = (l4_proto == IPPROTO_TCP);
600                 bool csum_partial = (head->ip_summed == CHECKSUM_PARTIAL);
601                 int l4_offset = skb_transport_offset(head);
602
603                 /* Need to call skb_orphan() to report currect true-size.
604                  * calling skb_orphan() in this layer is odd but SKB with
605                  * frag-list should not be associated with any socket, so
606                  * skb-orphan should be no-op. */
607                 skb_orphan(head);
608                 if (unlikely(segment_skb(&head, csum_partial,
609                                          ipv4, tcp, l4_offset)))
610                         goto error;
611         }
612
613         for (skb = head; skb; skb = skb->next) {
614                 if (__push_stt_header(skb, tun_id, s_port, d_port, saddr, dst,
615                                       l3_proto, l4_proto, dst_mtu))
616                         goto error;
617         }
618
619         return head;
620 error:
621         kfree_skb_list(head);
622         return NULL;
623 }
624
625 static int stt_can_offload(struct sk_buff *skb, __be16 l3_proto, u8 l4_proto)
626 {
627         if (skb_is_gso(skb) && skb->ip_summed != CHECKSUM_PARTIAL) {
628                 int csum_offset;
629                 __sum16 *csum;
630                 int len;
631
632                 if (l4_proto == IPPROTO_TCP)
633                         csum_offset = offsetof(struct tcphdr, check);
634                 else if (l4_proto == IPPROTO_UDP)
635                         csum_offset = offsetof(struct udphdr, check);
636                 else
637                         return 0;
638
639                 len = skb->len - skb_transport_offset(skb);
640                 csum = (__sum16 *)(skb_transport_header(skb) + csum_offset);
641
642                 if (unlikely(!pskb_may_pull(skb, skb_transport_offset(skb) +
643                                                  csum_offset + sizeof(*csum))))
644                         return -EINVAL;
645
646                 if (l3_proto == htons(ETH_P_IP)) {
647                         struct iphdr *iph = ip_hdr(skb);
648
649                         *csum = ~csum_tcpudp_magic(iph->saddr, iph->daddr,
650                                                    len, l4_proto, 0);
651                 } else if (l3_proto == htons(ETH_P_IPV6)) {
652                         struct ipv6hdr *ip6h = ipv6_hdr(skb);
653
654                         *csum = ~csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
655                                                  len, l4_proto, 0);
656                 } else {
657                         return 0;
658                 }
659                 skb->csum_start = skb_transport_header(skb) - skb->head;
660                 skb->csum_offset = csum_offset;
661                 skb->ip_summed = CHECKSUM_PARTIAL;
662         }
663
664         if (skb->ip_summed == CHECKSUM_PARTIAL) {
665                 /* Assume receiver can only offload TCP/UDP over IPv4/6,
666                  * and require 802.1Q VLANs to be accelerated.
667                  */
668                 if (l3_proto != htons(ETH_P_IP) &&
669                     l3_proto != htons(ETH_P_IPV6))
670                         return 0;
671
672                 if (l4_proto != IPPROTO_TCP && l4_proto != IPPROTO_UDP)
673                         return 0;
674
675                 /* L4 offset must fit in a 1-byte field. */
676                 if (skb->csum_start - skb_headroom(skb) > 255)
677                         return 0;
678
679                 if (skb_shinfo(skb)->gso_type & ~SUPPORTED_GSO_TYPES)
680                         return 0;
681         }
682         /* Total size of encapsulated packet must fit in 16 bits. */
683         if (skb->len + STT_HEADER_LEN + sizeof(struct iphdr) > 65535)
684                 return 0;
685
686 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
687         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q))
688                 return 0;
689 #endif
690         return 1;
691 }
692
693 static bool need_linearize(const struct sk_buff *skb)
694 {
695         struct skb_shared_info *shinfo = skb_shinfo(skb);
696         int i;
697
698         if (unlikely(shinfo->frag_list))
699                 return true;
700
701         /* Generally speaking we should linearize if there are paged frags.
702          * However, if all of the refcounts are 1 we know nobody else can
703          * change them from underneath us and we can skip the linearization.
704          */
705         for (i = 0; i < shinfo->nr_frags; i++)
706                 if (unlikely(page_count(skb_frag_page(&shinfo->frags[i])) > 1))
707                         return true;
708
709         return false;
710 }
711
712 static struct sk_buff *handle_offloads(struct sk_buff *skb, int min_headroom)
713 {
714         int err;
715
716 #if LINUX_VERSION_CODE >= KERNEL_VERSION(3,10,0)
717         if (skb_vlan_tag_present(skb) && skb->vlan_proto != htons(ETH_P_8021Q)) {
718
719                 min_headroom += VLAN_HLEN;
720                 if (skb_headroom(skb) < min_headroom) {
721                         int head_delta = SKB_DATA_ALIGN(min_headroom -
722                                                         skb_headroom(skb) + 16);
723
724                         err = pskb_expand_head(skb, max_t(int, head_delta, 0),
725                                                0, GFP_ATOMIC);
726                         if (unlikely(err))
727                                 goto error;
728                 }
729
730                 skb = __vlan_hwaccel_push_inside(skb);
731                 if (!skb) {
732                         err = -ENOMEM;
733                         goto error;
734                 }
735         }
736 #endif
737
738         if (skb_is_gso(skb)) {
739                 struct sk_buff *nskb;
740                 char cb[sizeof(skb->cb)];
741
742                 memcpy(cb, skb->cb, sizeof(cb));
743
744                 nskb = __skb_gso_segment(skb, 0, false);
745                 if (IS_ERR(nskb)) {
746                         err = PTR_ERR(nskb);
747                         goto error;
748                 }
749
750                 consume_skb(skb);
751                 skb = nskb;
752                 while (nskb) {
753                         memcpy(nskb->cb, cb, sizeof(cb));
754                         nskb = nskb->next;
755                 }
756         } else if (skb->ip_summed == CHECKSUM_PARTIAL) {
757                 /* Pages aren't locked and could change at any time.
758                  * If this happens after we compute the checksum, the
759                  * checksum will be wrong.  We linearize now to avoid
760                  * this problem.
761                  */
762                 if (unlikely(need_linearize(skb))) {
763                         err = __skb_linearize(skb);
764                         if (unlikely(err))
765                                 goto error;
766                 }
767
768                 err = skb_checksum_help(skb);
769                 if (unlikely(err))
770                         goto error;
771         }
772         skb->ip_summed = CHECKSUM_NONE;
773
774         return skb;
775 error:
776         kfree_skb(skb);
777         return ERR_PTR(err);
778 }
779
780 static int skb_list_xmit(struct rtable *rt, struct sk_buff *skb, __be32 src,
781                          __be32 dst, __u8 tos, __u8 ttl, __be16 df)
782 {
783         int len = 0;
784
785         while (skb) {
786                 struct sk_buff *next = skb->next;
787
788                 if (next)
789                         dst_clone(&rt->dst);
790
791                 skb_clear_ovs_gso_cb(skb);
792                 skb->next = NULL;
793                 len += iptunnel_xmit(NULL, rt, skb, src, dst, IPPROTO_TCP,
794                                      tos, ttl, df, false);
795
796                 skb = next;
797         }
798         return len;
799 }
800
801 static u8 parse_ipv6_l4_proto(struct sk_buff *skb)
802 {
803         unsigned int nh_ofs = skb_network_offset(skb);
804         int payload_ofs;
805         struct ipv6hdr *nh;
806         uint8_t nexthdr;
807         __be16 frag_off;
808
809         if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct ipv6hdr))))
810                 return 0;
811
812         nh = ipv6_hdr(skb);
813         nexthdr = nh->nexthdr;
814         payload_ofs = (u8 *)(nh + 1) - skb->data;
815
816         payload_ofs = ipv6_skip_exthdr(skb, payload_ofs, &nexthdr, &frag_off);
817         if (unlikely(payload_ofs < 0))
818                 return 0;
819
820         return nexthdr;
821 }
822
823 static u8 skb_get_l4_proto(struct sk_buff *skb, __be16 l3_proto)
824 {
825         if (l3_proto == htons(ETH_P_IP)) {
826                 unsigned int nh_ofs = skb_network_offset(skb);
827
828                 if (unlikely(!pskb_may_pull(skb, nh_ofs + sizeof(struct iphdr))))
829                         return 0;
830
831                 return ip_hdr(skb)->protocol;
832         } else if (l3_proto == htons(ETH_P_IPV6)) {
833                 return parse_ipv6_l4_proto(skb);
834         }
835         return 0;
836 }
837
838 int rpl_stt_xmit_skb(struct sk_buff *skb, struct rtable *rt,
839                  __be32 src, __be32 dst, __u8 tos,
840                  __u8 ttl, __be16 df, __be16 src_port, __be16 dst_port,
841                  __be64 tun_id)
842 {
843         struct ethhdr *eh = eth_hdr(skb);
844         int ret = 0, min_headroom;
845         __be16 inner_l3_proto;
846          u8 inner_l4_proto;
847
848         inner_l3_proto = eh->h_proto;
849         inner_l4_proto = skb_get_l4_proto(skb, inner_l3_proto);
850
851         min_headroom = LL_RESERVED_SPACE(rt->dst.dev) + rt->dst.header_len
852                         + STT_HEADER_LEN + sizeof(struct iphdr);
853
854         if (skb_headroom(skb) < min_headroom || skb_header_cloned(skb)) {
855                 int head_delta = SKB_DATA_ALIGN(min_headroom -
856                                                 skb_headroom(skb) +
857                                                 16);
858
859                 ret = pskb_expand_head(skb, max_t(int, head_delta, 0),
860                                        0, GFP_ATOMIC);
861                 if (unlikely(ret))
862                         goto err_free_rt;
863         }
864
865         ret = stt_can_offload(skb, inner_l3_proto, inner_l4_proto);
866         if (ret < 0)
867                 goto err_free_rt;
868         if (!ret) {
869                 skb = handle_offloads(skb, min_headroom);
870                 if (IS_ERR(skb)) {
871                         ret = PTR_ERR(skb);
872                         skb = NULL;
873                         goto err_free_rt;
874                 }
875         }
876
877         ret = 0;
878         while (skb) {
879                 struct sk_buff *next_skb = skb->next;
880
881                 skb->next = NULL;
882
883                 if (next_skb)
884                         dst_clone(&rt->dst);
885
886                 /* Push STT and TCP header. */
887                 skb = push_stt_header(skb, tun_id, src_port, dst_port, src,
888                                       dst, inner_l3_proto, inner_l4_proto,
889                                       dst_mtu(&rt->dst));
890                 if (unlikely(!skb)) {
891                         ip_rt_put(rt);
892                         goto next;
893                 }
894
895                 /* Push IP header. */
896                 ret += skb_list_xmit(rt, skb, src, dst, tos, ttl, df);
897
898 next:
899                 skb = next_skb;
900         }
901
902         return ret;
903
904 err_free_rt:
905         ip_rt_put(rt);
906         kfree_skb(skb);
907         return ret;
908 }
909 EXPORT_SYMBOL_GPL(rpl_stt_xmit_skb);
910
911 static void free_frag(struct stt_percpu *stt_percpu,
912                       struct pkt_frag *frag)
913 {
914         stt_percpu->frag_mem_used -= FRAG_CB(frag->skbs)->first.mem_used;
915         kfree_skb_list(frag->skbs);
916         list_del(&frag->lru_node);
917         frag->skbs = NULL;
918 }
919
920 static void evict_frags(struct stt_percpu *stt_percpu)
921 {
922         while (!list_empty(&stt_percpu->frag_lru) &&
923                stt_percpu->frag_mem_used > REASM_LO_THRESH) {
924                 struct pkt_frag *frag;
925
926                 frag = list_first_entry(&stt_percpu->frag_lru,
927                                         struct pkt_frag,
928                                         lru_node);
929                 free_frag(stt_percpu, frag);
930         }
931 }
932
933 static bool pkt_key_match(struct net *net,
934                           const struct pkt_frag *a, const struct pkt_key *b)
935 {
936         return a->key.saddr == b->saddr && a->key.daddr == b->daddr &&
937                a->key.pkt_seq == b->pkt_seq && a->key.mark == b->mark &&
938                net_eq(dev_net(a->skbs->dev), net);
939 }
940
941 static u32 pkt_key_hash(const struct net *net, const struct pkt_key *key)
942 {
943         u32 initval = frag_hash_seed ^ (u32)(unsigned long)net ^ key->mark;
944
945         return jhash_3words((__force u32)key->saddr, (__force u32)key->daddr,
946                             (__force u32)key->pkt_seq, initval);
947 }
948
949 static struct pkt_frag *lookup_frag(struct net *net,
950                                     struct stt_percpu *stt_percpu,
951                                     const struct pkt_key *key, u32 hash)
952 {
953         struct pkt_frag *frag, *victim_frag = NULL;
954         int i;
955
956         for (i = 0; i < FRAG_HASH_SEGS; i++) {
957                 frag = flex_array_get(stt_percpu->frag_hash,
958                                       hash & (FRAG_HASH_ENTRIES - 1));
959
960                 if (frag->skbs &&
961                     time_before(jiffies, frag->timestamp + FRAG_EXP_TIME) &&
962                     pkt_key_match(net, frag, key))
963                         return frag;
964
965                 if (!victim_frag ||
966                     (victim_frag->skbs &&
967                      (!frag->skbs ||
968                       time_before(frag->timestamp, victim_frag->timestamp))))
969                         victim_frag = frag;
970
971                 hash >>= FRAG_HASH_SHIFT;
972         }
973
974         if (victim_frag->skbs)
975                 free_frag(stt_percpu, victim_frag);
976
977         return victim_frag;
978 }
979
980 static struct sk_buff *reassemble(struct sk_buff *skb)
981 {
982         struct iphdr *iph = ip_hdr(skb);
983         struct tcphdr *tcph = tcp_hdr(skb);
984         u32 seq = ntohl(tcph->seq);
985         struct stt_percpu *stt_percpu;
986         struct sk_buff *last_skb;
987         struct pkt_frag *frag;
988         struct pkt_key key;
989         int tot_len;
990         u32 hash;
991
992         tot_len = seq >> STT_SEQ_LEN_SHIFT;
993         FRAG_CB(skb)->offset = seq & STT_SEQ_OFFSET_MASK;
994
995         if (unlikely(skb->len == 0))
996                 goto out_free;
997
998         if (unlikely(FRAG_CB(skb)->offset + skb->len > tot_len))
999                 goto out_free;
1000
1001         if (tot_len == skb->len)
1002                 goto out;
1003
1004         key.saddr = iph->saddr;
1005         key.daddr = iph->daddr;
1006         key.pkt_seq = tcph->ack_seq;
1007         key.mark = skb->mark;
1008         hash = pkt_key_hash(dev_net(skb->dev), &key);
1009
1010         stt_percpu = per_cpu_ptr(stt_percpu_data, smp_processor_id());
1011
1012         spin_lock(&stt_percpu->lock);
1013
1014         if (unlikely(stt_percpu->frag_mem_used + skb->truesize > REASM_HI_THRESH))
1015                 evict_frags(stt_percpu);
1016
1017         frag = lookup_frag(dev_net(skb->dev), stt_percpu, &key, hash);
1018         if (!frag->skbs) {
1019                 frag->skbs = skb;
1020                 frag->key = key;
1021                 frag->timestamp = jiffies;
1022                 FRAG_CB(skb)->first.last_skb = skb;
1023                 FRAG_CB(skb)->first.mem_used = skb->truesize;
1024                 FRAG_CB(skb)->first.tot_len = tot_len;
1025                 FRAG_CB(skb)->first.rcvd_len = skb->len;
1026                 FRAG_CB(skb)->first.set_ecn_ce = false;
1027                 list_add_tail(&frag->lru_node, &stt_percpu->frag_lru);
1028                 stt_percpu->frag_mem_used += skb->truesize;
1029
1030                 skb = NULL;
1031                 goto unlock;
1032         }
1033
1034         /* Optimize for the common case where fragments are received in-order
1035          * and not overlapping.
1036          */
1037         last_skb = FRAG_CB(frag->skbs)->first.last_skb;
1038         if (likely(FRAG_CB(last_skb)->offset + last_skb->len ==
1039                    FRAG_CB(skb)->offset)) {
1040                 last_skb->next = skb;
1041                 FRAG_CB(frag->skbs)->first.last_skb = skb;
1042         } else {
1043                 struct sk_buff *prev = NULL, *next;
1044
1045                 for (next = frag->skbs; next; next = next->next) {
1046                         if (FRAG_CB(next)->offset >= FRAG_CB(skb)->offset)
1047                                 break;
1048                         prev = next;
1049                 }
1050
1051                 /* Overlapping fragments aren't allowed.  We shouldn't start
1052                  * before the end of the previous fragment.
1053                  */
1054                 if (prev &&
1055                     FRAG_CB(prev)->offset + prev->len > FRAG_CB(skb)->offset)
1056                         goto unlock_free;
1057
1058                 /* We also shouldn't end after the beginning of the next
1059                  * fragment.
1060                  */
1061                 if (next &&
1062                     FRAG_CB(skb)->offset + skb->len > FRAG_CB(next)->offset)
1063                         goto unlock_free;
1064
1065                 if (prev) {
1066                         prev->next = skb;
1067                 } else {
1068                         FRAG_CB(skb)->first = FRAG_CB(frag->skbs)->first;
1069                         frag->skbs = skb;
1070                 }
1071
1072                 if (next)
1073                         skb->next = next;
1074                 else
1075                         FRAG_CB(frag->skbs)->first.last_skb = skb;
1076         }
1077
1078         FRAG_CB(frag->skbs)->first.set_ecn_ce |= INET_ECN_is_ce(iph->tos);
1079         FRAG_CB(frag->skbs)->first.rcvd_len += skb->len;
1080         FRAG_CB(frag->skbs)->first.mem_used += skb->truesize;
1081         stt_percpu->frag_mem_used += skb->truesize;
1082
1083         if (FRAG_CB(frag->skbs)->first.tot_len ==
1084             FRAG_CB(frag->skbs)->first.rcvd_len) {
1085                 struct sk_buff *frag_head = frag->skbs;
1086
1087                 frag_head->tstamp = skb->tstamp;
1088                 if (FRAG_CB(frag_head)->first.set_ecn_ce)
1089                         INET_ECN_set_ce(frag_head);
1090
1091                 list_del(&frag->lru_node);
1092                 stt_percpu->frag_mem_used -= FRAG_CB(frag_head)->first.mem_used;
1093                 frag->skbs = NULL;
1094                 skb = frag_head;
1095         } else {
1096                 list_move_tail(&frag->lru_node, &stt_percpu->frag_lru);
1097                 skb = NULL;
1098         }
1099
1100         goto unlock;
1101
1102 unlock_free:
1103         kfree_skb(skb);
1104         skb = NULL;
1105 unlock:
1106         spin_unlock(&stt_percpu->lock);
1107         return skb;
1108 out_free:
1109         kfree_skb(skb);
1110         skb = NULL;
1111 out:
1112         return skb;
1113 }
1114
1115 static bool validate_checksum(struct sk_buff *skb)
1116 {
1117         struct iphdr *iph = ip_hdr(skb);
1118
1119         if (skb_csum_unnecessary(skb))
1120                 return true;
1121
1122         if (skb->ip_summed == CHECKSUM_COMPLETE &&
1123             !tcp_v4_check(skb->len, iph->saddr, iph->daddr, skb->csum))
1124                 return true;
1125
1126         skb->csum = csum_tcpudp_nofold(iph->saddr, iph->daddr, skb->len,
1127                                        IPPROTO_TCP, 0);
1128
1129         return __tcp_checksum_complete(skb) == 0;
1130 }
1131
1132 static bool set_offloads(struct sk_buff *skb)
1133 {
1134         struct stthdr *stth = stt_hdr(skb);
1135         unsigned short gso_type;
1136         int l3_header_size;
1137         int l4_header_size;
1138         u16 csum_offset;
1139         u8 proto_type;
1140
1141         if (stth->vlan_tci)
1142                 __vlan_hwaccel_put_tag(skb, htons(ETH_P_8021Q),
1143                                        ntohs(stth->vlan_tci));
1144
1145         if (!(stth->flags & STT_CSUM_PARTIAL)) {
1146                 if (stth->flags & STT_CSUM_VERIFIED)
1147                         skb->ip_summed = CHECKSUM_UNNECESSARY;
1148                 else
1149                         skb->ip_summed = CHECKSUM_NONE;
1150
1151                 return clear_gso(skb) == 0;
1152         }
1153
1154         proto_type = stth->flags & STT_PROTO_TYPES;
1155
1156         switch (proto_type) {
1157         case (STT_PROTO_IPV4 | STT_PROTO_TCP):
1158                 /* TCP/IPv4 */
1159                 csum_offset = offsetof(struct tcphdr, check);
1160                 gso_type = SKB_GSO_TCPV4;
1161                 l3_header_size = sizeof(struct iphdr);
1162                 l4_header_size = sizeof(struct tcphdr);
1163                 skb->protocol = htons(ETH_P_IP);
1164                 break;
1165         case STT_PROTO_TCP:
1166                 /* TCP/IPv6 */
1167                 csum_offset = offsetof(struct tcphdr, check);
1168                 gso_type = SKB_GSO_TCPV6;
1169                 l3_header_size = sizeof(struct ipv6hdr);
1170                 l4_header_size = sizeof(struct tcphdr);
1171                 skb->protocol = htons(ETH_P_IPV6);
1172                 break;
1173         case STT_PROTO_IPV4:
1174                 /* UDP/IPv4 */
1175                 csum_offset = offsetof(struct udphdr, check);
1176                 gso_type = SKB_GSO_UDP;
1177                 l3_header_size = sizeof(struct iphdr);
1178                 l4_header_size = sizeof(struct udphdr);
1179                 skb->protocol = htons(ETH_P_IP);
1180                 break;
1181         default:
1182                 /* UDP/IPv6 */
1183                 csum_offset = offsetof(struct udphdr, check);
1184                 gso_type = SKB_GSO_UDP;
1185                 l3_header_size = sizeof(struct ipv6hdr);
1186                 l4_header_size = sizeof(struct udphdr);
1187                 skb->protocol = htons(ETH_P_IPV6);
1188         }
1189
1190         if (unlikely(stth->l4_offset < ETH_HLEN + l3_header_size))
1191                 return false;
1192
1193         if (unlikely(!pskb_may_pull(skb, stth->l4_offset + l4_header_size)))
1194                 return false;
1195
1196         stth = stt_hdr(skb);
1197
1198         skb->csum_start = skb_headroom(skb) + stth->l4_offset;
1199         skb->csum_offset = csum_offset;
1200         skb->ip_summed = CHECKSUM_PARTIAL;
1201
1202         if (stth->mss) {
1203                 if (unlikely(skb_unclone(skb, GFP_ATOMIC)))
1204                         return false;
1205
1206                 skb_shinfo(skb)->gso_type = gso_type | SKB_GSO_DODGY;
1207                 skb_shinfo(skb)->gso_size = ntohs(stth->mss);
1208                 skb_shinfo(skb)->gso_segs = 0;
1209         } else {
1210                 if (unlikely(clear_gso(skb)))
1211                         return false;
1212         }
1213
1214         return true;
1215 }
1216 static void stt_rcv(struct stt_sock *stt_sock, struct sk_buff *skb)
1217 {
1218         int err;
1219
1220         if (unlikely(!validate_checksum(skb)))
1221                 goto drop;
1222
1223         skb = reassemble(skb);
1224         if (!skb)
1225                 return;
1226
1227         if (skb->next && coalesce_skb(&skb))
1228                 goto drop;
1229
1230         err = iptunnel_pull_header(skb,
1231                                    sizeof(struct stthdr) + STT_ETH_PAD,
1232                                    htons(ETH_P_TEB));
1233         if (unlikely(err))
1234                 goto drop;
1235
1236         if (unlikely(stt_hdr(skb)->version != 0))
1237                 goto drop;
1238
1239         if (unlikely(!set_offloads(skb)))
1240                 goto drop;
1241
1242         if (skb_shinfo(skb)->frag_list && try_to_segment(skb))
1243                 goto drop;
1244
1245         stt_sock->rcv(stt_sock, skb);
1246         return;
1247 drop:
1248         /* Consume bad packet */
1249         kfree_skb_list(skb);
1250 }
1251
1252 static void tcp_sock_release(struct socket *sock)
1253 {
1254         kernel_sock_shutdown(sock, SHUT_RDWR);
1255         sock_release(sock);
1256 }
1257
1258 static int tcp_sock_create4(struct net *net, __be16 port,
1259                             struct socket **sockp)
1260 {
1261         struct sockaddr_in tcp_addr;
1262         struct socket *sock = NULL;
1263         int err;
1264
1265         err = sock_create_kern(net, AF_INET, SOCK_STREAM, IPPROTO_TCP, &sock);
1266         if (err < 0)
1267                 goto error;
1268
1269         memset(&tcp_addr, 0, sizeof(tcp_addr));
1270         tcp_addr.sin_family = AF_INET;
1271         tcp_addr.sin_addr.s_addr = htonl(INADDR_ANY);
1272         tcp_addr.sin_port = port;
1273         err = kernel_bind(sock, (struct sockaddr *)&tcp_addr,
1274                           sizeof(tcp_addr));
1275         if (err < 0)
1276                 goto error;
1277
1278         *sockp = sock;
1279         return 0;
1280
1281 error:
1282         if (sock)
1283                 tcp_sock_release(sock);
1284         *sockp = NULL;
1285         return err;
1286 }
1287
1288 static void schedule_clean_percpu(void)
1289 {
1290         schedule_delayed_work(&clean_percpu_wq, CLEAN_PERCPU_INTERVAL);
1291 }
1292
1293 static void clean_percpu(struct work_struct *work)
1294 {
1295         int i;
1296
1297         for_each_possible_cpu(i) {
1298                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1299                 int j;
1300
1301                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1302                         struct pkt_frag *frag;
1303
1304                         frag = flex_array_get(stt_percpu->frag_hash, j);
1305                         if (!frag->skbs ||
1306                             time_before(jiffies, frag->timestamp + FRAG_EXP_TIME))
1307                                 continue;
1308
1309                         spin_lock_bh(&stt_percpu->lock);
1310
1311                         if (frag->skbs &&
1312                             time_after(jiffies, frag->timestamp + FRAG_EXP_TIME))
1313                                 free_frag(stt_percpu, frag);
1314
1315                         spin_unlock_bh(&stt_percpu->lock);
1316                 }
1317         }
1318         schedule_clean_percpu();
1319 }
1320
1321 #ifdef HAVE_NF_HOOKFN_ARG_OPS
1322 #define FIRST_PARAM const struct nf_hook_ops *ops
1323 #else
1324 #define FIRST_PARAM unsigned int hooknum
1325 #endif
1326
1327 #if LINUX_VERSION_CODE >= KERNEL_VERSION(4,1,0)
1328 #define LAST_PARAM const struct nf_hook_state *state
1329 #else
1330 #define LAST_PARAM const struct net_device *in, const struct net_device *out, int (*okfn)(struct sk_buff *)
1331 #endif
1332
1333 static unsigned int nf_ip_hook(FIRST_PARAM, struct sk_buff *skb, LAST_PARAM)
1334 {
1335         struct stt_sock *stt_sock;
1336         int ip_hdr_len;
1337
1338         if (ip_hdr(skb)->protocol != IPPROTO_TCP)
1339                 return NF_ACCEPT;
1340
1341         ip_hdr_len = ip_hdrlen(skb);
1342         if (unlikely(!pskb_may_pull(skb, ip_hdr_len + sizeof(struct tcphdr))))
1343                 return NF_ACCEPT;
1344
1345         skb_set_transport_header(skb, ip_hdr_len);
1346
1347         stt_sock = stt_find_sock(dev_net(skb->dev), tcp_hdr(skb)->dest);
1348         if (!stt_sock)
1349                 return NF_ACCEPT;
1350
1351         __skb_pull(skb, ip_hdr_len + sizeof(struct tcphdr));
1352         stt_rcv(stt_sock, skb);
1353         return NF_STOLEN;
1354 }
1355
1356 static struct nf_hook_ops nf_hook_ops __read_mostly = {
1357         .hook           = nf_ip_hook,
1358         .owner          = THIS_MODULE,
1359         .pf             = NFPROTO_IPV4,
1360         .hooknum        = NF_INET_LOCAL_IN,
1361         .priority       = INT_MAX,
1362 };
1363
1364 static int stt_start(void)
1365 {
1366         int err;
1367         int i;
1368
1369         if (n_tunnels) {
1370                 n_tunnels++;
1371                 return 0;
1372         }
1373         get_random_bytes(&frag_hash_seed, sizeof(u32));
1374
1375         stt_percpu_data = alloc_percpu(struct stt_percpu);
1376         if (!stt_percpu_data) {
1377                 err = -ENOMEM;
1378                 goto error;
1379         }
1380
1381         for_each_possible_cpu(i) {
1382                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1383                 struct flex_array *frag_hash;
1384
1385                 spin_lock_init(&stt_percpu->lock);
1386                 INIT_LIST_HEAD(&stt_percpu->frag_lru);
1387                 get_random_bytes(&per_cpu(pkt_seq_counter, i), sizeof(u32));
1388
1389                 frag_hash = flex_array_alloc(sizeof(struct pkt_frag),
1390                                              FRAG_HASH_ENTRIES,
1391                                              GFP_KERNEL | __GFP_ZERO);
1392                 if (!frag_hash) {
1393                         err = -ENOMEM;
1394                         goto free_percpu;
1395                 }
1396                 stt_percpu->frag_hash = frag_hash;
1397
1398                 err = flex_array_prealloc(stt_percpu->frag_hash, 0,
1399                                           FRAG_HASH_ENTRIES,
1400                                           GFP_KERNEL | __GFP_ZERO);
1401                 if (err)
1402                         goto free_percpu;
1403         }
1404         err = nf_register_hook(&nf_hook_ops);
1405         if (err)
1406                 goto free_percpu;
1407
1408         schedule_clean_percpu();
1409         n_tunnels++;
1410         return 0;
1411
1412 free_percpu:
1413         for_each_possible_cpu(i) {
1414                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1415
1416                 if (stt_percpu->frag_hash)
1417                         flex_array_free(stt_percpu->frag_hash);
1418         }
1419
1420         free_percpu(stt_percpu_data);
1421
1422 error:
1423         return err;
1424 }
1425
1426 static void stt_cleanup(void)
1427 {
1428         int i;
1429
1430         n_tunnels--;
1431         if (n_tunnels)
1432                 return;
1433
1434         cancel_delayed_work_sync(&clean_percpu_wq);
1435         nf_unregister_hook(&nf_hook_ops);
1436
1437         for_each_possible_cpu(i) {
1438                 struct stt_percpu *stt_percpu = per_cpu_ptr(stt_percpu_data, i);
1439                 int j;
1440
1441                 for (j = 0; j < FRAG_HASH_ENTRIES; j++) {
1442                         struct pkt_frag *frag;
1443
1444                         frag = flex_array_get(stt_percpu->frag_hash, j);
1445                         kfree_skb_list(frag->skbs);
1446                 }
1447
1448                 flex_array_free(stt_percpu->frag_hash);
1449         }
1450
1451         free_percpu(stt_percpu_data);
1452 }
1453
1454 static struct stt_sock *stt_socket_create(struct net *net, __be16 port,
1455                                           stt_rcv_t *rcv, void *data)
1456 {
1457         struct stt_net *sn = net_generic(net, stt_net_id);
1458         struct stt_sock *stt_sock;
1459         struct socket *sock;
1460         int err;
1461
1462         stt_sock = kzalloc(sizeof(*stt_sock), GFP_KERNEL);
1463         if (!stt_sock)
1464                 return ERR_PTR(-ENOMEM);
1465
1466         err = tcp_sock_create4(net, port, &sock);
1467         if (err) {
1468                 kfree(stt_sock);
1469                 return ERR_PTR(err);
1470         }
1471
1472         stt_sock->sock = sock;
1473         stt_sock->rcv = rcv;
1474         stt_sock->rcv_data = data;
1475
1476         list_add_rcu(&stt_sock->list, &sn->sock_list);
1477
1478         return stt_sock;
1479 }
1480
1481 static void __stt_sock_release(struct stt_sock *stt_sock)
1482 {
1483         list_del_rcu(&stt_sock->list);
1484         tcp_sock_release(stt_sock->sock);
1485         kfree_rcu(stt_sock, rcu);
1486 }
1487
1488 struct stt_sock *rpl_stt_sock_add(struct net *net, __be16 port,
1489                               stt_rcv_t *rcv, void *data)
1490 {
1491         struct stt_sock *stt_sock;
1492         int err;
1493
1494         err = stt_start();
1495         if (err)
1496                 return ERR_PTR(err);
1497
1498         mutex_lock(&stt_mutex);
1499         rcu_read_lock();
1500         stt_sock = stt_find_sock(net, port);
1501         rcu_read_unlock();
1502         if (stt_sock)
1503                 stt_sock = ERR_PTR(-EBUSY);
1504         else
1505                 stt_sock = stt_socket_create(net, port, rcv, data);
1506
1507         mutex_unlock(&stt_mutex);
1508
1509         if (IS_ERR(stt_sock))
1510                 stt_cleanup();
1511
1512         return stt_sock;
1513 }
1514 EXPORT_SYMBOL_GPL(rpl_stt_sock_add);
1515
1516 void rpl_stt_sock_release(struct stt_sock *stt_sock)
1517 {
1518         mutex_lock(&stt_mutex);
1519         if (stt_sock) {
1520                 __stt_sock_release(stt_sock);
1521                 stt_cleanup();
1522         }
1523         mutex_unlock(&stt_mutex);
1524 }
1525 EXPORT_SYMBOL_GPL(rpl_stt_sock_release);
1526
1527 static int stt_init_net(struct net *net)
1528 {
1529         struct stt_net *sn = net_generic(net, stt_net_id);
1530
1531         INIT_LIST_HEAD(&sn->sock_list);
1532         return 0;
1533 }
1534
1535 static struct pernet_operations stt_net_ops = {
1536         .init = stt_init_net,
1537         .id   = &stt_net_id,
1538         .size = sizeof(struct stt_net),
1539 };
1540
1541 int ovs_stt_init_module(void)
1542 {
1543         return register_pernet_subsys(&stt_net_ops);
1544 }
1545 EXPORT_SYMBOL_GPL(ovs_stt_init_module);
1546
1547 void ovs_stt_cleanup_module(void)
1548 {
1549         unregister_pernet_subsys(&stt_net_ops);
1550 }
1551 EXPORT_SYMBOL_GPL(ovs_stt_cleanup_module);
1552 #endif